/*
 * Decompiled with CFR 0.152.
 */
package mitiv.cost;

import mitiv.base.Shape;
import mitiv.cost.DifferentiableCostFunction;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;
import mitiv.linalg.shaped.DoubleShapedVector;
import mitiv.linalg.shaped.FloatShapedVector;
import mitiv.linalg.shaped.ShapedVectorSpace;

public class HyperbolicTotalVariation
implements DifferentiableCostFunction {
    protected ShapedVectorSpace inputSpace;
    protected int rank;
    protected int type;
    protected Shape shape;
    protected double epsilon;
    protected double[] scale;
    protected double[] delta = null;

    public HyperbolicTotalVariation(ShapedVectorSpace inputSpace, double epsilon) {
        this.inputSpace = inputSpace;
        this.shape = inputSpace.getShape();
        this.rank = this.shape == null ? 0 : this.shape.rank();
        this.type = inputSpace.getType();
        this.setThreshold(epsilon);
        this.scale = new double[this.rank];
        this.defaultScale();
    }

    public HyperbolicTotalVariation(ShapedVectorSpace inputSpace, double epsilon, double[] delta) {
        this(inputSpace, epsilon);
        this.setScale(delta);
    }

    public HyperbolicTotalVariation(double epsilon, double[] delta) {
        this.setThreshold(epsilon);
        this.setDelta(delta);
    }

    public HyperbolicTotalVariation(double[] delta) {
        this.setDelta(delta);
    }

    public HyperbolicTotalVariation(double epsilon) {
        this.setThreshold(epsilon);
    }

    public void setThreshold(double epsilon) {
        if (HyperbolicTotalVariation.notFinite(epsilon) || epsilon <= 0.0) {
            throw new IllegalArgumentException("Bad threshold value");
        }
        this.epsilon = epsilon;
    }

    public double getThreshold() {
        return this.epsilon;
    }

    public void defaultScale() {
        if (this.delta == null) {
            this.setScale(1.0);
        } else {
            this.setScale(this.delta);
        }
    }

    private void setDelta(double[] value) {
        if (value == null) {
            throw new IllegalArgumentException("Bad delta value");
        }
        for (int k = 0; k < this.delta.length; ++k) {
            if (!HyperbolicTotalVariation.notFinite(value[k]) && !(value[k] <= 0.0)) continue;
            throw new IllegalArgumentException("Bad delta value");
        }
        this.delta = (double[])value.clone();
    }

    public void setScale(double value) {
        if (HyperbolicTotalVariation.notFinite(value) || value <= 0.0) {
            throw new IllegalArgumentException("Bad delta value");
        }
        for (int k = 0; k < this.rank; ++k) {
            this.scale[k] = value;
        }
        this.delta = null;
    }

    public void setScale(double[] scale) {
        if (scale == null) {
            throw new IllegalArgumentException("Bad scale size");
        }
        if (scale.length == 1) {
            this.setScale(scale[1]);
        } else {
            int k;
            if (scale.length != this.rank) {
                throw new IllegalArgumentException("Bad scale size");
            }
            for (k = 0; k < this.rank; ++k) {
                if (!HyperbolicTotalVariation.notFinite(scale[k]) && !(scale[k] <= 0.0)) continue;
                throw new IllegalArgumentException("Bad scale value");
            }
            for (k = 0; k < this.rank; ++k) {
                this.scale[k] = scale[k];
            }
        }
        this.delta = null;
    }

    public double getScale(int k) {
        return k >= 0 && k < this.rank ? this.scale[k] : 1.0;
    }

    @Override
    public VectorSpace getInputSpace() {
        return this.inputSpace;
    }

    public final void setInputSpace(ShapedVectorSpace inputSpace) {
        this.inputSpace = inputSpace;
        this.shape = inputSpace.getShape();
        this.rank = this.shape == null ? 0 : this.shape.rank();
        this.type = inputSpace.getType();
        this.scale = new double[this.rank];
        this.defaultScale();
    }

    @Override
    public double computeCostAndGradient(double alpha, Vector vx, Vector vgx, boolean clr) {
        if (vgx != null && clr) {
            vgx.fill(0.0);
        }
        if (this.type == 4) {
            float[] x = ((FloatShapedVector)vx).getData();
            float[] gx = null;
            if (vgx != null) {
                gx = ((FloatShapedVector)vgx).getData();
            }
            if (this.rank == 1) {
                return this.computeFloat1D(alpha, x, gx);
            }
            if (this.rank == 2) {
                return this.computeFloat2D(alpha, x, gx);
            }
            if (this.rank == 3) {
                return this.computeFloat3D(alpha, x, gx);
            }
            HyperbolicTotalVariation.badRank();
        } else if (this.type == 5) {
            double[] x = ((DoubleShapedVector)vx).getData();
            double[] gx = null;
            if (vgx != null) {
                gx = ((DoubleShapedVector)vgx).getData();
            }
            if (this.rank == 1) {
                return this.computeDouble1D(alpha, x, gx);
            }
            if (this.rank == 2) {
                return this.computeDouble2D(alpha, x, gx);
            }
            if (this.rank == 3) {
                return this.computeDouble3D(alpha, x, gx);
            }
            HyperbolicTotalVariation.badRank();
        } else {
            HyperbolicTotalVariation.badType();
        }
        return 0.0;
    }

    protected static void badRank() {
        throw new IllegalArgumentException("Unsupported number of dimensions for Total Variation");
    }

    protected static void badType() {
        throw new IllegalArgumentException("Unsupported data type for Total Variation");
    }

    private final double computeFloat1D(double alpha, float[] x, float[] gx) {
        boolean computeGradient = gx != null;
        int dim1 = this.shape.dimension(0);
        float s = (float)HyperbolicTotalVariation.square(this.epsilon * this.scale[0]);
        double fcost = 0.0;
        float beta = (float)(alpha / this.scale[0]);
        for (int i1 = 1; i1 < dim1; ++i1) {
            float d = x[i1] - x[i1 - 1];
            float r = HyperbolicTotalVariation.sqrt(d * d + s);
            fcost += (double)r;
            if (!computeGradient) continue;
            float p = beta * (d / r);
            int n = i1 - 1;
            gx[n] = gx[n] - p;
            int n2 = i1;
            gx[n2] = gx[n2] + p;
        }
        return (fcost = fcost / this.scale[0] - (double)(dim1 - 1) * this.epsilon) > 0.0 ? alpha * fcost : 0.0;
    }

    private final double computeDouble1D(double alpha, double[] x, double[] gx) {
        boolean computeGradient = gx != null;
        int dim1 = this.shape.dimension(0);
        double s = HyperbolicTotalVariation.square(this.epsilon * this.scale[0]);
        double fcost = 0.0;
        double beta = alpha / this.scale[0];
        for (int i1 = 1; i1 < dim1; ++i1) {
            double d = x[i1] - x[i1 - 1];
            double r = HyperbolicTotalVariation.sqrt(d * d + s);
            fcost += r;
            if (!computeGradient) continue;
            double p = beta * (d / r);
            int n = i1 - 1;
            gx[n] = gx[n] - p;
            int n2 = i1;
            gx[n2] = gx[n2] + p;
        }
        return (fcost = fcost / this.scale[0] - (double)(dim1 - 1) * this.epsilon) > 0.0 ? alpha * fcost : 0.0;
    }

    private final double computeFloat2D(double alpha, float[] x, float[] gx) {
        boolean computeGradient = gx != null;
        int dim1 = this.shape.dimension(0);
        int dim2 = this.shape.dimension(1);
        float w1 = (float)(1.0 / (2.0 * HyperbolicTotalVariation.square(this.scale[0])));
        float w2 = (float)(1.0 / (2.0 * HyperbolicTotalVariation.square(this.scale[1])));
        float s = (float)HyperbolicTotalVariation.square(this.epsilon);
        double fcost = 0.0;
        float _alpha = (float)alpha;
        if (w1 == w2) {
            float w = w1;
            for (int i2 = 1; i2 < dim2; ++i2) {
                int j2 = (i2 - 1) * dim1;
                int j4 = i2 * dim1;
                float x2 = x[j2];
                float x4 = x[j4];
                for (int i1 = 1; i1 < dim1; ++i1) {
                    int j1 = j2++;
                    int j3 = j4++;
                    float x1 = x2;
                    x2 = x[j2];
                    float x3 = x4;
                    x4 = x[j4];
                    float y21 = x2 - x1;
                    float y43 = x4 - x3;
                    float y31 = x3 - x1;
                    float y42 = x4 - x2;
                    float r = HyperbolicTotalVariation.sqrt((HyperbolicTotalVariation.square(y21) + HyperbolicTotalVariation.square(y43) + HyperbolicTotalVariation.square(y31) + HyperbolicTotalVariation.square(y42)) * w + s);
                    fcost += (double)r;
                    if (!computeGradient) continue;
                    float p = _alpha * w / r;
                    int n = j1;
                    gx[n] = gx[n] - (y21 + y31) * p;
                    int n2 = j2;
                    gx[n2] = gx[n2] + (y21 - y42) * p;
                    int n3 = j3;
                    gx[n3] = gx[n3] - (y43 - y31) * p;
                    int n4 = j4;
                    gx[n4] = gx[n4] + (y43 + y42) * p;
                }
            }
        } else {
            for (int i2 = 1; i2 < dim2; ++i2) {
                int j2 = (i2 - 1) * dim1;
                int j4 = i2 * dim1;
                float x2 = x[j2];
                float x4 = x[j4];
                for (int i1 = 1; i1 < dim1; ++i1) {
                    int j1 = j2++;
                    float x1 = x2;
                    x2 = x[j2];
                    int j3 = j4++;
                    float x3 = x4;
                    x4 = x[j4];
                    float y21 = x2 - x1;
                    float y43 = x4 - x3;
                    float y31 = x3 - x1;
                    float y42 = x4 - x2;
                    float r = HyperbolicTotalVariation.sqrt((HyperbolicTotalVariation.square(y21) + HyperbolicTotalVariation.square(y43)) * w1 + (HyperbolicTotalVariation.square(y31) + HyperbolicTotalVariation.square(y42)) * w2 + s);
                    fcost += (double)r;
                    if (!computeGradient) continue;
                    float q = _alpha / r;
                    float p1 = w1 * q;
                    y43 *= p1;
                    float p2 = w2 * q;
                    int n = j1;
                    gx[n] = gx[n] - ((y21 *= p1) + (y31 *= p2));
                    int n5 = j2;
                    gx[n5] = gx[n5] + (y21 - (y42 *= p2));
                    int n6 = j3;
                    gx[n6] = gx[n6] - (y43 - y31);
                    int n7 = j4;
                    gx[n7] = gx[n7] + (y43 + y42);
                }
            }
        }
        return (fcost -= (double)((dim1 - 1) * (dim2 - 1)) * this.epsilon) > 0.0 ? alpha * fcost : 0.0;
    }

    private final double computeDouble2D(double alpha, double[] x, double[] gx) {
        boolean computeGradient = gx != null;
        int dim1 = this.shape.dimension(0);
        int dim2 = this.shape.dimension(1);
        double w1 = 1.0 / (2.0 * HyperbolicTotalVariation.square(this.scale[0]));
        double w2 = 1.0 / (2.0 * HyperbolicTotalVariation.square(this.scale[1]));
        double s = HyperbolicTotalVariation.square(this.epsilon);
        double fcost = 0.0;
        if (w1 == w2) {
            double w = w1;
            for (int i2 = 1; i2 < dim2; ++i2) {
                int j2 = (i2 - 1) * dim1;
                int j4 = i2 * dim1;
                double x2 = x[j2];
                double x4 = x[j4];
                for (int i1 = 1; i1 < dim1; ++i1) {
                    int j1 = j2++;
                    int j3 = j4++;
                    double x1 = x2;
                    x2 = x[j2];
                    double x3 = x4;
                    x4 = x[j4];
                    double y21 = x2 - x1;
                    double y43 = x4 - x3;
                    double y31 = x3 - x1;
                    double y42 = x4 - x2;
                    double r = HyperbolicTotalVariation.sqrt((HyperbolicTotalVariation.square(y21) + HyperbolicTotalVariation.square(y43) + HyperbolicTotalVariation.square(y31) + HyperbolicTotalVariation.square(y42)) * w + s);
                    fcost += r;
                    if (!computeGradient) continue;
                    double p = alpha * w / r;
                    int n = j1;
                    gx[n] = gx[n] - (y21 + y31) * p;
                    int n2 = j2;
                    gx[n2] = gx[n2] + (y21 - y42) * p;
                    int n3 = j3;
                    gx[n3] = gx[n3] - (y43 - y31) * p;
                    int n4 = j4;
                    gx[n4] = gx[n4] + (y43 + y42) * p;
                }
            }
        } else {
            for (int i2 = 1; i2 < dim2; ++i2) {
                int j2 = (i2 - 1) * dim1;
                int j4 = i2 * dim1;
                double x2 = x[j2];
                double x4 = x[j4];
                for (int i1 = 1; i1 < dim1; ++i1) {
                    int j1 = j2++;
                    double x1 = x2;
                    x2 = x[j2];
                    int j3 = j4++;
                    double x3 = x4;
                    x4 = x[j4];
                    double y21 = x2 - x1;
                    double y43 = x4 - x3;
                    double y31 = x3 - x1;
                    double y42 = x4 - x2;
                    double r = HyperbolicTotalVariation.sqrt((HyperbolicTotalVariation.square(y21) + HyperbolicTotalVariation.square(y43)) * w1 + (HyperbolicTotalVariation.square(y31) + HyperbolicTotalVariation.square(y42)) * w2 + s);
                    fcost += r;
                    if (!computeGradient) continue;
                    double q = alpha / r;
                    double p1 = w1 * q;
                    y43 *= p1;
                    double p2 = w2 * q;
                    int n = j1;
                    gx[n] = gx[n] - ((y21 *= p1) + (y31 *= p2));
                    int n5 = j2;
                    gx[n5] = gx[n5] + (y21 - (y42 *= p2));
                    int n6 = j3;
                    gx[n6] = gx[n6] - (y43 - y31);
                    int n7 = j4;
                    gx[n7] = gx[n7] + (y43 + y42);
                }
            }
        }
        return (fcost -= (double)((dim1 - 1) * (dim2 - 1)) * this.epsilon) > 0.0 ? alpha * fcost : 0.0;
    }

    private final double computeFloat3D(double alpha, float[] x, float[] gx) {
        boolean computeGradient = gx != null;
        int dim1 = this.shape.dimension(0);
        int dim2 = this.shape.dimension(1);
        int dim3 = this.shape.dimension(2);
        float w1 = (float)(1.0 / (4.0 * HyperbolicTotalVariation.square(this.scale[0])));
        float w2 = (float)(1.0 / (4.0 * HyperbolicTotalVariation.square(this.scale[1])));
        float w3 = (float)(1.0 / (4.0 * HyperbolicTotalVariation.square(this.scale[2])));
        float s = (float)HyperbolicTotalVariation.square(this.epsilon);
        double fcost = 0.0;
        float _alpha = (float)alpha;
        for (int i3 = 1; i3 < dim3; ++i3) {
            for (int i2 = 1; i2 < dim2; ++i2) {
                int j2 = (i2 - 1 + (i3 - 1) * dim2) * dim1;
                int j4 = (i2 + (i3 - 1) * dim2) * dim1;
                int j6 = (i2 - 1 + i3 * dim2) * dim1;
                int j8 = (i2 + i3 * dim2) * dim1;
                float x2 = x[j2];
                float x4 = x[j4];
                float x6 = x[j6];
                float x8 = x[j8];
                for (int i1 = 1; i1 < dim1; ++i1) {
                    int j1 = j2++;
                    float x1 = x2;
                    x2 = x[j2];
                    int j3 = j4++;
                    float x3 = x4;
                    x4 = x[j4];
                    int j5 = j6++;
                    float x5 = x6;
                    x6 = x[j6];
                    int j7 = j8++;
                    float x7 = x8;
                    x8 = x[j8];
                    float y21 = x2 - x1;
                    float y43 = x4 - x3;
                    float y65 = x6 - x5;
                    float y87 = x8 - x7;
                    float r1 = HyperbolicTotalVariation.square(y21) + HyperbolicTotalVariation.square(y43) + HyperbolicTotalVariation.square(y65) + HyperbolicTotalVariation.square(y87);
                    float y31 = x3 - x1;
                    float y42 = x4 - x2;
                    float y75 = x7 - x5;
                    float y86 = x8 - x6;
                    float r2 = HyperbolicTotalVariation.square(y31) + HyperbolicTotalVariation.square(y42) + HyperbolicTotalVariation.square(y75) + HyperbolicTotalVariation.square(y86);
                    float y51 = x5 - x1;
                    float y62 = x6 - x2;
                    float y73 = x7 - x3;
                    float y84 = x8 - x4;
                    float r3 = HyperbolicTotalVariation.square(y51) + HyperbolicTotalVariation.square(y62) + HyperbolicTotalVariation.square(y73) + HyperbolicTotalVariation.square(y84);
                    float r = HyperbolicTotalVariation.sqrt(w1 * r1 + w2 * r2 + w3 * r3 + s);
                    fcost += (double)r;
                    if (!computeGradient) continue;
                    float q = _alpha / r;
                    float p1 = w1 * q;
                    y21 *= p1;
                    y43 *= p1;
                    y65 *= p1;
                    y87 *= p1;
                    float p2 = w2 * q;
                    y31 *= p2;
                    y42 *= p2;
                    y75 *= p2;
                    y86 *= p2;
                    float p3 = w3 * q;
                    y62 *= p3;
                    y73 *= p3;
                    y84 *= p3;
                    int n = j1;
                    gx[n] = gx[n] - (y21 + y31 + (y51 *= p3));
                    int n2 = j2;
                    gx[n2] = gx[n2] + (y21 - y42 - y62);
                    int n3 = j3;
                    gx[n3] = gx[n3] - (y43 - y31 + y73);
                    int n4 = j4;
                    gx[n4] = gx[n4] + (y43 + y42 - y84);
                    int n5 = j5;
                    gx[n5] = gx[n5] - (y65 + y75 - y51);
                    int n6 = j6;
                    gx[n6] = gx[n6] + (y65 - y86 + y62);
                    int n7 = j7;
                    gx[n7] = gx[n7] - (y87 - y75 - y73);
                    int n8 = j8;
                    gx[n8] = gx[n8] + (y87 + y86 + y84);
                }
            }
        }
        return (fcost -= (double)((dim1 - 1) * (dim2 - 1) * (dim3 - 1)) * this.epsilon) > 0.0 ? alpha * fcost : 0.0;
    }

    private final double computeDouble3D(double alpha, double[] x, double[] gx) {
        boolean computeGradient = gx != null;
        int dim1 = this.shape.dimension(0);
        int dim2 = this.shape.dimension(1);
        int dim3 = this.shape.dimension(2);
        double w1 = 1.0 / (4.0 * HyperbolicTotalVariation.square(this.scale[0]));
        double w2 = 1.0 / (4.0 * HyperbolicTotalVariation.square(this.scale[1]));
        double w3 = 1.0 / (4.0 * HyperbolicTotalVariation.square(this.scale[2]));
        double s = HyperbolicTotalVariation.square(this.epsilon);
        double fcost = 0.0;
        for (int i3 = 1; i3 < dim3; ++i3) {
            for (int i2 = 1; i2 < dim2; ++i2) {
                int j2 = (i2 - 1 + (i3 - 1) * dim2) * dim1;
                int j4 = (i2 + (i3 - 1) * dim2) * dim1;
                int j6 = (i2 - 1 + i3 * dim2) * dim1;
                int j8 = (i2 + i3 * dim2) * dim1;
                double x2 = x[j2];
                double x4 = x[j4];
                double x6 = x[j6];
                double x8 = x[j8];
                for (int i1 = 1; i1 < dim1; ++i1) {
                    int j1 = j2++;
                    double x1 = x2;
                    x2 = x[j2];
                    int j3 = j4++;
                    double x3 = x4;
                    x4 = x[j4];
                    int j5 = j6++;
                    double x5 = x6;
                    x6 = x[j6];
                    int j7 = j8++;
                    double x7 = x8;
                    x8 = x[j8];
                    double y21 = x2 - x1;
                    double y43 = x4 - x3;
                    double y65 = x6 - x5;
                    double y87 = x8 - x7;
                    double r1 = HyperbolicTotalVariation.square(y21) + HyperbolicTotalVariation.square(y43) + HyperbolicTotalVariation.square(y65) + HyperbolicTotalVariation.square(y87);
                    double y31 = x3 - x1;
                    double y42 = x4 - x2;
                    double y75 = x7 - x5;
                    double y86 = x8 - x6;
                    double r2 = HyperbolicTotalVariation.square(y31) + HyperbolicTotalVariation.square(y42) + HyperbolicTotalVariation.square(y75) + HyperbolicTotalVariation.square(y86);
                    double y51 = x5 - x1;
                    double y62 = x6 - x2;
                    double y73 = x7 - x3;
                    double y84 = x8 - x4;
                    double r3 = HyperbolicTotalVariation.square(y51) + HyperbolicTotalVariation.square(y62) + HyperbolicTotalVariation.square(y73) + HyperbolicTotalVariation.square(y84);
                    double r = HyperbolicTotalVariation.sqrt(w1 * r1 + w2 * r2 + w3 * r3 + s);
                    fcost += r;
                    if (!computeGradient) continue;
                    double q = alpha / r;
                    double p1 = w1 * q;
                    y21 *= p1;
                    y43 *= p1;
                    y65 *= p1;
                    y87 *= p1;
                    double p2 = w2 * q;
                    y31 *= p2;
                    y42 *= p2;
                    y75 *= p2;
                    y86 *= p2;
                    double p3 = w3 * q;
                    y62 *= p3;
                    y73 *= p3;
                    y84 *= p3;
                    int n = j1;
                    gx[n] = gx[n] - (y21 + y31 + (y51 *= p3));
                    int n2 = j2;
                    gx[n2] = gx[n2] + (y21 - y42 - y62);
                    int n3 = j3;
                    gx[n3] = gx[n3] - (y43 - y31 + y73);
                    int n4 = j4;
                    gx[n4] = gx[n4] + (y43 + y42 - y84);
                    int n5 = j5;
                    gx[n5] = gx[n5] - (y65 + y75 - y51);
                    int n6 = j6;
                    gx[n6] = gx[n6] + (y65 - y86 + y62);
                    int n7 = j7;
                    gx[n7] = gx[n7] - (y87 - y75 - y73);
                    int n8 = j8;
                    gx[n8] = gx[n8] + (y87 + y86 + y84);
                }
            }
        }
        return (fcost -= (double)((dim1 - 1) * (dim2 - 1) * (dim3 - 1)) * this.epsilon) > 0.0 ? alpha * fcost : 0.0;
    }

    private static final boolean notFinite(double val) {
        return Double.isInfinite(val) || Double.isNaN(val);
    }

    private static final float sqrt(float a) {
        return (float)Math.sqrt(a);
    }

    private static final double sqrt(double a) {
        return Math.sqrt(a);
    }

    private static final float square(float a) {
        return a * a;
    }

    private static final double square(double a) {
        return a * a;
    }

    @Override
    public double evaluate(double alpha, Vector x) {
        return this.computeCostAndGradient(alpha, x, null, false);
    }
}

