package mitiv.cost;

import mitiv.base.Shape;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;
import mitiv.linalg.shaped.DoubleShapedVector;
import mitiv.linalg.shaped.FloatShapedVector;
import mitiv.linalg.shaped.ShapedVectorSpace;

/* loaded from: input_file:mitiv/cost/HyperbolicTotalVariation.class */
public class HyperbolicTotalVariation implements DifferentiableCostFunction {
    protected ShapedVectorSpace inputSpace;
    protected int rank;
    protected int type;
    protected Shape shape;
    protected double epsilon;
    protected double[] scale;
    protected double[] delta;

    public HyperbolicTotalVariation(ShapedVectorSpace shapedVectorSpace, double d) {
        this.delta = null;
        this.inputSpace = shapedVectorSpace;
        this.shape = shapedVectorSpace.getShape();
        this.rank = this.shape == null ? 0 : this.shape.rank();
        this.type = shapedVectorSpace.getType();
        setThreshold(d);
        this.scale = new double[this.rank];
        defaultScale();
    }

    public HyperbolicTotalVariation(ShapedVectorSpace shapedVectorSpace, double d, double[] dArr) {
        this(shapedVectorSpace, d);
        setScale(dArr);
    }

    public HyperbolicTotalVariation(double d, double[] dArr) {
        this.delta = null;
        setThreshold(d);
        setDelta(dArr);
    }

    public HyperbolicTotalVariation(double[] dArr) {
        this.delta = null;
        setDelta(dArr);
    }

    public HyperbolicTotalVariation(double d) {
        this.delta = null;
        setThreshold(d);
    }

    public void setThreshold(double d) {
        if (notFinite(d) || d <= 0.0d) {
            throw new IllegalArgumentException("Bad threshold value");
        }
        this.epsilon = d;
    }

    public double getThreshold() {
        return this.epsilon;
    }

    public void defaultScale() {
        if (this.delta == null) {
            setScale(1.0d);
        } else {
            setScale(this.delta);
        }
    }

    private void setDelta(double[] dArr) {
        if (dArr == null) {
            throw new IllegalArgumentException("Bad delta value");
        }
        for (int i = 0; i < this.delta.length; i++) {
            if (notFinite(dArr[i]) || dArr[i] <= 0.0d) {
                throw new IllegalArgumentException("Bad delta value");
            }
        }
        this.delta = (double[]) dArr.clone();
    }

    public void setScale(double d) {
        if (notFinite(d) || d <= 0.0d) {
            throw new IllegalArgumentException("Bad delta value");
        }
        for (int i = 0; i < this.rank; i++) {
            this.scale[i] = d;
        }
        this.delta = null;
    }

    public void setScale(double[] dArr) {
        if (dArr == null) {
            throw new IllegalArgumentException("Bad scale size");
        }
        if (dArr.length == 1) {
            setScale(dArr[1]);
        } else {
            if (dArr.length != this.rank) {
                throw new IllegalArgumentException("Bad scale size");
            }
            for (int i = 0; i < this.rank; i++) {
                if (notFinite(dArr[i]) || dArr[i] <= 0.0d) {
                    throw new IllegalArgumentException("Bad scale value");
                }
            }
            for (int i2 = 0; i2 < this.rank; i2++) {
                this.scale[i2] = dArr[i2];
            }
        }
        this.delta = null;
    }

    public double getScale(int i) {
        if (i < 0 || i >= this.rank) {
            return 1.0d;
        }
        return this.scale[i];
    }

    @Override // mitiv.cost.CostFunction
    public VectorSpace getInputSpace() {
        return this.inputSpace;
    }

    public final void setInputSpace(ShapedVectorSpace shapedVectorSpace) {
        this.inputSpace = shapedVectorSpace;
        this.shape = shapedVectorSpace.getShape();
        this.rank = this.shape == null ? 0 : this.shape.rank();
        this.type = shapedVectorSpace.getType();
        this.scale = new double[this.rank];
        defaultScale();
    }

    @Override // mitiv.cost.DifferentiableCostFunction
    public double computeCostAndGradient(double d, Vector vector, Vector vector2, boolean z) {
        if (vector2 != null && z) {
            vector2.fill(0.0d);
        }
        if (this.type == 4) {
            float[] data = ((FloatShapedVector) vector).getData();
            float[] fArr = null;
            if (vector2 != null) {
                fArr = ((FloatShapedVector) vector2).getData();
            }
            if (this.rank == 1) {
                return computeFloat1D(d, data, fArr);
            }
            if (this.rank == 2) {
                return computeFloat2D(d, data, fArr);
            }
            if (this.rank == 3) {
                return computeFloat3D(d, data, fArr);
            }
            badRank();
            return 0.0d;
        }
        if (this.type != 5) {
            badType();
            return 0.0d;
        }
        double[] data2 = ((DoubleShapedVector) vector).getData();
        double[] dArr = null;
        if (vector2 != null) {
            dArr = ((DoubleShapedVector) vector2).getData();
        }
        if (this.rank == 1) {
            return computeDouble1D(d, data2, dArr);
        }
        if (this.rank == 2) {
            return computeDouble2D(d, data2, dArr);
        }
        if (this.rank == 3) {
            return computeDouble3D(d, data2, dArr);
        }
        badRank();
        return 0.0d;
    }

    protected static void badRank() {
        throw new IllegalArgumentException("Unsupported number of dimensions for Total Variation");
    }

    protected static void badType() {
        throw new IllegalArgumentException("Unsupported data type for Total Variation");
    }

    private final double computeFloat1D(double d, float[] fArr, float[] fArr2) {
        boolean z = fArr2 != null;
        int dimension = this.shape.dimension(0);
        float square = (float) square(this.epsilon * this.scale[0]);
        double d2 = 0.0d;
        float f = (float) (d / this.scale[0]);
        for (int i = 1; i < dimension; i++) {
            float f2 = fArr[i] - fArr[i - 1];
            float sqrt = sqrt((f2 * f2) + square);
            d2 += sqrt;
            if (z) {
                float f3 = f * (f2 / sqrt);
                int i2 = i - 1;
                fArr2[i2] = fArr2[i2] - f3;
                int i3 = i;
                fArr2[i3] = fArr2[i3] + f3;
            }
        }
        double d3 = (d2 / this.scale[0]) - ((dimension - 1) * this.epsilon);
        if (d3 > 0.0d) {
            return d * d3;
        }
        return 0.0d;
    }

    private final double computeDouble1D(double d, double[] dArr, double[] dArr2) {
        boolean z = dArr2 != null;
        int dimension = this.shape.dimension(0);
        double square = square(this.epsilon * this.scale[0]);
        double d2 = 0.0d;
        double d3 = d / this.scale[0];
        for (int i = 1; i < dimension; i++) {
            double d4 = dArr[i] - dArr[i - 1];
            double sqrt = sqrt((d4 * d4) + square);
            d2 += sqrt;
            if (z) {
                double d5 = d3 * (d4 / sqrt);
                int i2 = i - 1;
                dArr2[i2] = dArr2[i2] - d5;
                int i3 = i;
                dArr2[i3] = dArr2[i3] + d5;
            }
        }
        double d6 = (d2 / this.scale[0]) - ((dimension - 1) * this.epsilon);
        if (d6 > 0.0d) {
            return d * d6;
        }
        return 0.0d;
    }

    private final double computeFloat2D(double d, float[] fArr, float[] fArr2) {
        boolean z = fArr2 != null;
        int dimension = this.shape.dimension(0);
        int dimension2 = this.shape.dimension(1);
        float square = (float) (1.0d / (2.0d * square(this.scale[0])));
        float square2 = (float) (1.0d / (2.0d * square(this.scale[1])));
        float square3 = (float) square(this.epsilon);
        double d2 = 0.0d;
        float f = (float) d;
        if (square == square2) {
            for (int i = 1; i < dimension2; i++) {
                int i2 = (i - 1) * dimension;
                int i3 = i * dimension;
                float f2 = fArr[i2];
                float f3 = fArr[i3];
                for (int i4 = 1; i4 < dimension; i4++) {
                    int i5 = i2;
                    i2++;
                    int i6 = i3;
                    i3++;
                    float f4 = f2;
                    f2 = fArr[i2];
                    float f5 = f3;
                    f3 = fArr[i3];
                    float f6 = f2 - f4;
                    float f7 = f3 - f5;
                    float f8 = f5 - f4;
                    float f9 = f3 - f2;
                    float sqrt = sqrt(((square(f6) + square(f7) + square(f8) + square(f9)) * square) + square3);
                    d2 += sqrt;
                    if (z) {
                        float f10 = (f * square) / sqrt;
                        fArr2[i5] = fArr2[i5] - ((f6 + f8) * f10);
                        fArr2[i2] = fArr2[i2] + ((f6 - f9) * f10);
                        fArr2[i6] = fArr2[i6] - ((f7 - f8) * f10);
                        fArr2[i3] = fArr2[i3] + ((f7 + f9) * f10);
                    }
                }
            }
        } else {
            for (int i7 = 1; i7 < dimension2; i7++) {
                int i8 = (i7 - 1) * dimension;
                int i9 = i7 * dimension;
                float f11 = fArr[i8];
                float f12 = fArr[i9];
                for (int i10 = 1; i10 < dimension; i10++) {
                    int i11 = i8;
                    i8++;
                    float f13 = f11;
                    f11 = fArr[i8];
                    int i12 = i9;
                    i9++;
                    float f14 = f12;
                    f12 = fArr[i9];
                    float f15 = f11 - f13;
                    float f16 = f12 - f14;
                    float f17 = f14 - f13;
                    float f18 = f12 - f11;
                    float sqrt2 = sqrt(((square(f15) + square(f16)) * square) + ((square(f17) + square(f18)) * square2) + square3);
                    d2 += sqrt2;
                    if (z) {
                        float f19 = f / sqrt2;
                        float f20 = square * f19;
                        float f21 = f15 * f20;
                        float f22 = f16 * f20;
                        float f23 = square2 * f19;
                        float f24 = f17 * f23;
                        float f25 = f18 * f23;
                        fArr2[i11] = fArr2[i11] - (f21 + f24);
                        fArr2[i8] = fArr2[i8] + (f21 - f25);
                        fArr2[i12] = fArr2[i12] - (f22 - f24);
                        fArr2[i9] = fArr2[i9] + f22 + f25;
                    }
                }
            }
        }
        double d3 = d2 - (((dimension - 1) * (dimension2 - 1)) * this.epsilon);
        if (d3 > 0.0d) {
            return d * d3;
        }
        return 0.0d;
    }

    private final double computeDouble2D(double d, double[] dArr, double[] dArr2) {
        boolean z = dArr2 != null;
        int dimension = this.shape.dimension(0);
        int dimension2 = this.shape.dimension(1);
        double square = 1.0d / (2.0d * square(this.scale[0]));
        double square2 = 1.0d / (2.0d * square(this.scale[1]));
        double square3 = square(this.epsilon);
        double d2 = 0.0d;
        if (square == square2) {
            for (int i = 1; i < dimension2; i++) {
                int i2 = (i - 1) * dimension;
                int i3 = i * dimension;
                double d3 = dArr[i2];
                double d4 = dArr[i3];
                for (int i4 = 1; i4 < dimension; i4++) {
                    int i5 = i2;
                    i2++;
                    int i6 = i3;
                    i3++;
                    double d5 = d3;
                    d3 = dArr[i2];
                    double d6 = d4;
                    d4 = dArr[i3];
                    double d7 = d3 - d5;
                    double d8 = d4 - d6;
                    double d9 = d6 - d5;
                    double d10 = d4 - d3;
                    double sqrt = sqrt(((square(d7) + square(d8) + square(d9) + square(d10)) * square) + square3);
                    d2 += sqrt;
                    if (z) {
                        double d11 = (d * square) / sqrt;
                        dArr2[i5] = dArr2[i5] - ((d7 + d9) * d11);
                        dArr2[i2] = dArr2[i2] + ((d7 - d10) * d11);
                        dArr2[i6] = dArr2[i6] - ((d8 - d9) * d11);
                        dArr2[i3] = dArr2[i3] + ((d8 + d10) * d11);
                    }
                }
            }
        } else {
            for (int i7 = 1; i7 < dimension2; i7++) {
                int i8 = (i7 - 1) * dimension;
                int i9 = i7 * dimension;
                double d12 = dArr[i8];
                double d13 = dArr[i9];
                for (int i10 = 1; i10 < dimension; i10++) {
                    int i11 = i8;
                    i8++;
                    double d14 = d12;
                    d12 = dArr[i8];
                    int i12 = i9;
                    i9++;
                    double d15 = d13;
                    d13 = dArr[i9];
                    double d16 = d12 - d14;
                    double d17 = d13 - d15;
                    double d18 = d15 - d14;
                    double d19 = d13 - d12;
                    double sqrt2 = sqrt(((square(d16) + square(d17)) * square) + ((square(d18) + square(d19)) * square2) + square3);
                    d2 += sqrt2;
                    if (z) {
                        double d20 = d / sqrt2;
                        double d21 = square * d20;
                        double d22 = d16 * d21;
                        double d23 = d17 * d21;
                        double d24 = square2 * d20;
                        double d25 = d18 * d24;
                        double d26 = d19 * d24;
                        dArr2[i11] = dArr2[i11] - (d22 + d25);
                        dArr2[i8] = dArr2[i8] + (d22 - d26);
                        dArr2[i12] = dArr2[i12] - (d23 - d25);
                        dArr2[i9] = dArr2[i9] + d23 + d26;
                    }
                }
            }
        }
        double d27 = d2 - (((dimension - 1) * (dimension2 - 1)) * this.epsilon);
        if (d27 > 0.0d) {
            return d * d27;
        }
        return 0.0d;
    }

    private final double computeFloat3D(double d, float[] fArr, float[] fArr2) {
        boolean z = fArr2 != null;
        int dimension = this.shape.dimension(0);
        int dimension2 = this.shape.dimension(1);
        int dimension3 = this.shape.dimension(2);
        float square = (float) (1.0d / (4.0d * square(this.scale[0])));
        float square2 = (float) (1.0d / (4.0d * square(this.scale[1])));
        float square3 = (float) (1.0d / (4.0d * square(this.scale[2])));
        float square4 = (float) square(this.epsilon);
        double d2 = 0.0d;
        float f = (float) d;
        for (int i = 1; i < dimension3; i++) {
            for (int i2 = 1; i2 < dimension2; i2++) {
                int i3 = ((i2 - 1) + ((i - 1) * dimension2)) * dimension;
                int i4 = (i2 + ((i - 1) * dimension2)) * dimension;
                int i5 = ((i2 - 1) + (i * dimension2)) * dimension;
                int i6 = (i2 + (i * dimension2)) * dimension;
                float f2 = fArr[i3];
                float f3 = fArr[i4];
                float f4 = fArr[i5];
                float f5 = fArr[i6];
                for (int i7 = 1; i7 < dimension; i7++) {
                    int i8 = i3;
                    i3++;
                    float f6 = f2;
                    f2 = fArr[i3];
                    int i9 = i4;
                    i4++;
                    float f7 = f3;
                    f3 = fArr[i4];
                    int i10 = i5;
                    i5++;
                    float f8 = f4;
                    f4 = fArr[i5];
                    int i11 = i6;
                    i6++;
                    float f9 = f5;
                    f5 = fArr[i6];
                    float f10 = f2 - f6;
                    float f11 = f3 - f7;
                    float f12 = f4 - f8;
                    float f13 = f5 - f9;
                    float square5 = square(f10) + square(f11) + square(f12) + square(f13);
                    float f14 = f7 - f6;
                    float f15 = f3 - f2;
                    float f16 = f9 - f8;
                    float f17 = f5 - f4;
                    float square6 = square(f14) + square(f15) + square(f16) + square(f17);
                    float f18 = f8 - f6;
                    float f19 = f4 - f2;
                    float f20 = f9 - f7;
                    float f21 = f5 - f3;
                    float sqrt = sqrt((square * square5) + (square2 * square6) + (square3 * (square(f18) + square(f19) + square(f20) + square(f21))) + square4);
                    d2 += sqrt;
                    if (z) {
                        float f22 = f / sqrt;
                        float f23 = square * f22;
                        float f24 = f10 * f23;
                        float f25 = f11 * f23;
                        float f26 = f12 * f23;
                        float f27 = f13 * f23;
                        float f28 = square2 * f22;
                        float f29 = f14 * f28;
                        float f30 = f15 * f28;
                        float f31 = f16 * f28;
                        float f32 = f17 * f28;
                        float f33 = square3 * f22;
                        float f34 = f18 * f33;
                        float f35 = f19 * f33;
                        float f36 = f20 * f33;
                        float f37 = f21 * f33;
                        fArr2[i8] = fArr2[i8] - ((f24 + f29) + f34);
                        fArr2[i3] = fArr2[i3] + ((f24 - f30) - f35);
                        fArr2[i9] = fArr2[i9] - ((f25 - f29) + f36);
                        fArr2[i4] = fArr2[i4] + ((f25 + f30) - f37);
                        fArr2[i10] = fArr2[i10] - ((f26 + f31) - f34);
                        fArr2[i5] = fArr2[i5] + (f26 - f32) + f35;
                        fArr2[i11] = fArr2[i11] - ((f27 - f31) - f36);
                        fArr2[i6] = fArr2[i6] + f27 + f32 + f37;
                    }
                }
            }
        }
        double d3 = d2 - ((((dimension - 1) * (dimension2 - 1)) * (dimension3 - 1)) * this.epsilon);
        if (d3 > 0.0d) {
            return d * d3;
        }
        return 0.0d;
    }

    private final double computeDouble3D(double d, double[] dArr, double[] dArr2) {
        boolean z = dArr2 != null;
        int dimension = this.shape.dimension(0);
        int dimension2 = this.shape.dimension(1);
        int dimension3 = this.shape.dimension(2);
        double square = 1.0d / (4.0d * square(this.scale[0]));
        double square2 = 1.0d / (4.0d * square(this.scale[1]));
        double square3 = 1.0d / (4.0d * square(this.scale[2]));
        double square4 = square(this.epsilon);
        double d2 = 0.0d;
        for (int i = 1; i < dimension3; i++) {
            for (int i2 = 1; i2 < dimension2; i2++) {
                int i3 = ((i2 - 1) + ((i - 1) * dimension2)) * dimension;
                int i4 = (i2 + ((i - 1) * dimension2)) * dimension;
                int i5 = ((i2 - 1) + (i * dimension2)) * dimension;
                int i6 = (i2 + (i * dimension2)) * dimension;
                double d3 = dArr[i3];
                double d4 = dArr[i4];
                double d5 = dArr[i5];
                double d6 = dArr[i6];
                for (int i7 = 1; i7 < dimension; i7++) {
                    int i8 = i3;
                    i3++;
                    double d7 = d3;
                    d3 = dArr[i3];
                    int i9 = i4;
                    i4++;
                    double d8 = d4;
                    d4 = dArr[i4];
                    int i10 = i5;
                    i5++;
                    double d9 = d5;
                    d5 = dArr[i5];
                    int i11 = i6;
                    i6++;
                    double d10 = d6;
                    d6 = dArr[i6];
                    double d11 = d3 - d7;
                    double d12 = d4 - d8;
                    double d13 = d5 - d9;
                    double d14 = d6 - d10;
                    double square5 = square(d11) + square(d12) + square(d13) + square(d14);
                    double d15 = d8 - d7;
                    double d16 = d4 - d3;
                    double d17 = d10 - d9;
                    double d18 = d6 - d5;
                    double square6 = square(d15) + square(d16) + square(d17) + square(d18);
                    double d19 = d9 - d7;
                    double d20 = d5 - d3;
                    double d21 = d10 - d8;
                    double d22 = d6 - d4;
                    double sqrt = sqrt((square * square5) + (square2 * square6) + (square3 * (square(d19) + square(d20) + square(d21) + square(d22))) + square4);
                    d2 += sqrt;
                    if (z) {
                        double d23 = d / sqrt;
                        double d24 = square * d23;
                        double d25 = d11 * d24;
                        double d26 = d12 * d24;
                        double d27 = d13 * d24;
                        double d28 = d14 * d24;
                        double d29 = square2 * d23;
                        double d30 = d15 * d29;
                        double d31 = d16 * d29;
                        double d32 = d17 * d29;
                        double d33 = d18 * d29;
                        double d34 = square3 * d23;
                        double d35 = d19 * d34;
                        double d36 = d20 * d34;
                        double d37 = d21 * d34;
                        double d38 = d22 * d34;
                        dArr2[i8] = dArr2[i8] - ((d25 + d30) + d35);
                        dArr2[i3] = dArr2[i3] + ((d25 - d31) - d36);
                        dArr2[i9] = dArr2[i9] - ((d26 - d30) + d37);
                        dArr2[i4] = dArr2[i4] + ((d26 + d31) - d38);
                        dArr2[i10] = dArr2[i10] - ((d27 + d32) - d35);
                        dArr2[i5] = dArr2[i5] + (d27 - d33) + d36;
                        dArr2[i11] = dArr2[i11] - ((d28 - d32) - d37);
                        dArr2[i6] = dArr2[i6] + d28 + d33 + d38;
                    }
                }
            }
        }
        double d39 = d2 - ((((dimension - 1) * (dimension2 - 1)) * (dimension3 - 1)) * this.epsilon);
        if (d39 > 0.0d) {
            return d * d39;
        }
        return 0.0d;
    }

    private static final boolean notFinite(double d) {
        return Double.isInfinite(d) || Double.isNaN(d);
    }

    private static final float sqrt(float f) {
        return (float) Math.sqrt(f);
    }

    private static final double sqrt(double d) {
        return Math.sqrt(d);
    }

    private static final float square(float f) {
        return f * f;
    }

    private static final double square(double d) {
        return d * d;
    }

    @Override // mitiv.cost.CostFunction
    public double evaluate(double d, Vector vector) {
        return computeCostAndGradient(d, vector, null, false);
    }
}
