/*
 * Decompiled with CFR 0.152.
 */
package mitiv.cost;

import mitiv.base.Shape;
import mitiv.cost.DifferentiableCostFunction;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;
import mitiv.linalg.shaped.DoubleShapedVector;
import mitiv.linalg.shaped.FloatShapedVector;
import mitiv.linalg.shaped.ShapedVectorSpace;

public class HyperbolicTotalVariation
implements DifferentiableCostFunction {
    protected final ShapedVectorSpace inputSpace;
    protected int rank;
    protected int type;
    protected Shape shape;
    protected double epsilon;
    protected double[] delta;

    public HyperbolicTotalVariation(ShapedVectorSpace shapedVectorSpace, double d) {
        this.inputSpace = shapedVectorSpace;
        this.shape = shapedVectorSpace.getShape();
        this.rank = this.shape == null ? 0 : this.shape.rank();
        this.type = shapedVectorSpace.getType();
        this.setThreshold(d);
        this.delta = new double[this.rank];
        this.defaultScale();
    }

    public HyperbolicTotalVariation(ShapedVectorSpace shapedVectorSpace, double d, double[] dArray) {
        this(shapedVectorSpace, d);
        this.setScale(dArray);
    }

    public void setThreshold(double d) {
        if (HyperbolicTotalVariation.notFinite(d) || d <= 0.0) {
            throw new IllegalArgumentException("Bad threshold value");
        }
        this.epsilon = d;
    }

    public double getThreshold() {
        return this.epsilon;
    }

    public void defaultScale() {
        this.setScale(1.0);
    }

    public void setScale(double d) {
        if (HyperbolicTotalVariation.notFinite(d) || d <= 0.0) {
            throw new IllegalArgumentException("Bad scale value");
        }
        for (int i = 0; i < this.rank; ++i) {
            this.delta[i] = d;
        }
    }

    public void setScale(double[] dArray) {
        int n;
        if (dArray == null || dArray.length != this.rank) {
            throw new IllegalArgumentException("Bad scale size");
        }
        for (n = 0; n < this.rank; ++n) {
            if (!HyperbolicTotalVariation.notFinite(dArray[n]) && !(dArray[n] <= 0.0)) continue;
            throw new IllegalArgumentException("Bad scale value");
        }
        for (n = 0; n < this.rank; ++n) {
            this.delta[n] = dArray[n];
        }
    }

    public double getScale(int n) {
        return n >= 0 && n < this.rank ? this.delta[n] : 1.0;
    }

    @Override
    public VectorSpace getInputSpace() {
        return this.inputSpace;
    }

    @Override
    public double computeCostAndGradient(double d, Vector vector, Vector vector2, boolean bl) {
        if (vector2 != null && bl) {
            vector2.fill(0.0);
        }
        if (this.type == 4) {
            float[] fArray = ((FloatShapedVector)vector).getData();
            float[] fArray2 = null;
            if (vector2 != null) {
                fArray2 = ((FloatShapedVector)vector2).getData();
            }
            if (this.rank == 1) {
                return this.computeFloat1D(d, fArray, fArray2);
            }
            if (this.rank == 2) {
                return this.computeFloat2D(d, fArray, fArray2);
            }
            if (this.rank == 3) {
                return this.computeFloat3D(d, fArray, fArray2);
            }
            HyperbolicTotalVariation.badRank();
        } else if (this.type == 5) {
            double[] dArray = ((DoubleShapedVector)vector).getData();
            double[] dArray2 = null;
            if (vector2 != null) {
                dArray2 = ((DoubleShapedVector)vector2).getData();
            }
            if (this.rank == 1) {
                return this.computeDouble1D(d, dArray, dArray2);
            }
            if (this.rank == 2) {
                return this.computeDouble2D(d, dArray, dArray2);
            }
            if (this.rank == 3) {
                return this.computeDouble3D(d, dArray, dArray2);
            }
            HyperbolicTotalVariation.badRank();
        } else {
            HyperbolicTotalVariation.badType();
        }
        return 0.0;
    }

    protected static void badRank() {
        throw new IllegalArgumentException("Unsupported number of dimensions for Total Variation");
    }

    protected static void badType() {
        throw new IllegalArgumentException("Unsupported data type for Total Variation");
    }

    private final double computeFloat1D(double d, float[] fArray, float[] fArray2) {
        boolean bl = fArray2 != null;
        int n = this.shape.dimension(0);
        float f = (float)HyperbolicTotalVariation.square(this.epsilon * this.delta[0]);
        double d2 = 0.0;
        float f2 = (float)(d / this.delta[0]);
        for (int i = 1; i < n; ++i) {
            float f3 = fArray[i] - fArray[i - 1];
            float f4 = HyperbolicTotalVariation.sqrt(f3 * f3 + f);
            d2 += (double)f4;
            if (!bl) continue;
            float f5 = f2 * (f3 / f4);
            int n2 = i - 1;
            fArray2[n2] = fArray2[n2] - f5;
            int n3 = i;
            fArray2[n3] = fArray2[n3] + f5;
        }
        return (d2 = d2 / this.delta[0] - (double)(n - 1) * this.epsilon) > 0.0 ? d * d2 : 0.0;
    }

    private final double computeDouble1D(double d, double[] dArray, double[] dArray2) {
        boolean bl = dArray2 != null;
        int n = this.shape.dimension(0);
        double d2 = HyperbolicTotalVariation.square(this.epsilon * this.delta[0]);
        double d3 = 0.0;
        double d4 = d / this.delta[0];
        for (int i = 1; i < n; ++i) {
            double d5 = dArray[i] - dArray[i - 1];
            double d6 = HyperbolicTotalVariation.sqrt(d5 * d5 + d2);
            d3 += d6;
            if (!bl) continue;
            double d7 = d4 * (d5 / d6);
            int n2 = i - 1;
            dArray2[n2] = dArray2[n2] - d7;
            int n3 = i;
            dArray2[n3] = dArray2[n3] + d7;
        }
        return (d3 = d3 / this.delta[0] - (double)(n - 1) * this.epsilon) > 0.0 ? d * d3 : 0.0;
    }

    private final double computeFloat2D(double d, float[] fArray, float[] fArray2) {
        boolean bl = fArray2 != null;
        int n = this.shape.dimension(0);
        int n2 = this.shape.dimension(1);
        float f = (float)(1.0 / (2.0 * HyperbolicTotalVariation.square(this.delta[0])));
        float f2 = (float)(1.0 / (2.0 * HyperbolicTotalVariation.square(this.delta[1])));
        float f3 = (float)HyperbolicTotalVariation.square(this.epsilon);
        double d2 = 0.0;
        float f4 = (float)d;
        if (f == f2) {
            float f5 = f;
            for (int i = 1; i < n2; ++i) {
                int n3 = (i - 1) * n;
                int n4 = i * n;
                float f6 = fArray[n3];
                float f7 = fArray[n4];
                for (int j = 1; j < n; ++j) {
                    int n5 = n3++;
                    int n6 = n4++;
                    float f8 = f6;
                    f6 = fArray[n3];
                    float f9 = f7;
                    f7 = fArray[n4];
                    float f10 = f6 - f8;
                    float f11 = f7 - f9;
                    float f12 = f9 - f8;
                    float f13 = f7 - f6;
                    float f14 = HyperbolicTotalVariation.sqrt((HyperbolicTotalVariation.square(f10) + HyperbolicTotalVariation.square(f11) + HyperbolicTotalVariation.square(f12) + HyperbolicTotalVariation.square(f13)) * f5 + f3);
                    d2 += (double)f14;
                    if (!bl) continue;
                    float f15 = f4 * f5 / f14;
                    int n7 = n5;
                    fArray2[n7] = fArray2[n7] - (f10 + f12) * f15;
                    int n8 = n3;
                    fArray2[n8] = fArray2[n8] + (f10 - f13) * f15;
                    int n9 = n6;
                    fArray2[n9] = fArray2[n9] - (f11 - f12) * f15;
                    int n10 = n4;
                    fArray2[n10] = fArray2[n10] + (f11 + f13) * f15;
                }
            }
        } else {
            for (int i = 1; i < n2; ++i) {
                int n11 = (i - 1) * n;
                int n12 = i * n;
                float f16 = fArray[n11];
                float f17 = fArray[n12];
                for (int j = 1; j < n; ++j) {
                    int n13 = n11++;
                    float f18 = f16;
                    f16 = fArray[n11];
                    int n14 = n12++;
                    float f19 = f17;
                    f17 = fArray[n12];
                    float f20 = f16 - f18;
                    float f21 = f17 - f19;
                    float f22 = f19 - f18;
                    float f23 = f17 - f16;
                    float f24 = HyperbolicTotalVariation.sqrt((HyperbolicTotalVariation.square(f20) + HyperbolicTotalVariation.square(f21)) * f + (HyperbolicTotalVariation.square(f22) + HyperbolicTotalVariation.square(f23)) * f2 + f3);
                    d2 += (double)f24;
                    if (!bl) continue;
                    float f25 = f4 / f24;
                    float f26 = f * f25;
                    f21 *= f26;
                    float f27 = f2 * f25;
                    int n15 = n13;
                    fArray2[n15] = fArray2[n15] - ((f20 *= f26) + (f22 *= f27));
                    int n16 = n11;
                    fArray2[n16] = fArray2[n16] + (f20 - (f23 *= f27));
                    int n17 = n14;
                    fArray2[n17] = fArray2[n17] - (f21 - f22);
                    int n18 = n12;
                    fArray2[n18] = fArray2[n18] + (f21 + f23);
                }
            }
        }
        return (d2 -= (double)((n - 1) * (n2 - 1)) * this.epsilon) > 0.0 ? d * d2 : 0.0;
    }

    private final double computeDouble2D(double d, double[] dArray, double[] dArray2) {
        boolean bl = dArray2 != null;
        int n = this.shape.dimension(0);
        int n2 = this.shape.dimension(1);
        double d2 = 1.0 / (2.0 * HyperbolicTotalVariation.square(this.delta[0]));
        double d3 = 1.0 / (2.0 * HyperbolicTotalVariation.square(this.delta[1]));
        double d4 = HyperbolicTotalVariation.square(this.epsilon);
        double d5 = 0.0;
        if (d2 == d3) {
            double d6 = d2;
            for (int i = 1; i < n2; ++i) {
                int n3 = (i - 1) * n;
                int n4 = i * n;
                double d7 = dArray[n3];
                double d8 = dArray[n4];
                for (int j = 1; j < n; ++j) {
                    int n5 = n3++;
                    int n6 = n4++;
                    double d9 = d7;
                    d7 = dArray[n3];
                    double d10 = d8;
                    d8 = dArray[n4];
                    double d11 = d7 - d9;
                    double d12 = d8 - d10;
                    double d13 = d10 - d9;
                    double d14 = d8 - d7;
                    double d15 = HyperbolicTotalVariation.sqrt((HyperbolicTotalVariation.square(d11) + HyperbolicTotalVariation.square(d12) + HyperbolicTotalVariation.square(d13) + HyperbolicTotalVariation.square(d14)) * d6 + d4);
                    d5 += d15;
                    if (!bl) continue;
                    double d16 = d * d6 / d15;
                    int n7 = n5;
                    dArray2[n7] = dArray2[n7] - (d11 + d13) * d16;
                    int n8 = n3;
                    dArray2[n8] = dArray2[n8] + (d11 - d14) * d16;
                    int n9 = n6;
                    dArray2[n9] = dArray2[n9] - (d12 - d13) * d16;
                    int n10 = n4;
                    dArray2[n10] = dArray2[n10] + (d12 + d14) * d16;
                }
            }
        } else {
            for (int i = 1; i < n2; ++i) {
                int n11 = (i - 1) * n;
                int n12 = i * n;
                double d17 = dArray[n11];
                double d18 = dArray[n12];
                for (int j = 1; j < n; ++j) {
                    int n13 = n11++;
                    double d19 = d17;
                    d17 = dArray[n11];
                    int n14 = n12++;
                    double d20 = d18;
                    d18 = dArray[n12];
                    double d21 = d17 - d19;
                    double d22 = d18 - d20;
                    double d23 = d20 - d19;
                    double d24 = d18 - d17;
                    double d25 = HyperbolicTotalVariation.sqrt((HyperbolicTotalVariation.square(d21) + HyperbolicTotalVariation.square(d22)) * d2 + (HyperbolicTotalVariation.square(d23) + HyperbolicTotalVariation.square(d24)) * d3 + d4);
                    d5 += d25;
                    if (!bl) continue;
                    double d26 = d / d25;
                    double d27 = d2 * d26;
                    d22 *= d27;
                    double d28 = d3 * d26;
                    int n15 = n13;
                    dArray2[n15] = dArray2[n15] - ((d21 *= d27) + (d23 *= d28));
                    int n16 = n11;
                    dArray2[n16] = dArray2[n16] + (d21 - (d24 *= d28));
                    int n17 = n14;
                    dArray2[n17] = dArray2[n17] - (d22 - d23);
                    int n18 = n12;
                    dArray2[n18] = dArray2[n18] + (d22 + d24);
                }
            }
        }
        return (d5 -= (double)((n - 1) * (n2 - 1)) * this.epsilon) > 0.0 ? d * d5 : 0.0;
    }

    private final double computeFloat3D(double d, float[] fArray, float[] fArray2) {
        boolean bl = fArray2 != null;
        int n = this.shape.dimension(0);
        int n2 = this.shape.dimension(1);
        int n3 = this.shape.dimension(2);
        float f = (float)(1.0 / (4.0 * HyperbolicTotalVariation.square(this.delta[0])));
        float f2 = (float)(1.0 / (4.0 * HyperbolicTotalVariation.square(this.delta[1])));
        float f3 = (float)(1.0 / (4.0 * HyperbolicTotalVariation.square(this.delta[2])));
        float f4 = (float)HyperbolicTotalVariation.square(this.epsilon);
        double d2 = 0.0;
        float f5 = (float)d;
        for (int i = 1; i < n3; ++i) {
            for (int j = 1; j < n2; ++j) {
                int n4 = (j - 1 + (i - 1) * n2) * n;
                int n5 = (j + (i - 1) * n2) * n;
                int n6 = (j - 1 + i * n2) * n;
                int n7 = (j + i * n2) * n;
                float f6 = fArray[n4];
                float f7 = fArray[n5];
                float f8 = fArray[n6];
                float f9 = fArray[n7];
                for (int k = 1; k < n; ++k) {
                    int n8 = n4++;
                    float f10 = f6;
                    f6 = fArray[n4];
                    int n9 = n5++;
                    float f11 = f7;
                    f7 = fArray[n5];
                    int n10 = n6++;
                    float f12 = f8;
                    f8 = fArray[n6];
                    int n11 = n7++;
                    float f13 = f9;
                    f9 = fArray[n7];
                    float f14 = f6 - f10;
                    float f15 = f7 - f11;
                    float f16 = f8 - f12;
                    float f17 = f9 - f13;
                    float f18 = HyperbolicTotalVariation.square(f14) + HyperbolicTotalVariation.square(f15) + HyperbolicTotalVariation.square(f16) + HyperbolicTotalVariation.square(f17);
                    float f19 = f11 - f10;
                    float f20 = f7 - f6;
                    float f21 = f13 - f12;
                    float f22 = f9 - f8;
                    float f23 = HyperbolicTotalVariation.square(f19) + HyperbolicTotalVariation.square(f20) + HyperbolicTotalVariation.square(f21) + HyperbolicTotalVariation.square(f22);
                    float f24 = f12 - f10;
                    float f25 = f8 - f6;
                    float f26 = f13 - f11;
                    float f27 = f9 - f7;
                    float f28 = HyperbolicTotalVariation.square(f24) + HyperbolicTotalVariation.square(f25) + HyperbolicTotalVariation.square(f26) + HyperbolicTotalVariation.square(f27);
                    float f29 = HyperbolicTotalVariation.sqrt(f * f18 + f2 * f23 + f3 * f28 + f4);
                    d2 += (double)f29;
                    if (!bl) continue;
                    float f30 = f5 / f29;
                    float f31 = f * f30;
                    f14 *= f31;
                    f15 *= f31;
                    f16 *= f31;
                    f17 *= f31;
                    float f32 = f2 * f30;
                    f19 *= f32;
                    f20 *= f32;
                    f21 *= f32;
                    f22 *= f32;
                    float f33 = f3 * f30;
                    f25 *= f33;
                    f26 *= f33;
                    f27 *= f33;
                    int n12 = n8;
                    fArray2[n12] = fArray2[n12] - (f14 + f19 + (f24 *= f33));
                    int n13 = n4;
                    fArray2[n13] = fArray2[n13] + (f14 - f20 - f25);
                    int n14 = n9;
                    fArray2[n14] = fArray2[n14] - (f15 - f19 + f26);
                    int n15 = n5;
                    fArray2[n15] = fArray2[n15] + (f15 + f20 - f27);
                    int n16 = n10;
                    fArray2[n16] = fArray2[n16] - (f16 + f21 - f24);
                    int n17 = n6;
                    fArray2[n17] = fArray2[n17] + (f16 - f22 + f25);
                    int n18 = n11;
                    fArray2[n18] = fArray2[n18] - (f17 - f21 - f26);
                    int n19 = n7;
                    fArray2[n19] = fArray2[n19] + (f17 + f22 + f27);
                }
            }
        }
        return (d2 -= (double)((n - 1) * (n2 - 1) * (n3 - 1)) * this.epsilon) > 0.0 ? d * d2 : 0.0;
    }

    private final double computeDouble3D(double d, double[] dArray, double[] dArray2) {
        boolean bl = dArray2 != null;
        int n = this.shape.dimension(0);
        int n2 = this.shape.dimension(1);
        int n3 = this.shape.dimension(2);
        double d2 = 1.0 / (4.0 * HyperbolicTotalVariation.square(this.delta[0]));
        double d3 = 1.0 / (4.0 * HyperbolicTotalVariation.square(this.delta[1]));
        double d4 = 1.0 / (4.0 * HyperbolicTotalVariation.square(this.delta[2]));
        double d5 = HyperbolicTotalVariation.square(this.epsilon);
        double d6 = 0.0;
        for (int i = 1; i < n3; ++i) {
            for (int j = 1; j < n2; ++j) {
                int n4 = (j - 1 + (i - 1) * n2) * n;
                int n5 = (j + (i - 1) * n2) * n;
                int n6 = (j - 1 + i * n2) * n;
                int n7 = (j + i * n2) * n;
                double d7 = dArray[n4];
                double d8 = dArray[n5];
                double d9 = dArray[n6];
                double d10 = dArray[n7];
                for (int k = 1; k < n; ++k) {
                    int n8 = n4++;
                    double d11 = d7;
                    d7 = dArray[n4];
                    int n9 = n5++;
                    double d12 = d8;
                    d8 = dArray[n5];
                    int n10 = n6++;
                    double d13 = d9;
                    d9 = dArray[n6];
                    int n11 = n7++;
                    double d14 = d10;
                    d10 = dArray[n7];
                    double d15 = d7 - d11;
                    double d16 = d8 - d12;
                    double d17 = d9 - d13;
                    double d18 = d10 - d14;
                    double d19 = HyperbolicTotalVariation.square(d15) + HyperbolicTotalVariation.square(d16) + HyperbolicTotalVariation.square(d17) + HyperbolicTotalVariation.square(d18);
                    double d20 = d12 - d11;
                    double d21 = d8 - d7;
                    double d22 = d14 - d13;
                    double d23 = d10 - d9;
                    double d24 = HyperbolicTotalVariation.square(d20) + HyperbolicTotalVariation.square(d21) + HyperbolicTotalVariation.square(d22) + HyperbolicTotalVariation.square(d23);
                    double d25 = d13 - d11;
                    double d26 = d9 - d7;
                    double d27 = d14 - d12;
                    double d28 = d10 - d8;
                    double d29 = HyperbolicTotalVariation.square(d25) + HyperbolicTotalVariation.square(d26) + HyperbolicTotalVariation.square(d27) + HyperbolicTotalVariation.square(d28);
                    double d30 = HyperbolicTotalVariation.sqrt(d2 * d19 + d3 * d24 + d4 * d29 + d5);
                    d6 += d30;
                    if (!bl) continue;
                    double d31 = d / d30;
                    double d32 = d2 * d31;
                    d15 *= d32;
                    d16 *= d32;
                    d17 *= d32;
                    d18 *= d32;
                    double d33 = d3 * d31;
                    d20 *= d33;
                    d21 *= d33;
                    d22 *= d33;
                    d23 *= d33;
                    double d34 = d4 * d31;
                    d26 *= d34;
                    d27 *= d34;
                    d28 *= d34;
                    int n12 = n8;
                    dArray2[n12] = dArray2[n12] - (d15 + d20 + (d25 *= d34));
                    int n13 = n4;
                    dArray2[n13] = dArray2[n13] + (d15 - d21 - d26);
                    int n14 = n9;
                    dArray2[n14] = dArray2[n14] - (d16 - d20 + d27);
                    int n15 = n5;
                    dArray2[n15] = dArray2[n15] + (d16 + d21 - d28);
                    int n16 = n10;
                    dArray2[n16] = dArray2[n16] - (d17 + d22 - d25);
                    int n17 = n6;
                    dArray2[n17] = dArray2[n17] + (d17 - d23 + d26);
                    int n18 = n11;
                    dArray2[n18] = dArray2[n18] - (d18 - d22 - d27);
                    int n19 = n7;
                    dArray2[n19] = dArray2[n19] + (d18 + d23 + d28);
                }
            }
        }
        return (d6 -= (double)((n - 1) * (n2 - 1) * (n3 - 1)) * this.epsilon) > 0.0 ? d * d6 : 0.0;
    }

    private static final boolean notFinite(double d) {
        return Double.isInfinite(d) || Double.isNaN(d);
    }

    private static final float sqrt(float f) {
        return (float)Math.sqrt(f);
    }

    private static final double sqrt(double d) {
        return Math.sqrt(d);
    }

    private static final float square(float f) {
        return f * f;
    }

    private static final double square(double d) {
        return d * d;
    }

    @Override
    public double evaluate(double d, Vector vector) {
        return this.computeCostAndGradient(d, vector, null, false);
    }
}

