package plugins.spop.clahe;

import java.util.ArrayList;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;

import icy.sequence.Sequence;
import icy.sequence.SequenceUtil;
import icy.system.SystemUtil;
import icy.system.thread.Processor;
import icy.type.DataType;
import plugins.adufour.blocks.lang.Block;
import plugins.adufour.blocks.util.VarList;
import plugins.adufour.ezplug.EzPlug;
import plugins.adufour.ezplug.EzStatus;
import plugins.adufour.ezplug.EzStoppable;
import plugins.adufour.ezplug.EzVarDouble;
import plugins.adufour.ezplug.EzVarInteger;
import plugins.adufour.ezplug.EzVarSequence;
import plugins.adufour.ezplug.EzVarText;

public class Clahe extends EzPlug implements EzStoppable, Block
{
    EzVarSequence input = new EzVarSequence("Input");
    
    EzVarText type = new EzVarText("Type", new String[] { "2D", "Multi 2D", "3D" }, 2, false);
    EzVarInteger bins_menu = new EzVarInteger("No.bins", 255, 2, 255, 1);
    EzVarInteger half_size = new EzVarInteger("Half size XY", 15, 1, 10000, 1);
    EzVarInteger half_sizeZ = new EzVarInteger("Half size Z", 3, 1, 63, 1);
    EzVarDouble slope_menu = new EzVarDouble("Maximum slope", 3, 1, 255, 0.1);
    
    EzVarSequence output = new EzVarSequence("Output");
    
    @Override
    public void initialize()
    {
        addEzComponent(input);
        addEzComponent(type);
        addEzComponent(bins_menu);
        addEzComponent(half_size);
        addEzComponent(half_sizeZ);
        
        addEzComponent(slope_menu);
        
        setTimeDisplay(true);
        
    }
    
    @Override
    public void execute()
    {
        Sequence out = null;
        int blockRadius = half_size.getValue();
        int block_z = half_sizeZ.getValue();
        int bins = bins_menu.getValue();
        float slope = slope_menu.getValue().floatValue();
        Sequence seq_in = SequenceUtil.convertToType(input.getValue(true), DataType.UBYTE, true);
        
        if (type.getValue().equalsIgnoreCase("2D"))
        {
            out = clahe_run2D(seq_in, input.getValue().getFirstViewer().getPositionZ(), blockRadius, bins, slope, getStatus());
            if (out != null) out.setName("Clahe 2D");
        }
        else if (type.getValue().equalsIgnoreCase("Multi 2D"))
        {
            out = clahe_run2D(seq_in, -1, blockRadius, bins, slope, getStatus());
            if (out != null) out.setName("Clahe multi2D");
        }
        else if (type.getValue().equalsIgnoreCase("3D"))
        {
            out = clahe_run3D(seq_in, blockRadius, block_z, bins, slope, getStatus());
            if (out != null) out.setName("Clahe 3D");
        }
        
        if (out != null)
        {
            if (getUI() != null) addSequence(out);
            output.setValue(out);
        }
    }
    
    /**
     * @param seqIn
     * @param z
     *            the slice to process (or -1 to process all slices independently). NB: to process
     *            the whole stack in 3D, use
     *            {@link #clahe_run3D(Sequence, int, int, int, double, EzStatus)}
     * @param blockRadius
     * @param bins
     * @param slope
     * @param status
     * @return
     */
    private static Sequence clahe_run2D(Sequence seqIn, int slice, int blockRadius, int bins, float slope, EzStatus status)
    {
        Sequence seqOut = SequenceUtil.getCopy(seqIn);
        
        int dim_x = seqIn.getSizeX();
        int dim_y = seqIn.getSizeY();
        int dim_z = seqIn.getSizeZ();
        int dim_t = seqIn.getSizeT();
        int dim_c = seqIn.getSizeC();
        int z_init = 0;
        
        if (slice != -1)
        {
            z_init = slice;
            dim_z = z_init + 1;
        }
        
        Processor processor = new Processor(SystemUtil.getNumberOfCPUs());
        
        int nbTasks = dim_t * (dim_z - z_init) * dim_c;
        ArrayList<Future<?>> tasks = new ArrayList<Future<?>>(nbTasks);
        
        try
        {
            for (int t = 0; t < dim_t; t++)
                for (int z = z_init; z < dim_z; z++)
                    for (int c = 0; c < dim_c; c++)
                    {
                        byte[] srcSlice = seqIn.getDataXYAsByte(t, z, c);
                        byte[] dstSlice = seqOut.getDataXYAsByte(t, z, c);
                        tasks.add(processor.submit(new Clahe2D_task(srcSlice, dstSlice, blockRadius, bins, slope, dim_x, dim_y)));
                    }
                
            double tasksDone = 0;
            for (Future<?> task : tasks)
            {
                task.get();
                tasksDone++;
                // Don't refresh the progress bar too often (especially for short-lived tasks)
                if (status != null && tasksDone % 2 == 0) status.setCompletion(tasksDone / nbTasks);
            }
        }
        catch (InterruptedException e)
        {
            // Keep the interrupted state
            Thread.currentThread().interrupt();
            return null;
        }
        catch (ExecutionException e)
        {
        }
        finally
        {
            tasks.clear();
            processor.shutdownNow();
        }
        
        return seqOut;
    }
    
    private static Sequence clahe_run3D(Sequence seqIn, int blockRadius, int block_z, int bins, double slope, EzStatus status)
    {
        Sequence seqOut = SequenceUtil.getCopy(seqIn);
        
        int dim_x = seqIn.getSizeX();
        int dim_y = seqIn.getSizeY();
        int dim_z = seqIn.getSizeZ();
        int dim_t = seqIn.getSizeT();
        int dim_c = seqIn.getSizeC();
        
        Processor processor = new Processor(SystemUtil.getNumberOfCPUs());
        
        int nbTasks = dim_t * dim_z * dim_c;
        ArrayList<Future<?>> tasks = new ArrayList<Future<?>>();
        
        try
        {
            for (int t = 0; t < dim_t; t++)
            {
                for (int c = 0; c < dim_c; c++)
                {
                    byte[][] src = seqIn.getDataXYZAsByte(t, c);
                    byte[][] dst = seqOut.getDataXYZAsByte(t, c);
                    
                    for (int z = 0; z < dim_z; z++)
                        tasks.add(processor.submit(new Clahe3D_task(src, dst, blockRadius, block_z, bins, slope, z, dim_x, dim_y, dim_z)));
                }
            }
            
            double tasksDone = 0;
            for (Future<?> task : tasks)
            {
                task.get();
                tasksDone++;
                // Don't refresh the progress bar too often (especially for short-lived tasks)
                if (status != null && tasksDone % 2 == 0) status.setCompletion(tasksDone / nbTasks);
            }
        }
        catch (InterruptedException e)
        {
            // Keep the interrupted state
            Thread.currentThread().interrupt();
        }
        catch (ExecutionException e)
        {
            e.printStackTrace();
        }
        finally
        {
            tasks.clear();
            processor.shutdownNow();
        }
        
        return seqOut;
    }
    
    public void clean()
    {
    }
    
    @Override
    public void declareInput(VarList inputMap)
    {
        inputMap.add("Input", input.getVariable());
        
        inputMap.add("Type", type.getVariable());
        inputMap.add("No.bins", bins_menu.getVariable());
        inputMap.add("Half size XY", half_size.getVariable());
        inputMap.add("Half size Z", half_sizeZ.getVariable());
        inputMap.add("Maximum slope", slope_menu.getVariable());
        
    }
    
    @Override
    public void declareOutput(VarList outputMap)
    {
        outputMap.add("Output", output.getVariable());
    }
    
    static int roundPositive(float a)
    {
        return (int) (a + 0.5f);
    }
}

class Clahe3D_task implements Runnable
{
    int z, dim_x, dim_y, dim_z;
    int blockRadius;
    int block_z;
    int bins;
    double slope;
    byte[][] src;
    byte[][] dst;
    
    public Clahe3D_task(byte[][] src, byte[][] dst, int blockRadius, int block_z, int bins, double slope, int z, int dim_x, int dim_y, int dim_z)
    {
        this.z = z;
        this.dim_x = dim_x;
        this.dim_y = dim_y;
        this.dim_z = dim_z;
        this.blockRadius = blockRadius;
        this.block_z = block_z;
        this.bins = bins;
        this.slope = slope;
        this.src = src;
        this.dst = dst;
        
    }
    
    @Override
    public void run()
    {
        float binFactor = bins / 255f;
        
        int zMin = Math.max(0, z - block_z);
        int zMax = Math.min(dim_z, z + block_z + 1);
        int d = zMax - zMin;
        
        byte[] slice = src[z];
        
        for (int y = 0; y < dim_y; ++y)
        {
            int yMin = Math.max(0, y - blockRadius);
            int yMax = Math.min(dim_y, y + blockRadius + 1);
            //int h = yMax - yMin;
            int h = Math.min(dim_y, yMax) - yMin;
            
            int xMin0 = Math.max(0, 0 - blockRadius);
            int xMax0 = Math.min(dim_x - 1, 0 + blockRadius);
            
            /* initially fill histogram */
            int[] hist = new int[bins + 1];
            int[] clippedHist = new int[bins + 1];
            for (int zi = zMin; zi < zMax; ++zi)
            {
                byte[] neighbourSlice = src[zi];
                for (int yi = yMin; yi < yMax; ++yi)
                {
                    int offset = yi * dim_x + xMin0;
                    for (int xi = xMin0; xi < xMax0; ++xi, offset++)
                        ++hist[Clahe.roundPositive((neighbourSlice[offset] & 0xff) * binFactor)];
                }
            }
            
            for (int x = 0; x < dim_x; ++x)
            {
                int v = Clahe.roundPositive((slice[y * dim_x + x] & 0xff) * binFactor);
                
                int xMin = Math.max(0, x - blockRadius);
                int xMax = x + blockRadius + 1;
                //int w = xMax - xMin;
                int w = Math.min(dim_x, xMax) - xMin;
                int n = h * w * d;
                
                int limit;
                limit = (int) (slope * n / bins + 0.5f);
                
                /* remove left behind values from histogram */
                if (xMin > 0)
                {
                    int xMin1 = xMin - 1;
                    for (int zi = zMin; zi < zMax; ++zi)
                    {
                        byte[] neighbourSlice = src[zi];
                        for (int yi = yMin; yi < yMax; ++yi)
                            --hist[Clahe.roundPositive((neighbourSlice[yi * dim_x + xMin1] & 0xff) * binFactor)];
                    }
                }
                
                /* add newly included values to histogram */
                if (xMax <= dim_x)
                {
                    int xMax1 = xMax - 1;
                    for (int zi = zMin; zi < zMax; ++zi)
                    {
                        byte[] neighbourSlice = src[zi];
                        for (int yi = yMin; yi < yMax; ++yi)
                            ++hist[Clahe.roundPositive((neighbourSlice[yi * dim_x + xMax1] & 0xff) * binFactor)];
                    }
                }
                
                /* clip histogram and redistribute clipped entries */
                System.arraycopy(hist, 0, clippedHist, 0, hist.length);
                int clippedEntries = 0, clippedEntriesBefore;
                do
                {
                    clippedEntriesBefore = clippedEntries;
                    clippedEntries = 0;
                    for (int i = 0; i <= bins; ++i)
                    {
                        int dold = clippedHist[i] - limit;
                        if (dold > 0)
                        {
                            clippedEntries += dold;
                            clippedHist[i] = limit;
                        }
                    }
                    
                    int dold = clippedEntries / (bins + 1);
                    int m = clippedEntries % (bins + 1);
                    for (int i = 0; i <= bins; ++i)
                        clippedHist[i] += dold;
                    
                    if (m != 0)
                    {
                        int s = bins / m;
                        for (int i = 0; i <= bins; i += s)
                            ++clippedHist[i];
                    }
                }
                while (clippedEntries != clippedEntriesBefore);
                
                /* build cdf of clipped histogram */
                int hMin = bins;
                for (int i = 0; i < hMin; ++i)
                    if (clippedHist[i] != 0) hMin = i;
                
                int cdf = 0;
                for (int i = hMin; i <= v; ++i)
                    cdf += clippedHist[i];
                
                int cdfMax = cdf;
                for (int i = v + 1; i <= bins; ++i)
                    cdfMax += clippedHist[i];
                
                int cdfMin = clippedHist[hMin];
                
                dst[z][y * dim_x + x] = (byte) Clahe.roundPositive(255f * (cdf - cdfMin) / (cdfMax - cdfMin));
                
            }
        }
        
    }
    
}

class Clahe2D_task implements Runnable
{
    int dim_x, dim_y;
    int blockRadius;
    int bins;
    float slope;
    byte[] src;
    byte[] dst;
    
    public Clahe2D_task(byte[] src, byte[] dst, int blockRadius, int bins, float slope, int dim_x, int dim_y)
    {
        this.dim_x = dim_x;
        this.dim_y = dim_y;
        this.blockRadius = blockRadius;
        this.bins = bins;
        this.slope = slope;
        this.src = src;
        this.dst = dst;
    }
    
    @Override
    public void run()
    {
        float binFactor = bins / 255f;
        
        for (int y = 0; y < dim_y; ++y)
        {
            int yMin = Math.max(0, y - blockRadius);
            int yMax = Math.min(dim_y, y + blockRadius + 1);
            int h = yMax - yMin;
            
            /* initially fill histogram */
            int[] hist = new int[bins + 1];
            int[] clippedHist = new int[bins + 1];
            
            {
                int xMax0 = Math.min(dim_x - 1, blockRadius);
                for (int yi = yMin; yi < yMax; ++yi)
                    for (int xi = 0, offset = yi * dim_x; xi < xMax0; ++xi, ++offset)
                        ++hist[Clahe.roundPositive((src[offset] & 0xff) * binFactor)];
            }
            
            for (int x = 0, offset = y * dim_x; x < dim_x; ++x, ++offset)
            {
                int v = Clahe.roundPositive((src[offset] & 0xff) * binFactor);
                
                int xMin = Math.max(0, x - blockRadius);
                int xMax = x + blockRadius + 1;
                int w = Math.min(dim_x, xMax) - xMin;
                int n = h * w;
                
                int limit = Clahe.roundPositive(slope * n / bins);
                
                /* remove left behind values from histogram */
                if (xMin > 0)
                {
                    for (int yi = yMin, yOff = yi * dim_x + xMin - 1; yi < yMax; ++yi, yOff += dim_x)
                        --hist[Clahe.roundPositive((src[yOff] & 0xff) * binFactor)];
                }
                
                /* add newly included values to histogram */
                if (xMax <= dim_x)
                {
                    for (int yi = yMin, yOff = yi * dim_x + xMax - 1; yi < yMax; ++yi, yOff += dim_x)
                        ++hist[Clahe.roundPositive((src[yOff] & 0xff) * binFactor)];
                }
                
                /* clip histogram and redistribute clipped entries */
                System.arraycopy(hist, 0, clippedHist, 0, hist.length);
                int clippedEntries = 0, prevClippedEntries;
                do
                {
                    prevClippedEntries = clippedEntries;
                    clippedEntries = 0;
                    for (int i = 0; i <= bins; ++i)
                    {
                        int d = clippedHist[i] - limit;
                        if (d > 0)
                        {
                            clippedEntries += d;
                            clippedHist[i] = limit;
                        }
                    }
                    
                    int d = clippedEntries / (bins + 1);
                    int m = clippedEntries % (bins + 1);
                    for (int i = 0; i <= bins; ++i)
                        clippedHist[i] += d;
                    
                    if (m != 0)
                    {
                        int s = bins / m;
                        for (int i = 0; i <= bins; i += s)
                            ++clippedHist[i];
                    }
                }
                while (clippedEntries != prevClippedEntries);
                
                /* build cdf of clipped histogram */
                int hMin = bins;
                for (int i = 0; i < hMin; ++i)
                    if (clippedHist[i] != 0) hMin = i;
                
                int cdf = 0;
                for (int i = hMin; i <= v; ++i)
                    cdf += clippedHist[i];
                
                int cdfMax = cdf;
                for (int i = v + 1; i <= bins; ++i)
                    cdfMax += clippedHist[i];
                
                int cdfMin = clippedHist[hMin];
                
                dst[offset] = (byte) Clahe.roundPositive(255f * (cdf - cdfMin) / (cdfMax - cdfMin));
            }
        }
        
    }
}
