I am using tensorflow 2.
When using the Model.fit()
method with a tf.data.Dataset
, the argument 'batch_size
' is ignored. Thus to train my model on batches, I have to first change my dataset of samples into a dataset of batches of samples by calling tf.data.Dataset.batch(batch_size)
.
Then, after reading the documentation, I don't understand clearly how the .fit()
method will shuffle my dataset at each epoch.
Since my dataset is a dataset of batches, will it shuffle the batches among each other (the batches remain unchanged) ? Or will it shuffle all the samples and then regroup them into new batches (which is the desired behaviour) ?
Thanks a lot for your help.
The
shuffle
parameter has no effect on thefit
function when using thetf.data.Dataset
API.If we read the documentation (emphasis is mine) :
It's not super clear, but we can have a hint that the shuffle argument will be ignored when using a
tf.data.Dataset
, as it behave like a generator.To be certain, lets dive in the code. If we look at the code of the
fit
method, you will see that the data is handled by a special class,DataHandler
. Looking at the code of this class, we see that this is an Adapter class to handle different kind of data. We are interrested in the class that handle tf.data.Dataset,DatasetAdapter
, and we can see that this class does not take into account theshuffle
parameter :If you want to shuffle your dataset, use the shuffle function from the
tf.data.Dataset
API.