/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.example;

import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.bioimageio.description.weights.ModelWeight;
import io.bioimage.modelrunner.bioimageio.description.weights.WeightFormat;
import io.bioimage.modelrunner.engine.installation.EngineInstall;
import io.bioimage.modelrunner.model.java.BioimageIoModelJava;
import io.bioimage.modelrunner.tensor.Tensor;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.stream.LongStream;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;

public class ExampleLoadAndRunAllBmzModels {
    private static final String CWD = System.getProperty("user.dir");
    private static final String ENGINES_DIR = new File(CWD, "engines").getAbsolutePath();
    private static final String MODELS_DIR = new File(CWD, "models").getAbsolutePath();
    private static final List<String> SUPPORTED_FRAMEWORKS = new ArrayList<String>();

    private static void installAllValidEngines() throws InterruptedException {
        EngineInstall installer = EngineInstall.createInstaller(ENGINES_DIR);
        installer.basicEngineInstallation();
    }

    public static void main(String[] args) throws InterruptedException {
        ArrayList<String> modelsWithErrors = new ArrayList<String>();
        ExampleLoadAndRunAllBmzModels.installAllValidEngines();
        BioimageioRepo br = BioimageioRepo.connect();
        Map<String, ModelDescriptor> bmzModelList = br.listAllModels(false);
        int successModelCount = 0;
        for (Map.Entry<String, ModelDescriptor> modelEntry : bmzModelList.entrySet()) {
            try {
                ExampleLoadAndRunAllBmzModels.checkModelCompatibleWithEngines(modelEntry.getValue());
                String modelFolder = br.downloadByName(modelEntry.getValue().getName(), MODELS_DIR);
                ExampleLoadAndRunAllBmzModels.loadAndRunModel(modelFolder, modelEntry.getValue());
                ++successModelCount;
            }
            catch (IllegalArgumentException ex) {
                modelsWithErrors.add(modelEntry.getValue().getName());
            }
            catch (IOException | InterruptedException e) {
                System.out.println(modelEntry.getValue().getName() + ": Error downloading model." + e.toString());
                modelsWithErrors.add(modelEntry.getValue().getName());
            }
            catch (Exception e) {
                System.out.println(modelEntry.getValue().getName() + ": Error loading/running model." + e.toString());
                modelsWithErrors.add(modelEntry.getValue().getName());
            }
        }
        System.out.println("Models run without any issue: " + successModelCount + "/" + bmzModelList.size());
        System.out.println(modelsWithErrors);
    }

    public static <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void loadAndRunModel(String modelFolder, ModelDescriptor descriptor) throws Exception {
        try (BioimageIoModelJava model = BioimageIoModelJava.createBioimageioModel(modelFolder, ENGINES_DIR);){
            model.loadModel();
            List inputs = ExampleLoadAndRunAllBmzModels.createInputs(descriptor);
            List outputs = ExampleLoadAndRunAllBmzModels.createOutputs(descriptor);
            model.run(inputs, outputs);
            for (Tensor tt : outputs) {
                if (!tt.isEmpty()) continue;
                throw new Exception(descriptor.getName() + ": Output tensor is empty");
            }
            model.close();
            inputs.stream().forEach(t -> t.close());
            outputs.stream().forEach(t -> t.close());
        }
    }

    private static List<Tensor<FloatType>> createInputs(ModelDescriptor descriptor) {
        ArrayList<Tensor<FloatType>> inputs = new ArrayList<Tensor<FloatType>>();
        ArrayImgFactory imgFactory = new ArrayImgFactory((NativeType)new FloatType());
        for (TensorSpec it : descriptor.getInputTensors()) {
            String axesStr = it.getAxesOrder();
            String name = it.getName();
            int[] min = it.getMinTileSizeArr();
            int[] step = it.getTileStepArr();
            long[] imSize = LongStream.range(0L, step.length).map(i -> min[(int)i] + step[(int)i]).toArray();
            Tensor tt = Tensor.build(name, axesStr, imgFactory.create(imSize));
            inputs.add(tt);
        }
        return inputs;
    }

    private static <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createOutputs(ModelDescriptor descriptor) {
        ArrayList<Tensor<T>> outputs = new ArrayList<Tensor<T>>();
        for (TensorSpec ot : descriptor.getOutputTensors()) {
            String axesStr = ot.getAxesOrder();
            String name = ot.getName();
            Tensor tt = Tensor.buildEmptyTensor(name, axesStr);
            outputs.add(tt);
        }
        return outputs;
    }

    private static void checkModelCompatibleWithEngines(ModelDescriptor descriptor) {
        List<WeightFormat> wws = descriptor.getWeights().gettAllSupportedWeightObjects();
        boolean supported = false;
        for (WeightFormat ww : wws) {
            if (!SUPPORTED_FRAMEWORKS.contains(ww.getFramework())) continue;
            supported = true;
        }
        if (!supported) {
            throw new IllegalArgumentException(descriptor.getName() + ": pytorch 2 models cannot run on this test");
        }
        if (!supported) {
            throw new IllegalArgumentException(descriptor.getName() + ": weights not supported");
        }
    }

    static {
        SUPPORTED_FRAMEWORKS.add(ModelWeight.getTensorflowID());
        SUPPORTED_FRAMEWORKS.add(ModelWeight.getOnnxID());
        SUPPORTED_FRAMEWORKS.add(ModelWeight.getTorchscriptID());
    }
}

