Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions llmlingua/prompt_compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2086,8 +2086,11 @@ def segment_structured_context(
if not text.endswith("</llmlingua>"):
text = text + "</llmlingua>"

# Regular expression to match <llmlingua, rate=x, compress=y>content</llmlingua>, allowing rate and compress in any order
pattern = r"<llmlingua\s*(?:,\s*rate\s*=\s*([\d\.]+))?\s*(?:,\s*compress\s*=\s*(True|False))?\s*(?:,\s*rate\s*=\s*([\d\.]+))?\s*(?:,\s*compress\s*=\s*(True|False))?\s*>([^<]+)</llmlingua>"
# Regular expression to match <llmlingua, rate=x, compress=y>content</llmlingua>, allowing rate and compress in any order.
# The content group uses a non-greedy pattern that allows inner HTML/XML tags (e.g. <tag>...</tag>)
# while still stopping at </llmlingua>. The original [^<]+ would drop the last segment
# whenever the content contained any nested tag.
pattern = r"<llmlingua\s*(?:,\s*rate\s*=\s*([\d\.]+))?\s*(?:,\s*compress\s*=\s*(True|False))?\s*(?:,\s*rate\s*=\s*([\d\.]+))?\s*(?:,\s*compress\s*=\s*(True|False))?\s*>((?:[^<]*(?:<(?!/llmlingua>)[^>]*>)?)*?)</llmlingua>"
matches = re.findall(pattern, text)

# Extracting segment contents
Expand Down
120 changes: 120 additions & 0 deletions tests/test_nested_tag_regex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) 2023 Microsoft
# Licensed under The MIT License [see LICENSE for details]

"""
Unit tests for issue #201: the regex pattern in segment_structured_context
must handle nested tags inside <llmlingua>...</llmlingua> blocks.

These tests exercise the pattern directly (no model loading needed).
"""

import re
import sys
import unittest
from unittest.mock import MagicMock


def _install_mock_modules():
"""Stub heavy ML dependencies so llmlingua imports without torch/transformers.

Note: patching sys.modules here is process-local. pytest --dist=loadfile
(used in the Makefile) ensures this file runs in its own worker so the
stubs do not leak into integration tests that load real models.
"""
for mod_name in [
"torch",
"torch.nn",
"torch.nn.functional",
"torch.utils",
"torch.utils.data",
"transformers",
"numpy",
"numpy.linalg",
"nltk",
"nltk.tokenize",
"accelerate",
"tiktoken",
"yaml",
]:
if mod_name not in sys.modules:
sys.modules[mod_name] = MagicMock()


_install_mock_modules()

# Extract the pattern used in segment_structured_context without loading a model
import importlib.util
import inspect

spec = importlib.util.spec_from_file_location(
"prompt_compressor", "llmlingua/prompt_compressor.py"
)

# Read the pattern directly from the source rather than executing the module
_SOURCE = open("llmlingua/prompt_compressor.py").read()
_PATTERN_LINE = next(
line for line in _SOURCE.splitlines() if "re.findall(pattern, text)" in line or "pattern = r" in line.split("findall")[0]
if 'llmlingua\\s' in line
)
# Safer: use the canonical regex so tests stay in sync with the code
PATTERN = re.search(r'pattern = (r"[^"]*")', _SOURCE).group(1)
PATTERN = eval(PATTERN) # convert r"..." string literal to actual pattern string


class NestedTagRegexTester(unittest.TestCase):
"""Tests for the <llmlingua>...</llmlingua> parsing regex."""

def _findall(self, text):
return re.findall(PATTERN, text)

def test_plain_content_still_matched(self):
"""Regression: plain text inside llmlingua tags must still be captured."""
text = (
"<llmlingua, compress=False>Speaker 4:</llmlingua>"
"<llmlingua, rate=0.4> Thank you. </llmlingua>"
)
matches = self._findall(text)
contents = [m[4] for m in matches]
self.assertEqual(contents, ["Speaker 4:", " Thank you. "])

def test_nested_tag_does_not_drop_segment(self):
"""Fix for #201: content with inner <tag>...</tag> must not truncate the match."""
text = (
"<llmlingua, compress=False>Speaker 4:</llmlingua>"
"<llmlingua, rate=0.6>"
" We have <tag> nested content </tag> here."
"</llmlingua>"
)
matches = self._findall(text)
self.assertEqual(len(matches), 2, "Both segments must be captured")
self.assertIn("nested content", matches[1][4])
self.assertIn("here.", matches[1][4])

def test_multiple_segments_with_and_without_nested_tags(self):
"""Full scenario from the issue report."""
text = (
"<llmlingua, compress=False>Speaker 4:</llmlingua>"
"<llmlingua, rate=0.4> Thank you. </llmlingua>"
"<llmlingua, compress=False>\nSpeaker 0:</llmlingua>"
"<llmlingua, rate=0.4> Item 11. </llmlingua>"
"<llmlingua, compress=False>\nSpeaker 4:</llmlingua>"
"<llmlingua, rate=0.6>"
" We have <tag> nested </tag> and customers."
"</llmlingua>"
)
matches = self._findall(text)
self.assertEqual(len(matches), 6, "All 6 segments must be captured")
self.assertIn("nested", matches[5][4])

def test_rate_attribute_still_captured(self):
"""The rate and compress capture groups must still work with nested tags."""
text = "<llmlingua, rate=0.4> Text with <br/> inline tag </llmlingua>"
matches = self._findall(text)
self.assertEqual(len(matches), 1)
# group 0 or group 2 holds rate (either position)
rate = matches[0][0] or matches[0][2]
self.assertEqual(rate, "0.4")


if __name__ == "__main__":
unittest.main()