/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.validation;

import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.validation.functions.EqualityFn;
import org.nd4j.autodiff.validation.functions.RelErrorFn;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.function.Function;
import org.nd4j.linalg.api.ndarray.INDArray;

public class TestCase {
    public static final boolean GC_DEFAULT_PRINT = false;
    public static final boolean GC_DEFAULT_EXIT_FIRST_FAILURE = false;
    public static final boolean GC_DEFAULT_DEBUG_MODE = false;
    public static final double GC_DEFAULT_EPS = 1.0E-5;
    public static final double GC_DEFAULT_MAX_REL_ERROR = 1.0E-5;
    public static final double GC_DEFAULT_MIN_ABS_ERROR = 1.0E-6;
    private SameDiff sameDiff;
    private String testName;
    private Map<String, Function<INDArray, String>> fwdTestFns;
    private Map<String, INDArray> placeholderValues;
    private boolean gradientCheck = true;
    private boolean gradCheckPrint = false;
    private boolean gradCheckDefaultExitFirstFailure = false;
    private boolean gradCheckDebugMode = false;
    private double gradCheckEpsilon = 1.0E-5;
    private double gradCheckMaxRelativeError = 1.0E-5;
    private double gradCheckMinAbsError = 1.0E-6;
    private Set<String> gradCheckSkipVariables;
    private Map<String, INDArray> gradCheckMask;
    private TestSerialization testFlatBufferSerialization = TestSerialization.BOTH;

    public TestCase(SameDiff sameDiff) {
        this.sameDiff = sameDiff;
    }

    public TestCase expectedOutput(@NonNull String name, @NonNull INDArray expected, double eps) {
        if (name == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (expected == null) {
            throw new NullPointerException("expected is marked non-null but is null");
        }
        return this.expected(name, (Function<INDArray, String>)new EqualityFn(expected, eps));
    }

    public TestCase expectedOutput(@NonNull String name, @NonNull INDArray expected) {
        if (name == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (expected == null) {
            throw new NullPointerException("expected is marked non-null but is null");
        }
        return this.expectedOutput(name, expected, 0.001);
    }

    public TestCase expectedOutputRelError(@NonNull String name, @NonNull INDArray expected, double maxRelError, double minAbsError) {
        if (name == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (expected == null) {
            throw new NullPointerException("expected is marked non-null but is null");
        }
        return this.expected(name, (Function<INDArray, String>)new RelErrorFn(expected, maxRelError, minAbsError));
    }

    public TestCase expected(@NonNull SDVariable var, @NonNull INDArray output) {
        if (var == null) {
            throw new NullPointerException("var is marked non-null but is null");
        }
        if (output == null) {
            throw new NullPointerException("output is marked non-null but is null");
        }
        return this.expected(var.name(), output);
    }

    public TestCase expected(@NonNull String name, @NonNull INDArray output) {
        if (name == null) {
            throw new NullPointerException("name is marked non-null but is null");
        }
        if (output == null) {
            throw new NullPointerException("output is marked non-null but is null");
        }
        return this.expectedOutput(name, output);
    }

    public TestCase expected(SDVariable var, Function<INDArray, String> validationFn) {
        return this.expected(var.name(), validationFn);
    }

    public TestCase expected(String name, Function<INDArray, String> validationFn) {
        if (this.fwdTestFns == null) {
            this.fwdTestFns = new LinkedHashMap<String, Function<INDArray, String>>();
        }
        this.fwdTestFns.put(name, validationFn);
        return this;
    }

    public Set<String> gradCheckSkipVariables() {
        return this.gradCheckSkipVariables;
    }

    public Map<String, INDArray> gradCheckMask() {
        return this.gradCheckMask;
    }

    public TestCase gradCheckSkipVariables(String ... toSkip) {
        if (this.gradCheckSkipVariables == null) {
            this.gradCheckSkipVariables = new LinkedHashSet<String>();
        }
        Collections.addAll(this.gradCheckSkipVariables, toSkip);
        return this;
    }

    public TestCase placeholderValues(Map<String, INDArray> placeholderValues) {
        this.placeholderValues = placeholderValues;
        return this;
    }

    public TestCase placeholderValue(String variable, INDArray value) {
        if (this.placeholderValues == null) {
            this.placeholderValues = new HashMap<String, INDArray>();
        }
        this.placeholderValues.put(variable, value);
        return this;
    }

    public void assertConfigValid() {
        Preconditions.checkNotNull((Object)this.sameDiff, "SameDiff instance cannot be null%s", (Object)this.testNameErrMsg());
        Preconditions.checkState(this.gradientCheck || this.fwdTestFns != null && this.fwdTestFns.size() > 0, "Test case is empty: nothing to test (gradientCheck == false and no expected results available)%s", (Object)this.testNameErrMsg());
    }

    public String testNameErrMsg() {
        if (this.testName == null) {
            return "";
        }
        return " - Test name: \"" + this.testName + "\"";
    }

    public TestCase sameDiff(SameDiff sameDiff) {
        this.sameDiff = sameDiff;
        return this;
    }

    public TestCase testName(String testName) {
        this.testName = testName;
        return this;
    }

    public TestCase fwdTestFns(Map<String, Function<INDArray, String>> fwdTestFns) {
        this.fwdTestFns = fwdTestFns;
        return this;
    }

    public TestCase gradientCheck(boolean gradientCheck) {
        this.gradientCheck = gradientCheck;
        return this;
    }

    public TestCase gradCheckPrint(boolean gradCheckPrint) {
        this.gradCheckPrint = gradCheckPrint;
        return this;
    }

    public TestCase gradCheckDefaultExitFirstFailure(boolean gradCheckDefaultExitFirstFailure) {
        this.gradCheckDefaultExitFirstFailure = gradCheckDefaultExitFirstFailure;
        return this;
    }

    public TestCase gradCheckDebugMode(boolean gradCheckDebugMode) {
        this.gradCheckDebugMode = gradCheckDebugMode;
        return this;
    }

    public TestCase gradCheckEpsilon(double gradCheckEpsilon) {
        this.gradCheckEpsilon = gradCheckEpsilon;
        return this;
    }

    public TestCase gradCheckMaxRelativeError(double gradCheckMaxRelativeError) {
        this.gradCheckMaxRelativeError = gradCheckMaxRelativeError;
        return this;
    }

    public TestCase gradCheckMinAbsError(double gradCheckMinAbsError) {
        this.gradCheckMinAbsError = gradCheckMinAbsError;
        return this;
    }

    public TestCase gradCheckMask(Map<String, INDArray> gradCheckMask) {
        this.gradCheckMask = gradCheckMask;
        return this;
    }

    public TestCase testFlatBufferSerialization(TestSerialization testFlatBufferSerialization) {
        this.testFlatBufferSerialization = testFlatBufferSerialization;
        return this;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof TestCase)) {
            return false;
        }
        TestCase other = (TestCase)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.gradientCheck() != other.gradientCheck()) {
            return false;
        }
        if (this.gradCheckPrint() != other.gradCheckPrint()) {
            return false;
        }
        if (this.gradCheckDefaultExitFirstFailure() != other.gradCheckDefaultExitFirstFailure()) {
            return false;
        }
        if (this.gradCheckDebugMode() != other.gradCheckDebugMode()) {
            return false;
        }
        if (Double.compare(this.gradCheckEpsilon(), other.gradCheckEpsilon()) != 0) {
            return false;
        }
        if (Double.compare(this.gradCheckMaxRelativeError(), other.gradCheckMaxRelativeError()) != 0) {
            return false;
        }
        if (Double.compare(this.gradCheckMinAbsError(), other.gradCheckMinAbsError()) != 0) {
            return false;
        }
        SameDiff this$sameDiff = this.sameDiff();
        SameDiff other$sameDiff = other.sameDiff();
        if (this$sameDiff == null ? other$sameDiff != null : !((Object)this$sameDiff).equals(other$sameDiff)) {
            return false;
        }
        String this$testName = this.testName();
        String other$testName = other.testName();
        if (this$testName == null ? other$testName != null : !this$testName.equals(other$testName)) {
            return false;
        }
        Map<String, Function<INDArray, String>> this$fwdTestFns = this.fwdTestFns();
        Map<String, Function<INDArray, String>> other$fwdTestFns = other.fwdTestFns();
        if (this$fwdTestFns == null ? other$fwdTestFns != null : !((Object)this$fwdTestFns).equals(other$fwdTestFns)) {
            return false;
        }
        Map<String, INDArray> this$placeholderValues = this.placeholderValues();
        Map<String, INDArray> other$placeholderValues = other.placeholderValues();
        if (this$placeholderValues == null ? other$placeholderValues != null : !((Object)this$placeholderValues).equals(other$placeholderValues)) {
            return false;
        }
        Set<String> this$gradCheckSkipVariables = this.gradCheckSkipVariables();
        Set<String> other$gradCheckSkipVariables = other.gradCheckSkipVariables();
        if (this$gradCheckSkipVariables == null ? other$gradCheckSkipVariables != null : !((Object)this$gradCheckSkipVariables).equals(other$gradCheckSkipVariables)) {
            return false;
        }
        Map<String, INDArray> this$gradCheckMask = this.gradCheckMask();
        Map<String, INDArray> other$gradCheckMask = other.gradCheckMask();
        if (this$gradCheckMask == null ? other$gradCheckMask != null : !((Object)this$gradCheckMask).equals(other$gradCheckMask)) {
            return false;
        }
        TestSerialization this$testFlatBufferSerialization = this.testFlatBufferSerialization();
        TestSerialization other$testFlatBufferSerialization = other.testFlatBufferSerialization();
        return !(this$testFlatBufferSerialization == null ? other$testFlatBufferSerialization != null : !((Object)((Object)this$testFlatBufferSerialization)).equals((Object)other$testFlatBufferSerialization));
    }

    protected boolean canEqual(Object other) {
        return other instanceof TestCase;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + (this.gradientCheck() ? 79 : 97);
        result = result * 59 + (this.gradCheckPrint() ? 79 : 97);
        result = result * 59 + (this.gradCheckDefaultExitFirstFailure() ? 79 : 97);
        result = result * 59 + (this.gradCheckDebugMode() ? 79 : 97);
        long $gradCheckEpsilon = Double.doubleToLongBits(this.gradCheckEpsilon());
        result = result * 59 + (int)($gradCheckEpsilon >>> 32 ^ $gradCheckEpsilon);
        long $gradCheckMaxRelativeError = Double.doubleToLongBits(this.gradCheckMaxRelativeError());
        result = result * 59 + (int)($gradCheckMaxRelativeError >>> 32 ^ $gradCheckMaxRelativeError);
        long $gradCheckMinAbsError = Double.doubleToLongBits(this.gradCheckMinAbsError());
        result = result * 59 + (int)($gradCheckMinAbsError >>> 32 ^ $gradCheckMinAbsError);
        SameDiff $sameDiff = this.sameDiff();
        result = result * 59 + ($sameDiff == null ? 43 : ((Object)$sameDiff).hashCode());
        String $testName = this.testName();
        result = result * 59 + ($testName == null ? 43 : $testName.hashCode());
        Map<String, Function<INDArray, String>> $fwdTestFns = this.fwdTestFns();
        result = result * 59 + ($fwdTestFns == null ? 43 : ((Object)$fwdTestFns).hashCode());
        Map<String, INDArray> $placeholderValues = this.placeholderValues();
        result = result * 59 + ($placeholderValues == null ? 43 : ((Object)$placeholderValues).hashCode());
        Set<String> $gradCheckSkipVariables = this.gradCheckSkipVariables();
        result = result * 59 + ($gradCheckSkipVariables == null ? 43 : ((Object)$gradCheckSkipVariables).hashCode());
        Map<String, INDArray> $gradCheckMask = this.gradCheckMask();
        result = result * 59 + ($gradCheckMask == null ? 43 : ((Object)$gradCheckMask).hashCode());
        TestSerialization $testFlatBufferSerialization = this.testFlatBufferSerialization();
        result = result * 59 + ($testFlatBufferSerialization == null ? 43 : ((Object)((Object)$testFlatBufferSerialization)).hashCode());
        return result;
    }

    public String toString() {
        return "TestCase(sameDiff=" + this.sameDiff() + ", testName=" + this.testName() + ", fwdTestFns=" + this.fwdTestFns() + ", placeholderValues=" + this.placeholderValues() + ", gradientCheck=" + this.gradientCheck() + ", gradCheckPrint=" + this.gradCheckPrint() + ", gradCheckDefaultExitFirstFailure=" + this.gradCheckDefaultExitFirstFailure() + ", gradCheckDebugMode=" + this.gradCheckDebugMode() + ", gradCheckEpsilon=" + this.gradCheckEpsilon() + ", gradCheckMaxRelativeError=" + this.gradCheckMaxRelativeError() + ", gradCheckMinAbsError=" + this.gradCheckMinAbsError() + ", gradCheckSkipVariables=" + this.gradCheckSkipVariables() + ", gradCheckMask=" + this.gradCheckMask() + ", testFlatBufferSerialization=" + (Object)((Object)this.testFlatBufferSerialization()) + ")";
    }

    public SameDiff sameDiff() {
        return this.sameDiff;
    }

    public String testName() {
        return this.testName;
    }

    public Map<String, Function<INDArray, String>> fwdTestFns() {
        return this.fwdTestFns;
    }

    public Map<String, INDArray> placeholderValues() {
        return this.placeholderValues;
    }

    public boolean gradientCheck() {
        return this.gradientCheck;
    }

    public boolean gradCheckPrint() {
        return this.gradCheckPrint;
    }

    public boolean gradCheckDefaultExitFirstFailure() {
        return this.gradCheckDefaultExitFirstFailure;
    }

    public boolean gradCheckDebugMode() {
        return this.gradCheckDebugMode;
    }

    public double gradCheckEpsilon() {
        return this.gradCheckEpsilon;
    }

    public double gradCheckMaxRelativeError() {
        return this.gradCheckMaxRelativeError;
    }

    public double gradCheckMinAbsError() {
        return this.gradCheckMinAbsError;
    }

    public TestSerialization testFlatBufferSerialization() {
        return this.testFlatBufferSerialization;
    }

    public static enum TestSerialization {
        BEFORE_EXEC,
        AFTER_EXEC,
        BOTH,
        NONE;

    }
}

