/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.transformations;

import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.transformations.AbstractTensorTransformation;
import io.bioimage.modelrunner.transformations.TensorTransformation;
import java.util.ArrayList;
import java.util.List;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImg;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.IntegerType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

public class ZeroMeanUnitVarianceTransformation
extends AbstractTensorTransformation {
    private static String name = "zero_mean_unit_variace";
    private Double meanDouble;
    private Double stdDouble;
    private double[] meanArr;
    private double[] stdArr;
    private String axes;
    private double eps = Math.pow(10.0, -6.0);
    private static String FIXED_MODE_ERR = "If the mode is 'fixed', the parameters 'mean' and 'std need to be specified";
    private static String NOT_FIXED_MODE_ERR = "Only the mode 'fixed' requires providing the 'std' and 'mean parameters.";

    public ZeroMeanUnitVarianceTransformation() {
        super(name);
        this.mode = TensorTransformation.Mode.PER_SAMPLE;
    }

    public void setEps(Object eps) {
        if (eps instanceof Integer) {
            this.eps = ((Integer)eps).intValue();
        } else if (eps instanceof Double) {
            this.eps = (Double)eps;
        } else if (eps instanceof String) {
            this.eps = Double.valueOf((String)eps);
        } else {
            throw new IllegalArgumentException("'eps' parameter has to be either and instance of " + Float.class + " or " + Double.class + ". The provided argument is an instance of: " + eps.getClass());
        }
    }

    public void setMean(Object mean) {
        if (mean instanceof Integer) {
            this.meanDouble = (double)((Integer)mean);
        } else if (mean instanceof Double) {
            this.meanDouble = (double)((Double)mean);
        } else if (mean instanceof String) {
            this.meanDouble = Double.valueOf((String)mean);
        } else if (mean instanceof ArrayList) {
            this.meanArr = new double[((ArrayList)mean).size()];
            int c = 0;
            for (Object elem : (ArrayList)mean) {
                if (elem instanceof Integer) {
                    this.meanArr[c++] = ((Integer)elem).intValue();
                    continue;
                }
                if (elem instanceof Double) {
                    this.meanArr[c++] = (Double)elem;
                    continue;
                }
                if (elem instanceof ArrayList) {
                    throw new IllegalArgumentException("'mean' parameter cannot be an ArrayList containing another ArrayList. At the moment, only transformations of planes is allowed.");
                }
                throw new IllegalArgumentException("If the 'mean' parameter is an array, its elements  have to be instances of" + Integer.class + " or " + Double.class + ". The provided ArrayList contains instances of: " + elem.getClass());
            }
        } else {
            throw new IllegalArgumentException("'mean' parameter has to be either and instance of " + Integer.class + ", " + Double.class + " or " + ArrayList.class + ". The provided argument is an instance of: " + mean.getClass());
        }
    }

    public void setStd(Object std) {
        if (std instanceof Integer) {
            this.stdDouble = (double)((Integer)std);
        } else if (std instanceof Double) {
            this.stdDouble = (double)((Double)std);
        } else if (std instanceof String) {
            this.stdDouble = Double.valueOf((String)std);
        } else if (std instanceof ArrayList) {
            this.stdArr = new double[((ArrayList)std).size()];
            int c = 0;
            for (Object elem : (ArrayList)std) {
                if (elem instanceof Integer) {
                    this.stdArr[c++] = ((Integer)elem).intValue();
                    continue;
                }
                if (elem instanceof Double) {
                    this.stdArr[c++] = (Double)elem;
                    continue;
                }
                if (elem instanceof ArrayList) {
                    throw new IllegalArgumentException("'std' parameter cannot be an ArrayList containing another ArrayList. At the moment, only transformations of planes is allowed.");
                }
                throw new IllegalArgumentException("If the 'std' parameter is an array, its elements  have to be instances of" + Integer.class + " or " + Double.class + ". The provided ArrayList contains instances of: " + elem.getClass());
            }
        } else {
            throw new IllegalArgumentException("'std' parameter has to be either and instance of " + Integer.class + ", " + Double.class + " or " + ArrayList.class + ". The provided argument is an instance of: " + std.getClass());
        }
    }

    public void setAxes(Object axes) {
        if (axes instanceof String && ((String)axes).equals("channel")) {
            this.axes = "c";
        } else if (axes instanceof String) {
            this.axes = (String)axes;
        } else if (axes instanceof List) {
            this.axes = "";
            for (Object ax : (List)axes) {
                if (!(ax instanceof String)) {
                    throw new IllegalArgumentException("JDLL does not currently support this axes format. Please write an issue attaching the rdf.yaml file at: https://github.com/bioimage-io/JDLL/issues");
                }
                ax = ax.equals("channel") ? "c" : ax;
                this.axes = this.axes + ax;
            }
        } else if (axes instanceof String[]) {
            String[] axesArr = (String[])axes;
            this.axes = "";
            for (String ax : axesArr) {
                ax = ax.equals("channel") ? "c" : ax;
                this.axes = this.axes + ax;
            }
        } else {
            throw new IllegalArgumentException("'axes' parameter has to be an instance of " + String.class + ", of a String array or of a List of Strings. The provided argument is " + axes.getClass());
        }
    }

    public void setAxis(Object axes) {
        if (axes instanceof String && ((String)axes).equals("channel")) {
            this.axes = "c";
        } else if (axes instanceof String) {
            this.axes = (String)axes;
        } else if (axes instanceof List) {
            this.axes = "";
            for (Object ax : (List)axes) {
                if (!(ax instanceof String)) {
                    throw new IllegalArgumentException("JDLL does not currently support this axes format. Please write an issue attaching the rdf.yaml file at: https://github.com/bioimage-io/JDLL/issues");
                }
                ax = ax.equals("channel") ? "c" : ax;
                this.axes = this.axes + ax;
            }
        } else if (axes instanceof String[]) {
            String[] axesArr = (String[])axes;
            this.axes = "";
            for (String ax : axesArr) {
                ax = ax.equals("channel") ? "c" : ax;
                this.axes = this.axes + ax;
            }
        } else {
            throw new IllegalArgumentException("'axes' parameter has to be an instance of " + String.class + ", of a String array or of a List of Strings. The provided argument is " + axes.getClass());
        }
    }

    public void checkRequiredArgs() {
        if (this.mode == TensorTransformation.Mode.FIXED && this.meanArr == null && this.meanDouble == null) {
            throw new IllegalArgumentException(String.format(DEFAULT_MISSING_ARG_ERR, name, "mean") + System.lineSeparator() + "If 'mode' parameter equals 'fixed', the 'mean' argument should be provided too.");
        }
        if (this.mode == TensorTransformation.Mode.FIXED && this.stdArr == null && this.stdDouble == null) {
            throw new IllegalArgumentException(String.format(DEFAULT_MISSING_ARG_ERR, name, "std") + System.lineSeparator() + "If 'mode' parameter equals 'fixed', the 'std' argument should be provided too.");
        }
        if (this.mode == TensorTransformation.Mode.FIXED && (this.stdDouble == null && this.meanDouble != null || this.stdDouble != null && this.meanDouble == null)) {
            throw new IllegalArgumentException("Both arguments 'mean' and 'std' need to be of the same type. Either a single value or an array.");
        }
        if (this.mode == TensorTransformation.Mode.FIXED && this.meanArr != null && this.axes == null) {
            throw new IllegalArgumentException("If 'mean' and 'std' are provided as arrays and 'mode' is 'fixed', the corresponding 'axes' argument should be provided too.");
        }
    }

    @Override
    public <R extends RealType<R> & NativeType<R>> Tensor<FloatType> apply(Tensor<R> input) {
        this.checkRequiredArgs();
        Tensor<FloatType> output = this.makeOutput(input);
        this.applyInPlace(output);
        return output;
    }

    @Override
    public <R extends RealType<R> & NativeType<R>> void applyInPlace(Tensor<R> input) {
        this.checkRequiredArgs();
        String selectedAxes = "";
        for (String ax : input.getAxesOrderString().split("")) {
            if (this.axes == null || this.axes.toLowerCase().contains(ax.toLowerCase()) || ax.toLowerCase().equals("b")) continue;
            selectedAxes = selectedAxes + ax;
        }
        if (this.mode == TensorTransformation.Mode.FIXED && (this.axes == null || selectedAxes.equals("") || input.getAxesOrderString().replace("b", "").length() == selectedAxes.length())) {
            if (this.meanDouble == null && this.meanArr == null) {
                throw new IllegalArgumentException(FIXED_MODE_ERR);
            }
            if (this.meanDouble == null) {
                throw new IllegalArgumentException("The parameters 'mean' and 'std' cannot be arrays with the introduced 'axes'.");
            }
            this.fixedModeGlobalMeanStd(input);
        } else if (this.mode != TensorTransformation.Mode.FIXED && (this.axes == null || selectedAxes.equals("") || input.getAxesOrderString().replace("b", "").length() == selectedAxes.length())) {
            if (this.meanDouble != null || this.meanArr != null) {
                throw new IllegalArgumentException(NOT_FIXED_MODE_ERR);
            }
            this.notFixedModeGlobalMeanStd(input);
        } else if (this.mode != TensorTransformation.Mode.FIXED && this.axes.length() <= 2 && this.axes.length() > 0) {
            if (this.meanDouble != null || this.meanArr != null) {
                throw new IllegalArgumentException(NOT_FIXED_MODE_ERR);
            }
            this.notFixedAxesMeanStd(input, selectedAxes);
        } else if (this.mode == TensorTransformation.Mode.FIXED && this.axes.length() <= 2 && this.axes.length() > 0) {
            if (this.meanDouble == null && this.meanArr == null) {
                throw new IllegalArgumentException(FIXED_MODE_ERR);
            }
            if (this.meanDouble != null) {
                throw new IllegalArgumentException("The parameters 'mean' and ' std' have to be arrays with the introduced 'axes'.");
            }
            this.fixedAxesMeanStd(input, selectedAxes);
        } else {
            throw new IllegalArgumentException("At the moment, only allowed scaling of planes.");
        }
    }

    private <R extends RealType<R> & NativeType<R>> void fixedModeGlobalMeanStd(Tensor<R> output) {
        this.zeroMeanUnitVariance(output.getData(), this.meanDouble, this.stdDouble);
    }

    private <R extends RealType<R> & NativeType<R>> void notFixedAxesMeanStd(Tensor<R> output, String axesOfInterest) {
        long[][] points;
        int i;
        long[] start = new long[output.getData().numDimensions()];
        long[] dims = output.getData().dimensionsAsLongArray();
        long[] indOfDims = new long[dims.length - axesOfInterest.length()];
        long[] sizeOfDims = new long[dims.length - axesOfInterest.length()];
        for (i = 0; i < dims.length; ++i) {
            if (axesOfInterest.indexOf(output.getAxesOrderString().split("")[i]) != -1) continue;
            indOfDims[i] = i;
        }
        for (i = 0; i < sizeOfDims.length; ++i) {
            sizeOfDims[i] = dims[(int)indOfDims[i]];
        }
        for (long[] pp : points = ZeroMeanUnitVarianceTransformation.getAllCombinations(sizeOfDims)) {
            for (int i2 = 0; i2 < pp.length; ++i2) {
                start[(int)indOfDims[i2]] = pp[i2];
                dims[(int)indOfDims[i2]] = pp[i2] + 1L;
            }
            long[] end = new long[dims.length];
            for (int i3 = 0; i3 < dims.length; ++i3) {
                end[i3] = dims[i3] - start[i3];
            }
            IntervalView plane = Views.offsetInterval(output.getData(), (long[])start, (long[])end);
            float[] meanStd = ZeroMeanUnitVarianceTransformation.meanStd(plane);
            float mean = meanStd[0];
            float std = meanStd[1];
            this.zeroMeanUnitVariance(output.getData(), mean, std);
        }
    }

    private <R extends RealType<R> & NativeType<R>> void fixedAxesMeanStd(Tensor<R> output, String axesOfInterest) {
        int i;
        long[] start = new long[output.getData().numDimensions()];
        long[] dims = output.getData().dimensionsAsLongArray();
        long[] indOfDims = new long[dims.length - axesOfInterest.length()];
        long[] sizeOfDims = new long[dims.length - axesOfInterest.length()];
        for (i = 0; i < dims.length; ++i) {
            if (axesOfInterest.indexOf(output.getAxesOrderString().split("")[i]) != -1) continue;
            indOfDims[i] = i;
        }
        for (i = 0; i < sizeOfDims.length; ++i) {
            sizeOfDims[i] = dims[(int)indOfDims[i]];
        }
        long[][] points = ZeroMeanUnitVarianceTransformation.getAllCombinations(sizeOfDims);
        int c = 0;
        for (long[] pp : points) {
            for (int i2 = 0; i2 < pp.length; ++i2) {
                start[(int)indOfDims[i2]] = pp[i2];
                dims[(int)indOfDims[i2]] = pp[i2] + 1L;
            }
            long[] end = new long[dims.length];
            for (int i3 = 0; i3 < dims.length; ++i3) {
                end[i3] = dims[i3] - start[i3];
            }
            IntervalView plane = Views.offsetInterval(output.getData(), (long[])start, (long[])end);
            float mean = (float)this.meanArr[c];
            float std = (float)this.stdArr[c++];
            this.zeroMeanUnitVariance((RandomAccessibleInterval<R>)plane, mean, std);
        }
    }

    private <R extends RealType<R> & NativeType<R>> void notFixedModeGlobalMeanStd(Tensor<R> output) {
        float[] meanStd = ZeroMeanUnitVarianceTransformation.meanStd(output.getData());
        float mean = meanStd[0];
        float std = meanStd[1];
        this.zeroMeanUnitVariance(output.getData(), mean, std);
    }

    public static <R extends RealType<R> & NativeType<R>> float[] meanStd(RandomAccessibleInterval<R> rai) {
        double sum = 0.0;
        long n = 0L;
        for (RealType p : Views.iterable(rai)) {
            sum += p.getRealDouble();
            ++n;
        }
        if (n < 1L) {
            throw new IllegalArgumentException("Tensor must contain at least 2 pixels, got " + n);
        }
        double mean = sum / (double)n;
        double sumdx2 = 0.0;
        for (RealType p : Views.iterable(rai)) {
            double dx = p.getRealDouble() - mean;
            sumdx2 += dx * dx;
        }
        double variance = sumdx2 / (double)n;
        double std = Math.sqrt(variance);
        return new float[]{(float)mean, (float)std};
    }

    private static long[][] getAllCombinations(long[] arr) {
        long n = 1L;
        for (long nn : arr) {
            n *= nn;
        }
        long[][] allPoints = new long[(int)n][arr.length];
        int i = 0;
        while ((long)i < n) {
            for (int j = 0; j < arr.length; ++j) {
                int factor = 1;
                for (int k = 0; k < j; ++k) {
                    factor = (int)((long)factor * arr[k]);
                }
                int auxVal = i / factor;
                int val = auxVal % (int)arr[j];
                allPoints[i][j] = val;
            }
            ++i;
        }
        return allPoints;
    }

    public static void main(String[] args) {
        ZeroMeanUnitVarianceTransformation.test2();
        ZeroMeanUnitVarianceTransformation.test3();
    }

    public static void test1() {
        float[] arr = new float[9];
        for (int i = 0; i < arr.length; ++i) {
            arr[i] = i;
        }
        ZeroMeanUnitVarianceTransformation preprocessing = new ZeroMeanUnitVarianceTransformation();
        preprocessing.setMean(4);
        preprocessing.setStd(4);
        preprocessing.setMode("fixed");
        ArrayImg rai = ArrayImgs.floats((float[])arr, (long[])new long[]{3L, 3L});
        Tensor tt = Tensor.build("name", "xy", rai);
        preprocessing.applyInPlace(tt);
        System.out.print(true);
    }

    public static void test2() {
        float[] arr = new float[18];
        for (int i = 0; i < arr.length; ++i) {
            arr[i] = i;
        }
        ArrayImg rai = ArrayImgs.floats((float[])arr, (long[])new long[]{3L, 3L, 2L});
        ZeroMeanUnitVarianceTransformation preprocessing = new ZeroMeanUnitVarianceTransformation();
        preprocessing.setAxes("xy");
        preprocessing.setMode("per_sample");
        Tensor tt = Tensor.build("name", "xyc", rai);
        preprocessing.applyInPlace(tt);
        System.out.print(true);
    }

    public static void test3() {
        float[] arr = new float[9];
        for (int i = 0; i < arr.length; ++i) {
            arr[i] = i;
        }
        ArrayImg rai = ArrayImgs.floats((float[])arr, (long[])new long[]{1L, 1L, 3L, 3L});
        ZeroMeanUnitVarianceTransformation preprocessing = new ZeroMeanUnitVarianceTransformation();
        preprocessing.setAxes("y");
        preprocessing.setMode("fixed");
        preprocessing.setMean(new double[]{1.0, 4.0, 7.0});
        preprocessing.setStd(new double[]{0.8165, 0.8165, 0.8165});
        Tensor tt = Tensor.build("name", "bcyx", rai);
        preprocessing.applyInPlace(tt);
        System.out.print(true);
    }

    public <R extends RealType<R> & NativeType<R>> void zeroMeanUnitVariance(RandomAccessibleInterval<R> rai, double mean, double std) {
        RealType type = (RealType)Util.getTypeFromInterval(rai);
        if (type instanceof IntegerType) {
            LoopBuilder.setImages(rai).multiThreaded().forEachPixel(i -> i.setReal(Math.floor((i.getRealDouble() - mean) / (std + this.eps))));
        } else {
            LoopBuilder.setImages(rai).multiThreaded().forEachPixel(i -> i.setReal((i.getRealDouble() - mean) / (std + this.eps)));
        }
    }
}

