/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.bioimageio.tiling;

import io.bioimage.modelrunner.bioimageio.description.Axis;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.bioimageio.tiling.ImageInfo;
import io.bioimage.modelrunner.bioimageio.tiling.TileInfo;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class TileCalculator {
    private final ModelDescriptor descriptor;
    private static final long OPTIMAL_MAX_NUMBER_PIXELS = 0x3000000L;

    private TileCalculator(ModelDescriptor descriptor) {
        this.descriptor = descriptor;
    }

    public static TileCalculator init(ModelDescriptor descriptor) {
        return new TileCalculator(descriptor);
    }

    private long[] getOptimalTileSize(TensorSpec tensor, String inputAxesOrder, long[] dims) {
        boolean tiling = this.descriptor.isTilingAllowed();
        int[] halo = tensor.getAxesInfo().getHaloArr();
        int[] min = tensor.getMinTileSizeArr();
        int[] step = tensor.getTileStepArr();
        long[] patch = new long[inputAxesOrder.length()];
        String seqSizeAxesUpper = inputAxesOrder.toUpperCase();
        seqSizeAxesUpper = seqSizeAxesUpper.replace("T", "B");
        String[] axesArr = tensor.getAxesOrder().toUpperCase().split("");
        for (int ii = 0; ii < axesArr.length; ++ii) {
            int ind = seqSizeAxesUpper.indexOf(axesArr[ii]);
            int size = (int)dims[ind];
            if (step[ii] != 0 && tiling) {
                patch[ii] = (int)Math.ceil((size + 2 * halo[ii]) / step[ii]) * step[ii];
                if (patch[ii] > (long)(3 * size) && patch[ii] - (long)step[ii] >= (long)min[ii]) {
                    patch[ii] = patch[ii] - (long)step[ii];
                }
                if (patch[ii] <= (long)(3 * size) || (int)Math.ceil((double)size / (double)step[ii]) * step[ii] < min[ii]) continue;
                patch[ii] = (int)Math.ceil((double)size / (double)step[ii]) * step[ii];
                continue;
            }
            if (step[ii] != 0 && !tiling) {
                patch[ii] = -1L;
                continue;
            }
            if (step[ii] == 0 && min[ii] == -1) {
                patch[ii] = size;
                continue;
            }
            if (step[ii] != 0) continue;
            patch[ii] = min[ii];
        }
        return patch;
    }

    public List<TileInfo> getOptimalTileSize(List<ImageInfo> inputInfo) {
        boolean tiling = this.descriptor.isTilingAllowed();
        ArrayList<TileInfo> firstIterationInputs = new ArrayList<TileInfo>();
        for (TensorSpec tt : this.descriptor.getInputTensors()) {
            ImageInfo im = inputInfo.stream().filter(ii -> ii.getTensorName().equals(tt.getName())).findFirst().orElse(null);
            if (im == null) {
                throw new IllegalArgumentException("No data was provided for input tensor: " + tt.getName());
            }
            long[] tileSize = this.getOptimalTileSize(tt, im.getAxesOrder(), im.getDimensions());
            firstIterationInputs.add(TileInfo.build(tt.getName(), im.getDimensions(), im.getAxesOrder(), tileSize, im.getAxesOrder()));
        }
        if (!tiling) {
            return firstIterationInputs;
        }
        List<TensorSpec> affectedTensors = this.descriptor.getOutputTensors().stream().filter(ot -> ot.getAxesInfo().getAxesList().stream().filter(ax -> ax.getReferenceTensor() != null).findFirst().orElse(null) != null).collect(Collectors.toList());
        ArrayList<TileInfo> secondIterationInputs = new ArrayList<TileInfo>();
        for (int i = 0; i < firstIterationInputs.size(); ++i) {
            TensorSpec tensor = this.descriptor.findInputTensor(((TileInfo)firstIterationInputs.get(i)).getName());
            if (!Arrays.stream(tensor.getTileStepArr()).allMatch(ii -> ii == 0)) continue;
            secondIterationInputs.add((TileInfo)firstIterationInputs.get(i));
        }
        if (firstIterationInputs.size() == secondIterationInputs.size()) {
            return secondIterationInputs;
        }
        List<Long> outputTotByteSizes = this.calculateByteSizeOfAffectedOutput(affectedTensors, firstIterationInputs);
        return this.checkOutputSize(firstIterationInputs, affectedTensors, outputTotByteSizes);
    }

    private List<TileInfo> checkOutputSize(List<TileInfo> inputs, List<TensorSpec> affected, List<Long> outByteSizes) {
        List totInPixels = inputs.stream().map(in -> Arrays.stream(in.getTileDims()).reduce(1L, (x, y) -> x * y)).collect(Collectors.toList());
        if (totInPixels.stream().filter(oo -> oo > 0x3000000L).findFirst().orElse(null) == null && outByteSizes.stream().filter(oo -> oo > Integer.MAX_VALUE).findFirst().orElse(null) == null) {
            return inputs;
        }
        List inRatio = totInPixels.stream().map(ss -> 5.0331648E7 / (double)ss.longValue()).collect(Collectors.toList());
        List outRatio = outByteSizes.stream().map(ss -> 2.147483647E9 / (double)ss.longValue()).collect(Collectors.toList());
        block0: while ((Double)Collections.min(inRatio) < 1.0) {
            TensorSpec tt;
            Integer argmin = null;
            List sortedIndices = IntStream.range(0, inRatio.size()).boxed().sorted(Comparator.comparing(inRatio::get)).collect(Collectors.toList());
            for (Integer ind : sortedIndices) {
                tt = this.descriptor.findInputTensor(inputs.get(ind).getName());
                if (Arrays.stream(tt.getTileStepArr()).allMatch(ii -> ii == 0)) continue;
                argmin = ind;
                break;
            }
            if (argmin == null) break;
            Double startingRatio = (Double)inRatio.get(argmin);
            TileInfo in2 = inputs.get(argmin);
            tt = this.descriptor.findInputTensor(in2.getName());
            int c = 0;
            for (String ax : in2.getTileAxesOrder().split("")) {
                Axis axis = tt.getAxesInfo().getAxis(ax);
                if (axis.getStep() == 0) continue;
                long nTot = (Long)totInPixels.get(argmin) / in2.getTileDims()[c];
                in2.getTileDims()[c] = (double)in2.getTileDims()[c] * (Double)inRatio.get(argmin) < (double)axis.getMin() && axis.getMin() > 1 ? (long)((int)Math.ceil(100.0 / (double)axis.getStep()) * axis.getStep()) : ((double)in2.getTileDims()[c] * (Double)inRatio.get(argmin) < (double)axis.getMin() ? (long)axis.getMin() : (long)(Math.floor(((double)in2.getTileDims()[c] * (Double)inRatio.get(argmin) - (double)axis.getMin()) / (double)axis.getStep()) * (double)axis.getStep() + (double)axis.getMin()));
                totInPixels.set(argmin, nTot * in2.getTileDims()[c]);
                inRatio = totInPixels.stream().map(ss -> 5.0331648E7 / (double)ss.longValue()).collect(Collectors.toList());
                if (startingRatio == inRatio.get(argmin)) continue block0;
            }
        }
        if ((Double)Collections.min(outRatio = (outByteSizes = this.calculateByteSizeOfAffectedOutput(affected, null)).stream().map(ss -> 2.147483647E9 / (double)ss.longValue()).collect(Collectors.toList())) < 1.0 && (Double)Collections.min(inRatio) < 1.0) {
            throw new IllegalArgumentException("The input and/or ouput dimensions of the tensors specified by the current model are to big. JDLL is not able to run them.");
        }
        while ((Double)Collections.min(outRatio) < 1.0) {
            ArrayList finalOutRatio = new ArrayList(outRatio);
            int argmin = IntStream.range(0, finalOutRatio.size()).reduce((i, j) -> (Double)finalOutRatio.get(i) < (Double)finalOutRatio.get(j) ? i : j).getAsInt();
            Double oldRatio = (Double)outRatio.get(argmin);
            TensorSpec tt = this.descriptor.getOutputTensors().get(argmin);
            for (Axis ax : tt.getAxesInfo().getAxesList()) {
                if (ax.getReferenceTensor() == null) continue;
                TensorSpec inputT = this.descriptor.findInputTensor(ax.getReferenceTensor());
                TileInfo im = inputs.stream().filter(in -> in.getName().equals(inputT.getName())).findFirst().orElse(null);
                String refAxis = ax.getReferenceAxis();
                int index = im.getTileAxesOrder().indexOf(refAxis);
                Axis inAx = inputT.getAxesInfo().getAxis(refAxis);
                long size = im.getTileDims()[index];
                im.getTileDims()[index] = (double)size * (Double)outRatio.get(argmin) < (double)inAx.getMin() && inAx.getMin() > 1 ? (long)((int)Math.ceil(100.0 / (double)inAx.getStep()) * inAx.getStep()) : ((double)size * (Double)outRatio.get(argmin) < (double)inAx.getMin() ? (long)inAx.getMin() : (long)(Math.floor(((double)size * (Double)outRatio.get(argmin) - (double)inAx.getMin()) / (double)inAx.getStep()) * (double)inAx.getStep() + (double)inAx.getMin()));
                double change = ((double)size * ax.getScale() + 2.0 * ax.getOffset()) / ((double)im.getTileDims()[index] * ax.getScale() + 2.0 * ax.getOffset());
                outRatio.set(argmin, (Double)outRatio.get(argmin) * change);
                if (!((Double)outRatio.get(argmin) > 1.0)) continue;
                break;
            }
            if (outRatio.get(argmin) != oldRatio) continue;
            break;
        }
        if ((Double)Collections.min(outRatio) < 1.0) {
            throw new IllegalArgumentException("Due to the model specifications, the size of one of the output tensors exceeds the limit of tensor size in JDLL: 2147483647");
        }
        return null;
    }

    public void getTilesForNPixels(String tensorName, long[] dims, String inputAxesOrder) {
    }

    public void getForNTiles(int nTiles, String tensorName, long[] dims, String inputAxesOrder) {
    }

    private List<Long> calculateByteSizeOfAffectedOutput(List<TensorSpec> outputTensors, List<TileInfo> inputSize) {
        if (outputTensors == null || outputTensors.size() == 0) {
            return new ArrayList<Long>();
        }
        List outTiles = outputTensors.stream().map(t -> new long[t.getAxesInfo().getAxesList().size()]).collect(Collectors.toList());
        for (int i = 0; i < outputTensors.size(); ++i) {
            int j;
            TensorSpec tt = outputTensors.get(i);
            ArrayList<String> referencesList = new ArrayList<String>();
            for (j = 0; j < outputTensors.get(i).getAxesInfo().getAxesList().size(); ++j) {
                Axis ax = tt.getAxesInfo().getAxesList().get(j);
                String refName = ax.getReferenceTensor();
                if (refName == null && ax.getMin() != 0) {
                    ((long[])outTiles.get((int)i))[j] = ax.getMin();
                    continue;
                }
                if (refName == null) {
                    ((long[])outTiles.get((int)i))[j] = -1L;
                    continue;
                }
                referencesList.add(refName);
                String refAxisStr = ax.getReferenceAxis();
                TensorSpec refTensor = this.descriptor.findInputTensor(refName);
                long[] refTileSize = ((TileInfo)inputSize.stream().filter(tile -> tile.getName().equals(refName)).findFirst().orElse(null)).getTileDims();
                String axesOrder = refTensor.getAxesOrder();
                ((long[])outTiles.get((int)i))[j] = (long)((double)refTileSize[axesOrder.indexOf(refAxisStr)] * ax.getScale() + 2.0 * ax.getOffset());
            }
            if (referencesList.stream().distinct().count() != 1L) {
                throw new IllegalArgumentException("Model specs too complex for JDLL. Please contact the team and create and issue attaching the rdf.yaml file so we can troubleshoot at: https://github.com/bioimage-io/JDLL/issues");
            }
            for (j = 0; j < outputTensors.get(i).getAxesInfo().getAxesList().size(); ++j) {
                if (((long[])outTiles.get(i))[j] != -1L) continue;
                TensorSpec refInput = this.descriptor.findInputTensor((String)referencesList.get(0));
                int ind = refInput.getAxesOrder().indexOf(outputTensors.get(i).getAxesInfo().getAxesList().get(j).getAxis());
                if (ind == -1) {
                    throw new IllegalArgumentException("Model specs too complex for JDLL. Please contact the team and create and issue attaching the rdf.yaml file so we can troubleshoot at: https://github.com/bioimage-io/JDLL/issues");
                }
                long[] refTileSize = ((TileInfo)inputSize.stream().filter(tile -> tile.getName().equals(referencesList.get(0))).findFirst().orElse(null)).getTileDims();
                ((long[])outTiles.get((int)i))[j] = refTileSize[ind];
            }
        }
        List<Long> flatSizes = outTiles.stream().map(arr -> {
            long a = 1L;
            for (long l : arr) {
                a *= l;
            }
            return a;
        }).collect(Collectors.toList());
        for (int i = 0; i < flatSizes.size(); ++i) {
            if (outputTensors.get(i).getDataType().toLowerCase().equals("float32") || outputTensors.get(i).getDataType().toLowerCase().equals("int32") || outputTensors.get(i).getDataType().toLowerCase().equals("uint32")) {
                flatSizes.set(i, flatSizes.get(i) * 4L);
                continue;
            }
            if (outputTensors.get(i).getDataType().toLowerCase().equals("int16") || outputTensors.get(i).getDataType().toLowerCase().equals("uint16")) {
                flatSizes.set(i, flatSizes.get(i) * 2L);
                continue;
            }
            if (!outputTensors.get(i).getDataType().toLowerCase().equals("int64") && !outputTensors.get(i).getDataType().toLowerCase().equals("float64")) continue;
            flatSizes.set(i, flatSizes.get(i) * 8L);
        }
        return flatSizes;
    }
}

