package mitiv.cost;

import mitiv.exception.IncorrectSpaceException;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;

/* loaded from: input_file:mitiv/cost/CompositeDifferentiableCostFunction.class */
public class CompositeDifferentiableCostFunction implements DifferentiableCostFunction {
    private final VectorSpace inputSpace;
    private final DifferentiableCostFunction[] func;
    private final double[] wght;
    protected int nfx = 0;
    protected int ngx = 0;

    public CompositeDifferentiableCostFunction(double d, DifferentiableCostFunction differentiableCostFunction) {
        checkWeight(d);
        this.inputSpace = differentiableCostFunction.getInputSpace();
        this.func = new DifferentiableCostFunction[]{differentiableCostFunction};
        this.wght = new double[]{d};
    }

    public CompositeDifferentiableCostFunction(double d, DifferentiableCostFunction differentiableCostFunction, double d2, DifferentiableCostFunction differentiableCostFunction2) {
        checkWeight(d);
        checkWeight(d2);
        this.inputSpace = differentiableCostFunction.getInputSpace();
        if (differentiableCostFunction2.getInputSpace() != this.inputSpace) {
            throw new IncorrectSpaceException("All functions must have the same input space.");
        }
        this.func = new DifferentiableCostFunction[]{differentiableCostFunction, differentiableCostFunction2};
        this.wght = new double[]{d, d2};
    }

    public CompositeDifferentiableCostFunction(double d, DifferentiableCostFunction differentiableCostFunction, double d2, DifferentiableCostFunction differentiableCostFunction2, double d3, DifferentiableCostFunction differentiableCostFunction3) {
        checkWeight(d);
        checkWeight(d2);
        checkWeight(d3);
        this.inputSpace = differentiableCostFunction.getInputSpace();
        if (differentiableCostFunction2.getInputSpace() != this.inputSpace || differentiableCostFunction3.getInputSpace() != this.inputSpace) {
            throw new IncorrectSpaceException("All functions must have the same input space.");
        }
        this.func = new DifferentiableCostFunction[]{differentiableCostFunction, differentiableCostFunction2, differentiableCostFunction3};
        this.wght = new double[]{d, d2, d3};
    }

    @Override // mitiv.cost.DifferentiableCostFunction
    public double computeCostAndGradient(double d, Vector vector, Vector vector2, boolean z) {
        double d2 = 0.0d;
        if (d != 0.0d) {
            for (int i = 0; i < this.func.length; i++) {
                if (this.wght[i] != 0.0d) {
                    d2 += this.func[i].computeCostAndGradient(d * this.wght[i], vector, vector2, z);
                    z = false;
                }
            }
        } else if (z) {
            vector2.fill(0.0d);
        }
        this.nfx++;
        this.ngx++;
        return d2;
    }

    @Override // mitiv.cost.CostFunction
    public VectorSpace getInputSpace() {
        return this.inputSpace;
    }

    @Override // mitiv.cost.CostFunction
    public double evaluate(double d, Vector vector) {
        double d2 = 0.0d;
        if (d != 0.0d) {
            for (int i = 0; i < this.func.length; i++) {
                if (this.wght[i] != 0.0d) {
                    d2 += this.func[i].evaluate(d * this.wght[i], vector);
                }
            }
        }
        this.nfx++;
        return d2;
    }

    public int getNumberOfFunctionCalls() {
        return this.nfx;
    }

    public int getNumberOfGradientCalls() {
        return this.ngx;
    }

    private static final void checkWeight(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d) || d < 0.0d) {
            throw new IllegalArgumentException("Cost function weight must be finite and nonnegative");
        }
    }
}
