torchvision.transforms.Normalize() slows down learning when adding to torchvision,transforms.Compose()

826 views Asked by At

when i use

train_transforms = torchvision.transforms.Compose([
  torchvision.transforms.ToTensor(), 
  torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

for loading MNIST dataset, it slows down learning even with mean = 0 and std = 1.

1

There are 1 answers

2
RafazZ On BEST ANSWER

The transformations are performed on CPU, and it doesn't matter if the mean/std are all zeros (BTW, don't set std to 0). To speed up the transform you have two options:

  1. If you don't have any data augmentations in your flow, just transform the data and save it as normalized tensors (pickled or something).
  2. You can also use torch.utils.data.DataLoader with some arguments: for example num_workers specifies how many CPU processes to use to transform the data. THere is also pin_memory which will speed up the whole thing if you are using CUDA.