package io.bioimage.modelrunner.model.special.cellpose;

import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
import io.bioimage.modelrunner.bioimageio.description.weights.ModelWeight;
import io.bioimage.modelrunner.bioimageio.description.weights.WeightFormat;
import io.bioimage.modelrunner.bioimageio.tiling.TileCalculator;
import io.bioimage.modelrunner.download.MultiFileDownloader;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.python.BioimageIoModelPytorchProtected;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import io.bioimage.modelrunner.utils.Constants;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import net.imglib2.FinalInterval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImg;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.view.Views;

/* loaded from: input_file:io/bioimage/modelrunner/model/special/cellpose/Cellpose.class */
public class Cellpose extends BioimageIoModelPytorchProtected {
    protected boolean isBMZ;
    protected int[] channels;
    private Float diameter;
    private String rdfString;
    private boolean is3D;
    private static final String CELLPOSE_URL = "https://www.cellpose.org/models/%s";
    private static final Map<String, String> ALIAS;
    private static final Map<String, Long> MODEL_SIZE;
    protected static final String LOAD_MODEL_CODE_ABSTRACT;
    protected static final String PATH_TO_RDF = "special_models/cellpose/rdf.yaml";
    protected static final URL RDF_URL;
    private static final String ONE_CHANNEL_STR = "ch_0";
    private static final String TWO_CHANNEL_STR = "ch_0, ch_1";
    private static final String THREE_CHANNEL_STR = "ch_0, ch_1, ch_3";
    private static final List<String> PRETRAINED_CELLPOSE_MODELS = Arrays.asList("cyto", "cyto2", "cyto3", "nuclei");
    private static final Map<String, String[]> MODEL_REQ = new HashMap();

    protected Cellpose(String str, String str2, String str3, Map<String, Object> map, ModelDescriptor modelDescriptor) throws IOException {
        super(str, str2, null, str3, map, modelDescriptor, true);
        this.is3D = false;
        createPythonService();
    }

    public void setChannels(int[] iArr) {
        this.channels = iArr;
    }

    public void setDiameter(float f) {
        this.diameter = Float.valueOf(f);
    }

    public Float getDiameter() {
        return this.diameter;
    }

    private static <T extends RealType<T> & NativeType<T>> boolean isRedChannelEmpty(RandomAccessibleInterval<T> randomAccessibleInterval) {
        return true;
    }

    protected <R extends RealType<R> & NativeType<R>> List<Tensor<R>> checkInputTensors(List<Tensor<R>> list) {
        if (list.size() > 1) {
            throw new IllegalArgumentException("The input tensor list should contain just one tensor");
        }
        if (!list.get(0).getAxesOrderString().equals("xy") && !list.get(0).getAxesOrderString().equals("xyc")) {
            throw new IllegalArgumentException("The input axes should be 'xyc'");
        }
        long[] dimensionsAsLongArray = list.get(0).getData().dimensionsAsLongArray();
        if (dimensionsAsLongArray.length == 2) {
            list.set(0, Tensor.build(list.get(0).getName(), "xyc", Views.interval(list.get(0).getData(), new FinalInterval(new long[3], new long[]{dimensionsAsLongArray[0], dimensionsAsLongArray[1], 1}))));
        } else if (dimensionsAsLongArray.length == 3 && dimensionsAsLongArray[2] != 3 && dimensionsAsLongArray[2] != 1) {
            throw new IllegalArgumentException("Only 1 and 3 channel images supported. The provided input has " + dimensionsAsLongArray[2]);
        }
        return list;
    }

    protected <T extends RealType<T> & NativeType<T>> List<Tensor<T>> checkOutputTensors(List<Tensor<T>> list) {
        return list;
    }

    private <R extends RealType<R> & NativeType<R>> void createCustomDescriptor(List<Tensor<R>> list) {
        boolean z = true;
        String lowerCase = list.get(0).getAxesOrderString().toLowerCase();
        if (lowerCase.contains("c") && list.get(0).getData().dimensionsAsLongArray()[lowerCase.indexOf("c")] == 3) {
            z = 3;
        } else if (lowerCase.contains("c") && list.get(0).getData().dimensionsAsLongArray()[lowerCase.indexOf("c")] != 1) {
            throw new IllegalArgumentException("Inputs to cellpose model can only have either 1 or 3 channels.");
        }
        String name = new File(this.weightsPath).getName();
        this.descriptor = ModelDescriptorFactory.readFromYamlTextString(z ? String.format(this.rdfString, ONE_CHANNEL_STR, ONE_CHANNEL_STR, name) : String.format(this.rdfString, THREE_CHANNEL_STR, TWO_CHANNEL_STR, name));
        this.descriptor.addModelPath(Paths.get(new File(this.weightsPath).getParentFile().getAbsolutePath(), new String[0]));
        this.tileCalculator = TileCalculator.init(this.descriptor);
    }

    @Override // io.bioimage.modelrunner.model.python.BioimageIoModelPytorchProtected, io.bioimage.modelrunner.model.python.DLModelPytorchProtected, io.bioimage.modelrunner.model.BaseModel
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<T>> run(List<Tensor<R>> list) throws RunModelException {
        createCustomDescriptor(list);
        return super.run(checkInputTensors(list));
    }

    @Override // io.bioimage.modelrunner.model.python.BioimageIoModelPytorchProtected, io.bioimage.modelrunner.model.python.DLModelPytorchProtected, io.bioimage.modelrunner.model.BaseModel
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void run(List<Tensor<T>> list, List<Tensor<R>> list2) throws RunModelException {
        createCustomDescriptor(list);
        super.run(checkInputTensors(list), checkOutputTensors(list2));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.bioimage.modelrunner.model.python.DLModelPytorchProtected
    public String buildModelCode() throws IOException {
        return this.isBMZ ? super.buildModelCode() : String.format(LOAD_MODEL_CODE_ABSTRACT, this.weightsPath);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // io.bioimage.modelrunner.model.python.DLModelPytorchProtected
    public <T extends RealType<T> & NativeType<T>> String createInputsCode(List<RandomAccessibleInterval<T>> list, List<String> list2) {
        if (this.isBMZ) {
            return super.createInputsCode(list, list2);
        }
        String str = ("created_shms = []" + System.lineSeparator()) + "try:" + System.lineSeparator();
        for (int i = 0; i < list.size(); i++) {
            SharedMemoryArray createSHMAFromRAI = SharedMemoryArray.createSHMAFromRAI(list.get(i), false, false);
            str = str + codeToConvertShmaToPython(createSHMAFromRAI, list2.get(i));
            this.inShmaList.add(createSHMAFromRAI);
        }
        String str2 = ((str + "  print(type(" + list2.get(0) + "))" + System.lineSeparator()) + "  print(" + list2.get(0) + ".shape)" + System.lineSeparator()) + "  " + OUTPUT_LIST_KEY + " = " + MODEL_VAR_NAME + ".eval(";
        for (int i2 = 0; i2 < list.size(); i2++) {
            str2 = str2 + list2.get(i2) + ", channels=" + createChannelsArgCode(list.get(i2)) + ", ";
        }
        String str3 = str2 + "diameter=" + createDiamCode() + ")" + System.lineSeparator();
        String closeSHMWin = closeSHMWin();
        return ((((((str3 + "  " + closeSHMWin + System.lineSeparator()) + "except Exception as e:" + System.lineSeparator()) + "  " + closeSHMWin + System.lineSeparator()) + "  raise e" + System.lineSeparator()) + "" + SHMS_KEY + " = []" + System.lineSeparator() + SHM_NAMES_KEY + " = []" + System.lineSeparator() + DTYPES_KEY + " = []" + System.lineSeparator() + DIMS_KEY + " = []" + System.lineSeparator() + "globals()['" + SHMS_KEY + "'] = " + SHMS_KEY + System.lineSeparator() + "globals()['" + SHM_NAMES_KEY + "'] = " + SHM_NAMES_KEY + System.lineSeparator() + "globals()['" + DTYPES_KEY + "'] = " + DTYPES_KEY + System.lineSeparator() + "globals()['" + DIMS_KEY + "'] = " + DIMS_KEY + System.lineSeparator()) + "handle_output_list(" + OUTPUT_LIST_KEY + ")" + System.lineSeparator()) + taskOutputsCode();
    }

    protected <T extends RealType<T> & NativeType<T>> String createChannelsArgCode(RandomAccessibleInterval<T> randomAccessibleInterval) {
        long[] dimensionsAsLongArray = randomAccessibleInterval.dimensionsAsLongArray();
        if (this.channels == null && dimensionsAsLongArray.length == 2) {
            return "[0, 0]";
        }
        if (this.channels == null && dimensionsAsLongArray.length == 3 && dimensionsAsLongArray[2] == 1) {
            return "[0, 0]";
        }
        if (this.channels == null && dimensionsAsLongArray.length == 3 && dimensionsAsLongArray[2] == 1) {
            return "[0, 0]";
        }
        if (this.channels == null && dimensionsAsLongArray.length == 3 && dimensionsAsLongArray[2] == 3 && isRedChannelEmpty(randomAccessibleInterval)) {
            return "[2, 3]";
        }
        if (this.channels == null && dimensionsAsLongArray.length == 3 && dimensionsAsLongArray[2] == 3) {
            return "[2, 1]";
        }
        if (this.channels != null) {
            return Arrays.toString(this.channels);
        }
        throw new IllegalArgumentException("Bad configuration, dims=" + Arrays.toString(dimensionsAsLongArray) + ", channels=" + Arrays.toString(this.channels));
    }

    protected String createDiamCode() {
        return this.diameter == null ? "None" : "" + this.diameter;
    }

    public static Cellpose init(String str) throws IOException {
        File file = new File(str);
        if (file.isDirectory() && new File(file, Constants.RDF_FNAME).isFile()) {
            return init(ModelDescriptorFactory.readFromLocalFile(new File(file, Constants.RDF_FNAME).getAbsolutePath()));
        }
        if (!file.isFile()) {
            throw new IllegalArgumentException("The path provided does not correspond to an existing file: " + str);
        }
        Cellpose cellpose = new Cellpose(null, null, str, null, null);
        try {
            InputStream openStream = RDF_URL.openStream();
            try {
                ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
                try {
                    byte[] bArr = new byte[8192];
                    while (true) {
                        int read = openStream.read(bArr);
                        if (read == -1) {
                            break;
                        }
                        byteArrayOutputStream.write(bArr, 0, read);
                    }
                    cellpose.rdfString = byteArrayOutputStream.toString(StandardCharsets.UTF_8.name());
                    byteArrayOutputStream.close();
                    if (openStream != null) {
                        openStream.close();
                    }
                } catch (Throwable th) {
                    try {
                        byteArrayOutputStream.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                    throw th;
                }
            } finally {
            }
        } catch (IOException e) {
        }
        return cellpose;
    }

    public static Cellpose init(ModelDescriptor modelDescriptor) throws IOException {
        if (modelDescriptor.getTags().stream().filter(str -> {
            return str.toLowerCase().equals(ModelDescriptor.CELLPOSE);
        }).findFirst().orElse(null) == null && !modelDescriptor.getName().toLowerCase().contains(ModelDescriptor.CELLPOSE)) {
            throw new RuntimeException("This model does not seem to be a cellpose model from the Bioimage.io");
        }
        if (modelDescriptor.getWeights().getModelWeights(ModelWeight.getPytorchID()) == null) {
            throw new IllegalArgumentException("The model provided does not have weights in the required format, " + ModelWeight.getPytorchID() + ".");
        }
        WeightFormat modelWeights = modelDescriptor.getWeights().getModelWeights(ModelWeight.getPytorchID());
        Cellpose cellpose = new Cellpose(modelDescriptor.getModelPath() + File.separator + modelWeights.getArchitecture().getSource(), modelWeights.getArchitecture().getCallable(), modelDescriptor.getModelPath() + File.separator + modelWeights.getSource(), modelWeights.getArchitecture().getKwargs(), modelDescriptor);
        cellpose.isBMZ = true;
        return cellpose;
    }

    public static Cellpose fromPretained(String str, boolean z) throws IOException, InterruptedException, ExecutionException {
        return fromPretained(str, new File("models").getAbsolutePath(), z);
    }

    public static Cellpose fromPretained(String str, String str2, boolean z) throws IOException, InterruptedException, ExecutionException {
        if (PRETRAINED_CELLPOSE_MODELS.contains(str) && !z) {
            String fileIsCellpose = fileIsCellpose(str, str2);
            if (fileIsCellpose != null) {
                return init(fileIsCellpose);
            }
            return null;
        }
        if (PRETRAINED_CELLPOSE_MODELS.contains(str)) {
            return init(donwloadPretrainedOfficial(str, str2, null));
        }
        if (!z) {
            ModelDescriptor orElse = ModelDescriptorFactory.getModelsAtLocalRepo().stream().filter(modelDescriptor -> {
                return modelDescriptor.getModelID().equals(str) || modelDescriptor.getName().toLowerCase().equals(str.toLowerCase());
            }).findFirst().orElse(null);
            if (orElse != null) {
                return init(orElse);
            }
            return null;
        }
        BioimageioRepo connect = BioimageioRepo.connect();
        ModelDescriptor selectByName = connect.selectByName(str);
        if (selectByName == null) {
            selectByName = connect.selectByID(str);
        }
        if (selectByName == null) {
            throw new IllegalArgumentException("The model does not correspond to on of the available pretrained cellpose models. To find a list of available cellpose models, please run Cellpose.getPretrainedList()");
        }
        selectByName.addModelPath(Paths.get(BioimageioRepo.downloadModel(selectByName, str2), new String[0]));
        return init(selectByName);
    }

    public static String findPretrainedModelInstalled(String str, String str2) {
        if (str.endsWith(".pth")) {
            str = str.substring(0, str.length() - 4);
        } else if (str.endsWith(".pt")) {
            str = str.substring(0, str.length() - 3);
        }
        if (!ALIAS.keySet().contains(str) && !MODEL_SIZE.containsKey(str)) {
            throw new IllegalArgumentException("Only supported pretrained models are: " + ALIAS.keySet());
        }
        Iterator<String> it = findDirectoriesWithPattern(str2, str).iterator();
        while (it.hasNext()) {
            String lookForModelInDir = lookForModelInDir(str, it.next());
            if (lookForModelInDir != null) {
                return lookForModelInDir;
            }
        }
        return null;
    }

    private static String fileIsCellpose(String str, String str2) {
        File file = new File(str);
        if (file.isFile() && isCellposeFile(file)) {
            return file.getAbsolutePath();
        }
        String findPretrainedModelInstalled = findPretrainedModelInstalled(str, str2);
        return findPretrainedModelInstalled != null ? findPretrainedModelInstalled : lookForModelInDir(str, str2);
    }

    private static boolean isCellposeFile(File file) {
        return MODEL_SIZE.keySet().contains(file.getName()) && MODEL_SIZE.get(file.getName()).longValue() == file.length();
    }

    private static String lookForModelInDir(String str, String str2) {
        File file = new File(str2);
        if (!file.isDirectory()) {
            return null;
        }
        String str3 = MODEL_SIZE.keySet().contains(str) ? (String) ALIAS.entrySet().stream().filter(entry -> {
            return ((String) entry.getValue()).equals(str);
        }).map(entry2 -> {
            return (String) entry2.getKey();
        }).findFirst().get() : str;
        String str4 = file.getAbsolutePath() + File.separator + ALIAS.get(str3);
        File file2 = new File(str4);
        if (file2.isFile() && file2.length() == MODEL_SIZE.get(ALIAS.get(str3)).longValue()) {
            return str4;
        }
        return null;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v13, types: [java.util.List] */
    public static List<String> getPretrainedList() {
        ArrayList arrayList = new ArrayList();
        try {
            arrayList = (List) BioimageioRepo.connect().listAllModels(false).entrySet().stream().filter(entry -> {
                return ((ModelDescriptor) entry.getValue()).getModelFamily().equals(ModelDescriptor.CELLPOSE);
            }).map(entry2 -> {
                return ((ModelDescriptor) entry2.getValue()).getName();
            }).collect(Collectors.toList());
        } catch (InterruptedException e) {
        }
        arrayList.addAll(PRETRAINED_CELLPOSE_MODELS);
        return arrayList;
    }

    public static String donwloadPretrained(String str, String str2) throws ExecutionException, InterruptedException, IOException {
        return donwloadPretrained(str, str2, null);
    }

    public static String donwloadPretrained(String str, String str2, Consumer<Double> consumer) throws ExecutionException, InterruptedException, IOException {
        String donwloadPretrainedOfficial = donwloadPretrainedOfficial(str, str2, consumer);
        if (donwloadPretrainedOfficial == null) {
            donwloadPretrainedOfficial = donwloadPretrainedBioimageio(str, str2, consumer);
        }
        if (donwloadPretrainedOfficial == null) {
            throw new IllegalArgumentException("The model does not correspond to on of the available pretrained cellpose models. To find a list of available cellpose models, please run Cellpose.getPretrainedList()");
        }
        return donwloadPretrainedOfficial;
    }

    private static String donwloadPretrainedBioimageio(String str, String str2, Consumer<Double> consumer) throws InterruptedException, IOException {
        BioimageioRepo connect = BioimageioRepo.connect();
        ModelDescriptor selectByName = connect.selectByName(str);
        if (selectByName == null) {
            selectByName = connect.selectByID(str);
        }
        if (selectByName == null) {
            return null;
        }
        return BioimageioRepo.downloadModel(selectByName, str2, consumer);
    }

    private static String donwloadPretrainedOfficial(String str, String str2, Consumer<Double> consumer) throws ExecutionException, InterruptedException {
        ArrayList arrayList = new ArrayList();
        if (!MODEL_REQ.keySet().contains(str)) {
            return null;
        }
        for (String str3 : MODEL_REQ.get(str)) {
            try {
                arrayList.add(new URL(String.format(CELLPOSE_URL, str3)));
            } catch (MalformedURLException e) {
            }
        }
        String str4 = str2 + File.separator + MultiFileDownloader.addTimeStampToFileName(str, true);
        MultiFileDownloader multiFileDownloader = new MultiFileDownloader(arrayList, new File(str4));
        if (consumer != null) {
            multiFileDownloader.setPartialProgressConsumer(consumer);
        }
        multiFileDownloader.download();
        return str4 + File.separator + MODEL_REQ.get(str)[0];
    }

    private static List<String> findDirectoriesWithPattern(String str, String str2) {
        Pattern compile = Pattern.compile("^" + Pattern.quote(str2) + "_\\d{8}_\\d{6}$");
        return (List) Arrays.stream(new File(str).listFiles()).filter((v0) -> {
            return v0.isDirectory();
        }).filter(file -> {
            return compile.matcher(file.getName()).matches();
        }).map(file2 -> {
            return file2.getAbsolutePath();
        }).collect(Collectors.toList());
    }

    public static <T extends RealType<T> & NativeType<T>> void main(String[] strArr) throws IOException, InterruptedException, ExecutionException, LoadModelException, RunModelException {
        Cellpose fromPretained = fromPretained("cyto2", false);
        fromPretained.loadModel();
        ArrayImg floats = ArrayImgs.floats(new long[]{512, 512, 3});
        ArrayList arrayList = new ArrayList();
        arrayList.add(floats);
        long currentTimeMillis = System.currentTimeMillis();
        fromPretained.inference(arrayList);
        System.out.println(System.currentTimeMillis() - currentTimeMillis);
        long currentTimeMillis2 = System.currentTimeMillis();
        fromPretained.inference(arrayList);
        System.out.println(System.currentTimeMillis() - currentTimeMillis2);
        fromPretained.close();
        System.out.println(false);
    }

    static {
        MODEL_REQ.put("cyto2", new String[]{"cyto2torch_0", "size_cyto2torch_0.npy"});
        MODEL_REQ.put("cyto3", new String[]{"cyto3", "size_cyto3.npy"});
        MODEL_REQ.put("cyto", new String[]{"cytotorch_0", "size_cytotorch_0.npy"});
        MODEL_REQ.put("nuclei", new String[]{"nucleitorch_0", "size_nucleitorch_0.npy"});
        ALIAS = new HashMap();
        ALIAS.put("cyto2", "cyto2torch_0");
        ALIAS.put("cyto3", "cyto3");
        ALIAS.put("cyto", "cytotorch_0");
        ALIAS.put("nuclei", "nucleitorch_0");
        MODEL_SIZE = new HashMap();
        MODEL_SIZE.put("cyto2torch_0", 26563614L);
        MODEL_SIZE.put("cyto3", 26566255L);
        MODEL_SIZE.put("cytotorch_0", 26563614L);
        MODEL_SIZE.put("nucleitorch_0", 26563614L);
        LOAD_MODEL_CODE_ABSTRACT = "if 'denoise' not in globals().keys():" + System.lineSeparator() + "  from cellpose import denoise" + System.lineSeparator() + "  globals()['denoise'] = denoise" + System.lineSeparator() + "if 'np' not in globals().keys():" + System.lineSeparator() + "  import numpy as np" + System.lineSeparator() + "  globals()['np'] = np" + System.lineSeparator() + "if 'os' not in globals().keys():" + System.lineSeparator() + "  import os" + System.lineSeparator() + "  globals()['os'] = os" + System.lineSeparator() + "if 'shared_memory' not in globals().keys():" + System.lineSeparator() + "  from multiprocessing import shared_memory" + System.lineSeparator() + "  globals()['shared_memory'] = shared_memory" + System.lineSeparator() + "gpu_available = False" + System.lineSeparator() + (IS_ARM ? "" : "from torch.backends import mps" + System.lineSeparator() + "if mps.is_built() and mps.is_available():" + System.lineSeparator() + "  gpu_available = True" + System.lineSeparator()) + MODEL_VAR_NAME + " = denoise.CellposeDenoiseModel(gpu=gpu_available, pretrained_model=r'%s')" + System.lineSeparator() + "globals()['" + MODEL_VAR_NAME + "'] = " + MODEL_VAR_NAME + System.lineSeparator();
        RDF_URL = Cellpose.class.getClassLoader().getResource(PATH_TO_RDF);
    }
}
