package mitiv.invpb;

import mitiv.array.ByteArray;
import mitiv.array.DoubleArray;
import mitiv.array.FloatArray;
import mitiv.array.IntArray;
import mitiv.array.LongArray;
import mitiv.array.ShapedArray;
import mitiv.array.ShortArray;
import mitiv.linalg.shaped.DoubleShapedVector;
import mitiv.linalg.shaped.FloatShapedVector;
import mitiv.linalg.shaped.ShapedVector;
import mitiv.linalg.shaped.ShapedVectorSpace;

/* loaded from: input_file:mitiv/invpb/WeightedData.class */
public class WeightedData {
    protected final ShapedVectorSpace dataSpace;
    private final boolean single;
    private ShapedVector data;
    private ShapedVector weight;
    private boolean updatePending;
    private boolean writableData;
    private boolean writableWeight;

    public WeightedData(ShapedVectorSpace shapedVectorSpace) {
        this.data = null;
        this.weight = null;
        this.updatePending = true;
        this.writableData = false;
        this.writableWeight = false;
        switch (shapedVectorSpace.getType()) {
            case 4:
                this.single = true;
                break;
            case 5:
                this.single = false;
                break;
            default:
                throw new IllegalArgumentException("Unsupported data type");
        }
        this.dataSpace = shapedVectorSpace;
    }

    public WeightedData(ShapedVector shapedVector, boolean z) {
        this(shapedVector.getSpace());
        setData(shapedVector, z);
    }

    public WeightedData(ShapedVector shapedVector) {
        this(shapedVector.getSpace());
        setData(shapedVector);
    }

    public WeightedData(ShapedVector shapedVector, boolean z, ShapedVector shapedVector2, boolean z2) {
        this(shapedVector.getSpace());
        setData(shapedVector, z);
        setWeight(shapedVector2, z2);
    }

    public WeightedData(ShapedVector shapedVector, ShapedVector shapedVector2) {
        this(shapedVector.getSpace());
        setData(shapedVector);
        setWeight(shapedVector2);
    }

    public final boolean singlePrecision() {
        return this.single;
    }

    public final ShapedVectorSpace getDataSpace() {
        return this.dataSpace;
    }

    public final ShapedVector getData() {
        if (this.updatePending) {
            update();
        }
        return this.data;
    }

    public final ShapedVector getWeight() {
        if (this.updatePending) {
            update();
        }
        return this.weight;
    }

    public void setData(ShapedVector shapedVector) {
        setData(shapedVector, false);
    }

    public void setData(ShapedVector shapedVector, boolean z) {
        shapedVector.assertBelongsTo(this.dataSpace);
        if (this.data != null) {
            throw new IllegalArgumentException("Data can only be set once");
        }
        this.data = shapedVector;
        this.writableData = z;
        this.updatePending = true;
    }

    public void setWeight(ShapedVector shapedVector) {
        setWeight(shapedVector, false);
    }

    public void setWeight(ShapedVector shapedVector, boolean z) {
        shapedVector.assertBelongsTo(this.dataSpace);
        if (this.weight != null) {
            throw new IllegalArgumentException("Weights can only be set once");
        }
        this.weight = shapedVector;
        this.writableWeight = z;
        this.updatePending = true;
    }

    public void computeWeightsFromData(double d, double d2) {
        if (this.data == null) {
            throw new IllegalArgumentException("No data has been set");
        }
    }

    public void computeWeightsFromVariance(ShapedVector shapedVector) {
    }

    public void markBadData(ShapedVector shapedVector) {
        if (!shapedVector.getShape().equals(this.dataSpace.getShape())) {
            throw new IllegalArgumentException("Mask of bad data must have the same shape as the data");
        }
        markBadData(toBoolean(shapedVector));
    }

    public void markBadData(ShapedArray shapedArray) {
        if (!shapedArray.getShape().equals(this.dataSpace.getShape())) {
            throw new IllegalArgumentException("Mask of bad data must have same the shape as the data");
        }
        markBadData(toBoolean(shapedArray));
    }

    public final void markBadData(boolean[] zArr) {
        if (this.data == null) {
            throw new IllegalArgumentException("No data has been set");
        }
        int number = this.dataSpace.getNumber();
        if (zArr.length != number) {
            throw new IllegalArgumentException("Mask of bad data must have the same length as the data");
        }
        if (this.weight == null) {
            this.weight = this.dataSpace.create();
            this.writableWeight = true;
            if (this.single) {
                float[] data = ((FloatShapedVector) this.weight).getData();
                for (int i = 0; i < number; i++) {
                    data[i] = zArr[i] ? 0.0f : 1.0f;
                }
            } else {
                double[] data2 = ((DoubleShapedVector) this.weight).getData();
                for (int i2 = 0; i2 < number; i2++) {
                    data2[i2] = zArr[i2] ? 0.0d : 1.0d;
                }
            }
            this.updatePending = true;
            return;
        }
        if (this.single) {
            float[] data3 = ((FloatShapedVector) this.weight).getData();
            for (int i3 = 0; i3 < number; i3++) {
                if (zArr[i3] && data3[i3] != 0.0f) {
                    if (!this.writableWeight) {
                        cloneWeight();
                        data3 = ((FloatShapedVector) this.weight).getData();
                    }
                    data3[i3] = 0.0f;
                    this.updatePending = true;
                }
            }
            return;
        }
        double[] data4 = ((DoubleShapedVector) this.weight).getData();
        for (int i4 = 0; i4 < number; i4++) {
            if (zArr[i4] && data4[i4] != 0.0d) {
                if (!this.writableWeight) {
                    cloneWeight();
                    data4 = ((DoubleShapedVector) this.weight).getData();
                }
                data4[i4] = 0.0d;
                this.updatePending = true;
            }
        }
    }

    private final void cloneData() {
        if (this.writableData) {
            return;
        }
        this.data = this.data.m11clone();
        this.writableData = true;
    }

    private final void cloneWeight() {
        if (this.writableWeight) {
            return;
        }
        this.weight = this.weight.m11clone();
        this.writableWeight = true;
    }

    private void update() {
        if (this.data == null) {
            throw new IllegalArgumentException("No data has been set");
        }
        if (this.weight == null) {
            this.weight = this.dataSpace.create();
            this.writableWeight = true;
            if (this.single) {
                float[] data = ((FloatShapedVector) this.data).getData();
                float[] data2 = ((FloatShapedVector) this.weight).getData();
                for (int i = 0; i < data.length; i++) {
                    if (nonfinite(data[i])) {
                        if (!this.writableData) {
                            cloneData();
                            data = ((FloatShapedVector) this.data).getData();
                        }
                        data[i] = 0.0f;
                        data2[i] = 0.0f;
                    } else {
                        data2[i] = 1.0f;
                    }
                }
            } else {
                double[] data3 = ((DoubleShapedVector) this.data).getData();
                double[] data4 = ((DoubleShapedVector) this.weight).getData();
                for (int i2 = 0; i2 < data3.length; i2++) {
                    if (nonfinite(data3[i2])) {
                        if (!this.writableData) {
                            cloneData();
                            data3 = ((DoubleShapedVector) this.data).getData();
                        }
                        data3[i2] = 0.0d;
                        data4[i2] = 0.0d;
                    } else {
                        data4[i2] = 1.0d;
                    }
                }
            }
        } else if (this.single) {
            float[] data5 = ((FloatShapedVector) this.data).getData();
            float[] data6 = ((FloatShapedVector) this.weight).getData();
            for (int i3 = 0; i3 < data5.length; i3++) {
                if (nonfinite(data6[i3]) || data6[i3] < 0.0f) {
                    throw new IllegalArgumentException("Weights must be finite and nonnegative");
                }
                if (nonfinite(data5[i3])) {
                    if (data6[i3] > 0.0f) {
                        throw new IllegalArgumentException("Non-finite data must have zero weight");
                    }
                    if (!this.writableData) {
                        cloneData();
                        data5 = ((FloatShapedVector) this.data).getData();
                    }
                    data5[i3] = 0.0f;
                }
            }
        } else {
            double[] data7 = ((DoubleShapedVector) this.data).getData();
            double[] data8 = ((DoubleShapedVector) this.weight).getData();
            for (int i4 = 0; i4 < data7.length; i4++) {
                if (nonfinite(data8[i4]) || data8[i4] < 0.0d) {
                    throw new IllegalArgumentException("Weights must be finite and nonnegative");
                }
                if (nonfinite(data7[i4])) {
                    if (data8[i4] > 0.0d) {
                        throw new IllegalArgumentException("Non-finite data must have zero weight");
                    }
                    if (!this.writableData) {
                        cloneData();
                        data7 = ((DoubleShapedVector) this.data).getData();
                    }
                    data7[i4] = 0.0d;
                }
            }
        }
        this.updatePending = false;
    }

    private static final boolean nonfinite(float f) {
        return Float.isInfinite(f) || Float.isNaN(f);
    }

    private static final boolean nonfinite(double d) {
        return Double.isInfinite(d) || Double.isNaN(d);
    }

    private boolean[] toBoolean(ShapedArray shapedArray) {
        switch (shapedArray.getType()) {
            case 0:
                return toBoolean(((ByteArray) shapedArray).flatten());
            case 1:
                return toBoolean(((ShortArray) shapedArray).flatten());
            case 2:
                return toBoolean(((IntArray) shapedArray).flatten());
            case 3:
                return toBoolean(((LongArray) shapedArray).flatten());
            case 4:
                return toBoolean(((FloatArray) shapedArray).flatten());
            case 5:
                return toBoolean(((DoubleArray) shapedArray).flatten());
            default:
                throw new IllegalArgumentException("Unsupported data type");
        }
    }

    private boolean[] toBoolean(ShapedVector shapedVector) {
        switch (shapedVector.getType()) {
            case 4:
                return toBoolean(((FloatShapedVector) shapedVector).getData());
            case 5:
                return toBoolean(((DoubleShapedVector) shapedVector).getData());
            default:
                throw new IllegalArgumentException("Unsupported data type");
        }
    }

    private static final boolean[] toBoolean(byte[] bArr) {
        boolean[] zArr = new boolean[bArr.length];
        for (int i = 0; i < bArr.length; i++) {
            zArr[i] = bArr[i] != 0;
        }
        return zArr;
    }

    private static final boolean[] toBoolean(short[] sArr) {
        boolean[] zArr = new boolean[sArr.length];
        for (int i = 0; i < sArr.length; i++) {
            zArr[i] = sArr[i] != 0;
        }
        return zArr;
    }

    private static final boolean[] toBoolean(int[] iArr) {
        boolean[] zArr = new boolean[iArr.length];
        for (int i = 0; i < iArr.length; i++) {
            zArr[i] = iArr[i] != 0;
        }
        return zArr;
    }

    private static final boolean[] toBoolean(long[] jArr) {
        boolean[] zArr = new boolean[jArr.length];
        for (int i = 0; i < jArr.length; i++) {
            zArr[i] = jArr[i] != 0;
        }
        return zArr;
    }

    private static final boolean[] toBoolean(float[] fArr) {
        boolean[] zArr = new boolean[fArr.length];
        for (int i = 0; i < fArr.length; i++) {
            zArr[i] = fArr[i] != 0.0f;
        }
        return zArr;
    }

    private static final boolean[] toBoolean(double[] dArr) {
        boolean[] zArr = new boolean[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            zArr[i] = dArr[i] != 0.0d;
        }
        return zArr;
    }
}
