package mitiv.cost;

import mitiv.array.ArrayFactory;
import mitiv.array.ArrayUtils;
import mitiv.array.ByteArray;
import mitiv.array.DoubleArray;
import mitiv.array.FloatArray;
import mitiv.array.IntArray;
import mitiv.array.LongArray;
import mitiv.array.ShapedArray;
import mitiv.array.ShortArray;
import mitiv.base.Shape;
import mitiv.deconv.Convolution;
import mitiv.deconv.WeightedConvolutionCost;
import mitiv.linalg.Vector;
import mitiv.linalg.shaped.DoubleShapedVector;
import mitiv.linalg.shaped.DoubleShapedVectorSpace;
import mitiv.linalg.shaped.FloatShapedVector;
import mitiv.linalg.shaped.FloatShapedVectorSpace;
import mitiv.linalg.shaped.ShapedVector;
import mitiv.linalg.shaped.ShapedVectorSpace;
import mitiv.optim.OptimTask;
import mitiv.utils.FFTUtils;

/* loaded from: input_file:mitiv/cost/EdgePreservingDeconvolution.class */
public class EdgePreservingDeconvolution extends SmoothInverseProblem {
    private boolean single;
    private boolean updatePending = true;
    private ShapedVectorSpace dataSpace = null;
    private ShapedVectorSpace objectSpace = null;
    private Vector x = null;
    private ShapedArray data = null;
    private boolean writableData = false;
    private ShapedArray weights = null;
    private boolean writableWeights = false;
    private double sigma = Double.NaN;
    private double gamma = Double.NaN;
    private ShapedArray bads = null;
    private WeightedData weightedData = null;
    private ShapedArray psf = null;
    private boolean normalizePSF = false;
    private ShapedArray object = null;
    private Shape objectShape = null;
    private double padValue = Double.NaN;
    private double epsilon = 1.0d;
    private double[] scale = {1.0d};
    private boolean useNewCode = false;

    private void forceRestart() {
        this.weightedData = null;
        this.updatePending = true;
    }

    public boolean getUseNewCode() {
        return this.useNewCode;
    }

    public void setUseNewCode(boolean z) {
        if (this.useNewCode != z) {
            this.useNewCode = z;
            forceRestart();
        }
    }

    public boolean getForceSinglePrecision() {
        return this.single;
    }

    public void setForceSinglePrecision(boolean z) {
        if (this.single != z) {
            this.single = z;
            forceRestart();
        }
    }

    public ShapedArray getData() {
        return this.data;
    }

    public void setData(ShapedArray shapedArray, boolean z) {
        if (this.data != shapedArray) {
            this.data = shapedArray;
            this.writableData = z;
            forceRestart();
        }
    }

    public void setData(ShapedArray shapedArray) {
        setData(shapedArray, false);
    }

    public ShapedArray getWeights() {
        return this.weights;
    }

    public void setWeights(ShapedArray shapedArray, boolean z) {
        if (this.weights != shapedArray) {
            this.weights = shapedArray;
            this.writableWeights = z;
            forceRestart();
        }
    }

    public void setWeights(ShapedArray shapedArray) {
        setWeights(shapedArray, false);
    }

    public ShapedArray getBads() {
        return this.bads;
    }

    public void setBads(ShapedArray shapedArray) {
        if (this.bads != shapedArray) {
            this.bads = shapedArray;
            forceRestart();
        }
    }

    public ShapedArray getPSF() {
        return this.psf;
    }

    public void setPSF(ShapedArray shapedArray) {
        setPSF(shapedArray, false);
    }

    public void setPSF(ShapedArray shapedArray, boolean z) {
        if (this.psf != shapedArray) {
            this.psf = shapedArray;
            this.normalizePSF = z;
            forceRestart();
        }
    }

    public double getEdgeThreshold() {
        return this.epsilon;
    }

    public void setEdgeThreshold(double d) {
        if (nonfinite(d) || d <= 0.0d) {
            error("Edge threshold must be strictly positive");
        }
        if (this.epsilon != d) {
            this.epsilon = d;
            forceRestart();
        }
    }

    public void setScale(double... dArr) {
        this.scale = dArr;
    }

    public double[] getScale() {
        return this.scale;
    }

    public ShapedArray getSolution() {
        return this.object;
    }

    @Override // mitiv.optim.IterativeDifferentiableSolver
    public ShapedVector getBestSolution() {
        return (ShapedVector) super.getBestSolution();
    }

    public void setInitialSolution(ShapedArray shapedArray) {
        if (this.object != shapedArray) {
            this.object = shapedArray;
            forceRestart();
            resetIteration();
        }
    }

    public Shape getObjectShape() {
        return this.objectShape;
    }

    public void setObjectShape(Shape shape) {
        if ((shape == null) == (this.objectShape == null) && (shape == null || this.objectShape == null || shape.equals(this.objectShape))) {
            return;
        }
        this.objectShape = shape;
        forceRestart();
    }

    public double getFillValue() {
        return this.padValue;
    }

    public void setFillValue(double d) {
        this.padValue = d;
    }

    public void setObjectShape(int[] iArr) {
        setObjectShape(new Shape(iArr));
    }

    private void update() {
        if (this.data == null) {
            error("No data specified");
        }
        int rank = this.data.getRank();
        Shape shape = this.data.getShape();
        if (this.weights != null && !this.weights.getShape().equals(shape)) {
            error("Weights and data must have the same dimensions");
        }
        if (this.bads != null && !this.bads.getShape().equals(shape)) {
            error("Mask of invalid data must have the same dimensions as the data");
        }
        if (this.psf != null && this.psf.getRank() != rank) {
            error("PSF and data must have the same number of dimensions");
        }
        if (this.object != null && this.object.getRank() != rank) {
            error("Object and data must have the same number of dimensions");
        }
        if (this.objectShape != null && this.objectShape.rank() != rank) {
            error("Given object shape must the same number of dimensions as the data");
        }
        if (this.debug) {
            System.out.format("mu: %.2g, epsilon: %.2g\n", Double.valueOf(getRegularizationLevel()), Double.valueOf(getEdgeThreshold()));
        }
        int i = this.single ? 4 : (this.data.getType() == 5 || (this.psf != null && this.psf.getType() == 5) || ((this.weights != null && this.weights.getType() == 5) || (this.object != null && this.object.getType() == 5))) ? 5 : 4;
        if (this.psf == null) {
            this.objectShape = shape;
        } else {
            Shape shape2 = this.psf.getShape();
            if (this.objectShape != null) {
                for (int i2 = 0; i2 < rank; i2++) {
                    if (this.objectShape.dimension(i2) < shape.dimension(i2)) {
                        error("Given object dimensions must be at least those of the data");
                    }
                    if (shape2 != null && this.objectShape.dimension(i2) < shape2.dimension(i2)) {
                        error("Given object dimensions must be at least those of the PSF");
                    }
                }
            } else {
                int[] iArr = new int[rank];
                for (int i3 = 0; i3 < rank; i3++) {
                    int dimension = (shape.dimension(i3) + shape2.dimension(i3)) - 1;
                    if (this.object != null) {
                        dimension = Math.max(dimension, this.object.getDimension(i3));
                    }
                    iArr[i3] = FFTUtils.bestDimension(dimension);
                }
                this.objectShape = new Shape(iArr);
            }
        }
        if (i == 4) {
            if (this.dataSpace == null) {
                this.dataSpace = new FloatShapedVectorSpace(shape);
            }
            if (this.objectSpace == null) {
                this.objectSpace = new FloatShapedVectorSpace(this.objectShape);
            }
        } else {
            if (this.dataSpace == null) {
                this.dataSpace = new DoubleShapedVectorSpace(shape);
            }
            if (this.objectSpace == null) {
                this.objectSpace = new DoubleShapedVectorSpace(this.objectShape);
            }
        }
        if (this.psf == null) {
            this.weightedData = new WeightedData(this.dataSpace);
            setWeightsAndData(this.weightedData);
            setLikelihood(this.weightedData);
        } else if (this.useNewCode) {
            this.weightedData = new WeightedData(this.dataSpace);
            setWeightsAndData(this.weightedData);
            Convolution build = Convolution.build(this.objectSpace, this.dataSpace);
            build.setPSF(this.psf, this.normalizePSF);
            setLikelihood(new DifferentiableGaussianLikelihood(this.weightedData, build));
        } else {
            WeightedConvolutionCost build2 = WeightedConvolutionCost.build(this.objectSpace, this.dataSpace);
            setWeightsAndData(build2);
            build2.setPSF(this.psf, this.normalizePSF);
            setLikelihood(build2);
            this.weightedData = build2;
        }
        if (this.object == null) {
            double computePadValue = computePadValue();
            this.object = ArrayFactory.create(i, this.objectShape);
            if (this.single) {
                ((FloatArray) this.object).fill((float) computePadValue);
            } else {
                ((DoubleArray) this.object).fill(computePadValue);
            }
            if (this.debug) {
                System.err.format("Create initial array with value %g\n", Double.valueOf(computePadValue));
            }
        } else {
            double d = 0.0d;
            int i4 = 0;
            while (true) {
                if (i4 >= rank) {
                    break;
                }
                if (this.objectShape.dimension(i4) > this.object.getDimension(i4)) {
                    d = computePadValue();
                    break;
                }
                i4++;
            }
            if (this.debug) {
                System.err.format("Pad initial array with value %g\n", Double.valueOf(d));
            }
            this.object = ArrayUtils.extract(this.object, this.objectShape, d);
        }
        HyperbolicTotalVariation hyperbolicTotalVariation = new HyperbolicTotalVariation(this.objectSpace, this.epsilon);
        if (this.scale.length == 1) {
            hyperbolicTotalVariation.setScale(this.scale[0]);
        } else {
            hyperbolicTotalVariation.setScale(this.scale);
        }
        setRegularization(hyperbolicTotalVariation);
        boolean z = (this.object.getType() == i && this.object.isFlat()) ? false : true;
        this.x = this.objectSpace.create(this.object, false);
        if (z) {
            if (i == 4) {
                this.object = ArrayFactory.wrap(((FloatShapedVector) this.x).getData(), this.objectShape);
            } else {
                this.object = ArrayFactory.wrap(((DoubleShapedVector) this.x).getData(), this.objectShape);
            }
        }
        this.updatePending = false;
    }

    public OptimTask start() {
        return start(false);
    }

    public OptimTask start(boolean z) {
        if (this.updatePending) {
            update();
        }
        return super.start(this.x, z);
    }

    public OptimTask iterate() {
        return this.updatePending ? start() : super.iterate(this.x);
    }

    private static void error(String str) {
        throw new IllegalArgumentException(str);
    }

    private boolean nonfinite(double d) {
        return Double.isInfinite(d) || Double.isNaN(d);
    }

    public double getDetectorNoise() {
        return this.sigma;
    }

    public void setDetectorNoise(double d) {
        this.sigma = d;
    }

    public double getDetectorGain() {
        return this.gamma;
    }

    public void setDetectorGain(double d) {
        this.gamma = d;
    }

    private void setWeightsAndData(WeightedData weightedData) {
        double d;
        double abs2;
        weightedData.setData(this.data, this.writableData);
        if (this.weights != null) {
            if (!isnan(this.sigma) || !isnan(this.gamma)) {
                System.err.println("Warning: noise model parameters are ignored when weights are specified.");
            }
            weightedData.setWeights(this.weights, this.writableWeights);
        } else {
            if (isnan(this.sigma)) {
                if (!isnan(this.gamma)) {
                    System.err.println("Warning: linear noise model parameter is ignored if affine noise model parameter is not specified");
                }
                d = 0.0d;
                abs2 = 1.0d;
            } else if (isnan(this.gamma)) {
                d = 0.0d;
                abs2 = abs2(this.sigma);
            } else {
                d = 1.0d / this.gamma;
                abs2 = abs2(this.sigma / this.gamma);
            }
            System.err.format("alpha = %g, beta = %g\n", Double.valueOf(d), Double.valueOf(abs2));
            weightedData.computeWeightsFromData(d, abs2);
        }
        if (this.bads != null) {
            weightedData.markBadData(this.bads);
        }
    }

    private double computePadValue() {
        double d;
        if (isnan(this.padValue)) {
            d = this.weightedData.getWeightedMean();
            if (this.psf != null && !this.normalizePSF) {
                d /= sum(this.psf);
            }
        } else {
            d = this.padValue;
        }
        return d;
    }

    private static double sum(ShapedArray shapedArray) {
        double d = 0.0d;
        if (shapedArray != null) {
            switch (shapedArray.getType()) {
                case 0:
                    d = ((ByteArray) shapedArray).sum();
                    break;
                case 1:
                    d = ((ShortArray) shapedArray).sum();
                    break;
                case 2:
                    d = ((IntArray) shapedArray).sum();
                    break;
                case 3:
                    d = ((LongArray) shapedArray).sum();
                    break;
                case 4:
                    d = ((FloatArray) shapedArray).sum();
                    break;
                case 5:
                    d = ((DoubleArray) shapedArray).sum();
                    break;
            }
        }
        return d;
    }

    private static final boolean isnan(double d) {
        return Double.isNaN(d);
    }

    private static final double abs2(double d) {
        return d * d;
    }
}
