How do I skip entries in a TFRecord file when generating a TFRecordDataset?
Given a TFRecord
file and tf.contrib.data.TFRecordDataset
object, I create a new dataset by map
ing over the protobuf definition. For example,
features = {'some_data': tf.FixedLenFeature([], tf.string)}
def parser(example_proto):
e = tf.parse_single_example(example_proto, features)
data = e['some_data']
# ...do a bunch of stuff to data...
return data
x = TFRecordDataset(filename)
x = x.map(parser)
x = x.cache(cache_filename)
x = x.repeat()
x = x.batch(batch_size)
This lets me read in the data and do some preprocessing, then cache the results and batch it up for my model.
My question is, what if I want to skip one of the TFRecord entries (e.g., if the data is invalid/bad)? For example, in parser()
, maybe I could return None
, or some sort of tf.cond
to indicate an invalid entry, or trip some assertion.
(Summarizing the comment as an answer)
The
filter()
method ofDataset
could filter entries according to a predicate.