I am trying to run this my model on Colab Multi core TPU but I really don't know how to do it. I tried this tutorial notebook but I got some error and I can't fix it but I think there is maybe simpler wait for to do it.
About my model:
class BERTModel(nn.Module):
def __init__(self,...):
if ...:
self.bert_model = XLMRobertaModel.from_pretrained(...) # huggingface XLM-R
elif ...:
self.bert_model = others_model.from_pretrained(...) # huggingface XLM-R
... # some other model's parameters
def forward(self,...):
bert_input = ...
output = self.bert_model(bert_input)
... # some function that process on output
def other_function(self,...):
# just doing some process on output. like concat layers's embedding and return ...
class MAINModel(nn.Module):
def __init__(self,...):
print('Using model 1')
self.bert_model_1 = BERTModel(...)
print('Using model 2')
self.bert_model_2 = BERTModel(...)
self.linear = nn.Linear(...)
def forward(self,...):
bert_input = ...
bert_output = self.bert_model(bert_input)
linear_output = self.linear(bert_output)
return linear_output
Can you please tell me how to run a model like my model on Colab TPU? I used Colab PRO to make sure Ram memory is not a big problem. Thanks you so so much.
I would work off the examples here: https://github.com/pytorch/xla/tree/master/contrib/colab
Maybe start with a simpler model like this: https://github.com/pytorch/xla/blob/master/contrib/colab/mnist-training.ipynb
In the pseudocode you shared, there is no reference to the
library, which is required to use PyTorch on TPUs. I'd recommend starting with on of the working Colab notebooks in that directory I shared and then swapping out parts of the model with your own model. There are a few (usually like 3-4) places in the overall training code you need to modify for a model that runs on GPUs using native PyTorch if you want to run that model on TPUs. See here for a description of some of the changes. The other big change is to wrap the default dataloader with a ParallelLoader as shown in the example MNIST colab I sharedIf you have any specific error you see in one of the Colabs, feel free to open an issue : https://github.com/pytorch/xla/issues