PyToch: ValueError: Expected input batch_size (256) to match target batch_size (128)

2.5k views Asked by At

I've faced a ValueError while training a BiLSTM part of speech tagger using pytorch. ValueError: Expected input batch_size (256) to match target batch_size (128).

def train(model, iterator, optimizer, criterion, tag_pad_idx):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:
        
        text = batch.p
        tags = batch.t
        
        optimizer.zero_grad()
        
        #text = [sent len, batch size]
        
        predictions = model(text)
        
        #predictions = [sent len, batch size, output dim]
        #tags = [sent len, batch size]
        
        predictions = predictions.view(-1, predictions.shape[-1])
        tags = tags.view(-1)
        
        #predictions = [sent len * batch size, output dim]
        #tags = [sent len * batch size]
        
        loss = criterion(predictions, tags)
                
        acc = categorical_accuracy(predictions, tags, tag_pad_idx)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)



def evaluate(model, iterator, criterion, tag_pad_idx):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():
    
        for batch in iterator:

            text = batch.p
            tags = batch.t
            
            predictions = model(text)
            
            predictions = predictions.view(-1, predictions.shape[-1])
            tags = tags.view(-1)
            
            loss = criterion(predictions, tags)
            
            acc = categorical_accuracy(predictions, tags, tag_pad_idx)

            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

class BiLSTMPOSTagger(nn.Module):
        def __init__(self, 
                     input_dim, 
                     embedding_dim, 
                     hidden_dim, 
                     output_dim, 
                     n_layers, 
                     bidirectional, 
                     dropout, 
                     pad_idx):
            
            super().__init__()
            
            self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx = pad_idx)
            
            self.lstm = nn.LSTM(embedding_dim, 
                                hidden_dim, 
                                num_layers = n_layers, 
                                bidirectional = bidirectional,
                                dropout = dropout if n_layers > 1 else 0)
            
            self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
            
            self.dropout = nn.Dropout(dropout)
            
        def forward(self, text):
            embedded = self.dropout(self.embedding(text))
            outputs, (hidden, cell) = self.lstm(embedded)
            predictions = self.fc(self.dropout(outputs))        
            return predictions

........................... ........................... ........................... ...........................

INPUT_DIM = len(POS.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 128
OUTPUT_DIM = len(TAG.vocab)
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.25
PAD_IDX = POS.vocab.stoi[POS.pad_token]

print(INPUT_DIM)  #output 22147
print(OUTPUT_DIM) #output 42

model = BiLSTMPOSTagger(INPUT_DIM, 
                        EMBEDDING_DIM, 
                        HIDDEN_DIM, 
                        OUTPUT_DIM, 
                        N_LAYERS, 
                        BIDIRECTIONAL, 
                        DROPOUT, 
                        PAD_IDX)

........................... ........................... ........................... ...........................

N_EPOCHS = 10

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion, TAG_PAD_IDX)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, TAG_PAD_IDX)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut1-model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')





ValueError                                Traceback (most recent call last)
<ipython-input-55-83bf30366feb> in <module>()
      7     start_time = time.time()
      8 
----> 9     train_loss, train_acc = train(model, train_iterator, optimizer, criterion, TAG_PAD_IDX)
     10     valid_loss, valid_acc = evaluate(model, valid_iterator, criterion, TAG_PAD_IDX)
     11 

4 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in nll_loss(input, target, weight, size_average, ignore_index, reduce, reduction)
   2260     if input.size(0) != target.size(0):
   2261         raise ValueError('Expected input batch_size ({}) to match target batch_size ({}).'
-> 2262                          .format(input.size(0), target.size(0)))
   2263     if dim == 2:
   2264         ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)

ValueError: Expected input batch_size (256) to match target batch_size (128).
1

There are 1 answers

7
Theodor Peifer On

(continuing from the comments)

I guess that your batch-size is equal to 128 (its nowhere defined), right? The LSTM outputs a list of the outputs of every timestep. But for classification you normaly just want the last one. So the first dimension of outputs is your sequence length which in your case seems to be 2. And when you apply .view this two gets multiplied with your batch-size (128) and then it gets 256. So straight after your lstm layer you need to take the last output the the output-sequence outputs. Like this:

def forward(self, text):
    embedded = self.dropout(self.embedding(text))
    outputs, (hidden, cell) = self.lstm(embedded)
    
    # take last output
    outputs = outputs.reshape(batch_size, sequence_size, hidden_size)
    outputs = outputs[:, -1]

    predictions = self.fc(self.dropout(outputs))        
    return predictions