Multi-node/host training with the sharding API

96 views Asked by At

I wanted to do multi-node training with jax. For context, I'm on TPUs (I have 2x v3-8 nodes).

The docs suggests using distributed.initialize() along with xmap/pmap. However, that is now discouraged - the official way is to use sharding with device_put calls and let XLA autoparallelize the code across the local devices. This is what I'm using.

But I'm still confused on how to use sharding with multi-node setup. AIUI, we should be able to to a sort of 3D sharding like (2, 8, 1) for 2x TPUs with 8 local devices each, DDP styled. This would allow us to switch between n-way data parallelism and m-way model parallelism as outlined here.

That however doesn't seem like the correct way to accomplish this.

Can someone provide a minimal example here to demonstrate how exactly we modify the sharding to work in a multi-host setting?

1

There are 1 answers

0
Ricardo Gellman On

For your 2x v3-8 TPU setup, forget 3D sharding. Instead, create a 1D device mesh with 16 devices (2 nodes * 8 devices/node) and shard your training data across those devices for data parallelism. JAX's sharding API and libraries like Flax can help you with this. Consider model parallelism for very large models, but it's more complex.

import jax.numpy as jnp

data = <SINGLE_NUMPY_ARRAY>

num_devices = 16
sharded_data = jax.random.split(data, num_devices)

for shard in sharded_data:
  # Apply your model to the sharded data
  # Update model parameters (sharded as well)