package io.bioimage.modelrunner.model.java;

import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
import io.bioimage.modelrunner.bioimageio.tiling.TileInfo;
import io.bioimage.modelrunner.bioimageio.tiling.TileMaker;
import io.bioimage.modelrunner.engine.DeepLearningEngineInterface;
import io.bioimage.modelrunner.engine.EngineInfo;
import io.bioimage.modelrunner.engine.EngineLoader;
import io.bioimage.modelrunner.exceptions.LoadEngineException;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.BaseModel;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.utils.Constants;
import java.io.IOException;
import java.net.MalformedURLException;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;

/* loaded from: input_file:io/bioimage/modelrunner/model/java/DLModelJava.class */
public class DLModelJava extends BaseModel {
    protected EngineLoader engineClassLoader;
    protected EngineInfo engineInfo;
    protected String modelFolder;
    protected String modelSource;
    protected String modelName;
    protected List<TileInfo> inputTiles;
    protected List<TileInfo> outputTiles;
    protected TilingConsumer tileCounter;
    protected boolean loaded = false;
    protected boolean tiling = false;

    /* loaded from: input_file:io/bioimage/modelrunner/model/java/DLModelJava$TilingConsumer.class */
    public static class TilingConsumer {
        private Long totalTiles;
        private Long tilesProcessed;

        public void acceptTotal(Long l) {
            this.totalTiles = l;
        }

        public void acceptProgress(Long l) {
            this.tilesProcessed = l;
        }

        public Long getTotalTiles() {
            return this.totalTiles;
        }

        public Long getTilesProcessed() {
            return this.tilesProcessed;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DLModelJava(EngineInfo engineInfo, String str, String str2, ClassLoader classLoader) throws LoadEngineException, MalformedURLException, IllegalStateException, IOException {
        if (!engineInfo.getFramework().equals(EngineInfo.getTensorflowKey()) && !engineInfo.getFramework().equals(EngineInfo.getBioimageioTfKey())) {
            Objects.requireNonNull(str2);
        }
        this.engineInfo = engineInfo;
        this.modelFolder = str;
        this.modelSource = str2;
        setEngineClassLoader(classLoader);
    }

    @Override // io.bioimage.modelrunner.model.BaseModel
    public void loadModel() throws LoadModelException {
        DeepLearningEngineInterface engineInstance = this.engineClassLoader.getEngineInstance();
        this.engineClassLoader.setEngineClassLoader();
        engineInstance.loadModel(this.modelFolder, this.modelSource);
        this.engineClassLoader.setBaseClassLoader();
        this.loaded = true;
    }

    @Override // io.bioimage.modelrunner.model.BaseModel, java.io.Closeable, java.lang.AutoCloseable
    public void close() {
        if (getEngineClassLoader() == null) {
            return;
        }
        DeepLearningEngineInterface engineInstance = getEngineClassLoader().getEngineInstance();
        this.engineClassLoader.setEngineClassLoader();
        engineInstance.closeModel();
        getEngineClassLoader().close();
        this.engineClassLoader.setBaseClassLoader();
        this.engineClassLoader = null;
        this.loaded = false;
        this.closed = true;
    }

    @Override // io.bioimage.modelrunner.model.BaseModel
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void run(List<Tensor<T>> list, List<Tensor<R>> list2) throws RunModelException {
        if (!isLoaded()) {
            throw new RunModelException("Please first load the model.");
        }
        if (!this.tiling) {
            runNoTiles(list, list2);
            return;
        }
        if (isTiling() && (this.inputTiles != null || this.inputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the input tiles are not well defined");
        }
        if (isTiling() && (this.outputTiles == null || this.outputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the output tiles are not well defined");
        }
        TileMaker build = TileMaker.build(this.inputTiles, this.outputTiles);
        for (int i = 0; i < build.getNumberOfTiles(); i++) {
            Tensor<R> tensor = list2.get(i);
            long[] outputImageSize = build.getOutputImageSize(tensor.getName());
            if (outputImageSize == null) {
                throw new IllegalArgumentException("Tensor '" + tensor.getName() + "' is missing in the outputs.");
            }
            if (!tensor.isEmpty() && Arrays.equals(outputImageSize, tensor.getData().dimensionsAsLongArray())) {
                throw new IllegalArgumentException("Tensor '" + tensor.getName() + "' size is different than the expected size defined for the output image: " + Arrays.toString(tensor.getData().dimensionsAsLongArray()) + " vs " + Arrays.toString(outputImageSize) + ".");
            }
        }
        runTiling(list, list2, build);
    }

    @Override // io.bioimage.modelrunner.model.BaseModel
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<T>> run(List<Tensor<R>> list) throws RunModelException {
        if (!isLoaded()) {
            throw new RunModelException("Please first load the model.");
        }
        if (!this.tiling) {
            throw new UnsupportedOperationException("Cannot run a DLModel if no information about the outputs is provided. Either try with 'run( List< Tensor < T > > inTensors, List< Tensor < R > > outTensors )' or set the tiling information with 'setTileInfo(List<TileInfo> inputTiles, List<TileInfo> outputTiles)'. Another option is to run simple inference over an ImgLib2 RandomAccessibleInterval with 'inference(List<RandomAccessibleInteral<T>> input)'");
        }
        if (isTiling() && (this.inputTiles != null || this.inputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the input tiles are not well defined");
        }
        if (isTiling() && (this.outputTiles == null || this.outputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the output tiles are not well defined");
        }
        TileMaker build = TileMaker.build(this.inputTiles, this.outputTiles);
        List<Tensor<T>> createOutputTensors = createOutputTensors();
        runTiling(list, createOutputTensors, build);
        return createOutputTensors;
    }

    private <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createOutputTensors() {
        ArrayList arrayList = new ArrayList();
        for (TileInfo tileInfo : this.outputTiles) {
            arrayList.add(Tensor.buildBlankTensor(tileInfo.getName(), tileInfo.getImageAxesOrder(), tileInfo.getImageDims(), new FloatType()));
        }
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void runNoTiles(List<Tensor<T>> list, List<Tensor<R>> list2) throws RunModelException {
        DeepLearningEngineInterface engineInstance = this.engineClassLoader.getEngineInstance();
        this.engineClassLoader.setEngineClassLoader();
        ArrayList arrayList = new ArrayList();
        for (Tensor<T> tensor : list) {
            if (Util.getTypeFromInterval(tensor.getData()) instanceof FloatType) {
                arrayList.add((Tensor) Cast.unchecked(tensor));
            } else {
                arrayList.add(Tensor.createCopyOfTensorInWantedDataType(tensor, new FloatType()));
            }
        }
        engineInstance.run(arrayList, list2);
        this.engineClassLoader.setBaseClassLoader();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void runTiling(List<Tensor<R>> list, List<Tensor<T>> list2, TileMaker tileMaker) throws RunModelException {
        for (int i = 0; i < tileMaker.getNumberOfTiles(); i++) {
            int i2 = 0 + i;
            runNoTiles((List) list.stream().map(tensor -> {
                return tileMaker.getNthTileInput(tensor, i2);
            }).collect(Collectors.toList()), (List) list2.stream().map(tensor2 -> {
                return tileMaker.getNthTileOutput(tensor2, i2);
            }).collect(Collectors.toList()));
        }
    }

    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<RandomAccessibleInterval<R>> inference(List<RandomAccessibleInterval<T>> list) throws RunModelException {
        DeepLearningEngineInterface engineInstance = this.engineClassLoader.getEngineInstance();
        this.engineClassLoader.setEngineClassLoader();
        List<RandomAccessibleInterval<R>> inference = engineInstance.inference(list);
        this.engineClassLoader.setBaseClassLoader();
        return inference;
    }

    public boolean isTiling() {
        return this.tiling;
    }

    public void setTiling(boolean z) {
        this.tiling = z;
    }

    public void setTileInfo(List<TileInfo> list, List<TileInfo> list2) {
        this.inputTiles = list;
        this.outputTiles = list2;
        this.tiling = true;
    }

    public void setTilingCounter(TilingConsumer tilingConsumer) {
        this.tileCounter = tilingConsumer;
    }

    public EngineInfo getEngineInfo() {
        return this.engineInfo;
    }

    @Override // io.bioimage.modelrunner.model.BaseModel
    public boolean isLoaded() {
        return this.loaded;
    }

    public EngineLoader getEngineClassLoader() {
        return this.engineClassLoader;
    }

    @Override // io.bioimage.modelrunner.model.BaseModel
    public String getModelFolder() {
        return this.modelFolder;
    }

    public String getModelSource() {
        return this.modelSource;
    }

    public String getModelName() {
        return this.modelName;
    }

    public static DLModelJava createModel(String str, String str2, EngineInfo engineInfo) throws LoadEngineException, MalformedURLException, IllegalStateException, IOException {
        Objects.requireNonNull(str);
        Objects.requireNonNull(engineInfo);
        if (!engineInfo.getFramework().equals(EngineInfo.getTensorflowKey()) && !engineInfo.getFramework().equals(EngineInfo.getBioimageioTfKey())) {
            Objects.requireNonNull(str2);
        }
        if (Paths.get(str, Constants.RDF_FNAME).toFile().isFile()) {
            try {
                BioimageIoModelJava bioimageIoModelJava = new BioimageIoModelJava(engineInfo, str, str2, null);
                bioimageIoModelJava.descriptor = ModelDescriptorFactory.readFromLocalFile(Paths.get(str, Constants.RDF_FNAME).toAbsolutePath().toString());
                return bioimageIoModelJava;
            } catch (IOException e) {
            }
        }
        return new DLModelJava(engineInfo, str, str2, null);
    }

    public static DLModelJava createModel(String str, String str2, EngineInfo engineInfo, ClassLoader classLoader) throws LoadEngineException, MalformedURLException, IllegalStateException, IOException {
        Objects.requireNonNull(str);
        Objects.requireNonNull(engineInfo);
        if (!engineInfo.getFramework().equals(EngineInfo.getTensorflowKey()) && !engineInfo.getFramework().equals(EngineInfo.getBioimageioTfKey())) {
            Objects.requireNonNull(str2);
        }
        if (Paths.get(str, Constants.RDF_FNAME).toFile().isFile()) {
            try {
                BioimageIoModelJava bioimageIoModelJava = new BioimageIoModelJava(engineInfo, str, str2, classLoader);
                bioimageIoModelJava.descriptor = ModelDescriptorFactory.readFromLocalFile(Paths.get(str, Constants.RDF_FNAME).toAbsolutePath().toString());
                return bioimageIoModelJava;
            } catch (IOException e) {
            }
        }
        return new DLModelJava(engineInfo, str, str2, classLoader);
    }

    protected void setEngineClassLoader(ClassLoader classLoader) throws LoadEngineException, MalformedURLException, IllegalStateException, IOException {
        this.engineClassLoader = EngineLoader.createEngine(classLoader == null ? Thread.currentThread().getContextClassLoader() : classLoader, this.engineInfo);
    }

    public static TilingConsumer createTilingConsumer() {
        return new TilingConsumer();
    }
}
