package io.bioimage.modelrunner.example;

import icy.file.FileUtil;
import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.bioimageio.description.weights.ModelWeight;
import io.bioimage.modelrunner.bioimageio.description.weights.WeightFormat;
import io.bioimage.modelrunner.engine.installation.EngineInstall;
import io.bioimage.modelrunner.model.Model;
import io.bioimage.modelrunner.tensor.Tensor;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.LongStream;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;

/* loaded from: input_file:io/bioimage/modelrunner/example/ExampleLoadAndRunAllBmzModels.class */
public class ExampleLoadAndRunAllBmzModels {
    private static final String CWD = System.getProperty("user.dir");
    private static final String ENGINES_DIR = new File(CWD, "engines").getAbsolutePath();
    private static final String MODELS_DIR = new File(CWD, "models").getAbsolutePath();
    private static final List<String> SUPPORTED_FRAMEWORKS = new ArrayList();

    private static void installAllValidEngines() {
        EngineInstall.createInstaller(ENGINES_DIR).basicEngineInstallation();
    }

    public static void main(String[] strArr) throws InterruptedException {
        ArrayList arrayList = new ArrayList();
        installAllValidEngines();
        BioimageioRepo connect = BioimageioRepo.connect();
        Map<String, ModelDescriptor> listAllModels = connect.listAllModels(false);
        int i = 0;
        for (Map.Entry<String, ModelDescriptor> entry : listAllModels.entrySet()) {
            try {
                checkModelCompatibleWithEngines(entry.getValue());
                loadAndRunModel(connect.downloadByName(entry.getValue().getName(), MODELS_DIR), entry.getValue());
                i++;
            } catch (IOException | InterruptedException e) {
                System.out.println(entry.getValue().getName() + ": Error downloading model." + e.toString());
                arrayList.add(entry.getValue().getName());
            } catch (IllegalArgumentException e2) {
                arrayList.add(entry.getValue().getName());
            } catch (Exception e3) {
                System.out.println(entry.getValue().getName() + ": Error loading/running model." + e3.toString());
                arrayList.add(entry.getValue().getName());
            }
        }
        System.out.println("Models run without any issue: " + i + FileUtil.separator + listAllModels.size());
        System.out.println(arrayList);
    }

    public static <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void loadAndRunModel(String str, ModelDescriptor modelDescriptor) throws Exception {
        Model createBioimageioModel = Model.createBioimageioModel(str, ENGINES_DIR);
        try {
            createBioimageioModel.loadModel();
            List<Tensor<FloatType>> createInputs = createInputs(modelDescriptor);
            List<Tensor<R>> createOutputs = createOutputs(modelDescriptor);
            createBioimageioModel.runModel(createInputs, createOutputs);
            Iterator<Tensor<R>> it = createOutputs.iterator();
            while (it.hasNext()) {
                if (it.next().isEmpty()) {
                    throw new Exception(modelDescriptor.getName() + ": Output tensor is empty");
                }
            }
            createBioimageioModel.close();
            createInputs.stream().forEach(tensor -> {
                tensor.close();
            });
            createOutputs.stream().forEach(tensor2 -> {
                tensor2.close();
            });
            if (createBioimageioModel != null) {
                createBioimageioModel.close();
            }
        } catch (Throwable th) {
            if (createBioimageioModel != null) {
                try {
                    createBioimageioModel.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private static List<Tensor<FloatType>> createInputs(ModelDescriptor modelDescriptor) {
        ArrayList arrayList = new ArrayList();
        ArrayImgFactory arrayImgFactory = new ArrayImgFactory(new FloatType());
        for (TensorSpec tensorSpec : modelDescriptor.getInputTensors()) {
            String axesOrder = tensorSpec.getAxesOrder();
            String name = tensorSpec.getName();
            int[] minTileSizeArr = tensorSpec.getMinTileSizeArr();
            int[] tileStepArr = tensorSpec.getTileStepArr();
            arrayList.add(Tensor.build(name, axesOrder, arrayImgFactory.create(LongStream.range(0L, tileStepArr.length).map(j -> {
                return minTileSizeArr[(int) j] + tileStepArr[(int) j];
            }).toArray())));
        }
        return arrayList;
    }

    private static <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createOutputs(ModelDescriptor modelDescriptor) {
        ArrayList arrayList = new ArrayList();
        for (TensorSpec tensorSpec : modelDescriptor.getOutputTensors()) {
            arrayList.add(Tensor.buildEmptyTensor(tensorSpec.getName(), tensorSpec.getAxesOrder()));
        }
        return arrayList;
    }

    private static void checkModelCompatibleWithEngines(ModelDescriptor modelDescriptor) {
        boolean z = false;
        Iterator<WeightFormat> it = modelDescriptor.getWeights().gettAllSupportedWeightObjects().iterator();
        while (it.hasNext()) {
            if (SUPPORTED_FRAMEWORKS.contains(it.next().getFramework())) {
                z = true;
            }
        }
        if (!z) {
            throw new IllegalArgumentException(modelDescriptor.getName() + ": pytorch 2 models cannot run on this test");
        }
        if (!z) {
            throw new IllegalArgumentException(modelDescriptor.getName() + ": weights not supported");
        }
    }

    static {
        SUPPORTED_FRAMEWORKS.add(ModelWeight.getTensorflowID());
        SUPPORTED_FRAMEWORKS.add(ModelWeight.getOnnxID());
        SUPPORTED_FRAMEWORKS.add(ModelWeight.getTorchscriptID());
    }
}
