-
Notifications
You must be signed in to change notification settings - Fork 19
【Feature】SWEBench Support #191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,69 @@ | ||||||
| import re | ||||||
| import random | ||||||
| from datasets import load_dataset, Dataset, DatasetDict | ||||||
|
|
||||||
| from ais_bench.benchmark.registry import LOAD_DATASET | ||||||
| from ais_bench.benchmark.utils.logging.exceptions import ParameterValueError | ||||||
| from ais_bench.benchmark.utils.logging.error_codes import DSET_CODES | ||||||
| from ais_bench.benchmark.datasets.base import BaseDataset | ||||||
|
|
||||||
| DATASET_MAPPING = { | ||||||
| "full": "princeton-nlp/SWE-Bench", | ||||||
| "verified": "princeton-nlp/SWE-Bench_Verified", | ||||||
| "lite": "princeton-nlp/SWE-Bench_Lite", | ||||||
| "multimodal": "princeton-nlp/SWE-Bench_Multimodal", | ||||||
| "multilingual": "swe-bench/SWE-Bench_Multilingual", | ||||||
| } | ||||||
|
|
||||||
|
|
||||||
| @LOAD_DATASET.register_module() | ||||||
| class SWEBenchDataset(BaseDataset): | ||||||
| def filter_instances( | ||||||
| self, instances: list[dict], *, filter_spec: str, shuffle: bool = False | ||||||
| ) -> list[dict]: | ||||||
| """Filter and slice a list of SWEBench instances.""" | ||||||
| if shuffle: | ||||||
| instances = sorted(instances.copy(), key=lambda x: x["instance_id"]) | ||||||
| random.seed(42) | ||||||
| random.shuffle(instances) | ||||||
| before_filter = len(instances) | ||||||
| instances = [ | ||||||
| instance | ||||||
| for instance in instances | ||||||
| if re.match(filter_spec, instance["instance_id"]) | ||||||
| ] | ||||||
| if (after_filter := len(instances)) != before_filter: | ||||||
| self.logger.info( | ||||||
| f"Instance filter: {before_filter} -> {after_filter} instances" | ||||||
| ) | ||||||
| return instances | ||||||
|
|
||||||
| def load( | ||||||
| self, | ||||||
| path: str, | ||||||
| name: str, | ||||||
| split: str = "test", | ||||||
| filter_spec: str = "", | ||||||
| shuffle: bool = False, | ||||||
| **kwargs, | ||||||
| ): | ||||||
| if name not in DATASET_MAPPING: | ||||||
| raise ParameterValueError( | ||||||
| DSET_CODES.INVALID_PARAM_VALUE, | ||||||
| f"Invalid swebench dataset name, expected one of {list(DATASET_MAPPING.keys())} but got {name}", | ||||||
| ) | ||||||
| try: | ||||||
| dataset = load_dataset("parquet", data_files={split: path}) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function
Suggested change
|
||||||
| except Exception as e: | ||||||
| self.logger.warning( | ||||||
| f"Failed to load swebench dataset {name} from {path} with error: {e}, trying to load from Hugging Face" | ||||||
| ) | ||||||
| try: | ||||||
| dataset = load_dataset(DATASET_MAPPING[name], split=split) | ||||||
| except Exception as e: | ||||||
| raise ParameterValueError( | ||||||
| DSET_CODES.DATA_PREPROCESSING_ERROR, | ||||||
| f"Failed to load swebench dataset {name} from Hugging Face with error: {e}.", | ||||||
| ) | ||||||
| dataset = self.filter_instances(list(dataset), filter_spec=filter_spec, shuffle=shuffle) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Converting the entire dataset to a list using |
||||||
| return Dataset.from_list(dataset) | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,6 @@ | ||
| from ais_bench.benchmark.tasks.openicl_eval import * # noqa: F401, F403 | ||
| from ais_bench.benchmark.tasks.openicl_infer import * # noqa: F401, F403 | ||
| from ais_bench.benchmark.tasks.openicl_api_infer import OpenICLApiInferTask | ||
| from ais_bench.benchmark.tasks.swebench_infer import SWEBenchInferTask | ||
| from ais_bench.benchmark.tasks.swebench_eval import SWEBenchEvalTask | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,159 @@ | ||
| import argparse | ||
| import json | ||
| import os | ||
| import os.path as osp | ||
| import sys | ||
| import threading | ||
| import time | ||
|
|
||
| from mmengine.config import Config, ConfigDict | ||
| from mmengine.utils import mkdir_or_exist | ||
|
|
||
| from ais_bench.benchmark.registry import TASKS | ||
| from ais_bench.benchmark.tasks.base import BaseTask, TaskStateManager | ||
| from ais_bench.benchmark.utils.core.abbr import ( | ||
| get_infer_output_path, | ||
| task_abbr_from_cfg, | ||
| ) | ||
| from ais_bench.benchmark.utils.logging import AISLogger | ||
|
|
||
|
|
||
| @TASKS.register_module() | ||
| class SWEBenchEvalTask(BaseTask): | ||
| """SWEBench Evaluation Task. | ||
|
|
||
| Evaluates SWE-bench predictions using the official harness and writes | ||
| results to work_dir/results. | ||
| """ | ||
|
|
||
| name_prefix = "SWEBenchEval" | ||
| log_subdir = "logs/eval" | ||
| output_subdir = "results" | ||
|
|
||
| def __init__(self, cfg: ConfigDict): | ||
| super().__init__(cfg) | ||
|
|
||
| def get_command(self, cfg_path: str, template: str) -> str: | ||
| sys.path.append(os.getcwd()) | ||
| script_path = __file__ | ||
| python = sys.executable | ||
| command = f"{python} {script_path} {cfg_path}" | ||
| return template.format(task_cmd=command) | ||
|
|
||
| def run(self, task_state_manager: TaskStateManager): | ||
| self.task_state_manager = task_state_manager | ||
| self.logger.info("SWEBenchEvalTask %s", task_abbr_from_cfg(self.cfg)) | ||
|
|
||
| dataset_cfg = self.dataset_cfgs[0] | ||
| dataset_name = dataset_cfg.get("name", "lite") | ||
|
|
||
| pred_path = get_infer_output_path( | ||
| self.model_cfg, | ||
| dataset_cfg, | ||
| osp.join(self.work_dir, "predictions"), | ||
| file_extension="jsonl", | ||
| ) | ||
| if not osp.isfile(pred_path): | ||
| raise FileNotFoundError( | ||
| f"Predictions file not found: {pred_path}. Run infer first." | ||
| ) | ||
|
|
||
| out_path = get_infer_output_path( | ||
| self.model_cfg, | ||
| dataset_cfg, | ||
| osp.join(self.work_dir, self.output_subdir), | ||
| file_extension="json", | ||
| ) | ||
| mkdir_or_exist(osp.dirname(out_path)) | ||
|
|
||
| task_state_manager.update_task_state( | ||
| {"status": "eval", "progress_description": "SWE-bench harness"} | ||
| ) | ||
|
|
||
| try: | ||
| import swebench.harness.run_evaluation as run_eval | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "SWEBenchEvalTask requires the SWE-bench harness. " | ||
| "Install from: https://github.com/princeton-nlp/SWE-bench" | ||
| ) from e | ||
|
|
||
| run_id = task_abbr_from_cfg(self.cfg).replace("/", "_") | ||
| eval_runner = self.cfg.get("eval", {}).get("runner", {}) | ||
| max_workers = eval_runner.get("max_num_workers", 4) | ||
| report_dir = osp.dirname(out_path) | ||
|
|
||
| try: | ||
| run_eval.main( | ||
| dataset_name=dataset_name, | ||
| split="test", | ||
| instance_ids=[], | ||
| predictions_path=pred_path, | ||
| max_workers=max_workers, | ||
| force_rebuild=False, | ||
| cache_level="env", | ||
| clean=False, | ||
| open_file_limit=4096, | ||
| run_id=run_id, | ||
| timeout=1800, | ||
| namespace=None, | ||
| rewrite_reports=False, | ||
| modal=False, | ||
| report_dir=report_dir, | ||
| ) | ||
| harness_exit = 0 | ||
| except SystemExit as e: | ||
| harness_exit = e.code if e.code is not None else 1 | ||
| except Exception as e: | ||
| self.logger.exception("Harness failed: %s", e) | ||
| harness_exit = 1 | ||
|
|
||
| results = { | ||
| "harness_exit_code": harness_exit, | ||
| "dataset_name": dataset_name, | ||
| "predictions_path": pred_path, | ||
| "run_id": run_id, | ||
| } | ||
| with open(out_path, "w") as f: | ||
| json.dump(results, f, indent=2) | ||
|
|
||
| if harness_exit != 0: | ||
| self.logger.warning("Harness exited with code %s", harness_exit) | ||
|
|
||
|
|
||
| def parse_args(): | ||
| parser = argparse.ArgumentParser(description="SWEBench Eval") | ||
| parser.add_argument("config", help="Config file path") | ||
| return parser.parse_args() | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| logger = AISLogger() | ||
| args = parse_args() | ||
| cfg = Config.fromfile(args.config) | ||
| task_state_manager = TaskStateManager( | ||
| tmp_path=os.path.join(cfg["work_dir"], "status_tmp"), | ||
| task_name=task_abbr_from_cfg(cfg), | ||
| is_debug=cfg["cli_args"]["debug"], | ||
| ) | ||
| manager_t = threading.Thread(target=task_state_manager.launch, args=()) | ||
| manager_t.start() | ||
| task_state_manager.update_task_state( | ||
| { | ||
| "status": "start", | ||
| "task_log_path": os.path.join( | ||
| "logs/eval/", f"{task_abbr_from_cfg(cfg)}.out" | ||
| ), | ||
| } | ||
| ) | ||
| start_time = time.perf_counter() | ||
| try: | ||
| task = SWEBenchEvalTask(cfg) | ||
| task.run(task_state_manager) | ||
| except Exception as e: | ||
| task_state_manager.update_task_state({"status": "error"}) | ||
| raise | ||
| end_time = time.perf_counter() | ||
| logger.info("SWEBench eval time: %.2fs", end_time - start_time) | ||
| task_state_manager.update_task_state({"status": "finish"}) | ||
| manager_t.join() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The recursive call for list items can be simplified. Instead of checking the type of the item before the recursive call, you can just call
recur_convert_config_typeon every item. The function already handles non-container types by returning them as is.