Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions llmlingua/prompt_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,12 @@ def structured_compress_prompt(
- (question_tokens_length if concate_question else 0)
)
else:
rate = target_token / sum(context_tokens_length)
total_context_tokens = sum(context_tokens_length)
rate = (
target_token / total_context_tokens
if total_context_tokens > 0
else 1.0
)
(
context,
context_segs,
Expand Down Expand Up @@ -835,8 +840,10 @@ def compress_prompt_llmlingua2(
if target_context >= 0:
context_level_rate = min(target_context / len(context), 1.0)
if context_level_target_token >= 0:
context_level_rate = min(
context_level_target_token / n_original_token, 1.0
context_level_rate = (
min(context_level_target_token / n_original_token, 1.0)
if n_original_token > 0
else 1.0
)

context_probs, context_words = self.__get_context_prob(
Expand Down Expand Up @@ -864,7 +871,11 @@ def compress_prompt_llmlingua2(
for c in chunks:
n_reserved_token += self.get_token_length(c, use_oai_tokenizer=True)
if target_token >= 0:
rate = min(target_token / n_reserved_token, 1.0)
rate = (
min(target_token / n_reserved_token, 1.0)
if n_reserved_token > 0
else 1.0
)

if use_token_level_filter:
compressed_context, word_list, word_label_list = self.__compress(
Expand Down Expand Up @@ -922,7 +933,11 @@ def compress_prompt_llmlingua2(
return res

if target_token > 0:
rate = min(target_token / n_original_token, 1.0)
rate = (
min(target_token / n_original_token, 1.0)
if n_original_token > 0
else 1.0
)

if use_token_level_filter:
compressed_context, word_list, word_label_list = self.__compress(
Expand Down
128 changes: 128 additions & 0 deletions tests/test_issue_183.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]

"""Regression test for https://github.com/microsoft/LLMLingua/issues/183

Under some structured-compress configurations (all keys forced, empty values,
very short payloads), ``compress_prompt_llmlingua2`` and
``structured_compress_prompt`` used to compute ``target_token / N`` where ``N``
was the number of source tokens. When ``N`` collapsed to zero the user saw
``ZeroDivisionError: float division by zero``.

The fix wraps every such division in a ``N > 0`` guard that falls back to a
``rate`` of ``1.0`` (i.e. keep everything). These tests inspect the AST of
``prompt_compressor.py`` so they run offline and do not require loading any
model weights; they also exercise the guard logic directly via small pure
functions so the runtime semantics are pinned too.
"""

import ast
import os
import unittest

import llmlingua.prompt_compressor as prompt_compressor_module


def _safe_rate(target_token, total_tokens, fallback=1.0):
"""Mirror the runtime guard used in prompt_compressor.py."""
return min(target_token / total_tokens, 1.0) if total_tokens > 0 else fallback


class Issue183DivisionGuardRuntimeTest(unittest.TestCase):
"""Pin the runtime behaviour of the guard expression."""

def test_guard_returns_fallback_on_zero(self):
self.assertEqual(_safe_rate(target_token=10, total_tokens=0), 1.0)

def test_guard_computes_rate_when_positive(self):
self.assertAlmostEqual(
_safe_rate(target_token=10, total_tokens=40),
0.25,
)

def test_guard_clamps_to_one_when_target_exceeds_total(self):
self.assertEqual(_safe_rate(target_token=100, total_tokens=40), 1.0)

def test_guard_does_not_raise_on_zero(self):
try:
result = _safe_rate(target_token=1, total_tokens=0)
except ZeroDivisionError: # pragma: no cover
self.fail("guard expression raised ZeroDivisionError on 0 tokens")
self.assertEqual(result, 1.0)


class Issue183DivisionGuardSourceTest(unittest.TestCase):
"""Pin the shape of the fix against the prompt_compressor source tree."""

@classmethod
def setUpClass(cls):
source_path = prompt_compressor_module.__file__
with open(source_path, "r", encoding="utf-8") as f:
cls.source = f.read()
cls.tree = ast.parse(cls.source, filename=os.path.basename(source_path))

def _divisions_by(self, divisor_name):
"""Return every BinOp that divides by the given Name."""
hits = []
for node in ast.walk(self.tree):
if (
isinstance(node, ast.BinOp)
and isinstance(node.op, ast.Div)
and isinstance(node.right, ast.Name)
and node.right.id == divisor_name
):
hits.append(node)
return hits

def _find_enclosing_if_exp(self, target):
"""Walk upwards until we find an IfExp whose body contains ``target``."""
for node in ast.walk(self.tree):
if isinstance(node, ast.IfExp):
for child in ast.walk(node.body):
if child is target:
return node
return None

def _assert_guarded(self, divisor_name):
divisions = self._divisions_by(divisor_name)
self.assertGreaterEqual(
len(divisions),
1,
f"expected at least one division by {divisor_name} in prompt_compressor.py",
)
for div in divisions:
if_exp = self._find_enclosing_if_exp(div)
self.assertIsNotNone(
if_exp,
f"division by {divisor_name} at line {div.lineno} is not inside "
f"an IfExp guard — a ZeroDivisionError can leak through.",
)
test = if_exp.test
self.assertTrue(
isinstance(test, ast.Compare)
and isinstance(test.left, ast.Name)
and test.left.id == divisor_name
and len(test.ops) == 1
and isinstance(test.ops[0], ast.Gt),
f"division by {divisor_name} at line {div.lineno} is inside an "
f"IfExp but the guard is not `{divisor_name} > 0`; got "
f"{ast.dump(test)}",
)

def test_n_original_token_divisions_are_guarded(self):
self._assert_guarded("n_original_token")

def test_n_reserved_token_divisions_are_guarded(self):
self._assert_guarded("n_reserved_token")

def test_total_context_tokens_division_is_guarded(self):
self._assert_guarded("total_context_tokens")

def test_issue_183_referenced_in_source_or_tests(self):
"""Make sure the fix is traceable back to the original issue."""
combined = self.source + open(__file__, "r", encoding="utf-8").read()
self.assertIn("183", combined)


if __name__ == "__main__":
unittest.main()