/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.runmode;

import io.bioimage.modelrunner.apposed.appose.Appose;
import io.bioimage.modelrunner.apposed.appose.Environment;
import io.bioimage.modelrunner.apposed.appose.Service;
import io.bioimage.modelrunner.runmode.RunModeScripts;
import io.bioimage.modelrunner.runmode.ops.OpInterface;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryFile;
import io.bioimage.modelrunner.utils.CommonUtils;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.FileAlreadyExistsException;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.IntStream;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;

public class RunMode {
    private static final String IMPORT_XARRAY = "t = time()" + System.lineSeparator() + "import xarray as xr" + System.lineSeparator() + "task.update('xr imported: ' + str(time() - t))" + System.lineSeparator();
    private static final String IMPORT_NUMPY = "t = time()" + System.lineSeparator() + "import numpy as np" + System.lineSeparator() + "task.update('numpy imported: ' + str(time() - t))" + System.lineSeparator();
    private static final String IMPORT_SHM = "t = time()" + System.lineSeparator() + "from multiprocessing import shared_memory" + System.lineSeparator() + "task.update('multiproc imported: ' + str(time() - t))" + System.lineSeparator();
    protected static final String APPOSE_SHM_KEY = ("_shm_" + UUID.randomUUID().toString()).replace("-", "_");
    private static final String OUTPUT_REFORMATING = "if str(type(%s)) == \"<class 'xarray.core.dataarray.DataArray'>\" and False:" + System.lineSeparator() + "  %s = " + "convertXrIntoDic" + "_file(%s)" + System.lineSeparator() + "elif str(type(%s)) == \"<class 'xarray.core.dataarray.DataArray'>\":" + System.lineSeparator() + "  %s = " + "convertXrIntoDic" + "(%s)" + System.lineSeparator() + "elif isinstance(%s, np.ndarray):" + System.lineSeparator() + "  %s = " + "convertNpIntoDic" + "(%s)" + System.lineSeparator() + "elif isinstance(%s, list):" + System.lineSeparator() + "  %s = " + "convertListIntoSupportedList" + "(%s)" + System.lineSeparator() + "elif isinstance(%s, dict):" + System.lineSeparator() + "  %s = " + "convertDicIntoDic" + "(%s)" + System.lineSeparator();
    private static final String DEFAULT_IMPORT = "t = time()" + System.lineSeparator() + "import sys" + System.lineSeparator() + "task.update('sys imported: ' + str(time() - t))" + System.lineSeparator() + "t = time()" + System.lineSeparator() + "import os" + System.lineSeparator() + "task.update('os imported: ' + str(time() - t))" + System.lineSeparator() + IMPORT_NUMPY;
    private Environment env;
    private String envFileName;
    private String opCode;
    private OpInterface op;
    private LinkedHashMap<String, Object> apposeInputMap;
    private String tensorRecreationCode = "";
    private String importsCode = "";
    private String opMethodCode = "";
    private String retrieveResultsCode = "";
    private String taskOutputCode = "";
    private String shmInstancesCode = "task.update('just started')" + System.lineSeparator() + "from time import time" + System.lineSeparator() + "shm_out_list = []" + System.lineSeparator() + "globals()['shm_out_list'] = shm_out_list" + System.lineSeparator() + "task.update('time imported')" + System.lineSeparator();
    private String closeShmCode = "";
    private String moduleName;
    private List<SharedMemoryArray> shmaList = new ArrayList<SharedMemoryArray>();
    private List<String> outputNames = new ArrayList<String>();
    private List<String> filesToDestroy = new ArrayList<String>();

    private RunMode(OpInterface op) {
        this.op = op;
        this.moduleName = op.getOpPythonFilename().substring(0, op.getOpPythonFilename().length() - 3);
        IntStream.range(0, op.getNumberOfOutputs()).forEach(i -> this.outputNames.add("output" + i));
        this.addImports();
        this.convertInputMap();
        this.opExecutionCode();
        this.retrieveResultsCode();
        this.opCode = this.shmInstancesCode + System.lineSeparator() + this.importsCode + System.lineSeparator() + RunModeScripts.TYPE_CONVERSION_METHODS_SCRIPT + System.lineSeparator() + this.tensorRecreationCode + System.lineSeparator() + this.opMethodCode + System.lineSeparator() + this.retrieveResultsCode + System.lineSeparator() + this.taskOutputCode;
        System.out.println(this.opCode);
    }

    public static RunMode createRunMode(OpInterface op) {
        return new RunMode(op);
    }

    public Map<String, Object> runOP() throws IOException, InterruptedException {
        this.env = new Environment(){

            @Override
            public String base() {
                return RunMode.this.op.getCondaEnv();
            }

            @Override
            public boolean useSystemPath() {
                return false;
            }
        };
        Map<String, Object> outputs = null;
        try (Service python = this.env.python();){
            python.debug(line -> System.err.println((String)line));
            Service.Task task = python.task(this.opCode, this.apposeInputMap);
            task.waitFor();
            if (task.status != Service.TaskStatus.COMPLETE) {
                throw new RuntimeException("Error running Python code: " + task.error);
            }
            outputs = RunMode.recreateOutputObjects(task.outputs);
        }
        catch (IOException e) {
            this.closeShmaList();
            throw e;
        }
        catch (InterruptedException e) {
            this.closeShmaList();
            throw e;
        }
        catch (Exception e) {
            this.closeShmaList();
            throw new RuntimeException(e);
        }
        return outputs;
    }

    private void closeShmaList() {
        this.shmaList.stream().forEach(entry -> {
            try {
                entry.close();
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        });
    }

    private static Map<String, Object> recreateOutputObjects(Map<String, Object> apposeOuts) throws FileNotFoundException, IOException {
        LinkedHashMap<String, Object> jdllOuts = new LinkedHashMap<String, Object>();
        for (Map.Entry<String, Object> entry : apposeOuts.entrySet()) {
            Object value = entry.getValue();
            if (value instanceof Map && ((Map)value).get(RunModeScripts.APPOSE_DT_KEY) != null && ((Map)value).get(RunModeScripts.APPOSE_DT_KEY).equals("tensor")) {
                if (((Map)value).get("name") == null) {
                    ((Map)value).put("name", entry.getKey());
                }
                jdllOuts.put(entry.getKey(), RunMode.createTensorFromApposeOutput((Map)value));
                continue;
            }
            if (value instanceof Map && ((Map)value).get(RunModeScripts.APPOSE_DT_KEY) != null && ((Map)value).get(RunModeScripts.APPOSE_DT_KEY).equals("tensor_file")) {
                if (((Map)value).get("name") == null) {
                    ((Map)value).put("name", entry.getKey());
                }
                jdllOuts.put(entry.getKey(), RunMode.createTensorFromApposeOutputFile((Map)value));
                continue;
            }
            if (value instanceof Map && ((Map)value).get(RunModeScripts.APPOSE_DT_KEY) != null && ((Map)value).get(RunModeScripts.APPOSE_DT_KEY).equals("np_arr")) {
                jdllOuts.put(entry.getKey(), RunMode.createImgLib2ArrFromApposeOutput((Map)value));
                continue;
            }
            if (value instanceof Map) {
                jdllOuts.put(entry.getKey(), RunMode.recreateOutputObjects((Map)value));
                continue;
            }
            if (value instanceof List) {
                jdllOuts.put(entry.getKey(), RunMode.createListFromApposeOutput((List)value));
                continue;
            }
            if (RunMode.isTypeDirectlySupported(value.getClass())) {
                jdllOuts.put(entry.getKey(), value);
                continue;
            }
            throw new IllegalArgumentException("Type of output named: '" + entry.getKey() + "' not supported (" + value.getClass() + ").");
        }
        return jdllOuts;
    }

    public void envCreation() {
        if (this.checkRequiredEnvExists()) {
            this.env = Appose.base(new File(this.envFileName)).build();
            return;
        }
        this.env = Appose.conda(new File(this.envFileName)).build();
    }

    public boolean checkRequiredEnvExists() {
        return false;
    }

    private void addImports() {
        this.importsCode = DEFAULT_IMPORT + "t = time()" + System.lineSeparator() + "sys.path.append(r'" + this.op.getOpDir() + "')" + System.lineSeparator() + "task.update('extra file imported: ' + str(time() - t))" + System.lineSeparator() + "t = time()" + System.lineSeparator() + "t2 = time()" + System.lineSeparator() + "import " + this.moduleName + System.lineSeparator() + "task.update('extra module imported: ' + str(time() - t))" + System.lineSeparator() + "task.update('Imports')" + System.lineSeparator();
    }

    private <T extends RealType<T> & NativeType<T>> void convertInputMap() {
        this.apposeInputMap = new LinkedHashMap();
        if (this.op.getOpInputs() == null) {
            return;
        }
        for (Map.Entry<String, Object> entry : this.op.getOpInputs().entrySet()) {
            SharedMemoryArray shma;
            if (entry.getValue() instanceof String) {
                this.apposeInputMap.put(entry.getKey(), entry.getValue());
                continue;
            }
            if (entry.getValue() instanceof Tensor) {
                shma = SharedMemoryArray.createSHMAFromRAI(((Tensor)entry.getValue()).getData(), false, false);
                this.shmaList.add(shma);
                this.apposeInputMap.put(entry.getKey(), null);
                this.addCodeToRecreateTensor(entry.getKey(), shma, (Tensor)entry.getValue());
                continue;
            }
            if (entry.getValue() instanceof RandomAccessibleInterval) {
                shma = SharedMemoryArray.createSHMAFromRAI((RandomAccessibleInterval)entry.getValue(), false, false);
                this.shmaList.add(shma);
                this.apposeInputMap.put(entry.getKey(), null);
                this.addCodeToRecreateNumpyArray(entry.getKey(), shma, (RandomAccessibleInterval)entry.getValue());
                continue;
            }
            if (!entry.getValue().getClass().isArray() && RunMode.isTypeDirectlySupported(entry.getValue().getClass())) {
                this.apposeInputMap.put(entry.getKey(), entry.getValue());
                continue;
            }
            if (entry.getValue().getClass().isArray() && RunMode.isTypeDirectlySupported(entry.getValue().getClass().getComponentType())) {
                this.apposeInputMap.put(entry.getKey(), entry.getValue());
                continue;
            }
            if (entry.getValue() instanceof List && ((List)entry.getValue()).size() == 0) {
                this.apposeInputMap.put(entry.getKey(), new Object[0]);
                continue;
            }
            if (entry.getValue() instanceof List && RunMode.isTypeDirectlySupported(((List)entry.getValue()).get(0).getClass())) {
                this.apposeInputMap.put(entry.getKey(), ((List)entry.getValue()).toArray());
                continue;
            }
            throw new IllegalArgumentException("The type of the input argument: '" + entry.getKey() + "' is not supported (" + entry.getValue().getClass());
        }
    }

    private static boolean isTypeDirectlySupported(Class<?> cl) {
        return Number.class.isAssignableFrom(cl) || cl.isPrimitive() || String.class.isAssignableFrom(cl);
    }

    private <T extends RealType<T> & NativeType<T>> void addCodeToRecreateTensorFile(String ogName, Tensor<T> tensor, String filename) {
        if (!this.importsCode.contains(IMPORT_XARRAY)) {
            this.importsCode = this.importsCode + IMPORT_XARRAY;
        }
        if (!this.importsCode.contains(IMPORT_NUMPY)) {
            this.importsCode = this.importsCode + IMPORT_NUMPY;
        }
        this.tensorRecreationCode = this.tensorRecreationCode + ogName + " = xr.DataArray(np.load(r'" + filename + "'), dims=[";
        for (String ss : tensor.getAxesOrderString().split("")) {
            this.tensorRecreationCode = this.tensorRecreationCode + "\"" + ss + "\", ";
        }
        this.tensorRecreationCode = this.tensorRecreationCode.substring(0, this.tensorRecreationCode.length() - 2);
        this.tensorRecreationCode = this.tensorRecreationCode + "], name=\"" + tensor.getName() + "\")";
        this.tensorRecreationCode = this.tensorRecreationCode + System.lineSeparator();
    }

    private <T extends RealType<T> & NativeType<T>> void addCodeToRecreateTensor(String ogName, Tensor<T> tensor) {
        if (!this.importsCode.contains(IMPORT_XARRAY)) {
            this.importsCode = this.importsCode + IMPORT_XARRAY;
        }
        if (!this.importsCode.contains(IMPORT_NUMPY)) {
            this.importsCode = this.importsCode + IMPORT_NUMPY;
        }
        this.tensorRecreationCode = this.tensorRecreationCode + ogName + " = xr.DataArray(np.array(" + ogName + ").reshape([";
        for (int ll : tensor.getShape()) {
            this.tensorRecreationCode = this.tensorRecreationCode + ll + ", ";
        }
        this.tensorRecreationCode = this.tensorRecreationCode.substring(0, this.tensorRecreationCode.length() - 2);
        this.tensorRecreationCode = this.tensorRecreationCode + "]), dims=[";
        for (String ss : tensor.getAxesOrderString().split("")) {
            this.tensorRecreationCode = this.tensorRecreationCode + "\"" + ss + "\", ";
        }
        this.tensorRecreationCode = this.tensorRecreationCode.substring(0, this.tensorRecreationCode.length() - 2);
        this.tensorRecreationCode = this.tensorRecreationCode + "], name=\"" + tensor.getName() + "\")";
        this.tensorRecreationCode = this.tensorRecreationCode + System.lineSeparator();
    }

    private <T extends RealType<T> & NativeType<T>> void addCodeToRecreateNumpyArray(String ogName, RandomAccessibleInterval<T> rai) {
        if (!this.importsCode.contains(IMPORT_NUMPY)) {
            this.importsCode = this.importsCode + IMPORT_NUMPY;
        }
        this.tensorRecreationCode = this.tensorRecreationCode + ogName + " = np.array(" + ogName + ").reshape([";
        for (long ll : rai.dimensionsAsLongArray()) {
            this.tensorRecreationCode = this.tensorRecreationCode + ll + ", ";
        }
        this.tensorRecreationCode = this.tensorRecreationCode.substring(0, this.tensorRecreationCode.length() - 2);
        this.tensorRecreationCode = this.tensorRecreationCode + "])" + System.lineSeparator();
    }

    private <T extends RealType<T> & NativeType<T>> void addCodeToRecreateTensor(String ogName, SharedMemoryArray shma, Tensor<T> tensor) {
        long[] dims;
        if (!this.importsCode.contains(IMPORT_XARRAY)) {
            this.importsCode = this.importsCode + IMPORT_XARRAY;
        }
        if (!this.shmInstancesCode.contains(IMPORT_SHM)) {
            this.shmInstancesCode = this.shmInstancesCode + IMPORT_SHM;
        }
        this.shmInstancesCode = this.shmInstancesCode + ogName + APPOSE_SHM_KEY + " = shared_memory.SharedMemory(name='" + shma.getNameForPython() + "', size=" + shma.getSize() + ")" + System.lineSeparator();
        this.shmInstancesCode = this.shmInstancesCode + "shm_out_list.append(" + ogName + APPOSE_SHM_KEY + ")" + System.lineSeparator();
        this.shmInstancesCode = this.shmInstancesCode + ogName + APPOSE_SHM_KEY + ".unlink()" + System.lineSeparator();
        int size = 1;
        for (long l : dims = tensor.getData().dimensionsAsLongArray()) {
            size = (int)((long)size * l);
        }
        this.tensorRecreationCode = this.tensorRecreationCode + ogName + " = xr.DataArray(np.ndarray(" + size + ", dtype='" + CommonUtils.getDataTypeFromRAI(tensor.getData()) + "', buffer=" + ogName + APPOSE_SHM_KEY + ".buf).reshape([";
        for (long ll : dims) {
            this.tensorRecreationCode = this.tensorRecreationCode + ll + ", ";
        }
        this.tensorRecreationCode = this.tensorRecreationCode.substring(0, this.tensorRecreationCode.length() - 2);
        this.tensorRecreationCode = this.tensorRecreationCode + "]), dims=[";
        for (String ss : tensor.getAxesOrderString().split("")) {
            this.tensorRecreationCode = this.tensorRecreationCode + "\"" + ss + "\", ";
        }
        this.tensorRecreationCode = this.tensorRecreationCode.substring(0, this.tensorRecreationCode.length() - 2);
        this.tensorRecreationCode = this.tensorRecreationCode + "], name=\"" + tensor.getName() + "\")";
        this.tensorRecreationCode = this.tensorRecreationCode + System.lineSeparator();
    }

    private <T extends RealType<T> & NativeType<T>> void addCodeToRecreateNumpyArray(String ogName, SharedMemoryArray shma, RandomAccessibleInterval<T> rai) {
        long[] dims;
        if (!this.importsCode.contains(IMPORT_NUMPY)) {
            this.importsCode = this.importsCode + IMPORT_NUMPY;
        }
        if (!this.shmInstancesCode.contains(IMPORT_SHM)) {
            this.shmInstancesCode = this.shmInstancesCode + IMPORT_SHM;
        }
        this.shmInstancesCode = this.shmInstancesCode + ogName + APPOSE_SHM_KEY + " = shared_memory.SharedMemory(name='" + shma.getNameForPython() + "', size=" + shma.getSize() + ")" + System.lineSeparator();
        this.shmInstancesCode = this.shmInstancesCode + "shm_out_list.append(" + ogName + APPOSE_SHM_KEY + ")" + System.lineSeparator();
        this.shmInstancesCode = this.shmInstancesCode + ogName + APPOSE_SHM_KEY + ".unlink()" + System.lineSeparator();
        int size = 1;
        for (long l : dims = rai.dimensionsAsLongArray()) {
            size = (int)((long)size * l);
        }
        this.tensorRecreationCode = this.tensorRecreationCode + ogName + " = np.ndarray(" + size + ", dtype='" + CommonUtils.getDataTypeFromRAI(rai) + "', buffer=" + ogName + APPOSE_SHM_KEY + ".buf).reshape([";
        for (long ll : dims) {
            this.tensorRecreationCode = this.tensorRecreationCode + ll + ", ";
        }
        this.tensorRecreationCode = this.tensorRecreationCode.substring(0, this.tensorRecreationCode.length() - 2);
        this.tensorRecreationCode = this.tensorRecreationCode + "])" + System.lineSeparator();
    }

    private void opExecutionCode() {
        this.opMethodCode = "";
        this.opMethodCode = this.opMethodCode + "task.update('method')" + System.lineSeparator();
        for (String outN : this.outputNames) {
            this.opMethodCode = this.opMethodCode + outN + ", ";
        }
        this.opMethodCode = this.opMethodCode.substring(0, this.opMethodCode.length() - 2);
        this.opMethodCode = this.opMethodCode + " = ";
        this.opMethodCode = this.opMethodCode + this.moduleName + "." + this.op.getMethodName() + "(";
        for (String key : this.apposeInputMap.keySet()) {
            this.opMethodCode = this.opMethodCode + key + ",";
        }
        this.opMethodCode = this.opMethodCode + ")" + System.lineSeparator();
    }

    private void retrieveResultsCode() {
        this.retrieveResultsCode = "task.update('Preparing outputs')" + System.lineSeparator();
        for (String outN : this.outputNames) {
            String code = String.format(OUTPUT_REFORMATING, outN, outN, outN, outN, outN, outN, outN, outN, outN, outN, outN, outN, outN, outN, outN);
            this.retrieveResultsCode = this.retrieveResultsCode + code;
            this.taskOutputCode = this.taskOutputCode + String.format("task.outputs['%s'] = %s", outN, outN) + System.lineSeparator();
        }
    }

    private static <T extends RealType<T> & NativeType<T>> Tensor<T> createTensorFromApposeOutput(Map<String, Object> apposeTensor) {
        SharedMemoryArray shm;
        String shmName = (String)apposeTensor.get("data");
        List shape = (List)apposeTensor.get("shape");
        long[] longShape = new long[shape.size()];
        for (int i = 0; i < shape.size(); ++i) {
            longShape[i] = ((Number)shape.get(i)).longValue();
        }
        String dtype = (String)apposeTensor.get("dtype");
        String tensorname = (String)apposeTensor.get("name");
        String axes = (String)apposeTensor.get("axes");
        boolean isFortran = (Boolean)apposeTensor.get("is_fortran");
        try {
            shm = SharedMemoryArray.readOrCreate(shmName, longShape, (RealType)Cast.unchecked(CommonUtils.getImgLib2DataType(dtype)), isFortran, false);
        }
        catch (FileAlreadyExistsException e) {
            throw new RuntimeException("Error retrieving the image from Python", e);
        }
        Object rai = shm.getSharedRAI();
        rai = Tensor.createCopyOfRaiInWantedDataType(rai, (RealType)Util.getTypeFromInterval(rai));
        try {
            shm.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return Tensor.build(tensorname, axes, rai);
    }

    private static <T extends RealType<T> & NativeType<T>> Tensor<T> createTensorFromApposeOutputFile(Map<String, Object> apposeTensor) throws FileNotFoundException, IOException {
        String tensorname = (String)apposeTensor.get("name");
        String axes = (String)apposeTensor.get("axes");
        String fileName = (String)apposeTensor.get("file_path");
        RandomAccessibleInterval rai = SharedMemoryFile.buildRaiFromFile(fileName);
        return Tensor.build(tensorname, axes, rai);
    }

    private static <T extends RealType<T> & NativeType<T>> RandomAccessibleInterval<T> createImgLib2ArrFromApposeOutput(Map<String, Object> apposeTensor) {
        SharedMemoryArray shm;
        String shmName = (String)apposeTensor.get("data");
        List shape = (List)apposeTensor.get("shape");
        long[] longShape = new long[shape.size()];
        for (int i = 0; i < shape.size(); ++i) {
            longShape[i] = ((Number)shape.get(i)).longValue();
        }
        String dtype = (String)apposeTensor.get("dtype");
        boolean isFortran = (Boolean)apposeTensor.get("is_fortran");
        try {
            shm = SharedMemoryArray.readOrCreate(shmName, longShape, (RealType)Cast.unchecked(CommonUtils.getImgLib2DataType(dtype)), isFortran, false);
        }
        catch (FileAlreadyExistsException e) {
            throw new RuntimeException("Error retrieving the image from Python", e);
        }
        Object rai = shm.getSharedRAI();
        rai = Tensor.createCopyOfRaiInWantedDataType(rai, (RealType)Util.getTypeFromInterval(rai));
        try {
            shm.close();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        return rai;
    }

    private static List<Object> createListFromApposeOutput(List<Object> list) throws FileNotFoundException, IOException {
        ArrayList<Object> nList = new ArrayList<Object>();
        for (Object value : list) {
            if (value instanceof Map && ((Map)value).get(RunModeScripts.APPOSE_DT_KEY) != null && ((Map)value).get(RunModeScripts.APPOSE_DT_KEY).equals("tensor") || value instanceof Map && ((Map)value).get(RunModeScripts.APPOSE_DT_KEY) != null && ((Map)value).get(RunModeScripts.APPOSE_DT_KEY).equals("np_arr")) continue;
            if (value instanceof Map) {
                nList.add(RunMode.recreateOutputObjects((Map)value));
                continue;
            }
            if (value instanceof List) {
                nList.add(RunMode.createListFromApposeOutput((List)value));
                continue;
            }
            if (RunMode.isTypeDirectlySupported(value.getClass())) {
                nList.add(value);
                continue;
            }
            throw new IllegalArgumentException("Type of output not supported (" + value.getClass() + ").");
        }
        return nList;
    }
}

