At every epoch of my training, I need to split my dataset in n
batches of t
consecutive samples. For example, if my data is [1,2,3,4,5,6,7,8,9,10]
, n = 2
and t = 3
then valid batches would be
[1-2-3, 4-5-6] and [7-8-9, 10-1-2]
[2-3-4, 8-9-10] and [5-6-7, 1-2-3]
My old version is the following, but it samples every point in the data, meaning that I would parse the whole dataset t
times per epoch.
train_dataset = list(range(n))
train_sampler = None
if distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=bsize, shuffle=(train_sampler is None),
pin_memory=True, sampler=train_sampler)
for epoch in range(epochs):
if distributed:
train_sampler.set_epoch(epoch)
for starting_i in train_loader:
batch = np.array([np.mod(np.arange(i, i + t), n) for i in starting_i])
I have now implemented my own sampling function that splits the data into random batches where each sample is far from the two closest exactly t
. In the non-distributed scenario, I can do
for epoch in range(epochs):
pad = np.random.randint(n)
train_loader = np.mod(np.arange(pad, n + pad, t), n)
np.random.shuffle(train_loader)
train_loader = np.array_split(train_loader,
np.ceil(len(train_loader) / bsize))
for starting_i in train_loader:
batch = np.array([np.mod(np.arange(i, i + t), n) for i in starting_i])
How do I make this version distributed? Do I need to make a custom torch.nn.parallel.DistributedDataParallel
or torch.utils.data.DataLoader
?
I have checked the DistributedSampler
class
and my guess is that I have to override the __iter__
method. Am I right?
How does DistributedSampler
split the dataset? Is it sequentially among num_replicas
?
Say num_replicas = 2
. Would my dataset be split into [1,2,3,4,5]
and [6,7,8,9,10]
between the 2 workers? Or is it random? Like [1,4,7,3,10]
and [2,9,5,8,6]
? First case would be ok for me because keeps samples sequential, but second would not.
I ended up making my own
Dataset
where the data is[t, t + window, ... t + n * window]
. Every time it is called it randomizes the starting indices of the window. Then the sampler does the shuffling as usual. For reproducibility, it has aset_seed
method similar toset_epoch
of samplers.The following version randomizes the data outside the call and it is much much faster.