How does Model.fit() method's shuffle deals with Batches when using a tf.data.Dataset?

1.5k views Asked by At

I am using tensorflow 2.

When using the Model.fit() method with a tf.data.Dataset, the argument 'batch_size' is ignored. Thus to train my model on batches, I have to first change my dataset of samples into a dataset of batches of samples by calling tf.data.Dataset.batch(batch_size).

Then, after reading the documentation, I don't understand clearly how the .fit() method will shuffle my dataset at each epoch.

Since my dataset is a dataset of batches, will it shuffle the batches among each other (the batches remain unchanged) ? Or will it shuffle all the samples and then regroup them into new batches (which is the desired behaviour) ?

Thanks a lot for your help.

1

There are 1 answers

2
Lescurel On

The shuffle parameter has no effect on the fit function when using the tf.data.Dataset API.

If we read the documentation (emphasis is mine) :

shuffle: Boolean (whether to shuffle the training data before each epoch) or str (for 'batch'). This argument is ignored when x is a generator. 'batch' is a special option for dealing with the limitations of HDF5 data; it shuffles in batch-sized chunks. Has no effect when steps_per_epoch is not None.

It's not super clear, but we can have a hint that the shuffle argument will be ignored when using a tf.data.Dataset, as it behave like a generator.

To be certain, lets dive in the code. If we look at the code of the fit method, you will see that the data is handled by a special class, DataHandler. Looking at the code of this class, we see that this is an Adapter class to handle different kind of data. We are interrested in the class that handle tf.data.Dataset, DatasetAdapter, and we can see that this class does not take into account the shuffle parameter :

  def __init__(self,
               x,
               y=None,
               sample_weights=None,
               steps=None,
               **kwargs):
    super(DatasetAdapter, self).__init__(x, y, **kwargs)
    # Note that the dataset instance is immutable, its fine to reuse the user
    # provided dataset.
    self._dataset = x

    # The user-provided steps.
    self._user_steps = steps

    self._validate_args(y, sample_weights, steps)

If you want to shuffle your dataset, use the shuffle function from the tf.data.Dataset API.