I am trying to run this my model on Colab Multi core TPU but I really don't know how to do it. I tried this tutorial notebook but I got some error and I can't fix it but I think there is maybe simpler wait for to do it.
About my model:
class BERTModel(nn.Module):
def __init__(self,...):
super().__init__()
if ...:
self.bert_model = XLMRobertaModel.from_pretrained(...) # huggingface XLM-R
elif ...:
self.bert_model = others_model.from_pretrained(...) # huggingface XLM-R
... # some other model's parameters
def forward(self,...):
bert_input = ...
output = self.bert_model(bert_input)
... # some function that process on output
def other_function(self,...):
# just doing some process on output. like concat layers's embedding and return ...
class MAINModel(nn.Module):
def __init__(self,...):
super().__init__()
print('Using model 1')
self.bert_model_1 = BERTModel(...)
print('Using model 2')
self.bert_model_2 = BERTModel(...)
self.linear = nn.Linear(...)
def forward(self,...):
bert_input = ...
bert_output = self.bert_model(bert_input)
linear_output = self.linear(bert_output)
return linear_output
Can you please tell me how to run a model like my model on Colab TPU? I used Colab PRO to make sure Ram memory is not a big problem. Thanks you so so much.
I would work off the examples here: https://github.com/pytorch/xla/tree/master/contrib/colab
Maybe start with a simpler model like this: https://github.com/pytorch/xla/blob/master/contrib/colab/mnist-training.ipynb
In the pseudocode you shared, there is no reference to the
torch_xla
library, which is required to use PyTorch on TPUs. I'd recommend starting with on of the working Colab notebooks in that directory I shared and then swapping out parts of the model with your own model. There are a few (usually like 3-4) places in the overall training code you need to modify for a model that runs on GPUs using native PyTorch if you want to run that model on TPUs. See here for a description of some of the changes. The other big change is to wrap the default dataloader with a ParallelLoader as shown in the example MNIST colab I sharedIf you have any specific error you see in one of the Colabs, feel free to open an issue : https://github.com/pytorch/xla/issues