package org.nd4j.linalg.api.ops.impl.shape.tensorops;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/shape/tensorops/TensorArray.class */
public class TensorArray extends BaseTensorOp {
    protected DataType tensorArrayDataType;

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        return "TensorArrayV3";
    }

    public TensorArray(String str, SameDiff sameDiff, DataType dataType) {
        super(str, sameDiff, new SDVariable[0]);
        this.tensorArrayDataType = dataType;
    }

    public TensorArray(SameDiff sameDiff, DataType dataType) {
        super(sameDiff, new SDVariable[0]);
        this.tensorArrayDataType = dataType;
    }

    public TensorArray(TensorArray tensorArray) {
        super(tensorArray.sameDiff, new SDVariable[0]);
        this.tensorArrayDataType = tensorArray.tensorArrayDataType;
    }

    public TensorArray(TensorArray tensorArray, SDVariable[] sDVariableArr) {
        super(tensorArray.sameDiff, sDVariableArr);
        this.tensorArrayDataType = tensorArray.tensorArrayDataType;
    }

    @Override // org.nd4j.linalg.api.ops.impl.shape.tensorops.BaseTensorOp, org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        String input = nodeDef.getInput(nodeDef.getInputCount() - 1);
        NodeDef nodeDef2 = null;
        for (int i = 0; i < graphDef.getNodeCount(); i++) {
            if (graphDef.getNode(i).getName().equals(input)) {
                nodeDef2 = graphDef.getNode(i);
            }
        }
        INDArray nDArrayFromTensor = TFGraphMapper.getNDArrayFromTensor(nodeDef2);
        if (nDArrayFromTensor != null) {
            addIArgument(nDArrayFromTensor.getInt(0));
        }
        this.tensorArrayDataType = TFGraphMapper.convertType(map.get("dtype").getType());
    }

    public TensorArray() {
        this(DataType.FLOAT);
    }

    public TensorArray(DataType dataType) {
        this.tensorArrayDataType = dataType;
    }

    @Override // org.nd4j.linalg.api.ops.impl.shape.tensorops.BaseTensorOp, org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String toString() {
        return opName();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "create_list";
    }

    @Override // org.nd4j.linalg.api.ops.impl.shape.tensorops.BaseTensorOp, org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public Op.Type opType() {
        return Op.Type.CUSTOM;
    }

    private SDVariable getVar() {
        return outputVariable();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public SameDiff getSameDiff() {
        SameDiff sameDiff = this.sameDiff;
        return sameDiff.getChild() != null ? sameDiff.getChild() : sameDiff;
    }

    private SDVariable intToVar(int... iArr) {
        return this.sameDiff.constant(Nd4j.createFromArray(iArr));
    }

    public SDVariable read(int i) {
        return new TensorArrayRead(getSameDiff(), new SDVariable[]{getVar(), intToVar(i)}).outputVariable();
    }

    public SDVariable read(SDVariable sDVariable) {
        return new TensorArrayRead(getSameDiff(), new SDVariable[]{getVar(), sDVariable}).outputVariable();
    }

    public SDVariable gather(SDVariable sDVariable, int... iArr) {
        return new TensorArrayGather(getSameDiff(), new SDVariable[]{getVar(), this.sameDiff.constant(Nd4j.createFromArray(iArr)), sDVariable}).outputVariable();
    }

    public SDVariable gather(SDVariable sDVariable, SDVariable sDVariable2) {
        return new TensorArrayGather(getSameDiff(), new SDVariable[]{getVar(), sDVariable2, sDVariable}).outputVariable();
    }

    public SDVariable stack(SDVariable sDVariable) {
        return new TensorArrayGather(getSameDiff(), new SDVariable[]{getVar(), intToVar(-1), sDVariable}).outputVariable();
    }

    public SDVariable concat(SDVariable sDVariable) {
        return new TensorArrayConcat(getSameDiff(), new SDVariable[]{getVar()}).outputVariable();
    }

    public SDVariable write(SDVariable sDVariable, int i, SDVariable sDVariable2) {
        return write(sDVariable, intToVar(i), sDVariable2);
    }

    public SDVariable write(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return new TensorArrayWrite(getSameDiff(), new SDVariable[]{getVar(), sDVariable2, sDVariable3, sDVariable}).outputVariable();
    }

    public SDVariable scatter(SDVariable sDVariable, SDVariable sDVariable2, int... iArr) {
        return new TensorArrayScatter(getSameDiff(), new SDVariable[]{getVar(), intToVar(iArr), sDVariable2, sDVariable}).outputVariable();
    }

    public SDVariable scatter(SDVariable sDVariable, SDVariable sDVariable2, SDVariable sDVariable3) {
        return new TensorArrayScatter(getSameDiff(), new SDVariable[]{getVar(), sDVariable3, sDVariable2, sDVariable}).outputVariable();
    }

    public SDVariable unstack(SDVariable sDVariable, SDVariable sDVariable2) {
        return new TensorArrayScatter(getSameDiff(), new SDVariable[]{getVar(), intToVar(-1), sDVariable2, sDVariable}).outputVariable();
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        return Arrays.asList(DataType.BOOL, DataType.FLOAT);
    }

    @Override // org.nd4j.linalg.api.ops.impl.shape.tensorops.BaseTensorOp, org.nd4j.autodiff.functions.DifferentialFunction
    public int getNumOutputs() {
        return 2;
    }

    public DataType getTensorArrayDataType() {
        return this.tensorArrayDataType;
    }
}
