package plugins.ylemontag.gaussiannoiseestimator;

import icy.sequence.Sequence;
import icy.system.SystemUtil;
import icy.system.thread.Processor;
import icy.type.DataType;
import icy.type.collection.array.Array1DUtil;

import java.util.Arrays;

import plugins.adufour.filtering.Convolution1D;
import plugins.adufour.filtering.ConvolutionException;

/**
 * 
 * @author Yoann Le Montagner
 * 
 * Core functions used to estimate the standard deviation of an additive
 * white Gaussian noise that affects a sequence
 */
public class GaussianNoiseEstimator
{
	/**
	 * Result structure returned by the estimation functions
	 */
	public static class Result
	{
		private int      _sizeT ;
		private int      _sizeZ ;
		private int      _sizeC ;
		private double[] _buffer;
		
		/**
		 * Constructor
		 * 
		 * This constructor allocates a new buffer of results, of size
		 * 'sizeT*sizeZ*sizeC'.
		 */
		public Result(int sizeT, int sizeZ, int sizeC)
		{
			this(sizeT, sizeZ, sizeC, new double[sizeT*sizeZ*sizeC]);
		}
		
		/**
		 * Constructor
		 * 
		 * This constructor wraps an existing buffer of results, that must be of
		 * size 'sizeT*sizeZ*sizeC'.
		 */
		public Result(int sizeT, int sizeZ, int sizeC, double[] buffer)
		{
			if(buffer.length!=sizeT*sizeZ*sizeC) {
				throw new IllegalArgumentException("Invalid buffer length");
			}
			_sizeT  = sizeT ;
			_sizeZ  = sizeZ ;
			_sizeC  = sizeC ;
			_buffer = buffer;
		}
		
		/**
		 * Access to the raw buffer of results
		 */
		public double[] get()
		{
			return _buffer;
		}
		
		/**
		 * Access to a single result
		 */
		public double get(int t, int z, int c)
		{
			return _buffer[index(t, z, c)];
		}
		
		/**
		 * Set the result for the given coordinates
		 */
		public void set(int t, int z, int c, double value)
		{
			_buffer[index(t, z, c)] = value;
		}
		
		/**
		 * Index in the buffer that correspond to the given coordinates
		 * @return c + sizeC*z + sizeC*sizeZ*t
		 */
		public int index(int t, int z, int c)
		{
			if(t<0 || t>=_sizeT) throw new IllegalArgumentException("Invalid coordinate T");
			if(z<0 || z>=_sizeZ) throw new IllegalArgumentException("Invalid coordinate Z");
			if(c<0 || c>=_sizeC) throw new IllegalArgumentException("Invalid coordinate C");
			return c + _sizeC*(z + _sizeZ*t);
		}
	}
	
	/**
	 * Compute the standard deviation of the noise that affects the given sequence
	 * for all the frames, slices and channels in the sequence
	 */
	public static Result computeStandardDeviation(Sequence seq)
	{
		// Allocate the result
		int sizeT = seq.getSizeT();
		int sizeZ = seq.getSizeZ();
		int sizeC = seq.getSizeC();
		Result retVal = new Result(sizeT, sizeZ, sizeC);
		
		// Schedule the jobs
		int jobCount = sizeT*sizeZ*sizeC;
		if(jobCount==1) {
			Job job = new Job(seq, retVal, 0, 0, 0);
			job.doTheJob();
		}
		else {
			Processor processor = new Processor(jobCount, SystemUtil.getAvailableProcessors());
			for(int t=0; t<sizeT; ++t) {
				for(int z=0; z<sizeZ; ++z) {
					for(int c=0; c<sizeC; ++c) {
						processor.addTask(new Job(seq, retVal, t, z, c));
					}
				}
			}
			processor.waitAll();
		}
		
		// Return the result
		return retVal;
	}
	
	/**
	 * Multi-threading for multi-frame/slice/channels sequences: each thread
	 * computes the result for a given frame 't', slice 'z' and channel 'c'
	 */
	private static class Job implements Runnable
	{
		private Sequence _seq   ;
		private Result   _result;
		private int      _t     ;
		private int      _z     ;
		private int      _c     ;
		
		public Job(Sequence seq, Result result, int t, int z, int c)
		{
			_seq    = seq   ;
			_result = result;
			_t      = t     ;
			_z      = z     ;
			_c      = c     ;
		}
		
		public void doTheJob()
		{
			double value = computeStandardDeviation(_seq, _t, _z, _c);
			_result.set(_t, _z, _c, value);
		}

		@Override
		public void run()
		{
			doTheJob();
		}
	}
	
	/**
	 * Compute the standard deviation of the noise that affects the given sequence
	 * at frame 't', slice 'z' and channel 'c'
	 */
	public static double computeStandardDeviation(Sequence seq, int t, int z, int c)
	{
		// Extract the wavelet coefficents
		double[] waveletCoefficients = computeWaveletCoefficients(seq.getSizeX(), seq.getSizeY(),
			seq.getDataXY(t, z, c), seq.getDataType_());
		
		// Discard the sign of the wavelet coefficients
		for(int k=0; k<waveletCoefficients.length; ++k) {
			waveletCoefficients[k] = Math.abs(waveletCoefficients[k]);
		}
		
		// Get the median of the absolute value of the wavelet coefficients
		double median = computeMedian(waveletCoefficients);
		
		// Assuming Gaussianity, the ratio median/standard deviation is constant and
		// equal to a known value. This value R is defined as the solution of the
		// following equation:
		// 
		//                    / R
		//   3       1        |            t^2
		//   - = ---------- * |    exp( - ----- )*dt
		//   4   sqrt(2*pi)   |             2
		//                    /-infty
		//
		double R = 0.6744898;
		return median / R;
	}
	
	/**
	 * Compute the median of the input (that will be modified)
	 */
	private static double computeMedian(double[] input)
	{
		Arrays.sort(input);
		int lg = input.length;
		if(lg%2==0) {
			return (input[lg/2] + input[lg/2-1]) / 2;
		}
		else {
			return input[lg/2];
		}
	}
	
	/**
	 * Compute the wavelet detail coefficients at scale 1 for the given 'data' image
	 */
	private static double[] computeWaveletCoefficients(int sizeX, int sizeY, Object data, DataType dataType)
	{
		// Haar wavelet
		//double[] kernel_l = new double[] { 0, 1/Math.sqrt(2),  1/Math.sqrt(2) };
		//double[] kernel_h = new double[] { 0, 1/Math.sqrt(2), -1/Math.sqrt(2) };
		
		// Daubechies 4 wavelet
		double[] kernel_l = new double[] { 0, -0.010597, 0.032883,  0.030841, -0.187035, -0.027984, 0.630881,  0.714847,  0.230378 };
		double[] kernel_h = new double[] { 0, -0.230378, 0.714847, -0.630881, -0.027984,  0.187035, 0.030841, -0.032883, -0.010597 };
		
		// Apply the wavelet kernels in the X direction
		double[] data_lX = convolve(sizeX, sizeY, data, dataType, kernel_l, null);
		double[] data_hX = convolve(sizeX, sizeY, data, dataType, kernel_h, null);
		
		// Apply the wavelet kernels in the Y direction
		double[] data_lX_hY = convolve(sizeX, sizeY, data_lX, DataType.DOUBLE, null, kernel_h);
		double[] data_hX_lY = convolve(sizeX, sizeY, data_hX, DataType.DOUBLE, null, kernel_l);
		double[] data_hX_hY = convolve(sizeX, sizeY, data_hX, DataType.DOUBLE, null, kernel_h);
		
		// Aggregate the vertical, horizontal and diagonal coefficients
		double[] retVal = new double[data_lX_hY.length + data_hX_lY.length + data_hX_hY.length];
		int idx = 0;
		for(double d : data_lX_hY) {
			retVal[idx] = d;
			++idx;
		}
		for(double d : data_hX_lY) {
			retVal[idx] = d;
			++idx;
		}
		for(double d : data_hX_hY) {
			retVal[idx] = d;
			++idx;
		}
		return retVal;
	}
	
	/**
	 * Execute a convolution on the input array 'data', without modifying it
	 */
	private static double[] convolve(int sizeX, int sizeY, Object data, DataType dataType,
		double[] kernelX, double[] kernelY)
	{
		double[] dataCopy = Array1DUtil.arrayToDoubleArray(data, dataType.isSigned());
		double[][] buffer = new double[][] { dataCopy };
		try {
			Convolution1D.convolve(buffer, sizeX, sizeY, kernelX, kernelY, null);
		}
		catch(ConvolutionException err) {
			throw new IllegalArgumentException(err);
		}
		return buffer[0];
	}
}
