From fd584ea415386d2a569dc59f4d41577e6f70f185 Mon Sep 17 00:00:00 2001 From: SJTUyh Date: Mon, 16 Mar 2026 14:34:23 +0800 Subject: [PATCH] adapt tau2 bench --- ais_bench/benchmark/cli/config_manager.py | 6 +- ais_bench/benchmark/cli/utils.py | 2 + ais_bench/benchmark/cli/workers.py | 40 ++- ais_bench/benchmark/partitioners/base.py | 4 +- ais_bench/benchmark/partitioners/naive.py | 4 +- ais_bench/benchmark/registry.py | 2 +- ais_bench/benchmark/tasks/base.py | 8 + .../benchmark/tasks/custom_tasks/__init__.py | 0 .../tasks/custom_tasks/tau2_bench_task.py | 264 ++++++++++++++++++ ais_bench/benchmark/utils/prompt/prompt.py | 2 + .../configs/agent_example/tau2_bench_task.py | 61 ++++ requirements/runtime.txt | 2 +- 12 files changed, 379 insertions(+), 16 deletions(-) create mode 100644 ais_bench/benchmark/tasks/custom_tasks/__init__.py create mode 100644 ais_bench/benchmark/tasks/custom_tasks/tau2_bench_task.py create mode 100644 ais_bench/configs/agent_example/tau2_bench_task.py diff --git a/ais_bench/benchmark/cli/config_manager.py b/ais_bench/benchmark/cli/config_manager.py index dde76527..bd8f91d6 100644 --- a/ais_bench/benchmark/cli/config_manager.py +++ b/ais_bench/benchmark/cli/config_manager.py @@ -12,8 +12,8 @@ from ais_bench.benchmark.cli.utils import fill_model_path_if_datasets_need, fill_test_range_use_num_prompts class CustomConfigChecker: - MODEL_REQUIRED_FIELDS = ['type', 'abbr', 'attr'] - DATASET_REQUIRED_FIELDS = ['type', 'abbr', 'reader_cfg', 'infer_cfg', 'eval_cfg'] + MODEL_REQUIRED_FIELDS = ['abbr'] + DATASET_REQUIRED_FIELDS = ['abbr'] SUMMARIZER_REQUIRED_FIELDS = ['attr'] def __init__(self, config, file_path): @@ -106,6 +106,8 @@ def load_config(self, workflow): def _fill_dataset_configs(self): for dataset_cfg in self.cfg["datasets"]: + if dataset_cfg.get("infer_cfg", None) is None: + continue fill_test_range_use_num_prompts(self.cfg["cli_args"].get("num_prompts"), dataset_cfg) fill_model_path_if_datasets_need(self.cfg["models"][0], dataset_cfg) retriever_cfg = dataset_cfg["infer_cfg"]["retriever"] diff --git a/ais_bench/benchmark/cli/utils.py b/ais_bench/benchmark/cli/utils.py index 01ff50e5..ac2037e7 100644 --- a/ais_bench/benchmark/cli/utils.py +++ b/ais_bench/benchmark/cli/utils.py @@ -16,6 +16,8 @@ logger = AISLogger() def get_config_type(obj) -> str: + if obj is None: + return None if isinstance(obj, str): return obj return f"{obj.__module__}.{obj.__name__}" diff --git a/ais_bench/benchmark/cli/workers.py b/ais_bench/benchmark/cli/workers.py index ce1dd8bb..21483e45 100644 --- a/ais_bench/benchmark/cli/workers.py +++ b/ais_bench/benchmark/cli/workers.py @@ -15,6 +15,7 @@ from ais_bench.benchmark.partitioners import NaivePartitioner from ais_bench.benchmark.runners import LocalRunner from ais_bench.benchmark.tasks import OpenICLEvalTask, OpenICLApiInferTask, OpenICLInferTask +from ais_bench.benchmark.tasks.base import EmptyTask from ais_bench.benchmark.summarizers import DefaultSummarizer, DefaultPerfSummarizer from ais_bench.benchmark.calculators import DefaultPerfMetricCalculator from ais_bench.benchmark.cli.utils import fill_model_path_if_datasets_need @@ -26,6 +27,7 @@ class BaseWorker(ABC): def __init__(self, args) -> None: self.args = args + self.skip = False @abstractmethod def update_cfg(self, cfg: ConfigDict) -> None: @@ -39,21 +41,29 @@ def do_work(self, cfg: ConfigDict): class Infer(BaseWorker): - def update_cfg(self, cfg: ConfigDict) -> None: + def update_cfg(self, cfg: ConfigDict) -> ConfigDict: def get_task_type() -> str: if cfg["models"][0]["attr"] == "service": return get_config_type(OpenICLApiInferTask) else: return get_config_type(OpenICLInferTask) + custom_infer = cfg.get("infer") + custom_task = None + if custom_infer: + custom_task = custom_infer["runner"]["task"].get("type") + if custom_task == EmptyTask: + self.skip = True + return cfg + new_cfg = dict( infer=dict( - partitioner=dict(type=get_config_type(NaivePartitioner)), + partitioner= dict(type=get_config_type(NaivePartitioner)), runner=dict( max_num_workers=self.args.max_num_workers, max_workers_per_gpu=self.args.max_workers_per_gpu, debug=self.args.debug, - task=dict(type=get_task_type()), + task=dict(type=get_config_type(custom_task) if custom_task else get_task_type()), type=get_config_type(LocalRunner), ), ), @@ -66,6 +76,9 @@ def get_task_type() -> str: return cfg def do_work(self, cfg: ConfigDict): + if self.skip: + logger.info("EmptyTask is selected, skip inference.") + return partitioner = PARTITIONERS.build(cfg.infer.partitioner) logger.info("Starting inference tasks...") tasks = partitioner(cfg) @@ -118,7 +131,7 @@ def __init__(self, args) -> None: super().__init__(args) self.judge_model_type = None - def update_cfg(self, cfg: ConfigDict) -> None: + def update_cfg(self, cfg: ConfigDict) -> ConfigDict: for dataset_cfg in cfg["datasets"]: judge_infer_cfg = dataset_cfg.get("judge_infer_cfg") if judge_infer_cfg: @@ -258,20 +271,28 @@ def _result_post_process(self, tasks, cfg: ConfigDict): class Eval(BaseWorker): - def update_cfg(self, cfg: ConfigDict) -> None: + def update_cfg(self, cfg: ConfigDict) -> ConfigDict: + custom_eval = cfg.get("eval") + custom_task = None + if custom_eval: + custom_task = custom_eval["runner"]["task"].get("type") + if custom_task == EmptyTask: + self.skip = True + return cfg + new_cfg = dict( eval=dict( partitioner=dict(type=get_config_type(NaivePartitioner)), runner=dict( max_num_workers=self.args.max_num_workers, + max_workers_per_gpu=self.args.max_workers_per_gpu, debug=self.args.debug, - task=dict(type=get_config_type(OpenICLEvalTask)), + task=dict(type=get_config_type(custom_task) if custom_task else get_config_type(OpenICLEvalTask)), + type=get_config_type(LocalRunner), ), ), ) - new_cfg["eval"]["runner"]["type"] = get_config_type(LocalRunner) - new_cfg["eval"]["runner"]["max_workers_per_gpu"] = self.args.max_workers_per_gpu cfg.merge_from_dict(new_cfg) if cfg.cli_args.dump_eval_details: cfg.eval.runner.task.dump_details = True @@ -283,6 +304,9 @@ def update_cfg(self, cfg: ConfigDict) -> None: return cfg def do_work(self, cfg: ConfigDict): + if self.skip: + logger.info("EmptyTask is selected, skip evaluation.") + return partitioner = PARTITIONERS.build(cfg.eval.partitioner) logger.info("Starting evaluation tasks...") self._cfg_pre_process(cfg) diff --git a/ais_bench/benchmark/partitioners/base.py b/ais_bench/benchmark/partitioners/base.py index 5851f167..cbe0e013 100644 --- a/ais_bench/benchmark/partitioners/base.py +++ b/ais_bench/benchmark/partitioners/base.py @@ -111,8 +111,8 @@ def _check_task_cfg(self, tasks): filtered_tasks = [] for task in tasks: mode = task.get("cli_args", {}).get("mode") - dataset_type = task["datasets"][0][0]["type"] - model_type = task["models"][0]["type"] + dataset_type = task["datasets"][0][0].get("type", None) + model_type = task["models"][0].get("type", None) if mode not in ["perf", "perf_viz"] and dataset_type in ONLY_PERF_DATASETS: self.logger.warning( f"'{dataset_type}' can only be used for performance evaluation, " diff --git a/ais_bench/benchmark/partitioners/naive.py b/ais_bench/benchmark/partitioners/naive.py index 1a156ca9..429d9193 100644 --- a/ais_bench/benchmark/partitioners/naive.py +++ b/ais_bench/benchmark/partitioners/naive.py @@ -23,7 +23,7 @@ class NaivePartitioner(BasePartitioner): """ def __init__(self, - out_dir: str, + out_dir: str = '', n: int = 1, keep_keys: Optional[List[str]] = None): super().__init__(out_dir=out_dir, keep_keys=keep_keys) @@ -33,7 +33,7 @@ def partition(self, model_dataset_combinations: List[Dict[str, List[ConfigDict]]], work_dir: str, - out_dir: str, + out_dir: str = '', add_cfg: Dict = {}) -> List[Dict]: """Partition model-dataset pairs into tasks. Each task is defined as a dict and will run independently as a unit. Its structure is as diff --git a/ais_bench/benchmark/registry.py b/ais_bench/benchmark/registry.py index 5d1ba970..5faee5b3 100644 --- a/ais_bench/benchmark/registry.py +++ b/ais_bench/benchmark/registry.py @@ -53,7 +53,7 @@ def register_module( PARTITIONERS = Registry('partitioner', locations=get_locations('partitioners')) RUNNERS = Registry('runner', locations=get_locations('runners')) -TASKS = Registry('task', locations=get_locations('tasks')) +TASKS = Registry('task', locations=get_locations('tasks') + get_locations('tasks.custom_tasks')) MODELS = Registry('model', locations=get_locations('models')) # TODO: LOAD_DATASET -> DATASETS LOAD_DATASET = Registry('load_dataset', locations=get_locations('datasets')) diff --git a/ais_bench/benchmark/tasks/base.py b/ais_bench/benchmark/tasks/base.py index 5d7e6d1d..13890baa 100644 --- a/ais_bench/benchmark/tasks/base.py +++ b/ais_bench/benchmark/tasks/base.py @@ -134,6 +134,14 @@ def get_output_paths(self, file_extension: str = "json") -> List[str]: return output_paths +class EmptyTask(BaseTask): + def run(self): + pass + + def get_command(self, cfg_path, template) -> str: + return "" + + class TaskStateManager: def __init__(self, tmp_path: str, task_name: str, is_debug: bool, refresh_interval: int = 0.5): self.logger = AISLogger() diff --git a/ais_bench/benchmark/tasks/custom_tasks/__init__.py b/ais_bench/benchmark/tasks/custom_tasks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/ais_bench/benchmark/tasks/custom_tasks/tau2_bench_task.py b/ais_bench/benchmark/tasks/custom_tasks/tau2_bench_task.py new file mode 100644 index 00000000..cbc98642 --- /dev/null +++ b/ais_bench/benchmark/tasks/custom_tasks/tau2_bench_task.py @@ -0,0 +1,264 @@ +import argparse +import copy +import os +import json +import os.path as osp +import random +import threading +import sys +import time +from typing import Any + +from mmengine.config import Config, ConfigDict +from mmengine.utils import mkdir_or_exist + +import threading +from pathlib import Path +from tqdm import tqdm + +from ais_bench.benchmark.registry import (TASKS) +from ais_bench.benchmark.tasks.base import TaskStateManager +from ais_bench.benchmark.utils.config import ConfigDict +from ais_bench.benchmark.utils.logging import AISLogger +from ais_bench.benchmark.utils.logging.exceptions import AISBenchConfigError +from ais_bench.benchmark.utils.logging.error_codes import UTILS_CODES +from ais_bench.benchmark.utils.core.abbr import task_abbr_from_cfg, model_abbr_from_cfg, dataset_abbr_from_cfg +from ais_bench.benchmark.tasks.base import BaseTask + +# ================= 替换litellm中计费函数 ================= +import litellm +import logging + +litellm_logger = logging.getLogger("litellm") +litellm_logger.setLevel(logging.CRITICAL) + +try: + from litellm.utils import get_response_cost as litellm_get_response_cost +except ImportError: + try: + from litellm.cost_calculator import get_response_cost as litellm_get_response_cost + except ImportError: + litellm_get_response_cost = None + +def patched_get_response_cost(*args, **kwargs): + if litellm_get_response_cost is None: + return 0.0 + try: + return litellm_get_response_cost(*args, **kwargs) + except Exception as e: + if "This model isn't mapped yet" in str(e): + return 0.0 + raise e + +try: + litellm.utils.get_response_cost = patched_get_response_cost +except AttributeError: + pass +try: + litellm.cost_calculator.get_response_cost = patched_get_response_cost +except AttributeError: + pass +# ================= 替换litellm中计费函数 ================= + +DEFAULT_FAKE_API_KEY = "fake_api_key" + +from tau2.data_model.simulation import RunConfig +from tau2.run import run_domain, get_tasks +from tau2.metrics.agent_metrics import compute_metrics + +# ================= 替换tau2中计费函数 ================= +import tau2.utils.llm_utils as tau2_llm_utils +import loguru + +_original_tau2_get_response_cost = tau2_llm_utils.get_response_cost +_original_tau2_logger_error = tau2_llm_utils.logger.error + +def _patched_logger_error(message, *args, **kwargs): + if "This model isn't mapped yet" in str(message): + return + _original_tau2_logger_error(message, *args, **kwargs) + +tau2_llm_utils.logger.error = _patched_logger_error +# ================= 替换tau2中计费函数 ================= + +@TASKS.register_module() +class TAU2BenchTask(BaseTask): + name_prefix = "TAU2BenchTask" + log_subdir = "logs/eval" + output_subdir = "results" + + def __init__(self, cfg: ConfigDict) -> None: + super().__init__(cfg) + self.captured_metrics = None + + def get_command(self, cfg_path, template) -> str: + sys.path.append(os.getcwd()) + script_path = __file__ + python = sys.executable + return f'{python} {script_path} {cfg_path}' + + def run(self, task_state_manager: TaskStateManager): + self.logger.info(f'Task {task_abbr_from_cfg(self.cfg)}') + self.task_state_manager: TaskStateManager = task_state_manager + + self._set_api_key() + + self._prepare_out_dir() + + self._refresh_cfg() + + self.run_config: RunConfig = self._construct_run_cfg() + + simulation_results = self._run_with_tqdm() + + self._dump_eval_results(simulation_results) + + def _get_task_count(self, config): + if config.task_set_name is None: + task_set_name = config.domain + else: + task_set_name = config.task_set_name + tasks = get_tasks( + task_set_name=task_set_name, + task_split_name=config.task_split_name, + task_ids=config.task_ids, + num_tasks=config.num_tasks, + ) + return len(tasks) + + def _run_with_tqdm(self): + """ + Display the progress bar while running the simulation. + """ + self.logger.info(f"Pipeline Execute Config: {self.run_config}") + total_tasks = self._get_task_count(self.run_config) * self.run_config.num_trials + save_to = f"{self.run_config.save_to}.json" + pbar = tqdm(total=total_tasks, desc="Running TAU2 Bench", unit="task") + task_state_manager.update_task_state( + { + "status": "running", + "total_count": total_tasks, + "progress_description": f"Running TAU2 Bench", + "finish_count": 0, + } + ) + completed = 0 + + def monitor_file(): + nonlocal completed + while True: + if osp.exists(save_to): + with open(save_to, 'r') as f: + data = json.load(f) + new_completed = len(data.get('simulations', [])) + if new_completed > completed: + pbar.update(new_completed - completed) + task_state_manager.update_task_state( + { + "finish_count": new_completed, + } + ) + completed = new_completed + time.sleep(0.3) + if completed >= total_tasks: + pbar.update(completed - pbar.n) + break + + monitor_thread = threading.Thread(target=monitor_file, daemon=True) + monitor_thread.start() + + try: + results = run_domain(self.run_config) + monitor_thread.join() + finally: + pbar.update(total_tasks - pbar.n) + task_state_manager.update_task_state( + { + "finish_count": total_tasks, + } + ) + pbar.close() + + return results + + def _set_api_key(self): + api_key = self.cfg["models"][0].get("api_key") + if api_key is None: + api_key = DEFAULT_FAKE_API_KEY + os.environ["OPENAI_API_KEY"] = api_key + + def _prepare_out_dir(self): + self.out_dir = osp.join(self.work_dir, self.output_subdir, self.cfg["models"][0]["abbr"]) + mkdir_or_exist(osp.join(self.out_dir, self.cfg["datasets"][0][0]["abbr"])) + out_detail_json = osp.join(self.out_dir, self.cfg["datasets"][0][0]["abbr"], "tau2_run_detail") + if osp.exists(out_detail_json): + os.remove(out_detail_json) + self.cfg["datasets"][0][0]["args"]["save_to"] = osp.abspath(out_detail_json) + + def _refresh_cfg(self): + for key, value in self.cfg["models"][0].items(): + if key == "type": + continue + self.cfg["datasets"][0][0]["args"][key] = value + + def _construct_run_cfg(self) -> RunConfig: + kwargs = {} + for key, value in self.cfg["datasets"][0][0]["args"].items(): + if value is None: + continue + kwargs[key] = value + self.logger.info(f"Run Config: {kwargs}") + run_cfg = RunConfig(**kwargs) + return run_cfg + + def _dump_eval_results(self, simulation_results): + self.captured_metrics = compute_metrics(simulation_results) + if self.captured_metrics is None: + self.logger.error("No metrics captured. Please check the Tau2 run.") + return + out_json = osp.join(f"{self.out_dir}", f"{self.cfg['datasets'][0][0]['abbr']}.json") + results = {"accuracy": 100 * self.captured_metrics.avg_reward} + with open(out_json, "w") as f: + json.dump(results, f, indent=4) + self.logger.info(f"Evaluation results saved to {out_json}") + + +def parse_args(): + parser = argparse.ArgumentParser(description='Model Inferencer') + parser.add_argument('config', help='Config file path') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + logger = AISLogger(__name__) + args = parse_args() + cfg = Config.fromfile(args.config) + task_state_manager = TaskStateManager( + tmp_path=os.path.join(cfg["work_dir"], "status_tmp"), + task_name=task_abbr_from_cfg(cfg), + is_debug=cfg["cli_args"]["debug"], + ) + manager_t = threading.Thread( + target=task_state_manager.launch, + args=() + ) + manager_t.start() + task_state_manager.update_task_state( + { + "status": "start", + "task_log_path": os.path.join(TAU2BenchTask.log_subdir, f"{task_abbr_from_cfg(cfg)}.out"), + } + ) + start_time = time.perf_counter() + try: + inferencer: TAU2BenchTask = TAU2BenchTask(cfg) + inferencer.run(task_state_manager) + except Exception as e: + task_state_manager.update_task_state({"status": "error"}) + raise e + + end_time = time.perf_counter() + logger.info(f'Local infer task time elapsed: {end_time - start_time:.2f}s') + task_state_manager.update_task_state({"status": "finish"}) + manager_t.join() \ No newline at end of file diff --git a/ais_bench/benchmark/utils/prompt/prompt.py b/ais_bench/benchmark/utils/prompt/prompt.py index b485854d..610e0923 100644 --- a/ais_bench/benchmark/utils/prompt/prompt.py +++ b/ais_bench/benchmark/utils/prompt/prompt.py @@ -61,6 +61,8 @@ def get_prompt_hash(dataset_cfg: Union[ConfigDict, List[ConfigDict]]) -> str: hashes = ','.join([get_prompt_hash(cfg) for cfg in dataset_cfg]) hash_object = hashlib.sha256(hashes.encode()) return hash_object.hexdigest() + if not dataset_cfg.get("infer_cfg"): + return "/" if 'reader_cfg' in dataset_cfg.infer_cfg: # new config reader_cfg = dict(type='DatasetReader', diff --git a/ais_bench/configs/agent_example/tau2_bench_task.py b/ais_bench/configs/agent_example/tau2_bench_task.py new file mode 100644 index 00000000..85a9d419 --- /dev/null +++ b/ais_bench/configs/agent_example/tau2_bench_task.py @@ -0,0 +1,61 @@ +from mmengine.config import read_base +from ais_bench.benchmark.models import VLLMCustomAPIChat +from ais_bench.benchmark.tasks.custom_tasks.tau2_bench_task import TAU2BenchTask +from ais_bench.benchmark.tasks.base import EmptyTask + +with read_base(): + from ais_bench.benchmark.configs.summarizers.example import summarizer + +models = [ + dict( + abbr="openai-v1-chat", + api_key=None, # API KEY 默认是个无效字符串 ,内部会声明OPENAI_API_KEY + agent = None, # 使用的 agent 实现,默认为 DEFAULT_AGENT_IMPLEMENTATION + llm_agent = "openai/qwen3", # agent 使用的 LLM,默认为 DEFAULT_LLM_AGENT + llm_args_agent = {"api_base": "http://localhost:2498/v1", "temperature": 0.5}, # agent LLM 的参数,默认为 {"temperature": DEFAULT_LLM_TEMPERATURE_AGENT} + ) +] + +work_dir = 'outputs/default/' + +datasets = [] +sub_tasks = ["airline", "retail", "telecom"] +for task in sub_tasks: + datasets.append( + dict( + abbr=f'tau2_bench_{task}', + args = dict( + domain = "airline", # -d, 要运行的模拟域,可选值为 get_options().domains ["airline", "retail", "telecom"] + num_trials = 1, # 每个任务运行的次数,默认为 1 + # agent = "baseline", # 使用的 agent 实现,默认为 DEFAULT_AGENT_IMPLEMENTATION + # agent_llm = "openai/gpt-4o", # agent 使用的 LLM,默认为 DEFAULT_LLM_AGENT + # agent_llm_args = {"api_base": "http://localhost:2998/v1", "temperature": 0.0}, # agent LLM 的参数,默认为 {"temperature": DEFAULT_LLM_TEMPERATURE_AGENT} + user = None, # 使用的 user 实现,默认为 DEFAULT_USER_IMPLEMENTATION + llm_user = "openai/qwen3", # user 使用的 LLM,默认为 DEFAULT_LLM_USER + llm_args_user = {"api_base": "http://localhost:2498/v1", "temperature": 1.0}, # user LLM 的参数,默认为 {"temperature": DEFAULT_LLM_TEMPERATURE_USER} + task_set_name = None, # 要运行的任务集,如未提供则加载域的默认任务集 + task_split_name = None, # 要运行的任务分割,默认为 'base' + task_ids = None, # 可选,只运行指定 ID 的任务 + num_tasks = 5, # 要运行的任务数量 + max_steps = None, # 模拟运行的最大步数,默认为 DEFAULT_MAX_STEPS + max_errors = None, # 模拟中连续允许的最大工具错误数,默认为 DEFAULT_MAX_ERRORS + # save_to = None, # 模拟结果的保存路径,保存到 data/simulations/.json + max_concurrency = 5, # 并发运行的最大模拟数,默认为 DEFAULT_MAX_CONCURRENCY + seed = None, # 模拟使用的随机种子,默认为 DEFAULT_SEED + log_level = "INFO", # 模拟的日志级别,默认为 DEFAULT_LOG_LEVEL + enforce_communication_protocol = False, # 是否强制执行通信协议规则,默认为 False + ), + ) + ) + +infer = dict( + runner=dict( + task=dict(type=EmptyTask) + ), +) + +eval = dict( + runner=dict( + task=dict(type=TAU2BenchTask) + ), +) diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 66c22508..c317918e 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -35,7 +35,7 @@ rouge rouge_chinese rouge_score sacrebleu -scikit_learn==1.5.0 +scikit_learn>=1.5.0 scipy seaborn tabulate