/*
 * Decompiled with CFR 0.152.
 */
package com.nativelibs4java.opencl.util;

import com.nativelibs4java.opencl.CLBuildException;
import com.nativelibs4java.opencl.CLContext;
import com.nativelibs4java.opencl.CLDevice;
import com.nativelibs4java.opencl.CLKernel;
import com.nativelibs4java.opencl.CLProgram;
import com.nativelibs4java.opencl.CLQueue;
import com.nativelibs4java.opencl.JavaCL;
import com.ochafik.util.listenable.Pair;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.Map;

public class ParallelMath {
    protected CLContext context;
    protected CLQueue queue;
    private EnumMap<Fun1, EnumMap<Primitive, Fun1Kernels>> fun1Kernels = new EnumMap(Fun1.class);
    private EnumMap<Fun2, Map<PrimitiveTrio, CLKernel>> fun2Kernels = new EnumMap(Fun2.class);

    public ParallelMath() {
        this(JavaCL.createBestContext().createDefaultQueue(new CLDevice.QueueProperties[0]));
    }

    public ParallelMath(CLQueue queue) {
        this.queue = queue;
        CLContext context = queue.getContext();
    }

    public CLQueue getQueue() {
        return this.queue;
    }

    public CLContext getContext() {
        return this.getQueue().getContext();
    }

    protected String createVectFun1Source(Fun1 function, Primitive type, StringBuilder out, boolean inPlace) {
        String t = type.type();
        String kernelName = "vect_" + function.name() + "_" + t + (inPlace ? "_inplace" : "");
        out.append("__kernel void " + kernelName + "(\n");
        if (!inPlace) {
            out.append("\t__global const " + t + "* in,\n");
        }
        out.append("\t__global " + t + "* out\n");
        out.append(") {\n");
        out.append("\tint i = get_global_id(0);\n");
        out.append("\tout[i] = ");
        function.expr(inPlace ? "out" : "in", out);
        out.append("[i]);\n");
        out.append("}\n");
        return kernelName;
    }

    protected String createVectFun2Source(Fun2 function, Primitive type1, Primitive type2, Primitive typeOut, StringBuilder out) {
        String t1 = type1.type();
        String t2 = type2.type();
        String to = typeOut.type();
        String kernelName = "vect_" + function.name() + "_" + t1 + "_" + t2 + "_" + to;
        out.append("__kernel void " + kernelName + "(\n");
        out.append("\t__global const " + t1 + "* in1,\n");
        out.append("\t__global const " + t2 + "* in2,\n");
        out.append("\t__global " + to + "* out\n");
        out.append(") {\n");
        out.append("\tint i = get_global_id(0);\n");
        out.append("\tout[i] = (" + to + ")");
        function.expr("in1[i]", "in2[i]", out);
        out.append(";\n");
        out.append("}\n");
        return kernelName;
    }

    public synchronized CLKernel getKernel(Fun1 op, Primitive prim, boolean inPlace) throws CLBuildException {
        Fun1Kernels kers;
        EnumMap<Primitive, Fun1Kernels> m = this.fun1Kernels.get((Object)op);
        if (m == null) {
            m = new EnumMap(Primitive.class);
            this.fun1Kernels.put(op, m);
        }
        if ((kers = m.get((Object)prim)) == null) {
            StringBuilder out = new StringBuilder(300);
            String inPlaceName = this.createVectFun1Source(op, prim, out, true);
            String notInPlaceName = this.createVectFun1Source(op, prim, out, false);
            CLProgram prog = this.getContext().createProgram(out.toString()).build();
            kers = new Fun1Kernels();
            kers.inPlace = prog.createKernel(inPlaceName, new Object[0]);
            kers.notInPlace = prog.createKernel(notInPlaceName, new Object[0]);
            m.put(prim, kers);
        }
        return inPlace ? kers.inPlace : kers.notInPlace;
    }

    public synchronized CLKernel getKernel(Fun2 op, Primitive prim) throws CLBuildException {
        return this.getKernel(op, prim, prim, prim);
    }

    public synchronized CLKernel getKernel(Fun2 op, Primitive prim1, Primitive prim2, Primitive primOut) throws CLBuildException {
        PrimitiveTrio key;
        CLKernel ker;
        Map<PrimitiveTrio, CLKernel> m = this.fun2Kernels.get((Object)op);
        if (m == null) {
            m = new HashMap<PrimitiveTrio, CLKernel>();
            this.fun2Kernels.put(op, m);
        }
        if ((ker = m.get(key = new PrimitiveTrio(prim1, prim2, primOut))) == null) {
            StringBuilder out = new StringBuilder(300);
            String name = this.createVectFun2Source(op, prim1, prim2, primOut, out);
            CLProgram prog = this.getContext().createProgram(out.toString()).build();
            ker = prog.createKernel(name, new Object[0]);
            m.put(key, ker);
        }
        return ker;
    }

    static class PrimitiveTrio
    extends Pair<Primitive, Pair<Primitive, Primitive>> {
        public PrimitiveTrio(Primitive a, Primitive b, Primitive c) {
            super(a, new Pair<Primitive, Primitive>(b, c));
        }
    }

    private static class Fun1Kernels {
        CLKernel inPlace;
        CLKernel notInPlace;

        private Fun1Kernels() {
        }
    }

    public static enum Primitive {
        Float,
        Double,
        Long,
        Int,
        Short,
        Byte,
        Float2,
        Double2,
        Long2,
        Int2,
        Short2,
        Byte2,
        Float3,
        Double3,
        Long3,
        Int3,
        Short3,
        Byte3,
        Float4,
        Double4,
        Long4,
        Int4,
        Short4,
        Byte4,
        Float8,
        Double8,
        Long8,
        Int8,
        Short8,
        Byte8,
        Float16,
        Double16,
        Long16,
        Int16,
        Short16,
        Byte16;


        String type() {
            return this.name().toLowerCase();
        }
    }

    public static enum Fun2 {
        atan2,
        dist,
        modulo("%"),
        rshift(">>"),
        lshift("<<"),
        add("+"),
        substract("-"),
        multiply("*"),
        divide("/");

        String infixOp;

        private Fun2() {
        }

        private Fun2(String infixOp) {
            this.infixOp = infixOp;
        }

        void expr(String a, String b, StringBuilder out) {
            if (this.infixOp == null) {
                out.append(this.name()).append('(').append(a).append(", ").append(b).append(")");
            } else {
                out.append(a).append(' ').append(this.infixOp).append(' ').append(b);
            }
        }
    }

    public static enum Fun1 {
        log,
        exp,
        sqrt,
        sin,
        cos,
        tan,
        atan,
        asin,
        acos,
        sinh,
        cosh,
        tanh,
        asinh,
        acosh,
        atanh;


        void expr(String a, StringBuilder out) {
            out.append(this.name()).append('(').append(a).append(")");
        }
    }
}

