package plugins.lagache.matchtracks;


import plugins.adufour.blocks.lang.Block;
import plugins.adufour.blocks.util.VarList;
import plugins.adufour.vars.lang.Var;
import plugins.adufour.vars.lang.VarBoolean;
import plugins.adufour.vars.lang.VarDouble;
import plugins.adufour.vars.lang.VarInteger;
import plugins.fab.trackmanager.TrackGroup;
import plugins.fab.trackmanager.TrackSegment;
import plugins.lagache.sodasuite.SODAsuite;
import plugins.nchenouard.spot.Detection;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;

import java.util.Map;
import java.util.Set;

import icy.plugin.abstract_.Plugin;
import icy.plugin.interface_.PluginBundled;
import icy.system.profile.Chronometer;

public class EMC2Block extends Plugin implements Block, PluginBundled {	
	public Var<TrackGroup> tracks_in = new Var<TrackGroup>("Tracks (input)",new TrackGroup(null));
	public Var<TrackGroup> tracks_out = new Var<TrackGroup>("Tracks (output)",new TrackGroup(null));
	public VarInteger max_nb_fiducials = new VarInteger("Max. number of fiducials",100);
	public VarDouble Threshold = new VarDouble("Max. distance for track concatenation",10.0);
	public VarInteger time_window = new VarInteger("Max. time window for track concatenation",200);
	//public VarDouble pc = new VarDouble("Percentile (cost factor computation)",0.9);
	public VarDouble cost_factor = new VarDouble("Alternative cost factor (for JV linear association)",1.05);
	public VarBoolean nomotion = new VarBoolean("No motion correction", false);	
	
	//public VarBoolean complete =  new VarBoolean("Complete tracks with virtual detections", true);
	//static TrackPool trackPool = new TrackPool();
	@Override
	public void declareInput(VarList inputMap) {
		// TODO Auto-generated method stub			
		inputMap.add("Tracks (input)",tracks_in);
		inputMap.add("Max. number of fiducials",max_nb_fiducials);
		inputMap.add("Max. distance for tracks linking",Threshold);
		inputMap.add("Max. time window for tracks linking",time_window);
		//inputMap.add("Percentile (cost factor computation)",pc);
		inputMap.add("Alternative cost factor",cost_factor);
		inputMap.add("No motion correction for global linking",nomotion);
		//inputMap.add("Complete tracks with virtual detections", complete);
	}
	@Override
	public void declareOutput(VarList outputMap) {
		outputMap.add("Tracks (output)",tracks_out);		
		
	}
	static Integer nb_fiducials = 300;
	static double alternative_cost_factor=1.05d;
	static double percentile=0.9;
	
	@Override
	public void run() {	
		
		alternative_cost_factor=cost_factor.getValue();
		//percentile = pc.getValue();
		Chronometer chrono = new Chronometer("Chrono");
		
				
	nb_fiducials = max_nb_fiducials.getValue();
	TrackGroup group = tracks_in.getValue();	
	ArrayList<TrackSegment> projected_tracks  = project_tracks(nb_fiducials, Threshold.getValue(),time_window.getValue(),nomotion.getValue(),group);
	TrackGroup TG = new TrackGroup(null);	
	TG.setDescription("linked tracks");
	for (TrackSegment ts:projected_tracks){TG.addTrackSegment(ts);}
	tracks_out.setValue(TG);
	System.out.println("Time for track projection block");
	chrono.displayInSeconds();
	}
	
	
	public static double compute_shortest_distance(TrackSegment ts1,TrackSegment ts2,ArrayList<HashMap<TrackSegment,double[]>> backward_start,ArrayList<HashMap<TrackSegment,double[]>> forward_end,double threshold,int numT)
	{double min_distance = threshold+1;
	int last_time_1 = ts1.getLastDetection().getT();
	int ini_time_2 = ts2.getFirstDetection().getT();
	if (last_time_1>ini_time_2){return Double.MAX_VALUE;}
	else{
		for (int t = last_time_1;t<ini_time_2;t++){
			double[] pos_last_t = forward_end.get(t).get(ts1);
			double[] pos_ini_t = backward_start.get(numT-1-t).get(ts2);
			double temp = Math.sqrt(Math.pow(pos_last_t[0]-pos_ini_t[0], 2)+Math.pow(pos_last_t[1]-pos_ini_t[1], 2));
			if (temp<min_distance){
			min_distance=temp;
			}
		}
	}
	if (min_distance<threshold){
		return min_distance;}
	else {return Double.MAX_VALUE;}				
	}
		
	
	 public static Integer getKeyFromValue(Map<Integer,TrackSegment> hm, TrackSegment value) {
		    for (Integer o : hm.keySet()) {
		      if (hm.get(o).equals(value)) {
		        return o;
		      }
		    }
		    return null;
		  }
	 
	 public static ArrayList<HashMap<TrackSegment,double[]>>  forward_end_computation(int initial_time,int last_time, ArrayList<TrackSegment> ts_liste, boolean nomotion){		
	 ArrayList<HashMap<TrackSegment,double[]>> forward_end = new ArrayList<HashMap<TrackSegment,double[]>>();
		//we begin with the forward computation of end track points and the iterative update of forward_end hashmaps
		for (int t=initial_time;t<last_time+1;t++){
			//we create a hashmap
			HashMap<TrackSegment,double[]> forward_end_t = new HashMap<TrackSegment, double[]>();
			HashMap<Detection,Detection> coupled_detections_forward=new HashMap<Detection,Detection>();
			for (TrackSegment ts:ts_liste){
				//if the ts ends at t, we add it to the HashMap
				if (ts.getLastDetection().getT()==t){
					forward_end_t.put(ts, new double[]{ts.getLastDetection().getX(),ts.getLastDetection().getY()});
				}
				//in any case we update the list of fiducials with short tracks
				Detection d1 = ts.getDetectionAtTime(t);
				Detection d2 = ts.getDetectionAtTime(t+1);
				if ((d1!=null)&(d2!=null))
					{coupled_detections_forward.put(d1, d2);}
			}		
			
			if (nomotion){
				if (forward_end.size()>1){
				HashMap<TrackSegment,double[]> forward_end_t_1 = forward_end.get(forward_end.size()-1);
				//we then update each position 
				Set cles = forward_end_t_1.keySet();
				Iterator it = cles.iterator();
				while (it.hasNext()){
					TrackSegment ts=(TrackSegment)(it.next());					
					double[] pt_last = forward_end_t_1.get(ts);
					forward_end_t.put(ts, pt_last);}}
				forward_end.add(forward_end_t);
			}
			else{
			//we compute the forward TPS
			int size = Math.min(nb_fiducials, coupled_detections_forward.size());
			//By default, the TPS transform is identity
			ThinPlateR2LogRSplineKernelTransform TPST_forward = new ThinPlateR2LogRSplineKernelTransform(2);					
			if (size>0){
				double[][] srcPts = new double[2][size];
				double[][] tgtPts = new double[2][size];
				int i=0;
				for (Detection d:coupled_detections_forward.keySet()){
					if (i<size)
					{		
					srcPts[0][i] = d.getX();
					srcPts[1][i] = d.getY();
					tgtPts[0][i] = coupled_detections_forward.get(d).getX();
					tgtPts[1][i] = coupled_detections_forward.get(d).getY();
					i++;
					}}		
			TPST_forward = new ThinPlateR2LogRSplineKernelTransform( 2, srcPts, tgtPts );
			TPST_forward.solve();
			}
			//and iteratively update the previous forward_end_t hashmap
			//get the last forward hashmap from the List
			if (forward_end.size()>1){
			HashMap<TrackSegment,double[]> forward_end_t_1 = forward_end.get(forward_end.size()-1);
			//we then update each position 
			Set cles = forward_end_t_1.keySet();
			Iterator it = cles.iterator();
			while (it.hasNext()){
				TrackSegment ts=(TrackSegment)(it.next());
				//iterative projections of initial and last points of tracks
				double[] pt_last_old = forward_end_t_1.get(ts);
				double[] pt_last_new = TPST_forward.apply(pt_last_old);
				forward_end_t.put(ts, pt_last_new);}}
			forward_end.add(forward_end_t);		
		}}
	 return forward_end;
}
	 
	 public static ArrayList<HashMap<TrackSegment,double[]>>  backward_start_computation(int initial_time,int last_time, ArrayList<TrackSegment> ts_liste,boolean nomotion){	 		
		 ArrayList<HashMap<TrackSegment,double[]>> backward_start = new ArrayList<HashMap<TrackSegment,double[]>>();
		 
			//we begin with the backward computation of end track points and the iterative update of backward_start hashmaps	
		 for (int t=last_time;t>initial_time-1;t--){			 
			//we create a hashmap
			HashMap<TrackSegment,double[]> backward_start_t = new HashMap<TrackSegment, double[]>();
			HashMap<Detection,Detection> coupled_detections_backward=new HashMap<Detection,Detection>();
			for (TrackSegment ts:ts_liste){
				//if the ts starts at t, we add it to the HashMap
				if (ts.getFirstDetection().getT()==t){
					backward_start_t.put(ts, new double[]{ts.getFirstDetection().getX(),ts.getFirstDetection().getY()});
				}
				//in any case we update the list of fiducials with short tracks
				Detection d2 = ts.getDetectionAtTime(t);
				Detection d1 = ts.getDetectionAtTime(t-1);
				if ((d1!=null)&(d2!=null))
					{coupled_detections_backward.put(d2, d1);}
			}		
			
			if (nomotion){
				if (backward_start.size()>1){
					HashMap<TrackSegment,double[]> backward_start_t_p_1 = backward_start.get(backward_start.size()-1);
					//we then update each position 
					Set cles = backward_start_t_p_1.keySet();
					Iterator it = cles.iterator();
					while (it.hasNext()){
						TrackSegment ts=(TrackSegment)(it.next());
						//iterative projections of initial and last points of tracks
						double[] pt_ini = backward_start_t_p_1.get(ts);
						backward_start_t.put(ts, pt_ini);}}
					backward_start.add(backward_start_t);	
			}
			else{
			//we compute the backward TPS
			int size = Math.min(nb_fiducials, coupled_detections_backward.size());
			//by default the TPS transform is identity
			ThinPlateR2LogRSplineKernelTransform TPST_backward = new ThinPlateR2LogRSplineKernelTransform( 2);
			if (size>0){
				double[][] srcPts = new double[2][size];
				double[][] tgtPts = new double[2][size];
				int i=0;
				for (Detection d:coupled_detections_backward.keySet()){
					if (i<size)
					{		
					srcPts[0][i] = d.getX();
					srcPts[1][i] = d.getY();
					tgtPts[0][i] = coupled_detections_backward.get(d).getX();
					tgtPts[1][i] = coupled_detections_backward.get(d).getY();
					i++;
					}}		
			TPST_backward = new ThinPlateR2LogRSplineKernelTransform( 2, srcPts, tgtPts );
			TPST_backward.solve();
			}
			//and iteratively update the previous backward_end_t hashmap
			//get the last backward hashmap from the List
			if (backward_start.size()>1){
			HashMap<TrackSegment,double[]> backward_start_t_p_1 = backward_start.get(backward_start.size()-1);
			//we then update each position 
			Set cles = backward_start_t_p_1.keySet();
			Iterator it = cles.iterator();
			while (it.hasNext()){
				TrackSegment ts=(TrackSegment)(it.next());
				//iterative projections of initial and last points of tracks
				double[] pt_ini_old = backward_start_t_p_1.get(ts);
				double[] pt_ini_new = TPST_backward.apply(pt_ini_old);
				backward_start_t.put(ts, pt_ini_new);}}
			backward_start.add(backward_start_t);		
		}}
		 return backward_start;
	 }
		

	public static ArrayList<TrackSegment>   project_tracks (int nb_fiducials, double threshold,Integer time_window,Boolean nomotion,TrackGroup group){
	ArrayList<TrackSegment> ts_liste= group.getTrackSegmentList();		
	ArrayList<TrackSegment> projected_tracks = new ArrayList<TrackSegment>();
	int last_time = 0;
	int initial_time=Integer.MAX_VALUE;
	//determine the time window of tracks and assign a number to each track
	HashMap<Integer,TrackSegment> ordered_ts  = new HashMap<Integer, TrackSegment>();
	int j=0;
	for (TrackSegment ts:ts_liste){
		ts.setId(j);
		ordered_ts.put(j, ts);j++;
		if (ts.getLastDetection().getT()>last_time){last_time=ts.getLastDetection().getT();}
		if (ts.getFirstDetection().getT()<initial_time){initial_time=ts.getFirstDetection().getT();}
		}	
	//we create to tabs of Hashmaps that contains the forward-projected positions of end tracks and backward-projected 
	//positions of start tracks 	
	//we begin with the forward computation of end track points and the iterative update of forward_end hashmaps
	ArrayList<HashMap<TrackSegment,double[]>>  forward_end = forward_end_computation(initial_time,last_time,ts_liste,nomotion);	
	/////////////////////////////////////////////////////////////////////////////////////////////			
	//we then perform the backward projection of track starting points
	ArrayList<HashMap<TrackSegment,double[]>>  backward_start = backward_start_computation(initial_time,last_time,ts_liste,nomotion);
	SparseLinkingCostMatrixCreator SLCMC = new SparseLinkingCostMatrixCreator(ts_liste, ts_liste, backward_start, forward_end, threshold,time_window, alternative_cost_factor, percentile,initial_time,last_time);
	SLCMC.process();
	SparseLinker SL = new SparseLinker(SLCMC);
	SL.process();
	Map< TrackSegment, TrackSegment > M=SL.getResult();
	
	//we then create new grouped tracks based on assignment matrix
		ArrayList<ArrayList<TrackSegment>> neuron_liste= new ArrayList<ArrayList<TrackSegment>>();			
		ArrayList<ArrayList<TrackSegment>> neuron_liste_dispo= new ArrayList<ArrayList<TrackSegment>>();
		
		for (int t=initial_time;t<last_time+1;t++){								
		neuron_liste_dispo.clear();
		for (ArrayList<TrackSegment> n:neuron_liste)
		{if (n.size()==0)
			neuron_liste_dispo.add(n);
		else{
			int liste_size = n.size();
			if (n.get(liste_size-1).getLastDetection().getT()<t)
				neuron_liste_dispo.add(n);	
		}					
		}	
		
		for (TrackSegment ts:ts_liste){
			if (ts.getFirstDetection().getT()==t){
				//we use the mapping of sparseLAP
					boolean linked = false;TrackSegment previous_ts=null;
					for (TrackSegment ts_prev:ts_liste){
						if (M.get(ts_prev)==ts){
							linked=true;
						previous_ts=ts_prev;
						break;}
					}
			
			if (linked ==  false){//we create a new Neuron
				ArrayList<TrackSegment> tracks = new ArrayList<TrackSegment>();
				tracks.add(ts);				
				neuron_liste.add(tracks);
			}
			else{ //we search for the neuron containing previous ts				
				for (ArrayList<TrackSegment> n:neuron_liste_dispo){
					if (n.contains(previous_ts)){
						n.add(ts);						
						break;
					}
				}								
			}
			}
		}
	}
	
	for (ArrayList<TrackSegment> n:neuron_liste){				
		ArrayList<Detection> liste_globale =  new ArrayList<Detection>();		
		
		for (int k=0;k<n.size();k++){			
			TrackSegment ts1=n.get(k);				
			liste_globale.addAll(ts1.getDetectionList());
			
			if (((k+1)<n.size())){
				TrackSegment ts2=n.get(k+1);				
				liste_globale.addAll(createVirtualDetections(ts1, ts2,forward_end,initial_time));
			}
		}
			TrackSegment t=new TrackSegment(liste_globale);
		projected_tracks.add(t);
	}	
	return projected_tracks;
	}
public static ArrayList<Detection> createVirtualDetections(TrackSegment ts1,TrackSegment ts2,ArrayList<HashMap<TrackSegment,double[]>> forward_end,int initial_time){
	ArrayList<Detection> vts = new ArrayList<Detection>();
	Detection detectionStart=ts1.getLastDetection();
	Detection detectionEnd = ts2.getFirstDetection();
	for ( int t = detectionStart.getT()+1 ; t < detectionEnd.getT() ; t++ )
	{
		// TPS interpolation
			double[] pos_last_t = forward_end.get(t-initial_time).get(ts1);
			Detection detect = new Detection( pos_last_t[0] , pos_last_t[1] , 0 , t );
			detect.setDetectionType( Detection.DETECTIONTYPE_VIRTUAL_DETECTION );
			vts.add( detect );
	}
	return vts;
	}

@Override
public String getMainPluginClassName() {
	// TODO Auto-generated method stub
	return EMC2.class.getName();
}
}
