Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions evalbench/evaluator/dataagentevaluator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import traceback
from typing import Any, List
import datetime
from work import promptgenwork
Expand Down Expand Up @@ -96,8 +97,6 @@ def evaluate(
try:
result = future.result()
except Exception as exc:
import traceback

print(traceback.format_exc())
print(f"A task generated an exception: {exc}")

Expand Down
1 change: 0 additions & 1 deletion evalbench/evaluator/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging

import time
from typing import Any, List
import datetime
Expand Down
8 changes: 3 additions & 5 deletions evalbench/test/mongodb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
import json
import sys
import os
from databases import mongodb


# ---------------------------------------------------------------------------
# Shared fixture
# ---------------------------------------------------------------------------

@pytest.fixture(scope="module")
def client():
"""MongoDB client backed by mongomock, seeded with e-commerce documents."""
Expand All @@ -21,13 +21,11 @@ def client():
"max_executions_per_minute": 100,
"connection_string": "mongodb://mock-host:27017",
}

from databases import mongodb

# Directly use mongomock.MongoClient instead of patching
# This avoids issues with where MongoClient is imported
mock_client = mongomock.MongoClient("mongodb://mock-host:27017")
original_client = mongodb.MongoClient
mongodb.MongoClient = lambda *args, **kwargs: mock_client

try:
db = get_database(db_config, "unit_test_db")

Expand Down
29 changes: 29 additions & 0 deletions evalbench/test/robustness_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,35 @@
import unittest


class TestExecutionBugs(unittest.TestCase):

def test_sqlexecwork_handles_empty_query_safely(self):
db = MagicMock()
db_queue = Queue()
eval_result = {
"sql_generator_error": None,
"generated_sql": " ",
"query_type": "dql",
"eval_query": [],
"golden_sql": "",
"preprocess_sql": []
}
config = {
"prompt_generator": "NOOPGenerator",
"dialect": "sqlite"
}

work = SQLExecWork(db, config, eval_result, db_queue)

# Should not raise "list index out of range"
result = work.run()

self.assertIsNone(result.get("generated_result"))
self.assertEqual(
result.get("generated_error"),
"list index out of range (empty query)")


class TestStability(unittest.TestCase):

def test_rate_limit_guaranteed_release(self):
Expand Down
60 changes: 20 additions & 40 deletions evalbench/util/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import datetime
import logging
import os
Expand Down Expand Up @@ -31,7 +32,10 @@ def load_db_data_from_csvs(data_directory: str):
current_directory = os.getcwd()
if not os.path.isdir(os.path.join(current_directory, data_directory)):
return tables
for filename in os.listdir(os.path.join(current_directory, data_directory)):
for filename in os.listdir(
os.path.join(
current_directory,
data_directory)):
if filename.endswith(".csv"):
table_name = filename[:-4]
with open(
Expand Down Expand Up @@ -68,10 +72,10 @@ def load_setup_scripts(setup_scripts_directory_path: str):
current_directory, setup_scripts_directory_path, "post_setup.json"
)
if os.path.exists(post_setup_json_path):
import json

with open(post_setup_json_path, "r") as f:
# Load as list of dicts, then convert back to strings for batch_execute
# Load as list of dicts, then convert back to strings for
# batch_execute
try:
data = json.load(f)
if isinstance(data, list):
Expand All @@ -83,8 +87,9 @@ def load_setup_scripts(setup_scripts_directory_path: str):
else:
post_setup = _load_setup_sql(
os.path.join(
current_directory, setup_scripts_directory_path, "post_setup.sql"
),
current_directory,
setup_scripts_directory_path,
"post_setup.sql"),
)
return (pre_setup, setup, post_setup)

Expand Down Expand Up @@ -125,42 +130,11 @@ def config_to_df(
}
)
df = pd.DataFrame.from_dict(configs)
df[["job_id", "config", "value"]] = df[["job_id", "config", "value"]].astype(
"string"
)
df[["job_id", "config", "value"]] = df[[
"job_id", "config", "value"]].astype("string")
return df


def df_to_config(df: pd.DataFrame) -> dict:
import ast

original_dict = {}

for _, row in df.iterrows():
key_path = row["config"]
value_str = row["value"]

try:
if pd.isna(value_str):
value = None
else:
value = ast.literal_eval(value_str)
except (ValueError, SyntaxError, TypeError):
value = value_str

keys = key_path.split(".")

current_level = original_dict
for key in keys[:-1]:
if key not in current_level:
current_level[key] = {}
current_level = current_level[key]

current_level[keys[-1]] = value

return original_dict


def update_google3_relative_paths(
experiment_config: dict, session_id: str, resource_map: dict
):
Expand All @@ -171,7 +145,8 @@ def update_google3_relative_paths(
elif isinstance(value, list):
values = []
for sub_value in value:
if isinstance(sub_value, str) and sub_value.startswith("google3/"):
if isinstance(sub_value,
str) and sub_value.startswith("google3/"):
values.append(get_google3_relative_path(
sub_value, session_id))
elif isinstance(sub_value, str) and sub_value in resource_map:
Expand Down Expand Up @@ -208,7 +183,12 @@ def get_google3_relative_path(value, session_id):
def set_session_configs(session, experiment_config: dict):
session["config"] = experiment_config
if "dataset_config" in experiment_config and experiment_config["dataset_config"]:
session["dataset_config"] = experiment_config["dataset_config"]
# Handle both flat string paths and nested dicts (e.g. BIRD configs)
dc = experiment_config["dataset_config"]
if isinstance(dc, dict) and "prompts_file" in dc:
session["dataset_config"] = dc["prompts_file"]
else:
session["dataset_config"] = dc
if (
"database_configs" in experiment_config
and experiment_config["database_configs"]
Expand Down
41 changes: 25 additions & 16 deletions evalbench/work/sqlexecwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,27 @@ def _run_inner(self, work_config: Any = None) -> dict:
golden_eval_result = None
golden_error = None

query_type = self.eval_result["query_type"]
eval_query = self._get_eval_query()
preprocess_sql = self._get_preprocess_sql_query()
golden_sql = self._get_golden_sql()

if golden_sql:
golden_result, golden_eval_result, golden_error = (
self._evaluate_execution_results(
golden_sql,
preprocess_sql,
eval_query,
query_type,
is_golden=True,
)
)

if (
self.eval_result["sql_generator_error"] is None
and self.eval_result["generated_sql"]
and self.eval_result.get("generated_sql")
):
query_type = self.eval_result["query_type"]
eval_query = self._get_eval_query()
sanitized_generated_sql = self._sanitize_sql()
preprocess_sql = self._get_preprocess_sql_query()
golden_sql = self._get_golden_sql()

if sanitized_generated_sql:
generated_result, generated_eval_result, generated_error = (
self._evaluate_execution_results(
Expand All @@ -66,15 +77,6 @@ def _run_inner(self, work_config: Any = None) -> dict:
is_golden=False,
)
)
golden_result, golden_eval_result, golden_error = (
self._evaluate_execution_results(
golden_sql,
preprocess_sql,
eval_query,
query_type,
is_golden=True,
)
)

self.eval_result["generated_result"] = generated_result
self.eval_result["eval_results"] = generated_eval_result
Expand All @@ -96,10 +98,17 @@ def _evaluate_execution_results(
self.db.execute(preprocess_sql)
except Exception as preprocess_error:
traceback.print_exc()

if not query or not query.strip():
return None, None, "list index out of range (empty query)"

if query_type == "dql":
try:
stmts = sqlparse.split(query)
if not stmts:
return None, None, "list index out of range (empty query)"
result, _, error = self.db.execute(
sqlparse.split(query)[0], use_cache=True, rollback=True
stmts[0], use_cache=True, rollback=True
)
except Exception as e:
error = str(e)
Expand Down
3 changes: 1 addition & 2 deletions evalbench/work/sqlgenquerydatawork.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import traceback
"""Work is the base class for all work items."""

from typing import Any
Expand Down Expand Up @@ -30,8 +31,6 @@ def run(self, work_config: str = None) -> dict:
self.eval_result["generated_sql"] = None
self.eval_result["sql_generator_error"] = "No result generated"
except Exception as e:
import traceback

traceback.print_exc()
sql_generator_error = str(e)

Expand Down
Loading