gpt2 logits are different when I use past_key_values

136 views Asked by At

I am trying to use past_key_values to speed up the inference:

import torch
from transformers import GPT2LMHeadModel

torch.set_default_device("cuda")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()
model.to("cuda")

seq = torch.tensor([1, 2, 3, 4, 5])
original_out = model(input_ids=seq).logits

seq2 = torch.tensor([1, 2, 3])
key_values = model(input_ids=seq2, use_cache=True).past_key_values
new_seq = torch.tensor([4, 5])
magic = model(input_ids=new_seq, past_key_values=key_values).logits

print(torch.equal(original_out[-1, :], magic[-1, :]))

But this returns False, while I expect it to return True.

1

There are 1 answers

0
cronoik On

Your code is fine but you experience some floating-point precision issues. torch.equal checks if two tensors have the same shape and exactly the same values, but your two variables are slightly different:

import torch
from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()

seq = torch.tensor([1, 2, 3, 4, 5])
seq2 = torch.tensor([1, 2, 3])
new_seq = torch.tensor([4, 5])

with torch.inference_mode():
  original_out = model(input_ids=seq).logits[-1, :]
  key_values = model(input_ids=seq2, use_cache=True).past_key_values
  magic = model(input_ids=new_seq, past_key_values=key_values).logits[-1, :]

print(torch.equal(original_out, magic))
# Checking the difference of the first 20 elements
print((original_out[:20] - magic[:20]))

Output:

False
tensor([ 7.6294e-06,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         0.0000e+00, -7.6294e-06,  1.5259e-05,  1.5259e-05,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  2.2888e-05,
        -7.6294e-06, -7.6294e-06,  1.5259e-05,  0.0000e+00,  7.6294e-06])

I recommend using torch.allclose for comparing two tensors because it takes into account some tolerance:

print(torch.allclose(original_out, magic))

Output:

True