/*
 * Decompiled with CFR 0.152.
 */
package mitiv.cost;

import mitiv.cost.DifferentiableCostFunction;
import mitiv.exception.IncorrectSpaceException;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;

public class CompositeDifferentiableCostFunction
implements DifferentiableCostFunction {
    private final VectorSpace inputSpace;
    private final DifferentiableCostFunction[] func;
    private final double[] wght;

    public CompositeDifferentiableCostFunction(double w1, DifferentiableCostFunction f1) {
        this.inputSpace = f1.getInputSpace();
        this.func = new DifferentiableCostFunction[]{f1};
        this.wght = new double[]{w1};
    }

    public CompositeDifferentiableCostFunction(double w1, DifferentiableCostFunction f1, double w2, DifferentiableCostFunction f2) {
        this.inputSpace = f1.getInputSpace();
        if (f2.getInputSpace() != this.inputSpace) {
            throw new IncorrectSpaceException("All functions must have the same input space.");
        }
        this.func = new DifferentiableCostFunction[]{f1, f2};
        this.wght = new double[]{w1, w2};
    }

    public CompositeDifferentiableCostFunction(double w1, DifferentiableCostFunction f1, double w2, DifferentiableCostFunction f2, double w3, DifferentiableCostFunction f3) {
        this.inputSpace = f1.getInputSpace();
        if (f2.getInputSpace() != this.inputSpace || f3.getInputSpace() != this.inputSpace) {
            throw new IncorrectSpaceException("All functions must have the same input space.");
        }
        this.func = new DifferentiableCostFunction[]{f1, f2, f3};
        this.wght = new double[]{w1, w2, w3};
    }

    @Override
    public double computeCostAndGradient(double alpha, Vector x, Vector gx, boolean clr) {
        double cost = 0.0;
        if (alpha == 0.0) {
            if (clr) {
                gx.fill(0.0);
            }
        } else {
            int k = 0;
            while (k < this.func.length) {
                if (this.wght[k] != 0.0) {
                    cost += this.func[k].computeCostAndGradient(alpha * this.wght[k], x, gx, clr);
                    clr = false;
                }
                ++k;
            }
        }
        return cost;
    }

    @Override
    public VectorSpace getInputSpace() {
        return this.inputSpace;
    }

    @Override
    public double evaluate(double alpha, Vector x) {
        double cost = 0.0;
        if (alpha != 0.0) {
            int k = 0;
            while (k < this.func.length) {
                if (this.wght[k] != 0.0) {
                    cost += this.func[k].evaluate(alpha * this.wght[k], x);
                }
                ++k;
            }
        }
        return cost;
    }
}

