From aa40abbf258b6180e6cb529ca69fe648e4b502fe Mon Sep 17 00:00:00 2001 From: Ousama Ben Younes Date: Sat, 11 Apr 2026 18:53:25 +0000 Subject: [PATCH] fix: normalize past_key_values across transformers DynamicCache API (#210) Older llmlingua code assumed past_key_values was a tuple of (k, v) pairs. Transformers >= 4.36 returns a DynamicCache object, breaking the `for k, v in past_key_values` loops and the `past_key_values[0][0].shape[2]` access in iterative_compress_prompt with `AttributeError: 'list' object has no attribute 'get_seq_length'` (as reported in #210) or the unpacking ValueError that the test suite catches. Normalize to the legacy tuple format after every model forward pass and convert back to DynamicCache before any call that feeds it to the model. Generated by Claude Code Vibe coded by ousamabenyounes Co-Authored-By: Claude --- llmlingua/prompt_compressor.py | 87 +++++++++++++++++++++- tests/test_issue_210.py | 127 +++++++++++++++++++++++++++++++++ 2 files changed, 212 insertions(+), 2 deletions(-) create mode 100644 tests/test_issue_210.py diff --git a/llmlingua/prompt_compressor.py b/llmlingua/prompt_compressor.py index 84e390e..95d0697 100644 --- a/llmlingua/prompt_compressor.py +++ b/llmlingua/prompt_compressor.py @@ -22,6 +22,82 @@ AutoTokenizer, ) +try: + from transformers.cache_utils import DynamicCache +except ImportError: + DynamicCache = None + + +def _dynamic_cache_to_legacy(cache): + """Convert a ``DynamicCache`` to a legacy tuple-of-(k, v) list. + + Works across transformers versions: older ones expose + ``to_legacy_cache()``; newer ones (>= 5.0) expose a ``layers`` attribute + whose entries have ``keys``/``values`` tensors. + """ + if hasattr(cache, "to_legacy_cache"): + try: + return cache.to_legacy_cache() + except Exception: + pass + if hasattr(cache, "layers"): + return [(layer.keys, layer.values) for layer in cache.layers] + # Last-resort: assume it's already iterable as (k, v) pairs. + return list(cache) + + +def _legacy_to_dynamic_cache(past_key_values): + """Build a ``DynamicCache`` from a legacy tuple-of-(k, v) list.""" + if DynamicCache is None: + return past_key_values + legacy = tuple((k, v) for k, v in past_key_values) + if hasattr(DynamicCache, "from_legacy_cache"): + return DynamicCache.from_legacy_cache(legacy) + # Transformers >= 5.0: use the ddp_cache_data constructor argument. + try: + return DynamicCache(ddp_cache_data=legacy) + except TypeError: + # Extremely old / unusual API — rebuild manually via update(). + cache = DynamicCache() + for layer_idx, (k, v) in enumerate(legacy): + cache.update(k, v, layer_idx) + return cache + + +def _to_legacy_past_key_values(past_key_values): + """Normalize ``past_key_values`` from the transformers model output to the + legacy tuple-of-(k, v) format that the rest of this module expects. + + Transformers >= 4.36 returns a ``DynamicCache`` instead of a tuple, which + breaks ``for k, v in past_key_values`` loops and + ``past_key_values[0][0].shape[2]`` accesses used by + ``iterative_compress_prompt``. See microsoft/LLMLingua#210. + """ + if past_key_values is None: + return None + if DynamicCache is not None and isinstance(past_key_values, DynamicCache): + return _dynamic_cache_to_legacy(past_key_values) + if hasattr(past_key_values, "to_legacy_cache"): + return past_key_values.to_legacy_cache() + if hasattr(past_key_values, "layers"): + return _dynamic_cache_to_legacy(past_key_values) + return past_key_values + + +def _to_model_past_key_values(past_key_values): + """Convert a legacy tuple-of-(k, v) cache back into the object the model + expects. On newer transformers this is a ``DynamicCache``; on older + versions the legacy tuple is still accepted, so we pass it through. + """ + if past_key_values is None: + return None + if DynamicCache is None: + return past_key_values + if isinstance(past_key_values, DynamicCache): + return past_key_values + return _legacy_to_dynamic_cache(past_key_values) + + from .utils import ( TokenClfDataset, get_pure_token, @@ -178,6 +254,10 @@ def get_ppl( tokenized_text = self.tokenizer(text, return_tensors="pt") input_ids = tokenized_text["input_ids"].to(self.device) attention_mask = tokenized_text["attention_mask"].to(self.device) + # Normalize any cache the caller hands us to the legacy tuple format + # so the ``past_key_values[0][0].shape[2]`` read below stays valid + # even when transformers >= 4.36 returned a DynamicCache. See #210. + past_key_values = _to_legacy_past_key_values(past_key_values) if past_key_values is not None: past_length = past_key_values[0][0].shape[2] else: @@ -189,10 +269,13 @@ def get_ppl( response = self.model( input_ids[:, past_length:end], attention_mask=attention_mask[:, :end], - past_key_values=past_key_values, + past_key_values=_to_model_past_key_values(past_key_values), use_cache=True, ) - past_key_values = response.past_key_values + # Normalize the returned cache back to the legacy tuple format so + # ``for k, v in past_key_values`` loops in iterative_compress_prompt + # continue to work on newer transformers. + past_key_values = _to_legacy_past_key_values(response.past_key_values) shift_logits = response.logits[..., :-1, :].contiguous() shift_labels = input_ids[..., past_length + 1 : end].contiguous() diff --git a/tests/test_issue_210.py b/tests/test_issue_210.py new file mode 100644 index 0000000..1d8ad64 --- /dev/null +++ b/tests/test_issue_210.py @@ -0,0 +1,127 @@ +# Copyright (c) 2023-2025 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +"""Regression tests for microsoft/LLMLingua#210. + +llmlingua's ``iterative_compress_prompt`` assumes ``past_key_values`` is a +tuple-of-(k, v) per layer: it slices ``past_key_values[0][0].shape[2]`` to +read the cached length and iterates ``for k, v in past_key_values`` four +separate times. + +Transformers >= 4.36 now returns a ``DynamicCache`` object instead of a +tuple. This test verifies the normalization helpers in +``llmlingua.prompt_compressor`` correctly round-trip between the two +representations so that the iteration/slicing code in +``iterative_compress_prompt`` continues to work. +""" + +import unittest + +import torch + +from llmlingua.prompt_compressor import ( + DynamicCache, + _to_legacy_past_key_values, + _to_model_past_key_values, +) + + +def _make_legacy_cache(num_layers=4, batch=1, num_heads=2, seq_len=5, head_dim=3): + """Build a legacy ``past_key_values``: a tuple of (k, v) tensors per layer.""" + torch.manual_seed(0) + return tuple( + ( + torch.randn(batch, num_heads, seq_len, head_dim), + torch.randn(batch, num_heads, seq_len, head_dim), + ) + for _ in range(num_layers) + ) + + +class Issue210Tester(unittest.TestCase): + def test_to_legacy_handles_none(self): + self.assertIsNone(_to_legacy_past_key_values(None)) + + def test_to_model_handles_none(self): + self.assertIsNone(_to_model_past_key_values(None)) + + def test_legacy_roundtrip_is_iterable_as_kv_pairs(self): + """``for k, v in past_key_values`` must continue to work.""" + legacy = _make_legacy_cache() + normalized = _to_legacy_past_key_values(legacy) + self.assertIsNotNone(normalized) + pairs = [(k, v) for k, v in normalized] + self.assertEqual(len(pairs), len(legacy)) + for (k_new, v_new), (k_old, v_old) in zip(pairs, legacy): + self.assertTrue(torch.equal(k_new, k_old)) + self.assertTrue(torch.equal(v_new, v_old)) + + def test_legacy_roundtrip_preserves_shape_access(self): + """``past_key_values[0][0].shape[2]`` is used in iterative_compress_prompt.""" + legacy = _make_legacy_cache(seq_len=7) + normalized = _to_legacy_past_key_values(legacy) + self.assertEqual(normalized[0][0].shape[2], 7) + + @unittest.skipIf(DynamicCache is None, "transformers.cache_utils not available") + def test_dynamic_cache_to_legacy(self): + """A DynamicCache must be normalized to a tuple-of-(k, v) per layer.""" + legacy = _make_legacy_cache() + dynamic = _to_model_past_key_values(legacy) + self.assertIsInstance(dynamic, DynamicCache) + + normalized = _to_legacy_past_key_values(dynamic) + self.assertIsNotNone(normalized) + # Must be iterable as (k, v) pairs and shape-addressable. + pairs = [(k, v) for k, v in normalized] + self.assertEqual(len(pairs), len(legacy)) + for (k_new, v_new), (k_old, v_old) in zip(pairs, legacy): + self.assertTrue(torch.equal(k_new, k_old)) + self.assertTrue(torch.equal(v_new, v_old)) + + # And the shape read used in iterative_compress_prompt must work. + self.assertEqual(normalized[0][0].shape[2], legacy[0][0].shape[2]) + + @unittest.skipIf(DynamicCache is None, "transformers.cache_utils not available") + def test_to_model_is_idempotent_on_dynamic_cache(self): + """Passing a DynamicCache through ``_to_model_past_key_values`` must + leave it untouched (no redundant rebuild).""" + legacy = _make_legacy_cache() + dynamic = _to_model_past_key_values(legacy) + again = _to_model_past_key_values(dynamic) + self.assertIs(again, dynamic) + + @unittest.skipIf(DynamicCache is None, "transformers.cache_utils not available") + def test_roundtrip_tuple_to_dynamic_to_tuple(self): + """tuple -> DynamicCache -> tuple must preserve content.""" + legacy = _make_legacy_cache(num_layers=3, seq_len=4) + dynamic = _to_model_past_key_values(legacy) + back = _to_legacy_past_key_values(dynamic) + self.assertEqual(len(back), len(legacy)) + for (k_new, v_new), (k_old, v_old) in zip(back, legacy): + self.assertTrue(torch.equal(k_new, k_old)) + self.assertTrue(torch.equal(v_new, v_old)) + + def test_slice_pattern_from_iterative_compress_prompt(self): + """Reproduce the ``[[k[..., :s, :], v[..., :s, :]] for k, v in past]`` + pattern used by ``iterative_compress_prompt`` on a normalized cache + that originally came from a DynamicCache.""" + legacy = _make_legacy_cache(seq_len=10) + dynamic = _to_model_past_key_values(legacy) + past = _to_legacy_past_key_values(dynamic) + + s, e = 2, 3 + sliced = [ + [ + torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2), + torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2), + ] + for k, v in past + ] + self.assertEqual(len(sliced), len(legacy)) + for k, v in sliced: + self.assertEqual(k.shape[2], 10 - e) + self.assertEqual(v.shape[2], 10 - e) + + +if __name__ == "__main__": + unittest.main()