package plugins.nchenouard.trackprocessorperformance;

import java.util.ArrayList;

import plugins.fab.trackmanager.TrackSegment;
import plugins.nchenouard.spot.Detection;

/**
 * Utilities to compute several tracking performance criteria for a given pairing between a reference and candidate set of tracks
 *
 * @version February 3, 2012
 * 
 * @author Nicolas Chenouard
 *
 * */

public class PerformanceAnalyzer
{
	final ArrayList<TrackSegment> referenceTracks;
	final ArrayList<TrackSegment> candidateTracks;
	final ArrayList<TrackPair> trackPairs;

	/**
	 * Build the analyzer
	 * @param referenceTracks the set of reference tracks
	 * @param candidateTracks the set of candidate tracks
	 * @param trackPairs the pairing between the set of tracks. Each track in the reference set has to be represented.
	 * */
	public PerformanceAnalyzer(ArrayList<TrackSegment> referenceTracks, ArrayList<TrackSegment> candidateTracks, ArrayList<TrackPair> trackPairs)
	{
		this.referenceTracks = new ArrayList<TrackSegment>();
		this.referenceTracks.addAll(referenceTracks);
		this.candidateTracks = new ArrayList<TrackSegment>();
		this.candidateTracks.addAll(candidateTracks);
		this.trackPairs = new ArrayList<TrackPair>();
		this.trackPairs.addAll(trackPairs);
	}

	/**
	 * @return the number of reference tracks
	 * */
	public int getNumRefTracks()
	{
		return referenceTracks.size();
	}


	/**
	 * @return the total number of detection for reference tracks
	 * */
	public int getNumRefDetections()
	{
		int numDetections = 0;
		for (TrackSegment ts:referenceTracks)
			numDetections+= (ts.getLastDetection().getT() - ts.getFirstDetection().getT() +1);
		return numDetections;
	}


	/**
	 * @return the number of candidate tracks
	 * */
	public int getNumCandidateTracks()
	{
		return candidateTracks.size();
	}

	/**
	 * @return the total number of detection for candidate tracks
	 * */
	public int getNumCandidateDetections()
	{
		int numDetections = 0;
		for (TrackSegment ts:candidateTracks)
			numDetections+= (ts.getLastDetection().getT() - ts.getFirstDetection().getT() +1);
		return numDetections;
	}

	/**
	 * @return the distance between the pairs
	 * */
	public double getPairedTracksDistance(DistanceTypes distType, double maxDist)
	{
		double distance = 0;
		for (TrackPair tp:trackPairs)
		{
			TrackToTrackDistance d = new TrackToTrackDistance(tp.referenceTrack, tp.candidateTrack, distType, maxDist);
			distance += d.distance;
		}
		return distance;
	}

	/**
	 * @return the normalized distance between the pairs (alpha criterion)
	 * */
	public double getPairedTracksNormalizedDistance(DistanceTypes distType, double maxDist)
	{
		double distance = 0;
		for (TrackPair tp:trackPairs)
		{
			TrackToTrackDistance d = new TrackToTrackDistance(tp.referenceTrack, tp.candidateTrack, distType, maxDist);
			distance += d.distance;
		}
		// divide now by the maximum distance that corresponds to reference tracks with no associated tracks
		double normalization = 0;
		for (TrackSegment ts:referenceTracks)
		{
			TrackToTrackDistance d = new TrackToTrackDistance(ts, null, distType, maxDist);
			normalization += d.distance;
		}
		return 1d-distance/normalization;
	}


	/**
	 * @return the full distance between the pairs (beta criterion) that accounts for non-associated candidate tracks
	 * */
	public double getFullTrackingScore(DistanceTypes distType, double maxDist)
	{
		double distance = 0;
		for (TrackPair tp:trackPairs)
		{
			TrackToTrackDistance d = new TrackToTrackDistance(tp.referenceTrack, tp.candidateTrack, distType, maxDist);
			distance += d.distance;
		}
		// compute the bound on the distance
		double bound = 0;
		for (TrackSegment ts:referenceTracks)
		{
			TrackToTrackDistance d = new TrackToTrackDistance(ts, null, distType, maxDist);
			bound += d.distance;
		}
		// compute the penalty for wrong tracks
		double penalty = 0;
		for (TrackSegment ts:candidateTracks)
		{
			boolean found = false;
			for (TrackPair tp:trackPairs)
			{
				if (tp.candidateTrack==ts)
				{
					found = true;
					break;
				}
			}
			if (!found)
			{
				TrackToTrackDistance d = new TrackToTrackDistance(ts, null, distType, maxDist);
				penalty+=d.distance;
			}
		}

		return (bound - distance)/(bound + penalty);
	}

	/**
	 * @return the number of non-associated candidate tracks
	 * */
	public int getNumSpuriousTracks()
	{
		int numSpuriousTracks = 0;
		for (TrackSegment ts:candidateTracks)
		{
			boolean found = false;
			for (TrackPair tp:trackPairs)
			{
				if (tp.candidateTrack==ts)
				{
					found = true;
					break;
				}
			}
			if (!found)
				numSpuriousTracks++;
		}
		return numSpuriousTracks;
	}

	/**
	 * @return the number of non-associated reference tracks (or associated with a dummy track)
	 * */
	public int getNumMissedTracks()
	{
		int numMissedTrack = 0;
		for (TrackSegment ts:referenceTracks)
		{
			boolean found = false;
			for (TrackPair tp:trackPairs)
			{
				if (tp.referenceTrack==ts)
				{
					if (tp.candidateTrack!=null && !tp.candidateTrack.getDetectionList().isEmpty())
						found = true;
					break;
				}
			}
			if (!found)
				numMissedTrack++;
		}
		return numMissedTrack;
	}

	/**
	 * @return the number of pairs between reference and candidate tracks
	 * */
	public int getNumPairedTracks()
	{
		int numCorrectTracks = 0;
		for (TrackSegment ts:candidateTracks)
		{
			boolean found = false;
			for (TrackPair tp:trackPairs)
			{
				if (tp.candidateTrack==ts)
				{
					found = true;
					break;
				}
			}
			if (found)
				numCorrectTracks++;
		}
		return numCorrectTracks;
	}

	/**
	 * @return the total number of paired detections
	 * */
	public int getNumPairedDetections(double maxDist)
	{
		int numRecoveredDetections = 0;
		for (TrackPair tp:trackPairs)
		{
			TrackToTrackDistance d = new TrackToTrackDistance(tp.referenceTrack, tp.candidateTrack, DistanceTypes.DISTANCE_MATCHING, maxDist);
			numRecoveredDetections += d.numMatchingDetections;
		}
		return numRecoveredDetections;
	}

	/**
	 * @return the number of detections for the reference tracks that are not paired to a candidate detection
	 * */
	public int getNumMissedDetections(double maxDist) {
		int numMissedDetections = 0;
		for (TrackPair tp:trackPairs)
		{
			TrackToTrackDistance d = new TrackToTrackDistance(tp.referenceTrack, tp.candidateTrack, DistanceTypes.DISTANCE_MATCHING, maxDist);
			numMissedDetections += d.numNonMatchedDetections;
		}
		return numMissedDetections;
	}

	/**
	 * @return the number of detections for the candidate tracks that are not paired to a reference detection
	 * */
	public int getNumWrongDetections(double maxDist)
	{
		int numSpuriousDetections = 0;
		for (TrackSegment ts:candidateTracks)
		{
			boolean found = false;
			for (TrackPair tp:trackPairs)
			{
				if (tp.candidateTrack==ts)
				{
					TrackToTrackDistance d = new TrackToTrackDistance(tp.referenceTrack, tp.candidateTrack, DistanceTypes.DISTANCE_MATCHING, maxDist);
					numSpuriousDetections += d.numWrongDetections;
					found = true;
					break;
				}
			}
			if (!found)
			{
				//numSpuriousDetections += (ts.getLastDetection().getT() - ts.getFirstDetection().getT() + 1);
				for (Detection d:ts.getDetectionList())
					if(d.getDetectionType()==Detection.DETECTIONTYPE_REAL_DETECTION)
						numSpuriousDetections++;// virtual detections are not considered as spurious detections
			}
		}
		return numSpuriousDetections;
	}
	public ArrayList<Double> getDistanceDetectionList(double maxDist)
	{
		ArrayList<Double> distanceList = new ArrayList<Double>();
		for (TrackPair tp:trackPairs)
		{
			if (tp.candidateTrack!=null && !tp.candidateTrack.getDetectionList().isEmpty())
			{
				distanceList.addAll(getDetectionEuclidianDistances(tp.referenceTrack, tp.candidateTrack));
			}
		}
		return distanceList;
	}
	public double[] getDistanceDetectionData(double maxDist) {
		double sumDistance = 0;
		double sumSquareDistance = 0;
		double minDistance = Double.MAX_VALUE;
		double maxDistance = 0;
		int numDetections = 0;
		for (TrackPair tp:trackPairs)
		{
			if (tp.candidateTrack!=null && !tp.candidateTrack.getDetectionList().isEmpty())
			{
				TrackToTrackDistance d = new TrackToTrackDistance(tp.referenceTrack, tp.candidateTrack, DistanceTypes.DISTANCE_MATCHING, maxDist);
				sumDistance += d.sumDetectionDistance;
				sumSquareDistance += d.sumSquareDetectionDistance;
				if (d.minDetectionDistance < minDistance)
					minDistance = d.minDetectionDistance;
				if (d.maxDetectionDistance > maxDistance)
					maxDistance = d.maxDetectionDistance;
				numDetections += d.numMatchingDetections;
			}
		}
		if (numDetections == 0)
			return new double[]{0, 0, 0, 0};
		else
		{
			double rmse = Math.sqrt(sumSquareDistance/numDetections);
			double stdDistance = Math.sqrt(sumSquareDistance/numDetections - Math.pow(sumDistance/numDetections,2));
			return new double[]{rmse, minDistance, maxDistance, stdDistance};
		}
	}
	
	public ArrayList<Double> getAllPairsDetectionEuclidianDistances()
	{
		ArrayList<Double> distanceList = new ArrayList<Double>();
		for (TrackPair tp:trackPairs)
		{
			distanceList.addAll(getDetectionEuclidianDistances(tp.referenceTrack, tp.candidateTrack));
		}
		return distanceList;
	}
	
	protected ArrayList<Double> getDetectionEuclidianDistances(TrackSegment ts1, TrackSegment ts2)
	{
		ArrayList<Double> distanceList = new ArrayList<Double>();
		if (ts2 == null || ts2.getDetectionList().isEmpty())
			return new ArrayList<Double>();		
		int t0_1 = ts1.getFirstDetection().getT();
		int tend_1 = ts1.getLastDetection().getT();
		int t0_2 = ts2.getFirstDetection().getT();
		int tend_2 = ts2.getLastDetection().getT();
		// test if there is an intersection between segments
		if ((t0_2 >= t0_1 && t0_2 <= tend_1) || (tend_2 >= t0_1 && tend_2 <= tend_1) || (t0_2 <= t0_1 && tend_2 >= tend_1) )
		{
			int firstT = Math.max(t0_1, t0_2);
			int endT = Math.min(tend_1, tend_2);
			for (int t = firstT; t <=endT; t++)
			{
				Detection d1 = ts1.getDetectionAtTime(t);
				Detection d2 = ts2.getDetectionAtTime(t);
				distanceList.add(new Double(Math.sqrt((d1.getX()-d2.getX())*(d1.getX()-d2.getX()) + (d1.getY()-d2.getY())*(d1.getY()-d2.getY()) + (d1.getZ()-d2.getZ())*(d1.getZ()-d2.getZ()))));
			}
		}
		return distanceList;
	}
	
	public ArrayList<Double> getReferenceTracksJumpLengthList()
	{
		ArrayList<Double> lengthList = new ArrayList<Double>();
		for (TrackSegment ts:referenceTracks)
			lengthList.addAll(getJumpLengthList(ts));
		return lengthList;
	}
	
	public ArrayList<Double> getCandidateTracksJumpLengthList()
	{
		ArrayList<Double> lengthList = new ArrayList<Double>();
		for (TrackSegment ts:candidateTracks)
			lengthList.addAll(getJumpLengthList(ts));
		return lengthList;
	}
	
	protected ArrayList<Double> getJumpLengthList(TrackSegment ts1)
	{
		ArrayList<Double> lengthList = new ArrayList<Double>();
		int firstT = ts1.getFirstDetection().getT();
		int lastT = ts1.getLastDetection().getT();
		
		for (int t = firstT; t < lastT; t++)
		{
			Detection d1 = ts1.getDetectionAtTime(t);
			if (d1 != null)
			{
				Detection d2 = ts1.getDetectionAtTime(t + 1);
				if (d2 != null)
					lengthList.add(new Double(Math.sqrt((d1.getX()-d2.getX())*(d1.getX()-d2.getX()) + (d1.getY()-d2.getY())*(d1.getY()-d2.getY()) + (d1.getZ()-d2.getZ())*(d1.getZ()-d2.getZ()))));
			}
		}
		return lengthList;
	}
	
	public double[] getCandidateTracksMSDs()
	{
		return getMSDs(candidateTracks);
	}
	
	public double[] getReferenceTracksMSDs()
	{
		return getMSDs(referenceTracks);
	}
	
	protected double[] getMSDs(ArrayList<TrackSegment> tracks)
	{
		int maxTGap = 0;
		for (TrackSegment ts:tracks)
		{
			int trackLength = (ts.getLastDetection().getT() - ts.getFirstDetection().getT()) + 1;
			if (trackLength - 1 > maxTGap)
				maxTGap = trackLength - 1;
		}
		
		double[] msds = new double[maxTGap];
		int[] numJumps = new int[maxTGap];

		for (TrackSegment ts:tracks)
		{
			for (int tGap = 1; tGap <= maxTGap; tGap++)
			{
				int firstT = ts.getFirstDetection().getT();
				int lastT = ts.getLastDetection().getT();
				
				for (int t = firstT; t <= lastT - tGap; t++)
				{
					Detection d1 = ts.getDetectionAtTime(t);
					if (d1 != null)
					{
						Detection d2 = ts.getDetectionAtTime(t + tGap);
						if (d2 != null)
						{
							msds[tGap - 1] += (d1.getX()-d2.getX())*(d1.getX()-d2.getX()) + (d1.getY()-d2.getY())*(d1.getY()-d2.getY()) + (d1.getZ()-d2.getZ())*(d1.getZ()-d2.getZ());
							numJumps[tGap - 1] += 1;
						}
					}
				}
			}
		}
		for (int k = 0; k < msds.length; k++)
		{
			if (numJumps[k] > 0)
				msds[k] /= numJumps[k];
		}
		return msds;
	}
	
}
