/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.example;

import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
import io.bioimage.modelrunner.engine.EngineInfo;
import io.bioimage.modelrunner.engine.installation.EngineInstall;
import io.bioimage.modelrunner.exceptions.LoadEngineException;
import io.bioimage.modelrunner.model.java.DLModelJava;
import io.bioimage.modelrunner.system.PlatformDetection;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.versionmanagement.AvailableEngines;
import io.bioimage.modelrunner.versionmanagement.DeepLearningVersion;
import io.bioimage.modelrunner.versionmanagement.InstalledEngines;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutionException;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;

public class ExampleLoadTensorflow1Tensorflow2 {
    private static final String CWD = System.getProperty("user.dir");
    private static final String ENGINES_DIR = new File(CWD, "engines").getAbsolutePath();
    private static final String MODELS_DIR = new File(CWD, "models").getAbsolutePath();

    public static void main(String[] args) throws LoadEngineException, Exception {
        if (PlatformDetection.isUsingRosseta()) {
            System.out.println("Tensorflow 1 cannot run on ARM64 chips (Apple Silicon) and Tensorflow 2 cannot run on ARM64 chips using Rosetta. In order to be able to run Tensorflow 2, please update the Java version to one compatible with ARM64 chips that does not require rosetta. Tensorflow 1 is not available for ARM64 based computers.");
            return;
        }
        if (PlatformDetection.getArch().equals("arm64") && !PlatformDetection.isUsingRosseta()) {
            System.out.println("Tensorflow 1 cannot run on ARM64 chips (Apple Silicon). Only Tensorflow 2.7.0 is currently compatible with ARM64 chips, the execution of a Tensorflow 1 model wil be skipped.");
            ExampleLoadTensorflow1Tensorflow2.loadAndRunTf2();
            System.out.println("Great success!");
            return;
        }
        ExampleLoadTensorflow1Tensorflow2.loadAndRunTf2();
        ExampleLoadTensorflow1Tensorflow2.loadAndRunTf1();
        System.out.println("Great success!");
    }

    public static <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void loadAndRunTf2() throws LoadEngineException, Exception {
        String framework = "tensorflow_saved_model_bundle";
        String engineVersion = "2.7.0";
        String enginesDir = ENGINES_DIR;
        ExampleLoadTensorflow1Tensorflow2.downloadCPUEngine(framework, engineVersion, enginesDir);
        String bmzModelName = "B. Sutilist bacteria segmentation - Widefield microscopy - 2D UNet";
        String modelFolder = ExampleLoadTensorflow1Tensorflow2.downloadBMZModel(bmzModelName, MODELS_DIR);
        boolean cpu = true;
        List<DeepLearningVersion> installedList = InstalledEngines.checkEngineWithArgsInstalledForOS(framework, engineVersion, cpu, null, enginesDir);
        boolean gpu = installedList.get(0).getGPU();
        EngineInfo engineInfo = ExampleLoadTensorflow1Tensorflow2.createEngineInfo(framework, engineVersion, enginesDir, cpu, gpu);
        try (DLModelJava model = ExampleLoadTensorflow1Tensorflow2.loadModel(modelFolder, null, engineInfo);){
            ArrayImgFactory imgFactory = new ArrayImgFactory((NativeType)new FloatType());
            Img img1 = imgFactory.create(new long[]{1L, 512L, 512L, 1L});
            Tensor inpTensor = Tensor.build("input_1", "bxyc", img1);
            ArrayList inputs = new ArrayList();
            inputs.add(inpTensor);
            Tensor<FloatType> outTensor0 = Tensor.buildBlankTensor("conv2d_19", "bxyc", new long[]{1L, 512L, 512L, 3L}, new FloatType());
            ArrayList outputs = new ArrayList();
            outputs.add(outTensor0);
            System.out.println(Util.average((double[])Util.asDoubleArray(((Tensor)outputs.get(0)).getData())));
            model.run(inputs, outputs);
            System.out.println(Util.average((double[])Util.asDoubleArray(((Tensor)outputs.get(0)).getData())));
            model.close();
            inputs.stream().forEach(t -> t.close());
            outputs.stream().forEach(t -> t.close());
            System.out.println("Success running Tensorflow 2!!");
        }
    }

    public static <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void loadAndRunTf1() throws LoadEngineException, Exception {
        String framework = "tensorflow_saved_model_bundle";
        String engineVersion = "1.15.0";
        String enginesDir = ENGINES_DIR;
        ExampleLoadTensorflow1Tensorflow2.downloadCPUEngine(framework, engineVersion, enginesDir);
        String bmzModelName = "StarDist H&E Nuclei Segmentation";
        String modelFolder = ExampleLoadTensorflow1Tensorflow2.downloadBMZModel(bmzModelName, MODELS_DIR);
        boolean cpu = true;
        List<DeepLearningVersion> installedList = InstalledEngines.checkEngineWithArgsInstalledForOS(framework, engineVersion, cpu, null, enginesDir);
        boolean gpu = installedList.get(0).getGPU();
        EngineInfo engineInfo = ExampleLoadTensorflow1Tensorflow2.createEngineInfo(framework, engineVersion, enginesDir, cpu, gpu);
        try (DLModelJava model = ExampleLoadTensorflow1Tensorflow2.loadModel(modelFolder, null, engineInfo);){
            ArrayImgFactory imgFactory = new ArrayImgFactory((NativeType)new FloatType());
            Img img1 = imgFactory.create(new long[]{1L, 512L, 512L, 3L});
            Tensor inpTensor = Tensor.build("input", "byxc", img1);
            ArrayList inputs = new ArrayList();
            inputs.add(inpTensor);
            Img img2 = imgFactory.create(new long[]{1L, 512L, 512L, 33L});
            Tensor outTensor = Tensor.build("output", "byxc", img2);
            ArrayList outputs = new ArrayList();
            outputs.add(outTensor);
            System.out.println(Util.average((double[])Util.asDoubleArray(((Tensor)outputs.get(0)).getData())));
            model.run(inputs, outputs);
            System.out.println(Util.average((double[])Util.asDoubleArray(((Tensor)outputs.get(0)).getData())));
            model.close();
            inputs.stream().forEach(t -> t.close());
            outputs.stream().forEach(t -> t.close());
            System.out.println("Success running Tensorflow 1!!");
        }
    }

    public static EngineInfo createEngineInfo(String engine, String engineVersion, String enginesDir, boolean cpu, boolean gpu) {
        return EngineInfo.defineCompatibleDLEngine(engine, engineVersion, cpu, gpu, enginesDir);
    }

    public static DLModelJava loadModel(String modelFolder, String modelSource, EngineInfo engineInfo) throws LoadEngineException, Exception {
        DLModelJava model = DLModelJava.createModel(modelFolder, modelSource, engineInfo);
        model.loadModel();
        return model;
    }

    public static void downloadCPUEngine(String framework, String engineVersion, String enginesDir) throws IOException, InterruptedException, ExecutionException {
        List<DeepLearningVersion> possibleEngines = AvailableEngines.getEnginesForOsByParams(framework, engineVersion, true, null);
        EngineInstall.installEngineInDir(possibleEngines.get(0), enginesDir);
    }

    public static String downloadBMZModel(String bmzModelName, String modelsDir) throws IOException, InterruptedException {
        BioimageioRepo br = BioimageioRepo.connect();
        return br.downloadByName(bmzModelName, modelsDir);
    }
}

