package plugins.adufour.roi;

import java.awt.Point;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;

import org.apache.commons.math3.ml.clustering.Cluster;
import org.apache.commons.math3.ml.clustering.Clusterable;
import org.apache.commons.math3.ml.clustering.Clusterer;
import org.apache.commons.math3.ml.clustering.KMeansPlusPlusClusterer;

import icy.math.ArrayMath;
import icy.roi.ROI;
import icy.roi.ROI2D;
import icy.roi.ROI3D;
import icy.type.point.Point3D;
import plugins.adufour.ezplug.EzLabel;
import plugins.adufour.ezplug.EzPlug;
import plugins.adufour.ezplug.EzVar;
import plugins.adufour.ezplug.EzVarEnum;
import plugins.adufour.ezplug.EzVarInteger;
import plugins.adufour.ezplug.EzVarListener;
import plugins.adufour.ezplug.EzVarSequence;
import plugins.kernel.roi.roi2d.ROI2DArea;
import plugins.kernel.roi.roi3d.ROI3DArea;

public class SplitROI extends EzPlug// implements ROIBlock
{
    private enum Mode
    {
        MANUAL("Splits the selected ROI into a given number of components"),
        AUTOMATIC("Calculates the median size of all ROI and\nsplits the largest ROI into as many parts as necessary\nto fit the median size");
        
        final String desc;
        
        Mode(String description)
        {
            desc = description;
        }
    }
    
    EzVarSequence   sequence     = new EzVarSequence("Select ROI from");
    
    EzVarEnum<Mode> mode         = new EzVarEnum<Mode>("Split mode", Mode.values());
    
    EzLabel         description  = new EzLabel(mode.getValue().desc);
    
    EzVarInteger    nbComponents = new EzVarInteger("Number of components", 2, 2, 10, 1);
    
    @Override
    public void clean()
    {
    }
    
    @Override
    protected void initialize()
    {
        addEzComponent(sequence);
        addEzComponent(mode);
        addEzComponent(description);
        mode.addVarChangeListener(new EzVarListener<SplitROI.Mode>()
        {
            @Override
            public void variableChanged(EzVar<Mode> source, Mode newValue)
            {
                description.setText(newValue.desc);
            }
        });
        addEzComponent(nbComponents);
        mode.addVisibilityTriggerTo(nbComponents, Mode.MANUAL);
    }
    
    @Override
    protected void execute()
    {
        HashMap<ROI, Integer> largeROI = new HashMap<ROI, Integer>();
        
        if (mode.getValue() == Mode.AUTOMATIC)
        {
            List<ROI> sequenceROI = sequence.getValue(true).getROIs();
            
            double[] volumes = new double[sequenceROI.size()];
            for (int i = 0; i < volumes.length; i++)
                volumes[i] = sequenceROI.get(i).getNumberOfPoints();
            double median = ArrayMath.median(volumes, true);
            
            for (ROI roi : sequenceROI)
            {
                int nbObjects = (int) (Math.round(roi.getNumberOfPoints() / median));
                if (nbObjects >= 2) largeROI.put(roi, nbObjects);
            }
        }
        else if (mode.getValue() == Mode.MANUAL)
        {
            for (ROI roi : sequence.getValue(true).getSelectedROIs())
                largeROI.put(roi, nbComponents.getValue());
        }
        
        for (ROI roi : largeROI.keySet())
        {
            int nbSplits = largeROI.get(roi);
            
            ROI[] splits;
            
            if (roi instanceof ROI2D)
            {
                splits = split((ROI2D) roi, nbSplits);
            }
            else if (roi instanceof ROI3D)
            {
                splits = split((ROI3D) roi, nbSplits);
            }
            else
            {
                System.err.println("[Split ROI] Unsupported ROI: " + roi.getName());
                continue;
            }
            
            if (mode.getValue() == Mode.AUTOMATIC)
            {
                // Check if the split is valid based on convexity
                double hullSize = Convexify.createConvexROI(roi).getNumberOfPoints();
                double newSize = 0;
                for (ROI split : splits)
                    newSize += split.getNumberOfPoints();
                
                // Heuristic check
                if (newSize < hullSize * 0.97)
                {
                    sequence.getValue().removeROI(roi, true);
                    sequence.getValue().addROIs(Arrays.asList(splits), true);
                }
            }
            else if (mode.getValue() == Mode.MANUAL)
            {
                sequence.getValue().removeROI(roi, true);
                sequence.getValue().addROIs(Arrays.asList(splits), true);
            }
        }
    }
    
    /**
     * Splits the specified 2D ROI into the specified number of parts, based on a basic distance
     * clustering approach
     * 
     * @param roi
     * @param nbObjects
     * @return
     */
    public static ROI2D[] split(ROI2D roi, int nbObjects)
    {
        ROI2D[] newROI = new ROI2D[nbObjects];
        
        ArrayList<Clusterable> data = new ArrayList<Clusterable>();
        
        // populate the data set with pixel coordinates
        for (final Point pt : roi.getBooleanMask(true).getPoints())
            data.add(new Clusterable()
            {
                @Override
                public double[] getPoint()
                {
                    return new double[] { pt.x, pt.y };
                }
            });
        
        // Load the clustering algorithm
        Clusterer<Clusterable> clusterer = new KMeansPlusPlusClusterer<Clusterable>(nbObjects);
        // Clusterer<Clusterable> clusterer = new
        // DBSCANClusterer<Clusterable>(roi.getNumberOfContourPoints() / (4 * Math.PI), (int)
        // roi.getNumberOfPoints() / 2);
        
        // Cluster the data and create a new ROI for each cluster
        int clusterID = 0;
        for (Cluster<Clusterable> cluster : clusterer.cluster(data))
        {
            ROI2DArea area = new ROI2DArea();
            area.setT(roi.getT());
            area.setC(roi.getC());
            area.setZ(roi.getZ());
            area.setColor(roi.getColor());
            
            for (Clusterable pt : cluster.getPoints())
            {
                double[] xy = pt.getPoint();
                area.addPoint((int) xy[0], (int) xy[1]);
            }
            newROI[clusterID++] = area;
        }
        
        return newROI;
    }
    
    /**
     * Splits the specified 3D ROI into the specified number of parts, based on a basic distance
     * clustering approach
     * 
     * @param roi
     * @param nbObjects
     * @return
     */
    public static ROI3D[] split(ROI3D roi, int nbObjects)
    {
        ROI3D[] newROI = new ROI3D[nbObjects];
        
        ArrayList<Clusterable> data = new ArrayList<Clusterable>();
        
        // populate the data set with pixel coordinates
        for (final Point3D.Integer pt : roi.getBooleanMask(true).getPoints())
            data.add(new Clusterable()
            {
                @Override
                public double[] getPoint()
                {
                    return new double[] { pt.x, pt.y, pt.z };
                }
            });
        
        // Load the clustering algorithm
        Clusterer<Clusterable> clusterer = new KMeansPlusPlusClusterer<Clusterable>(nbObjects);
        
        // Cluster the data and create a new ROI for each cluster
        int clusterID = 0;
        for (Cluster<Clusterable> cluster : clusterer.cluster(data))
        {
            ROI3DArea area = new ROI3DArea();
            area.setT(roi.getT());
            area.setC(roi.getC());
            area.setColor(roi.getColor());
            
            for (Clusterable pt : cluster.getPoints())
            {
                double[] xyz = pt.getPoint();
                area.addPoint((int) xyz[0], (int) xyz[1], (int) xyz[2]);
            }
            newROI[clusterID++] = area;
        }
        
        return newROI;
    }
}
