diff --git a/ais_bench/benchmark/datasets/utils/datasets.py b/ais_bench/benchmark/datasets/utils/datasets.py index 6bb31701..41a6594e 100644 --- a/ais_bench/benchmark/datasets/utils/datasets.py +++ b/ais_bench/benchmark/datasets/utils/datasets.py @@ -23,7 +23,9 @@ "ais_bench.benchmark.datasets.VocalSoundDataset", ] # Multimodal APIs. -MM_APIS = ["ais_bench.benchmark.models.VLLMCustomAPIChat"] +MM_APIS = ["ais_bench.benchmark.models.VLLMCustomAPIChat", + "ais_bench.benchmark.models.VLLMCustomAPIChatStream", +] def get_cache_dir(default_dir): # TODO Add any necessary supplementary information for here diff --git a/ais_bench/benchmark/utils/model_postprocessors.py b/ais_bench/benchmark/utils/model_postprocessors.py index 34917705..ff90180d 100644 --- a/ais_bench/benchmark/utils/model_postprocessors.py +++ b/ais_bench/benchmark/utils/model_postprocessors.py @@ -1,152 +1,21 @@ import re -from functools import partial -from multiprocessing import Pool -from typing import Union - -from tqdm import tqdm from ais_bench.benchmark.registry import TEXT_POSTPROCESSORS - -from ais_bench.benchmark.utils.postprocess.postprocessors.naive import NaiveExtractor, format_input_naive -from ais_bench.benchmark.utils.postprocess.postprocessors.xfinder.extractor import Extractor from ais_bench.benchmark.utils.logging.logger import AISLogger -from ais_bench.benchmark.utils.postprocess.postprocessors.xfinder.xfinder_utils import (DataProcessor, - convert_to_xfinder_format) logger = AISLogger() logger.warning("ais_bench.benchmark.utils.model_postprocessors is deprecated, please use ais_bench.benchmark.utils.postprocess.model_postprocessors instead.") -def gen_output_naive(ori_data, extractor): - extracted_answers = [] - for item in tqdm(ori_data): - user_input = extractor.prepare_input(item) - extracted_answer = extractor.gen_output(user_input) - item['extracted_answer'] = extracted_answer - extracted_answers.append(extracted_answer) - - return extracted_answers - - -@TEXT_POSTPROCESSORS.register_module('naive') -def naive_model_postprocess(preds: list, - model_name: str, - custom_instruction: str, - api_url: Union[str, list], - num_processes: int = 8, - **kwargs) -> list: - """Postprocess the text extracted by custom model. - Args: - preds (list): The question, reference answer and model prediction. - model_name (str): The name of the model. - custom_instruction (str): Custom instruction for the dataset. - url (Union[str, list]): The api url of the model. - - Returns: - list: The postprocessed answers. - """ - - def _eval_pred(texts, extractor, num_processes): - ori_data = texts - extracted_answers = [] - batched_ori_data = [] - # Split data into batches - num_processes = min(num_processes, len(ori_data)) - batch_size = len(ori_data) // num_processes - for i in range(0, len(ori_data), batch_size): - batched_ori_data.append(ori_data[i:i + batch_size]) - with Pool(num_processes) as p: - results = p.map(partial(gen_output_naive, extractor=extractor), - batched_ori_data) - for result in results: - extracted_answers.extend(result) - return extracted_answers - - format_data = format_input_naive(preds) - assert api_url is not None, 'Please provide the api url.' - extractor = NaiveExtractor( - model_name=model_name, - custom_instruction=custom_instruction, - url=api_url.split(',') if ',' in api_url else api_url) - calc_acc_func = partial(_eval_pred, - extractor=extractor, - num_processes=num_processes) - extracted_answers = calc_acc_func(format_data) - return extracted_answers - - -def gen_output_xfinder(ori_data, extractor): - ext_cor_pairs = [] - extracted_data = [] - extracted_answers = [] - for item in tqdm(ori_data): - user_input = extractor.prepare_input(item) - extracted_answer = extractor.gen_output(user_input) - ext_cor_pairs.append([ - item['key_answer_type'], item['standard_answer_range'], - extracted_answer, item['correct_answer'] - ]) - item['xfinder_extracted_answer'] = extracted_answer - extracted_answers.append(extracted_answer) - extracted_data.append(item) - - return extracted_answers, ext_cor_pairs, extracted_data - - -@TEXT_POSTPROCESSORS.register_module('xfinder') -def xfinder_postprocess(preds: list, question_type: str, model_name: str, - api_url: Union[str, list], **kwargs) -> list: - """Postprocess the text extracted by xFinder model. - Args: - preds (list): The question, reference answer and model prediction. - question_type (str): The type of the question. - url (Union[str, list]): The api url of the xFinder model. - - - Returns: - list: The postprocessed texts. - """ - - def _eval_pred(texts, data_processor, extractor, num_processes=8): - ori_data = data_processor.read_data(texts) - extracted_correct_pairs = [] - extracted_data = [] - extracted_answers = [] - batched_ori_data = [] - # Split data into batches - num_processes = min(num_processes, len(ori_data)) - batch_size = len(ori_data) // num_processes - for i in range(0, len(ori_data), batch_size): - batched_ori_data.append(ori_data[i:i + batch_size]) - with Pool(num_processes) as p: - results = p.map(partial(gen_output_xfinder, extractor=extractor), - batched_ori_data) - for result in results: - extracted_answers += result[0] - extracted_correct_pairs += result[1] - extracted_data += result[2] - return extracted_answers - - format_data = convert_to_xfinder_format(question_type, preds) - assert api_url is not None, 'Please provide the api url.' - data_processor = DataProcessor() - extractor = Extractor( - model_name=model_name, - url=api_url.split(',') if ',' in api_url else api_url) - calc_acc_func = partial(_eval_pred, - data_processor=data_processor, - extractor=extractor) - extracted_answers = calc_acc_func(format_data) - return extracted_answers - def list_decorator(func): """Decorator: make the function able to handle list input""" def wrapper(text_or_list, *args, **kwargs): if isinstance(text_or_list, list): + logger.debug(f"list_decorator({func.__name__}): processing list of {len(text_or_list)} item(s)") return [func(text, *args, **kwargs) for text in text_or_list] - else: - return func(text_or_list, *args, **kwargs) + logger.debug(f"list_decorator({func.__name__}): processing single item") + return func(text_or_list, *args, **kwargs) return wrapper @@ -180,20 +49,24 @@ def extract_non_reasoning_content( >>> text = "Startreasoning here End" >>> extract_non_reasoning_content(text) 'Start End' - + >>> # When input is a list >>> texts = ["Startreasoning End", "Test Result"] >>> extract_non_reasoning_content(texts) ['Start End', 'Result'] """ + logger.debug(f"extract_non_reasoning_content: start_token='{think_start_token}', end_token='{think_end_token}'") # If text contains only end token, split by end token and take the last part if not isinstance(text, str): return text if think_start_token not in text and think_end_token in text: - return text.split(think_end_token)[-1].strip() + result = text.split(think_end_token)[-1].strip() + logger.debug(f"extract_non_reasoning_content: only end token present -> length={len(result)}") + return result # Original behavior for complete tag pairs reasoning_regex = re.compile(rf'{think_start_token}(.*?){think_end_token}', re.DOTALL) non_reasoning_content = reasoning_regex.sub('', text).strip() - return non_reasoning_content \ No newline at end of file + logger.debug(f"extract_non_reasoning_content: removed reasoning sections -> length={len(non_reasoning_content)}") + return non_reasoning_content diff --git a/tests/UT/utils/test_model_postprocessors.py b/tests/UT/utils/test_model_postprocessors.py index be10136b..ebaf122c 100644 --- a/tests/UT/utils/test_model_postprocessors.py +++ b/tests/UT/utils/test_model_postprocessors.py @@ -41,131 +41,6 @@ def setUp(self): } ] - @patch('ais_bench.benchmark.utils.model_postprocessors.Pool') - @patch('ais_bench.benchmark.utils.model_postprocessors.NaiveExtractor') - @patch('ais_bench.benchmark.utils.model_postprocessors.format_input_naive') - def test_naive_model_postprocess(self, mock_format, mock_extractor_class, mock_pool): - """Test naive_model_postprocess function.""" - # Setup mocks - mock_format.return_value = self.test_preds - mock_extractor = MagicMock() - mock_extractor_class.return_value = mock_extractor - - # Mock the extractor's methods - def mock_gen_output(ori_data, extractor): - results = [] - for item in ori_data: - extracted = extractor.gen_output(item) - results.append(extracted) - return results - - mock_extractor.prepare_input.return_value = "prepared_input" - mock_extractor.gen_output.return_value = "extracted_answer" - - # Mock Pool - mock_pool_instance = MagicMock() - mock_pool.return_value.__enter__.return_value = mock_pool_instance - mock_pool.return_value.__exit__.return_value = False - - # Mock map to call the function directly - def mock_map(func, batches): - results = [] - for batch in batches: - batch_result = func(batch) - results.append(batch_result) - return results - - mock_pool_instance.map = mock_map - - # Call the function - result = mp.naive_model_postprocess( - preds=self.test_preds, - model_name="test_model", - custom_instruction="Extract the answer", - api_url="http://test.api", - num_processes=2 - ) - - # Verify - self.assertIsInstance(result, list) - self.assertEqual(len(result), 2) - mock_format.assert_called_once_with(self.test_preds) - mock_extractor_class.assert_called_once() - - @patch('ais_bench.benchmark.utils.model_postprocessors.Pool') - @patch('ais_bench.benchmark.utils.model_postprocessors.Extractor') - @patch('ais_bench.benchmark.utils.model_postprocessors.DataProcessor') - @patch('ais_bench.benchmark.utils.model_postprocessors.convert_to_xfinder_format') - def test_xfinder_postprocess(self, mock_convert, mock_data_processor_class, - mock_extractor_class, mock_pool): - """Test xfinder_postprocess function.""" - # Setup mocks - mock_convert.return_value = self.test_preds - mock_data_processor = MagicMock() - mock_data_processor_class.return_value = mock_data_processor - mock_data_processor.read_data.return_value = [ - { - 'key_answer_type': 'number', - 'standard_answer_range': '0-10', - 'correct_answer': '4' - } - ] - - mock_extractor = MagicMock() - mock_extractor_class.return_value = mock_extractor - mock_extractor.prepare_input.return_value = "prepared_input" - mock_extractor.gen_output.return_value = "extracted_answer" - - # Mock Pool - mock_pool_instance = MagicMock() - mock_pool.return_value.__enter__.return_value = mock_pool_instance - mock_pool.return_value.__exit__.return_value = False - - def mock_map(func, batches): - results = [] - for batch in batches: - batch_result = func(batch) - results.append(batch_result) - return results - - mock_pool_instance.map = mock_map - - # Call the function - result = mp.xfinder_postprocess( - preds=self.test_preds, - question_type="math", - model_name="test_model", - api_url="http://test.api" - ) - - # Verify - self.assertIsInstance(result, list) - mock_convert.assert_called_once() - mock_data_processor_class.assert_called_once() - mock_extractor_class.assert_called_once() - - def test_naive_model_postprocess_no_api_url(self): - """Test naive_model_postprocess raises error when api_url is None.""" - with self.assertRaises(AssertionError) as cm: - mp.naive_model_postprocess( - preds=self.test_preds, - model_name="test_model", - custom_instruction="Extract", - api_url=None - ) - self.assertIn("api url", str(cm.exception).lower()) - - def test_xfinder_postprocess_no_api_url(self): - """Test xfinder_postprocess raises error when api_url is None.""" - with self.assertRaises(AssertionError) as cm: - mp.xfinder_postprocess( - preds=self.test_preds, - question_type="math", - model_name="test_model", - api_url=None - ) - self.assertIn("api url", str(cm.exception).lower()) - def test_list_decorator_with_list(self): """Test list_decorator with list input.""" @mp.list_decorator