/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.model.special.cellpose;

import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
import io.bioimage.modelrunner.bioimageio.description.weights.ModelWeight;
import io.bioimage.modelrunner.bioimageio.description.weights.WeightFormat;
import io.bioimage.modelrunner.bioimageio.tiling.TileCalculator;
import io.bioimage.modelrunner.download.MultiFileDownloader;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.python.BioimageIoModelPytorchProtected;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import net.imglib2.FinalInterval;
import net.imglib2.Interval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImg;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

public class Cellpose
extends BioimageIoModelPytorchProtected {
    protected boolean isBMZ;
    protected int[] channels;
    private Float diameter;
    private String rdfString;
    private boolean is3D = false;
    private static final List<String> PRETRAINED_CELLPOSE_MODELS = Arrays.asList("cyto", "cyto2", "cyto3", "nuclei");
    private static final String CELLPOSE_URL = "https://www.cellpose.org/models/%s";
    private static final Map<String, String[]> MODEL_REQ = new HashMap<String, String[]>();
    private static final Map<String, String> ALIAS;
    private static final Map<String, Long> MODEL_SIZE;
    protected static final String LOAD_MODEL_CODE_ABSTRACT;
    protected static final String PATH_TO_RDF = "special_models/cellpose/rdf.yaml";
    protected static final URL RDF_URL;
    private static final String ONE_CHANNEL_STR = "ch_0";
    private static final String TWO_CHANNEL_STR = "ch_0, ch_1";
    private static final String THREE_CHANNEL_STR = "ch_0, ch_1, ch_3";

    protected Cellpose(String modelFile, String callable, String weightsPath, Map<String, Object> kwargs, ModelDescriptor descriptor) throws IOException {
        super(modelFile, callable, null, weightsPath, kwargs, descriptor, true);
        this.createPythonService();
    }

    public void setChannels(int[] channels) {
        this.channels = channels;
    }

    public void setDiameter(float diameter) {
        this.diameter = Float.valueOf(diameter);
    }

    public Float getDiameter() {
        return this.diameter;
    }

    private static <T extends RealType<T> & NativeType<T>> boolean isRedChannelEmpty(RandomAccessibleInterval<T> image) {
        return true;
    }

    protected <R extends RealType<R> & NativeType<R>> List<Tensor<R>> checkInputTensors(List<Tensor<R>> inputTensors) {
        if (inputTensors.size() > 1) {
            throw new IllegalArgumentException("The input tensor list should contain just one tensor");
        }
        if (!inputTensors.get(0).getAxesOrderString().equals("xy") && !inputTensors.get(0).getAxesOrderString().equals("xyc")) {
            throw new IllegalArgumentException("The input axes should be 'xyc'");
        }
        long[] dims = inputTensors.get(0).getData().dimensionsAsLongArray();
        if (dims.length == 2) {
            FinalInterval interval = new FinalInterval(new long[3], new long[]{dims[0], dims[1], 1L});
            IntervalView nData = Views.interval(inputTensors.get(0).getData(), (Interval)interval);
            inputTensors.set(0, Tensor.build(inputTensors.get(0).getName(), "xyc", nData));
        } else if (dims.length == 3 && dims[2] != 3L && dims[2] != 1L) {
            throw new IllegalArgumentException("Only 1 and 3 channel images supported. The provided input has " + dims[2]);
        }
        return inputTensors;
    }

    protected <T extends RealType<T> & NativeType<T>> List<Tensor<T>> checkOutputTensors(List<Tensor<T>> outputTensors) {
        return outputTensors;
    }

    private <R extends RealType<R> & NativeType<R>> void createCustomDescriptor(List<Tensor<R>> inputTensors) {
        int nChannels = 1;
        String axesOrder = inputTensors.get(0).getAxesOrderString().toLowerCase();
        if (axesOrder.contains("c") && inputTensors.get(0).getData().dimensionsAsLongArray()[axesOrder.indexOf("c")] == 3L) {
            nChannels = 3;
        } else if (axesOrder.contains("c") && inputTensors.get(0).getData().dimensionsAsLongArray()[axesOrder.indexOf("c")] != 1L) {
            throw new IllegalArgumentException("Inputs to cellpose model can only have either 1 or 3 channels.");
        }
        String weightsName = new File(this.weightsPath).getName();
        String adaptedRdfString = nChannels == 1 ? String.format(this.rdfString, ONE_CHANNEL_STR, ONE_CHANNEL_STR, weightsName) : String.format(this.rdfString, THREE_CHANNEL_STR, TWO_CHANNEL_STR, weightsName);
        this.descriptor = ModelDescriptorFactory.readFromYamlTextString(adaptedRdfString);
        this.descriptor.addModelPath(Paths.get(new File(this.weightsPath).getParentFile().getAbsolutePath(), new String[0]));
        this.tileCalculator = TileCalculator.init(this.descriptor);
    }

    @Override
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<T>> run(List<Tensor<R>> inputTensors) throws RunModelException {
        this.createCustomDescriptor(inputTensors);
        return super.run(this.checkInputTensors(inputTensors));
    }

    @Override
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors) throws RunModelException {
        this.createCustomDescriptor(inputTensors);
        super.run(this.checkInputTensors(inputTensors), this.checkOutputTensors(outputTensors));
    }

    @Override
    protected String buildModelCode() {
        if (this.isBMZ) {
            return super.buildModelCode();
        }
        String code = String.format(LOAD_MODEL_CODE_ABSTRACT, "False", this.weightsPath);
        return code;
    }

    @Override
    protected <T extends RealType<T> & NativeType<T>> String createInputsCode(List<RandomAccessibleInterval<T>> inRais, List<String> names) {
        int i;
        if (this.isBMZ) {
            return super.createInputsCode(inRais, names);
        }
        String code = "";
        for (i = 0; i < inRais.size(); ++i) {
            SharedMemoryArray shma = SharedMemoryArray.createSHMAFromRAI(inRais.get(i), false, false);
            code = code + Cellpose.codeToConvertShmaToPython(shma, names.get(i));
            this.inShmaList.add(shma);
        }
        code = code + "print(type(" + names.get(0) + "))" + System.lineSeparator();
        code = code + "print(" + names.get(0) + ".shape)" + System.lineSeparator();
        code = code + OUTPUT_LIST_KEY + " = " + MODEL_VAR_NAME + ".eval(";
        for (i = 0; i < inRais.size(); ++i) {
            code = code + names.get(i) + ", channels=" + this.createChannelsArgCode(inRais.get(i)) + ", ";
        }
        code = code + "diameter=" + this.createDiamCode() + ")" + System.lineSeparator();
        code = code + "" + SHMS_KEY + " = []" + System.lineSeparator() + SHM_NAMES_KEY + " = []" + System.lineSeparator() + DTYPES_KEY + " = []" + System.lineSeparator() + DIMS_KEY + " = []" + System.lineSeparator() + "globals()['" + SHMS_KEY + "'] = " + SHMS_KEY + System.lineSeparator() + "globals()['" + SHM_NAMES_KEY + "'] = " + SHM_NAMES_KEY + System.lineSeparator() + "globals()['" + DTYPES_KEY + "'] = " + DTYPES_KEY + System.lineSeparator() + "globals()['" + DIMS_KEY + "'] = " + DIMS_KEY + System.lineSeparator();
        code = code + "handle_output_list(" + OUTPUT_LIST_KEY + ")" + System.lineSeparator();
        code = code + this.taskOutputsCode();
        return code;
    }

    protected <T extends RealType<T> & NativeType<T>> String createChannelsArgCode(RandomAccessibleInterval<T> rai) {
        long[] dims = rai.dimensionsAsLongArray();
        if (this.channels == null && dims.length == 2) {
            return "[0, 0]";
        }
        if (this.channels == null && dims.length == 3 && dims[2] == 1L) {
            return "[0, 0]";
        }
        if (this.channels == null && dims.length == 3 && dims[2] == 1L) {
            return "[0, 0]";
        }
        if (this.channels == null && dims.length == 3 && dims[2] == 3L && Cellpose.isRedChannelEmpty(rai)) {
            return "[2, 3]";
        }
        if (this.channels == null && dims.length == 3 && dims[2] == 3L) {
            return "[2, 1]";
        }
        if (this.channels != null) {
            return Arrays.toString(this.channels);
        }
        throw new IllegalArgumentException("Bad configuration, dims=" + Arrays.toString(dims) + ", channels=" + Arrays.toString(this.channels));
    }

    protected String createDiamCode() {
        if (this.diameter == null) {
            return "None";
        }
        return "" + this.diameter;
    }

    public static Cellpose init(String weightsPath) throws IOException {
        File wFile = new File(weightsPath);
        if (wFile.isDirectory() && new File(wFile, "rdf.yaml").isFile()) {
            return Cellpose.init(ModelDescriptorFactory.readFromLocalFile(new File(wFile, "rdf.yaml").getAbsolutePath()));
        }
        if (!wFile.isFile()) {
            throw new IllegalArgumentException("The path provided does not correspond to an existing file: " + weightsPath);
        }
        Cellpose cellpose = new Cellpose(null, null, weightsPath, null, null);
        try (InputStream in = RDF_URL.openStream();
             ByteArrayOutputStream baos = new ByteArrayOutputStream();){
            int len;
            byte[] buffer = new byte[8192];
            while ((len = in.read(buffer)) != -1) {
                baos.write(buffer, 0, len);
            }
            cellpose.rdfString = baos.toString(StandardCharsets.UTF_8.name());
        }
        catch (IOException iOException) {
            // empty catch block
        }
        return cellpose;
    }

    public static Cellpose init(ModelDescriptor descriptor) throws IOException {
        if (descriptor.getTags().stream().filter(tt -> tt.toLowerCase().equals("cellpose")).findFirst().orElse(null) == null && !descriptor.getName().toLowerCase().contains("cellpose")) {
            throw new RuntimeException("This model does not seem to be a cellpose model from the Bioimage.io");
        }
        if (descriptor.getWeights().getModelWeights(ModelWeight.getPytorchID()) == null) {
            throw new IllegalArgumentException("The model provided does not have weights in the required format, " + ModelWeight.getPytorchID() + ".");
        }
        WeightFormat pytorchWeights = descriptor.getWeights().getModelWeights(ModelWeight.getPytorchID());
        String modelFile = descriptor.getModelPath() + File.separator + pytorchWeights.getArchitecture().getSource();
        String callable = pytorchWeights.getArchitecture().getCallable();
        String weightsFile = descriptor.getModelPath() + File.separator + pytorchWeights.getSource();
        Map<String, Object> kwargs = pytorchWeights.getArchitecture().getKwargs();
        Cellpose model = new Cellpose(modelFile, callable, weightsFile, kwargs, descriptor);
        model.isBMZ = true;
        return model;
    }

    public static Cellpose fromPretained(String pretrainedModel, boolean install) throws IOException, InterruptedException, ExecutionException {
        return Cellpose.fromPretained(pretrainedModel, new File("models").getAbsolutePath(), install);
    }

    public static Cellpose fromPretained(String pretrainedModel, String modelsDir, boolean install) throws IOException, InterruptedException, ExecutionException {
        if (PRETRAINED_CELLPOSE_MODELS.contains(pretrainedModel) && !install) {
            String weightsPath = Cellpose.fileIsCellpose(pretrainedModel, modelsDir);
            if (weightsPath != null) {
                return Cellpose.init(weightsPath);
            }
            return null;
        }
        if (PRETRAINED_CELLPOSE_MODELS.contains(pretrainedModel)) {
            String path = Cellpose.donwloadPretrainedOfficial(pretrainedModel, modelsDir, null);
            return Cellpose.init(path);
        }
        if (!install) {
            List<ModelDescriptor> localModels = ModelDescriptorFactory.getModelsAtLocalRepo();
            ModelDescriptor model = localModels.stream().filter(md -> md.getModelID().equals(pretrainedModel) || md.getName().toLowerCase().equals(pretrainedModel.toLowerCase())).findFirst().orElse(null);
            if (model != null) {
                return Cellpose.init(model);
            }
            return null;
        }
        BioimageioRepo br = BioimageioRepo.connect();
        ModelDescriptor descriptor = br.selectByName(pretrainedModel);
        if (descriptor == null) {
            descriptor = br.selectByID(pretrainedModel);
        }
        if (descriptor == null) {
            throw new IllegalArgumentException("The model does not correspond to on of the available pretrained cellpose models. To find a list of available cellpose models, please run Cellpose.getPretrainedList()");
        }
        String path = BioimageioRepo.downloadModel(descriptor, modelsDir);
        descriptor.addModelPath(Paths.get(path, new String[0]));
        return Cellpose.init(descriptor);
    }

    public static String findPretrainedModelInstalled(String modelName, String modelsDir) {
        if (modelName.endsWith(".pth")) {
            modelName = modelName.substring(0, modelName.length() - 4);
        } else if (modelName.endsWith(".pt")) {
            modelName = modelName.substring(0, modelName.length() - 3);
        }
        if (ALIAS.keySet().contains(modelName) || MODEL_SIZE.containsKey(modelName)) {
            for (String dir : Cellpose.findDirectoriesWithPattern(modelsDir, modelName)) {
                String path = Cellpose.lookForModelInDir(modelName, dir);
                if (path == null) continue;
                return path;
            }
        } else {
            throw new IllegalArgumentException("Only supported pretrained models are: " + ALIAS.keySet());
        }
        return null;
    }

    private static String fileIsCellpose(String pretrainedModel, String modelsDir) {
        File pretrainedFile = new File(pretrainedModel);
        if (pretrainedFile.isFile() && Cellpose.isCellposeFile(pretrainedFile)) {
            return pretrainedFile.getAbsolutePath();
        }
        String path = Cellpose.findPretrainedModelInstalled(pretrainedModel, modelsDir);
        if (path != null) {
            return path;
        }
        return Cellpose.lookForModelInDir(pretrainedModel, modelsDir);
    }

    private static boolean isCellposeFile(File pretrainedFile) {
        return MODEL_SIZE.keySet().contains(pretrainedFile.getName()) && MODEL_SIZE.get(pretrainedFile.getName()).longValue() == pretrainedFile.length();
    }

    private static String lookForModelInDir(String modelName, String modelsDir) {
        File dir = new File(modelsDir);
        if (!dir.isDirectory()) {
            return null;
        }
        String name = MODEL_SIZE.keySet().contains(modelName) ? ALIAS.entrySet().stream().filter(ee -> ((String)ee.getValue()).equals(modelName)).map(ee -> (String)ee.getKey()).findFirst().get() : modelName;
        String weightsPath = dir.getAbsolutePath() + File.separator + ALIAS.get(name);
        File weigthsFile = new File(weightsPath);
        if (weigthsFile.isFile() && weigthsFile.length() == MODEL_SIZE.get(ALIAS.get(name)).longValue()) {
            return weightsPath;
        }
        return null;
    }

    public static List<String> getPretrainedList() {
        List<String> list = new ArrayList<String>();
        try {
            BioimageioRepo br = BioimageioRepo.connect();
            Map<String, ModelDescriptor> models = br.listAllModels(false);
            list = models.entrySet().stream().filter(ee -> ((ModelDescriptor)ee.getValue()).getModelFamily().equals("cellpose")).map(ee -> ((ModelDescriptor)ee.getValue()).getName()).collect(Collectors.toList());
        }
        catch (InterruptedException interruptedException) {
            // empty catch block
        }
        list.addAll(PRETRAINED_CELLPOSE_MODELS);
        return list;
    }

    public static String donwloadPretrained(String modelName, String downloadDir) throws ExecutionException, InterruptedException, IOException {
        return Cellpose.donwloadPretrained(modelName, downloadDir, null);
    }

    public static String donwloadPretrained(String modelName, String downloadDir, Consumer<Double> progressConsumer) throws ExecutionException, InterruptedException, IOException {
        String path = Cellpose.donwloadPretrainedOfficial(modelName, downloadDir, progressConsumer);
        if (path == null) {
            path = Cellpose.donwloadPretrainedBioimageio(modelName, downloadDir, progressConsumer);
        }
        if (path == null) {
            throw new IllegalArgumentException("The model does not correspond to on of the available pretrained cellpose models. To find a list of available cellpose models, please run Cellpose.getPretrainedList()");
        }
        return path;
    }

    private static String donwloadPretrainedBioimageio(String modelName, String downloadDir, Consumer<Double> progressConsumer) throws InterruptedException, IOException {
        BioimageioRepo br = BioimageioRepo.connect();
        ModelDescriptor descriptor = br.selectByName(modelName);
        if (descriptor == null) {
            descriptor = br.selectByID(modelName);
        }
        if (descriptor == null) {
            return null;
        }
        String path = BioimageioRepo.downloadModel(descriptor, downloadDir, progressConsumer);
        return path;
    }

    private static String donwloadPretrainedOfficial(String modelName, String downloadDir, Consumer<Double> progressConsumer) throws ExecutionException, InterruptedException {
        ArrayList<URL> urls = new ArrayList<URL>();
        if (!MODEL_REQ.keySet().contains(modelName)) {
            return null;
        }
        for (String str : MODEL_REQ.get(modelName)) {
            try {
                urls.add(new URL(String.format(CELLPOSE_URL, str)));
            }
            catch (MalformedURLException malformedURLException) {
                // empty catch block
            }
        }
        String fname = MultiFileDownloader.addTimeStampToFileName(modelName, true);
        downloadDir = downloadDir + File.separator + fname;
        MultiFileDownloader mfd = new MultiFileDownloader(urls, new File(downloadDir));
        if (progressConsumer != null) {
            mfd.setPartialProgressConsumer(progressConsumer);
        }
        mfd.download();
        return downloadDir + File.separator + MODEL_REQ.get(modelName)[0];
    }

    private static List<String> findDirectoriesWithPattern(String folderPath, String keyword) {
        String regex = "^" + Pattern.quote(keyword) + "_\\d{8}_\\d{6}$";
        Pattern pattern = Pattern.compile(regex);
        return Arrays.stream(new File(folderPath).listFiles()).filter(File::isDirectory).filter(ff -> pattern.matcher(ff.getName()).matches()).map(ff -> ff.getAbsolutePath()).collect(Collectors.toList());
    }

    public static <T extends RealType<T> & NativeType<T>> void main(String[] args) throws IOException, InterruptedException, ExecutionException, LoadModelException, RunModelException {
        Cellpose model = Cellpose.fromPretained("cyto2", false);
        model.loadModel();
        ArrayImg rai = ArrayImgs.floats((long[])new long[]{512L, 512L, 3L});
        ArrayList rais = new ArrayList();
        rais.add(rai);
        long tt = System.currentTimeMillis();
        List res = model.inference(rais);
        System.out.println(System.currentTimeMillis() - tt);
        tt = System.currentTimeMillis();
        List rees = model.inference(rais);
        System.out.println(System.currentTimeMillis() - tt);
        model.close();
        System.out.println(false);
    }

    static {
        MODEL_REQ.put("cyto2", new String[]{"cyto2torch_0", "size_cyto2torch_0.npy"});
        MODEL_REQ.put("cyto3", new String[]{"cyto3", "size_cyto3.npy"});
        MODEL_REQ.put("cyto", new String[]{"cytotorch_0", "size_cytotorch_0.npy"});
        MODEL_REQ.put("nuclei", new String[]{"nucleitorch_0", "size_nucleitorch_0.npy"});
        ALIAS = new HashMap<String, String>();
        ALIAS.put("cyto2", "cyto2torch_0");
        ALIAS.put("cyto3", "cyto3");
        ALIAS.put("cyto", "cytotorch_0");
        ALIAS.put("nuclei", "nucleitorch_0");
        MODEL_SIZE = new HashMap<String, Long>();
        MODEL_SIZE.put("cyto2torch_0", 26563614L);
        MODEL_SIZE.put("cyto3", 26566255L);
        MODEL_SIZE.put("cytotorch_0", 26563614L);
        MODEL_SIZE.put("nucleitorch_0", 26563614L);
        LOAD_MODEL_CODE_ABSTRACT = "if 'denoise' not in globals().keys():" + System.lineSeparator() + "  from cellpose import denoise" + System.lineSeparator() + "  globals()['denoise'] = denoise" + System.lineSeparator() + "if 'np' not in globals().keys():" + System.lineSeparator() + "  import numpy as np" + System.lineSeparator() + "  globals()['np'] = np" + System.lineSeparator() + "if 'os' not in globals().keys():" + System.lineSeparator() + "  import os" + System.lineSeparator() + "  globals()['os'] = os" + System.lineSeparator() + "if 'shared_memory' not in globals().keys():" + System.lineSeparator() + "  from multiprocessing import shared_memory" + System.lineSeparator() + "  globals()['shared_memory'] = shared_memory" + System.lineSeparator() + MODEL_VAR_NAME + " = denoise.CellposeDenoiseModel(gpu=%s, pretrained_model='%s')" + System.lineSeparator() + "globals()['" + MODEL_VAR_NAME + "'] = " + MODEL_VAR_NAME + System.lineSeparator();
        RDF_URL = Cellpose.class.getClassLoader().getResource(PATH_TO_RDF);
    }
}

