Fix incorrect KV cache length during offline CacheGen KV extraction #8
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Background
When extracting KV caches for CacheGen offline dataset construction, we use:
However, depending on the model (e.g., Mistral/LLaMA/LongChat variants) and transformers version,
generate()may return either:The original implementation always dropped the last position:
This silently produced off-by-one KV caches for models where seq_len == input_len, resulting in truncated KV tensors and misalignment with the true prompt length.
Problem
For certain model backends (e.g., mistral-community/Mistral-7B-v0.2),
generated.past_key_values already has correct length = input_len.
Blindly slicing [:-1] incorrectly removes one valid token’s KV.
Example (longchat dataset; first instance):
This produces incomplete KV caches.
Fix
Instead of always removing the last timestep,
we explicitly slice the KV tensors to exactly match the original input token length:
This handles both cases safely.