/*
 * Decompiled with CFR 0.152.
 */
package mitiv.cost;

import mitiv.cost.DifferentiableCostFunction;
import mitiv.exception.IncorrectSpaceException;
import mitiv.linalg.LinearOperator;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;

public class QuadraticCost
implements DifferentiableCostFunction {
    protected VectorSpace inputSpace = null;
    protected LinearOperator H = null;
    protected LinearOperator W = null;
    protected Vector y = null;
    protected Vector r = null;
    protected Vector Wr = null;
    protected Vector HtWr = null;
    protected boolean quickResiduals;
    protected boolean quickWeightedResiduals;
    protected boolean quickQuasiGradient;
    protected boolean shareMemory;

    public QuadraticCost(LinearOperator H, Vector y, LinearOperator W) {
        this.setComponents(H, y, W);
    }

    public QuadraticCost(VectorSpace space) {
        this.inputSpace = space;
    }

    public QuadraticCost(LinearOperator H, Vector y) {
        this(H, y, null);
    }

    public QuadraticCost(LinearOperator H) {
        this(H, null, null);
    }

    public void setComponents(LinearOperator H, Vector y, LinearOperator W) {
        VectorSpace outputSpace;
        VectorSpace inputSpace;
        if (H != null) {
            inputSpace = H.getInputSpace();
            outputSpace = H.getOutputSpace();
        } else {
            inputSpace = null;
            outputSpace = null;
        }
        if (y != null) {
            if (outputSpace == null) {
                inputSpace = outputSpace = y.getSpace();
            } else if (!y.belongsTo(outputSpace)) {
                throw new IncorrectSpaceException("y must belong to the output space of H");
            }
        }
        if (W != null) {
            if (!W.isEndomorphism()) {
                throw new IncorrectSpaceException("W must be an endomorphism");
            }
            if (outputSpace == null) {
                inputSpace = outputSpace = W.getInputSpace();
            } else if (W.getInputSpace() != outputSpace) {
                throw new IncorrectSpaceException("incompatible vector space for operator W");
            }
        }
        if (inputSpace == null) {
            throw new IllegalArgumentException("one of H, y, or W must be non null");
        }
        this.H = H;
        this.y = y;
        this.W = W;
        this.inputSpace = inputSpace;
        this.quickResiduals = H == null && y == null;
        this.quickWeightedResiduals = W == null;
        this.quickQuasiGradient = H == null;
        boolean bl = this.shareMemory = !this.quickResiduals && inputSpace == outputSpace;
        if (this.r != null && (this.quickResiduals || !this.r.belongsTo(outputSpace))) {
            this.r = null;
        }
        if (this.Wr != null && (this.quickWeightedResiduals || !this.Wr.belongsTo(outputSpace))) {
            this.Wr = null;
        }
        if (this.shareMemory) {
            this.HtWr = this.r;
        } else if (this.HtWr != null && (this.quickQuasiGradient || !this.HtWr.belongsTo(inputSpace))) {
            this.HtWr = null;
        }
    }

    @Override
    public VectorSpace getInputSpace() {
        return this.inputSpace;
    }

    @Override
    public double evaluate(double alpha, Vector x) {
        if (alpha == 0.0) {
            return 0.0;
        }
        this.formResiduals(x);
        double q = this.r.dot(this.Wr);
        if (this.quickResiduals) {
            this.r = null;
        }
        if (this.quickWeightedResiduals) {
            this.Wr = null;
        }
        return alpha * q;
    }

    @Override
    public double computeCostAndGradient(double alpha, Vector x, Vector gx, boolean clr) {
        if (alpha == 0.0) {
            if (clr) {
                gx.zero();
            }
            return 0.0;
        }
        this.formResiduals(x);
        double q = this.r.dot(this.Wr);
        if (this.quickQuasiGradient) {
            this.HtWr = this.Wr;
        } else {
            if (this.HtWr == null || !this.HtWr.belongsTo(this.inputSpace)) {
                this.HtWr = this.shareMemory ? this.r : this.inputSpace.create();
            }
            this.H.apply(this.Wr, this.HtWr, LinearOperator.ADJOINT);
        }
        this.inputSpace.axpby(clr ? 0.0 : 1.0, gx, 2.0 * alpha, this.HtWr, gx);
        if (this.quickResiduals) {
            this.r = null;
        }
        if (this.quickWeightedResiduals) {
            this.Wr = null;
        }
        if (this.quickQuasiGradient) {
            this.HtWr = null;
        }
        return alpha * q;
    }

    private void formResiduals(Vector x) {
        VectorSpace innerSpace;
        VectorSpace vectorSpace = innerSpace = this.H == null ? this.inputSpace : this.H.getOutputSpace();
        if (this.quickResiduals) {
            this.r = x;
        } else {
            if (this.r == null || !this.r.belongsTo(innerSpace)) {
                this.r = innerSpace.create();
            }
            if (this.H == null) {
                innerSpace.axpby(1.0, x, -1.0, this.y, this.r);
            } else {
                this.H.apply(x, this.r);
                if (this.y != null) {
                    innerSpace.axpby(1.0, this.r, -1.0, this.y, this.r);
                }
            }
        }
        if (this.quickWeightedResiduals) {
            this.Wr = this.r;
        } else {
            if (this.Wr == null || !this.Wr.belongsTo(innerSpace)) {
                this.Wr = innerSpace.create();
            }
            this.W.apply(this.r, this.Wr);
        }
    }
}

