package io.bioimage.modelrunner.bioimageio.tiling;

import io.bioimage.modelrunner.bioimageio.description.Axis;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

/* loaded from: input_file:io/bioimage/modelrunner/bioimageio/tiling/TileCalculator.class */
public class TileCalculator {
    private final ModelDescriptor descriptor;
    private static final long OPTIMAL_MAX_NUMBER_PIXELS = 50331648;

    private TileCalculator(ModelDescriptor modelDescriptor) {
        this.descriptor = modelDescriptor;
    }

    public static TileCalculator init(ModelDescriptor modelDescriptor) {
        return new TileCalculator(modelDescriptor);
    }

    private long[] getOptimalTileSize(TensorSpec tensorSpec, String str, long[] jArr) {
        boolean isTilingAllowed = this.descriptor.isTilingAllowed();
        int[] haloArr = tensorSpec.getAxesInfo().getHaloArr();
        int[] minTileSizeArr = tensorSpec.getMinTileSizeArr();
        int[] tileStepArr = tensorSpec.getTileStepArr();
        long[] jArr2 = new long[str.length()];
        String replace = str.toUpperCase().replace("T", "B");
        String[] split = tensorSpec.getAxesOrder().toUpperCase().split("");
        for (int i = 0; i < split.length; i++) {
            int i2 = (int) jArr[replace.indexOf(split[i])];
            if (tileStepArr[i] != 0 && isTilingAllowed) {
                jArr2[i] = ((int) Math.ceil((i2 + (2 * haloArr[i])) / tileStepArr[i])) * tileStepArr[i];
                if (jArr2[i] > 3 * i2 && jArr2[i] - tileStepArr[i] >= minTileSizeArr[i]) {
                    jArr2[i] = jArr2[i] - tileStepArr[i];
                }
                if (jArr2[i] > 3 * i2 && ((int) Math.ceil(i2 / tileStepArr[i])) * tileStepArr[i] >= minTileSizeArr[i]) {
                    jArr2[i] = ((int) Math.ceil(i2 / tileStepArr[i])) * tileStepArr[i];
                }
            } else if (tileStepArr[i] != 0 && !isTilingAllowed) {
                jArr2[i] = -1;
            } else if (tileStepArr[i] == 0 && minTileSizeArr[i] == -1) {
                jArr2[i] = i2;
            } else if (tileStepArr[i] == 0) {
                jArr2[i] = minTileSizeArr[i];
            }
        }
        return jArr2;
    }

    public List<TileInfo> getOptimalTileSize(List<ImageInfo> list) {
        boolean isTilingAllowed = this.descriptor.isTilingAllowed();
        ArrayList arrayList = new ArrayList();
        for (TensorSpec tensorSpec : this.descriptor.getInputTensors()) {
            ImageInfo orElse = list.stream().filter(imageInfo -> {
                return imageInfo.getTensorName().equals(tensorSpec.getName());
            }).findFirst().orElse(null);
            if (orElse == null) {
                throw new IllegalArgumentException("No data was provided for input tensor: " + tensorSpec.getName());
            }
            arrayList.add(TileInfo.build(tensorSpec.getName(), orElse.getDimensions(), orElse.getAxesOrder(), getOptimalTileSize(tensorSpec, orElse.getAxesOrder(), orElse.getDimensions()), orElse.getAxesOrder()));
        }
        if (!isTilingAllowed) {
            return arrayList;
        }
        List<TensorSpec> list2 = (List) this.descriptor.getOutputTensors().stream().filter(tensorSpec2 -> {
            return tensorSpec2.getAxesInfo().getAxesList().stream().filter(axis -> {
                return axis.getReferenceTensor() != null;
            }).findFirst().orElse(null) != null;
        }).collect(Collectors.toList());
        ArrayList arrayList2 = new ArrayList();
        for (int i = 0; i < arrayList.size(); i++) {
            if (Arrays.stream(this.descriptor.findInputTensor(arrayList.get(i).getName()).getTileStepArr()).allMatch(i2 -> {
                return i2 == 0;
            })) {
                arrayList2.add(arrayList.get(i));
            }
        }
        return arrayList.size() == arrayList2.size() ? arrayList2 : checkOutputSize(arrayList, list2, calculateByteSizeOfAffectedOutput(list2, arrayList));
    }

    private List<TileInfo> checkOutputSize(List<TileInfo> list, List<TensorSpec> list2, List<Long> list3) {
        List list4 = (List) list.stream().map(tileInfo -> {
            return Long.valueOf(Arrays.stream(tileInfo.getTileDims()).reduce(1L, (j, j2) -> {
                return j * j2;
            }));
        }).collect(Collectors.toList());
        if (list4.stream().filter(l -> {
            return l.longValue() > OPTIMAL_MAX_NUMBER_PIXELS;
        }).findFirst().orElse(null) == null && list3.stream().filter(l2 -> {
            return l2.longValue() > 2147483647L;
        }).findFirst().orElse(null) == null) {
            return list;
        }
        List list5 = (List) list4.stream().map(l3 -> {
            return Double.valueOf(5.0331648E7d / l3.longValue());
        }).collect(Collectors.toList());
        while (((Double) Collections.min(list5)).doubleValue() < 1.0d) {
            Integer num = null;
            Stream<Integer> boxed = IntStream.range(0, list5.size()).boxed();
            List list6 = list5;
            Objects.requireNonNull(list6);
            Iterator it = ((List) boxed.sorted(Comparator.comparing((v1) -> {
                return r1.get(v1);
            })).collect(Collectors.toList())).iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Integer num2 = (Integer) it.next();
                if (!Arrays.stream(this.descriptor.findInputTensor(list.get(num2.intValue()).getName()).getTileStepArr()).allMatch(i -> {
                    return i == 0;
                })) {
                    num = num2;
                    break;
                }
            }
            if (num == null) {
                break;
            }
            Double d = (Double) list5.get(num.intValue());
            TileInfo tileInfo2 = list.get(num.intValue());
            TensorSpec findInputTensor = this.descriptor.findInputTensor(tileInfo2.getName());
            for (String str : tileInfo2.getTileAxesOrder().split("")) {
                Axis axis = findInputTensor.getAxesInfo().getAxis(str);
                if (axis.getStep() != 0) {
                    long longValue = ((Long) list4.get(num.intValue())).longValue() / tileInfo2.getTileDims()[0];
                    if (tileInfo2.getTileDims()[0] * ((Double) list5.get(num.intValue())).doubleValue() < axis.getMin() && axis.getMin() > 1) {
                        tileInfo2.getTileDims()[0] = ((int) Math.ceil(100.0d / axis.getStep())) * axis.getStep();
                    } else if (tileInfo2.getTileDims()[0] * ((Double) list5.get(num.intValue())).doubleValue() < axis.getMin()) {
                        tileInfo2.getTileDims()[0] = axis.getMin();
                    } else {
                        tileInfo2.getTileDims()[0] = (long) ((Math.floor(((tileInfo2.getTileDims()[0] * ((Double) list5.get(num.intValue())).doubleValue()) - axis.getMin()) / axis.getStep()) * axis.getStep()) + axis.getMin());
                    }
                    list4.set(num.intValue(), Long.valueOf(longValue * tileInfo2.getTileDims()[0]));
                    list5 = (List) list4.stream().map(l4 -> {
                        return Double.valueOf(5.0331648E7d / l4.longValue());
                    }).collect(Collectors.toList());
                    if (d == list5.get(num.intValue())) {
                        break;
                    }
                }
            }
        }
        List list7 = (List) calculateByteSizeOfAffectedOutput(list2, null).stream().map(l5 -> {
            return Double.valueOf(2.147483647E9d / l5.longValue());
        }).collect(Collectors.toList());
        if (((Double) Collections.min(list7)).doubleValue() < 1.0d && ((Double) Collections.min(list5)).doubleValue() < 1.0d) {
            throw new IllegalArgumentException("The input and/or ouput dimensions of the tensors specified by the current model are to big. JDLL is not able to run them.");
        }
        while (((Double) Collections.min(list7)).doubleValue() < 1.0d) {
            ArrayList arrayList = new ArrayList(list7);
            int asInt = IntStream.range(0, arrayList.size()).reduce((i2, i3) -> {
                return ((Double) arrayList.get(i2)).doubleValue() < ((Double) arrayList.get(i3)).doubleValue() ? i2 : i3;
            }).getAsInt();
            Double d2 = (Double) list7.get(asInt);
            for (Axis axis2 : this.descriptor.getOutputTensors().get(asInt).getAxesInfo().getAxesList()) {
                if (axis2.getReferenceTensor() != null) {
                    TensorSpec findInputTensor2 = this.descriptor.findInputTensor(axis2.getReferenceTensor());
                    TileInfo orElse = list.stream().filter(tileInfo3 -> {
                        return tileInfo3.getName().equals(findInputTensor2.getName());
                    }).findFirst().orElse(null);
                    String referenceAxis = axis2.getReferenceAxis();
                    int indexOf = orElse.getTileAxesOrder().indexOf(referenceAxis);
                    Axis axis3 = findInputTensor2.getAxesInfo().getAxis(referenceAxis);
                    long j = orElse.getTileDims()[indexOf];
                    if (j * ((Double) list7.get(asInt)).doubleValue() < axis3.getMin() && axis3.getMin() > 1) {
                        orElse.getTileDims()[indexOf] = ((int) Math.ceil(100.0d / axis3.getStep())) * axis3.getStep();
                    } else if (j * ((Double) list7.get(asInt)).doubleValue() < axis3.getMin()) {
                        orElse.getTileDims()[indexOf] = axis3.getMin();
                    } else {
                        orElse.getTileDims()[indexOf] = (long) ((Math.floor(((j * ((Double) list7.get(asInt)).doubleValue()) - axis3.getMin()) / axis3.getStep()) * axis3.getStep()) + axis3.getMin());
                    }
                    list7.set(asInt, Double.valueOf(((Double) list7.get(asInt)).doubleValue() * (((j * axis2.getScale()) + (2.0d * axis2.getOffset())) / ((orElse.getTileDims()[indexOf] * axis2.getScale()) + (2.0d * axis2.getOffset())))));
                    if (((Double) list7.get(asInt)).doubleValue() > 1.0d) {
                        break;
                    }
                }
            }
            if (list7.get(asInt) == d2) {
                break;
            }
        }
        if (((Double) Collections.min(list7)).doubleValue() < 1.0d) {
            throw new IllegalArgumentException("Due to the model specifications, the size of one of the output tensors exceeds the limit of tensor size in JDLL: 2147483647");
        }
        return null;
    }

    public void getTilesForNPixels(String str, long[] jArr, String str2) {
    }

    public void getForNTiles(int i, String str, long[] jArr, String str2) {
    }

    private List<Long> calculateByteSizeOfAffectedOutput(List<TensorSpec> list, List<TileInfo> list2) {
        if (list == null || list.size() == 0) {
            return new ArrayList();
        }
        List list3 = (List) list.stream().map(tensorSpec -> {
            return new long[tensorSpec.getAxesInfo().getAxesList().size()];
        }).collect(Collectors.toList());
        for (int i = 0; i < list.size(); i++) {
            TensorSpec tensorSpec2 = list.get(i);
            ArrayList arrayList = new ArrayList();
            for (int i2 = 0; i2 < list.get(i).getAxesInfo().getAxesList().size(); i2++) {
                Axis axis = tensorSpec2.getAxesInfo().getAxesList().get(i2);
                String referenceTensor = axis.getReferenceTensor();
                if (referenceTensor == null && axis.getMin() != 0) {
                    ((long[]) list3.get(i))[i2] = axis.getMin();
                } else if (referenceTensor == null) {
                    ((long[]) list3.get(i))[i2] = -1;
                } else {
                    arrayList.add(referenceTensor);
                    String referenceAxis = axis.getReferenceAxis();
                    TensorSpec findInputTensor = this.descriptor.findInputTensor(referenceTensor);
                    ((long[]) list3.get(i))[i2] = (long) ((list2.stream().filter(tileInfo -> {
                        return tileInfo.getName().equals(referenceTensor);
                    }).findFirst().orElse(null).getTileDims()[findInputTensor.getAxesOrder().indexOf(referenceAxis)] * axis.getScale()) + (2.0d * axis.getOffset()));
                }
            }
            if (arrayList.stream().distinct().count() != 1) {
                throw new IllegalArgumentException("Model specs too complex for JDLL. Please contact the team and create and issue attaching the rdf.yaml file so we can troubleshoot at: https://github.com/bioimage-io/JDLL/issues");
            }
            for (int i3 = 0; i3 < list.get(i).getAxesInfo().getAxesList().size(); i3++) {
                if (((long[]) list3.get(i))[i3] == -1) {
                    int indexOf = this.descriptor.findInputTensor((String) arrayList.get(0)).getAxesOrder().indexOf(list.get(i).getAxesInfo().getAxesList().get(i3).getAxis());
                    if (indexOf == -1) {
                        throw new IllegalArgumentException("Model specs too complex for JDLL. Please contact the team and create and issue attaching the rdf.yaml file so we can troubleshoot at: https://github.com/bioimage-io/JDLL/issues");
                    }
                    ((long[]) list3.get(i))[i3] = list2.stream().filter(tileInfo2 -> {
                        return tileInfo2.getName().equals(arrayList.get(0));
                    }).findFirst().orElse(null).getTileDims()[indexOf];
                }
            }
        }
        List<Long> list4 = (List) list3.stream().map(jArr -> {
            long j = 1;
            for (long j2 : jArr) {
                j *= j2;
            }
            return Long.valueOf(j);
        }).collect(Collectors.toList());
        for (int i4 = 0; i4 < list4.size(); i4++) {
            if (list.get(i4).getDataType().toLowerCase().equals("float32") || list.get(i4).getDataType().toLowerCase().equals("int32") || list.get(i4).getDataType().toLowerCase().equals("uint32")) {
                list4.set(i4, Long.valueOf(list4.get(i4).longValue() * 4));
            } else if (list.get(i4).getDataType().toLowerCase().equals("int16") || list.get(i4).getDataType().toLowerCase().equals("uint16")) {
                list4.set(i4, Long.valueOf(list4.get(i4).longValue() * 2));
            } else if (list.get(i4).getDataType().toLowerCase().equals("int64") || list.get(i4).getDataType().toLowerCase().equals("float64")) {
                list4.set(i4, Long.valueOf(list4.get(i4).longValue() * 8));
            }
        }
        return list4;
    }
}
