Catching an exception thrown from a callback in cudaLaunchHostFunc

358 views Asked by At

I want to check for an error flag living in managed memory that might have been written by a kernel running on a certain stream. Depending on the error flag I need to throw an exception. I would simply sync this stream and check the flag from the host, but I need to do so from inside a CUDA graph. AFAIK I need to somehow encode this host-side error checking inside a cudaLaunchHostFunc callback.

I am trying to understand how the cudaLaunchHostFunc function deals with exceptions. The documentation does not mention anything about it. Is there any way to catch of an exception thrown from inside the function provided to cudaLaunchHostFunc?

Consider the following MWE:

#include<iostream>
#include <stdexcept>

__global__ void kern(){
  int id = blockIdx.x*blockDim.x + threadIdx.x;
  printf("Kernel\n");
  return;
}

void foo(void* data){
  std::cerr<<"Callback"<<std::endl;
  throw std::runtime_error("Error in callback");
}

void launch(){
  cudaStream_t st = 0;
  kern<<<1,1,0,st>>>();
  cudaHostFn_t fn = foo;
  cudaLaunchHostFunc(st, fn, nullptr);
  cudaDeviceSynchronize();
}

int main(){
  try{
    launch();
  }
  catch(...){
    std::cerr<<"Catched exception"<<std::endl;
  }
  return 0;
}

The output of this code is:

Kernel
Callback
terminate called after throwing an instance of 'std::runtime_error'
  what():  Error in callback
Aborted (core dumped)

The exception is thrown but it appears that it is not propagated to the launch function. I would have expected the above launch() function to be equivalent (exception-wise) to the following:

void launch(){
  cudaStream_t st = 0;
  kern<<<1,1,0,st>>>();
  cudaStreamSynchronize(st);
  foo(nullptr);
  // cudaHostFn_t fn = foo;
  // cudaLaunchHostFunc(st, fn, nullptr);
  cudaDeviceSynchronize();
}

which does outputs the expected:

Kernel
Callback
Catched exception

Additionally, in the first case, all cuda calls return cudaSuccess.

1

There are 1 answers

0
Raul On BEST ANSWER

Thanks to the comments I understand now that my question is essentially the same as, for instance, this one: How can I propagate exceptions between threads?

The techniques used to take exceptions from a worker thread to the main thread also apply here.

For completion, the foo and launch functions in my dummy example could be rewritten as follows

void foo(void* data){
  auto e = static_cast<std::exception_ptr*>(data);
  std::cerr<<"Callback"<<std::endl;
  try{
    throw std::runtime_error("Error in callback");
  }
  catch(...){
    *e = std::current_exception();
  }
}

void launch(){
  cudaStream_t st = 0;
  dataD = 0;
  kern<<<1,1,0,st>>>();
  cudaStreamSynchronize(st);
  cudaHostFn_t fn = foo;
  std::exception_ptr e;
  cudaLaunchHostFunc(st, fn, (void*)&e);
  cudaDeviceSynchronize();
  if(e) std::rethrow_exception(e);
}

Which prints the expected:

Kernel
Callback
Catched exception