package mitiv.invpb;

import mitiv.base.mapping.Mapping;
import mitiv.cost.CostFunction;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;
import mitiv.linalg.shaped.DoubleShapedVector;
import mitiv.linalg.shaped.FloatShapedVector;
import mitiv.linalg.shaped.ShapedVector;
import mitiv.linalg.shaped.ShapedVectorSpace;

/* loaded from: input_file:mitiv/invpb/GaussianLikelihood.class */
public class GaussianLikelihood implements CostFunction {
    protected final WeightedData weightedData;
    protected final Mapping directModel;
    protected final VectorSpace variableSpace;
    protected final ShapedVectorSpace dataSpace;
    protected ShapedVector work1 = null;
    protected boolean ignoreWeights = false;

    public GaussianLikelihood(WeightedData weightedData, Mapping mapping) {
        if (mapping.getOutputSpace() != weightedData.getDataSpace()) {
            throw new IllegalArgumentException("Output space of the direct model must be the data space");
        }
        this.directModel = mapping;
        this.variableSpace = mapping.getInputSpace();
        this.dataSpace = weightedData.getDataSpace();
        this.weightedData = weightedData;
    }

    @Override // mitiv.cost.CostFunction
    public final VectorSpace getInputSpace() {
        return this.variableSpace;
    }

    public final VectorSpace getVariableSpace() {
        return this.variableSpace;
    }

    public final ShapedVectorSpace getDataSpace() {
        return this.dataSpace;
    }

    public final ShapedVector getData() {
        return this.weightedData.getData();
    }

    public final ShapedVector getWeight() {
        return this.weightedData.getWeight();
    }

    public final boolean singlePrecision() {
        return this.weightedData.singlePrecision();
    }

    public final ShapedVector computeModel(ShapedVector shapedVector) {
        ShapedVector create = this.dataSpace.create();
        computeModel(create, shapedVector);
        return create;
    }

    public final void computeModel(ShapedVector shapedVector, Vector vector) {
        this.directModel.apply(shapedVector, vector);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final void computeResiduals(Vector vector) {
        if (this.work1 == null) {
            this.work1 = this.dataSpace.create();
        }
        computeModel(this.work1, vector);
        this.work1.combine(1.0d, this.work1, -1.0d, getData());
    }

    @Override // mitiv.cost.CostFunction
    public final double evaluate(double d, Vector vector) {
        if (d == 0.0d) {
            return 0.0d;
        }
        computeResiduals(vector);
        double d2 = 0.0d;
        if (this.ignoreWeights) {
            d2 = this.work1.norm2();
        } else if (singlePrecision()) {
            float[] data = ((FloatShapedVector) this.work1).getData();
            float[] data2 = ((FloatShapedVector) getWeight()).getData();
            for (int i = 0; i < data.length; i++) {
                d2 += data2[i] * data[i] * data[i];
            }
        } else {
            double[] data3 = ((DoubleShapedVector) this.work1).getData();
            double[] data4 = ((DoubleShapedVector) getWeight()).getData();
            for (int i2 = 0; i2 < data3.length; i2++) {
                d2 += data4[i2] * data3[i2] * data3[i2];
            }
        }
        return (d * d2) / 2.0d;
    }
}
