Question: Is there a clean and straightforward way to create a custom dataset in TensorFlow by subclassing tf.data.Dataset, similar to the functionality available in PyTorch?
Details: I'm currently working on a project that involves training deep learning models using TensorFlow. In PyTorch, I found it convenient to create custom datasets by subclassing torch.utils.data.Dataset, which allowed me to encapsulate the data loading and preprocessing logic.
class FaceLandmarksDataset(Dataset):
"""Face Landmarks dataset."""
def __init__(self, csv_file, root_dir, transform=None):
...
def __len__(self):
...
def __getitem__(self, idx):
...
However, in TensorFlow, I'm struggling to find a similar mechanism for creating custom datasets. I've been using the tf.data.Dataset API for handling data pipelines, but I haven't come across a way to subclass it and define my own custom dataset.
Is there a recommended approach in TensorFlow for achieving this? Ideally, I would like to have the flexibility to implement custom data loading, preprocessing, and augmentation logic within the dataset subclass, as it provides a clean and modular structure.
Any guidance or examples on how to create a custom dataset by subclassing tf.data.Dataset would be greatly appreciated. Thank you!
P.S. similar question has already been asked in Is there a proper way to subclass Tensorflow's Dataset? with no good answer.
I don't know if this is exactly what you want but if you want to create custom dataset (simillary to datamodule in pytorch lightning) the simplest way is to subclass the
tfds.dataset_builders.TfDataBuilder
by overriding the__init__
method with your custom logic of data loading, preprocessing, and augmentation. See https://www.tensorflow.org/datasets/format_specific_dataset_builders#defining_a_new_dataset_builder_class. You can also add some metadata about your dataset, pass in a batch_size etc, and use build-in data loading functions such astf.data.Dataset.from_tensor_slices
.If you want to have generator similar to
torch.utils.data.Dataset
you can subclasstf.keras.utils.Sequence
(https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence) which has similar structure (the only major difference I can think of is that__getitem__
should return whole batch not just one element as in pytorch). You can then pass that object tomodel.fit()
.So you can combine those 2 approaches to define loading data, preprocessing, split etc in
__init__()
ofTfDataBuilder
and then useSequence
class for defining the logic of loading a single batch.