/*
 * Decompiled with CFR 0.152.
 */
package io.bioimage.modelrunner.bioimageio.description.stardist;

import io.bioimage.modelrunner.bioimageio.description.Axes;
import io.bioimage.modelrunner.bioimageio.description.Axis;
import io.bioimage.modelrunner.bioimageio.description.AxisV05;
import io.bioimage.modelrunner.bioimageio.description.ModelDescriptorV05;
import io.bioimage.modelrunner.bioimageio.description.TensorSpec;
import io.bioimage.modelrunner.bioimageio.description.TransformSpec;
import io.bioimage.modelrunner.model.special.stardist.StardistAbstract;
import io.bioimage.modelrunner.numpy.DecodeNumpy;
import io.bioimage.modelrunner.tensor.Utils;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.type.NativeType;
import net.imglib2.type.numeric.RealType;
import net.imglib2.view.Views;

public class ModelDescriptorStardistV05
extends ModelDescriptorV05 {
    private List<String> oldOrdersInp = new ArrayList<String>();
    private List<String> oldOrdersOut = new ArrayList<String>();
    private static final String STARDIST_TEST = "stardist_test";

    public ModelDescriptorStardistV05(Map<String, Object> yamlElements) {
        super(yamlElements);
        this.input_tensors = this.buildInputTensorsStardist();
        this.output_tensors = this.buildOutputTensorsStardist();
        this.modifyTestInputs();
        this.modifyTestOutputs();
    }

    @Override
    public boolean areRequirementsInstalled() {
        return StardistAbstract.isInstalled();
    }

    @Override
    public String getModelFamily() {
        return "stardist";
    }

    protected List<TensorSpec> buildInputTensorsStardist() {
        ArrayList<Map<String, Object>> tensors = new ArrayList<Map<String, Object>>();
        for (TensorSpec tt : this.input_tensors) {
            Map<String, Object> map = this.reverseAxesShape(tt);
            this.oldOrdersInp.add(tt.getAxesOrder());
            map.put("id", tt.getName());
            map.put("description", tt.getDescription());
            HashMap<String, String> sampleMap = new HashMap<String, String>();
            sampleMap.put("source", tt.getSampleTensorName());
            map.put("sample_tensor", sampleMap);
            HashMap<String, String> tensorMap = new HashMap<String, String>();
            tensorMap.put("source", tt.getTestTensorName());
            map.put("test_tensor", tensorMap);
            ArrayList<Map<String, Object>> preList = new ArrayList<Map<String, Object>>();
            for (TransformSpec prep : tt.getPreprocessing()) {
                preList.add(prep.getSpecMap());
            }
            map.put("preprocessing", preList);
            tensors.add(map);
        }
        this.yamlElements.put("inputs", tensors);
        return super.buildInputTensors();
    }

    protected List<TensorSpec> buildOutputTensorsStardist() {
        ArrayList<Map<String, Object>> tensors = new ArrayList<Map<String, Object>>();
        for (TensorSpec tt : this.output_tensors) {
            Map<String, Object> map = this.reverseAxesShape(tt);
            map.put("name", tt.getName());
            this.oldOrdersOut.add(tt.getAxesOrder());
            map.put("description", tt.getDescription());
            HashMap<String, String> sampleMap = new HashMap<String, String>();
            sampleMap.put("source", tt.getSampleTensorName());
            map.put("sample_tensor", sampleMap);
            HashMap<String, String> tensorMap = new HashMap<String, String>();
            tensorMap.put("source", tt.getTestTensorName());
            map.put("test_tensor", tensorMap);
            ArrayList<Map<String, Object>> postList = new ArrayList<Map<String, Object>>();
            for (TransformSpec prep : tt.getPostprocessing()) {
                postList.add(prep.getSpecMap());
            }
            map.put("postprocessing", postList);
            tensors.add(map);
        }
        this.yamlElements.put("outputs", tensors);
        return super.buildOutputTensors();
    }

    protected <T extends RealType<T> & NativeType<T>> void modifyTestInputs() {
        if (this.localModelPath == null) {
            return;
        }
        for (int i = 0; i < this.oldOrdersInp.size(); ++i) {
            TensorSpec tt = (TensorSpec)this.input_tensors.get(i);
            String testName = tt.getTestTensorName();
            String newTestName = "stardist_test_input_" + i + ".npy";
            if (new File(this.localModelPath + File.separator + newTestName).exists()) {
                this.setInputTestNpyName(i, newTestName);
                continue;
            }
            try {
                RandomAccessibleInterval im = DecodeNumpy.loadNpy(this.localModelPath + File.separator + testName);
                List<Integer> removeDims = ModelDescriptorStardistV05.removeExtraDims(this.oldOrdersInp.get(i), tt.getAxesOrder());
                String newImAxesOrder = ModelDescriptorStardistV05.getNewAxes(this.oldOrdersInp.get(i), removeDims);
                im = ModelDescriptorStardistV05.getNewRai(im, removeDims);
                im = ModelDescriptorStardistV05.transposeToAxesOrder(im, tt.getAxesOrder(), newImAxesOrder);
                DecodeNumpy.saveNpy(this.localModelPath + File.separator + newTestName, im);
                this.setInputTestNpyName(i, newTestName);
                continue;
            }
            catch (IOException e) {
                // empty catch block
            }
        }
    }

    protected <T extends RealType<T> & NativeType<T>> void modifyTestOutputs() {
        if (this.localModelPath == null) {
            return;
        }
        for (int i = 0; i < this.oldOrdersInp.size(); ++i) {
            TensorSpec tt = (TensorSpec)this.input_tensors.get(i);
            String testName = tt.getTestTensorName();
            String newTestName = "stardist_test_output_" + i + ".npy";
            if (new File(this.localModelPath + File.separator + newTestName).exists()) {
                this.setOutputTestNpyName(i, newTestName);
                continue;
            }
            try {
                RandomAccessibleInterval im = DecodeNumpy.loadNpy(this.localModelPath + File.separator + testName);
                List<Integer> removeDims = ModelDescriptorStardistV05.removeExtraDims(this.oldOrdersInp.get(i), tt.getAxesOrder());
                String newImAxesOrder = ModelDescriptorStardistV05.getNewAxes(this.oldOrdersInp.get(i), removeDims);
                im = ModelDescriptorStardistV05.getNewRai(im, removeDims);
                im = ModelDescriptorStardistV05.transposeToAxesOrder(im, tt.getAxesOrder(), newImAxesOrder);
                if (!new File(this.localModelPath + File.separator + newTestName).isFile()) {
                    DecodeNumpy.saveNpy(this.localModelPath + File.separator + newTestName, im);
                }
                this.setOutputTestNpyName(i, newTestName);
                continue;
            }
            catch (IOException e) {
                // empty catch block
            }
        }
    }

    private Map<String, Object> reverseAxesShape(TensorSpec tt) {
        Axes axes = tt.getAxesInfo();
        boolean is3d = axes.getAxesOrder().contains("z");
        String nAxesOrder = is3d ? "xyzc" : "xyc";
        ArrayList<Map<String, Object>> list = new ArrayList<Map<String, Object>>();
        for (String ax : nAxesOrder.split("")) {
            Axis axis = axes.getAxesList().stream().filter(aa -> aa.getAxis().equals(ax)).findFirst().orElse(null);
            if (axis == null) {
                throw new RuntimeException("Axis '" + ax + "' missing for StarDist");
            }
            list.add(((AxisV05)axis).getOriginalDescription());
        }
        HashMap<String, Object> map = new HashMap<String, Object>();
        map.put("axes", list);
        return map;
    }

    private static <T extends RealType<T> & NativeType<T>> RandomAccessibleInterval<T> getNewRai(RandomAccessibleInterval<T> rai, List<Integer> removeDims) {
        for (int i = 0; i < removeDims.size(); ++i) {
            rai = Views.hyperSlice(rai, (int)removeDims.get(removeDims.size() - 1 - i), (long)0L);
        }
        return rai;
    }

    private static String getNewAxes(String ogAxes, List<Integer> removeDims) {
        String newAxis = "";
        String[] splitAxes = ogAxes.split("");
        for (int i = 0; i < ogAxes.length(); ++i) {
            if (removeDims.contains(i)) continue;
            newAxis = newAxis + splitAxes[i];
        }
        return newAxis;
    }

    private static List<Integer> removeExtraDims(String ogAxes, String targetAxes) {
        ArrayList<Integer> remove = new ArrayList<Integer>();
        int c = -1;
        for (String ax : ogAxes.split("")) {
            ++c;
            if (targetAxes.contains(ax)) continue;
            remove.add(c);
        }
        return remove;
    }

    private static <T extends RealType<T> & NativeType<T>> RandomAccessibleInterval<T> transposeToAxesOrder(RandomAccessibleInterval<T> rai, String ogAxes, String targetAxes) {
        int[] transformation = new int[ogAxes.length()];
        int c = 0;
        for (String ss : targetAxes.split("")) {
            transformation[c++] = ogAxes.indexOf(ss);
        }
        return Utils.rearangeAxes(rai, transformation);
    }
}

