Skip Dataset entries in TFRecordDataset.map()

1.7k views Asked by At

How do I skip entries in a TFRecord file when generating a TFRecordDataset?

Given a TFRecord file and tf.contrib.data.TFRecordDataset object, I create a new dataset by maping over the protobuf definition. For example,

features = {'some_data': tf.FixedLenFeature([], tf.string)}

def parser(example_proto):
    e = tf.parse_single_example(example_proto, features)
    data = e['some_data']
    # ...do a bunch of stuff to data...
    return data

x = TFRecordDataset(filename)
x = x.map(parser)
x = x.cache(cache_filename)
x = x.repeat()
x = x.batch(batch_size)

This lets me read in the data and do some preprocessing, then cache the results and batch it up for my model.

My question is, what if I want to skip one of the TFRecord entries (e.g., if the data is invalid/bad)? For example, in parser(), maybe I could return None, or some sort of tf.cond to indicate an invalid entry, or trip some assertion.

1

There are 1 answers

2
Yao Zhang On BEST ANSWER

(Summarizing the comment as an answer)

The filter() method of Dataset could filter entries according to a predicate.