Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ais_bench/benchmark/cli/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from ais_bench.benchmark.utils.file import match_cfg_file
from ais_bench.benchmark.utils.config.run import try_fill_in_custom_cfgs
from ais_bench.benchmark.utils.logging.exceptions import CommandError, AISBenchConfigError
from ais_bench.benchmark.cli.utils import fill_model_path_if_datasets_need, fill_test_range_use_num_prompts
from ais_bench.benchmark.cli.utils import fill_model_path_if_datasets_need, fill_test_range_use_num_prompts, recur_convert_config_type

class CustomConfigChecker:
MODEL_REQUIRED_FIELDS = ['type', 'abbr', 'attr']
DATASET_REQUIRED_FIELDS = ['type', 'abbr', 'reader_cfg', 'infer_cfg', 'eval_cfg']
DATASET_REQUIRED_FIELDS = ['type', 'abbr']
SUMMARIZER_REQUIRED_FIELDS = ['attr']

def __init__(self, config, file_path):
Expand Down Expand Up @@ -327,6 +327,8 @@ def _dump_and_reload_config(self):
# dump config
output_config_path = osp.join(self.cfg.work_dir, 'configs',
f'{self.cfg_time_str}_{os.getpid()}.py')

recur_convert_config_type(self.cfg)
self.cfg.dump(output_config_path)
# eval nums set
if (self.args.num_prompts and self.args.num_prompts < 0) or self.args.num_prompts == 0:
Expand Down
20 changes: 20 additions & 0 deletions ais_bench/benchmark/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from datetime import datetime

from mmengine.config import ConfigDict, Config
from ais_bench.benchmark.utils.logging.exceptions import AISBenchConfigError
from ais_bench.benchmark.utils.logging.logger import AISLogger
from ais_bench.benchmark.utils.logging.error_codes import UTILS_CODES
Expand All @@ -20,6 +21,25 @@ def get_config_type(obj) -> str:
return obj
return f"{obj.__module__}.{obj.__name__}"

def recur_convert_config_type(cfg):
"""Recursively convert the type of the config to the string type.

Args:
cfg: The config to convert.
"""
if isinstance(cfg, (dict, ConfigDict, Config)):
for key, value in cfg.items():
if key == "type":
cfg[key] = get_config_type(value)
else:
cfg[key] = recur_convert_config_type(value)
elif isinstance(cfg, list):
for i, item in enumerate(cfg):
cfg[i] = recur_convert_config_type(item) if isinstance(item, (dict, ConfigDict, Config, list)) else item
Comment on lines +37 to +38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The recursive call for list items can be simplified. Instead of checking the type of the item before the recursive call, you can just call recur_convert_config_type on every item. The function already handles non-container types by returning them as is.

Suggested change
for i, item in enumerate(cfg):
cfg[i] = recur_convert_config_type(item) if isinstance(item, (dict, ConfigDict, Config, list)) else item
for i, item in enumerate(cfg):
cfg[i] = recur_convert_config_type(item)

else:
return cfg
return cfg


def get_current_time_str():
return datetime.now().strftime("%Y%m%d_%H%M%S")
Expand Down
72 changes: 39 additions & 33 deletions ais_bench/benchmark/cli/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,26 +42,30 @@ class Infer(BaseWorker):
def update_cfg(self, cfg: ConfigDict) -> None:
def get_task_type() -> str:
if cfg["models"][0]["attr"] == "service":
return get_config_type(OpenICLApiInferTask)
return OpenICLApiInferTask
else:
return get_config_type(OpenICLInferTask)
return OpenICLInferTask

new_cfg = dict(
infer=dict(
partitioner=dict(type=get_config_type(NaivePartitioner)),
runner=dict(
max_num_workers=self.args.max_num_workers,
max_workers_per_gpu=self.args.max_workers_per_gpu,
debug=self.args.debug,
task=dict(type=get_task_type()),
type=get_config_type(LocalRunner),
),
),
)
def update_new_infer_cfg(new_cfg: ConfigDict) -> None:
runner_cfg = new_cfg['infer']['runner']
runner_cfg['max_num_workers'] = self.args.max_num_workers
runner_cfg['max_workers_per_gpu'] = self.args.max_workers_per_gpu
runner_cfg['debug'] = self.args.debug or cfg.cli_args.debug

if cfg.get('infer'):
new_cfg = dict(infer=cfg.infer)
else:
new_cfg = dict(
infer=dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
task=dict(type=get_task_type()),
type=LocalRunner,
),
),
)
update_new_infer_cfg(new_cfg)
cfg.merge_from_dict(new_cfg)
if cfg.cli_args.debug:
cfg.infer.runner.debug = True
cfg.infer.partitioner["out_dir"] = osp.join(cfg["work_dir"], "predictions/")
return cfg

Expand Down Expand Up @@ -259,26 +263,28 @@ def _result_post_process(self, tasks, cfg: ConfigDict):

class Eval(BaseWorker):
def update_cfg(self, cfg: ConfigDict) -> None:
new_cfg = dict(
eval=dict(
partitioner=dict(type=get_config_type(NaivePartitioner)),
runner=dict(
max_num_workers=self.args.max_num_workers,
debug=self.args.debug,
task=dict(type=get_config_type(OpenICLEvalTask)),
def update_eval_cfg(new_cfg: ConfigDict) -> None:
runner_cfg = new_cfg['eval']['runner']
runner_cfg['max_num_workers'] = self.args.max_num_workers
runner_cfg['max_workers_per_gpu'] = self.args.max_workers_per_gpu
runner_cfg['debug'] = self.args.debug
runner_cfg['dump_details'] = cfg.cli_args.dump_eval_details
runner_cfg['cal_extract_rate'] = cfg.cli_args.dump_extract_rate

if cfg.get('eval'):
new_cfg = dict(eval=cfg.eval)
else:
new_cfg = dict(
eval=dict(
partitioner=dict(type=NaivePartitioner),
runner=dict(
type=LocalRunner,
task=dict(type=OpenICLEvalTask),
),
),
)
))

new_cfg["eval"]["runner"]["type"] = get_config_type(LocalRunner)
new_cfg["eval"]["runner"]["max_workers_per_gpu"] = self.args.max_workers_per_gpu
update_eval_cfg(new_cfg)
cfg.merge_from_dict(new_cfg)
if cfg.cli_args.dump_eval_details:
cfg.eval.runner.task.dump_details = True
if cfg.cli_args.dump_extract_rate:
cfg.eval.runner.task.cal_extract_rate = True
if cfg.cli_args.debug:
cfg.eval.runner.debug = True
cfg.eval.partitioner["out_dir"] = osp.join(cfg["work_dir"], "results/")
return cfg

Expand Down
1 change: 1 addition & 0 deletions ais_bench/benchmark/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,4 @@
from ais_bench.benchmark.datasets.mmstar import * # noqa: F401, F403
from ais_bench.benchmark.datasets.dapo_math import * # noqa: F401, F403
from ais_bench.benchmark.datasets.mooncake_trace import * # noqa: F401, F403
from ais_bench.benchmark.datasets.swebench import * # noqa: F401, F403
69 changes: 69 additions & 0 deletions ais_bench/benchmark/datasets/swebench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import re
import random
from datasets import load_dataset, Dataset, DatasetDict

from ais_bench.benchmark.registry import LOAD_DATASET
from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError
from ais_bench.benchmark.utils.logging.error_codes import DSET_CODES
from ais_bench.benchmark.datasets.base import BaseDataset

DATASET_MAPPING = {
"full": "princeton-nlp/SWE-Bench",
"verified": "princeton-nlp/SWE-Bench_Verified",
"lite": "princeton-nlp/SWE-Bench_Lite",
"multimodal": "princeton-nlp/SWE-Bench_Multimodal",
"multilingual": "swe-bench/SWE-Bench_Multilingual",
}


@LOAD_DATASET.register_module()
class SWEBenchDataset(BaseDataset):
def filter_instances(
self, instances: list[dict], *, filter_spec: str, shuffle: bool = False
) -> list[dict]:
"""Filter and slice a list of SWEBench instances."""
if shuffle:
instances = sorted(instances.copy(), key=lambda x: x["instance_id"])
random.seed(42)
random.shuffle(instances)
before_filter = len(instances)
instances = [
instance
for instance in instances
if re.match(filter_spec, instance["instance_id"])
]
if (after_filter := len(instances)) != before_filter:
self.logger.info(
f"Instance filter: {before_filter} -> {after_filter} instances"
)
return instances

def load(
self,
path: str,
name: str,
split: str = "test",
filter_spec: str = "",
shuffle: bool = False,
**kwargs,
):
if name not in DATASET_MAPPING:
raise ParameterValueError(
DSET_CODES.INVALID_PARAM_VALUE,
f"Invalid swebench dataset name, expected one of {list(DATASET_MAPPING.keys())} but got {name}",
)
try:
dataset = load_dataset("parquet", data_files={split: path})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function load_dataset("parquet", data_files=...) returns a DatasetDict object, not a Dataset. The subsequent code on line 68, list(dataset), will then operate on the keys of this dictionary (e.g., ['test']) instead of the dataset records, which will cause a TypeError inside filter_instances. You need to select the appropriate split from the DatasetDict before processing it.

Suggested change
dataset = load_dataset("parquet", data_files={split: path})
dataset = load_dataset("parquet", data_files={split: path})[split]

except Exception as e:
self.logger.warning(
f"Failed to load swebench dataset {name} from {path} with error: {e}, trying to load from Hugging Face"
)
try:
dataset = load_dataset(DATASET_MAPPING[name], split=split)
except Exception as e:
raise ParameterValueError(
DSET_CODES.DATA_PREPROCESSING_ERROR,
f"Failed to load swebench dataset {name} from Hugging Face with error: {e}.",
)
dataset = self.filter_instances(list(dataset), filter_spec=filter_spec, shuffle=shuffle)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Converting the entire dataset to a list using list(dataset) can be very memory-intensive, especially for large datasets, and may lead to out-of-memory errors. It is more efficient to use the .filter() method provided by the datasets library, which processes the data in a streaming fashion without loading everything into memory at once. Consider refactoring filter_instances to work directly with Dataset objects.

return Dataset.from_list(dataset)
3 changes: 3 additions & 0 deletions ais_bench/benchmark/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from ais_bench.benchmark.tasks.openicl_eval import * # noqa: F401, F403
from ais_bench.benchmark.tasks.openicl_infer import * # noqa: F401, F403
from ais_bench.benchmark.tasks.openicl_api_infer import OpenICLApiInferTask
from ais_bench.benchmark.tasks.swebench_infer import SWEBenchInferTask
from ais_bench.benchmark.tasks.swebench_eval import SWEBenchEvalTask

159 changes: 159 additions & 0 deletions ais_bench/benchmark/tasks/swebench_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import argparse
import json
import os
import os.path as osp
import sys
import threading
import time

from mmengine.config import Config, ConfigDict
from mmengine.utils import mkdir_or_exist

from ais_bench.benchmark.registry import TASKS
from ais_bench.benchmark.tasks.base import BaseTask, TaskStateManager
from ais_bench.benchmark.utils.core.abbr import (
get_infer_output_path,
task_abbr_from_cfg,
)
from ais_bench.benchmark.utils.logging import AISLogger


@TASKS.register_module()
class SWEBenchEvalTask(BaseTask):
"""SWEBench Evaluation Task.

Evaluates SWE-bench predictions using the official harness and writes
results to work_dir/results.
"""

name_prefix = "SWEBenchEval"
log_subdir = "logs/eval"
output_subdir = "results"

def __init__(self, cfg: ConfigDict):
super().__init__(cfg)

def get_command(self, cfg_path: str, template: str) -> str:
sys.path.append(os.getcwd())
script_path = __file__
python = sys.executable
command = f"{python} {script_path} {cfg_path}"
return template.format(task_cmd=command)

def run(self, task_state_manager: TaskStateManager):
self.task_state_manager = task_state_manager
self.logger.info("SWEBenchEvalTask %s", task_abbr_from_cfg(self.cfg))

dataset_cfg = self.dataset_cfgs[0]
dataset_name = dataset_cfg.get("name", "lite")

pred_path = get_infer_output_path(
self.model_cfg,
dataset_cfg,
osp.join(self.work_dir, "predictions"),
file_extension="jsonl",
)
if not osp.isfile(pred_path):
raise FileNotFoundError(
f"Predictions file not found: {pred_path}. Run infer first."
)

out_path = get_infer_output_path(
self.model_cfg,
dataset_cfg,
osp.join(self.work_dir, self.output_subdir),
file_extension="json",
)
mkdir_or_exist(osp.dirname(out_path))

task_state_manager.update_task_state(
{"status": "eval", "progress_description": "SWE-bench harness"}
)

try:
import swebench.harness.run_evaluation as run_eval
except ImportError as e:
raise ImportError(
"SWEBenchEvalTask requires the SWE-bench harness. "
"Install from: https://github.com/princeton-nlp/SWE-bench"
) from e

run_id = task_abbr_from_cfg(self.cfg).replace("/", "_")
eval_runner = self.cfg.get("eval", {}).get("runner", {})
max_workers = eval_runner.get("max_num_workers", 4)
report_dir = osp.dirname(out_path)

try:
run_eval.main(
dataset_name=dataset_name,
split="test",
instance_ids=[],
predictions_path=pred_path,
max_workers=max_workers,
force_rebuild=False,
cache_level="env",
clean=False,
open_file_limit=4096,
run_id=run_id,
timeout=1800,
namespace=None,
rewrite_reports=False,
modal=False,
report_dir=report_dir,
)
harness_exit = 0
except SystemExit as e:
harness_exit = e.code if e.code is not None else 1
except Exception as e:
self.logger.exception("Harness failed: %s", e)
harness_exit = 1

results = {
"harness_exit_code": harness_exit,
"dataset_name": dataset_name,
"predictions_path": pred_path,
"run_id": run_id,
}
with open(out_path, "w") as f:
json.dump(results, f, indent=2)

if harness_exit != 0:
self.logger.warning("Harness exited with code %s", harness_exit)


def parse_args():
parser = argparse.ArgumentParser(description="SWEBench Eval")
parser.add_argument("config", help="Config file path")
return parser.parse_args()


if __name__ == "__main__":
logger = AISLogger()
args = parse_args()
cfg = Config.fromfile(args.config)
task_state_manager = TaskStateManager(
tmp_path=os.path.join(cfg["work_dir"], "status_tmp"),
task_name=task_abbr_from_cfg(cfg),
is_debug=cfg["cli_args"]["debug"],
)
manager_t = threading.Thread(target=task_state_manager.launch, args=())
manager_t.start()
task_state_manager.update_task_state(
{
"status": "start",
"task_log_path": os.path.join(
"logs/eval/", f"{task_abbr_from_cfg(cfg)}.out"
),
}
)
start_time = time.perf_counter()
try:
task = SWEBenchEvalTask(cfg)
task.run(task_state_manager)
except Exception as e:
task_state_manager.update_task_state({"status": "error"})
raise
end_time = time.perf_counter()
logger.info("SWEBench eval time: %.2fs", end_time - start_time)
task_state_manager.update_task_state({"status": "finish"})
manager_t.join()
Loading
Loading