/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor;

import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;

public class LabelLastTimeStepPreProcessor
implements DataSetPreProcessor {
    @Override
    public void preProcess(DataSet toPreProcess) {
        INDArray labels2d;
        INDArray label3d = toPreProcess.getLabels();
        Preconditions.checkState(label3d.rank() == 3, "LabelLastTimeStepPreProcessor expects rank 3 labels, got rank %s labels with shape %ndShape", (Object)label3d.rank(), (Object)label3d);
        INDArray lMask = toPreProcess.getLabelsMaskArray();
        if (lMask == null) {
            labels2d = label3d.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(label3d.size(2) - 1L)).dup();
        } else {
            INDArray lastIndex = BooleanIndexing.lastIndex(lMask, Conditions.greaterThan(0), 1);
            long[] idxs = lastIndex.data().asLong();
            labels2d = Nd4j.create(DataType.FLOAT, label3d.size(0), label3d.size(1));
            for (int i = 0; i < idxs.length; ++i) {
                long lastStepIdx = idxs[i];
                Preconditions.checkState(lastStepIdx >= 0L, "Invalid last time step index: example %s in minibatch is entirely masked out (label mask is all 0s, meaning no label data is present for this example)", i);
                labels2d.putRow(i, label3d.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.point(lastStepIdx)));
            }
        }
        toPreProcess.setLabels(labels2d);
        toPreProcess.setLabelsMaskArray(null);
    }
}

