What does cudaSetDevice() do to a CUDA device's context stack?

3.3k views Asked by At

Suppose I have an active CUDA context associated with device i, and I now call cudaSetDevice(i). What happens? :

  1. Nothing?
  2. Primary context replaces the top of the stack?
  3. Primary context is pushed onto the stack?

It actually seems to be inconsistent. I've written this program, running on a machine with a single device:

#include <cuda.h>
#include <cuda_runtime_api.h>
#include <cassert>
#include <iostream>

int main()
{
        CUcontext ctx1, primary;
        cuInit(0);
        auto status = cuCtxCreate(&ctx1, 0, 0);
        assert (status == (CUresult) cudaSuccess);
        cuCtxPushCurrent(ctx1);
        status = cudaSetDevice(0);
        assert (status == cudaSuccess);
        void* ptr1;
        void* ptr2;
        cudaMalloc(&ptr1, 1024);
        assert (status == cudaSuccess);
        cuCtxGetCurrent(&primary);
        assert (status == (CUresult) cudaSuccess);
        assert(primary != ctx1);
        status = cuCtxPushCurrent(ctx1);
        assert (status == (CUresult) cudaSuccess);
        cudaMalloc(&ptr2, 1024);
        assert (status == (CUresult) cudaSuccess);
        cudaSetDevice(0);
        assert (status == (CUresult) cudaSuccess);
        int i = 0;
        while (true) {
                status = cuCtxPopCurrent(&primary);
                if (status != (CUresult) cudaSuccess) { break; }
                std::cout << "Next context on stack (" << i++ << ") is " << (void*) primary << '\n';
        }
}

and I get the following output:

context ctx1 is 0x563ec6225e30
primary context is 0x563ec61f5490
Next context on stack (0) is 0x563ec61f5490
Next context on stack (1) is 0x563ec61f5490
Next context on stack(2) is 0x563ec6225e3

This seems like the behavior is sometimes a replacement, and sometimes a push.

What's going on?

1

There are 1 answers

6
Robert Crovella On BEST ANSWER

TL;DR: Based on the code you have provided, in both instances of your particular usage, it seems that cudaSetDevice() is replacing the context at the top of the stack.

Let's modify your code a bit, and then see what we can infer about the effect of each API call in your code on the context stack:

$ cat t1759.cu
#include <cuda.h>
#include <cuda_runtime_api.h>
#include <cassert>
#include <iostream>
void check(int j, CUcontext ctx1, CUcontext ctx2){
  CUcontext ctx0;
  int i = 0;
  while (true) {
                auto status = cuCtxPopCurrent(&ctx0);
                if (status != CUDA_SUCCESS) { break; }
                if (ctx0 == ctx1) std::cout << j << ":Next context on stack (" << i++ << ") is ctx1:" << (void*) ctx0 << '\n';
                else if (ctx0 == ctx2) std::cout << j << ":Next context on stack (" << i++ << ") is ctx2:" << (void*) ctx0 << '\n';
                else std::cout << j << ":Next context on stack (" << i++ << ") is unknown:" << (void*) ctx0 << '\n';
  }
}
void runtest(int i)
{
        CUcontext ctx1, primary = NULL;
        cuInit(0);
        auto dstatus = cuCtxCreate(&ctx1, 0, 0);    // checkpoint 1
        assert (dstatus == CUDA_SUCCESS);
        if (i == 1) {check(i,ctx1,primary); return;}// checkpoint 1
        dstatus = cuCtxPushCurrent(ctx1);           // checkpoint 2
        assert (dstatus == CUDA_SUCCESS);
        if (i == 2) {check(i,ctx1,primary); return;}// checkpoint 2
        auto rstatus = cudaSetDevice(0);            // checkpoint 3
        assert (rstatus == cudaSuccess);
        if (i == 3) {check(i,ctx1,primary); return;}// checkpoint 3
        void* ptr1;
        void* ptr2;
        rstatus = cudaMalloc(&ptr1, 1024);          // checkpoint 4
        assert (rstatus == cudaSuccess);
        if (i == 4) {check(i,ctx1,primary); return;}// checkpoint 4
        dstatus = cuCtxGetCurrent(&primary);        // checkpoint 5
        assert (dstatus == CUDA_SUCCESS);
        assert(primary != ctx1);
        if (i == 5) {check(i,ctx1,primary); return;}// checkpoint 5
        dstatus = cuCtxPushCurrent(ctx1);           // checkpoint 6
        assert (dstatus == CUDA_SUCCESS);
        if (i == 6) {check(i,ctx1,primary); return;}// checkpoint 6
        rstatus = cudaMalloc(&ptr2, 1024);          // checkpoint 7
        assert (rstatus == cudaSuccess);
        if (i == 7) {check(i,ctx1,primary); return;}// checkpoint 7
        rstatus = cudaSetDevice(0);                 // checkpoint 8
        assert (rstatus == cudaSuccess);
        if (i == 8) {check(i,ctx1,primary); return;}// checkpoint 8
        return;
}

int main(){
        for (int i = 1; i < 9; i++){
          cudaDeviceReset();
          runtest(i);}
}
$ nvcc -o t1759 t1759.cu -lcuda -std=c++11
$ ./t1759
1:Next context on stack (0) is ctx1:0x11087e0
2:Next context on stack (0) is ctx1:0x1741160
2:Next context on stack (1) is ctx1:0x1741160
3:Next context on stack (0) is unknown:0x10dc520
3:Next context on stack (1) is ctx1:0x1c5aa70
4:Next context on stack (0) is unknown:0x10dc520
4:Next context on stack (1) is ctx1:0x23eaa00
5:Next context on stack (0) is ctx2:0x10dc520
5:Next context on stack (1) is ctx1:0x32caf30
6:Next context on stack (0) is ctx1:0x3a44ed0
6:Next context on stack (1) is ctx2:0x10dc520
6:Next context on stack (2) is ctx1:0x3a44ed0
7:Next context on stack (0) is ctx1:0x41cfd90
7:Next context on stack (1) is ctx2:0x10dc520
7:Next context on stack (2) is ctx1:0x41cfd90
8:Next context on stack (0) is ctx2:0x10dc520
8:Next context on stack (1) is ctx2:0x10dc520
8:Next context on stack (2) is ctx1:0x4959c70
$

Based on the above, as we proceed through each API call in your code:

1.

        auto dstatus = cuCtxCreate(&ctx1, 0, 0);    // checkpoint 1
1:Next context on stack (0) is ctx1:0x11087e0

The context creation also pushes the newly created context on the stack, as mentioned here.

2.

        dstatus = cuCtxPushCurrent(ctx1);           // checkpoint 2
2:Next context on stack (0) is ctx1:0x1741160
2:Next context on stack (1) is ctx1:0x1741160

No surprise, pushing the same context on the stack creates another stack entry for it.

3.

        auto rstatus = cudaSetDevice(0);            // checkpoint 3
3:Next context on stack (0) is unknown:0x10dc520
3:Next context on stack (1) is ctx1:0x1c5aa70

The cudaSetDevice() call has replaced the top of the stack with an "unknown" context. (Only unknown at this point because we have not retrieved the handle value of the "other" context).

4.

        rstatus = cudaMalloc(&ptr1, 1024);          // checkpoint 4
4:Next context on stack (0) is unknown:0x10dc520
4:Next context on stack (1) is ctx1:0x23eaa00

No difference in stack configuration due to this call.

5.

        dstatus = cuCtxGetCurrent(&primary);        // checkpoint 5
5:Next context on stack (0) is ctx2:0x10dc520
5:Next context on stack (1) is ctx1:0x32caf30

No difference in stack configuration due to this call, but we now know that the top of stack context is the current context (and we can surmise it is the primary context).

6.

        dstatus = cuCtxPushCurrent(ctx1);           // checkpoint 6
6:Next context on stack (0) is ctx1:0x3a44ed0
6:Next context on stack (1) is ctx2:0x10dc520
6:Next context on stack (2) is ctx1:0x3a44ed0

No real surprise here. We are pushing ctx1 on the stack, and so the stack has 3 entries, the first one being the driver API created context, and the next two entries being the same as the stack configuration from step 5, just moved down one stack location.

7.

        rstatus = cudaMalloc(&ptr2, 1024);          // checkpoint 7
7:Next context on stack (0) is ctx1:0x41cfd90
7:Next context on stack (1) is ctx2:0x10dc520
7:Next context on stack (2) is ctx1:0x41cfd90

Again, this call has no effect on stack configuration.

8.

        rstatus = cudaSetDevice(0);                 // checkpoint 8
8:Next context on stack (0) is ctx2:0x10dc520
8:Next context on stack (1) is ctx2:0x10dc520
8:Next context on stack (2) is ctx1:0x4959c70

Once again, we see that the behavior here is that the cudaSetDevice() call has replaced the top of stack context with the primary context.

The conclusion I have from your test code is that I see no inconsistency of behavior of the cudaSetDevice() call when intermixed with various runtime and driver API calls as you have in your code.

From my perspective, this sort of programming paradigm is insanity. I can't imagine why you would want to intermix driver API and runtime API code this way.