/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.functions;

import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.shade.jackson.annotation.JsonIgnore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public abstract class DifferentialFunction {
    private static final Logger log = LoggerFactory.getLogger(DifferentialFunction.class);
    @JsonIgnore
    protected SameDiff sameDiff;
    @JsonIgnore
    protected boolean inPlace;
    @JsonIgnore
    protected INDArray scalarValue;
    @JsonIgnore
    protected int[] dimensions;
    @JsonIgnore
    protected Object[] extraArgs;
    @JsonIgnore
    protected String ownName;
    @JsonIgnore
    protected boolean ownNameSetWithDefault = false;

    public DifferentialFunction() {
        this(true);
    }

    public DifferentialFunction(boolean sameDiff) {
        if (sameDiff) {
            this.setInstanceId();
        }
    }

    public DifferentialFunction(SameDiff sameDiff, NodeDef nodeDef, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        this.sameDiff = sameDiff;
        this.setInstanceId();
        this.initFromTensorFlow(nodeDef, sameDiff, attributesForNode, graph);
    }

    public DifferentialFunction(SameDiff sameDiff, Onnx.NodeProto node, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
        this.sameDiff = sameDiff;
        this.setInstanceId();
        this.initFromOnnx(node, sameDiff, attributesForNode, graph);
    }

    public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
        return Collections.emptyMap();
    }

    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        return Collections.emptyMap();
    }

    public Map<String, Object> propertiesForFunction() {
        Map<String, Field> fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
        LinkedHashMap<String, Object> ret = new LinkedHashMap<String, Object>();
        Preconditions.checkNotNull(fields, "DifferentialFunctionClassHolder returned null fields for %s - op has not been added to ImportClassMapping?", this.getClass());
        for (Map.Entry<String, Field> entry : fields.entrySet()) {
            try {
                ret.put(entry.getKey(), fields.get(entry.getKey()).get(this));
            }
            catch (IllegalAccessException e) {
                throw new RuntimeException("Unable to get property for field: " + entry.getKey(), e);
            }
        }
        return ret;
    }

    public void configureWithSameDiff(SameDiff sameDiff) {
    }

    public void setPropertiesForFunction(Map<String, Object> properties) {
        Map<String, Field> fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
        for (String s : properties.keySet()) {
            Field f = fields.get(s);
            if (f == null) {
                log.warn("No fields found for property name {} for class {}", (Object)s, (Object)this.getClass().getName());
                continue;
            }
            this.setValueFor(f, properties.get(s));
        }
    }

    protected Boolean getBooleanFromProperty(String propertyName, Map<String, Object> properties) {
        if (properties.containsKey(propertyName)) {
            Boolean value = (Boolean)properties.get(propertyName);
            return value;
        }
        return null;
    }

    protected String getStringFromProperty(String propertyName, Map<String, Object> properties) {
        if (properties.containsKey(propertyName)) {
            String value = (String)properties.get(propertyName);
            return value;
        }
        return null;
    }

    protected Integer getIntValueFromProperty(String propertyName, Map<String, Object> properties) {
        if (properties.containsKey(propertyName)) {
            Number value = (Number)properties.get(propertyName);
            return value.intValue();
        }
        return null;
    }

    protected Long getLongValueFromProperty(String propertyName, Map<String, Object> properties) {
        if (properties.containsKey(propertyName)) {
            Number value = (Number)properties.get(propertyName);
            return value.longValue();
        }
        return null;
    }

    protected Double getDoubleValueFromProperty(String propertyName, Map<String, Object> properties) {
        if (properties.containsKey(propertyName)) {
            Number value = (Number)properties.get(propertyName);
            return value.doubleValue();
        }
        return null;
    }

    public Object getValue(Field property) {
        try {
            return property.get(this);
        }
        catch (IllegalAccessException e) {
            log.error("", (Throwable)e);
            return null;
        }
    }

    public void setValueFor(Field target, Object value) {
        block29: {
            if (value == null && target.getType().isPrimitive()) {
                throw new ND4JIllegalStateException("Unable to set primitive field " + target + " of type " + target.getClass() + " using null value!");
            }
            if (value != null) {
                value = this.ensureProperType(target, value);
            }
            if (this.isConfigProperties()) {
                String propertyName = this.configFieldName();
                if (propertyName == null) {
                    propertyName = "config";
                }
                Field f = null;
                Class<?> currClass = this.getClass();
                try {
                    f = currClass.getDeclaredField(propertyName);
                }
                catch (NoSuchFieldException noSuchFieldException) {
                    // empty catch block
                }
                while (f == null && currClass.getSuperclass() != null) {
                    currClass = currClass.getSuperclass();
                    try {
                        f = currClass.getDeclaredField(propertyName);
                    }
                    catch (NoSuchFieldException noSuchFieldException) {}
                }
                if (f == null) {
                    throw new IllegalStateException("Could not find field \"" + propertyName + "\" for class " + this.getClass().getName());
                }
                try {
                    f.setAccessible(true);
                    Object o = f.get(this);
                    if (o == null) {
                        Class<?> c = f.getType();
                        try {
                            o = c.newInstance();
                        }
                        catch (InstantiationException e) {
                            throw new RuntimeException("Error creating new instance of configuration object type " + c.getName(), e);
                        }
                        f.set(this, o);
                    }
                    target.set(o, value);
                    break block29;
                }
                catch (IllegalAccessException e) {
                    throw new RuntimeException("Error setting configuration field \"" + propertyName + "\" for config field \"" + propertyName + "\" on class " + this.getClass().getName());
                }
            }
            try {
                Number value2;
                if (target.getType() == Float.TYPE && value instanceof Double) {
                    value = Float.valueOf(((Double)value).floatValue());
                }
                if (target.getType() == Character.TYPE && value instanceof Integer) {
                    value = Character.valueOf((char)((Integer)value).intValue());
                }
                if (target.getType() == Character.TYPE && value instanceof Long) {
                    value = Character.valueOf((char)((Long)value).intValue());
                }
                if (target.getType() == Integer.TYPE && value instanceof Long) {
                    value2 = (Long)value;
                    value = ((Long)value2).intValue();
                }
                if (target.getType().equals(Integer.class) && value instanceof Long) {
                    value2 = (Long)value;
                    value = ((Long)value2).intValue();
                }
                if (target.getType().equals(Long.class) && value instanceof Integer) {
                    value2 = (Integer)value;
                    value = ((Integer)value2).longValue();
                }
                if (target.getType().equals(Double.class) && value instanceof Long) {
                    value2 = (Long)value;
                    value = ((Long)value2).doubleValue();
                }
                if (target.getType().equals(Boolean.class) || target.getType().equals(Boolean.TYPE) && value instanceof Number) {
                    value2 = (Number)value;
                    value = value2.doubleValue() > 0.0;
                }
                if (target.getType().equals(DataType.class) && value instanceof Double) {
                    value2 = (Double)value;
                    int idxConverted = ((Double)value2).intValue();
                    value = DataType.values()[idxConverted];
                }
                if (target.getType().isEnum() && value instanceof Long || value instanceof Integer && !target.getType().equals(Integer.TYPE) && !target.getType().equals(Long.TYPE)) {
                    Object get;
                    Class<?> enumType = target.getType();
                    Method method = enumType.getMethod("values", new Class[0]);
                    method.setAccessible(true);
                    Object[] invoke = (Object[])method.invoke(null, new Object[0]);
                    Number number = (Number)value;
                    int idx = number.intValue();
                    value = get = invoke[idx];
                }
                target.set(this, value);
            }
            catch (Exception e) {
                throw new RuntimeException("Error setting property for function " + this.getClass().getName(), e);
            }
        }
    }

    private Object ensureProperType(Field targetType, Object value) {
        Class<?> valueType;
        Class<?> firstClass = targetType.getType();
        if (!firstClass.equals(valueType = value.getClass())) {
            if (firstClass.isEnum()) {
                if (valueType.equals(String.class)) {
                    ?[] enumConstants = firstClass.getEnumConstants();
                    for (int i = 0; i < enumConstants.length; ++i) {
                        if (!enumConstants[i].toString().equalsIgnoreCase((String)value)) continue;
                        return enumConstants[i];
                    }
                    throw new IllegalStateException("Could not find enum constant value for value \"" + value + "\" for enum class " + firstClass.getName());
                }
            } else {
                if (firstClass.equals(int[].class)) {
                    if (value instanceof Number) {
                        Number number = (Number)value;
                        value = number.intValue();
                    }
                    int otherValue = (Integer)value;
                    int[] setValue = new int[]{otherValue};
                    return setValue;
                }
                if (firstClass.equals(Integer[].class)) {
                    if (value instanceof Number) {
                        Number number = (Number)value;
                        value = number.intValue();
                    }
                    Integer otherValue = (Integer)value;
                    Integer[] setValue = new Integer[]{otherValue};
                    return setValue;
                }
                if (firstClass.equals(long[].class)) {
                    if (value instanceof Number) {
                        Number number = (Number)value;
                        value = number.longValue();
                    }
                    long otherValue = (Long)value;
                    long[] setValue = new long[]{otherValue};
                    return setValue;
                }
                if (firstClass.equals(Long[].class)) {
                    if (value instanceof Number) {
                        Number number = (Number)value;
                        value = number.longValue();
                    }
                    Long otherValue = (Long)value;
                    Long[] setValue = new Long[]{otherValue};
                    return setValue;
                }
                if (firstClass.equals(double[].class)) {
                    if (value instanceof Number) {
                        Number number = (Number)value;
                        value = number.doubleValue();
                    }
                    double otherValue = (Double)value;
                    double[] setValue = new double[]{otherValue};
                    return setValue;
                }
                if (firstClass.equals(Double[].class)) {
                    if (value instanceof Number) {
                        Number number = (Number)value;
                        value = number.doubleValue();
                    }
                    Double otherValue = (Double)value;
                    Double[] setValue = new Double[]{otherValue};
                    return setValue;
                }
                if (firstClass.equals(float[].class)) {
                    if (value instanceof Number) {
                        Number number = (Number)value;
                        value = Float.valueOf(number.floatValue());
                    }
                    float otherValue = ((Float)value).floatValue();
                    float[] setValue = new float[]{otherValue};
                    return setValue;
                }
                if (firstClass.equals(Float[].class)) {
                    if (value instanceof Number) {
                        Number number = (Number)value;
                        value = Float.valueOf(number.floatValue());
                    }
                    Float otherValue = (Float)value;
                    Float[] setValue = new Float[]{otherValue};
                    return setValue;
                }
            }
        }
        return value;
    }

    public boolean isConfigProperties() {
        return false;
    }

    public String configFieldName() {
        return null;
    }

    public DifferentialFunction(SameDiff sameDiff, boolean inPlace, Object[] extraArgs) {
        this.sameDiff = sameDiff;
        this.inPlace = inPlace;
        this.setInstanceId();
        this.extraArgs = extraArgs;
    }

    public DifferentialFunction(SameDiff sameDiff, Object[] extraArgs) {
        this.sameDiff = sameDiff;
        this.setInstanceId();
        this.extraArgs = extraArgs;
    }

    public DifferentialFunction(SameDiff sameDiff, SDVariable[] args) {
        this(sameDiff, false, args);
    }

    public DifferentialFunction(SameDiff sameDiff, boolean inPlace, SDVariable[] args) {
        this.sameDiff = sameDiff;
        this.inPlace = inPlace;
        this.setInstanceId();
        if (sameDiff != null && args != null) {
            sameDiff.addArgsFor(args, this);
        }
    }

    public void replaceArg(int i, SDVariable newArg) {
        if (this.sameDiff != null) {
            this.sameDiff.replaceArgFor(i, newArg, this);
        }
    }

    public SDVariable[] outputVariables() {
        return this.outputVariables(this.getOwnName() != null ? this.getOwnName() : this.opName());
    }

    public SDVariable outputVariable() {
        return this.outputVariables()[0];
    }

    public List<SDVariable> outputs() {
        SDVariable[] out = this.outputVariables();
        return out == null ? null : Arrays.asList(out);
    }

    public String[] outputVariablesNames() {
        SDVariable[] outputVars = this.outputVariables();
        String[] out = new String[outputVars.length];
        for (int i = 0; i < out.length; ++i) {
            out[i] = outputVars[i].name();
        }
        return out;
    }

    public abstract SDVariable[] outputVariables(String var1);

    public abstract List<SDVariable> doDiff(List<SDVariable> var1);

    public SDVariable[] args() {
        return this.sameDiff == null ? null : this.sameDiff.getInputVariablesForOp(this);
    }

    public SDVariable arg(int num) {
        SDVariable[] args = this.args();
        Preconditions.checkNotNull((Object)args, "Arguments are null for function %s", (Object)this.getOwnName());
        Preconditions.checkArgument(num >= 0 && num < args.length, "Invalid index: must be 0 to numArgs (0 <= idx < %s), got %s", args.length, num);
        return args[num];
    }

    public String[] argNames() {
        SDVariable[] args = this.args();
        String[] out = new String[args.length];
        for (int i = 0; i < args.length; ++i) {
            out[i] = args[i].name();
        }
        return out;
    }

    public SDVariable arg() {
        if (this.args() == null || this.args().length == 0) {
            return null;
        }
        return this.args()[0];
    }

    public List<SDVariable> diff(List<SDVariable> i_v1) {
        List<SDVariable> vals = this.doDiff(i_v1);
        if (vals == null) {
            throw new IllegalStateException("Error executing diff operation: doDiff returned null for op: " + this.opName());
        }
        SDVariable[] outputVars = this.args();
        boolean copied = false;
        for (int i = 0; i < vals.size(); ++i) {
            SDVariable gradVar;
            SDVariable grad;
            SDVariable var = outputVars[i];
            SDVariable sDVariable = grad = var.hasGradient() ? var.getGradient() : null;
            if (grad != null) {
                if (!copied) {
                    vals = new ArrayList<SDVariable>(vals);
                    copied = true;
                }
                gradVar = var.getSameDiff().math.add(grad, vals.get(i));
                vals.set(i, gradVar);
                this.sameDiff.setGradientForVariableName(var.name(), gradVar);
                continue;
            }
            gradVar = vals.get(i);
            if (this.sameDiff.hasVariable(var.name() + "-grad")) {
                this.sameDiff.getVariable(var.name() + "-grad").add(gradVar);
                continue;
            }
            this.sameDiff.updateVariableNameAndReference(gradVar, var.name() + "-grad");
            this.sameDiff.setGradientForVariableName(var.name(), gradVar);
        }
        return vals;
    }

    protected void setInstanceId() {
        if (this.ownName == null) {
            String n;
            this.ownNameSetWithDefault = true;
            this.ownName = this.sameDiff == null ? UUID.randomUUID().toString() : (n = this.sameDiff.getOpName(this.opName()));
            if (this.sameDiff != null) {
                this.sameDiff.putOpForId(this.ownName, this);
            }
        }
    }

    public String opName() {
        throw new UnsupportedOperationException();
    }

    public Op.Type opType() {
        throw new UnsupportedOperationException();
    }

    public int opNum() {
        throw new UnsupportedOperationException();
    }

    @JsonIgnore
    public INDArray getInputArgument(int index) {
        throw new UnsupportedOperationException("Not implemented");
    }

    public abstract void initFromTensorFlow(NodeDef var1, SameDiff var2, Map<String, AttrValue> var3, GraphDef var4);

    public abstract void initFromOnnx(Onnx.NodeProto var1, SameDiff var2, Map<String, Onnx.AttributeProto> var3, Onnx.GraphProto var4);

    public SDVariable larg() {
        SDVariable[] args = this.args();
        if (args == null || args.length == 0) {
            throw new ND4JIllegalStateException("No arguments found.");
        }
        return this.args()[0];
    }

    public SDVariable rarg() {
        SDVariable[] args = this.args();
        if (args == null || args.length != 2) {
            throw new ND4JIllegalStateException("In order to use this function, the number of arguments for this function must be 2.");
        }
        return args[1];
    }

    public DifferentialFunction dup() {
        return FlatBuffersMapper.cloneViaSerialize(this.sameDiff, this);
    }

    public List<LongShapeDescriptor> calculateOutputShape() {
        throw new ND4JIllegalStateException("Op type of " + this.getClass().getName() + "did not override calculateOutputShape() method leaked out for [" + this.opName() + "]");
    }

    public List<LongShapeDescriptor> calculateOutputShape(OpContext oc) {
        throw new ND4JIllegalStateException("Op type of " + this.getClass().getName() + " did not override calculateOutputShape(OpContext) method leaked out for [" + this.opName() + "]");
    }

    public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
        throw new UnsupportedOperationException("Op type of " + this.getClass().getName() + " and name " + this.toString() + " did not override  calculateOutputDataTypes()! This function has not been implemented for " + this.getClass().getName());
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        DifferentialFunction that = (DifferentialFunction)o;
        if (this.inPlace != that.inPlace) {
            return false;
        }
        if (this.scalarValue != null ? !this.scalarValue.equals(that.scalarValue) : that.scalarValue != null) {
            return false;
        }
        if (!Arrays.equals(this.dimensions, that.dimensions)) {
            return false;
        }
        return this.ownName != null ? this.ownName.equals(that.ownName) : that.ownName == null;
    }

    public int hashCode() {
        int result = 31;
        result = 31 * result + (this.inPlace ? 1 : 0);
        result = 31 * result + (this.scalarValue != null ? this.scalarValue.hashCode() : 0);
        result = 31 * result + Arrays.hashCode(this.dimensions);
        result = 31 * result + (this.ownName != null ? this.ownName.hashCode() : 0);
        return result;
    }

    public String[] onnxNames() {
        return new String[]{this.onnxName()};
    }

    public String[] tensorflowNames() {
        return new String[]{this.tensorflowName()};
    }

    public abstract String onnxName();

    public abstract String tensorflowName();

    public int getNumOutputs() {
        return -1;
    }

    public abstract void clearArrays();

    public Object[] getExtraArgs() {
        return this.extraArgs;
    }

    public void setExtraArgs(Object[] extraArgs) {
        this.extraArgs = extraArgs;
    }

    public String toString() {
        return "DifferentialFunction(sameDiff=" + this.getSameDiff() + ", inPlace=" + this.isInPlace() + ", scalarValue=" + this.getScalarValue() + ", dimensions=" + Arrays.toString(this.getDimensions()) + ", extraArgs=" + Arrays.deepToString(this.getExtraArgs()) + ", ownName=" + this.getOwnName() + ", ownNameSetWithDefault=" + this.isOwnNameSetWithDefault() + ")";
    }

    public SameDiff getSameDiff() {
        return this.sameDiff;
    }

    public void setSameDiff(SameDiff sameDiff) {
        this.sameDiff = sameDiff;
    }

    public boolean isInPlace() {
        return this.inPlace;
    }

    public void setInPlace(boolean inPlace) {
        this.inPlace = inPlace;
    }

    public INDArray getScalarValue() {
        return this.scalarValue;
    }

    public void setScalarValue(INDArray scalarValue) {
        this.scalarValue = scalarValue;
    }

    public int[] getDimensions() {
        return this.dimensions;
    }

    public void setDimensions(int[] dimensions) {
        this.dimensions = dimensions;
    }

    public String getOwnName() {
        return this.ownName;
    }

    public void setOwnName(String ownName) {
        this.ownName = ownName;
    }

    public boolean isOwnNameSetWithDefault() {
        return this.ownNameSetWithDefault;
    }

    public void setOwnNameSetWithDefault(boolean ownNameSetWithDefault) {
        this.ownNameSetWithDefault = ownNameSetWithDefault;
    }
}

