package org.nd4j.autodiff.samediff.optimize.optimizations;

import java.util.List;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.optimize.OptimizationHelper;
import org.nd4j.autodiff.samediff.optimize.Optimizer;
import org.nd4j.linalg.api.environment.Nd4jEnvironment;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/autodiff/samediff/optimize/optimizations/CuDNNFunctionOptimizations.class */
public class CuDNNFunctionOptimizations extends BaseOptimizerSet {
    protected static final boolean isCudaBackend = "CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty(Nd4jEnvironment.BACKEND_KEY));

    /* loaded from: input_file:org/nd4j/autodiff/samediff/optimize/optimizations/CuDNNFunctionOptimizations$CudnnConv2dNCHWtoNHWCConversion.class */
    public static class CudnnConv2dNCHWtoNHWCConversion implements Optimizer {
        @Override // org.nd4j.autodiff.samediff.optimize.Optimizer
        public boolean checkAndApply(SameDiff sameDiff, OptimizationHelper optimizationHelper, SameDiffOp sameDiffOp, ArrayHolder arrayHolder, ArrayHolder arrayHolder2) {
            if (!(sameDiffOp.getOp() instanceof Conv2D)) {
                return false;
            }
            Conv2D conv2D = (Conv2D) sameDiffOp.getOp();
            boolean isNHWC = conv2D.getConfig().isNHWC();
            if (isNHWC && 0 != 0) {
                return false;
            }
            List<String> inputsToOp = sameDiffOp.getInputsToOp();
            String str = inputsToOp.get(1);
            if (!isNHWC) {
                SDVariable variable = sameDiff.getVariable(inputsToOp.get(0));
                String str2 = variable.name() + "_cudnn_nchw_to_nhwc";
                OptimizationUtils.replaceOpInputsWith(sameDiff, variable.name(), str2);
                variable.permute(0, 2, 3, 1).rename(str2);
                SDVariable variable2 = sameDiff.getVariable(sameDiffOp.getOutputsOfOp().get(0));
                OptimizationUtils.replaceOpInputsWith(sameDiff, variable2.name(), variable2.permute(0, 3, 1, 2).rename(variable2.name() + "_cudnn_nhwc_to_nchw").name());
                conv2D.getConfig().isNHWC(true);
            }
            if (0 != 0) {
                return true;
            }
            SDVariable variable3 = sameDiff.getVariable(str);
            String str3 = variable3.name() + "_cudnn_yxio_to_oyxi";
            OptimizationUtils.replaceOpInputsWith(sameDiff, variable3.name(), str3);
            variable3.permute(3, 0, 1, 2).rename(str3);
            return true;
        }
    }
}
