I am trying to learn how to use the transformers library to make predictions on the next word given a sentence. My code always predicts a "period" as the next token. Can someone help me see what I am doing wrong?
import torch
from transformers import DistilBertTokenizer, DistilBertForMaskedLM
# Load the pre-trained model and tokenizer
model_name = 'distilbert-base-uncased'
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertForMaskedLM.from_pretrained(model_name)
# Example sentence for predicting the next word
sentence = "I want to go to the"
# Tokenize the sentence
tokens = tokenizer.tokenize(sentence)
# Convert tokens to token IDs
token_ids = tokenizer.convert_tokens_to_ids(tokens)
# Add [CLS] and [SEP] tokens to the token IDs
token_ids = [tokenizer.cls_token_id] + token_ids + [tokenizer.sep_token_id]
# Create tensor input with the token IDs
input_ids = torch.tensor([token_ids])
# Get the predictions for the next word using top-k sampling
with torch.no_grad():
outputs = model(input_ids)
predictions = outputs.logits[0, -1] # Predictions for the last token
# Apply top-k sampling to obtain the predicted next word
top_k = 5 # Number of top-k predictions to consider
probabilities = torch.softmax(predictions, dim=-1)
top_k_predictions = torch.topk(probabilities, k=top_k)
predicted_token_ids = top_k_predictions.indices.tolist()
# Convert predicted token IDs to actual words
predicted_words = tokenizer.convert_ids_to_tokens(predicted_token_ids)
# Print the predicted next words
print(f"Original Sentence: {sentence}")
print("Predicted Next Words:")
for word in predicted_words:
print(word)
@steve-landiss
DistilBERT model is trained to predict masked or missing words in a sentence. However, it's important to note that the models are not guaranteed to always produce meaningful results. DistilBERT generates outputs based on the probabilities learned during training, but they can still produce nonsensical outputs. To improve the quality, you can fine-tune it with a dataset you have. Also, there are a couple of ways to get better results, like 1. increasing the value of top_k may give you a broader range of predicted words. 2. Ensembling: Instead of relying on a single language model, you can use an ensemble of multiple models. 3. Using larger models: Consider using a larger language model, like BERT or GPT-2. 4. Post-processing: Apply post-processing techniques to refine the model's outputs. You can even eliminate some of the outputs that you may get, like the "period" you said. 5. Context window: Adjust the context window size used for generating predictions. Here I provide you the code with some of these adjustments that may give you a deeper understanding of how to play with that: