package io.bioimage.modelrunner.model.python;

import io.bioimage.modelrunner.apposed.appose.Mamba;
import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.bioimageio.description.weights.ModelDependencies;
import io.bioimage.modelrunner.bioimageio.description.weights.ModelWeight;
import io.bioimage.modelrunner.bioimageio.tiling.ImageInfo;
import io.bioimage.modelrunner.bioimageio.tiling.TileCalculator;
import io.bioimage.modelrunner.bioimageio.tiling.TileMaker;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.processing.Processing;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.utils.CommonUtils;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;

/* loaded from: input_file:io/bioimage/modelrunner/model/python/BioimageIoModelPytorchProtected.class */
public class BioimageIoModelPytorchProtected extends DLModelPytorchProtected {
    protected ModelDescriptor descriptor;
    protected TileCalculator tileCalculator;

    /* JADX INFO: Access modifiers changed from: protected */
    public BioimageIoModelPytorchProtected(String str, String str2, String str3, Map<String, Object> map, ModelDescriptor modelDescriptor, boolean z) throws IOException {
        super(str, str2, str3, map, z);
        this.tiling = true;
        this.descriptor = modelDescriptor;
        this.tileCalculator = TileCalculator.init(modelDescriptor);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BioimageIoModelPytorchProtected(String str, String str2, String str3, Map<String, Object> map, ModelDescriptor modelDescriptor) throws IOException {
        this(str, str2, str3, map, modelDescriptor, false);
    }

    @Override // io.bioimage.modelrunner.model.python.DLModelPytorchProtected, io.bioimage.modelrunner.model.BaseModel
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<T>> run(List<Tensor<R>> list) throws RunModelException {
        if (!isLoaded()) {
            throw new RunModelException("Please first load the model.");
        }
        if (!this.tiling) {
            List<Tensor<T>> createOutputTensors = createOutputTensors();
            runNoTiles(list, createOutputTensors);
            return createOutputTensors;
        }
        TileMaker build = TileMaker.build(this.descriptor, this.tileCalculator.getOptimalTileSize((List) list.stream().map(tensor -> {
            return new ImageInfo(tensor.getName(), tensor.getAxesOrderString(), tensor.getData().dimensionsAsLongArray());
        }).collect(Collectors.toList())));
        return runBMZ(list, createOutputTensors(build), build);
    }

    private <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createOutputTensors(TileMaker tileMaker) {
        ArrayList arrayList = new ArrayList();
        for (TensorSpec tensorSpec : this.descriptor.getOutputTensors()) {
            arrayList.add(Tensor.buildBlankTensor(tensorSpec.getName(), tensorSpec.getAxesOrder(), tileMaker.getOutputImageSize(tensorSpec.getName()), CommonUtils.getImgLib2DataType(tensorSpec.getDataType())));
        }
        return arrayList;
    }

    private <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createOutputTensors() {
        ArrayList arrayList = new ArrayList();
        for (TensorSpec tensorSpec : this.descriptor.getOutputTensors()) {
            arrayList.add(Tensor.buildEmptyTensor(tensorSpec.getName(), tensorSpec.getAxesOrder()));
        }
        return arrayList;
    }

    private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<T>> runBMZ(List<Tensor<R>> list, List<Tensor<T>> list2, TileMaker tileMaker) throws RunModelException {
        Processing init = Processing.init(this.descriptor);
        runTiling(init.preprocess(list, false), list2, tileMaker);
        return init.postprocess(list2, true);
    }

    @Override // io.bioimage.modelrunner.model.python.DLModelPytorchProtected, io.bioimage.modelrunner.model.BaseModel
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void run(List<Tensor<T>> list, List<Tensor<R>> list2) throws RunModelException {
        if (!isLoaded()) {
            throw new RunModelException("Please first load the model.");
        }
        if (!this.tiling) {
            runNoTiles(list, list2);
            return;
        }
        TileMaker build = TileMaker.build(this.descriptor, this.tileCalculator.getOptimalTileSize((List) list.stream().map(tensor -> {
            return new ImageInfo(tensor.getName(), tensor.getAxesOrderString(), tensor.getData().dimensionsAsLongArray());
        }).collect(Collectors.toList())));
        for (int i = 0; i < build.getNumberOfTiles(); i++) {
            Tensor<R> tensor2 = list2.get(i);
            long[] outputImageSize = build.getOutputImageSize(tensor2.getName());
            if (outputImageSize == null) {
                throw new IllegalArgumentException("Tensor '" + tensor2.getName() + "' is missing in the outputs.");
            }
            if (!tensor2.isEmpty() && Arrays.equals(outputImageSize, tensor2.getData().dimensionsAsLongArray())) {
                throw new IllegalArgumentException("Tensor '" + tensor2.getName() + "' size is different than the expected size as defined by the rdf.yaml: " + Arrays.toString(tensor2.getData().dimensionsAsLongArray()) + " vs " + Arrays.toString(outputImageSize) + ".");
            }
        }
        runBMZ(list, list2, build);
    }

    public List<String> findMissingDependencies() {
        Mamba mamba = new Mamba(new File(this.envPath).getParentFile().getParentFile().getAbsolutePath());
        List<String> dependencies = ModelDependencies.getDependencies(this.descriptor, this.descriptor.getWeights().getModelWeights(ModelWeight.getPytorchID()));
        try {
            return mamba.checkUninstalledDependenciesInEnv(this.envPath, dependencies);
        } catch (MambaInstallException e) {
            return dependencies;
        }
    }

    public boolean allDependenciesInstalled() {
        return findMissingDependencies().size() == 0;
    }
}
