package commands;

import java.io.IOException;
import java.io.PrintStream;
import java.util.List;
import java.util.Locale;
import mitiv.array.ArrayFactory;
import mitiv.array.DoubleArray;
import mitiv.array.ShapedArray;
import mitiv.base.Shape;
import mitiv.cost.CompositeDifferentiableCostFunction;
import mitiv.cost.DifferentiableCostFunction;
import mitiv.cost.HyperbolicTotalVariation;
import mitiv.cost.QuadraticCost;
import mitiv.deconv.ConvolutionOperator;
import mitiv.deconv.WeightedConvolutionCost;
import mitiv.exception.IncorrectSpaceException;
import mitiv.invpb.ReconstructionJob;
import mitiv.invpb.ReconstructionSynchronizer;
import mitiv.invpb.ReconstructionViewer;
import mitiv.invpb.SimpleViewer;
import mitiv.io.ColorModel;
import mitiv.io.DataFormat;
import mitiv.linalg.ArrayOps;
import mitiv.linalg.LinearOperator;
import mitiv.linalg.Vector;
import mitiv.linalg.shaped.DoubleShapedVector;
import mitiv.linalg.shaped.DoubleShapedVectorSpace;
import mitiv.linalg.shaped.RealComplexFFT;
import mitiv.optim.BLMVM;
import mitiv.optim.BoundProjector;
import mitiv.optim.LBFGS;
import mitiv.optim.MoreThuenteLineSearch;
import mitiv.optim.NonLinearConjugateGradient;
import mitiv.optim.OptimTask;
import mitiv.optim.ReverseCommunicationOptimizer;
import mitiv.optim.SimpleBounds;
import mitiv.optim.SimpleLowerBound;
import mitiv.optim.SimpleUpperBound;
import mitiv.utils.FFTUtils;
import mitiv.utils.Timer;
import org.kohsuke.args4j.Argument;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;

/* loaded from: input_file:commands/TotalVariationDeconvolution.class */
public class TotalVariationDeconvolution implements ReconstructionJob {

    @Option(name = "--single", aliases = {"-s"}, usage = "Force single precision.")
    private boolean single;

    @Option(name = "--help", aliases = {"-h", "-?"}, usage = "Display help.")
    private boolean help;

    @Option(name = "--old", usage = "Use old convolution operator.")
    private boolean old;

    @Argument
    private List<String> arguments;

    @Option(name = "--output", aliases = {"-o"}, usage = "Name of output image.", metaVar = "OUTPUT")
    private String outName = "output.mda";

    @Option(name = "--init", aliases = {"-i"}, usage = "Name of initial image.", metaVar = "INIT")
    private String initName = null;

    @Option(name = "--weight", aliases = {"-w"}, usage = "Name of file with weights.", metaVar = "WEIGHT")
    private String weightName = null;

    @Option(name = "--eta", aliases = {"-e"}, usage = "Mean data error.", metaVar = "ETA")
    private double eta = 1.0d;

    @Option(name = "--mu", aliases = {"-m"}, usage = "Regularization level.", metaVar = "MU")
    private double mu = 10.0d;

    @Option(name = "--epsilon", aliases = {"-t"}, usage = "Threshold level.", metaVar = "EPSILON")
    private double epsilon = 1.0d;

    @Option(name = "--gatol", usage = "Absolute gradient tolerance for the convergence.", metaVar = "GATOL")
    private double gatol = 0.0d;

    @Option(name = "--grtol", usage = "Relative gradient tolerance for the convergence.", metaVar = "GRTOL")
    private double grtol = 0.001d;

    @Option(name = "--lbfgs", usage = "Use LBFGS method with M saved steps.", metaVar = "M")
    private int limitedMemorySize = 0;

    @Option(name = "--xmin", usage = "Lower bound for the variables.", metaVar = "VALUE")
    private double lowerBound = Double.NEGATIVE_INFINITY;

    @Option(name = "--xmax", usage = "Upper bound for the variables.", metaVar = "VALUE")
    private double upperBound = Double.POSITIVE_INFINITY;

    @Option(name = "--verbose", aliases = {"-v"}, usage = "Verbose mode.")
    private boolean verbose = false;

    @Option(name = "--debug", aliases = {"-d"}, usage = "Debug mode.")
    private boolean debug = false;

    @Option(name = "--maxiter", aliases = {"-l"}, usage = "Maximum number of iterations, -1 for no limits.")
    private int maxiter = 200;

    @Option(name = "--pad", usage = "Padding method (auto|none).", metaVar = "VALUE")
    private String paddingMethod = "auto";
    private DoubleArray data = null;
    private DoubleArray psf = null;
    private DoubleArray result = null;
    private DoubleArray weight = null;
    private double fcost = 0.0d;
    private DoubleShapedVector gcost = null;
    private Timer timer = new Timer();
    private ReverseCommunicationOptimizer minimizer = null;
    private ReconstructionViewer viewer = null;
    private ReconstructionSynchronizer synchronizer = null;
    private double[] synchronizedParameters = {0.0d, 0.0d};
    private boolean[] change = new boolean[2];
    private String[] synchronizedParameterNames = {"Regularization Level", "Relaxation Threshold"};
    private boolean run = true;

    public DoubleArray getData() {
        return this.data;
    }

    public void setData(DoubleArray doubleArray) {
        this.data = doubleArray;
    }

    public DoubleArray getPsf() {
        return this.psf;
    }

    public void setPsf(DoubleArray doubleArray) {
        this.psf = doubleArray;
    }

    @Override // mitiv.invpb.ReconstructionJob
    public DoubleArray getResult() {
        return this.result;
    }

    public void setResult(DoubleArray doubleArray) {
        this.result = doubleArray;
    }

    public ReconstructionViewer getViewer() {
        return this.viewer;
    }

    public void setViewer(ReconstructionViewer reconstructionViewer) {
        this.viewer = reconstructionViewer;
    }

    public ReconstructionSynchronizer getSynchronizer() {
        return this.synchronizer;
    }

    public void createSynchronizer() {
        if (this.synchronizer == null) {
            this.synchronizedParameters[0] = this.mu;
            this.synchronizedParameters[1] = this.epsilon;
            this.synchronizer = new ReconstructionSynchronizer(this.synchronizedParameters);
        }
    }

    public void deleteSynchronizer() {
        this.synchronizer = null;
    }

    public String getSynchronizedParameterName(int i) {
        return this.synchronizedParameterNames[i];
    }

    public void setVerboseMode(boolean z) {
        this.verbose = z;
    }

    public void setDebugMode(boolean z) {
        this.debug = z;
    }

    public void setOutputName(String str) {
        this.outName = str;
    }

    public void setMaximumIterations(int i) {
        this.maxiter = i;
    }

    public void setLimitedMemorySize(int i) {
        this.limitedMemorySize = i;
    }

    public void setTargetError(double d) {
        this.eta = d;
    }

    public void setRegularizationWeight(double d) {
        this.mu = d;
    }

    public void setRegularizationThreshold(double d) {
        this.epsilon = d;
    }

    public void setAbsoluteTolerance(double d) {
        this.gatol = d;
    }

    public void setRelativeTolerance(double d) {
        this.grtol = d;
    }

    @Override // mitiv.invpb.ReconstructionJob
    public double getRelativeTolerance() {
        return this.grtol;
    }

    public void setLowerBound(double d) {
        this.lowerBound = d;
    }

    public void setUpperBound(double d) {
        this.upperBound = d;
    }

    public void stop() {
        this.run = false;
    }

    public void setWeight(DoubleArray doubleArray) {
        this.weight = doubleArray;
    }

    public static DoubleArray loadData(String str) {
        ShapedArray load = DataFormat.load(str);
        return ColorModel.guessColorModel(load) == ColorModel.NONE ? load.toDouble() : ColorModel.filterImageAsDouble(load, ColorModel.GRAY);
    }

    public static void main(String[] strArr) {
        Locale.setDefault(Locale.US);
        TotalVariationDeconvolution totalVariationDeconvolution = new TotalVariationDeconvolution();
        CmdLineParser cmdLineParser = new CmdLineParser(totalVariationDeconvolution);
        try {
            cmdLineParser.parseArgument(strArr);
            if (totalVariationDeconvolution.mu < 0.0d) {
                System.err.format("Regularization level MU must be strictly positive.\n", new Object[0]);
                System.exit(1);
            }
            if (totalVariationDeconvolution.epsilon <= 0.0d) {
                System.err.format("Threshold level EPSILON must be strictly positive.\n", new Object[0]);
                System.exit(1);
            }
            if (totalVariationDeconvolution.help) {
                PrintStream printStream = System.out;
                printStream.println("Usage: tvdec [OPTIONS] INPUT_IMAGE PSF");
                printStream.println("Options:");
                cmdLineParser.setUsageWidth(80);
                cmdLineParser.printUsage(printStream);
                System.exit(0);
            }
        } catch (CmdLineException e) {
            System.err.format("Error: %s\n", e.getMessage());
            cmdLineParser.setUsageWidth(80);
            cmdLineParser.printUsage(System.err);
        }
        int size = totalVariationDeconvolution.arguments == null ? 0 : totalVariationDeconvolution.arguments.size();
        if (size != 2) {
            PrintStream printStream2 = System.err;
            Object[] objArr = new Object[1];
            objArr[0] = size < 2 ? "few" : "many";
            printStream2.format("Too %s arguments.\n", objArr);
            System.exit(1);
        }
        String str = totalVariationDeconvolution.arguments.get(0);
        String str2 = totalVariationDeconvolution.arguments.get(1);
        if (totalVariationDeconvolution.verbose) {
            totalVariationDeconvolution.setViewer(new SimpleViewer());
        }
        if (totalVariationDeconvolution.debug) {
            System.out.format("mu: %.2g, threshold: %.2g, output: %s\n", Double.valueOf(totalVariationDeconvolution.mu), Double.valueOf(totalVariationDeconvolution.epsilon), totalVariationDeconvolution.outName);
        }
        totalVariationDeconvolution.data = loadData(str);
        totalVariationDeconvolution.psf = loadData(str2);
        if (totalVariationDeconvolution.initName != null) {
            totalVariationDeconvolution.result = loadData(totalVariationDeconvolution.initName);
        }
        totalVariationDeconvolution.deconvolve(totalVariationDeconvolution.paddingMethod);
        try {
            DataFormat.save(totalVariationDeconvolution.result, totalVariationDeconvolution.outName);
        } catch (IOException e2) {
            if (totalVariationDeconvolution.debug) {
                e2.printStackTrace();
            }
            System.err.format("Failed to write output image.\n", new Object[0]);
            System.exit(1);
        }
        if (totalVariationDeconvolution.verbose) {
            System.out.println("Done!");
        }
        System.exit(0);
    }

    private static void fatal(String str) {
        throw new IllegalArgumentException(str);
    }

    public void deconvolve(String str) {
        Shape shape = this.data.getShape();
        Shape shape2 = this.psf.getShape();
        if (this.old) {
            deconvolve(shape);
            return;
        }
        int rank = this.data.getRank();
        int[] iArr = new int[rank];
        if (str.equals("auto")) {
            for (int i = 0; i < rank; i++) {
                iArr[i] = FFTUtils.bestDimension((shape.dimension(i) + shape2.dimension(i)) - 1);
            }
        } else if (str.equals("none")) {
            for (int i2 = 0; i2 < rank; i2++) {
                iArr[i2] = FFTUtils.bestDimension(Math.max(shape.dimension(i2), shape2.dimension(i2)));
            }
        } else {
            fatal("Unknown padding strategy.");
        }
        deconvolve(new Shape(iArr));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void deconvolve(Shape shape) {
        DoubleShapedVector create;
        DifferentiableCostFunction differentiableCostFunction;
        this.timer.start();
        if (this.data == null) {
            fatal("Input data not specified.");
        }
        Shape shape2 = this.data.getShape();
        int rank = this.data.getRank();
        if (this.psf == null) {
            fatal("PSF not specified.");
        }
        if (this.psf.getRank() != rank) {
            fatal("PSF must have same rank as data.");
        }
        Shape shape3 = this.psf.getShape();
        if (this.old) {
            for (int i = 0; i < rank; i++) {
                if (this.psf.getDimension(i) != shape2.dimension(i)) {
                    fatal("The dimensions of the PSF must match those of the data.");
                }
            }
        }
        if (this.result != null) {
            int i2 = 0;
            while (true) {
                if (i2 >= rank) {
                    break;
                }
                if (this.result.getDimension(i2) != this.data.getDimension(i2)) {
                    this.result = null;
                    break;
                }
                i2++;
            }
        }
        for (int i3 = 0; i3 < rank; i3++) {
            if (!this.old) {
                if (shape.dimension(i3) < shape2.dimension(i3)) {
                    fatal("The dimensions of the result must be at least those of the data.");
                }
                if (shape.dimension(i3) < shape3.dimension(i3)) {
                    fatal("The dimensions of the result must be at least those of the PSF.");
                }
            } else if (shape.dimension(i3) != shape2.dimension(i3)) {
                fatal("The dimensions of the result must be equal to those of the data.");
            }
        }
        DoubleShapedVectorSpace doubleShapedVectorSpace = new DoubleShapedVectorSpace(shape2);
        DoubleShapedVectorSpace doubleShapedVectorSpace2 = this.old ? doubleShapedVectorSpace : new DoubleShapedVectorSpace(shape);
        LinearOperator linearOperator = null;
        DoubleShapedVector create2 = doubleShapedVectorSpace.create(this.data);
        if (this.result != null) {
            create = doubleShapedVectorSpace2.create(this.result);
        } else if (this.old) {
            double sum = this.psf.sum();
            Vector create3 = doubleShapedVectorSpace2.create();
            create = create3;
            if (sum != 1.0d) {
                if (sum != 0.0d) {
                    create3.combine(0.0d, create3, 1.0d / sum, create2);
                    create = create3;
                } else {
                    create3.fill(0.0d);
                    create = create3;
                }
            }
        } else {
            create = doubleShapedVectorSpace2.create(0.0d);
        }
        this.result = ArrayFactory.wrap(create.getData(), shape);
        if (this.old) {
            RealComplexFFT realComplexFFT = new RealComplexFFT(doubleShapedVectorSpace2);
            if (this.weight != null) {
                if (this.weight.getNumber() != this.data.getNumber()) {
                    throw new IllegalArgumentException("Error weights and input data size don't match");
                }
                linearOperator = new LinearOperator(doubleShapedVectorSpace2) { // from class: commands.TotalVariationDeconvolution.1
                    @Override // mitiv.linalg.LinearOperator
                    protected void _apply(Vector vector, Vector vector2, int i4) throws IncorrectSpaceException {
                        double[] data = ((DoubleShapedVector) vector2).getData();
                        double[] data2 = ((DoubleShapedVector) vector).getData();
                        double[] flatten = TotalVariationDeconvolution.this.weight.flatten();
                        int number = vector2.getNumber();
                        for (int i5 = 0; i5 < number; i5++) {
                            data2[i5] = data[i5] * flatten[i5];
                        }
                    }
                };
            }
            differentiableCostFunction = new QuadraticCost(new ConvolutionOperator(realComplexFFT, doubleShapedVectorSpace2.create(this.psf)), create2, linearOperator);
        } else {
            WeightedConvolutionCost build = WeightedConvolutionCost.build(doubleShapedVectorSpace2, doubleShapedVectorSpace);
            build.setPSF(this.psf);
            build.setWeightsAndData(this.weight, this.data);
            differentiableCostFunction = build;
        }
        if (this.debug) {
            System.out.println("Vector space initialization complete.");
        }
        CompositeDifferentiableCostFunction compositeDifferentiableCostFunction = new CompositeDifferentiableCostFunction(1.0d, differentiableCostFunction, this.mu, new HyperbolicTotalVariation(doubleShapedVectorSpace2, this.epsilon));
        this.fcost = 0.0d;
        this.gcost = doubleShapedVectorSpace2.create();
        this.timer.stop();
        if (this.debug) {
            System.out.format("Cost function initialization completed in %.3f sec.\n", Double.valueOf(this.timer.getElapsedTime()));
        }
        this.timer.reset();
        this.timer.start();
        boolean z = this.lowerBound != Double.NEGATIVE_INFINITY ? false | true : false;
        boolean z2 = z;
        if (this.upperBound != Double.POSITIVE_INFINITY) {
            z2 = ((z ? 1 : 0) | 2) == true ? 1 : 0;
        }
        if (z2) {
            BoundProjector simpleLowerBound = z2 ? new SimpleLowerBound(doubleShapedVectorSpace2, this.lowerBound) : z2 == 2 ? new SimpleUpperBound(doubleShapedVectorSpace2, this.upperBound) : new SimpleBounds(doubleShapedVectorSpace2, this.lowerBound, this.upperBound);
            BLMVM blmvm = new BLMVM(doubleShapedVectorSpace2, simpleLowerBound, this.limitedMemorySize > 1 ? this.limitedMemorySize : 5);
            blmvm.setAbsoluteTolerance(this.gatol);
            blmvm.setRelativeTolerance(this.grtol);
            this.minimizer = blmvm;
            simpleLowerBound.projectVariables(create, create);
        } else {
            MoreThuenteLineSearch moreThuenteLineSearch = new MoreThuenteLineSearch(0.05d, 0.1d, 1.0E-17d);
            if (this.limitedMemorySize > 0) {
                LBFGS lbfgs = new LBFGS(doubleShapedVectorSpace2, this.limitedMemorySize, moreThuenteLineSearch);
                lbfgs.setAbsoluteTolerance(this.gatol);
                lbfgs.setRelativeTolerance(this.grtol);
                this.minimizer = lbfgs;
            } else {
                NonLinearConjugateGradient nonLinearConjugateGradient = new NonLinearConjugateGradient(doubleShapedVectorSpace2, NonLinearConjugateGradient.DEFAULT_METHOD, moreThuenteLineSearch);
                nonLinearConjugateGradient.setAbsoluteTolerance(this.gatol);
                nonLinearConjugateGradient.setRelativeTolerance(this.grtol);
                this.minimizer = nonLinearConjugateGradient;
            }
        }
        this.timer.stop();
        if (this.debug) {
            System.out.format("Optimization method initialization complete in %.3f sec.\n", Double.valueOf(this.timer.getElapsedTime()));
        }
        this.timer.reset();
        OptimTask start = this.minimizer.start();
        while (true) {
            OptimTask optimTask = start;
            if (!this.run) {
                break;
            }
            if (optimTask != OptimTask.COMPUTE_FG) {
                if (optimTask != OptimTask.NEW_X && optimTask != OptimTask.FINAL_X) {
                    System.err.println("TiPi: TotalVariationDeconvolution, error/warning: " + this.minimizer.getReason());
                    break;
                }
                if (this.viewer != null) {
                    this.viewer.display(this);
                }
                boolean z3 = optimTask == OptimTask.FINAL_X;
                if (!z3 && this.maxiter >= 0 && this.minimizer.getIterations() >= this.maxiter) {
                    System.err.format("Warning: too many iterations (%d).\n", Integer.valueOf(this.maxiter));
                    z3 = true;
                }
                if (z3) {
                    break;
                }
            } else {
                this.timer.resume();
                this.fcost = compositeDifferentiableCostFunction.computeCostAndGradient(1.0d, create, this.gcost, true);
                this.timer.stop();
            }
            if (this.synchronizer != null) {
                if (this.synchronizer.getTask() == 1) {
                    break;
                }
                this.synchronizedParameters[0] = this.mu;
                this.synchronizedParameters[1] = this.epsilon;
                if (this.synchronizer.updateParameters(this.synchronizedParameters, this.change)) {
                    if (this.change[0]) {
                        this.mu = this.synchronizedParameters[0];
                    }
                    if (this.change[1]) {
                        this.epsilon = this.synchronizedParameters[1];
                    }
                }
            }
            start = this.minimizer.iterate(create, this.fcost, this.gcost);
        }
        if (this.verbose) {
            this.timer.stop();
            double elapsedTime = this.timer.getElapsedTime();
            int evaluations = getEvaluations();
            PrintStream printStream = System.out;
            Object[] objArr = new Object[2];
            objArr[0] = Double.valueOf(elapsedTime);
            objArr[1] = Double.valueOf(evaluations > 0 ? (1000.0d * elapsedTime) / evaluations : 0.0d);
            printStream.format("Total time in cost function: %.3f s (%.3f ms/eval.)\n", objArr);
            if (differentiableCostFunction instanceof WeightedConvolutionCost) {
                WeightedConvolutionCost weightedConvolutionCost = (WeightedConvolutionCost) differentiableCostFunction;
                double elapsedTimeInFFT = weightedConvolutionCost.getElapsedTimeInFFT();
                PrintStream printStream2 = System.out;
                Object[] objArr2 = new Object[2];
                objArr2[0] = Double.valueOf(elapsedTimeInFFT);
                objArr2[1] = Double.valueOf(evaluations > 0 ? (1000.0d * elapsedTimeInFFT) / evaluations : 0.0d);
                printStream2.format("Total time in FFT: %.3f s (%.3f ms/eval.)\n", objArr2);
                double elapsedTime2 = weightedConvolutionCost.getElapsedTime() - elapsedTimeInFFT;
                PrintStream printStream3 = System.out;
                Object[] objArr3 = new Object[2];
                objArr3[0] = Double.valueOf(elapsedTime2);
                objArr3[1] = Double.valueOf(evaluations > 0 ? (1000.0d * elapsedTime2) / evaluations : 0.0d);
                printStream3.format("Total time in other parts of the convolution operator: %.3f s (%.3f ms/eval.)\n", objArr3);
            }
            System.out.format("min(x) = %g\n", Double.valueOf(ArrayOps.getMin(create.getData())));
            System.out.format("max(x) = %g\n", Double.valueOf(ArrayOps.getMax(create.getData())));
        }
    }

    @Override // mitiv.invpb.ReconstructionJob
    public int getIterations() {
        if (this.minimizer == null) {
            return 0;
        }
        return this.minimizer.getIterations();
    }

    @Override // mitiv.invpb.ReconstructionJob
    public int getEvaluations() {
        if (this.minimizer == null) {
            return 0;
        }
        return this.minimizer.getEvaluations();
    }

    @Override // mitiv.invpb.ReconstructionJob
    public double getCost() {
        return this.fcost;
    }

    @Override // mitiv.invpb.ReconstructionJob
    public double getGradientNorm2() {
        if (this.gcost == null) {
            return 0.0d;
        }
        return this.gcost.norm2();
    }

    @Override // mitiv.invpb.ReconstructionJob
    public double getGradientNorm1() {
        if (this.gcost == null) {
            return 0.0d;
        }
        return this.gcost.norm1();
    }

    @Override // mitiv.invpb.ReconstructionJob
    public double getGradientNormInf() {
        if (this.gcost == null) {
            return 0.0d;
        }
        return this.gcost.normInf();
    }
}
