/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.model.python;

import io.bioimage.modelrunner.apposed.appose.Environment;
import io.bioimage.modelrunner.apposed.appose.Mamba;
import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
import io.bioimage.modelrunner.apposed.appose.Service;
import io.bioimage.modelrunner.apposed.appose.Types;
import io.bioimage.modelrunner.bioimageio.tiling.TileInfo;
import io.bioimage.modelrunner.bioimageio.tiling.TileMaker;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.BaseModel;
import io.bioimage.modelrunner.model.java.DLModelJava;
import io.bioimage.modelrunner.system.PlatformDetection;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.utils.CommonUtils;
import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import net.imglib2.Interval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;
import org.apache.commons.compress.archivers.ArchiveException;

public class DLModelPytorchProtected
extends BaseModel {
    protected final String modelFile;
    protected final String callable;
    protected final String importModule;
    protected final String weightsPath;
    protected final Map<String, Object> kwargs;
    protected String envPath;
    private Service python;
    protected List<SharedMemoryArray> inShmaList = new ArrayList<SharedMemoryArray>();
    private List<String> outShmNames;
    private List<String> outShmDTypes;
    private List<long[]> outShmDims;
    protected List<TileInfo> inputTiles;
    protected List<TileInfo> outputTiles;
    protected boolean tiling = false;
    protected DLModelJava.TilingConsumer tileCounter;
    public static final String COMMON_PYTORCH_ENV_NAME = "biapy";
    private static final List<String> BIAPY_CONDA_DEPS = Arrays.asList("python=3.10");
    private static final List<String> BIAPY_PIP_DEPS_TORCH = PlatformDetection.isMacOS() && PlatformDetection.getArch().equals("x86_64") && !PlatformDetection.isUsingRosseta() ? Arrays.asList("torch==2.2.2", "torchvision==0.17.2", "torchaudio==2.2.2") : Arrays.asList("torch==2.4.0", "torchvision==0.19.0", "torchaudio==2.4.0");
    private static final List<String> BIAPY_PIP_DEPS = Arrays.asList("timm==1.0.14", "pytorch-msssim==1.0.0", "torchmetrics==1.4.3", "cellpose==3.1.1.1", "scipy==1.15.2", "torch-fidelity==0.3.0", "careamics", "biapy==3.5.10", "appose");
    private static final List<String> BIAPY_PIP_ARGS = Arrays.asList("--index-url", "https://download.pytorch.org/whl/cpu");
    protected static String INSTALLATION_DIR = Mamba.BASE_PATH;
    protected static final String MODEL_VAR_NAME = "model_" + UUID.randomUUID().toString().replace("-", "_");
    protected static final String LOAD_MODEL_CODE_ABSTRACT = "if 'sys' not in globals().keys():" + System.lineSeparator() + "  import sys" + System.lineSeparator() + "  globals()['sys'] = sys" + System.lineSeparator() + "if 'np' not in globals().keys():" + System.lineSeparator() + "  import numpy as np" + System.lineSeparator() + "  globals()['np'] = np" + System.lineSeparator() + "if 'os' not in globals().keys():" + System.lineSeparator() + "  import os" + System.lineSeparator() + "  globals()['os'] = os" + System.lineSeparator() + "if 'shared_memory' not in globals().keys():" + System.lineSeparator() + "  from multiprocessing import shared_memory" + System.lineSeparator() + "  globals()['shared_memory'] = shared_memory" + System.lineSeparator() + "%s" + System.lineSeparator() + "%s" + System.lineSeparator() + "if '%s' not in globals().keys():" + System.lineSeparator() + "  globals()['%s'] = %s" + System.lineSeparator();
    protected static final String OUTPUT_LIST_KEY = "out_list" + UUID.randomUUID().toString().replace("-", "_");
    protected static final String SHMS_KEY = "shms_" + UUID.randomUUID().toString().replace("-", "_");
    protected static final String SHM_NAMES_KEY = "shm_names_" + UUID.randomUUID().toString().replace("-", "_");
    protected static final String DTYPES_KEY = "dtypes_" + UUID.randomUUID().toString().replace("-", "_");
    protected static final String DIMS_KEY = "dims_" + UUID.randomUUID().toString().replace("-", "_");
    protected static final String RECOVER_OUTPUTS_CODE = "def handle_output(outs_i):" + System.lineSeparator() + "    if type(outs_i) == np.ndarray:" + System.lineSeparator() + "      shm = shared_memory.SharedMemory(create=True, size=outs_i.nbytes)" + System.lineSeparator() + "      sh_np_array = np.ndarray(outs_i.shape, dtype=outs_i.dtype, buffer=shm.buf)" + System.lineSeparator() + "      np.copyto(sh_np_array, outs_i)" + System.lineSeparator() + "      " + SHMS_KEY + ".append(shm)" + System.lineSeparator() + "      " + SHM_NAMES_KEY + ".append(shm.name)" + System.lineSeparator() + "      " + DTYPES_KEY + ".append(str(outs_i.dtype))" + System.lineSeparator() + "      " + DIMS_KEY + ".append(outs_i.shape)" + System.lineSeparator() + "    elif str(type(outs_i)) == \"<class 'torch.Tensor'>\":" + System.lineSeparator() + "      if 'torch' not in globals().keys():" + System.lineSeparator() + "        import torch" + System.lineSeparator() + "        globals()['torch'] = torch" + System.lineSeparator() + "      else:" + System.lineSeparator() + "        torch = globals()['torch']" + System.lineSeparator() + "      shm = shared_memory.SharedMemory(create=True, size=outs_i.numel() * outs_i.element_size())" + System.lineSeparator() + "      np_arr = np.ndarray(outs_i.shape, dtype=str(outs_i.dtype).split('.')[-1], buffer=shm.buf)" + System.lineSeparator() + "      tensor_np_view = torch.from_numpy(np_arr)" + System.lineSeparator() + "      tensor_np_view.copy_(outs_i)" + System.lineSeparator() + "      " + SHMS_KEY + ".append(shm)" + System.lineSeparator() + "      " + SHM_NAMES_KEY + ".append(shm.name)" + System.lineSeparator() + "      " + DTYPES_KEY + ".append(str(outs_i.dtype).split('.')[-1])" + System.lineSeparator() + "      " + DIMS_KEY + ".append(outs_i.shape)" + System.lineSeparator() + "    elif type(outs_i) == int:" + System.lineSeparator() + "      shm = shared_memory.SharedMemory(create=True, size=8)" + System.lineSeparator() + "      shm.buf[:8] = outs_i.to_bytes(8, byteorder='little', signed=True)" + System.lineSeparator() + "      " + SHMS_KEY + ".append(shm)" + System.lineSeparator() + "      " + SHM_NAMES_KEY + ".append(shm.name)" + System.lineSeparator() + "      " + DTYPES_KEY + ".append('int64')" + System.lineSeparator() + "      " + DIMS_KEY + ".append((1))" + System.lineSeparator() + "    elif type(outs_i) == float:" + System.lineSeparator() + "      shm = shared_memory.SharedMemory(create=True, size=8)" + System.lineSeparator() + "      shm.buf[:8] = outs_i.to_bytes(8, byteorder='little', signed=True)" + System.lineSeparator() + "      " + SHMS_KEY + ".append(shm)" + System.lineSeparator() + "      " + SHM_NAMES_KEY + ".append(shm.name)" + System.lineSeparator() + "      " + DTYPES_KEY + ".append('float64')" + System.lineSeparator() + "      " + DIMS_KEY + ".append((1))" + System.lineSeparator() + "    elif type(outs_i) == tuple or type(outs_i) == list:" + System.lineSeparator() + "      handle_output_list(outs_i)" + System.lineSeparator() + "    else:" + System.lineSeparator() + "      task.update('output type : ' + str(type(outs_i)) + ' not supported. Only supported output types are: np.ndarray, torch.tensor, int and float, or a list or tuple of any of those.')" + System.lineSeparator() + System.lineSeparator() + System.lineSeparator() + "def handle_output_list(out_list):" + System.lineSeparator() + "  if type(out_list) == tuple or type(out_list) == list:" + System.lineSeparator() + "    for outs_i in out_list:" + System.lineSeparator() + "      handle_output(outs_i)" + System.lineSeparator() + "  else:" + System.lineSeparator() + "    handle_output(out_list)" + System.lineSeparator() + "" + System.lineSeparator() + "" + System.lineSeparator() + "globals()['handle_output_list'] = handle_output_list" + System.lineSeparator() + "globals()['handle_output'] = handle_output" + System.lineSeparator() + "" + System.lineSeparator() + "" + System.lineSeparator() + "print('should be done')" + System.lineSeparator();
    private static final String CLEAN_SHM_CODE = "if '" + SHMS_KEY + "' in globals().keys():" + System.lineSeparator() + "  for s in " + SHMS_KEY + ":" + System.lineSeparator() + "    s.close()" + System.lineSeparator() + "    s.unlink()" + System.lineSeparator() + "    del s" + System.lineSeparator();
    private static final String JDLL_UUID = UUID.randomUUID().toString().replaceAll("-", "_");

    protected DLModelPytorchProtected(String modelFile, String callable, String importModule, String weightsPath, Map<String, Object> kwargs) throws IOException {
        this(modelFile, callable, importModule, weightsPath, kwargs, false);
    }

    protected DLModelPytorchProtected(String modelFile, String callable, String importModule, String weightsPath, Map<String, Object> kwargs, boolean customJDLL) throws IOException {
        if (!(customJDLL || new File(modelFile).isFile() && modelFile.endsWith(".py") || importModule != null)) {
            throw new IllegalArgumentException("The model file does not correspond to an existing .py file.");
        }
        if (!new File(weightsPath).isFile() || !customJDLL && !weightsPath.endsWith(".pt") && !weightsPath.endsWith(".pth")) {
            throw new IllegalArgumentException("The weights file does not correspond to an existing .pt/.pth file.");
        }
        this.callable = callable;
        this.modelFile = !customJDLL && modelFile != null && new File(modelFile).isFile() ? new File(modelFile).getAbsolutePath() : null;
        this.importModule = !customJDLL && importModule != null ? importModule : null;
        this.weightsPath = new File(weightsPath).getAbsolutePath();
        this.kwargs = kwargs;
        this.envPath = INSTALLATION_DIR + File.separator + "envs" + File.separator + COMMON_PYTORCH_ENV_NAME;
        this.createPythonService();
    }

    protected void createPythonService() throws IOException {
        Environment env = new Environment(){

            @Override
            public String base() {
                return DLModelPytorchProtected.this.envPath;
            }
        };
        this.python = env.python();
        this.python.debug(System.err::println);
    }

    public String getEnvPath() {
        return this.envPath;
    }

    public void setCustomEnvPath(String envPath) throws IOException {
        this.envPath = envPath;
        this.python.close();
        this.createPythonService();
    }

    public boolean isTiling() {
        return this.tiling;
    }

    public void setTiling(boolean doTiling) {
        this.tiling = doTiling;
    }

    public void setTileInfo(List<TileInfo> inputTiles, List<TileInfo> outputTiles) {
        this.inputTiles = inputTiles;
        this.outputTiles = outputTiles;
        this.tiling = true;
    }

    public void setTilingCounter(DLModelJava.TilingConsumer tileCounter) {
        this.tileCounter = tileCounter;
    }

    @Override
    public void loadModel() throws LoadModelException {
        if (this.loaded) {
            return;
        }
        if (this.closed) {
            throw new RuntimeException("Cannot load model after it has been closed");
        }
        try {
            String code = this.buildModelCode();
            code = code + RECOVER_OUTPUTS_CODE;
            System.out.println(code);
            Service.Task task = this.python.task(code);
            task.waitFor();
            if (task.status == Service.TaskStatus.CANCELED) {
                throw new RuntimeException("Task canceled");
            }
            if (task.status == Service.TaskStatus.FAILED) {
                throw new RuntimeException(task.error);
            }
            if (task.status == Service.TaskStatus.CRASHED) {
                throw new RuntimeException(task.error);
            }
        }
        catch (IOException | InterruptedException e) {
            throw new LoadModelException(Types.stackTrace(e));
        }
        this.loaded = true;
    }

    private static void copyAndReplace(String inputPath, String outputPath) throws IOException {
        if (new File(outputPath).isFile()) {
            return;
        }
        Files.write(Paths.get(outputPath, new String[0]), Files.readAllBytes(Paths.get(inputPath, new String[0])), new OpenOption[0]);
    }

    protected String buildModelCode() throws IOException {
        String addPath = "";
        String importStr = "";
        String code = "print('importing torch')" + System.lineSeparator() + "print(('torch' not in globals().keys()))" + System.lineSeparator() + "if 'torch' not in globals().keys():" + System.lineSeparator() + "  print('importing')" + System.lineSeparator() + "  import torch" + System.lineSeparator() + "  globals()['torch'] = torch" + System.lineSeparator() + "  print('done')" + System.lineSeparator() + "print('torch imported')" + System.lineSeparator();
        if (this.modelFile != null) {
            String moduleName = new File(this.modelFile).getName();
            if ((moduleName = moduleName.substring(0, moduleName.length() - 3)).contains("+")) {
                String newModelFile = this.modelFile.replaceAll("\\+", JDLL_UUID);
                DLModelPytorchProtected.copyAndReplace(this.modelFile, newModelFile);
                moduleName = new File(newModelFile).getName();
                moduleName = moduleName.substring(0, moduleName.length() - 3);
                addPath = String.format("sys.path.append(os.path.abspath(r'%s'))", new File(newModelFile).getParentFile().getAbsolutePath());
                importStr = String.format("from %s import %s", moduleName, this.callable);
            } else {
                addPath = String.format("sys.path.append(os.path.abspath(r'%s'))", new File(this.modelFile).getParentFile().getAbsolutePath());
                importStr = String.format("from %s import %s", moduleName, this.callable);
            }
        } else {
            importStr = String.format("from %s import %s", this.importModule, this.callable);
        }
        code = code + "print('everything imported')" + System.lineSeparator();
        code = code + String.format(LOAD_MODEL_CODE_ABSTRACT, addPath, importStr, this.callable, this.callable, this.callable);
        code = code + "print('model loaded')" + System.lineSeparator();
        code = code + MODEL_VAR_NAME + "=" + this.callable + "(" + this.codeForKwargs() + ")" + System.lineSeparator();
        code = code + "print('callable')" + System.lineSeparator();
        code = code + "try:" + System.lineSeparator() + "  " + MODEL_VAR_NAME + ".load_state_dict(torch.load(r'" + this.weightsPath + "', map_location=" + MODEL_VAR_NAME + ".device))" + System.lineSeparator() + "  print('mm')" + System.lineSeparator() + "except:" + System.lineSeparator() + "  " + MODEL_VAR_NAME + ".load_state_dict(torch.load(r'" + this.weightsPath + "', map_location=torch.device('cpu')))" + System.lineSeparator() + "  print('lol')" + System.lineSeparator();
        code = code + "globals()['" + MODEL_VAR_NAME + "'] = " + MODEL_VAR_NAME + System.lineSeparator();
        return code;
    }

    private String codeForKwargsList(List<Object> list) {
        String code = "[";
        for (Object codeVal : list) {
            code = codeVal == null ? code + "None" : (codeVal instanceof Boolean && (Boolean)codeVal != false || codeVal.equals("true") ? code + "True" : (codeVal instanceof Boolean && (Boolean)codeVal == false || codeVal.equals("false") ? code + "False" : (codeVal instanceof String ? code + "\"" + codeVal + "\"" : (codeVal instanceof List ? code + this.codeForKwargsList((List)codeVal) : (codeVal instanceof Map ? code + this.codeForKwargsMap((Map)codeVal) : code + codeVal)))));
            code = code + ",";
        }
        code = code + "]";
        return code;
    }

    private String codeForKwargsMap(Map<String, Object> map) {
        String code = "{";
        for (Map.Entry<String, Object> entry : map.entrySet()) {
            Object codeVal = entry.getValue();
            code = code + "'" + entry.getKey() + "':";
            code = codeVal == null ? code + "None" : (codeVal instanceof Boolean && (Boolean)codeVal != false || codeVal.equals("true") ? code + "True" : (codeVal instanceof Boolean && (Boolean)codeVal == false || codeVal.equals("false") ? code + "False" : (codeVal instanceof String ? code + "\"" + codeVal + "\"" : (codeVal instanceof List ? code + this.codeForKwargsList((List)codeVal) : (codeVal instanceof Map ? code + this.codeForKwargsMap((Map)codeVal) : code + codeVal)))));
            code = code + ",";
        }
        code = code + "}";
        return code;
    }

    private String codeForKwargs() {
        String code = "";
        for (Map.Entry<String, Object> ee : this.kwargs.entrySet()) {
            Object codeVal = ee.getValue();
            if (codeVal == null) {
                codeVal = "None";
            } else if (codeVal instanceof Boolean && ((Boolean)codeVal).booleanValue() || codeVal.equals("true")) {
                codeVal = "True";
            } else if (codeVal instanceof Boolean && !((Boolean)codeVal).booleanValue() || codeVal.equals("false")) {
                codeVal = "False";
            } else if (codeVal instanceof String) {
                codeVal = "\"" + codeVal + "\"";
            } else if (codeVal instanceof List) {
                codeVal = this.codeForKwargsList((List)codeVal);
            } else if (codeVal instanceof Map) {
                codeVal = this.codeForKwargsMap((Map)codeVal);
            }
            code = code + ee.getKey() + "=" + codeVal + ",";
        }
        return code;
    }

    @Override
    public void close() {
        if (!this.loaded) {
            return;
        }
        this.python.close();
        this.loaded = false;
        this.closed = true;
    }

    private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> Map<String, RandomAccessibleInterval<R>> predictForInputTensors(List<Tensor<T>> inTensors) throws RunModelException {
        if (!this.loaded) {
            throw new RuntimeException("Please load the model first.");
        }
        List<String> names = inTensors.stream().map(tt -> tt.getName() + "_np").collect(Collectors.toList());
        List<RandomAccessibleInterval<T>> rais = inTensors.stream().map(tt -> tt.getData()).collect(Collectors.toList());
        return this.executeCode(this.createInputsCode(rais, names));
    }

    private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> Map<String, RandomAccessibleInterval<R>> executeCode(String code) throws RunModelException {
        Map<String, RandomAccessibleInterval<R>> outMap;
        try {
            Service.Task task = this.python.task(code);
            task.waitFor();
            if (task.status == Service.TaskStatus.CANCELED) {
                throw new RuntimeException("Task canceled");
            }
            if (task.status == Service.TaskStatus.FAILED) {
                throw new RuntimeException(task.error);
            }
            if (task.status == Service.TaskStatus.CRASHED) {
                throw new RuntimeException(task.error);
            }
            this.loaded = true;
            outMap = this.reconstructOutputs(task);
            this.cleanShm();
        }
        catch (IOException | InterruptedException e) {
            try {
                this.cleanShm();
            }
            catch (IOException | InterruptedException e1) {
                throw new RunModelException(Types.stackTrace(e1));
            }
            throw new RunModelException(Types.stackTrace(e));
        }
        return outMap;
    }

    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<RandomAccessibleInterval<R>> inference(List<RandomAccessibleInterval<T>> inputs) throws RunModelException {
        if (!this.loaded) {
            throw new RuntimeException("Please load the model first.");
        }
        List<String> names = IntStream.range(0, inputs.size()).mapToObj(i -> "var_" + UUID.randomUUID().toString().replace("-", "_")).collect(Collectors.toList());
        String code = this.createInputsCode(inputs, names);
        Map<String, RandomAccessibleInterval<R>> map = this.executeCode(code);
        ArrayList<RandomAccessibleInterval<R>> outRais = new ArrayList<RandomAccessibleInterval<R>>();
        for (Map.Entry<String, RandomAccessibleInterval<R>> ee : map.entrySet()) {
            outRais.add(ee.getValue());
        }
        return outRais;
    }

    protected <T extends RealType<T> & NativeType<T>> String createInputsCode(List<RandomAccessibleInterval<T>> rais, List<String> names) {
        int i;
        String code = "";
        for (i = 0; i < rais.size(); ++i) {
            SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(rais.get(i), false, false);
            code = code + DLModelPytorchProtected.codeToConvertShmaToPython(shma, names.get(i));
            this.inShmaList.add(shma);
        }
        code = code + OUTPUT_LIST_KEY + " = " + MODEL_VAR_NAME + "(";
        for (i = 0; i < rais.size(); ++i) {
            code = code + "torch.from_numpy(" + names.get(i) + "), ";
        }
        code = code.substring(0, code.length() - 2);
        code = code + ")" + System.lineSeparator();
        code = code + String.format("print(type(%s))", OUTPUT_LIST_KEY) + System.lineSeparator();
        code = code + "" + SHMS_KEY + " = []" + System.lineSeparator() + SHM_NAMES_KEY + " = []" + System.lineSeparator() + DTYPES_KEY + " = []" + System.lineSeparator() + DIMS_KEY + " = []" + System.lineSeparator() + "globals()['" + SHMS_KEY + "'] = " + SHMS_KEY + System.lineSeparator() + "globals()['" + SHM_NAMES_KEY + "'] = " + SHM_NAMES_KEY + System.lineSeparator() + "globals()['" + DTYPES_KEY + "'] = " + DTYPES_KEY + System.lineSeparator() + "globals()['" + DIMS_KEY + "'] = " + DIMS_KEY + System.lineSeparator();
        code = code + "handle_output_list(" + OUTPUT_LIST_KEY + ")" + System.lineSeparator();
        code = code + this.taskOutputsCode();
        return code;
    }

    protected String taskOutputsCode() {
        String code = "task.outputs['" + SHM_NAMES_KEY + "'] = " + SHM_NAMES_KEY + System.lineSeparator() + "task.outputs['" + DTYPES_KEY + "'] = " + DTYPES_KEY + System.lineSeparator() + "task.outputs['" + DIMS_KEY + "'] = " + DIMS_KEY + System.lineSeparator();
        return code;
    }

    @Override
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<T>> run(List<Tensor<R>> inputTensors) throws RunModelException {
        if (!this.isLoaded()) {
            throw new RunModelException("Please first load the model.");
        }
        if (!this.tiling) {
            throw new UnsupportedOperationException("Cannot run a DLModel if no information about the outputs is provided. Either try with 'run( List< Tensor < T > > inTensors, List< Tensor < R > > outTensors )' or set the tiling information with 'setTileInfo(List<TileInfo> inputTiles, List<TileInfo> outputTiles)'. Another option is to run simple inference over an ImgLib2 RandomAccessibleInterval with 'inference(List<RandomAccessibleInteral<T>> input)'");
        }
        if (this.isTiling() && (this.inputTiles != null || this.inputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the input tiles are not well defined");
        }
        if (this.isTiling() && (this.outputTiles == null || this.outputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the output tiles are not well defined");
        }
        TileMaker maker = TileMaker.build(this.inputTiles, this.outputTiles);
        List<Tensor<T>> outTensors = this.createOutputTensors();
        this.runTiling(inputTensors, outTensors, maker);
        return outTensors;
    }

    private <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createOutputTensors() {
        ArrayList<Tensor<T>> outputTensors = new ArrayList<Tensor<T>>();
        for (TileInfo tt : this.outputTiles) {
            outputTensors.add(Tensor.buildBlankTensor(tt.getName(), tt.getImageAxesOrder(), tt.getImageDims(), new FloatType()));
        }
        return outputTensors;
    }

    @Override
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void run(List<Tensor<T>> inTensors, List<Tensor<R>> outTensors) throws RunModelException {
        if (!this.isLoaded()) {
            throw new RunModelException("Please first load the model.");
        }
        if (!this.tiling) {
            this.runNoTiles(inTensors, outTensors);
            return;
        }
        if (this.isTiling() && (this.inputTiles != null || this.inputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the input tiles are not well defined");
        }
        if (this.isTiling() && (this.outputTiles == null || this.outputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the output tiles are not well defined");
        }
        TileMaker tiles = TileMaker.build(this.inputTiles, this.outputTiles);
        for (int i = 0; i < tiles.getNumberOfTiles(); ++i) {
            Tensor<R> tt = outTensors.get(i);
            long[] expectedSize = tiles.getOutputImageSize(tt.getName());
            if (expectedSize == null) {
                throw new IllegalArgumentException("Tensor '" + tt.getName() + "' is missing in the outputs.");
            }
            if (tt.isEmpty() || !Arrays.equals(expectedSize, tt.getData().dimensionsAsLongArray())) continue;
            throw new IllegalArgumentException("Tensor '" + tt.getName() + "' size is different than the expected size defined for the output image: " + Arrays.toString(tt.getData().dimensionsAsLongArray()) + " vs " + Arrays.toString(expectedSize) + ".");
        }
        this.runTiling(inTensors, outTensors, tiles);
    }

    protected <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void runTiling(List<Tensor<R>> inputTensors, List<Tensor<T>> outputTensors, TileMaker tiles) throws RunModelException {
        for (int i = 0; i < tiles.getNumberOfTiles(); ++i) {
            int nTile = 0 + i;
            List<Tensor<T>> inputTiles = inputTensors.stream().map(tt -> tiles.getNthTileInput(tt, nTile)).collect(Collectors.toList());
            List<Tensor<R>> outputTiles = outputTensors.stream().map(tt -> tiles.getNthTileOutput(tt, nTile)).collect(Collectors.toList());
            this.runNoTiles(inputTiles, outputTiles);
        }
    }

    protected <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void runNoTiles(List<Tensor<T>> inTensors, List<Tensor<R>> outTensors) throws RunModelException {
        Map<String, RandomAccessibleInterval<R>> outMap = this.predictForInputTensors(inTensors);
        int c = 0;
        for (Map.Entry<String, RandomAccessibleInterval<R>> ee : outMap.entrySet()) {
            RandomAccessibleInterval<R> rai = ee.getValue();
            try {
                outTensors.get(c).setData(rai);
                ++c;
            }
            catch (Exception exception) {}
        }
    }

    private void closeShm() throws IOException {
        for (SharedMemoryArray shm : this.inShmaList) {
            shm.close();
        }
    }

    private void cleanShm() throws InterruptedException, IOException {
        this.closeShm();
        if (PlatformDetection.isWindows()) {
            Service.Task closeSHMTask = this.python.task(CLEAN_SHM_CODE);
            closeSHMTask.waitFor();
        }
    }

    protected <T extends RealType<T> & NativeType<T>> Map<String, RandomAccessibleInterval<T>> reconstructOutputs(Service.Task task) throws IOException {
        this.buildOutShmList(task);
        this.buildOutDTypesList(task);
        this.buildOutDimsList(task);
        LinkedHashMap<String, RandomAccessibleInterval<T>> outs = new LinkedHashMap<String, RandomAccessibleInterval<T>>();
        for (int i = 0; i < this.outShmNames.size(); ++i) {
            String shmName = this.outShmNames.get(i);
            String dtype = this.outShmDTypes.get(i);
            long[] dims = this.outShmDims.get(i);
            RandomAccessibleInterval<T> rai = this.reconstruct(shmName, dtype, dims);
            outs.put("output_" + i, rai);
        }
        return outs;
    }

    private void buildOutShmList(Service.Task task) {
        this.outShmNames = new ArrayList<String>();
        if (!(task.outputs.get(SHM_NAMES_KEY) instanceof List)) {
            throw new RuntimeException("Unexpected type for '" + SHM_NAMES_KEY + "'.");
        }
        List list = (List)task.outputs.get(SHM_NAMES_KEY);
        for (Object elem : list) {
            if (!(elem instanceof String)) {
                throw new RuntimeException("Unexpected type for element of  '" + SHM_NAMES_KEY + "' list.");
            }
            this.outShmNames.add((String)elem);
        }
    }

    private void buildOutDTypesList(Service.Task task) {
        this.outShmDTypes = new ArrayList<String>();
        if (!(task.outputs.get(DTYPES_KEY) instanceof List)) {
            throw new RuntimeException("Unexpected type for '" + DTYPES_KEY + "'.");
        }
        List list = (List)task.outputs.get(DTYPES_KEY);
        for (Object elem : list) {
            if (!(elem instanceof String)) {
                throw new RuntimeException("Unexpected type for element of  '" + DTYPES_KEY + "' list.");
            }
            this.outShmDTypes.add((String)elem);
        }
    }

    private void buildOutDimsList(Service.Task task) {
        this.outShmDims = new ArrayList<long[]>();
        if (!(task.outputs.get(DIMS_KEY) instanceof List)) {
            throw new RuntimeException("Unexpected type for '" + DIMS_KEY + "'.");
        }
        List list = (List)task.outputs.get(DIMS_KEY);
        for (Object elem : list) {
            int i;
            long[] longArr;
            Object arr;
            if (!(elem instanceof Object[]) && !(elem instanceof List)) {
                throw new RuntimeException("Unexpected type for element of  '" + DIMS_KEY + "' list.");
            }
            if (elem instanceof Object[]) {
                arr = (Object[])elem;
                longArr = new long[((Object[])arr).length];
                for (i = 0; i < ((Object)arr).length; ++i) {
                    if (!(arr[i] instanceof Number)) {
                        throw new RuntimeException("Unexpected type for array of element of  '" + DIMS_KEY + "' list.");
                    }
                    longArr[i] = ((Number)arr[i]).longValue();
                }
                this.outShmDims.add(longArr);
                continue;
            }
            if (elem instanceof List) {
                arr = (List)elem;
                longArr = new long[arr.size()];
                for (i = 0; i < arr.size(); ++i) {
                    if (!(arr.get(i) instanceof Number)) {
                        throw new RuntimeException("Unexpected type for array of element of  '" + DIMS_KEY + "' list.");
                    }
                    longArr[i] = ((Number)arr.get(i)).longValue();
                }
                this.outShmDims.add(longArr);
                continue;
            }
            throw new RuntimeException("Unexpected type for element of  '" + DIMS_KEY + "' list.");
        }
    }

    private <T extends RealType<T> & NativeType<T>> RandomAccessibleInterval<T> reconstruct(String key, String dtype, long[] dims) throws IOException {
        SharedMemoryArray shm = SharedMemoryArray.readOrCreate(key, dims, (RealType)Cast.unchecked(CommonUtils.getImgLib2DataType(dtype)), false, false);
        RandomAccessibleInterval rai = shm.getSharedRAI();
        RandomAccessibleInterval<RealType> raiCopy = Tensor.createCopyOfRaiInWantedDataType((RandomAccessibleInterval)Cast.unchecked(rai), (RealType)Util.getTypeFromInterval((Interval)Cast.unchecked(rai)));
        shm.close();
        return raiCopy;
    }

    protected static String codeToConvertShmaToPython(SharedMemoryArray shma, String varName) {
        String code = "";
        code = code + varName + "_shm = shared_memory.SharedMemory(name='" + shma.getNameForPython() + "', size=" + shma.getSize() + ")" + System.lineSeparator();
        long nElems = 1L;
        for (long elem : shma.getOriginalShape()) {
            nElems *= elem;
        }
        code = code + varName + " = np.ndarray(" + nElems + ", dtype='" + CommonUtils.getDataTypeFromRAI((RandomAccessibleInterval)Cast.unchecked(shma.getSharedRAI())) + "', buffer=" + varName + "_shm.buf).reshape([";
        for (int i = 0; i < shma.getOriginalShape().length; ++i) {
            code = code + shma.getOriginalShape()[i] + ", ";
        }
        code = code + "])" + System.lineSeparator();
        code = code + "if os.name == 'nt':" + System.lineSeparator() + "  im_shm.close()" + System.lineSeparator() + "  im_shm.unlink()" + System.lineSeparator();
        return code;
    }

    public static boolean isInstalled() {
        return DLModelPytorchProtected.isInstalled(null);
    }

    public static boolean isInstalled(String envPath) {
        if (envPath == null) {
            envPath = COMMON_PYTORCH_ENV_NAME;
        }
        Mamba mamba = new Mamba(INSTALLATION_DIR);
        try {
            boolean inst = mamba.checkAllDependenciesInEnv(envPath, BIAPY_CONDA_DEPS);
            if (!inst) {
                return inst;
            }
            inst = mamba.checkAllDependenciesInEnv(envPath, BIAPY_PIP_DEPS);
            if (!inst) {
                return inst;
            }
        }
        catch (MambaInstallException e) {
            return false;
        }
        return true;
    }

    public static void installRequirements() throws IOException, InterruptedException, RuntimeException, MambaInstallException, ArchiveException, URISyntaxException {
        DLModelPytorchProtected.installRequirements(null);
    }

    public static void installRequirements(Consumer<String> consumer) throws IOException, InterruptedException, RuntimeException, MambaInstallException, ArchiveException, URISyntaxException {
        Mamba mamba = new Mamba(INSTALLATION_DIR);
        if (consumer != null) {
            mamba.setConsoleOutputConsumer(consumer);
            mamba.setErrorOutputConsumer(consumer);
        }
        boolean biapyPythonInstalled = false;
        try {
            biapyPythonInstalled = mamba.checkAllDependenciesInEnv(COMMON_PYTORCH_ENV_NAME, BIAPY_CONDA_DEPS);
            biapyPythonInstalled = mamba.checkAllDependenciesInEnv(COMMON_PYTORCH_ENV_NAME, BIAPY_PIP_DEPS);
        }
        catch (MambaInstallException e) {
            mamba.installMicromamba();
        }
        if (!biapyPythonInstalled) {
            mamba.create(COMMON_PYTORCH_ENV_NAME, true, new ArrayList<String>(), BIAPY_CONDA_DEPS);
            ArrayList<String> args = new ArrayList<String>(BIAPY_PIP_ARGS);
            args.addAll(BIAPY_PIP_DEPS_TORCH);
            mamba.pipInstallIn(COMMON_PYTORCH_ENV_NAME, args.toArray(new String[args.size()]));
            mamba.pipInstallIn(COMMON_PYTORCH_ENV_NAME, BIAPY_PIP_DEPS.toArray(new String[BIAPY_PIP_DEPS.size()]));
        }
    }

    public static void setInstallationDir(String installationDir) {
        INSTALLATION_DIR = installationDir;
    }

    public static String getInstallationDir() {
        return INSTALLATION_DIR;
    }
}

