package org.nd4j.autodiff.samediff.optimize;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.optimize.debug.OptimizationDebugger;
import org.nd4j.autodiff.samediff.optimize.optimizations.ConstantFunctionOptimizations;
import org.nd4j.autodiff.samediff.optimize.optimizations.CuDNNFunctionOptimizations;
import org.nd4j.autodiff.samediff.optimize.optimizations.IdentityFunctionOptimizations;
import org.nd4j.autodiff.samediff.optimize.optimizations.ShapeFunctionOptimizations;
import org.nd4j.autodiff.samediff.optimize.optimizations.UnusedFunctionOptimizations;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/autodiff/samediff/optimize/GraphOptimizer.class */
public class GraphOptimizer {
    private static final Logger log = LoggerFactory.getLogger(GraphOptimizer.class);

    public static List<OptimizerSet> defaultOptimizations() {
        return Arrays.asList(new UnusedFunctionOptimizations(), new ConstantFunctionOptimizations(), new IdentityFunctionOptimizations(), new ShapeFunctionOptimizations(), new UnusedFunctionOptimizations(), new CuDNNFunctionOptimizations());
    }

    public static SameDiff optimize(SameDiff sameDiff, String... strArr) {
        return optimize(sameDiff, (List<String>) Arrays.asList(strArr));
    }

    public static SameDiff optimize(SameDiff sameDiff, List<String> list) {
        return optimize(sameDiff, list, defaultOptimizations());
    }

    public static SameDiff optimize(SameDiff sameDiff, List<String> list, List<OptimizerSet> list2) {
        return optimize(sameDiff, list, list2, null);
    }

    public static SameDiff optimize(SameDiff sameDiff, List<String> list, List<OptimizerSet> list2, OptimizationDebugger optimizationDebugger) {
        SameDiff dup = sameDiff.dup();
        ArrayHolder constantArrays = dup.getConstantArrays();
        ArrayHolder variablesArrays = dup.getVariablesArrays();
        OptimizationHelper optimizationHelper = new OptimizationHelper(sameDiff, new OptimizationConfig());
        for (int i = 0; i < 3; i++) {
            Iterator<OptimizerSet> it = list2.iterator();
            while (it.hasNext()) {
                for (Optimizer optimizer : it.next().getOptimizers()) {
                    for (SameDiffOp sameDiffOp : new ArrayList(dup.getOps().values())) {
                        if (dup.getOps().containsKey(sameDiffOp.getName())) {
                            if (optimizationDebugger != null) {
                                optimizationDebugger.beforeOptimizationCheck(dup, sameDiffOp, optimizer);
                            }
                            boolean checkAndApply = optimizer.checkAndApply(dup, optimizationHelper, sameDiffOp, constantArrays, variablesArrays);
                            if (checkAndApply) {
                                log.info("Operation was applied: {}", optimizer);
                            }
                            if (optimizationDebugger != null) {
                                optimizationDebugger.afterOptimizationsCheck(dup, sameDiffOp, optimizer, checkAndApply);
                            }
                        }
                    }
                }
            }
        }
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        int i5 = 0;
        int i6 = 0;
        int i7 = 0;
        Iterator<SDVariable> it2 = sameDiff.variables().iterator();
        while (it2.hasNext()) {
            switch (it2.next().getVariableType()) {
                case VARIABLE:
                    i4++;
                    break;
                case CONSTANT:
                    i2++;
                    break;
                case ARRAY:
                    i6++;
                    break;
            }
        }
        Iterator<SDVariable> it3 = dup.variables().iterator();
        while (it3.hasNext()) {
            switch (it3.next().getVariableType()) {
                case VARIABLE:
                    i5++;
                    break;
                case CONSTANT:
                    i3++;
                    break;
                case ARRAY:
                    i7++;
                    break;
            }
        }
        log.info("Total variables: {} before, {} after", Integer.valueOf(sameDiff.getVariables().size()), Integer.valueOf(dup.getVariables().size()));
        log.info("Constant variables: {} before, {} after", Integer.valueOf(i2), Integer.valueOf(i3));
        log.info("Array type variables: {} before, {} after", Integer.valueOf(i6), Integer.valueOf(i7));
        log.info("Variable type variables: {} before, {} after", Integer.valueOf(i4), Integer.valueOf(i5));
        log.info("Ops: {} before, {} after", Integer.valueOf(sameDiff.getOps().size()), Integer.valueOf(dup.getOps().size()));
        return dup;
    }
}
