Skip to content

Commit 67e7bf3

Browse files
committed
fix: handle empty queries safely, ensure golden execution, and parse config robustly
1 parent dcb8bf6 commit 67e7bf3

10 files changed

Lines changed: 100 additions & 63 deletions

evalbench/evaluator/dataagentevaluator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import traceback
12
from typing import Any, List
23
import datetime
34
from work import promptgenwork
@@ -96,7 +97,7 @@ def evaluate(
9697
try:
9798
result = future.result()
9899
except Exception as exc:
99-
import traceback
100+
100101

101102
print(traceback.format_exc())
102103
print(f"A task generated an exception: {exc}")

evalbench/evaluator/evaluator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import Any, List
23
import datetime
34
from util import truncateExecutionOutputs

evalbench/evaluator/oneshotorchestrator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
import logging
12
import concurrent.futures
23
import datetime
34
import json
4-
import logging
5+
56
import tempfile
67
import threading
78
import uuid

evalbench/evaluator/progress_reporter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import logging
2+
import os
3+
24
from multiprocessing.managers import SyncManager
35
import sys
46
import threading

evalbench/test/mongodb_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import json
66
import sys
77
import os
8+
from evalbench.databases import mongodb
89

910
sys.path.append(os.path.abspath(
1011
os.path.join(os.path.dirname(__file__), "../..")))
@@ -25,7 +26,6 @@ def client():
2526

2627
# Directly use mongomock.MongoClient instead of patching
2728
# This avoids issues with where MongoClient is imported
28-
from databases import mongodb
2929

3030
# Create a mock client
3131
mock_client = mongomock.MongoClient("mongodb://mock-host:27017")
@@ -71,7 +71,8 @@ def test_aggregate(self, client):
7171
"""Tests aggregation query."""
7272
# Data already inserted in previous test (session scope fixture, but we might want to clean up)
7373
# For safety, let's insert again or assume persistence.
74-
# mongomock is in-memory, so it persists for the session if not cleared.
74+
# mongomock is in-memory, so it persists for the session if not
75+
# cleared.
7576

7677
query = json.dumps(
7778
{

evalbench/test/robustness_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import time
2+
from queue import Queue
3+
from unittest.mock import MagicMock
4+
from work.sqlexecwork import SQLExecWork
5+
import unittest
6+
7+
8+
class TestExecutionBugs(unittest.TestCase):
9+
10+
def test_sqlexecwork_handles_empty_query_safely(self):
11+
db = MagicMock()
12+
db_queue = Queue()
13+
eval_result = {
14+
"sql_generator_error": None,
15+
"generated_sql": " ",
16+
"query_type": "dql",
17+
"eval_query": [],
18+
"golden_sql": "",
19+
"preprocess_sql": []
20+
}
21+
config = {
22+
"prompt_generator": "NOOPGenerator",
23+
"dialect": "sqlite"
24+
}
25+
26+
work = SQLExecWork(db, config, eval_result, db_queue)
27+
28+
# Should not raise "list index out of range"
29+
result = work.run()
30+
31+
self.assertIsNone(result.get("generated_result"))
32+
self.assertEqual(
33+
result.get("generated_error"),
34+
"list index out of range (empty query)")
35+
36+
37+
if __name__ == '__main__':
38+
unittest.main()

evalbench/util/config.py

Lines changed: 21 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import datetime
23
import logging
34
import os
@@ -31,7 +32,10 @@ def load_db_data_from_csvs(data_directory: str):
3132
current_directory = os.getcwd()
3233
if not os.path.isdir(os.path.join(current_directory, data_directory)):
3334
return tables
34-
for filename in os.listdir(os.path.join(current_directory, data_directory)):
35+
for filename in os.listdir(
36+
os.path.join(
37+
current_directory,
38+
data_directory)):
3539
if filename.endswith(".csv"):
3640
table_name = filename[:-4]
3741
with open(
@@ -68,10 +72,10 @@ def load_setup_scripts(setup_scripts_directory_path: str):
6872
current_directory, setup_scripts_directory_path, "post_setup.json"
6973
)
7074
if os.path.exists(post_setup_json_path):
71-
import json
7275

7376
with open(post_setup_json_path, "r") as f:
74-
# Load as list of dicts, then convert back to strings for batch_execute
77+
# Load as list of dicts, then convert back to strings for
78+
# batch_execute
7579
try:
7680
data = json.load(f)
7781
if isinstance(data, list):
@@ -83,8 +87,9 @@ def load_setup_scripts(setup_scripts_directory_path: str):
8387
else:
8488
post_setup = _load_setup_sql(
8589
os.path.join(
86-
current_directory, setup_scripts_directory_path, "post_setup.sql"
87-
),
90+
current_directory,
91+
setup_scripts_directory_path,
92+
"post_setup.sql"),
8893
)
8994
return (pre_setup, setup, post_setup)
9095

@@ -125,40 +130,9 @@ def config_to_df(
125130
}
126131
)
127132
df = pd.DataFrame.from_dict(configs)
128-
df[["job_id", "config", "value"]] = df[["job_id", "config", "value"]].astype(
129-
"string"
130-
)
131-
return df
132-
133-
134-
def df_to_config(df: pd.DataFrame) -> dict:
135-
import ast
136-
137-
original_dict = {}
138-
139-
for _, row in df.iterrows():
140-
key_path = row["config"]
141-
value_str = row["value"]
142-
143-
try:
144-
if pd.isna(value_str):
145-
value = None
146-
else:
147-
value = ast.literal_eval(value_str)
148-
except (ValueError, SyntaxError, TypeError):
149-
value = value_str
150-
151-
keys = key_path.split(".")
152-
153-
current_level = original_dict
154-
for key in keys[:-1]:
155-
if key not in current_level:
156-
current_level[key] = {}
157-
current_level = current_level[key]
158-
159-
current_level[keys[-1]] = value
160-
161-
return original_dict
133+
df[["job_id", "config", "value"]] = df[[
134+
"job_id", "config", "value"]].astype("string")
135+
return config
162136

163137

164138
def update_google3_relative_paths(
@@ -171,7 +145,8 @@ def update_google3_relative_paths(
171145
elif isinstance(value, list):
172146
values = []
173147
for sub_value in value:
174-
if isinstance(sub_value, str) and sub_value.startswith("google3/"):
148+
if isinstance(sub_value,
149+
str) and sub_value.startswith("google3/"):
175150
values.append(get_google3_relative_path(
176151
sub_value, session_id))
177152
elif isinstance(sub_value, str) and sub_value in resource_map:
@@ -208,7 +183,12 @@ def get_google3_relative_path(value, session_id):
208183
def set_session_configs(session, experiment_config: dict):
209184
session["config"] = experiment_config
210185
if "dataset_config" in experiment_config and experiment_config["dataset_config"]:
211-
session["dataset_config"] = experiment_config["dataset_config"]
186+
# Handle both flat string paths and nested dicts (e.g. BIRD configs)
187+
dc = experiment_config["dataset_config"]
188+
if isinstance(dc, dict) and "prompts_file" in dc:
189+
session["dataset_config"] = dc["prompts_file"]
190+
else:
191+
session["dataset_config"] = dc
212192
if (
213193
"database_configs" in experiment_config
214194
and experiment_config["database_configs"]

evalbench/work/sqlexecwork.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,27 @@ def run(self, work_config: Any = None) -> dict:
4040
golden_eval_result = None
4141
golden_error = None
4242

43+
query_type = self.eval_result["query_type"]
44+
eval_query = self._get_eval_query()
45+
preprocess_sql = self._get_preprocess_sql_query()
46+
golden_sql = self._get_golden_sql()
47+
48+
if golden_sql:
49+
golden_result, golden_eval_result, golden_error = (
50+
self._evaluate_execution_results(
51+
golden_sql,
52+
preprocess_sql,
53+
eval_query,
54+
query_type,
55+
is_golden=True,
56+
)
57+
)
58+
4359
if (
4460
self.eval_result["sql_generator_error"] is None
45-
and self.eval_result["generated_sql"]
61+
and self.eval_result.get("generated_sql")
4662
):
47-
query_type = self.eval_result["query_type"]
48-
eval_query = self._get_eval_query()
4963
sanitized_generated_sql = self._sanitize_sql()
50-
preprocess_sql = self._get_preprocess_sql_query()
51-
golden_sql = self._get_golden_sql()
52-
5364
if sanitized_generated_sql:
5465
generated_result, generated_eval_result, generated_error = (
5566
self._evaluate_execution_results(
@@ -60,15 +71,6 @@ def run(self, work_config: Any = None) -> dict:
6071
is_golden=False,
6172
)
6273
)
63-
golden_result, golden_eval_result, golden_error = (
64-
self._evaluate_execution_results(
65-
golden_sql,
66-
preprocess_sql,
67-
eval_query,
68-
query_type,
69-
is_golden=True,
70-
)
71-
)
7274

7375
self.eval_result["generated_result"] = generated_result
7476
self.eval_result["eval_results"] = generated_eval_result
@@ -91,10 +93,17 @@ def _evaluate_execution_results(
9193
self.db.execute(preprocess_sql)
9294
except Exception as preprocess_error:
9395
traceback.print_exc()
96+
97+
if not query or not query.strip():
98+
return None, None, "list index out of range (empty query)"
99+
94100
if query_type == "dql":
95101
try:
102+
stmts = sqlparse.split(query)
103+
if not stmts:
104+
return None, None, "list index out of range (empty query)"
96105
result, _, error = self.db.execute(
97-
sqlparse.split(query)[0], use_cache=True, rollback=True
106+
stmts[0], use_cache=True, rollback=True
98107
)
99108
except Exception as e:
100109
error = str(e)
@@ -143,7 +152,8 @@ def _get_golden_sql(self):
143152
return golden_sql
144153

145154
def _get_eval_query(self):
146-
if self.eval_result["eval_query"] and len(self.eval_result["eval_query"]) > 0:
155+
if self.eval_result["eval_query"] and len(
156+
self.eval_result["eval_query"]) > 0:
147157
return self.eval_result["eval_query"][0]
148158
else:
149159
return None

evalbench/work/sqlgenquerydatawork.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import traceback
12
"""Work is the base class for all work items."""
23

34
from typing import Any
@@ -30,7 +31,7 @@ def run(self, work_config: str = None) -> dict:
3031
self.eval_result["generated_sql"] = None
3132
self.eval_result["sql_generator_error"] = "No result generated"
3233
except Exception as e:
33-
import traceback
34+
3435

3536
traceback.print_exc()
3637
sql_generator_error = str(e)

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,5 @@ mongomock
3333
rich
3434
google-adk
3535
mcp
36+
pytest
37+
sqlparse

0 commit comments

Comments
 (0)