How to integrate tf.data.dataset with rayTune for distributed training

60 views Asked by At

Using tensorflow-cpu==2.9.3, petastorm==0.12.1 on python 3.7

I've created tf.data.Dataset using petastorm for train and validation dataset.

  • ds_train (DatasetV1Adapter; think this is old version of tf.data.dataset)
  • ds_valid (DatasetV1Adapter)

First Trial: following rayTune doc https://docs.ray.io/en/latest/tune/faq.html#how-can-i-use-large-datasets-in-tune

tuner = tune.Tuner(
    tune.with_resources(tune.with_parameters(
       train_model, ds_train=ds_train, ds_valid=ds_valid),
    resources={'cpu':1})
)

def train_model(config, ds_train, ds_valid):
    model = <simple deep learning model>
    history = model.fit(x=ds_train, validation_data=ds_valid)
    return history

This outputs raytune tensorflow.python.framework.errors_impl.InvalidArgumentERror: Cannot convert a Tensor of dtype variant to a Numpy array

Second Trial : creating tf.data.dataset within train_model function works however it consumes more memory since each rayTune worker needs to generate tf.data.dataset leading to OOM error.

Third Trial : Saved ds_train, ds_valid using tf.data.experimentai.save(ds_train, path) then within train_model allowed each rayTune worker to simple load tf.data.dataset then use it via tf.data.experimental.load(path).

This removes benefit of petastorm which ses AWS S3 as source data for tf.data.Dataset therefore everytime size of tf.data.dataset increases, local disk size must increase as well.

What are some best practices to do distributed training using RayTune on tf.data.Dataset?

0

There are 0 answers