package plugins.dmandache.denoise;

public class FISTAwrapper
{
	/**
	 * Total Variation based image regularization in 2D
	 * 
	 * It makes use of the FISTA based optimization algorithm described in:
	 * A. Beck and M. Teboulle.
	 * "Fast Gradient-Based Algorithms for Constrained Total Variation Image Denoising and Deblurring Problems," Image Processing, IEEE Transactions on , vol.18, no.11, pp.2419,2434, Nov. 2009
	 * doi: 10.1109/TIP.2009.2028250
	 * URL: http://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=5173518&isnumber=5286712
	 * 
	 */


	/**
	 * Solves the optimization problem (in the paper is called deblurring problem):
	 * 
	 * 			min FrobNorm[ A(x) - b]^2 + 2*lambda*TVNorm(x) , where A is the Fast Fourier Transform
	 *  
	 * Wrapper for functions in ICY_HOME/plugins/nchenouard/tvdenoising/TVDenoising.jar TVFista.class
	 * (code duplicate needed because of privacy settings)
	 * 
	 * @author Diana Mandache
	 * @version 1.0
	 * @date 2017-02-20
	 * @license gpl v3.0
	 * 
	 * 
	 * @param Y a double array of size 2 x (w*h) containing the Fourier transform of an image where Y[0][x + y*width] corresponds to the real part of the coefficient at coordinates {x, y} and Y[1][x + y*width] corresponds to the imaginary part of the coefficient at coordinates {x, y}
	 * @param fourierMask the sampling mask in the Fourier domain
	 * @param width width of the image
	 * @param height height of the image
	 * @param numIter maximum number of iterations
	 * @param lambda regularization parameter
	 * 
	 * */
	public static double[] optimization(double[][] Y, double[] fourierMask, int w, int h, int numIter, double lambda){
		double epsilon = 1e-4;
		
		double[] Yk = new double[w*h];		// Y_k
		double[] Ykp1 = new double[w*h];	// Y_k+1
		double[] Xk = new double[w*h];		// X_k
		double[] Xkm1 = new double[w*h];	// X_k-1
		
		double[][] tempF = new double[2][w*h];
		double[] tempIF = new double[w*h];
		
		double tk = 1;		// t_k     t0 = 1
		double tkp1 = 1;	// t_k+1
		
		double L = 8d*lambda;
		
		int i = 0;
		int cnt = 0;
		
		while(i < numIter && cnt < 5){			
			i++;			
			Xkm1 = Xk;
			
			// A*Yk - b
			tempF = FFTwrapper.FFT_2D_with_mask(Yk, fourierMask, w, h); 
			tempF = Util.subtract2D(tempF, Y, 2, w*h);					
		
			// Yk - 2/L * A^T(A*Yk - b)
			tempIF = FFTwrapper.IFFT_2D(tempF, w, h);
			tempIF = Util.multiplyScalar1D(tempIF, 2/L);
			tempIF = Util.subtract1D(Yk,tempIF);		
			
			// Xk = D( Yk - 2/L * A^T(A*Yk - b), 2*lambda/L )
			Xk = regularizeTV(tempIF, w, h, numIter/2, 2*lambda/L); 
			
			tkp1 = (1 + Math.sqrt(1 + 4*tk*tk))/2;
			
			//Y_k+1 = Xk + (tk - 1)/t_k+1 * ( Xk - X_k-1 )
			Ykp1 = Util.add1D(Xk, Util.multiplyScalar1D( Util.subtract1D(Xk, Xkm1), (tk - 1)/tkp1));
			
			Yk = Ykp1;
			tk = tkp1;
						
			// compute the froebonius norm of Xk and Xk-1 and compare their ratio against epsilon
			double  froebXk = 0;
			for (int k = 0; k < Xk.length; k++)
				froebXk += Xk[k]*Xk[k];
			double  froebXkold = 0;
			for (int k = 0; k < Xkm1.length; k++)
				froebXkold += Xkm1[k]*Xkm1[k];
			if (Math.sqrt(froebXk/froebXkold) < epsilon)
				cnt++;
			else
				cnt = 0;
		}
		
		return Xk;
		
	}
	
	/**
	 * Solves the optimization problem (in the paper is called denoising problem):
	 * 
	 * 				min FrobNorm[ x - b]^2 + 2*lambda*TVNorm(x)
	 * 
	 * Implements the algorithm FGP(b,lambda,Niter) fast gradient projection from the paper.
	 * 
	 * 
	 * Wrapper for functions in ICY_HOME/plugins/nchenouard/tvdenoising/TVDenoising.jar TVFista.class
	 * (code duplicate needed because of privacy settings)
	 * 
	 * @author Nicolas Chenouard (nicolas.chenouard.dev@gmail.com)
	 * @version 1.0
	 * @date 2014-05-22
	 * @license gpl v3.0
	 * 
	 * 
	 * @param im a double array containing pixel values im[x + y*width] corresponds to the pixel value at coordinate {x, y}
	 * @param width width of the image
	 * @param height height of the image
	 * @param numIter maximum number of iterations
	 * @param lambda regularization parameter
	 * 
	 * */
	public static double[] regularizeTV(double[] im, int width, int height, int numIter, double lambda)
	{
		double epsilon = 1e-4;
		double tk = 1;		// t_k     t0 = 1
		double tkp1 = 1;	// t_k+1

		double[] D = new double[width*height];
		double[] Dold = new double[width*height];

		int width1 = width;								
		int height1 = height - 1;
		double[] P1 = new double[width1*height1];
		double[] P1old = new double[width1*height1];
		double[] R1 = new double[width1*height1];
		double[] Q1 = new double[width1*height1];

		int width2 = width - 1;							
		int height2 = height;
		double[] P2old = new double[width2*height2];
		double[] P2 = new double[width2*height2];
		double[] R2 = new double[width2*height2];
		double[] Q2 = new double[width2*height2];		

		int iter = 0;
		int cnt = 0;
		while (iter < numIter && cnt < 5)
		{	
			iter++;
			tk = tkp1;

			double[] tmp = P1old;
			P1old = P1;
			P1 = tmp;

			tmp = P2old;
			P2old = P2;
			P2 = tmp;

			tmp = Dold;
			Dold = D;
			D = tmp;

			// Computing the gradient of the objective function
			smoothPartGradient(im, lambda, R1, width1, height1, R2, width2, height2, D); // D = x = y - lambda * L(p,q)
			Ltrans(D, width, height, Q1, Q2);	// L^T(D) = (Q1,Q2)

			// Taking a step towards minus of the gradient
			// (p_k,q_k) = (p_k-1,q_k-1) + 1/8*lambda * L^T(y - lambda * L(p_k-1,q_k-1))
			for (int k = 0; k < P1.length; k++)			
				P1[k] = R1[k] + 1/(8d*lambda)*Q1[k];
			for (int k = 0; k < P2.length; k++)
				P2[k] = R2[k] + 1/(8d*lambda)*Q2[k];

			// L1 regularization
			for (int k = 0; k < P1.length; k++)
				if(P1[k] > 1)
					P1[k] = 1;
				else if (P1[k] < -1)
					P1[k] = -1;
			for (int k = 0; k < P2.length; k++)
				if(P2[k] > 1)
					P2[k] = 1;
				else if (P2[k] < -1)
					P2[k] = -1;

			// Updating R and t
			tkp1 = (1 + Math.sqrt(1 + 4*tk*tk))/2;					// t_k+1
			
			for (int k = 0; k < R1.length; k++)						// y_k+1 = x_k + (t_k - 1) / t_k+1 * (x_k - x_k-1)
				R1[k] = P1[k] + (tk - 1)/(tkp1)*(P1[k] - P1old[k]);
			for (int k = 0; k < R2.length; k++)
				R2[k] = P2[k] + (tk - 1)/(tkp1)*(P2[k] - P2old[k]);			
			
			// Compute the Froebonius norms of D and Dold and compare their ratio against epsilon
			double  froebD = 0;
			for (int k = 0; k < D.length; k++)
				froebD += D[k]*D[k];
			double  froebDold = 0;
			for (int k = 0; k < Dold.length; k++)
				froebDold += Dold[k]*Dold[k];
			if (Math.sqrt(froebD/froebDold) < epsilon)
				cnt++;
			else
				cnt = 0;
		}
		return D;
	}

	/**
	 * Gradient of the smooth part of the cost function
	 * */
	private static void smoothPartGradient(double[] im, double lambda, double[] r1, int width1, int height1, double[] r2, int width2, int height2, double[] output)
	{
		lForward(r1, width1, height1, r2, width2, height2, output); // L(p,q) 
		for (int k = 0; k < im.length; k++)
			output[k] = im[k] - lambda*output[k]; // b - lambda * L(p,q)
	}

	
	// L(p,q) = p_i,j + q_i,j - p_i-1,j - q_i,j-1
	private static void lForward(double[] P1, int width1, int height1, double[] P2, int width2, int height2, double[] output)
	{
		int m = height1 + 1;
		int n = width1;

		System.arraycopy(P1, 0, output, 0, height1*width1); 
		for (int x = 0; x < n; x++)
			output[x + (m - 1)*n] = 0; 

		for (int y = 0; y < m; y++)
			for (int x = 0; x < n - 1; x++)
				output[x + y*n] = output[x + y*n] + P2[x + y*width2];

		for (int y = 1; y < m; y++)
			for (int x = 0; x < n; x++)
				output[x + y*n] = output[x + y*n] - P1[x + (y - 1)*width1]; 

		for (int y = 0; y < m; y++)
			for (int x = 1; x < n; x++)
				output[x + y*n] = output[x + y*n] - P2[(x - 1) + y*width2]; 
	}

	// L^T(x) = (p,q)
	private static void Ltrans(double[] X, int width, int height, double[] P1, double[] P2)
	{
		for (int y = 0; y < height - 1; y++)
			for (int x = 0; x < width; x++)
				P1[x + y*width] = X[x + y*width] - X[x + (y + 1)*width];

		for (int y = 0; y < height; y++)
			for (int x = 0; x < width - 1; x++)
				P2[x + y*(width - 1)] = X[x + y*width] - X[x + 1 + y*width]; 
	}
}