I've been using the torchdata library (v0.6.0) to construct datapipes for my machine learning model, but I can't seem to figure out how torchdata expects its users to make a train/test split.
Supposing I have a datapipe dp, my first attempt was to use the Sampler datapipe along with a torch.utils.data.SubsetRandomSampler (which is what I expected from this part of the documentation), but this doesn't work how I would've thought:
>>> dp = SequenceWrapper(range(5))
>>> Sampler(dp,SubsetRandomSampler([0, 1, 2]))
Traceback (most recent call last):
...
TypeError: 'SubsetRandomSampler' object is not callable
Maybe torchdata has it's own samplers I'm not familiar with.
The only other way I can think of doing this would be to use a Demultiplexer, but this feels unclean to me, because we have to enumerate then "de-enumerate":
>>> train_len = len(dp) * 0.8
>>> dp1, dp2 = dp.enumerate().demux(num_instances=2, classifier_fn=lambda x: x[0] >= train_len)
>>> dp1, dp2 = (d.map(lambda x: x[1]) for d in (dp1, dp2))
Is there an "intended" way of doing this with torchdata which I'm missing?
PyTorch's tutorial on using DataPipes answers the question:
If you want to use the built-in
random_split()method of Iterable-style DataPipe:Edit: You can directly access the
DataPipefrom within the split dataset (this works with bothIterDataPipeandMapDataPipe:If you want the output of the
random_split()function to be aMapDataPipe, you can always wrap the outputs inSequenceWrapper():And same idea with
IterDataPipe: