From 5ec059b472e736b49b70ae3cf4cc9e380a87d6b6 Mon Sep 17 00:00:00 2001
From: GaoHua <1484391106@qq.com>
Date: Fri, 13 Mar 2026 08:53:00 +0000
Subject: [PATCH 1/3] Fix compatibility issues with model configuration
---
.../benchmark/datasets/utils/datasets.py | 4 +-
.../benchmark/utils/model_postprocessors.py | 155 ++----------------
2 files changed, 21 insertions(+), 138 deletions(-)
diff --git a/ais_bench/benchmark/datasets/utils/datasets.py b/ais_bench/benchmark/datasets/utils/datasets.py
index 6bb31701..41a6594e 100644
--- a/ais_bench/benchmark/datasets/utils/datasets.py
+++ b/ais_bench/benchmark/datasets/utils/datasets.py
@@ -23,7 +23,9 @@
"ais_bench.benchmark.datasets.VocalSoundDataset",
]
# Multimodal APIs.
-MM_APIS = ["ais_bench.benchmark.models.VLLMCustomAPIChat"]
+MM_APIS = ["ais_bench.benchmark.models.VLLMCustomAPIChat",
+ "ais_bench.benchmark.models.VLLMCustomAPIChatStream",
+]
def get_cache_dir(default_dir):
# TODO Add any necessary supplementary information for here
diff --git a/ais_bench/benchmark/utils/model_postprocessors.py b/ais_bench/benchmark/utils/model_postprocessors.py
index 34917705..c3a7192e 100644
--- a/ais_bench/benchmark/utils/model_postprocessors.py
+++ b/ais_bench/benchmark/utils/model_postprocessors.py
@@ -1,152 +1,25 @@
import re
-from functools import partial
-from multiprocessing import Pool
-from typing import Union
-
-from tqdm import tqdm
from ais_bench.benchmark.registry import TEXT_POSTPROCESSORS
-
-from ais_bench.benchmark.utils.postprocess.postprocessors.naive import NaiveExtractor, format_input_naive
-from ais_bench.benchmark.utils.postprocess.postprocessors.xfinder.extractor import Extractor
from ais_bench.benchmark.utils.logging.logger import AISLogger
-from ais_bench.benchmark.utils.postprocess.postprocessors.xfinder.xfinder_utils import (DataProcessor,
- convert_to_xfinder_format)
logger = AISLogger()
logger.warning("ais_bench.benchmark.utils.model_postprocessors is deprecated, please use ais_bench.benchmark.utils.postprocess.model_postprocessors instead.")
-def gen_output_naive(ori_data, extractor):
- extracted_answers = []
- for item in tqdm(ori_data):
- user_input = extractor.prepare_input(item)
- extracted_answer = extractor.gen_output(user_input)
- item['extracted_answer'] = extracted_answer
- extracted_answers.append(extracted_answer)
-
- return extracted_answers
-
-
-@TEXT_POSTPROCESSORS.register_module('naive')
-def naive_model_postprocess(preds: list,
- model_name: str,
- custom_instruction: str,
- api_url: Union[str, list],
- num_processes: int = 8,
- **kwargs) -> list:
- """Postprocess the text extracted by custom model.
- Args:
- preds (list): The question, reference answer and model prediction.
- model_name (str): The name of the model.
- custom_instruction (str): Custom instruction for the dataset.
- url (Union[str, list]): The api url of the model.
-
- Returns:
- list: The postprocessed answers.
- """
-
- def _eval_pred(texts, extractor, num_processes):
- ori_data = texts
- extracted_answers = []
- batched_ori_data = []
- # Split data into batches
- num_processes = min(num_processes, len(ori_data))
- batch_size = len(ori_data) // num_processes
- for i in range(0, len(ori_data), batch_size):
- batched_ori_data.append(ori_data[i:i + batch_size])
- with Pool(num_processes) as p:
- results = p.map(partial(gen_output_naive, extractor=extractor),
- batched_ori_data)
- for result in results:
- extracted_answers.extend(result)
- return extracted_answers
-
- format_data = format_input_naive(preds)
- assert api_url is not None, 'Please provide the api url.'
- extractor = NaiveExtractor(
- model_name=model_name,
- custom_instruction=custom_instruction,
- url=api_url.split(',') if ',' in api_url else api_url)
- calc_acc_func = partial(_eval_pred,
- extractor=extractor,
- num_processes=num_processes)
- extracted_answers = calc_acc_func(format_data)
- return extracted_answers
-
-
-def gen_output_xfinder(ori_data, extractor):
- ext_cor_pairs = []
- extracted_data = []
- extracted_answers = []
- for item in tqdm(ori_data):
- user_input = extractor.prepare_input(item)
- extracted_answer = extractor.gen_output(user_input)
- ext_cor_pairs.append([
- item['key_answer_type'], item['standard_answer_range'],
- extracted_answer, item['correct_answer']
- ])
- item['xfinder_extracted_answer'] = extracted_answer
- extracted_answers.append(extracted_answer)
- extracted_data.append(item)
-
- return extracted_answers, ext_cor_pairs, extracted_data
-
-
-@TEXT_POSTPROCESSORS.register_module('xfinder')
-def xfinder_postprocess(preds: list, question_type: str, model_name: str,
- api_url: Union[str, list], **kwargs) -> list:
- """Postprocess the text extracted by xFinder model.
- Args:
- preds (list): The question, reference answer and model prediction.
- question_type (str): The type of the question.
- url (Union[str, list]): The api url of the xFinder model.
-
-
- Returns:
- list: The postprocessed texts.
- """
-
- def _eval_pred(texts, data_processor, extractor, num_processes=8):
- ori_data = data_processor.read_data(texts)
- extracted_correct_pairs = []
- extracted_data = []
- extracted_answers = []
- batched_ori_data = []
- # Split data into batches
- num_processes = min(num_processes, len(ori_data))
- batch_size = len(ori_data) // num_processes
- for i in range(0, len(ori_data), batch_size):
- batched_ori_data.append(ori_data[i:i + batch_size])
- with Pool(num_processes) as p:
- results = p.map(partial(gen_output_xfinder, extractor=extractor),
- batched_ori_data)
- for result in results:
- extracted_answers += result[0]
- extracted_correct_pairs += result[1]
- extracted_data += result[2]
- return extracted_answers
-
- format_data = convert_to_xfinder_format(question_type, preds)
- assert api_url is not None, 'Please provide the api url.'
- data_processor = DataProcessor()
- extractor = Extractor(
- model_name=model_name,
- url=api_url.split(',') if ',' in api_url else api_url)
- calc_acc_func = partial(_eval_pred,
- data_processor=data_processor,
- extractor=extractor)
- extracted_answers = calc_acc_func(format_data)
- return extracted_answers
-
def list_decorator(func):
"""Decorator: make the function able to handle list input"""
def wrapper(text_or_list, *args, **kwargs):
if isinstance(text_or_list, list):
+ logger.debug(
+ f"list_decorator({func.__name__}): processing list of {len(text_or_list)} item(s)"
+ )
return [func(text, *args, **kwargs) for text in text_or_list]
- else:
- return func(text_or_list, *args, **kwargs)
+ logger.debug(
+ f"list_decorator({func.__name__}): processing single item"
+ )
+ return func(text_or_list, *args, **kwargs)
return wrapper
@@ -180,20 +53,28 @@ def extract_non_reasoning_content(
>>> text = "Startreasoning here End"
>>> extract_non_reasoning_content(text)
'Start End'
-
+
>>> # When input is a list
>>> texts = ["Startreasoning End", "Test Result"]
>>> extract_non_reasoning_content(texts)
['Start End', 'Result']
"""
+ logger.debug(f"extract_non_reasoning_content: start_token='{think_start_token}', end_token='{think_end_token}'")
# If text contains only end token, split by end token and take the last part
if not isinstance(text, str):
return text
if think_start_token not in text and think_end_token in text:
- return text.split(think_end_token)[-1].strip()
+ result = text.split(think_end_token)[-1].strip()
+ logger.debug(
+ f"extract_non_reasoning_content: only end token present -> length={len(result)}"
+ )
+ return result
# Original behavior for complete tag pairs
reasoning_regex = re.compile(rf'{think_start_token}(.*?){think_end_token}',
re.DOTALL)
non_reasoning_content = reasoning_regex.sub('', text).strip()
- return non_reasoning_content
\ No newline at end of file
+ logger.debug(
+ f"extract_non_reasoning_content: removed reasoning sections -> length={len(non_reasoning_content)}"
+ )
+ return non_reasoning_content
From 510fe1c2c966ed9fa8697d0b484c577559eed581 Mon Sep 17 00:00:00 2001
From: GaoHua <1484391106@qq.com>
Date: Fri, 13 Mar 2026 09:29:19 +0000
Subject: [PATCH 2/3] del not use UT
---
tests/UT/utils/test_model_postprocessors.py | 125 --------------------
1 file changed, 125 deletions(-)
diff --git a/tests/UT/utils/test_model_postprocessors.py b/tests/UT/utils/test_model_postprocessors.py
index be10136b..ebaf122c 100644
--- a/tests/UT/utils/test_model_postprocessors.py
+++ b/tests/UT/utils/test_model_postprocessors.py
@@ -41,131 +41,6 @@ def setUp(self):
}
]
- @patch('ais_bench.benchmark.utils.model_postprocessors.Pool')
- @patch('ais_bench.benchmark.utils.model_postprocessors.NaiveExtractor')
- @patch('ais_bench.benchmark.utils.model_postprocessors.format_input_naive')
- def test_naive_model_postprocess(self, mock_format, mock_extractor_class, mock_pool):
- """Test naive_model_postprocess function."""
- # Setup mocks
- mock_format.return_value = self.test_preds
- mock_extractor = MagicMock()
- mock_extractor_class.return_value = mock_extractor
-
- # Mock the extractor's methods
- def mock_gen_output(ori_data, extractor):
- results = []
- for item in ori_data:
- extracted = extractor.gen_output(item)
- results.append(extracted)
- return results
-
- mock_extractor.prepare_input.return_value = "prepared_input"
- mock_extractor.gen_output.return_value = "extracted_answer"
-
- # Mock Pool
- mock_pool_instance = MagicMock()
- mock_pool.return_value.__enter__.return_value = mock_pool_instance
- mock_pool.return_value.__exit__.return_value = False
-
- # Mock map to call the function directly
- def mock_map(func, batches):
- results = []
- for batch in batches:
- batch_result = func(batch)
- results.append(batch_result)
- return results
-
- mock_pool_instance.map = mock_map
-
- # Call the function
- result = mp.naive_model_postprocess(
- preds=self.test_preds,
- model_name="test_model",
- custom_instruction="Extract the answer",
- api_url="http://test.api",
- num_processes=2
- )
-
- # Verify
- self.assertIsInstance(result, list)
- self.assertEqual(len(result), 2)
- mock_format.assert_called_once_with(self.test_preds)
- mock_extractor_class.assert_called_once()
-
- @patch('ais_bench.benchmark.utils.model_postprocessors.Pool')
- @patch('ais_bench.benchmark.utils.model_postprocessors.Extractor')
- @patch('ais_bench.benchmark.utils.model_postprocessors.DataProcessor')
- @patch('ais_bench.benchmark.utils.model_postprocessors.convert_to_xfinder_format')
- def test_xfinder_postprocess(self, mock_convert, mock_data_processor_class,
- mock_extractor_class, mock_pool):
- """Test xfinder_postprocess function."""
- # Setup mocks
- mock_convert.return_value = self.test_preds
- mock_data_processor = MagicMock()
- mock_data_processor_class.return_value = mock_data_processor
- mock_data_processor.read_data.return_value = [
- {
- 'key_answer_type': 'number',
- 'standard_answer_range': '0-10',
- 'correct_answer': '4'
- }
- ]
-
- mock_extractor = MagicMock()
- mock_extractor_class.return_value = mock_extractor
- mock_extractor.prepare_input.return_value = "prepared_input"
- mock_extractor.gen_output.return_value = "extracted_answer"
-
- # Mock Pool
- mock_pool_instance = MagicMock()
- mock_pool.return_value.__enter__.return_value = mock_pool_instance
- mock_pool.return_value.__exit__.return_value = False
-
- def mock_map(func, batches):
- results = []
- for batch in batches:
- batch_result = func(batch)
- results.append(batch_result)
- return results
-
- mock_pool_instance.map = mock_map
-
- # Call the function
- result = mp.xfinder_postprocess(
- preds=self.test_preds,
- question_type="math",
- model_name="test_model",
- api_url="http://test.api"
- )
-
- # Verify
- self.assertIsInstance(result, list)
- mock_convert.assert_called_once()
- mock_data_processor_class.assert_called_once()
- mock_extractor_class.assert_called_once()
-
- def test_naive_model_postprocess_no_api_url(self):
- """Test naive_model_postprocess raises error when api_url is None."""
- with self.assertRaises(AssertionError) as cm:
- mp.naive_model_postprocess(
- preds=self.test_preds,
- model_name="test_model",
- custom_instruction="Extract",
- api_url=None
- )
- self.assertIn("api url", str(cm.exception).lower())
-
- def test_xfinder_postprocess_no_api_url(self):
- """Test xfinder_postprocess raises error when api_url is None."""
- with self.assertRaises(AssertionError) as cm:
- mp.xfinder_postprocess(
- preds=self.test_preds,
- question_type="math",
- model_name="test_model",
- api_url=None
- )
- self.assertIn("api url", str(cm.exception).lower())
-
def test_list_decorator_with_list(self):
"""Test list_decorator with list input."""
@mp.list_decorator
From f699f77180fa2c655f28df76c214a8b7cec16394 Mon Sep 17 00:00:00 2001
From: zhang GaoHua <73919261+GaoHuaZhang@users.noreply.github.com>
Date: Mon, 16 Mar 2026 10:27:04 +0800
Subject: [PATCH 3/3] fix logger.debug calls for consistency
---
.../benchmark/utils/model_postprocessors.py | 16 ++++------------
1 file changed, 4 insertions(+), 12 deletions(-)
diff --git a/ais_bench/benchmark/utils/model_postprocessors.py b/ais_bench/benchmark/utils/model_postprocessors.py
index c3a7192e..ff90180d 100644
--- a/ais_bench/benchmark/utils/model_postprocessors.py
+++ b/ais_bench/benchmark/utils/model_postprocessors.py
@@ -12,13 +12,9 @@ def list_decorator(func):
"""Decorator: make the function able to handle list input"""
def wrapper(text_or_list, *args, **kwargs):
if isinstance(text_or_list, list):
- logger.debug(
- f"list_decorator({func.__name__}): processing list of {len(text_or_list)} item(s)"
- )
+ logger.debug(f"list_decorator({func.__name__}): processing list of {len(text_or_list)} item(s)")
return [func(text, *args, **kwargs) for text in text_or_list]
- logger.debug(
- f"list_decorator({func.__name__}): processing single item"
- )
+ logger.debug(f"list_decorator({func.__name__}): processing single item")
return func(text_or_list, *args, **kwargs)
return wrapper
@@ -65,16 +61,12 @@ def extract_non_reasoning_content(
return text
if think_start_token not in text and think_end_token in text:
result = text.split(think_end_token)[-1].strip()
- logger.debug(
- f"extract_non_reasoning_content: only end token present -> length={len(result)}"
- )
+ logger.debug(f"extract_non_reasoning_content: only end token present -> length={len(result)}")
return result
# Original behavior for complete tag pairs
reasoning_regex = re.compile(rf'{think_start_token}(.*?){think_end_token}',
re.DOTALL)
non_reasoning_content = reasoning_regex.sub('', text).strip()
- logger.debug(
- f"extract_non_reasoning_content: removed reasoning sections -> length={len(non_reasoning_content)}"
- )
+ logger.debug(f"extract_non_reasoning_content: removed reasoning sections -> length={len(non_reasoning_content)}")
return non_reasoning_content