package io.bioimage.modelrunner.example;

import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
import io.bioimage.modelrunner.engine.EngineInfo;
import io.bioimage.modelrunner.engine.installation.EngineInstall;
import io.bioimage.modelrunner.exceptions.LoadEngineException;
import io.bioimage.modelrunner.model.java.DLModelJava;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.versionmanagement.AvailableEngines;
import io.bioimage.modelrunner.versionmanagement.InstalledEngines;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.concurrent.ExecutionException;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;

/* loaded from: input_file:io/bioimage/modelrunner/example/ExampleLoadAndRunModel.class */
public class ExampleLoadAndRunModel {
    private static final String CWD = System.getProperty("user.dir");
    private static final String ENGINES_DIR = new File(CWD, "engines").getAbsolutePath();
    private static final String MODELS_DIR = new File(CWD, "models").getAbsolutePath();

    public static <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void main(String[] strArr) throws LoadEngineException, Exception {
        String str = ENGINES_DIR;
        downloadCPUEngine("torchscript", "1.13.1", str);
        String downloadBMZModel = downloadBMZModel("EnhancerMitochondriaEM2D", MODELS_DIR);
        DLModelJava loadModel = loadModel(downloadBMZModel, new File(downloadBMZModel, "weights-torchscript.pt").getAbsolutePath(), createEngineInfo("torchscript", "1.13.1", str, true, InstalledEngines.checkEngineWithArgsInstalledForOS("torchscript", "1.13.1", true, null, str).get(0).getGPU()));
        ArrayImgFactory arrayImgFactory = new ArrayImgFactory(new FloatType());
        Tensor<T> build = Tensor.build("input0", "bcyx", arrayImgFactory.create(new long[]{1, 1, 512, 512}));
        ArrayList arrayList = new ArrayList();
        arrayList.add(build);
        Tensor<R> build2 = Tensor.build("output0", "bcyx", arrayImgFactory.create(new long[]{1, 2, 512, 512}));
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(build2);
        System.out.println(Util.average(Util.asDoubleArray(arrayList2.get(0).getData())));
        loadModel.run(arrayList, arrayList2);
        System.out.println(Util.average(Util.asDoubleArray(arrayList2.get(0).getData())));
        loadModel.close();
        arrayList.stream().forEach(tensor -> {
            tensor.close();
        });
        arrayList2.stream().forEach(tensor2 -> {
            tensor2.close();
        });
        System.out.print("Success!!");
    }

    public static void downloadCPUEngine(String str, String str2, String str3) throws IOException, InterruptedException, ExecutionException {
        EngineInstall.installEngineInDir(AvailableEngines.getEnginesForOsByParams(str, str2, true, null).get(0), str3);
    }

    public static String downloadBMZModel(String str, String str2) throws IOException, InterruptedException {
        return BioimageioRepo.connect().downloadByName(str, str2);
    }

    public static EngineInfo createEngineInfo(String str, String str2, String str3, boolean z, boolean z2) {
        return EngineInfo.defineDLEngine(str, str2, z, z2, str3);
    }

    public static DLModelJava loadModel(String str, String str2, EngineInfo engineInfo) throws LoadEngineException, Exception {
        DLModelJava createModel = DLModelJava.createModel(str, str2, engineInfo);
        createModel.loadModel();
        return createModel;
    }
}
