Sample by index from TFRecordDataset

46 views Asked by At

I have data stored in TFRecord files, each file holding 500 consecutive observations. My current pipeline reads each line in a TFRecord file, parses it and returns the unbatched Tensor with shape (512, 512).

I now want to return two pairwise observations and the delta between them. My approach to embed this generator

def gen(inputs: tf.Tensor):
    indices = np.arange(500)
    for (idx0, idx1) in itertools.combinations(indices, r=2):
        delta = idx1 - idx0
        yield inputs[idx1], inputs[idx0], delta

into my existing tf.data.Dataset pipeline, which reads and parses multiple TFRecord files in parallel:

dataset = tf.data.Dataset.from_tensor_slices(files).interleave(
    lambda file: tf.data.TFRecordDataset(file).map(parse_fn)

The above pipeline yields (512, 512) patches. I can now collect all 500 consecutive observations per TFRecord file by appending .batch(batch_size=500) to the code above, yielding tensors with shape (500, 512, 512).

I now want to use the generator to yield pairwise (512, 512) patches. However, Tensorflow throws "TypeError: values must be a sequence". This error is not very helpful to me; Is there a way to get this approach to work? Any help is greatly appreciated!

For sake of completeness: My complete pipeline looks like this and works flawlessly without the generator part.

with tf.device('/cpu:0'):
    tfr_dataset = tf.data.Dataset.from_tensor_slices(files).interleave(
        lambda tfr_file: tf.data.TFRecordDataset(
            tfr_file,
        ).map(  # translate tfrecord file to float32
            parse_fn,
            num_parallel_calls=tf.data.AUTOTUNE
        ).batch(  # collect 500 consecutive observations
            batch_size=500
        ).interleave(
            lambda batched_input: tf.data.TFRecordDataset.from_generator(
                gen,  # apply generator on current batched input
                args=batched_input,
                output_signature=(
                    tf.TensorSpec(shape=(512, 512), dtype=tf.float32),
                    tf.TensorSpec(shape=(512, 512), dtype=tf.float32),
                    tf.TensorSpec(shape=(1, ), dtype=tf.int64)
                )
            )
        )
    )

    tfr_dataset = tfr_dataset.batch(batch_size=batch_size)  # batch dataset
    tfr_dataset = tfr_dataset.prefetch(tf.data.AUTOTUNE)  # prefetch

0

There are 0 answers