/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.evaluation.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Triple;
import org.nd4j.evaluation.BaseEvaluation;
import org.nd4j.evaluation.IMetric;
import org.nd4j.evaluation.curves.Histogram;
import org.nd4j.evaluation.curves.ReliabilityDiagram;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.lossfunctions.LossUtil;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.serde.jackson.shaded.NDArrayDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArraySerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

public class EvaluationCalibration
extends BaseEvaluation<EvaluationCalibration> {
    public static final int DEFAULT_RELIABILITY_DIAG_NUM_BINS = 10;
    public static final int DEFAULT_HISTOGRAM_NUM_BINS = 50;
    private final int reliabilityDiagNumBins;
    private final int histogramNumBins;
    private final boolean excludeEmptyBins;
    protected int axis = 1;
    @JsonSerialize(using=NDArraySerializer.class)
    @JsonDeserialize(using=NDArrayDeSerializer.class)
    private INDArray rDiagBinPosCount;
    @JsonSerialize(using=NDArraySerializer.class)
    @JsonDeserialize(using=NDArrayDeSerializer.class)
    private INDArray rDiagBinTotalCount;
    @JsonSerialize(using=NDArraySerializer.class)
    @JsonDeserialize(using=NDArrayDeSerializer.class)
    private INDArray rDiagBinSumPredictions;
    @JsonSerialize(using=NDArrayTextSerializer.class)
    @JsonDeserialize(using=NDArrayTextDeSerializer.class)
    private INDArray labelCountsEachClass;
    @JsonSerialize(using=NDArrayTextSerializer.class)
    @JsonDeserialize(using=NDArrayTextDeSerializer.class)
    private INDArray predictionCountsEachClass;
    @JsonSerialize(using=NDArrayTextSerializer.class)
    @JsonDeserialize(using=NDArrayTextDeSerializer.class)
    private INDArray residualPlotOverall;
    @JsonSerialize(using=NDArraySerializer.class)
    @JsonDeserialize(using=NDArrayDeSerializer.class)
    private INDArray residualPlotByLabelClass;
    @JsonSerialize(using=NDArrayTextSerializer.class)
    @JsonDeserialize(using=NDArrayTextDeSerializer.class)
    private INDArray probHistogramOverall;
    @JsonSerialize(using=NDArraySerializer.class)
    @JsonDeserialize(using=NDArrayDeSerializer.class)
    private INDArray probHistogramByLabelClass;

    protected EvaluationCalibration(int axis, int reliabilityDiagNumBins, int histogramNumBins, boolean excludeEmptyBins) {
        this.axis = axis;
        this.reliabilityDiagNumBins = reliabilityDiagNumBins;
        this.histogramNumBins = histogramNumBins;
        this.excludeEmptyBins = excludeEmptyBins;
    }

    public EvaluationCalibration() {
        this(10, 50, true);
    }

    public EvaluationCalibration(int reliabilityDiagNumBins, int histogramNumBins) {
        this(reliabilityDiagNumBins, histogramNumBins, true);
    }

    public EvaluationCalibration(@JsonProperty(value="reliabilityDiagNumBins") int reliabilityDiagNumBins, @JsonProperty(value="histogramNumBins") int histogramNumBins, @JsonProperty(value="excludeEmptyBins") boolean excludeEmptyBins) {
        this.reliabilityDiagNumBins = reliabilityDiagNumBins;
        this.histogramNumBins = histogramNumBins;
        this.excludeEmptyBins = excludeEmptyBins;
    }

    public void setAxis(int axis) {
        this.axis = axis;
    }

    public int getAxis() {
        return this.axis;
    }

    @Override
    public void eval(INDArray labels, INDArray predictions, INDArray mask) {
        Triple<INDArray, INDArray, INDArray> triple = BaseEvaluation.reshapeAndExtractNotMasked(labels, predictions, mask, this.axis);
        if (triple == null) {
            return;
        }
        INDArray labels2d = triple.getFirst();
        INDArray predictions2d = triple.getSecond();
        INDArray maskArray = triple.getThird();
        Preconditions.checkState(maskArray == null, "Per-output masking for EvaluationCalibration is not supported");
        long nClasses = labels2d.size(1);
        if (this.rDiagBinPosCount == null) {
            DataType dt = DataType.DOUBLE;
            this.rDiagBinPosCount = Nd4j.create(DataType.LONG, this.reliabilityDiagNumBins, nClasses);
            this.rDiagBinTotalCount = Nd4j.create(DataType.LONG, this.reliabilityDiagNumBins, nClasses);
            this.rDiagBinSumPredictions = Nd4j.create(dt, this.reliabilityDiagNumBins, nClasses);
            this.labelCountsEachClass = Nd4j.create(DataType.LONG, 1L, nClasses);
            this.predictionCountsEachClass = Nd4j.create(DataType.LONG, 1L, nClasses);
            this.residualPlotOverall = Nd4j.create(dt, 1L, this.histogramNumBins);
            this.residualPlotByLabelClass = Nd4j.create(dt, this.histogramNumBins, nClasses);
            this.probHistogramOverall = Nd4j.create(dt, 1L, this.histogramNumBins);
            this.probHistogramByLabelClass = Nd4j.create(dt, this.histogramNumBins, nClasses);
        }
        double histogramBinSize = 1.0 / (double)this.histogramNumBins;
        double reliabilityBinSize = 1.0 / (double)this.reliabilityDiagNumBins;
        INDArray p = predictions2d;
        INDArray l = labels2d;
        if (maskArray != null) {
            l = maskArray.isColumnVectorOrScalar() ? l.mulColumnVector(maskArray) : l.mul(maskArray);
        }
        for (int j = 0; j < this.reliabilityDiagNumBins; ++j) {
            INDArray geqBinLower = p.gte((double)j * reliabilityBinSize).castTo(predictions2d.dataType());
            INDArray ltBinUpper = j == this.reliabilityDiagNumBins - 1 ? p.lte(1.0).castTo(predictions2d.dataType()) : p.lt((double)(j + 1) * reliabilityBinSize).castTo(predictions2d.dataType());
            INDArray currBinBitMask = geqBinLower.muli(ltBinUpper);
            if (maskArray != null) {
                if (maskArray.isColumnVectorOrScalar()) {
                    currBinBitMask.muliColumnVector(maskArray);
                } else {
                    currBinBitMask.muli(maskArray);
                }
            }
            INDArray isPosLabelForBin = l.mul(currBinBitMask);
            INDArray maskedProbs = predictions2d.mul(currBinBitMask);
            INDArray numPredictionsCurrBin = currBinBitMask.sum(0);
            this.rDiagBinSumPredictions.getRow(j).addi(maskedProbs.sum(0).castTo(this.rDiagBinSumPredictions.dataType()));
            this.rDiagBinPosCount.getRow(j).addi(isPosLabelForBin.sum(0).castTo(this.rDiagBinPosCount.dataType()));
            this.rDiagBinTotalCount.getRow(j).addi(numPredictionsCurrBin.castTo(this.rDiagBinTotalCount.dataType()));
        }
        this.labelCountsEachClass.addi(labels2d.sum(0).castTo(this.labelCountsEachClass.dataType()));
        INDArray isPredictedClass = Nd4j.getExecutioner().exec(new IsMax(p, p.ulike(), 1))[0];
        if (maskArray != null) {
            LossUtil.applyMask(isPredictedClass, maskArray);
        }
        this.predictionCountsEachClass.addi(isPredictedClass.sum(0).castTo(this.predictionCountsEachClass.dataType()));
        INDArray labelsSubPredicted = labels2d.sub(predictions2d);
        INDArray maskedProbs = predictions2d.dup();
        Transforms.abs(labelsSubPredicted, false);
        if (maskArray != null) {
            INDArray newMask = maskArray.mul(-10);
            labelsSubPredicted.addiColumnVector(newMask);
            maskedProbs.addiColumnVector(newMask);
        }
        for (int j = 0; j < this.histogramNumBins; ++j) {
            INDArray ltBinUpperProbs;
            INDArray ltBinUpper;
            INDArray geqBinLower = labelsSubPredicted.gte((double)j * histogramBinSize).castTo(predictions2d.dataType());
            INDArray geqBinLowerProbs = maskedProbs.gte((double)j * histogramBinSize).castTo(predictions2d.dataType());
            if (j == this.histogramNumBins - 1) {
                ltBinUpper = labelsSubPredicted.lte(1.0).castTo(predictions2d.dataType());
                ltBinUpperProbs = maskedProbs.lte(1.0).castTo(predictions2d.dataType());
            } else {
                ltBinUpper = labelsSubPredicted.lt((double)(j + 1) * histogramBinSize).castTo(predictions2d.dataType());
                ltBinUpperProbs = maskedProbs.lt((double)(j + 1) * histogramBinSize).castTo(predictions2d.dataType());
            }
            INDArray currBinBitMask = geqBinLower.muli(ltBinUpper);
            INDArray currBinBitMaskProbs = geqBinLowerProbs.muli(ltBinUpperProbs);
            int newTotalCount = this.residualPlotOverall.getInt(0, j) + currBinBitMask.sumNumber().intValue();
            this.residualPlotOverall.putScalar(0L, j, newTotalCount);
            INDArray isPosLabelForBin = l.mul(currBinBitMask);
            this.residualPlotByLabelClass.getRow(j).addi(isPosLabelForBin.sum(0).castTo(this.residualPlotByLabelClass.dataType()));
            int probNewTotalCount = this.probHistogramOverall.getInt(0, j) + currBinBitMaskProbs.sumNumber().intValue();
            this.probHistogramOverall.putScalar(0L, j, probNewTotalCount);
            INDArray isPosLabelForBinProbs = l.mul(currBinBitMaskProbs);
            INDArray temp = isPosLabelForBinProbs.sum(0);
            this.probHistogramByLabelClass.getRow(j).addi(temp.castTo(this.probHistogramByLabelClass.dataType()));
        }
    }

    @Override
    public void eval(INDArray labels, INDArray networkPredictions) {
        this.eval(labels, networkPredictions, (INDArray)null);
    }

    @Override
    public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray, List<? extends Serializable> recordMetaData) {
        this.eval(labels, networkPredictions, maskArray);
    }

    @Override
    public void merge(EvaluationCalibration other) {
        if (this.reliabilityDiagNumBins != other.reliabilityDiagNumBins) {
            throw new UnsupportedOperationException("Cannot merge EvaluationCalibration instances with different numbers of bins");
        }
        if (other.rDiagBinPosCount == null) {
            return;
        }
        if (this.rDiagBinPosCount == null) {
            this.rDiagBinPosCount = other.rDiagBinPosCount;
            this.rDiagBinTotalCount = other.rDiagBinTotalCount;
            this.rDiagBinSumPredictions = other.rDiagBinSumPredictions;
        }
        this.rDiagBinPosCount.addi(other.rDiagBinPosCount);
        this.rDiagBinTotalCount.addi(other.rDiagBinTotalCount);
        this.rDiagBinSumPredictions.addi(other.rDiagBinSumPredictions);
    }

    @Override
    public void reset() {
        this.rDiagBinPosCount = null;
        this.rDiagBinTotalCount = null;
        this.rDiagBinSumPredictions = null;
    }

    @Override
    public String stats() {
        return "EvaluationCalibration(nBins=" + this.reliabilityDiagNumBins + ")";
    }

    public int numClasses() {
        if (this.rDiagBinTotalCount == null) {
            return -1;
        }
        return (int)this.rDiagBinTotalCount.size(1);
    }

    public ReliabilityDiagram getReliabilityDiagram(int classIdx) {
        Preconditions.checkState(this.rDiagBinPosCount != null, "Unable to get reliability diagram: no evaluation has been performed (no data)");
        INDArray totalCountBins = this.rDiagBinTotalCount.getColumn(classIdx);
        INDArray countPositiveBins = this.rDiagBinPosCount.getColumn(classIdx);
        double[] meanPredictionBins = this.rDiagBinSumPredictions.getColumn(classIdx).castTo(DataType.DOUBLE).div(totalCountBins.castTo(DataType.DOUBLE)).data().asDouble();
        double[] fracPositives = countPositiveBins.castTo(DataType.DOUBLE).div(totalCountBins.castTo(DataType.DOUBLE)).data().asDouble();
        if (this.excludeEmptyBins) {
            MatchCondition condition = new MatchCondition(totalCountBins, Conditions.equals(0), new int[0]);
            int numZeroBins = Nd4j.getExecutioner().exec(condition).getInt(0);
            if (numZeroBins != 0) {
                double[] mpb = meanPredictionBins;
                double[] fp = fracPositives;
                meanPredictionBins = new double[(int)(totalCountBins.length() - (long)numZeroBins)];
                fracPositives = new double[meanPredictionBins.length];
                int j = 0;
                for (int i = 0; i < mpb.length; ++i) {
                    if (totalCountBins.getDouble((long)i) == 0.0) continue;
                    meanPredictionBins[j] = mpb[i];
                    fracPositives[j] = fp[i];
                    ++j;
                }
            }
        }
        String title = "Reliability Diagram: Class " + classIdx;
        return new ReliabilityDiagram(title, meanPredictionBins, fracPositives);
    }

    public int[] getLabelCountsEachClass() {
        return this.labelCountsEachClass == null ? null : this.labelCountsEachClass.data().asInt();
    }

    public int[] getPredictionCountsEachClass() {
        return this.predictionCountsEachClass == null ? null : this.predictionCountsEachClass.data().asInt();
    }

    public Histogram getResidualPlotAllClasses() {
        String title = "Residual Plot - All Predictions and Classes";
        int[] counts = this.residualPlotOverall.data().asInt();
        return new Histogram(title, 0.0, 1.0, counts);
    }

    public Histogram getResidualPlot(int labelClassIdx) {
        Preconditions.checkState(this.rDiagBinPosCount != null, "Unable to get residual plot: no evaluation has been performed (no data)");
        String title = "Residual Plot - Predictions for Label Class " + labelClassIdx;
        int[] counts = this.residualPlotByLabelClass.getColumn(labelClassIdx).dup().data().asInt();
        return new Histogram(title, 0.0, 1.0, counts);
    }

    public Histogram getProbabilityHistogramAllClasses() {
        String title = "Network Probabilities Histogram - All Predictions and Classes";
        int[] counts = this.probHistogramOverall.data().asInt();
        return new Histogram(title, 0.0, 1.0, counts);
    }

    public Histogram getProbabilityHistogram(int labelClassIdx) {
        Preconditions.checkState(this.rDiagBinPosCount != null, "Unable to get probability histogram: no evaluation has been performed (no data)");
        String title = "Network Probabilities Histogram - P(class " + labelClassIdx + ") - Data Labelled Class " + labelClassIdx + " Only";
        int[] counts = this.probHistogramByLabelClass.getColumn(labelClassIdx).dup().data().asInt();
        return new Histogram(title, 0.0, 1.0, counts);
    }

    public static EvaluationCalibration fromJson(String json) {
        return EvaluationCalibration.fromJson(json, EvaluationCalibration.class);
    }

    @Override
    public double getValue(IMetric metric) {
        throw new IllegalStateException("Can't get value for non-calibration Metric " + metric);
    }

    @Override
    public EvaluationCalibration newInstance() {
        return new EvaluationCalibration(this.axis, this.reliabilityDiagNumBins, this.histogramNumBins, this.excludeEmptyBins);
    }

    public int getReliabilityDiagNumBins() {
        return this.reliabilityDiagNumBins;
    }

    public int getHistogramNumBins() {
        return this.histogramNumBins;
    }

    public boolean isExcludeEmptyBins() {
        return this.excludeEmptyBins;
    }

    public INDArray getRDiagBinPosCount() {
        return this.rDiagBinPosCount;
    }

    public INDArray getRDiagBinTotalCount() {
        return this.rDiagBinTotalCount;
    }

    public INDArray getRDiagBinSumPredictions() {
        return this.rDiagBinSumPredictions;
    }

    public INDArray getResidualPlotOverall() {
        return this.residualPlotOverall;
    }

    public INDArray getResidualPlotByLabelClass() {
        return this.residualPlotByLabelClass;
    }

    public INDArray getProbHistogramOverall() {
        return this.probHistogramOverall;
    }

    public INDArray getProbHistogramByLabelClass() {
        return this.probHistogramByLabelClass;
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof EvaluationCalibration)) {
            return false;
        }
        EvaluationCalibration other = (EvaluationCalibration)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getReliabilityDiagNumBins() != other.getReliabilityDiagNumBins()) {
            return false;
        }
        if (this.getHistogramNumBins() != other.getHistogramNumBins()) {
            return false;
        }
        if (this.isExcludeEmptyBins() != other.isExcludeEmptyBins()) {
            return false;
        }
        INDArray this$rDiagBinPosCount = this.getRDiagBinPosCount();
        INDArray other$rDiagBinPosCount = other.getRDiagBinPosCount();
        if (this$rDiagBinPosCount == null ? other$rDiagBinPosCount != null : !this$rDiagBinPosCount.equals(other$rDiagBinPosCount)) {
            return false;
        }
        INDArray this$rDiagBinTotalCount = this.getRDiagBinTotalCount();
        INDArray other$rDiagBinTotalCount = other.getRDiagBinTotalCount();
        if (this$rDiagBinTotalCount == null ? other$rDiagBinTotalCount != null : !this$rDiagBinTotalCount.equals(other$rDiagBinTotalCount)) {
            return false;
        }
        INDArray this$rDiagBinSumPredictions = this.getRDiagBinSumPredictions();
        INDArray other$rDiagBinSumPredictions = other.getRDiagBinSumPredictions();
        if (this$rDiagBinSumPredictions == null ? other$rDiagBinSumPredictions != null : !this$rDiagBinSumPredictions.equals(other$rDiagBinSumPredictions)) {
            return false;
        }
        if (!Arrays.equals(this.getLabelCountsEachClass(), other.getLabelCountsEachClass())) {
            return false;
        }
        if (!Arrays.equals(this.getPredictionCountsEachClass(), other.getPredictionCountsEachClass())) {
            return false;
        }
        INDArray this$residualPlotOverall = this.getResidualPlotOverall();
        INDArray other$residualPlotOverall = other.getResidualPlotOverall();
        if (this$residualPlotOverall == null ? other$residualPlotOverall != null : !this$residualPlotOverall.equals(other$residualPlotOverall)) {
            return false;
        }
        INDArray this$residualPlotByLabelClass = this.getResidualPlotByLabelClass();
        INDArray other$residualPlotByLabelClass = other.getResidualPlotByLabelClass();
        if (this$residualPlotByLabelClass == null ? other$residualPlotByLabelClass != null : !this$residualPlotByLabelClass.equals(other$residualPlotByLabelClass)) {
            return false;
        }
        INDArray this$probHistogramOverall = this.getProbHistogramOverall();
        INDArray other$probHistogramOverall = other.getProbHistogramOverall();
        if (this$probHistogramOverall == null ? other$probHistogramOverall != null : !this$probHistogramOverall.equals(other$probHistogramOverall)) {
            return false;
        }
        INDArray this$probHistogramByLabelClass = this.getProbHistogramByLabelClass();
        INDArray other$probHistogramByLabelClass = other.getProbHistogramByLabelClass();
        return !(this$probHistogramByLabelClass == null ? other$probHistogramByLabelClass != null : !this$probHistogramByLabelClass.equals(other$probHistogramByLabelClass));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof EvaluationCalibration;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getReliabilityDiagNumBins();
        result = result * 59 + this.getHistogramNumBins();
        result = result * 59 + (this.isExcludeEmptyBins() ? 79 : 97);
        INDArray $rDiagBinPosCount = this.getRDiagBinPosCount();
        result = result * 59 + ($rDiagBinPosCount == null ? 43 : $rDiagBinPosCount.hashCode());
        INDArray $rDiagBinTotalCount = this.getRDiagBinTotalCount();
        result = result * 59 + ($rDiagBinTotalCount == null ? 43 : $rDiagBinTotalCount.hashCode());
        INDArray $rDiagBinSumPredictions = this.getRDiagBinSumPredictions();
        result = result * 59 + ($rDiagBinSumPredictions == null ? 43 : $rDiagBinSumPredictions.hashCode());
        result = result * 59 + Arrays.hashCode(this.getLabelCountsEachClass());
        result = result * 59 + Arrays.hashCode(this.getPredictionCountsEachClass());
        INDArray $residualPlotOverall = this.getResidualPlotOverall();
        result = result * 59 + ($residualPlotOverall == null ? 43 : $residualPlotOverall.hashCode());
        INDArray $residualPlotByLabelClass = this.getResidualPlotByLabelClass();
        result = result * 59 + ($residualPlotByLabelClass == null ? 43 : $residualPlotByLabelClass.hashCode());
        INDArray $probHistogramOverall = this.getProbHistogramOverall();
        result = result * 59 + ($probHistogramOverall == null ? 43 : $probHistogramOverall.hashCode());
        INDArray $probHistogramByLabelClass = this.getProbHistogramByLabelClass();
        result = result * 59 + ($probHistogramByLabelClass == null ? 43 : $probHistogramByLabelClass.hashCode());
        return result;
    }
}

