diff --git a/llmlingua/prompt_compressor.py b/llmlingua/prompt_compressor.py index 84e390e..95d0697 100644 --- a/llmlingua/prompt_compressor.py +++ b/llmlingua/prompt_compressor.py @@ -22,6 +22,82 @@ AutoTokenizer, ) +try: + from transformers.cache_utils import DynamicCache +except ImportError: + DynamicCache = None + + +def _dynamic_cache_to_legacy(cache): + """Convert a ``DynamicCache`` to a legacy tuple-of-(k, v) list. + + Works across transformers versions: older ones expose + ``to_legacy_cache()``; newer ones (>= 5.0) expose a ``layers`` attribute + whose entries have ``keys``/``values`` tensors. + """ + if hasattr(cache, "to_legacy_cache"): + try: + return cache.to_legacy_cache() + except Exception: + pass + if hasattr(cache, "layers"): + return [(layer.keys, layer.values) for layer in cache.layers] + # Last-resort: assume it's already iterable as (k, v) pairs. + return list(cache) + + +def _legacy_to_dynamic_cache(past_key_values): + """Build a ``DynamicCache`` from a legacy tuple-of-(k, v) list.""" + if DynamicCache is None: + return past_key_values + legacy = tuple((k, v) for k, v in past_key_values) + if hasattr(DynamicCache, "from_legacy_cache"): + return DynamicCache.from_legacy_cache(legacy) + # Transformers >= 5.0: use the ddp_cache_data constructor argument. + try: + return DynamicCache(ddp_cache_data=legacy) + except TypeError: + # Extremely old / unusual API — rebuild manually via update(). + cache = DynamicCache() + for layer_idx, (k, v) in enumerate(legacy): + cache.update(k, v, layer_idx) + return cache + + +def _to_legacy_past_key_values(past_key_values): + """Normalize ``past_key_values`` from the transformers model output to the + legacy tuple-of-(k, v) format that the rest of this module expects. + + Transformers >= 4.36 returns a ``DynamicCache`` instead of a tuple, which + breaks ``for k, v in past_key_values`` loops and + ``past_key_values[0][0].shape[2]`` accesses used by + ``iterative_compress_prompt``. See microsoft/LLMLingua#210. + """ + if past_key_values is None: + return None + if DynamicCache is not None and isinstance(past_key_values, DynamicCache): + return _dynamic_cache_to_legacy(past_key_values) + if hasattr(past_key_values, "to_legacy_cache"): + return past_key_values.to_legacy_cache() + if hasattr(past_key_values, "layers"): + return _dynamic_cache_to_legacy(past_key_values) + return past_key_values + + +def _to_model_past_key_values(past_key_values): + """Convert a legacy tuple-of-(k, v) cache back into the object the model + expects. On newer transformers this is a ``DynamicCache``; on older + versions the legacy tuple is still accepted, so we pass it through. + """ + if past_key_values is None: + return None + if DynamicCache is None: + return past_key_values + if isinstance(past_key_values, DynamicCache): + return past_key_values + return _legacy_to_dynamic_cache(past_key_values) + + from .utils import ( TokenClfDataset, get_pure_token, @@ -178,6 +254,10 @@ def get_ppl( tokenized_text = self.tokenizer(text, return_tensors="pt") input_ids = tokenized_text["input_ids"].to(self.device) attention_mask = tokenized_text["attention_mask"].to(self.device) + # Normalize any cache the caller hands us to the legacy tuple format + # so the ``past_key_values[0][0].shape[2]`` read below stays valid + # even when transformers >= 4.36 returned a DynamicCache. See #210. + past_key_values = _to_legacy_past_key_values(past_key_values) if past_key_values is not None: past_length = past_key_values[0][0].shape[2] else: @@ -189,10 +269,13 @@ def get_ppl( response = self.model( input_ids[:, past_length:end], attention_mask=attention_mask[:, :end], - past_key_values=past_key_values, + past_key_values=_to_model_past_key_values(past_key_values), use_cache=True, ) - past_key_values = response.past_key_values + # Normalize the returned cache back to the legacy tuple format so + # ``for k, v in past_key_values`` loops in iterative_compress_prompt + # continue to work on newer transformers. + past_key_values = _to_legacy_past_key_values(response.past_key_values) shift_logits = response.logits[..., :-1, :].contiguous() shift_labels = input_ids[..., past_length + 1 : end].contiguous() diff --git a/tests/test_issue_210.py b/tests/test_issue_210.py new file mode 100644 index 0000000..1d8ad64 --- /dev/null +++ b/tests/test_issue_210.py @@ -0,0 +1,127 @@ +# Copyright (c) 2023-2025 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +"""Regression tests for microsoft/LLMLingua#210. + +llmlingua's ``iterative_compress_prompt`` assumes ``past_key_values`` is a +tuple-of-(k, v) per layer: it slices ``past_key_values[0][0].shape[2]`` to +read the cached length and iterates ``for k, v in past_key_values`` four +separate times. + +Transformers >= 4.36 now returns a ``DynamicCache`` object instead of a +tuple. This test verifies the normalization helpers in +``llmlingua.prompt_compressor`` correctly round-trip between the two +representations so that the iteration/slicing code in +``iterative_compress_prompt`` continues to work. +""" + +import unittest + +import torch + +from llmlingua.prompt_compressor import ( + DynamicCache, + _to_legacy_past_key_values, + _to_model_past_key_values, +) + + +def _make_legacy_cache(num_layers=4, batch=1, num_heads=2, seq_len=5, head_dim=3): + """Build a legacy ``past_key_values``: a tuple of (k, v) tensors per layer.""" + torch.manual_seed(0) + return tuple( + ( + torch.randn(batch, num_heads, seq_len, head_dim), + torch.randn(batch, num_heads, seq_len, head_dim), + ) + for _ in range(num_layers) + ) + + +class Issue210Tester(unittest.TestCase): + def test_to_legacy_handles_none(self): + self.assertIsNone(_to_legacy_past_key_values(None)) + + def test_to_model_handles_none(self): + self.assertIsNone(_to_model_past_key_values(None)) + + def test_legacy_roundtrip_is_iterable_as_kv_pairs(self): + """``for k, v in past_key_values`` must continue to work.""" + legacy = _make_legacy_cache() + normalized = _to_legacy_past_key_values(legacy) + self.assertIsNotNone(normalized) + pairs = [(k, v) for k, v in normalized] + self.assertEqual(len(pairs), len(legacy)) + for (k_new, v_new), (k_old, v_old) in zip(pairs, legacy): + self.assertTrue(torch.equal(k_new, k_old)) + self.assertTrue(torch.equal(v_new, v_old)) + + def test_legacy_roundtrip_preserves_shape_access(self): + """``past_key_values[0][0].shape[2]`` is used in iterative_compress_prompt.""" + legacy = _make_legacy_cache(seq_len=7) + normalized = _to_legacy_past_key_values(legacy) + self.assertEqual(normalized[0][0].shape[2], 7) + + @unittest.skipIf(DynamicCache is None, "transformers.cache_utils not available") + def test_dynamic_cache_to_legacy(self): + """A DynamicCache must be normalized to a tuple-of-(k, v) per layer.""" + legacy = _make_legacy_cache() + dynamic = _to_model_past_key_values(legacy) + self.assertIsInstance(dynamic, DynamicCache) + + normalized = _to_legacy_past_key_values(dynamic) + self.assertIsNotNone(normalized) + # Must be iterable as (k, v) pairs and shape-addressable. + pairs = [(k, v) for k, v in normalized] + self.assertEqual(len(pairs), len(legacy)) + for (k_new, v_new), (k_old, v_old) in zip(pairs, legacy): + self.assertTrue(torch.equal(k_new, k_old)) + self.assertTrue(torch.equal(v_new, v_old)) + + # And the shape read used in iterative_compress_prompt must work. + self.assertEqual(normalized[0][0].shape[2], legacy[0][0].shape[2]) + + @unittest.skipIf(DynamicCache is None, "transformers.cache_utils not available") + def test_to_model_is_idempotent_on_dynamic_cache(self): + """Passing a DynamicCache through ``_to_model_past_key_values`` must + leave it untouched (no redundant rebuild).""" + legacy = _make_legacy_cache() + dynamic = _to_model_past_key_values(legacy) + again = _to_model_past_key_values(dynamic) + self.assertIs(again, dynamic) + + @unittest.skipIf(DynamicCache is None, "transformers.cache_utils not available") + def test_roundtrip_tuple_to_dynamic_to_tuple(self): + """tuple -> DynamicCache -> tuple must preserve content.""" + legacy = _make_legacy_cache(num_layers=3, seq_len=4) + dynamic = _to_model_past_key_values(legacy) + back = _to_legacy_past_key_values(dynamic) + self.assertEqual(len(back), len(legacy)) + for (k_new, v_new), (k_old, v_old) in zip(back, legacy): + self.assertTrue(torch.equal(k_new, k_old)) + self.assertTrue(torch.equal(v_new, v_old)) + + def test_slice_pattern_from_iterative_compress_prompt(self): + """Reproduce the ``[[k[..., :s, :], v[..., :s, :]] for k, v in past]`` + pattern used by ``iterative_compress_prompt`` on a normalized cache + that originally came from a DynamicCache.""" + legacy = _make_legacy_cache(seq_len=10) + dynamic = _to_model_past_key_values(legacy) + past = _to_legacy_past_key_values(dynamic) + + s, e = 2, 3 + sliced = [ + [ + torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2), + torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2), + ] + for k, v in past + ] + self.assertEqual(len(sliced), len(legacy)) + for k, v in sliced: + self.assertEqual(k.shape[2], 10 - e) + self.assertEqual(v.shape[2], 10 - e) + + +if __name__ == "__main__": + unittest.main()