/**
 * 
 */
package plugins.adufour.trackprocessors.speed;

import icy.math.ArrayMath;
import icy.plugin.abstract_.Plugin;
import icy.plugin.interface_.PluginBundled;
import icy.sequence.Sequence;

import java.util.ArrayList;
import java.util.List;

import plugins.adufour.blocks.lang.Block;
import plugins.adufour.blocks.util.VarList;
import plugins.adufour.vars.lang.Var;
import plugins.adufour.vars.lang.VarBoolean;
import plugins.adufour.vars.lang.VarDouble;
import plugins.adufour.vars.lang.VarEnum;
import plugins.adufour.vars.lang.VarSequence;
import plugins.fab.trackmanager.TrackGroup;
import plugins.fab.trackmanager.TrackSegment;
import plugins.nchenouard.spot.Detection;

/**
 * @author Stephane
 */
public class FilterTracks extends Plugin implements Block, PluginBundled
{
    public enum Criterion
    {
        CRITERION_T_START("start time"), CRITERION_CURRENT_T("current time"), CRITERION_T_END("end time"),
        CRITERION_DURATION("track duration"), CRITERION_TOT_DISP("total displacement"),
        CRITERION_NET_DISP("net displacement"), CRITERION_LINEARITY("linearity"), CRITERION_EXTENT("extent"),
        CRITERION_MIN_DISP("minimum speed"), CRITERION_MAX_DISP("maximum speed"), CRITERION_AVG_DISP("average speed");

        String description;

        private Criterion(String description)
        {
            this.description = description;
        }

        @Override
        public String toString()
        {
            return description;
        }
    }

    public enum ComparisonType
    {
        GREATER_THAN(">"), LOWER_THAN("<");

        private String symbol;

        private ComparisonType(String symbol)
        {
            this.symbol = symbol;
        }

        boolean compare(double d1, double d2)
        {
            switch (this)
            {
                case GREATER_THAN:
                    return d1 > d2;
                case LOWER_THAN:
                    return d1 < d2;
                default:
                    throw new UnsupportedOperationException("Unknown operation: " + toString());
            }
        }

        @Override
        public String toString()
        {
            return symbol;
        }
    }

    // VAR
    public final Var<TrackGroup> tracks;
    public final VarSequence sequence;
    public final VarBoolean useRealUnits;
    public final VarEnum<Criterion> criterion;
    public final VarEnum<ComparisonType> operation;
    public final VarDouble value;

    public FilterTracks()
    {
        super();

        tracks = new Var<>("Track group", new TrackGroup(null));
        sequence = new VarSequence("Sequence", null);
        useRealUnits = new VarBoolean("Use real units (\u03BCm/sec)", false);
        criterion = new VarEnum<>("Filter on", Criterion.CRITERION_DURATION);
        operation = new VarEnum<>("Accept if", ComparisonType.GREATER_THAN);
        value = new VarDouble("Value", 0d);
    }

    /**
     * Filter given list of track on a specific criterion (keep tracks replying to the given criterion)
     * 
     * @param tracks
     *        tracks to filter
     * @param criterion
     *        criterion to use to filter tracks
     * @param comp
     *        comparison type
     * @param value
     *        value used for comparison
     * @param sequence
     *        sequence used to retrieve real units from (using pixel size information from metadata).<br>
     *        Keep this value to <code>null</code> if you want use raw pixel / frame units instead.
     * @param setTrackEnabledFlag
     *        if set to <code>TRUE</code> then {@link TrackSegment#setAllDetectionEnabled(boolean)} and {@link Detection#setEnabled(boolean)} properties are
     *        modified
     * @return filtered list of track (tracks replying to the filter criterion)
     */
    public static List<TrackSegment> filterTracks(List<TrackSegment> tracks, Criterion criterion, ComparisonType comp,
            double value, Sequence sequence, boolean setTrackEnabledFlag)
    {
        final List<TrackSegment> result = new ArrayList<>();

        final double xScale;
        final double yScale;
        final double zScale;
        final double tScale;

        if (sequence != null)
        {
            xScale = sequence.getPixelSizeX();
            yScale = sequence.getPixelSizeY();
            zScale = sequence.getPixelSizeZ();
            tScale = sequence.getTimeInterval();
        }
        else
        {
            xScale = 1d;
            yScale = 1d;
            zScale = 1d;
            tScale = 1d;
        }

        for (TrackSegment ts : tracks)
        {
            final int nbDetections = ts.getDetectionList().size();

            // we need at least 2 detections to do a track
            if (nbDetections < 2)
            {
                // add it and pass to next
                result.add(ts);
                continue;
            }

            // extract the raw track data
            final double[] xData = new double[nbDetections];
            final double[] yData = new double[nbDetections];
            final double[] zData = new double[nbDetections];

            // displacement data (size: N-1)
            final double[] disp = new double[nbDetections - 1];
            final double[] xDisp = new double[nbDetections - 1];
            final double[] yDisp = new double[nbDetections - 1];
            final double[] zDisp = new double[nbDetections - 1];

            // store minimum and maximum position to calculate the track extent
            double searchRadius = 0;

            // accumulate data here
            for (int i = 0; i < nbDetections; i++)
            {
                final Detection det = ts.getDetectionAt(i);

                xData[i] = det.getX() * xScale;
                yData[i] = det.getY() * yScale;
                zData[i] = det.getZ() * zScale;

                if (i > 0)
                {
                    // retrieve displacement from previous position
                    xDisp[i - 1] = xData[i] - xData[i - 1];
                    yDisp[i - 1] = yData[i] - yData[i - 1];
                    zDisp[i - 1] = zData[i] - zData[i - 1];
                    disp[i - 1] = Math.sqrt(
                            xDisp[i - 1] * xDisp[i - 1] + yDisp[i - 1] * yDisp[i - 1] + zDisp[i - 1] * zDisp[i - 1]);
                }
            }

            final double startTime = ts.getFirstDetection().getT() * tScale;
            final double endTime = ts.getLastDetection().getT() * tScale;
            final double duration = (nbDetections - 1) * tScale;

            // relative displacement (start to end)
            final double xNetDisp = Math.abs(xData[xData.length - 1] - xData[0]);
            final double yNetDisp = Math.abs(yData[yData.length - 1] - yData[0]);
            final double zNetDisp = Math.abs(zData[zData.length - 1] - zData[0]);
            final double netDisp = Math.sqrt(xNetDisp * xNetDisp + yNetDisp * yNetDisp + zNetDisp * zNetDisp);

            // total (sum of) displacement(s)
            final double totDisp = ArrayMath.sum(ArrayMath.abs(disp, false));

            // linearity (ratio of total to net displacement)
            final double tortuosity = netDisp / totDisp;

            // extent (largest distance between any 2 detections)
            for (int i = 0; i < nbDetections; i++)
            {
                for (int j = i + 1; j < nbDetections; j++)
                {
                    double dx = xData[i] - xData[j];
                    double dy = yData[i] - yData[j];
                    double dz = zData[i] - zData[j];
                    double dist = Math.sqrt(dx * dx + dy * dy + dz * dz);

                    if (dist > searchRadius)
                        searchRadius = dist;
                }
            }

            // convert displacement per frame into speed
            ArrayMath.divide(disp, tScale, disp);

            final double minDisp = ArrayMath.min(disp);
            final double maxDisp = ArrayMath.max(disp);
            final double avgDisp = ArrayMath.mean(disp);

            // by default we put all detection in 'enabled' state
            if (setTrackEnabledFlag)
                ts.setAllDetectionEnabled(true);

            // time to filter
            final boolean accepted;

            switch (criterion)
            {
                case CRITERION_T_START:
                    accepted = comp.compare(startTime, value);
                    break;

                case CRITERION_CURRENT_T:
                    accepted = false;
                    if (setTrackEnabledFlag)
                    {
                        for (Detection det : ts.getDetectionList())
                            det.setEnabled(comp.compare(det.getT() * tScale, value));
                    }
                    break;

                case CRITERION_T_END:
                    accepted = comp.compare(endTime, value);
                    break;

                case CRITERION_DURATION:
                    accepted = comp.compare(duration, value);
                    break;

                case CRITERION_LINEARITY:
                    accepted = comp.compare(tortuosity, value);
                    break;

                case CRITERION_EXTENT:
                    accepted = comp.compare(searchRadius, value);
                    break;

                case CRITERION_TOT_DISP:
                    accepted = comp.compare(totDisp, value);
                    break;

                case CRITERION_NET_DISP:
                    accepted = comp.compare(netDisp, value);
                    break;

                case CRITERION_MIN_DISP:
                    accepted = comp.compare(minDisp, value);
                    break;

                case CRITERION_MAX_DISP:
                    accepted = comp.compare(maxDisp, value);
                    break;

                case CRITERION_AVG_DISP:
                    accepted = comp.compare(avgDisp, value);
                    break;

                default:
                    accepted = false;
                    break;
            }

            // need to filter ?
            if (setTrackEnabledFlag && !accepted)
                ts.setAllDetectionEnabled(false);

            // accepted ? --> add to result list
            if (accepted)
                result.add(ts);
        }

        return result;
    }

    @Override
    public void declareInput(VarList inputMap)
    {
        inputMap.add("trackgroup", tracks);
        inputMap.add("sequence", sequence);
        inputMap.add("useRealUnitss", useRealUnits);
        inputMap.add("criterion", criterion);
        inputMap.add("comp", operation);
        inputMap.add("value", value);

    }

    @Override
    public void declareOutput(VarList outputMap)
    {
        // no output here
    }

    @Override
    public void run()
    {
        // execution from protocol
        final TrackGroup trackGroup = tracks.getValue();

        if (trackGroup != null)
        {
            // build the list of enabled tracks
            final List<TrackSegment> tracksToFilter = new ArrayList<>();

            for (TrackSegment track : trackGroup.getTrackSegmentList())
                if (track.isAllDetectionEnabled())
                    tracksToFilter.add(track);

            // filter tracks (modifying internal track segment property)
            filterTracks(tracksToFilter, criterion.getValue(), operation.getValue(), value.getValue().doubleValue(),
                    useRealUnits.getValue().booleanValue() ? sequence.getValue() : null, true);
        }
    }

    @Override
    public String getMainPluginClassName()
    {
        return SpeedProfiler.class.getName();
    }
}
