From 347f21ea77e15c7e6bbcb007670b6a23d0f4a0cf Mon Sep 17 00:00:00 2001 From: Ousama Ben Younes Date: Sat, 11 Apr 2026 19:01:01 +0000 Subject: [PATCH] fix: dispatch LLMLingua-2 models by loosened name markers (#168) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit is_begin_of_new_word and get_pure_token in llmlingua/utils.py used to dispatch on literal substrings — "bert-base-multilingual-cased" for the BERT family and "xlm-roberta-large" for the XLM-RoBERTa family. Users who download LLMLingua-2 weights and load them from a renamed local directory (for example "/data/models/llmlingua-2-xlm" as in the reporter's setup) fall straight through to `raise NotImplementedError` deep inside `compress_prompt_llmlingua2`, with no hint of what went wrong. Factor the dispatch into `_model_family` and broaden the marker lists to include the shorter common prefixes (`xlm-roberta`, `bert-base-multilingual`), plus explicit `llmlingua-2-xlm` and `llmlingua-2-bert` local-path variants. The canonical names, tinybert/mobilebert, slingua and securitylingua all keep working unchanged, and the fallback NotImplementedError now lists the supported markers so diagnosing a bad model_name no longer requires reading the source. Generated by Claude Code Vibe coded by ousamabenyounes Co-Authored-By: Claude --- llmlingua/utils.py | 63 +++++++++++++++----- tests/test_issue_168.py | 129 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 178 insertions(+), 14 deletions(-) create mode 100644 tests/test_issue_168.py diff --git a/llmlingua/utils.py b/llmlingua/utils.py index da9986e..9d52705 100644 --- a/llmlingua/utils.py +++ b/llmlingua/utils.py @@ -78,18 +78,48 @@ def seed_everything(seed: int): torch.backends.cudnn.benchmark = False +_BERT_FAMILY_MARKERS = ( + "bert-base-multilingual", + "tinybert", + "mobilebert", + "llmlingua-2-bert", +) + +_XLM_ROBERTA_FAMILY_MARKERS = ( + "xlm-roberta", + "slingua", + "securitylingua", + "llmlingua-2-xlm", +) + + +def _model_family(model_name): + """Classify a model name / local path as 'bert' or 'xlm-roberta'. + + Users frequently load LLMLingua-2 weights from a renamed local directory + (e.g. ``/path/to/llmlingua-2-xlm``) — see issue #168 — so the markers + below intentionally match on shorter substrings than the full + ``xlm-roberta-large`` / ``bert-base-multilingual-cased`` canonical names. + """ + if not model_name: + return None + needle = model_name.lower() + if any(m in needle for m in _BERT_FAMILY_MARKERS): + return "bert" + if any(m in needle for m in _XLM_ROBERTA_FAMILY_MARKERS): + return "xlm-roberta" + return None + + def is_begin_of_new_word(token, model_name, force_tokens, token_map): - if "bert-base-multilingual-cased" in model_name \ - or "tinybert" in model_name.lower() \ - or "mobilebert" in model_name.lower(): + family = _model_family(model_name) + if family == "bert": if token.lstrip("##") in force_tokens or token.lstrip("##") in set( token_map.values() ): return True return not token.startswith("##") - elif "xlm-roberta-large" in model_name \ - or 'slingua' in model_name.lower() \ - or 'securitylingua' in model_name.lower(): + elif family == "xlm-roberta": if ( token in string.punctuation or token in force_tokens @@ -98,7 +128,11 @@ def is_begin_of_new_word(token, model_name, force_tokens, token_map): return True return token.startswith("▁") else: - raise NotImplementedError() + raise NotImplementedError( + f"Unsupported model_name={model_name!r}. Expected a BERT-family " + f"or XLM-RoBERTa-family LLMLingua-2 model. Known markers: " + f"{_BERT_FAMILY_MARKERS + _XLM_ROBERTA_FAMILY_MARKERS}." + ) def replace_added_token(token, token_map): @@ -108,16 +142,17 @@ def replace_added_token(token, token_map): def get_pure_token(token, model_name): - if "bert-base-multilingual-cased" in model_name \ - or "tinybert" in model_name.lower() \ - or "mobilebert" in model_name.lower(): + family = _model_family(model_name) + if family == "bert": return token.lstrip("##") - elif "xlm-roberta-large" in model_name \ - or 'slingua' in model_name.lower() \ - or 'securitylingua' in model_name.lower(): + elif family == "xlm-roberta": return token.lstrip("▁") else: - raise NotImplementedError() + raise NotImplementedError( + f"Unsupported model_name={model_name!r}. Expected a BERT-family " + f"or XLM-RoBERTa-family LLMLingua-2 model. Known markers: " + f"{_BERT_FAMILY_MARKERS + _XLM_ROBERTA_FAMILY_MARKERS}." + ) def process_structured_json_data(json_data, json_config): diff --git a/tests/test_issue_168.py b/tests/test_issue_168.py new file mode 100644 index 0000000..32d3e48 --- /dev/null +++ b/tests/test_issue_168.py @@ -0,0 +1,129 @@ +# Copyright (c) 2023 Microsoft +# Licensed under The MIT License [see LICENSE for details] + +"""Regression test for https://github.com/microsoft/LLMLingua/issues/168 + +Users often download LLMLingua-2 weights and load them from a renamed local +directory — the reporter's path was +``/home/webservice/llm/compressFromNet/llmlingua-2-xlm``. The old dispatch +in ``is_begin_of_new_word`` / ``get_pure_token`` only matched the literal +canonical names ``bert-base-multilingual-cased`` and ``xlm-roberta-large``, +so those local-path setups crashed with ``NotImplementedError`` deep inside +``compress_prompt_llmlingua2``. + +The fix broadens the substring dispatch via ``_model_family`` so that common +local-path renames (``llmlingua-2-xlm``, ``llmlingua-2-bert``, +``xlm-roberta`` without ``-large``, etc.) still resolve to the correct +tokenizer family. These tests pin that behaviour. +""" + +import unittest + +from llmlingua.utils import ( + _BERT_FAMILY_MARKERS, + _XLM_ROBERTA_FAMILY_MARKERS, + _model_family, + get_pure_token, + is_begin_of_new_word, +) + + +class Issue168ModelFamilyDispatchTest(unittest.TestCase): + def test_reporter_local_path_is_recognized(self): + path = "/home/webservice/llm/compressFromNet/llmlingua-2-xlm" + self.assertEqual(_model_family(path), "xlm-roberta") + + def test_canonical_xlm_roberta_large_still_matches(self): + self.assertEqual( + _model_family("microsoft/llmlingua-2-xlm-roberta-large-meetingbank"), + "xlm-roberta", + ) + + def test_canonical_bert_still_matches(self): + self.assertEqual( + _model_family( + "microsoft/llmlingua-2-bert-base-multilingual-cased-meetingbank" + ), + "bert", + ) + + def test_bare_xlm_roberta_matches(self): + self.assertEqual(_model_family("xlm-roberta-base"), "xlm-roberta") + + def test_llmlingua2_bert_local_path_matches(self): + self.assertEqual( + _model_family("/data/models/llmlingua-2-bert-base-multilingual"), + "bert", + ) + + def test_tinybert_and_mobilebert_still_match(self): + self.assertEqual(_model_family("tinybert-whatever"), "bert") + self.assertEqual(_model_family("MobileBERT"), "bert") + + def test_slingua_still_matches_xlm_roberta(self): + self.assertEqual(_model_family("/path/to/slingua-v1"), "xlm-roberta") + + def test_unknown_model_returns_none(self): + self.assertIsNone(_model_family("unknown-llm-v42")) + + def test_empty_or_none_model_name(self): + self.assertIsNone(_model_family("")) + self.assertIsNone(_model_family(None)) + + def test_marker_lists_are_populated(self): + self.assertIn("xlm-roberta", _XLM_ROBERTA_FAMILY_MARKERS) + self.assertIn("llmlingua-2-xlm", _XLM_ROBERTA_FAMILY_MARKERS) + self.assertIn("bert-base-multilingual", _BERT_FAMILY_MARKERS) + self.assertIn("llmlingua-2-bert", _BERT_FAMILY_MARKERS) + + +class Issue168IsBeginOfNewWordTest(unittest.TestCase): + def test_xlm_roberta_local_path_no_longer_raises(self): + path = "/home/webservice/llm/compressFromNet/llmlingua-2-xlm" + try: + result = is_begin_of_new_word("▁hello", path, force_tokens=[], token_map={}) + except NotImplementedError: + self.fail( + "is_begin_of_new_word raised NotImplementedError on the #168 " + "reporter's local path — fix regressed" + ) + self.assertTrue(result) + + def test_xlm_roberta_local_path_continuation(self): + path = "/home/webservice/llm/compressFromNet/llmlingua-2-xlm" + self.assertFalse( + is_begin_of_new_word("world", path, force_tokens=[], token_map={}) + ) + + def test_bert_local_path_subword(self): + path = "/data/models/llmlingua-2-bert-base-multilingual" + self.assertFalse( + is_begin_of_new_word("##ing", path, force_tokens=[], token_map={}) + ) + self.assertTrue( + is_begin_of_new_word("hello", path, force_tokens=[], token_map={}) + ) + + def test_unknown_model_still_raises(self): + with self.assertRaises(NotImplementedError): + is_begin_of_new_word( + "token", "unknown-model", force_tokens=[], token_map={} + ) + + +class Issue168GetPureTokenTest(unittest.TestCase): + def test_xlm_roberta_local_path_strips_leading_marker(self): + path = "/home/webservice/llm/compressFromNet/llmlingua-2-xlm" + self.assertEqual(get_pure_token("▁hello", path), "hello") + + def test_bert_local_path_strips_hash_prefix(self): + path = "/data/models/llmlingua-2-bert-base-multilingual" + self.assertEqual(get_pure_token("##ing", path), "ing") + + def test_unknown_model_raises(self): + with self.assertRaises(NotImplementedError): + get_pure_token("token", "unknown-model") + + +if __name__ == "__main__": + unittest.main()