package org.nd4j.linalg.factory.ops;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.GRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMConfiguration;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.GRUWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMWeights;
import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.SRUWeights;
import org.nd4j.linalg.factory.NDValidation;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/factory/ops/NDRNN.class */
public class NDRNN {
    public INDArray gru(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5) {
        NDValidation.validateNumerical("gru", "x", iNDArray);
        NDValidation.validateNumerical("gru", "hLast", iNDArray2);
        NDValidation.validateNumerical("gru", "Wx", iNDArray3);
        NDValidation.validateNumerical("gru", "Wh", iNDArray4);
        NDValidation.validateNumerical("gru", "biases", iNDArray5);
        return Nd4j.exec(new GRU(iNDArray, iNDArray2, iNDArray3, iNDArray4, iNDArray5))[0];
    }

    public INDArray[] gruCell(INDArray iNDArray, INDArray iNDArray2, GRUWeights gRUWeights) {
        NDValidation.validateNumerical("gruCell", "x", iNDArray);
        NDValidation.validateNumerical("gruCell", "hLast", iNDArray2);
        return Nd4j.exec(new GRUCell(iNDArray, iNDArray2, gRUWeights));
    }

    public INDArray[] lstmCell(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, LSTMWeights lSTMWeights, LSTMConfiguration lSTMConfiguration) {
        NDValidation.validateNumerical("lstmCell", "x", iNDArray);
        NDValidation.validateNumerical("lstmCell", "cLast", iNDArray2);
        NDValidation.validateNumerical("lstmCell", "yLast", iNDArray3);
        return Nd4j.exec(new LSTMBlockCell(iNDArray, iNDArray2, iNDArray3, lSTMWeights, lSTMConfiguration));
    }

    public INDArray[] lstmLayer(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, LSTMLayerWeights lSTMLayerWeights, LSTMLayerConfig lSTMLayerConfig) {
        NDValidation.validateNumerical("lstmLayer", "x", iNDArray);
        NDValidation.validateNumerical("lstmLayer", "cLast", iNDArray2);
        NDValidation.validateNumerical("lstmLayer", "yLast", iNDArray3);
        NDValidation.validateNumerical("lstmLayer", "maxTSLength", iNDArray4);
        return Nd4j.exec(new LSTMLayer(iNDArray, iNDArray2, iNDArray3, iNDArray4, lSTMLayerWeights, lSTMLayerConfig));
    }

    public INDArray[] lstmLayer(INDArray iNDArray, LSTMLayerWeights lSTMLayerWeights, LSTMLayerConfig lSTMLayerConfig) {
        NDValidation.validateNumerical("lstmLayer", "x", iNDArray);
        return Nd4j.exec(new LSTMLayer(iNDArray, null, null, null, lSTMLayerWeights, lSTMLayerConfig));
    }

    public INDArray lstmblock(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, LSTMWeights lSTMWeights, LSTMConfiguration lSTMConfiguration) {
        NDValidation.validateNumerical("lstmblock", "maxTSLength", iNDArray);
        NDValidation.validateNumerical("lstmblock", "x", iNDArray2);
        NDValidation.validateNumerical("lstmblock", "cLast", iNDArray3);
        NDValidation.validateNumerical("lstmblock", "yLast", iNDArray4);
        return Nd4j.exec(new LSTMBlock(iNDArray, iNDArray2, iNDArray3, iNDArray4, lSTMWeights, lSTMConfiguration))[0];
    }

    public INDArray lstmblock(INDArray iNDArray, LSTMWeights lSTMWeights, LSTMConfiguration lSTMConfiguration) {
        NDValidation.validateNumerical("lstmblock", "x", iNDArray);
        return Nd4j.exec(new LSTMBlock(null, iNDArray, null, null, lSTMWeights, lSTMConfiguration))[0];
    }

    public INDArray sru(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, SRUWeights sRUWeights) {
        NDValidation.validateNumerical("sru", "x", iNDArray);
        NDValidation.validateNumerical("sru", "initialC", iNDArray2);
        NDValidation.validateNumerical("sru", "mask", iNDArray3);
        return Nd4j.exec(new SRU(iNDArray, iNDArray2, iNDArray3, sRUWeights))[0];
    }

    public INDArray sru(INDArray iNDArray, INDArray iNDArray2, SRUWeights sRUWeights) {
        NDValidation.validateNumerical("sru", "x", iNDArray);
        NDValidation.validateNumerical("sru", "initialC", iNDArray2);
        return Nd4j.exec(new SRU(iNDArray, iNDArray2, null, sRUWeights))[0];
    }

    public INDArray sruCell(INDArray iNDArray, INDArray iNDArray2, SRUWeights sRUWeights) {
        NDValidation.validateNumerical("sruCell", "x", iNDArray);
        NDValidation.validateNumerical("sruCell", "cLast", iNDArray2);
        return Nd4j.exec(new SRUCell(iNDArray, iNDArray2, sRUWeights))[0];
    }
}
