package io.bioimage.modelrunner.gui.adapter;

import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
import io.bioimage.modelrunner.apposed.appose.Types;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.bioimageio.description.exceptions.ModelSpecsException;
import io.bioimage.modelrunner.bioimageio.description.weights.ModelWeight;
import io.bioimage.modelrunner.engine.installation.EngineInstall;
import io.bioimage.modelrunner.exceptions.LoadEngineException;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.BaseModel;
import io.bioimage.modelrunner.model.java.BioimageIoModelJava;
import io.bioimage.modelrunner.model.python.BioimageIoModelPytorch;
import io.bioimage.modelrunner.model.python.DLModelPytorchProtected;
import io.bioimage.modelrunner.model.special.stardist.Stardist2D;
import io.bioimage.modelrunner.model.special.stardist.StardistAbstract;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.utils.Constants;
import java.io.Closeable;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import org.apache.commons.compress.archivers.ArchiveException;

/* loaded from: input_file:io/bioimage/modelrunner/gui/adapter/RunnerAdapter.class */
public abstract class RunnerAdapter implements Closeable {
    protected ModelDescriptor descriptor;
    protected final String enginesPath;
    protected final ClassLoader classLoader;
    protected BaseModel model;
    protected boolean closed;
    protected boolean loaded;

    protected abstract <T extends RealType<T> & NativeType<T>> LinkedHashMap<TensorSpec, RandomAccessibleInterval<T>> displayTestInputs(LinkedHashMap<TensorSpec, String> linkedHashMap);

    protected abstract LinkedHashMap<TensorSpec, String> getTestInputs();

    protected RunnerAdapter(ModelDescriptor modelDescriptor) {
        this.closed = false;
        this.loaded = false;
        this.descriptor = modelDescriptor;
        this.enginesPath = new File("engines").getAbsolutePath();
        this.classLoader = null;
    }

    protected RunnerAdapter(ModelDescriptor modelDescriptor, ClassLoader classLoader) {
        this.closed = false;
        this.loaded = false;
        this.descriptor = modelDescriptor;
        this.enginesPath = new File("engines").getAbsolutePath();
        this.classLoader = classLoader;
    }

    protected RunnerAdapter(ModelDescriptor modelDescriptor, String str) {
        this.closed = false;
        this.loaded = false;
        this.descriptor = modelDescriptor;
        this.enginesPath = str;
        this.classLoader = null;
    }

    protected RunnerAdapter(ModelDescriptor modelDescriptor, String str, ClassLoader classLoader) {
        this.closed = false;
        this.loaded = false;
        this.descriptor = modelDescriptor;
        this.enginesPath = str;
        this.classLoader = classLoader;
    }

    public ModelDescriptor getDescriptor() {
        return this.descriptor;
    }

    public void load() throws LoadModelException {
        load(true);
    }

    public void load(boolean z) throws LoadModelException {
        if (this.closed) {
            throw new RuntimeException("The model has already been closed");
        }
        try {
            if (this.classLoader == null) {
                initWithEnginesPath(z);
            } else {
                initWithEnginesClassLoader(z);
            }
            this.model.loadModel();
            this.loaded = true;
        } catch (Exception e) {
            throw new LoadModelException(Types.stackTrace(e));
        }
    }

    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<R>> run(List<Tensor<T>> list) throws FileNotFoundException, RunModelException, IOException {
        if (this.closed) {
            throw new RuntimeException("The model has already been closed");
        }
        if (this.model.isLoaded()) {
            return this.model.run(list);
        }
        throw new RuntimeException("Please first load the model");
    }

    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<R>> runOnTestImages() throws FileNotFoundException, ModelSpecsException, RunModelException, IOException {
        return this.model.run(createTestTensorList(displayTestInputs(getTestInputs())));
    }

    private <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createTestTensorList(LinkedHashMap<TensorSpec, RandomAccessibleInterval<T>> linkedHashMap) {
        return (List) linkedHashMap.entrySet().stream().map(entry -> {
            return Tensor.build(((TensorSpec) entry.getKey()).getName(), ((TensorSpec) entry.getKey()).getAxesOrder(), (RandomAccessibleInterval) entry.getValue());
        }).collect(Collectors.toList());
    }

    public boolean isClosed() {
        return this.closed;
    }

    private void initWithEnginesPath(boolean z) throws IOException, LoadEngineException, InterruptedException, RuntimeException, MambaInstallException, ArchiveException, URISyntaxException {
        List<String> allSuportedWeightNames = this.descriptor.getWeights().getAllSuportedWeightNames();
        if (this.descriptor.getModelFamily().equals(ModelDescriptor.STARDIST)) {
            boolean isInstalled = StardistAbstract.isInstalled();
            if (z && !isInstalled) {
                StardistAbstract.installRequirements();
            } else if (!isInstalled) {
                this.descriptor = ModelDescriptorFactory.readFromLocalFile(this.descriptor.getModelPath() + File.separator + Constants.RDF_FNAME, false);
                initWithEnginesPath(z);
                return;
            }
            this.model = Stardist2D.fromBioimageioModel(this.descriptor);
            return;
        }
        if (this.descriptor.getModelFamily().equals(ModelDescriptor.BIOIMAGEIO) && (allSuportedWeightNames.size() != 1 || !allSuportedWeightNames.contains(ModelWeight.getPytorchID()))) {
            if (z) {
                EngineInstall.installEnginesForModelInDir(this.descriptor, this.enginesPath, d -> {
                    System.out.println("Downloading engines for " + this.descriptor.getName() + ":" + (Math.round(d.doubleValue() * 10000.0d) / 100) + "%");
                });
            }
            this.model = BioimageIoModelJava.createBioimageioModel(this.descriptor.getModelPath(), this.enginesPath);
            return;
        }
        if (!this.descriptor.getModelFamily().equals(ModelDescriptor.BIOIMAGEIO)) {
            throw new IllegalArgumentException("Model not supported");
        }
        if (z && !DLModelPytorchProtected.isInstalled()) {
            DLModelPytorchProtected.installRequirements();
        }
        this.model = BioimageIoModelPytorch.create(this.descriptor);
    }

    private void initWithEnginesClassLoader(boolean z) throws LoadEngineException, IOException, InterruptedException, RuntimeException, MambaInstallException, ArchiveException, URISyntaxException {
        List<String> allSuportedWeightNames = this.descriptor.getWeights().getAllSuportedWeightNames();
        if (this.descriptor.getModelFamily().equals(ModelDescriptor.STARDIST)) {
            boolean isInstalled = StardistAbstract.isInstalled();
            if (z && !isInstalled) {
                StardistAbstract.installRequirements();
            } else if (!isInstalled) {
                this.descriptor = ModelDescriptorFactory.readFromLocalFile(this.descriptor.getModelPath() + File.separator + Constants.RDF_FNAME, false);
                initWithEnginesPath(z);
                return;
            }
            this.model = Stardist2D.fromBioimageioModel(this.descriptor);
            return;
        }
        if (this.descriptor.getModelFamily().equals(ModelDescriptor.BIOIMAGEIO) && (allSuportedWeightNames.size() != 1 || !allSuportedWeightNames.contains(ModelWeight.getPytorchID()))) {
            if (z) {
                EngineInstall.installEnginesForModelInDir(this.descriptor, this.enginesPath, d -> {
                    System.out.println("Downloading engines for " + this.descriptor.getName() + ":" + (Math.round(d.doubleValue() * 10000.0d) / 100) + "%");
                });
            }
            this.model = BioimageIoModelJava.createBioimageioModel(this.descriptor.getModelPath(), this.enginesPath, this.classLoader);
            return;
        }
        if (!this.descriptor.getModelFamily().equals(ModelDescriptor.BIOIMAGEIO)) {
            throw new IllegalArgumentException("Model not supported");
        }
        if (z && !DLModelPytorchProtected.isInstalled()) {
            DLModelPytorchProtected.installRequirements();
        }
        this.model = BioimageIoModelPytorch.create(this.descriptor);
    }

    public boolean isLoaded() {
        if (isClosed()) {
            return false;
        }
        return this.loaded;
    }

    @Override // java.io.Closeable, java.lang.AutoCloseable
    public void close() throws IOException {
        this.model.close();
        this.closed = true;
        this.loaded = false;
    }
}
