package plugins.adufour.projection;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import icy.image.IcyBufferedImage;
import icy.math.ArrayMath;
import icy.roi.ROI;
import icy.sequence.Sequence;
import icy.sequence.SequenceUtil;
import icy.system.SystemUtil;
import icy.type.DataType;
import icy.type.collection.array.Array1DUtil;
import plugins.adufour.blocks.lang.Block;
import plugins.adufour.blocks.util.VarList;
import plugins.adufour.ezplug.EzPlug;
import plugins.adufour.ezplug.EzStoppable;
import plugins.adufour.ezplug.EzVarBoolean;
import plugins.adufour.ezplug.EzVarEnum;
import plugins.adufour.ezplug.EzVarSequence;
import plugins.adufour.vars.lang.VarSequence;

public class Projection extends EzPlug implements Block, EzStoppable
{
    public enum ProjectionDirection
    {
        Z, T
    }
    
    public enum ProjectionType
    {
        MAX("Maximum"), MEAN("Average"), MED("Median"), MIN("Minimum"), STD("Standard Deviation"), SATSUM("Saturated Sum");
        
        private final String description;
        
        ProjectionType(String description)
        {
            this.description = description;
        }
        
        public String toString()
        {
            return description;
        }
    }
    
    private final EzVarSequence input = new EzVarSequence("Input");
    
    private final EzVarEnum<ProjectionDirection> projectionDir = new EzVarEnum<Projection.ProjectionDirection>("Project along", ProjectionDirection.values(),
            ProjectionDirection.Z);
            
    private final EzVarEnum<ProjectionType> projectionType = new EzVarEnum<Projection.ProjectionType>("Projection type", ProjectionType.values(), ProjectionType.MAX);
    
    private final EzVarBoolean restrictToROI = new EzVarBoolean("Restrict to ROI", false);
    
    private final VarSequence output = new VarSequence("projected sequence", null);
    
    @Override
    protected void initialize()
    {
        addEzComponent(input);
        addEzComponent(projectionDir);
        addEzComponent(projectionType);
        
        restrictToROI.setToolTipText("Check this option to project only the intensity data contained within the sequence ROI");
        addEzComponent(restrictToROI);
        
        setTimeDisplay(true);
    }
    
    @Override
    protected void execute()
    {
        switch (projectionDir.getValue())
        {
        case T:
            output.setValue(tProjection(input.getValue(true), projectionType.getValue(), true, restrictToROI.getValue()));
            break;
        case Z:
            output.setValue(zProjection(input.getValue(true), projectionType.getValue(), true, restrictToROI.getValue()));
            break;
        default:
            throw new UnsupportedOperationException("Projection along " + projectionDir.getValue() + " not supported");
        }
        
        if (getUI() != null) addSequence(output.getValue());
    }
    
    @Override
    public void clean()
    {
    
    }
    
    /**
     * Performs a Z projection of the input sequence using the specified algorithm. If the sequence
     * is already 2D, then a copy of the sequence is returned
     * 
     * @param sequence
     *            the sequence to project
     * @param projection
     *            the type of projection to perform (see {@link ProjectionType} enumeration)
     * @param multiThread
     *            true if the process should be multi-threaded
     * @return the projected sequence
     */
    public static Sequence zProjection(final Sequence in, final ProjectionType projection, boolean multiThread)
    {
        return zProjection(in, projection, multiThread, false);
    }
    
    /**
     * Performs a Z projection of the input sequence using the specified algorithm. If the sequence
     * is already 2D, then a copy of the sequence is returned
     * 
     * @param sequence
     *            the sequence to project
     * @param projection
     *            the type of projection to perform (see {@link ProjectionType} enumeration)
     * @param multiThread
     *            true if the process should be multi-threaded
     * @param restrictToROI
     *            <code>true</code> projects only data located within the sequence ROI,
     *            <code>false</code> projects the entire data set
     * @return the projected sequence
     */
    public static Sequence zProjection(final Sequence in, final ProjectionType projection, boolean multiThread, boolean restrictToROI)
    {
        final int depth = in.getSizeZ();
        if (depth == 1 && !restrictToROI) return SequenceUtil.getCopy(in);
        
        final Sequence out = new Sequence(projection.name() + " projection of " + in.getName());
        out.copyMetaDataFrom(in, false);
        
        final int width = in.getSizeX();
        final int height = in.getSizeY();
        final int frames = in.getSizeT();
        final int channels = in.getSizeC();
        final DataType dataType = in.getDataType_();
        
        final Collection<ROI> rois = in.getROISet();
        final boolean processROI = restrictToROI && rois.size() > 0;
        
        int cpus = SystemUtil.getNumberOfCPUs();
        int chunkSize = width * height / cpus;
        final int[] minOffsets = new int[cpus];
        final int[] maxOffsets = new int[cpus];
        for (int cpu = 0; cpu < cpus; cpu++)
        {
            minOffsets[cpu] = chunkSize * cpu;
            maxOffsets[cpu] = chunkSize * (cpu + 1);
        }
        // NB: the last chunk must include the remaining pixels (in case rounding off occurs)
        maxOffsets[cpus - 1] = width * height;
        
        ExecutorService service = multiThread ? Executors.newFixedThreadPool(cpus) : Executors.newSingleThreadExecutor();
        ArrayList<Future<?>> futures = new ArrayList<Future<?>>(channels * frames * cpus);
        
        for (int frame = 0; frame < frames; frame++)
        {
            final int t = frame;
            
            if (Thread.currentThread().isInterrupted()) break;
            
            out.setImage(t, 0, new IcyBufferedImage(width, height, channels, dataType));
            
            for (int channel = 0; channel < channels; channel++)
            {
                final int c = channel;
                final Object[] in_Z_XY = (Object[]) in.getDataXYZ(t, c);
                final Object out_Z_XY = out.getDataXY(t, 0, c);
                
                for (int cpu = 0; cpu < cpus; cpu++)
                {
                    final int minOffset = minOffsets[cpu];
                    final int maxOffset = maxOffsets[cpu];
                    
                    futures.add(service.submit(new Runnable()
                    {
                        @Override
                        public void run()
                        {
                            double[] buffer = new double[depth];
                            double[] dataToProject = null;
                            
                            for (int offset = minOffset; offset < maxOffset; offset++)
                            {
                                if (processROI)
                                {
                                    int x = offset % width;
                                    int y = offset / width;
                                    
                                    int nbValues = 0;
                                    
                                    for (int z = 0; z < depth; z++)
                                        for (ROI roi : rois)
                                            if (roi.contains(x, y, z, t, c))
                                            {
                                                buffer[nbValues++] = Array1DUtil.getValue(in_Z_XY[z], offset, dataType);
                                                break;
                                            }
                                            
                                    if (nbValues == 0) continue;
                                    
                                    dataToProject = (nbValues == buffer.length) ? buffer : Arrays.copyOf(buffer, nbValues);
                                }
                                else
                                {
                                    for (int z = 0; z < depth; z++)
                                        buffer[z] = Array1DUtil.getValue(in_Z_XY[z], offset, dataType);
                                    dataToProject = buffer;
                                }
                                
                                switch (projection)
                                {
                                case MAX:
                                    Array1DUtil.setValue(out_Z_XY, offset, dataType, ArrayMath.max(dataToProject));
                                    break;
                                case MEAN:
                                    Array1DUtil.setValue(out_Z_XY, offset, dataType, ArrayMath.mean(dataToProject));
                                    break;
                                case MED:
                                    Array1DUtil.setValue(out_Z_XY, offset, dataType, ArrayMath.median(dataToProject, false));
                                    break;
                                case MIN:
                                    Array1DUtil.setValue(out_Z_XY, offset, dataType, ArrayMath.min(dataToProject));
                                    break;
                                case STD:
                                    Array1DUtil.setValue(out_Z_XY, offset, dataType, ArrayMath.std(dataToProject, true));
                                    break;
                                case SATSUM:
                                    Array1DUtil.setValue(out_Z_XY, offset, dataType, Math.min(ArrayMath.sum(dataToProject), dataType.getMaxValue()));
                                    break;
                                default:
                                    throw new UnsupportedOperationException(projection + " intensity projection not implemented");
                                }
                            } // offset
                        }
                    }));
                }
            }
        }
        
        try
        {
            for (Future<?> future : futures)
                future.get();
        }
        catch (InterruptedException iE)
        {
            Thread.currentThread().interrupt();
        }
        catch (ExecutionException eE)
        {
            throw new RuntimeException(eE);
        }
        
        service.shutdown();
        
        // Copy color map information
        for (int c = 0; c < in.getSizeC(); c++)
            out.getColorModel().setColorMap(c, in.getColorMap(c), true);
        
        out.dataChanged();
        
        return out;
    }
    
    /**
     * Performs a T projection of the input sequence using the specified algorithm. If the sequence
     * has only one time point, then a copy of the sequence is returned
     * 
     * @param sequence
     *            the sequence to project
     * @param projection
     *            the type of projection to perform (see {@link ProjectionType} enumeration)
     * @param multiThread
     *            true if the process should be multi-threaded
     * @return the projected sequence
     */
    public static Sequence tProjection(final Sequence in, final ProjectionType projection, boolean multiThread)
    {
        return tProjection(in, projection, multiThread, false);
    }
    
    /**
     * Performs a T projection of the input sequence using the specified algorithm. If the sequence
     * has only one time point, then a copy of the sequence is returned
     * 
     * @param sequence
     *            the sequence to project
     * @param projection
     *            the type of projection to perform (see {@link ProjectionType} enumeration)
     * @param multiThread
     *            true if the process should be multi-threaded
     * @param restrictToROI
     *            <code>true</code> projects only data located within the sequence ROI,
     *            <code>false</code> projects the entire data set
     * @return the projected sequence
     */
    public static Sequence tProjection(final Sequence in, final ProjectionType projection, boolean multiThread, boolean restrictToROI)
    {
        final int frames = in.getSizeT();
        if (frames == 1 && !restrictToROI) return SequenceUtil.getCopy(in);
        
        final Sequence out = new Sequence(projection.name() + " projection of " + in.getName());
        out.copyMetaDataFrom(in, false);
        
        final int width = in.getSizeX();
        final int height = in.getSizeY();
        final int depth = in.getSizeZ();
        final int channels = in.getSizeC();
        final DataType dataType = in.getDataType_();
        
        final Collection<ROI> rois = in.getROISet();
        final boolean processROI = restrictToROI && rois.size() > 0;
        
        int cpus = SystemUtil.getNumberOfCPUs();
        int chunkSize = width * height / cpus;
        final int[] minOffsets = new int[cpus];
        final int[] maxOffsets = new int[cpus];
        for (int cpu = 0; cpu < cpus; cpu++)
        {
            minOffsets[cpu] = chunkSize * cpu;
            maxOffsets[cpu] = chunkSize * (cpu + 1);
        }
        // NB: the last chunk must include the remaining pixels (in case rounding off occurs)
        maxOffsets[cpus - 1] = width * height;
        
        ExecutorService service = multiThread ? Executors.newFixedThreadPool(SystemUtil.getNumberOfCPUs()) : Executors.newSingleThreadExecutor();
        ArrayList<Future<?>> futures = new ArrayList<Future<?>>(channels * depth);
        
        for (int slice = 0; slice < depth; slice++)
        {
            if (Thread.currentThread().isInterrupted()) break;
            
            final int z = slice;
            
            out.setImage(0, z, new IcyBufferedImage(width, height, channels, dataType));
            
            for (int channel = 0; channel < channels; channel++)
            {
                final int c = channel;
                
                for (int cpu = 0; cpu < cpus; cpu++)
                {
                    final int minOffset = minOffsets[cpu];
                    final int maxOffset = maxOffsets[cpu];
                    
                    futures.add(service.submit(new Runnable()
                    {
                        @Override
                        public void run()
                        {
                            Object[][] in_T_Z_XY = (Object[][]) in.getDataXYZT(c);
                            Object out_Z_XY = out.getDataXY(0, z, c);
                            
                            double[] buffer = new double[frames];
                            double[] dataToProject = null;
                            
                            for (int offset = minOffset; offset < maxOffset; offset++)
                            {
                                if (processROI)
                                {
                                    int x = offset % width;
                                    int y = offset / width;
                                    
                                    int nbValues = 0;
                                    
                                    for (int t = 0; t < frames; t++)
                                        for (ROI roi : rois)
                                            if (roi.contains(x, y, z, t, c))
                                            {
                                                buffer[nbValues++] = Array1DUtil.getValue(in_T_Z_XY[t][z], offset, dataType);
                                                break;
                                            }
                                            
                                    if (nbValues == 0) continue;
                                    
                                    dataToProject = (nbValues == buffer.length) ? buffer : Arrays.copyOf(buffer, nbValues);
                                }
                                else
                                {
                                    for (int t = 0; t < frames; t++)
                                        buffer[t] = Array1DUtil.getValue(in_T_Z_XY[t][z], offset, dataType);
                                    dataToProject = buffer;
                                }
                                
                                switch (projection)
                                {
                                case MAX:
                                    Array1DUtil.setValue(out_Z_XY, offset, dataType, ArrayMath.max(dataToProject));
                                    break;
                                case MEAN:
                                    Array1DUtil.setValue(out_Z_XY, offset, dataType, ArrayMath.mean(dataToProject));
                                    break;
                                case MED:
                                    Array1DUtil.setValue(out_Z_XY, offset, dataType, ArrayMath.median(dataToProject, false));
                                    break;
                                case MIN:
                                    Array1DUtil.setValue(out_Z_XY, offset, dataType, ArrayMath.min(dataToProject));
                                    break;
                                case STD:
                                    Array1DUtil.setValue(out_Z_XY, offset, dataType, ArrayMath.std(dataToProject, true));
                                    break;
                                case SATSUM:
                                    Array1DUtil.setValue(out_Z_XY, offset, dataType, Math.min(ArrayMath.sum(dataToProject), dataType.getMaxValue()));
                                    break;
                                default:
                                    throw new UnsupportedOperationException(projection + " intensity projection not implemented");
                                }
                            } // offset
                        }
                    }));
                }
            }
        }
        
        try
        {
            for (Future<?> future : futures)
                future.get();
        }
        catch (Exception e)
        {
            throw new RuntimeException(e);
        }
        
        service.shutdown();
        
        // Copy color map information
        for (int c = 0; c < in.getSizeC(); c++)
            out.getColorModel().setColorMap(c, in.getColorMap(c), true);
            
        out.dataChanged();
        
        return out;
    }
    
    @Override
    public void declareInput(VarList inputMap)
    {
        inputMap.add("input", input.getVariable());
        inputMap.add("projection direction", projectionDir.getVariable());
        inputMap.add("projection type", projectionType.getVariable());
        inputMap.add("restrict to ROI", restrictToROI.getVariable());
    }
    
    @Override
    public void declareOutput(VarList outputMap)
    {
        outputMap.add("projection output", output);
    }
    
}
