/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.optimize.optimizations;

import java.util.List;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.samediff.optimize.OptimizationHelper;
import org.nd4j.autodiff.samediff.optimize.Optimizer;
import org.nd4j.autodiff.samediff.optimize.optimizations.BaseOptimizerSet;
import org.nd4j.autodiff.samediff.optimize.optimizations.OptimizationUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.factory.Nd4j;

public class ConstantFunctionOptimizations
extends BaseOptimizerSet {
    public static final String CONSTANT_FN_FOLDING_MAX_SIZE = "optimizer.constants.function.max.output.size";
    public static final long CONSTANT_FN_FOLDING_MAX_SIZE_DEFAULT = 0x400000L;

    public static class FoldConstantFunctions
    implements Optimizer {
        @Override
        public boolean checkAndApply(SameDiff sd, OptimizationHelper helper, SameDiffOp op, ArrayHolder constantArrays, ArrayHolder variablesArrays) {
            INDArray[] outputs;
            Object o;
            List<String> in = op.getInputsToOp();
            if (in == null || in.isEmpty()) {
                return false;
            }
            for (String s : in) {
                if (sd.getVariable(s).isConstant()) continue;
                return false;
            }
            long maxSizeToApply = Long.parseLong(helper.getProperties().getProperty(ConstantFunctionOptimizations.CONSTANT_FN_FOLDING_MAX_SIZE, String.valueOf(0x400000L)));
            DifferentialFunction df = op.getOp();
            df.clearArrays();
            for (int i = 0; i < in.size(); ++i) {
                String s = in.get(i);
                INDArray arr = sd.getVariable(s).getArr();
                if (df instanceof CustomOp) {
                    ((CustomOp)((Object)df)).addInputArgument(arr);
                    continue;
                }
                if (i == 0) {
                    ((Op)((Object)df)).setX(arr);
                    continue;
                }
                ((Op)((Object)df)).setY(arr);
            }
            if (df instanceof CustomOp) {
                o = (CustomOp)((Object)df);
                Nd4j.exec((CustomOp)o);
                outputs = new INDArray[o.numOutputArguments()];
                for (int j = 0; j < outputs.length; ++j) {
                    outputs[j] = o.getOutputArgument(j);
                }
            } else {
                o = (Op)((Object)df);
                Nd4j.exec((Op)o);
                outputs = new INDArray[]{o.z()};
            }
            long sizeCount = 0L;
            for (INDArray i : outputs) {
                if (!i.dataType().isNumerical()) continue;
                sizeCount += i.length() * (long)i.dataType().width();
            }
            if (sizeCount > maxSizeToApply) {
                return false;
            }
            List<String> outputNames = op.getOutputsOfOp();
            for (int i = 0; i < outputNames.size(); ++i) {
                String n = outputNames.get(i);
                sd.getVariable(n).setVariableType(VariableType.CONSTANT);
                constantArrays.setArray(n, outputs[i]);
                ((Variable)sd.getVariables().get((Object)n)).setOutputOfOp(null);
            }
            OptimizationUtils.removeOp(sd, df.getOwnName());
            return true;
        }
    }
}

