/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.listeners.impl;

import java.text.DecimalFormat;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.records.LossCurve;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ScoreListener
extends BaseListener {
    private static final Logger log = LoggerFactory.getLogger(ScoreListener.class);
    private final int frequency;
    private final boolean reportEpochs;
    private final boolean reportIterPerformance;
    private long epochExampleCount;
    private int epochBatchCount;
    private long etlTotalTimeEpoch;
    private long lastIterTime;
    private long etlTimeSumSinceLastReport;
    private long iterTimeSumSinceLastReport;
    private int examplesSinceLastReportIter;
    private long lastReportTime = -1L;
    protected static final ThreadLocal<DecimalFormat> DF_2DP = new ThreadLocal();
    protected static final ThreadLocal<DecimalFormat> DF_2DP_SCI = new ThreadLocal();
    protected static final ThreadLocal<DecimalFormat> DF_5DP = new ThreadLocal();
    protected static final ThreadLocal<DecimalFormat> DF_5DP_SCI = new ThreadLocal();

    public ScoreListener() {
        this(10, true);
    }

    public ScoreListener(int frequency) {
        this(frequency, true);
    }

    public ScoreListener(int frequency, boolean reportEpochs) {
        this(frequency, reportEpochs, true);
    }

    public ScoreListener(int frequency, boolean reportEpochs, boolean reportIterPerformance) {
        Preconditions.checkArgument(frequency > 0, "ScoreListener frequency must be > 0, got %s", frequency);
        this.frequency = frequency;
        this.reportEpochs = reportEpochs;
        this.reportIterPerformance = reportIterPerformance;
    }

    @Override
    public boolean isActive(Operation operation) {
        return operation == Operation.TRAINING;
    }

    @Override
    public void epochStart(SameDiff sd, At at) {
        if (this.reportEpochs) {
            this.epochExampleCount = 0L;
            this.epochBatchCount = 0;
            this.etlTotalTimeEpoch = 0L;
        }
        this.lastReportTime = -1L;
        this.examplesSinceLastReportIter = 0;
    }

    @Override
    public ListenerResponse epochEnd(SameDiff sd, At at, LossCurve lossCurve, long epochTimeMillis) {
        if (this.reportEpochs) {
            double batchesPerSec = (double)this.epochBatchCount / ((double)epochTimeMillis / 1000.0);
            double examplesPerSec = (double)this.epochExampleCount / ((double)epochTimeMillis / 1000.0);
            double pcEtl = 100.0 * (double)this.etlTotalTimeEpoch / (double)epochTimeMillis;
            String etl = this.formatDurationMs(this.etlTotalTimeEpoch) + " ETL time" + (this.etlTotalTimeEpoch > 0L ? "(" + this.format2dp(pcEtl) + " %)" : "");
            log.info("Epoch {} complete on iteration {} - {} batches ({} examples) in {} - {} batches/sec, {} examples/sec, {}", new Object[]{at.epoch(), at.iteration(), this.epochBatchCount, this.epochExampleCount, this.formatDurationMs(epochTimeMillis), this.format2dp(batchesPerSec), this.format2dp(examplesPerSec), etl});
        }
        return ListenerResponse.CONTINUE;
    }

    @Override
    public void iterationStart(SameDiff sd, At at, MultiDataSet data, long etlMs) {
        this.lastIterTime = System.currentTimeMillis();
        this.etlTimeSumSinceLastReport += etlMs;
        this.etlTotalTimeEpoch += etlMs;
    }

    @Override
    public void iterationDone(SameDiff sd, At at, MultiDataSet dataSet, Loss loss) {
        this.iterTimeSumSinceLastReport += System.currentTimeMillis() - this.lastIterTime;
        ++this.epochBatchCount;
        if (dataSet.numFeatureArrays() > 0 && dataSet.getFeatures(0) != null) {
            int n = (int)dataSet.getFeatures(0).size(0);
            this.examplesSinceLastReportIter += n;
            this.epochExampleCount += (long)n;
        }
        if (at.iteration() > 0 && at.iteration() % this.frequency == 0) {
            double l = loss.totalLoss();
            String etl = "";
            if (this.etlTimeSumSinceLastReport > 0L) {
                etl = "(" + this.formatDurationMs(this.etlTimeSumSinceLastReport) + " ETL";
                etl = this.frequency == 1 ? etl + ")" : etl + " in " + this.frequency + " iter)";
            }
            if (!this.reportIterPerformance) {
                log.info("Loss at epoch {}, iteration {}: {}{}", new Object[]{at.epoch(), at.iteration(), this.format5dp(l), etl});
            } else {
                long time = System.currentTimeMillis();
                if (this.lastReportTime > 0L) {
                    double batchPerSec = (double)(1000 * this.frequency) / (double)(time - this.lastReportTime);
                    double exPerSec = (double)(1000 * this.examplesSinceLastReportIter) / (double)(time - this.lastReportTime);
                    log.info("Loss at epoch {}, iteration {}: {}{}, batches/sec: {}, examples/sec: {}", new Object[]{at.epoch(), at.iteration(), this.format5dp(l), etl, this.format5dp(batchPerSec), this.format5dp(exPerSec)});
                } else {
                    log.info("Loss at epoch {}, iteration {}: {}{}", new Object[]{at.epoch(), at.iteration(), this.format5dp(l), etl});
                }
                this.lastReportTime = time;
            }
            this.iterTimeSumSinceLastReport = 0L;
            this.etlTimeSumSinceLastReport = 0L;
            this.examplesSinceLastReportIter = 0;
        }
    }

    protected String formatDurationMs(long ms) {
        if (ms <= 100L) {
            return ms + " ms";
        }
        if (ms <= 60000L) {
            double sec = (double)ms / 1000.0;
            return this.format2dp(sec) + " sec";
        }
        if (ms <= 3600000L) {
            double min = (double)ms / 60000.0;
            return this.format2dp(min) + " min";
        }
        double hr = (double)ms / 360000.0;
        return this.format2dp(hr) + " hr";
    }

    protected String format2dp(double d) {
        if (d < 0.01) {
            DecimalFormat f = DF_2DP_SCI.get();
            if (f == null) {
                f = new DecimalFormat("0.00E0");
                DF_2DP.set(f);
            }
            return f.format(d);
        }
        DecimalFormat f = DF_2DP.get();
        if (f == null) {
            f = new DecimalFormat("#.00");
            DF_2DP.set(f);
        }
        return f.format(d);
    }

    protected String format5dp(double d) {
        if (d < 1.0E-4 || d > 10000.0) {
            DecimalFormat f = DF_5DP_SCI.get();
            if (f == null) {
                f = new DecimalFormat("0.00000E0");
                DF_5DP_SCI.set(f);
            }
            return f.format(d);
        }
        DecimalFormat f = DF_5DP.get();
        if (f == null) {
            f = new DecimalFormat("0.00000");
            DF_5DP.set(f);
        }
        return f.format(d);
    }
}

