Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion alphasql/database/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from alphasql.database.sql_execution import execute_sql_without_timeout

MAX_LENGTH_COLUMN_DESCRIPTION = 100

def lower_str_list(str_list: List[Any]) -> List[Any]:
"""
Convert a list of strings or nested lists to a list of lowercase strings or nested lists.
Expand Down Expand Up @@ -158,7 +160,12 @@ def load_value_examples(db_id: str, database_root_dir: str, table_name: str, col
"""
db_path = Path(database_root_dir) / db_id / f"{db_id}.sqlite"
examples = execute_sql_without_timeout(db_path, f"SELECT DISTINCT `{column_name}` FROM `{table_name}` WHERE `{column_name}` IS NOT NULL AND `{column_name}` != '' LIMIT {max_num_examples};").result
return [example[0] for example in examples]
res = [example[0] for example in examples]
# print(res, flush=True)
if str(res) > MAX_LENGTH_COLUMN_DESCRIPTION:
return [res[0]]
else:
return res

def load_database_schema_dict(db_id: str, database_root_dir: str) -> Dict[str, Dict[str, Dict[str, Any]]]:
"""
Expand Down
158 changes: 158 additions & 0 deletions alphasql/runner/evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copy from https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/bird
# The way of finding files has been modified.

import sys
import json
import argparse
import sqlite3
import multiprocessing as mp
from func_timeout import func_timeout, FunctionTimedOut

def load_json(dir):
with open(dir, 'r') as j:
contents = json.loads(j.read())
return contents

def result_callback(result):
exec_result.append(result)


def execute_sql(predicted_sql,ground_truth, db_path):
conn = sqlite3.connect(db_path)
# Connect to the database
cursor = conn.cursor()
cursor.execute(predicted_sql)
predicted_res = cursor.fetchall()
cursor.execute(ground_truth)
ground_truth_res = cursor.fetchall()
res = 0
if set(predicted_res) == set(ground_truth_res):
res = 1
return res



def execute_model(predicted_sql,ground_truth, db_place, idx, meta_time_out):
try:
res = func_timeout(meta_time_out, execute_sql,
args=(predicted_sql, ground_truth, db_place))
except KeyboardInterrupt:
sys.exit(0)
except FunctionTimedOut:
result = [(f'timeout',)]
res = 0
except Exception as e:
result = [(f'error',)] # possibly len(query) > 512 or not executable
res = 0
# print(result)
# result = str(set([ret[0] for ret in result]))
result = {'sql_idx': idx, 'res': res}
# print(result)
return result


def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'):
clean_sqls = []
db_path_list = []
if mode == 'gpt':
# sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r'))
sql_data = json.load(open(sql_path, 'r'))
for idx, sql_str in sql_data.items():
if type(sql_str) == str:
sql, db_name = sql_str.split('\t----- bird -----\t')
else:
sql, db_name = " ", "financial"
clean_sqls.append(sql)
db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite')

elif mode == 'gt':
# sqls = open(sql_path + data_mode + '_gold.sql')
sqls = open(sql_path)
sql_txt = sqls.readlines()
# sql_txt = [sql.split('\t')[0] for sql in sql_txt]
for idx, sql_str in enumerate(sql_txt):
sql, db_name = sql_str.strip().split('\t')
clean_sqls.append(sql)
db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite')

return clean_sqls, db_path_list

def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0):
pool = mp.Pool(processes=num_cpus)
for i,sql_pair in enumerate(sqls):

predicted_sql, ground_truth = sql_pair
pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out), callback=result_callback)
pool.close()
pool.join()

def sort_results(list_of_dicts):
return sorted(list_of_dicts, key=lambda x: x['sql_idx'])

def compute_acc_by_diff(exec_results,diff_json_path):
num_queries = len(exec_results)
results = [res['res'] for res in exec_results]
contents = load_json(diff_json_path)
simple_results, moderate_results, challenging_results = [], [], []

for i,content in enumerate(contents):
if content['difficulty'] == 'simple':
simple_results.append(exec_results[i])

if content['difficulty'] == 'moderate':
moderate_results.append(exec_results[i])

if content['difficulty'] == 'challenging':
challenging_results.append(exec_results[i])

simple_acc = sum([res['res'] for res in simple_results])/len(simple_results)
moderate_acc = sum([res['res'] for res in moderate_results])/len(moderate_results)
challenging_acc = sum([res['res'] for res in challenging_results])/len(challenging_results)
all_acc = sum(results)/num_queries
count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries]
return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists



def print_data(score_lists,count_lists):
levels = ['simple', 'moderate', 'challenging', 'total']
print("{:20} {:20} {:20} {:20} {:20}".format("", *levels))
print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists))

print('====================================== ACCURACY =====================================')
print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists))


if __name__ == '__main__':
args_parser = argparse.ArgumentParser()
args_parser.add_argument('--predicted_sql_path', type=str, required=True, default='')
args_parser.add_argument('--ground_truth_path', type=str, required=True, default='')
args_parser.add_argument('--data_mode', type=str, required=True, default='dev')
args_parser.add_argument('--db_root_path', type=str, required=True, default='')
args_parser.add_argument('--num_cpus', type=int, default=1)
args_parser.add_argument('--meta_time_out', type=float, default=30.0)
args_parser.add_argument('--mode_gt', type=str, default='gt')
args_parser.add_argument('--mode_predict', type=str, default='gpt')
args_parser.add_argument('--difficulty',type=str,default='simple')
args_parser.add_argument('--diff_json_path',type=str,default='')
args = args_parser.parse_args()
exec_result = []

pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode='gpt',
data_mode=args.data_mode)
# generate gt sqls:
gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt',
data_mode=args.data_mode)

query_pairs = list(zip(pred_queries,gt_queries))
run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out)
exec_result = sort_results(exec_result)

print('start calculate')
simple_acc, moderate_acc, challenging_acc, acc, count_lists = \
compute_acc_by_diff(exec_result,args.diff_json_path)
score_lists = [simple_acc, moderate_acc, challenging_acc, acc]
print_data(score_lists,count_lists)
print('===========================================================================================')
print("Finished evaluation")

27 changes: 24 additions & 3 deletions alphasql/runner/sql_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@ def select_final_sql_query(results_file_path: str, db_root_dir: str):
question_id = int(results_file_path.split("/")[-1].split(".")[0])
with open(results_file_path, "rb") as f:
results = pickle.load(f)
if results is None or len(results) == 0:
return {
"question_id": question_id,
"db_id": "financial",
"sql": "ERROR"
}
db_id = results[0][0].db_id
db_path = f"{db_root_dir}/{db_id}/{db_id}.sqlite"
result_groups = defaultdict(list)
result_groups_with_invalid_result = defaultdict(list)

for idx, result in tqdm(enumerate(results), desc=f"Processing results for {question_id}"):
for idx, result in enumerate(results):
sql_query = result[-1].final_sql_query
answer = cached_execute_sql_with_timeout(db_path, sql_query)
if answer.result_type.value == "success":
Expand Down Expand Up @@ -53,6 +59,7 @@ def select_final_sql_query(results_file_path: str, db_root_dir: str):

return {
"question_id": question_id,
"db_id": db_id,
"sql": final_selected_sql_query
}

Expand Down Expand Up @@ -82,6 +89,7 @@ def select_final_sql_query(results_file_path: str, db_root_dir: str):

return {
"question_id": question_id,
"db_id": db_id,
"sql": final_selected_sql_query
}

Expand All @@ -93,10 +101,23 @@ def main(args):

for future in tqdm(as_completed(future_to_path), total=len(future_to_path), desc="Processing results"):
selected_item = future.result()
final_pred_sqls[str(selected_item["question_id"])] = selected_item["sql"]
final_pred_sqls[str(selected_item["question_id"])] = selected_item["sql"] + '\t----- bird -----\t' + selected_item["db_id"]

# Fill missing IDs with default value
if final_pred_sqls:
question_ids = [int(qid) for qid in final_pred_sqls.keys()]
min_id = min(question_ids)
max_id = max(question_ids)

for qid in range(min_id, max_id + 1):
if str(qid) not in final_pred_sqls:
final_pred_sqls[str(qid)] = "ERROR\t----- bird -----\tfinancial"

# Sort by question_id before output
sorted_pred_sqls = dict(sorted(final_pred_sqls.items(), key=lambda x: int(x[0])))

with open(args.output_path, "w") as f:
json.dump(final_pred_sqls, f, indent=4)
json.dump(sorted_pred_sqls, f, indent=4)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ loguru
tqdm
pyyaml
weave
sqlglot[rs]
sqlglot[rs]
prettytable
func_timeout
13 changes: 13 additions & 0 deletions script/evaluation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
db_root_path='data/bird/dev/dev_databases/'
data_mode='dev'
diff_json_path='data/bird/dev/dev.json'
predicted_sql_path_kg='results/pred_sqls_7B.json'
ground_truth_path='data/bird/dev/dev.sql'
num_cpus=16
meta_time_out=30.0
mode_gt='gt'
mode_predict='gpt'
output_path='results/evaluation_results_7B.txt'

echo "Evaluating SQLs..."
python3 -u -m alphasql.runner.evaluation --db_root_path ${db_root_path} --predicted_sql_path ${predicted_sql_path_kg} --data_mode ${data_mode} --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} > ${output_path}
20 changes: 17 additions & 3 deletions script/sql_selection.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,26 @@
DB_ROOT_DIR="data/bird/dev/dev_databases"
PROCESS_NUM=32

RESULTS_DIR="results/Qwen2.5-Coder-32B-Instruct/bird/dev"
OUTPUT_PATH="./pred_sqls.json"
RESULTS_DIR="results/Qwen2.5-Coder-7B-Instruct/bird/dev"
OUTPUT_PATH="results/pred_sqls_7B.json"

echo "Selecting SQLs..."
python -m alphasql.runner.sql_selection \
--results_dir $RESULTS_DIR \
--db_root_dir $DB_ROOT_DIR \
--process_num $PROCESS_NUM \
--output_path $OUTPUT_PATH
--output_path $OUTPUT_PATH

db_root_path='data/bird/dev/dev_databases/'
data_mode='dev'
diff_json_path='data/bird/dev/dev.json'
predicted_sql_path_kg='results/pred_sqls_7B.json'
ground_truth_path='data/bird/dev/dev.sql'
num_cpus=16
meta_time_out=30.0
mode_gt='gt'
mode_predict='gpt'
output_path='results/evaluation_results_7B.txt'

echo "Evaluating SQLs..."
python3 -u -m alphasql.runner.evaluation --db_root_path ${db_root_path} --predicted_sql_path ${predicted_sql_path_kg} --data_mode ${data_mode} --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} > ${output_path}