In tensorflow understanding pipeline what is use of take(1) in 'for feature_batch, label_batch in train_ds.take(1)'

3.9k views Asked by At

I have started to learn tensorflow to increase my machine learning skills. In tensorflow understanding pipeline what is use of take(1) in

for feature_batch, label_batch in train_ds.take(1)
2

There are 2 answers

1
Nicolas Gervais - Open to Work On

It essentially "takes" that many elements from the dataset. I'm guessing that in this specific example, someone wanted to show what the data looked like and took only one element. If you don't use take, all elements will eventually be fetched:

import tensorflow as tf

dataset = tf.data.Dataset.range(9).batch(3)

for i in dataset:
    print(i)
tf.Tensor([0 1 2], shape=(3,), dtype=int64)
tf.Tensor([3 4 5], shape=(3,), dtype=int64)
tf.Tensor([6 7 8], shape=(3,), dtype=int64)

Now, if you take one element only:

for i in dataset.take(1):
    print(i)
tf.Tensor([0 1 2], shape=(3,), dtype=int64)

From the documentation of tf.data.Dataset.take:

Creates a Dataset with at most count elements from this dataset.

0
Sankarsan Mohanty On

for feature_batch, label_batch in train_ds.take(1)

in the above code take(1) is referring to the 1st batch of train_ds.

For example, If you have defined your batch size as 32. Then the length of train_ds.take(1) will be 32.

If this answers please mark it as correct.