Change batch size using a list of Pytorch's data loader

220 views Asked by At

During the training of my neural network model, I used a Pytorch's data loader to accelerate the training of the model. But instead of using a fixed batch size before updating the model's parameter, I have a list of different batch sizes that I want the data loader to use.

Example

train_dataset = TensorDataset(x_train, y_train) # x_train.shape (8400, 4)
dataloader_train = DataLoader(train_dataset, batch_size=64) # with fixed batch size of 64

What I want is a data loader that can use a list of batch size that is dynamic (not fixe)?

list_batch_size = [30, 60, 110, ..., 231] # with this list's sum being equal to x_train.shape[0] (8400) 
1

There are 1 answers

4
heemayl On BEST ANSWER

You can use a custom sampler (or batch sampler) for this.

Here's a quick proof-of-concept for a sampler that takes custom batch sizes as an argument to return batch indices as such:

class VariableBatchSampler(Sampler):
    def __init__(self, dataset_len: int, batch_sizes: list):
        self.dataset_len = dataset_len
        self.batch_sizes = batch_sizes
        self.batch_idx = 0
        self.start_idx = 0
        self.end_idx = self.batch_sizes[self.batch_idx]
        
    def __iter__(self):
        return self
       
    def __next__(self):
        if self.start_idx >= self.dataset_len:
            raise StopIteration()
 
        batch_indices = torch.arange(self.start_idx, self.end_idx, dtype=torch.int)
        self.start_idx += (self.end_idx - self.start_idx)
        self.batch_idx += 1

        try:
            self.end_idx += self.batch_sizes[self.batch_idx]
        except IndexError:
            self.end_idx = self.dataset_len
             
        return batch_indices

You can instantiate the sampler and use it as the sampler argument while instantiating the DataLoader e.g.:

sampler = VariableBatchSampler(dataset_len=len(train_dataset), batch_sizes=[10, 20, 30, 40])
data_loader = DataLoader(train_dataset, sampler=sampler)

Note that, each element in the data_loader iterable would contain one extra dimension for the batch (as the default value for batch_size is 1 in DataLoader); you can either use unsqueeze(dim=0) to get rid of the extra dim. Or better pass the sampler as the batch_sampler argument:

data_loader = DataLoader(train_dataset, batch_sampler=sampler)