From c0a99936c05ae3a67f1ece07c21116bfb9847a70 Mon Sep 17 00:00:00 2001 From: Ousama Ben Younes Date: Sat, 11 Apr 2026 18:39:59 +0000 Subject: [PATCH] fix: guard token-count divisions against ZeroDivisionError (#183) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit compress_prompt_llmlingua2 and structured_compress_prompt computed target_token / n_original_token (and its siblings) in four spots. Under certain structured-compress payloads — e.g. a compress_json call where every key is force-reserved or the context resolves to empty strings — the denominator collapses to zero and the call dies with "ZeroDivisionError: float division by zero" before the user ever sees a compressed result. Wrap each division in an `if n > 0` guard that falls back to rate=1.0 (i.e. keep everything) when no source tokens are available. Covered sites: structured_compress_prompt (context_tokens_length), compress_prompt_llmlingua2 (n_original_token at the context-level and token-level paths, n_reserved_token at the context-level path). Generated by Claude Code Vibe coded by ousamabenyounes Co-Authored-By: Claude --- llmlingua/prompt_compressor.py | 25 +++++-- tests/test_issue_183.py | 128 +++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 5 deletions(-) create mode 100644 tests/test_issue_183.py diff --git a/llmlingua/prompt_compressor.py b/llmlingua/prompt_compressor.py index 84e390e..fa78e91 100644 --- a/llmlingua/prompt_compressor.py +++ b/llmlingua/prompt_compressor.py @@ -384,7 +384,12 @@ def structured_compress_prompt( - (question_tokens_length if concate_question else 0) ) else: - rate = target_token / sum(context_tokens_length) + total_context_tokens = sum(context_tokens_length) + rate = ( + target_token / total_context_tokens + if total_context_tokens > 0 + else 1.0 + ) ( context, context_segs, @@ -835,8 +840,10 @@ def compress_prompt_llmlingua2( if target_context >= 0: context_level_rate = min(target_context / len(context), 1.0) if context_level_target_token >= 0: - context_level_rate = min( - context_level_target_token / n_original_token, 1.0 + context_level_rate = ( + min(context_level_target_token / n_original_token, 1.0) + if n_original_token > 0 + else 1.0 ) context_probs, context_words = self.__get_context_prob( @@ -864,7 +871,11 @@ def compress_prompt_llmlingua2( for c in chunks: n_reserved_token += self.get_token_length(c, use_oai_tokenizer=True) if target_token >= 0: - rate = min(target_token / n_reserved_token, 1.0) + rate = ( + min(target_token / n_reserved_token, 1.0) + if n_reserved_token > 0 + else 1.0 + ) if use_token_level_filter: compressed_context, word_list, word_label_list = self.__compress( @@ -922,7 +933,11 @@ def compress_prompt_llmlingua2( return res if target_token > 0: - rate = min(target_token / n_original_token, 1.0) + rate = ( + min(target_token / n_original_token, 1.0) + if n_original_token > 0 + else 1.0 + ) if use_token_level_filter: compressed_context, word_list, word_label_list = self.__compress( diff --git a/tests/test_issue_183.py b/tests/test_issue_183.py new file mode 100644 index 0000000..ebbaf27 --- /dev/null +++ b/tests/test_issue_183.py @@ -0,0 +1,128 @@ +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +"""Regression test for https://github.com/microsoft/LLMLingua/issues/183 + +Under some structured-compress configurations (all keys forced, empty values, +very short payloads), ``compress_prompt_llmlingua2`` and +``structured_compress_prompt`` used to compute ``target_token / N`` where ``N`` +was the number of source tokens. When ``N`` collapsed to zero the user saw +``ZeroDivisionError: float division by zero``. + +The fix wraps every such division in a ``N > 0`` guard that falls back to a +``rate`` of ``1.0`` (i.e. keep everything). These tests inspect the AST of +``prompt_compressor.py`` so they run offline and do not require loading any +model weights; they also exercise the guard logic directly via small pure +functions so the runtime semantics are pinned too. +""" + +import ast +import os +import unittest + +import llmlingua.prompt_compressor as prompt_compressor_module + + +def _safe_rate(target_token, total_tokens, fallback=1.0): + """Mirror the runtime guard used in prompt_compressor.py.""" + return min(target_token / total_tokens, 1.0) if total_tokens > 0 else fallback + + +class Issue183DivisionGuardRuntimeTest(unittest.TestCase): + """Pin the runtime behaviour of the guard expression.""" + + def test_guard_returns_fallback_on_zero(self): + self.assertEqual(_safe_rate(target_token=10, total_tokens=0), 1.0) + + def test_guard_computes_rate_when_positive(self): + self.assertAlmostEqual( + _safe_rate(target_token=10, total_tokens=40), + 0.25, + ) + + def test_guard_clamps_to_one_when_target_exceeds_total(self): + self.assertEqual(_safe_rate(target_token=100, total_tokens=40), 1.0) + + def test_guard_does_not_raise_on_zero(self): + try: + result = _safe_rate(target_token=1, total_tokens=0) + except ZeroDivisionError: # pragma: no cover + self.fail("guard expression raised ZeroDivisionError on 0 tokens") + self.assertEqual(result, 1.0) + + +class Issue183DivisionGuardSourceTest(unittest.TestCase): + """Pin the shape of the fix against the prompt_compressor source tree.""" + + @classmethod + def setUpClass(cls): + source_path = prompt_compressor_module.__file__ + with open(source_path, "r", encoding="utf-8") as f: + cls.source = f.read() + cls.tree = ast.parse(cls.source, filename=os.path.basename(source_path)) + + def _divisions_by(self, divisor_name): + """Return every BinOp that divides by the given Name.""" + hits = [] + for node in ast.walk(self.tree): + if ( + isinstance(node, ast.BinOp) + and isinstance(node.op, ast.Div) + and isinstance(node.right, ast.Name) + and node.right.id == divisor_name + ): + hits.append(node) + return hits + + def _find_enclosing_if_exp(self, target): + """Walk upwards until we find an IfExp whose body contains ``target``.""" + for node in ast.walk(self.tree): + if isinstance(node, ast.IfExp): + for child in ast.walk(node.body): + if child is target: + return node + return None + + def _assert_guarded(self, divisor_name): + divisions = self._divisions_by(divisor_name) + self.assertGreaterEqual( + len(divisions), + 1, + f"expected at least one division by {divisor_name} in prompt_compressor.py", + ) + for div in divisions: + if_exp = self._find_enclosing_if_exp(div) + self.assertIsNotNone( + if_exp, + f"division by {divisor_name} at line {div.lineno} is not inside " + f"an IfExp guard — a ZeroDivisionError can leak through.", + ) + test = if_exp.test + self.assertTrue( + isinstance(test, ast.Compare) + and isinstance(test.left, ast.Name) + and test.left.id == divisor_name + and len(test.ops) == 1 + and isinstance(test.ops[0], ast.Gt), + f"division by {divisor_name} at line {div.lineno} is inside an " + f"IfExp but the guard is not `{divisor_name} > 0`; got " + f"{ast.dump(test)}", + ) + + def test_n_original_token_divisions_are_guarded(self): + self._assert_guarded("n_original_token") + + def test_n_reserved_token_divisions_are_guarded(self): + self._assert_guarded("n_reserved_token") + + def test_total_context_tokens_division_is_guarded(self): + self._assert_guarded("total_context_tokens") + + def test_issue_183_referenced_in_source_or_tests(self): + """Make sure the fix is traceable back to the original issue.""" + combined = self.source + open(__file__, "r", encoding="utf-8").read() + self.assertIn("183", combined) + + +if __name__ == "__main__": + unittest.main()