/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.shape.tensorops;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.BaseTensorOp;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayConcat;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayGather;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayRead;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayScatter;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayWrite;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class TensorArray
extends BaseTensorOp {
    protected DataType tensorArrayDataType;

    @Override
    public String tensorflowName() {
        return "TensorArrayV3";
    }

    public TensorArray(String name, SameDiff sameDiff, DataType dataType) {
        super(name, sameDiff, new SDVariable[0]);
        this.tensorArrayDataType = dataType;
    }

    public TensorArray(SameDiff sameDiff, DataType dataType) {
        super(sameDiff, new SDVariable[0]);
        this.tensorArrayDataType = dataType;
    }

    public TensorArray(TensorArray ta) {
        super(ta.sameDiff, new SDVariable[0]);
        this.tensorArrayDataType = ta.tensorArrayDataType;
    }

    public TensorArray(TensorArray ta, SDVariable[] inputs) {
        super(ta.sameDiff, inputs);
        this.tensorArrayDataType = ta.tensorArrayDataType;
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        String idd = nodeDef.getInput(nodeDef.getInputCount() - 1);
        NodeDef iddNode = null;
        for (int i = 0; i < graph.getNodeCount(); ++i) {
            if (!graph.getNode(i).getName().equals(idd)) continue;
            iddNode = graph.getNode(i);
        }
        INDArray arr = TFGraphMapper.getNDArrayFromTensor(iddNode);
        if (arr != null) {
            int idx = arr.getInt(0);
            this.addIArgument(idx);
        }
        this.tensorArrayDataType = TFGraphMapper.convertType(attributesForNode.get("dtype").getType());
    }

    public TensorArray() {
        this(DataType.FLOAT);
    }

    public TensorArray(DataType dataType) {
        this.tensorArrayDataType = dataType;
    }

    @Override
    public String toString() {
        return this.opName();
    }

    @Override
    public String opName() {
        return "create_list";
    }

    @Override
    public Op.Type opType() {
        return Op.Type.CUSTOM;
    }

    private SDVariable getVar() {
        return this.outputVariable();
    }

    @Override
    public SameDiff getSameDiff() {
        SameDiff sd = this.sameDiff;
        if (sd.getChild() != null) {
            return sd.getChild();
        }
        return sd;
    }

    private SDVariable intToVar(int ... index) {
        return this.sameDiff.constant(Nd4j.createFromArray(index));
    }

    public SDVariable read(int index) {
        return new TensorArrayRead(this.getSameDiff(), new SDVariable[]{this.getVar(), this.intToVar(index)}).outputVariable();
    }

    public SDVariable read(SDVariable index) {
        return new TensorArrayRead(this.getSameDiff(), new SDVariable[]{this.getVar(), index}).outputVariable();
    }

    public SDVariable gather(SDVariable flow, int ... indices) {
        return new TensorArrayGather(this.getSameDiff(), new SDVariable[]{this.getVar(), this.sameDiff.constant(Nd4j.createFromArray(indices)), flow}).outputVariable();
    }

    public SDVariable gather(SDVariable flow, SDVariable indices) {
        return new TensorArrayGather(this.getSameDiff(), new SDVariable[]{this.getVar(), indices, flow}).outputVariable();
    }

    public SDVariable stack(SDVariable flow) {
        return new TensorArrayGather(this.getSameDiff(), new SDVariable[]{this.getVar(), this.intToVar(-1), flow}).outputVariable();
    }

    public SDVariable concat(SDVariable flow) {
        return new TensorArrayConcat(this.getSameDiff(), new SDVariable[]{this.getVar()}).outputVariable();
    }

    public SDVariable write(SDVariable flow, int index, SDVariable value) {
        return this.write(flow, this.intToVar(index), value);
    }

    public SDVariable write(SDVariable flow, SDVariable index, SDVariable value) {
        return new TensorArrayWrite(this.getSameDiff(), new SDVariable[]{this.getVar(), index, value, flow}).outputVariable();
    }

    public SDVariable scatter(SDVariable flow, SDVariable value, int ... indices) {
        return new TensorArrayScatter(this.getSameDiff(), new SDVariable[]{this.getVar(), this.intToVar(indices), value, flow}).outputVariable();
    }

    public SDVariable scatter(SDVariable flow, SDVariable value, SDVariable indices) {
        return new TensorArrayScatter(this.getSameDiff(), new SDVariable[]{this.getVar(), indices, value, flow}).outputVariable();
    }

    public SDVariable unstack(SDVariable flow, SDVariable value) {
        return new TensorArrayScatter(this.getSameDiff(), new SDVariable[]{this.getVar(), this.intToVar(-1), value, flow}).outputVariable();
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataType) {
        return Arrays.asList(DataType.BOOL, DataType.FLOAT);
    }

    @Override
    public int getNumOutputs() {
        return 2;
    }

    public DataType getTensorArrayDataType() {
        return this.tensorArrayDataType;
    }
}

