Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions llmlingua/prompt_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,6 +1414,8 @@ def get_compressed_input(
self_input_ids=None,
self_attention_mask=None,
):
if end < iterative_size:
end = iterative_size
if self_loss is not None:
need_idx = torch.concat(
[
Expand Down
160 changes: 160 additions & 0 deletions tests/test_issue_196.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]

"""Regression test for https://github.com/microsoft/LLMLingua/issues/196

LLMLingua / LongLLMLingua's ``iterative_compress_prompt`` calls
``get_compressed_input(..., end=prompt_len, iterative_size=200)``. When the
prompt is shorter than ``iterative_size``, ``end - iterative_size`` goes
negative, ``need_idx[: end - iterative_size] = 1`` ends up overwriting the
tail of ``need_idx`` from the right (because negative indices wrap), and the
thresholding decision is discarded — no tokens are actually dropped, so the
achieved compression rate collapses to 1.0x.

The fix clamps ``end`` to at least ``iterative_size`` at the top of
``get_compressed_input``. This test exercises the function directly with a
``PromptCompressor.__new__`` shim so no model weights are loaded.
"""

import unittest

import torch

from llmlingua import PromptCompressor


def _bare_compressor():
"""Return a PromptCompressor instance that skips ``__init__`` / model load."""
return PromptCompressor.__new__(PromptCompressor)


class Issue196SmallPromptCompressionTest(unittest.TestCase):
def _run(self, loss, input_ids, end, iterative_size, threshold):
compressor = _bare_compressor()
attention_mask = torch.ones_like(input_ids)
return compressor.get_compressed_input(
loss=loss,
input_ids=input_ids,
attention_mask=attention_mask,
end=end,
iterative_size=iterative_size,
threshold=threshold,
)

def test_short_prompt_actually_compresses(self):
"""A 50-token prompt with the default iterative_size=200 must drop
tokens whose loss is below the threshold. Before the fix, every
token was kept — the achieved compression rate was 1.0x regardless
of the thresholding decision."""
prompt_len = 50
iterative_size = 200
# Half the tokens have loss well above the threshold, half well below.
# Without the fix, need_idx[: end - iterative_size] = 1 (i.e.
# need_idx[:-150] = 1 which wraps to need_idx[:len-150] = 1 on a
# 51-element tensor) keeps every position.
loss = torch.tensor(
[10.0 if i % 2 == 0 else 0.0 for i in range(prompt_len)],
dtype=torch.float32,
)
input_ids = torch.arange(prompt_len, dtype=torch.long).unsqueeze(0)

(
compressed_input_ids,
_compressed_attention_mask,
_keep_flag,
_end,
_loss_out,
_self_loss_out,
_self_ids,
_self_mask,
) = self._run(
loss=loss,
input_ids=input_ids,
end=prompt_len,
iterative_size=iterative_size,
threshold=1.0,
)

self.assertLess(
compressed_input_ids.shape[1],
prompt_len,
"short prompt was not compressed at all — issue #196 regression",
)
# Every kept token had loss > threshold (even indices).
kept = set(compressed_input_ids[0].tolist())
odd_indices = {i for i in range(prompt_len) if i % 2 == 1}
self.assertFalse(
kept & odd_indices,
f"below-threshold tokens leaked into the compressed output: "
f"{sorted(kept & odd_indices)}",
)

def test_prompt_exactly_at_iterative_size_still_compresses(self):
prompt_len = 200
iterative_size = 200
loss = torch.tensor(
[10.0 if i % 3 == 0 else 0.0 for i in range(prompt_len)],
dtype=torch.float32,
)
input_ids = torch.arange(prompt_len, dtype=torch.long).unsqueeze(0)

(compressed_input_ids, *_) = self._run(
loss=loss,
input_ids=input_ids,
end=prompt_len,
iterative_size=iterative_size,
threshold=1.0,
)
self.assertLess(compressed_input_ids.shape[1], prompt_len)

def test_prompt_longer_than_iterative_size_is_untouched_by_fix(self):
"""Long prompts must behave exactly as before — the clamp only
kicks in for short prompts."""
prompt_len = 300
iterative_size = 200
loss = torch.zeros(prompt_len, dtype=torch.float32)
# Last 200 tokens have high loss; first 100 low loss. With
# end=prompt_len=300 and iterative_size=200, the window is [100, 300).
loss[100:] = 10.0
input_ids = torch.arange(prompt_len, dtype=torch.long).unsqueeze(0)

(compressed_input_ids, *_) = self._run(
loss=loss,
input_ids=input_ids,
end=prompt_len,
iterative_size=iterative_size,
threshold=1.0,
)
# Nothing should drop because the [:end - iterative_size] block
# keeps the first 100 and the threshold keeps the rest.
self.assertEqual(compressed_input_ids.shape[1], prompt_len)

def test_very_short_prompt_below_66_tokens(self):
"""The reporter specifically calls out that prompts below ~66 tokens
weren't compressed at all. Anchor that boundary."""
prompt_len = 40
iterative_size = 200
loss = torch.tensor(
[10.0 if i < 20 else 0.0 for i in range(prompt_len)],
dtype=torch.float32,
)
input_ids = torch.arange(prompt_len, dtype=torch.long).unsqueeze(0)

(compressed_input_ids, *_) = self._run(
loss=loss,
input_ids=input_ids,
end=prompt_len,
iterative_size=iterative_size,
threshold=1.0,
)
self.assertLess(compressed_input_ids.shape[1], prompt_len)
kept = set(compressed_input_ids[0].tolist())
below_threshold = {i for i in range(20, prompt_len)}
self.assertFalse(
kept & below_threshold,
"below-threshold tokens leaked in — #196 regression at <66 tokens",
)


if __name__ == "__main__":
unittest.main()