I am using tensorflow 2.
When using the Model.fit() method with a tf.data.Dataset, the argument 'batch_size' is ignored. Thus to train my model on batches, I have to first change my dataset of samples into a dataset of batches of samples by calling tf.data.Dataset.batch(batch_size).
Then, after reading the documentation, I don't understand clearly how the .fit() method will shuffle my dataset at each epoch.
Since my dataset is a dataset of batches, will it shuffle the batches among each other (the batches remain unchanged) ? Or will it shuffle all the samples and then regroup them into new batches (which is the desired behaviour) ?
Thanks a lot for your help.
The
shuffleparameter has no effect on thefitfunction when using thetf.data.DatasetAPI.If we read the documentation (emphasis is mine) :
It's not super clear, but we can have a hint that the shuffle argument will be ignored when using a
tf.data.Dataset, as it behave like a generator.To be certain, lets dive in the code. If we look at the code of the
fitmethod, you will see that the data is handled by a special class,DataHandler. Looking at the code of this class, we see that this is an Adapter class to handle different kind of data. We are interrested in the class that handle tf.data.Dataset,DatasetAdapter, and we can see that this class does not take into account theshuffleparameter :If you want to shuffle your dataset, use the shuffle function from the
tf.data.DatasetAPI.