Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ais_bench/benchmark/datasets/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
"ais_bench.benchmark.datasets.VocalSoundDataset",
]
# Multimodal APIs.
MM_APIS = ["ais_bench.benchmark.models.VLLMCustomAPIChat"]
MM_APIS = ["ais_bench.benchmark.models.VLLMCustomAPIChat",
"ais_bench.benchmark.models.VLLMCustomAPIChatStream",
]

def get_cache_dir(default_dir):
# TODO Add any necessary supplementary information for here
Expand Down
147 changes: 10 additions & 137 deletions ais_bench/benchmark/utils/model_postprocessors.py
Original file line number Diff line number Diff line change
@@ -1,152 +1,21 @@
import re
from functools import partial
from multiprocessing import Pool
from typing import Union

from tqdm import tqdm

from ais_bench.benchmark.registry import TEXT_POSTPROCESSORS

from ais_bench.benchmark.utils.postprocess.postprocessors.naive import NaiveExtractor, format_input_naive
from ais_bench.benchmark.utils.postprocess.postprocessors.xfinder.extractor import Extractor
from ais_bench.benchmark.utils.logging.logger import AISLogger
from ais_bench.benchmark.utils.postprocess.postprocessors.xfinder.xfinder_utils import (DataProcessor,
convert_to_xfinder_format)


logger = AISLogger()
logger.warning("ais_bench.benchmark.utils.model_postprocessors is deprecated, please use ais_bench.benchmark.utils.postprocess.model_postprocessors instead.")

def gen_output_naive(ori_data, extractor):
extracted_answers = []
for item in tqdm(ori_data):
user_input = extractor.prepare_input(item)
extracted_answer = extractor.gen_output(user_input)
item['extracted_answer'] = extracted_answer
extracted_answers.append(extracted_answer)

return extracted_answers


@TEXT_POSTPROCESSORS.register_module('naive')
def naive_model_postprocess(preds: list,
model_name: str,
custom_instruction: str,
api_url: Union[str, list],
num_processes: int = 8,
**kwargs) -> list:
"""Postprocess the text extracted by custom model.
Args:
preds (list): The question, reference answer and model prediction.
model_name (str): The name of the model.
custom_instruction (str): Custom instruction for the dataset.
url (Union[str, list]): The api url of the model.

Returns:
list: The postprocessed answers.
"""

def _eval_pred(texts, extractor, num_processes):
ori_data = texts
extracted_answers = []
batched_ori_data = []
# Split data into batches
num_processes = min(num_processes, len(ori_data))
batch_size = len(ori_data) // num_processes
for i in range(0, len(ori_data), batch_size):
batched_ori_data.append(ori_data[i:i + batch_size])
with Pool(num_processes) as p:
results = p.map(partial(gen_output_naive, extractor=extractor),
batched_ori_data)
for result in results:
extracted_answers.extend(result)
return extracted_answers

format_data = format_input_naive(preds)
assert api_url is not None, 'Please provide the api url.'
extractor = NaiveExtractor(
model_name=model_name,
custom_instruction=custom_instruction,
url=api_url.split(',') if ',' in api_url else api_url)
calc_acc_func = partial(_eval_pred,
extractor=extractor,
num_processes=num_processes)
extracted_answers = calc_acc_func(format_data)
return extracted_answers


def gen_output_xfinder(ori_data, extractor):
ext_cor_pairs = []
extracted_data = []
extracted_answers = []
for item in tqdm(ori_data):
user_input = extractor.prepare_input(item)
extracted_answer = extractor.gen_output(user_input)
ext_cor_pairs.append([
item['key_answer_type'], item['standard_answer_range'],
extracted_answer, item['correct_answer']
])
item['xfinder_extracted_answer'] = extracted_answer
extracted_answers.append(extracted_answer)
extracted_data.append(item)

return extracted_answers, ext_cor_pairs, extracted_data


@TEXT_POSTPROCESSORS.register_module('xfinder')
def xfinder_postprocess(preds: list, question_type: str, model_name: str,
api_url: Union[str, list], **kwargs) -> list:
"""Postprocess the text extracted by xFinder model.
Args:
preds (list): The question, reference answer and model prediction.
question_type (str): The type of the question.
url (Union[str, list]): The api url of the xFinder model.


Returns:
list: The postprocessed texts.
"""

def _eval_pred(texts, data_processor, extractor, num_processes=8):
ori_data = data_processor.read_data(texts)
extracted_correct_pairs = []
extracted_data = []
extracted_answers = []
batched_ori_data = []
# Split data into batches
num_processes = min(num_processes, len(ori_data))
batch_size = len(ori_data) // num_processes
for i in range(0, len(ori_data), batch_size):
batched_ori_data.append(ori_data[i:i + batch_size])
with Pool(num_processes) as p:
results = p.map(partial(gen_output_xfinder, extractor=extractor),
batched_ori_data)
for result in results:
extracted_answers += result[0]
extracted_correct_pairs += result[1]
extracted_data += result[2]
return extracted_answers

format_data = convert_to_xfinder_format(question_type, preds)
assert api_url is not None, 'Please provide the api url.'
data_processor = DataProcessor()
extractor = Extractor(
model_name=model_name,
url=api_url.split(',') if ',' in api_url else api_url)
calc_acc_func = partial(_eval_pred,
data_processor=data_processor,
extractor=extractor)
extracted_answers = calc_acc_func(format_data)
return extracted_answers


def list_decorator(func):
"""Decorator: make the function able to handle list input"""
def wrapper(text_or_list, *args, **kwargs):
if isinstance(text_or_list, list):
logger.debug(f"list_decorator({func.__name__}): processing list of {len(text_or_list)} item(s)")
return [func(text, *args, **kwargs) for text in text_or_list]
else:
return func(text_or_list, *args, **kwargs)
logger.debug(f"list_decorator({func.__name__}): processing single item")
return func(text_or_list, *args, **kwargs)
return wrapper


Expand Down Expand Up @@ -180,20 +49,24 @@ def extract_non_reasoning_content(
>>> text = "Start<think>reasoning here</think> End"
>>> extract_non_reasoning_content(text)
'Start End'

>>> # When input is a list
>>> texts = ["Start<think>reasoning</think> End", "Test</think> Result"]
>>> extract_non_reasoning_content(texts)
['Start End', 'Result']
"""
logger.debug(f"extract_non_reasoning_content: start_token='{think_start_token}', end_token='{think_end_token}'")
# If text contains only end token, split by end token and take the last part
if not isinstance(text, str):
return text
if think_start_token not in text and think_end_token in text:
return text.split(think_end_token)[-1].strip()
result = text.split(think_end_token)[-1].strip()
logger.debug(f"extract_non_reasoning_content: only end token present -> length={len(result)}")
return result

# Original behavior for complete tag pairs
reasoning_regex = re.compile(rf'{think_start_token}(.*?){think_end_token}',
re.DOTALL)
non_reasoning_content = reasoning_regex.sub('', text).strip()
return non_reasoning_content
logger.debug(f"extract_non_reasoning_content: removed reasoning sections -> length={len(non_reasoning_content)}")
return non_reasoning_content
125 changes: 0 additions & 125 deletions tests/UT/utils/test_model_postprocessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,131 +41,6 @@ def setUp(self):
}
]

@patch('ais_bench.benchmark.utils.model_postprocessors.Pool')
@patch('ais_bench.benchmark.utils.model_postprocessors.NaiveExtractor')
@patch('ais_bench.benchmark.utils.model_postprocessors.format_input_naive')
def test_naive_model_postprocess(self, mock_format, mock_extractor_class, mock_pool):
"""Test naive_model_postprocess function."""
# Setup mocks
mock_format.return_value = self.test_preds
mock_extractor = MagicMock()
mock_extractor_class.return_value = mock_extractor

# Mock the extractor's methods
def mock_gen_output(ori_data, extractor):
results = []
for item in ori_data:
extracted = extractor.gen_output(item)
results.append(extracted)
return results

mock_extractor.prepare_input.return_value = "prepared_input"
mock_extractor.gen_output.return_value = "extracted_answer"

# Mock Pool
mock_pool_instance = MagicMock()
mock_pool.return_value.__enter__.return_value = mock_pool_instance
mock_pool.return_value.__exit__.return_value = False

# Mock map to call the function directly
def mock_map(func, batches):
results = []
for batch in batches:
batch_result = func(batch)
results.append(batch_result)
return results

mock_pool_instance.map = mock_map

# Call the function
result = mp.naive_model_postprocess(
preds=self.test_preds,
model_name="test_model",
custom_instruction="Extract the answer",
api_url="http://test.api",
num_processes=2
)

# Verify
self.assertIsInstance(result, list)
self.assertEqual(len(result), 2)
mock_format.assert_called_once_with(self.test_preds)
mock_extractor_class.assert_called_once()

@patch('ais_bench.benchmark.utils.model_postprocessors.Pool')
@patch('ais_bench.benchmark.utils.model_postprocessors.Extractor')
@patch('ais_bench.benchmark.utils.model_postprocessors.DataProcessor')
@patch('ais_bench.benchmark.utils.model_postprocessors.convert_to_xfinder_format')
def test_xfinder_postprocess(self, mock_convert, mock_data_processor_class,
mock_extractor_class, mock_pool):
"""Test xfinder_postprocess function."""
# Setup mocks
mock_convert.return_value = self.test_preds
mock_data_processor = MagicMock()
mock_data_processor_class.return_value = mock_data_processor
mock_data_processor.read_data.return_value = [
{
'key_answer_type': 'number',
'standard_answer_range': '0-10',
'correct_answer': '4'
}
]

mock_extractor = MagicMock()
mock_extractor_class.return_value = mock_extractor
mock_extractor.prepare_input.return_value = "prepared_input"
mock_extractor.gen_output.return_value = "extracted_answer"

# Mock Pool
mock_pool_instance = MagicMock()
mock_pool.return_value.__enter__.return_value = mock_pool_instance
mock_pool.return_value.__exit__.return_value = False

def mock_map(func, batches):
results = []
for batch in batches:
batch_result = func(batch)
results.append(batch_result)
return results

mock_pool_instance.map = mock_map

# Call the function
result = mp.xfinder_postprocess(
preds=self.test_preds,
question_type="math",
model_name="test_model",
api_url="http://test.api"
)

# Verify
self.assertIsInstance(result, list)
mock_convert.assert_called_once()
mock_data_processor_class.assert_called_once()
mock_extractor_class.assert_called_once()

def test_naive_model_postprocess_no_api_url(self):
"""Test naive_model_postprocess raises error when api_url is None."""
with self.assertRaises(AssertionError) as cm:
mp.naive_model_postprocess(
preds=self.test_preds,
model_name="test_model",
custom_instruction="Extract",
api_url=None
)
self.assertIn("api url", str(cm.exception).lower())

def test_xfinder_postprocess_no_api_url(self):
"""Test xfinder_postprocess raises error when api_url is None."""
with self.assertRaises(AssertionError) as cm:
mp.xfinder_postprocess(
preds=self.test_preds,
question_type="math",
model_name="test_model",
api_url=None
)
self.assertIn("api url", str(cm.exception).lower())

def test_list_decorator_with_list(self):
"""Test list_decorator with list input."""
@mp.list_decorator
Expand Down