/*
 * Decompiled with CFR 0.152.
 */
package commands;

import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintStream;
import java.util.List;
import java.util.Locale;
import mitiv.array.ArrayFactory;
import mitiv.array.DoubleArray;
import mitiv.array.ShapedArray;
import mitiv.base.Shape;
import mitiv.cost.CompositeDifferentiableCostFunction;
import mitiv.cost.DifferentiableCostFunction;
import mitiv.cost.HyperbolicTotalVariation;
import mitiv.cost.QuadraticCost;
import mitiv.deconv.ConvolutionOperator;
import mitiv.deconv.WeightedConvolutionCost;
import mitiv.exception.IncorrectSpaceException;
import mitiv.invpb.ReconstructionJob;
import mitiv.invpb.ReconstructionSynchronizer;
import mitiv.invpb.ReconstructionViewer;
import mitiv.invpb.SimpleViewer;
import mitiv.io.ColorModel;
import mitiv.io.DataFormat;
import mitiv.linalg.ArrayOps;
import mitiv.linalg.LinearOperator;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;
import mitiv.linalg.shaped.DoubleShapedVector;
import mitiv.linalg.shaped.DoubleShapedVectorSpace;
import mitiv.linalg.shaped.RealComplexFFT;
import mitiv.optim.BLMVM;
import mitiv.optim.BoundProjector;
import mitiv.optim.LBFGS;
import mitiv.optim.LineSearch;
import mitiv.optim.MoreThuenteLineSearch;
import mitiv.optim.NonLinearConjugateGradient;
import mitiv.optim.OptimTask;
import mitiv.optim.ReverseCommunicationOptimizer;
import mitiv.optim.SimpleBounds;
import mitiv.optim.SimpleLowerBound;
import mitiv.optim.SimpleUpperBound;
import mitiv.utils.FFTUtils;
import mitiv.utils.Timer;
import org.kohsuke.args4j.Argument;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;

public class TotalVariationDeconvolution
implements ReconstructionJob {
    @Option(name="--output", aliases={"-o"}, usage="Name of output image.", metaVar="OUTPUT")
    private String outName = "output.mda";
    @Option(name="--init", aliases={"-i"}, usage="Name of initial image.", metaVar="INIT")
    private String initName = null;
    @Option(name="--weight", aliases={"-w"}, usage="Name of file with weights.", metaVar="WEIGHT")
    private String weightName = null;
    @Option(name="--eta", aliases={"-e"}, usage="Mean data error.", metaVar="ETA")
    private double eta = 1.0;
    @Option(name="--mu", aliases={"-m"}, usage="Regularization level.", metaVar="MU")
    private double mu = 10.0;
    @Option(name="--epsilon", aliases={"-t"}, usage="Threshold level.", metaVar="EPSILON")
    private double epsilon = 1.0;
    @Option(name="--gatol", usage="Absolute gradient tolerance for the convergence.", metaVar="GATOL")
    private double gatol = 0.0;
    @Option(name="--grtol", usage="Relative gradient tolerance for the convergence.", metaVar="GRTOL")
    private double grtol = 0.001;
    @Option(name="--lbfgs", usage="Use LBFGS method with M saved steps.", metaVar="M")
    private int limitedMemorySize = 0;
    @Option(name="--xmin", usage="Lower bound for the variables.", metaVar="VALUE")
    private double lowerBound = Double.NEGATIVE_INFINITY;
    @Option(name="--xmax", usage="Upper bound for the variables.", metaVar="VALUE")
    private double upperBound = Double.POSITIVE_INFINITY;
    @Option(name="--single", aliases={"-s"}, usage="Force single precision.")
    private boolean single;
    @Option(name="--help", aliases={"-h", "-?"}, usage="Display help.")
    private boolean help;
    @Option(name="--verbose", aliases={"-v"}, usage="Verbose mode.")
    private boolean verbose = false;
    @Option(name="--debug", aliases={"-d"}, usage="Debug mode.")
    private boolean debug = false;
    @Option(name="--maxiter", aliases={"-l"}, usage="Maximum number of iterations, -1 for no limits.")
    private int maxiter = 200;
    @Option(name="--pad", usage="Padding method (auto|none).", metaVar="VALUE")
    private String paddingMethod = "auto";
    @Option(name="--old", usage="Use old convolution operator.")
    private boolean old;
    @Argument
    private List<String> arguments;
    private DoubleArray data = null;
    private DoubleArray psf = null;
    private DoubleArray result = null;
    private DoubleArray weight = null;
    private double fcost = 0.0;
    private DoubleShapedVector gcost = null;
    private Timer timer = new Timer();
    private ReverseCommunicationOptimizer minimizer = null;
    private ReconstructionViewer viewer = null;
    private ReconstructionSynchronizer synchronizer = null;
    private double[] synchronizedParameters = new double[]{0.0, 0.0};
    private boolean[] change = new boolean[2];
    private String[] synchronizedParameterNames = new String[]{"Regularization Level", "Relaxation Threshold"};
    private boolean run = true;

    public DoubleArray getData() {
        return this.data;
    }

    public void setData(DoubleArray data) {
        this.data = data;
    }

    public DoubleArray getPsf() {
        return this.psf;
    }

    public void setPsf(DoubleArray psf) {
        this.psf = psf;
    }

    @Override
    public DoubleArray getResult() {
        return this.result;
    }

    public void setResult(DoubleArray result) {
        this.result = result;
    }

    public ReconstructionViewer getViewer() {
        return this.viewer;
    }

    public void setViewer(ReconstructionViewer rv) {
        this.viewer = rv;
    }

    public ReconstructionSynchronizer getSynchronizer() {
        return this.synchronizer;
    }

    public void createSynchronizer() {
        if (this.synchronizer == null) {
            this.synchronizedParameters[0] = this.mu;
            this.synchronizedParameters[1] = this.epsilon;
            this.synchronizer = new ReconstructionSynchronizer(this.synchronizedParameters);
        }
    }

    public void deleteSynchronizer() {
        this.synchronizer = null;
    }

    public String getSynchronizedParameterName(int i) {
        return this.synchronizedParameterNames[i];
    }

    public void setVerboseMode(boolean value) {
        this.verbose = value;
    }

    public void setDebugMode(boolean value) {
        this.debug = value;
    }

    public void setOutputName(String name) {
        this.outName = name;
    }

    public void setMaximumIterations(int value) {
        this.maxiter = value;
    }

    public void setLimitedMemorySize(int value) {
        this.limitedMemorySize = value;
    }

    public void setTargetError(double value) {
        this.eta = value;
    }

    public void setRegularizationWeight(double value) {
        this.mu = value;
    }

    public void setRegularizationThreshold(double value) {
        this.epsilon = value;
    }

    public void setAbsoluteTolerance(double value) {
        this.gatol = value;
    }

    public void setRelativeTolerance(double value) {
        this.grtol = value;
    }

    @Override
    public double getRelativeTolerance() {
        return this.grtol;
    }

    public void setLowerBound(double value) {
        this.lowerBound = value;
    }

    public void setUpperBound(double value) {
        this.upperBound = value;
    }

    public void stop() {
        this.run = false;
    }

    public void setWeight(DoubleArray W) {
        this.weight = W;
    }

    public static DoubleArray loadData(String name) {
        ShapedArray arr = DataFormat.load(name);
        ColorModel colorModel = ColorModel.guessColorModel(arr);
        if (colorModel == ColorModel.NONE) {
            return arr.toDouble();
        }
        return ColorModel.filterImageAsDouble(arr, ColorModel.GRAY);
    }

    public static void main(String[] args) {
        int size;
        Locale.setDefault(Locale.US);
        TotalVariationDeconvolution job = new TotalVariationDeconvolution();
        CmdLineParser parser = new CmdLineParser((Object)job);
        try {
            parser.parseArgument(args);
            if (job.mu < 0.0) {
                System.err.format("Regularization level MU must be strictly positive.\n", new Object[0]);
                System.exit(1);
            }
            if (job.epsilon <= 0.0) {
                System.err.format("Threshold level EPSILON must be strictly positive.\n", new Object[0]);
                System.exit(1);
            }
            if (job.help) {
                PrintStream stream = System.out;
                stream.println("Usage: tvdec [OPTIONS] INPUT_IMAGE PSF");
                stream.println("Options:");
                parser.setUsageWidth(80);
                parser.printUsage((OutputStream)stream);
                System.exit(0);
            }
        }
        catch (CmdLineException e) {
            System.err.format("Error: %s\n", e.getMessage());
            parser.setUsageWidth(80);
            parser.printUsage((OutputStream)System.err);
        }
        int n = size = job.arguments == null ? 0 : job.arguments.size();
        if (size != 2) {
            System.err.format("Too %s arguments.\n", size < 2 ? "few" : "many");
            System.exit(1);
        }
        String inputName = job.arguments.get(0);
        String psfName = job.arguments.get(1);
        if (job.verbose) {
            job.setViewer(new SimpleViewer());
        }
        if (job.debug) {
            System.out.format("mu: %.2g, threshold: %.2g, output: %s\n", job.mu, job.epsilon, job.outName);
        }
        job.data = TotalVariationDeconvolution.loadData(inputName);
        job.psf = TotalVariationDeconvolution.loadData(psfName);
        if (job.initName != null) {
            job.result = TotalVariationDeconvolution.loadData(job.initName);
        }
        job.deconvolve(job.paddingMethod);
        try {
            DataFormat.save(job.result, job.outName);
        }
        catch (IOException e) {
            if (job.debug) {
                e.printStackTrace();
            }
            System.err.format("Failed to write output image.\n", new Object[0]);
            System.exit(1);
        }
        if (job.verbose) {
            System.out.println("Done!");
        }
        System.exit(0);
    }

    private static void fatal(String reason) {
        throw new IllegalArgumentException(reason);
    }

    public void deconvolve(String padding) {
        Shape dataShape = this.data.getShape();
        Shape psfShape = this.psf.getShape();
        if (this.old) {
            this.deconvolve(dataShape);
        } else {
            int rank = this.data.getRank();
            int[] dims = new int[rank];
            if (padding.equals("auto")) {
                int k = 0;
                while (k < rank) {
                    int resultDim;
                    int dataDim = dataShape.dimension(k);
                    int psfDim = psfShape.dimension(k);
                    dims[k] = resultDim = FFTUtils.bestDimension(dataDim + psfDim - 1);
                    ++k;
                }
            } else if (padding.equals("none")) {
                int k = 0;
                while (k < rank) {
                    int resultDim;
                    int dataDim = dataShape.dimension(k);
                    int psfDim = psfShape.dimension(k);
                    dims[k] = resultDim = FFTUtils.bestDimension(Math.max(dataDim, psfDim));
                    ++k;
                }
            } else {
                TotalVariationDeconvolution.fatal("Unknown padding strategy.");
            }
            this.deconvolve(Shape.make(dims));
        }
    }

    public void deconvolve(Shape resultShape) {
        DifferentiableCostFunction fdata;
        int k;
        this.timer.start();
        if (this.data == null) {
            TotalVariationDeconvolution.fatal("Input data not specified.");
        }
        Shape dataShape = this.data.getShape();
        int rank = this.data.getRank();
        if (this.psf == null) {
            TotalVariationDeconvolution.fatal("PSF not specified.");
        }
        if (this.psf.getRank() != rank) {
            TotalVariationDeconvolution.fatal("PSF must have same rank as data.");
        }
        Shape psfShape = this.psf.getShape();
        if (this.old) {
            k = 0;
            while (k < rank) {
                if (this.psf.getDimension(k) != dataShape.dimension(k)) {
                    TotalVariationDeconvolution.fatal("The dimensions of the PSF must match those of the data.");
                }
                ++k;
            }
        }
        if (this.result != null) {
            k = 0;
            while (k < rank) {
                if (this.result.getDimension(k) != this.data.getDimension(k)) {
                    this.result = null;
                    break;
                }
                ++k;
            }
        }
        k = 0;
        while (k < rank) {
            if (this.old) {
                if (resultShape.dimension(k) != dataShape.dimension(k)) {
                    TotalVariationDeconvolution.fatal("The dimensions of the result must be equal to those of the data.");
                }
            } else {
                if (resultShape.dimension(k) < dataShape.dimension(k)) {
                    TotalVariationDeconvolution.fatal("The dimensions of the result must be at least those of the data.");
                }
                if (resultShape.dimension(k) < psfShape.dimension(k)) {
                    TotalVariationDeconvolution.fatal("The dimensions of the result must be at least those of the PSF.");
                }
            }
            ++k;
        }
        DoubleShapedVectorSpace dataSpace = new DoubleShapedVectorSpace(dataShape);
        DoubleShapedVectorSpace resultSpace = this.old ? dataSpace : new DoubleShapedVectorSpace(resultShape);
        LinearOperator W = null;
        DoubleShapedVector y = dataSpace.create(this.data);
        DoubleShapedVector x = null;
        if (this.result != null) {
            x = resultSpace.create(this.result);
        } else if (this.old) {
            double psf_sum = this.psf.sum();
            x = resultSpace.create();
            if (psf_sum != 1.0) {
                if (psf_sum != 0.0) {
                    x.axpby(0.0, x, 1.0 / psf_sum, y);
                } else {
                    x.fill(0.0);
                }
            }
        } else {
            x = resultSpace.create(0.0);
        }
        this.result = ArrayFactory.wrap(x.getData(), resultShape);
        ConvolutionOperator H = null;
        if (this.old) {
            RealComplexFFT FFT = new RealComplexFFT(resultSpace);
            if (this.weight != null) {
                if (this.weight.getNumber() != this.data.getNumber()) {
                    throw new IllegalArgumentException("Error weights and input data size don't match");
                }
                W = new LinearOperator((VectorSpace)resultSpace){

                    @Override
                    protected void privApply(Vector src, Vector dst, int job) throws IncorrectSpaceException {
                        double[] inp = ((DoubleShapedVector)src).getData();
                        double[] out = ((DoubleShapedVector)dst).getData();
                        double[] weights = TotalVariationDeconvolution.this.weight.flatten();
                        int number = src.getNumber();
                        int i = 0;
                        while (i < number) {
                            out[i] = inp[i] * weights[i];
                            ++i;
                        }
                    }
                };
            }
            DoubleShapedVector h = resultSpace.create(this.psf);
            H = new ConvolutionOperator(FFT, h);
            fdata = new QuadraticCost(H, y, W);
        } else {
            WeightedConvolutionCost cost = WeightedConvolutionCost.build(resultSpace, dataSpace);
            cost.setPSF(this.psf);
            cost.setWeightsAndData(this.weight, this.data);
            fdata = cost;
        }
        if (this.debug) {
            System.out.println("Vector space initialization complete.");
        }
        HyperbolicTotalVariation fprior = new HyperbolicTotalVariation(resultSpace, this.epsilon);
        CompositeDifferentiableCostFunction cost = new CompositeDifferentiableCostFunction(1.0, fdata, this.mu, fprior);
        this.fcost = 0.0;
        this.gcost = resultSpace.create();
        this.timer.stop();
        if (this.debug) {
            System.out.format("Cost function initialization completed in %.3f sec.\n", this.timer.getElapsedTime());
        }
        this.timer.reset();
        this.timer.start();
        MoreThuenteLineSearch lineSearch = null;
        LBFGS lbfgs = null;
        BLMVM blmvm = null;
        NonLinearConjugateGradient nlcg = null;
        BoundProjector projector = null;
        int bounded = 0;
        if (this.lowerBound != Double.NEGATIVE_INFINITY) {
            bounded |= 1;
        }
        if (this.upperBound != Double.POSITIVE_INFINITY) {
            bounded |= 2;
        }
        if (bounded == 0) {
            lineSearch = new MoreThuenteLineSearch(0.05, 0.1, 1.0E-17);
            if (this.limitedMemorySize > 0) {
                lbfgs = new LBFGS(resultSpace, this.limitedMemorySize, (LineSearch)lineSearch);
                lbfgs.setAbsoluteTolerance(this.gatol);
                lbfgs.setRelativeTolerance(this.grtol);
                this.minimizer = lbfgs;
            } else {
                int method = 771;
                nlcg = new NonLinearConjugateGradient(resultSpace, method, lineSearch);
                nlcg.setAbsoluteTolerance(this.gatol);
                nlcg.setRelativeTolerance(this.grtol);
                this.minimizer = nlcg;
            }
        } else {
            projector = bounded == 1 ? new SimpleLowerBound(resultSpace, this.lowerBound) : (bounded == 2 ? new SimpleUpperBound(resultSpace, this.upperBound) : new SimpleBounds(resultSpace, this.lowerBound, this.upperBound));
            int m = this.limitedMemorySize > 1 ? this.limitedMemorySize : 5;
            blmvm = new BLMVM(resultSpace, projector, m);
            blmvm.setAbsoluteTolerance(this.gatol);
            blmvm.setRelativeTolerance(this.grtol);
            this.minimizer = blmvm;
            projector.projectVariables(x, x);
        }
        this.timer.stop();
        if (this.debug) {
            System.out.format("Optimization method initialization complete in %.3f sec.\n", this.timer.getElapsedTime());
        }
        this.timer.reset();
        OptimTask task = this.minimizer.start();
        while (this.run) {
            if (task == OptimTask.COMPUTE_FG) {
                this.timer.resume();
                this.fcost = cost.computeCostAndGradient(1.0, x, this.gcost, true);
                this.timer.stop();
            } else if (task == OptimTask.NEW_X || task == OptimTask.FINAL_X) {
                boolean stop;
                if (this.viewer != null) {
                    this.viewer.display(this);
                }
                boolean bl = stop = task == OptimTask.FINAL_X;
                if (!stop && this.maxiter >= 0 && this.minimizer.getIterations() >= this.maxiter) {
                    System.err.format("Warning: too many iterations (%d).\n", this.maxiter);
                    stop = true;
                }
                if (stop) {
                    break;
                }
            } else {
                System.err.println("TiPi: TotalVariationDeconvolution, error/warning: " + this.minimizer.getReason());
                break;
            }
            if (this.synchronizer != null) {
                if (this.synchronizer.getTask() == 1) break;
                this.synchronizedParameters[0] = this.mu;
                this.synchronizedParameters[1] = this.epsilon;
                if (this.synchronizer.updateParameters(this.synchronizedParameters, this.change)) {
                    if (this.change[0]) {
                        this.mu = this.synchronizedParameters[0];
                    }
                    if (this.change[1]) {
                        this.epsilon = this.synchronizedParameters[1];
                    }
                }
            }
            task = this.minimizer.iterate(x, this.fcost, this.gcost);
        }
        if (this.verbose) {
            this.timer.stop();
            double elapsed = this.timer.getElapsedTime();
            int nevals = this.getEvaluations();
            System.out.format("Total time in cost function: %.3f s (%.3f ms/eval.)\n", elapsed, nevals > 0 ? 1000.0 * elapsed / (double)nevals : 0.0);
            if (fdata instanceof WeightedConvolutionCost) {
                DifferentiableCostFunction f = fdata;
                elapsed = ((WeightedConvolutionCost)f).getElapsedTimeInFFT();
                System.out.format("Total time in FFT: %.3f s (%.3f ms/eval.)\n", elapsed, nevals > 0 ? 1000.0 * elapsed / (double)nevals : 0.0);
                elapsed = ((WeightedConvolutionCost)f).getElapsedTime() - elapsed;
                System.out.format("Total time in other parts of the convolution operator: %.3f s (%.3f ms/eval.)\n", elapsed, nevals > 0 ? 1000.0 * elapsed / (double)nevals : 0.0);
            }
            System.out.format("min(x) = %g\n", ArrayOps.getMin(x.getData()));
            System.out.format("max(x) = %g\n", ArrayOps.getMax(x.getData()));
        }
    }

    @Override
    public int getIterations() {
        return this.minimizer == null ? 0 : this.minimizer.getIterations();
    }

    @Override
    public int getEvaluations() {
        return this.minimizer == null ? 0 : this.minimizer.getEvaluations();
    }

    @Override
    public double getCost() {
        return this.fcost;
    }

    @Override
    public double getGradientNorm2() {
        return this.gcost == null ? 0.0 : this.gcost.norm2();
    }

    @Override
    public double getGradientNorm1() {
        return this.gcost == null ? 0.0 : this.gcost.norm1();
    }

    @Override
    public double getGradientNormInf() {
        return this.gcost == null ? 0.0 : this.gcost.normInf();
    }
}

