diff --git a/mas_arena/benchmark_runner.py b/mas_arena/benchmark_runner.py index 66d438d..10daeda 100644 --- a/mas_arena/benchmark_runner.py +++ b/mas_arena/benchmark_runner.py @@ -132,7 +132,7 @@ def _prepare_benchmark(self, benchmark_name, data_path, limit, agent_system, age primary_id = benchmark_config.get("normalization_keys", {}).get("id", None) if primary_id is not None: for problem in problems: - if str(problem["task_id"]) == data_id: + if str(problem[primary_id]) == data_id: problems = [problem] break @@ -394,10 +394,10 @@ def run(self, benchmark_name="math", data_path=None, limit=None, agent_system="s concurrency=1 # Run sequentially )) - async def arun(self, benchmark_name="math", data_path=None, limit=None, agent_system="single_agent", agent_config=None, verbose=True, concurrency=10): + async def arun(self, benchmark_name="math", data_path=None, limit=None, agent_system="single_agent", agent_config=None, verbose=True, data_id=None, concurrency=10): # Prepare benchmark; we only need problems and config here _, problems, benchmark_config, output_file = self._prepare_benchmark( - benchmark_name, data_path, limit, agent_system, agent_config, verbose + benchmark_name, data_path, limit, agent_system, agent_config, verbose, data_id ) if verbose: