Oversampling the dataset with pytorch

1.2k views Asked by At

I'm quite new to PyTorch and python. and I have a binary classification problem where one class have more samples than the other, so I decided to oversample the class that has less number of samples by doing more augmentation on it, so for example I would generate 7 images out of one sample for one class, while for the other class I would generate 3 images out of one sample. I'm using imguag for augmentation with PyTorch, so I'm not sure which is better, to augment my dataset first, then passing it to torch.utils.data.Dataset class, or reading the data and augmenting it inside init function of Dataset class.

1

There are 1 answers

1
Craig.Li On

I think there is another way to deal with the unbalanced data, nn.BCELoss is a common choice for the binary classification problem, you can set a pos_weight to balance positive and negative samples. If you do so, you can apply same augmentation to all samples. Here is the code:

# defines the augmentation
transform = transforms.Compose([transforms.RandomRotation(20),
                            transforms.Resize((32, 32)),
                            transforms.ToTensor(),
                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# initializes the data set
dataset = Dataset(train_data_path, transforms=transform)
# defines the loss function
criterion = torch.nn.BCELoss(torch.tensor([10.]))