package plugins.nchenouard.tvdenoising;

public class TVFISTA
{
	/**
	 * Total Variation based image regularization in 2D
	 * 
	 * It makes use of the FISTA based optimization algorithm described in:
	 * Beck, A.; Teboulle, M.
	 * "Fast Gradient-Based Algorithms for Constrained Total Variation Image Denoising and Deblurring Problems," Image Processing, IEEE Transactions on , vol.18, no.11, pp.2419,2434, Nov. 2009
	 * doi: 10.1109/TIP.2009.2028250
	 * URL: http://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=5173518&isnumber=5286712
	 * 
	 * The java implementation is the direct portage of the original Matlab code provided by Beck and Teboulle.
	 * 
	 * @author Nicolas Chenouard (nicolas.chenouard.dev@gmail.com)
	 * @version 1.0
	 * @date 2014-05-22
	 * @license gpl v3.0
	 */

	
	public static enum RegularizationType{ISOTROPIC, L1}; //Type of regularization: isotropic finite differences, or l1-separable finite differences

	/**
	 * Regularize a 2D image using total variation
	 * @param im a double array containing pixel values im[x + y*width] corresponds to the pixel value at coordinate {x, y}
	 * @param width width of the image
	 * @param height height of the image
	 * @param numIter maximum number of iterations
	 * @param lambda regularization parameter
	 * @param regularization type of regularization
	 * 
	 * */
	public static double[] regularizeTV(double[] im, int width, int height, int numIter, double lambda, RegularizationType regularization)
	{
		double epsilon = 1e-4;
		double tk = 1;
		double tkp1 = 1;

		double[] D = new double[width*height];
		double[] Dold = new double[width*height];

		int width1 = width;
		int height1 = height - 1;
		double[] P1 = new double[width1*height1];
		double[] P1old = new double[width1*height1];
		double[] R1 = new double[width1*height1];
		double[] Q1 = new double[width1*height1];

		int width2 = width - 1;
		int height2 = height;
		double[] P2old = new double[width2*height2];
		double[] P2 = new double[width2*height2];
		double[] R2 = new double[width2*height2];
		double[] Q2 = new double[width2*height2];		

		int iter = 0;
		int cnt = 0;
		while (iter < numIter && cnt < 5)
		{	
			iter++;
			tk = tkp1;

			double[] tmp = P1old;
			P1old = P1;
			P1 = tmp;

			tmp = P2old;
			P2old = P2;
			P2 = tmp;

			tmp = Dold;
			Dold = D;
			D = tmp;

			// Computing the gradient of the objective function
			smoothPartGradient(im, lambda, R1, width1, height1, R2, width2, height2, D);
			Ltrans(D, width, height, Q1, Q2);

			// Taking a step towards minus of the gradient
			for (int k = 0; k < P1.length; k++)
				P1[k] = R1[k] + 1/(8d*lambda)*Q1[k];
			for (int k = 0; k < P2.length; k++)
				P2[k] = R2[k] + 1/(8d*lambda)*Q2[k];

			switch (regularization)
			{
			case ISOTROPIC:
				Q1 = new double[width*height];		
				for (int y = 0; y < height - 1; y++)
				{
					for (int x = 0; x < width - 1; x++)
						Q1[x + y*width] = Math.sqrt(P1[x + y*width1]*P1[x + y*width1] + P2[x + y*width2]*P2[x + y*width2]);
					Q1[(width - 1) + y*width] = Math.sqrt(P1[(width - 1) + y*width1]*P1[(width - 1) + y*width1]);
				}
				for (int x = 0; x < width - 1; x++)
				{
					Q1[x + (height - 1)*width] = Math.sqrt(P2[x + (height - 1)*width2]*P2[x + (height - 1)*width2]);
				}
				Q1[Q1.length - 1] = 0;				
				for (int k = 0; k < P1.length; k++)
					if (Q1[k] > 1)
						P1[k] = P1[k]/Q1[k];
				for (int y = 0; y < height; y++)
					for (int x = 0; x < width - 1; x++)
						if (Q1[x + y*width] > 1)
							P2[x + y*width2] = P2[x +y*width2]/Q1[x + y*width];						
				break;
			case L1:
				for (int k = 0; k < P1.length; k++)
					if(P1[k] > 1)
						P1[k] = 1;
					else if (P1[k] < -1)
						P1[k] = -1;
				for (int k = 0; k < P2.length; k++)
					if(P2[k] > 1)
						P2[k] = 1;
					else if (P2[k] < -1)
						P2[k] = -1;

				break;
			}
			// Updating R and t
			tkp1 = (1 + Math.sqrt(1 + 4*tk*tk))/2;
			for (int k = 0; k < R1.length; k++)
				R1[k] = P1[k] + (tk - 1)/(tkp1)*(P1[k] - P1old[k]);
			for (int k = 0; k < R2.length; k++)
				R2[k] = P2[k] + (tk - 1)/(tkp1)*(P2[k] - P2old[k]);			
			// compute the froebonius norm of D and Dold and compare their ratio against epsilon
			double  froebD = 0;
			for (int k = 0; k < D.length; k++)
				froebD += D[k]*D[k];
			double  froebDold = 0;
			for (int k = 0; k < Dold.length; k++)
				froebDold += Dold[k]*Dold[k];
			if (Math.sqrt(froebD/froebDold) < epsilon)
				cnt++;
			else
				cnt = 0;
		}
		return D;
	}

	/**
	 * Gradient of the smooth part of the cost function
	 * */
	private static void smoothPartGradient(double[] im, double lambda, double[] r1, int width1, int height1, double[] r2, int width2, int height2, double[] output)
	{
		lForward(r1, width1, height1, r2, width2, height2, output);
		for (int k = 0; k < im.length; k++)
			output[k] = im[k] - lambda*output[k];
	}

	private static void lForward(double[] P1, int width1, int height1, double[] P2, int width2, int height2, double[] output)
	{
		int m = height1 + 1;
		int n = width1;

		System.arraycopy(P1, 0, output, 0, height1*width1);
		for (int x = 0; x < n; x++)
			output[x + (m - 1)*n] = 0; // fill the last row with zeros

		for (int y = 0; y < m; y++)
			for (int x = 0; x < n - 1; x++)
				output[x + y*n] = output[x + y*n] + P2[x + y*width2];

		for (int y = 1; y < m; y++)
			for (int x = 0; x < n; x++)
				output[x + y*n] = output[x + y*n] - P1[x + (y - 1)*width1];

		for (int y = 0; y < m; y++)
			for (int x = 1; x < n; x++)
				output[x + y*n] = output[x + y*n] - P2[(x - 1) + y*width2];
	}

	private static void Ltrans(double[] X, int width, int height, double[] P1, double[] P2)
	{
		for (int y = 0; y < height - 1; y++)
			for (int x = 0; x < width; x++)
				P1[x + y*width] = X[x + y*width] - X[x + (y + 1)*width];

		for (int y = 0; y < height; y++)
			for (int x = 0; x < width - 1; x++)
				P2[x + y*(width - 1)] = X[x + y*width] - X[x + 1 + y*width];
	}
}