package org.nd4j.linalg.api.ops.impl.transforms.dtype;

import java.lang.reflect.Field;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.descriptors.properties.adapters.DataTypeAdapter;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

/* loaded from: input_file:org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.class */
public class Cast extends BaseDynamicTransformOp {
    private DataType typeDst;

    public Cast() {
    }

    public Cast(SameDiff sameDiff, SDVariable sDVariable, @NonNull DataType dataType) {
        super(sameDiff, new SDVariable[]{sDVariable}, false);
        if (dataType == null) {
            throw new NullPointerException("dst is marked non-null but is null");
        }
        this.typeDst = dataType;
        addArgs();
    }

    public Cast(@NonNull INDArray iNDArray, @NonNull DataType dataType) {
        super(new INDArray[]{iNDArray}, null);
        if (iNDArray == null) {
            throw new NullPointerException("arg is marked non-null but is null");
        }
        if (dataType == null) {
            throw new NullPointerException("dataType is marked non-null but is null");
        }
        this.typeDst = dataType;
        addArgs();
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff sameDiff, Map<String, AttrValue> map, GraphDef graphDef) {
        TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, map, nodeDef, graphDef);
        addArgs();
    }

    protected void addArgs() {
        addIArgument(FlatBuffersMapper.getDataTypeAsByte(this.typeDst));
    }

    @Override // org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        LinkedHashMap linkedHashMap2 = new LinkedHashMap();
        DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
        linkedHashMap2.put("typeDst", new DataTypeAdapter());
        linkedHashMap.put(tensorflowName(), linkedHashMap2);
        return linkedHashMap;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap hashMap = new HashMap();
        HashMap hashMap2 = new HashMap();
        for (PropertyMapping propertyMapping : new PropertyMapping[]{PropertyMapping.builder().tfAttrName("DstT").propertyNames(new String[]{"typeDst"}).build()}) {
            for (String str : propertyMapping.getPropertyNames()) {
                hashMap2.put(str, propertyMapping);
            }
        }
        hashMap.put(tensorflowName(), hashMap2);
        return hashMap;
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public void setValueFor(Field field, Object obj) {
        if (obj == null || (obj instanceof String) || (obj instanceof DataType)) {
            super.setValueFor(field, obj);
        }
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction, org.nd4j.linalg.api.ops.CustomOp
    public String opName() {
        return "cast";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public String tensorflowName() {
        return "Cast";
    }

    @Override // org.nd4j.linalg.api.ops.DynamicCustomOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<SDVariable> doDiff(List<SDVariable> list) {
        return arg().dataType().isFPType() ? Collections.singletonList(list.get(0).castTo(arg().dataType())) : Collections.singletonList(this.sameDiff.zerosLike(arg()));
    }

    @Override // org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp, org.nd4j.autodiff.functions.DifferentialFunction
    public List<DataType> calculateOutputDataTypes(List<DataType> list) {
        Preconditions.checkState(list != null && list.size() == 1, "Expected exactly 1 input datatype for %s, got input %s", getClass(), list);
        return Collections.singletonList(this.typeDst);
    }
}
