I am trying to use past_key_values to speed up the inference:
import torch
from transformers import GPT2LMHeadModel
torch.set_default_device("cuda")
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.eval()
model.to("cuda")
seq = torch.tensor([1, 2, 3, 4, 5])
original_out = model(input_ids=seq).logits
seq2 = torch.tensor([1, 2, 3])
key_values = model(input_ids=seq2, use_cache=True).past_key_values
new_seq = torch.tensor([4, 5])
magic = model(input_ids=new_seq, past_key_values=key_values).logits
print(torch.equal(original_out[-1, :], magic[-1, :]))
But this returns False, while I expect it to return True.
Your code is fine but you experience some floating-point precision issues. torch.equal checks if two tensors have the same shape and exactly the same values, but your two variables are slightly different:
Output:
I recommend using torch.allclose for comparing two tensors because it takes into account some tolerance:
Output: