I get the following error when I try to crop a batch of images inside a tf.data.Dataset pipeline:

InvalidArgumentError: Input shape axis 0 must equal 4, got shape [5] [[{{node crop_to_bounding_box/unstack}}]] [Op:IteratorGetNext]

def crop(img_batch, label_batch):
    #cropped_image = img_batch
    cropped_image = tf.image.crop_to_bounding_box(img_batch, 0, 0, 100, 100)
    return cropped_image, label_batch


train_dataset_cropped = train_dataset.map(crop)

But when I try to run the following for loop I get the mentioned error:

for img_batch, label_batch in train_dataset_cropped:
    print(type(img_batch), img_batch.shape, label_batch.shape)

Note that the pipeline works without the tf.image.crop_to_bounding_box inside the crop function (directly using cropped_image = img_batch).

Do you know how to correctly crop a batch of images inside a tf.data.Dataset pipeline?

1

There are 1 answers

0
Antoine On

I didn't find any documentation for this, but I think you can't call methods from tf.image in a method that will be used within a tf.data.Dataset.map. A simple workaround for your problem is then to do:

def crop(img_batch, label_batch):
    cropped_image = img_batch[:, :100, :100] # if your dataset is already batched
    # cropped_image = img_batch[:100, :100] # otherwise
    return cropped_image, label_batch