package org.nd4j.autodiff.samediff.internal;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.AbstractSession;
import org.nd4j.autodiff.samediff.internal.memory.ArrayCacheMemoryMgr;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.common.util.ArrayUtil;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseOp;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.Concat;
import org.nd4j.linalg.api.ops.impl.shape.Stack;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.BaseTensorOp;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayConcat;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayGather;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayRead;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayScatter;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArraySize;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArraySplit;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayWrite;
import org.nd4j.linalg.api.ops.impl.transforms.Assert;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession.class */
public class InferenceSession extends AbstractSession<INDArray, Pair<SameDiffOp, OpContext>> {
    private static final Logger log = LoggerFactory.getLogger(InferenceSession.class);
    private static final String SCOPE_PANIC_MSG = "If required, arrays in workspaces can be detached using INDArray.detach() before being passed to the SameDiff instance.\nAlternatively, arrays defined in a workspace must be replaced after the workspace has been closed.";
    protected static final String KERAS_TRAIN_TEST = "keras_learning_phase";
    private SessionMemMgr mmgr;
    private AbstractDependencyTracker<INDArray, Dep> arrayUseTracker;
    private Map<String, OpContext> opContexts;

    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession$ConstantDep.class */
    protected static class ConstantDep extends Dep {
        protected String constName;

        public String getConstName() {
            return this.constName;
        }

        public void setConstName(String str) {
            this.constName = str;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public String toString() {
            return "InferenceSession.ConstantDep(constName=" + getConstName() + ")";
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof ConstantDep)) {
                return false;
            }
            ConstantDep constantDep = (ConstantDep) obj;
            if (!constantDep.canEqual(this) || !super.equals(obj)) {
                return false;
            }
            String constName = getConstName();
            String constName2 = constantDep.getConstName();
            return constName == null ? constName2 == null : constName.equals(constName2);
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        protected boolean canEqual(Object obj) {
            return obj instanceof ConstantDep;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public int hashCode() {
            int hashCode = super.hashCode();
            String constName = getConstName();
            return (hashCode * 59) + (constName == null ? 43 : constName.hashCode());
        }

        public ConstantDep(String str) {
            this.constName = str;
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession$Dep.class */
    public static abstract class Dep {
        protected String frame;
        protected FrameIter parentFrame;

        public String getFrame() {
            return this.frame;
        }

        public FrameIter getParentFrame() {
            return this.parentFrame;
        }

        public void setFrame(String str) {
            this.frame = str;
        }

        public void setParentFrame(FrameIter frameIter) {
            this.parentFrame = frameIter;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof Dep)) {
                return false;
            }
            Dep dep = (Dep) obj;
            if (!dep.canEqual(this)) {
                return false;
            }
            String frame = getFrame();
            String frame2 = dep.getFrame();
            if (frame == null) {
                if (frame2 != null) {
                    return false;
                }
            } else if (!frame.equals(frame2)) {
                return false;
            }
            FrameIter parentFrame = getParentFrame();
            FrameIter parentFrame2 = dep.getParentFrame();
            return parentFrame == null ? parentFrame2 == null : parentFrame.equals(parentFrame2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof Dep;
        }

        public int hashCode() {
            String frame = getFrame();
            int hashCode = (1 * 59) + (frame == null ? 43 : frame.hashCode());
            FrameIter parentFrame = getParentFrame();
            return (hashCode * 59) + (parentFrame == null ? 43 : parentFrame.hashCode());
        }

        public String toString() {
            return "InferenceSession.Dep(frame=" + getFrame() + ", parentFrame=" + getParentFrame() + ")";
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession$ExecDoneDep.class */
    public static class ExecDoneDep extends Dep {
        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public String toString() {
            return "InferenceSession.ExecDoneDep()";
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            return (obj instanceof ExecDoneDep) && ((ExecDoneDep) obj).canEqual(this) && super.equals(obj);
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        protected boolean canEqual(Object obj) {
            return obj instanceof ExecDoneDep;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public int hashCode() {
            return super.hashCode();
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession$OpDep.class */
    public static class OpDep extends Dep {
        protected String opName;
        protected int iter;

        protected OpDep(@NonNull String str, @NonNull String str2, int i, FrameIter frameIter) {
            if (str == null) {
                throw new NullPointerException("opName is marked non-null but is null");
            }
            if (str2 == null) {
                throw new NullPointerException("frame is marked non-null but is null");
            }
            this.opName = str;
            this.frame = str2;
            this.iter = i;
            this.parentFrame = frameIter;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public String toString() {
            return "OpDep(" + this.opName + ",frame=" + this.frame + ",iter=" + this.iter + (this.parentFrame == null ? "" : ",parent=" + this.parentFrame) + ")";
        }

        public OpDep(String str, int i) {
            this.opName = str;
            this.iter = i;
        }

        public String getOpName() {
            return this.opName;
        }

        public int getIter() {
            return this.iter;
        }

        public void setOpName(String str) {
            this.opName = str;
        }

        public void setIter(int i) {
            this.iter = i;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof OpDep)) {
                return false;
            }
            OpDep opDep = (OpDep) obj;
            if (!opDep.canEqual(this) || !super.equals(obj) || getIter() != opDep.getIter()) {
                return false;
            }
            String opName = getOpName();
            String opName2 = opDep.getOpName();
            return opName == null ? opName2 == null : opName.equals(opName2);
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        protected boolean canEqual(Object obj) {
            return obj instanceof OpDep;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public int hashCode() {
            int hashCode = (super.hashCode() * 59) + getIter();
            String opName = getOpName();
            return (hashCode * 59) + (opName == null ? 43 : opName.hashCode());
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession$PlaceholderDep.class */
    protected static class PlaceholderDep extends Dep {
        protected String phName;

        public String getPhName() {
            return this.phName;
        }

        public void setPhName(String str) {
            this.phName = str;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public String toString() {
            return "InferenceSession.PlaceholderDep(phName=" + getPhName() + ")";
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof PlaceholderDep)) {
                return false;
            }
            PlaceholderDep placeholderDep = (PlaceholderDep) obj;
            if (!placeholderDep.canEqual(this) || !super.equals(obj)) {
                return false;
            }
            String phName = getPhName();
            String phName2 = placeholderDep.getPhName();
            return phName == null ? phName2 == null : phName.equals(phName2);
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        protected boolean canEqual(Object obj) {
            return obj instanceof PlaceholderDep;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public int hashCode() {
            int hashCode = super.hashCode();
            String phName = getPhName();
            return (hashCode * 59) + (phName == null ? 43 : phName.hashCode());
        }

        public PlaceholderDep(String str) {
            this.phName = str;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession$ReqOutputDep.class */
    public static class ReqOutputDep extends Dep {
        protected String outputName;

        public String getOutputName() {
            return this.outputName;
        }

        public void setOutputName(String str) {
            this.outputName = str;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public String toString() {
            return "InferenceSession.ReqOutputDep(outputName=" + getOutputName() + ")";
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof ReqOutputDep)) {
                return false;
            }
            ReqOutputDep reqOutputDep = (ReqOutputDep) obj;
            if (!reqOutputDep.canEqual(this) || !super.equals(obj)) {
                return false;
            }
            String outputName = getOutputName();
            String outputName2 = reqOutputDep.getOutputName();
            return outputName == null ? outputName2 == null : outputName.equals(outputName2);
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        protected boolean canEqual(Object obj) {
            return obj instanceof ReqOutputDep;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public int hashCode() {
            int hashCode = super.hashCode();
            String outputName = getOutputName();
            return (hashCode * 59) + (outputName == null ? 43 : outputName.hashCode());
        }

        public ReqOutputDep(String str) {
            this.outputName = str;
        }
    }

    /* loaded from: input_file:org/nd4j/autodiff/samediff/internal/InferenceSession$VariableDep.class */
    protected static class VariableDep extends Dep {
        protected String varName;

        public String getVarName() {
            return this.varName;
        }

        public void setVarName(String str) {
            this.varName = str;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public String toString() {
            return "InferenceSession.VariableDep(varName=" + getVarName() + ")";
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof VariableDep)) {
                return false;
            }
            VariableDep variableDep = (VariableDep) obj;
            if (!variableDep.canEqual(this) || !super.equals(obj)) {
                return false;
            }
            String varName = getVarName();
            String varName2 = variableDep.getVarName();
            return varName == null ? varName2 == null : varName.equals(varName2);
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        protected boolean canEqual(Object obj) {
            return obj instanceof VariableDep;
        }

        @Override // org.nd4j.autodiff.samediff.internal.InferenceSession.Dep
        public int hashCode() {
            int hashCode = super.hashCode();
            String varName = getVarName();
            return (hashCode * 59) + (varName == null ? 43 : varName.hashCode());
        }

        public VariableDep(String str) {
            this.varName = str;
        }
    }

    public InferenceSession(@NonNull SameDiff sameDiff) {
        super(sameDiff);
        this.arrayUseTracker = new IdentityDependencyTracker();
        this.opContexts = new HashMap();
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked non-null but is null");
        }
        this.mmgr = new ArrayCacheMemoryMgr();
    }

    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    protected Map<String, INDArray> preprocessPlaceholders(Map<String, INDArray> map, At at) {
        this.arrayUseTracker.clear();
        for (SDVariable sDVariable : this.sameDiff.variables()) {
            if (sDVariable.getVariableType() == VariableType.CONSTANT) {
                this.arrayUseTracker.addDependency(sDVariable.getArr(), new ConstantDep(sDVariable.name()));
            } else if (sDVariable.getVariableType() == VariableType.VARIABLE) {
                this.arrayUseTracker.addDependency(sDVariable.getArr(), new VariableDep(sDVariable.name()));
            }
        }
        boolean z = false;
        List<String> inputs = this.sameDiff.inputs();
        if (inputs != null && !inputs.isEmpty()) {
            for (String str : inputs) {
                if (str.endsWith(KERAS_TRAIN_TEST) && !map.containsKey(str)) {
                    INDArray assign = this.mmgr.allocate(false, DataType.BOOL, new long[0]).assign(at.operation().isTrainingPhase());
                    map = new HashMap(map);
                    map.put(str, assign);
                    z = true;
                }
            }
        }
        if (map == null || map.isEmpty()) {
            return map;
        }
        HashMap hashMap = new HashMap();
        for (Map.Entry<String, INDArray> entry : map.entrySet()) {
            Preconditions.checkState(this.sameDiff.hasVariable(entry.getKey()), "Invalid placeholder passed for execution: No variable/placeholder with name %s exists", entry.getKey());
            INDArray value = entry.getValue();
            if (value.isAttached()) {
                MemoryWorkspace parentWorkspace = value.data() == null ? null : value.data().getParentWorkspace();
                if (parentWorkspace != null && parentWorkspace.getWorkspaceType() != MemoryWorkspace.Type.CIRCULAR) {
                    if (!parentWorkspace.isScopeActive()) {
                        throw new ND4JIllegalStateException("Placeholder \"" + entry.getKey() + "\" array uses leaked workspace pointer from workspace [" + parentWorkspace.getId() + "]: Workspace the array was defined in is no longer open.\nAll open workspaces: " + DefaultOpExecutioner.allOpenWorkspaces() + "\n" + SCOPE_PANIC_MSG);
                    }
                    if (parentWorkspace.getGenerationId() != value.data().getGenerationId()) {
                        throw new ND4JIllegalStateException("Placeholder \"" + entry.getKey() + "\" array uses outdated workspace pointer from workspace [" + parentWorkspace.getId() + "]: Workspace array was defined in has been closed and reopened at least once since array creation. Array WS iteration: " + value.data().getGenerationId() + ". Workspace current iteration: " + parentWorkspace.getGenerationId() + "\nAll open workspaces: " + DefaultOpExecutioner.allOpenWorkspaces() + "\n" + SCOPE_PANIC_MSG);
                    }
                }
            }
            DataType dataType = this.sameDiff.getVariable(entry.getKey()).dataType();
            if (z && entry.getKey().endsWith(KERAS_TRAIN_TEST)) {
                this.arrayUseTracker.addDependency(value, new ExecDoneDep());
            } else if (value.dataType() == dataType) {
                this.arrayUseTracker.addDependency(entry.getValue(), new PlaceholderDep(entry.getKey()));
            } else {
                INDArray allocate = this.mmgr.allocate(false, dataType, value.shape());
                allocate.assign(value);
                value = allocate;
                this.arrayUseTracker.addDependency(value, new ExecDoneDep());
            }
            hashMap.put(entry.getKey(), value);
        }
        return hashMap;
    }

    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    protected Map<String, INDArray> postProcessOutput(Map<String, INDArray> map) {
        if (this.dt.hasNewAllSatisfied()) {
            for (AbstractSession.ExecStep execStep : this.dt.getNewAllSatisfiedList()) {
                if (execStep.getType() == AbstractSession.ExecType.OP) {
                    this.arrayUseTracker.markSatisfied(new OpDep(execStep.getName(), execStep.getFrameIter().getFrame(), execStep.getFrameIter().getIteration(), execStep.getFrameIter().getParentFrame()), true);
                }
            }
        }
        this.arrayUseTracker.markSatisfied(new ExecDoneDep(), true);
        if (this.arrayUseTracker.hasNewAllSatisfied()) {
            Iterator<INDArray> it = this.arrayUseTracker.getNewAllSatisfiedList().iterator();
            while (it.hasNext()) {
                this.mmgr.release(it.next());
            }
        }
        return map;
    }

    /* renamed from: getOutputs, reason: avoid collision after fix types in other method */
    public INDArray[] getOutputs2(Pair<SameDiffOp, OpContext> pair, FrameIter frameIter, Set<AbstractSession.VarId> set, Set<AbstractSession.VarId> set2, Set<String> set3, List<Listener> list, At at, MultiDataSet multiDataSet, Set<String> set4) {
        SameDiffOp first = pair.getFirst();
        at.setFrameIter(frameIter);
        if (list != null && list.size() > 0) {
            SameDiffOp sameDiffOp = this.sameDiff.getOps().get(first.getOp().getOwnName());
            for (Listener listener : list) {
                if (listener.isActive(at.operation())) {
                    listener.preOpExecution(this.sameDiff, at, sameDiffOp, pair.getSecond());
                }
            }
        }
        if (this.sameDiff.isDebugMode()) {
            log.info("Executing samediff op: " + first.getName());
        }
        INDArray[] doExec = doExec(first.getOp(), pair.getRight(), frameIter, set, set2, set3);
        if (log.isTraceEnabled()) {
            StringBuilder sb = new StringBuilder();
            sb.append(first.getName()).append(" - ").append(frameIter).append(" outputs: ");
            List<String> outputsOfOp = first.getOutputsOfOp();
            for (int i = 0; i < doExec.length; i++) {
                if (i > 0) {
                    sb.append(", ");
                }
                sb.append("(").append(i).append(" - ").append(outputsOfOp.get(i)).append(" = ").append(doExec[i] == null ? null : Long.valueOf(doExec[i].getId())).append(")");
            }
            log.trace(sb.toString());
        }
        if (list != null && list.size() > 0) {
            Map map = null;
            for (Listener listener2 : list) {
                if (listener2.isActive(at.operation())) {
                    if (map == null) {
                        HashMap hashMap = new HashMap();
                        for (int i2 = 0; i2 < doExec.length; i2++) {
                            hashMap.put(first.outputsOfOp.get(i2), doExec[i2]);
                        }
                        map = Collections.unmodifiableMap(hashMap);
                    }
                    listener2.opExecution(this.sameDiff, at, multiDataSet, first, pair.getSecond(), doExec);
                    for (String str : map.keySet()) {
                        listener2.activationAvailable(this.sameDiff, at, multiDataSet, first, str, (INDArray) map.get(str));
                    }
                }
            }
        }
        first.getOp().clearArrays();
        if (pair.getSecond() != null) {
            pair.getSecond().purge();
        }
        SameDiffOp sameDiffOp2 = this.sameDiff.getOps().get(first.getName());
        List<String> outputsOfOp2 = sameDiffOp2.getOutputsOfOp();
        for (int i3 = 0; i3 < doExec.length; i3++) {
            if (doExec[i3] != null || !(sameDiffOp2.getOp() instanceof Switch)) {
                String str2 = outputsOfOp2.get(i3);
                List<String> inputsForOp = ((Variable) this.sameDiff.getVariables().get(str2)).getInputsForOp();
                if (inputsForOp != null) {
                    for (String str3 : inputsForOp) {
                        if (this.subgraphOps.contains(str3)) {
                            SameDiffOp sameDiffOp3 = this.sameDiff.getOps().get(str3);
                            if (sameDiffOp3.getOp() instanceof Enter) {
                                Enter enter = (Enter) sameDiffOp3.getOp();
                                if (enter.isConstant()) {
                                    this.arrayUseTracker.addDependency(doExec[i3], new ExecDoneDep());
                                } else {
                                    this.arrayUseTracker.addDependency(doExec[i3], new OpDep(str3, enter.getFrameName(), 0, frameIter));
                                }
                            } else if (sameDiffOp3.getOp() instanceof NextIteration) {
                                this.arrayUseTracker.addDependency(doExec[i3], new OpDep(str3, frameIter.getFrame(), frameIter.getIteration() + 1, frameIter.getParentFrame()));
                            } else if (sameDiffOp3.getOp() instanceof Exit) {
                                FrameIter parentFrame = frameIter.getParentFrame();
                                this.arrayUseTracker.addDependency(doExec[i3], new OpDep(str3, parentFrame.getFrame(), parentFrame.getIteration(), parentFrame.getParentFrame()));
                            } else {
                                this.arrayUseTracker.addDependency(doExec[i3], new OpDep(str3, frameIter.getFrame(), frameIter.getIteration(), frameIter.getParentFrame()));
                            }
                        }
                    }
                }
                if (AbstractSession.OUTER_FRAME.equals(frameIter.getFrame()) && set4.contains(str2)) {
                    this.arrayUseTracker.addDependency(doExec[i3], new ReqOutputDep(str2));
                } else if ((inputsForOp == null || inputsForOp.isEmpty()) && !this.arrayUseTracker.hasDependency(doExec[i3])) {
                    if (log.isTraceEnabled()) {
                        log.trace("Found array id {} (output of {}) not required anywhere, deallocating", Long.valueOf(doExec[i3].getId()), sameDiffOp2.getName());
                    }
                    this.mmgr.release(doExec[i3]);
                }
            }
        }
        this.arrayUseTracker.markSatisfied(new OpDep(first.getName(), frameIter.getFrame(), frameIter.getIteration(), frameIter.getParentFrame()), true);
        if (this.arrayUseTracker.hasNewAllSatisfied()) {
            for (INDArray iNDArray : this.arrayUseTracker.getNewAllSatisfiedList()) {
                if (log.isTraceEnabled()) {
                    log.trace("Closing array... id={}, {}", Long.valueOf(iNDArray.getId()), iNDArray.shapeInfoToString());
                }
                this.mmgr.release(iNDArray);
            }
        }
        return doExec;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public INDArray[] doExec(DifferentialFunction differentialFunction, OpContext opContext, FrameIter frameIter, Set<AbstractSession.VarId> set, Set<AbstractSession.VarId> set2, Set<String> set3) {
        INDArray inputArray;
        int size = (set == null ? 0 : set.size()) + (set3 == null ? 0 : set3.size()) + (set2 == null ? 0 : set2.size());
        boolean z = (set == null || set.size() == 0) && (set2 == null || set2.size() == 0);
        if (differentialFunction instanceof Identity) {
            String[] argNames = ((Identity) differentialFunction).argNames();
            Preconditions.checkState(argNames.length == 1, "Expected only 1 arg name in identity op, got %s", (Object) argNames);
            return new INDArray[]{(INDArray) this.nodeOutputs.get(frameIter.toVarId(argNames[0]))};
        }
        if (differentialFunction instanceof Switch) {
            String[] argNames2 = ((Switch) differentialFunction).argNames();
            AbstractSession.VarId varId = frameIter.toVarId(argNames2[1]);
            INDArray iNDArray = (INDArray) this.nodeOutputs.get(varId);
            if (iNDArray == null && !set3.isEmpty() && set3.contains(argNames2[1])) {
                iNDArray = (INDArray) this.nodeOutputs.get(new AbstractSession.VarId(argNames2[1], AbstractSession.OUTER_FRAME, 0, null));
            }
            Preconditions.checkNotNull(iNDArray, "Error during graph execution: Predicate array was null. VarId=%s", varId);
            Preconditions.checkState(iNDArray.isScalar() && iNDArray.dataType() == DataType.BOOL, "Expected boolean predicate: got %ndSInfo", iNDArray);
            AbstractSession.VarId varId2 = frameIter.toVarId(argNames2[0]);
            return iNDArray.getDouble(0L) == 0.0d ? new INDArray[]{(INDArray) this.nodeOutputs.get(varId2), null} : new INDArray[]{null, (INDArray) this.nodeOutputs.get(varId2)};
        }
        if (differentialFunction instanceof Enter) {
            Enter enter = (Enter) differentialFunction;
            String[] argNames3 = enter.argNames();
            Preconditions.checkState(argNames3.length == 1, "Expected only 1 arg name for enter op: got %s", (Object) argNames3);
            Preconditions.checkState(size == 1, "Expected exactly 1 op input for Enter op \"%s\", got %s+%s", enter.getOwnName(), set, set3);
            INDArray iNDArray2 = (INDArray) this.nodeOutputs.get(z ? new AbstractSession.VarId(set3.iterator().next(), AbstractSession.OUTER_FRAME, 0, null) : (set2 == null || set2.size() <= 0) ? set.iterator().next() : set2.iterator().next());
            Preconditions.checkNotNull(iNDArray2, "Could not get enter op \"%s\" input: output variable %s - %s", enter.getOwnName(), enter.outputVariablesNames(), frameIter);
            return new INDArray[]{iNDArray2};
        }
        if (differentialFunction instanceof Exit) {
            return new INDArray[]{(INDArray) this.nodeOutputs.get(z ? new AbstractSession.VarId(set3.iterator().next(), AbstractSession.OUTER_FRAME, 0, null) : (set2 == null || set2.size() <= 0) ? set.iterator().next() : set2.iterator().next())};
        }
        if (differentialFunction instanceof NextIteration) {
            Preconditions.checkState(size == 1, "Expected exactly 1 op input for NextIteration: got %s+%s", set, set3);
            AbstractSession.VarId next = (set2 == null || set2.isEmpty()) ? set.iterator().next() : set2.iterator().next();
            Preconditions.checkState(frameIter.getFrame().equals(next.getFrame()), "Expected same frame for NextIteration input vs. output: got input %s, output %s", next, frameIter);
            Preconditions.checkState(frameIter.getIteration() == next.getIteration() + 1, "Expected output iteration for NextIteration output to be 1 larger than the input iteration. Input: %s, output %s", next, frameIter);
            INDArray iNDArray3 = (INDArray) this.nodeOutputs.get(next);
            if (iNDArray3 == null) {
                Preconditions.throwStateEx("Could not find array for NextIteration operation %s with output %s (frame=%s, iteration=%s)", differentialFunction.getOwnName(), this.sameDiff.getOps().get(differentialFunction.getOwnName()).getOutputsOfOp().get(0), frameIter.getFrame(), Integer.valueOf(frameIter.getIteration()));
            }
            return new INDArray[]{iNDArray3};
        }
        if (differentialFunction instanceof Merge) {
            Merge merge = (Merge) differentialFunction;
            String[] inputsForOp = this.sameDiff.getInputsForOp(differentialFunction);
            for (String str : inputsForOp) {
                AbstractSession.VarId varId3 = frameIter.toVarId(str);
                if (this.nodeOutputs.containsKey(varId3)) {
                    log.trace("Returning input \"{}\" for merge node \"{}\"", merge.getOwnName(), str);
                    INDArray iNDArray4 = (INDArray) this.nodeOutputs.get(varId3);
                    Preconditions.checkState(iNDArray4 != null, "Could not find output array for %s", varId3);
                    return new INDArray[]{iNDArray4};
                }
            }
            throw new IllegalStateException("Merge node " + merge.getOwnName() + " has no available inputs (all inputs: " + Arrays.toString(inputsForOp) + ") - should not be executed at this point");
        }
        if (differentialFunction instanceof LoopCond) {
            String[] argNames4 = ((LoopCond) differentialFunction).argNames();
            Preconditions.checkState(argNames4.length == 1, "Expected only 1 arg name in LoopCond op, got %s", (Object) argNames4);
            INDArray iNDArray5 = (INDArray) this.nodeOutputs.get(frameIter.toVarId(argNames4[0]));
            Preconditions.checkNotNull(iNDArray5, "Input to LoopCond op must not be null");
            Preconditions.checkState(iNDArray5.isScalar() && iNDArray5.dataType() == DataType.BOOL, "LoopCond input must be a scalar boolean, got %ndShape");
            return new INDArray[]{iNDArray5};
        }
        if (differentialFunction instanceof BaseTensorOp) {
            return getOutputsHelperTensorArrayOps(differentialFunction, frameIter, set, set2);
        }
        if (differentialFunction instanceof GradientBackwardsMarker) {
            return new INDArray[]{this.mmgr.allocate(false, DataType.FLOAT, new long[0]).assign(Float.valueOf(1.0f))};
        }
        if (differentialFunction instanceof ExternalErrorsFunction) {
            INDArray iNDArray6 = (INDArray) this.nodeOutputs.get(new AbstractSession.VarId(((ExternalErrorsFunction) differentialFunction).getGradPlaceholderName(), AbstractSession.OUTER_FRAME, 0, null));
            Preconditions.checkState(iNDArray6 != null, "Could not find external errors placeholder array: %s", iNDArray6);
            INDArray allocate = this.mmgr.allocate(false, iNDArray6.dataType(), iNDArray6.shape());
            allocate.assign(iNDArray6);
            return new INDArray[]{allocate};
        }
        if (!(differentialFunction instanceof Assert)) {
            if (differentialFunction instanceof CustomOp) {
                Nd4j.exec((CustomOp) differentialFunction, opContext);
                return (INDArray[]) opContext.getOutputArrays().toArray(new INDArray[0]);
            }
            if (!(differentialFunction instanceof Op)) {
                throw new UnsupportedOperationException("Execution not yet implemented for: " + differentialFunction.getClass().getName());
            }
            Nd4j.exec((Op) differentialFunction, opContext);
            return new INDArray[]{opContext.getOutputArray(0)};
        }
        Assert r0 = (Assert) differentialFunction;
        if (opContext.getInputArray(0).getDouble(0L) != 0.0d) {
            return (INDArray[]) opContext.getOutputArrays().toArray(new INDArray[0]);
        }
        String str2 = "Assertion failed for operation \"" + differentialFunction.getOwnName() + "\" during execution";
        if (r0.numInputArguments() >= 3 && (inputArray = opContext.getInputArray(2)) != null && inputArray.dataType() == DataType.UTF8) {
            str2 = str2 + ": " + inputArray.getString(0L);
        }
        if (r0.numInputArguments() >= 5) {
            str2 = str2 + "\n" + opContext.getInputArray(4);
        }
        throw new IllegalStateException(str2);
    }

    public INDArray[] getOutputsHelperTensorArrayOps(DifferentialFunction differentialFunction, FrameIter frameIter, Set<AbstractSession.VarId> set, Set<AbstractSession.VarId> set2) {
        if (differentialFunction instanceof TensorArray) {
            AbstractSession.VarId varId = frameIter.toVarId(differentialFunction.outputVariable().name());
            Preconditions.checkState(!this.tensorArrays.containsKey(varId), "TensorArray already exists for %s when executing TensorArrayV3", varId);
            this.tensorArrays.put(varId, new ArrayList());
            return new INDArray[]{this.mmgr.allocate(false, DataType.BOOL, new long[0]).assign(true), this.mmgr.allocate(false, DataType.FLOAT, new long[0]).assign(Double.valueOf(0.0d))};
        }
        if (differentialFunction instanceof TensorArrayRead) {
            INDArray array = getArray(differentialFunction.arg(1), set, set2);
            Preconditions.checkState(array.isScalar(), "TensorArrayRead input argument 1 should be scalar - has shape %ndShape", array);
            int i = array.getInt(0);
            SDVariable arg = differentialFunction.arg(0);
            AbstractSession.VarId lookup = set == null ? null : lookup(arg.name(), set, false);
            if (lookup == null && set2 != null) {
                lookup = lookup(arg.name(), set2, false);
            }
            Preconditions.checkState(lookup != null, "Could not find input %s", arg.name());
            while (this.sameDiff.getVariableOutputOp(arg.name()) instanceof Enter) {
                arg = this.sameDiff.getVariableOutputOp(arg.name()).arg();
                lookup = lookup.getParentFrame().toVarId(arg.name());
            }
            List<INDArray> list = getTensorArrays().get(lookup);
            Preconditions.checkState(list != null, "Could not find TensorList for %s", lookup);
            Preconditions.checkState(list.size() > i, "Cannot get index %s from TensorList of size %s (array not present?) - VarId=%s", Integer.valueOf(i), Integer.valueOf(list.size()), lookup);
            return new INDArray[]{list.get(i)};
        }
        if (differentialFunction instanceof TensorArrayWrite) {
            SDVariable arg2 = differentialFunction.arg(0);
            AbstractSession.VarId lookup2 = set == null ? null : lookup(arg2.name(), set, false);
            if (lookup2 == null && set2 != null) {
                lookup2 = lookup(arg2.name(), set2, false);
            }
            Preconditions.checkState(lookup2 != null, "Could not find input %s", arg2.name());
            while (this.sameDiff.getVariableOutputOp(arg2.name()) instanceof Enter) {
                arg2 = this.sameDiff.getVariableOutputOp(arg2.name()).arg();
                lookup2 = lookup2.getParentFrame().toVarId(arg2.name());
            }
            INDArray array2 = getArray(this.sameDiff.getVariable(differentialFunction.arg(1).name()), set, set2);
            Preconditions.checkState(array2.isScalar(), "Index variable ID for TensorArrayWrite should be a scalar, got %ndShape", array2);
            int i2 = array2.getInt(0);
            String name = differentialFunction.arg(2).name();
            INDArray array3 = getArray(this.sameDiff.getVariable(name), set, set2);
            Preconditions.checkState(array3 != null, "Could not find array for %s", name);
            Preconditions.checkState(this.tensorArrays.containsKey(lookup2), "Tensor array does not exist for %s", lookup2);
            List list2 = (List) this.tensorArrays.get(lookup2);
            while (list2.size() <= i2) {
                list2.add(null);
            }
            list2.set(i2, array3);
            this.arrayUseTracker.addDependency(array3, new ExecDoneDep());
            return new INDArray[]{this.mmgr.allocate(false, DataType.FLOAT, new long[0]).assign(Double.valueOf(0.0d))};
        }
        if (differentialFunction instanceof TensorArraySize) {
            SDVariable arg3 = differentialFunction.arg(0);
            AbstractSession.VarId lookup3 = set == null ? null : lookup(arg3.name(), set, false);
            if (lookup3 == null && set2 != null) {
                lookup3 = lookup(arg3.name(), set2, false);
            }
            List list3 = (List) this.tensorArrays.get(lookup3);
            Preconditions.checkState(list3 != null, "Could not find TensorArray: %s", lookup3);
            return new INDArray[]{this.mmgr.allocate(false, DataType.INT, new long[0]).assign(Integer.valueOf(list3.size()))};
        }
        if (differentialFunction instanceof TensorArrayConcat) {
            SDVariable arg4 = differentialFunction.arg(0);
            AbstractSession.VarId lookup4 = set == null ? null : lookup(arg4.name(), set, false);
            if (lookup4 == null && set2 != null) {
                lookup4 = lookup(arg4.name(), set2, false);
            }
            Concat concat = new Concat(0, (INDArray[]) ((List) this.tensorArrays.get(lookup4)).toArray(new INDArray[0]));
            INDArray allocate = this.mmgr.allocate(false, concat.calculateOutputShape().get(0));
            concat.setOutputArgument(0, allocate);
            Nd4j.exec(concat);
            return new INDArray[]{allocate};
        }
        if (differentialFunction instanceof TensorArrayGather) {
            SDVariable arg5 = differentialFunction.arg(0);
            AbstractSession.VarId lookup5 = set == null ? null : lookup(arg5.name(), set, false);
            if (lookup5 == null && set2 != null) {
                lookup5 = lookup(arg5.name(), set2, false);
            }
            List list4 = (List) this.tensorArrays.get(lookup5);
            Preconditions.checkState(list4 != null, "Could not find TensorArray: %s", lookup5);
            String name2 = differentialFunction.arg(1).name();
            INDArray array4 = getArray(this.sameDiff.getVariable(name2), set, set2);
            Preconditions.checkState(array4.isVector(), "Indices variable for TensorArrayGather should be a vector, got %ndShape for %s", array4, name2);
            Preconditions.checkState(array4.dataType().isIntType(), "Indices variable for TensorArrayGather should be an integer type, got %s for array %s", array4.dataType(), name2);
            int[] intVector = array4.toIntVector();
            ArrayList arrayList = new ArrayList();
            if (intVector.length == 1 || intVector[0] < 1) {
                arrayList.addAll(list4);
            } else {
                int length = intVector.length;
                for (int i3 = 0; i3 < length; i3++) {
                    int i4 = intVector[i3];
                    Preconditions.checkState(i4 >= 0, "Index for TensorArrayGather must be >= 0, got %s", i4);
                    arrayList.add(list4.get(i4));
                }
            }
            Stack stack = new Stack((INDArray[]) arrayList.toArray(new INDArray[0]), (INDArray) null, 0);
            INDArray allocate2 = this.mmgr.allocate(false, stack.calculateOutputShape().get(0));
            stack.setOutputArgument(0, allocate2);
            Nd4j.exec(stack);
            return new INDArray[]{allocate2};
        }
        if (!(differentialFunction instanceof TensorArrayScatter)) {
            if (!(differentialFunction instanceof TensorArraySplit)) {
                throw new IllegalStateException("Execution support not yet implemented for: " + differentialFunction.getClass().getName());
            }
            SDVariable arg6 = differentialFunction.arg(0);
            AbstractSession.VarId lookup6 = set == null ? null : lookup(arg6.name(), set, false);
            if (lookup6 == null && set2 != null) {
                lookup6 = lookup(arg6.name(), set2, false);
            }
            List list5 = (List) this.tensorArrays.get(lookup6);
            Preconditions.checkState(list5 != null, "Could not find TensorArray: %s", lookup6);
            INDArray array5 = getArray(this.sameDiff.getVariable(differentialFunction.arg(1).name()), set, set2);
            String name3 = differentialFunction.arg(2).name();
            INDArray array6 = getArray(this.sameDiff.getVariable(name3), set, set2);
            Preconditions.checkState(array6.isVector(), "Indices variable for TensorArraySplit should be a vector, got %ndShape for %s", array6, name3);
            Preconditions.checkState(array6.dataType().isIntType(), "Indices variable for TensorArraySplit should be an integer type, got %s for array %s", array6.dataType(), name3);
            int[] intVector2 = array6.toIntVector();
            while (list5.size() <= intVector2.length) {
                list5.add(null);
            }
            INDArrayIndex[] iNDArrayIndexArr = (INDArrayIndex[]) ArrayUtil.nTimes(array5.rank(), NDArrayIndex.all(), INDArrayIndex.class);
            int i5 = 0;
            for (int i6 = 0; i6 < intVector2.length; i6++) {
                iNDArrayIndexArr[0] = NDArrayIndex.interval(i5, i5 + intVector2[i6]);
                INDArray dup = this.mmgr.dup(array5.get(iNDArrayIndexArr));
                list5.set(i6, dup);
                i5 += intVector2[i6];
                this.arrayUseTracker.addDependency(dup, new ExecDoneDep());
            }
            return new INDArray[]{this.mmgr.allocate(false, DataType.FLOAT, new long[0]).assign(Double.valueOf(0.0d))};
        }
        SDVariable arg7 = differentialFunction.arg(0);
        AbstractSession.VarId lookup7 = set == null ? null : lookup(arg7.name(), set, false);
        if (lookup7 == null && set2 != null) {
            lookup7 = lookup(arg7.name(), set2, false);
        }
        List list6 = (List) this.tensorArrays.get(lookup7);
        Preconditions.checkState(list6 != null, "Could not find TensorArray: %s", lookup7);
        String name4 = differentialFunction.arg(1).name();
        INDArray array7 = getArray(this.sameDiff.getVariable(name4), set, set2);
        Preconditions.checkState(array7.isVector(), "Indices variable for TensorArrayScatter should be a vector, got %ndShape for %s", array7, name4);
        Preconditions.checkState(array7.dataType().isIntType(), "Indices variable for TensorArrayScatter should be an integer type, got %s for array %s", array7.dataType(), name4);
        int[] intVector3 = array7.toIntVector();
        INDArray array8 = getArray(this.sameDiff.getVariable(differentialFunction.arg(2).name()), set, set2);
        while (list6.size() <= intVector3.length) {
            list6.add(null);
        }
        if (intVector3.length == 1 && intVector3[0] == -1) {
            intVector3 = ArrayUtil.range(0, (int) array8.size(0));
        }
        INDArrayIndex[] iNDArrayIndexArr2 = (INDArrayIndex[]) ArrayUtil.nTimes(array8.rank(), NDArrayIndex.all(), INDArrayIndex.class);
        for (int i7 = 0; i7 < intVector3.length; i7++) {
            iNDArrayIndexArr2[0] = NDArrayIndex.point(i7);
            INDArray dup2 = this.mmgr.dup(array8.get(iNDArrayIndexArr2));
            int i8 = intVector3[i7];
            if (array8.rank() == 1 && dup2.rank() > 0) {
                dup2 = dup2.reshape(new long[0]);
            }
            if (i8 >= list6.size()) {
                while (list6.size() <= i8) {
                    list6.add(null);
                }
            }
            list6.set(i8, dup2);
            this.arrayUseTracker.addDependency(dup2, new ExecDoneDep());
        }
        return new INDArray[]{this.mmgr.allocate(false, DataType.FLOAT, new long[0]).assign(Double.valueOf(0.0d))};
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public INDArray getConstantOrVariable(String str) {
        Preconditions.checkState(this.sameDiff.getVariable(str).isConstant() || this.sameDiff.getVariable(str).getVariableType() == VariableType.VARIABLE, "Variable %s is not a constant", str);
        return this.sameDiff.getArrForVarName(str);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    /* JADX WARN: Multi-variable type inference failed */
    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public Pair<SameDiffOp, OpContext> getAndParameterizeOp(String str, FrameIter frameIter, Set<AbstractSession.VarId> set, Set<AbstractSession.VarId> set2, Set<String> set3, Map<String, INDArray> map, Set<String> set4) {
        SameDiffOp sameDiffOp = this.sameDiff.getOps().get(str);
        DifferentialFunction op = sameDiffOp.getOp();
        Preconditions.checkNotNull(op, "No differential function found with name \"%s\"", str);
        if ((op instanceof LoopCond) || (op instanceof Enter) || (op instanceof Exit) || (op instanceof NextIteration) || (op instanceof Merge) || (op instanceof Switch) || (op instanceof BaseTensorOp)) {
            return new Pair<>(sameDiffOp, null);
        }
        String[] argNames = op.argNames();
        int length = argNames == null ? 0 : argNames.length;
        int size = set == null ? 0 : set.size();
        int size2 = set2 == null ? 0 : set2.size();
        int size3 = set3 == null ? 0 : set3.size();
        if (length != size + size3 + size2) {
            if (length > 1) {
                Collections.addAll(new LinkedHashSet(), argNames);
            } else {
                Preconditions.checkState(length == size + size3, "Different number of arg names as op inputs for op %s (%s): arg names %s vs. op inputs %s+%s", op.getClass().getSimpleName(), str, argNames, set, set3);
            }
        }
        INDArray[] iNDArrayArr = null;
        if (argNames != null && argNames.length > 0) {
            iNDArrayArr = new INDArray[argNames.length];
            int i = 0;
            for (String str2 : argNames) {
                SDVariable variable = this.sameDiff.getVariable(str2);
                if (variable.isConstant()) {
                    iNDArrayArr[i] = variable.getArr();
                } else if (variable.getVariableType() == VariableType.VARIABLE) {
                    iNDArrayArr[i] = variable.getArr();
                } else if (variable.isPlaceHolder()) {
                    Preconditions.checkState(map != null && map.containsKey(str2), "No array was provided for required placeholder variable \"%s\"", str2);
                    iNDArrayArr[i] = map.get(str2);
                } else {
                    iNDArrayArr[i] = (INDArray) this.nodeOutputs.get(lookup(str2, set, set2, true));
                }
                Preconditions.checkNotNull(iNDArrayArr[i], "Could not parameterize op %s: array %s (variable %s) is null", str, Integer.valueOf(i), variable.name());
                i++;
            }
        }
        boolean z = !frameIter.getFrame().equals(AbstractSession.OUTER_FRAME) && frameIter.getIteration() > 0;
        OpContext opContext = this.opContexts.get(str);
        if (opContext == null) {
            opContext = Nd4j.getExecutioner().buildContext();
            this.opContexts.put(str, opContext);
        }
        if (op instanceof CustomOp) {
            DynamicCustomOp dynamicCustomOp = (DynamicCustomOp) op;
            if (iNDArrayArr != null) {
                opContext.setInputArrays(iNDArrayArr);
            }
            if (op instanceof Identity) {
                return new Pair<>(sameDiffOp, opContext);
            }
            if (dynamicCustomOp.numIArguments() > 0) {
                opContext.setIArguments(dynamicCustomOp.iArgs());
            }
            if (dynamicCustomOp.numDArguments() > 0) {
                opContext.setDArguments(dynamicCustomOp.dArgs());
            }
            if (dynamicCustomOp.numTArguments() > 0) {
                opContext.setTArguments(dynamicCustomOp.tArgs());
            }
            if (dynamicCustomOp.numBArguments() > 0) {
                opContext.setBArguments(dynamicCustomOp.bArgs());
            }
            List<LongShapeDescriptor> calculateOutputShape = dynamicCustomOp.calculateOutputShape(opContext);
            Preconditions.checkState(calculateOutputShape != null && calculateOutputShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", dynamicCustomOp.opName(), dynamicCustomOp.getOwnName());
            String[] outputVariablesNames = op.outputVariablesNames();
            Preconditions.checkState(outputVariablesNames.length == calculateOutputShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation with %s outputs (number of shapes and outputs must be equal)", op.opName(), Integer.valueOf(calculateOutputShape.size()), Integer.valueOf(outputVariablesNames.length));
            for (int i2 = 0; i2 < calculateOutputShape.size(); i2++) {
                LongShapeDescriptor longShapeDescriptor = calculateOutputShape.get(i2);
                DataType dataType = this.sameDiff.getVariable(outputVariablesNames[i2]).dataType();
                if (dataType != longShapeDescriptor.dataType()) {
                    longShapeDescriptor = longShapeDescriptor.asDataType(dataType);
                }
                INDArray allocate = this.mmgr.allocate(set4.contains(outputVariablesNames[i2]), longShapeDescriptor);
                if (longShapeDescriptor.isEmpty() && !allocate.isEmpty()) {
                    throw new IllegalStateException("Output shape was empty, but created array was not.");
                }
                opContext.setOutputArray(i2, allocate);
            }
        } else if (op instanceof Op) {
            Op op2 = (Op) op;
            boolean z2 = false;
            boolean z3 = false;
            if ((op2 instanceof ReduceOp) && ((ReduceOp) op2).getOpType() != Op.Type.REDUCE3 && op.argNames().length == 2) {
                SDVariable arg = op.arg(1);
                Preconditions.checkState(arg.dataType().isIntType(), "Legacy op %s input 1 (axis) was expected to be an integer type, is %s", op.getClass(), arg.dataType());
                INDArray array = getArray(arg, set, set2);
                Preconditions.checkState(array != null, "Could not get axis argument for op %s: %s", op.getOwnName(), op.getClass());
                if (array.isEmpty()) {
                    op.setDimensions(null);
                    z3 = true;
                    ((BaseReduceOp) op2).setEmptyReduce(true);
                } else {
                    op.setDimensions(Shape.normalizeAxis(iNDArrayArr[0].rank(), array.toIntVector()));
                    ((BaseReduceOp) op2).setEmptyReduce(false);
                }
                z2 = true;
            } else if ((op2 instanceof ScalarOp) && op.argNames().length == 2) {
                INDArray array2 = getArray(op.arg(1), set, set2);
                Preconditions.checkState(array2 != null, "Could not get scalar argument for op %s: %s", op.getOwnName(), op.getClass());
                Preconditions.checkState(array2.isScalar(), "Scalar argument for op %s (%s) is not a scalar: has shape %ndShape", op.getOwnName(), op.getClass(), array2);
                ((ScalarOp) op2).setScalar(array2);
            }
            if (iNDArrayArr != null && iNDArrayArr.length > 0) {
                opContext.setInputArray(0, iNDArrayArr[0]);
                if (iNDArrayArr.length == 2 && !z2) {
                    opContext.setInputArray(1, iNDArrayArr[1]);
                }
            }
            boolean contains = set4.contains(((BaseOp) op2).outputVariablesNames()[0]);
            if (z3) {
                opContext.setOutputArray(0, this.mmgr.allocate(false, opContext.getInputArray(0).dataType(), opContext.getInputArray(0).shape()));
            } else {
                List<LongShapeDescriptor> calculateOutputShape2 = ((BaseOp) op2).calculateOutputShape(opContext);
                Preconditions.checkState(calculateOutputShape2 != null && calculateOutputShape2.size() == 1, "Could not calculate output shape for op: %s", op2.getClass());
                opContext.setOutputArray(0, this.mmgr.allocate(contains, calculateOutputShape2.get(0)));
            }
        }
        return new Pair<>(sameDiffOp, opContext);
    }

    protected INDArray getArray(SDVariable sDVariable, Collection<AbstractSession.VarId> collection, Collection<AbstractSession.VarId> collection2) {
        String name = sDVariable.name();
        if (sDVariable.getVariableType() == VariableType.CONSTANT || sDVariable.getVariableType() == VariableType.VARIABLE) {
            return getConstantOrVariable(name);
        }
        if (sDVariable.getArr() != null) {
            return sDVariable.getArr();
        }
        AbstractSession.VarId lookup = lookup(name, collection, collection2, false);
        Preconditions.checkState(lookup != null, "Could not find array for variable %s", sDVariable.name());
        return (INDArray) this.nodeOutputs.get(lookup);
    }

    public SessionMemMgr getMmgr() {
        return this.mmgr;
    }

    public void setMmgr(SessionMemMgr sessionMemMgr) {
        this.mmgr = sessionMemMgr;
    }

    public AbstractDependencyTracker<INDArray, Dep> getArrayUseTracker() {
        return this.arrayUseTracker;
    }

    public void setArrayUseTracker(AbstractDependencyTracker<INDArray, Dep> abstractDependencyTracker) {
        this.arrayUseTracker = abstractDependencyTracker;
    }

    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public /* bridge */ /* synthetic */ INDArray[] getOutputs(Pair<SameDiffOp, OpContext> pair, FrameIter frameIter, Set set, Set set2, Set set3, List list, At at, MultiDataSet multiDataSet, Set set4) {
        return getOutputs2(pair, frameIter, (Set<AbstractSession.VarId>) set, (Set<AbstractSession.VarId>) set2, (Set<String>) set3, (List<Listener>) list, at, multiDataSet, (Set<String>) set4);
    }

    @Override // org.nd4j.autodiff.samediff.internal.AbstractSession
    public /* bridge */ /* synthetic */ Pair<SameDiffOp, OpContext> getAndParameterizeOp(String str, FrameIter frameIter, Set set, Set set2, Set set3, Map<String, INDArray> map, Set set4) {
        return getAndParameterizeOp(str, frameIter, (Set<AbstractSession.VarId>) set, (Set<AbstractSession.VarId>) set2, (Set<String>) set3, map, (Set<String>) set4);
    }
}
