package plugins.fantm.manualregistration;


import java.util.Arrays;
import javax.vecmath.Vector2d;
import java.awt.Point;
import java.awt.Rectangle;
import icy.image.IcyBufferedImage;
import icy.image.IcyBufferedImageUtil;
import icy.sequence.Sequence;
import icy.sequence.SequenceUtil;
import plugins.adufour.ezplug.EzPlug;
import plugins.adufour.blocks.lang.Block;
import plugins.adufour.blocks.util.VarList;
import plugins.adufour.ezplug.EzGroup;
import plugins.adufour.ezplug.EzLabel;
import plugins.adufour.ezplug.EzVarBoolean;
import plugins.adufour.ezplug.EzVarEnum;
import plugins.adufour.ezplug.EzVarSequence;
import plugins.adufour.vars.lang.VarBoolean;
import plugins.adufour.ezplug.EzVarDouble;
import icy.gui.dialog.MessageDialog;

import icy.type.collection.array.Array1DUtil;
import edu.emory.mathcs.jtransforms.fft.FloatFFT_2D;
import flanagan.complex.Complex;
import icy.type.DataType;

/*
 * Manual Registration Plugin
 * Based off adufour's Rigid Registration plugin, 
 * but makes it easier to register lots of images with 
 * the same parameters by using a protocol. 
 * 
 */

public class ManualRegistration extends EzPlug implements Block
{
	// EzPlug variables
	private enum Mode {
		Calculate, Apply
	}
	
    EzVarEnum<Mode> mode = new EzVarEnum<Mode>("Mode", Mode.values(), Mode.Apply);
	
    // Calculate Mode
    EzVarSequence seq  = new EzVarSequence("Sequence");
    
    // Apply Mode
    EzVarSequence applyTo = new EzVarSequence("Sequence Input");
    EzVarBoolean  preserveSize = new EzVarBoolean("Preserve size", true);
    EzVarDouble preTranslateX = new EzVarDouble("Pre-translation x", 0, -2048, 2048, 1);
    EzVarDouble preTranslateY = new EzVarDouble("Pre-translation y", 0, -2048, 2048, 1);
    EzVarDouble rotation = new EzVarDouble("Rotation (degrees)", 0, -360, 360, 1);
    EzVarDouble postTranslateX = new EzVarDouble("Post-translation x", 0, -2048, 2048, 1);
    EzVarDouble postTranslateY = new EzVarDouble("Post-translation y", 0, -2048, 2048, 1);
    EzVarBoolean applyInPlace = new EzVarBoolean("Apply in place", false);
    
    EzVarSequence outputSeq = new EzVarSequence("Registered Sequence");
    
    // Headless input
    VarBoolean calculateModeOn = new VarBoolean("Calculate Parameters?", false);
           
	@Override
	protected void initialize() {
        setTimeDisplay(true);
        
        addEzComponent(mode); 
        
        EzLabel calculateLabel = new EzLabel("Calculate registration parameters from an image");
        final EzGroup calculating = new EzGroup("Calculate mode", calculateLabel, seq);
        addEzComponent(calculating);
        mode.addVisibilityTriggerTo(calculating, Mode.Calculate);
        
        EzLabel applyLabel = new EzLabel("Apply registration parameters to a new image");
        final EzGroup applying = new EzGroup("Apply mode", applyLabel, applyTo, preserveSize, preTranslateX, preTranslateY, rotation, postTranslateX, postTranslateY, applyInPlace);
        addEzComponent(applying);
        mode.addVisibilityTriggerTo(applying, Mode.Apply);
        
	}
	
	@Override
	public void declareInput(VarList inputMap){
		inputMap.add("Calculate Paramters?", calculateModeOn);
		inputMap.add("Input Sequence", applyTo.getVariable());
		inputMap.add("Pre-translation X", preTranslateX.getVariable());
		inputMap.add("Pre-translation Y", preTranslateY.getVariable());
		inputMap.add("Rotation (degrees)", rotation.getVariable());
		inputMap.add("Post-translation X", postTranslateX.getVariable());
		inputMap.add("Post-translation Y", postTranslateY.getVariable());
	}
	
	@Override
	public void declareOutput(VarList outputMap){	
		outputMap.add("Registered Sequence", outputSeq.getVariable()); 
	}
	

	@Override
	protected void execute() {
		if (this.isHeadLess()){
			if (calculateModeOn.getValue()){
				mode.setValue(Mode.Calculate);
			}else{
				mode.setValue(Mode.Apply);
			}
		}
		
		if (mode.getValue() == Mode.Calculate){
			// calculate registration parameters
			
			Sequence s;
			if(this.isHeadLess()){
				s = applyTo.getValue(true);
			}else{
				s = seq.getValue(true);
			}
			
			Sequence refS = SequenceUtil.extractChannel(s, 0);
			Sequence tarS = SequenceUtil.extractChannel(s, 1);
						
			Vector2d preTranslation = getTranslation(tarS, refS); 
			
			Sequence[] y;
			y = new Sequence[2];
			y[0] = refS; y[1] = tarS;
			Sequence x = SequenceUtil.concatC(y);
			
			applyTranslation(x, new Vector2d(-preTranslation.x, -preTranslation.y));
			
			refS = SequenceUtil.extractChannel(x, 0);
			tarS = SequenceUtil.extractChannel(x, 1);
			
			double rotationAngle = getRotation(tarS, refS); //

			Sequence [] z;
			z = new Sequence[2];
			z[0] = refS; z[1] = tarS;
			Sequence w = SequenceUtil.concatC(z);
			
			applyRotation(w, -rotationAngle);
			
			refS = SequenceUtil.extractChannel(w, 0);
			tarS = SequenceUtil.extractChannel(w, 1);
			
			Vector2d postTranslation = getTranslation(tarS, refS); //
			
			// message box with all the values in
			if(!this.isHeadLess()){
				String str = String.format("Pre-translation X: %.3f \n Pre-translation Y: %.3f \n Rotation (degrees): %.3f \n Post-translation X: %.3f \n Post-translation X: %.3f \n", preTranslation.x, preTranslation.y, rotationAngle, postTranslation.x, postTranslation.y);
				MessageDialog.showDialog(str);
			}
			
			// and also put them into the apply boxes
			preTranslateX.setValue(-preTranslation.x);
			preTranslateY.setValue(-preTranslation.y);
			rotation.setValue(-rotationAngle*57.295779513);
			postTranslateX.setValue(-postTranslation.x);
			postTranslateY.setValue(-postTranslation.y);
					
		} 
		else if (mode.getValue() == Mode.Apply){			
			Sequence s = applyTo.getValue(true);
			Sequence outputCopy = null;
			
			if (!this.isHeadLess()){
				if(applyInPlace.getValue()){
					outputCopy = s;
				} else {
					outputCopy = SequenceUtil.getCopy(s);
					addSequence(outputCopy);
				}
			} else {
				outputCopy = SequenceUtil.getCopy(s);
			}
			
			// apply registration parameters to sequence 
			Vector2d preTranslation = new Vector2d(preTranslateX.getValue(), preTranslateY.getValue()); // from somewhere?
			double rotationAngle = rotation.getValue()/57.29577951; // from somewhere?
			Vector2d postTranslation = new Vector2d(postTranslateX.getValue(), postTranslateY.getValue()); // from somewhere?
			
			applyTranslation(outputCopy , preTranslation);
			applyRotation(outputCopy , rotationAngle);
			applyTranslation(outputCopy , postTranslation);
			
			outputSeq.setValue(outputCopy);			
		}
	}

	@Override
	public void clean() {
		// TODO Should something go in here to tidy up the function? I don't know.
	}
	
	
	/*
	 * 
	 *  The following 'under-the-hood' methods are 
	 *  mostly copied from adufour's "Rigid Registration" plug-in
	 *  
	 */
	
	
	// Calculating functions
	public static Vector2d getTranslation(Sequence targetSeq, Sequence referenceSeq){
        int sizeT = referenceSeq.getSizeT();
        int sizeZ = referenceSeq.getSizeZ();
        
        if (sizeT != targetSeq.getSizeT() || sizeZ != targetSeq.getSizeZ())
        {
            throw new IllegalArgumentException("Source and target sequences have different (Z,T) dimensions");
        }
        
        Vector2d vector = new Vector2d();
        for (int t = 0; t < sizeT; t++)
            for (int z = 0; z < sizeZ; z++)
            {
                IcyBufferedImage srcImg = referenceSeq.getImage(t, z);
                IcyBufferedImage tgtImg = targetSeq.getImage(t, z);
                vector.add(findTranslation2D(srcImg, 0, tgtImg, 0));
            }
        vector.scale(1.0 / (sizeT * sizeZ));
        return vector;
	}
	
    public static Vector2d findTranslation2D(IcyBufferedImage source, int sourceC, IcyBufferedImage target, int targetC)
    {
        if (!source.getBounds().equals(target.getBounds())) throw new UnsupportedOperationException("Cannot register images of different size (yet)");
        
        int width = source.getWidth();
        int height = source.getHeight();
        
        float[] _source = Array1DUtil.arrayToFloatArray(source.getDataXY(sourceC), source.isSignedDataType());
        float[] _target = Array1DUtil.arrayToFloatArray(target.getDataXY(targetC), target.isSignedDataType());
        
        float[] correlationMap = spectralCorrelation(_source, _target, width, height);
        
        // IcyBufferedImage corr = new IcyBufferedImage(width, height, new
        // float[][]{correlationMap});
        // Icy.getMainInterface().addSequence(new Sequence(corr));
        
        // Find maximum correlation
        
        int argMax = argMax(correlationMap, correlationMap.length);
        
        int transX = argMax % width;
        int transY = argMax / width;
        
        if (transX > width / 2) transX -= width;
        if (transY > height / 2) transY -= height;
        
        // recover (x,y)
        return new Vector2d(-transX, -transY);
    }	
	
    private static float[] spectralCorrelation(float[] a1, float[] a2, int width, int height)
    {
        // JTransforms's FFT takes dimensions as (rows, columns)
        FloatFFT_2D fft = new FloatFFT_2D(height, width);
        
        return spectralCorrelation(a1, a2, width, height, fft);
    }
    
    private static float[] spectralCorrelation(float[] a1, float[] a2, int width, int height, FloatFFT_2D fft)
    {
        // FFT on images
        float[] sourceFFT = forwardFFT(a1, fft);
        float[] targetFFT = forwardFFT(a2, fft);
        
        // Compute correlation
        
        Complex c1 = new Complex(), c2 = new Complex();
        for (int i = 0; i < sourceFFT.length; i += 2)
        {
            c1.setReal(sourceFFT[i]);
            c1.setImag(sourceFFT[i + 1]);
            
            c2.setReal(targetFFT[i]);
            c2.setImag(targetFFT[i + 1]);
            
            // correlate c1 and c2 (no need to normalize)
            c1.timesEquals(c2.conjugate());
            
            sourceFFT[i] = (float) c1.getReal();
            sourceFFT[i + 1] = (float) c1.getImag();
        }
        
        // IFFT
        
        return inverseFFT(sourceFFT, fft);
    }
    
    private static float[] forwardFFT(float[] realData, FloatFFT_2D fft)
    {
        float[] out = new float[realData.length * 2];
        
        // format the input as a complex array
        // => real and imaginary values are interleaved
        for (int i = 0, j = 0; i < realData.length; i++, j += 2)
            out[j] = realData[i];
            
        fft.complexForward(out);
        return out;
    }
    
    /**
     * Apply an inverse FFT on complex data.
     * 
     * @param data
     *            the complex data to transform.
     * @param fft
     *            An FFT object to perform the transform
     * @return the real, Fourier-inverse data.
     */
    private static float[] inverseFFT(float[] cplxData, FloatFFT_2D fft)
    {
        float[] out = new float[cplxData.length / 2];
        
        fft.complexInverse(cplxData, true);
        
        // format the input as a real array
        // => skip imaginary values
        for (int i = 0, j = 0; i < cplxData.length; i += 2, j++)
            out[j] = cplxData[i];
            
        return out;
    }
    
    private static int argMax(float[] array, int n)
    {
        int argMax = 0;
        float max = array[0];
        for (int i = 1; i < n; i++)
        {
            float val = array[i];
            if (val > max)
            {
                max = val;
                argMax = i;
            }
        }
        return argMax;
    }
	
	
	
	public static double getRotation(Sequence targetSeq, Sequence referenceSeq){
        int sizeT = referenceSeq.getSizeT();
        int sizeZ = referenceSeq.getSizeZ();
        
        if (sizeT != targetSeq.getSizeT() || sizeZ != targetSeq.getSizeZ())
        {
            throw new IllegalArgumentException("Source and target sequences have different (Z,T) dimensions");
        }
        
        double angle = 0;
        for (int t = 0; t < sizeT; t++)
            for (int z = 0; z < sizeZ; z++)
            {
                IcyBufferedImage srcImg = referenceSeq.getImage(t, z);
                IcyBufferedImage tgtImg = targetSeq.getImage(t, z);
                angle += findRotation2D(srcImg, 0, tgtImg, 0);
            }
        angle /= (sizeT * sizeZ);
        return angle;
	}
	
    public static double findRotation2D(IcyBufferedImage source, int sourceC, IcyBufferedImage target, int targetC)
    {
        if (!source.getBounds().equals(target.getBounds())) throw new UnsupportedOperationException("Cannot register images of different size (yet)");
        
        // Convert to Log-Polar
        
        IcyBufferedImage sourceLogPol = toLogPolar(source.getImage(sourceC));
        IcyBufferedImage targetLogPol = toLogPolar(target.getImage(targetC));
        
        int width = sourceLogPol.getWidth(), height = sourceLogPol.getHeight();
        
        float[] _sourceLogPol = sourceLogPol.getDataXYAsFloat(0);
        float[] _targetLogPol = targetLogPol.getDataXYAsFloat(0);
        
        // Compute spectral correlation
        
        float[] correlationMap = spectralCorrelation(_sourceLogPol, _targetLogPol, width, height);
        
        // Find maximum correlation (=> rotation)
        
        int argMax = argMax(correlationMap, correlationMap.length / 2);
        
        // rotation is given along the X axis
        int rotX = argMax % width;
        
        if (rotX > width / 2) rotX -= width;
        
        return -rotX * 2 * Math.PI / width;
    }
    
    private static IcyBufferedImage toLogPolar(IcyBufferedImage image)
    {
        return toLogPolar(image, image.getWidth() / 2, image.getHeight() / 2, 1080, 360);
    }
    
    private static IcyBufferedImage toLogPolar(IcyBufferedImage image, int centerX, int centerY, int sizeTheta, int sizeRho)
    {
        int sizeC = image.getSizeC();
        
        // create the log-polar image (X = theta, Y = rho)
        
        // theta: number of sectors
        double theta = 0.0, dtheta = 2 * Math.PI / sizeTheta;
        // pre-compute all sine/cosines
        float[] cosTheta = new float[sizeTheta];
        float[] sinTheta = new float[sizeTheta];
        for (int thetaIndex = 0; thetaIndex < sizeTheta; thetaIndex++, theta += dtheta)
        {
            cosTheta[thetaIndex] = (float) Math.cos(theta);
            sinTheta[thetaIndex] = (float) Math.sin(theta);
        }
        
        // rho: number of rings
        float drho = (float) (Math.sqrt(centerX * centerX + centerY * centerY) / sizeRho);
        
        IcyBufferedImage logPol = new IcyBufferedImage(sizeTheta, sizeRho, sizeC, DataType.FLOAT);
        
        for (int c = 0; c < sizeC; c++)
        {
            float[] out = logPol.getDataXYAsFloat(c);
            
            // first ring (rho=0): center value
            Array1DUtil.fill(out, 0, sizeTheta, getPixelValue(image, centerX, centerY, c));
            
            // Other rings
            float rho = drho;
            int outOffset = sizeTheta;
            for (int rhoIndex = 1; rhoIndex < sizeRho; rhoIndex++, rho += drho)
                for (int thetaIndex = 0; thetaIndex < sizeTheta; thetaIndex++, outOffset++)
                {
                    double x = centerX + rho * cosTheta[thetaIndex];
                    double y = centerY + rho * sinTheta[thetaIndex];
                    out[outOffset] = (float) getPixelValue(image, x, y, c);
                }
        }
        
        logPol.updateChannelsBounds();
        return logPol;
    }
    
    private static float getPixelValue(IcyBufferedImage img, double x, double y, int c)
    {
        int width = img.getWidth();
        int height = img.getHeight();
        Object data = img.getDataXY(c);
        DataType type = img.getDataType_();
        
        // "center" the coordinates to the center of the pixel
        x -= 0.5;
        y -= 0.5;
        
        int i = (int) Math.floor(x);
        int j = (int) Math.floor(y);
        
        if (i <= 0 || i >= width - 1 || j <= 0 || j >= height - 1) return 0f;
        
        float value = 0;
        
        final int offset = i + j * width;
        final int offset_plus_1 = offset + 1; // saves 1 addition
        
        x -= i;
        y -= j;
        
        final double mx = 1 - x;
        final double my = 1 - y;
        
        value += mx * my * Array1DUtil.getValueAsFloat(data, offset, type);
        value += x * my * Array1DUtil.getValueAsFloat(data, offset_plus_1, type);
        value += mx * y * Array1DUtil.getValueAsFloat(data, offset + width, type);
        value += x * y * Array1DUtil.getValueAsFloat(data, offset_plus_1 + width, type);
        
        return value;
    }
	
	
	
	
	
	
	
	
	
	// Applying functions
	
	public static void applyTranslation(Sequence target, Vector2d translation){
		
        if (translation.lengthSquared() == 0.0) return;
        
        int minT = 0, maxT = target.getSizeT();
        int minZ = 0, maxZ = target.getSizeZ();
        
        //int minT = (t == -1 ? 0 : t), maxT = (t == -1 ? seq.getSizeT() : t);
        //int minZ = (z == -1 ? 0 : z), maxZ = (z == -1 ? seq.getSizeZ() : z);
        
        for (int time = minT; time < maxT; time++)
            for (int slice = minZ; slice < maxZ; slice++)
            {
                IcyBufferedImage image = target.getImage(time, slice);
                image = applyTranslation2D(image, 1, translation, true);
                target.setImage(time, slice, image);
            }
		
	}
	
	   public static IcyBufferedImage applyTranslation2D(IcyBufferedImage image, int channel, Vector2d vector, boolean preserveImageSize)
	    {
	        int dx = (int) Math.round(vector.x);
	        int dy = (int) Math.round(vector.y);
	        
	        if (dx == 0 && dy == 0) return image;
	        
	        Rectangle newSize = image.getBounds();
	        newSize.width += Math.abs(dx);
	        newSize.height += Math.abs(dy);
	        
	        Point dstPoint_shiftedChannel = new Point(Math.max(0, dx), Math.max(0, dy));
	        Point dstPoint_otherChannels = new Point(Math.max(0, -dx), Math.max(0, -dy));
	        
	        IcyBufferedImage newImage = new IcyBufferedImage(newSize.width, newSize.height, image.getSizeC(), image.getDataType_());
	        for (int c = 0; c < image.getSizeC(); c++)
	        {
	            Point dstPoint = (channel == -1 || c == channel) ? dstPoint_shiftedChannel : dstPoint_otherChannels;
	            newImage.copyData(image, null, dstPoint, c, c);
	        }
	        
	        if (preserveImageSize)
	        {
	            newSize = image.getBounds();
	            newSize.x = Math.max(0, -dx);
	            newSize.y = Math.max(0, -dy);
	            
	            return IcyBufferedImageUtil.getSubImage(newImage, newSize);
	        }
	        else return newImage;
	    }
	
	
	
	public static void applyRotation(Sequence target, double rotation){
        if (rotation == 0.0) return;
        
        int minT = 0, maxT = target.getSizeT();
        int minZ = 0, maxZ = target.getSizeZ();
        
        //int minT = (t == -1 ? 0 : t), maxT = (t == -1 ? seq.getSizeT() : t);
        //int minZ = (z == -1 ? 0 : z), maxZ = (z == -1 ? seq.getSizeZ() : z);
        
        for (int time = minT; time < maxT; time++)
            for (int slice = minZ; slice < maxZ; slice++)
            {
                IcyBufferedImage image = target.getImage(time, slice);
                image = applyRotation2D(image, 1, rotation, true);
                target.setImage(time, slice, image);
            }
	}
	
    public static IcyBufferedImage applyRotation2D(IcyBufferedImage img, int channel, double angle, boolean preserveImageSize)
    {
        if (angle == 0.0) return img;
        
        // start with the rotation to calculate the largest bounds
        IcyBufferedImage rotImg = IcyBufferedImageUtil.rotate(img.getImage(channel), angle);
        
        // calculate the difference in bounds
        Rectangle oldSize = img.getBounds();
        Rectangle newSize = rotImg.getBounds();
        int dw = (newSize.width - oldSize.width) / 2;
        int dh = (newSize.height - oldSize.height) / 2;
        
        if (channel == -1 || img.getSizeC() == 1)
        {
            if (preserveImageSize)
            {
                oldSize.translate(dw, dh);
                return IcyBufferedImageUtil.getSubImage(rotImg, oldSize);
            }
            else return rotImg;
        }
        else
        {
            IcyBufferedImage[] newImages = new IcyBufferedImage[img.getSizeC()];
            
            if (preserveImageSize)
            {
                for (int c = 0; c < newImages.length; c++)
                    if (c == channel)
                    {
                        // crop the rotated channel
                        oldSize.translate(dw, dh);
                        newImages[c] = IcyBufferedImageUtil.getSubImage(rotImg, oldSize);
                    }
                    else newImages[c] = img.getImage(c);
            }
            else
            {
                for (int c = 0; c < newImages.length; c++)
                    if (c != channel)
                    {
                        // enlarge and center the non-rotated channels
                        newImages[c] = new IcyBufferedImage(newSize.width, newSize.height, 1, img.getDataType_());
                        newImages[c].copyData(img.getImage(c), null, new Point(dw, dh));
                    }
                    else newImages[channel] = rotImg;
            }
            
            return IcyBufferedImage.createFrom(Arrays.asList(newImages));
        }
    }
	
	
}
