/*
 * Decompiled with CFR 0.152.
 */
package mitiv.cost;

import mitiv.array.ByteArray;
import mitiv.array.DoubleArray;
import mitiv.array.FloatArray;
import mitiv.array.IntArray;
import mitiv.array.LongArray;
import mitiv.array.ShapedArray;
import mitiv.array.ShortArray;
import mitiv.base.ArrayDescriptor;
import mitiv.base.Shape;
import mitiv.cost.DifferentiableCostFunction;
import mitiv.exception.NonConformableArrayException;
import mitiv.linalg.Vector;
import mitiv.linalg.VectorSpace;
import mitiv.linalg.shaped.DoubleShapedVector;
import mitiv.linalg.shaped.DoubleShapedVectorSpace;
import mitiv.linalg.shaped.FloatShapedVector;
import mitiv.linalg.shaped.FloatShapedVectorSpace;
import mitiv.linalg.shaped.ShapedVector;
import mitiv.linalg.shaped.ShapedVectorSpace;
import mitiv.utils.WeightFactory;

public class WeightedData
implements DifferentiableCostFunction {
    private final ShapedVectorSpace dataSpace;
    private final boolean single;
    private ShapedVector data = null;
    private ShapedVector weights = null;
    private boolean updatePending = true;
    private boolean writableData = false;
    private boolean writableWeights = false;
    private int validDataNumber = 0;

    public WeightedData(ArrayDescriptor descr) {
        this(descr.getType(), descr.getShape());
    }

    public WeightedData(ShapedVectorSpace space) {
        switch (space.getType()) {
            case 4: {
                this.single = true;
                break;
            }
            case 5: {
                this.single = false;
                break;
            }
            default: {
                throw WeightedData.mustBeFloatingPoint();
            }
        }
        this.dataSpace = space;
    }

    public WeightedData(int type, Shape shape) {
        switch (type) {
            case 4: {
                this.single = true;
                this.dataSpace = new FloatShapedVectorSpace(shape);
                break;
            }
            case 5: {
                this.single = false;
                this.dataSpace = new DoubleShapedVectorSpace(shape);
                break;
            }
            default: {
                throw WeightedData.mustBeFloatingPoint();
            }
        }
    }

    public WeightedData(int type, int ... dims) {
        this(type, new Shape(dims));
    }

    public WeightedData(ShapedArray data, boolean writable) {
        this(data.getType(), data.getShape());
        this.setData(data, writable);
    }

    public WeightedData(ShapedArray data) {
        this(data, false);
    }

    public WeightedData(ShapedVector data, boolean writable) {
        this(data.getSpace());
        this.setData(data, writable);
    }

    public WeightedData(ShapedVector data) {
        this(data.getSpace());
        this.setData(data);
    }

    public WeightedData(ShapedArray data, ShapedArray weights) {
        this(Math.max(data.getType(), weights.getType()), data.getShape());
        this.setData(data);
        this.setWeights(weights);
    }

    public WeightedData(ShapedVector data, boolean writableData, ShapedVector weights, boolean writableWeights) {
        this(data.getSpace());
        this.setData(data, writableData);
        this.setWeights(weights, writableWeights);
    }

    public WeightedData(ShapedVector data, ShapedVector weights) {
        this(data.getSpace());
        this.setData(data);
        this.setWeights(weights);
    }

    public final boolean isSinglePrecision() {
        return this.single;
    }

    public final ShapedVectorSpace getDataSpace() {
        return this.dataSpace;
    }

    public final ShapedVector getData() {
        if (this.updatePending) {
            this.update();
        }
        return this.data;
    }

    public final ShapedVector getWeights() {
        if (this.updatePending) {
            this.update();
        }
        return this.weights;
    }

    public final int getValidDataNumber() {
        if (this.updatePending) {
            this.update();
        }
        return this.validDataNumber;
    }

    public double getWeightedMean() {
        if (this.updatePending) {
            this.update();
        }
        if (this.validDataNumber < 1) {
            return Double.NaN;
        }
        double swy = 0.0;
        double sw = 0.0;
        if (this.single) {
            float[] w = ((FloatShapedVector)this.weights).getData();
            float[] y = ((FloatShapedVector)this.data).getData();
            for (int i = 0; i < y.length; ++i) {
                sw += (double)w[i];
                swy += (double)(w[i] * y[i]);
            }
        } else {
            double[] w = ((DoubleShapedVector)this.weights).getData();
            double[] y = ((DoubleShapedVector)this.data).getData();
            for (int i = 0; i < y.length; ++i) {
                sw += w[i];
                swy += w[i] * y[i];
            }
        }
        return swy / sw;
    }

    public void setData(ShapedArray arr) {
        this.setData(arr, false);
    }

    public void setData(ShapedArray arr, boolean writable) {
        if (!arr.getShape().equals(this.dataSpace.getShape())) {
            throw WeightedData.nonconformableData();
        }
        if (arr.getType() != this.dataSpace.getType()) {
            arr = this.single ? arr.toFloat() : arr.toDouble();
            writable = true;
        } else if (!arr.isFlat()) {
            writable = true;
        }
        ShapedVector vec = this.single ? new FloatShapedVector((FloatShapedVectorSpace)this.dataSpace, ((FloatArray)arr).flatten()) : new DoubleShapedVector((DoubleShapedVectorSpace)this.dataSpace, ((DoubleArray)arr).flatten());
        this.setData(vec, writable);
    }

    public void setData(ShapedVector data) {
        this.setData(data, false);
    }

    public void setData(ShapedVector data, boolean writable) {
        data.assertBelongsTo(this.dataSpace);
        if (this.data != null) {
            throw WeightedData.dataCanOnlyBeSpecifiedOnce();
        }
        this.data = data;
        this.writableData = writable;
        this.updatePending = true;
    }

    public void setWeights(ShapedArray arr) {
        this.setWeights(arr, false);
    }

    public void setWeights(ShapedArray arr, boolean writable) {
        if (!arr.getShape().equals(this.dataSpace.getShape())) {
            throw WeightedData.nonconformableData();
        }
        if (arr.getType() != this.dataSpace.getType()) {
            arr = this.single ? arr.toFloat() : arr.toDouble();
            writable = true;
        } else if (!arr.isFlat()) {
            writable = true;
        }
        ShapedVector vec = this.single ? new FloatShapedVector((FloatShapedVectorSpace)this.dataSpace, ((FloatArray)arr).flatten()) : new DoubleShapedVector((DoubleShapedVectorSpace)this.dataSpace, ((DoubleArray)arr).flatten());
        this.setWeights(vec, writable);
    }

    public void setWeights(ShapedVector weights) {
        this.setWeights(weights, false);
    }

    public void setWeights(ShapedVector weights, boolean writable) {
        weights.assertBelongsTo(this.dataSpace);
        if (this.weights != null) {
            throw WeightedData.weightsCanOnlyBeSpecifiedOnce();
        }
        this.weights = weights;
        this.writableWeights = writable;
        this.updatePending = true;
    }

    public void computeWeightsFromData(double alpha, double beta) {
        if (this.data == null) {
            WeightedData.noDataHasBeenSpecified();
        }
        if (this.weights != null) {
            throw WeightedData.weightsCanOnlyBeSpecifiedOnce();
        }
        this.weights = this.data.create();
        this.writableWeights = true;
        if (this.single) {
            float[] wgt = ((FloatShapedVector)this.weights).getData();
            float[] dat = ((FloatShapedVector)this.data).getData();
            this.validDataNumber = WeightFactory.computeWeightsFromData(wgt, dat, (float)alpha, (float)beta);
        } else {
            double[] wgt = ((DoubleShapedVector)this.weights).getData();
            double[] dat = ((DoubleShapedVector)this.data).getData();
            this.validDataNumber = WeightFactory.computeWeightsFromData(wgt, dat, alpha, beta);
        }
        this.updatePending = true;
    }

    public void markBadData(ShapedVector bad) {
        if (!bad.getShape().equals(this.dataSpace.getShape())) {
            throw WeightedData.badMaskShape();
        }
        this.markBadData(this.toBoolean(bad));
    }

    public void markBadData(ShapedArray bad) {
        if (!bad.getShape().equals(this.dataSpace.getShape())) {
            throw WeightedData.badMaskShape();
        }
        this.markBadData(this.toBoolean(bad));
    }

    public final void markBadData(boolean[] bad) {
        if (this.data == null) {
            throw WeightedData.noDataHasBeenSpecified();
        }
        int len = this.data.getNumber();
        if (bad.length != len) {
            throw WeightedData.badMaskLength();
        }
        if (this.weights == null) {
            this.weights = this.data.create();
            this.writableWeights = true;
            if (this.single) {
                float one = 1.0f;
                float zero = 0.0f;
                float[] wgt = ((FloatShapedVector)this.weights).getData();
                for (int i = 0; i < len; ++i) {
                    wgt[i] = bad[i] ? 0.0f : 1.0f;
                }
            } else {
                double one = 1.0;
                double zero = 0.0;
                double[] wgt = ((DoubleShapedVector)this.weights).getData();
                for (int i = 0; i < len; ++i) {
                    wgt[i] = bad[i] ? 0.0 : 1.0;
                }
            }
            this.updatePending = true;
        } else if (this.single) {
            float zero = 0.0f;
            float[] wgt = ((FloatShapedVector)this.weights).getData();
            for (int i = 0; i < len; ++i) {
                if (!bad[i] || wgt[i] == 0.0f) continue;
                if (!this.writableWeights) {
                    this.cloneWeights();
                    wgt = ((FloatShapedVector)this.weights).getData();
                }
                wgt[i] = 0.0f;
                this.updatePending = true;
            }
        } else {
            double zero = 0.0;
            double[] wgt = ((DoubleShapedVector)this.weights).getData();
            for (int i = 0; i < len; ++i) {
                if (!bad[i] || wgt[i] == 0.0) continue;
                if (!this.writableWeights) {
                    this.cloneWeights();
                    wgt = ((DoubleShapedVector)this.weights).getData();
                }
                wgt[i] = 0.0;
                this.updatePending = true;
            }
        }
    }

    private final void cloneData() {
        if (!this.writableData) {
            this.data = this.data.clone();
            this.writableData = true;
        }
    }

    private final void cloneWeights() {
        if (!this.writableWeights) {
            this.weights = this.weights.clone();
            this.writableWeights = true;
        }
    }

    private void update() {
        if (this.data == null) {
            throw WeightedData.noDataHasBeenSpecified();
        }
        int count = 0;
        if (this.weights == null) {
            this.weights = this.dataSpace.create();
            this.writableWeights = true;
            if (this.single) {
                float one = 1.0f;
                float zero = 0.0f;
                float[] dat = ((FloatShapedVector)this.data).getData();
                float[] wgt = ((FloatShapedVector)this.weights).getData();
                for (int i = 0; i < dat.length; ++i) {
                    if (WeightedData.nonfinite(dat[i])) {
                        if (!this.writableData) {
                            this.cloneData();
                            dat = ((FloatShapedVector)this.data).getData();
                        }
                        dat[i] = 0.0f;
                        wgt[i] = 0.0f;
                        continue;
                    }
                    wgt[i] = 1.0f;
                    ++count;
                }
            } else {
                double one = 1.0;
                double zero = 0.0;
                double[] dat = ((DoubleShapedVector)this.data).getData();
                double[] wgt = ((DoubleShapedVector)this.weights).getData();
                for (int i = 0; i < dat.length; ++i) {
                    if (WeightedData.nonfinite(dat[i])) {
                        if (!this.writableData) {
                            this.cloneData();
                            dat = ((DoubleShapedVector)this.data).getData();
                        }
                        dat[i] = 0.0;
                        wgt[i] = 0.0;
                        continue;
                    }
                    wgt[i] = 1.0;
                    ++count;
                }
            }
        } else if (this.single) {
            float zero = 0.0f;
            float[] dat = ((FloatShapedVector)this.data).getData();
            float[] wgt = ((FloatShapedVector)this.weights).getData();
            for (int i = 0; i < dat.length; ++i) {
                if (WeightedData.nonfinite(wgt[i]) || wgt[i] < 0.0f) {
                    throw WeightedData.weightsMustBeFiniteAndNonnegative();
                }
                if (WeightedData.nonfinite(dat[i])) {
                    if (wgt[i] > 0.0f) {
                        throw WeightedData.nonfiniteDataMustHaveZeroWeight();
                    }
                    if (!this.writableData) {
                        this.cloneData();
                        dat = ((FloatShapedVector)this.data).getData();
                    }
                    dat[i] = 0.0f;
                }
                if (!(wgt[i] > 0.0f)) continue;
                ++count;
            }
        } else {
            double zero = 0.0;
            double[] dat = ((DoubleShapedVector)this.data).getData();
            double[] wgt = ((DoubleShapedVector)this.weights).getData();
            for (int i = 0; i < dat.length; ++i) {
                if (WeightedData.nonfinite(wgt[i]) || wgt[i] < 0.0) {
                    throw WeightedData.weightsMustBeFiniteAndNonnegative();
                }
                if (WeightedData.nonfinite(dat[i])) {
                    if (wgt[i] > 0.0) {
                        throw WeightedData.nonfiniteDataMustHaveZeroWeight();
                    }
                    if (!this.writableData) {
                        this.cloneData();
                        dat = ((DoubleShapedVector)this.data).getData();
                    }
                    dat[i] = 0.0;
                }
                if (!(wgt[i] > 0.0)) continue;
                ++count;
            }
        }
        this.updatePending = false;
        this.validDataNumber = count;
    }

    private static final boolean nonfinite(float val) {
        return Float.isInfinite(val) || Float.isNaN(val);
    }

    private static final boolean nonfinite(double val) {
        return Double.isInfinite(val) || Double.isNaN(val);
    }

    private boolean[] toBoolean(ShapedArray arr) {
        switch (arr.getType()) {
            case 0: {
                return WeightedData.toBoolean(((ByteArray)arr).flatten());
            }
            case 1: {
                return WeightedData.toBoolean(((ShortArray)arr).flatten());
            }
            case 2: {
                return WeightedData.toBoolean(((IntArray)arr).flatten());
            }
            case 3: {
                return WeightedData.toBoolean(((LongArray)arr).flatten());
            }
            case 4: {
                return WeightedData.toBoolean(((FloatArray)arr).flatten());
            }
            case 5: {
                return WeightedData.toBoolean(((DoubleArray)arr).flatten());
            }
        }
        throw WeightedData.unsupportedDataType();
    }

    private boolean[] toBoolean(ShapedVector vec) {
        switch (vec.getType()) {
            case 4: {
                return WeightedData.toBoolean(((FloatShapedVector)vec).getData());
            }
            case 5: {
                return WeightedData.toBoolean(((DoubleShapedVector)vec).getData());
            }
        }
        throw WeightedData.unsupportedDataType();
    }

    private static final boolean[] toBoolean(byte[] arr) {
        boolean zero = false;
        boolean[] res = new boolean[arr.length];
        for (int i = 0; i < arr.length; ++i) {
            res[i] = arr[i] != 0;
        }
        return res;
    }

    private static final boolean[] toBoolean(short[] arr) {
        boolean zero = false;
        boolean[] res = new boolean[arr.length];
        for (int i = 0; i < arr.length; ++i) {
            res[i] = arr[i] != 0;
        }
        return res;
    }

    private static final boolean[] toBoolean(int[] arr) {
        boolean zero = false;
        boolean[] res = new boolean[arr.length];
        for (int i = 0; i < arr.length; ++i) {
            res[i] = arr[i] != 0;
        }
        return res;
    }

    private static final boolean[] toBoolean(long[] arr) {
        long zero = 0L;
        boolean[] res = new boolean[arr.length];
        for (int i = 0; i < arr.length; ++i) {
            res[i] = arr[i] != 0L;
        }
        return res;
    }

    private static final boolean[] toBoolean(float[] arr) {
        float zero = 0.0f;
        boolean[] res = new boolean[arr.length];
        for (int i = 0; i < arr.length; ++i) {
            res[i] = arr[i] != 0.0f;
        }
        return res;
    }

    private static final boolean[] toBoolean(double[] arr) {
        double zero = 0.0;
        boolean[] res = new boolean[arr.length];
        for (int i = 0; i < arr.length; ++i) {
            res[i] = arr[i] != 0.0;
        }
        return res;
    }

    @Override
    public VectorSpace getInputSpace() {
        return this.dataSpace;
    }

    @Override
    public double evaluate(double alpha, Vector vx) {
        this.dataSpace.check(vx);
        if (alpha == 0.0) {
            return 0.0;
        }
        double sum = 0.0;
        if (this.single) {
            float[] w = ((FloatShapedVector)this.weights).getData();
            float[] x = ((FloatShapedVector)vx).getData();
            float[] y = ((FloatShapedVector)this.data).getData();
            for (int i = 0; i < y.length; ++i) {
                float r = x[i] - y[i];
                sum += (double)(r * w[i] * r);
            }
        } else {
            double[] w = ((DoubleShapedVector)this.weights).getData();
            double[] x = ((DoubleShapedVector)vx).getData();
            double[] y = ((DoubleShapedVector)this.data).getData();
            for (int i = 0; i < y.length; ++i) {
                double r = x[i] - y[i];
                sum += r * w[i] * r;
            }
        }
        return alpha / 2.0 * sum;
    }

    @Override
    public double computeCostAndGradient(double alpha, Vector vx, Vector vg, boolean clr) {
        this.dataSpace.check(vx);
        this.dataSpace.check(vg);
        if (alpha == 0.0) {
            if (clr) {
                vg.zero();
            }
            return 0.0;
        }
        double sum = 0.0;
        if (this.single) {
            float q = (float)alpha;
            float[] g = ((FloatShapedVector)vg).getData();
            float[] w = ((FloatShapedVector)this.weights).getData();
            float[] x = ((FloatShapedVector)vx).getData();
            float[] y = ((FloatShapedVector)this.data).getData();
            if (clr) {
                for (int i = 0; i < y.length; ++i) {
                    float r = x[i] - y[i];
                    float wr = w[i] * r;
                    sum += (double)(r * wr);
                    g[i] = q * wr;
                }
            } else {
                int i = 0;
                while (i < y.length) {
                    float r = x[i] - y[i];
                    float wr = w[i] * r;
                    sum += (double)(r * wr);
                    int n = i++;
                    g[n] = g[n] + q * wr;
                }
            }
        } else {
            double q = alpha;
            double[] g = ((DoubleShapedVector)vg).getData();
            double[] w = ((DoubleShapedVector)this.weights).getData();
            double[] x = ((DoubleShapedVector)vx).getData();
            double[] y = ((DoubleShapedVector)this.data).getData();
            if (clr) {
                for (int i = 0; i < y.length; ++i) {
                    double r = x[i] - y[i];
                    double wr = w[i] * r;
                    sum += r * wr;
                    g[i] = q * wr;
                }
            } else {
                int i = 0;
                while (i < y.length) {
                    double r = x[i] - y[i];
                    double wr = w[i] * r;
                    sum += r * wr;
                    int n = i++;
                    g[n] = g[n] + q * wr;
                }
            }
        }
        return alpha / 2.0 * sum;
    }

    private static final NonConformableArrayException nonconformableData() {
        return new NonConformableArrayException("Data array has non conformable dimensions");
    }

    private static final IllegalArgumentException badMaskShape() {
        return new IllegalArgumentException("Mask of bad data must have the same shape as the data");
    }

    private static final IllegalArgumentException badMaskLength() {
        return new IllegalArgumentException("Mask of bad data must have the same length as the data");
    }

    private static final IllegalArgumentException noDataHasBeenSpecified() {
        return new IllegalArgumentException("Data has not yet been specified");
    }

    private static final IllegalArgumentException dataCanOnlyBeSpecifiedOnce() {
        return new IllegalArgumentException("Data can only be specified once");
    }

    private static final IllegalArgumentException weightsCanOnlyBeSpecifiedOnce() {
        return new IllegalArgumentException("Weights can only be specified or computed once");
    }

    private static final IllegalArgumentException nonfiniteDataMustHaveZeroWeight() {
        return new IllegalArgumentException("Non-finite data must have zero weight");
    }

    private static final IllegalArgumentException weightsMustBeFiniteAndNonnegative() {
        return new IllegalArgumentException("Weights must be finite and nonnegative");
    }

    private static final IllegalArgumentException unsupportedDataType() {
        return new IllegalArgumentException("Unsupported data type");
    }

    private static final IllegalArgumentException mustBeFloatingPoint() {
        return new IllegalArgumentException("Weighted data type must be 'float' or 'double'");
    }
}

