package mitiv.cost;

import mitiv.exception.IncorrectSpaceException;
import mitiv.linalg.LinearOperator;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;

/* loaded from: input_file:mitiv/cost/QuadraticCost.class */
public class QuadraticCost implements DifferentiableCostFunction {
    protected VectorSpace inputSpace;
    protected LinearOperator H;
    protected LinearOperator W;
    protected Vector y;
    protected Vector r;
    protected Vector Wr;
    protected Vector HtWr;
    protected boolean quickResiduals;
    protected boolean quickWeightedResiduals;
    protected boolean quickQuasiGradient;
    protected boolean shareMemory;

    public QuadraticCost(LinearOperator linearOperator, Vector vector, LinearOperator linearOperator2) {
        this.inputSpace = null;
        this.H = null;
        this.W = null;
        this.y = null;
        this.r = null;
        this.Wr = null;
        this.HtWr = null;
        setComponents(linearOperator, vector, linearOperator2);
    }

    public QuadraticCost(VectorSpace vectorSpace) {
        this.inputSpace = null;
        this.H = null;
        this.W = null;
        this.y = null;
        this.r = null;
        this.Wr = null;
        this.HtWr = null;
        this.inputSpace = vectorSpace;
    }

    public QuadraticCost(LinearOperator linearOperator, Vector vector) {
        this(linearOperator, vector, null);
    }

    public QuadraticCost(LinearOperator linearOperator) {
        this(linearOperator, null, null);
    }

    public void setComponents(LinearOperator linearOperator, Vector vector, LinearOperator linearOperator2) {
        VectorSpace vectorSpace;
        VectorSpace vectorSpace2;
        if (linearOperator != null) {
            vectorSpace = linearOperator.getInputSpace();
            vectorSpace2 = linearOperator.getOutputSpace();
        } else {
            vectorSpace = null;
            vectorSpace2 = null;
        }
        if (vector != null) {
            if (vectorSpace2 == null) {
                vectorSpace2 = vector.getSpace();
                vectorSpace = vectorSpace2;
            } else if (!vector.belongsTo(vectorSpace2)) {
                throw new IncorrectSpaceException("y must belong to the output space of H");
            }
        }
        if (linearOperator2 != null) {
            if (!linearOperator2.isEndomorphism()) {
                throw new IncorrectSpaceException("W must be an endomorphism");
            }
            if (vectorSpace2 == null) {
                vectorSpace2 = linearOperator2.getInputSpace();
                vectorSpace = vectorSpace2;
            } else if (linearOperator2.getInputSpace() != vectorSpace2) {
                throw new IncorrectSpaceException("incompatible vector space for operator W");
            }
        }
        if (vectorSpace == null) {
            throw new IllegalArgumentException("one of H, y, or W must be non null");
        }
        this.H = linearOperator;
        this.y = vector;
        this.W = linearOperator2;
        this.inputSpace = vectorSpace;
        this.quickResiduals = linearOperator == null && vector == null;
        this.quickWeightedResiduals = linearOperator2 == null;
        this.quickQuasiGradient = linearOperator == null;
        this.shareMemory = !this.quickResiduals && vectorSpace == vectorSpace2;
        if (this.r != null && (this.quickResiduals || !this.r.belongsTo(vectorSpace2))) {
            this.r = null;
        }
        if (this.Wr != null && (this.quickWeightedResiduals || !this.Wr.belongsTo(vectorSpace2))) {
            this.Wr = null;
        }
        if (this.shareMemory) {
            this.HtWr = this.r;
        } else if (this.HtWr != null) {
            if (this.quickQuasiGradient || !this.HtWr.belongsTo(vectorSpace)) {
                this.HtWr = null;
            }
        }
    }

    @Override // mitiv.cost.CostFunction
    public VectorSpace getInputSpace() {
        return this.inputSpace;
    }

    @Override // mitiv.cost.CostFunction
    public double evaluate(double d, Vector vector) {
        if (d == 0.0d) {
            return 0.0d;
        }
        formResiduals(vector);
        double dot = this.r.dot(this.Wr);
        if (this.quickResiduals) {
            this.r = null;
        }
        if (this.quickWeightedResiduals) {
            this.Wr = null;
        }
        return d * dot;
    }

    @Override // mitiv.cost.DifferentiableCostFunction
    public double computeCostAndGradient(double d, Vector vector, Vector vector2, boolean z) {
        if (d == 0.0d) {
            if (!z) {
                return 0.0d;
            }
            vector2.zero();
            return 0.0d;
        }
        formResiduals(vector);
        double dot = this.r.dot(this.Wr);
        if (this.quickQuasiGradient) {
            this.HtWr = this.Wr;
        } else {
            if (this.HtWr == null || !this.HtWr.belongsTo(this.inputSpace)) {
                if (this.shareMemory) {
                    this.HtWr = this.r;
                } else {
                    this.HtWr = this.inputSpace.create();
                }
            }
            this.H.apply(this.HtWr, this.Wr, LinearOperator.ADJOINT);
        }
        vector2.combine(z ? 0 : 1, vector2, 2.0d * d, this.HtWr);
        if (this.quickResiduals) {
            this.r = null;
        }
        if (this.quickWeightedResiduals) {
            this.Wr = null;
        }
        if (this.quickQuasiGradient) {
            this.HtWr = null;
        }
        return d * dot;
    }

    private void formResiduals(Vector vector) {
        VectorSpace outputSpace = this.H == null ? this.inputSpace : this.H.getOutputSpace();
        if (this.quickResiduals) {
            this.r = vector;
        } else {
            if (this.r == null || !this.r.belongsTo(outputSpace)) {
                this.r = outputSpace.create();
            }
            if (this.H == null) {
                this.r.combine(1.0d, vector, -1.0d, this.y);
            } else {
                this.H.apply(this.r, vector);
                if (this.y != null) {
                    this.r.add(-1.0d, this.y);
                }
            }
        }
        if (this.quickWeightedResiduals) {
            this.Wr = this.r;
            return;
        }
        if (this.Wr == null || !this.Wr.belongsTo(outputSpace)) {
            this.Wr = outputSpace.create();
        }
        this.W.apply(this.Wr, this.r);
    }
}
