/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.transformations;

import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.transformations.AbstractTensorTransformation;
import io.bioimage.modelrunner.transformations.TensorTransformation;
import java.util.Arrays;
import java.util.List;
import net.imglib2.Cursor;
import net.imglib2.IterableInterval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImg;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

public class ScaleRangeTransformation
extends AbstractTensorTransformation {
    private static final String name = "scale_range";
    private double minPercentile = 0.0;
    private double maxPercentile = 100.0;
    private String axes;
    private String tensorName;
    private double eps = Math.pow(10.0, -6.0);

    public ScaleRangeTransformation() {
        super(name);
        this.mode = TensorTransformation.Mode.PER_SAMPLE;
    }

    public void setEps(Object eps) {
        if (eps instanceof Integer) {
            this.eps = ((Integer)eps).intValue();
        } else if (eps instanceof Double) {
            this.eps = (Double)eps;
        } else if (eps instanceof String) {
            this.eps = Double.valueOf((String)eps);
        } else {
            throw new IllegalArgumentException("'eps' parameter has to be either and instance of " + Float.class + " or " + Double.class + ". The provided argument is an instance of: " + eps.getClass());
        }
    }

    public void setMinPercentile(Object minPercentile) {
        if (minPercentile instanceof Integer) {
            this.minPercentile = Double.valueOf(((Integer)minPercentile).intValue()) / 100.0;
        } else if (minPercentile instanceof Double) {
            this.minPercentile = (Double)minPercentile / 100.0;
        } else if (minPercentile instanceof String) {
            this.minPercentile = Double.valueOf((String)minPercentile) / 100.0;
        } else {
            throw new IllegalArgumentException("'minPercentile' parameter has to be either and instance of " + Integer.class + " or " + Double.class + ". The provided argument is an instance of: " + minPercentile.getClass());
        }
    }

    public void setMaxPercentile(Object maxPercentile) {
        if (maxPercentile instanceof Integer) {
            this.maxPercentile = Double.valueOf(((Integer)maxPercentile).intValue()) / 100.0;
        } else if (maxPercentile instanceof Double) {
            this.maxPercentile = (Double)maxPercentile / 100.0;
        } else if (maxPercentile instanceof String) {
            this.maxPercentile = Double.valueOf((String)maxPercentile) / 100.0;
        } else {
            throw new IllegalArgumentException("'maxPercentile' parameter has to be either and instance of " + Integer.class + " or " + Double.class + ". The provided argument is an instance of: " + maxPercentile.getClass());
        }
    }

    public void setReferenceTensor(Object refTensor) {
        System.err.println("JDLL still does not support this processing. Please create an issue at https://github.com/bioimage-io/JDLL/issues referencing this model.");
    }

    public void setAxes(Object axes) {
        if (axes instanceof String && ((String)axes).equals("channel")) {
            this.axes = "c";
        } else if (axes instanceof String) {
            this.axes = (String)axes;
        } else if (axes instanceof List) {
            this.axes = "";
            for (Object ax : (List)axes) {
                if (!(ax instanceof String)) {
                    throw new IllegalArgumentException("JDLL does not currently support this axes format. Please write an issue attaching the rdf.yaml file at: https://github.com/bioimage-io/JDLL/issues");
                }
                ax = ax.equals("channel") ? "c" : ax;
                this.axes = this.axes + ax;
            }
        } else if (axes instanceof String[]) {
            String[] axesArr = (String[])axes;
            this.axes = "";
            for (String ax : axesArr) {
                ax = ax.equals("channel") ? "c" : ax;
                this.axes = this.axes + ax;
            }
        } else {
            throw new IllegalArgumentException("'axes' parameter has to be an instance of " + String.class + ", of a String array or of a List of Strings. The provided argument is " + axes.getClass());
        }
    }

    public void setAxis(Object axes) {
        if (axes instanceof String && ((String)axes).equals("channel")) {
            this.axes = "c";
        } else if (axes instanceof String) {
            this.axes = (String)axes;
        } else if (axes instanceof List) {
            this.axes = "";
            for (Object ax : (List)axes) {
                if (!(ax instanceof String)) {
                    throw new IllegalArgumentException("JDLL does not currently support this axes format. Please write an issue attaching the rdf.yaml file at: https://github.com/bioimage-io/JDLL/issues");
                }
                ax = ax.equals("channel") ? "c" : ax;
                this.axes = this.axes + ax;
            }
        } else if (axes instanceof String[]) {
            String[] axesArr = (String[])axes;
            this.axes = "";
            for (String ax : axesArr) {
                ax = ax.equals("channel") ? "c" : ax;
                this.axes = this.axes + ax;
            }
        } else {
            throw new IllegalArgumentException("'axes' parameter has to be an instance of " + String.class + ", of a String array or of a List of Strings. The provided argument is " + axes.getClass());
        }
    }

    public void setTensorName(Object tensorName) {
        if (!(tensorName instanceof String)) {
            throw new IllegalArgumentException("'tensorName' parameter has to be an instance of " + String.class + ". The provided argument is " + tensorName.getClass());
        }
        this.tensorName = (String)tensorName;
    }

    @Override
    public <R extends RealType<R> & NativeType<R>> Tensor<FloatType> apply(Tensor<R> input) {
        Tensor<FloatType> output = this.makeOutput(input);
        this.applyInPlace(output);
        return output;
    }

    @Override
    public <R extends RealType<R> & NativeType<R>> void applyInPlace(Tensor<R> input) {
        String selectedAxes = "";
        for (String ax : input.getAxesOrderString().split("")) {
            if (this.axes == null || this.axes.toLowerCase().contains(ax.toLowerCase()) || ax.toLowerCase().equals("b")) continue;
            selectedAxes = selectedAxes + ax;
        }
        if (this.axes == null || selectedAxes.equals("") || input.getAxesOrderString().replace("b", "").length() == selectedAxes.length()) {
            this.globalScale(input);
        } else if (this.axes.length() > 0) {
            this.axesScale(input, selectedAxes);
        }
    }

    private <R extends RealType<R> & NativeType<R>> void globalScale(Tensor<R> output) {
        double minPercentileVal = this.findPercentileValue(output.getData(), this.minPercentile);
        double maxPercentileVal = this.findPercentileValue(output.getData(), this.maxPercentile);
        this.scaleRange(output.getData(), maxPercentileVal, minPercentileVal);
    }

    private <R extends RealType<R> & NativeType<R>> double findPercentileValue(RandomAccessibleInterval<R> rai, double percentile) {
        IterableInterval flatImage = Views.iterable(rai);
        long flatSize = Arrays.stream(flatImage.dimensionsAsLongArray()).reduce(1L, (a, b) -> a * b);
        double[] flatArr = new double[(int)flatSize];
        int count = 0;
        Cursor cursor = flatImage.cursor();
        while (cursor.hasNext()) {
            cursor.next();
            flatArr[count++] = ((RealType)cursor.get()).getRealDouble();
        }
        Arrays.sort(flatArr);
        int percentilePos = (int)((double)flatSize * percentile);
        percentilePos = percentilePos >= flatArr.length ? flatArr.length - 1 : percentilePos;
        return flatArr[percentilePos];
    }

    private <R extends RealType<R> & NativeType<R>> void axesScale(Tensor<R> output, String axesOfInterest) {
        long[][] points;
        int i;
        long[] start = new long[output.getData().numDimensions()];
        long[] dims = output.getData().dimensionsAsLongArray();
        long[] indOfDims = new long[axesOfInterest.length()];
        long[] sizeOfDims = new long[axesOfInterest.length()];
        for (i = 0; i < indOfDims.length; ++i) {
            indOfDims[i] = output.getAxesOrderString().indexOf(axesOfInterest.split("")[i]);
        }
        for (i = 0; i < sizeOfDims.length; ++i) {
            sizeOfDims[i] = dims[(int)indOfDims[i]];
        }
        for (long[] pp : points = ScaleRangeTransformation.getAllCombinations(sizeOfDims)) {
            for (int i2 = 0; i2 < pp.length; ++i2) {
                start[(int)indOfDims[i2]] = pp[i2];
                dims[(int)indOfDims[i2]] = pp[i2] + 1L;
            }
            long[] end = new long[dims.length];
            for (int i3 = 0; i3 < dims.length; ++i3) {
                end[i3] = dims[i3] - start[i3];
            }
            IntervalView plane = Views.offsetInterval(output.getData(), (long[])start, (long[])end);
            double minPercentileVal = this.findPercentileValue((RandomAccessibleInterval<R>)plane, this.minPercentile);
            double maxPercentileVal = this.findPercentileValue((RandomAccessibleInterval<R>)plane, this.maxPercentile);
            this.scaleRange((RandomAccessibleInterval<R>)plane, maxPercentileVal, minPercentileVal);
        }
    }

    private static long[][] getAllCombinations(long[] arr) {
        long n = 1L;
        for (long nn : arr) {
            n *= nn;
        }
        long[][] allPoints = new long[(int)n][arr.length];
        int i = 0;
        while ((long)i < n) {
            for (int j = 0; j < arr.length; ++j) {
                int factor = 1;
                for (int k = 0; k < j; ++k) {
                    factor = (int)((long)factor * arr[k]);
                }
                int auxVal = i / factor;
                int val = auxVal % (int)arr[j];
                allPoints[i][j] = val;
            }
            ++i;
        }
        return allPoints;
    }

    public static void main(String[] args) {
        ScaleRangeTransformation.test1();
        ScaleRangeTransformation.test2();
    }

    public static void test1() {
        float[] arr = new float[9];
        for (int i = 0; i < arr.length; ++i) {
            arr[i] = i;
        }
        ArrayImg rai = ArrayImgs.floats((float[])arr, (long[])new long[]{3L, 3L});
        ScaleRangeTransformation preprocessing = new ScaleRangeTransformation();
        Tensor tt = Tensor.build("name", "xy", rai);
        preprocessing.applyInPlace(tt);
        System.out.print(true);
    }

    public static void test2() {
        float[] arr = new float[18];
        for (int i = 0; i < arr.length; ++i) {
            arr[i] = i;
        }
        ArrayImg rai = ArrayImgs.floats((float[])arr, (long[])new long[]{3L, 3L, 2L});
        ScaleRangeTransformation preprocessing = new ScaleRangeTransformation();
        preprocessing.setAxes("xy");
        preprocessing.setMaxPercentile(99);
        preprocessing.setMinPercentile(1);
        Tensor tt = Tensor.build("name", "xyc", rai);
        preprocessing.applyInPlace(tt);
        System.out.print(true);
    }

    public <R extends RealType<R> & NativeType<R>> void scaleRange(RandomAccessibleInterval<R> rai, double maxPercentileVal, double minPercentileVal) {
        double diff = maxPercentileVal - minPercentileVal;
        RealType type = (RealType)Util.getTypeFromInterval(rai);
        if (type instanceof IntegerType) {
            LoopBuilder.setImages(rai).multiThreaded().forEachPixel(i -> i.setReal(Math.floor((i.getRealDouble() - minPercentileVal) / (diff + this.eps))));
        } else {
            LoopBuilder.setImages(rai).multiThreaded().forEachPixel(i -> i.setReal((i.getRealDouble() - minPercentileVal) / (diff + this.eps)));
        }
    }
}

