I have written a custom keras callback to check the augmented data from a generator. (See this answer for the full code.) However, when I tried to use the same callback for a tf.data.Dataset
, it gave me an error:
File "/path/to/tensorflow_image_callback.py", line 16, in on_batch_end
imgs = self.train[batch][images_or_labels]
TypeError: 'PrefetchDataset' object is not subscriptable
Do keras callbacks in general only work with generators, or is it something about the way I've written my one? Is there a way to modify either my callback or the dataset to make it work?
I think there are three pieces to this puzzle. I'm open to changes to any and all of them. Firstly, the init function in the custom callback class:
class TensorBoardImage(tf.keras.callbacks.Callback):
def __init__(self, logdir, train, validation=None):
super(TensorBoardImage, self).__init__()
self.logdir = logdir
self.file_writer = tf.summary.create_file_writer(logdir)
self.train = train
self.validation = validation
Secondly, the on_batch_end
function within that same class
def on_batch_end(self, batch, logs):
images_or_labels = 0 #0=images, 1=labels
imgs = self.train[batch][images_or_labels]
Thirdly, instantiating the callback
import tensorflow_image_callback
tensorboard_image_callback = tensorflow_image_callback.TensorBoardImage(logdir=tensorboard_log_dir, train=train_dataset, validation=valid_dataset)
model.fit(train_dataset,
epochs=n_epochs,
validation_data=valid_dataset,
callbacks=[
tensorboard_callback,
tensorboard_image_callback
])
Some related threads which haven't led me to an answer yet:
Accessing validation data within a custom callback
Create keras callback to save model predictions and targets for each batch during training
What ended up working for me was the following, using
tfds
:the
__init__
function:then
on_batch_end
:I didn't need to make any changes to the instantiation.
I recommend only using it for debugging, otherwise it saves every nth image in your dataset to tensorboard every epoch. That can end up using a lot of disk space.