When allocating a TPU under TPU VM architechture, pod versions such as tpu-vm-tf-2.6.2-pod is available as TPU software version. When selecting pod as software version, and following instruction at Run JAX code on TPU Pod Slide jax.device_count() cannot find TPU.
Is selecting pod version sufficient to allocate a TPU Pod or are there additional steps/requirements? How can I select which TPU VM's to run under pod?
If you are using Jax, please use Jax images
tpu-vm-baseandtpu-vm-v4-baseinstead of Tensorflow (e.g.tpu-vm-tf-2.12.0-pod).Note: For Jax use the same image
--version tpu-vm-basefor both TPU VM device (v2-8, v3-8) and TPU VM pod slices (e.g. v3-32, v3-64, v2-32, etc.).For TPU version v2 and v3 please use
--version tpu-vm-basethen Install Jax on pod slice:details Run JAX code on TPU Pod slices. For TPU v4 please use
tpu-vm-v4-base, details v4-users-guide.