Skip to content

Conversation

@Kaminyou
Copy link

Background

When extracting KV caches for CacheGen offline dataset construction, we use:

generated = model.generate(input_ids, max_new_tokens=1, return_dict_in_generate=True)
past_key_values = generated["past_key_values"]

However, depending on the model (e.g., Mistral/LLaMA/LongChat variants) and transformers version,
generate() may return either:

  • seq_len == input_len (no decode step performed), or
  • seq_len == input_len + 1 (decode step performed internally)

The original implementation always dropped the last position:

  • kv[i][0] = kv[i][0][:, :, :-1][0]
  • kv[i][1] = kv[i][1][:, :, :-1][0]

This silently produced off-by-one KV caches for models where seq_len == input_len, resulting in truncated KV tensors and misalignment with the true prompt length.

Problem

For certain model backends (e.g., mistral-community/Mistral-7B-v0.2),
generated.past_key_values already has correct length = input_len.

Blindly slicing [:-1] incorrectly removes one valid token’s KV.

Example (longchat dataset; first instance):

input_ids.shape = (1, 8903)
returned kv seq_len = 8903
after slicing: 8902   <-- incorrect

This produces incomplete KV caches.

Fix

Instead of always removing the last timestep,
we explicitly slice the KV tensors to exactly match the original input token length:

token_length = input_ids.shape[1]
kv[i][0] = kv[i][0][:, :, :token_length][0]
kv[i][1] = kv[i][1][:, :, :token_length][0]

This handles both cases safely.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant