package io.bioimage.modelrunner.runmode;

import io.bioimage.modelrunner.apposed.appose.Appose;
import io.bioimage.modelrunner.apposed.appose.Environment;
import io.bioimage.modelrunner.apposed.appose.Service;
import io.bioimage.modelrunner.numpy.DecodeNumpy;
import io.bioimage.modelrunner.runmode.ops.OpInterface;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryFile;
import io.bioimage.modelrunner.utils.CommonUtils;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.FileAlreadyExistsException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.IntStream;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;

/* loaded from: input_file:io/bioimage/modelrunner/runmode/RunMode.class */
public class RunMode {
    private static final String IMPORT_XARRAY = "t = time()" + System.lineSeparator() + "import xarray as xr" + System.lineSeparator() + "task.update('xr imported: ' + str(time() - t))" + System.lineSeparator();
    private static final String IMPORT_NUMPY = "t = time()" + System.lineSeparator() + "import numpy as np" + System.lineSeparator() + "task.update('numpy imported: ' + str(time() - t))" + System.lineSeparator();
    private static final String IMPORT_SHM = "t = time()" + System.lineSeparator() + "from multiprocessing import shared_memory" + System.lineSeparator() + "task.update('multiproc imported: ' + str(time() - t))" + System.lineSeparator();
    protected static final String APPOSE_SHM_KEY = ("_shm_" + UUID.randomUUID().toString()).replace("-", "_");
    private static final String OUTPUT_REFORMATING = "if str(type(%s)) == \"<class 'xarray.core.dataarray.DataArray'>\" and False:" + System.lineSeparator() + "  %s = convertXrIntoDic_file(%s)" + System.lineSeparator() + "elif str(type(%s)) == \"<class 'xarray.core.dataarray.DataArray'>\":" + System.lineSeparator() + "  %s = convertXrIntoDic(%s)" + System.lineSeparator() + "elif isinstance(%s, np.ndarray):" + System.lineSeparator() + "  %s = convertNpIntoDic(%s)" + System.lineSeparator() + "elif isinstance(%s, list):" + System.lineSeparator() + "  %s = convertListIntoSupportedList(%s)" + System.lineSeparator() + "elif isinstance(%s, dict):" + System.lineSeparator() + "  %s = convertDicIntoDic(%s)" + System.lineSeparator();
    private static final String DEFAULT_IMPORT = "t = time()" + System.lineSeparator() + "import sys" + System.lineSeparator() + "task.update('sys imported: ' + str(time() - t))" + System.lineSeparator() + "t = time()" + System.lineSeparator() + "import os" + System.lineSeparator() + "task.update('os imported: ' + str(time() - t))" + System.lineSeparator() + IMPORT_NUMPY;
    private Environment env;
    private String envFileName;
    private String opCode;
    private OpInterface op;
    private LinkedHashMap<String, Object> apposeInputMap;
    private String moduleName;
    private String tensorRecreationCode = "";
    private String importsCode = "";
    private String opMethodCode = "";
    private String retrieveResultsCode = "";
    private String taskOutputCode = "";
    private String shmInstancesCode = "task.update('just started')" + System.lineSeparator() + "from time import time" + System.lineSeparator() + "shm_out_list = []" + System.lineSeparator() + "globals()['shm_out_list'] = shm_out_list" + System.lineSeparator() + "task.update('time imported')" + System.lineSeparator();
    private String closeShmCode = "";
    private List<SharedMemoryArray> shmaList = new ArrayList();
    private List<String> outputNames = new ArrayList();
    private List<String> filesToDestroy = new ArrayList();

    private RunMode(OpInterface opInterface) {
        this.op = opInterface;
        this.moduleName = opInterface.getOpPythonFilename().substring(0, opInterface.getOpPythonFilename().length() - 3);
        IntStream.range(0, opInterface.getNumberOfOutputs()).forEach(i -> {
            this.outputNames.add("output" + i);
        });
        addImports();
        convertInputMap();
        opExecutionCode();
        retrieveResultsCode();
        this.opCode = this.shmInstancesCode + System.lineSeparator() + this.importsCode + System.lineSeparator() + RunModeScripts.TYPE_CONVERSION_METHODS_SCRIPT + System.lineSeparator() + this.tensorRecreationCode + System.lineSeparator() + this.opMethodCode + System.lineSeparator() + this.retrieveResultsCode + System.lineSeparator() + this.taskOutputCode;
        System.out.println(this.opCode);
    }

    public static RunMode createRunMode(OpInterface opInterface) {
        return new RunMode(opInterface);
    }

    public Map<String, Object> runOP() throws IOException, InterruptedException {
        this.env = new Environment() { // from class: io.bioimage.modelrunner.runmode.RunMode.1
            @Override // io.bioimage.modelrunner.apposed.appose.Environment
            public String base() {
                return RunMode.this.op.getCondaEnv();
            }

            @Override // io.bioimage.modelrunner.apposed.appose.Environment
            public boolean useSystemPath() {
                return false;
            }
        };
        try {
            Service python = this.env.python();
            try {
                python.debug(str -> {
                    System.err.println(str);
                });
                Service.Task task = python.task(this.opCode, this.apposeInputMap);
                task.waitFor();
                if (task.status != Service.TaskStatus.COMPLETE) {
                    throw new RuntimeException("Error running Python code: " + task.error);
                }
                Map<String, Object> recreateOutputObjects = recreateOutputObjects(task.outputs);
                if (python != null) {
                    python.close();
                }
                return recreateOutputObjects;
            } catch (Throwable th) {
                if (python != null) {
                    try {
                        python.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        } catch (IOException e) {
            closeShmaList();
            throw e;
        } catch (InterruptedException e2) {
            closeShmaList();
            throw e2;
        } catch (Exception e3) {
            closeShmaList();
            throw new RuntimeException(e3);
        }
    }

    private void closeShmaList() {
        this.shmaList.stream().forEach(sharedMemoryArray -> {
            try {
                sharedMemoryArray.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        });
    }

    private static Map<String, Object> recreateOutputObjects(Map<String, Object> map) throws FileNotFoundException, IOException {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (Map.Entry<String, Object> entry : map.entrySet()) {
            Object value = entry.getValue();
            if ((value instanceof Map) && ((Map) value).get(RunModeScripts.APPOSE_DT_KEY) != null && ((Map) value).get(RunModeScripts.APPOSE_DT_KEY).equals("tensor")) {
                if (((Map) value).get("name") == null) {
                    ((Map) value).put("name", entry.getKey());
                }
                linkedHashMap.put(entry.getKey(), createTensorFromApposeOutput((Map) value));
            } else if ((value instanceof Map) && ((Map) value).get(RunModeScripts.APPOSE_DT_KEY) != null && ((Map) value).get(RunModeScripts.APPOSE_DT_KEY).equals("tensor_file")) {
                if (((Map) value).get("name") == null) {
                    ((Map) value).put("name", entry.getKey());
                }
                linkedHashMap.put(entry.getKey(), createTensorFromApposeOutputFile((Map) value));
            } else if ((value instanceof Map) && ((Map) value).get(RunModeScripts.APPOSE_DT_KEY) != null && ((Map) value).get(RunModeScripts.APPOSE_DT_KEY).equals("np_arr")) {
                linkedHashMap.put(entry.getKey(), createImgLib2ArrFromApposeOutput((Map) value));
            } else if (value instanceof Map) {
                linkedHashMap.put(entry.getKey(), recreateOutputObjects((Map) value));
            } else if (value instanceof List) {
                linkedHashMap.put(entry.getKey(), createListFromApposeOutput((List) value));
            } else {
                if (!isTypeDirectlySupported(value.getClass())) {
                    throw new IllegalArgumentException("Type of output named: '" + entry.getKey() + "' not supported (" + value.getClass() + ").");
                }
                linkedHashMap.put(entry.getKey(), value);
            }
        }
        return linkedHashMap;
    }

    public void envCreation() {
        if (checkRequiredEnvExists()) {
            this.env = Appose.base(new File(this.envFileName)).build();
        } else {
            this.env = Appose.conda(new File(this.envFileName)).build();
        }
    }

    public boolean checkRequiredEnvExists() {
        return false;
    }

    private void addImports() {
        this.importsCode = DEFAULT_IMPORT + "t = time()" + System.lineSeparator() + "sys.path.append(r'" + this.op.getOpDir() + "')" + System.lineSeparator() + "task.update('extra file imported: ' + str(time() - t))" + System.lineSeparator() + "t = time()" + System.lineSeparator() + "t2 = time()" + System.lineSeparator() + "import " + this.moduleName + System.lineSeparator() + "task.update('extra module imported: ' + str(time() - t))" + System.lineSeparator() + "task.update('Imports')" + System.lineSeparator();
    }

    private <T extends RealType<T> & NativeType<T>> void convertInputMap() {
        this.apposeInputMap = new LinkedHashMap<>();
        if (this.op.getOpInputs() == null) {
            return;
        }
        for (Map.Entry<String, Object> entry : this.op.getOpInputs().entrySet()) {
            if (entry.getValue() instanceof String) {
                this.apposeInputMap.put(entry.getKey(), entry.getValue());
            } else if (entry.getValue() instanceof Tensor) {
                SharedMemoryArray createSHMAFromRAI = SharedMemoryArray.createSHMAFromRAI(((Tensor) entry.getValue()).getData(), false, false);
                this.shmaList.add(createSHMAFromRAI);
                this.apposeInputMap.put(entry.getKey(), null);
                addCodeToRecreateTensor(entry.getKey(), createSHMAFromRAI, (Tensor) entry.getValue());
            } else if (entry.getValue() instanceof RandomAccessibleInterval) {
                SharedMemoryArray createSHMAFromRAI2 = SharedMemoryArray.createSHMAFromRAI((RandomAccessibleInterval) entry.getValue(), false, false);
                this.shmaList.add(createSHMAFromRAI2);
                this.apposeInputMap.put(entry.getKey(), null);
                addCodeToRecreateNumpyArray(entry.getKey(), createSHMAFromRAI2, (RandomAccessibleInterval) entry.getValue());
            } else if (!entry.getValue().getClass().isArray() && isTypeDirectlySupported(entry.getValue().getClass())) {
                this.apposeInputMap.put(entry.getKey(), entry.getValue());
            } else if (entry.getValue().getClass().isArray() && isTypeDirectlySupported(entry.getValue().getClass().getComponentType())) {
                this.apposeInputMap.put(entry.getKey(), entry.getValue());
            } else if ((entry.getValue() instanceof List) && ((List) entry.getValue()).size() == 0) {
                this.apposeInputMap.put(entry.getKey(), new Object[0]);
            } else {
                if (!(entry.getValue() instanceof List) || !isTypeDirectlySupported(((List) entry.getValue()).get(0).getClass())) {
                    throw new IllegalArgumentException("The type of the input argument: '" + entry.getKey() + "' is not supported (" + entry.getValue().getClass());
                }
                this.apposeInputMap.put(entry.getKey(), ((List) entry.getValue()).toArray());
            }
        }
    }

    private static boolean isTypeDirectlySupported(Class<?> cls) {
        return Number.class.isAssignableFrom(cls) || cls.isPrimitive() || String.class.isAssignableFrom(cls);
    }

    private <T extends RealType<T> & NativeType<T>> void addCodeToRecreateTensorFile(String str, Tensor<T> tensor, String str2) {
        if (!this.importsCode.contains(IMPORT_XARRAY)) {
            this.importsCode += IMPORT_XARRAY;
        }
        if (!this.importsCode.contains(IMPORT_NUMPY)) {
            this.importsCode += IMPORT_NUMPY;
        }
        this.tensorRecreationCode += str + " = xr.DataArray(np.load(r'" + str2 + "'), dims=[";
        for (String str3 : tensor.getAxesOrderString().split("")) {
            this.tensorRecreationCode += "\"" + str3 + "\", ";
        }
        this.tensorRecreationCode = this.tensorRecreationCode.substring(0, this.tensorRecreationCode.length() - 2);
        this.tensorRecreationCode += "], name=\"" + tensor.getName() + "\")";
        this.tensorRecreationCode += System.lineSeparator();
    }

    private <T extends RealType<T> & NativeType<T>> void addCodeToRecreateTensor(String str, Tensor<T> tensor) {
        if (!this.importsCode.contains(IMPORT_XARRAY)) {
            this.importsCode += IMPORT_XARRAY;
        }
        if (!this.importsCode.contains(IMPORT_NUMPY)) {
            this.importsCode += IMPORT_NUMPY;
        }
        this.tensorRecreationCode += str + " = xr.DataArray(np.array(" + str + ").reshape([";
        for (int i : tensor.getShape()) {
            this.tensorRecreationCode += i + ", ";
        }
        this.tensorRecreationCode = this.tensorRecreationCode.substring(0, this.tensorRecreationCode.length() - 2);
        this.tensorRecreationCode += "]), dims=[";
        for (String str2 : tensor.getAxesOrderString().split("")) {
            this.tensorRecreationCode += "\"" + str2 + "\", ";
        }
        this.tensorRecreationCode = this.tensorRecreationCode.substring(0, this.tensorRecreationCode.length() - 2);
        this.tensorRecreationCode += "], name=\"" + tensor.getName() + "\")";
        this.tensorRecreationCode += System.lineSeparator();
    }

    private <T extends RealType<T> & NativeType<T>> void addCodeToRecreateNumpyArray(String str, RandomAccessibleInterval<T> randomAccessibleInterval) {
        if (!this.importsCode.contains(IMPORT_NUMPY)) {
            this.importsCode += IMPORT_NUMPY;
        }
        this.tensorRecreationCode += str + " = np.array(" + str + ").reshape([";
        for (long j : randomAccessibleInterval.dimensionsAsLongArray()) {
            this.tensorRecreationCode += j + ", ";
        }
        this.tensorRecreationCode = this.tensorRecreationCode.substring(0, this.tensorRecreationCode.length() - 2);
        this.tensorRecreationCode += "])" + System.lineSeparator();
    }

    private <T extends RealType<T> & NativeType<T>> void addCodeToRecreateTensor(String str, SharedMemoryArray sharedMemoryArray, Tensor<T> tensor) {
        if (!this.importsCode.contains(IMPORT_XARRAY)) {
            this.importsCode += IMPORT_XARRAY;
        }
        if (!this.shmInstancesCode.contains(IMPORT_SHM)) {
            this.shmInstancesCode += IMPORT_SHM;
        }
        this.shmInstancesCode += str + APPOSE_SHM_KEY + " = shared_memory.SharedMemory(name='" + sharedMemoryArray.getNameForPython() + "', size=" + sharedMemoryArray.getSize() + ")" + System.lineSeparator();
        this.shmInstancesCode += "shm_out_list.append(" + str + APPOSE_SHM_KEY + ")" + System.lineSeparator();
        this.shmInstancesCode += str + APPOSE_SHM_KEY + ".unlink()" + System.lineSeparator();
        int i = 1;
        long[] dimensionsAsLongArray = tensor.getData().dimensionsAsLongArray();
        for (long j : dimensionsAsLongArray) {
            i = (int) (i * j);
        }
        this.tensorRecreationCode += str + " = xr.DataArray(np.ndarray(" + i + ", dtype='" + CommonUtils.getDataTypeFromRAI(tensor.getData()) + "', buffer=" + str + APPOSE_SHM_KEY + ".buf).reshape([";
        for (long j2 : dimensionsAsLongArray) {
            this.tensorRecreationCode += j2 + ", ";
        }
        this.tensorRecreationCode = this.tensorRecreationCode.substring(0, this.tensorRecreationCode.length() - 2);
        this.tensorRecreationCode += "]), dims=[";
        for (String str2 : tensor.getAxesOrderString().split("")) {
            this.tensorRecreationCode += "\"" + str2 + "\", ";
        }
        this.tensorRecreationCode = this.tensorRecreationCode.substring(0, this.tensorRecreationCode.length() - 2);
        this.tensorRecreationCode += "], name=\"" + tensor.getName() + "\")";
        this.tensorRecreationCode += System.lineSeparator();
    }

    private <T extends RealType<T> & NativeType<T>> void addCodeToRecreateNumpyArray(String str, SharedMemoryArray sharedMemoryArray, RandomAccessibleInterval<T> randomAccessibleInterval) {
        if (!this.importsCode.contains(IMPORT_NUMPY)) {
            this.importsCode += IMPORT_NUMPY;
        }
        if (!this.shmInstancesCode.contains(IMPORT_SHM)) {
            this.shmInstancesCode += IMPORT_SHM;
        }
        this.shmInstancesCode += str + APPOSE_SHM_KEY + " = shared_memory.SharedMemory(name='" + sharedMemoryArray.getNameForPython() + "', size=" + sharedMemoryArray.getSize() + ")" + System.lineSeparator();
        this.shmInstancesCode += "shm_out_list.append(" + str + APPOSE_SHM_KEY + ")" + System.lineSeparator();
        this.shmInstancesCode += str + APPOSE_SHM_KEY + ".unlink()" + System.lineSeparator();
        int i = 1;
        long[] dimensionsAsLongArray = randomAccessibleInterval.dimensionsAsLongArray();
        for (long j : dimensionsAsLongArray) {
            i = (int) (i * j);
        }
        this.tensorRecreationCode += str + " = np.ndarray(" + i + ", dtype='" + CommonUtils.getDataTypeFromRAI(randomAccessibleInterval) + "', buffer=" + str + APPOSE_SHM_KEY + ".buf).reshape([";
        for (long j2 : dimensionsAsLongArray) {
            this.tensorRecreationCode += j2 + ", ";
        }
        this.tensorRecreationCode = this.tensorRecreationCode.substring(0, this.tensorRecreationCode.length() - 2);
        this.tensorRecreationCode += "])" + System.lineSeparator();
    }

    private void opExecutionCode() {
        this.opMethodCode = "";
        this.opMethodCode += "task.update('method')" + System.lineSeparator();
        Iterator<String> it = this.outputNames.iterator();
        while (it.hasNext()) {
            this.opMethodCode += it.next() + ", ";
        }
        this.opMethodCode = this.opMethodCode.substring(0, this.opMethodCode.length() - 2);
        this.opMethodCode += " = ";
        this.opMethodCode += this.moduleName + "." + this.op.getMethodName() + "(";
        Iterator<String> it2 = this.apposeInputMap.keySet().iterator();
        while (it2.hasNext()) {
            this.opMethodCode += it2.next() + ",";
        }
        this.opMethodCode += ")" + System.lineSeparator();
    }

    private void retrieveResultsCode() {
        this.retrieveResultsCode = "task.update('Preparing outputs')" + System.lineSeparator();
        for (String str : this.outputNames) {
            this.retrieveResultsCode += String.format(OUTPUT_REFORMATING, str, str, str, str, str, str, str, str, str, str, str, str, str, str, str);
            this.taskOutputCode += String.format("task.outputs['%s'] = %s", str, str) + System.lineSeparator();
        }
    }

    private static <T extends RealType<T> & NativeType<T>> Tensor<T> createTensorFromApposeOutput(Map<String, Object> map) {
        String str = (String) map.get("data");
        List list = (List) map.get(DecodeNumpy.SHAPE_KEY);
        long[] jArr = new long[list.size()];
        for (int i = 0; i < list.size(); i++) {
            jArr[i] = ((Number) list.get(i)).longValue();
        }
        String str2 = (String) map.get(DecodeNumpy.DTYPE_KEY);
        String str3 = (String) map.get("name");
        String str4 = (String) map.get("axes");
        try {
            SharedMemoryArray readOrCreate = SharedMemoryArray.readOrCreate(str, jArr, (RealType) Cast.unchecked(CommonUtils.getImgLib2DataType(str2)), ((Boolean) map.get("is_fortran")).booleanValue(), false);
            RandomAccessibleInterval<T> sharedRAI = readOrCreate.getSharedRAI();
            RandomAccessibleInterval createCopyOfRaiInWantedDataType = Tensor.createCopyOfRaiInWantedDataType(sharedRAI, (RealType) Util.getTypeFromInterval(sharedRAI));
            try {
                readOrCreate.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
            return Tensor.build(str3, str4, createCopyOfRaiInWantedDataType);
        } catch (FileAlreadyExistsException e2) {
            throw new RuntimeException("Error retrieving the image from Python", e2);
        }
    }

    private static <T extends RealType<T> & NativeType<T>> Tensor<T> createTensorFromApposeOutputFile(Map<String, Object> map) throws FileNotFoundException, IOException {
        return Tensor.build((String) map.get("name"), (String) map.get("axes"), SharedMemoryFile.buildRaiFromFile((String) map.get("file_path")));
    }

    private static <T extends RealType<T> & NativeType<T>> RandomAccessibleInterval<T> createImgLib2ArrFromApposeOutput(Map<String, Object> map) {
        String str = (String) map.get("data");
        List list = (List) map.get(DecodeNumpy.SHAPE_KEY);
        long[] jArr = new long[list.size()];
        for (int i = 0; i < list.size(); i++) {
            jArr[i] = ((Number) list.get(i)).longValue();
        }
        try {
            SharedMemoryArray readOrCreate = SharedMemoryArray.readOrCreate(str, jArr, (RealType) Cast.unchecked(CommonUtils.getImgLib2DataType((String) map.get(DecodeNumpy.DTYPE_KEY))), ((Boolean) map.get("is_fortran")).booleanValue(), false);
            RandomAccessibleInterval<T> sharedRAI = readOrCreate.getSharedRAI();
            RandomAccessibleInterval<T> createCopyOfRaiInWantedDataType = Tensor.createCopyOfRaiInWantedDataType(sharedRAI, (RealType) Util.getTypeFromInterval(sharedRAI));
            try {
                readOrCreate.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
            return createCopyOfRaiInWantedDataType;
        } catch (FileAlreadyExistsException e2) {
            throw new RuntimeException("Error retrieving the image from Python", e2);
        }
    }

    private static List<Object> createListFromApposeOutput(List<Object> list) throws FileNotFoundException, IOException {
        ArrayList arrayList = new ArrayList();
        for (Object obj : list) {
            if (!(obj instanceof Map) || ((Map) obj).get(RunModeScripts.APPOSE_DT_KEY) == null || !((Map) obj).get(RunModeScripts.APPOSE_DT_KEY).equals("tensor")) {
                if (!(obj instanceof Map) || ((Map) obj).get(RunModeScripts.APPOSE_DT_KEY) == null || !((Map) obj).get(RunModeScripts.APPOSE_DT_KEY).equals("np_arr")) {
                    if (obj instanceof Map) {
                        arrayList.add(recreateOutputObjects((Map) obj));
                    } else if (obj instanceof List) {
                        arrayList.add(createListFromApposeOutput((List) obj));
                    } else {
                        if (!isTypeDirectlySupported(obj.getClass())) {
                            throw new IllegalArgumentException("Type of output not supported (" + obj.getClass() + ").");
                        }
                        arrayList.add(obj);
                    }
                }
            }
        }
        return arrayList;
    }
}
