Run Pytorch stacked model on Colab TPU

768 views Asked by At

I am trying to run this my model on Colab Multi core TPU but I really don't know how to do it. I tried this tutorial notebook but I got some error and I can't fix it but I think there is maybe simpler wait for to do it.

About my model:

class BERTModel(nn.Module):
    def __init__(self,...):
        super().__init__()
        if ...:
            self.bert_model = XLMRobertaModel.from_pretrained(...)   # huggingface XLM-R
        elif ...:
            self.bert_model = others_model.from_pretrained(...)   # huggingface XLM-R
        
        ... # some other model's parameters
        
    def forward(self,...):
        bert_input = ...
        output = self.bert_model(bert_input)
        
        ... # some function that process on output
        
    def other_function(self,...):
        # just doing some process on output. like concat layers's embedding and return ...
        
class MAINModel(nn.Module):
    def __init__(self,...):
        super().__init__()
        
        print('Using model 1')
        self.bert_model_1 = BERTModel(...)
        
        print('Using model 2')
        self.bert_model_2 = BERTModel(...)
        
        self.linear = nn.Linear(...)
        
    def forward(self,...):
        bert_input = ...
        bert_output = self.bert_model(bert_input)
        linear_output = self.linear(bert_output)
   
        return linear_output

Can you please tell me how to run a model like my model on Colab TPU? I used Colab PRO to make sure Ram memory is not a big problem. Thanks you so so much.

1

There are 1 answers

0
Zachary Cain On

I would work off the examples here: https://github.com/pytorch/xla/tree/master/contrib/colab

Maybe start with a simpler model like this: https://github.com/pytorch/xla/blob/master/contrib/colab/mnist-training.ipynb

In the pseudocode you shared, there is no reference to the torch_xla library, which is required to use PyTorch on TPUs. I'd recommend starting with on of the working Colab notebooks in that directory I shared and then swapping out parts of the model with your own model. There are a few (usually like 3-4) places in the overall training code you need to modify for a model that runs on GPUs using native PyTorch if you want to run that model on TPUs. See here for a description of some of the changes. The other big change is to wrap the default dataloader with a ParallelLoader as shown in the example MNIST colab I shared

If you have any specific error you see in one of the Colabs, feel free to open an issue : https://github.com/pytorch/xla/issues