diff --git a/llmlingua/prompt_compressor.py b/llmlingua/prompt_compressor.py index 84e390e..8fea23c 100644 --- a/llmlingua/prompt_compressor.py +++ b/llmlingua/prompt_compressor.py @@ -2086,8 +2086,11 @@ def segment_structured_context( if not text.endswith(""): text = text + "" - # Regular expression to match content, allowing rate and compress in any order - pattern = r"([^<]+)" + # Regular expression to match content, allowing rate and compress in any order. + # The content group uses a non-greedy pattern that allows inner HTML/XML tags (e.g. ...) + # while still stopping at . The original [^<]+ would drop the last segment + # whenever the content contained any nested tag. + pattern = r"((?:[^<]*(?:<(?!/llmlingua>)[^>]*>)?)*?)" matches = re.findall(pattern, text) # Extracting segment contents diff --git a/tests/test_nested_tag_regex.py b/tests/test_nested_tag_regex.py new file mode 100644 index 0000000..df3c3ab --- /dev/null +++ b/tests/test_nested_tag_regex.py @@ -0,0 +1,120 @@ +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +""" +Unit tests for issue #201: the regex pattern in segment_structured_context +must handle nested tags inside ... blocks. + +These tests exercise the pattern directly (no model loading needed). +""" + +import re +import sys +import unittest +from unittest.mock import MagicMock + + +def _install_mock_modules(): + """Stub heavy ML dependencies so llmlingua imports without torch/transformers. + + Note: patching sys.modules here is process-local. pytest --dist=loadfile + (used in the Makefile) ensures this file runs in its own worker so the + stubs do not leak into integration tests that load real models. + """ + for mod_name in [ + "torch", + "torch.nn", + "torch.nn.functional", + "torch.utils", + "torch.utils.data", + "transformers", + "numpy", + "numpy.linalg", + "nltk", + "nltk.tokenize", + "accelerate", + "tiktoken", + "yaml", + ]: + if mod_name not in sys.modules: + sys.modules[mod_name] = MagicMock() + + +_install_mock_modules() + +# Extract the pattern used in segment_structured_context without loading a model +import importlib.util +import inspect + +spec = importlib.util.spec_from_file_location( + "prompt_compressor", "llmlingua/prompt_compressor.py" +) + +# Read the pattern directly from the source rather than executing the module +_SOURCE = open("llmlingua/prompt_compressor.py").read() +_PATTERN_LINE = next( + line for line in _SOURCE.splitlines() if "re.findall(pattern, text)" in line or "pattern = r" in line.split("findall")[0] + if 'llmlingua\\s' in line +) +# Safer: use the canonical regex so tests stay in sync with the code +PATTERN = re.search(r'pattern = (r"[^"]*")', _SOURCE).group(1) +PATTERN = eval(PATTERN) # convert r"..." string literal to actual pattern string + + +class NestedTagRegexTester(unittest.TestCase): + """Tests for the ... parsing regex.""" + + def _findall(self, text): + return re.findall(PATTERN, text) + + def test_plain_content_still_matched(self): + """Regression: plain text inside llmlingua tags must still be captured.""" + text = ( + "Speaker 4:" + " Thank you. " + ) + matches = self._findall(text) + contents = [m[4] for m in matches] + self.assertEqual(contents, ["Speaker 4:", " Thank you. "]) + + def test_nested_tag_does_not_drop_segment(self): + """Fix for #201: content with inner ... must not truncate the match.""" + text = ( + "Speaker 4:" + "" + " We have nested content here." + "" + ) + matches = self._findall(text) + self.assertEqual(len(matches), 2, "Both segments must be captured") + self.assertIn("nested content", matches[1][4]) + self.assertIn("here.", matches[1][4]) + + def test_multiple_segments_with_and_without_nested_tags(self): + """Full scenario from the issue report.""" + text = ( + "Speaker 4:" + " Thank you. " + "\nSpeaker 0:" + " Item 11. " + "\nSpeaker 4:" + "" + " We have nested and customers." + "" + ) + matches = self._findall(text) + self.assertEqual(len(matches), 6, "All 6 segments must be captured") + self.assertIn("nested", matches[5][4]) + + def test_rate_attribute_still_captured(self): + """The rate and compress capture groups must still work with nested tags.""" + text = " Text with
inline tag " + matches = self._findall(text) + self.assertEqual(len(matches), 1) + # group 0 or group 2 holds rate (either position) + rate = matches[0][0] or matches[0][2] + self.assertEqual(rate, "0.4") + + +if __name__ == "__main__": + unittest.main()