jocl neural network

282 views Asked by At

I wrote a neural network in java and it looked like a good idea to take the computation on the gpu for performance issue. The problem I have is that its too slow... I have used jocl to do so. I dont now if its the kernel here is some code:

private static String programSource = "__kernel void "
        + "sampleKernel(__constant float *input,"
        + "             __global float *weights,"
        + "             __constant int *length,"
        + "             __global float *dst)" + "               {"
        + "    __private int gid = get_global_id(0);"
        + "    __private int pos = (gid*length[0]);"
        + "    __private float tmp = 0;"
        + "    __private int l = length[0];" + "        dst[gid]  = 0;"
        + "    for(int i = 0; i < l; i++){"
        + "         tmp += gewichte[pos+i]*input[i];"
        + "    }"
        + "   dst[gid] = tanh(tmp);" + "}";

making the weights __constant made the programm even slower(maybe it has to permanently switch data between global and local memory because the weights array is too big)

it seems like the most time takes this line:

tmp += gewichte[pos+i]*input[i];

one kernel call represents the computation of one Neural Network Layer and for every neuron of the layer one shader should perform( tanh(weightsOnThisNeuron + OutputFromAllNeuronsOfPreviousLayer).

I prepare all the kernels and store them so that if I want to execute them, they dont have to be prepared again and again.

The only IO between GPU and CPU is at the beginning and at the end when I retrieve the Output

Here is the code where I initialize the network and run the kernels:

public OpenClNetz(float[][][] gew, cl_context context,
        cl_command_queue commandQueue) throws Exception {
    if (context == null) {
        throw new Exception("context == null, Konstruktor schlug fehl");
    }
    if (commandQueue == null) {
        throw new Exception("commandQueue == null, Konstruktor schlug fehl");
    }
    this.layersize = new int[gew.length + 1];
    for (int i = 0; i < layersize.length - 1; i++) {
        this.layersize[i] = gew[i][0].length;
    }
    this.layersize[this.layersize.length - 1] = gew[gew.length - 1].length;
    this.context = context;
    builded = false;
    this.commandQueue = commandQueue;
    this.output = new float[layersize[layersize.length - 1]];
    gewichte = new cl_mem[layersize.length - 1];
    tmp = new cl_mem[layersize.length - 1];
    lengths = new cl_mem[layersize.length - 1];
    input = new cl_mem();
    float[] tmpG;
    int[][] tmpL = new int[layersize.length - 1][];
    for (int i = 0; i < gewichte.length; i++) {
        tmpG = new float[layersize[i] * layersize[i + 1]];
        tmpL[i] = new int[1];
        tmpL[i][0] = layersize[i];
        int n = 0;
        for (int j = 0; j < layersize[i + 1]; j++) {
            for (int k = 0; k < layersize[i]; k++) {
                tmpG[n] = gew[i][j][k];
                n++;
            }
        }
        gewichte[i] = clCreateBuffer(context, CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR, Sizeof.cl_float * tmpG.length, Pointer.to(tmpG),
                null);
        lengths[i] = clCreateBuffer(context, CL_MEM_READ_WRITE
                | CL_MEM_COPY_HOST_PTR, Sizeof.cl_int,  Pointer.to(tmpL[i]), null);
        tmp[i] = clCreateBuffer(context, CL_MEM_READ_WRITE, Sizeof.cl_float
                * layersize[i + 1], null, null);
    }

}


public void setInput(float[] in) {
    if (in.length != layersize[0]) {
        System.out
                .println("array Länge entspricht nicht der Inputsize, setInput schlug fehl");
        return;
    }
    input = clCreateBuffer(context, CL_MEM_READ_WRITE
            | CL_MEM_COPY_HOST_PTR, Sizeof.cl_float * layersize[0],
            Pointer.to(in), null);
    clSetKernelArg(kernel[0], 0, Sizeof.cl_mem, Pointer.to(input));
}

public void buildProgramm() {
    program = clCreateProgramWithSource(context, 1,
            new String[] { programSource }, null, null);
    clBuildProgram(program, 0, null, null, null, null);
    builded = true;
    kernel = new cl_kernel[gewichte.length];
    kernel[0] = clCreateKernel(program, "sampleKernel", null);
    clSetKernelArg(kernel[0], 0, Sizeof.cl_mem, Pointer.to(input));
    clSetKernelArg(kernel[0], 1, Sizeof.cl_mem, Pointer.to(gewichte[0]));
    clSetKernelArg(kernel[0], 2, Sizeof.cl_mem, Pointer.to(lengths[0]));
    clSetKernelArg(kernel[0], 3, Sizeof.cl_mem, Pointer.to(tmp[0]));
    for (int i = 1; i < gewichte.length; i++) {
        kernel[i] = clCreateKernel(program, "sampleKernel", null);
        clSetKernelArg(kernel[i], 0, Sizeof.cl_mem, Pointer.to(tmp[i - 1]));
        clSetKernelArg(kernel[i], 1, Sizeof.cl_mem, Pointer.to(gewichte[i]));
        clSetKernelArg(kernel[i], 2, Sizeof.cl_mem, Pointer.to(lengths[i]));
        clSetKernelArg(kernel[i], 3, Sizeof.cl_mem, Pointer.to(tmp[i]));
    }
}


public void run() throws Exception {
    if (!builded) {
        throw new Exception(
                "buildProgramm muss zuerst aufgerufen werden, run schlug fehl");
    }
    long global_work_size[] = new long[] { layersize[1] };
    this.local_work_size = new long[] { 8 };
    // Execute the kernel
    clEnqueueNDRangeKernel(commandQueue, kernel[0], 1, null,
            global_work_size, local_work_size, 0, null, null);

    for (int i = 1; i < gewichte.length; i++) {
        global_work_size = new long[] { layersize[i + 1] };

        // Execute the kernel
        clEnqueueNDRangeKernel(commandQueue, kernel[i], 1, null,
                global_work_size, local_work_size, 0, null, null);

    }

}

thats the main:

public class TEST{
public static void main(String args[]) throws Exception
{
    // The platform, device type and device number
    // that will be used
    final int platformIndex = 0;
    final long deviceType = CL_DEVICE_TYPE_DEFAULT;
    final int deviceIndex = 0;

    // Enable exceptions and subsequently omit error checks in this sample
    CL.setExceptionsEnabled(true);

    // Obtain the number of platforms
    int numPlatformsArray[] = new int[1];
    clGetPlatformIDs(0, null, numPlatformsArray);
    int numPlatforms = numPlatformsArray[0];

    // Obtain a platform ID
    cl_platform_id platforms[] = new cl_platform_id[numPlatforms];
    clGetPlatformIDs(platforms.length, platforms, null);
    cl_platform_id platform = platforms[platformIndex];

    // Initialize the context properties
    cl_context_properties contextProperties = new cl_context_properties();
    contextProperties.addProperty(CL_CONTEXT_PLATFORM, platform);

    // Obtain the number of devices for the platform
    int numDevicesArray[] = new int[1];
    clGetDeviceIDs(platform, deviceType, 0, null, numDevicesArray);
    int numDevices = numDevicesArray[0];

    // Obtain a device ID 
    cl_device_id devices[] = new cl_device_id[numDevices];
    clGetDeviceIDs(platform, deviceType, numDevices, devices, null);
    cl_device_id device = devices[deviceIndex];

    // Create a context for the selected device
    cl_context context = clCreateContext(
        contextProperties, 1, new cl_device_id[]{device}, 
        null, null, null);
    // Create a command-queue for the selected device
    cl_command_queue commandQueue = 
        clCreateCommandQueue(context, device, 0, null);




    int[] layersize = {512,512,512};
    float[] in = new float[512];
    for(int i = 0; i < 512; i++){
        in[i] = (float) (Math.random()*1.4 -0.7);
    }
    Netz net = new Netz(layersize);
    net.set_Input(in);
    OpenClNetz netz= new OpenClNetz(net.gewichte,context,commandQueue);
    netz.buildProgramm();
    netz.setInput(in);
    double time = System.currentTimeMillis();
    for(int i = 0; i < 10000; i++){
        netz.run();
    }
    System.out.println(Arrays.toString(netz.retrieveOutput()));
    System.out.println("time OpenCl: " + (System.currentTimeMillis()-time));

    time = System.currentTimeMillis();

    for(int i = 0; i < 10000; i++){
        net.start();
    }

    System.out.println("time normal: " + (System.currentTimeMillis()-time));
    System.out.println(Arrays.toString(netz.retrieveOutput()));
    System.out.println(Arrays.toString(net.start()));

    netz.destroy();






    // Release kernel, program, and memory objects
    clReleaseCommandQueue(commandQueue);
    clReleaseContext(context);

has somebody any idea how I can make this faster??

the output is for:

normal (running on CPU) : 6475ms

running on GPU (local worksize = 1) : 19110ms
running on GPU (local worksize = 2) : 11778ms
running on GPU (local worksize = 4) : 8985ms
running on GPU (local worksize = 8) : 6880ms
running on GPU (local worksize = 16) : 8237ms              (it becomes slower ?! O.o)
running on GPU (local worksize = 32) : 9298ms              (Im kinda new to Jocl)
running on GPU (local worksize = 64) : 10062ms
0

There are 0 answers