/*
 * Decompiled with CFR 0.152.
 */
package mitiv.cost;

import mitiv.cost.DifferentiableCostFunction;
import mitiv.exception.IncorrectSpaceException;
import mitiv.linalg.LinearOperator;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;

public class QuadraticCost
implements DifferentiableCostFunction {
    protected VectorSpace inputSpace = null;
    protected LinearOperator H = null;
    protected LinearOperator W = null;
    protected Vector y = null;
    protected Vector r = null;
    protected Vector Wr = null;
    protected Vector HtWr = null;
    protected boolean quickResiduals;
    protected boolean quickWeightedResiduals;
    protected boolean quickQuasiGradient;
    protected boolean shareMemory;

    public QuadraticCost(LinearOperator linearOperator, Vector vector, LinearOperator linearOperator2) {
        this.setComponents(linearOperator, vector, linearOperator2);
    }

    public QuadraticCost(VectorSpace vectorSpace) {
        this.inputSpace = vectorSpace;
    }

    public QuadraticCost(LinearOperator linearOperator, Vector vector) {
        this(linearOperator, vector, null);
    }

    public QuadraticCost(LinearOperator linearOperator) {
        this(linearOperator, null, null);
    }

    public void setComponents(LinearOperator linearOperator, Vector vector, LinearOperator linearOperator2) {
        VectorSpace vectorSpace;
        VectorSpace vectorSpace2;
        if (linearOperator != null) {
            vectorSpace2 = linearOperator.getInputSpace();
            vectorSpace = linearOperator.getOutputSpace();
        } else {
            vectorSpace2 = null;
            vectorSpace = null;
        }
        if (vector != null) {
            if (vectorSpace == null) {
                vectorSpace2 = vectorSpace = vector.getSpace();
            } else if (!vector.belongsTo(vectorSpace)) {
                throw new IncorrectSpaceException("y must belong to the output space of H");
            }
        }
        if (linearOperator2 != null) {
            if (!linearOperator2.isEndomorphism()) {
                throw new IncorrectSpaceException("W must be an endomorphism");
            }
            if (vectorSpace == null) {
                vectorSpace2 = vectorSpace = linearOperator2.getInputSpace();
            } else if (linearOperator2.getInputSpace() != vectorSpace) {
                throw new IncorrectSpaceException("incompatible vector space for operator W");
            }
        }
        if (vectorSpace2 == null) {
            throw new IllegalArgumentException("one of H, y, or W must be non null");
        }
        this.H = linearOperator;
        this.y = vector;
        this.W = linearOperator2;
        this.inputSpace = vectorSpace2;
        this.quickResiduals = linearOperator == null && vector == null;
        this.quickWeightedResiduals = linearOperator2 == null;
        this.quickQuasiGradient = linearOperator == null;
        boolean bl = this.shareMemory = !this.quickResiduals && vectorSpace2 == vectorSpace;
        if (this.r != null && (this.quickResiduals || !this.r.belongsTo(vectorSpace))) {
            this.r = null;
        }
        if (this.Wr != null && (this.quickWeightedResiduals || !this.Wr.belongsTo(vectorSpace))) {
            this.Wr = null;
        }
        if (this.shareMemory) {
            this.HtWr = this.r;
        } else if (this.HtWr != null && (this.quickQuasiGradient || !this.HtWr.belongsTo(vectorSpace2))) {
            this.HtWr = null;
        }
    }

    @Override
    public VectorSpace getInputSpace() {
        return this.inputSpace;
    }

    @Override
    public double evaluate(double d, Vector vector) {
        if (d == 0.0) {
            return 0.0;
        }
        this.formResiduals(vector);
        double d2 = this.r.dot(this.Wr);
        if (this.quickResiduals) {
            this.r = null;
        }
        if (this.quickWeightedResiduals) {
            this.Wr = null;
        }
        return d * d2;
    }

    @Override
    public double computeCostAndGradient(double d, Vector vector, Vector vector2, boolean bl) {
        if (d == 0.0) {
            if (bl) {
                vector2.zero();
            }
            return 0.0;
        }
        this.formResiduals(vector);
        double d2 = this.r.dot(this.Wr);
        if (this.quickQuasiGradient) {
            this.HtWr = this.Wr;
        } else {
            if (this.HtWr == null || !this.HtWr.belongsTo(this.inputSpace)) {
                this.HtWr = this.shareMemory ? this.r : this.inputSpace.create();
            }
            this.H.apply(this.HtWr, this.Wr, LinearOperator.ADJOINT);
        }
        vector2.combine(!bl ? 1 : 0, vector2, 2.0 * d, this.HtWr);
        if (this.quickResiduals) {
            this.r = null;
        }
        if (this.quickWeightedResiduals) {
            this.Wr = null;
        }
        if (this.quickQuasiGradient) {
            this.HtWr = null;
        }
        return d * d2;
    }

    private void formResiduals(Vector vector) {
        VectorSpace vectorSpace;
        VectorSpace vectorSpace2 = vectorSpace = this.H == null ? this.inputSpace : this.H.getOutputSpace();
        if (this.quickResiduals) {
            this.r = vector;
        } else {
            if (this.r == null || !this.r.belongsTo(vectorSpace)) {
                this.r = vectorSpace.create();
            }
            if (this.H == null) {
                this.r.combine(1.0, vector, -1.0, this.y);
            } else {
                this.H.apply(this.r, vector);
                if (this.y != null) {
                    this.r.add(-1.0, this.y);
                }
            }
        }
        if (this.quickWeightedResiduals) {
            this.Wr = this.r;
        } else {
            if (this.Wr == null || !this.Wr.belongsTo(vectorSpace)) {
                this.Wr = vectorSpace.create();
            }
            this.W.apply(this.Wr, this.r);
        }
    }
}

