From e04a7bdc1b8df7aeb0fb95625a5eb2f82faa10a8 Mon Sep 17 00:00:00 2001 From: Ousama Ben Younes Date: Sat, 11 Apr 2026 18:58:06 +0000 Subject: [PATCH] fix: compress prompts shorter than iterative_size (#196) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When LLMLingua/LongLLMLingua iteratively compresses a prompt shorter than iterative_size (default 200), get_compressed_input is called with end == prompt_len. The line `need_idx[: end - iterative_size] = 1` then uses a negative index that wraps around from the right, which silently overwrites the tail of need_idx with True — so every token gets kept regardless of the thresholding decision and the achieved compression rate collapses to 1.0x for small prompts. Clamp end to at least iterative_size at the top of get_compressed_input so the two masking writes become a no-op on the left slice and only keep the tail on the right slice, letting the threshold-based need_idx actually take effect. Generated by Claude Code Vibe coded by ousamabenyounes Co-Authored-By: Claude --- llmlingua/prompt_compressor.py | 2 + tests/test_issue_196.py | 160 +++++++++++++++++++++++++++++++++ 2 files changed, 162 insertions(+) create mode 100644 tests/test_issue_196.py diff --git a/llmlingua/prompt_compressor.py b/llmlingua/prompt_compressor.py index 84e390e..cee9543 100644 --- a/llmlingua/prompt_compressor.py +++ b/llmlingua/prompt_compressor.py @@ -1414,6 +1414,8 @@ def get_compressed_input( self_input_ids=None, self_attention_mask=None, ): + if end < iterative_size: + end = iterative_size if self_loss is not None: need_idx = torch.concat( [ diff --git a/tests/test_issue_196.py b/tests/test_issue_196.py new file mode 100644 index 0000000..188c537 --- /dev/null +++ b/tests/test_issue_196.py @@ -0,0 +1,160 @@ +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +"""Regression test for https://github.com/microsoft/LLMLingua/issues/196 + +LLMLingua / LongLLMLingua's ``iterative_compress_prompt`` calls +``get_compressed_input(..., end=prompt_len, iterative_size=200)``. When the +prompt is shorter than ``iterative_size``, ``end - iterative_size`` goes +negative, ``need_idx[: end - iterative_size] = 1`` ends up overwriting the +tail of ``need_idx`` from the right (because negative indices wrap), and the +thresholding decision is discarded — no tokens are actually dropped, so the +achieved compression rate collapses to 1.0x. + +The fix clamps ``end`` to at least ``iterative_size`` at the top of +``get_compressed_input``. This test exercises the function directly with a +``PromptCompressor.__new__`` shim so no model weights are loaded. +""" + +import unittest + +import torch + +from llmlingua import PromptCompressor + + +def _bare_compressor(): + """Return a PromptCompressor instance that skips ``__init__`` / model load.""" + return PromptCompressor.__new__(PromptCompressor) + + +class Issue196SmallPromptCompressionTest(unittest.TestCase): + def _run(self, loss, input_ids, end, iterative_size, threshold): + compressor = _bare_compressor() + attention_mask = torch.ones_like(input_ids) + return compressor.get_compressed_input( + loss=loss, + input_ids=input_ids, + attention_mask=attention_mask, + end=end, + iterative_size=iterative_size, + threshold=threshold, + ) + + def test_short_prompt_actually_compresses(self): + """A 50-token prompt with the default iterative_size=200 must drop + tokens whose loss is below the threshold. Before the fix, every + token was kept — the achieved compression rate was 1.0x regardless + of the thresholding decision.""" + prompt_len = 50 + iterative_size = 200 + # Half the tokens have loss well above the threshold, half well below. + # Without the fix, need_idx[: end - iterative_size] = 1 (i.e. + # need_idx[:-150] = 1 which wraps to need_idx[:len-150] = 1 on a + # 51-element tensor) keeps every position. + loss = torch.tensor( + [10.0 if i % 2 == 0 else 0.0 for i in range(prompt_len)], + dtype=torch.float32, + ) + input_ids = torch.arange(prompt_len, dtype=torch.long).unsqueeze(0) + + ( + compressed_input_ids, + _compressed_attention_mask, + _keep_flag, + _end, + _loss_out, + _self_loss_out, + _self_ids, + _self_mask, + ) = self._run( + loss=loss, + input_ids=input_ids, + end=prompt_len, + iterative_size=iterative_size, + threshold=1.0, + ) + + self.assertLess( + compressed_input_ids.shape[1], + prompt_len, + "short prompt was not compressed at all — issue #196 regression", + ) + # Every kept token had loss > threshold (even indices). + kept = set(compressed_input_ids[0].tolist()) + odd_indices = {i for i in range(prompt_len) if i % 2 == 1} + self.assertFalse( + kept & odd_indices, + f"below-threshold tokens leaked into the compressed output: " + f"{sorted(kept & odd_indices)}", + ) + + def test_prompt_exactly_at_iterative_size_still_compresses(self): + prompt_len = 200 + iterative_size = 200 + loss = torch.tensor( + [10.0 if i % 3 == 0 else 0.0 for i in range(prompt_len)], + dtype=torch.float32, + ) + input_ids = torch.arange(prompt_len, dtype=torch.long).unsqueeze(0) + + (compressed_input_ids, *_) = self._run( + loss=loss, + input_ids=input_ids, + end=prompt_len, + iterative_size=iterative_size, + threshold=1.0, + ) + self.assertLess(compressed_input_ids.shape[1], prompt_len) + + def test_prompt_longer_than_iterative_size_is_untouched_by_fix(self): + """Long prompts must behave exactly as before — the clamp only + kicks in for short prompts.""" + prompt_len = 300 + iterative_size = 200 + loss = torch.zeros(prompt_len, dtype=torch.float32) + # Last 200 tokens have high loss; first 100 low loss. With + # end=prompt_len=300 and iterative_size=200, the window is [100, 300). + loss[100:] = 10.0 + input_ids = torch.arange(prompt_len, dtype=torch.long).unsqueeze(0) + + (compressed_input_ids, *_) = self._run( + loss=loss, + input_ids=input_ids, + end=prompt_len, + iterative_size=iterative_size, + threshold=1.0, + ) + # Nothing should drop because the [:end - iterative_size] block + # keeps the first 100 and the threshold keeps the rest. + self.assertEqual(compressed_input_ids.shape[1], prompt_len) + + def test_very_short_prompt_below_66_tokens(self): + """The reporter specifically calls out that prompts below ~66 tokens + weren't compressed at all. Anchor that boundary.""" + prompt_len = 40 + iterative_size = 200 + loss = torch.tensor( + [10.0 if i < 20 else 0.0 for i in range(prompt_len)], + dtype=torch.float32, + ) + input_ids = torch.arange(prompt_len, dtype=torch.long).unsqueeze(0) + + (compressed_input_ids, *_) = self._run( + loss=loss, + input_ids=input_ids, + end=prompt_len, + iterative_size=iterative_size, + threshold=1.0, + ) + self.assertLess(compressed_input_ids.shape[1], prompt_len) + kept = set(compressed_input_ids[0].tolist()) + below_threshold = {i for i in range(20, prompt_len)} + self.assertFalse( + kept & below_threshold, + "below-threshold tokens leaked in — #196 regression at <66 tokens", + ) + + +if __name__ == "__main__": + unittest.main()