You are computing
past_kv = outputs.past_key_values
but using past_key_values
in your subsequent call.
outputs = model.generate(
input_ids=outputs.sequences,
attention_mask=torch.ones((1, outputs.sequences.shape[1]), dtype=torch.int).to(device),
past_key_values=past_key_values, # change this to past_kv
temperature=0.7,
max_new_tokens=1,
use_cache=True,
return_dict_in_generate=True
)
Maybe you are using a previously computed past_key_values
? With this change I'm getting
print(t2-t1) # 4.438955783843994
print(t3-t2) # 0.06613826751708984
as expected.