package plugins.lagache.matchtracks;

import java.awt.geom.Point2D;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import javax.swing.SwingUtilities;


import org.apache.poi.hssf.usermodel.HSSFWorkbook;
import org.apache.poi.ss.usermodel.Row;
import org.apache.poi.ss.usermodel.Sheet;
import org.apache.poi.ss.usermodel.Workbook;
import org.apache.poi.ss.util.WorkbookUtil;

import icy.main.Icy;
import icy.plugin.abstract_.Plugin;
import icy.roi.BooleanMask2D;
import icy.roi.ROI2D;
import icy.sequence.Sequence;
import icy.swimmingPool.SwimmingObject;
import plugins.adufour.blocks.lang.Block;
import plugins.adufour.blocks.util.VarList;
import plugins.adufour.vars.lang.Var;
import plugins.adufour.vars.lang.VarDouble;
import plugins.adufour.vars.lang.VarSequence;
import plugins.adufour.vars.lang.VarWorkbook;
import plugins.fab.trackmanager.TrackGroup;
import plugins.fab.trackmanager.TrackSegment;
import plugins.kernel.roi.roi2d.ROI2DPoint;
import plugins.nchenouard.spot.Detection;
import plugins.nchenouard.spot.DetectionResult;
import plugins.nchenouard.spot.Spot;

public class TracksComparisonBlock extends Plugin implements Block {
	public VarSequence sequence = new VarSequence("sequence",null);
	public Var<TrackGroup> tracks_ref = new Var<TrackGroup>("Reference Tracks",new TrackGroup(null));
	public Var<TrackGroup> tracks_test = new Var<TrackGroup>("Test Tracks",new TrackGroup(null));
	public Var<DetectionResult> detections = new Var<DetectionResult>("Detections", new DetectionResult());
	
	
	public VarDouble match_percentage = new VarDouble("% of detections that match reference tracks", 0.9);
	public VarDouble max_distance = new VarDouble("Max. distance for detection association", 10.0);
	
	VarWorkbook                book              = new VarWorkbook("Workbook",(Workbook) null);	
	VarWorkbook                book_detail              = new VarWorkbook("Workbook (details)",(Workbook) null);
	
	public Var<TrackGroup> tracks_ref_match = new Var<TrackGroup>("Reference tracks (match)",new TrackGroup(null));
	public Var<TrackGroup> tracks_test_match = new Var<TrackGroup>("Test tracks (match)",new TrackGroup(null));
	public Var<TrackGroup> tracks_ref_no_match = new Var<TrackGroup>("Reference tracks (no match)",new TrackGroup(null));
	public Var<TrackGroup> tracks_test_no_match = new Var<TrackGroup>("Test tracks (no match)",new TrackGroup(null));
	public Var<TrackGroup> tracks_ref_multiple = new Var<TrackGroup>("Reference tracks (multiple match)",new TrackGroup(null));
	public Var<TrackGroup> tracks_test_multiple = new Var<TrackGroup>("Test tracks (multiple match)",new TrackGroup(null));
	@Override
	public void declareInput(VarList inputMap) {
		inputMap.add("Sequence",sequence);
		inputMap.add("Reference tracks",tracks_ref);
		inputMap.add("Test tracks",tracks_test);
		inputMap.add("Detections",detections);
		
		
	}
	@Override
	public void declareOutput(VarList outputMap) {
		outputMap.add("Reference tracks (match)",tracks_ref_match);
		outputMap.add("Reference tracks (no match)",tracks_ref_no_match);
		outputMap.add("Reference tracks (multiple match)",tracks_ref_multiple);
		outputMap.add("Test tracks (match)",tracks_test_match);
		outputMap.add("Test tracks (no match)",tracks_test_no_match);
		outputMap.add("Test tracks (multiple match)",tracks_test_multiple);		
		outputMap.add("Workbook",book);
		outputMap.add("Workbook (details)",book_detail);
	}
	@Override
	public void run() {
		if (book.getValue() == null) {
			book.setValue(new HSSFWorkbook());
		}
		book.getValue().setMissingCellPolicy(Row.CREATE_NULL_AS_BLANK);
		int row = 0;		
		// initialisation du workbook
		Workbook wb = book.getValue();
		// create the sheet
		String sheetName = "Results";
		Sheet sheet = wb.getSheet(sheetName);
		if (sheet == null)
			sheet = wb.createSheet(sheetName);							
			Row header = sheet.createRow(0);
			header.getCell(0).setCellValue("nb ref match");			
			header.getCell(1).setCellValue("nb ref no match");
			header.getCell(2).setCellValue("nb ref multiple match");
			header.getCell(3).setCellValue("nb test match");
			header.getCell(4).setCellValue("nb test no match");
			header.getCell(5).setCellValue("nb test multiple match");
			row++;
			
			if (book_detail.getValue() == null) {
				book_detail.setValue(new HSSFWorkbook());
			}
			book_detail.getValue().setMissingCellPolicy(Row.CREATE_NULL_AS_BLANK);
			int row_detail = 0;		
			// initialisation du workbook
			Workbook wb_detaik = book_detail.getValue();
			// create the sheet
			String sheetName_detail = "Results";
			Sheet sheet_detail = wb_detaik.getSheet(sheetName_detail);
			if (sheet_detail == null)
				sheet_detail = wb_detaik.createSheet(sheetName_detail);							
				Row header_detail = sheet_detail.createRow(0);
				header_detail.getCell(0).setCellValue("nb ref track");			
				header_detail.getCell(1).setCellValue("length");
				header_detail.getCell(2).setCellValue("% of match");
				/*header_detail.getCell(3).setCellValue("nb test track");			
				header_detail.getCell(4).setCellValue("length");
				header_detail.getCell(5).setCellValue("% of match");*/				
				row_detail++;
				
	//les trackgroups de sortie
    TrackGroup TG_matched_test = new TrackGroup(sequence.getValue());
   	TG_matched_test.setDescription(sequence.getValue().getName()+"-test-matched");
   	
   	TrackGroup TG_matched_ref = new TrackGroup(sequence.getValue());
   	TG_matched_ref.setDescription(sequence.getValue().getName()+"-ref-matched");
   	   	
   	TrackGroup TG_unmatched_ref = new TrackGroup(sequence.getValue());
   	TG_unmatched_ref.setDescription(sequence.getValue().getName()+"-ref-unmatched");
   	TrackGroup TG_unmatched_test = new TrackGroup(sequence.getValue());
   	TG_unmatched_test.setDescription(sequence.getValue().getName()+"-test-unmatched");
   	
   	
   	TrackGroup TG_many_matched_ref = new TrackGroup(sequence.getValue());
   	TG_many_matched_ref.setDescription(sequence.getValue().getName()+"-ref-many-matched");
   	TrackGroup TG_many_matched_test = new TrackGroup(sequence.getValue());
   	TG_many_matched_test.setDescription(sequence.getValue().getName()+"-test-many-matched");
   	
   	//On recupere la liste de tracks
    TrackGroup ref = (TrackGroup)tracks_ref.getValue();  	
  	TrackGroup test = (TrackGroup)tracks_test.getValue();
  	ArrayList<TrackSegment> ts_ref = ref.getTrackSegmentList();
  	ArrayList<TrackSegment> ts_test = test.getTrackSegmentList();
  	
  	//on crée la liste de détection à partir des tracks tests
  	DetectionResult detect_ = detections.getValue();
  	ArrayList<Detection> liste_totale=new ArrayList<Detection>();
  	for (int t=0;t<sequence.getValue().getSizeT();t++){
  		for (Spot s:detect_.getDetectionsAtT(t)){
  			liste_totale.add(new Detection(s.mass_center.x, s.mass_center.y, s.mass_center.z, t));
  		}
  	}
  	//on associe ensuite une liste de detections par track de reference
  	HashMap<TrackSegment, ArrayList<Detection>> tracks_ref_w_detections = new HashMap<TrackSegment, ArrayList<Detection>>();
  	//on crée une liste pour chaque track de reference
  	for (TrackSegment ts:ts_ref){
  		ArrayList<Detection> new_liste = new ArrayList<Detection>();
  		tracks_ref_w_detections.put(ts,new_liste);  		
  	}
  	//ensuite pour chaque detection, on cherche la track de reference la plus proche et on ajoute la detection à la liste  	
  	double distance_max=max_distance.getValue();
  	for (Detection d:liste_totale){  		
  		TrackSegment ts_temp = new TrackSegment();
  		double min_dist = distance_max;
  		int t=d.getT();
  		for (TrackSegment ts:ts_ref){  		
  			if (ts.getDetectionAtTime(t)!=null){
  				double dist = Math.sqrt(Math.pow(ts.getDetectionAtTime(t).getX()-d.getX(), 2)+Math.pow(ts.getDetectionAtTime(t).getY()-d.getY(), 2));
  				if (dist<min_dist)
  				{ts_temp=ts;min_dist=dist;}}}
  		if (min_dist<distance_max){
  			tracks_ref_w_detections.get(ts_temp).add(d);
  			}
  	}
  	
  //on associe ensuite une liste de detections par track test
  	HashMap<TrackSegment, ArrayList<Detection>> tracks_test_w_detections = new HashMap<TrackSegment, ArrayList<Detection>>();
  	//on crée une liste pour chaque track de reference
  	for (TrackSegment ts:ts_test){
  		ArrayList<Detection> new_liste = new ArrayList<Detection>();
  		tracks_test_w_detections.put(ts,new_liste);  		
  	}
  	//ensuite pour chaque detection, on cherche la track de reference la plus proche et on ajoute la detection à la liste  	
  	for (Detection d:liste_totale){  		
  		TrackSegment ts_temp = new TrackSegment();
  		double min_dist = distance_max;
  		int t=d.getT();
  		for (TrackSegment ts:ts_test){  		
  			if (ts.getDetectionAtTime(t)!=null){
  				double dist = Math.sqrt(Math.pow(ts.getDetectionAtTime(t).getX()-d.getX(), 2)+Math.pow(ts.getDetectionAtTime(t).getY()-d.getY(), 2));
  				if (dist<min_dist)
  				{ts_temp=ts;min_dist=dist;}}}
  		if (min_dist<distance_max){
  			tracks_test_w_detections.get(ts_temp).add(d);
  			}
  	}
  	
  	
  	//Pour chaque track de reférence, on associe les tracks test correspondantes, i.e. dont >x% detectionc correspond
  	double threshold = match_percentage.getValue();
  	//on va ensuite associer chaque track test à une track de reference			
  	HashMap<TrackSegment, ArrayList<TrackSegment>> matching_test_to_ref = new HashMap<TrackSegment, ArrayList<TrackSegment>>();
  	//initialisation de la hashmap  	
  	for (TrackSegment tsr:ts_ref)
  	{matching_test_to_ref.put(tsr, new ArrayList<TrackSegment>());}
	//on remplit ensuite la hashmap: pour chaque track de reference, on cherche les tracks test associés  	  	
  	
  	int compteur=0;
  	HashMap<TrackSegment, Integer> Ref_ts_with_nb_match_det = new HashMap<TrackSegment, Integer>();
  	HashMap<TrackSegment, Integer> Test_ts_with_nb_match_det = new HashMap<TrackSegment, Integer>();
  	//on va aussi créer la liste des tous les tracks test avec au moins un match pour pouvoir ensuite isoler les tracks test sans match
  	ArrayList<TrackSegment> tst_matched = new ArrayList<TrackSegment>();
  	
  
    for (TrackSegment tsr:ts_ref){    
    	Integer max_match_det = 0;
    	compteur++;    	
    	//on parcourt ensuite les différents tracks test pour comparer les detections listes
    	for (TrackSegment tst:ts_test){
    		int nb_match_det=0;    				
    		//on parcourt la liste des detections
    				for (Detection d:tracks_test_w_detections.get(tst)){
    					if(tracks_ref_w_detections.get(tsr).contains(d))
    					{nb_match_det++;}
    					}
    				if (nb_match_det>max_match_det){max_match_det=nb_match_det;}
    				
    		//ensuite si le nb de mtch est suffisant on ajoute tst à la liste de match de tsr
    		if (nb_match_det>threshold*tracks_test_w_detections.get(tst).size())
    		{matching_test_to_ref.get(tsr).add(tst);tst_matched.add(tst);}    					
    		}
    	Ref_ts_with_nb_match_det.put(tsr, max_match_det);
    	}
    //il faut ensuite voir, parmi les track segment de ref, lesquelles sont associés à exactement 1 ts test (définit un match!)
    int no_match=0;int match =0;int many_match=0;
    
    Iterator it = matching_test_to_ref.entrySet().iterator();
    while (it.hasNext())
    {Map.Entry entry = (Map.Entry) it.next();
     ArrayList<TrackSegment> listeOfTracks = (ArrayList<TrackSegment>)entry.getValue();
     if (listeOfTracks.size()==0){
    	 TrackSegment tsref = (TrackSegment)entry.getKey();
    	 ArrayList<Detection> li = new ArrayList<Detection>(tsref.getDetectionList());
    	 TG_unmatched_ref.addTrackSegment(new TrackSegment(li));    	 
    	 no_match++;}
     
     if (listeOfTracks.size()==1){
    	 TrackSegment tsref = (TrackSegment)entry.getKey();
    	 ArrayList<TrackSegment> tstest = (ArrayList<TrackSegment>)entry.getValue();
    	 ArrayList<Detection> li = new ArrayList<Detection>(tsref.getDetectionList());    	 
    	 ArrayList<Detection> li2 = new ArrayList<Detection>(tstest.get(0).getDetectionList());
    	 TG_matched_ref.addTrackSegment(new TrackSegment(li));
    	 TG_matched_test.addTrackSegment(new TrackSegment(li2));    	 
    	 match++;}
     if (listeOfTracks.size()>1){
    	 TrackSegment tsref = (TrackSegment)entry.getKey();
    	 ArrayList<TrackSegment> tstest = (ArrayList<TrackSegment>)entry.getValue();
    	 ArrayList<Detection> li = new ArrayList<Detection>(tsref.getDetectionList()); 
    	 //il faut voir si l'aun des track segment "test" de la liste contient plus de x% des detections (i.e. "quasi-match")
    	 TG_many_matched_ref.addTrackSegment(new TrackSegment(li));    			 
    	 for (TrackSegment tst:tstest){
    				 ArrayList<Detection> li_temp = new ArrayList<Detection>();
    				 li_temp.addAll(tst.getDetectionList());
    				 TG_many_matched_test.addTrackSegment(new TrackSegment(li_temp));  
    			 }    	     	   	     	     	 
    	 many_match++;}
    }
    //reste à définier la liste des tst sans match
    for (TrackSegment tst:ts_test){
    	if (tst_matched.contains(tst)){}
    	else{
    		ArrayList<Detection> li = new ArrayList<Detection>();
			 li.addAll(tst.getDetectionList());
			 TG_unmatched_test.addTrackSegment(new TrackSegment(li));  
		 }   
    	}    
    
	tracks_ref_match.setValue(TG_matched_ref);
	tracks_ref_no_match.setValue(TG_unmatched_ref);
	tracks_ref_multiple.setValue(TG_many_matched_ref);
	tracks_test_match.setValue(TG_matched_test);
	tracks_test_no_match.setValue(TG_unmatched_test);
	tracks_test_multiple.setValue(TG_many_matched_test);
	
	Row ro = sheet.createRow(row);		   	  
	ro.getCell(0).setCellValue(TG_matched_ref.getTrackSegmentList().size());
	ro.getCell(1).setCellValue(TG_unmatched_ref.getTrackSegmentList().size());
	ro.getCell(2).setCellValue(TG_many_matched_ref.getTrackSegmentList().size());
	ro.getCell(3).setCellValue(TG_matched_test.getTrackSegmentList().size());
	ro.getCell(4).setCellValue(TG_unmatched_test.getTrackSegmentList().size());
	ro.getCell(5).setCellValue(TG_many_matched_test.getTrackSegmentList().size());	
	row++;	
	
	
	for (TrackSegment ts:ts_ref){
	Row ro_detail = sheet_detail.createRow(row_detail);		   	  
	ro_detail.getCell(0).setCellValue(ts.getId());
	ro_detail.getCell(1).setCellValue(ts.getDetectionList().size());
	ro_detail.getCell(2).setCellValue(Ref_ts_with_nb_match_det.get(ts));
	row_detail++;}	
	
}	
	public static void sendTracksToPool(final TrackGroup trackGroup,
			final Sequence sequence) {
		SwingUtilities.invokeLater(new Runnable()
		{
			@Override
			public void run() {
				// Add the given trackGroup
				SwimmingObject result = new SwimmingObject(trackGroup);// should
				Icy.getMainInterface().getSwimmingPool().add(result);
				
				
				/*TrackManager manager = new TrackManager();
				if (sequence != null)
					manager.setDisplaySequence(sequence );*/				
			}
		});
	}


	public static TrackSegment search_for_closest_ref_track(TrackSegment tst,HashMap<TrackSegment, ArrayList<Detection>> tracks_w_detections)
	{			
		
		HashMap<TrackSegment,Integer> matched_ref_liste = new HashMap<TrackSegment,Integer>();
		//initialisation de la hashmap
		for (TrackSegment tsr:tracks_w_detections.keySet())
		{matched_ref_liste.put(tsr, 0);}
		//for each detection of tst, let search for the closest ref_track
		for (Detection d:tst.getDetectionList()){
			double min_dist=100000;
			TrackSegment track_ref = null;
			for (TrackSegment tsr:tracks_w_detections.keySet()){				
			if (tsr.getDetectionAtTime(d.getT())!=null){
				double dist = Math.pow(d.getX()-tsr.getDetectionAtTime(d.getT()).getX(), 2)+Math.pow(d.getY()-tsr.getDetectionAtTime(d.getT()).getY(), 2);
				if (dist<min_dist){min_dist=dist;track_ref=tsr;}}}
			if (track_ref!=null){
				int nbm = matched_ref_liste.get(track_ref);
				nbm++;
				matched_ref_liste.put(track_ref,nbm);}}					
		//il faut ensuite chercher la trackd de reference avec le plus de match
		int nbm_max=0;
		TrackSegment match=null;
		for (TrackSegment tsr:matched_ref_liste.keySet()){
			if (matched_ref_liste.get(tsr)>nbm_max){
				nbm_max=matched_ref_liste.get(tsr);
				match=tsr;
			}
		}
		return match;		
		}
	
	public static Point2D getMassCenter2D(ROI2D roi)
	   {
		if (roi instanceof ROI2DPoint)
		{
			return new Point2D.Double(roi.getBounds2D().getX(),roi.getBounds2D().getY());
		}
		else
		{
	       double x = 0, y = 0;
	       long len = 0;

	       final BooleanMask2D mask = roi.getBooleanMask(true);
	       final boolean m[] = mask.mask;
	       final int h = mask.bounds.height;
	       final int w = mask.bounds.width;

	       int off = 0;
	       for (int j = 0; j < h; j++)
	       {
	           for (int i = 0; i < w; i++)
	           {
	               if (m[off++])
	               {
	                   x += i;
	                   y += j;
	                   len++;
	               }
	           }
	       }

	       final Point2D pos2d = roi.getPosition2D();
	       return new Point2D.Double(pos2d.getX() + (x / len), pos2d.getY() + (y / len));
		}
	   }
	
	}