package plugins.dmandache.denoise;

import java.util.Random;
import edu.emory.mathcs.jtransforms.fft.DoubleFFT_2D;

public class FFTwrapper{
	
	/**
	 * FFT, IFFT and sampling methods 
	 *  
	 * @author Diana Mandache
	 * @version 1.0
	 * @date 2017-02-20
	 * @license gpl v3.0
	 */

	/**
	 * Compute fast fourier transform FFT with sampling mask
	 * 
	 * @param in image
	 * @param mask sampling masks
	 * @param w width of the mask
	 * @param h height of the mask 
	 * @return out FFT of image
	 * */
	public static double[][] FFT_2D_with_mask(double[] in, double[] mask, int w, int h) 
	{
		double[][] out = new double [2][w*h];
		
		double[] inCopy = new double [2*w*h];
		System.arraycopy( in, 0, inCopy, 0, in.length );
		
		final DoubleFFT_2D fft = new DoubleFFT_2D(h, w);
		fft.realForwardFull(inCopy);
		
		for (int i = 0; i < inCopy.length/2; i++)
		{
			double real = inCopy[2*i];
			double imag = inCopy[2*i + 1];
			
			out[0][i] = real;
			out[1][i] = imag;
		}	

		for (int k = 0; k < mask.length; k++){
			out[0][k] = out[0][k] * mask[k];
			out[1][k] = out[1][k] * mask[k];
		}
		
		return out;		
	}
	
	/**
	 * Compute inverse fast fourier transform IFFT
	 * 
	 * @param in FFT of image
	 * @param w width of the mask
	 * @param h height of the mask 
	 * @return out IFFT of images
	 * */
	public static double[] IFFT_2D(double[][] in, int w, int h){
		double[] out = new double [w*h];
		
		final DoubleFFT_2D fft = new DoubleFFT_2D(h, w);
		
		double[][] inCopy = new double [h][2*w];
		for(int y = 0; y < h; y++){
			for(int x = 0; x < w; x++){
				inCopy[y][2*x + 0] = in[0][x + y*w]; // real
				inCopy[y][2*x + 1] = in[1][x + y*w]; // imag
			}
		}
		
		fft.complexInverse(inCopy, true);

		for(int y = 0; y < h; y++){
			for(int x = 0; x < w; x++){
				out[x+w*y] = inCopy[y][2*x]; 
			}
		}		
		return out;		
	}
	
	/**
	 * Compute magnitude of a complex array
	 * */
	public static double[] Magnitude( double[] Re, double[] Im){
		double[] mag = new double[Re.length];
		for(int i = 0; i < Re.length; i++)
			mag[i] = Math.sqrt(Math.pow(Re[i], 2) + Math.pow(Im[i], 2));
		return mag;	
		
	}
	
	/**
	 * Creates sampling mask by taking all low frequency coefficients (corresponding to frequencies smaller than a given cutoff frequency) and randmly sampling in the high frequency space until we reach the sub-sampling rate
	 * 
	 * @param w width of the mask
	 * @param h height of the mask 
	 * @param samplingRate sub-sampling rate
	 * @param cutOffFreq cut-off frequency
	 * @param swap boolean for swapping the quadrants
	 * */
	public static double[] FFTmask_cutoff(int w, int h, double samplingRate, double cutOffFreq, boolean swap){
		double[] mask = new double[w*h];
		
		int N = w * h;								// signal size
		int M = (int) Math.round(samplingRate * N); // number of samples

	    int x0 = (int) Math.round(w/2); 			// center point coords
	    int y0 = (int) Math.round(h/2);
	    
	    double fX, fY, f;		//frequency
	    
	    int numLowCoeff = 0;  	// number of low frequency coefficients
	    int numRandCoeff = 0; 	// number of random higher frequency coefficients
	    
	    int x; int y;
	    
	    //Sampling in low frequency (smaller than cut-off frequency) space
	    
	    for(y = 0; y < h; y++){
	    	fY = (y-(double)y0)/(double)y0;
			for(x = 0; x < w; x++){
				fX = (x-(double)x0)/(double)x0;
				f = Math.sqrt(fX*fX+fY*fY);						
				if(f<=cutOffFreq){ 
					mask[x+w*y] = 1;
					numLowCoeff ++;
				}
			}
	    } 			    
	       
	    //Random Sampling in high frequency space
	     
	    numRandCoeff = M - numLowCoeff;
	    
	    Random position = new Random();
	    int k = 0;
	    do{
		    x = position.nextInt(w);
		    y = position.nextInt(h);		      
			
		    if(mask[x+w*y] == 0){			    				    
			    mask[x+w*y] = 1;			
		    	k++;
		    }
	    }while(k<=numRandCoeff);
	
	    mask = swap(mask,w,h);
	    
	    if(swap)
	    	mask = swap(mask,w,h);

	    return mask;		
	}
	
	/**
	 * Creates sampling mask by choosing the coefficients randomly following an uniform distribution
	 * 
	 * @param w width of the mask
	 * @param h height of the mask 
	 * @param samplingRate sub-sampling rate
	 * @param cutOffFreq cut-off frequency
	 * @param swap boolean for swapping the quadrants
	 * */
	public static double[] FFTmask_random(int w, int h, double samplingRate, boolean swap){
		double[] mask = new double[w*h];
		
		int N = w * h;		// signal size
		int M = (int) Math.round(samplingRate * N); // number of samples
		
		int x; 
		int y;
		
		Random position = new Random();
	    int k = 0;
	    do{
		    x = position.nextInt(w);
		    y = position.nextInt(h);		      
			
		    if(mask[x+w*y] == 0){			    				    
			    mask[x+w*y] = 1;			
		    	k++;
		    }
	    }while(k<=M);
	
	    mask = swap(mask,w,h);
	    
	    if(swap)
	    	mask = swap(mask,w,h);
	    
	    return mask;	
	}
	
	/**
	 * Creates sampling mask by choosing the coefficients randomly following a normal distribution (Gaussian Sampling)
	 * 
	 * @param w width of the mask
	 * @param h height of the mask 
	 * @param samplingRate sub-sampling rate
	 * @param cutOffFreq cut-off frequency
	 * @param swap boolean for swapping the quadrants
	 * */
	public static double[] FFTmask_gaussian(int w, int h, double samplingRate, boolean swap){
		double[] mask = new double[w*h];
		
		int N = w * h;								// signal size
		int M = (int) Math.round(samplingRate * N); // number of samples
		
		int x; 
		int y;
		
		Random position = new Random();
	    int k = 0;
	    do{
		    x = Math.abs((int)(position.nextGaussian() * w/6 + w/2));
		    y = Math.abs((int)(position.nextGaussian() * h/6 + h/2));
		    
		    if( x < w && y < h && mask[x+w*y] == 0){			    				    
			    mask[x+w*y] = 1;			
		    	k++;
		    }
	    }while(k<=M);
	
	    mask = swap(mask,w,h);
	    
	    if(swap)
	    	mask = swap(mask,w,h);
	    return mask;	
	}
	
	/**
	 * Creates sampling mask
	 * 
	 * @param w width of the mask
	 * @param h height of the mask 
	 * @param samplingRate sub-sampling rate
	 * @param samplingMethod sampling method (cutoff, random, gaussian)
	 * @param cutOffFreq cut-off frequency
	 * @param swap boolean for swapping the quadrants
	 * */
	public static double[] FFTmask(int w, int h, double samplingRate, String samplingMethod, double cutOffFreq, boolean swap){
		double[] mask = new double[w*h];
		if(samplingMethod == "Cutoff Frequency")
			mask = FFTmask_cutoff(w, h, samplingRate, cutOffFreq, swap);
		else if(samplingMethod == "Gaussian")
			mask = FFTmask_gaussian(w, h, samplingRate, swap);
		else if(samplingMethod == "Random")
			mask = FFTmask_random(w, h, samplingRate, swap);
		return mask;
		
	}
	
	/**
	 * Swaps mask quadrants
	 * */
	public static double[] swap (double[] mask, int w, int h){
		double[] swappedMask = new double[w*h];
		
		int wc = (int) Math.ceil(w/2);
		int hc = (int) Math.ceil(h/2);
		
		for(int y = 0; y < h; y++){
			for(int x = 0; x < w; x++){
				int sx = (x + wc)%w; // swap quadrants !
				int sy = (y + hc)%h;				
				swappedMask[sx + w*sy] = mask[x + w*y];
			}
		}		
		return swappedMask;		
	}
	
	/**
	 * Generate an array of sampling masks 
	 * */
	public static double[][] multipleFFTMasks(int w, int h, double samplingRate, int numMasks, String samplingMethod, double cutOffFreq, boolean swap){
		double[][] masks = new double[numMasks][w*h];
		for (int m=0; m<numMasks; m++){
			masks[m] = FFTmask(w, h, samplingRate, samplingMethod, cutOffFreq, swap);
		}
		return masks;		
	}
}
