/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.ops;

import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.ops.SDOps;
import org.nd4j.autodiff.samediff.ops.SDValidation;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;

public class SDRNN
extends SDOps {
    public SDRNN(SameDiff sameDiff) {
        super(sameDiff);
    }

    public SDVariable gru(SDVariable x, SDVariable hLast, SDVariable Wx, SDVariable Wh, SDVariable biases) {
        SDValidation.validateNumerical("gru", "x", x);
        SDValidation.validateNumerical("gru", "hLast", hLast);
        SDValidation.validateNumerical("gru", "Wx", Wx);
        SDValidation.validateNumerical("gru", "Wh", Wh);
        SDValidation.validateNumerical("gru", "biases", biases);
        return new GRU(this.sd, x, hLast, Wx, Wh, biases).outputVariable();
    }

    public SDVariable gru(String name, SDVariable x, SDVariable hLast, SDVariable Wx, SDVariable Wh, SDVariable biases) {
        SDValidation.validateNumerical("gru", "x", x);
        SDValidation.validateNumerical("gru", "hLast", hLast);
        SDValidation.validateNumerical("gru", "Wx", Wx);
        SDValidation.validateNumerical("gru", "Wh", Wh);
        SDValidation.validateNumerical("gru", "biases", biases);
        SDVariable out = new GRU(this.sd, x, hLast, Wx, Wh, biases).outputVariable();
        return this.sd.updateVariableNameAndReference(out, name);
    }

    public SDVariable[] gruCell(SDVariable x, SDVariable hLast, GRUWeights GRUWeights2) {
        SDValidation.validateNumerical("gruCell", "x", x);
        SDValidation.validateNumerical("gruCell", "hLast", hLast);
        return new GRUCell(this.sd, x, hLast, GRUWeights2).outputVariables();
    }

    public SDVariable[] gruCell(String[] names, SDVariable x, SDVariable hLast, GRUWeights GRUWeights2) {
        SDValidation.validateNumerical("gruCell", "x", x);
        SDValidation.validateNumerical("gruCell", "hLast", hLast);
        SDVariable[] out = new GRUCell(this.sd, x, hLast, GRUWeights2).outputVariables();
        return this.sd.updateVariableNamesAndReferences(out, names);
    }

    public SDVariable[] lstmCell(SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights LSTMWeights2, LSTMConfiguration LSTMConfiguration2) {
        SDValidation.validateNumerical("lstmCell", "x", x);
        SDValidation.validateNumerical("lstmCell", "cLast", cLast);
        SDValidation.validateNumerical("lstmCell", "yLast", yLast);
        return new LSTMBlockCell(this.sd, x, cLast, yLast, LSTMWeights2, LSTMConfiguration2).outputVariables();
    }

    public SDVariable[] lstmCell(String[] names, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights LSTMWeights2, LSTMConfiguration LSTMConfiguration2) {
        SDValidation.validateNumerical("lstmCell", "x", x);
        SDValidation.validateNumerical("lstmCell", "cLast", cLast);
        SDValidation.validateNumerical("lstmCell", "yLast", yLast);
        SDVariable[] out = new LSTMBlockCell(this.sd, x, cLast, yLast, LSTMWeights2, LSTMConfiguration2).outputVariables();
        return this.sd.updateVariableNamesAndReferences(out, names);
    }

    public SDVariable[] lstmLayer(SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, LSTMLayerWeights LSTMLayerWeights2, LSTMLayerConfig LSTMLayerConfig2) {
        SDValidation.validateNumerical("lstmLayer", "x", x);
        SDValidation.validateNumerical("lstmLayer", "cLast", cLast);
        SDValidation.validateNumerical("lstmLayer", "yLast", yLast);
        SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
        return new LSTMLayer(this.sd, x, cLast, yLast, maxTSLength, LSTMLayerWeights2, LSTMLayerConfig2).outputVariables();
    }

    public SDVariable[] lstmLayer(String[] names, SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, LSTMLayerWeights LSTMLayerWeights2, LSTMLayerConfig LSTMLayerConfig2) {
        SDValidation.validateNumerical("lstmLayer", "x", x);
        SDValidation.validateNumerical("lstmLayer", "cLast", cLast);
        SDValidation.validateNumerical("lstmLayer", "yLast", yLast);
        SDValidation.validateNumerical("lstmLayer", "maxTSLength", maxTSLength);
        SDVariable[] out = new LSTMLayer(this.sd, x, cLast, yLast, maxTSLength, LSTMLayerWeights2, LSTMLayerConfig2).outputVariables();
        return this.sd.updateVariableNamesAndReferences(out, names);
    }

    public SDVariable[] lstmLayer(SDVariable x, LSTMLayerWeights LSTMLayerWeights2, LSTMLayerConfig LSTMLayerConfig2) {
        SDValidation.validateNumerical("lstmLayer", "x", x);
        return new LSTMLayer(this.sd, x, null, null, null, LSTMLayerWeights2, LSTMLayerConfig2).outputVariables();
    }

    public SDVariable[] lstmLayer(String[] names, SDVariable x, LSTMLayerWeights LSTMLayerWeights2, LSTMLayerConfig LSTMLayerConfig2) {
        SDValidation.validateNumerical("lstmLayer", "x", x);
        SDVariable[] out = new LSTMLayer(this.sd, x, null, null, null, LSTMLayerWeights2, LSTMLayerConfig2).outputVariables();
        return this.sd.updateVariableNamesAndReferences(out, names);
    }

    public SDVariable lstmblock(SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights LSTMWeights2, LSTMConfiguration LSTMConfiguration2) {
        SDValidation.validateNumerical("lstmblock", "maxTSLength", maxTSLength);
        SDValidation.validateNumerical("lstmblock", "x", x);
        SDValidation.validateNumerical("lstmblock", "cLast", cLast);
        SDValidation.validateNumerical("lstmblock", "yLast", yLast);
        return new LSTMBlock(this.sd, maxTSLength, x, cLast, yLast, LSTMWeights2, LSTMConfiguration2).outputVariable();
    }

    public SDVariable lstmblock(String name, SDVariable maxTSLength, SDVariable x, SDVariable cLast, SDVariable yLast, LSTMWeights LSTMWeights2, LSTMConfiguration LSTMConfiguration2) {
        SDValidation.validateNumerical("lstmblock", "maxTSLength", maxTSLength);
        SDValidation.validateNumerical("lstmblock", "x", x);
        SDValidation.validateNumerical("lstmblock", "cLast", cLast);
        SDValidation.validateNumerical("lstmblock", "yLast", yLast);
        SDVariable out = new LSTMBlock(this.sd, maxTSLength, x, cLast, yLast, LSTMWeights2, LSTMConfiguration2).outputVariable();
        return this.sd.updateVariableNameAndReference(out, name);
    }

    public SDVariable lstmblock(SDVariable x, LSTMWeights LSTMWeights2, LSTMConfiguration LSTMConfiguration2) {
        SDValidation.validateNumerical("lstmblock", "x", x);
        return new LSTMBlock(this.sd, null, x, null, null, LSTMWeights2, LSTMConfiguration2).outputVariable();
    }

    public SDVariable lstmblock(String name, SDVariable x, LSTMWeights LSTMWeights2, LSTMConfiguration LSTMConfiguration2) {
        SDValidation.validateNumerical("lstmblock", "x", x);
        SDVariable out = new LSTMBlock(this.sd, null, x, null, null, LSTMWeights2, LSTMConfiguration2).outputVariable();
        return this.sd.updateVariableNameAndReference(out, name);
    }

    public SDVariable sru(SDVariable x, SDVariable initialC, SDVariable mask, SRUWeights SRUWeights2) {
        SDValidation.validateNumerical("sru", "x", x);
        SDValidation.validateNumerical("sru", "initialC", initialC);
        SDValidation.validateNumerical("sru", "mask", mask);
        return new SRU(this.sd, x, initialC, mask, SRUWeights2).outputVariable();
    }

    public SDVariable sru(String name, SDVariable x, SDVariable initialC, SDVariable mask, SRUWeights SRUWeights2) {
        SDValidation.validateNumerical("sru", "x", x);
        SDValidation.validateNumerical("sru", "initialC", initialC);
        SDValidation.validateNumerical("sru", "mask", mask);
        SDVariable out = new SRU(this.sd, x, initialC, mask, SRUWeights2).outputVariable();
        return this.sd.updateVariableNameAndReference(out, name);
    }

    public SDVariable sru(SDVariable x, SDVariable initialC, SRUWeights SRUWeights2) {
        SDValidation.validateNumerical("sru", "x", x);
        SDValidation.validateNumerical("sru", "initialC", initialC);
        return new SRU(this.sd, x, initialC, null, SRUWeights2).outputVariable();
    }

    public SDVariable sru(String name, SDVariable x, SDVariable initialC, SRUWeights SRUWeights2) {
        SDValidation.validateNumerical("sru", "x", x);
        SDValidation.validateNumerical("sru", "initialC", initialC);
        SDVariable out = new SRU(this.sd, x, initialC, null, SRUWeights2).outputVariable();
        return this.sd.updateVariableNameAndReference(out, name);
    }

    public SDVariable sruCell(SDVariable x, SDVariable cLast, SRUWeights SRUWeights2) {
        SDValidation.validateNumerical("sruCell", "x", x);
        SDValidation.validateNumerical("sruCell", "cLast", cLast);
        return new SRUCell(this.sd, x, cLast, SRUWeights2).outputVariable();
    }

    public SDVariable sruCell(String name, SDVariable x, SDVariable cLast, SRUWeights SRUWeights2) {
        SDValidation.validateNumerical("sruCell", "x", x);
        SDValidation.validateNumerical("sruCell", "cLast", cLast);
        SDVariable out = new SRUCell(this.sd, x, cLast, SRUWeights2).outputVariable();
        return this.sd.updateVariableNameAndReference(out, name);
    }
}

