/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.runmode.ops;

import io.bioimage.modelrunner.bioimageio.BioimageioRepo;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptor;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorFactory;
import io.bioimage.modelrunner.bioimageio.description.exceptions.ModelSpecsException;
import io.bioimage.modelrunner.download.FileDownloader;
import io.bioimage.modelrunner.runmode.RunMode;
import io.bioimage.modelrunner.runmode.ops.OpInterface;
import io.bioimage.modelrunner.runmode.ops.StardistInferJdllOp;
import io.bioimage.modelrunner.tensor.Tensor;
import io.bioimage.modelrunner.utils.FileUtils;
import io.bioimage.modelrunner.utils.JSONUtils;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.MalformedURLException;
import java.nio.channels.Channel;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.integer.UnsignedShortType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Cast;
import net.imglib2.util.Util;
import net.imglib2.view.IntervalView;
import net.imglib2.view.Views;

public class StardistFineTuneJdllOp
implements OpInterface {
    private String model;
    private String nModelParentPath;
    private String weightsToFineTune;
    private int nChannelsModel;
    private float lr = 1.0E-5f;
    private int batchSize = 16;
    private int epochs = 1;
    private boolean downloadStardistPretrained = false;
    private Tensor<FloatType> trainingSamples;
    private Tensor<UnsignedShortType> groundTruth;
    private String opFilePath;
    private String envPath;
    private LinkedHashMap<String, Object> inputsMap;
    private static final Map<String, String> PRETRAINED_3C_STARDIST_MODELS = new HashMap<String, String>();
    private static final Map<String, String> PRETRAINED_1C_STARDIST_MODELS;
    private static final String STARDIST_CONFIG_KEY = "config";
    private static final String CONFIG_JSON = "config.json";
    private static final String STARDIST_THRES_KEY = "thresholds";
    private static final String THRES_JSON = "thresholds.json";
    private static final String MODEL_KEY = "model";
    private static final String TRAIN_SAMPLES_KEY = "train_samples";
    private static final String GROUND_TRUTH_KEY = "ground_truth";
    private static final String WEIGHTS_TO_FINE_TUNE_KEY = "weights_file";
    private static final String PATCH_SIZE_KEY = "train_patch_size";
    private static final String BATCH_SIZE_KEY = "train_batch_size";
    private static final String LR_KEY = "train_learning_rate";
    private static final String EPOCHS_KEY = "train_epochs";
    private static final String STARDIST_WEIGHTS_FILE = "stardist_weights.h5";
    private static final String KERAS_SUFFIX_FILE = ".h5";
    private static final String DOWNLOAD_STARDIST_KEY = "download_pretrained_stardist";
    private static final String OP_METHOD_NAME = "finetune_stardist";
    private static final int N_STARDIST_OUTPUTS = 1;
    private static final String STARDIST_OP_FNAME = "stardist_fine_tune.py";
    private static final String STARDIST_2D_AXES = "byxc";
    private static final String STARDIST_3D_AXES = "bzyxc";
    private static final String GROUNDTRUTH_AXES = "byx";

    public static void main(String[] args) throws IOException, InterruptedException, Exception {
        ArrayImgFactory imgFactory = new ArrayImgFactory((NativeType)new FloatType());
        Img img1 = imgFactory.create(new long[]{2L, 64L, 64L, 3L});
        Tensor inpTensor = Tensor.build("input0", STARDIST_2D_AXES, img1);
        ArrayImgFactory gtFactory = new ArrayImgFactory((NativeType)new FloatType());
        Img gt = gtFactory.create(new long[]{2L, 64L, 64L});
        Tensor gtTensor = Tensor.build("gt", GROUNDTRUTH_AXES, gt);
        String modelName = "C:\\Users\\angel\\OneDrive\\Documentos\\pasteur\\git\\model-runner-java\\models";
        String p = "C:\\Users\\angel\\OneDrive\\Documentos\\pasteur\\git\\model-runner-java\\models\\finetuned_StarDist H&E Nuclei Segmentation_04102023_123644";
        StardistFineTuneJdllOp op = StardistFineTuneJdllOp.finetuneInPlace(p);
        op.installOp();
        op.setBatchSize(2);
        op.setEpochs(1);
        op.setFineTuningData(inpTensor, gtTensor);
        op.setLearingRate(1.0E-5f);
        RunMode rm = RunMode.createRunMode(op);
        Map<String, Object> aa = rm.runOP();
        System.out.print(false);
    }

    public static StardistFineTuneJdllOp finetuneInPlace(String modelToFineTune, String newModelDir) throws IOException, InterruptedException, IllegalArgumentException, ModelSpecsException {
        Objects.requireNonNull(modelToFineTune, "modelToFineTune' cannot be null. It should correspond to either a Bioimage.io folder containing a StarDist model, the nickname of a StarDist model in the Bioimage.io (example: chatty-frog) or to one if the StarDist pre-trained available weigths (example: 2D_versatile_fluo)");
        Objects.requireNonNull(newModelDir, "newModelDir' cannot be null. It should be a path to the directory where\tthe we want the fine tuned model to be saved.");
        if (!new File(newModelDir).isDirectory()) {
            throw new IllegalArgumentException("Argument 'newModelDir' should be an existing directory. In that directory the fine tuned StarDist model is going to be created.");
        }
        StardistFineTuneJdllOp op = new StardistFineTuneJdllOp();
        op.nModelParentPath = newModelDir;
        op.model = modelToFineTune;
        op.setModel();
        try {
            op.findNChannels();
        }
        catch (Exception e) {
            throw new IllegalArgumentException("Unable to correctly read the rdf.yaml file of Bioimage.io StarDist model at :" + new File(op.model).getParent(), e);
        }
        return op;
    }

    public static StardistFineTuneJdllOp finetuneInPlace(String modelToFineTune) throws IOException, InterruptedException, IllegalArgumentException, ModelSpecsException {
        Objects.requireNonNull(modelToFineTune, "modelToFineTune' cannot be null. It should correspond to either a Bioimage.io folder containing a StarDist model, the nickname of a StarDist model in the Bioimage.io (example: chatty-frog) or to one if the StarDist pre-trained available weigths (example: 2D_versatile_fluo)");
        if (!new File(modelToFineTune).isDirectory()) {
            throw new IllegalArgumentException("Argument 'modelToFineTune' should be an existing directory. That directory should contain the model that wants to be fine-tuned and overwritten.");
        }
        StardistFineTuneJdllOp op = new StardistFineTuneJdllOp();
        op.model = modelToFineTune;
        op.setModel();
        try {
            op.findNChannels();
        }
        catch (Exception e) {
            throw new IllegalArgumentException("Unable to correctly read the rdf.yaml file of Bioimage.io StarDist model at :" + new File(op.model).getParent(), e);
        }
        return op;
    }

    public void setWeightsToFineTune(String weigthsToFineTune) {
        this.weightsToFineTune = weigthsToFineTune;
    }

    public <T extends RealType<T> & NativeType<T>> void setFineTuningData(List<Tensor<T>> trainingSamples, List<Tensor<T>> groundTruth) {
    }

    public <T extends RealType<T> & NativeType<T>> void setFineTuningData(Tensor<T> trainingSamples, Tensor<T> groundTruth) {
        this.checkTrainAndGroundTruthDimensions(trainingSamples, groundTruth);
        this.setTrainingSamples(trainingSamples);
        this.setGroundTruth(groundTruth);
    }

    public void setBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }

    public void setLearingRate(float learningRate) {
        this.lr = learningRate;
    }

    public void setEpochs(int epochs) {
        this.epochs = epochs;
    }

    @Override
    public String getOpPythonFilename() {
        return STARDIST_OP_FNAME;
    }

    @Override
    public int getNumberOfOutputs() {
        return 1;
    }

    @Override
    public boolean isOpInstalled() {
        return true;
    }

    @Override
    public void installOp() {
        this.opFilePath = "C:\\Users\\angel\\OneDrive\\Documentos\\pasteur\\git\\model-runner-java\\python\\ops\\\\stardist_fine_tune";
        this.envPath = "C:\\Users\\angel\\git\\jep\\miniconda\\envs\\stardist";
    }

    @Override
    public LinkedHashMap<String, Object> getOpInputs() {
        this.inputsMap = new LinkedHashMap();
        Objects.requireNonNull(this.trainingSamples, "Please make sure that the training samples have been provided and that they are not null.Use the method: setFineTuningData(Tensor<T> trainingSamples, Tensor<T> groundTruth)");
        this.inputsMap.put(MODEL_KEY, this.model);
        Objects.requireNonNull(this.groundTruth, "Please make sure that the ground truth has been provided and that it is not null.Use the method: setFineTuningData(Tensor<T> trainingSamples, Tensor<T> groundTruth)");
        this.inputsMap.put(TRAIN_SAMPLES_KEY, this.trainingSamples);
        this.inputsMap.put(GROUND_TRUTH_KEY, this.groundTruth);
        try {
            this.setUpConfigs();
        }
        catch (ModelSpecsException | IOException e) {
            throw new RuntimeException(e);
        }
        if (this.weightsToFineTune != null) {
            this.setWeigthsFile();
        }
        return this.inputsMap;
    }

    private void setWeigthsFile() {
        if (!this.weightsToFineTune.endsWith(KERAS_SUFFIX_FILE)) {
            throw new IllegalArgumentException("StarDist weigths files must always end with '.h5' and the provided file does not: " + this.weightsToFineTune);
        }
        if (new File(this.weightsToFineTune).isFile() && !new File(this.weightsToFineTune).getParent().equals(this.model)) {
            throw new IllegalArgumentException("StarDist weigths files that can be fine tuned with this modelshould be in the folder: " + this.model);
        }
        if (!new File(this.weightsToFineTune).isFile() && new File(this.model, this.weightsToFineTune).isFile()) {
            throw new IllegalArgumentException("The StarDist weigths file provided (" + this.weightsToFineTune + ") cannot be found in the StarDist model folder : " + this.model);
        }
        this.inputsMap.put(WEIGHTS_TO_FINE_TUNE_KEY, this.weightsToFineTune);
    }

    @Override
    public String getCondaEnv() {
        return this.envPath;
    }

    @Override
    public String getMethodName() {
        return OP_METHOD_NAME;
    }

    @Override
    public String getOpDir() {
        return this.opFilePath;
    }

    private void setModel() throws IOException, InterruptedException, IllegalArgumentException, ModelSpecsException {
        Objects.requireNonNull(this.model, "The modelName input argument cannot be null.");
        if (PRETRAINED_1C_STARDIST_MODELS.keySet().contains(this.model) || PRETRAINED_3C_STARDIST_MODELS.keySet().contains(this.model)) {
            this.downloadStardistPretrained = true;
            this.setUpStardistModelFromStardistRepo();
            return;
        }
        if (new File(this.model).isDirectory() && !new File(this.model, "rdf.yaml").isFile()) {
            throw new IllegalArgumentException("The directory selected does not correspond to a valid Bioimage.io model, it does not contain the required specs file: rdf.yaml");
        }
        if (new File(this.model).isDirectory() && !StardistInferJdllOp.isModelFileStardist(this.model + File.separator + "rdf.yaml")) {
            throw new IllegalArgumentException("The directory selected does not correspond to a Bioimage.io StarDist model, as per its specs file: rdf.yaml");
        }
        if (new File(this.model).isDirectory()) {
            this.setUpStardistModelFromLocal();
        } else {
            if (!new File(this.model).isDirectory() && !StardistInferJdllOp.isModelNameStardist(this.model)) {
                throw new IllegalArgumentException("The model name provided does not correspond to a valid Stardist model present in the Bioimage.io online reposritory.");
            }
            if (!new File(this.model).isDirectory()) {
                this.setUpStardistModelFromBioimageio();
            } else {
                throw new IllegalArgumentException("Cannot recognise the model provided as a StarDist model. You can provide either the name of a StarDist model in the Bioimage.io, the path to a Bioimage.io StarDist model (parent dir of the rdf.yaml file) or the name of a pre-trained StarDist model.");
            }
        }
    }

    private void setUpStardistModelFromStardistRepo() throws IOException, InterruptedException, IllegalArgumentException, ModelSpecsException {
        if (PRETRAINED_1C_STARDIST_MODELS.get(this.model) != null) {
            this.model = PRETRAINED_1C_STARDIST_MODELS.get(this.model);
            this.setUpStardistModelFromBioimageio();
        } else if (PRETRAINED_3C_STARDIST_MODELS.get(this.model) != null) {
            this.model = PRETRAINED_3C_STARDIST_MODELS.get(this.model);
            this.setUpStardistModelFromBioimageio();
        }
    }

    private void setUpStardistModelFromBioimageio() throws IOException, InterruptedException, IllegalArgumentException, ModelSpecsException {
        BioimageioRepo br = BioimageioRepo.connect();
        if (br.selectByName(this.model) != null) {
            this.model = br.downloadByName(this.model, this.nModelParentPath);
        } else if (br.selectByID(this.model) != null) {
            this.model = br.downloadModelByID(this.model, this.nModelParentPath);
        }
        File folder = new File(this.model);
        String fineTuned = folder.getParent() + File.separator + "finetuned_" + folder.getName();
        File renamedFolder = new File(fineTuned);
        if (folder.renameTo(renamedFolder)) {
            this.model = fineTuned;
        }
        this.downloadBioimageioStardistWeights();
    }

    private void downloadBioimageioStardistWeights() throws IllegalArgumentException, IOException, ModelSpecsException {
        File stardistSubfolder = new File(this.model, "stardist");
        if (!stardistSubfolder.exists() && !stardistSubfolder.mkdirs()) {
            throw new IOException("Unable to create folder named 'stardist' at: " + this.model);
        }
        this.setUpKerasWeights();
    }

    private void setUpConfigs() throws IOException, ModelSpecsException {
        String rdfDir = new File(this.model).getParent();
        if (new File(rdfDir + File.separator + "rdf.yaml").exists()) {
            this.setUpConfigsBioimageio();
        } else {
            if (!new File(this.model + File.separator + CONFIG_JSON).exists()) {
                throw new IOException("Missing necessary file for StarDist: config.json");
            }
            if (!new File(this.model + File.separator + THRES_JSON).exists()) {
                throw new IOException("Missing necessary file for StarDist: thresholds.json");
            }
            Map<String, Object> config = JSONUtils.load(this.model + File.separator + CONFIG_JSON);
            int w = this.trainingSamples.getShape()[this.trainingSamples.getAxesOrderString().indexOf("x")];
            int h = this.trainingSamples.getShape()[this.trainingSamples.getAxesOrderString().indexOf("y")];
            config.put(PATCH_SIZE_KEY, new int[]{w, h});
            config.put(BATCH_SIZE_KEY, this.batchSize);
            config.put(LR_KEY, Float.valueOf(this.lr));
            config.put(EPOCHS_KEY, this.epochs);
            JSONUtils.writeJSONFile(this.model + File.separator + CONFIG_JSON, config);
        }
    }

    private void setUpConfigsBioimageio() throws IOException, ModelSpecsException {
        String rdfDir = new File(this.model).getParent();
        ModelDescriptor descriptor = ModelDescriptorFactory.readFromLocalFile(rdfDir + File.separator + "rdf.yaml");
        Object stardistInfo = descriptor.getConfig().getSpecMap().get("stardist");
        if (stardistInfo == null || !(stardistInfo instanceof Map)) {
            throw new IllegalArgumentException("The rdf.yaml file of the Bioimage.io StarDist model at: " + this.model + " is invalid. The field config>stardist is missing. Look for StarDist models in the Bioimage.io repo to see how the rdf.yaml should look like.");
        }
        Object config = ((Map)stardistInfo).get(STARDIST_CONFIG_KEY);
        if (config == null || !(config instanceof Map)) {
            throw new IllegalArgumentException("The rdf.yaml file of the Bioimage.io StarDist model at: " + this.model + " is invalid. The field config>stardist>" + STARDIST_CONFIG_KEY + " is missing. Look for StarDist models in the Bioimage.io repo to see how the rdf.yaml should look like.");
        }
        Object thres = ((Map)stardistInfo).get(STARDIST_THRES_KEY);
        if (thres == null || !(thres instanceof Map)) {
            throw new IllegalArgumentException("The rdf.yaml file of the Bioimage.io StarDist model at: " + this.model + " is invalid. The field config>stardist>" + STARDIST_THRES_KEY + " is missing. Look for StarDist models in the Bioimage.io repo to see how the rdf.yaml should look like.");
        }
        int w = this.trainingSamples.getShape()[this.trainingSamples.getAxesOrderString().indexOf("x")];
        int h = this.trainingSamples.getShape()[this.trainingSamples.getAxesOrderString().indexOf("y")];
        ((Map)config).put(PATCH_SIZE_KEY, new int[]{w, h});
        ((Map)config).put(BATCH_SIZE_KEY, this.batchSize);
        ((Map)config).put(LR_KEY, Float.valueOf(this.lr));
        ((Map)config).put(EPOCHS_KEY, this.epochs);
        JSONUtils.writeJSONFile(this.model + File.separator + CONFIG_JSON, (Map)config);
        JSONUtils.writeJSONFile(this.model + File.separator + THRES_JSON, (Map)thres);
    }

    private void setUpKerasWeights() throws IOException, ModelSpecsException {
        String rdfYamlFN = this.model + File.separator + "rdf.yaml";
        ModelDescriptor descriptor = ModelDescriptorFactory.readFromLocalFile(rdfYamlFN);
        String stardistWeights = this.model + File.separator + "stardist";
        stardistWeights = stardistWeights + File.separator + STARDIST_WEIGHTS_FILE;
        String stardistWeightsParent = this.model + File.separator + STARDIST_WEIGHTS_FILE;
        this.model = this.model + File.separator + "stardist";
        if (new File(stardistWeights).exists()) {
            return;
        }
        if (new File(stardistWeights).exists()) {
            try {
                Files.copy(Paths.get(stardistWeightsParent, new String[0]), Paths.get(stardistWeights, new String[0]), StandardCopyOption.REPLACE_EXISTING);
                return;
            }
            catch (IOException iOException) {
                // empty catch block
            }
        }
        StardistFineTuneJdllOp.downloadFileFromInternet(StardistFineTuneJdllOp.getKerasWeigthsLink(descriptor), new File(stardistWeights));
    }

    private static String getKerasWeigthsLink(ModelDescriptor descriptor) throws IOException {
        for (String url : descriptor.getAttachments()) {
            try {
                if (!FileDownloader.getFileNameFromURLString(url).equals(STARDIST_WEIGHTS_FILE)) continue;
                return url;
            }
            catch (MalformedURLException malformedURLException) {
            }
        }
        throw new IOException("Stardist rdf.yaml file at : " + descriptor.getModelPath() + " is invalid, as it does not contain the URL to StarDist Keras weights in the attachements field. Look for a StarDist model on the Bioimage.io repository for an example of a correct version.");
    }

    private void setUpStardistModelFromLocal() throws IllegalArgumentException, IOException, ModelSpecsException {
        if (this.nModelParentPath == null) {
            File folder = new File(this.model);
            String fineTuned = folder.getParent() + File.separator + "finetuned_" + folder.getName();
            String fineTunedAux = "" + fineTuned;
            int c = 1;
            while (new File(fineTuned).exists()) {
                fineTuned = fineTunedAux + "-" + c++;
            }
            if (folder.renameTo(new File(fineTuned))) {
                this.model = fineTuned;
            }
        } else {
            File folder = new File(this.model);
            String fineTuned = this.nModelParentPath + File.separator + "finetuned_" + folder.getName();
            String fineTunedAux = "" + fineTuned;
            int c = 1;
            while (new File(fineTuned).exists()) {
                fineTuned = fineTunedAux + "-" + c++;
            }
            if (!new File(fineTuned).mkdirs()) {
                throw new IOException("Unable to create directory for fine tuned model at: " + fineTuned);
            }
            Files.copy(Paths.get(this.model, "rdf.yaml"), Paths.get(fineTuned, "rdf.yaml"), StandardCopyOption.REPLACE_EXISTING);
            if (new File(this.model + File.separator + "stardist").isDirectory()) {
                try {
                    FileUtils.copyFolder(Paths.get(this.model, "stardist"), Paths.get(fineTuned, "stardist"));
                }
                catch (IOException iOException) {
                    // empty catch block
                }
            }
            this.model = fineTuned;
        }
        this.downloadBioimageioStardistWeights();
    }

    private <T extends RealType<T> & NativeType<T>> void checkTrainAndGroundTruthDimensions(Tensor<T> trainingSamples, Tensor<T> groundTruth) {
        StardistFineTuneJdllOp.checkTrainingSamplesTensorDimsForStardist(trainingSamples);
        StardistFineTuneJdllOp.checkGroundTruthTensorDimsForStardist(groundTruth);
        int trW = trainingSamples.getShape()[trainingSamples.getAxesOrderString().indexOf("x")];
        int trH = trainingSamples.getShape()[trainingSamples.getAxesOrderString().indexOf("y")];
        int trB = trainingSamples.getShape()[trainingSamples.getAxesOrderString().indexOf("b")];
        int gtW = groundTruth.getShape()[groundTruth.getAxesOrderString().indexOf("x")];
        int gtH = groundTruth.getShape()[groundTruth.getAxesOrderString().indexOf("y")];
        int gtB = groundTruth.getShape()[groundTruth.getAxesOrderString().indexOf("b")];
        if (gtW != trW) {
            throw new IllegalArgumentException("Training samples (" + trW + ") and ground truth (" + gtW + ") width (x-axis) must be the same.");
        }
        if (trH != gtH) {
            throw new IllegalArgumentException("Training samples (" + trH + ") and ground truth (" + gtH + ") height (y-axis) must be the same.");
        }
        if (trB != gtB) {
            throw new IllegalArgumentException("Training samples (" + trB + ") and ground truth (" + gtB + ") batch size (b-axis) must be the same.");
        }
    }

    private static <T extends RealType<T> & NativeType<T>> void checkTrainingSamplesTensorDimsForStardist(Tensor<T> trainingSamples) {
        String axes = trainingSamples.getAxesOrderString();
        String stardistAxes = STARDIST_2D_AXES;
        if (axes.length() == 5) {
            stardistAxes = STARDIST_3D_AXES;
        } else if (axes.length() != 5 && axes.length() != 4) {
            throw new IllegalArgumentException("Training input tensors should have 4 dimensions (byxc) or 5 (bzyxc), but it has " + axes.length() + " (" + axes + ").");
        }
        StardistFineTuneJdllOp.checkDimOrderAndTranspose(trainingSamples, stardistAxes, "training input");
    }

    private static <T extends RealType<T> & NativeType<T>> void checkGroundTruthTensorDimsForStardist(Tensor<T> gt) {
        String axes = gt.getAxesOrderString();
        String stardistAxes = GROUNDTRUTH_AXES;
        if (axes.length() != GROUNDTRUTH_AXES.length()) {
            throw new IllegalArgumentException("Ground truth tensors should have 3 dimensions (byx), but it has " + axes.length() + " (" + axes + ").");
        }
        StardistFineTuneJdllOp.checkDimOrderAndTranspose(gt, stardistAxes, "ground truth");
    }

    private static <T extends RealType<T> & NativeType<T>> void checkDimOrderAndTranspose(Tensor<T> tensor, String stardistAxes, String errMsgObject) {
        for (int c = 0; c < stardistAxes.length(); ++c) {
            String axes = tensor.getAxesOrderString();
            int trueInd = axes.indexOf(stardistAxes.split("")[c]);
            if (trueInd == -1) {
                throw new IllegalArgumentException("The " + errMsgObject + " tensors provided should have dimension '" + stardistAxes.split("")[c] + "' in the axes order, but it does not (" + axes + ").");
            }
            if (trueInd == c) continue;
            IntervalView wrapImg = Views.permute(tensor.getData(), (int)trueInd, (int)c);
            StringBuilder nAxes = new StringBuilder(axes);
            nAxes.setCharAt(c, stardistAxes.charAt(c));
            nAxes.setCharAt(trueInd, axes.charAt(c));
            tensor = Tensor.build(tensor.getName(), nAxes.toString(), wrapImg);
            c = 0;
        }
    }

    private <T extends RealType<T> & NativeType<T>> void setTrainingSamples(Tensor<T> trainingSamples) {
        int tensorChannels = trainingSamples.getShape()[trainingSamples.getAxesOrderString().indexOf("c")];
        if (this.nChannelsModel != tensorChannels) {
            throw new IllegalArgumentException("The pre-trained selected model only supports " + this.nChannelsModel + "-channel inputs whereas the provided training input tensor has " + tensorChannels + " channels.");
        }
        this.trainingSamples = !(Util.getTypeFromInterval(trainingSamples.getData()) instanceof FloatType) ? Tensor.createCopyOfTensorInWantedDataType(trainingSamples, new FloatType()) : (Tensor)Cast.unchecked(trainingSamples);
    }

    private <T extends RealType<T> & NativeType<T>> void setGroundTruth(Tensor<T> groundTruth) {
        this.groundTruth = !(Util.getTypeFromInterval(groundTruth.getData()) instanceof UnsignedShortType) ? Tensor.createCopyOfTensorInWantedDataType(groundTruth, new UnsignedShortType()) : (Tensor)Cast.unchecked(groundTruth);
    }

    private void findNChannels() throws Exception {
        if (this.downloadStardistPretrained && PRETRAINED_1C_STARDIST_MODELS.keySet().contains(this.model)) {
            this.nChannelsModel = 1;
        } else if (this.downloadStardistPretrained && PRETRAINED_3C_STARDIST_MODELS.keySet().contains(this.model)) {
            this.nChannelsModel = 3;
        }
        String rdfFileName = new File(this.model).getParentFile() + File.separator + "rdf.yaml";
        ModelDescriptor descriptor = ModelDescriptorFactory.readFromLocalFile(rdfFileName);
        int cInd = descriptor.getInputTensors().get(0).getAxesOrder().indexOf("c");
        this.nChannelsModel = descriptor.getInputTensors().get(0).getMinTileSizeArr()[cInd];
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static void downloadFileFromInternet(String downloadURL, File targetFile) {
        FileOutputStream fos = null;
        Channel rbc = null;
        try {
            FileDownloader downloader = new FileDownloader(downloadURL, targetFile);
            downloader.download();
        }
        catch (IOException | ExecutionException e) {
            String msg = "The link for the file: " + targetFile.getName() + " is broken.";
            new IOException(msg, e).printStackTrace();
        }
        finally {
            try {
                if (fos != null) {
                    fos.close();
                }
                if (rbc != null) {
                    rbc.close();
                }
            }
            catch (IOException e) {
                e.printStackTrace();
            }
        }
    }

    static {
        PRETRAINED_3C_STARDIST_MODELS.put("2D_versatile_fluo", "fearless-crab");
        PRETRAINED_3C_STARDIST_MODELS.put("2D_paper_dsb2018", null);
        PRETRAINED_1C_STARDIST_MODELS = new HashMap<String, String>();
        PRETRAINED_1C_STARDIST_MODELS.put("2D_versatile_he", "chatty-frog");
    }
}

