/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.layers.recurrent;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayerBp;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMActivations;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDataFormat;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMDirectionMode;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.shade.guava.primitives.Booleans;

public class LSTMLayer
extends DynamicCustomOp {
    private LSTMLayerConfig configuration;
    private LSTMLayerWeights weights;
    private SDVariable cLast;
    private SDVariable yLast;
    private String cLastName;
    private String yLastName;
    private SDVariable maxTSLength;

    public LSTMLayer() {
    }

    public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, LSTMLayerWeights weights, LSTMLayerConfig configuration) {
        super(null, sameDiff, weights.argsWithInputs(x, maxTSLength, cLast, yLast));
        if (sameDiff == null) {
            throw new NullPointerException("sameDiff is marked non-null but is null");
        }
        this.configuration = configuration;
        this.weights = weights;
        this.cLast = cLast;
        this.yLast = yLast;
        this.maxTSLength = maxTSLength;
        this.addIArgument(this.iArgs());
        this.addTArgument(this.tArgs());
        this.addBArgument(this.bArgs(weights, maxTSLength, yLast, cLast));
        Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(), "You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence  methods  in LSTMLayerConfig builder to specify them");
    }

    public LSTMLayer(INDArray x, INDArray cLast, INDArray yLast, INDArray maxTSLength, LSTMLayerWeights lstmWeights, LSTMLayerConfig LSTMLayerConfig2) {
        super(null, null, lstmWeights.argsWithInputs(maxTSLength, x, cLast, yLast));
        this.configuration = LSTMLayerConfig2;
        this.weights = lstmWeights;
        this.addIArgument(this.iArgs());
        this.addTArgument(this.tArgs());
        this.addBArgument(this.bArgs(this.weights, maxTSLength, yLast, cLast));
        Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(), "You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence  methods  in LSTMLayerConfig builder to specify them");
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
        Preconditions.checkState(inputDataTypes != null && 3 <= inputDataTypes.size() && inputDataTypes.size() <= 8, "Expected amount of inputs to LSTMLayer between 3 inputs minimum (input, Wx, Wr only) or 8 maximum, got %s", inputDataTypes);
        DataType dt = inputDataTypes.get(1);
        ArrayList<DataType> list = new ArrayList<DataType>();
        if (this.configuration.isRetFullSequence()) {
            list.add(dt);
        }
        if (this.configuration.isRetLastC()) {
            list.add(dt);
        }
        if (this.configuration.isRetLastH()) {
            list.add(dt);
        }
        Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", (Object)dt);
        return list;
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> grads) {
        int i = 0;
        SDVariable grad0 = this.configuration.isRetFullSequence() ? grads.get(i++) : null;
        SDVariable grad1 = this.configuration.isRetLastH() ? grads.get(i++) : null;
        SDVariable grad2 = this.configuration.isRetLastC() ? grads.get(i++) : null;
        return Arrays.asList(new LSTMLayerBp(this.sameDiff, this.arg(0), this.cLast, this.yLast, this.maxTSLength, this.weights, this.configuration, grad0, grad1, grad2).outputVariables());
    }

    @Override
    public String opName() {
        return "lstmLayer";
    }

    @Override
    public Map<String, Object> propertiesForFunction() {
        Map<String, Object> base = this.configuration.toProperties(true, true);
        if (this.cLast != null) {
            base.put("cLast", this.cLast);
        }
        if (this.yLast != null) {
            base.put("yLast", this.yLast);
        }
        return base;
    }

    @Override
    public long[] iArgs() {
        return new long[]{this.configuration.getLstmdataformat().ordinal(), this.configuration.getDirectionMode().ordinal(), this.configuration.getGateAct().ordinal(), this.configuration.getOutAct().ordinal(), this.configuration.getCellAct().ordinal()};
    }

    @Override
    public double[] tArgs() {
        return new double[]{this.configuration.getCellClip()};
    }

    protected <T> boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) {
        return new boolean[]{weights.hasBias(), maxTSLength != null, yLast != null, cLast != null, weights.hasPH(), this.configuration.isRetFullSequence(), this.configuration.isRetLastH(), this.configuration.isRetLastC()};
    }

    @Override
    public void configureFromArguments() {
        if (!(this.configuration != null || this.bArguments.isEmpty() || this.iArguments.isEmpty() || this.tArguments.isEmpty())) {
            LSTMLayerConfig.LSTMLayerConfigBuilder builder = LSTMLayerConfig.builder();
            builder.retLastH((Boolean)this.bArguments.get(6));
            builder.retFullSequence((Boolean)this.bArguments.get(5));
            builder.retLastC((Boolean)this.bArguments.get(4));
            builder.cellClip((Double)this.tArguments.get(0));
            builder.lstmdataformat(LSTMDataFormat.values()[((Long)this.iArguments.get(0)).intValue()]);
            builder.directionMode(LSTMDirectionMode.values()[((Long)this.iArguments.get(1)).intValue()]);
            builder.gateAct(LSTMActivations.values()[((Long)this.iArguments.get(2)).intValue()]);
            builder.outAct(LSTMActivations.values()[((Long)this.iArguments.get(3)).intValue()]);
            builder.cellAct(LSTMActivations.values()[((Long)this.iArguments.get(4)).intValue()]);
            this.configuration = builder.build();
        }
    }

    @Override
    public void configureWithSameDiff(SameDiff sameDiff) {
        this.sameDiff = sameDiff;
        String[] inputsForOp = sameDiff.getInputsForOp(this);
        LSTMLayerWeights.LSTMLayerWeightsBuilder builder = LSTMLayerWeights.builder();
        boolean hasBiases = (Boolean)this.bArguments.get(0);
        boolean hasSeqLen = (Boolean)this.bArguments.get(1);
        boolean hasInitH = (Boolean)this.bArguments.get(2);
        boolean hasInitC = (Boolean)this.bArguments.get(3);
        boolean hasPH = (Boolean)this.bArguments.get(4);
        boolean retFullSeq = (Boolean)this.bArguments.get(5);
        boolean retLastH = (Boolean)this.bArguments.get(6);
        boolean retLastC = (Boolean)this.bArguments.get(7);
        builder.weights(sameDiff.getVariable(inputsForOp[1]));
        builder.rWeights(sameDiff.getVariable(inputsForOp[2]));
        if (hasBiases) {
            builder.bias(sameDiff.getVariable(inputsForOp[3]));
        }
        if (hasPH) {
            builder.peepholeWeights(sameDiff.getVariable(inputsForOp[inputsForOp.length - 1]));
        }
        this.weights = builder.build();
        if (this.yLastName != null) {
            this.yLast = sameDiff.getVariable(this.yLastName);
        }
        if (this.cLastName != null) {
            this.cLast = sameDiff.getVariable(this.cLastName);
        }
    }

    @Override
    public void setPropertiesForFunction(Map<String, Object> properties) {
        if (this.configuration == null) {
            String yLast;
            String cLast;
            String lstmdataformat;
            String gateAct;
            Boolean retLastH;
            Boolean retLastC;
            String cellAct;
            Double cellClip;
            String directionMode;
            String act;
            LSTMLayerConfig.LSTMLayerConfigBuilder builder = LSTMLayerConfig.builder();
            Boolean retFullSequence = this.getBooleanFromProperty("retFullSequence", properties);
            if (retFullSequence != null) {
                builder.retFullSequence(retFullSequence);
            }
            if ((act = this.getStringFromProperty("outAct", properties)) != null) {
                builder.outAct(LSTMActivations.valueOf(act));
            }
            if ((directionMode = this.getStringFromProperty("directionMode", properties)) != null) {
                builder.directionMode(LSTMDirectionMode.valueOf(directionMode));
            }
            if ((cellClip = this.getDoubleValueFromProperty("cellClip", properties)) != null) {
                builder.cellClip(cellClip);
            }
            if ((cellAct = this.getStringFromProperty("cellAct", properties)) != null) {
                builder.cellAct(LSTMActivations.valueOf(cellAct));
            }
            if ((retLastC = this.getBooleanFromProperty("retLastC", properties)) != null) {
                builder.retLastC(retLastC);
            }
            if ((retLastH = this.getBooleanFromProperty("retLastH", properties)) != null) {
                builder.retLastH(retLastH);
            }
            if ((gateAct = this.getStringFromProperty("gateAct", properties)) != null) {
                builder.gateAct(LSTMActivations.valueOf(gateAct));
            }
            if ((lstmdataformat = this.getStringFromProperty("lstmdataformat", properties)) != null) {
                builder.lstmdataformat(LSTMDataFormat.valueOf(LSTMDataFormat.class, lstmdataformat));
            }
            if ((cLast = this.getStringFromProperty("cLast", properties)) != null) {
                this.cLastName = cLast;
            }
            if ((yLast = this.getStringFromProperty("cLast", properties)) != null) {
                this.yLastName = yLast;
            }
            this.configuration = builder.build();
        }
    }

    @Override
    public boolean isConfigProperties() {
        return true;
    }

    @Override
    public String configFieldName() {
        return "configuration";
    }

    @Override
    public int getNumOutputs() {
        return Booleans.countTrue(this.configuration.isRetFullSequence(), this.configuration.isRetLastH(), this.configuration.isRetLastC());
    }

    public LSTMLayerConfig getConfiguration() {
        return this.configuration;
    }

    public LSTMLayerWeights getWeights() {
        return this.weights;
    }
}

