How to handle imbalanced classes in transformers pytorch binary classification

709 views Asked by At

I am working on a binary text classification problem. How do I apply SMOTE or WeightedRandomSampler for the imbalance in my dataset? My code currently looks like this:

class GDataset(Dataset):

  def __init__(self, passage, targets, tokenizer, max_len):
    self.passage = passage
    self.targets = targets
    self.tokenizer = tokenizer
    self.max_len = max_len
  
  def __len__(self):
    return len(self.passage)
  
  def __getitem__(self, item):
    passage = str(self.passage[item])
    target = self.targets[item]
    
    if (target == 1) and self.transform: # minority class
            x = self.transform(x)

    encoding = self.tokenizer.encode_plus(
      passage,
      add_special_tokens=True,
      max_length=self.max_len,
      return_token_type_ids=False,
      pad_to_max_length=True,
      return_attention_mask=True,
      return_tensors='pt',
    )
    return
      'passage_text': passage,
      'input_ids': encoding['input_ids'].flatten(),
      'attention_mask': encoding['attention_mask'].flatten(),
      'targets': torch.tensor(target, dtype=torch.long

How can I use other balancing techniques?

0

There are 0 answers