package mitiv.cost;

import mitiv.array.ByteArray;
import mitiv.array.DoubleArray;
import mitiv.array.FloatArray;
import mitiv.array.IntArray;
import mitiv.array.LongArray;
import mitiv.array.ShapedArray;
import mitiv.array.ShortArray;
import mitiv.base.ArrayDescriptor;
import mitiv.base.Shape;
import mitiv.exception.NonConformableArrayException;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;
import mitiv.linalg.shaped.DoubleShapedVector;
import mitiv.linalg.shaped.DoubleShapedVectorSpace;
import mitiv.linalg.shaped.FloatShapedVector;
import mitiv.linalg.shaped.FloatShapedVectorSpace;
import mitiv.linalg.shaped.ShapedVector;
import mitiv.linalg.shaped.ShapedVectorSpace;
import mitiv.utils.WeightFactory;

/* loaded from: input_file:mitiv/cost/WeightedData.class */
public class WeightedData implements DifferentiableCostFunction {
    private final ShapedVectorSpace dataSpace;
    private final boolean single;
    private ShapedVector data;
    private ShapedVector weights;
    private boolean updatePending;
    private boolean writableData;
    private boolean writableWeights;
    private int validDataNumber;

    public WeightedData(ArrayDescriptor arrayDescriptor) {
        this(arrayDescriptor.getType(), arrayDescriptor.getShape());
    }

    public WeightedData(ShapedVectorSpace shapedVectorSpace) {
        this.data = null;
        this.weights = null;
        this.updatePending = true;
        this.writableData = false;
        this.writableWeights = false;
        this.validDataNumber = 0;
        switch (shapedVectorSpace.getType()) {
            case 4:
                this.single = true;
                break;
            case 5:
                this.single = false;
                break;
            default:
                throw mustBeFloatingPoint();
        }
        this.dataSpace = shapedVectorSpace;
    }

    public WeightedData(int i, Shape shape) {
        this.data = null;
        this.weights = null;
        this.updatePending = true;
        this.writableData = false;
        this.writableWeights = false;
        this.validDataNumber = 0;
        switch (i) {
            case 4:
                this.single = true;
                this.dataSpace = new FloatShapedVectorSpace(shape);
                return;
            case 5:
                this.single = false;
                this.dataSpace = new DoubleShapedVectorSpace(shape);
                return;
            default:
                throw mustBeFloatingPoint();
        }
    }

    public WeightedData(int i, int... iArr) {
        this(i, new Shape(iArr));
    }

    public WeightedData(ShapedArray shapedArray, boolean z) {
        this(shapedArray.getType(), shapedArray.getShape());
        setData(shapedArray, z);
    }

    public WeightedData(ShapedArray shapedArray) {
        this(shapedArray, false);
    }

    public WeightedData(ShapedVector shapedVector, boolean z) {
        this(shapedVector.getSpace());
        setData(shapedVector, z);
    }

    public WeightedData(ShapedVector shapedVector) {
        this(shapedVector.getSpace());
        setData(shapedVector);
    }

    public WeightedData(ShapedArray shapedArray, ShapedArray shapedArray2) {
        this(Math.max(shapedArray.getType(), shapedArray2.getType()), shapedArray.getShape());
        setData(shapedArray);
        setWeights(shapedArray2);
    }

    public WeightedData(ShapedVector shapedVector, boolean z, ShapedVector shapedVector2, boolean z2) {
        this(shapedVector.getSpace());
        setData(shapedVector, z);
        setWeights(shapedVector2, z2);
    }

    public WeightedData(ShapedVector shapedVector, ShapedVector shapedVector2) {
        this(shapedVector.getSpace());
        setData(shapedVector);
        setWeights(shapedVector2);
    }

    public final boolean isSinglePrecision() {
        return this.single;
    }

    public final ShapedVectorSpace getDataSpace() {
        return this.dataSpace;
    }

    public final ShapedVector getData() {
        if (this.updatePending) {
            update();
        }
        return this.data;
    }

    public final ShapedVector getWeights() {
        if (this.updatePending) {
            update();
        }
        return this.weights;
    }

    public final int getValidDataNumber() {
        if (this.updatePending) {
            update();
        }
        return this.validDataNumber;
    }

    public double getWeightedMean() {
        if (this.updatePending) {
            update();
        }
        if (this.validDataNumber < 1) {
            return Double.NaN;
        }
        double d = 0.0d;
        double d2 = 0.0d;
        if (this.single) {
            float[] data = ((FloatShapedVector) this.weights).getData();
            for (int i = 0; i < ((FloatShapedVector) this.data).getData().length; i++) {
                d2 += data[i];
                d += data[i] * r0[i];
            }
        } else {
            double[] data2 = ((DoubleShapedVector) this.weights).getData();
            double[] data3 = ((DoubleShapedVector) this.data).getData();
            for (int i2 = 0; i2 < data3.length; i2++) {
                d2 += data2[i2];
                d += data2[i2] * data3[i2];
            }
        }
        return d / d2;
    }

    public void setData(ShapedArray shapedArray) {
        setData(shapedArray, false);
    }

    public void setData(ShapedArray shapedArray, boolean z) {
        if (!shapedArray.getShape().equals(this.dataSpace.getShape())) {
            throw nonconformableData();
        }
        if (shapedArray.getType() != this.dataSpace.getType()) {
            shapedArray = this.single ? shapedArray.toFloat() : shapedArray.toDouble();
            z = true;
        } else if (!shapedArray.isFlat()) {
            z = true;
        }
        setData(this.single ? new FloatShapedVector((FloatShapedVectorSpace) this.dataSpace, ((FloatArray) shapedArray).flatten()) : new DoubleShapedVector((DoubleShapedVectorSpace) this.dataSpace, ((DoubleArray) shapedArray).flatten()), z);
    }

    public void setData(ShapedVector shapedVector) {
        setData(shapedVector, false);
    }

    public void setData(ShapedVector shapedVector, boolean z) {
        shapedVector.assertBelongsTo(this.dataSpace);
        if (this.data != null) {
            throw dataCanOnlyBeSpecifiedOnce();
        }
        this.data = shapedVector;
        this.writableData = z;
        this.updatePending = true;
    }

    public void setWeights(ShapedArray shapedArray) {
        setWeights(shapedArray, false);
    }

    public void setWeights(ShapedArray shapedArray, boolean z) {
        if (!shapedArray.getShape().equals(this.dataSpace.getShape())) {
            throw nonconformableData();
        }
        if (shapedArray.getType() != this.dataSpace.getType()) {
            shapedArray = this.single ? shapedArray.toFloat() : shapedArray.toDouble();
            z = true;
        } else if (!shapedArray.isFlat()) {
            z = true;
        }
        setWeights(this.single ? new FloatShapedVector((FloatShapedVectorSpace) this.dataSpace, ((FloatArray) shapedArray).flatten()) : new DoubleShapedVector((DoubleShapedVectorSpace) this.dataSpace, ((DoubleArray) shapedArray).flatten()), z);
    }

    public void setWeights(ShapedVector shapedVector) {
        setWeights(shapedVector, false);
    }

    public void setWeights(ShapedVector shapedVector, boolean z) {
        shapedVector.assertBelongsTo(this.dataSpace);
        if (this.weights != null && !this.writableWeights) {
            throw weightsCanOnlyBeSpecifiedOnce();
        }
        this.weights = shapedVector;
        this.writableWeights = z;
        this.updatePending = true;
    }

    public void computeWeightsFromData(double d, double d2) {
        if (this.data == null) {
            noDataHasBeenSpecified();
        }
        if (this.weights != null) {
            throw weightsCanOnlyBeSpecifiedOnce();
        }
        this.weights = this.data.create();
        this.writableWeights = true;
        if (this.single) {
            this.validDataNumber = WeightFactory.computeWeightsFromData(((FloatShapedVector) this.weights).getData(), ((FloatShapedVector) this.data).getData(), (float) d, (float) d2);
        } else {
            this.validDataNumber = WeightFactory.computeWeightsFromData(((DoubleShapedVector) this.weights).getData(), ((DoubleShapedVector) this.data).getData(), d, d2);
        }
        this.updatePending = true;
    }

    public void markBadData(ShapedVector shapedVector) {
        if (!shapedVector.getShape().equals(this.dataSpace.getShape())) {
            throw badMaskShape();
        }
        markBadData(toBoolean(shapedVector));
    }

    public void markBadData(ShapedArray shapedArray) {
        if (!shapedArray.getShape().equals(this.dataSpace.getShape())) {
            throw badMaskShape();
        }
        markBadData(toBoolean(shapedArray));
    }

    public final void markBadData(boolean[] zArr) {
        if (this.data == null) {
            throw noDataHasBeenSpecified();
        }
        int number = this.data.getNumber();
        if (zArr.length != number) {
            throw badMaskLength();
        }
        if (this.weights == null) {
            this.weights = this.data.create();
            this.writableWeights = true;
            if (this.single) {
                float[] data = ((FloatShapedVector) this.weights).getData();
                for (int i = 0; i < number; i++) {
                    data[i] = zArr[i] ? 0.0f : 1.0f;
                }
            } else {
                double[] data2 = ((DoubleShapedVector) this.weights).getData();
                for (int i2 = 0; i2 < number; i2++) {
                    data2[i2] = zArr[i2] ? 0.0d : 1.0d;
                }
            }
            this.updatePending = true;
            return;
        }
        if (this.single) {
            float[] data3 = ((FloatShapedVector) this.weights).getData();
            for (int i3 = 0; i3 < number; i3++) {
                if (zArr[i3] && data3[i3] != 0.0f) {
                    if (!this.writableWeights) {
                        cloneWeights();
                        data3 = ((FloatShapedVector) this.weights).getData();
                    }
                    data3[i3] = 0.0f;
                    this.updatePending = true;
                }
            }
            return;
        }
        double[] data4 = ((DoubleShapedVector) this.weights).getData();
        for (int i4 = 0; i4 < number; i4++) {
            if (zArr[i4] && data4[i4] != 0.0d) {
                if (!this.writableWeights) {
                    cloneWeights();
                    data4 = ((DoubleShapedVector) this.weights).getData();
                }
                data4[i4] = 0.0d;
                this.updatePending = true;
            }
        }
    }

    private final void cloneData() {
        if (this.writableData) {
            return;
        }
        this.data = this.data.mo7clone();
        this.writableData = true;
    }

    private final void cloneWeights() {
        if (this.writableWeights) {
            return;
        }
        this.weights = this.weights.mo7clone();
        this.writableWeights = true;
    }

    private void update() {
        if (this.data == null) {
            throw noDataHasBeenSpecified();
        }
        int i = 0;
        if (this.weights == null) {
            this.weights = this.dataSpace.create();
            this.writableWeights = true;
            if (this.single) {
                float[] data = ((FloatShapedVector) this.data).getData();
                float[] data2 = ((FloatShapedVector) this.weights).getData();
                for (int i2 = 0; i2 < data.length; i2++) {
                    if (nonfinite(data[i2])) {
                        if (!this.writableData) {
                            cloneData();
                            data = ((FloatShapedVector) this.data).getData();
                        }
                        data[i2] = 0.0f;
                        data2[i2] = 0.0f;
                    } else {
                        data2[i2] = 1.0f;
                        i++;
                    }
                }
            } else {
                double[] data3 = ((DoubleShapedVector) this.data).getData();
                double[] data4 = ((DoubleShapedVector) this.weights).getData();
                for (int i3 = 0; i3 < data3.length; i3++) {
                    if (nonfinite(data3[i3])) {
                        if (!this.writableData) {
                            cloneData();
                            data3 = ((DoubleShapedVector) this.data).getData();
                        }
                        data3[i3] = 0.0d;
                        data4[i3] = 0.0d;
                    } else {
                        data4[i3] = 1.0d;
                        i++;
                    }
                }
            }
        } else if (this.single) {
            float[] data5 = ((FloatShapedVector) this.data).getData();
            float[] data6 = ((FloatShapedVector) this.weights).getData();
            for (int i4 = 0; i4 < data5.length; i4++) {
                if (nonfinite(data6[i4]) || data6[i4] < 0.0f) {
                    throw weightsMustBeFiniteAndNonnegative();
                }
                if (nonfinite(data5[i4])) {
                    if (data6[i4] > 0.0f) {
                        throw nonfiniteDataMustHaveZeroWeight();
                    }
                    if (!this.writableData) {
                        cloneData();
                        data5 = ((FloatShapedVector) this.data).getData();
                    }
                    data5[i4] = 0.0f;
                }
                if (data6[i4] > 0.0f) {
                    i++;
                }
            }
        } else {
            double[] data7 = ((DoubleShapedVector) this.data).getData();
            double[] data8 = ((DoubleShapedVector) this.weights).getData();
            for (int i5 = 0; i5 < data7.length; i5++) {
                if (nonfinite(data8[i5]) || data8[i5] < 0.0d) {
                    throw weightsMustBeFiniteAndNonnegative();
                }
                if (nonfinite(data7[i5])) {
                    if (data8[i5] > 0.0d) {
                        throw nonfiniteDataMustHaveZeroWeight();
                    }
                    if (!this.writableData) {
                        cloneData();
                        data7 = ((DoubleShapedVector) this.data).getData();
                    }
                    data7[i5] = 0.0d;
                }
                if (data8[i5] > 0.0d) {
                    i++;
                }
            }
        }
        this.updatePending = false;
        this.validDataNumber = i;
    }

    private static final boolean nonfinite(float f) {
        return Float.isInfinite(f) || Float.isNaN(f);
    }

    private static final boolean nonfinite(double d) {
        return Double.isInfinite(d) || Double.isNaN(d);
    }

    private boolean[] toBoolean(ShapedArray shapedArray) {
        switch (shapedArray.getType()) {
            case 0:
                return toBoolean(((ByteArray) shapedArray).flatten());
            case 1:
                return toBoolean(((ShortArray) shapedArray).flatten());
            case 2:
                return toBoolean(((IntArray) shapedArray).flatten());
            case 3:
                return toBoolean(((LongArray) shapedArray).flatten());
            case 4:
                return toBoolean(((FloatArray) shapedArray).flatten());
            case 5:
                return toBoolean(((DoubleArray) shapedArray).flatten());
            default:
                throw unsupportedDataType();
        }
    }

    private boolean[] toBoolean(ShapedVector shapedVector) {
        switch (shapedVector.getType()) {
            case 4:
                return toBoolean(((FloatShapedVector) shapedVector).getData());
            case 5:
                return toBoolean(((DoubleShapedVector) shapedVector).getData());
            default:
                throw unsupportedDataType();
        }
    }

    private static final boolean[] toBoolean(byte[] bArr) {
        boolean[] zArr = new boolean[bArr.length];
        for (int i = 0; i < bArr.length; i++) {
            zArr[i] = bArr[i] != 0;
        }
        return zArr;
    }

    private static final boolean[] toBoolean(short[] sArr) {
        boolean[] zArr = new boolean[sArr.length];
        for (int i = 0; i < sArr.length; i++) {
            zArr[i] = sArr[i] != 0;
        }
        return zArr;
    }

    private static final boolean[] toBoolean(int[] iArr) {
        boolean[] zArr = new boolean[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            zArr[i] = iArr[i] != 0;
        }
        return zArr;
    }

    private static final boolean[] toBoolean(long[] jArr) {
        boolean[] zArr = new boolean[jArr.length];
        for (int i = 0; i < jArr.length; i++) {
            zArr[i] = jArr[i] != 0;
        }
        return zArr;
    }

    private static final boolean[] toBoolean(float[] fArr) {
        boolean[] zArr = new boolean[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            zArr[i] = fArr[i] != 0.0f;
        }
        return zArr;
    }

    private static final boolean[] toBoolean(double[] dArr) {
        boolean[] zArr = new boolean[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            zArr[i] = dArr[i] != 0.0d;
        }
        return zArr;
    }

    @Override // mitiv.cost.CostFunction
    public VectorSpace getInputSpace() {
        return this.dataSpace;
    }

    @Override // mitiv.cost.CostFunction
    public double evaluate(double d, Vector vector) {
        this.dataSpace.check(vector);
        if (d == 0.0d) {
            return 0.0d;
        }
        double d2 = 0.0d;
        if (this.single) {
            float[] data = ((FloatShapedVector) this.weights).getData();
            float[] data2 = ((FloatShapedVector) vector).getData();
            float[] data3 = ((FloatShapedVector) this.data).getData();
            for (int i = 0; i < data3.length; i++) {
                float f = data2[i] - data3[i];
                d2 += f * data[i] * f;
            }
        } else {
            double[] data4 = ((DoubleShapedVector) this.weights).getData();
            double[] data5 = ((DoubleShapedVector) vector).getData();
            double[] data6 = ((DoubleShapedVector) this.data).getData();
            for (int i2 = 0; i2 < data6.length; i2++) {
                double d3 = data5[i2] - data6[i2];
                d2 += d3 * data4[i2] * d3;
            }
        }
        return (d / 2.0d) * d2;
    }

    @Override // mitiv.cost.DifferentiableCostFunction
    public double computeCostAndGradient(double d, Vector vector, Vector vector2, boolean z) {
        this.dataSpace.check(vector);
        this.dataSpace.check(vector2);
        if (d == 0.0d) {
            if (!z) {
                return 0.0d;
            }
            vector2.zero();
            return 0.0d;
        }
        double d2 = 0.0d;
        if (this.single) {
            float f = (float) d;
            float[] data = ((FloatShapedVector) vector2).getData();
            float[] data2 = ((FloatShapedVector) this.weights).getData();
            float[] data3 = ((FloatShapedVector) vector).getData();
            float[] data4 = ((FloatShapedVector) this.data).getData();
            if (z) {
                for (int i = 0; i < data4.length; i++) {
                    d2 += r0 * r0;
                    data[i] = f * data2[i] * (data3[i] - data4[i]);
                }
            } else {
                for (int i2 = 0; i2 < data4.length; i2++) {
                    d2 += r0 * r0;
                    int i3 = i2;
                    data[i3] = data[i3] + (f * data2[i2] * (data3[i2] - data4[i2]));
                }
            }
        } else {
            double[] data5 = ((DoubleShapedVector) vector2).getData();
            double[] data6 = ((DoubleShapedVector) this.weights).getData();
            double[] data7 = ((DoubleShapedVector) vector).getData();
            double[] data8 = ((DoubleShapedVector) this.data).getData();
            if (z) {
                for (int i4 = 0; i4 < data8.length; i4++) {
                    double d3 = data7[i4] - data8[i4];
                    double d4 = data6[i4] * d3;
                    d2 += d3 * d4;
                    data5[i4] = d * d4;
                }
            } else {
                for (int i5 = 0; i5 < data8.length; i5++) {
                    double d5 = data7[i5] - data8[i5];
                    double d6 = data6[i5] * d5;
                    d2 += d5 * d6;
                    int i6 = i5;
                    data5[i6] = data5[i6] + (d * d6);
                }
            }
        }
        return (d / 2.0d) * d2;
    }

    private static final NonConformableArrayException nonconformableData() {
        return new NonConformableArrayException("Data array has non conformable dimensions");
    }

    private static final IllegalArgumentException badMaskShape() {
        return new IllegalArgumentException("Mask of bad data must have the same shape as the data");
    }

    private static final IllegalArgumentException badMaskLength() {
        return new IllegalArgumentException("Mask of bad data must have the same length as the data");
    }

    private static final IllegalArgumentException noDataHasBeenSpecified() {
        return new IllegalArgumentException("Data has not yet been specified");
    }

    private static final IllegalArgumentException dataCanOnlyBeSpecifiedOnce() {
        return new IllegalArgumentException("Data can only be specified once");
    }

    private static final IllegalArgumentException weightsCanOnlyBeSpecifiedOnce() {
        return new IllegalArgumentException("Weights can only be specified or computed once");
    }

    private static final IllegalArgumentException nonfiniteDataMustHaveZeroWeight() {
        return new IllegalArgumentException("Non-finite data must have zero weight");
    }

    private static final IllegalArgumentException weightsMustBeFiniteAndNonnegative() {
        return new IllegalArgumentException("Weights must be finite and nonnegative");
    }

    private static final IllegalArgumentException unsupportedDataType() {
        return new IllegalArgumentException("Unsupported data type");
    }

    private static final IllegalArgumentException mustBeFloatingPoint() {
        return new IllegalArgumentException("Weighted data type must be 'float' or 'double'");
    }
}
