package mitiv.deconv;

import mitiv.array.ShapedArray;
import mitiv.base.Shape;
import mitiv.cost.DifferentiableCostFunction;
import mitiv.deconv.impl.WeightedConvolutionDouble1D;
import mitiv.deconv.impl.WeightedConvolutionDouble2D;
import mitiv.deconv.impl.WeightedConvolutionDouble3D;
import mitiv.deconv.impl.WeightedConvolutionFloat1D;
import mitiv.deconv.impl.WeightedConvolutionFloat2D;
import mitiv.deconv.impl.WeightedConvolutionFloat3D;
import mitiv.exception.IllegalTypeException;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;
import mitiv.linalg.shaped.ShapedVector;
import mitiv.linalg.shaped.ShapedVectorSpace;
import mitiv.utils.Timer;

/* loaded from: input_file:mitiv/deconv/WeightedConvolutionCost.class */
public abstract class WeightedConvolutionCost implements DifferentiableCostFunction {
    protected final ShapedVectorSpace objectSpace;
    protected final ShapedVectorSpace dataSpace;
    protected Timer timerForFFT = new Timer();
    protected Timer timer = new Timer();

    /* JADX INFO: Access modifiers changed from: protected */
    public WeightedConvolutionCost(ShapedVectorSpace shapedVectorSpace, ShapedVectorSpace shapedVectorSpace2) {
        this.objectSpace = shapedVectorSpace;
        this.dataSpace = shapedVectorSpace2;
    }

    public VectorSpace getObjectSpace() {
        return this.objectSpace;
    }

    @Override // mitiv.cost.CostFunction
    public VectorSpace getInputSpace() {
        return this.objectSpace;
    }

    public VectorSpace getDataSpace() {
        return this.dataSpace;
    }

    public static WeightedConvolutionCost build(ShapedVectorSpace shapedVectorSpace) {
        return build(shapedVectorSpace, shapedVectorSpace);
    }

    public static WeightedConvolutionCost build(ShapedVectorSpace shapedVectorSpace, ShapedVectorSpace shapedVectorSpace2) {
        int min = Math.min(shapedVectorSpace.getRank(), shapedVectorSpace2.getRank());
        int[] iArr = new int[min];
        for (int i = 0; i < min; i++) {
            iArr[i] = (shapedVectorSpace.getDimension(i) / 2) - (shapedVectorSpace2.getDimension(i) / 2);
        }
        return build(shapedVectorSpace, shapedVectorSpace2, iArr);
    }

    public static WeightedConvolutionCost build(ShapedVectorSpace shapedVectorSpace, ShapedVectorSpace shapedVectorSpace2, int[] iArr) {
        int type = shapedVectorSpace.getType();
        if (shapedVectorSpace2.getType() != type) {
            throw new IllegalTypeException("Input and output spaces must have same element type.");
        }
        int rank = shapedVectorSpace.getRank();
        if (shapedVectorSpace2.getShape().rank() != rank) {
            throw new IllegalTypeException("Input and output spaces must have same rank.");
        }
        switch (type) {
            case 4:
                switch (rank) {
                    case 1:
                        return new WeightedConvolutionFloat1D(shapedVectorSpace, shapedVectorSpace2, iArr);
                    case 2:
                        return new WeightedConvolutionFloat2D(shapedVectorSpace, shapedVectorSpace2, iArr);
                    case 3:
                        return new WeightedConvolutionFloat3D(shapedVectorSpace, shapedVectorSpace2, iArr);
                }
            case 5:
                switch (rank) {
                    case 1:
                        return new WeightedConvolutionDouble1D(shapedVectorSpace, shapedVectorSpace2, iArr);
                    case 2:
                        return new WeightedConvolutionDouble2D(shapedVectorSpace, shapedVectorSpace2, iArr);
                    case 3:
                        return new WeightedConvolutionDouble3D(shapedVectorSpace, shapedVectorSpace2, iArr);
                }
            default:
                throw new IllegalTypeException("Only float and double types are implemented.");
        }
        throw new IllegalArgumentException("Only 1D, 2D and 3D convolution are implemented.");
    }

    private final void checkObject(Vector vector) {
        if (!vector.belongsTo(this.objectSpace)) {
            throw new IllegalArgumentException("Variables X does not belong to the object space.");
        }
    }

    private final void checkGradient(Vector vector) {
        if (!vector.belongsTo(this.objectSpace)) {
            throw new IllegalArgumentException("Gradient GX does not belong to the object space.");
        }
    }

    @Override // mitiv.cost.CostFunction
    public double evaluate(double d, Vector vector) {
        checkObject(vector);
        if (d == 0.0d) {
            return 0.0d;
        }
        return cost(d, vector);
    }

    @Override // mitiv.cost.DifferentiableCostFunction
    public double computeCostAndGradient(double d, Vector vector, Vector vector2, boolean z) {
        checkObject(vector);
        checkGradient(vector2);
        if (d != 0.0d) {
            return cost(d, vector, vector2, z);
        }
        if (!z) {
            return 0.0d;
        }
        vector2.zero();
        return 0.0d;
    }

    protected abstract double cost(double d, Vector vector);

    protected abstract double cost(double d, Vector vector, Vector vector2, boolean z);

    public abstract void setWeightsAndData(ShapedVector shapedVector, ShapedVector shapedVector2);

    public abstract void setWeightsAndData(ShapedArray shapedArray, ShapedArray shapedArray2);

    /* JADX INFO: Access modifiers changed from: protected */
    public static void badWeights() {
        throw new IllegalArgumentException("Weights must be finite and non-negative.");
    }

    public abstract void setPSF(ShapedVector shapedVector);

    public void setPSF(ShapedArray shapedArray) {
        setPSF(shapedArray, Convolution.center(shapedArray.getShape()));
    }

    public abstract void setPSF(ShapedArray shapedArray, int[] iArr);

    protected static int outputOffset(int i, Shape shape, Shape shape2, int[] iArr) {
        int i2;
        if (shape.rank() != i) {
            throw new IllegalArgumentException("Bad rank for input space.");
        }
        if (shape2.rank() != i) {
            throw new IllegalArgumentException("Bad rank for output space.");
        }
        if (iArr != null && iArr.length != i) {
            throw new IllegalArgumentException("Bad number of coordinates for the first position");
        }
        int i3 = 0;
        int i4 = 1;
        for (int i5 = 0; i5 < i; i5++) {
            int dimension = shape.dimension(i5);
            int dimension2 = shape2.dimension(i5);
            if (dimension2 > dimension) {
                throw new IllegalArgumentException("Output dimensions must be at most as large as input dimensions");
            }
            if (iArr == null) {
                i2 = (dimension / 2) - (dimension2 / 2);
            } else {
                i2 = iArr[i5];
                if (i2 < 0 || i2 + dimension2 > dimension) {
                    throw new IllegalArgumentException("Output region is outside bounds");
                }
            }
            i3 += i4 * i2;
            i4 *= dimension;
        }
        return i3;
    }

    public void resetTimers() {
        this.timerForFFT.stop();
        this.timerForFFT.reset();
        this.timer.stop();
        this.timer.reset();
    }

    public double getElapsedTime() {
        return this.timer.getElapsedTime();
    }

    public double getElapsedTimeInFFT() {
        return this.timerForFFT.getElapsedTime();
    }
}
