package plugins.nchenouard.tvdenoising;

import icy.image.IcyBufferedImage;
import icy.sequence.Sequence;
import icy.type.collection.array.ArrayUtil;
import plugins.adufour.blocks.lang.Block;
import plugins.adufour.blocks.util.VarList;
import plugins.adufour.ezplug.EzPlug;
import plugins.adufour.ezplug.EzStoppable;
import plugins.adufour.ezplug.EzVarBoolean;
import plugins.adufour.ezplug.EzVarDouble;
import plugins.adufour.ezplug.EzVarEnum;
import plugins.adufour.ezplug.EzVarInteger;
import plugins.adufour.ezplug.EzVarSequence;
import plugins.nchenouard.tvdenoising.TVFISTA.RegularizationType;


/**
 * Total Variation based image regularization and denoising for ICY
 * 
 * Image regularization is applied in 2D slice by slice and time frame by time frame.
 * It makes use of the FISTA based optimization algorithm described in:
 * Beck, A.; Teboulle, M.
 * "Fast Gradient-Based Algorithms for Constrained Total Variation Image Denoising and Deblurring Problems," Image Processing, IEEE Transactions on , vol.18, no.11, pp.2419,2434, Nov. 2009
 * doi: 10.1109/TIP.2009.2028250
 * URL: http://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=5173518&isnumber=5286712
 * 
 * @author Nicolas Chenouard (nicolas.chenouard.dev@gmail.com) and Fabrice de Chaumont
 * @version 1.0
 * @date 2014-05-22
 * @license gpl v3.0
*/

public class TVDenoising extends EzPlug implements Block, EzStoppable
{
	EzVarSequence inputSequence = new EzVarSequence("Input sequence" );
	EzVarSequence outputSequence = new EzVarSequence("Output sequence" );
	EzVarDouble inputLambda = new EzVarDouble( "Smoothing level", 10 , 0 , Double.MAX_VALUE , 0.1 );
	EzVarInteger maximumIter = new EzVarInteger( "Maximum number of iterations", 1000 , 0 , Integer.MAX_VALUE , 1 );

	EzVarEnum<TVFISTA.RegularizationType> regularizationType = 
			new EzVarEnum<TVFISTA.RegularizationType>("Regularization type", TVFISTA.RegularizationType.values() );
	EzVarBoolean processWholeSequence = new EzVarBoolean("Process whole sequence", true);

	boolean isUsedAsBlock = false;
	boolean toStop = false;
	boolean isRunning = false;

	/**
	 * Regularize a whole sequence using the Total Variation criterion
	 * 
	 * @param seq the sequence to process
	 * @param channel the channel to process
	 * @param numIter maximum number of iterations
	 * @param lambda the regularization level parameter
	 * @param regularization type of regularization to use for 2D images
	 *  
	 * */
	public static Sequence regularizeTVSequence(Sequence seq, int channel, int numIter, double lambda, TVFISTA.RegularizationType regularization)
	{
		Sequence outputSeq = new Sequence(seq.getName() + "-TVregularized");
		int width = seq.getSizeX();
		int height = seq.getSizeY();
		for (int t = 0; t < seq.getSizeT(); t++)
		{
			for (int z = 0; z < seq.getSizeZ(); z++)
			{
				double[] im = (double[]) ArrayUtil.arrayToDoubleArray(seq.getImage(t, z).getDataXY(channel), seq.isSignedDataType());
				double[] output =  TVFISTA.regularizeTV(im, width, height, numIter, lambda, regularization);
				outputSeq.setImage(t, z, new IcyBufferedImage(width, height, output));
			}
		}
		return outputSeq;
	}

	/**
	 * Regularize an image using the Total Variation criterion
	 * 
	 * @param image the sequence to process
	 * @param channel the channel to process
	 * @param numIter maximum number of iterations
	 * @param lambda the regularization level parameter
	 * @param regularization type of regularization to use for 2D images
	 *  
	 * */
	public static IcyBufferedImage regularizeTVImage(IcyBufferedImage image, int channel, int numIter, double lambda, TVFISTA.RegularizationType regularization)
	{
		int width = image.getSizeX();
		int height = image.getSizeY();
		double[] im = (double[]) ArrayUtil.arrayToDoubleArray(image.getDataXY(channel), image.isSignedDataType());
		double[] output =  TVFISTA.regularizeTV(im, width, height, numIter, lambda, regularization);
		return new IcyBufferedImage(width, height, output);
	}

	@Override
	protected void execute() {
		double lambda = inputLambda.getValue();	
		Sequence seq = inputSequence.getValue();

		if(seq == null)
			return;

		int width = seq.getWidth();
		int height = seq.getHeight();

		RegularizationType regularization = regularizationType.getValue();
		int numIter = maximumIter.getValue().intValue();
		boolean wholeSequence = processWholeSequence.getValue().booleanValue();

		Sequence outputSeq = new Sequence(seq.getName() + "-TVregularized");
		if ( !isUsedAsBlock )
			addSequence(outputSeq);
		if (wholeSequence)
		{
			isRunning = true;
			toStop = false;
			try{
				for (int t = 0; t < seq.getSizeT(); t++)
				{
					if (toStop)
						return;
					for (int z = 0; z < seq.getSizeZ(); z++)
					{
						if (toStop)
							return;
						System.out.println(z);
						double[] im = (double[]) ArrayUtil.arrayToDoubleArray(seq.getImage(t, z).getDataXY(0), seq.isSignedDataType());
						double[] output =  TVFISTA.regularizeTV(im, width, height, numIter, lambda, regularization);
						outputSeq.setImage(t, z, new IcyBufferedImage(width, height, output));
						if ( !isUsedAsBlock && outputSeq.getFirstViewer() != null)
						{
							outputSeq.getFirstViewer().setPositionT(t);
							outputSeq.getFirstViewer().setPositionZ(z);
						}
					}
				}
			}
			finally
			{
				toStop = false;
				isRunning = false;
			}
		}
		else
		{
			double[] im = (double[]) ArrayUtil.arrayToDoubleArray(seq.getImage(0, 0).getDataXY(0), seq.isSignedDataType());
			double[] output =  TVFISTA.regularizeTV(im, width, height, numIter, lambda, regularization);
			outputSeq.addImage(new IcyBufferedImage(width, height, output));
		}
		if ( isUsedAsBlock )
			outputSequence.setValue( outputSeq );
	}

	@Override
	protected void initialize() {
		addEzComponent( inputSequence );
		addEzComponent( inputLambda );
		addEzComponent( maximumIter );
		addEzComponent( processWholeSequence );
		addEzComponent( regularizationType );		
	}

	@Override
	public void clean() {

	}

	@Override
	public void declareInput(VarList inputMap) {

		isUsedAsBlock = true;

		inputMap.add( inputSequence.getVariable() );
		inputMap.add( inputLambda.getVariable() );
		inputMap.add( regularizationType.getVariable() );
		inputMap.add( maximumIter.getVariable() );
	}

	@Override
	public void declareOutput(VarList outputMap) {

		outputMap.add( outputSequence.getVariable() );

	}

	@Override
	public void stopExecution()
	{
		if (isRunning)
			toStop = true;
	}

}
