Distributed training on PyTorch and Spot checkpoints in SageMaker

110 views Asked by At

I'm building custom model on PyTorch and want to know how to implement snapshot logic for distributed training.

If a model is trained on multiple spot instances and the model is implemented on BYO PyTorch image, how dpes Sagemaker know which snapshot to load for a failed job? E.g. there are 4 spot instances and they produce 4 snapshots. Let's say one instance is terminated - how SageMaker knows which snapshot to load?

1

There are 1 answers

0
Gili Nachum On

Saving - If you're doing data parallelization, then checkpoint only from the first GPU (rank=0), as all GPUs see the same state after a mini-batch.
Loading - SageMaker will load the last checkpoint directory to all instances, so load it for each of the GPUs (ranks), and continue from there.