Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 85 additions & 2 deletions llmlingua/prompt_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,82 @@
AutoTokenizer,
)

try:
from transformers.cache_utils import DynamicCache
except ImportError:
DynamicCache = None


def _dynamic_cache_to_legacy(cache):
"""Convert a ``DynamicCache`` to a legacy tuple-of-(k, v) list.

Works across transformers versions: older ones expose
``to_legacy_cache()``; newer ones (>= 5.0) expose a ``layers`` attribute
whose entries have ``keys``/``values`` tensors.
"""
if hasattr(cache, "to_legacy_cache"):
try:
return cache.to_legacy_cache()
except Exception:
pass
if hasattr(cache, "layers"):
return [(layer.keys, layer.values) for layer in cache.layers]
# Last-resort: assume it's already iterable as (k, v) pairs.
return list(cache)


def _legacy_to_dynamic_cache(past_key_values):
"""Build a ``DynamicCache`` from a legacy tuple-of-(k, v) list."""
if DynamicCache is None:
return past_key_values
legacy = tuple((k, v) for k, v in past_key_values)
if hasattr(DynamicCache, "from_legacy_cache"):
return DynamicCache.from_legacy_cache(legacy)
# Transformers >= 5.0: use the ddp_cache_data constructor argument.
try:
return DynamicCache(ddp_cache_data=legacy)
except TypeError:
# Extremely old / unusual API — rebuild manually via update().
cache = DynamicCache()
for layer_idx, (k, v) in enumerate(legacy):
cache.update(k, v, layer_idx)
return cache


def _to_legacy_past_key_values(past_key_values):
"""Normalize ``past_key_values`` from the transformers model output to the
legacy tuple-of-(k, v) format that the rest of this module expects.

Transformers >= 4.36 returns a ``DynamicCache`` instead of a tuple, which
breaks ``for k, v in past_key_values`` loops and
``past_key_values[0][0].shape[2]`` accesses used by
``iterative_compress_prompt``. See microsoft/LLMLingua#210.
"""
if past_key_values is None:
return None
if DynamicCache is not None and isinstance(past_key_values, DynamicCache):
return _dynamic_cache_to_legacy(past_key_values)
if hasattr(past_key_values, "to_legacy_cache"):
return past_key_values.to_legacy_cache()
if hasattr(past_key_values, "layers"):
return _dynamic_cache_to_legacy(past_key_values)
return past_key_values


def _to_model_past_key_values(past_key_values):
"""Convert a legacy tuple-of-(k, v) cache back into the object the model
expects. On newer transformers this is a ``DynamicCache``; on older
versions the legacy tuple is still accepted, so we pass it through.
"""
if past_key_values is None:
return None
if DynamicCache is None:
return past_key_values
if isinstance(past_key_values, DynamicCache):
return past_key_values
return _legacy_to_dynamic_cache(past_key_values)


from .utils import (
TokenClfDataset,
get_pure_token,
Expand Down Expand Up @@ -178,6 +254,10 @@ def get_ppl(
tokenized_text = self.tokenizer(text, return_tensors="pt")
input_ids = tokenized_text["input_ids"].to(self.device)
attention_mask = tokenized_text["attention_mask"].to(self.device)
# Normalize any cache the caller hands us to the legacy tuple format
# so the ``past_key_values[0][0].shape[2]`` read below stays valid
# even when transformers >= 4.36 returned a DynamicCache. See #210.
past_key_values = _to_legacy_past_key_values(past_key_values)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
else:
Expand All @@ -189,10 +269,13 @@ def get_ppl(
response = self.model(
input_ids[:, past_length:end],
attention_mask=attention_mask[:, :end],
past_key_values=past_key_values,
past_key_values=_to_model_past_key_values(past_key_values),
use_cache=True,
)
past_key_values = response.past_key_values
# Normalize the returned cache back to the legacy tuple format so
# ``for k, v in past_key_values`` loops in iterative_compress_prompt
# continue to work on newer transformers.
past_key_values = _to_legacy_past_key_values(response.past_key_values)

shift_logits = response.logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., past_length + 1 : end].contiguous()
Expand Down
127 changes: 127 additions & 0 deletions tests/test_issue_210.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2023-2025 Microsoft
# Licensed under The MIT License [see LICENSE for details]

"""Regression tests for microsoft/LLMLingua#210.

llmlingua's ``iterative_compress_prompt`` assumes ``past_key_values`` is a
tuple-of-(k, v) per layer: it slices ``past_key_values[0][0].shape[2]`` to
read the cached length and iterates ``for k, v in past_key_values`` four
separate times.

Transformers >= 4.36 now returns a ``DynamicCache`` object instead of a
tuple. This test verifies the normalization helpers in
``llmlingua.prompt_compressor`` correctly round-trip between the two
representations so that the iteration/slicing code in
``iterative_compress_prompt`` continues to work.
"""

import unittest

import torch

from llmlingua.prompt_compressor import (
DynamicCache,
_to_legacy_past_key_values,
_to_model_past_key_values,
)


def _make_legacy_cache(num_layers=4, batch=1, num_heads=2, seq_len=5, head_dim=3):
"""Build a legacy ``past_key_values``: a tuple of (k, v) tensors per layer."""
torch.manual_seed(0)
return tuple(
(
torch.randn(batch, num_heads, seq_len, head_dim),
torch.randn(batch, num_heads, seq_len, head_dim),
)
for _ in range(num_layers)
)


class Issue210Tester(unittest.TestCase):
def test_to_legacy_handles_none(self):
self.assertIsNone(_to_legacy_past_key_values(None))

def test_to_model_handles_none(self):
self.assertIsNone(_to_model_past_key_values(None))

def test_legacy_roundtrip_is_iterable_as_kv_pairs(self):
"""``for k, v in past_key_values`` must continue to work."""
legacy = _make_legacy_cache()
normalized = _to_legacy_past_key_values(legacy)
self.assertIsNotNone(normalized)
pairs = [(k, v) for k, v in normalized]
self.assertEqual(len(pairs), len(legacy))
for (k_new, v_new), (k_old, v_old) in zip(pairs, legacy):
self.assertTrue(torch.equal(k_new, k_old))
self.assertTrue(torch.equal(v_new, v_old))

def test_legacy_roundtrip_preserves_shape_access(self):
"""``past_key_values[0][0].shape[2]`` is used in iterative_compress_prompt."""
legacy = _make_legacy_cache(seq_len=7)
normalized = _to_legacy_past_key_values(legacy)
self.assertEqual(normalized[0][0].shape[2], 7)

@unittest.skipIf(DynamicCache is None, "transformers.cache_utils not available")
def test_dynamic_cache_to_legacy(self):
"""A DynamicCache must be normalized to a tuple-of-(k, v) per layer."""
legacy = _make_legacy_cache()
dynamic = _to_model_past_key_values(legacy)
self.assertIsInstance(dynamic, DynamicCache)

normalized = _to_legacy_past_key_values(dynamic)
self.assertIsNotNone(normalized)
# Must be iterable as (k, v) pairs and shape-addressable.
pairs = [(k, v) for k, v in normalized]
self.assertEqual(len(pairs), len(legacy))
for (k_new, v_new), (k_old, v_old) in zip(pairs, legacy):
self.assertTrue(torch.equal(k_new, k_old))
self.assertTrue(torch.equal(v_new, v_old))

# And the shape read used in iterative_compress_prompt must work.
self.assertEqual(normalized[0][0].shape[2], legacy[0][0].shape[2])

@unittest.skipIf(DynamicCache is None, "transformers.cache_utils not available")
def test_to_model_is_idempotent_on_dynamic_cache(self):
"""Passing a DynamicCache through ``_to_model_past_key_values`` must
leave it untouched (no redundant rebuild)."""
legacy = _make_legacy_cache()
dynamic = _to_model_past_key_values(legacy)
again = _to_model_past_key_values(dynamic)
self.assertIs(again, dynamic)

@unittest.skipIf(DynamicCache is None, "transformers.cache_utils not available")
def test_roundtrip_tuple_to_dynamic_to_tuple(self):
"""tuple -> DynamicCache -> tuple must preserve content."""
legacy = _make_legacy_cache(num_layers=3, seq_len=4)
dynamic = _to_model_past_key_values(legacy)
back = _to_legacy_past_key_values(dynamic)
self.assertEqual(len(back), len(legacy))
for (k_new, v_new), (k_old, v_old) in zip(back, legacy):
self.assertTrue(torch.equal(k_new, k_old))
self.assertTrue(torch.equal(v_new, v_old))

def test_slice_pattern_from_iterative_compress_prompt(self):
"""Reproduce the ``[[k[..., :s, :], v[..., :s, :]] for k, v in past]``
pattern used by ``iterative_compress_prompt`` on a normalized cache
that originally came from a DynamicCache."""
legacy = _make_legacy_cache(seq_len=10)
dynamic = _to_model_past_key_values(legacy)
past = _to_legacy_past_key_values(dynamic)

s, e = 2, 3
sliced = [
[
torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2),
torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2),
]
for k, v in past
]
self.assertEqual(len(sliced), len(legacy))
for k, v in sliced:
self.assertEqual(k.shape[2], 10 - e)
self.assertEqual(v.shape[2], 10 - e)


if __name__ == "__main__":
unittest.main()