package io.bioimage.modelrunner.bioimageio.description.weights;

import io.bioimage.modelrunner.engine.engines.OnnxEngine;
import io.bioimage.modelrunner.versionmanagement.SupportedVersions;
import io.bioimage.modelrunner.versionmanagement.VersionStringUtils;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/* loaded from: input_file:io/bioimage/modelrunner/bioimageio/description/weights/ModelWeight.class */
public class ModelWeight {
    private String selectedEngine;
    private String selectedVersion;
    private WeightFormat selectedWeights;
    private HashMap<String, WeightFormat> weightsDic;
    private static Map<String, WeightFormat> loadedWeights = new HashMap();
    private static String kerasIdentifier = "keras_hdf5";
    private static String onnxIdentifier = OnnxEngine.NAME;
    private static String torchIdentifier = "pytorch_state_dict";
    private static String tfIdentifier = "tensorflow_saved_model_bundle";
    private static String tfJsIdentifier = "tensorflow_js";
    private static String torchscriptIdentifier = "torchscript";
    private static String bioengineIdentifier = "bioengine";

    public static ModelWeight build(Map<String, Object> map) {
        ModelWeight modelWeight = new ModelWeight();
        Set<String> keySet = map.keySet();
        modelWeight.weightsDic = new HashMap<>();
        for (String str : keySet) {
            Map map2 = (Map) map.get(str);
            if (map2 != null) {
                if (str.contentEquals(kerasIdentifier)) {
                    KerasWeights kerasWeights = new KerasWeights(map2);
                    modelWeight.weightsDic.put(modelWeight.kerasEngineName(kerasWeights), kerasWeights);
                } else if (str.contentEquals(onnxIdentifier)) {
                    OnnxWeights onnxWeights = new OnnxWeights(map2);
                    modelWeight.weightsDic.put(modelWeight.onnxEngineName(onnxWeights), onnxWeights);
                } else if (str.contentEquals(torchIdentifier)) {
                    PytorchWeights pytorchWeights = new PytorchWeights(map2);
                    modelWeight.weightsDic.put(modelWeight.torchEngineName(pytorchWeights), pytorchWeights);
                } else if (str.contentEquals(tfIdentifier)) {
                    TfWeights tfWeights = new TfWeights(map2);
                    modelWeight.weightsDic.put(modelWeight.tfEngineName(tfWeights), tfWeights);
                } else if (str.contentEquals(tfJsIdentifier)) {
                    TfJsWeights tfJsWeights = new TfJsWeights(map2);
                    modelWeight.weightsDic.put(modelWeight.tfJsEngineName(tfJsWeights), tfJsWeights);
                } else if (str.contentEquals(torchscriptIdentifier) || str.contentEquals("pytorch_script")) {
                    TorchscriptWeights torchscriptWeights = new TorchscriptWeights(map2);
                    modelWeight.weightsDic.put(modelWeight.torchscriptEngineName(torchscriptWeights), torchscriptWeights);
                }
            }
        }
        return modelWeight;
    }

    public WeightFormat getModelWeights(String str) {
        if (this.weightsDic.get(str) != null) {
            return this.weightsDic.get(str);
        }
        String orElse = this.weightsDic.keySet().stream().filter(str2 -> {
            return str2.startsWith(str);
        }).findFirst().orElse(null);
        if (orElse == null) {
            return null;
        }
        return this.weightsDic.get(orElse);
    }

    public WeightFormat getSupportedWeightObject(String str, String str2) throws IllegalArgumentException {
        if (str.equals(getBioengineID())) {
            return null;
        }
        WeightFormat orElse = this.weightsDic.values().stream().filter(weightFormat -> {
            if (weightFormat.getFramework().equals(str)) {
                return weightFormat.getTrainingVersion().equals(SupportedVersions.getClosestSupportedPythonVersion(str, str2)) || VersionStringUtils.areTheyTheSameVersionUntilPoint(weightFormat.getTrainingVersion(), str2, 2);
            }
            return false;
        }).findFirst().orElse(null);
        if (orElse == null) {
            throw new IllegalArgumentException("JDLL does not support the provided weight format: " + str + " . Supported weight formats are: " + kerasIdentifier + ", " + onnxIdentifier + ", " + torchIdentifier + ", " + tfIdentifier + ", " + tfJsIdentifier + ", " + torchscriptIdentifier + ", " + bioengineIdentifier);
        }
        return orElse;
    }

    public List<WeightFormat> gettAllSupportedWeightObjects() {
        return (List) this.weightsDic.values().stream().collect(Collectors.toList());
    }

    public List<String> getSupportedWeightNamesAndVersion() {
        return (List) this.weightsDic.keySet().stream().collect(Collectors.toList());
    }

    public List<String> getAllSuportedWeightNames() {
        return (List) this.weightsDic.entrySet().stream().map(entry -> {
            return ((WeightFormat) entry.getValue()).getFramework();
        }).distinct().collect(Collectors.toList());
    }

    public String getSelectedWeightsIdentifier() {
        return this.selectedEngine;
    }

    public String getWeightsSelectedVersion() throws IOException {
        return this.selectedVersion;
    }

    public WeightFormat getSelectedWeights() {
        return this.selectedWeights;
    }

    public void setSelectedWeightsFormat(String str, String str2) {
        if (str.startsWith(kerasIdentifier)) {
            this.selectedEngine = kerasIdentifier;
        } else if (str.startsWith(onnxIdentifier)) {
            this.selectedEngine = onnxIdentifier;
        } else if (str.startsWith(torchIdentifier)) {
            this.selectedEngine = torchIdentifier;
        } else if (str.startsWith(tfIdentifier)) {
            this.selectedEngine = tfIdentifier;
        } else if (str.startsWith(tfJsIdentifier)) {
            this.selectedEngine = tfJsIdentifier;
        } else if (str.startsWith(torchscriptIdentifier)) {
            this.selectedEngine = torchscriptIdentifier;
        } else {
            if (!str.startsWith(bioengineIdentifier)) {
                throw new IllegalArgumentException("Unsupported Deep Learning framework for JDLL.");
            }
            this.selectedEngine = bioengineIdentifier;
        }
        this.selectedVersion = str2;
        setSelectedWeights(str, str2);
    }

    private void setSelectedWeights(String str, String str2) {
        this.selectedWeights = getSupportedWeightObject(str, str2);
    }

    public void setWeightsAsLoaded() {
        if (this.selectedWeights != null) {
            loadedWeights.put(this.selectedWeights.getFramework(), this.selectedWeights);
        }
    }

    private String torchscriptEngineName(TorchscriptWeights torchscriptWeights) {
        String str = torchscriptIdentifier + "_v";
        String trainingVersion = torchscriptWeights.getTrainingVersion();
        if (trainingVersion == null) {
            boolean z = true;
            trainingVersion = "Unknown";
            int i = 0;
            while (z) {
                if (!this.weightsDic.keySet().contains(str + trainingVersion + i)) {
                    trainingVersion = trainingVersion + i;
                    z = false;
                }
                i++;
            }
        }
        return str + trainingVersion;
    }

    private String tfJsEngineName(TfJsWeights tfJsWeights) {
        String str = tfJsIdentifier + "_v";
        String trainingVersion = tfJsWeights.getTrainingVersion();
        if (trainingVersion == null) {
            boolean z = true;
            trainingVersion = "Unknown";
            int i = 0;
            while (z) {
                if (!this.weightsDic.keySet().contains(str + trainingVersion + i)) {
                    trainingVersion = trainingVersion + i;
                    z = false;
                }
                i++;
            }
        }
        return str + trainingVersion;
    }

    private String onnxEngineName(OnnxWeights onnxWeights) {
        String str = onnxIdentifier + "_v";
        String trainingVersion = onnxWeights.getTrainingVersion();
        if (trainingVersion == null) {
            boolean z = true;
            trainingVersion = "Unknown";
            int i = 0;
            while (z) {
                if (!this.weightsDic.keySet().contains(str + trainingVersion + i)) {
                    trainingVersion = trainingVersion + i;
                    z = false;
                }
                i++;
            }
        }
        return str + trainingVersion;
    }

    private String tfEngineName(TfWeights tfWeights) {
        String str = tfIdentifier + "_v";
        String trainingVersion = tfWeights.getTrainingVersion();
        if (trainingVersion == null) {
            boolean z = true;
            trainingVersion = "Unknown";
            int i = 0;
            while (z) {
                if (!this.weightsDic.keySet().contains(str + trainingVersion + i)) {
                    trainingVersion = trainingVersion + i;
                    z = false;
                }
                i++;
            }
        }
        return str + trainingVersion;
    }

    private String torchEngineName(PytorchWeights pytorchWeights) {
        String str = torchIdentifier + "_v";
        String trainingVersion = pytorchWeights.getTrainingVersion();
        if (trainingVersion == null) {
            boolean z = true;
            trainingVersion = "Unknown";
            int i = 0;
            while (z) {
                if (!this.weightsDic.keySet().contains(str + trainingVersion + i)) {
                    trainingVersion = trainingVersion + i;
                    z = false;
                }
                i++;
            }
        }
        return str + trainingVersion;
    }

    private String kerasEngineName(KerasWeights kerasWeights) {
        String str = kerasIdentifier + "_v";
        String trainingVersion = kerasWeights.getTrainingVersion();
        if (trainingVersion == null) {
            boolean z = true;
            trainingVersion = "Unknown";
            int i = 0;
            while (z) {
                if (!this.weightsDic.keySet().contains(str + trainingVersion + i)) {
                    trainingVersion = trainingVersion + i;
                    z = false;
                }
                i++;
            }
        }
        return str + trainingVersion;
    }

    public static String getKerasID() {
        return kerasIdentifier;
    }

    public static String getOnnxID() {
        return onnxIdentifier;
    }

    public static String getPytorchID() {
        return torchIdentifier;
    }

    public static String getTensorflowJsID() {
        return tfJsIdentifier;
    }

    public static String getTensorflowID() {
        return tfIdentifier;
    }

    public static String getTorchscriptID() {
        return torchscriptIdentifier;
    }

    public static String getBioengineID() {
        return bioengineIdentifier;
    }
}
