/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.validation;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import org.apache.commons.collections4.trie.PatriciaTrie;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.validation.ActivationGradientCheckListener;
import org.nd4j.autodiff.validation.TestCase;
import org.nd4j.autodiff.validation.listeners.NonInplaceValidationListener;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GradCheckUtil {
    private static final Logger log = LoggerFactory.getLogger(GradCheckUtil.class);
    public static final boolean DEFAULT_PRINT = false;
    public static final boolean DEFAULT_EXIT_FIRST_FAILURE = false;
    public static final boolean DEFAULT_DEBUG_MODE = false;
    public static final double DEFAULT_EPS = 1.0E-5;
    public static final double DEFAULT_MAX_REL_ERROR = 1.0E-5;
    public static final double DEFAULT_MIN_ABS_ERROR = 1.0E-6;

    public static boolean checkGradients(TestCase t) {
        return GradCheckUtil.checkGradients(t.sameDiff(), t.placeholderValues(), t.gradCheckEpsilon(), t.gradCheckMaxRelativeError(), t.gradCheckMinAbsError(), t.gradCheckPrint(), t.gradCheckDefaultExitFirstFailure(), false, t.gradCheckDebugMode(), t.gradCheckSkipVariables(), t.gradCheckMask());
    }

    public static boolean checkGradients(SameDiff sd, Map<String, INDArray> placeholderValues, String ... skipVariables) {
        HashSet<String> skip = null;
        if (skipVariables != null) {
            skip = new HashSet<String>();
            Collections.addAll(skip, skipVariables);
        }
        return GradCheckUtil.checkGradients(sd, placeholderValues, 1.0E-5, 1.0E-5, 1.0E-6, false, false, false, false, skip, null);
    }

    public static boolean checkGradients(SameDiff sd, Map<String, INDArray> placeholderValues, boolean print, boolean exitOnFirstFailure) {
        return GradCheckUtil.checkGradients(sd, placeholderValues, 1.0E-5, 1.0E-5, 1.0E-6, print, exitOnFirstFailure);
    }

    public static boolean checkGradients(SameDiff sd, Map<String, INDArray> placeholderValues, double eps, double maxRelError, double minAbsError, boolean print, boolean exitOnFirstFailure) {
        return GradCheckUtil.checkGradients(sd, placeholderValues, eps, maxRelError, minAbsError, print, exitOnFirstFailure, false, false, null, null);
    }

    public static boolean checkGradients(SameDiff sd, Map<String, INDArray> placeholderValues, double eps, double maxRelError, double minAbsError, boolean print, boolean exitOnFirstFailure, boolean skipValidation, boolean debugMode, Set<String> skipVariables, Map<String, INDArray> gradCheckMask) {
        return GradCheckUtil.checkGradients(sd, placeholderValues, eps, maxRelError, minAbsError, print, exitOnFirstFailure, skipValidation, debugMode, skipVariables, gradCheckMask, -1, null);
    }

    public static boolean checkGradients(SameDiff sd, Map<String, INDArray> placeholderValues, double eps, double maxRelError, double minAbsError, boolean print, boolean exitOnFirstFailure, boolean skipValidation, boolean debugMode, Set<String> skipVariables, Map<String, INDArray> gradCheckMask, int maxPerParam, Subset subset) {
        boolean debugBefore = sd.isDebugMode();
        if (debugMode) {
            sd.enableDebugMode();
        }
        if (!skipValidation) {
            GradCheckUtil.validateInternalState(sd, true);
        }
        if (Nd4j.dataType() != DataType.DOUBLE) {
            throw new IllegalStateException("Data type must be set to double");
        }
        HashSet<String> fnOutputs = new HashSet<String>();
        for (DifferentialFunction f : sd.ops()) {
            for (SDVariable s : f.outputVariables()) {
                fnOutputs.add(s.name());
            }
        }
        for (Variable v : sd.getVariables().values()) {
            if (v.getVariable().getVariableType() == VariableType.ARRAY || v.getVariable().getArr(true) != null) continue;
            throw new IllegalStateException("Variable \"" + v.getName() + "\" does not have array associated with it");
        }
        List<String> lossFnVariables = sd.getLossVariables();
        Preconditions.checkState(lossFnVariables != null && !lossFnVariables.isEmpty(), "Expected 1 or more loss function variables for gradient check, got %s", lossFnVariables);
        HashSet<String> varsNeedingGrads = new HashSet<String>();
        for (Variable v : sd.getVariables().values()) {
            if (!v.getVariable().dataType().isFPType() || v.getVariable().getVariableType() != VariableType.VARIABLE && v.getVariable().getVariableType() != VariableType.PLACEHOLDER) continue;
            SDVariable g = v.getVariable().getGradient();
            Preconditions.checkNotNull((Object)g, "No gradient variable found for variable %s", (Object)v.getVariable());
            varsNeedingGrads.add(v.getName());
        }
        ArrayList<Listener> listenersBefore = new ArrayList<Listener>(sd.getListeners());
        int listenerIdx = -1;
        if (listenersBefore.isEmpty()) {
            sd.addListeners(new NonInplaceValidationListener());
            listenerIdx = 0;
        } else {
            boolean found = false;
            int i = 0;
            for (Listener l : listenersBefore) {
                if (l instanceof NonInplaceValidationListener) {
                    found = true;
                    listenerIdx = i;
                    break;
                }
                ++i;
            }
            if (!found) {
                sd.addListeners(new NonInplaceValidationListener());
                listenerIdx = i;
            }
        }
        Map<String, INDArray> gm = sd.calculateGradients(placeholderValues, varsNeedingGrads);
        sd.getListeners().remove(listenerIdx);
        HashMap<String, INDArray> grad = new HashMap<String, INDArray>();
        for (SDVariable v : sd.variables()) {
            if (fnOutputs.contains(v.name()) || !v.hasGradient()) continue;
            SDVariable g = sd.grad(v.name());
            if (g == null) {
                throw new IllegalStateException("Null gradient variable for \"" + v.name() + "\"");
            }
            INDArray ga = gm.get(v.name());
            if (ga == null) {
                throw new IllegalStateException("Null gradient array encountered for variable: " + v.name());
            }
            if (!Arrays.equals(v.getArr().shape(), ga.shape())) {
                throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" + v.name() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " + Arrays.toString(ga.shape()));
            }
            grad.put(v.name(), ga.dup());
        }
        int totalNFailures = 0;
        int totalCount = 0;
        double maxError = 0.0;
        Random r = new Random(12345L);
        for (SDVariable s : sd.variables()) {
            INDArray varMask;
            Iterator<long[]> iter;
            if (fnOutputs.contains(s.name()) || !s.dataType().isFPType()) continue;
            if (skipVariables != null && skipVariables.contains(s.name())) {
                log.info("Grad check: skipping variable \"{}\"", (Object)s.name());
                continue;
            }
            if (s.dataType() != DataType.DOUBLE) {
                log.warn("DataType for variable {} is not double (is: {}) may cause precision issues in gradient checks", (Object)s.name(), (Object)s.dataType());
            }
            String name = s.name();
            INDArray a = s.getArr();
            long n = a.length();
            if (print) {
                log.info("Starting test for variable \"{}\" with {} values", (Object)s.name(), (Object)n);
            }
            if (maxPerParam > 0 && subset != null && (long)maxPerParam < a.length()) {
                long[] pos;
                long[] shape = a.shape();
                ArrayList<long[]> l = new ArrayList<long[]>();
                if (subset == Subset.RANDOM) {
                    HashSet<Integer> set = new HashSet<Integer>();
                    while (set.size() < maxPerParam) {
                        int next = r.nextInt((int)a.length());
                        set.add(next);
                    }
                    ArrayList sorted = new ArrayList(set);
                    Collections.sort(sorted);
                    for (Integer i : sorted) {
                        pos = Shape.ind2subC(shape, (long)i.intValue());
                        l.add(pos);
                    }
                } else {
                    long everyN = n / (long)maxPerParam;
                    for (long curr = 0L; curr < n; curr += everyN) {
                        pos = Shape.ind2subC(shape, curr);
                        l.add(pos);
                    }
                }
                iter = l.iterator();
            } else {
                iter = new NdIndexIterator('c', a.shape());
            }
            INDArray iNDArray = varMask = gradCheckMask == null ? null : gradCheckMask.get(s.name());
            if (varMask != null) {
                Preconditions.checkState(a.equalShapes(varMask), "Variable \"%s\": Gradient check mask and array shapes must be equal: got %s vs. mask shape %s", (Object)s.name(), (Object)a.shape(), (Object)varMask.shape());
                Preconditions.checkState(varMask.dataType() == DataType.BOOL, "Variable \"%s\": Gradient check mask must be BOOLEAN datatype, got %s", (Object)s.name(), (Object)varMask.dataType());
            }
            int i = 0;
            while (iter.hasNext()) {
                boolean maskValue;
                long[] idx = (long[])iter.next();
                String strIdx = null;
                if (print) {
                    strIdx = Arrays.toString(idx).replaceAll(" ", "");
                }
                boolean bl = maskValue = varMask == null || varMask.getDouble(idx) != 0.0;
                if (!maskValue) continue;
                ++totalCount;
                double orig = a.getDouble(idx);
                a.putScalar(idx, orig + eps);
                double scorePlus = 0.0;
                Map<String, INDArray> m = sd.output(placeholderValues, lossFnVariables);
                for (INDArray arr : m.values()) {
                    scorePlus += arr.sumNumber().doubleValue();
                }
                a.putScalar(idx, orig - eps);
                m = sd.output(placeholderValues, lossFnVariables);
                double scoreMinus = 0.0;
                for (INDArray arr : m.values()) {
                    scoreMinus += arr.sumNumber().doubleValue();
                }
                a.putScalar(idx, orig);
                double numericalGrad = (scorePlus - scoreMinus) / (2.0 * eps);
                INDArray aGrad = (INDArray)grad.get(s.name());
                if (aGrad == null) {
                    log.warn("No gradient array for variable \"{}\" was found, skipping variable...", (Object)s.name());
                    continue;
                }
                double analyticGrad = aGrad.getDouble(idx);
                if (Double.isInfinite(numericalGrad) || Double.isNaN(numericalGrad)) {
                    throw new IllegalStateException("Numerical gradient was " + numericalGrad + " for variable \"" + name + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
                }
                if (Double.isInfinite(analyticGrad) || Double.isNaN(analyticGrad)) {
                    throw new IllegalStateException("Analytic (SameDiff) gradient was " + analyticGrad + " for variable \"" + name + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
                }
                double relError = numericalGrad == 0.0 || analyticGrad == 0.0 ? 0.0 : Math.abs(analyticGrad - numericalGrad) / Math.abs(Math.abs(analyticGrad) + Math.abs(numericalGrad));
                if (relError > maxError) {
                    maxError = relError;
                }
                if (relError > maxRelError || Double.isNaN(relError)) {
                    double absError = Math.abs(analyticGrad - numericalGrad);
                    if (absError < minAbsError) {
                        if (print) {
                            log.info("Param " + i + " (" + name + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsError);
                        }
                    } else {
                        log.info("Param " + i + " (" + name + strIdx + ") FAILED: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + ", absError=" + absError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
                        if (exitOnFirstFailure) {
                            return false;
                        }
                        ++totalNFailures;
                    }
                } else if (print) {
                    log.info("Param " + i + " (" + name + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError);
                }
                ++i;
            }
        }
        int nPass = totalCount - totalNFailures;
        log.info("GradCheckUtil.checkGradients(): " + totalCount + " params checked, " + nPass + " passed, " + totalNFailures + " failed. Largest relative error = " + maxError);
        if (debugMode && !debugBefore) {
            sd.disableDebugging();
        }
        return totalNFailures == 0;
    }

    public static boolean checkActivationGradients(ActGradConfig config) {
        List<String> lossFnVariables;
        SameDiff sd = config.getSd();
        List<String> actGrads = config.getActivationGradsToCheck();
        double maxRelError = config.getMaxRelError();
        double minAbsError = config.getMinAbsError();
        Preconditions.checkState(sd != null, "SameDiff instance was not set in configuration");
        Preconditions.checkState(actGrads != null && !actGrads.isEmpty(), "No activation gradients were specified to gradient check");
        Preconditions.checkState(config.getEps() > 0.0, "Epsilon has not been set");
        Preconditions.checkState(maxRelError > 0.0, "Max relative error must be set (is 0.0)");
        for (String s : actGrads) {
            SDVariable v = ((Variable)sd.getVariables().get((Object)s)).getVariable();
            Preconditions.checkState(v != null, "No variable with name \"%s\" was found", (Object)s);
            Preconditions.checkState(v.getVariableType() == VariableType.ARRAY, "Only variables with type ARRAY may be gradient checked using this method. Variable \"%s\" has type %s", (Object)s, (Object)v.getVariableType());
            Preconditions.checkState(v.dataType().isFPType(), "Cannot gradient check activation variable \"%s\": must be floating point type. Is type: %s", (Object)s, (Object)v.dataType());
            if (v.dataType() == DataType.DOUBLE) continue;
            log.warn("Floating point variable {} is not double precision - this may result in spurious failures due to limited precision. Variable is type: {}", (Object)s, (Object)v.dataType());
        }
        boolean debugBefore = sd.isDebugMode();
        if (config.isDebugMode()) {
            sd.enableDebugMode();
        }
        if (!config.isSkipValidation()) {
            GradCheckUtil.validateInternalState(sd, true);
        }
        Preconditions.checkState((lossFnVariables = sd.getLossVariables()) != null && !lossFnVariables.isEmpty(), "Expected 1 or more loss function variables for gradient check, got %s", lossFnVariables);
        sd.createGradFunction();
        HashSet<String> varsRequiringGrads = new HashSet<String>();
        for (String s : actGrads) {
            SDVariable grad = sd.getVariable(s).gradient();
            Preconditions.checkState(grad != null, "Could not get gradient for activation \"%s\": gradient variable is null", (Object)s);
            varsRequiringGrads.add(s);
        }
        Map<String, INDArray> grads = sd.calculateGradients(config.getPlaceholderValues(), new ArrayList<String>(varsRequiringGrads));
        HashMap<String, INDArray> gradientsForAct = new HashMap<String, INDArray>();
        for (String s : actGrads) {
            INDArray arr = grads.get(s);
            Preconditions.checkState(arr != null, "No activation gradient array for variable \"%s\"", (Object)s);
            gradientsForAct.put(s, arr.dup());
        }
        int totalNFailures = 0;
        boolean totalCount = false;
        double maxError = 0.0;
        ActivationGradientCheckListener listener = new ActivationGradientCheckListener();
        sd.setListeners(listener);
        Random r = new Random(12345L);
        int maxPerParam = config.getMaxPerParam();
        for (String s : actGrads) {
            Iterator<long[]> iter;
            long n = ((INDArray)gradientsForAct.get(s)).length();
            if (config.isPrint()) {
                log.info("Starting test for variable \"{}\" with {} values", (Object)s, (Object)n);
            }
            if (maxPerParam > 0 && config.getSubset() != null && (long)maxPerParam < n) {
                long[] pos;
                long[] shape = ((INDArray)gradientsForAct.get(s)).shape();
                ArrayList<long[]> l = new ArrayList<long[]>();
                if (config.getSubset() == Subset.RANDOM) {
                    HashSet<Integer> set = new HashSet<Integer>();
                    while (set.size() < maxPerParam) {
                        int next = r.nextInt((int)n);
                        set.add(next);
                    }
                    ArrayList sorted = new ArrayList(set);
                    Collections.sort(sorted);
                    for (Integer i : sorted) {
                        pos = Shape.ind2subC(shape, (long)i.intValue());
                        l.add(pos);
                    }
                } else {
                    long everyN = n / (long)maxPerParam;
                    for (long curr = 0L; curr < n; curr += everyN) {
                        pos = Shape.ind2subC(shape, curr);
                        l.add(pos);
                    }
                }
                iter = l.iterator();
            } else {
                iter = new NdIndexIterator('c', ((INDArray)gradientsForAct.get(s)).shape());
            }
            INDArray varMask = config.getGradCheckMask() == null ? null : config.getGradCheckMask().get(s);
            listener.setVariableName(s);
            int i = 0;
            while (iter.hasNext()) {
                boolean maskValue;
                long[] idx = (long[])iter.next();
                String strIdx = null;
                if (config.isPrint()) {
                    strIdx = Arrays.toString(idx).replaceAll(" ", "");
                }
                boolean bl = maskValue = varMask == null || varMask.getDouble(idx) != 0.0;
                if (!maskValue) continue;
                listener.setIdx(idx);
                listener.setEps(config.getEps());
                double scorePlus = 0.0;
                Map<String, INDArray> m = sd.output(config.getPlaceholderValues(), lossFnVariables);
                for (INDArray arr : m.values()) {
                    scorePlus += arr.sumNumber().doubleValue();
                }
                listener.setEps(-config.getEps());
                m = sd.output(config.getPlaceholderValues(), lossFnVariables);
                double scoreMinus = 0.0;
                for (INDArray arr : m.values()) {
                    scoreMinus += arr.sumNumber().doubleValue();
                }
                double numericalGrad = (scorePlus - scoreMinus) / (2.0 * config.getEps());
                double analyticGrad = ((INDArray)gradientsForAct.get(s)).getDouble(idx);
                if (Double.isInfinite(numericalGrad) || Double.isNaN(numericalGrad)) {
                    throw new IllegalStateException("Numerical gradient was " + numericalGrad + " for variable \"" + s + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
                }
                if (Double.isInfinite(analyticGrad) || Double.isNaN(analyticGrad)) {
                    throw new IllegalStateException("Analytic (SameDiff) gradient was " + analyticGrad + " for variable \"" + s + "\", parameter " + i + " of " + n + " (position: " + strIdx + ")");
                }
                double relError = numericalGrad == 0.0 && analyticGrad == 0.0 ? 0.0 : Math.abs(analyticGrad - numericalGrad) / Math.abs(Math.abs(analyticGrad) + Math.abs(numericalGrad));
                if (relError > maxError) {
                    maxError = relError;
                }
                if (relError > maxRelError || Double.isNaN(relError)) {
                    double absError = Math.abs(analyticGrad - numericalGrad);
                    if (absError < minAbsError) {
                        if (config.isPrint()) {
                            log.info("Param " + i + " (" + s + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + "; absolute error = " + absError + " < minAbsoluteError = " + minAbsError);
                        }
                    } else {
                        if (config.isPrint()) {
                            log.info("Param " + i + " (" + s + strIdx + ") FAILED: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError + ", absError=" + absError + ", scorePlus=" + scorePlus + ", scoreMinus= " + scoreMinus);
                        }
                        if (config.isExitOnFirstFailure()) {
                            return false;
                        }
                        ++totalNFailures;
                    }
                } else if (config.isPrint()) {
                    log.info("Param " + i + " (" + s + strIdx + ") passed: grad= " + analyticGrad + ", numericalGrad= " + numericalGrad + ", relError= " + relError);
                }
                ++i;
            }
        }
        return totalNFailures == 0;
    }

    public static void validateInternalState(SameDiff sd, boolean generateAndCheckGradFn) {
        DifferentialFunction[] dfs = sd.ops();
        List<SDVariable> vars = sd.variables();
        HashSet<String> varSetStr = new HashSet<String>();
        for (SDVariable sDVariable : vars) {
            if (varSetStr.contains(sDVariable.name())) {
                throw new IllegalStateException("Variable with name " + sDVariable.name() + " already encountered");
            }
            varSetStr.add(sDVariable.name());
        }
        Preconditions.checkState(vars.size() == varSetStr.size(), "Duplicate variables in variables() list");
        Map<String, SameDiffOp> ops = sd.getOps();
        Preconditions.checkState(dfs.length == ops.size(), "All functions not present in incomingArgsReverse");
        for (DifferentialFunction df : dfs) {
            Preconditions.checkState(ops.containsKey(df.getOwnName()), df.getOwnName() + " not present in ops map");
            SameDiffOp sameDiffOp = ops.get(df.getOwnName());
            List<String> str = sameDiffOp.getInputsToOp();
            if (str != null) {
                for (String s : str) {
                    Preconditions.checkState(varSetStr.contains(s), "Variable " + s + " in op inputs not a known variable name");
                }
            }
            if ((str = sameDiffOp.getOutputsOfOp()) == null) continue;
            for (String s : str) {
                Preconditions.checkState(varSetStr.contains(s), "Variable " + s + " in op outputs not a known variable name");
            }
        }
        HashMap<String, String> hashMap = new HashMap<String, String>();
        for (Map.Entry<String, SameDiffOp> e : ops.entrySet()) {
            List<String> varNames = e.getValue().getOutputsOfOp();
            if (varNames == null) continue;
            for (String s : varNames) {
                if (hashMap.containsKey(s)) {
                    throw new IllegalStateException("Already saw variable \"" + s + "\" as output for op \"" + (String)hashMap.get(s) + "\": expected variables to be present as an output only once; also seen as output for op \"" + e.getKey() + "\"");
                }
                hashMap.put(s, e.getKey());
            }
        }
        PatriciaTrie<Variable> variableMap = sd.getVariables();
        Preconditions.checkState(vars.size() == variableMap.size(), "Variable map size check failed");
        for (Map.Entry e : variableMap.entrySet()) {
            Preconditions.checkState(((String)e.getKey()).equals(((Variable)e.getValue()).getVariable().name()), "Name not equal");
        }
        if (generateAndCheckGradFn) {
            if (sd.getFunction("grad") == null) {
                sd.createGradFunction();
            }
            SameDiff gradFn = sd.getFunction("grad");
            GradCheckUtil.validateInternalState(gradFn, false);
            for (DifferentialFunction dfOrig : dfs) {
                Preconditions.checkNotNull(gradFn.getOpById(dfOrig.getOwnName()), "DifferentialFunction " + dfOrig.getOwnName() + " from original SameDiff instance not present in grad fn");
            }
        }
    }

    private static <T> T getObject(String fieldName, Object from, Class<?> fromClass) {
        try {
            Field f = fromClass.getDeclaredField(fieldName);
            f.setAccessible(true);
            return (T)f.get(from);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    public static class ActGradConfig {
        private SameDiff sd;
        private Map<String, INDArray> placeholderValues;
        private List<String> activationGradsToCheck;
        private double eps;
        private double maxRelError;
        private double minAbsError;
        private boolean print;
        boolean exitOnFirstFailure;
        private boolean skipValidation;
        private boolean debugMode;
        private Set<String> skipVariables;
        private Map<String, INDArray> gradCheckMask;
        int maxPerParam;
        private Subset subset;

        private static double $default$eps() {
            return 1.0E-5;
        }

        private static double $default$maxRelError() {
            return 1.0E-5;
        }

        private static double $default$minAbsError() {
            return 1.0E-6;
        }

        private static boolean $default$print() {
            return false;
        }

        private static boolean $default$exitOnFirstFailure() {
            return false;
        }

        private static boolean $default$skipValidation() {
            return false;
        }

        private static boolean $default$debugMode() {
            return false;
        }

        ActGradConfig(SameDiff sd, Map<String, INDArray> placeholderValues, List<String> activationGradsToCheck, double eps, double maxRelError, double minAbsError, boolean print, boolean exitOnFirstFailure, boolean skipValidation, boolean debugMode, Set<String> skipVariables, Map<String, INDArray> gradCheckMask, int maxPerParam, Subset subset) {
            this.sd = sd;
            this.placeholderValues = placeholderValues;
            this.activationGradsToCheck = activationGradsToCheck;
            this.eps = eps;
            this.maxRelError = maxRelError;
            this.minAbsError = minAbsError;
            this.print = print;
            this.exitOnFirstFailure = exitOnFirstFailure;
            this.skipValidation = skipValidation;
            this.debugMode = debugMode;
            this.skipVariables = skipVariables;
            this.gradCheckMask = gradCheckMask;
            this.maxPerParam = maxPerParam;
            this.subset = subset;
        }

        public static ActGradConfigBuilder builder() {
            return new ActGradConfigBuilder();
        }

        public SameDiff getSd() {
            return this.sd;
        }

        public Map<String, INDArray> getPlaceholderValues() {
            return this.placeholderValues;
        }

        public List<String> getActivationGradsToCheck() {
            return this.activationGradsToCheck;
        }

        public double getEps() {
            return this.eps;
        }

        public double getMaxRelError() {
            return this.maxRelError;
        }

        public double getMinAbsError() {
            return this.minAbsError;
        }

        public boolean isPrint() {
            return this.print;
        }

        public boolean isExitOnFirstFailure() {
            return this.exitOnFirstFailure;
        }

        public boolean isSkipValidation() {
            return this.skipValidation;
        }

        public boolean isDebugMode() {
            return this.debugMode;
        }

        public Set<String> getSkipVariables() {
            return this.skipVariables;
        }

        public Map<String, INDArray> getGradCheckMask() {
            return this.gradCheckMask;
        }

        public int getMaxPerParam() {
            return this.maxPerParam;
        }

        public Subset getSubset() {
            return this.subset;
        }

        public void setSd(SameDiff sd) {
            this.sd = sd;
        }

        public void setPlaceholderValues(Map<String, INDArray> placeholderValues) {
            this.placeholderValues = placeholderValues;
        }

        public void setActivationGradsToCheck(List<String> activationGradsToCheck) {
            this.activationGradsToCheck = activationGradsToCheck;
        }

        public void setEps(double eps) {
            this.eps = eps;
        }

        public void setMaxRelError(double maxRelError) {
            this.maxRelError = maxRelError;
        }

        public void setMinAbsError(double minAbsError) {
            this.minAbsError = minAbsError;
        }

        public void setPrint(boolean print) {
            this.print = print;
        }

        public void setExitOnFirstFailure(boolean exitOnFirstFailure) {
            this.exitOnFirstFailure = exitOnFirstFailure;
        }

        public void setSkipValidation(boolean skipValidation) {
            this.skipValidation = skipValidation;
        }

        public void setDebugMode(boolean debugMode) {
            this.debugMode = debugMode;
        }

        public void setSkipVariables(Set<String> skipVariables) {
            this.skipVariables = skipVariables;
        }

        public void setGradCheckMask(Map<String, INDArray> gradCheckMask) {
            this.gradCheckMask = gradCheckMask;
        }

        public void setMaxPerParam(int maxPerParam) {
            this.maxPerParam = maxPerParam;
        }

        public void setSubset(Subset subset) {
            this.subset = subset;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof ActGradConfig)) {
                return false;
            }
            ActGradConfig other = (ActGradConfig)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (Double.compare(this.getEps(), other.getEps()) != 0) {
                return false;
            }
            if (Double.compare(this.getMaxRelError(), other.getMaxRelError()) != 0) {
                return false;
            }
            if (Double.compare(this.getMinAbsError(), other.getMinAbsError()) != 0) {
                return false;
            }
            if (this.isPrint() != other.isPrint()) {
                return false;
            }
            if (this.isExitOnFirstFailure() != other.isExitOnFirstFailure()) {
                return false;
            }
            if (this.isSkipValidation() != other.isSkipValidation()) {
                return false;
            }
            if (this.isDebugMode() != other.isDebugMode()) {
                return false;
            }
            if (this.getMaxPerParam() != other.getMaxPerParam()) {
                return false;
            }
            SameDiff this$sd = this.getSd();
            SameDiff other$sd = other.getSd();
            if (this$sd == null ? other$sd != null : !((Object)this$sd).equals(other$sd)) {
                return false;
            }
            Map<String, INDArray> this$placeholderValues = this.getPlaceholderValues();
            Map<String, INDArray> other$placeholderValues = other.getPlaceholderValues();
            if (this$placeholderValues == null ? other$placeholderValues != null : !((Object)this$placeholderValues).equals(other$placeholderValues)) {
                return false;
            }
            List<String> this$activationGradsToCheck = this.getActivationGradsToCheck();
            List<String> other$activationGradsToCheck = other.getActivationGradsToCheck();
            if (this$activationGradsToCheck == null ? other$activationGradsToCheck != null : !((Object)this$activationGradsToCheck).equals(other$activationGradsToCheck)) {
                return false;
            }
            Set<String> this$skipVariables = this.getSkipVariables();
            Set<String> other$skipVariables = other.getSkipVariables();
            if (this$skipVariables == null ? other$skipVariables != null : !((Object)this$skipVariables).equals(other$skipVariables)) {
                return false;
            }
            Map<String, INDArray> this$gradCheckMask = this.getGradCheckMask();
            Map<String, INDArray> other$gradCheckMask = other.getGradCheckMask();
            if (this$gradCheckMask == null ? other$gradCheckMask != null : !((Object)this$gradCheckMask).equals(other$gradCheckMask)) {
                return false;
            }
            Subset this$subset = this.getSubset();
            Subset other$subset = other.getSubset();
            return !(this$subset == null ? other$subset != null : !((Object)((Object)this$subset)).equals((Object)other$subset));
        }

        protected boolean canEqual(Object other) {
            return other instanceof ActGradConfig;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            long $eps = Double.doubleToLongBits(this.getEps());
            result = result * 59 + (int)($eps >>> 32 ^ $eps);
            long $maxRelError = Double.doubleToLongBits(this.getMaxRelError());
            result = result * 59 + (int)($maxRelError >>> 32 ^ $maxRelError);
            long $minAbsError = Double.doubleToLongBits(this.getMinAbsError());
            result = result * 59 + (int)($minAbsError >>> 32 ^ $minAbsError);
            result = result * 59 + (this.isPrint() ? 79 : 97);
            result = result * 59 + (this.isExitOnFirstFailure() ? 79 : 97);
            result = result * 59 + (this.isSkipValidation() ? 79 : 97);
            result = result * 59 + (this.isDebugMode() ? 79 : 97);
            result = result * 59 + this.getMaxPerParam();
            SameDiff $sd = this.getSd();
            result = result * 59 + ($sd == null ? 43 : ((Object)$sd).hashCode());
            Map<String, INDArray> $placeholderValues = this.getPlaceholderValues();
            result = result * 59 + ($placeholderValues == null ? 43 : ((Object)$placeholderValues).hashCode());
            List<String> $activationGradsToCheck = this.getActivationGradsToCheck();
            result = result * 59 + ($activationGradsToCheck == null ? 43 : ((Object)$activationGradsToCheck).hashCode());
            Set<String> $skipVariables = this.getSkipVariables();
            result = result * 59 + ($skipVariables == null ? 43 : ((Object)$skipVariables).hashCode());
            Map<String, INDArray> $gradCheckMask = this.getGradCheckMask();
            result = result * 59 + ($gradCheckMask == null ? 43 : ((Object)$gradCheckMask).hashCode());
            Subset $subset = this.getSubset();
            result = result * 59 + ($subset == null ? 43 : ((Object)((Object)$subset)).hashCode());
            return result;
        }

        public String toString() {
            return "GradCheckUtil.ActGradConfig(sd=" + this.getSd() + ", placeholderValues=" + this.getPlaceholderValues() + ", activationGradsToCheck=" + this.getActivationGradsToCheck() + ", eps=" + this.getEps() + ", maxRelError=" + this.getMaxRelError() + ", minAbsError=" + this.getMinAbsError() + ", print=" + this.isPrint() + ", exitOnFirstFailure=" + this.isExitOnFirstFailure() + ", skipValidation=" + this.isSkipValidation() + ", debugMode=" + this.isDebugMode() + ", skipVariables=" + this.getSkipVariables() + ", gradCheckMask=" + this.getGradCheckMask() + ", maxPerParam=" + this.getMaxPerParam() + ", subset=" + (Object)((Object)this.getSubset()) + ")";
        }

        public static class ActGradConfigBuilder {
            private SameDiff sd;
            private Map<String, INDArray> placeholderValues;
            private List<String> activationGradsToCheck;
            private boolean eps$set;
            private double eps$value;
            private boolean maxRelError$set;
            private double maxRelError$value;
            private boolean minAbsError$set;
            private double minAbsError$value;
            private boolean print$set;
            private boolean print$value;
            private boolean exitOnFirstFailure$set;
            private boolean exitOnFirstFailure$value;
            private boolean skipValidation$set;
            private boolean skipValidation$value;
            private boolean debugMode$set;
            private boolean debugMode$value;
            private Set<String> skipVariables;
            private Map<String, INDArray> gradCheckMask;
            private int maxPerParam;
            private Subset subset;

            ActGradConfigBuilder() {
            }

            public ActGradConfigBuilder sd(SameDiff sd) {
                this.sd = sd;
                return this;
            }

            public ActGradConfigBuilder placeholderValues(Map<String, INDArray> placeholderValues) {
                this.placeholderValues = placeholderValues;
                return this;
            }

            public ActGradConfigBuilder activationGradsToCheck(List<String> activationGradsToCheck) {
                this.activationGradsToCheck = activationGradsToCheck;
                return this;
            }

            public ActGradConfigBuilder eps(double eps) {
                this.eps$value = eps;
                this.eps$set = true;
                return this;
            }

            public ActGradConfigBuilder maxRelError(double maxRelError) {
                this.maxRelError$value = maxRelError;
                this.maxRelError$set = true;
                return this;
            }

            public ActGradConfigBuilder minAbsError(double minAbsError) {
                this.minAbsError$value = minAbsError;
                this.minAbsError$set = true;
                return this;
            }

            public ActGradConfigBuilder print(boolean print) {
                this.print$value = print;
                this.print$set = true;
                return this;
            }

            public ActGradConfigBuilder exitOnFirstFailure(boolean exitOnFirstFailure) {
                this.exitOnFirstFailure$value = exitOnFirstFailure;
                this.exitOnFirstFailure$set = true;
                return this;
            }

            public ActGradConfigBuilder skipValidation(boolean skipValidation) {
                this.skipValidation$value = skipValidation;
                this.skipValidation$set = true;
                return this;
            }

            public ActGradConfigBuilder debugMode(boolean debugMode) {
                this.debugMode$value = debugMode;
                this.debugMode$set = true;
                return this;
            }

            public ActGradConfigBuilder skipVariables(Set<String> skipVariables) {
                this.skipVariables = skipVariables;
                return this;
            }

            public ActGradConfigBuilder gradCheckMask(Map<String, INDArray> gradCheckMask) {
                this.gradCheckMask = gradCheckMask;
                return this;
            }

            public ActGradConfigBuilder maxPerParam(int maxPerParam) {
                this.maxPerParam = maxPerParam;
                return this;
            }

            public ActGradConfigBuilder subset(Subset subset) {
                this.subset = subset;
                return this;
            }

            public ActGradConfig build() {
                double eps$value = this.eps$value;
                if (!this.eps$set) {
                    eps$value = ActGradConfig.$default$eps();
                }
                double maxRelError$value = this.maxRelError$value;
                if (!this.maxRelError$set) {
                    maxRelError$value = ActGradConfig.$default$maxRelError();
                }
                double minAbsError$value = this.minAbsError$value;
                if (!this.minAbsError$set) {
                    minAbsError$value = ActGradConfig.$default$minAbsError();
                }
                boolean print$value = this.print$value;
                if (!this.print$set) {
                    print$value = ActGradConfig.$default$print();
                }
                boolean exitOnFirstFailure$value = this.exitOnFirstFailure$value;
                if (!this.exitOnFirstFailure$set) {
                    exitOnFirstFailure$value = ActGradConfig.$default$exitOnFirstFailure();
                }
                boolean skipValidation$value = this.skipValidation$value;
                if (!this.skipValidation$set) {
                    skipValidation$value = ActGradConfig.$default$skipValidation();
                }
                boolean debugMode$value = this.debugMode$value;
                if (!this.debugMode$set) {
                    debugMode$value = ActGradConfig.$default$debugMode();
                }
                return new ActGradConfig(this.sd, this.placeholderValues, this.activationGradsToCheck, eps$value, maxRelError$value, minAbsError$value, print$value, exitOnFirstFailure$value, skipValidation$value, debugMode$value, this.skipVariables, this.gradCheckMask, this.maxPerParam, this.subset);
            }

            public String toString() {
                return "GradCheckUtil.ActGradConfig.ActGradConfigBuilder(sd=" + this.sd + ", placeholderValues=" + this.placeholderValues + ", activationGradsToCheck=" + this.activationGradsToCheck + ", eps$value=" + this.eps$value + ", maxRelError$value=" + this.maxRelError$value + ", minAbsError$value=" + this.minAbsError$value + ", print$value=" + this.print$value + ", exitOnFirstFailure$value=" + this.exitOnFirstFailure$value + ", skipValidation$value=" + this.skipValidation$value + ", debugMode$value=" + this.debugMode$value + ", skipVariables=" + this.skipVariables + ", gradCheckMask=" + this.gradCheckMask + ", maxPerParam=" + this.maxPerParam + ", subset=" + (Object)((Object)this.subset) + ")";
            }
        }
    }

    public static enum Subset {
        EVERY_N,
        RANDOM;

    }
}

