package plugins.adufour.opencl4icy;

import icy.system.IcyHandledException;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.FloatBuffer;

import plugins.adufour.ezplug.EzException;
import plugins.adufour.ezplug.EzLabel;
import plugins.adufour.ezplug.EzPlug;
import plugins.adufour.ezplug.EzVarEnum;
import plugins.adufour.ezplug.EzVarInteger;

import com.nativelibs4java.opencl.CLBuildException;
import com.nativelibs4java.opencl.CLContext;
import com.nativelibs4java.opencl.CLException;
import com.nativelibs4java.opencl.CLFloatBuffer;
import com.nativelibs4java.opencl.CLKernel;
import com.nativelibs4java.opencl.CLMem.MapFlags;
import com.nativelibs4java.opencl.CLMem.Usage;
import com.nativelibs4java.opencl.CLProgram;
import com.nativelibs4java.opencl.CLQueue;
import com.nativelibs4java.opencl.JavaCL;
import com.ochafik.io.ReadText;

public class OpenCL_Lab extends EzPlug
{
    private CLContext    context;
    private CLQueue      queue;
    private CLProgram    program;
    private CLKernel     kernel;
    
    private EzVarInteger arraySize;
    
    private enum CLFUNCTIONS
    {
        multiply2arrays,
    }
    
    private boolean                runnable = false;
    
    private EzVarEnum<CLFUNCTIONS> function = new EzVarEnum<OpenCL_Lab.CLFUNCTIONS>("Run function", CLFUNCTIONS.values());
    
    @Override
    protected void initialize()
    {
        String output = initCL();
        
        addEzComponent(new EzLabel(output));
        addEzComponent(arraySize = new EzVarInteger("Array size", 100, 1000000, 100));
        addEzComponent(function);
    }
    
    private String initCL()
    {
        String output = "";
        
        try
        {
            context = JavaCL.createBestContext();
            queue = context.createDefaultQueue();
            String programFile = ReadText.readText(OpenCL_Lab.class.getResourceAsStream("CLfunctions.cl"));
            program = context.createProgram(programFile).build();
            
            runnable = true;
            output = "found OpenCL drivers v. " + context.getDevices()[0].getOpenCLVersion();
        }
        catch (IOException e)
        {
            output = "Error (OpenCL lab): unable to load the OpenCL code.";
            e.printStackTrace();
        }
        catch (CLException e)
        {
            output = "Error (OpenCL lab): unable to create the OpenCL context.";
            e.printStackTrace();
        }
        catch (CLBuildException e)
        {
            output = "Error (OpenCL lab): unable to create the OpenCL context.";
            e.printStackTrace();
        }
        catch (UnsatisfiedLinkError linkError)
        {
            output = "Error (OpenCL lab): OpenCL drivers not found.";
        }
        catch (NoClassDefFoundError e)
        {
            throw new IcyHandledException("Error: couldn't load the OpenCL drivers.\n(note: on Microsoft Windows, the drivers can only be loaded once)");
        }
        
        return output;
    }
    
    @Override
    protected void execute()
    {
        if (!runnable) throw new IcyHandledException("Cannot run the plug-in. Probably because OpenCL was not found or not initialized correctly");
        
        String funcName = function.getValue().name();
        try
        {
            kernel = program.createKernel(funcName);
        }
        catch (CLBuildException e)
        {
            throw new EzException("Unable to load OpenCL function \"" + funcName + "\":\n" + e.getMessage(), true);
        }
        
        switch (function.getValue())
        {
            case multiply2arrays:
            {
                multiply2arrays();
            }
            break;
            
            default:
            break;
        }
    }
    
    private void multiply2arrays()
    {
        final int ARRAY_SIZE = arraySize.getValue();
        
        float[] a = new float[ARRAY_SIZE];
        float[] b = new float[ARRAY_SIZE];
        float[] ab = new float[ARRAY_SIZE];
        
        // fill a and b with some values
        for (int i = 0; i < ARRAY_SIZE; i++)
        {
            a[i] = i;
            b[i] = i + 1;
        }
        
        long start, end;
        
        start = System.nanoTime();
        
        // input arguments should be mapped to pre-existing CL buffers
        CLFloatBuffer cl_inBuffer_a = context.createFloatBuffer(Usage.Input, ARRAY_SIZE);
        CLFloatBuffer cl_inBuffer_b = context.createFloatBuffer(Usage.Input, ARRAY_SIZE);
        
        // create a CLEvent, needed for synchronization purposes
        // CLEvent event;
        
        // map the GPU buffer to local memory
        FloatBuffer fb_a = cl_inBuffer_a.map(queue, MapFlags.Write);
        // write the local data to it
        fb_a.put(a);
        // rewind the buffer (needed on some drivers)
        fb_a.rewind();
        // release the mapping
        cl_inBuffer_a.unmap(queue, fb_a);
        
        // same for array b
        FloatBuffer fb_b = cl_inBuffer_b.map(queue, MapFlags.Write);
        fb_b.put(b);
        fb_b.rewind();
        cl_inBuffer_b.unmap(queue, fb_b);
        
        // proceed differently for the output: create first a "direct" float buffer
        FloatBuffer outBuffer = ByteBuffer.allocateDirect(ARRAY_SIZE * 4).order(context.getByteOrder()).asFloatBuffer();
        // share the reference directly with the GPU (with no copy)
        // NOTE: using this technique with copy=true for input parameters is less optimal than the
        // mapping version above
        CLFloatBuffer cl_outBuffer = context.createFloatBuffer(Usage.Output, outBuffer, false);
        
        // send the parameters to the kernel
        kernel.setArgs(cl_inBuffer_a, cl_inBuffer_b, cl_outBuffer);
        
        // run the GPU code
        kernel.enqueueNDRange(queue, new int[] { ARRAY_SIZE });
        
        // read the result
        cl_outBuffer.read(queue, outBuffer, true);
        
        // retrieve the content of the buffer into the output array
        outBuffer.get(ab);
        // rewind the output buffer (not necessary here, but ensures clean code)
        outBuffer.rewind();
        
        end = System.nanoTime();
        
        // print out first array values
        System.out.print("First values of a:  ");
        print(a, 10);
        System.out.print("First values of b:  ");
        print(b, 10);
        System.out.print("First values of a*b: ");
        print(ab, 10);
        System.out.println("Computation time (OpenCL): " + (end - start) / 1000000 + " milliseconds");
    }
    
    private void print(float[] a, int i)
    {
        System.out.print("[ ");
        for (int j = 0; j < i; j++)
            System.out.print(a[j] + " ");
        System.out.println(" ]");
    }
    
    @Override
    public void clean()
    {
        if (queue != null) queue.release();
        if (context != null) context.release();
    }
    
}
