/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.listeners.debugging;

import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.BaseListener;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.factory.Nd4j;

public class ArraySavingListener
extends BaseListener {
    protected final File dir;
    protected int count = 0;

    public ArraySavingListener(@NonNull File dir) {
        if (dir == null) {
            throw new NullPointerException("dir is marked non-null but is null");
        }
        if (!dir.exists()) {
            dir.mkdir();
        }
        if (dir.listFiles() != null && dir.listFiles().length > 0) {
            throw new IllegalStateException("Directory is not empty: " + dir.getAbsolutePath());
        }
        this.dir = dir;
    }

    @Override
    public boolean isActive(Operation operation) {
        return true;
    }

    @Override
    public void opExecution(SameDiff sd, At at, MultiDataSet batch, SameDiffOp op, OpContext opContext, INDArray[] outputs) {
        List<String> outNames = op.getOutputsOfOp();
        for (int i = 0; i < outputs.length; ++i) {
            String filename = this.count++ + "_" + outNames.get(i).replaceAll("/", "__") + ".bin";
            File outFile = new File(this.dir, filename);
            INDArray arr = outputs[i];
            try {
                Nd4j.saveBinary(arr, outFile);
                System.out.println(outFile.getAbsolutePath());
                continue;
            }
            catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
    }

    public static void compare(File dir1, File dir2, double eps) throws Exception {
        File[] files1 = dir1.listFiles();
        File[] files2 = dir2.listFiles();
        Preconditions.checkNotNull((Object)files1, "No files in directory 1: %s", (Object)dir1);
        Preconditions.checkNotNull((Object)files2, "No files in directory 2: %s", (Object)dir2);
        Preconditions.checkState(files1.length == files2.length, "Different number of files: %s vs %s", files1.length, files2.length);
        Map<String, File> m1 = ArraySavingListener.toMap(files1);
        Map<String, File> m2 = ArraySavingListener.toMap(files2);
        for (File f : files1) {
            INDArray arr2;
            String name = f.getName();
            String varName = name.substring(name.indexOf(95) + 1, name.length() - 4);
            File f2 = m2.get(varName);
            INDArray arr1 = Nd4j.readBinary(f);
            boolean eq = arr1.equalsWithEps(arr2 = Nd4j.readBinary(f2), eps);
            if (eq) {
                System.out.println("Equals: " + varName.replaceAll("__", "/"));
            } else if (arr1.dataType() == DataType.BOOL) {
                INDArray xor = Nd4j.exec(new Xor(arr1, arr2));
                int count = xor.castTo(DataType.INT).sumNumber().intValue();
                System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - boolean, # differences = " + count);
                System.out.println("\t" + f.getAbsolutePath());
                System.out.println("\t" + f2.getAbsolutePath());
                xor.close();
            } else {
                INDArray sub = arr1.sub(arr2);
                INDArray diff = Nd4j.math.abs(sub);
                double maxDiff = diff.maxNumber().doubleValue();
                System.out.println("FAILS: " + varName.replaceAll("__", "/") + " - max difference = " + maxDiff);
                System.out.println("\t" + f.getAbsolutePath());
                System.out.println("\t" + f2.getAbsolutePath());
                sub.close();
                diff.close();
            }
            arr1.close();
            arr2.close();
        }
    }

    private static Map<String, File> toMap(File[] files) {
        HashMap<String, File> ret = new HashMap<String, File>();
        for (File f : files) {
            String name = f.getName();
            String varName = name.substring(name.indexOf(95) + 1, name.length() - 4);
            ret.put(varName, f);
        }
        return ret;
    }
}

