Pytorch Text classification example - error while training, some issues with indexing

34 views Asked by At

I am trying to train a model on text classification. Every time I run the training block, I run into some or other index related key error. I have reset the index many times, try to drop the columns but it didnt help. Below is the piece of code ( please ignore the lines after the error message, they are just random check to see if the index value is present

I try to reset the index, dropped the secondary index adverty created and try to perform train test split on the clean dataset. Still the key error persist.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import spacy
# Define a custom dataset class
class TicketDataset(Dataset):
    def __init__(self, data, labels, vocab):
        self.data = data
        self.labels = labels
        self.vocab = vocab

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            text = self.data[idx]
            label = self.labels[idx]
            
            # Convert text to numerical representation
            text_ids = [self.vocab.get(word, 0) for word in text.split()]
            
            # Pad sequences to a fixed length
            max_len = 191
            text_ids = text_ids[:max_len] + [0] * (max_len - len(text_ids))
            
            return torch.tensor(text_ids), torch.tensor(label)
        except KeyError as e:
            if e.args[0] == 3247:  # Check for specific key (3247)
                print(f"Error: Key {e.args[0]} not found in data or labels.")
                return None, None  # Handle the error gracefully
            else:
                raise e  # Re-raise other KeyError exceptions
# Define the neural network model
class TextClassifier(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes):
        super(TextClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, text_ids):
        # Embed text tokens
        embedded = self.embedding(text_ids)

        # Pass through LSTM
        _, (hidden, _) = self.lstm(embedded.unsqueeze(0))

        # Extract final hidden state
        output = hidden.squeeze(0)

        # Apply linear layer and softmax
        logits = self.fc(output)
        probs = F.softmax(logits, dim=1)
        return probs
nn_data = pd.read_csv('Clean_data_nn.csv')
nn_data.drop(columns=['Unnamed: 0'], inplace=True)
nn_data['Issue_Description'] = nn_data['Issue_Description'].str.replace('\s{2,}', ' ', regex=True)
nn_data = nn_data[['Issue_Description','Category']]

nlp = spacy.load("en_core_web_lg")
nn_data['Issue_Description'] = nn_data['Issue_Description'].astype(str)
values_to_map = nn_data['Category'].unique()
label_mapping = {category : idx for idx,category in enumerate(values_to_map)}
nn_data['Category'] = nn_data['Category'].map(label_mapping)
label_mapping
{'batch': 0,
 'functional': 1,
 'ad hoc': 2,
 'data reports': 3,
 'data setup issue': 4,
 'working as expected': 5,
 'performance': 6,
 'access requests or uar': 7,
 'incorrect assignment': 8}
nn_data.reset_index(drop=True, inplace=True)
nn_data.iloc[3222]
Issue_Description    panel broker file
Category                             3
Name: 3222, dtype: object
text = ''
vocab = {}
text = ' '.join(nn_data['Issue_Description'])
# Process the text with spaCy
doc = nlp(text)
# Extract unique words (lemmas) from the processed text
vocab = {token.lemma_: idx for idx, token in enumerate(doc)}
# Text preprocessing (modify this based on your needs)
text_length = max(len(text.split()) for text in nn_data['Issue_Description'])
text_length
191
# Load your data (replace with your actual loading logic)
data = nn_data['Issue_Description']  # Replace with your data
labels = nn_data['Category']  # Replace with your labels
from sklearn.model_selection import train_test_split
# Prepare datasets
train_data, val_data, train_labels, val_labels = train_test_split(data, labels, test_size=0.2)
train_dataset = TicketDataset(train_data, train_labels, vocab)
val_dataset = TicketDataset(val_data, val_labels, vocab)
train_dataset
<__main__.TicketDataset at 0x27f3df68850>
# Define hyperparameters
vocab_size = len(vocab)
embedding_dim = 128
hidden_dim = 64
num_classes = 9
batch_size = 32
learning_rate = 0.001
epochs = 10
# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# Create the model and optimizer
model = TextClassifier(vocab_size, embedding_dim, hidden_dim, num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Define the loss function (consider using weighted cross-entropy for imbalanced data)
criterion = nn.CrossEntropyLoss()
# Training loop
for epoch in range(epochs):
    model.train()
    for i, (text_ids, labels) in enumerate(train_loader):
        print(i, text_ids, labels)
        optimizer.zero_grad()
        outputs = model(text_ids)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if i % 100 == 0:
            print(f"Epoch: {epoch+1}/{epochs}, Step: {i+1}/{len(train_loader)}, Loss: {loss.item():.4f}")`
`Error Message When running the training block: 
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~\PycharmProjects\pythonProject\venv\Lib\site-packages\pandas\core\indexes\base.py:3791, in Index.get_loc(self, key)
   3790 try:
-> 3791     return self._engine.get_loc(casted_key)
   3792 except KeyError as err:

File index.pyx:152, in pandas._libs.index.IndexEngine.get_loc()

File index.pyx:181, in pandas._libs.index.IndexEngine.get_loc()

File pandas\_libs\hashtable_class_helper.pxi:2606, in pandas._libs.hashtable.Int64HashTable.get_item()

File pandas\_libs\hashtable_class_helper.pxi:2630, in pandas._libs.hashtable.Int64HashTable.get_item()

KeyError: 9

The above exception was the direct cause of the following exception:

KeyError                                  Traceback (most recent call last)
Cell In[213], line 4
      2 for epoch in range(epochs):
      3     model.train()
----> 4     for i, (text_ids, labels) in enumerate(train_loader):
      5         print(i, text_ids, labels)
      6         optimizer.zero_grad()

File ~\PycharmProjects\pythonProject\venv\Lib\site-packages\torch\utils\data\dataloader.py:631, in _BaseDataLoaderIter.__next__(self)
    628 if self._sampler_iter is None:
    629     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    630     self._reset()  # type: ignore[call-arg]
--> 631 data = self._next_data()
    632 self._num_yielded += 1
    633 if self._dataset_kind == _DatasetKind.Iterable and \
    634         self._IterableDataset_len_called is not None and \
    635         self._num_yielded > self._IterableDataset_len_called:

File ~\PycharmProjects\pythonProject\venv\Lib\site-packages\torch\utils\data\dataloader.py:675, in _SingleProcessDataLoaderIter._next_data(self)
    673 def _next_data(self):
    674     index = self._next_index()  # may raise StopIteration
--> 675     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    676     if self._pin_memory:
    677         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File ~\PycharmProjects\pythonProject\venv\Lib\site-packages\torch\utils\data\_utils\fetch.py:51, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
---> 51         data = [self.dataset[idx] for idx in possibly_batched_index]
     52 else:
     53     data = self.dataset[possibly_batched_index]

File ~\PycharmProjects\pythonProject\venv\Lib\site-packages\torch\utils\data\_utils\fetch.py:51, in <listcomp>(.0)
     49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
---> 51         data = [self.dataset[idx] for idx in possibly_batched_index]
     52 else:
     53     data = self.dataset[possibly_batched_index]

Cell In[192], line 29, in TicketDataset.__getitem__(self, idx)
     27     return None, None  # Handle the error gracefully
     28 else:
---> 29     raise e

Cell In[192], line 13, in TicketDataset.__getitem__(self, idx)
     11 def __getitem__(self, idx):
     12     try:
---> 13         text = self.data[idx]
     14         label = self.labels[idx]
     16         # Convert text to numerical representation

File ~\PycharmProjects\pythonProject\venv\Lib\site-packages\pandas\core\series.py:1040, in Series.__getitem__(self, key)
   1037     return self._values[key]
   1039 elif key_is_scalar:
-> 1040     return self._get_value(key)
   1042 # Convert generator to list before going through hashable part
   1043 # (We will iterate through the generator there to check for slices)
   1044 if is_iterator(key):

File ~\PycharmProjects\pythonProject\venv\Lib\site-packages\pandas\core\series.py:1156, in Series._get_value(self, label, takeable)
   1153     return self._values[label]
   1155 # Similar to Index.get_value, but we do not fall back to positional
-> 1156 loc = self.index.get_loc(label)
   1158 if is_integer(loc):
   1159     return self._values[loc]

File ~\PycharmProjects\pythonProject\venv\Lib\site-packages\pandas\core\indexes\base.py:3798, in Index.get_loc(self, key)
   3793     if isinstance(casted_key, slice) or (
   3794         isinstance(casted_key, abc.Iterable)
   3795         and any(isinstance(x, slice) for x in casted_key)
   3796     ):
   3797         raise InvalidIndexError(key)
-> 3798     raise KeyError(key) from err
   3799 except TypeError:
   3800     # If we have a listlike key, _check_indexing_error will raise
   3801     #  InvalidIndexError. Otherwise we fall through and re-raise
   3802     #  the TypeError.
   3803     self._check_indexing_error(key)

KeyError: 9
len(train_loader)
147`
`

0

There are 0 answers