From d28ee7e5075c5569a1032e28ba12dc1bdaca92b2 Mon Sep 17 00:00:00 2001 From: cbacode <1586230797@qq.com> Date: Sun, 23 Nov 2025 21:09:41 +0800 Subject: [PATCH 1/2] changes --- alphasql/database/utils.py | 9 +- alphasql/runner/evaluation.py | 153 +++++++++++++++++++++++++++++++ alphasql/runner/sql_selection.py | 27 +++++- requirements.txt | 4 +- 4 files changed, 188 insertions(+), 5 deletions(-) create mode 100644 alphasql/runner/evaluation.py diff --git a/alphasql/database/utils.py b/alphasql/database/utils.py index 65558cf..39e0f40 100644 --- a/alphasql/database/utils.py +++ b/alphasql/database/utils.py @@ -6,6 +6,8 @@ from alphasql.database.sql_execution import execute_sql_without_timeout +MAX_LENGTH_COLUMN_DESCRIPTION = 100 + def lower_str_list(str_list: List[Any]) -> List[Any]: """ Convert a list of strings or nested lists to a list of lowercase strings or nested lists. @@ -158,7 +160,12 @@ def load_value_examples(db_id: str, database_root_dir: str, table_name: str, col """ db_path = Path(database_root_dir) / db_id / f"{db_id}.sqlite" examples = execute_sql_without_timeout(db_path, f"SELECT DISTINCT `{column_name}` FROM `{table_name}` WHERE `{column_name}` IS NOT NULL AND `{column_name}` != '' LIMIT {max_num_examples};").result - return [example[0] for example in examples] + res = [example[0] for example in examples] + # print(res, flush=True) + if str(res) > MAX_LENGTH_COLUMN_DESCRIPTION: + return [res[0]] + else: + return res def load_database_schema_dict(db_id: str, database_root_dir: str) -> Dict[str, Dict[str, Dict[str, Any]]]: """ diff --git a/alphasql/runner/evaluation.py b/alphasql/runner/evaluation.py new file mode 100644 index 0000000..2e67a58 --- /dev/null +++ b/alphasql/runner/evaluation.py @@ -0,0 +1,153 @@ +import sys +import json +import argparse +import sqlite3 +import multiprocessing as mp +from func_timeout import func_timeout, FunctionTimedOut + +def load_json(dir): + with open(dir, 'r') as j: + contents = json.loads(j.read()) + return contents + +def result_callback(result): + exec_result.append(result) + + +def execute_sql(predicted_sql,ground_truth, db_path): + conn = sqlite3.connect(db_path) + # Connect to the database + cursor = conn.cursor() + cursor.execute(predicted_sql) + predicted_res = cursor.fetchall() + cursor.execute(ground_truth) + ground_truth_res = cursor.fetchall() + res = 0 + if set(predicted_res) == set(ground_truth_res): + res = 1 + return res + + + +def execute_model(predicted_sql,ground_truth, db_place, idx, meta_time_out): + try: + res = func_timeout(meta_time_out, execute_sql, + args=(predicted_sql, ground_truth, db_place)) + except KeyboardInterrupt: + sys.exit(0) + except FunctionTimedOut: + result = [(f'timeout',)] + res = 0 + except Exception as e: + result = [(f'error',)] # possibly len(query) > 512 or not executable + res = 0 + # print(result) + # result = str(set([ret[0] for ret in result])) + result = {'sql_idx': idx, 'res': res} + # print(result) + return result + + +def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): + clean_sqls = [] + db_path_list = [] + if mode == 'gpt': + sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r')) + for idx, sql_str in sql_data.items(): + if type(sql_str) == str: + sql, db_name = sql_str.split('\t----- bird -----\t') + else: + sql, db_name = " ", "financial" + clean_sqls.append(sql) + db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') + + elif mode == 'gt': + sqls = open(sql_path + data_mode + '_gold.sql') + sql_txt = sqls.readlines() + # sql_txt = [sql.split('\t')[0] for sql in sql_txt] + for idx, sql_str in enumerate(sql_txt): + sql, db_name = sql_str.strip().split('\t') + clean_sqls.append(sql) + db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') + + return clean_sqls, db_path_list + +def run_sqls_parallel(sqls, db_places, num_cpus=1, meta_time_out=30.0): + pool = mp.Pool(processes=num_cpus) + for i,sql_pair in enumerate(sqls): + + predicted_sql, ground_truth = sql_pair + pool.apply_async(execute_model, args=(predicted_sql, ground_truth, db_places[i], i, meta_time_out), callback=result_callback) + pool.close() + pool.join() + +def sort_results(list_of_dicts): + return sorted(list_of_dicts, key=lambda x: x['sql_idx']) + +def compute_acc_by_diff(exec_results,diff_json_path): + num_queries = len(exec_results) + results = [res['res'] for res in exec_results] + contents = load_json(diff_json_path) + simple_results, moderate_results, challenging_results = [], [], [] + + for i,content in enumerate(contents): + if content['difficulty'] == 'simple': + simple_results.append(exec_results[i]) + + if content['difficulty'] == 'moderate': + moderate_results.append(exec_results[i]) + + if content['difficulty'] == 'challenging': + challenging_results.append(exec_results[i]) + + simple_acc = sum([res['res'] for res in simple_results])/len(simple_results) + moderate_acc = sum([res['res'] for res in moderate_results])/len(moderate_results) + challenging_acc = sum([res['res'] for res in challenging_results])/len(challenging_results) + all_acc = sum(results)/num_queries + count_lists = [len(simple_results), len(moderate_results), len(challenging_results), num_queries] + return simple_acc * 100, moderate_acc * 100, challenging_acc * 100, all_acc * 100, count_lists + + + +def print_data(score_lists,count_lists): + levels = ['simple', 'moderate', 'challenging', 'total'] + print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) + print("{:20} {:<20} {:<20} {:<20} {:<20}".format('count', *count_lists)) + + print('====================================== ACCURACY =====================================') + print("{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format('accuracy', *score_lists)) + + +if __name__ == '__main__': + args_parser = argparse.ArgumentParser() + args_parser.add_argument('--predicted_sql_path', type=str, required=True, default='') + args_parser.add_argument('--ground_truth_path', type=str, required=True, default='') + args_parser.add_argument('--data_mode', type=str, required=True, default='dev') + args_parser.add_argument('--db_root_path', type=str, required=True, default='') + args_parser.add_argument('--num_cpus', type=int, default=1) + args_parser.add_argument('--meta_time_out', type=float, default=30.0) + args_parser.add_argument('--mode_gt', type=str, default='gt') + args_parser.add_argument('--mode_predict', type=str, default='gpt') + args_parser.add_argument('--difficulty',type=str,default='simple') + args_parser.add_argument('--diff_json_path',type=str,default='') + args = args_parser.parse_args() + exec_result = [] + + pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode='gt', + data_mode=args.data_mode) + # generate gt sqls: + gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt', + data_mode=args.data_mode) + + query_pairs = list(zip(pred_queries,gt_queries)) + run_sqls_parallel(query_pairs, db_places=db_paths, num_cpus=args.num_cpus, meta_time_out=args.meta_time_out) + exec_result = sort_results(exec_result) + + print('start calculate') + simple_acc, moderate_acc, challenging_acc, acc, count_lists = \ + compute_acc_by_diff(exec_result,args.diff_json_path) + score_lists = [simple_acc, moderate_acc, challenging_acc, acc] + print_data(score_lists,count_lists) + print('===========================================================================================') + print("Finished evaluation") + \ No newline at end of file diff --git a/alphasql/runner/sql_selection.py b/alphasql/runner/sql_selection.py index 06a3db6..3fc44be 100644 --- a/alphasql/runner/sql_selection.py +++ b/alphasql/runner/sql_selection.py @@ -14,12 +14,18 @@ def select_final_sql_query(results_file_path: str, db_root_dir: str): question_id = int(results_file_path.split("/")[-1].split(".")[0]) with open(results_file_path, "rb") as f: results = pickle.load(f) + if results is None or len(results) == 0: + return { + "question_id": question_id, + "db_id": "financial", + "sql": "ERROR" + } db_id = results[0][0].db_id db_path = f"{db_root_dir}/{db_id}/{db_id}.sqlite" result_groups = defaultdict(list) result_groups_with_invalid_result = defaultdict(list) - for idx, result in tqdm(enumerate(results), desc=f"Processing results for {question_id}"): + for idx, result in enumerate(results): sql_query = result[-1].final_sql_query answer = cached_execute_sql_with_timeout(db_path, sql_query) if answer.result_type.value == "success": @@ -53,6 +59,7 @@ def select_final_sql_query(results_file_path: str, db_root_dir: str): return { "question_id": question_id, + "db_id": db_id, "sql": final_selected_sql_query } @@ -82,6 +89,7 @@ def select_final_sql_query(results_file_path: str, db_root_dir: str): return { "question_id": question_id, + "db_id": db_id, "sql": final_selected_sql_query } @@ -93,10 +101,23 @@ def main(args): for future in tqdm(as_completed(future_to_path), total=len(future_to_path), desc="Processing results"): selected_item = future.result() - final_pred_sqls[str(selected_item["question_id"])] = selected_item["sql"] + final_pred_sqls[str(selected_item["question_id"])] = selected_item["sql"] + '\t----- bird -----\t' + selected_item["db_id"] + + # Fill missing IDs with default value + if final_pred_sqls: + question_ids = [int(qid) for qid in final_pred_sqls.keys()] + min_id = min(question_ids) + max_id = max(question_ids) + + for qid in range(min_id, max_id + 1): + if str(qid) not in final_pred_sqls: + final_pred_sqls[str(qid)] = "ERROR\t----- bird -----\tfinancial" + + # Sort by question_id before output + sorted_pred_sqls = dict(sorted(final_pred_sqls.items(), key=lambda x: int(x[0]))) with open(args.output_path, "w") as f: - json.dump(final_pred_sqls, f, indent=4) + json.dump(sorted_pred_sqls, f, indent=4) if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/requirements.txt b/requirements.txt index 2e4fec5..3cf167c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,6 @@ loguru tqdm pyyaml weave -sqlglot[rs] \ No newline at end of file +sqlglot[rs] +prettytable +func_timeout \ No newline at end of file From a4db148c68ae914dc4760243de4f7c20c11083b0 Mon Sep 17 00:00:00 2001 From: cbacode <1586230797@qq.com> Date: Sun, 23 Nov 2025 21:15:19 +0800 Subject: [PATCH 2/2] changes --- alphasql/runner/evaluation.py | 11 ++++++++--- script/evaluation.sh | 13 +++++++++++++ script/sql_selection.sh | 20 +++++++++++++++++--- 3 files changed, 38 insertions(+), 6 deletions(-) create mode 100644 script/evaluation.sh diff --git a/alphasql/runner/evaluation.py b/alphasql/runner/evaluation.py index 2e67a58..3da2584 100644 --- a/alphasql/runner/evaluation.py +++ b/alphasql/runner/evaluation.py @@ -1,3 +1,6 @@ +# Copy from https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/bird +# The way of finding files has been modified. + import sys import json import argparse @@ -52,7 +55,8 @@ def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): clean_sqls = [] db_path_list = [] if mode == 'gpt': - sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r')) + # sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r')) + sql_data = json.load(open(sql_path, 'r')) for idx, sql_str in sql_data.items(): if type(sql_str) == str: sql, db_name = sql_str.split('\t----- bird -----\t') @@ -62,7 +66,8 @@ def package_sqls(sql_path, db_root_path, mode='gpt', data_mode='dev'): db_path_list.append(db_root_path + db_name + '/' + db_name + '.sqlite') elif mode == 'gt': - sqls = open(sql_path + data_mode + '_gold.sql') + # sqls = open(sql_path + data_mode + '_gold.sql') + sqls = open(sql_path) sql_txt = sqls.readlines() # sql_txt = [sql.split('\t')[0] for sql in sql_txt] for idx, sql_str in enumerate(sql_txt): @@ -133,7 +138,7 @@ def print_data(score_lists,count_lists): args = args_parser.parse_args() exec_result = [] - pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode='gt', + pred_queries, db_paths = package_sqls(args.predicted_sql_path, args.db_root_path, mode='gpt', data_mode=args.data_mode) # generate gt sqls: gt_queries, db_paths_gt = package_sqls(args.ground_truth_path, args.db_root_path, mode='gt', diff --git a/script/evaluation.sh b/script/evaluation.sh new file mode 100644 index 0000000..10a6d9c --- /dev/null +++ b/script/evaluation.sh @@ -0,0 +1,13 @@ +db_root_path='data/bird/dev/dev_databases/' +data_mode='dev' +diff_json_path='data/bird/dev/dev.json' +predicted_sql_path_kg='results/pred_sqls_7B.json' +ground_truth_path='data/bird/dev/dev.sql' +num_cpus=16 +meta_time_out=30.0 +mode_gt='gt' +mode_predict='gpt' +output_path='results/evaluation_results_7B.txt' + +echo "Evaluating SQLs..." +python3 -u -m alphasql.runner.evaluation --db_root_path ${db_root_path} --predicted_sql_path ${predicted_sql_path_kg} --data_mode ${data_mode} --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} > ${output_path} \ No newline at end of file diff --git a/script/sql_selection.sh b/script/sql_selection.sh index 793bca9..66c48ab 100644 --- a/script/sql_selection.sh +++ b/script/sql_selection.sh @@ -3,12 +3,26 @@ DB_ROOT_DIR="data/bird/dev/dev_databases" PROCESS_NUM=32 -RESULTS_DIR="results/Qwen2.5-Coder-32B-Instruct/bird/dev" -OUTPUT_PATH="./pred_sqls.json" +RESULTS_DIR="results/Qwen2.5-Coder-7B-Instruct/bird/dev" +OUTPUT_PATH="results/pred_sqls_7B.json" echo "Selecting SQLs..." python -m alphasql.runner.sql_selection \ --results_dir $RESULTS_DIR \ --db_root_dir $DB_ROOT_DIR \ --process_num $PROCESS_NUM \ - --output_path $OUTPUT_PATH \ No newline at end of file + --output_path $OUTPUT_PATH + +db_root_path='data/bird/dev/dev_databases/' +data_mode='dev' +diff_json_path='data/bird/dev/dev.json' +predicted_sql_path_kg='results/pred_sqls_7B.json' +ground_truth_path='data/bird/dev/dev.sql' +num_cpus=16 +meta_time_out=30.0 +mode_gt='gt' +mode_predict='gpt' +output_path='results/evaluation_results_7B.txt' + +echo "Evaluating SQLs..." +python3 -u -m alphasql.runner.evaluation --db_root_path ${db_root_path} --predicted_sql_path ${predicted_sql_path_kg} --data_mode ${data_mode} --ground_truth_path ${ground_truth_path} --num_cpus ${num_cpus} --mode_gt ${mode_gt} --mode_predict ${mode_predict} --diff_json_path ${diff_json_path} --meta_time_out ${meta_time_out} > ${output_path} \ No newline at end of file