package plugins.adufour.activemeshes.energy;

import icy.image.IcyBufferedImage;
import icy.sequence.Sequence;
import icy.type.DataType;

import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.concurrent.ExecutorService;

import javax.vecmath.Point3d;
import javax.vecmath.Vector3d;

import plugins.adufour.activemeshes.mesh.Mesh;
import plugins.adufour.activemeshes.mesh.Vertex;
import plugins.adufour.vars.lang.Var;

public class ChanVeseMumfordShahTerm extends DataAttachmentTerm
{
    private final LinkedHashMap<Mesh, Double> means_in  = new LinkedHashMap<Mesh, Double>();
    private final LinkedHashMap<Mesh, Double> means_out = new LinkedHashMap<Mesh, Double>();
    private double                            mean_out  = 0.0;
    private final Sequence                    shortMask = new Sequence();
    private short[][]                         shortMask_Z_XY;
    
//    private final Var<Double>                 sensitivity;
    
    public ChanVeseMumfordShahTerm(ExecutorService service, Var<Double> weight)//, Var<Double> sensitivity)
    {
        super(service, weight);
//        this.sensitivity = sensitivity;
    }
    
    public synchronized void registerMesh(Mesh mesh)
    {
        means_in.put(mesh, 1.0);
        
        mesh.addModel(new Model(mesh)
        {
            @Override
            public void computeForces()
            {
                Vector3d cvms = new Vector3d();
                double value, inDiff2, outDiff2;
                double mean_in = means_in.get(mesh);
//                double sens = sensitivity.getValue();
                double w = weight.getValue();
                
                // new: sensitivity is dependent on the data:
                // low mean_in => dim object => high sensitivity
                double sens = (1 / Math.max(mean_out * 2, mean_in));
                
                for (Vertex v : mesh.vertices)
                {
                    if (v == null) continue;
                    
                    value = sampler.getPixelValue(v.position.x / imResolution.x, v.position.y / imResolution.y, v.position.z / imResolution.z, true);
                    inDiff2 = value - mean_in;
                    outDiff2 = value - mean_out;
                    inDiff2 *= inDiff2;
                    outDiff2 *= outDiff2;
                    cvms.scale(mesh.getResolution() * w * (sens * outDiff2 - inDiff2 / sens), v.normal);
                    
                    v.forces.add(cvms);
                }
            }
            
            @Override
            public void removeMesh(Mesh mesh)
            {
                unregisterMesh(mesh);
            }
        });
    }
    
    /**
     * 
     * @param sequence
     * @param t
     * @param c
     * @param isNormalized
     *            set to true if the input data is already normalized with double values between 0
     *            and 1, false otherwise.
     */
    public void setImageData(Sequence sequence, int t, int c, boolean isNormalized, boolean smartRescale)
    {
        super.setImageData(sequence, t, c, isNormalized, smartRescale);
        
        // create a binary mask for means computation
        
        int w = sequence.getSizeX();
        int h = sequence.getSizeY();
        int d = sequence.getSizeZ();
        
        if (shortMask.getNumImage() != d)
        {
            // fill bitMask with empty images on the first pass
            for (int z = 0; z < d; z++)
                shortMask.setImage(0, z, new IcyBufferedImage(w, h, 1, DataType.USHORT));
        }
        else
        {
            // erase the mask
            for (int z = 0; z < d; z++)
                Arrays.fill(shortMask.getDataXYAsShort(0, z, 0), (short) 0);
        }
        
        shortMask_Z_XY = shortMask.getDataXYZAsShort(0, 0);
    }
    
    public final void updateMeans(boolean local)
    {
        if (local)
        {
            updateMeans_local();
        }
        else
        {
            updateMeans_global();
        }
    }
    
    private final void updateMeans_global()
    {
        // erase the mask
        for (int z = 0; z < shortMask_Z_XY.length; z++)
            Arrays.fill(shortMask_Z_XY[z], (short) 0);
        
        short cpt = 1;
        
        for (final Mesh mesh : means_in.keySet())
        {
            final short id = cpt++;
            
            means_in.put(mesh, mesh.computeIntensity(sampler, shortMask_Z_XY, id, imResolution, multiThreadService));
        }
        
        // byteMask is now filled, mean_out can be computed
        
        double[][] data = sampler.getData();
        
        double outSum = 0;
        int outCpt = 0;
        
        for (int z = 0; z < shortMask_Z_XY.length; z++)
        {
            double[] dataSlice = data[z];
            short[] maskSlice = shortMask_Z_XY[z];
            
            for (int offset = 0; offset < dataSlice.length; offset++)
                if (maskSlice[offset] == 0)
                {
                    outSum += dataSlice[offset];
                    outCpt++;
                }
        }
        
        mean_out = outSum / outCpt;
    }
    
    private final void updateMeans_local()
    {
        // erase the mask
        for (int z = 0; z < shortMask_Z_XY.length; z++)
            Arrays.fill(shortMask_Z_XY[z], (short) 0);
        
        short cpt = 1;
        
        for (final Mesh mesh : means_in.keySet())
        {
            final short id = cpt++;
            means_in.put(mesh, mesh.computeIntensity(sampler, shortMask_Z_XY, id, imResolution, multiThreadService));
        }
        
        // byteMask is now filled, mean_out can be computed for each mesh "locally"
        
        double[][] data = sampler.getData();
        
        int neighborhoodRadius = 20;
        
        for (final Mesh mesh : means_in.keySet())
        {
            Point3d boxMin = new Point3d(), boxMax = new Point3d();
            mesh.getBoundingBox(boxMin, boxMax);
            
            final int minX = Math.max(0, (int) Math.floor(boxMin.x / imResolution.x) - neighborhoodRadius);
            final int minY = Math.max(0, (int) Math.floor(boxMin.y / imResolution.y) - neighborhoodRadius);
            final int minZ = Math.max(0, (int) Math.floor(boxMin.z / imResolution.z) - neighborhoodRadius);
            
            final int maxX = Math.min(sampler.dimensions.x - 1, (int) Math.ceil(boxMax.x / imResolution.x) + neighborhoodRadius);
            final int maxY = Math.min(sampler.dimensions.y - 1, (int) Math.ceil(boxMax.y / imResolution.y) + neighborhoodRadius);
            final int maxZ = Math.min(sampler.dimensions.z - 1, (int) Math.ceil(boxMax.z / imResolution.z) + neighborhoodRadius);
            
            double outSum = 0, outCpt = 0;
            
            for (int k = minZ; k <= maxZ; k++)
            {
                double[] dataSlice = data[k];
                short[] maskSlice = shortMask_Z_XY[k];
                
                for (int j = minY; j <= maxY; j++)
                {
                    int xyOffset = j * sampler.dimensions.x + minX;
                    
                    for (int i = minX; i <= maxX; i++, xyOffset++)
                    {
                        if (maskSlice[xyOffset] == 0)
                        {
                            outCpt++;
                            outSum += dataSlice[xyOffset];
                        }
                    }
                }
            }
            
            means_out.put(mesh, outCpt == 0 ? 0 : outSum / outCpt);
        }
    }
    
    public Sequence getBinaryVolume()
    {
        return shortMask;
    }
    
    @Override
    public void unregisterMesh(Mesh mesh)
    {
        means_in.remove(mesh);
    }
    
    @Override
    public void unregisterMeshes()
    {
        means_in.clear();
    }
}
