package io.bioimage.modelrunner.model;

import icy.preferences.RepositoryPreferences;
import io.bioimage.modelrunner.apposed.appose.Mamba;
import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
import io.bioimage.modelrunner.bioimageio.description.exceptions.ModelSpecsException;
import io.bioimage.modelrunner.engine.engines.TensorflowEngine;
import io.bioimage.modelrunner.engine.installation.EngineInstall;
import io.bioimage.modelrunner.exceptions.LoadEngineException;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.runmode.RunMode;
import io.bioimage.modelrunner.runmode.ops.GenericOp;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.Utils;
import io.bioimage.modelrunner.utils.Constants;
import io.bioimage.modelrunner.versionmanagement.InstalledEngines;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Calendar;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.util.Cast;
import net.imglib2.view.Views;
import org.apache.commons.compress.archivers.ArchiveException;

/* loaded from: input_file:io/bioimage/modelrunner/model/Stardist2D.class */
public class Stardist2D {
    private ModelDescriptor descriptor;
    private final int channels;
    private final float nms_threshold;
    private final float prob_threshold;
    private static final List<String> STARDIST_DEPS = Arrays.asList("python=3.10", "stardist", "numpy", "appose");
    private static final List<String> STARDIST_CHANNELS = Arrays.asList("conda-forge", RepositoryPreferences.DEFAULT_REPOSITERY_NAME);
    private static final String STARDIST2D_PATH_IN_RESOURCES = "ops/stardist_postprocessing/";
    private static final String STARDIST2D_SCRIPT_NAME = "stardist_postprocessing.py";
    private static final String STARDIST2D_METHOD_NAME = "stardist_postprocessing";

    private Stardist2D() {
        this.channels = 1;
        this.nms_threshold = 0.0f;
        this.prob_threshold = 0.0f;
    }

    private Stardist2D(ModelDescriptor modelDescriptor) {
        this.descriptor = modelDescriptor;
        Map map = (Map) modelDescriptor.getConfig().getSpecMap().get("stardist");
        Map map2 = (Map) map.get("config");
        Map map3 = (Map) map.get("thresholds");
        this.channels = ((Integer) map2.get("n_channel_in")).intValue();
        this.nms_threshold = new Double(((Double) map3.get("nms")).doubleValue()).floatValue();
        this.prob_threshold = new Double(((Double) map3.get("prob")).doubleValue()).floatValue();
    }

    public static Stardist2D fromBioimageioModel(String str) throws ModelSpecsException, FileNotFoundException, IOException {
        return new Stardist2D(ModelDescriptorFactory.readFromLocalFile(str + File.separator + Constants.RDF_FNAME));
    }

    public static Stardist2D fromPretained(String str, boolean z) throws IOException, InterruptedException, ModelSpecsException {
        return fromPretained(str, new File("models").getAbsolutePath(), z);
    }

    public static Stardist2D fromPretained(String str, String str2, boolean z) throws IOException, InterruptedException, ModelSpecsException {
        if ((str.equals("StarDist H&E Nuclei Segmentation") || str.equals("2D_versatile_he")) && !z) {
            ModelDescriptor orElse = ModelDescriptorFactory.getModelsAtLocalRepo().stream().filter(modelDescriptor -> {
                return modelDescriptor.getName().equals("StarDist H&E Nuclei Segmentation");
            }).findFirst().orElse(null);
            return orElse != null ? new Stardist2D(orElse) : fromBioimageioModel(BioimageioRepo.connect().downloadByName("StarDist H&E Nuclei Segmentation", str2));
        }
        if (str.equals("StarDist H&E Nuclei Segmentation") || str.equals("2D_versatile_he")) {
            return fromBioimageioModel(BioimageioRepo.connect().downloadByName("StarDist H&E Nuclei Segmentation", str2));
        }
        if ((str.equals("StarDist Fluorescence Nuclei Segmentation") || str.equals("2D_versatile_fluo")) && !z) {
            ModelDescriptor orElse2 = ModelDescriptorFactory.getModelsAtLocalRepo().stream().filter(modelDescriptor2 -> {
                return modelDescriptor2.getName().equals("StarDist Fluorescence Nuclei Segmentation");
            }).findFirst().orElse(null);
            return orElse2 != null ? new Stardist2D(orElse2) : fromBioimageioModel(BioimageioRepo.connect().downloadByName("StarDist Fluorescence Nuclei Segmentation", str2));
        }
        if (str.equals("StarDist Fluorescence Nuclei Segmentation") || str.equals("2D_versatile_fluo")) {
            return fromBioimageioModel(BioimageioRepo.connect().downloadByName("StarDist Fluorescence Nuclei Segmentation", str2));
        }
        throw new IllegalArgumentException("There is no Stardist2D model called: " + str);
    }

    private <T extends RealType<T> & NativeType<T>> void checkInput(RandomAccessibleInterval<T> randomAccessibleInterval) {
        if (randomAccessibleInterval.dimensionsAsLongArray().length == 2 && this.channels != 1) {
            throw new IllegalArgumentException("Stardist2D needs an image with three dimensions: XYC");
        }
        if (randomAccessibleInterval.dimensionsAsLongArray().length != 3 && this.channels != 1) {
            throw new IllegalArgumentException("Stardist2D needs an image with three dimensions: XYC");
        }
        if (randomAccessibleInterval.dimensionsAsLongArray().length != 2 && randomAccessibleInterval.dimensionsAsLongArray()[2] != this.channels) {
            throw new IllegalArgumentException("This Stardist2D model requires " + this.channels + " channels.");
        }
        if (randomAccessibleInterval.dimensionsAsLongArray().length > 3 || randomAccessibleInterval.dimensionsAsLongArray().length < 2) {
            throw new IllegalArgumentException("Stardist2D model requires an image with dimensions XYC.");
        }
    }

    public <T extends RealType<T> & NativeType<T>> RandomAccessibleInterval<T> predict(RandomAccessibleInterval<T> randomAccessibleInterval) throws ModelSpecsException, LoadModelException, LoadEngineException, IOException, RunModelException, InterruptedException {
        checkInput(randomAccessibleInterval);
        if (randomAccessibleInterval.dimensionsAsLongArray().length == 2) {
            randomAccessibleInterval = Views.addDimension(randomAccessibleInterval, 0L, 0L);
        }
        Tensor build = Tensor.build("input", "byxc", Views.permute(Views.addDimension(Views.permute(randomAccessibleInterval, 0, 2), 0L, 0L), 0, 3));
        Tensor buildEmptyTensor = Tensor.buildEmptyTensor("output", "byxc");
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList.add(build);
        arrayList2.add(buildEmptyTensor);
        Model createBioimageioModel = Model.createBioimageioModel(this.descriptor.getModelPath());
        createBioimageioModel.loadModel();
        createBioimageioModel.runModel(arrayList, arrayList2);
        return Utils.transpose((RandomAccessibleInterval) Cast.unchecked(postProcessing(((Tensor) arrayList2.get(0)).getData())));
    }

    public <T extends RealType<T> & NativeType<T>> RandomAccessibleInterval<T> postProcessing(RandomAccessibleInterval<T> randomAccessibleInterval) throws IOException, InterruptedException {
        String str = new Mamba().getEnvsDir() + File.separator + "stardist";
        GenericOp create = GenericOp.create(str, str + File.separator + STARDIST2D_SCRIPT_NAME, STARDIST2D_METHOD_NAME, 1);
        LinkedHashMap<String, Object> linkedHashMap = new LinkedHashMap<>();
        linkedHashMap.put("input_" + new SimpleDateFormat("ddMMYYYY_HHmmss").format(Calendar.getInstance().getTime()), randomAccessibleInterval);
        linkedHashMap.put("nms_thresh", Float.valueOf(this.nms_threshold));
        linkedHashMap.put("prob_thresh", Float.valueOf(this.prob_threshold));
        create.setInputs(linkedHashMap);
        return (RandomAccessibleInterval) ((List) RunMode.createRunMode(create).runOP().entrySet().stream().filter(entry -> {
            return entry.getValue() instanceof RandomAccessibleInterval;
        }).map(entry2 -> {
            return (RandomAccessibleInterval) entry2.getValue();
        }).collect(Collectors.toList())).get(0);
    }

    public void checkRequirementsInstalled() {
    }

    public static void installRequirements() throws IOException, InterruptedException, RuntimeException, MambaInstallException, ArchiveException, URISyntaxException {
        if (!(InstalledEngines.buildEnginesFinder().checkEngineWithArgsInstalledForOS(TensorflowEngine.NAME, "1.15.0", null, null).size() != 0)) {
            EngineInstall.installEngineWithArgs(TensorflowEngine.NAME, "1.15.0", true, true);
        }
        Mamba mamba = new Mamba();
        boolean z = false;
        try {
            z = mamba.checkAllDependenciesInEnv("stardist", STARDIST_DEPS);
        } catch (MambaInstallException e) {
            mamba.installMicromamba();
        }
        if (!z) {
            mamba.create("stardist", true, STARDIST_CHANNELS, STARDIST_DEPS);
        }
        String str = (mamba.getEnvsDir() + File.separator + "stardist") + File.separator + STARDIST2D_SCRIPT_NAME;
        if (Paths.get(str, new String[0]).toFile().isFile()) {
            return;
        }
        InputStream resourceAsStream = Stardist2D.class.getClassLoader().getResourceAsStream("ops/stardist_postprocessing/stardist_postprocessing.py");
        try {
            Files.copy(resourceAsStream, Paths.get(str, new String[0]), StandardCopyOption.REPLACE_EXISTING);
            if (resourceAsStream != null) {
                resourceAsStream.close();
            }
        } catch (Throwable th) {
            if (resourceAsStream != null) {
                try {
                    resourceAsStream.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public static void main(String[] strArr) throws IOException, InterruptedException, RuntimeException, MambaInstallException, ModelSpecsException, LoadEngineException, RunModelException, ArchiveException, URISyntaxException, LoadModelException {
        installRequirements();
        fromPretained("2D_versatile_fluo", false).predict(ArrayImgs.floats(new long[]{512, 512}));
        System.out.println(true);
    }
}
