Filling Float buffer in Metal

1.2k views Asked by At

Problem:

I need to fill a MTLBuffer of Floats with a constant value — say 1729.68921. I also need it to be as fast as possible.

Therefore I'm prohibited from filling the buffer on the CPU side (i.e. getting UnsafeMutablePointer<Float> from the MTLBuffer and assigning in serial manner).

My approach

Ideally I'd use MTLBlitCommandEncoder.fill(), however AFAIK it's only capable to fill a buffer with UInt8 values (given that UInt8 is 1 byte long and Float is 4 bytes long, I can't specify arbitrary value of my Float constant).

So far I can see only 2 options left, but both seem to be overkill:

  1. create another buffer B filled with the constant value and copy its contents into my buffer via MTLBlitCommandEncoder
  2. create a kernel function that'd fill the buffer

Questions

What's the fastest way of filling MTLBuffer of Floats with a constant value?

1

There are 1 answers

0
warrenm On BEST ANSWER

Using a compute shader that writes to multiple buffer elements from each thread was the fastest approach in my experiments. This is hardware-dependent, so you should test on the full range of devices you expect the app to be deployed on.

I wrote two compute shaders: one that fills 16 contiguous array elements without checking against the array bounds, and one that sets a single array element after checking against the length of the buffer:

kernel void fill_16_unchecked(device float *buffer  [[buffer(0)]],
                              constant float &value [[buffer(1)]],
                              uint index            [[thread_position_in_grid]])
{
    for (int i = 0; i < 16; ++i) {
        buffer[index * 16 + i] = value;
    }
}

kernel void single_fill_checked(device float *buffer         [[buffer(0)]],
                                constant float &value        [[buffer(1)]],
                                constant uint &buffer_length [[buffer(2)]],
                                uint index                   [[thread_position_in_grid]])
{
    if (index < buffer_length) {
        buffer[index] = value;
    }
}

If you know that your buffer count will always be a multiple of the thread execution width multiplied by the number of elements you set in the loop, you can just use the first function. The second function is a fallback for when you might dispatch a grid that would otherwise overrun the buffer.

Once you have two pipelines built from these functions, you can dispatch the work with a pair of compute commands as follows:

NSInteger executionWidth = [unchecked16Pipeline threadExecutionWidth];
id<MTLComputeCommandEncoder> computeEncoder = [commandBuffer computeCommandEncoder];
[computeEncoder setBuffer:buffer offset:0 atIndex:0];
[computeEncoder setBytes:&value length:sizeof(float) atIndex:1];
if (bufferCount / (executionWidth * 16) != 0) {
    [computeEncoder setComputePipelineState:unchecked16Pipeline];
    [computeEncoder dispatchThreadgroups:MTLSizeMake(bufferCount / (executionWidth * 16), 1, 1)
                   threadsPerThreadgroup:MTLSizeMake(executionWidth, 1, 1)];
}
if (bufferCount % (executionWidth * 16) != 0) {
    int remainder = bufferCount % (executionWidth * 16);
    [computeEncoder setComputePipelineState:checkedSinglePipeline];
    [computeEncoder setBytes:&bufferCount length:sizeof(bufferCount) atIndex:2];
    [computeEncoder dispatchThreadgroups:MTLSizeMake((remainder / executionWidth) + 1, 1, 1)
                   threadsPerThreadgroup:MTLSizeMake(executionWidth, 1, 1)];
}
[computeEncoder endEncoding];

Note that doing the work in this manner will not necessarily be faster than the naive approach that just writes one element per thread. In my tests, it was 40% faster on A8, roughly equivalent on A10, and 2-3x slower (!) on A9. Always test with your own workload.