package io.bioimage.modelrunner.model.python;

import io.bioimage.modelrunner.apposed.appose.Environment;
import io.bioimage.modelrunner.apposed.appose.Mamba;
import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
import io.bioimage.modelrunner.apposed.appose.Service;
import io.bioimage.modelrunner.apposed.appose.Types;
import io.bioimage.modelrunner.bioimageio.tiling.TileInfo;
import io.bioimage.modelrunner.bioimageio.tiling.TileMaker;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.BaseModel;
import io.bioimage.modelrunner.model.java.DLModelJava;
import io.bioimage.modelrunner.system.PlatformDetection;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.utils.CommonUtils;
import java.io.File;
import java.io.IOException;
import java.io.PrintStream;
import java.net.URISyntaxException;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import net.imglib2.Interval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;
import org.apache.commons.compress.archivers.ArchiveException;

/* loaded from: input_file:io/bioimage/modelrunner/model/python/DLModelPytorchProtected.class */
public class DLModelPytorchProtected extends BaseModel {
    protected final String modelFile;
    protected final String callable;
    protected final String importModule;
    protected final String weightsPath;
    protected final Map<String, Object> kwargs;
    protected String envPath;
    private Service python;
    protected List<SharedMemoryArray> inShmaList;
    private List<String> outShmNames;
    private List<String> outShmDTypes;
    private List<long[]> outShmDims;
    protected List<TileInfo> inputTiles;
    protected List<TileInfo> outputTiles;
    protected boolean tiling;
    protected DLModelJava.TilingConsumer tileCounter;
    public static final String COMMON_PYTORCH_ENV_NAME = "biapy";
    protected static final boolean IS_ARM;
    private static final List<String> BIAPY_CONDA_DEPS;
    private static final List<String> BIAPY_PIP_DEPS_TORCH;
    private static final List<String> BIAPY_PIP_DEPS;
    private static final List<String> BIAPY_PIP_ARGS;
    protected static String INSTALLATION_DIR;
    protected static final String MODEL_VAR_NAME;
    protected static final String LOAD_MODEL_CODE_ABSTRACT;
    protected static final String OUTPUT_LIST_KEY;
    protected static final String SHMS_KEY;
    protected static final String SHM_NAMES_KEY;
    protected static final String DTYPES_KEY;
    protected static final String DIMS_KEY;
    protected static final String RECOVER_OUTPUTS_CODE;
    private static final String CLEAN_SHM_CODE;
    private static final String JDLL_UUID;

    /* JADX INFO: Access modifiers changed from: protected */
    public DLModelPytorchProtected(String str, String str2, String str3, String str4, Map<String, Object> map) throws IOException {
        this(str, str2, str3, str4, map, false);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DLModelPytorchProtected(String str, String str2, String str3, String str4, Map<String, Object> map, boolean z) throws IOException {
        this.inShmaList = new ArrayList();
        this.tiling = false;
        if (!z && ((!new File(str).isFile() || !str.endsWith(".py")) && str3 == null)) {
            throw new IllegalArgumentException("The model file does not correspond to an existing .py file.");
        }
        if (!new File(str4).isFile() || (!z && !str4.endsWith(".pt") && !str4.endsWith(".pth"))) {
            throw new IllegalArgumentException("The weights file does not correspond to an existing .pt/.pth file.");
        }
        this.callable = str2;
        if (z || str == null || !new File(str).isFile()) {
            this.modelFile = null;
        } else {
            this.modelFile = new File(str).getAbsolutePath();
        }
        if (z || str3 == null) {
            this.importModule = null;
        } else {
            this.importModule = str3;
        }
        if (new File(str4).isFile()) {
            this.modelFolder = new File(str4).getParentFile().getAbsolutePath();
        } else if (new File(str).isFile()) {
            this.modelFolder = new File(str).getParentFile().getAbsolutePath();
        }
        this.weightsPath = new File(str4).getAbsolutePath();
        this.kwargs = map;
        this.envPath = INSTALLATION_DIR + File.separator + Mamba.ENVS_NAME + File.separator + COMMON_PYTORCH_ENV_NAME;
        createPythonService();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void createPythonService() throws IOException {
        this.python = new Environment() { // from class: io.bioimage.modelrunner.model.python.DLModelPytorchProtected.1
            @Override // io.bioimage.modelrunner.apposed.appose.Environment
            public String base() {
                return DLModelPytorchProtected.this.envPath;
            }
        }.python();
        Service service = this.python;
        PrintStream printStream = System.err;
        Objects.requireNonNull(printStream);
        service.debug(printStream::println);
    }

    public String getEnvPath() {
        return this.envPath;
    }

    public void setCustomEnvPath(String str) throws IOException {
        this.envPath = str;
        this.python.close();
        createPythonService();
    }

    public boolean isTiling() {
        return this.tiling;
    }

    public void setTiling(boolean z) {
        this.tiling = z;
    }

    public void setTileInfo(List<TileInfo> list, List<TileInfo> list2) {
        this.inputTiles = list;
        this.outputTiles = list2;
        this.tiling = true;
    }

    public void setTilingCounter(DLModelJava.TilingConsumer tilingConsumer) {
        this.tileCounter = tilingConsumer;
    }

    @Override // io.bioimage.modelrunner.model.BaseModel
    public void loadModel() throws LoadModelException {
        if (this.loaded) {
            return;
        }
        if (this.closed) {
            throw new RuntimeException("Cannot load model after it has been closed");
        }
        try {
            Service.Task task = this.python.task(buildModelCode() + RECOVER_OUTPUTS_CODE);
            task.waitFor();
            if (task.status == Service.TaskStatus.CANCELED) {
                throw new RuntimeException("Task canceled");
            }
            if (task.status == Service.TaskStatus.FAILED) {
                throw new RuntimeException(task.error);
            }
            if (task.status == Service.TaskStatus.CRASHED) {
                throw new RuntimeException(task.error);
            }
            this.loaded = true;
        } catch (IOException | InterruptedException e) {
            throw new LoadModelException(Types.stackTrace(e));
        }
    }

    private static void copyAndReplace(String str, String str2) throws IOException {
        if (new File(str2).isFile()) {
            return;
        }
        Files.write(Paths.get(str2, new String[0]), Files.readAllBytes(Paths.get(str, new String[0])), new OpenOption[0]);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public String buildModelCode() throws IOException {
        String format;
        String str = "";
        String str2 = "device = 'cpu'" + System.lineSeparator() + "if 'torch' not in globals().keys():" + System.lineSeparator() + "  import torch" + System.lineSeparator() + "  globals()['torch'] = torch" + System.lineSeparator() + (!IS_ARM ? "" : "        if torch.backends.mps.is_built() and torch.backends.mps.is_available():" + System.lineSeparator() + "          device = 'mps'" + System.lineSeparator()) + "globals()['device'] = device" + System.lineSeparator();
        if (this.modelFile != null) {
            String name = new File(this.modelFile).getName();
            String substring = name.substring(0, name.length() - 3);
            if (substring.contains("+")) {
                String replaceAll = this.modelFile.replaceAll("\\+", JDLL_UUID);
                copyAndReplace(this.modelFile, replaceAll);
                String name2 = new File(replaceAll).getName();
                String substring2 = name2.substring(0, name2.length() - 3);
                str = String.format("sys.path.append(os.path.abspath(r'%s'))", new File(replaceAll).getParentFile().getAbsolutePath());
                format = String.format("from %s import %s", substring2, this.callable);
            } else {
                str = String.format("sys.path.append(os.path.abspath(r'%s'))", new File(this.modelFile).getParentFile().getAbsolutePath());
                format = String.format("from %s import %s", substring, this.callable);
            }
        } else {
            format = String.format("from %s import %s", this.importModule, this.callable);
        }
        return (((str2 + String.format(LOAD_MODEL_CODE_ABSTRACT, str, format, this.callable, this.callable, this.callable)) + MODEL_VAR_NAME + "=" + this.callable + "(" + codeForKwargs() + ").to(device)" + System.lineSeparator()) + "try:" + System.lineSeparator() + "  " + MODEL_VAR_NAME + ".load_state_dict(torch.load(r'" + this.weightsPath + "', map_location=" + MODEL_VAR_NAME + ".device))" + System.lineSeparator() + "except:" + System.lineSeparator() + "  " + MODEL_VAR_NAME + ".load_state_dict(torch.load(r'" + this.weightsPath + "', map_location=torch.device(device)))" + System.lineSeparator()) + "globals()['" + MODEL_VAR_NAME + "'] = " + MODEL_VAR_NAME + System.lineSeparator();
    }

    private String codeForKwargsList(List<Object> list) {
        String str = "[";
        Iterator<Object> it = list.iterator();
        while (it.hasNext()) {
            Object next = it.next();
            str = (next == null ? str + "None" : (((next instanceof Boolean) && ((Boolean) next).booleanValue()) || next.equals("true")) ? str + "True" : ((!(next instanceof Boolean) || ((Boolean) next).booleanValue()) && !next.equals("false")) ? next instanceof String ? str + "\"" + next + "\"" : next instanceof List ? str + codeForKwargsList((List) next) : next instanceof Map ? str + codeForKwargsMap((Map) next) : str + next : str + "False") + ",";
        }
        return str + "]";
    }

    private String codeForKwargsMap(Map<String, Object> map) {
        String str = "{";
        for (Map.Entry<String, Object> entry : map.entrySet()) {
            Object value = entry.getValue();
            String str2 = str + "'" + entry.getKey() + "':";
            str = (value == null ? str2 + "None" : (((value instanceof Boolean) && ((Boolean) value).booleanValue()) || value.equals("true")) ? str2 + "True" : ((!(value instanceof Boolean) || ((Boolean) value).booleanValue()) && !value.equals("false")) ? value instanceof String ? str2 + "\"" + value + "\"" : value instanceof List ? str2 + codeForKwargsList((List) value) : value instanceof Map ? str2 + codeForKwargsMap((Map) value) : str2 + value : str2 + "False") + ",";
        }
        return str + "}";
    }

    private String codeForKwargs() {
        String str = "";
        for (Map.Entry<String, Object> entry : this.kwargs.entrySet()) {
            Object value = entry.getValue();
            if (value == null) {
                value = "None";
            } else if (((value instanceof Boolean) && ((Boolean) value).booleanValue()) || value.equals("true")) {
                value = "True";
            } else if (((value instanceof Boolean) && !((Boolean) value).booleanValue()) || value.equals("false")) {
                value = "False";
            } else if (value instanceof String) {
                value = "\"" + value + "\"";
            } else if (value instanceof List) {
                value = codeForKwargsList((List) value);
            } else if (value instanceof Map) {
                value = codeForKwargsMap((Map) value);
            }
            str = str + entry.getKey() + "=" + value + ",";
        }
        return str;
    }

    @Override // io.bioimage.modelrunner.model.BaseModel, java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        if (this.loaded) {
            this.python.close();
            this.loaded = false;
            this.closed = true;
        }
    }

    private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> Map<String, RandomAccessibleInterval<R>> predictForInputTensors(List<Tensor<T>> list) throws RunModelException {
        if (!this.loaded) {
            throw new RuntimeException("Please load the model first.");
        }
        return executeCode(createInputsCode((List) list.stream().map(tensor -> {
            return tensor.getData();
        }).collect(Collectors.toList()), (List) list.stream().map(tensor2 -> {
            return tensor2.getName() + "_np";
        }).collect(Collectors.toList())));
    }

    private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> Map<String, RandomAccessibleInterval<R>> executeCode(String str) throws RunModelException {
        try {
            Service.Task task = this.python.task(str);
            task.waitFor();
            if (task.status == Service.TaskStatus.CANCELED) {
                throw new RuntimeException("Task canceled");
            }
            if (task.status == Service.TaskStatus.FAILED) {
                throw new RuntimeException(task.error);
            }
            if (task.status == Service.TaskStatus.CRASHED) {
                throw new RuntimeException(task.error);
            }
            this.loaded = true;
            Map<String, RandomAccessibleInterval<R>> reconstructOutputs = reconstructOutputs(task);
            cleanShm();
            return reconstructOutputs;
        } catch (IOException | InterruptedException e) {
            try {
                cleanShm();
                throw new RunModelException(Types.stackTrace(e));
            } catch (IOException | InterruptedException e2) {
                throw new RunModelException(Types.stackTrace(e2));
            }
        }
    }

    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<RandomAccessibleInterval<R>> inference(List<RandomAccessibleInterval<T>> list) throws RunModelException {
        if (!this.loaded) {
            throw new RuntimeException("Please load the model first.");
        }
        Map<String, RandomAccessibleInterval<R>> executeCode = executeCode(createInputsCode(list, (List) IntStream.range(0, list.size()).mapToObj(i -> {
            return "var_" + UUID.randomUUID().toString().replace("-", "_");
        }).collect(Collectors.toList())));
        ArrayList arrayList = new ArrayList();
        Iterator<Map.Entry<String, RandomAccessibleInterval<R>>> it = executeCode.entrySet().iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getValue());
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public <T extends RealType<T> & NativeType<T>> String createInputsCode(List<RandomAccessibleInterval<T>> list, List<String> list2) {
        String str = ("created_shms = []" + System.lineSeparator()) + "try:" + System.lineSeparator();
        for (int i = 0; i < list.size(); i++) {
            SharedMemoryArray createSHMAFromRAI = SharedMemoryArray.createSHMAFromRAI(list.get(i), false, false);
            str = str + codeToConvertShmaToPython(createSHMAFromRAI, list2.get(i));
            this.inShmaList.add(createSHMAFromRAI);
        }
        String str2 = str + "  " + OUTPUT_LIST_KEY + " = " + MODEL_VAR_NAME + "(";
        for (int i2 = 0; i2 < list.size(); i2++) {
            str2 = str2 + "torch.from_numpy(" + list2.get(i2) + ").to(device), ";
        }
        String str3 = ((str2.substring(0, str2.length() - 2) + ")" + System.lineSeparator()) + "  " + SHMS_KEY + " = []" + System.lineSeparator() + "  " + SHM_NAMES_KEY + " = []" + System.lineSeparator() + "  " + DTYPES_KEY + " = []" + System.lineSeparator() + "  " + DIMS_KEY + " = []" + System.lineSeparator() + "  globals()['" + SHMS_KEY + "'] = " + SHMS_KEY + System.lineSeparator() + "  globals()['" + SHM_NAMES_KEY + "'] = " + SHM_NAMES_KEY + System.lineSeparator() + "  globals()['" + DTYPES_KEY + "'] = " + DTYPES_KEY + System.lineSeparator() + "  globals()['" + DIMS_KEY + "'] = " + DIMS_KEY + System.lineSeparator()) + "  handle_output_list(" + OUTPUT_LIST_KEY + ")" + System.lineSeparator();
        String closeSHMWin = closeSHMWin();
        return ((((str3 + "  " + closeSHMWin + System.lineSeparator()) + "except Exception as e:" + System.lineSeparator()) + "  " + closeSHMWin + System.lineSeparator()) + "  raise e" + System.lineSeparator()) + taskOutputsCode();
    }

    private static String closeSHMWin() {
        return !PlatformDetection.isWindows() ? "" : "[(shm_i.close(), shm_i.unlink()) for shm_i in created_shms]";
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public String taskOutputsCode() {
        return "task.outputs['" + SHM_NAMES_KEY + "'] = " + SHM_NAMES_KEY + System.lineSeparator() + "task.outputs['" + DTYPES_KEY + "'] = " + DTYPES_KEY + System.lineSeparator() + "task.outputs['" + DIMS_KEY + "'] = " + DIMS_KEY + System.lineSeparator();
    }

    @Override // io.bioimage.modelrunner.model.BaseModel
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<T>> run(List<Tensor<R>> list) throws RunModelException {
        if (!isLoaded()) {
            throw new RunModelException("Please first load the model.");
        }
        if (!this.tiling) {
            throw new UnsupportedOperationException("Cannot run a DLModel if no information about the outputs is provided. Either try with 'run( List< Tensor < T > > inTensors, List< Tensor < R > > outTensors )' or set the tiling information with 'setTileInfo(List<TileInfo> inputTiles, List<TileInfo> outputTiles)'. Another option is to run simple inference over an ImgLib2 RandomAccessibleInterval with 'inference(List<RandomAccessibleInteral<T>> input)'");
        }
        if (isTiling() && (this.inputTiles != null || this.inputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the input tiles are not well defined");
        }
        if (isTiling() && (this.outputTiles == null || this.outputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the output tiles are not well defined");
        }
        TileMaker build = TileMaker.build(this.inputTiles, this.outputTiles);
        List<Tensor<T>> createOutputTensors = createOutputTensors();
        runTiling(list, createOutputTensors, build);
        return createOutputTensors;
    }

    private <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createOutputTensors() {
        ArrayList arrayList = new ArrayList();
        for (TileInfo tileInfo : this.outputTiles) {
            arrayList.add(Tensor.buildBlankTensor(tileInfo.getName(), tileInfo.getImageAxesOrder(), tileInfo.getImageDims(), new FloatType()));
        }
        return arrayList;
    }

    @Override // io.bioimage.modelrunner.model.BaseModel
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void run(List<Tensor<T>> list, List<Tensor<R>> list2) throws RunModelException {
        if (!isLoaded()) {
            throw new RunModelException("Please first load the model.");
        }
        if (!this.tiling) {
            runNoTiles(list, list2);
            return;
        }
        if (isTiling() && (this.inputTiles != null || this.inputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the input tiles are not well defined");
        }
        if (isTiling() && (this.outputTiles == null || this.outputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the output tiles are not well defined");
        }
        TileMaker build = TileMaker.build(this.inputTiles, this.outputTiles);
        for (int i = 0; i < build.getNumberOfTiles(); i++) {
            Tensor<R> tensor = list2.get(i);
            long[] outputImageSize = build.getOutputImageSize(tensor.getName());
            if (outputImageSize == null) {
                throw new IllegalArgumentException("Tensor '" + tensor.getName() + "' is missing in the outputs.");
            }
            if (!tensor.isEmpty() && Arrays.equals(outputImageSize, tensor.getData().dimensionsAsLongArray())) {
                throw new IllegalArgumentException("Tensor '" + tensor.getName() + "' size is different than the expected size defined for the output image: " + Arrays.toString(tensor.getData().dimensionsAsLongArray()) + " vs " + Arrays.toString(outputImageSize) + ".");
            }
        }
        runTiling(list, list2, build);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void runTiling(List<Tensor<R>> list, List<Tensor<T>> list2, TileMaker tileMaker) throws RunModelException {
        for (int i = 0; i < tileMaker.getNumberOfTiles(); i++) {
            int i2 = 0 + i;
            runNoTiles((List) list.stream().map(tensor -> {
                return tileMaker.getNthTileInput(tensor, i2);
            }).collect(Collectors.toList()), (List) list2.stream().map(tensor2 -> {
                return tileMaker.getNthTileOutput(tensor2, i2);
            }).collect(Collectors.toList()));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void runNoTiles(List<Tensor<T>> list, List<Tensor<R>> list2) throws RunModelException {
        int i = 0;
        Iterator<Map.Entry<String, RandomAccessibleInterval<R>>> it = predictForInputTensors(list).entrySet().iterator();
        while (it.hasNext()) {
            try {
                list2.get(i).setData(it.next().getValue());
                i++;
            } catch (Exception e) {
            }
        }
    }

    private void closeShm() throws IOException {
        Iterator<SharedMemoryArray> it = this.inShmaList.iterator();
        while (it.hasNext()) {
            it.next().close();
        }
    }

    private void cleanShm() throws InterruptedException, IOException {
        closeShm();
        if (PlatformDetection.isWindows()) {
            this.python.task(CLEAN_SHM_CODE).waitFor();
        }
    }

    protected <T extends RealType<T> & NativeType<T>> Map<String, RandomAccessibleInterval<T>> reconstructOutputs(Service.Task task) throws IOException {
        buildOutShmList(task);
        buildOutDTypesList(task);
        buildOutDimsList(task);
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        for (int i = 0; i < this.outShmNames.size(); i++) {
            linkedHashMap.put("output_" + i, reconstruct(this.outShmNames.get(i), this.outShmDTypes.get(i), this.outShmDims.get(i)));
        }
        return linkedHashMap;
    }

    private void buildOutShmList(Service.Task task) {
        this.outShmNames = new ArrayList();
        if (!(task.outputs.get(SHM_NAMES_KEY) instanceof List)) {
            throw new RuntimeException("Unexpected type for '" + SHM_NAMES_KEY + "'.");
        }
        for (Object obj : (List) task.outputs.get(SHM_NAMES_KEY)) {
            if (!(obj instanceof String)) {
                throw new RuntimeException("Unexpected type for element of  '" + SHM_NAMES_KEY + "' list.");
            }
            this.outShmNames.add((String) obj);
        }
    }

    private void buildOutDTypesList(Service.Task task) {
        this.outShmDTypes = new ArrayList();
        if (!(task.outputs.get(DTYPES_KEY) instanceof List)) {
            throw new RuntimeException("Unexpected type for '" + DTYPES_KEY + "'.");
        }
        for (Object obj : (List) task.outputs.get(DTYPES_KEY)) {
            if (!(obj instanceof String)) {
                throw new RuntimeException("Unexpected type for element of  '" + DTYPES_KEY + "' list.");
            }
            this.outShmDTypes.add((String) obj);
        }
    }

    private void buildOutDimsList(Service.Task task) {
        this.outShmDims = new ArrayList();
        if (!(task.outputs.get(DIMS_KEY) instanceof List)) {
            throw new RuntimeException("Unexpected type for '" + DIMS_KEY + "'.");
        }
        for (Object obj : (List) task.outputs.get(DIMS_KEY)) {
            if (!(obj instanceof Object[]) && !(obj instanceof List)) {
                throw new RuntimeException("Unexpected type for element of  '" + DIMS_KEY + "' list.");
            }
            if (obj instanceof Object[]) {
                Object[] objArr = (Object[]) obj;
                long[] jArr = new long[objArr.length];
                for (int i = 0; i < objArr.length; i++) {
                    if (!(objArr[i] instanceof Number)) {
                        throw new RuntimeException("Unexpected type for array of element of  '" + DIMS_KEY + "' list.");
                    }
                    jArr[i] = ((Number) objArr[i]).longValue();
                }
                this.outShmDims.add(jArr);
            } else {
                if (!(obj instanceof List)) {
                    throw new RuntimeException("Unexpected type for element of  '" + DIMS_KEY + "' list.");
                }
                List list = (List) obj;
                long[] jArr2 = new long[list.size()];
                for (int i2 = 0; i2 < list.size(); i2++) {
                    if (!(list.get(i2) instanceof Number)) {
                        throw new RuntimeException("Unexpected type for array of element of  '" + DIMS_KEY + "' list.");
                    }
                    jArr2[i2] = ((Number) list.get(i2)).longValue();
                }
                this.outShmDims.add(jArr2);
            }
        }
    }

    /* JADX WARN: Type inference failed for: r1v2, types: [net.imglib2.Interval, java.lang.Object] */
    private <T extends RealType<T> & NativeType<T>> RandomAccessibleInterval<T> reconstruct(String str, String str2, long[] jArr) throws IOException {
        SharedMemoryArray readOrCreate = SharedMemoryArray.readOrCreate(str, jArr, (RealType) Cast.unchecked(CommonUtils.getImgLib2DataType(str2)), false, false);
        RandomAccessibleInterval<T> sharedRAI = readOrCreate.getSharedRAI();
        RandomAccessibleInterval<T> createCopyOfRaiInWantedDataType = Tensor.createCopyOfRaiInWantedDataType((RandomAccessibleInterval) Cast.unchecked(sharedRAI), (RealType) Util.getTypeFromInterval((Interval) Cast.unchecked(sharedRAI)));
        readOrCreate.close();
        return createCopyOfRaiInWantedDataType;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static String codeToConvertShmaToPython(SharedMemoryArray sharedMemoryArray, String str) {
        String str2 = ("  " + str + "_shm = shared_memory.SharedMemory(name='" + sharedMemoryArray.getNameForPython() + "', size=" + sharedMemoryArray.getSize() + ")" + System.lineSeparator()) + "  created_shms.append(" + str + "_shm)" + System.lineSeparator();
        long j = 1;
        for (long j2 : sharedMemoryArray.getOriginalShape()) {
            j *= j2;
        }
        String str3 = str2 + "  " + str + " = np.ndarray(" + j + ", dtype='" + CommonUtils.getDataTypeFromRAI((RandomAccessibleInterval) Cast.unchecked(sharedMemoryArray.getSharedRAI())) + "', buffer=" + str + "_shm.buf).reshape([";
        for (int i = 0; i < sharedMemoryArray.getOriginalShape().length; i++) {
            str3 = str3 + sharedMemoryArray.getOriginalShape()[i] + ", ";
        }
        return str3 + "])" + System.lineSeparator();
    }

    public static boolean isInstalled() {
        return isInstalled(null);
    }

    public static boolean isInstalled(String str) {
        if (str == null) {
            str = COMMON_PYTORCH_ENV_NAME;
        }
        Mamba mamba = new Mamba(INSTALLATION_DIR);
        try {
            boolean checkAllDependenciesInEnv = mamba.checkAllDependenciesInEnv(str, BIAPY_CONDA_DEPS);
            if (!checkAllDependenciesInEnv) {
                return checkAllDependenciesInEnv;
            }
            boolean checkAllDependenciesInEnv2 = mamba.checkAllDependenciesInEnv(str, BIAPY_PIP_DEPS_TORCH);
            if (!checkAllDependenciesInEnv2) {
                return checkAllDependenciesInEnv2;
            }
            boolean checkAllDependenciesInEnv3 = mamba.checkAllDependenciesInEnv(str, BIAPY_PIP_DEPS);
            if (checkAllDependenciesInEnv3) {
                return true;
            }
            return checkAllDependenciesInEnv3;
        } catch (MambaInstallException e) {
            return false;
        }
    }

    public static void installRequirements() throws IOException, InterruptedException, RuntimeException, MambaInstallException, ArchiveException, URISyntaxException {
        installRequirements(null);
    }

    public static void installRequirements(Consumer<String> consumer) throws IOException, InterruptedException, RuntimeException, MambaInstallException, ArchiveException, URISyntaxException {
        Mamba mamba = new Mamba(INSTALLATION_DIR);
        if (consumer != null) {
            mamba.setConsoleOutputConsumer(consumer);
            mamba.setErrorOutputConsumer(consumer);
        }
        boolean z = false;
        try {
            mamba.checkAllDependenciesInEnv(COMMON_PYTORCH_ENV_NAME, BIAPY_CONDA_DEPS);
            mamba.checkAllDependenciesInEnv(COMMON_PYTORCH_ENV_NAME, BIAPY_PIP_DEPS_TORCH);
            z = mamba.checkAllDependenciesInEnv(COMMON_PYTORCH_ENV_NAME, BIAPY_PIP_DEPS);
            if (PlatformDetection.isMacOS() && PlatformDetection.getOSVersion().getMajor() < 14) {
                z = mamba.checkDependencyInEnv(COMMON_PYTORCH_ENV_NAME, "biapy==3.5.10");
            }
        } catch (MambaInstallException e) {
            mamba.installMicromamba();
        }
        if (!z) {
            mamba.create(COMMON_PYTORCH_ENV_NAME, true, new ArrayList(), BIAPY_CONDA_DEPS);
            ArrayList arrayList = new ArrayList(BIAPY_PIP_ARGS);
            arrayList.addAll(BIAPY_PIP_DEPS_TORCH);
            mamba.pipInstallIn(COMMON_PYTORCH_ENV_NAME, (String[]) arrayList.toArray(new String[arrayList.size()]));
            mamba.pipInstallIn(COMMON_PYTORCH_ENV_NAME, (String[]) BIAPY_PIP_DEPS.toArray(new String[BIAPY_PIP_DEPS.size()]));
            if (PlatformDetection.isMacOS() && PlatformDetection.getOSVersion().getMajor() < 14) {
                mamba.pipInstallIn(COMMON_PYTORCH_ENV_NAME, "biapy==3.5.10", "--no-deps");
            }
        }
        if (!isInstalled(INSTALLATION_DIR)) {
            throw new RuntimeException("Not all the requried packages were installed correctly. Please try again. If the error persists, please post an issue at: https://github.com/bioimage-io/JDLL/issues");
        }
    }

    public static void setInstallationDir(String str) {
        INSTALLATION_DIR = str;
    }

    public static String getInstallationDir() {
        return INSTALLATION_DIR;
    }

    static {
        IS_ARM = PlatformDetection.isMacOS() && (PlatformDetection.getArch().equals(PlatformDetection.ARCH_AARCH64) || PlatformDetection.isUsingRosseta());
        BIAPY_CONDA_DEPS = Arrays.asList("python=3.10");
        if (PlatformDetection.isMacOS() && PlatformDetection.getArch().equals(PlatformDetection.ARCH_X86_64) && !PlatformDetection.isUsingRosseta()) {
            BIAPY_PIP_DEPS_TORCH = Arrays.asList("torch==2.2.2", "torchvision==0.17.2", "torchaudio==2.2.2");
        } else if (PlatformDetection.isWindows()) {
            BIAPY_PIP_DEPS_TORCH = Arrays.asList("torch==2.4.1", "torchvision==0.19.1", "torchaudio==2.4.1");
        } else {
            BIAPY_PIP_DEPS_TORCH = Arrays.asList("torch==2.4.0", "torchvision==0.19.0", "torchaudio==2.4.0");
        }
        if (PlatformDetection.isWindows()) {
            BIAPY_PIP_DEPS = Arrays.asList("timm==1.0.14", "pytorch-msssim==1.0.0", "torchmetrics==1.4.3", "cellpose==3.1.1.1", "scipy==1.15.2", "torch-fidelity==0.3.0", "careamics", "biapy==3.5.10", "appose");
        } else if (!PlatformDetection.isMacOS() || PlatformDetection.getOSVersion().getMajor() >= 14) {
            BIAPY_PIP_DEPS = Arrays.asList("timm==1.0.14", "pytorch-msssim==1.0.0", "torchmetrics==1.4.3", "cellpose==3.1.1.1", "scipy==1.15.2", "torch-fidelity==0.3.0", "careamics", "biapy==3.5.10", "appose");
        } else {
            BIAPY_PIP_DEPS = Arrays.asList("timm==1.0.14", "pytorch-msssim==1.0.0", "torchmetrics==1.4.3", "cellpose==3.1.1.1", "torch-fidelity==0.3.0", "careamics", "pooch>=1.8.1", "numpy<2", "imagecodecs>=2024.1.1", "bioimageio.core==0.7.0", "h5py>=3.9.0", "torchinfo>=1.8.0", "pandas>=1.5.3", "xarray==2025.1.2", "fill-voids>=2.0.6", "edt>=2.3.2", "tqdm>=4.66.1", "yacs>=0.1.8", "zarr>=2.16.1", "pydot>=1.4.2", "matplotlib>=3.7.1", "imgaug>=0.4.0", "scipy==1.15.2", "tensorboardX>=2.6.2.2", "scikit-learn>=1.4.0", "opencv-python>=4.8.0.76", "scikit-image>=0.21.0", "appose");
        }
        BIAPY_PIP_ARGS = Arrays.asList("--index-url", "https://download.pytorch.org/whl/cpu");
        INSTALLATION_DIR = Mamba.BASE_PATH;
        MODEL_VAR_NAME = "model_" + UUID.randomUUID().toString().replace("-", "_");
        LOAD_MODEL_CODE_ABSTRACT = "if 'sys' not in globals().keys():" + System.lineSeparator() + "  import sys" + System.lineSeparator() + "  globals()['sys'] = sys" + System.lineSeparator() + "if 'np' not in globals().keys():" + System.lineSeparator() + "  import numpy as np" + System.lineSeparator() + "  globals()['np'] = np" + System.lineSeparator() + "if 'os' not in globals().keys():" + System.lineSeparator() + "  import os" + System.lineSeparator() + "  globals()['os'] = os" + System.lineSeparator() + "if 'shared_memory' not in globals().keys():" + System.lineSeparator() + "  from multiprocessing import shared_memory" + System.lineSeparator() + "  globals()['shared_memory'] = shared_memory" + System.lineSeparator() + "%s" + System.lineSeparator() + "%s" + System.lineSeparator() + "if '%s' not in globals().keys():" + System.lineSeparator() + "  globals()['%s'] = %s" + System.lineSeparator();
        OUTPUT_LIST_KEY = "out_list" + UUID.randomUUID().toString().replace("-", "_");
        SHMS_KEY = "shms_" + UUID.randomUUID().toString().replace("-", "_");
        SHM_NAMES_KEY = "shm_names_" + UUID.randomUUID().toString().replace("-", "_");
        DTYPES_KEY = "dtypes_" + UUID.randomUUID().toString().replace("-", "_");
        DIMS_KEY = "dims_" + UUID.randomUUID().toString().replace("-", "_");
        RECOVER_OUTPUTS_CODE = "def handle_output(outs_i):" + System.lineSeparator() + "    if type(outs_i) == np.ndarray:" + System.lineSeparator() + "      shm = shared_memory.SharedMemory(create=True, size=outs_i.nbytes)" + System.lineSeparator() + "      sh_np_array = np.ndarray(outs_i.shape, dtype=outs_i.dtype, buffer=shm.buf)" + System.lineSeparator() + "      np.copyto(sh_np_array, outs_i)" + System.lineSeparator() + "      " + SHMS_KEY + ".append(shm)" + System.lineSeparator() + "      " + SHM_NAMES_KEY + ".append(shm.name)" + System.lineSeparator() + "      " + DTYPES_KEY + ".append(str(outs_i.dtype))" + System.lineSeparator() + "      " + DIMS_KEY + ".append(outs_i.shape)" + System.lineSeparator() + "    elif str(type(outs_i)) == \"<class 'torch.Tensor'>\":" + System.lineSeparator() + "      if 'torch' not in globals().keys():" + System.lineSeparator() + "        import torch" + System.lineSeparator() + "        globals()['torch'] = torch" + System.lineSeparator() + (!IS_ARM ? "" : "        if torch.backends.mps.is_built() and torch.backends.mps.is_available():" + System.lineSeparator() + "          device = 'mps'" + System.lineSeparator()) + "      else:" + System.lineSeparator() + "        torch = globals()['torch']" + System.lineSeparator() + "      shm = shared_memory.SharedMemory(create=True, size=outs_i.numel() * outs_i.element_size())" + System.lineSeparator() + "      np_arr = np.ndarray(outs_i.shape, dtype=str(outs_i.dtype).split('.')[-1], buffer=shm.buf)" + System.lineSeparator() + "      tensor_np_view = torch.from_numpy(np_arr)" + System.lineSeparator() + "      tensor_np_view.copy_(outs_i)" + System.lineSeparator() + "      " + SHMS_KEY + ".append(shm)" + System.lineSeparator() + "      " + SHM_NAMES_KEY + ".append(shm.name)" + System.lineSeparator() + "      " + DTYPES_KEY + ".append(str(outs_i.dtype).split('.')[-1])" + System.lineSeparator() + "      " + DIMS_KEY + ".append(outs_i.shape)" + System.lineSeparator() + "    elif type(outs_i) == int:" + System.lineSeparator() + "      shm = shared_memory.SharedMemory(create=True, size=8)" + System.lineSeparator() + "      shm.buf[:8] = outs_i.to_bytes(8, byteorder='little', signed=True)" + System.lineSeparator() + "      " + SHMS_KEY + ".append(shm)" + System.lineSeparator() + "      " + SHM_NAMES_KEY + ".append(shm.name)" + System.lineSeparator() + "      " + DTYPES_KEY + ".append('int64')" + System.lineSeparator() + "      " + DIMS_KEY + ".append((1))" + System.lineSeparator() + "    elif type(outs_i) == float:" + System.lineSeparator() + "      shm = shared_memory.SharedMemory(create=True, size=8)" + System.lineSeparator() + "      shm.buf[:8] = outs_i.to_bytes(8, byteorder='little', signed=True)" + System.lineSeparator() + "      " + SHMS_KEY + ".append(shm)" + System.lineSeparator() + "      " + SHM_NAMES_KEY + ".append(shm.name)" + System.lineSeparator() + "      " + DTYPES_KEY + ".append('float64')" + System.lineSeparator() + "      " + DIMS_KEY + ".append((1))" + System.lineSeparator() + "    elif type(outs_i) == tuple or type(outs_i) == list:" + System.lineSeparator() + "      handle_output_list(outs_i)" + System.lineSeparator() + "    else:" + System.lineSeparator() + "      task.update('output type : ' + str(type(outs_i)) + ' not supported. Only supported output types are: np.ndarray, torch.tensor, int and float, or a list or tuple of any of those.')" + System.lineSeparator() + System.lineSeparator() + System.lineSeparator() + "def handle_output_list(out_list):" + System.lineSeparator() + "  if type(out_list) == tuple or type(out_list) == list:" + System.lineSeparator() + "    for outs_i in out_list:" + System.lineSeparator() + "      handle_output(outs_i)" + System.lineSeparator() + "  else:" + System.lineSeparator() + "    handle_output(out_list)" + System.lineSeparator() + "" + System.lineSeparator() + "" + System.lineSeparator() + "globals()['handle_output_list'] = handle_output_list" + System.lineSeparator() + "globals()['handle_output'] = handle_output" + System.lineSeparator() + "" + System.lineSeparator() + "" + System.lineSeparator();
        CLEAN_SHM_CODE = "if '" + SHMS_KEY + "' in globals().keys():" + System.lineSeparator() + "  for s in " + SHMS_KEY + ":" + System.lineSeparator() + "    s.close()" + System.lineSeparator() + "    s.unlink()" + System.lineSeparator() + "    del s" + System.lineSeparator();
        JDLL_UUID = UUID.randomUUID().toString().replaceAll("-", "_");
    }
}
