/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.model.python;

import io.bioimage.modelrunner.apposed.appose.Mamba;
import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.bioimageio.description.weights.ModelDependencies;
import io.bioimage.modelrunner.bioimageio.description.weights.ModelWeight;
import io.bioimage.modelrunner.bioimageio.tiling.ImageInfo;
import io.bioimage.modelrunner.bioimageio.tiling.TileCalculator;
import io.bioimage.modelrunner.bioimageio.tiling.TileInfo;
import io.bioimage.modelrunner.bioimageio.tiling.TileMaker;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.processing.Processing;
import io.bioimage.modelrunner.model.python.DLModelPytorchProtected;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.utils.CommonUtils;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;

public class BioimageIoModelPytorchProtected
extends DLModelPytorchProtected {
    protected ModelDescriptor descriptor;
    protected TileCalculator tileCalculator;

    protected BioimageIoModelPytorchProtected(String modelFile, String callable, String importModule, String weightsPath, Map<String, Object> kwargs, ModelDescriptor descriptor, boolean custom) throws IOException {
        super(modelFile, callable, importModule, weightsPath, kwargs, custom);
        this.tiling = true;
        this.descriptor = descriptor;
        this.tileCalculator = TileCalculator.init(descriptor);
    }

    protected BioimageIoModelPytorchProtected(String modelFile, String callable, String importModule, String weightsPath, Map<String, Object> kwargs, ModelDescriptor descriptor) throws IOException {
        this(modelFile, callable, importModule, weightsPath, kwargs, descriptor, false);
    }

    @Override
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<T>> run(List<Tensor<R>> inputTensors) throws RunModelException {
        if (!this.isLoaded()) {
            throw new RunModelException("Please first load the model.");
        }
        if (!this.tiling) {
            List<Tensor<T>> outs = this.createOutputTensors();
            this.runNoTiles(inputTensors, outs);
            return outs;
        }
        List<ImageInfo> imageInfos = inputTensors.stream().map(tt -> new ImageInfo(tt.getName(), tt.getAxesOrderString(), tt.getData().dimensionsAsLongArray())).collect(Collectors.toList());
        List<TileInfo> inputTiles = this.tileCalculator.getOptimalTileSize(imageInfos);
        TileMaker maker = TileMaker.build(this.descriptor, inputTiles);
        List<Tensor<T>> outTensors = this.createOutputTensors(maker);
        return this.runBMZ(inputTensors, outTensors, maker);
    }

    private <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createOutputTensors(TileMaker maker) {
        ArrayList<Tensor<T>> outputTensors = new ArrayList<Tensor<T>>();
        for (TensorSpec tt : this.descriptor.getOutputTensors()) {
            long[] dims = maker.getOutputImageSize(tt.getName());
            outputTensors.add(Tensor.buildBlankTensor(tt.getName(), tt.getAxesOrder(), dims, CommonUtils.getImgLib2DataType(tt.getDataType())));
        }
        return outputTensors;
    }

    private <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createOutputTensors() {
        ArrayList<Tensor<T>> outputTensors = new ArrayList<Tensor<T>>();
        for (TensorSpec tt : this.descriptor.getOutputTensors()) {
            outputTensors.add(Tensor.buildEmptyTensor(tt.getName(), tt.getAxesOrder()));
        }
        return outputTensors;
    }

    private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<T>> runBMZ(List<Tensor<R>> inputTensors, List<Tensor<T>> outputTensors, TileMaker tiles) throws RunModelException {
        Processing processing = Processing.init(this.descriptor);
        inputTensors = processing.preprocess(inputTensors, false);
        this.runTiling(inputTensors, outputTensors, tiles);
        return processing.postprocess(outputTensors, true);
    }

    @Override
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void run(List<Tensor<T>> inputTensors, List<Tensor<R>> outputTensors) throws RunModelException {
        if (!this.isLoaded()) {
            throw new RunModelException("Please first load the model.");
        }
        if (!this.tiling) {
            this.runNoTiles(inputTensors, outputTensors);
            return;
        }
        List<ImageInfo> imageInfos = inputTensors.stream().map(tt -> new ImageInfo(tt.getName(), tt.getAxesOrderString(), tt.getData().dimensionsAsLongArray())).collect(Collectors.toList());
        List<TileInfo> inputTiles = this.tileCalculator.getOptimalTileSize(imageInfos);
        TileMaker maker = TileMaker.build(this.descriptor, inputTiles);
        for (int i = 0; i < maker.getNumberOfTiles(); ++i) {
            Tensor<R> tt2 = outputTensors.get(i);
            long[] expectedSize = maker.getOutputImageSize(tt2.getName());
            if (expectedSize == null) {
                throw new IllegalArgumentException("Tensor '" + tt2.getName() + "' is missing in the outputs.");
            }
            if (tt2.isEmpty() || Arrays.equals(expectedSize, tt2.getData().dimensionsAsLongArray())) continue;
            throw new IllegalArgumentException("Tensor '" + tt2.getName() + "' size is different than the expected size as defined by the rdf.yaml: " + Arrays.toString(tt2.getData().dimensionsAsLongArray()) + " vs " + Arrays.toString(expectedSize) + ".");
        }
        this.runBMZ(inputTensors, outputTensors, maker);
    }

    public List<String> findMissingDependencies() {
        Mamba mamba = new Mamba(new File(this.envPath).getParentFile().getParentFile().getAbsolutePath());
        List<String> reqDeps = ModelDependencies.getDependencies(this.descriptor, this.descriptor.getWeights().getModelWeights(ModelWeight.getPytorchID()));
        try {
            return mamba.checkUninstalledDependenciesInEnv(this.envPath, reqDeps);
        }
        catch (MambaInstallException e) {
            return reqDeps;
        }
    }

    public boolean allDependenciesInstalled() {
        return this.findMissingDependencies().size() == 0;
    }
}

