diff --git a/evalbench/evaluator/dataagentevaluator.py b/evalbench/evaluator/dataagentevaluator.py index 86ccda4..cc6418d 100644 --- a/evalbench/evaluator/dataagentevaluator.py +++ b/evalbench/evaluator/dataagentevaluator.py @@ -1,3 +1,4 @@ +import traceback from typing import Any, List import datetime from work import promptgenwork @@ -96,8 +97,6 @@ def evaluate( try: result = future.result() except Exception as exc: - import traceback - print(traceback.format_exc()) print(f"A task generated an exception: {exc}") diff --git a/evalbench/evaluator/evaluator.py b/evalbench/evaluator/evaluator.py index 2d4d26b..3d9e31b 100644 --- a/evalbench/evaluator/evaluator.py +++ b/evalbench/evaluator/evaluator.py @@ -1,5 +1,4 @@ import logging - import time from typing import Any, List import datetime diff --git a/evalbench/test/mongodb_test.py b/evalbench/test/mongodb_test.py index 88a101c..8ad4867 100644 --- a/evalbench/test/mongodb_test.py +++ b/evalbench/test/mongodb_test.py @@ -5,12 +5,12 @@ import json import sys import os +from databases import mongodb # --------------------------------------------------------------------------- # Shared fixture # --------------------------------------------------------------------------- - @pytest.fixture(scope="module") def client(): """MongoDB client backed by mongomock, seeded with e-commerce documents.""" @@ -21,13 +21,11 @@ def client(): "max_executions_per_minute": 100, "connection_string": "mongodb://mock-host:27017", } - - from databases import mongodb - + # Directly use mongomock.MongoClient instead of patching + # This avoids issues with where MongoClient is imported mock_client = mongomock.MongoClient("mongodb://mock-host:27017") original_client = mongodb.MongoClient mongodb.MongoClient = lambda *args, **kwargs: mock_client - try: db = get_database(db_config, "unit_test_db") diff --git a/evalbench/test/robustness_test.py b/evalbench/test/robustness_test.py index 727e563..1bc6dd2 100644 --- a/evalbench/test/robustness_test.py +++ b/evalbench/test/robustness_test.py @@ -8,6 +8,35 @@ import unittest +class TestExecutionBugs(unittest.TestCase): + + def test_sqlexecwork_handles_empty_query_safely(self): + db = MagicMock() + db_queue = Queue() + eval_result = { + "sql_generator_error": None, + "generated_sql": " ", + "query_type": "dql", + "eval_query": [], + "golden_sql": "", + "preprocess_sql": [] + } + config = { + "prompt_generator": "NOOPGenerator", + "dialect": "sqlite" + } + + work = SQLExecWork(db, config, eval_result, db_queue) + + # Should not raise "list index out of range" + result = work.run() + + self.assertIsNone(result.get("generated_result")) + self.assertEqual( + result.get("generated_error"), + "list index out of range (empty query)") + + class TestStability(unittest.TestCase): def test_rate_limit_guaranteed_release(self): diff --git a/evalbench/util/config.py b/evalbench/util/config.py index 5b86c37..832e85c 100644 --- a/evalbench/util/config.py +++ b/evalbench/util/config.py @@ -1,3 +1,4 @@ +import json import datetime import logging import os @@ -31,7 +32,10 @@ def load_db_data_from_csvs(data_directory: str): current_directory = os.getcwd() if not os.path.isdir(os.path.join(current_directory, data_directory)): return tables - for filename in os.listdir(os.path.join(current_directory, data_directory)): + for filename in os.listdir( + os.path.join( + current_directory, + data_directory)): if filename.endswith(".csv"): table_name = filename[:-4] with open( @@ -68,10 +72,10 @@ def load_setup_scripts(setup_scripts_directory_path: str): current_directory, setup_scripts_directory_path, "post_setup.json" ) if os.path.exists(post_setup_json_path): - import json with open(post_setup_json_path, "r") as f: - # Load as list of dicts, then convert back to strings for batch_execute + # Load as list of dicts, then convert back to strings for + # batch_execute try: data = json.load(f) if isinstance(data, list): @@ -83,8 +87,9 @@ def load_setup_scripts(setup_scripts_directory_path: str): else: post_setup = _load_setup_sql( os.path.join( - current_directory, setup_scripts_directory_path, "post_setup.sql" - ), + current_directory, + setup_scripts_directory_path, + "post_setup.sql"), ) return (pre_setup, setup, post_setup) @@ -125,42 +130,11 @@ def config_to_df( } ) df = pd.DataFrame.from_dict(configs) - df[["job_id", "config", "value"]] = df[["job_id", "config", "value"]].astype( - "string" - ) + df[["job_id", "config", "value"]] = df[[ + "job_id", "config", "value"]].astype("string") return df -def df_to_config(df: pd.DataFrame) -> dict: - import ast - - original_dict = {} - - for _, row in df.iterrows(): - key_path = row["config"] - value_str = row["value"] - - try: - if pd.isna(value_str): - value = None - else: - value = ast.literal_eval(value_str) - except (ValueError, SyntaxError, TypeError): - value = value_str - - keys = key_path.split(".") - - current_level = original_dict - for key in keys[:-1]: - if key not in current_level: - current_level[key] = {} - current_level = current_level[key] - - current_level[keys[-1]] = value - - return original_dict - - def update_google3_relative_paths( experiment_config: dict, session_id: str, resource_map: dict ): @@ -171,7 +145,8 @@ def update_google3_relative_paths( elif isinstance(value, list): values = [] for sub_value in value: - if isinstance(sub_value, str) and sub_value.startswith("google3/"): + if isinstance(sub_value, + str) and sub_value.startswith("google3/"): values.append(get_google3_relative_path( sub_value, session_id)) elif isinstance(sub_value, str) and sub_value in resource_map: @@ -208,7 +183,12 @@ def get_google3_relative_path(value, session_id): def set_session_configs(session, experiment_config: dict): session["config"] = experiment_config if "dataset_config" in experiment_config and experiment_config["dataset_config"]: - session["dataset_config"] = experiment_config["dataset_config"] + # Handle both flat string paths and nested dicts (e.g. BIRD configs) + dc = experiment_config["dataset_config"] + if isinstance(dc, dict) and "prompts_file" in dc: + session["dataset_config"] = dc["prompts_file"] + else: + session["dataset_config"] = dc if ( "database_configs" in experiment_config and experiment_config["database_configs"] diff --git a/evalbench/work/sqlexecwork.py b/evalbench/work/sqlexecwork.py index ee56b07..26104eb 100644 --- a/evalbench/work/sqlexecwork.py +++ b/evalbench/work/sqlexecwork.py @@ -46,16 +46,27 @@ def _run_inner(self, work_config: Any = None) -> dict: golden_eval_result = None golden_error = None + query_type = self.eval_result["query_type"] + eval_query = self._get_eval_query() + preprocess_sql = self._get_preprocess_sql_query() + golden_sql = self._get_golden_sql() + + if golden_sql: + golden_result, golden_eval_result, golden_error = ( + self._evaluate_execution_results( + golden_sql, + preprocess_sql, + eval_query, + query_type, + is_golden=True, + ) + ) + if ( self.eval_result["sql_generator_error"] is None - and self.eval_result["generated_sql"] + and self.eval_result.get("generated_sql") ): - query_type = self.eval_result["query_type"] - eval_query = self._get_eval_query() sanitized_generated_sql = self._sanitize_sql() - preprocess_sql = self._get_preprocess_sql_query() - golden_sql = self._get_golden_sql() - if sanitized_generated_sql: generated_result, generated_eval_result, generated_error = ( self._evaluate_execution_results( @@ -66,15 +77,6 @@ def _run_inner(self, work_config: Any = None) -> dict: is_golden=False, ) ) - golden_result, golden_eval_result, golden_error = ( - self._evaluate_execution_results( - golden_sql, - preprocess_sql, - eval_query, - query_type, - is_golden=True, - ) - ) self.eval_result["generated_result"] = generated_result self.eval_result["eval_results"] = generated_eval_result @@ -96,10 +98,17 @@ def _evaluate_execution_results( self.db.execute(preprocess_sql) except Exception as preprocess_error: traceback.print_exc() + + if not query or not query.strip(): + return None, None, "list index out of range (empty query)" + if query_type == "dql": try: + stmts = sqlparse.split(query) + if not stmts: + return None, None, "list index out of range (empty query)" result, _, error = self.db.execute( - sqlparse.split(query)[0], use_cache=True, rollback=True + stmts[0], use_cache=True, rollback=True ) except Exception as e: error = str(e) diff --git a/evalbench/work/sqlgenquerydatawork.py b/evalbench/work/sqlgenquerydatawork.py index 7b05519..e4c6801 100644 --- a/evalbench/work/sqlgenquerydatawork.py +++ b/evalbench/work/sqlgenquerydatawork.py @@ -1,3 +1,4 @@ +import traceback """Work is the base class for all work items.""" from typing import Any @@ -30,8 +31,6 @@ def run(self, work_config: str = None) -> dict: self.eval_result["generated_sql"] = None self.eval_result["sql_generator_error"] = "No result generated" except Exception as e: - import traceback - traceback.print_exc() sql_generator_error = str(e)