/*
 * Decompiled with CFR 0.152.
 */
package de.mirkosertic.bytecoder.backend.opencl;

import de.mirkosertic.bytecoder.allocator.Allocator;
import de.mirkosertic.bytecoder.api.Logger;
import de.mirkosertic.bytecoder.api.opencl.Context;
import de.mirkosertic.bytecoder.api.opencl.FloatSerializable;
import de.mirkosertic.bytecoder.api.opencl.Kernel;
import de.mirkosertic.bytecoder.api.opencl.OpenCLOptions;
import de.mirkosertic.bytecoder.api.opencl.OpenCLType;
import de.mirkosertic.bytecoder.backend.CompileOptions;
import de.mirkosertic.bytecoder.backend.opencl.OpenCLCompileBackend;
import de.mirkosertic.bytecoder.backend.opencl.OpenCLCompileResult;
import de.mirkosertic.bytecoder.backend.opencl.OpenCLInputOutputs;
import de.mirkosertic.bytecoder.backend.opencl.OpenCLPlatform;
import de.mirkosertic.bytecoder.core.BytecodeLinkerContext;
import de.mirkosertic.bytecoder.core.BytecodeLoader;
import de.mirkosertic.bytecoder.core.BytecodeMethodSignature;
import de.mirkosertic.bytecoder.optimizer.KnownOptimizer;
import de.mirkosertic.bytecoder.ssa.TypeRef;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.Buffer;
import java.nio.FloatBuffer;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.jocl.CL;
import org.jocl.NativePointerObject;
import org.jocl.Pointer;
import org.jocl.Sizeof;
import org.jocl.cl_command_queue;
import org.jocl.cl_context;
import org.jocl.cl_context_properties;
import org.jocl.cl_device_id;
import org.jocl.cl_kernel;
import org.jocl.cl_mem;
import org.jocl.cl_program;

class OpenCLContext
implements Context {
    private static final Map<Class, OpenCLCompileResult> ALREADY_COMPILED = new HashMap<Class, OpenCLCompileResult>();
    private final OpenCLCompileBackend backend;
    private final CompileOptions compileOptions;
    private final cl_context context;
    private final cl_command_queue commandQueue;
    private final Map<Class, CachedKernel> cachedKernels;
    private final Logger logger;
    private final OpenCLPlatform platform;

    OpenCLContext(OpenCLPlatform aPlatform, Logger aLogger, OpenCLOptions aOptions) {
        this.logger = aLogger;
        this.platform = aPlatform;
        this.cachedKernels = new HashMap<Class, CachedKernel>();
        this.backend = new OpenCLCompileBackend();
        this.compileOptions = new CompileOptions(this.logger, false, KnownOptimizer.ALL, true, "opencl", 512, 512, false, aOptions.isPreferStackifier(), Allocator.linear, new String[0], new String[0], null);
        cl_context_properties contextProperties = new cl_context_properties();
        contextProperties.addProperty(4228L, aPlatform.selectedPlatform.id);
        this.context = CL.clCreateContextFromType((cl_context_properties)contextProperties, (long)-1L, null, null, null);
        this.commandQueue = CL.clCreateCommandQueue((cl_context)this.context, (cl_device_id)aPlatform.selectedDevice.id, (long)0L, null);
    }

    private CachedKernel kernelFor(Kernel aKernel) throws IOException {
        Class<?> theKernelClass = aKernel.getClass();
        CachedKernel theCachedKernel = this.cachedKernels.get(theKernelClass);
        if (null != theCachedKernel) {
            return theCachedKernel;
        }
        OpenCLCompileResult theResult = ALREADY_COMPILED.get(theKernelClass);
        OpenCLCompileResult.OpenCLContent content = null;
        if (null == theResult) {
            Method theMethod;
            try {
                theMethod = aKernel.getClass().getDeclaredMethod("processWorkItem", new Class[0]);
            }
            catch (Exception e) {
                throw new IllegalArgumentException("Error resolving kernel method", e);
            }
            BytecodeMethodSignature theSignature = this.backend.signatureFrom(theMethod);
            BytecodeLoader theLoader = new BytecodeLoader(theKernelClass.getClassLoader());
            BytecodeLinkerContext theLinkerContext = new BytecodeLinkerContext(theLoader, this.compileOptions.getLogger());
            theResult = this.backend.generateCodeFor(this.compileOptions, theLinkerContext, aKernel.getClass(), theMethod.getName(), theSignature);
            content = (OpenCLCompileResult.OpenCLContent)theResult.getContent()[0];
            this.logger.debug("Generated Kernel code : {}", new Object[]{content.asString()});
            ALREADY_COMPILED.put(theKernelClass, theResult);
        } else {
            content = (OpenCLCompileResult.OpenCLContent)theResult.getContent()[0];
        }
        cl_program theCLProgram = CL.clCreateProgramWithSource((cl_context)this.context, (int)1, (String[])new String[]{content.asString()}, null, null);
        try {
            CL.clBuildProgram((cl_program)theCLProgram, (int)0, null, null, null, null);
        }
        catch (Exception e) {
            throw new RuntimeException("Error compiling : " + content.asString(), e);
        }
        cl_kernel theKernel = CL.clCreateKernel((cl_program)theCLProgram, (String)"BytecoderKernel", null);
        CachedKernel theCached = new CachedKernel(content.getInputOutputs(), theCLProgram, theKernel);
        this.cachedKernels.put(theKernelClass, theCached);
        return theCached;
    }

    public void compute(int aNumberOfStreams, Kernel aKernel) throws IOException {
        Class<?> theKernelClass = aKernel.getClass();
        CachedKernel theCachedKernel = this.kernelFor(aKernel);
        List<OpenCLInputOutputs.KernelArgument> theArguments = theCachedKernel.inputOutputs.arguments();
        cl_mem[] theMemObjects = new cl_mem[theArguments.size()];
        HashMap<Integer, DataRef> theOutputs = new HashMap<Integer, DataRef>();
        try {
            block14: for (int i = 0; i < theArguments.size(); ++i) {
                OpenCLInputOutputs.KernelArgument theArgument = theArguments.get(i);
                TypeRef theFieldType = TypeRef.toType(theArgument.getField().getValue().getTypeRef());
                if (theFieldType.isArray()) {
                    DataRef theDataRef;
                    TypeRef.ArrayTypeRef theArrayTypeRef = (TypeRef.ArrayTypeRef)theFieldType;
                    TypeRef theArrayElement = TypeRef.toType(theArrayTypeRef.arrayType().getType());
                    switch (theArrayElement.resolve()) {
                        case INT: {
                            Field theField = theKernelClass.getDeclaredField(theArgument.getField().getValue().getName().stringValue());
                            theField.setAccessible(true);
                            Object[] theData = (int[])theField.get(aKernel);
                            theDataRef = new DataRef(Pointer.to((int[])theData), 4 * theData.length);
                            break;
                        }
                        case FLOAT: {
                            Field theField = theKernelClass.getDeclaredField(theArgument.getField().getValue().getName().stringValue());
                            theField.setAccessible(true);
                            Object[] theData = (float[])theField.get(aKernel);
                            theDataRef = new DataRef(Pointer.to((float[])theData), 4 * theData.length);
                            break;
                        }
                        case REFERENCE: {
                            Field theField = theKernelClass.getDeclaredField(theArgument.getField().getValue().getName().stringValue());
                            theField.setAccessible(true);
                            Object[] theData = (Object[])theField.get(aKernel);
                            Class<?> theObjectType = theData.getClass().getComponentType();
                            theDataRef = OpenCLContext.toDataRef(theData, theObjectType);
                            break;
                        }
                        default: {
                            throw new IllegalArgumentException("Not supported array element type " + theArrayElement.resolve() + " for kernel argument " + theArgument.getField().getValue().getName());
                        }
                    }
                    switch (theArgument.getType()) {
                        case INPUT: 
                        case OUTPUT: 
                        case INPUTOUTPUT: {
                            theMemObjects[i] = CL.clCreateBuffer((cl_context)this.context, (long)9L, (long)theDataRef.size, (Pointer)theDataRef.pointer, null);
                            theOutputs.put(i, theDataRef);
                        }
                    }
                    CL.clSetKernelArg((cl_kernel)theCachedKernel.kernel, (int)i, (long)Sizeof.cl_mem, (Pointer)Pointer.to((NativePointerObject)theMemObjects[i]));
                    continue;
                }
                Field theField = theKernelClass.getDeclaredField(theArgument.getField().getValue().getName().stringValue());
                theField.setAccessible(true);
                switch (theFieldType.resolve()) {
                    case FLOAT: {
                        float theData = ((Float)theField.get(aKernel)).floatValue();
                        CL.clSetKernelArg((cl_kernel)theCachedKernel.kernel, (int)i, (long)4L, (Pointer)Pointer.to((float[])new float[]{theData}));
                        continue block14;
                    }
                    case INT: {
                        int theData = (Integer)theField.get(aKernel);
                        CL.clSetKernelArg((cl_kernel)theCachedKernel.kernel, (int)i, (long)4L, (Pointer)Pointer.to((int[])new int[]{theData}));
                        continue block14;
                    }
                    default: {
                        throw new IllegalArgumentException("Type " + theFieldType + " is not supported for kernel argument " + theArgument.getField().getValue().getName().stringValue());
                    }
                }
            }
        }
        catch (Exception e) {
            throw new RuntimeException("Error extracting kernel parameter", e);
        }
        long[] global_work_size = new long[]{aNumberOfStreams};
        long[] local_work_size = null;
        CL.clEnqueueNDRangeKernel((cl_command_queue)this.commandQueue, (cl_kernel)theCachedKernel.kernel, (int)1, null, (long[])global_work_size, local_work_size, (int)0, null, null);
        CL.clFinish((cl_command_queue)this.commandQueue);
        for (Map.Entry theEntry : theOutputs.entrySet()) {
            DataRef theDataRef = (DataRef)theEntry.getValue();
            CL.clEnqueueReadBuffer((cl_command_queue)this.commandQueue, (cl_mem)theMemObjects[(Integer)theEntry.getKey()], (boolean)true, (long)0L, (long)theDataRef.size, (Pointer)theDataRef.pointer, (int)0, null, null);
            theDataRef.updateFromBuffer();
        }
        for (cl_mem theMem : theMemObjects) {
            if (theMem == null) continue;
            CL.clReleaseMemObject((cl_mem)theMem);
        }
    }

    private static DataRef toDataRef(final Object[] aArray, Class aDataType) {
        if (FloatSerializable.class.isAssignableFrom(aDataType)) {
            OpenCLType theType = aDataType.getAnnotation(OpenCLType.class);
            int theSize = aArray.length * theType.elementCount();
            final FloatBuffer theBuffer = FloatBuffer.allocate(theSize);
            for (Object anAArray : aArray) {
                FloatSerializable theVec = (FloatSerializable)anAArray;
                theVec.writeTo(theBuffer);
            }
            return new DataRef(Pointer.to((Buffer)theBuffer), 4 * theSize){

                @Override
                public void updateFromBuffer() {
                    theBuffer.rewind();
                    for (Object anAArray : aArray) {
                        FloatSerializable theVec = (FloatSerializable)anAArray;
                        theVec.readFrom(theBuffer);
                    }
                }
            };
        }
        throw new IllegalArgumentException("Not supported datatype : " + aDataType);
    }

    public void close() {
        for (CachedKernel theCached : this.cachedKernels.values()) {
            theCached.close();
        }
        CL.clReleaseCommandQueue((cl_command_queue)this.commandQueue);
        CL.clReleaseContext((cl_context)this.context);
    }

    private static class DataRef {
        private final Pointer pointer;
        private final int size;

        DataRef(Pointer aPointer, int aSize) {
            this.pointer = aPointer;
            this.size = aSize;
        }

        void updateFromBuffer() {
        }
    }

    private static class CachedKernel {
        private final OpenCLInputOutputs inputOutputs;
        private final cl_program program;
        private final cl_kernel kernel;

        CachedKernel(OpenCLInputOutputs aInputOutputs, cl_program aProgram, cl_kernel aKernel) {
            this.inputOutputs = aInputOutputs;
            this.program = aProgram;
            this.kernel = aKernel;
        }

        void close() {
            CL.clReleaseKernel((cl_kernel)this.kernel);
            CL.clReleaseProgram((cl_program)this.program);
        }
    }
}

