Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ais_bench/benchmark/cli/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from ais_bench.benchmark.cli.utils import fill_model_path_if_datasets_need, fill_test_range_use_num_prompts

class CustomConfigChecker:
MODEL_REQUIRED_FIELDS = ['type', 'abbr', 'attr']
DATASET_REQUIRED_FIELDS = ['type', 'abbr', 'reader_cfg', 'infer_cfg', 'eval_cfg']
MODEL_REQUIRED_FIELDS = ['abbr']
DATASET_REQUIRED_FIELDS = ['abbr']
SUMMARIZER_REQUIRED_FIELDS = ['attr']

def __init__(self, config, file_path):
Expand Down Expand Up @@ -106,6 +106,8 @@ def load_config(self, workflow):

def _fill_dataset_configs(self):
for dataset_cfg in self.cfg["datasets"]:
if dataset_cfg.get("infer_cfg", None) is None:
continue
fill_test_range_use_num_prompts(self.cfg["cli_args"].get("num_prompts"), dataset_cfg)
fill_model_path_if_datasets_need(self.cfg["models"][0], dataset_cfg)
retriever_cfg = dataset_cfg["infer_cfg"]["retriever"]
Expand Down
2 changes: 2 additions & 0 deletions ais_bench/benchmark/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
logger = AISLogger()

def get_config_type(obj) -> str:
if obj is None:
return None
Comment on lines +19 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While adding a None check is a good improvement for robustness, the function's return type hint -> str on line 18 is now incorrect because the function can return None. Please update the signature to -> Optional[str] to accurately reflect its behavior. You will also need to add from typing import Optional at the top of the file.

if isinstance(obj, str):
return obj
return f"{obj.__module__}.{obj.__name__}"
Expand Down
40 changes: 32 additions & 8 deletions ais_bench/benchmark/cli/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ais_bench.benchmark.partitioners import NaivePartitioner
from ais_bench.benchmark.runners import LocalRunner
from ais_bench.benchmark.tasks import OpenICLEvalTask, OpenICLApiInferTask, OpenICLInferTask
from ais_bench.benchmark.tasks.base import EmptyTask
from ais_bench.benchmark.summarizers import DefaultSummarizer, DefaultPerfSummarizer
from ais_bench.benchmark.calculators import DefaultPerfMetricCalculator
from ais_bench.benchmark.cli.utils import fill_model_path_if_datasets_need
Expand All @@ -26,6 +27,7 @@
class BaseWorker(ABC):
def __init__(self, args) -> None:
self.args = args
self.skip = False

@abstractmethod
def update_cfg(self, cfg: ConfigDict) -> None:
Expand All @@ -39,21 +41,29 @@ def do_work(self, cfg: ConfigDict):


class Infer(BaseWorker):
def update_cfg(self, cfg: ConfigDict) -> None:
def update_cfg(self, cfg: ConfigDict) -> ConfigDict:
def get_task_type() -> str:
if cfg["models"][0]["attr"] == "service":
return get_config_type(OpenICLApiInferTask)
else:
return get_config_type(OpenICLInferTask)

custom_infer = cfg.get("infer")
custom_task = None
if custom_infer:
custom_task = custom_infer["runner"]["task"].get("type")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Accessing nested dictionary keys directly without checking for their existence can lead to a KeyError. custom_infer['runner'] or custom_infer['runner']['task'] could fail if these keys are not present in the configuration. You should use .get() with default values for safer access.

Suggested change
custom_task = custom_infer["runner"]["task"].get("type")
custom_task = custom_infer.get("runner", {}).get("task", {}).get("type")

if custom_task == EmptyTask:
self.skip = True
return cfg

new_cfg = dict(
infer=dict(
partitioner=dict(type=get_config_type(NaivePartitioner)),
partitioner= dict(type=get_config_type(NaivePartitioner)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is an extra space after partitioner=. Please remove it for consistent formatting.

Suggested change
partitioner= dict(type=get_config_type(NaivePartitioner)),
partitioner=dict(type=get_config_type(NaivePartitioner)),

runner=dict(
max_num_workers=self.args.max_num_workers,
max_workers_per_gpu=self.args.max_workers_per_gpu,
debug=self.args.debug,
task=dict(type=get_task_type()),
task=dict(type=get_config_type(custom_task) if custom_task else get_task_type()),
type=get_config_type(LocalRunner),
),
),
Expand All @@ -66,6 +76,9 @@ def get_task_type() -> str:
return cfg

def do_work(self, cfg: ConfigDict):
if self.skip:
logger.info("EmptyTask is selected, skip inference.")
return
partitioner = PARTITIONERS.build(cfg.infer.partitioner)
logger.info("Starting inference tasks...")
tasks = partitioner(cfg)
Expand Down Expand Up @@ -118,7 +131,7 @@ def __init__(self, args) -> None:
super().__init__(args)
self.judge_model_type = None

def update_cfg(self, cfg: ConfigDict) -> None:
def update_cfg(self, cfg: ConfigDict) -> ConfigDict:
for dataset_cfg in cfg["datasets"]:
judge_infer_cfg = dataset_cfg.get("judge_infer_cfg")
if judge_infer_cfg:
Expand Down Expand Up @@ -258,20 +271,28 @@ def _result_post_process(self, tasks, cfg: ConfigDict):


class Eval(BaseWorker):
def update_cfg(self, cfg: ConfigDict) -> None:
def update_cfg(self, cfg: ConfigDict) -> ConfigDict:
custom_eval = cfg.get("eval")
custom_task = None
if custom_eval:
custom_task = custom_eval["runner"]["task"].get("type")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the Infer class, accessing nested dictionary keys directly can cause a KeyError if runner or task keys are missing in the eval configuration. Please use .get() for safe access to prevent potential crashes.

Suggested change
custom_task = custom_eval["runner"]["task"].get("type")
custom_task = custom_eval.get("runner", {}).get("task", {}).get("type")

if custom_task == EmptyTask:
self.skip = True
return cfg

new_cfg = dict(
eval=dict(
partitioner=dict(type=get_config_type(NaivePartitioner)),
runner=dict(
max_num_workers=self.args.max_num_workers,
max_workers_per_gpu=self.args.max_workers_per_gpu,
debug=self.args.debug,
task=dict(type=get_config_type(OpenICLEvalTask)),
task=dict(type=get_config_type(custom_task) if custom_task else get_config_type(OpenICLEvalTask)),
type=get_config_type(LocalRunner),
),
),
)

new_cfg["eval"]["runner"]["type"] = get_config_type(LocalRunner)
new_cfg["eval"]["runner"]["max_workers_per_gpu"] = self.args.max_workers_per_gpu
cfg.merge_from_dict(new_cfg)
if cfg.cli_args.dump_eval_details:
cfg.eval.runner.task.dump_details = True
Expand All @@ -283,6 +304,9 @@ def update_cfg(self, cfg: ConfigDict) -> None:
return cfg

def do_work(self, cfg: ConfigDict):
if self.skip:
logger.info("EmptyTask is selected, skip evaluation.")
return
partitioner = PARTITIONERS.build(cfg.eval.partitioner)
logger.info("Starting evaluation tasks...")
self._cfg_pre_process(cfg)
Expand Down
4 changes: 2 additions & 2 deletions ais_bench/benchmark/partitioners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def _check_task_cfg(self, tasks):
filtered_tasks = []
for task in tasks:
mode = task.get("cli_args", {}).get("mode")
dataset_type = task["datasets"][0][0]["type"]
model_type = task["models"][0]["type"]
dataset_type = task["datasets"][0][0].get("type", None)
model_type = task["models"][0].get("type", None)
if mode not in ["perf", "perf_viz"] and dataset_type in ONLY_PERF_DATASETS:
self.logger.warning(
f"'{dataset_type}' can only be used for performance evaluation, "
Expand Down
4 changes: 2 additions & 2 deletions ais_bench/benchmark/partitioners/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class NaivePartitioner(BasePartitioner):
"""

def __init__(self,
out_dir: str,
out_dir: str = '',
n: int = 1,
keep_keys: Optional[List[str]] = None):
super().__init__(out_dir=out_dir, keep_keys=keep_keys)
Expand All @@ -33,7 +33,7 @@ def partition(self,
model_dataset_combinations: List[Dict[str,
List[ConfigDict]]],
work_dir: str,
out_dir: str,
out_dir: str = '',
add_cfg: Dict = {}) -> List[Dict]:
"""Partition model-dataset pairs into tasks. Each task is defined as a
dict and will run independently as a unit. Its structure is as
Expand Down
2 changes: 1 addition & 1 deletion ais_bench/benchmark/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def register_module(

PARTITIONERS = Registry('partitioner', locations=get_locations('partitioners'))
RUNNERS = Registry('runner', locations=get_locations('runners'))
TASKS = Registry('task', locations=get_locations('tasks'))
TASKS = Registry('task', locations=get_locations('tasks') + get_locations('tasks.custom_tasks'))
MODELS = Registry('model', locations=get_locations('models'))
# TODO: LOAD_DATASET -> DATASETS
LOAD_DATASET = Registry('load_dataset', locations=get_locations('datasets'))
Expand Down
8 changes: 8 additions & 0 deletions ais_bench/benchmark/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ def get_output_paths(self, file_extension: str = "json") -> List[str]:
return output_paths


class EmptyTask(BaseTask):
def run(self):
pass

def get_command(self, cfg_path, template) -> str:
return ""


class TaskStateManager:
def __init__(self, tmp_path: str, task_name: str, is_debug: bool, refresh_interval: int = 0.5):
self.logger = AISLogger()
Expand Down
Empty file.
Loading
Loading