/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.model.special.stardist;

import io.bioimage.modelrunner.apposed.appose.Environment;
import io.bioimage.modelrunner.apposed.appose.Mamba;
import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
import io.bioimage.modelrunner.apposed.appose.Service;
import io.bioimage.modelrunner.apposed.appose.Types;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.BaseModel;
import io.bioimage.modelrunner.model.processing.Processing;
import io.bioimage.modelrunner.model.special.stardist.Stardist2D;
import io.bioimage.modelrunner.model.special.stardist.Stardist3D;
import io.bioimage.modelrunner.system.PlatformDetection;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.transformations.ScaleRangeTransformation;
import io.bioimage.modelrunner.utils.CommonUtils;
import io.bioimage.modelrunner.utils.JSONUtils;
import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import net.imglib2.Interval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;
import org.apache.commons.compress.archivers.ArchiveException;

public abstract class StardistAbstract
extends BaseModel {
    private final String modelDir;
    protected final String name;
    protected final String basedir;
    protected final int nChannels;
    protected Map<String, Object> config;
    protected SharedMemoryArray shma;
    private ModelDescriptor descriptor;
    private Service python;
    public double scaleRangeMaxPercentile = 99.8;
    public double scaleRangeMinPercentile = 1.0;
    public String scaleRangeAxes = null;
    private static String INSTALLATION_DIR = Mamba.BASE_PATH;
    private static final List<String> STARDIST_DEPS = Arrays.asList("python=3.10", "stardist", "numpy", "appose");
    private static final List<String> STARDIST_DEPS_PIP = PlatformDetection.isMacOS() && (PlatformDetection.getArch().equals("arm64") || PlatformDetection.isUsingRosseta()) ? Arrays.asList("tensorflow-macos<2.11") : Arrays.asList("tensorflow<2.11");
    private static final List<String> STARDIST_CHANNELS = Arrays.asList("conda-forge", "default");
    private static final String OUTPUT_MASK_KEY = "mask";
    private static final String SHM_NAME_KEY = "_shm_name";
    private static final String DTYPE_KEY = "_dtype";
    private static final String SHAPE_KEY = "_shape";
    private static final String KEYS_KEY = "keys";
    protected static final String LOAD_MODEL_CODE_ABSTRACT = "if '%s' not in globals().keys():" + System.lineSeparator() + "  from stardist.models import %s" + System.lineSeparator() + "  globals()['%s'] = %s" + System.lineSeparator() + "if 'np' not in globals().keys():" + System.lineSeparator() + "  import numpy as np" + System.lineSeparator() + "  globals()['np'] = np" + System.lineSeparator() + "if 'os' not in globals().keys():" + System.lineSeparator() + "  import os" + System.lineSeparator() + "  globals()['os'] = os" + System.lineSeparator() + "if 'shared_memory' not in globals().keys():" + System.lineSeparator() + "  from multiprocessing import shared_memory" + System.lineSeparator() + "  globals()['shared_memory'] = shared_memory" + System.lineSeparator() + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"-1\"" + System.lineSeparator() + "model = %s(None, name='%s', basedir=r\"%s\")" + System.lineSeparator() + "globals()['model'] = model" + System.lineSeparator();
    private static final String RUN_MODEL_CODE = "output = model.predict_instances(im, return_predict=False)" + System.lineSeparator() + "if type(output) == np.ndarray:" + System.lineSeparator() + "  im[:] = output" + System.lineSeparator() + "  im[:] = output" + System.lineSeparator() + "  if os.name == 'nt':" + System.lineSeparator() + "    im_shm.close()" + System.lineSeparator() + "    im_shm.unlink()" + System.lineSeparator() + "if type(output) != list and type(output) != tuple:" + System.lineSeparator() + "  raise TypeError('StarDist output should be a list of a np.ndarray')" + System.lineSeparator() + "if type(output[0]) != np.ndarray:" + System.lineSeparator() + "  raise TypeError('If the StarDist output is a list, the first entry should be a np.ndarray')" + System.lineSeparator() + "if len(im.shape) == 3 and len(output[0].shape) == 2:" + System.lineSeparator() + "  im[:, :, 0] = output[0]" + System.lineSeparator() + "elif len(im.shape) == 4 and len(output[0].shape) == 3:" + System.lineSeparator() + "  im[:, :, :, 0] = output[0]" + System.lineSeparator() + "else:" + System.lineSeparator() + "  im[:] = output[0]" + System.lineSeparator() + "if len(output) > 1 and type(output[1]) != dict:" + System.lineSeparator() + "  raise TypeError('If the StarDist output is a list, the second entry needs to be a dict.')" + System.lineSeparator() + "task.outputs['" + "keys" + "'] = list(output[1].keys())" + System.lineSeparator() + "shm_list = []" + System.lineSeparator() + "np_list = []" + System.lineSeparator() + "for kk, vv in output[1].items():" + System.lineSeparator() + "  if type(vv) != np.ndarray:" + System.lineSeparator() + "    task.update('Output ' + kk + ' is not a np.ndarray. Only np.ndarrays supported.')" + System.lineSeparator() + "    continue" + System.lineSeparator() + "  if output[1][kk].nbytes == 0:" + System.lineSeparator() + "    task.outputs[kk] = None" + System.lineSeparator() + "  else:" + System.lineSeparator() + "    task.outputs[kk + '" + "_shape" + "'] = output[1][kk].shape" + System.lineSeparator() + "    task.outputs[kk + '" + "_dtype" + "'] = str(output[1][kk].dtype)" + System.lineSeparator() + "    shm = shared_memory.SharedMemory(create=True, size=output[1][kk].nbytes)" + System.lineSeparator() + "    task.outputs[kk + '" + "_shm_name" + "'] = shm.name" + System.lineSeparator() + "    shm_list.append(shm)" + System.lineSeparator() + "    aa = np.ndarray(output[1][kk].shape, dtype=output[1][kk].dtype, buffer=shm.buf)" + System.lineSeparator() + "    aa[:] = output[1][kk]" + System.lineSeparator() + "    np_list.append(aa)" + System.lineSeparator() + "globals()['shm_list'] = shm_list" + System.lineSeparator() + "globals()['np_list'] = np_list" + System.lineSeparator() + "if os.name == 'nt':" + System.lineSeparator() + "  im_shm.close()" + System.lineSeparator() + "  im_shm.unlink()" + System.lineSeparator();
    private static final String CLOSE_SHM_CODE = "if 'np_list' in globals().keys():" + System.lineSeparator() + "  for a in np_list:" + System.lineSeparator() + "    del a" + System.lineSeparator() + "if 'shm_list' in globals().keys():" + System.lineSeparator() + "  for s in shm_list:" + System.lineSeparator() + "    s.unlink()" + System.lineSeparator() + "    del s" + System.lineSeparator();

    protected abstract String createImportsCode();

    protected abstract <T extends RealType<T> & NativeType<T>> void checkInput(RandomAccessibleInterval<T> var1);

    protected abstract <T extends RealType<T> & NativeType<T>> RandomAccessibleInterval<T> reconstructMask() throws IOException;

    public abstract boolean is2D();

    public abstract boolean is3D();

    protected StardistAbstract(String modelName, String baseDir, Map<String, Object> config) throws IOException {
        this.name = modelName;
        this.basedir = baseDir;
        this.modelDir = new File(baseDir, modelName).getAbsolutePath();
        StardistAbstract.checkFilesPresent(this.modelDir);
        this.nChannels = ((Number)config.get("n_channel_in")).intValue();
        this.createPythonService();
    }

    protected StardistAbstract(String modelName, String baseDir) throws IOException {
        this.name = modelName;
        this.basedir = baseDir;
        this.modelDir = new File(baseDir, modelName).getAbsolutePath();
        StardistAbstract.checkFilesPresent(this.modelDir);
        this.config = JSONUtils.load(new File(this.modelDir, "config.json").getAbsolutePath());
        this.nChannels = ((Number)this.config.get("n_channel_in")).intValue();
        this.createPythonService();
    }

    public static void checkFilesPresent(String modelDir) throws IOException {
        if (!new File(modelDir, "config.json").isFile() && !new File(modelDir, "rdf.yaml").isFile()) {
            throw new IllegalArgumentException("No 'config.json' file found in the model directory");
        }
        if (!new File(modelDir, "config.json").isFile()) {
            StardistAbstract.createConfigFromBioimageio(null, modelDir);
        }
        if (!new File(modelDir, "thresholds.json").isFile() && !new File(modelDir, "rdf.yaml").isFile()) {
            throw new IllegalArgumentException("No 'thresholds.json' file found in the model directory");
        }
        if (!new File(modelDir, "thresholds.json").isFile()) {
            StardistAbstract.createThresholdsFromBioimageio(null, modelDir);
        }
    }

    protected StardistAbstract(ModelDescriptor descriptor) throws IOException {
        this.descriptor = descriptor;
        this.name = new File(descriptor.getModelPath()).getName();
        this.basedir = new File(descriptor.getModelPath()).getParentFile().getAbsolutePath();
        this.modelDir = descriptor.getModelPath();
        if (!new File(this.modelDir, "config.json").isFile()) {
            StardistAbstract.createConfigFromBioimageio(descriptor, this.modelDir);
        }
        if (!new File(this.modelDir, "thresholds.json").isFile()) {
            StardistAbstract.createThresholdsFromBioimageio(descriptor, this.modelDir);
        }
        this.config = JSONUtils.load(new File(this.modelDir, "config.json").getAbsolutePath());
        this.nChannels = ((Number)this.config.get("n_channel_in")).intValue();
        this.createPythonService();
    }

    private static void createConfigFromBioimageio(ModelDescriptor descriptor, String modelDir) throws IOException {
        if (descriptor == null) {
            descriptor = ModelDescriptorFactory.readFromLocalFile(modelDir + File.separator + "rdf.yaml");
        }
        Map stardistMap = (Map)descriptor.getConfig().getSpecMap().get("stardist");
        Map stardistConfig = (Map)stardistMap.get("config");
        JSONUtils.writeJSONFile(new File(modelDir, "config.json").getAbsolutePath(), stardistConfig);
    }

    private static void createThresholdsFromBioimageio(ModelDescriptor descriptor, String modelDir) throws IOException {
        if (descriptor == null) {
            descriptor = ModelDescriptorFactory.readFromLocalFile(modelDir + File.separator + "rdf.yaml");
        }
        Map stardistMap = (Map)descriptor.getConfig().getSpecMap().get("stardist");
        Map stardistThres = (Map)stardistMap.get("thresholds");
        JSONUtils.writeJSONFile(new File(modelDir, "thresholds.json").getAbsolutePath(), stardistThres);
    }

    private void createPythonService() throws IOException {
        Environment env = new Environment(){

            @Override
            public String base() {
                return new Mamba(INSTALLATION_DIR).getEnvsDir() + File.separator + "stardist";
            }
        };
        this.python = env.python();
        this.python.debug(System.err::println);
    }

    protected String createEncodeImageScript() {
        String code = "";
        code = code + "im_shm = shared_memory.SharedMemory(name='" + this.shma.getNameForPython() + "', size=" + this.shma.getSize() + ")" + System.lineSeparator();
        long nElems = 1L;
        for (long elem : this.shma.getOriginalShape()) {
            nElems *= elem;
        }
        code = code + "im = np.ndarray(" + nElems + ", dtype='" + CommonUtils.getDataTypeFromRAI((RandomAccessibleInterval)Cast.unchecked(this.shma.getSharedRAI())) + "', buffer=im_shm.buf).reshape([";
        for (int i = 0; i < this.shma.getOriginalShape().length; ++i) {
            code = code + this.shma.getOriginalShape()[i] + ", ";
        }
        code = code + "])" + System.lineSeparator();
        return code;
    }

    protected <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void preprocess(List<Tensor<T>> inputs) {
        if (this.descriptor != null) {
            Processing processing = Processing.init(this.descriptor);
            List inputsProcessed = processing.preprocess(inputs, false);
            inputs.set(0, inputsProcessed.get(0));
        } else {
            ScaleRangeTransformation transform = new ScaleRangeTransformation();
            transform.setAxes(transform);
            transform.setMaxPercentile(this.scaleRangeMaxPercentile);
            transform.setMinPercentile(this.scaleRangeMinPercentile);
            transform.setAxes(this.scaleRangeAxes);
            inputs.set(0, (Tensor)Cast.unchecked(transform.apply(inputs.get(0))));
        }
    }

    @Override
    public void close() {
        if (!this.loaded) {
            return;
        }
        this.python.close();
        this.loaded = false;
        this.closed = true;
    }

    @Override
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void run(List<Tensor<T>> inTensors, List<Tensor<R>> outTensors) throws RunModelException {
        if (inTensors.size() > 1) {
            throw new RunModelException("Stardist needs just one input image");
        }
        this.preprocess(inTensors);
        try {
            Map<String, RandomAccessibleInterval<R>> outputs = this.run(inTensors.get(0).getData());
            for (Tensor tensor : outTensors) {
                Map.Entry entry = outputs.entrySet().stream().filter(ee -> tensor.getName().equals(ee.getKey()) && Arrays.equals(tensor.getData().dimensionsAsLongArray(), ((RandomAccessibleInterval)ee.getValue()).dimensionsAsLongArray())).findFirst().orElse(null);
                if (entry == null && Arrays.equals(tensor.getData().dimensionsAsLongArray(), outputs.get(OUTPUT_MASK_KEY).dimensionsAsLongArray())) {
                    tensor.setData(outputs.get(OUTPUT_MASK_KEY));
                    continue;
                }
                if (entry == null) continue;
                tensor.setData((RandomAccessibleInterval)entry.getValue());
            }
        }
        catch (IOException | InterruptedException e) {
            throw new RunModelException(Types.stackTrace(e));
        }
    }

    @Override
    public void loadModel() throws LoadModelException {
        if (this.closed) {
            throw new RuntimeException("Cannot load model after it has been closed");
        }
        String code = "";
        if (!this.loaded) {
            code = code + this.createImportsCode() + System.lineSeparator();
        }
        try {
            Service.Task task = this.python.task(code);
            task.waitFor();
            if (task.status == Service.TaskStatus.CANCELED) {
                throw new RuntimeException("Task canceled");
            }
            if (task.status == Service.TaskStatus.FAILED) {
                throw new RuntimeException(task.error);
            }
            if (task.status == Service.TaskStatus.CRASHED) {
                throw new RuntimeException(task.error);
            }
            this.loaded = true;
        }
        catch (IOException | InterruptedException e) {
            throw new LoadModelException(Types.stackTrace(e));
        }
    }

    @Override
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<T>> run(List<Tensor<R>> inputTensors) throws RunModelException {
        if (inputTensors.size() > 1) {
            throw new RunModelException("Stardist needs just one input image");
        }
        this.preprocess(inputTensors);
        try {
            Map<String, RandomAccessibleInterval<R>> outputs = this.run(inputTensors.get(0).getData());
            ArrayList<Tensor<T>> outTensors = new ArrayList<Tensor<T>>();
            for (Map.Entry<String, RandomAccessibleInterval<R>> entry : outputs.entrySet()) {
                if (entry.getValue() == null) continue;
                String axesOrder = "xy";
                if (entry.getValue().dimensionsAsLongArray().length > 2 && this.is2D()) {
                    axesOrder = axesOrder + "c";
                } else if (entry.getValue().dimensionsAsLongArray().length == 3 && this.is3D()) {
                    axesOrder = axesOrder + "z";
                } else if (entry.getValue().dimensionsAsLongArray().length > 3 && this.is3D()) {
                    axesOrder = axesOrder + "zc";
                } else if (entry.getValue().dimensionsAsLongArray().length == 1) {
                    axesOrder = "i";
                }
                Tensor<R> tt = Tensor.build(entry.getKey(), axesOrder, entry.getValue());
                if (tt.getName() != OUTPUT_MASK_KEY) continue;
                outTensors.add(tt);
            }
            return outTensors;
        }
        catch (IOException | InterruptedException e) {
            throw new RunModelException(Types.stackTrace(e));
        }
    }

    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> Map<String, RandomAccessibleInterval<R>> run(RandomAccessibleInterval<T> img) throws IOException, InterruptedException {
        this.checkInput(img);
        this.shma = SharedMemoryArray.createSHMAFromRAI(img, false, false);
        String code = "";
        if (!this.loaded) {
            code = code + this.createImportsCode() + System.lineSeparator();
        }
        code = code + this.createEncodeImageScript() + System.lineSeparator();
        code = code + RUN_MODEL_CODE + System.lineSeparator();
        Service.Task task = this.python.task(code);
        task.waitFor();
        if (task.status == Service.TaskStatus.CANCELED) {
            throw new RuntimeException("Task canceled");
        }
        if (task.status == Service.TaskStatus.FAILED) {
            throw new RuntimeException(task.error);
        }
        if (task.status == Service.TaskStatus.CRASHED) {
            throw new RuntimeException(task.error);
        }
        this.loaded = true;
        return this.reconstructOutputs(task);
    }

    private <T extends RealType<T> & NativeType<T>> Map<String, RandomAccessibleInterval<T>> reconstructOutputs(Service.Task task) throws IOException, InterruptedException {
        LinkedHashMap<String, RandomAccessibleInterval<T>> outs = new LinkedHashMap<String, RandomAccessibleInterval<T>>();
        outs.put(OUTPUT_MASK_KEY, this.reconstructMask());
        if (task.outputs.get(KEYS_KEY) != null) {
            for (String kk : (List)task.outputs.get(KEYS_KEY)) {
                outs.put(kk, this.reconstruct(task, kk));
            }
        }
        if (PlatformDetection.isWindows()) {
            Service.Task closeSHMTask = this.python.task(CLOSE_SHM_CODE);
            closeSHMTask.waitFor();
        }
        return outs;
    }

    private <T extends RealType<T> & NativeType<T>> RandomAccessibleInterval<T> reconstruct(Service.Task task, String key) throws IOException {
        String shm_name = (String)task.outputs.get(key + SHM_NAME_KEY);
        String coords_dtype = (String)task.outputs.get(key + DTYPE_KEY);
        List coords_shape = (List)task.outputs.get(key + SHAPE_KEY);
        if (coords_shape == null) {
            return null;
        }
        long[] coordsSh = new long[coords_shape.size()];
        for (int i = 0; i < coordsSh.length; ++i) {
            coordsSh[i] = ((Number)coords_shape.get(i)).longValue();
        }
        SharedMemoryArray shmCoords = SharedMemoryArray.readOrCreate(shm_name, coordsSh, (RealType)Cast.unchecked(CommonUtils.getImgLib2DataType(coords_dtype)), false, false);
        RandomAccessibleInterval coordsRAI = shmCoords.getSharedRAI();
        RandomAccessibleInterval<RealType> coordsCopy = Tensor.createCopyOfRaiInWantedDataType((RandomAccessibleInterval)Cast.unchecked(coordsRAI), (RealType)Util.getTypeFromInterval((Interval)Cast.unchecked(coordsRAI)));
        shmCoords.close();
        return coordsCopy;
    }

    public static StardistAbstract init(String modelDir) throws IOException {
        File modelDirFile = new File(modelDir);
        String modelName = modelDirFile.getName();
        String baseDir = modelDirFile.getParentFile().getAbsolutePath();
        StardistAbstract.checkFilesPresent(modelDir);
        Map<String, Object> configMap = JSONUtils.load(new File(modelDir, "config.json").getAbsolutePath());
        String axes = ((String)configMap.get("axes")).toUpperCase();
        if (axes.contains("Z")) {
            return new Stardist3D(modelName, baseDir, configMap);
        }
        return new Stardist2D(modelName, baseDir, configMap);
    }

    public static StardistAbstract init(String modelName, String baseDir) throws IOException {
        String modelDir = new File(baseDir, modelName).getAbsolutePath();
        StardistAbstract.checkFilesPresent(modelDir);
        Map<String, Object> configMap = JSONUtils.load(new File(modelDir, "config.json").getAbsolutePath());
        String axes = ((String)configMap.get("axes")).toUpperCase();
        if (axes.contains("Z")) {
            return new Stardist3D(modelName, baseDir, configMap);
        }
        return new Stardist2D(modelName, baseDir, configMap);
    }

    public static StardistAbstract fromBioimageioModel(ModelDescriptor descriptor) throws IOException {
        if (!descriptor.getConfig().getSpecMap().keySet().contains("stardist")) {
            throw new IllegalArgumentException("This Bioimage.io model does not correspond to a StarDist model.");
        }
        if (!descriptor.getModelFamily().equals("stardist")) {
            throw new RuntimeException("Please first install StarDist with 'StardistAbstract.installRequirements()'");
        }
        if (descriptor.getInputTensors().get(0).getAxesOrder().contains("z")) {
            return Stardist3D.fromBioimageioModel(descriptor);
        }
        return Stardist2D.fromBioimageioModel(descriptor);
    }

    public static boolean isInstalled() {
        Mamba mamba = new Mamba(INSTALLATION_DIR);
        try {
            return mamba.checkAllDependenciesInEnv("stardist", STARDIST_DEPS);
        }
        catch (MambaInstallException e) {
            return false;
        }
    }

    public static boolean isInstalled(String envPath) {
        Mamba mamba = new Mamba(INSTALLATION_DIR);
        try {
            return mamba.checkAllDependenciesInEnv(envPath, STARDIST_DEPS);
        }
        catch (MambaInstallException e) {
            return false;
        }
    }

    public static void installRequirements() throws IOException, InterruptedException, RuntimeException, MambaInstallException, ArchiveException, URISyntaxException {
        StardistAbstract.installRequirements(null);
    }

    public static void installRequirements(Consumer<String> consumer) throws IOException, InterruptedException, RuntimeException, MambaInstallException, ArchiveException, URISyntaxException {
        Mamba mamba = new Mamba(INSTALLATION_DIR);
        if (consumer != null) {
            mamba.setConsoleOutputConsumer(consumer);
            mamba.setErrorOutputConsumer(consumer);
        }
        boolean stardistPythonInstalled = false;
        try {
            ArrayList<String> deps = new ArrayList<String>(STARDIST_DEPS);
            for (String dd2 : deps) {
                deps.add(dd2.equals("tensorflow-macos<2.11") ? dd2.replace("-macos", "") : dd2);
            }
            stardistPythonInstalled = mamba.checkAllDependenciesInEnv("stardist", deps);
        }
        catch (MambaInstallException e) {
            mamba.installMicromamba();
        }
        if (!stardistPythonInstalled) {
            mamba.create("stardist", true, STARDIST_CHANNELS, STARDIST_DEPS.stream().map(dd -> dd.contains("<") | dd.contains(">") ? "\"" + dd + "\"" : dd).collect(Collectors.toList()));
            mamba.pipInstallIn("stardist", STARDIST_DEPS_PIP.stream().collect(Collectors.toList()).toArray(new String[STARDIST_DEPS_PIP.size()]));
        }
    }

    public static void setInstallationDir(String installationDir) {
        INSTALLATION_DIR = installationDir;
    }

    public static String getInstallationDir() {
        return INSTALLATION_DIR;
    }
}

