/*
 * Decompiled with CFR 0.152.
 */
package mitiv.cost;

import mitiv.base.Shape;
import mitiv.cost.DifferentiableCostFunction;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;
import mitiv.linalg.shaped.DoubleShapedVector;
import mitiv.linalg.shaped.FloatShapedVector;
import mitiv.linalg.shaped.ShapedVectorSpace;

public class HyperbolicTotalVariation
implements DifferentiableCostFunction {
    protected final ShapedVectorSpace inputSpace;
    protected int rank;
    protected int type;
    protected Shape shape;
    protected double epsilon;
    protected double[] delta;

    public HyperbolicTotalVariation(ShapedVectorSpace inputSpace, double epsilon) {
        this.inputSpace = inputSpace;
        this.shape = inputSpace.getShape();
        this.rank = this.shape == null ? 0 : this.shape.rank();
        this.type = inputSpace.getType();
        this.setThreshold(epsilon);
        this.delta = new double[this.rank];
        this.defaultScale();
    }

    public HyperbolicTotalVariation(ShapedVectorSpace inputSpace, double epsilon, double[] delta) {
        this(inputSpace, epsilon);
        this.setScale(delta);
    }

    public void setThreshold(double epsilon) {
        if (epsilon <= 0.0) {
            throw new IllegalArgumentException("Bad threshold value.");
        }
        this.epsilon = epsilon;
    }

    public double getThreshold() {
        return this.epsilon;
    }

    public void defaultScale() {
        this.setScale(1.0);
    }

    public void setScale(double value) {
        if (value <= 0.0) {
            throw new IllegalArgumentException("Bad scale value.");
        }
        int k = 0;
        while (k < this.rank) {
            this.delta[k] = value;
            ++k;
        }
    }

    public void setScale(double[] delta) {
        if (delta == null || delta.length != this.rank) {
            throw new IllegalArgumentException("Bad scale size.");
        }
        int k = 0;
        while (k < this.rank) {
            if (delta[k] <= 0.0) {
                throw new IllegalArgumentException("Bad scale value.");
            }
            ++k;
        }
        k = 0;
        while (k < this.rank) {
            this.delta[k] = delta[k];
            ++k;
        }
    }

    public double getScale(int k) {
        return k >= 0 && k < this.rank ? this.delta[k] : 1.0;
    }

    @Override
    public VectorSpace getInputSpace() {
        return this.inputSpace;
    }

    @Override
    public double computeCostAndGradient(double alpha, Vector vx, Vector vgx, boolean clr) {
        if (vgx != null && clr) {
            vgx.fill(0.0);
        }
        if (this.type == 4) {
            float[] x = ((FloatShapedVector)vx).getData();
            float[] gx = null;
            if (vgx != null) {
                gx = ((FloatShapedVector)vgx).getData();
            }
            if (this.rank == 1) {
                return this.computeFloat1D(alpha, x, gx);
            }
            if (this.rank == 2) {
                return this.computeFloat2D(alpha, x, gx);
            }
            if (this.rank == 3) {
                return this.computeFloat3D(alpha, x, gx);
            }
            HyperbolicTotalVariation.badRank();
        } else if (this.type == 5) {
            double[] x = ((DoubleShapedVector)vx).getData();
            double[] gx = null;
            if (vgx != null) {
                gx = ((DoubleShapedVector)vgx).getData();
            }
            if (this.rank == 1) {
                return this.computeDouble1D(alpha, x, gx);
            }
            if (this.rank == 2) {
                return this.computeDouble2D(alpha, x, gx);
            }
            if (this.rank == 3) {
                return this.computeDouble3D(alpha, x, gx);
            }
            HyperbolicTotalVariation.badRank();
        } else {
            HyperbolicTotalVariation.badType();
        }
        return 0.0;
    }

    protected static void badRank() {
        throw new IllegalArgumentException("Unsupported number of dimensions for Total Variation.");
    }

    protected static void badType() {
        throw new IllegalArgumentException("Unsupported data type for Total Variation.");
    }

    private final double computeFloat1D(double alpha, float[] x, float[] gx) {
        boolean computeGradient = gx != null;
        int dim1 = this.shape.dimension(0);
        float s = (float)HyperbolicTotalVariation.square(this.epsilon * this.delta[0]);
        double fcost = 0.0;
        float beta = (float)(alpha / this.delta[0]);
        int i1 = 1;
        while (i1 < dim1) {
            float d = x[i1] - x[i1 - 1];
            float r = HyperbolicTotalVariation.sqrt(d * d + s);
            fcost += (double)r;
            if (computeGradient) {
                float p = beta * (d / r);
                int n = i1 - 1;
                gx[n] = gx[n] - p;
                int n2 = i1;
                gx[n2] = gx[n2] + p;
            }
            ++i1;
        }
        if ((fcost = fcost / this.delta[0] - (double)(dim1 - 1) * this.epsilon) < 0.0) {
            fcost = 0.0;
        }
        return alpha * fcost;
    }

    private final double computeDouble1D(double alpha, double[] x, double[] gx) {
        boolean computeGradient = gx != null;
        int dim1 = this.shape.dimension(0);
        double s = HyperbolicTotalVariation.square(this.epsilon * this.delta[0]);
        double fcost = 0.0;
        double beta = alpha / this.delta[0];
        int i1 = 1;
        while (i1 < dim1) {
            double d = x[i1] - x[i1 - 1];
            double r = HyperbolicTotalVariation.sqrt(d * d + s);
            fcost += r;
            if (computeGradient) {
                double p = beta * (d / r);
                int n = i1 - 1;
                gx[n] = gx[n] - p;
                int n2 = i1;
                gx[n2] = gx[n2] + p;
            }
            ++i1;
        }
        if ((fcost = fcost / this.delta[0] - (double)(dim1 - 1) * this.epsilon) < 0.0) {
            fcost = 0.0;
        }
        return alpha * fcost;
    }

    private final double computeFloat2D(double alpha, float[] x, float[] gx) {
        boolean computeGradient = gx != null;
        int dim1 = this.shape.dimension(0);
        int dim2 = this.shape.dimension(1);
        float w1 = (float)(1.0 / (2.0 * HyperbolicTotalVariation.square(this.delta[0])));
        float w2 = (float)(1.0 / (2.0 * HyperbolicTotalVariation.square(this.delta[1])));
        float s = (float)HyperbolicTotalVariation.square(this.epsilon);
        double fcost = 0.0;
        float _alpha = (float)alpha;
        if (w1 == w2) {
            float w = w1;
            int i2 = 1;
            while (i2 < dim2) {
                int j2 = (i2 - 1) * dim1;
                int j4 = i2 * dim1;
                float x2 = x[j2];
                float x4 = x[j4];
                int i1 = 1;
                while (i1 < dim1) {
                    int j1 = j2++;
                    int j3 = j4++;
                    float x1 = x2;
                    x2 = x[j2];
                    float x3 = x4;
                    x4 = x[j4];
                    float y21 = x2 - x1;
                    float y43 = x4 - x3;
                    float y31 = x3 - x1;
                    float y42 = x4 - x2;
                    float r = HyperbolicTotalVariation.sqrt((HyperbolicTotalVariation.square(y21) + HyperbolicTotalVariation.square(y43) + HyperbolicTotalVariation.square(y31) + HyperbolicTotalVariation.square(y42)) * w + s);
                    fcost += (double)r;
                    if (computeGradient) {
                        float p = _alpha * w / r;
                        gx[j1] = gx[j1] - (y21 + y31) * p;
                        gx[j2] = gx[j2] + (y21 - y42) * p;
                        gx[j3] = gx[j3] - (y43 - y31) * p;
                        gx[j4] = gx[j4] + (y43 + y42) * p;
                    }
                    ++i1;
                }
                ++i2;
            }
        } else {
            int i2 = 1;
            while (i2 < dim2) {
                int j2 = (i2 - 1) * dim1;
                int j4 = i2 * dim1;
                float x2 = x[j2];
                float x4 = x[j4];
                int i1 = 1;
                while (i1 < dim1) {
                    int j1 = j2++;
                    float x1 = x2;
                    x2 = x[j2];
                    int j3 = j4++;
                    float x3 = x4;
                    x4 = x[j4];
                    float y21 = x2 - x1;
                    float y43 = x4 - x3;
                    float y31 = x3 - x1;
                    float y42 = x4 - x2;
                    float r = HyperbolicTotalVariation.sqrt((HyperbolicTotalVariation.square(y21) + HyperbolicTotalVariation.square(y43)) * w1 + (HyperbolicTotalVariation.square(y31) + HyperbolicTotalVariation.square(y42)) * w2 + s);
                    fcost += (double)r;
                    if (computeGradient) {
                        float q = _alpha / r;
                        float p1 = w1 * q;
                        float p2 = w2 * q;
                        gx[j1] = gx[j1] - ((y21 *= p1) + (y31 *= p2));
                        gx[j2] = gx[j2] + (y21 - (y42 *= p2));
                        gx[j3] = gx[j3] - ((y43 *= p1) - y31);
                        gx[j4] = gx[j4] + (y43 + y42);
                    }
                    ++i1;
                }
                ++i2;
            }
        }
        fcost -= (double)((dim1 - 1) * (dim2 - 1)) * this.epsilon;
        if (fcost < 0.0) {
            fcost = 0.0;
        }
        return alpha * fcost;
    }

    private final double computeDouble2D(double alpha, double[] x, double[] gx) {
        boolean computeGradient = gx != null;
        int dim1 = this.shape.dimension(0);
        int dim2 = this.shape.dimension(1);
        double w1 = 1.0 / (2.0 * HyperbolicTotalVariation.square(this.delta[0]));
        double w2 = 1.0 / (2.0 * HyperbolicTotalVariation.square(this.delta[1]));
        double s = HyperbolicTotalVariation.square(this.epsilon);
        double fcost = 0.0;
        if (w1 == w2) {
            double w = w1;
            int i2 = 1;
            while (i2 < dim2) {
                int j2 = (i2 - 1) * dim1;
                int j4 = i2 * dim1;
                double x2 = x[j2];
                double x4 = x[j4];
                int i1 = 1;
                while (i1 < dim1) {
                    int j1 = j2++;
                    int j3 = j4++;
                    double x1 = x2;
                    x2 = x[j2];
                    double x3 = x4;
                    x4 = x[j4];
                    double y21 = x2 - x1;
                    double y43 = x4 - x3;
                    double y31 = x3 - x1;
                    double y42 = x4 - x2;
                    double r = HyperbolicTotalVariation.sqrt((HyperbolicTotalVariation.square(y21) + HyperbolicTotalVariation.square(y43) + HyperbolicTotalVariation.square(y31) + HyperbolicTotalVariation.square(y42)) * w + s);
                    fcost += r;
                    if (computeGradient) {
                        double p = alpha * w / r;
                        gx[j1] = gx[j1] - (y21 + y31) * p;
                        gx[j2] = gx[j2] + (y21 - y42) * p;
                        gx[j3] = gx[j3] - (y43 - y31) * p;
                        gx[j4] = gx[j4] + (y43 + y42) * p;
                    }
                    ++i1;
                }
                ++i2;
            }
        } else {
            int i2 = 1;
            while (i2 < dim2) {
                int j2 = (i2 - 1) * dim1;
                int j4 = i2 * dim1;
                double x2 = x[j2];
                double x4 = x[j4];
                int i1 = 1;
                while (i1 < dim1) {
                    int j1 = j2++;
                    double x1 = x2;
                    x2 = x[j2];
                    int j3 = j4++;
                    double x3 = x4;
                    x4 = x[j4];
                    double y21 = x2 - x1;
                    double y43 = x4 - x3;
                    double y31 = x3 - x1;
                    double y42 = x4 - x2;
                    double r = HyperbolicTotalVariation.sqrt((HyperbolicTotalVariation.square(y21) + HyperbolicTotalVariation.square(y43)) * w1 + (HyperbolicTotalVariation.square(y31) + HyperbolicTotalVariation.square(y42)) * w2 + s);
                    fcost += r;
                    if (computeGradient) {
                        double q = alpha / r;
                        double p1 = w1 * q;
                        double p2 = w2 * q;
                        gx[j1] = gx[j1] - ((y21 *= p1) + (y31 *= p2));
                        gx[j2] = gx[j2] + (y21 - (y42 *= p2));
                        gx[j3] = gx[j3] - ((y43 *= p1) - y31);
                        gx[j4] = gx[j4] + (y43 + y42);
                    }
                    ++i1;
                }
                ++i2;
            }
        }
        fcost -= (double)((dim1 - 1) * (dim2 - 1)) * this.epsilon;
        if (fcost < 0.0) {
            fcost = 0.0;
        }
        return alpha * fcost;
    }

    private final double computeFloat3D(double alpha, float[] x, float[] gx) {
        boolean computeGradient = gx != null;
        int dim1 = this.shape.dimension(0);
        int dim2 = this.shape.dimension(1);
        int dim3 = this.shape.dimension(2);
        float w1 = (float)(1.0 / (2.0 * HyperbolicTotalVariation.square(this.delta[0])));
        float w2 = (float)(1.0 / (2.0 * HyperbolicTotalVariation.square(this.delta[1])));
        float w3 = (float)(1.0 / (2.0 * HyperbolicTotalVariation.square(this.delta[2])));
        float s = (float)HyperbolicTotalVariation.square(this.epsilon);
        double fcost = 0.0;
        float _alpha = (float)alpha;
        int i3 = 1;
        while (i3 < dim3) {
            int i2 = 1;
            while (i2 < dim2) {
                int j2 = (i2 - 1 + (i3 - 1) * dim2) * dim1;
                int j4 = (i2 + (i3 - 1) * dim2) * dim1;
                int j6 = (i2 - 1 + i3 * dim2) * dim1;
                int j8 = (i2 + i3 * dim2) * dim1;
                float x2 = x[j2];
                float x4 = x[j4];
                float x6 = x[j6];
                float x8 = x[j8];
                int i1 = 1;
                while (i1 < dim1) {
                    int j1 = j2++;
                    float x1 = x2;
                    x2 = x[j2];
                    int j3 = j4++;
                    float x3 = x4;
                    x4 = x[j4];
                    int j5 = j6++;
                    float x5 = x6;
                    x6 = x[j6];
                    int j7 = j8++;
                    float x7 = x8;
                    x8 = x[j8];
                    float y21 = x2 - x1;
                    float y43 = x4 - x3;
                    float y65 = x6 - x5;
                    float y87 = x8 - x7;
                    float r1 = HyperbolicTotalVariation.square(y21) + HyperbolicTotalVariation.square(y43) + HyperbolicTotalVariation.square(y65) + HyperbolicTotalVariation.square(y87);
                    float y31 = x3 - x1;
                    float y42 = x4 - x2;
                    float y75 = x7 - x5;
                    float y86 = x8 - x6;
                    float r2 = HyperbolicTotalVariation.square(y31) + HyperbolicTotalVariation.square(y42) + HyperbolicTotalVariation.square(y75) + HyperbolicTotalVariation.square(y86);
                    float y51 = x5 - x1;
                    float y62 = x6 - x2;
                    float y73 = x7 - x3;
                    float y84 = x8 - x4;
                    float r3 = HyperbolicTotalVariation.square(y51) + HyperbolicTotalVariation.square(y62) + HyperbolicTotalVariation.square(y73) + HyperbolicTotalVariation.square(y84);
                    float r = HyperbolicTotalVariation.sqrt(w1 * r1 + w2 * r2 + w3 * r3 + s);
                    fcost += (double)r;
                    if (computeGradient) {
                        float q = _alpha / r;
                        float p1 = w1 * q;
                        y21 *= p1;
                        y43 *= p1;
                        y65 *= p1;
                        y87 *= p1;
                        float p2 = w2 * q;
                        y75 *= p2;
                        y86 *= p2;
                        float p3 = w3 * q;
                        gx[j1] = gx[j1] - (y21 + (y31 *= p2) + (y51 *= p3));
                        gx[j2] = gx[j2] + (y21 - (y42 *= p2) - (y62 *= p3));
                        gx[j3] = gx[j3] - (y43 - y31 + (y73 *= p3));
                        gx[j4] = gx[j4] + (y43 + y42 - (y84 *= p3));
                        gx[j5] = gx[j5] - (y65 + y75 - y51);
                        gx[j6] = gx[j6] + (y65 - y86 + y62);
                        gx[j7] = gx[j7] - (y87 - y75 - y73);
                        gx[j8] = gx[j8] + (y87 + y86 + y84);
                    }
                    ++i1;
                }
                ++i2;
            }
            ++i3;
        }
        if ((fcost -= (double)((dim1 - 1) * (dim2 - 1) * (dim3 - 1)) * this.epsilon) < 0.0) {
            fcost = 0.0;
        }
        return alpha * fcost;
    }

    private final double computeDouble3D(double alpha, double[] x, double[] gx) {
        boolean computeGradient = gx != null;
        int dim1 = this.shape.dimension(0);
        int dim2 = this.shape.dimension(1);
        int dim3 = this.shape.dimension(2);
        double w1 = 1.0 / (2.0 * HyperbolicTotalVariation.square(this.delta[0]));
        double w2 = 1.0 / (2.0 * HyperbolicTotalVariation.square(this.delta[1]));
        double w3 = 1.0 / (2.0 * HyperbolicTotalVariation.square(this.delta[2]));
        double s = HyperbolicTotalVariation.square(this.epsilon);
        double fcost = 0.0;
        int i3 = 1;
        while (i3 < dim3) {
            int i2 = 1;
            while (i2 < dim2) {
                int j2 = (i2 - 1 + (i3 - 1) * dim2) * dim1;
                int j4 = (i2 + (i3 - 1) * dim2) * dim1;
                int j6 = (i2 - 1 + i3 * dim2) * dim1;
                int j8 = (i2 + i3 * dim2) * dim1;
                double x2 = x[j2];
                double x4 = x[j4];
                double x6 = x[j6];
                double x8 = x[j8];
                int i1 = 1;
                while (i1 < dim1) {
                    int j1 = j2++;
                    double x1 = x2;
                    x2 = x[j2];
                    int j3 = j4++;
                    double x3 = x4;
                    x4 = x[j4];
                    int j5 = j6++;
                    double x5 = x6;
                    x6 = x[j6];
                    int j7 = j8++;
                    double x7 = x8;
                    x8 = x[j8];
                    double y21 = x2 - x1;
                    double y43 = x4 - x3;
                    double y65 = x6 - x5;
                    double y87 = x8 - x7;
                    double r1 = HyperbolicTotalVariation.square(y21) + HyperbolicTotalVariation.square(y43) + HyperbolicTotalVariation.square(y65) + HyperbolicTotalVariation.square(y87);
                    double y31 = x3 - x1;
                    double y42 = x4 - x2;
                    double y75 = x7 - x5;
                    double y86 = x8 - x6;
                    double r2 = HyperbolicTotalVariation.square(y31) + HyperbolicTotalVariation.square(y42) + HyperbolicTotalVariation.square(y75) + HyperbolicTotalVariation.square(y86);
                    double y51 = x5 - x1;
                    double y62 = x6 - x2;
                    double y73 = x7 - x3;
                    double y84 = x8 - x4;
                    double r3 = HyperbolicTotalVariation.square(y51) + HyperbolicTotalVariation.square(y62) + HyperbolicTotalVariation.square(y73) + HyperbolicTotalVariation.square(y84);
                    double r = HyperbolicTotalVariation.sqrt(w1 * r1 + w2 * r2 + w3 * r3 + s);
                    fcost += r;
                    if (computeGradient) {
                        double q = alpha / r;
                        double p1 = w1 * q;
                        y21 *= p1;
                        y43 *= p1;
                        y65 *= p1;
                        y87 *= p1;
                        double p2 = w2 * q;
                        y75 *= p2;
                        y86 *= p2;
                        double p3 = w3 * q;
                        gx[j1] = gx[j1] - (y21 + (y31 *= p2) + (y51 *= p3));
                        gx[j2] = gx[j2] + (y21 - (y42 *= p2) - (y62 *= p3));
                        gx[j3] = gx[j3] - (y43 - y31 + (y73 *= p3));
                        gx[j4] = gx[j4] + (y43 + y42 - (y84 *= p3));
                        gx[j5] = gx[j5] - (y65 + y75 - y51);
                        gx[j6] = gx[j6] + (y65 - y86 + y62);
                        gx[j7] = gx[j7] - (y87 - y75 - y73);
                        gx[j8] = gx[j8] + (y87 + y86 + y84);
                    }
                    ++i1;
                }
                ++i2;
            }
            ++i3;
        }
        if ((fcost -= (double)((dim1 - 1) * (dim2 - 1) * (dim3 - 1)) * this.epsilon) < 0.0) {
            fcost = 0.0;
        }
        return alpha * fcost;
    }

    private static final float sqrt(float a) {
        return (float)Math.sqrt(a);
    }

    private static final double sqrt(double a) {
        return Math.sqrt(a);
    }

    private static final float square(float a) {
        return a * a;
    }

    private static final double square(double a) {
        return a * a;
    }

    @Override
    public double evaluate(double alpha, Vector x) {
        return this.computeCostAndGradient(alpha, x, null, false);
    }
}

