I get the following error when I try to crop a batch of images inside a tf.data.Dataset
pipeline:
InvalidArgumentError: Input shape axis 0 must equal 4, got shape [5] [[{{node crop_to_bounding_box/unstack}}]] [Op:IteratorGetNext]
def crop(img_batch, label_batch):
#cropped_image = img_batch
cropped_image = tf.image.crop_to_bounding_box(img_batch, 0, 0, 100, 100)
return cropped_image, label_batch
train_dataset_cropped = train_dataset.map(crop)
But when I try to run the following for loop I get the mentioned error:
for img_batch, label_batch in train_dataset_cropped:
print(type(img_batch), img_batch.shape, label_batch.shape)
Note that the pipeline works without the tf.image.crop_to_bounding_box
inside the crop
function (directly using cropped_image = img_batch
).
Do you know how to correctly crop a batch of images inside a tf.data.Dataset pipeline?
I didn't find any documentation for this, but I think you can't call methods from
tf.image
in a method that will be used within atf.data.Dataset.map
. A simple workaround for your problem is then to do: