/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.model.special.stardist;

import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
import io.bioimage.modelrunner.exceptions.LoadEngineException;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.special.stardist.StardistAbstract;
import io.bioimage.modelrunner.tensor.Tensor;
import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;
import net.imglib2.Interval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImg;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;
import org.apache.commons.compress.archivers.ArchiveException;

public class Stardist3D
extends StardistAbstract {
    private static String MODULE_NAME = "StarDist3D";

    protected Stardist3D(String modelName, String baseDir, Map<String, Object> config) throws IOException {
        super(modelName, baseDir, config);
        this.scaleRangeAxes = "xyzc";
    }

    private Stardist3D(String modelName, String baseDir) throws IOException {
        super(modelName, baseDir);
        String axes = ((String)this.config.get("axes")).toUpperCase();
        if (!axes.contains("Z")) {
            throw new IllegalArgumentException("Trying to instantiate a StarDist2D model. Please use Stardist2D instead of Stardist3D.");
        }
        this.scaleRangeAxes = "xyzc";
    }

    private Stardist3D(ModelDescriptor descriptor) throws IOException {
        super(descriptor);
        String axes = ((String)this.config.get("axes")).toUpperCase();
        if (!axes.contains("Z")) {
            throw new IllegalArgumentException("Trying to instantiate a StarDist2D model. Please use Stardist2D instead of Stardist3D.");
        }
        this.scaleRangeAxes = "xyzc";
    }

    @Override
    protected String createImportsCode() {
        return String.format(LOAD_MODEL_CODE_ABSTRACT, MODULE_NAME, MODULE_NAME, MODULE_NAME, MODULE_NAME, MODULE_NAME, this.name, this.basedir);
    }

    @Override
    protected <T extends RealType<T> & NativeType<T>> void checkInput(RandomAccessibleInterval<T> image) {
        if (image.dimensionsAsLongArray().length == 3 && this.nChannels != 1) {
            throw new IllegalArgumentException("Stardist3D needs an image with four dimensions: XYZC");
        }
        if (image.dimensionsAsLongArray().length != 4 && this.nChannels != 1) {
            throw new IllegalArgumentException("Stardist3D needs an image with four dimensions: XYZC");
        }
        if (image.dimensionsAsLongArray().length == 4 && image.dimensionsAsLongArray()[3] != (long)this.nChannels) {
            throw new IllegalArgumentException("This Stardist3D model requires " + this.nChannels + " channels.");
        }
        if (image.dimensionsAsLongArray().length > 4 || image.dimensionsAsLongArray().length < 2) {
            throw new IllegalArgumentException("Stardist3D model requires an image with dimensions XYZC.");
        }
    }

    @Override
    protected <T extends RealType<T> & NativeType<T>> RandomAccessibleInterval<T> reconstructMask() throws IOException {
        RandomAccessibleInterval<RealType> maskCopy = Tensor.createCopyOfRaiInWantedDataType((RandomAccessibleInterval)Cast.unchecked(this.shma.getSharedRAI()), (RealType)Util.getTypeFromInterval((Interval)Cast.unchecked(this.shma.getSharedRAI())));
        this.shma.close();
        return maskCopy;
    }

    @Override
    public boolean is2D() {
        return false;
    }

    @Override
    public boolean is3D() {
        return true;
    }

    public static Stardist3D fromBioimageioModel(ModelDescriptor descriptor) throws IOException {
        if (!descriptor.getConfig().getSpecMap().keySet().contains("stardist")) {
            throw new IllegalArgumentException("This Bioimage.io model does not correspond to a StarDist model.");
        }
        if (!descriptor.getModelFamily().equals("stardist")) {
            throw new RuntimeException("Please first install StarDist with 'StardistAbstract.installRequirements()'");
        }
        if (!descriptor.getInputTensors().get(0).getAxesOrder().contains("z")) {
            throw new IllegalArgumentException("This StarDist model is not 3D");
        }
        return new Stardist3D(descriptor);
    }

    public static Stardist3D fromPretained(String pretrainedModel, boolean install) throws IOException, InterruptedException {
        return Stardist3D.fromPretained(pretrainedModel, new File("models").getAbsolutePath(), install);
    }

    public static Stardist3D fromPretained(String pretrainedModel, String installDir, boolean install) throws IOException, InterruptedException {
        if (pretrainedModel.equals("StarDist Plant Nuclei 3D ResNet") && !install) {
            ModelDescriptor md = ModelDescriptorFactory.getModelsAtLocalRepo().stream().filter(mm -> mm.getName().equals(pretrainedModel)).findFirst().orElse(null);
            if (md != null) {
                return new Stardist3D(md);
            }
            return null;
        }
        if (pretrainedModel.equals("StarDist Plant Nuclei 3D ResNet")) {
            String path = BioimageioRepo.connect().downloadByName("StarDist Plant Nuclei 3D ResNet", installDir);
            return Stardist3D.fromBioimageioModel(ModelDescriptorFactory.readFromLocalFile(path));
        }
        throw new IllegalArgumentException("There is no Stardist3D model called: " + pretrainedModel);
    }

    public static String donwloadPretrained(String modelName, String downloadDir) throws ExecutionException, InterruptedException, IOException {
        return Stardist3D.donwloadPretrained(modelName, downloadDir, null);
    }

    public static String donwloadPretrained(String modelName, String downloadDir, Consumer<Double> progressConsumer) throws InterruptedException, IOException {
        return Stardist3D.donwloadPretrainedBioimageio(modelName, downloadDir, progressConsumer);
    }

    private static String donwloadPretrainedBioimageio(String modelName, String downloadDir, Consumer<Double> progressConsumer) throws InterruptedException, IOException {
        BioimageioRepo br = BioimageioRepo.connect();
        ModelDescriptor descriptor = br.selectByName(modelName);
        if (descriptor == null) {
            descriptor = br.selectByID(modelName);
        }
        if (descriptor == null) {
            throw new IllegalArgumentException("The model does not correspond to on of the available pretrained StarDist3D models. To find a list of available cellpose models, please run StarDist3D.getPretrainedList()");
        }
        return BioimageioRepo.downloadModel(descriptor, downloadDir, progressConsumer);
    }

    public static void main(String[] args) throws IOException, InterruptedException, RuntimeException, MambaInstallException, LoadEngineException, RunModelException, ArchiveException, URISyntaxException, LoadModelException {
        Stardist3D.installRequirements();
        Stardist3D model = Stardist3D.fromPretained("StarDist Plant Nuclei 3D ResNet", false);
        ArrayImg img = ArrayImgs.floats((long[])new long[]{116L, 120L, 66L});
        Map res = model.run(img);
        model.close();
        System.out.println(true);
    }
}

