PyTorch Move Nested Modules to GPU?

454 views Asked by At

I am new to PyTorch and have some custom nn.Modules that I would like to run on a GPU. Let's call them M_outer, M_inner, and M_sub. In general, the structure looks like:

class M_outer(nn.Module):
    def __init__(self, **kwargs):
        self.inner_1 = M_inner(**kwargs)
        self.inner_2 = M_inner(**kwargs)
        # ...
    
    def forward(self, input):
        result = self.inner_1(input)
        result = self.inner_2(result)
        # ...
        return result
    

class M_inner(nn.Module):
    def __init__(self, **kwargs):
        self.sub_1 = M_sub(**kwargs)
        self.sub_2 = M_sub(**kwargs)
        # ...
    
    def forward(self, input):
        result = self.sub_1(input)
        result = self.sub_2(result)
        # ...
        return result

class M_sub(nn.Module):
    def __init__(self, **kwargs):
        self.emb_1 = nn.Embedding(x, y)
        self.emb_2 = nn.Embedding(x, y)
        # ...
        self.norm  = nn.LayerNorm()
    
    def forward(self, input):
        emb = (self.emb_1(input) + self.emb_2(input))        
        # ...
        return self.norm(emb)

and I try to get my module on a gpu via:

model = M_outer(params).to(device)

Yet I am still getting errors from the embedding layers saying that some operations are on the cpu.

I have read the documentation. I have read useful Discuss posts like this and related StackOverflow posts like this.

I can not register an nn.EmbeddingLayer via nn.Parameter. What am I missing?

1

There are 1 answers

0
Emanuel Huber On

PyTorch will move all submodules into the specified device. Your example should work just fine. I altered it a little for reproducibility:

import torch
from torch import nn

class M_outer(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.fc = M_inner()
    def forward(self, input):
        return self.fc(input)
    

class M_inner(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.fc = M_sub()
    def forward(self, input):
        return self.fc(input)

class M_sub(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.fc = nn.Linear(1, 1)
    def forward(self, input):
        return self.fc(input)

model = M_outer().to("cuda")
t = torch.randn(1).unsqueeze(0).to("cuda")
model(t)

One thing to notice is that PyTorch won't move your class members that aren't a nn.Module instance. Therefore, if you use a static tensor for a calculation inside your class during inference, you need to manually move it into your device.