How to get perplexity per token rather than average perplexity?

487 views Asked by At

I can get the perplexity of a whole sentence from here:

device = "cuda"
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

device = "cuda"
model_id = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
sent = 'Happy Birthday!'
input_ids = tokenizer(sent, return_tensors='pt')['input_ids']
target_ids = input_ids.clone()
outputs = model(input_ids.to(device), labels=target_ids)
ppl = torch.exp(outputs.loss)
print(ppl)
>>>tensor(1499.6934, device='cuda:0', grad_fn=<ExpBackward0>)

But how can I get the perplexity value for each token, instead of of the average perplexity of the entire sequence of tokens? The input sentence in this example, 'Happy Birthday!' is composed of 3 tokens. Based on the formula for perplexity: enter image description here

This should result in 3 values: log probability of the first token, log probability of the second token given the first, and the log probability of the third token given the first 2. Each should be exponentiated to get the perplexity value of each token.

I currently have the following:

import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

device = "cuda"
model_id = "gpt2"
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)

sent = 'Happy Birthday!'
input_ids = tokenizer(sent, return_tensors='pt')['input_ids'].to(device)
target_ids = input_ids.clone()

# Initialize an empty list to store perplexities for each token
perplexities = []

# Calculate perplexity for each token
for i in range(input_ids.shape[1]):
    output = model(input_ids[:, :i+1], labels=target_ids[:, :i+1])
    log_prob = output.loss.item()
    perplexity = torch.exp(torch.tensor(log_prob))
    perplexities.append(perplexity.item())

# Perplexities is now a list containing the perplexity values for each token
for i, token in enumerate([tokenizer.decode(i) for i in input_ids[0]]):
    print(f"Token: {token}, Perplexity: {perplexities[i]}")
    >>> Token: Happy, Perplexity: nan
Token:  Birthday, Perplexity: 54192.46484375
Token: !, Perplexity: 1499.693359375

But I'm not sure what I'm doing wrong, as the last token seem to have the same perplexity as the entire sentence.

2

There are 2 answers

9
Martin Weyssow On BEST ANSWER

this is happening because in the second code snippet, you loop over the input sequence by adding a new token at each iteration:

i=0: input_ids[:, :i+1] := tensor([[25082]], device='cuda:0')
i=1: input_ids[:, :i+1] := tensor([[25082, 33511]], device='cuda:0')
i=2: input_ids[:, :i+1] := tensor([[25082, 33511,     0]], device='cuda:0')

Then, the computation of the perplexity in the last iteration of the loop is essentially identical to doing this:

outputs = model(input_ids.to(device), labels=target_ids)
ppl = torch.exp(outputs.loss)

Here's how you can compute the perplexity and per-token perplexity (see https://github.com/huggingface/transformers/blob/v4.35.0/src/transformers/models/gpt2/modeling_gpt2.py#L1103):

import torch.nn.functional as F
[...]
sent = 'Happy Birthday!'
input_ids = tokenizer(sent, return_tensors='pt')['input_ids'].to(device)
labels = input_ids.clone()

output = model(input_ids, labels=labels)
logits = output.logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction='none')
per_token_perplexity = torch.exp(loss)
average_perplexity = torch.exp(torch.mean(loss))
print(per_token_perplexity)
print(average_perplexity)

The output:

tensor([5.4192e+04, 4.1502e+01], device='cuda:0', grad_fn=<ExpBackward0>)
tensor(1499.6934, device='cuda:0', grad_fn=<ExpBackward0>)
0
Liam Payne On

FYI

# Calculate log softmax to get log probabilities
log_probs = torch.nn.functional.log_softmax(logits, dim=-1)

# Initialize perplexity list
token_ppls = []


# Traverse each token, calculate its perplexity
for i in range(1, input_ids.size(1)):
    # Get the actual log probability of the next token
    target_id = input_ids[0, i]
    target_log_prob = log_probs[0, i-1, target_id].item()

    # Calculate the token's perplexity and add it to the list
    token_ppl = torch.exp(-torch.tensor(target_log_prob)).item()
    token_ppls.append(token_ppl)