Fine tuning a BERT Model as a chatbot giving error while training

172 views Asked by At

I have been trying to fine tune a BERT model to give response sentences like a character based on input sentences but I am getting a rather odd error every time . the code is `

Here sourcetexts is a list of sentences that give the context and target_text is a list of sentences that give response to context statments


from transformers import AutoModel, AutoTokenizer

model = AutoModel.from_pretrained("bert-base-cased").to(device)
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

input_ids = \[\]
output_ids = \[\]
for i in range (0 , len(source_text):
input_ids.append(tokenizer.encode(source_texts\[i\], return_tensors="pt"))
output_ids.append(tokenizer.encode(target_texts\[i\], return_tensors="pt"))

import torch
device = torch.device("cuda")

from transformers import BertForMaskedLM, AdamW

model = BertForMaskedLM.from_pretrained("bert-base-cased")
optimizer = AdamW(model.parameters(), lr=1e-5)
loss_fn = torch.nn.CrossEntropyLoss()

def train(input_id, output_id):
input_id = input_id.to(device)
output_id = output_id.to(device)

    model.zero_grad()
    
    logits, _ = model(input_id, labels=output_id)
    
    # Compute the loss
    loss = loss_fn(logits.view(-1, logits.size(-1)), output_id.view(-1))
    
    loss.backward()
    optimizer.step()
    return loss.item()

for epoch in range(50):
\# Train the model on the training dataset
train_loss = 0.0
for input_sequences, output_sequences in zip(input_ids, output_ids):
input_sequences = input_sequences.to(device)
output_sequences = output_sequences.to(device)
train_loss += train(input_sequences, output_sequences)

This is the Error that I am getting

Any help would be really appreciated .

Pls help!!

1

There are 1 answers

0
Edwin Cheong On

Hi i saw your code but you didn't move your model to GPU, only the inputs, pytorch by default is on CPU

import torch

device = torch.device('cuda')

model = BertForMaskedLM.from_pretrained("bert-base-cased")
model.to(device)