package org.nd4j.imports.converters;

import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.descriptors.onnx.OnnxDescriptorParser;
import org.nd4j.imports.descriptors.onnx.OpDescriptor;
import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.OpDef;

/* loaded from: input_file:org/nd4j/imports/converters/DifferentialFunctionClassHolder.class */
public class DifferentialFunctionClassHolder {
    private Map<String, OpDescriptor> onnxOpDescriptors;
    private Map<String, OpDef> tensorflowOpDescriptors;
    private int countTotalTfOps;
    private int countTotalMappedOps;
    private static DifferentialFunctionClassHolder INSTANCE;
    private static final Logger log = LoggerFactory.getLogger(DifferentialFunctionClassHolder.class);
    private static final Set<String> fieldNamesOpsIgnore = new LinkedHashSet<String>() { // from class: org.nd4j.imports.converters.DifferentialFunctionClassHolder.1
        {
            add("extraArgs");
            add("arrayInitialized");
            add("log");
            add("inputArguments");
            add("outputArguments");
            add("outputShapes");
            add("outputVariables");
            add("tArguments");
            add("iArguments");
            add("bArguments");
            add("dArguments");
            add("hash");
            add("opName");
            add("sameDiff");
            add("ownName");
        }
    };
    private static final Set<Class> classesToIgnore = new HashSet(Arrays.asList(Object.class));
    private static final Map<Class<?>, Set<String>> classFieldsToIgnore = new HashMap();
    private Map<String, DifferentialFunction> nodeConverters = ImportClassMapping.getOpNameMapping();
    private Map<String, DifferentialFunction> tensorFlowNames = ImportClassMapping.getTFOpMappingFunctions();
    private Map<String, DifferentialFunction> onnxNames = ImportClassMapping.getOnnxOpMappingFunctions();
    private Map<Long, Class<?>> customOpHashToClass = new HashMap();
    private Map<Long, Map<String, Class<?>>> customOpHashToClasses = new HashMap();
    private List<String> missingOps = new ArrayList();
    private Map<String, Map<String, Field>> fieldsForFunction = new LinkedHashMap();

    public Map<String, Field> getFieldsForFunction(DifferentialFunction differentialFunction) {
        return !this.fieldsForFunction.containsKey(differentialFunction.getClass().getName()) ? Collections.emptyMap() : this.fieldsForFunction.get(differentialFunction.getClass().getName());
    }

    public OpDef getOpDefByTensorflowName(String str) {
        if (this.tensorflowOpDescriptors.containsKey(str)) {
            return this.tensorflowOpDescriptors.get(str);
        }
        throw new ND4JIllegalStateException("No op found with name " + str);
    }

    public OpDescriptor getOpDescriptorForOnnx(String str) {
        if (this.onnxOpDescriptors.containsKey(str)) {
            return this.onnxOpDescriptors.get(str);
        }
        throw new ND4JIllegalStateException("No op found with name " + str);
    }

    public DifferentialFunction getOpWithTensorflowName(String str) {
        return this.tensorFlowNames.get(str);
    }

    public DifferentialFunction getOpWithOnnxName(String str) {
        return this.onnxNames.get(str);
    }

    private DifferentialFunctionClassHolder() {
        for (DifferentialFunction differentialFunction : ImportClassMapping.getOpNameMapping().values()) {
            if (differentialFunction != null && differentialFunction.opName() != null) {
                try {
                    LinkedHashMap linkedHashMap = new LinkedHashMap();
                    Class<?> cls = differentialFunction.getClass();
                    ArrayList arrayList = new ArrayList();
                    boolean z = true;
                    while (cls.getSuperclass() != null && !classesToIgnore.contains(cls.getSuperclass())) {
                        if (differentialFunction.isConfigProperties() && z) {
                            String configFieldName = differentialFunction.configFieldName();
                            configFieldName = configFieldName == null ? "config" : configFieldName;
                            Field field = null;
                            try {
                                field = cls.getDeclaredField(configFieldName);
                            } catch (NoSuchFieldException e) {
                                for (Class<? super Object> superclass = cls.getSuperclass(); superclass.getSuperclass() != null; superclass = superclass.getSuperclass()) {
                                    try {
                                        field = superclass.getDeclaredField(configFieldName);
                                        break;
                                    } catch (NoSuchFieldException e2) {
                                    }
                                }
                            }
                            if (field != null) {
                                for (Field field2 : field.getType().getDeclaredFields()) {
                                    if (!Modifier.isStatic(field2.getModifiers()) && !fieldNamesOpsIgnore.contains(field2.getName()) && (!classFieldsToIgnore.containsKey(cls) || !classFieldsToIgnore.get(cls).contains(field2.getName()))) {
                                        arrayList.add(field2);
                                        field2.setAccessible(true);
                                        if (linkedHashMap.containsKey(field2.getName())) {
                                            throw new IllegalStateException("Field with name " + field2.getName() + " exists for multiple classes: " + ((Field) linkedHashMap.get(field2.getName())).getDeclaringClass().getName() + " and " + field2.getDeclaringClass().getName());
                                        }
                                        linkedHashMap.put(field2.getName(), field2);
                                    }
                                }
                            }
                        } else {
                            for (Field field3 : cls.getDeclaredFields()) {
                                if (!Modifier.isStatic(field3.getModifiers()) && !fieldNamesOpsIgnore.contains(field3.getName()) && (!classFieldsToIgnore.containsKey(cls) || !classFieldsToIgnore.get(cls).contains(field3.getName()))) {
                                    arrayList.add(field3);
                                    field3.setAccessible(true);
                                    if (linkedHashMap.containsKey(field3.getName())) {
                                        throw new IllegalStateException("Field with name " + field3.getName() + " exists for multiple classes: " + ((Field) linkedHashMap.get(field3.getName())).getDeclaringClass().getName() + " and " + field3.getDeclaringClass().getName());
                                    }
                                    linkedHashMap.put(field3.getName(), field3);
                                }
                            }
                        }
                        cls = cls.getSuperclass();
                        z = false;
                    }
                    this.fieldsForFunction.put(differentialFunction.getClass().getName(), linkedHashMap);
                } catch (NoOpNameFoundException e3) {
                    log.trace("Skipping function  " + differentialFunction.getClass());
                } catch (Exception e4) {
                    throw new RuntimeException(e4);
                }
            }
        }
        try {
            this.tensorflowOpDescriptors = TensorflowDescriptorParser.opDescs();
            this.onnxOpDescriptors = OnnxDescriptorParser.onnxOpDescriptors();
            Set keySet = new HashMap(Nd4j.getExecutioner().getCustomOperations()).keySet();
            keySet.removeAll(this.nodeConverters.keySet());
            this.missingOps.addAll(keySet);
            Collections.sort(this.missingOps);
            this.countTotalTfOps = this.tensorflowOpDescriptors.size();
            HashSet hashSet = new HashSet();
            Iterator<DifferentialFunction> it = this.nodeConverters.values().iterator();
            while (it.hasNext()) {
                try {
                    Collections.addAll(hashSet, it.next().tensorflowNames());
                } catch (NoOpNameFoundException e5) {
                }
            }
            this.countTotalMappedOps = hashSet.size();
            Map<String, CustomOpDescriptor> customOperations = Nd4j.getExecutioner().getCustomOperations();
            HashSet hashSet2 = new HashSet();
            for (Map.Entry<String, CustomOpDescriptor> entry : customOperations.entrySet()) {
                DifferentialFunction differentialFunctionClassHolder = getInstance(entry.getKey());
                if (differentialFunctionClassHolder != null && CustomOp.class.isAssignableFrom(differentialFunctionClassHolder.getClass())) {
                    long hash = entry.getValue().getHash();
                    if (this.customOpHashToClass.containsKey(Long.valueOf(hash))) {
                        hashSet2.add(Long.valueOf(hash));
                    }
                    this.customOpHashToClass.put(Long.valueOf(entry.getValue().getHash()), differentialFunctionClassHolder.getClass());
                }
            }
            for (Map.Entry<String, CustomOpDescriptor> entry2 : customOperations.entrySet()) {
                long hash2 = entry2.getValue().getHash();
                if (hashSet2.contains(Long.valueOf(hash2))) {
                    if (!this.customOpHashToClasses.containsKey(Long.valueOf(hash2))) {
                        this.customOpHashToClasses.put(Long.valueOf(hash2), new HashMap());
                    }
                    Map<String, Class<?>> map = this.customOpHashToClasses.get(Long.valueOf(hash2));
                    DifferentialFunction differentialFunctionClassHolder2 = getInstance(entry2.getKey());
                    if (differentialFunctionClassHolder2 != null) {
                        map.put(entry2.getKey(), differentialFunctionClassHolder2.getClass());
                    }
                }
            }
        } catch (Exception e6) {
            throw new RuntimeException(e6);
        }
    }

    public Set<String> missingOnnxOps() {
        HashSet hashSet = new HashSet(this.onnxOpDescriptors.keySet());
        hashSet.removeAll(this.onnxNames.keySet());
        return hashSet;
    }

    public Set<String> missingTensorflowOps() {
        HashSet hashSet = new HashSet(this.tensorflowOpDescriptors.keySet());
        hashSet.removeAll(this.tensorFlowNames.keySet());
        return hashSet;
    }

    public List<String> missingOps() {
        return this.missingOps;
    }

    public boolean hasName(String str) {
        return this.nodeConverters.containsKey(str);
    }

    public Set<String> opNames() {
        return this.nodeConverters.keySet();
    }

    public DifferentialFunction getInstance(String str) {
        return this.nodeConverters.get(str);
    }

    public Class<?> customOpClassForHashAndName(long j, String str) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -889473228:
                if (str.equals(Switch.OP_NAME)) {
                    z = 4;
                    break;
                }
                break;
            case 3127582:
                if (str.equals(Exit.OP_NAME)) {
                    z = true;
                    break;
                }
                break;
            case 96667352:
                if (str.equals(Enter.OP_NAME)) {
                    z = false;
                    break;
                }
                break;
            case 103785528:
                if (str.equals(Merge.OP_NAME)) {
                    z = 3;
                    break;
                }
                break;
            case 947935902:
                if (str.equals(ExternalErrorsFunction.OP_NAME)) {
                    z = 6;
                    break;
                }
                break;
            case 1272784861:
                if (str.equals(LoopCond.OP_NAME)) {
                    z = 5;
                    break;
                }
                break;
            case 1781178513:
                if (str.equals(NextIteration.OP_NAME)) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return Enter.class;
            case true:
                return Exit.class;
            case true:
                return NextIteration.class;
            case true:
                return Merge.class;
            case true:
                return Switch.class;
            case true:
                return LoopCond.class;
            case true:
                return ExternalErrorsFunction.class;
            default:
                if (this.customOpHashToClasses.containsKey(Long.valueOf(j))) {
                    return this.customOpHashToClasses.get(Long.valueOf(j)).get(str);
                }
                if (this.customOpHashToClass.containsKey(Long.valueOf(j))) {
                    return this.customOpHashToClass.get(Long.valueOf(j));
                }
                if (ImportClassMapping.getOpNameMapping().containsKey(str)) {
                    return ImportClassMapping.getOpNameMapping().get(str).getClass();
                }
                throw new IllegalStateException("No op known for hash: " + j + " and name " + str);
        }
    }

    public static DifferentialFunctionClassHolder getInstance() {
        return INSTANCE;
    }

    public Map<String, DifferentialFunction> getTensorFlowNames() {
        return Collections.unmodifiableMap(this.tensorFlowNames);
    }

    public int getCountTotalTfOps() {
        return this.countTotalTfOps;
    }

    public int getCountTotalMappedOps() {
        return this.countTotalMappedOps;
    }

    static {
        classFieldsToIgnore.put(BaseOp.class, new HashSet(Arrays.asList("x", "y", "z", "n", "numProcessed", "xVertexId", "yVertexId", "zVertexId", "extraArgz")));
        INSTANCE = new DifferentialFunctionClassHolder();
    }
}
