package plugins.adufour.extrema;

import icy.image.IcyBufferedImage;
import icy.sequence.Sequence;
import icy.type.DataType;
import icy.type.collection.array.Array1DUtil;
import plugins.adufour.blocks.lang.Block;
import plugins.adufour.blocks.util.VarList;
import plugins.adufour.ezplug.EzPlug;
import plugins.adufour.ezplug.EzVar;
import plugins.adufour.ezplug.EzVarDouble;
import plugins.adufour.ezplug.EzVarEnum;
import plugins.adufour.ezplug.EzVarInteger;
import plugins.adufour.ezplug.EzVarListener;
import plugins.adufour.ezplug.EzVarSequence;
import plugins.adufour.vars.lang.VarSequence;

public class LocalExtrema extends EzPlug implements Block
{
    public enum Extrema
    {
        MINIMA, MAXIMA, ALL
    }
    
    private static volatile int cpt       = 1;
    
    private EzVarSequence       input     = new EzVarSequence("Input sequence");
    
    private EzVarEnum<Extrema>  type      = new EzVarEnum<LocalExtrema.Extrema>("Extrema to detect", Extrema.values(), Extrema.MAXIMA);
    
    private EzVarInteger        sizeX     = new EzVarInteger("Search radius (X)", 1, Short.MAX_VALUE, 1);
    private EzVarInteger        sizeY     = new EzVarInteger("Search radius (Y)", 1, Short.MAX_VALUE, 1);
    private EzVarInteger        sizeZ     = new EzVarInteger("Search radius (Z)", 1, Short.MAX_VALUE, 1);
    
    private EzVarDouble         minThreshold = new EzVarDouble("Minima threshold");
    private EzVarDouble         maxThreshold = new EzVarDouble("Maxima threshold");
    
    private VarSequence         output    = new VarSequence("Binary output sequence", null);
    
    @Override
    public void initialize()
    {
        addEzComponent(input);
        addEzComponent(type);
        addEzComponent(sizeX);
        addEzComponent(sizeY);
        addEzComponent(sizeZ);
        
        input.addVarChangeListener(new EzVarListener<Sequence>()
        {
            @Override
            public void variableChanged(EzVar<Sequence> source, Sequence newValue)
            {
                sizeZ.setVisible(newValue != null && newValue.getSizeZ() > 1);
            }
        });
        
        addEzComponent(minThreshold);
        addEzComponent(maxThreshold);
    }
    
    @Override
    protected void execute()
    {
        Sequence in = input.getValue(true);
        Sequence out = null;
        
        switch (type.getValue())
        {
            case MINIMA:
                out = new Sequence(in.getName() + "_LocalMinima#" + cpt++);
                localMinima(in, out, minThreshold.getValue(), sizeX.getValue(), sizeY.getValue(), sizeZ.getValue());
            break;
            case MAXIMA:
                out = new Sequence(in.getName() + "_LocalMaxima#" + cpt++);
                localMaxima(in, out, maxThreshold.getValue(), sizeX.getValue(), sizeY.getValue(), sizeZ.getValue());
            break;
            case ALL:
                out = new Sequence(in.getName() + "_LocalExtrema#" + cpt++);
                localExtrema(in, out, minThreshold.getValue(), maxThreshold.getValue(), sizeX.getValue(), sizeY.getValue(), sizeZ.getValue());
            break;
        }
        output.setValue(out);
        
        if (!isHeadLess()) addSequence(out);
    }
    
    /**
     * Computes the local extrema in the given sequence
     * 
     * @param in
     * @param out
     * @param t
     * @param c
     * @param minThreshold
     * @param maxThreshold
     * @param halfSizeX
     * @param halfSizeY
     * @param halfSizeZ
     * @return
     */
    public void localExtrema(Sequence in, Sequence out, double minThreshold, double maxThreshold, int halfSizeX, int halfSizeY, int halfSizeZ)
    {
        int w = in.getSizeX();
        int h = in.getSizeY();
        int d = in.getSizeZ();
        int c = in.getSizeC();
        int t = in.getSizeT();
        
        DataType dataType = in.getDataType_();
        
        int nx, ny, nz;
        
        byte[] outputSlice = null;
        
        for (int time = 0; time < t; time++)
        {
            for (int slice = 0; slice < d; slice++)
            {
                IcyBufferedImage inSlice = in.getImage(time, slice);
                IcyBufferedImage outSlice = new IcyBufferedImage(w, h, c, DataType.UBYTE);
                
                for (int channel = 0; channel < c; channel++)
                {
                    int xyOffset = 0;
                    
                    Object array_xy = inSlice.getDataXY(channel);
                    
                    outputSlice = outSlice.getDataXYAsByte(channel);
                    
                    for (int j = 0; j < h; j++)
                    {
                        for (int i = 0; i < w; i++, xyOffset++)
                        {
                            double value = Array1DUtil.getValue(array_xy, xyOffset, dataType);
                            
                            boolean min = value < minThreshold, max = value > maxThreshold;
                            
                            // check the neighborhood for a pixel with higher/smaller value
                            
                            for (nz = -halfSizeZ; nz <= halfSizeZ; nz++)
                            {
                                int nSlice = slice + nz;
                                
                                if (nSlice < 0 || nSlice >= d) continue;
                                
                                Object neighborSlice = in.getDataXY(time, nSlice, channel);
                                
                                for (ny = -halfSizeY; ny <= halfSizeY; ny++)
                                {
                                    if (j + ny < 0 || j + ny >= h) continue;
                                    
                                    for (nx = -halfSizeX; nx <= halfSizeX; nx++)
                                    {
                                        if (i + nx < 0 || i + nx >= w) continue;
                                        
                                        double diff = value - Array1DUtil.getValue(neighborSlice, xyOffset + nx + ny * w, dataType);
                                        
                                        if (max && diff < 0 && max)
                                        {
                                            max = false;
                                        }
                                        else if (min && diff > 0)
                                        {
                                            min = false;
                                        }
                                    }
                                }
                            }
                            
                            outputSlice[xyOffset] = max ? (byte) 255 : min ? 0 : (byte) 127;
                        }
                    }
                }
                
                out.setImage(time, slice, outSlice);
            }
        }
    }
    
    /**
     * Computes the local maxima in the given sequence
     * 
     * @param in
     * @param out
     * @param t
     * @param c
     * @param threshold
     * @param halfSizeX
     * @param halfSizeY
     * @param halfSizeZ
     * @return
     */
    public void localMaxima(Sequence in, Sequence out, double threshold, int halfSizeX, int halfSizeY, int halfSizeZ)
    {
        int w = in.getSizeX();
        int h = in.getSizeY();
        int d = in.getSizeZ();
        int c = in.getSizeC();
        int t = in.getSizeT();
        
        DataType dataType = in.getDataType_();
        
        int nx, ny, nz;
        
        byte[] outputSlice = null;
        
        for (int time = 0; time < t; time++)
        {
            for (int slice = 0; slice < d; slice++)
            {
                IcyBufferedImage inSlice = in.getImage(time, slice);
                IcyBufferedImage outSlice = new IcyBufferedImage(w, h, c, DataType.UBYTE);
                
                for (int channel = 0; channel < c; channel++)
                {
                    int xyOffset = 0;
                    
                    Object array_xy = inSlice.getDataXY(channel);
                    
                    outputSlice = outSlice.getDataXYAsByte(channel);
                    
                    for (int j = 0; j < h; j++)
                    {
                        mainLoop: for (int i = 0; i < w; i++, xyOffset++)
                        {
                            double value = Array1DUtil.getValue(array_xy, xyOffset, dataType);
                            
                            // check the pixel value against the minimum threshold
                            
                            if (value <= threshold) continue mainLoop;
                            
                            // check the neighborhood for a pixel with higher value
                            
                            for (nz = -halfSizeZ; nz <= halfSizeZ; nz++)
                            {
                                int nSlice = slice + nz;
                                
                                if (nSlice < 0 || nSlice >= d) continue;
                                
                                Object neighborSlice = in.getDataXY(time, nSlice, channel);
                                
                                for (ny = -halfSizeY; ny <= halfSizeY; ny++)
                                {
                                    if (j + ny < 0 || j + ny >= h) continue;
                                    
                                    for (nx = -halfSizeX; nx <= halfSizeX; nx++)
                                    {
                                        if (i + nx < 0 || i + nx >= w) continue;
                                        
                                        if (value < Array1DUtil.getValue(neighborSlice, xyOffset + nx + ny * w, dataType)) continue mainLoop;
                                    }
                                }
                            }
                            
                            // code runs here => (i,j,k) is a local maximum
                            
                            outputSlice[xyOffset] = (byte) 0xff;
                        }
                    }
                }
                
                out.setImage(time, slice, outSlice);
            }
        }
    }
    
    /**
     * Computes the local minima in the given sequence
     * 
     * @param in
     * @param out
     * @param t
     * @param c
     * @param threshold
     * @param halfSizeX
     * @param halfSizeY
     * @param halfSizeZ
     * @return
     */
    public void localMinima(Sequence in, Sequence out, double threshold, int halfSizeX, int halfSizeY, int halfSizeZ)
    {
        int w = in.getSizeX();
        int h = in.getSizeY();
        int d = in.getSizeZ();
        int c = in.getSizeC();
        int t = in.getSizeT();
        
        DataType dataType = in.getDataType_();
        
        int nx, ny, nz;
        
        byte[] outputSlice = null;
        
        for (int time = 0; time < t; time++)
        {
            for (int slice = 0; slice < d; slice++)
            {
                IcyBufferedImage inSlice = in.getImage(time, slice);
                IcyBufferedImage outSlice = new IcyBufferedImage(w, h, c, DataType.UBYTE);
                
                for (int channel = 0; channel < c; channel++)
                {
                    int xyOffset = 0;
                    
                    Object array_xy = inSlice.getDataXY(channel);
                    
                    outputSlice = outSlice.getDataXYAsByte(channel);
                    
                    for (int j = 0; j < h; j++)
                    {
                        mainLoop: for (int i = 0; i < w; i++, xyOffset++)
                        {
                            double value = Array1DUtil.getValue(array_xy, xyOffset, dataType);
                            
                            // check the pixel value against the maximum threshold
                            
                            if (value >= threshold) continue mainLoop;
                            
                            // check the neighborhood for a pixel with higher value
                            
                            for (nz = -halfSizeZ; nz <= halfSizeZ; nz++)
                            {
                                int nSlice = slice + nz;
                                
                                if (nSlice < 0 || nSlice >= d) continue;
                                
                                Object neighborSlice = in.getDataXY(time, nSlice, channel);
                                
                                for (ny = -halfSizeY; ny <= halfSizeY; ny++)
                                {
                                    if (j + ny < 0 || j + ny >= h) continue;
                                    
                                    for (nx = -halfSizeX; nx <= halfSizeX; nx++)
                                    {
                                        if (i + nx < 0 || i + nx >= w) continue;
                                        
                                        if (value > Array1DUtil.getValue(neighborSlice, xyOffset + nx + ny * w, dataType)) continue mainLoop;
                                    }
                                }
                            }
                            
                            // code runs here => (i,j,k) is a local maximum
                            
                            outputSlice[xyOffset] = (byte) 0xff;
                        }
                    }
                }
                
                out.setImage(time, slice, outSlice);
            }
        }
    }
    
    @Override
    public void clean()
    {
        
    }
    
    @Override
    public void declareInput(VarList inputMap)
    {
        inputMap.add(input.getVariable());
        inputMap.add(type.getVariable());
        inputMap.add(sizeX.getVariable());
        inputMap.add(sizeY.getVariable());
        inputMap.add(sizeZ.getVariable());
        inputMap.add(minThreshold.getVariable());
        inputMap.add(maxThreshold.getVariable());
    }
    
    @Override
    public void declareOutput(VarList outputMap)
    {
        outputMap.add(output);
    }
}
