/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops;

import java.nio.Buffer;
import java.util.Arrays;
import java.util.Map;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseReduceOp;
import org.nd4j.linalg.api.ops.BaseScalarOp;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.GridOp;
import org.nd4j.linalg.api.ops.IndexAccumulation;
import org.nd4j.linalg.api.ops.MetaOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.ScalarOp;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public abstract class BaseOp
extends DifferentialFunction
implements Op {
    protected INDArray x;
    protected INDArray y;
    protected INDArray z;
    protected String xVertexId;
    protected String yVertexId;
    protected String zVertexId;
    protected DataBuffer extraArgz;
    protected INDArray dimensionz;

    public BaseOp() {
    }

    public BaseOp(SameDiff sameDiff, boolean inPlace, Object[] extraArgs) {
        super(sameDiff, inPlace, extraArgs);
    }

    public BaseOp(SameDiff sameDiff, Object[] extraArgs) {
        super(sameDiff, extraArgs);
    }

    public BaseOp(INDArray x, INDArray z) {
        this(x, null, z);
    }

    public BaseOp(INDArray x, INDArray y, INDArray z) {
        super(false);
        this.x = x;
        this.y = y;
        this.z = z;
    }

    public BaseOp(INDArray x) {
        this(x, null, x);
    }

    public static Op.Type getOpType(Op op) {
        Op.Type type = null;
        if (op instanceof CustomOp) {
            return Op.Type.CUSTOM;
        }
        if (op instanceof TransformOp) {
            type = op.y() == null ? Op.Type.TRANSFORM_FLOAT : Op.Type.PAIRWISE;
        } else if (op instanceof ReduceOp) {
            type = op.y() == null ? ((ReduceOp)op).getOpType() : Op.Type.REDUCE3;
        } else if (op instanceof ScalarOp) {
            type = Op.Type.SCALAR;
        } else if (op instanceof BroadcastOp) {
            type = Op.Type.BROADCAST;
        } else if (op instanceof IndexAccumulation) {
            type = Op.Type.INDEXREDUCE;
        } else if (op instanceof MetaOp) {
            type = Op.Type.META;
        } else if (op instanceof GridOp) {
            type = Op.Type.GRID;
        }
        return type;
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    }

    @Override
    public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
    }

    @Override
    public DataBuffer extraArgsDataBuff(DataType dtype) {
        if (this.extraArgz != null) {
            return this.extraArgz;
        }
        if (this.extraArgs != null) {
            if (Shape.isZ(dtype) || Shape.isB(dtype)) {
                long[] extraz = new long[this.extraArgs.length];
                for (int i = 0; i < this.extraArgs.length; ++i) {
                    long val;
                    if (!(this.extraArgs[i] instanceof Number)) continue;
                    Number arg = (Number)this.extraArgs[i];
                    extraz[i] = val = arg.longValue();
                }
                this.extraArgz = Nd4j.getConstantHandler().getConstantBuffer(extraz, dtype);
                return this.extraArgz;
            }
            if (Shape.isR(dtype)) {
                double[] extraz = new double[this.extraArgs.length];
                for (int i = 0; i < this.extraArgs.length; ++i) {
                    double val;
                    if (!(this.extraArgs[i] instanceof Number)) continue;
                    Number arg = (Number)this.extraArgs[i];
                    if (arg == null) {
                        arg = 0.0;
                    }
                    extraz[i] = val = arg.doubleValue();
                }
                this.extraArgz = Nd4j.getConstantHandler().getConstantBuffer(extraz, dtype);
                return this.extraArgz;
            }
        }
        return null;
    }

    @Override
    public Buffer extraArgsBuff() {
        if (this.extraArgs != null) {
            if (this.x.data().dataType() == DataType.FLOAT) {
                DataBuffer retBuff = Nd4j.createBuffer(new float[this.extraArgs.length]);
                for (int i = 0; i < this.extraArgs.length; ++i) {
                    Number val = (Number)this.extraArgs[i];
                    retBuff.put((long)i, val.floatValue());
                }
                return retBuff.asNioFloat();
            }
            DataBuffer retBuff = Nd4j.createBuffer(new double[this.extraArgs.length]);
            for (int i = 0; i < this.extraArgs.length; ++i) {
                Number val = (Number)this.extraArgs[i];
                retBuff.put((long)i, val.doubleValue());
            }
            return retBuff.asNioDouble();
        }
        return null;
    }

    @Override
    public void setX(INDArray x) {
        this.x = x;
    }

    @Override
    public void setZ(INDArray z) {
        this.z = z;
    }

    @Override
    public void setY(INDArray y) {
        this.y = y;
    }

    @Override
    public Object[] extraArgs() {
        return this.extraArgs;
    }

    @Override
    public INDArray x() {
        return this.x;
    }

    @Override
    public INDArray y() {
        return this.y;
    }

    @Override
    public INDArray z() {
        return this.z;
    }

    @Override
    public INDArray getInputArgument(int index) {
        Preconditions.checkState(index >= 0 && index < 2, "Input argument index must be 0 or 1, got %s", index);
        return index == 0 ? this.x : this.y;
    }

    @Override
    public SDVariable[] outputVariables(String baseName) {
        if (this.zVertexId == null) {
            String[] outputNames = this.sameDiff.getOutputsForOp(this);
            if (outputNames != null) {
                this.zVertexId = this.sameDiff.getVariable(outputNames[0]).name();
                SDVariable[] ret = new SDVariable[]{this.sameDiff.getVariable(outputNames[0])};
                return ret;
            }
            if (this.isInPlace()) {
                SDVariable[] newVars = this.sameDiff.generateOutputVariableForOp(this, null, false);
                INDArray inputArr = this.x();
                if (inputArr == null) {
                    this.computeVariables(newVars);
                    return newVars;
                }
                this.sameDiff.setArrayForVariable(newVars[0].name(), inputArr);
                this.z = inputArr;
                if (this.sameDiff.getOutputsForOp(this) == null) {
                    this.sameDiff.addOutgoingFor(newVars, (DifferentialFunction)this);
                }
                this.computeVariables(newVars);
                return newVars;
            }
            SDVariable[] newVars = this.sameDiff.generateOutputVariableForOp(this, baseName, false);
            this.computeVariables(newVars);
            if (this.sameDiff.getOutputsForOp(this) == null) {
                this.sameDiff.addOutgoingFor(newVars, (DifferentialFunction)this);
            }
            return newVars;
        }
        return new SDVariable[]{this.sameDiff.getVariable(this.zVertexId)};
    }

    public void computeVariables(SDVariable[] newVars) {
        if (this.sameDiff.isEagerMode()) {
            BaseScalarOp baseScalarOp;
            SDVariable[] args = this.args();
            if (args.length == 1) {
                this.x = args[0].getArr();
            } else if (args.length > 1) {
                this.x = args[0].getArr();
                if (this.opType() == Op.Type.REDUCE3 || this.opType() == Op.Type.PAIRWISE_BOOL || this.opType() == Op.Type.TRANSFORM_SAME || this.opType() == Op.Type.REDUCE_SAME) {
                    this.y = args[1].getArr();
                } else if (this.opType() == Op.Type.REDUCE_FLOAT || this.opType() == Op.Type.REDUCE_LONG || this.opType() == Op.Type.REDUCE_BOOL) {
                    this.dimensionz = args[1].getArr();
                    this.dimensions = args[1].getArr().toIntVector();
                }
            }
            if (this.x == null) {
                throw new IllegalArgumentException("No variable found for the given input variables of " + args[0].name() + " At least one input required.");
            }
            if (args.length > 0 && args[0].dataType() != null) {
                this.x = this.x.castTo(args[0].dataType());
            }
            if (args.length > 1 && args[1].dataType() != null && this.y != null) {
                this.y = this.y.castTo(args[1].dataType());
            }
            if (this.z == null) {
                if (this.dimensions == null) {
                    this.setZ(Nd4j.zeros(this.x.shape()).castTo(newVars[0].dataType()));
                } else if (this instanceof BaseReduceOp) {
                    BaseReduceOp baseReduceOp = (BaseReduceOp)this;
                    this.setZ(Nd4j.create(Shape.reductionShape(this.x, this.dimensions, true, baseReduceOp.keepDims)).castTo(newVars[0].dataType()));
                } else {
                    this.setZ(Nd4j.create(Shape.reductionShape(this.x, this.dimensions, true, false)).castTo(newVars[0].dataType()));
                }
            }
            if (this instanceof BaseScalarOp && (baseScalarOp = (BaseScalarOp)this).scalar() != null && baseScalarOp.scalar().dataType() != baseScalarOp.x().dataType()) {
                baseScalarOp.setScalar(baseScalarOp.scalar().castTo(this.x().dataType()));
            }
            INDArray exec = Nd4j.getExecutioner().exec(this);
            for (int i = 0; i < newVars.length; ++i) {
                newVars[i].setShape(exec.shape());
                this.sameDiff.setEagerArrForVarName(newVars[i].name(), exec);
            }
        }
    }

    @Override
    public String toString() {
        return this.opName();
    }

    @Override
    public CustomOp toCustomOp() {
        DynamicCustomOp.DynamicCustomOpsBuilder customOpBuilder = DynamicCustomOp.builder(this.opName());
        customOpBuilder.callInplace(this.x() == this.z());
        if (this.y() != null) {
            customOpBuilder.addInputs(this.x(), this.y());
        } else {
            customOpBuilder.addInputs(this.x());
        }
        customOpBuilder.addOutputs(this.z());
        if (this.extraArgs != null) {
            for (int i = 0; i < this.extraArgs.length; ++i) {
                if (this.extraArgs[i] instanceof Integer) {
                    customOpBuilder.addIntegerArguments((long)((Integer)this.extraArgs[i]).intValue());
                    continue;
                }
                if (!(this.extraArgs[i] instanceof Double) && !(this.extraArgs[i] instanceof Float)) continue;
                Double num = (Double)this.extraArgs[i];
                customOpBuilder.addFloatingPointArguments(num);
            }
        }
        return customOpBuilder.build();
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        BaseOp baseOp = (BaseOp)o;
        if (this.x != null ? !this.x.equals(baseOp.x) : baseOp.x != null) {
            return false;
        }
        if (this.y != null ? !this.y.equals(baseOp.y) : baseOp.y != null) {
            return false;
        }
        if (this.z != null ? !this.z.equals(baseOp.z) : baseOp.z != null) {
            return false;
        }
        if (!Arrays.equals(this.extraArgs, baseOp.extraArgs)) {
            return false;
        }
        return this.extraArgz != null ? this.extraArgz.equals(baseOp.extraArgz) : baseOp.extraArgz == null;
    }

    @Override
    public int hashCode() {
        int result = super.hashCode();
        result = 31 * result + (this.x != null ? this.x.hashCode() : 0);
        result = 31 * result + (this.y != null ? this.y.hashCode() : 0);
        result = 31 * result + (this.z != null ? this.z.hashCode() : 0);
        result = 31 * result + Arrays.hashCode(this.extraArgs);
        result = 31 * result + (this.extraArgz != null ? this.extraArgz.hashCode() : 0);
        return result;
    }

    protected void defineDimensions(int ... dimensions) {
        if (dimensions != null && dimensions.length > 0 && this.x != null) {
            dimensions = Shape.normalizeAxis(this.x.rank(), dimensions);
        }
        if (dimensions == null || dimensions.length == 0) {
            dimensions = new int[]{Integer.MAX_VALUE};
        }
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            this.dimensionz = Shape.ndArrayDimFromInt(dimensions);
        }
    }

    public INDArray dimensions() {
        return this.dimensionz;
    }

    public Number getFinalResult() {
        if (this.z == null) {
            throw new ND4JIllegalStateException("Op.Z is null. Op wasn't executed yet?");
        }
        if (this.z.isEmpty()) {
            throw new ND4JIllegalStateException("Can't get number from empty array");
        }
        if (!this.z.isScalar()) {
            throw new ND4JIllegalStateException("Can't get final result scalar out of N-dim tensor");
        }
        if (this.z.isR()) {
            return new Double(this.z.getDouble(0L));
        }
        if (this.z.isZ()) {
            return new Long(this.z.getInt(0));
        }
        if (this.z.isB()) {
            return new Integer(this.z.getInt(0));
        }
        throw new ND4JIllegalStateException("???");
    }

    @Override
    public int getNumOutputs() {
        return 1;
    }

    @Override
    public void clearArrays() {
        this.x = null;
        this.y = null;
        this.z = null;
    }

    @Override
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op opName found for " + this.opName());
    }

    @Override
    public String tensorflowName() {
        throw new NoOpNameFoundException("No tensorflow op opName found for " + this.opName());
    }

    public INDArray getX() {
        return this.x;
    }

    public INDArray getY() {
        return this.y;
    }

    public INDArray getZ() {
        return this.z;
    }

    public DataBuffer getExtraArgz() {
        return this.extraArgz;
    }

    public INDArray getDimensionz() {
        return this.dimensionz;
    }

    public void setExtraArgz(DataBuffer extraArgz) {
        this.extraArgz = extraArgz;
    }

    public void setDimensionz(INDArray dimensionz) {
        this.dimensionz = dimensionz;
    }

    public String getXVertexId() {
        return this.xVertexId;
    }

    public String getYVertexId() {
        return this.yVertexId;
    }

    public String getZVertexId() {
        return this.zVertexId;
    }

    public void setXVertexId(String xVertexId) {
        this.xVertexId = xVertexId;
    }

    public void setYVertexId(String yVertexId) {
        this.yVertexId = yVertexId;
    }

    public void setZVertexId(String zVertexId) {
        this.zVertexId = zVertexId;
    }
}

