Prefetching an iterator of 128-dim array to device

107 views Asked by At

I'm having trouble using flax.jax_utils.prefetch_to_device for the simple function below. I'm loading the SIFT 1M dataset, and converting the array to jnp array.

I then want to prefetch the iterator of 128-dim arrays.

import tensorflow_datasets as tfds
import tensorflow as tf
import jax
import jax.numpy as jnp
import itertools
import jax.dlpack
import jax.tools.colab_tpu
import flax

def _sift1m_iter():
    def prepare_tf_data(xs):
        def _prepare(x):
            dl_arr = tf.experimental.dlpack.to_dlpack(x)
            jax_arr = jax.dlpack.from_dlpack(dl_arr)
            return jax_arr

        return jax.tree_util.tree_map(_prepare, xs['embedding'])

    ds = tfds.load('sift1m', split='database')
    it = map(prepare_tf_data, ds)
    #it = flax.jax_utils.prefetch_to_device(it, 2)  => this causes an error
    return it

However, when I run this code, I get an error:

ValueError: len(shards) = 128 must equal len(devices) = 1.

I'm running this on a CPU-only device, but from the error it seems like the shape of the data I'm passing into prefetch_to_device is wrong.

1

There are 1 answers

1
Rômulo Silva On

The output in the _prepare(x) function should have the shape [num_devices, batch_size].

In your case, assuming that you have a single GPU, its shape should be [1, 128].

Take a look on how it can be done here.