/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.engine.engines;

import io.bioimage.modelrunner.apposed.appose.Environment;
import io.bioimage.modelrunner.apposed.appose.Mamba;
import io.bioimage.modelrunner.apposed.appose.MambaInstallException;
import io.bioimage.modelrunner.apposed.appose.Service;
import io.bioimage.modelrunner.engine.AbstractEngine;
import io.bioimage.modelrunner.system.PlatformDetection;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.tensor.shm.SharedMemoryArray;
import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import org.apache.commons.compress.archivers.ArchiveException;

public class KerasEngine
extends AbstractEngine {
    private Mamba mamba;
    private String version;
    private boolean gpu;
    private boolean isPython;
    private Boolean installed;
    private Environment env;
    private Service python;
    public static final String NAME = "keras";
    private static final List<String> SUPPORTED_KERAS_GPU_VERSIONS = Arrays.stream(new String[0]).collect(Collectors.toList());
    private static final List<String> SUPPORTED_KERAS_VERSION_NUMBERS = Arrays.stream(new String[0]).collect(Collectors.toList());
    private static final String LOAD_SCRIPT_KERAS_2 = "";
    private static final String LOAD_SCRIPT_KERAS_3 = "keras_model = None" + System.lineSeparator() + "if %s.endswith('.keras'):" + System.lineSeparator() + "  keras_model = keras.saving.load_model('%s')" + System.lineSeparator() + "" + System.lineSeparator() + "" + System.lineSeparator();
    private static final Map<String, String> LOAD_SCRIPT_MAP = new HashMap<String, String>();
    private static final String UNLOAD_SCRIPT_KERAS_2 = "";
    private static final String UNLOAD_SCRIPT_KERAS_3 = "";
    private static final Map<String, String> UNLOAD_SCRIPT_MAP;
    private static final String IS_MODEL_LOADED_SCRIPT_KERAS_2 = "";
    private static final String IS_MODEL_LOADED_SCRIPT_KERAS_3 = "";
    private static final Map<String, String> IS_MODEL_LOADED_SCRIPT_MAP;
    private static final String RUN_SCRIPT_KERAS_2 = "";
    private static final String RUN_SCRIPT_KERAS_3 = "";
    private static final Map<String, String> RUN_SCRIPT_MAP;

    private KerasEngine(String version, boolean gpu, boolean isPython) {
        if (!isPython) {
            throw new IllegalArgumentException("JDLL only has support for Keras through a Python engine.");
        }
        if (!SUPPORTED_KERAS_VERSION_NUMBERS.contains(version)) {
            throw new IllegalArgumentException("The provided Keras version is not supported by JDLL: " + version + ". The supported versions are: " + SUPPORTED_KERAS_VERSION_NUMBERS);
        }
        if (gpu && !SUPPORTED_KERAS_GPU_VERSIONS.contains(version)) {
            throw new IllegalArgumentException("The provided Keras version has no GPU support in JDLL: " + version + ". GPU supported versions are: " + SUPPORTED_KERAS_GPU_VERSIONS);
        }
        this.mamba = new Mamba();
        this.isPython = isPython;
        this.version = version;
    }

    public static KerasEngine initialize(String version, boolean gpu, boolean isPython) {
        return new KerasEngine(version, gpu, isPython);
    }

    public static String getFolderName(String version, boolean gpu, boolean isPython) {
        if (!isPython) {
            throw new IllegalArgumentException("JDLL only has support for Keras through a Python engine.");
        }
        if (!SUPPORTED_KERAS_VERSION_NUMBERS.contains(version)) {
            throw new IllegalArgumentException("The provided Keras version is not supported by JDLL: " + version + ". The supported versions are: " + SUPPORTED_KERAS_VERSION_NUMBERS);
        }
        if (gpu && !SUPPORTED_KERAS_GPU_VERSIONS.contains(version)) {
            throw new IllegalArgumentException("The provided Keras version has no GPU support in JDLL: " + version + ". GPU supported versions are: " + SUPPORTED_KERAS_GPU_VERSIONS);
        }
        return "keras_" + version + (gpu ? "_gpu" : "");
    }

    public static List<KerasEngine> getInstalledVersions() {
        List<KerasEngine> cpus = SUPPORTED_KERAS_VERSION_NUMBERS.stream().map(str -> new KerasEngine((String)str, false, true)).filter(vv -> vv.isInstalled()).collect(Collectors.toList());
        List gpus = SUPPORTED_KERAS_VERSION_NUMBERS.stream().map(str -> new KerasEngine((String)str, false, true)).filter(vv -> vv.isInstalled()).collect(Collectors.toList());
        cpus.addAll(gpus);
        return cpus;
    }

    @Override
    public String getName() {
        return NAME;
    }

    @Override
    public String getDir() {
        return this.mamba.getEnvsDir() + File.separator + KerasEngine.getFolderName(this.version, this.gpu, false);
    }

    @Override
    public boolean isPython() {
        return this.isPython;
    }

    @Override
    public String getVersion() {
        return this.version;
    }

    @Override
    public boolean supportsGPU() {
        return this.gpu;
    }

    @Override
    public boolean isInstalled() {
        if (this.installed != null) {
            return this.installed;
        }
        if (!new File(this.getDir()).exists()) {
            return false;
        }
        ArrayList<String> dependencies = new ArrayList<String>();
        try {
            this.installed = this.mamba.checkAllDependenciesInEnv(this.getDir(), dependencies);
        }
        catch (MambaInstallException e) {
            this.installed = false;
        }
        return this.installed;
    }

    @Override
    public void install() throws IOException, InterruptedException, MambaInstallException, ArchiveException, URISyntaxException {
        if (!this.mamba.checkMambaInstalled()) {
            this.mamba.installMicromamba();
        }
        ArrayList dependencies = new ArrayList();
        this.mamba.create(this.getDir(), dependencies.toArray(new String[dependencies.size()]));
        this.installed = true;
    }

    @Override
    public void loadModel(String modelFolder, String modelSource) throws IOException, InterruptedException {
        if (!this.isInstalled()) {
            throw new IllegalArgumentException("Current engine '" + this.toString() + "' is not installed. Please install it first.");
        }
        if (this.env == null) {
            this.env = new Environment(){

                @Override
                public String base() {
                    return KerasEngine.this.getDir();
                }

                @Override
                public boolean useSystemPath() {
                    return false;
                }
            };
            this.python = this.env.python();
        }
        String loadScriptFormatted = String.format(LOAD_SCRIPT_MAP.get(this.version), modelFolder, modelSource);
        Service.Task task = this.python.task(loadScriptFormatted);
        task.waitFor();
        if (task.status == Service.TaskStatus.COMPLETE) {
            return;
        }
        throw new RuntimeException("Error loading the model. " + task.error);
    }

    @Override
    public boolean isModelLoaded(String modelFolder, String modelSource) throws IOException, InterruptedException {
        if (this.python == null) {
            return false;
        }
        String loadScriptFormatted = String.format(IS_MODEL_LOADED_SCRIPT_MAP.get(this.version), new Object[0]);
        Service.Task task = this.python.task(loadScriptFormatted);
        task.waitFor();
        if (task.status == Service.TaskStatus.COMPLETE) {
            return task.outputs.get("isLoaded").equals("True");
        }
        throw new RuntimeException("Error unloading the model. " + task.error);
    }

    @Override
    public <T extends RealType<T> & NativeType<T>> void runModel(List<Tensor<T>> inputTensors, List<Tensor<T>> outputTensors) throws IOException, InterruptedException {
        if (this.python == null) {
            throw new RuntimeException("Python Keras engine has not been loaded yet.");
        }
        List<SharedMemoryArray> inputShms = inputTensors.stream().map(tt -> SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, false)).collect(Collectors.toList());
        List<Object> outputShms = inputTensors.stream().map(tt -> {
            if (tt.isEmpty() && PlatformDetection.isWindows()) {
                return SharedMemoryArray.create(0);
            }
            if (tt.isEmpty()) {
                return SharedMemoryArray.createShmName();
            }
            return SharedMemoryArray.createSHMAFromRAI(tt.getData(), false, false);
        }).collect(Collectors.toList());
        String runScriptFormatted = this.createScriptForInference(inputShms, outputShms);
        Service.Task task = this.python.task(runScriptFormatted);
        task.waitFor();
        if (task.status != Service.TaskStatus.COMPLETE) {
            throw new RuntimeException("Error making inference with the model. " + task.error);
        }
        this.retrieveOutputs(outputShms, outputTensors);
    }

    private String createScriptForInference(List<SharedMemoryArray> inputs, List<Object> outputs) {
        String runScriptFormatted = String.format(RUN_SCRIPT_MAP.get(this.version), new Object[0]);
        return "";
    }

    private <T extends RealType<T> & NativeType<T>> void retrieveOutputs(List<Object> outputShms, List<Tensor<T>> outputTensors) {
        String retrieveOutputsScriptFormatted = String.format(RUN_SCRIPT_MAP.get(this.version), new Object[0]);
    }

    @Override
    public void unloadModel() throws IOException, InterruptedException {
        if (this.python == null) {
            return;
        }
        String loadScriptFormatted = String.format(UNLOAD_SCRIPT_MAP.get(this.version), new Object[0]);
        Service.Task task = this.python.task(loadScriptFormatted);
        task.waitFor();
        if (task.status == Service.TaskStatus.COMPLETE) {
            return;
        }
        throw new RuntimeException("Error unloading the model. " + task.error);
    }

    @Override
    public void close() throws Exception {
        if (this.env == null && this.python == null) {
            return;
        }
        this.unloadModel();
        this.python.close();
        this.python = null;
        this.env = null;
    }

    public String toString() {
        return "keras_" + this.version + (this.gpu ? "_gpu" : "");
    }

    static {
        LOAD_SCRIPT_MAP.put(LOAD_SCRIPT_KERAS_3, "");
        LOAD_SCRIPT_MAP.put(LOAD_SCRIPT_KERAS_3, LOAD_SCRIPT_KERAS_3);
        UNLOAD_SCRIPT_MAP = new HashMap<String, String>();
        UNLOAD_SCRIPT_MAP.put(LOAD_SCRIPT_KERAS_3, "");
        UNLOAD_SCRIPT_MAP.put(LOAD_SCRIPT_KERAS_3, "");
        IS_MODEL_LOADED_SCRIPT_MAP = new HashMap<String, String>();
        IS_MODEL_LOADED_SCRIPT_MAP.put(LOAD_SCRIPT_KERAS_3, "");
        IS_MODEL_LOADED_SCRIPT_MAP.put(LOAD_SCRIPT_KERAS_3, "");
        RUN_SCRIPT_MAP = new HashMap<String, String>();
        RUN_SCRIPT_MAP.put(LOAD_SCRIPT_KERAS_3, "");
        RUN_SCRIPT_MAP.put(LOAD_SCRIPT_KERAS_3, "");
    }
}

