package io.bioimage.modelrunner.model.java;

import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.bioimageio.description.weights.ModelWeight;
import io.bioimage.modelrunner.bioimageio.description.weights.WeightFormat;
import io.bioimage.modelrunner.bioimageio.tiling.ImageInfo;
import io.bioimage.modelrunner.bioimageio.tiling.TileCalculator;
import io.bioimage.modelrunner.bioimageio.tiling.TileInfo;
import io.bioimage.modelrunner.bioimageio.tiling.TileMaker;
import io.bioimage.modelrunner.engine.EngineInfo;
import io.bioimage.modelrunner.exceptions.LoadEngineException;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.processing.Processing;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.utils.CommonUtils;
import io.bioimage.modelrunner.utils.Constants;
import io.bioimage.modelrunner.versionmanagement.InstalledEngines;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.net.MalformedURLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.util.Cast;

/* loaded from: input_file:io/bioimage/modelrunner/model/java/BioimageIoModelJava.class */
public class BioimageIoModelJava extends DLModelJava {
    private boolean bioengine;
    protected ModelDescriptor descriptor;
    protected TileCalculator tileCalculator;

    /* JADX INFO: Access modifiers changed from: protected */
    public BioimageIoModelJava(EngineInfo engineInfo, String str, String str2, ClassLoader classLoader) throws LoadEngineException, MalformedURLException, IllegalStateException, IOException {
        super(engineInfo, str, str2, classLoader);
        this.bioengine = false;
        this.tiling = true;
    }

    public static BioimageIoModelJava createBioimageioModel(String str, ClassLoader classLoader) throws LoadEngineException, IOException {
        return createBioimageioModel(str, InstalledEngines.getEnginesDir(), classLoader);
    }

    public static BioimageIoModelJava createBioimageioModel(String str) throws LoadEngineException, IOException {
        return createBioimageioModel(str, InstalledEngines.getEnginesDir());
    }

    public static BioimageIoModelJava createBioimageioModel(String str, String str2) throws LoadEngineException, IOException {
        return createBioimageioModel(str, str2, null);
    }

    public static BioimageIoModelJava createBioimageioModel(String str, String str2, ClassLoader classLoader) throws LoadEngineException, IOException {
        Objects.requireNonNull(str);
        Objects.requireNonNull(str2);
        if (!new File(str, Constants.RDF_FNAME).isFile()) {
            throw new IOException("A Bioimage.io model folder should contain its corresponding rdf.yaml file.");
        }
        ModelDescriptor readFromLocalFile = ModelDescriptorFactory.readFromLocalFile(str + File.separator + Constants.RDF_FNAME);
        String str3 = null;
        EngineInfo engineInfo = null;
        Iterator<WeightFormat> it = readFromLocalFile.getWeights().gettAllSupportedWeightObjects().iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            WeightFormat next = it.next();
            String sourceFileName = next.getSourceFileName();
            if (new File(str, sourceFileName).isFile() || next.getFramework().equals(ModelWeight.getTensorflowID())) {
                if (!next.getFramework().equals(ModelWeight.getTensorflowID()) || new File(str, sourceFileName).isFile() || (new File(str, "saved_model.pb").isFile() && new File(str, "variables").isDirectory())) {
                    engineInfo = EngineInfo.defineCompatibleDLEngineWithRdfYamlWeights(next, str2);
                    if (engineInfo != null) {
                        str3 = new File(str, sourceFileName).getAbsolutePath();
                        break;
                    }
                }
            }
        }
        if (engineInfo == null) {
            throw new IOException("Please install a compatible engine with the model weights. To be compatible the engine has to be of the same framework and the major version needs to be the same. The model weights are: " + readFromLocalFile.getWeights().getSupportedWeightNamesAndVersion());
        }
        BioimageIoModelJava bioimageIoModelJava = new BioimageIoModelJava(engineInfo, str, str3, classLoader);
        bioimageIoModelJava.descriptor = readFromLocalFile;
        bioimageIoModelJava.tileCalculator = TileCalculator.init(readFromLocalFile);
        return bioimageIoModelJava;
    }

    public static BioimageIoModelJava createBioimageioModelWithExactWeigths(String str, String str2, ClassLoader classLoader) throws IOException, IllegalStateException, LoadEngineException {
        Objects.requireNonNull(str);
        Objects.requireNonNull(str2);
        if (!new File(str, Constants.RDF_FNAME).isFile()) {
            throw new IOException("A Bioimage.io model folder should contain its corresponding rdf.yaml file.");
        }
        ModelDescriptor readFromLocalFile = ModelDescriptorFactory.readFromLocalFile(str + File.separator + Constants.RDF_FNAME);
        String str3 = null;
        EngineInfo engineInfo = null;
        Iterator<WeightFormat> it = readFromLocalFile.getWeights().gettAllSupportedWeightObjects().iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            WeightFormat next = it.next();
            String sourceFileName = next.getSourceFileName();
            if (new File(str, sourceFileName).isFile()) {
                engineInfo = EngineInfo.defineExactDLEngineWithRdfYamlWeights(next, str2);
                if (engineInfo != null) {
                    str3 = new File(str, sourceFileName).getAbsolutePath();
                    break;
                }
            }
        }
        if (engineInfo == null) {
            throw new IOException("Please install the engines defined by the model weights. The model weights are: " + readFromLocalFile.getWeights().getSupportedWeightNamesAndVersion());
        }
        BioimageIoModelJava bioimageIoModelJava = new BioimageIoModelJava(engineInfo, str, str3, classLoader);
        bioimageIoModelJava.descriptor = readFromLocalFile;
        bioimageIoModelJava.tileCalculator = TileCalculator.init(readFromLocalFile);
        return bioimageIoModelJava;
    }

    @Override // io.bioimage.modelrunner.model.java.DLModelJava, io.bioimage.modelrunner.model.BaseModel
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<T>> run(List<Tensor<R>> list) throws RunModelException {
        if (!isLoaded()) {
            throw new RunModelException("Please first load the model.");
        }
        if (!this.tiling) {
            List<Tensor<T>> createOutputTensors = createOutputTensors();
            runNoTiles(list, createOutputTensors);
            return createOutputTensors;
        }
        TileMaker build = TileMaker.build(this.descriptor, this.tileCalculator.getOptimalTileSize((List) list.stream().map(tensor -> {
            return new ImageInfo(tensor.getName(), tensor.getAxesOrderString(), tensor.getData().dimensionsAsLongArray());
        }).collect(Collectors.toList())));
        return runBMZ(list, createOutputTensors(build), build);
    }

    private <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createOutputTensors(TileMaker tileMaker) {
        ArrayList arrayList = new ArrayList();
        for (TensorSpec tensorSpec : this.descriptor.getOutputTensors()) {
            arrayList.add(Tensor.buildBlankTensor(tensorSpec.getName(), tensorSpec.getAxesOrder(), tileMaker.getOutputImageSize(tensorSpec.getName()), CommonUtils.getImgLib2DataType(tensorSpec.getDataType())));
        }
        return arrayList;
    }

    private <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createOutputTensors() {
        ArrayList arrayList = new ArrayList();
        for (TensorSpec tensorSpec : this.descriptor.getOutputTensors()) {
            arrayList.add(Tensor.buildEmptyTensor(tensorSpec.getName(), tensorSpec.getAxesOrder()));
        }
        return arrayList;
    }

    private <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<T>> runBMZ(List<Tensor<R>> list, List<Tensor<T>> list2, TileMaker tileMaker) throws RunModelException {
        Processing init = Processing.init(this.descriptor);
        runTiling(init.preprocess(list, false), list2, tileMaker);
        return init.postprocess(list2, true);
    }

    @Override // io.bioimage.modelrunner.model.java.DLModelJava, io.bioimage.modelrunner.model.BaseModel
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void run(List<Tensor<T>> list, List<Tensor<R>> list2) throws RunModelException {
        if (!isLoaded()) {
            throw new RunModelException("Please first load the model.");
        }
        if (!this.tiling) {
            runNoTiles(list, list2);
            return;
        }
        TileMaker build = TileMaker.build(this.descriptor, this.tileCalculator.getOptimalTileSize((List) list.stream().map(tensor -> {
            return new ImageInfo(tensor.getName(), tensor.getAxesOrderString(), tensor.getData().dimensionsAsLongArray());
        }).collect(Collectors.toList())));
        for (int i = 0; i < build.getNumberOfTiles(); i++) {
            Tensor<R> tensor2 = list2.get(i);
            long[] outputImageSize = build.getOutputImageSize(tensor2.getName());
            if (outputImageSize == null) {
                throw new IllegalArgumentException("Tensor '" + tensor2.getName() + "' is missing in the outputs.");
            }
            if (!tensor2.isEmpty() && Arrays.equals(outputImageSize, tensor2.getData().dimensionsAsLongArray())) {
                throw new IllegalArgumentException("Tensor '" + tensor2.getName() + "' size is different than the expected size as defined by the rdf.yaml: " + Arrays.toString(tensor2.getData().dimensionsAsLongArray()) + " vs " + Arrays.toString(outputImageSize) + ".");
            }
        }
        runBMZ(list, list2, build);
    }

    public static <T extends NativeType<T> & RealType<T>> void main(String[] strArr) throws IOException, LoadEngineException, RunModelException, LoadModelException {
        Img img = (Img) Cast.unchecked(ArrayImgs.floats(new long[]{1, 1, 512, 512}));
        ArrayList arrayList = new ArrayList();
        arrayList.add(Tensor.build("input0", "bcyx", img));
        BioimageIoModelJava createBioimageioModel = createBioimageioModel("/home/carlos/git/JDLL/models/NucleiSegmentationBoundaryModel_17122023_143125");
        createBioimageioModel.loadModel();
        new ArrayList().add(TileInfo.build(((Tensor) arrayList.get(0)).getName(), new long[]{1, 1, 512, 512}, ((Tensor) arrayList.get(0)).getAxesOrderString(), new long[]{1, 1, 512, 512}, ((Tensor) arrayList.get(0)).getAxesOrderString()));
        createBioimageioModel.run(arrayList);
        System.out.println(false);
    }

    public boolean isBioengine() {
        return this.bioengine;
    }

    public ModelDescriptor getBioimageioSpecs() throws FileNotFoundException, IOException {
        if (this.descriptor == null && new File(this.modelFolder + File.separator + Constants.RDF_FNAME).isFile()) {
            this.descriptor = ModelDescriptorFactory.readFromLocalFile(this.modelFolder + File.separator + Constants.RDF_FNAME);
        }
        return this.descriptor;
    }
}
