-
Notifications
You must be signed in to change notification settings - Fork 19
[feature]Support tau2 bench #192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -15,6 +15,7 @@ | |||||
| from ais_bench.benchmark.partitioners import NaivePartitioner | ||||||
| from ais_bench.benchmark.runners import LocalRunner | ||||||
| from ais_bench.benchmark.tasks import OpenICLEvalTask, OpenICLApiInferTask, OpenICLInferTask | ||||||
| from ais_bench.benchmark.tasks.base import EmptyTask | ||||||
| from ais_bench.benchmark.summarizers import DefaultSummarizer, DefaultPerfSummarizer | ||||||
| from ais_bench.benchmark.calculators import DefaultPerfMetricCalculator | ||||||
| from ais_bench.benchmark.cli.utils import fill_model_path_if_datasets_need | ||||||
|
|
@@ -26,6 +27,7 @@ | |||||
| class BaseWorker(ABC): | ||||||
| def __init__(self, args) -> None: | ||||||
| self.args = args | ||||||
| self.skip = False | ||||||
|
|
||||||
| @abstractmethod | ||||||
| def update_cfg(self, cfg: ConfigDict) -> None: | ||||||
|
|
@@ -39,21 +41,29 @@ def do_work(self, cfg: ConfigDict): | |||||
|
|
||||||
|
|
||||||
| class Infer(BaseWorker): | ||||||
| def update_cfg(self, cfg: ConfigDict) -> None: | ||||||
| def update_cfg(self, cfg: ConfigDict) -> ConfigDict: | ||||||
| def get_task_type() -> str: | ||||||
| if cfg["models"][0]["attr"] == "service": | ||||||
| return get_config_type(OpenICLApiInferTask) | ||||||
| else: | ||||||
| return get_config_type(OpenICLInferTask) | ||||||
|
|
||||||
| custom_infer = cfg.get("infer") | ||||||
| custom_task = None | ||||||
| if custom_infer: | ||||||
| custom_task = custom_infer["runner"]["task"].get("type") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Accessing nested dictionary keys directly without checking for their existence can lead to a
Suggested change
|
||||||
| if custom_task == EmptyTask: | ||||||
| self.skip = True | ||||||
| return cfg | ||||||
|
|
||||||
| new_cfg = dict( | ||||||
| infer=dict( | ||||||
| partitioner=dict(type=get_config_type(NaivePartitioner)), | ||||||
| partitioner= dict(type=get_config_type(NaivePartitioner)), | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||
| runner=dict( | ||||||
| max_num_workers=self.args.max_num_workers, | ||||||
| max_workers_per_gpu=self.args.max_workers_per_gpu, | ||||||
| debug=self.args.debug, | ||||||
| task=dict(type=get_task_type()), | ||||||
| task=dict(type=get_config_type(custom_task) if custom_task else get_task_type()), | ||||||
| type=get_config_type(LocalRunner), | ||||||
| ), | ||||||
| ), | ||||||
|
|
@@ -66,6 +76,9 @@ def get_task_type() -> str: | |||||
| return cfg | ||||||
|
|
||||||
| def do_work(self, cfg: ConfigDict): | ||||||
| if self.skip: | ||||||
| logger.info("EmptyTask is selected, skip inference.") | ||||||
| return | ||||||
| partitioner = PARTITIONERS.build(cfg.infer.partitioner) | ||||||
| logger.info("Starting inference tasks...") | ||||||
| tasks = partitioner(cfg) | ||||||
|
|
@@ -118,7 +131,7 @@ def __init__(self, args) -> None: | |||||
| super().__init__(args) | ||||||
| self.judge_model_type = None | ||||||
|
|
||||||
| def update_cfg(self, cfg: ConfigDict) -> None: | ||||||
| def update_cfg(self, cfg: ConfigDict) -> ConfigDict: | ||||||
| for dataset_cfg in cfg["datasets"]: | ||||||
| judge_infer_cfg = dataset_cfg.get("judge_infer_cfg") | ||||||
| if judge_infer_cfg: | ||||||
|
|
@@ -258,20 +271,28 @@ def _result_post_process(self, tasks, cfg: ConfigDict): | |||||
|
|
||||||
|
|
||||||
| class Eval(BaseWorker): | ||||||
| def update_cfg(self, cfg: ConfigDict) -> None: | ||||||
| def update_cfg(self, cfg: ConfigDict) -> ConfigDict: | ||||||
| custom_eval = cfg.get("eval") | ||||||
| custom_task = None | ||||||
| if custom_eval: | ||||||
| custom_task = custom_eval["runner"]["task"].get("type") | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the
Suggested change
|
||||||
| if custom_task == EmptyTask: | ||||||
| self.skip = True | ||||||
| return cfg | ||||||
|
|
||||||
| new_cfg = dict( | ||||||
| eval=dict( | ||||||
| partitioner=dict(type=get_config_type(NaivePartitioner)), | ||||||
| runner=dict( | ||||||
| max_num_workers=self.args.max_num_workers, | ||||||
| max_workers_per_gpu=self.args.max_workers_per_gpu, | ||||||
| debug=self.args.debug, | ||||||
| task=dict(type=get_config_type(OpenICLEvalTask)), | ||||||
| task=dict(type=get_config_type(custom_task) if custom_task else get_config_type(OpenICLEvalTask)), | ||||||
| type=get_config_type(LocalRunner), | ||||||
| ), | ||||||
| ), | ||||||
| ) | ||||||
|
|
||||||
| new_cfg["eval"]["runner"]["type"] = get_config_type(LocalRunner) | ||||||
| new_cfg["eval"]["runner"]["max_workers_per_gpu"] = self.args.max_workers_per_gpu | ||||||
| cfg.merge_from_dict(new_cfg) | ||||||
| if cfg.cli_args.dump_eval_details: | ||||||
| cfg.eval.runner.task.dump_details = True | ||||||
|
|
@@ -283,6 +304,9 @@ def update_cfg(self, cfg: ConfigDict) -> None: | |||||
| return cfg | ||||||
|
|
||||||
| def do_work(self, cfg: ConfigDict): | ||||||
| if self.skip: | ||||||
| logger.info("EmptyTask is selected, skip evaluation.") | ||||||
| return | ||||||
| partitioner = PARTITIONERS.build(cfg.eval.partitioner) | ||||||
| logger.info("Starting evaluation tasks...") | ||||||
| self._cfg_pre_process(cfg) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While adding a
Nonecheck is a good improvement for robustness, the function's return type hint-> stron line 18 is now incorrect because the function can returnNone. Please update the signature to-> Optional[str]to accurately reflect its behavior. You will also need to addfrom typing import Optionalat the top of the file.