/*
 * Decompiled with CFR 0.152.
 */
package mitiv.cost;

import mitiv.array.ByteArray;
import mitiv.array.DoubleArray;
import mitiv.array.FloatArray;
import mitiv.array.IntArray;
import mitiv.array.LongArray;
import mitiv.array.ShapedArray;
import mitiv.array.ShortArray;
import mitiv.base.ArrayDescriptor;
import mitiv.base.Shape;
import mitiv.cost.DifferentiableCostFunction;
import mitiv.exception.NonConformableArrayException;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;
import mitiv.linalg.shaped.DoubleShapedVector;
import mitiv.linalg.shaped.DoubleShapedVectorSpace;
import mitiv.linalg.shaped.FloatShapedVector;
import mitiv.linalg.shaped.FloatShapedVectorSpace;
import mitiv.linalg.shaped.ShapedVector;
import mitiv.linalg.shaped.ShapedVectorSpace;
import mitiv.utils.WeightFactory;

public class WeightedData
implements DifferentiableCostFunction {
    private final ShapedVectorSpace dataSpace;
    private final boolean single;
    private ShapedVector data = null;
    private ShapedVector weights = null;
    private boolean updatePending = true;
    private boolean writableData = false;
    private boolean writableWeights = false;
    private int validDataNumber = 0;

    public WeightedData(ArrayDescriptor arrayDescriptor) {
        this(arrayDescriptor.getType(), arrayDescriptor.getShape());
    }

    public WeightedData(ShapedVectorSpace shapedVectorSpace) {
        switch (shapedVectorSpace.getType()) {
            case 4: {
                this.single = true;
                break;
            }
            case 5: {
                this.single = false;
                break;
            }
            default: {
                throw WeightedData.mustBeFloatingPoint();
            }
        }
        this.dataSpace = shapedVectorSpace;
    }

    public WeightedData(int n, Shape shape) {
        switch (n) {
            case 4: {
                this.single = true;
                this.dataSpace = new FloatShapedVectorSpace(shape);
                break;
            }
            case 5: {
                this.single = false;
                this.dataSpace = new DoubleShapedVectorSpace(shape);
                break;
            }
            default: {
                throw WeightedData.mustBeFloatingPoint();
            }
        }
    }

    public WeightedData(int n, int ... nArray) {
        this(n, new Shape(nArray));
    }

    public WeightedData(ShapedArray shapedArray, boolean bl) {
        this(shapedArray.getType(), shapedArray.getShape());
        this.setData(shapedArray, bl);
    }

    public WeightedData(ShapedArray shapedArray) {
        this(shapedArray, false);
    }

    public WeightedData(ShapedVector shapedVector, boolean bl) {
        this(shapedVector.getSpace());
        this.setData(shapedVector, bl);
    }

    public WeightedData(ShapedVector shapedVector) {
        this(shapedVector.getSpace());
        this.setData(shapedVector);
    }

    public WeightedData(ShapedArray shapedArray, ShapedArray shapedArray2) {
        this(Math.max(shapedArray.getType(), shapedArray2.getType()), shapedArray.getShape());
        this.setData(shapedArray);
        this.setWeights(shapedArray2);
    }

    public WeightedData(ShapedVector shapedVector, boolean bl, ShapedVector shapedVector2, boolean bl2) {
        this(shapedVector.getSpace());
        this.setData(shapedVector, bl);
        this.setWeights(shapedVector2, bl2);
    }

    public WeightedData(ShapedVector shapedVector, ShapedVector shapedVector2) {
        this(shapedVector.getSpace());
        this.setData(shapedVector);
        this.setWeights(shapedVector2);
    }

    public final boolean isSinglePrecision() {
        return this.single;
    }

    public final ShapedVectorSpace getDataSpace() {
        return this.dataSpace;
    }

    public final ShapedVector getData() {
        if (this.updatePending) {
            this.update();
        }
        return this.data;
    }

    public final ShapedVector getWeights() {
        if (this.updatePending) {
            this.update();
        }
        return this.weights;
    }

    public final int getValidDataNumber() {
        if (this.updatePending) {
            this.update();
        }
        return this.validDataNumber;
    }

    public double getWeightedMean() {
        if (this.updatePending) {
            this.update();
        }
        if (this.validDataNumber < 1) {
            return Double.NaN;
        }
        double d = 0.0;
        double d2 = 0.0;
        if (this.single) {
            float[] fArray = ((FloatShapedVector)this.weights).getData();
            float[] fArray2 = ((FloatShapedVector)this.data).getData();
            for (int i = 0; i < fArray2.length; ++i) {
                d2 += (double)fArray[i];
                d += (double)(fArray[i] * fArray2[i]);
            }
        } else {
            double[] dArray = ((DoubleShapedVector)this.weights).getData();
            double[] dArray2 = ((DoubleShapedVector)this.data).getData();
            for (int i = 0; i < dArray2.length; ++i) {
                d2 += dArray[i];
                d += dArray[i] * dArray2[i];
            }
        }
        return d / d2;
    }

    public void setData(ShapedArray shapedArray) {
        this.setData(shapedArray, false);
    }

    public void setData(ShapedArray shapedArray, boolean bl) {
        if (!shapedArray.getShape().equals(this.dataSpace.getShape())) {
            throw WeightedData.nonconformableData();
        }
        if (shapedArray.getType() != this.dataSpace.getType()) {
            shapedArray = this.single ? shapedArray.toFloat() : shapedArray.toDouble();
            bl = true;
        } else if (!shapedArray.isFlat()) {
            bl = true;
        }
        ShapedVector shapedVector = this.single ? new FloatShapedVector((FloatShapedVectorSpace)this.dataSpace, ((FloatArray)shapedArray).flatten()) : new DoubleShapedVector((DoubleShapedVectorSpace)this.dataSpace, ((DoubleArray)shapedArray).flatten());
        this.setData(shapedVector, bl);
    }

    public void setData(ShapedVector shapedVector) {
        this.setData(shapedVector, false);
    }

    public void setData(ShapedVector shapedVector, boolean bl) {
        shapedVector.assertBelongsTo(this.dataSpace);
        if (this.data != null) {
            throw WeightedData.dataCanOnlyBeSpecifiedOnce();
        }
        this.data = shapedVector;
        this.writableData = bl;
        this.updatePending = true;
    }

    public void setWeights(ShapedArray shapedArray) {
        this.setWeights(shapedArray, false);
    }

    public void setWeights(ShapedArray shapedArray, boolean bl) {
        if (!shapedArray.getShape().equals(this.dataSpace.getShape())) {
            throw WeightedData.nonconformableData();
        }
        if (shapedArray.getType() != this.dataSpace.getType()) {
            shapedArray = this.single ? shapedArray.toFloat() : shapedArray.toDouble();
            bl = true;
        } else if (!shapedArray.isFlat()) {
            bl = true;
        }
        ShapedVector shapedVector = this.single ? new FloatShapedVector((FloatShapedVectorSpace)this.dataSpace, ((FloatArray)shapedArray).flatten()) : new DoubleShapedVector((DoubleShapedVectorSpace)this.dataSpace, ((DoubleArray)shapedArray).flatten());
        this.setWeights(shapedVector, bl);
    }

    public void setWeights(ShapedVector shapedVector) {
        this.setWeights(shapedVector, false);
    }

    public void setWeights(ShapedVector shapedVector, boolean bl) {
        shapedVector.assertBelongsTo(this.dataSpace);
        if (this.weights != null) {
            throw WeightedData.weightsCanOnlyBeSpecifiedOnce();
        }
        this.weights = shapedVector;
        this.writableWeights = bl;
        this.updatePending = true;
    }

    public void computeWeightsFromData(double d, double d2) {
        if (this.data == null) {
            WeightedData.noDataHasBeenSpecified();
        }
        if (this.weights != null) {
            throw WeightedData.weightsCanOnlyBeSpecifiedOnce();
        }
        this.weights = this.data.create();
        this.writableWeights = true;
        if (this.single) {
            float[] fArray = ((FloatShapedVector)this.weights).getData();
            float[] fArray2 = ((FloatShapedVector)this.data).getData();
            this.validDataNumber = WeightFactory.computeWeightsFromData(fArray, fArray2, (float)d, (float)d2);
        } else {
            double[] dArray = ((DoubleShapedVector)this.weights).getData();
            double[] dArray2 = ((DoubleShapedVector)this.data).getData();
            this.validDataNumber = WeightFactory.computeWeightsFromData(dArray, dArray2, d, d2);
        }
        this.updatePending = true;
    }

    public void markBadData(ShapedVector shapedVector) {
        if (!shapedVector.getShape().equals(this.dataSpace.getShape())) {
            throw WeightedData.badMaskShape();
        }
        this.markBadData(this.toBoolean(shapedVector));
    }

    public void markBadData(ShapedArray shapedArray) {
        if (!shapedArray.getShape().equals(this.dataSpace.getShape())) {
            throw WeightedData.badMaskShape();
        }
        this.markBadData(this.toBoolean(shapedArray));
    }

    public final void markBadData(boolean[] blArray) {
        if (this.data == null) {
            throw WeightedData.noDataHasBeenSpecified();
        }
        int n = this.data.getNumber();
        if (blArray.length != n) {
            throw WeightedData.badMaskLength();
        }
        if (this.weights == null) {
            this.weights = this.data.create();
            this.writableWeights = true;
            if (this.single) {
                float[] fArray = ((FloatShapedVector)this.weights).getData();
                for (int i = 0; i < n; ++i) {
                    fArray[i] = blArray[i] ? 0.0f : 1.0f;
                }
            } else {
                double[] dArray = ((DoubleShapedVector)this.weights).getData();
                for (int i = 0; i < n; ++i) {
                    dArray[i] = blArray[i] ? 0.0 : 1.0;
                }
            }
            this.updatePending = true;
        } else if (this.single) {
            float[] fArray = ((FloatShapedVector)this.weights).getData();
            for (int i = 0; i < n; ++i) {
                if (!blArray[i] || fArray[i] == 0.0f) continue;
                if (!this.writableWeights) {
                    this.cloneWeights();
                    fArray = ((FloatShapedVector)this.weights).getData();
                }
                fArray[i] = 0.0f;
                this.updatePending = true;
            }
        } else {
            double[] dArray = ((DoubleShapedVector)this.weights).getData();
            for (int i = 0; i < n; ++i) {
                if (!blArray[i] || dArray[i] == 0.0) continue;
                if (!this.writableWeights) {
                    this.cloneWeights();
                    dArray = ((DoubleShapedVector)this.weights).getData();
                }
                dArray[i] = 0.0;
                this.updatePending = true;
            }
        }
    }

    private final void cloneData() {
        if (!this.writableData) {
            this.data = this.data.clone();
            this.writableData = true;
        }
    }

    private final void cloneWeights() {
        if (!this.writableWeights) {
            this.weights = this.weights.clone();
            this.writableWeights = true;
        }
    }

    private void update() {
        if (this.data == null) {
            throw WeightedData.noDataHasBeenSpecified();
        }
        int n = 0;
        if (this.weights == null) {
            this.weights = this.dataSpace.create();
            this.writableWeights = true;
            if (this.single) {
                float[] fArray = ((FloatShapedVector)this.data).getData();
                float[] fArray2 = ((FloatShapedVector)this.weights).getData();
                for (int i = 0; i < fArray.length; ++i) {
                    if (WeightedData.nonfinite(fArray[i])) {
                        if (!this.writableData) {
                            this.cloneData();
                            fArray = ((FloatShapedVector)this.data).getData();
                        }
                        fArray[i] = 0.0f;
                        fArray2[i] = 0.0f;
                        continue;
                    }
                    fArray2[i] = 1.0f;
                    ++n;
                }
            } else {
                double[] dArray = ((DoubleShapedVector)this.data).getData();
                double[] dArray2 = ((DoubleShapedVector)this.weights).getData();
                for (int i = 0; i < dArray.length; ++i) {
                    if (WeightedData.nonfinite(dArray[i])) {
                        if (!this.writableData) {
                            this.cloneData();
                            dArray = ((DoubleShapedVector)this.data).getData();
                        }
                        dArray[i] = 0.0;
                        dArray2[i] = 0.0;
                        continue;
                    }
                    dArray2[i] = 1.0;
                    ++n;
                }
            }
        } else if (this.single) {
            float[] fArray = ((FloatShapedVector)this.data).getData();
            float[] fArray3 = ((FloatShapedVector)this.weights).getData();
            for (int i = 0; i < fArray.length; ++i) {
                if (WeightedData.nonfinite(fArray3[i]) || fArray3[i] < 0.0f) {
                    throw WeightedData.weightsMustBeFiniteAndNonnegative();
                }
                if (WeightedData.nonfinite(fArray[i])) {
                    if (fArray3[i] > 0.0f) {
                        throw WeightedData.nonfiniteDataMustHaveZeroWeight();
                    }
                    if (!this.writableData) {
                        this.cloneData();
                        fArray = ((FloatShapedVector)this.data).getData();
                    }
                    fArray[i] = 0.0f;
                }
                if (!(fArray3[i] > 0.0f)) continue;
                ++n;
            }
        } else {
            double[] dArray = ((DoubleShapedVector)this.data).getData();
            double[] dArray3 = ((DoubleShapedVector)this.weights).getData();
            for (int i = 0; i < dArray.length; ++i) {
                if (WeightedData.nonfinite(dArray3[i]) || dArray3[i] < 0.0) {
                    throw WeightedData.weightsMustBeFiniteAndNonnegative();
                }
                if (WeightedData.nonfinite(dArray[i])) {
                    if (dArray3[i] > 0.0) {
                        throw WeightedData.nonfiniteDataMustHaveZeroWeight();
                    }
                    if (!this.writableData) {
                        this.cloneData();
                        dArray = ((DoubleShapedVector)this.data).getData();
                    }
                    dArray[i] = 0.0;
                }
                if (!(dArray3[i] > 0.0)) continue;
                ++n;
            }
        }
        this.updatePending = false;
        this.validDataNumber = n;
    }

    private static final boolean nonfinite(float f) {
        return Float.isInfinite(f) || Float.isNaN(f);
    }

    private static final boolean nonfinite(double d) {
        return Double.isInfinite(d) || Double.isNaN(d);
    }

    private boolean[] toBoolean(ShapedArray shapedArray) {
        switch (shapedArray.getType()) {
            case 0: {
                return WeightedData.toBoolean(((ByteArray)shapedArray).flatten());
            }
            case 1: {
                return WeightedData.toBoolean(((ShortArray)shapedArray).flatten());
            }
            case 2: {
                return WeightedData.toBoolean(((IntArray)shapedArray).flatten());
            }
            case 3: {
                return WeightedData.toBoolean(((LongArray)shapedArray).flatten());
            }
            case 4: {
                return WeightedData.toBoolean(((FloatArray)shapedArray).flatten());
            }
            case 5: {
                return WeightedData.toBoolean(((DoubleArray)shapedArray).flatten());
            }
        }
        throw WeightedData.unsupportedDataType();
    }

    private boolean[] toBoolean(ShapedVector shapedVector) {
        switch (shapedVector.getType()) {
            case 4: {
                return WeightedData.toBoolean(((FloatShapedVector)shapedVector).getData());
            }
            case 5: {
                return WeightedData.toBoolean(((DoubleShapedVector)shapedVector).getData());
            }
        }
        throw WeightedData.unsupportedDataType();
    }

    private static final boolean[] toBoolean(byte[] byArray) {
        boolean[] blArray = new boolean[byArray.length];
        for (int i = 0; i < byArray.length; ++i) {
            blArray[i] = byArray[i] != 0;
        }
        return blArray;
    }

    private static final boolean[] toBoolean(short[] sArray) {
        boolean[] blArray = new boolean[sArray.length];
        for (int i = 0; i < sArray.length; ++i) {
            blArray[i] = sArray[i] != 0;
        }
        return blArray;
    }

    private static final boolean[] toBoolean(int[] nArray) {
        boolean[] blArray = new boolean[nArray.length];
        for (int i = 0; i < nArray.length; ++i) {
            blArray[i] = nArray[i] != 0;
        }
        return blArray;
    }

    private static final boolean[] toBoolean(long[] lArray) {
        boolean[] blArray = new boolean[lArray.length];
        for (int i = 0; i < lArray.length; ++i) {
            blArray[i] = lArray[i] != 0L;
        }
        return blArray;
    }

    private static final boolean[] toBoolean(float[] fArray) {
        boolean[] blArray = new boolean[fArray.length];
        for (int i = 0; i < fArray.length; ++i) {
            blArray[i] = fArray[i] != 0.0f;
        }
        return blArray;
    }

    private static final boolean[] toBoolean(double[] dArray) {
        boolean[] blArray = new boolean[dArray.length];
        for (int i = 0; i < dArray.length; ++i) {
            blArray[i] = dArray[i] != 0.0;
        }
        return blArray;
    }

    @Override
    public VectorSpace getInputSpace() {
        return this.dataSpace;
    }

    @Override
    public double evaluate(double d, Vector vector) {
        this.dataSpace.check(vector);
        if (d == 0.0) {
            return 0.0;
        }
        double d2 = 0.0;
        if (this.single) {
            float[] fArray = ((FloatShapedVector)this.weights).getData();
            float[] fArray2 = ((FloatShapedVector)vector).getData();
            float[] fArray3 = ((FloatShapedVector)this.data).getData();
            for (int i = 0; i < fArray3.length; ++i) {
                float f = fArray2[i] - fArray3[i];
                d2 += (double)(f * fArray[i] * f);
            }
        } else {
            double[] dArray = ((DoubleShapedVector)this.weights).getData();
            double[] dArray2 = ((DoubleShapedVector)vector).getData();
            double[] dArray3 = ((DoubleShapedVector)this.data).getData();
            for (int i = 0; i < dArray3.length; ++i) {
                double d3 = dArray2[i] - dArray3[i];
                d2 += d3 * dArray[i] * d3;
            }
        }
        return d / 2.0 * d2;
    }

    @Override
    public double computeCostAndGradient(double d, Vector vector, Vector vector2, boolean bl) {
        this.dataSpace.check(vector);
        this.dataSpace.check(vector2);
        if (d == 0.0) {
            if (bl) {
                vector2.zero();
            }
            return 0.0;
        }
        double d2 = 0.0;
        if (this.single) {
            float f = (float)d;
            float[] fArray = ((FloatShapedVector)vector2).getData();
            float[] fArray2 = ((FloatShapedVector)this.weights).getData();
            float[] fArray3 = ((FloatShapedVector)vector).getData();
            float[] fArray4 = ((FloatShapedVector)this.data).getData();
            if (bl) {
                for (int i = 0; i < fArray4.length; ++i) {
                    float f2 = fArray3[i] - fArray4[i];
                    float f3 = fArray2[i] * f2;
                    d2 += (double)(f2 * f3);
                    fArray[i] = f * f3;
                }
            } else {
                int n = 0;
                while (n < fArray4.length) {
                    float f4 = fArray3[n] - fArray4[n];
                    float f5 = fArray2[n] * f4;
                    d2 += (double)(f4 * f5);
                    int n2 = n++;
                    fArray[n2] = fArray[n2] + f * f5;
                }
            }
        } else {
            double d3 = d;
            double[] dArray = ((DoubleShapedVector)vector2).getData();
            double[] dArray2 = ((DoubleShapedVector)this.weights).getData();
            double[] dArray3 = ((DoubleShapedVector)vector).getData();
            double[] dArray4 = ((DoubleShapedVector)this.data).getData();
            if (bl) {
                for (int i = 0; i < dArray4.length; ++i) {
                    double d4 = dArray3[i] - dArray4[i];
                    double d5 = dArray2[i] * d4;
                    d2 += d4 * d5;
                    dArray[i] = d3 * d5;
                }
            } else {
                int n = 0;
                while (n < dArray4.length) {
                    double d6 = dArray3[n] - dArray4[n];
                    double d7 = dArray2[n] * d6;
                    d2 += d6 * d7;
                    int n3 = n++;
                    dArray[n3] = dArray[n3] + d3 * d7;
                }
            }
        }
        return d / 2.0 * d2;
    }

    private static final NonConformableArrayException nonconformableData() {
        return new NonConformableArrayException("Data array has non conformable dimensions");
    }

    private static final IllegalArgumentException badMaskShape() {
        return new IllegalArgumentException("Mask of bad data must have the same shape as the data");
    }

    private static final IllegalArgumentException badMaskLength() {
        return new IllegalArgumentException("Mask of bad data must have the same length as the data");
    }

    private static final IllegalArgumentException noDataHasBeenSpecified() {
        return new IllegalArgumentException("Data has not yet been specified");
    }

    private static final IllegalArgumentException dataCanOnlyBeSpecifiedOnce() {
        return new IllegalArgumentException("Data can only be specified once");
    }

    private static final IllegalArgumentException weightsCanOnlyBeSpecifiedOnce() {
        return new IllegalArgumentException("Weights can only be specified or computed once");
    }

    private static final IllegalArgumentException nonfiniteDataMustHaveZeroWeight() {
        return new IllegalArgumentException("Non-finite data must have zero weight");
    }

    private static final IllegalArgumentException weightsMustBeFiniteAndNonnegative() {
        return new IllegalArgumentException("Weights must be finite and nonnegative");
    }

    private static final IllegalArgumentException unsupportedDataType() {
        return new IllegalArgumentException("Unsupported data type");
    }

    private static final IllegalArgumentException mustBeFloatingPoint() {
        return new IllegalArgumentException("Weighted data type must be 'float' or 'double'");
    }
}

