/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.model.java;

import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
import io.bioimage.modelrunner.bioimageio.tiling.TileInfo;
import io.bioimage.modelrunner.bioimageio.tiling.TileMaker;
import io.bioimage.modelrunner.engine.DeepLearningEngineInterface;
import io.bioimage.modelrunner.engine.EngineInfo;
import io.bioimage.modelrunner.engine.EngineLoader;
import io.bioimage.modelrunner.exceptions.LoadEngineException;
import io.bioimage.modelrunner.exceptions.LoadModelException;
import io.bioimage.modelrunner.exceptions.RunModelException;
import io.bioimage.modelrunner.model.BaseModel;
import io.bioimage.modelrunner.model.java.BioimageIoModelJava;
import io.bioimage.modelrunner.tensor.Tensor;
import java.io.IOException;
import java.net.MalformedURLException;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;

public class DLModelJava
extends BaseModel {
    protected boolean loaded = false;
    protected EngineLoader engineClassLoader;
    protected EngineInfo engineInfo;
    protected String modelFolder;
    protected String modelSource;
    protected String modelName;
    protected List<TileInfo> inputTiles;
    protected List<TileInfo> outputTiles;
    protected boolean tiling = false;
    protected TilingConsumer tileCounter;

    protected DLModelJava(EngineInfo engineInfo, String modelFolder, String modelSource, ClassLoader classLoader) throws LoadEngineException, MalformedURLException, IllegalStateException, IOException {
        if (!engineInfo.getFramework().equals(EngineInfo.getTensorflowKey()) && !engineInfo.getFramework().equals(EngineInfo.getBioimageioTfKey())) {
            Objects.requireNonNull(modelSource);
        }
        this.engineInfo = engineInfo;
        this.modelFolder = modelFolder;
        this.modelSource = modelSource;
        this.setEngineClassLoader(classLoader);
    }

    @Override
    public void loadModel() throws LoadModelException {
        DeepLearningEngineInterface engineInstance = this.engineClassLoader.getEngineInstance();
        this.engineClassLoader.setEngineClassLoader();
        engineInstance.loadModel(this.modelFolder, this.modelSource);
        this.engineClassLoader.setBaseClassLoader();
        this.loaded = true;
    }

    @Override
    public void close() {
        if (this.getEngineClassLoader() == null) {
            return;
        }
        DeepLearningEngineInterface engineInstance = this.getEngineClassLoader().getEngineInstance();
        this.engineClassLoader.setEngineClassLoader();
        engineInstance.closeModel();
        this.getEngineClassLoader().close();
        engineInstance = null;
        this.engineClassLoader.setBaseClassLoader();
        this.engineClassLoader = null;
        this.loaded = false;
        this.closed = true;
    }

    @Override
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void run(List<Tensor<T>> inTensors, List<Tensor<R>> outTensors) throws RunModelException {
        if (!this.isLoaded()) {
            throw new RunModelException("Please first load the model.");
        }
        if (!this.tiling) {
            this.runNoTiles(inTensors, outTensors);
            return;
        }
        if (this.isTiling() && (this.inputTiles != null || this.inputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the input tiles are not well defined");
        }
        if (this.isTiling() && (this.outputTiles == null || this.outputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the output tiles are not well defined");
        }
        TileMaker tiles = TileMaker.build(this.inputTiles, this.outputTiles);
        for (int i = 0; i < tiles.getNumberOfTiles(); ++i) {
            Tensor<R> tt = outTensors.get(i);
            long[] expectedSize = tiles.getOutputImageSize(tt.getName());
            if (expectedSize == null) {
                throw new IllegalArgumentException("Tensor '" + tt.getName() + "' is missing in the outputs.");
            }
            if (tt.isEmpty() || !Arrays.equals(expectedSize, tt.getData().dimensionsAsLongArray())) continue;
            throw new IllegalArgumentException("Tensor '" + tt.getName() + "' size is different than the expected size defined for the output image: " + Arrays.toString(tt.getData().dimensionsAsLongArray()) + " vs " + Arrays.toString(expectedSize) + ".");
        }
        this.runTiling(inTensors, outTensors, tiles);
    }

    @Override
    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<Tensor<T>> run(List<Tensor<R>> inputTensors) throws RunModelException {
        if (!this.isLoaded()) {
            throw new RunModelException("Please first load the model.");
        }
        if (!this.tiling) {
            throw new UnsupportedOperationException("Cannot run a DLModel if no information about the outputs is provided. Either try with 'run( List< Tensor < T > > inTensors, List< Tensor < R > > outTensors )' or set the tiling information with 'setTileInfo(List<TileInfo> inputTiles, List<TileInfo> outputTiles)'. Another option is to run simple inference over an ImgLib2 RandomAccessibleInterval with 'inference(List<RandomAccessibleInteral<T>> input)'");
        }
        if (this.isTiling() && (this.inputTiles != null || this.inputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the input tiles are not well defined");
        }
        if (this.isTiling() && (this.outputTiles == null || this.outputTiles.size() == 0)) {
            throw new UnsupportedOperationException("Tiling is set to 'true' but the output tiles are not well defined");
        }
        TileMaker maker = TileMaker.build(this.inputTiles, this.outputTiles);
        List<Tensor<T>> outTensors = this.createOutputTensors();
        this.runTiling(inputTensors, outTensors, maker);
        return outTensors;
    }

    private <T extends RealType<T> & NativeType<T>> List<Tensor<T>> createOutputTensors() {
        ArrayList<Tensor<T>> outputTensors = new ArrayList<Tensor<T>>();
        for (TileInfo tt : this.outputTiles) {
            outputTensors.add(Tensor.buildBlankTensor(tt.getName(), tt.getImageAxesOrder(), tt.getImageDims(), new FloatType()));
        }
        return outputTensors;
    }

    protected <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void runNoTiles(List<Tensor<T>> inTensors, List<Tensor<R>> outTensors) throws RunModelException {
        DeepLearningEngineInterface engineInstance = this.engineClassLoader.getEngineInstance();
        this.engineClassLoader.setEngineClassLoader();
        ArrayList inTensorsFloat = new ArrayList();
        for (Tensor<T> tt : inTensors) {
            if (Util.getTypeFromInterval(tt.getData()) instanceof FloatType) {
                inTensorsFloat.add((Tensor)Cast.unchecked(tt));
                continue;
            }
            inTensorsFloat.add(Tensor.createCopyOfTensorInWantedDataType(tt, new FloatType()));
        }
        engineInstance.run(inTensorsFloat, outTensors);
        this.engineClassLoader.setBaseClassLoader();
    }

    protected <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> void runTiling(List<Tensor<R>> inputTensors, List<Tensor<T>> outputTensors, TileMaker tiles) throws RunModelException {
        for (int i = 0; i < tiles.getNumberOfTiles(); ++i) {
            int nTile = 0 + i;
            List<Tensor<T>> inputTiles = inputTensors.stream().map(tt -> tiles.getNthTileInput(tt, nTile)).collect(Collectors.toList());
            List<Tensor<R>> outputTiles = outputTensors.stream().map(tt -> tiles.getNthTileOutput(tt, nTile)).collect(Collectors.toList());
            this.runNoTiles(inputTiles, outputTiles);
        }
    }

    public <T extends RealType<T> & NativeType<T>, R extends RealType<R> & NativeType<R>> List<RandomAccessibleInterval<R>> inference(List<RandomAccessibleInterval<T>> inputs) throws RunModelException {
        DeepLearningEngineInterface engineInstance = this.engineClassLoader.getEngineInstance();
        this.engineClassLoader.setEngineClassLoader();
        List outs = engineInstance.inference(inputs);
        this.engineClassLoader.setBaseClassLoader();
        return outs;
    }

    public boolean isTiling() {
        return this.tiling;
    }

    public void setTiling(boolean doTiling) {
        this.tiling = doTiling;
    }

    public void setTileInfo(List<TileInfo> inputTiles, List<TileInfo> outputTiles) {
        this.inputTiles = inputTiles;
        this.outputTiles = outputTiles;
        this.tiling = true;
    }

    public void setTilingCounter(TilingConsumer tileCounter) {
        this.tileCounter = tileCounter;
    }

    public EngineInfo getEngineInfo() {
        return this.engineInfo;
    }

    @Override
    public boolean isLoaded() {
        return this.loaded;
    }

    public EngineLoader getEngineClassLoader() {
        return this.engineClassLoader;
    }

    @Override
    public String getModelFolder() {
        return this.modelFolder;
    }

    public String getModelSource() {
        return this.modelSource;
    }

    public String getModelName() {
        return this.modelName;
    }

    public static DLModelJava createModel(String modelFolder, String modelSource, EngineInfo engineInfo) throws LoadEngineException, MalformedURLException, IllegalStateException, IOException {
        Objects.requireNonNull(modelFolder);
        Objects.requireNonNull(engineInfo);
        if (!engineInfo.getFramework().equals(EngineInfo.getTensorflowKey()) && !engineInfo.getFramework().equals(EngineInfo.getBioimageioTfKey())) {
            Objects.requireNonNull(modelSource);
        }
        if (Paths.get(modelFolder, "rdf.yaml").toFile().isFile()) {
            try {
                BioimageIoModelJava model = new BioimageIoModelJava(engineInfo, modelFolder, modelSource, null);
                model.descriptor = ModelDescriptorFactory.readFromLocalFile(Paths.get(modelFolder, "rdf.yaml").toAbsolutePath().toString());
                return model;
            }
            catch (IOException iOException) {
                // empty catch block
            }
        }
        return new DLModelJava(engineInfo, modelFolder, modelSource, null);
    }

    public static DLModelJava createModel(String modelFolder, String modelSource, EngineInfo engineInfo, ClassLoader classLoader) throws LoadEngineException, MalformedURLException, IllegalStateException, IOException {
        Objects.requireNonNull(modelFolder);
        Objects.requireNonNull(engineInfo);
        if (!engineInfo.getFramework().equals(EngineInfo.getTensorflowKey()) && !engineInfo.getFramework().equals(EngineInfo.getBioimageioTfKey())) {
            Objects.requireNonNull(modelSource);
        }
        if (Paths.get(modelFolder, "rdf.yaml").toFile().isFile()) {
            try {
                BioimageIoModelJava model = new BioimageIoModelJava(engineInfo, modelFolder, modelSource, classLoader);
                model.descriptor = ModelDescriptorFactory.readFromLocalFile(Paths.get(modelFolder, "rdf.yaml").toAbsolutePath().toString());
                return model;
            }
            catch (IOException iOException) {
                // empty catch block
            }
        }
        return new DLModelJava(engineInfo, modelFolder, modelSource, classLoader);
    }

    protected void setEngineClassLoader(ClassLoader classLoader) throws LoadEngineException, MalformedURLException, IllegalStateException, IOException {
        this.engineClassLoader = EngineLoader.createEngine(classLoader == null ? Thread.currentThread().getContextClassLoader() : classLoader, this.engineInfo);
    }

    public static TilingConsumer createTilingConsumer() {
        return new TilingConsumer();
    }

    public static class TilingConsumer {
        private Long totalTiles;
        private Long tilesProcessed;

        public void acceptTotal(Long totalTiles) {
            this.totalTiles = totalTiles;
        }

        public void acceptProgress(Long tilesProcessed) {
            this.tilesProcessed = tilesProcessed;
        }

        public Long getTotalTiles() {
            return this.totalTiles;
        }

        public Long getTilesProcessed() {
            return this.tilesProcessed;
        }
    }
}

