/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.bioimageio.description.weights;

import io.bioimage.modelrunner.bioimageio.description.weights.KerasWeights;
import io.bioimage.modelrunner.bioimageio.description.weights.OnnxWeights;
import io.bioimage.modelrunner.bioimageio.description.weights.PytorchWeights;
import io.bioimage.modelrunner.bioimageio.description.weights.TfJsWeights;
import io.bioimage.modelrunner.bioimageio.description.weights.TfWeights;
import io.bioimage.modelrunner.bioimageio.description.weights.TorchscriptWeights;
import io.bioimage.modelrunner.bioimageio.description.weights.WeightFormat;
import io.bioimage.modelrunner.versionmanagement.SupportedVersions;
import io.bioimage.modelrunner.versionmanagement.VersionStringUtils;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

public class ModelWeight {
    private String selectedEngine;
    private String selectedVersion;
    private WeightFormat selectedWeights;
    private static Map<String, WeightFormat> loadedWeights = new HashMap<String, WeightFormat>();
    private HashMap<String, WeightFormat> weightsDic;
    private static String kerasIdentifier = "keras_hdf5";
    private static String onnxIdentifier = "onnx";
    private static String torchIdentifier = "pytorch_state_dict";
    private static String tfIdentifier = "tensorflow_saved_model_bundle";
    private static String tfJsIdentifier = "tensorflow_js";
    private static String torchscriptIdentifier = "torchscript";
    private static String bioengineIdentifier = "bioengine";

    public static ModelWeight build(Map<String, Object> yamlFieldElements) {
        ModelWeight model = new ModelWeight();
        Set<String> weightsFormats = yamlFieldElements.keySet();
        model.weightsDic = new HashMap();
        for (String ww : weightsFormats) {
            WeightFormat weightsObject;
            Map weights = (Map)yamlFieldElements.get(ww);
            if (weights == null) continue;
            if (ww.contentEquals(kerasIdentifier)) {
                weightsObject = new KerasWeights(weights);
                model.weightsDic.put(model.kerasEngineName((KerasWeights)weightsObject), weightsObject);
                continue;
            }
            if (ww.contentEquals(onnxIdentifier)) {
                weightsObject = new OnnxWeights(weights);
                model.weightsDic.put(model.onnxEngineName((OnnxWeights)weightsObject), weightsObject);
                continue;
            }
            if (ww.contentEquals(torchIdentifier)) {
                weightsObject = new PytorchWeights(weights);
                model.weightsDic.put(model.torchEngineName((PytorchWeights)weightsObject), weightsObject);
                continue;
            }
            if (ww.contentEquals(tfIdentifier)) {
                weightsObject = new TfWeights(weights);
                model.weightsDic.put(model.tfEngineName((TfWeights)weightsObject), weightsObject);
                continue;
            }
            if (ww.contentEquals(tfJsIdentifier)) {
                weightsObject = new TfJsWeights(weights);
                model.weightsDic.put(model.tfJsEngineName((TfJsWeights)weightsObject), weightsObject);
                continue;
            }
            if (!ww.contentEquals(torchscriptIdentifier) && !ww.contentEquals("pytorch_script")) continue;
            weightsObject = new TorchscriptWeights(weights);
            model.weightsDic.put(model.torchscriptEngineName((TorchscriptWeights)weightsObject), weightsObject);
        }
        return model;
    }

    public WeightFormat getModelWeights(String weightID) {
        if (this.weightsDic.get(weightID) != null) {
            return this.weightsDic.get(weightID);
        }
        String sel = this.weightsDic.keySet().stream().filter(kk -> kk.startsWith(weightID)).findFirst().orElse(null);
        if (sel == null) {
            return null;
        }
        return this.weightsDic.get(sel);
    }

    public WeightFormat getSupportedWeightObject(String weightsFormat, String version) throws IllegalArgumentException {
        if (weightsFormat.equals(ModelWeight.getBioengineID())) {
            return null;
        }
        WeightFormat ww = this.weightsDic.values().stream().filter(w -> {
            if (!w.getFramework().equals(weightsFormat)) {
                return false;
            }
            if (w.getTrainingVersion().equals(SupportedVersions.getClosestSupportedPythonVersion(weightsFormat, version))) {
                return true;
            }
            return VersionStringUtils.areTheyTheSameVersionUntilPoint(w.getTrainingVersion(), version, 2);
        }).findFirst().orElse(null);
        if (ww == null) {
            throw new IllegalArgumentException("JDLL does not support the provided weight format: " + weightsFormat + " . Supported weight formats are: " + kerasIdentifier + ", " + onnxIdentifier + ", " + torchIdentifier + ", " + tfIdentifier + ", " + tfJsIdentifier + ", " + torchscriptIdentifier + ", " + bioengineIdentifier);
        }
        return ww;
    }

    public List<WeightFormat> gettAllSupportedWeightObjects() {
        return this.weightsDic.values().stream().collect(Collectors.toList());
    }

    public List<String> getSupportedWeightNamesAndVersion() {
        return this.weightsDic.keySet().stream().collect(Collectors.toList());
    }

    public List<String> getAllSuportedWeightNames() {
        return this.weightsDic.entrySet().stream().map(i -> ((WeightFormat)i.getValue()).getFramework()).distinct().collect(Collectors.toList());
    }

    public String getSelectedWeightsIdentifier() {
        return this.selectedEngine;
    }

    public String getWeightsSelectedVersion() throws IOException {
        return this.selectedVersion;
    }

    public WeightFormat getSelectedWeights() {
        return this.selectedWeights;
    }

    public void setSelectedWeightsFormat(String weightFormat, String version) {
        if (weightFormat.startsWith(kerasIdentifier)) {
            this.selectedEngine = kerasIdentifier;
        } else if (weightFormat.startsWith(onnxIdentifier)) {
            this.selectedEngine = onnxIdentifier;
        } else if (weightFormat.startsWith(torchIdentifier)) {
            this.selectedEngine = torchIdentifier;
        } else if (weightFormat.startsWith(tfIdentifier)) {
            this.selectedEngine = tfIdentifier;
        } else if (weightFormat.startsWith(tfJsIdentifier)) {
            this.selectedEngine = tfJsIdentifier;
        } else if (weightFormat.startsWith(torchscriptIdentifier)) {
            this.selectedEngine = torchscriptIdentifier;
        } else if (weightFormat.startsWith(bioengineIdentifier)) {
            this.selectedEngine = bioengineIdentifier;
        } else {
            throw new IllegalArgumentException("Unsupported Deep Learning framework for JDLL.");
        }
        this.selectedVersion = version;
        this.setSelectedWeights(weightFormat, version);
    }

    private void setSelectedWeights(String weightFormat, String version) {
        this.selectedWeights = this.getSupportedWeightObject(weightFormat, version);
    }

    public void setWeightsAsLoaded() {
        if (this.selectedWeights != null) {
            loadedWeights.put(this.selectedWeights.getFramework(), this.selectedWeights);
        }
    }

    private String torchscriptEngineName(TorchscriptWeights ww) {
        String name = torchscriptIdentifier + "_v";
        String suffix = ww.getTrainingVersion();
        if (suffix == null) {
            boolean exist = true;
            suffix = "Unknown";
            int c = 0;
            while (exist) {
                if (!this.weightsDic.keySet().contains(name + suffix + c)) {
                    suffix = suffix + c;
                    exist = false;
                }
                ++c;
            }
        }
        return name + suffix;
    }

    private String tfJsEngineName(TfJsWeights ww) {
        String name = tfJsIdentifier + "_v";
        String suffix = ww.getTrainingVersion();
        if (suffix == null) {
            boolean exist = true;
            suffix = "Unknown";
            int c = 0;
            while (exist) {
                if (!this.weightsDic.keySet().contains(name + suffix + c)) {
                    suffix = suffix + c;
                    exist = false;
                }
                ++c;
            }
        }
        return name + suffix;
    }

    private String onnxEngineName(OnnxWeights ww) {
        String name = onnxIdentifier + "_v";
        String suffix = ww.getTrainingVersion();
        if (suffix == null) {
            boolean exist = true;
            suffix = "Unknown";
            int c = 0;
            while (exist) {
                if (!this.weightsDic.keySet().contains(name + suffix + c)) {
                    suffix = suffix + c;
                    exist = false;
                }
                ++c;
            }
        }
        return name + suffix;
    }

    private String tfEngineName(TfWeights ww) {
        String name = tfIdentifier + "_v";
        String suffix = ww.getTrainingVersion();
        if (suffix == null) {
            boolean exist = true;
            suffix = "Unknown";
            int c = 0;
            while (exist) {
                if (!this.weightsDic.keySet().contains(name + suffix + c)) {
                    suffix = suffix + c;
                    exist = false;
                }
                ++c;
            }
        }
        return name + suffix;
    }

    private String torchEngineName(PytorchWeights ww) {
        String name = torchIdentifier + "_v";
        String suffix = ww.getTrainingVersion();
        if (suffix == null) {
            boolean exist = true;
            suffix = "Unknown";
            int c = 0;
            while (exist) {
                if (!this.weightsDic.keySet().contains(name + suffix + c)) {
                    suffix = suffix + c;
                    exist = false;
                }
                ++c;
            }
        }
        return name + suffix;
    }

    private String kerasEngineName(KerasWeights ww) {
        String name = kerasIdentifier + "_v";
        String suffix = ww.getTrainingVersion();
        if (suffix == null) {
            boolean exist = true;
            suffix = "Unknown";
            int c = 0;
            while (exist) {
                if (!this.weightsDic.keySet().contains(name + suffix + c)) {
                    suffix = suffix + c;
                    exist = false;
                }
                ++c;
            }
        }
        return name + suffix;
    }

    public static String getKerasID() {
        return kerasIdentifier;
    }

    public static String getOnnxID() {
        return onnxIdentifier;
    }

    public static String getPytorchID() {
        return torchIdentifier;
    }

    public static String getTensorflowJsID() {
        return tfJsIdentifier;
    }

    public static String getTensorflowID() {
        return tfIdentifier;
    }

    public static String getTorchscriptID() {
        return torchscriptIdentifier;
    }

    public static String getBioengineID() {
        return bioengineIdentifier;
    }
}

