From 350f87c16043cb5f1c1066376cfd3bb94ed9f92a Mon Sep 17 00:00:00 2001 From: misiugodfrey Date: Mon, 9 Mar 2026 14:55:41 -0700 Subject: [PATCH 1/3] first pass validation --- presto/scripts/run_benchmark.sh | 18 ++++ .../performance_benchmarks/common_fixtures.py | 96 +++++++++++++++++++ .../performance_benchmarks/conftest.py | 2 + 3 files changed, 116 insertions(+) diff --git a/presto/scripts/run_benchmark.sh b/presto/scripts/run_benchmark.sh index 2a4d7462..79ab19b7 100755 --- a/presto/scripts/run_benchmark.sh +++ b/presto/scripts/run_benchmark.sh @@ -38,6 +38,11 @@ OPTIONS: --skip-drop-cache Skip dropping system caches before each benchmark query (dropped by default). -m, --metrics Collect detailed metrics from Presto REST API after each query. Metrics are stored in query-specific directories. + -e, --expected-results-dir + Directory containing expected query result parquet files for validation. + If not specified, the default is derived from the table schema by appending + "_expected" to the data directory (e.g., \$PRESTO_DATA_DIR/tpchsf100_expected). + If the directory does not exist, validation is skipped. EXAMPLES: $0 -b tpch -s bench_sf100 @@ -160,6 +165,15 @@ parse_args() { METRICS=true shift ;; + -e|--expected-results-dir) + if [[ -n $2 ]]; then + EXPECTED_RESULTS_DIR=$2 + shift 2 + else + echo "Error: --expected-results-dir requires a value" + exit 1 + fi + ;; *) echo "Error: Unknown argument $1" print_help @@ -236,6 +250,10 @@ if [[ "${SKIP_DROP_CACHE}" == "true" ]]; then PYTEST_ARGS+=("--skip-drop-cache") fi +if [[ -n ${EXPECTED_RESULTS_DIR} ]]; then + PYTEST_ARGS+=("--expected-results-dir ${EXPECTED_RESULTS_DIR}") +fi + source "${SCRIPT_DIR}/../../scripts/py_env_functions.sh" trap delete_python_virtual_env EXIT diff --git a/presto/testing/performance_benchmarks/common_fixtures.py b/presto/testing/performance_benchmarks/common_fixtures.py index 8ea3b6db..87616b4e 100644 --- a/presto/testing/performance_benchmarks/common_fixtures.py +++ b/presto/testing/performance_benchmarks/common_fixtures.py @@ -1,14 +1,22 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 +import os from pathlib import Path +import duckdb import pandas as pd import prestodb import pytest +from common.testing.integration_tests.test_utils import ( + assert_rows_equal, + normalize_rows, + none_safe_sort_key, +) from common.testing.performance_benchmarks.benchmark_keys import BenchmarkKeys +from ..common.test_utils import get_table_external_location from ..integration_tests.analyze_tables import check_tables_analyzed from .metrics_collector import collect_metrics from .profiler_utils import start_profiler, stop_profiler @@ -115,3 +123,91 @@ def benchmark_query_function(query_id): stop_profiler(profile_script_path, profile_output_file_path) return benchmark_query_function + + +def _derive_expected_results_dir(hostname, port, user, schema): + """Derive the expected results directory from the table schema. + + Queries the schema to find a table's external location on the host + (e.g. $PRESTO_DATA_DIR/tpchsf100/lineitem), goes up one level to + get the data root, and appends '_expected'. + """ + conn = prestodb.dbapi.connect(host=hostname, port=port, user=user, catalog="hive", schema=schema) + cursor = conn.cursor() + try: + table = cursor.execute(f"SHOW TABLES IN {schema}").fetchone()[0] + table_location = get_table_external_location(schema, table, cursor) + data_root = os.path.dirname(table_location) + return f"{data_root}_expected" + except Exception as e: + print(f"[Validation] Could not derive expected results directory from schema: {e}") + return None + finally: + cursor.close() + conn.close() + + +@pytest.fixture(scope="session", autouse=True) +def validate_benchmark_results(request): + """Session-scoped fixture that validates benchmark query results after all queries complete.""" + yield + + expected_results_dir = request.config.getoption("--expected-results-dir") + if expected_results_dir is None: + hostname = request.config.getoption("--hostname") + port = request.config.getoption("--port") + user = request.config.getoption("--user") + schema = request.config.getoption("--schema-name") + expected_results_dir = _derive_expected_results_dir(hostname, port, user, schema) + + if expected_results_dir is None: + print("[Validation] Skipping result validation (could not determine expected results directory).") + return + + expected_dir = Path(expected_results_dir) + if not expected_dir.is_dir(): + print(f"[Validation] Skipping result validation (expected results directory '{expected_dir}' not found).") + return + + expected_files = sorted(expected_dir.glob("*.parquet")) + if not expected_files: + print(f"[Validation] Skipping result validation (no parquet files in '{expected_dir}').") + return + + output_dir = request.config.getoption("--output-dir") + actual_results_dir = Path(output_dir) / "query_results" + if not actual_results_dir.is_dir(): + print(f"[Validation] Skipping result validation (no query results directory at '{actual_results_dir}').") + return + + passed = 0 + failed = 0 + for expected_file in expected_files: + query_name = expected_file.name + actual_file = actual_results_dir / query_name + if not actual_file.exists(): + print(f"[Validation] SKIPPED: {query_name} - no actual result found.") + continue + try: + expected_rel = duckdb.from_parquet(str(expected_file)) + actual_rel = duckdb.from_parquet(str(actual_file)) + types = expected_rel.types + + expected_rows = sorted(normalize_rows(expected_rel.fetchall(), types), key=none_safe_sort_key) + actual_rows = sorted(normalize_rows(actual_rel.fetchall(), types), key=none_safe_sort_key) + + assert len(actual_rows) == len(expected_rows), ( + f"Row count mismatch: {len(actual_rows)} vs {len(expected_rows)}" + ) + assert_rows_equal(actual_rows, expected_rows, types) + print(f"[Validation] PASSED: {query_name}") + passed += 1 + except AssertionError as e: + print(f"[Validation] FAILED: {query_name} - {e}") + failed += 1 + except Exception as e: + print(f"[Validation] ERROR: {query_name} - {e}") + failed += 1 + + total = passed + failed + print(f"[Validation] Result validation complete: {passed}/{total} queries passed.") diff --git a/presto/testing/performance_benchmarks/conftest.py b/presto/testing/performance_benchmarks/conftest.py index 2636fc99..963a5611 100644 --- a/presto/testing/performance_benchmarks/conftest.py +++ b/presto/testing/performance_benchmarks/conftest.py @@ -25,6 +25,7 @@ from .common_fixtures import ( benchmark_query, # noqa: F401 presto_cursor, # noqa: F401 + validate_benchmark_results, # noqa: F401 ) @@ -42,6 +43,7 @@ def pytest_addoption(parser): parser.addoption("--profile-script-path") parser.addoption("--metrics", action="store_true", default=False) parser.addoption("--skip-drop-cache", action="store_true", default=False) + parser.addoption("--expected-results-dir", default=None) def pytest_configure(config): From 599d812f0a6db9f82fdd998936041557f5d410f2 Mon Sep 17 00:00:00 2001 From: misiugodfrey Date: Tue, 10 Mar 2026 00:18:13 -0700 Subject: [PATCH 2/3] fixed linting --- presto/testing/performance_benchmarks/common_fixtures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/presto/testing/performance_benchmarks/common_fixtures.py b/presto/testing/performance_benchmarks/common_fixtures.py index 87616b4e..c123674d 100644 --- a/presto/testing/performance_benchmarks/common_fixtures.py +++ b/presto/testing/performance_benchmarks/common_fixtures.py @@ -11,8 +11,8 @@ from common.testing.integration_tests.test_utils import ( assert_rows_equal, - normalize_rows, none_safe_sort_key, + normalize_rows, ) from common.testing.performance_benchmarks.benchmark_keys import BenchmarkKeys From df519c952aacb14637025ee265b941287fbee02f Mon Sep 17 00:00:00 2001 From: misiugodfrey Date: Tue, 10 Mar 2026 00:53:05 -0700 Subject: [PATCH 3/3] Change validation based on how deterministic the query is --- .../performance_benchmarks/common_fixtures.py | 149 ++++++++++++++---- .../performance_benchmarks/conftest.py | 12 +- 2 files changed, 132 insertions(+), 29 deletions(-) diff --git a/presto/testing/performance_benchmarks/common_fixtures.py b/presto/testing/performance_benchmarks/common_fixtures.py index c123674d..e9a3ea8c 100644 --- a/presto/testing/performance_benchmarks/common_fixtures.py +++ b/presto/testing/performance_benchmarks/common_fixtures.py @@ -8,13 +8,16 @@ import pandas as pd import prestodb import pytest +import sqlglot from common.testing.integration_tests.test_utils import ( assert_rows_equal, + get_orderby_indices, none_safe_sort_key, normalize_rows, ) from common.testing.performance_benchmarks.benchmark_keys import BenchmarkKeys +from common.testing.test_utils import get_queries from ..common.test_utils import get_table_external_location from ..integration_tests.analyze_tables import check_tables_analyzed @@ -115,7 +118,9 @@ def benchmark_query_function(query_id): ) raw_times_dict[query_id] = result except Exception as e: - failed_queries_dict[query_id] = f"{e.error_type}: {e.error_name}" + error_desc = getattr(e, "error_type", type(e).__name__) + error_name = getattr(e, "error_name", str(e)) + failed_queries_dict[query_id] = f"{error_desc}: {error_name}" raw_times_dict[query_id] = None raise finally: @@ -147,17 +152,75 @@ def _derive_expected_results_dir(hostname, port, user, schema): conn.close() -@pytest.fixture(scope="session", autouse=True) -def validate_benchmark_results(request): - """Session-scoped fixture that validates benchmark query results after all queries complete.""" - yield +def _classify_limit_query(query_sql): + """Classify a query's validation strategy based on its LIMIT and ORDER BY. + + Returns: + "full" - no LIMIT, compare all columns + "orderby_only" - LIMIT with deterministic ORDER BY (raw columns, COUNT, etc.), + compare only ORDER BY columns + "skip" - LIMIT with non-deterministic ORDER BY (SUM/AVG float aggregates), + skip validation because distributed floating-point aggregation + can change the ranking and thus which rows appear in the result set + """ + try: + expr = sqlglot.parse_one(query_sql) + except sqlglot.errors.ParseError: + return "full" + + has_limit = any(isinstance(e, sqlglot.exp.Limit) for e in expr.iter_expressions()) + if not has_limit: + return "full" + + order = next((e for e in expr.find_all(sqlglot.exp.Order)), None) + if not order: + return "full" + + order_names = set() + for ordered in order.expressions: + key = ordered.this + if isinstance(key, sqlglot.exp.Column): + order_names.add(key.name) + + select = expr.find(sqlglot.exp.Select) + float_aggs = (sqlglot.exp.Sum, sqlglot.exp.Avg) + for s in select.expressions: + alias = s.alias if hasattr(s, "alias") else None + if alias and alias in order_names: + for node_tuple in s.walk(): + node = node_tuple[0] if isinstance(node_tuple, tuple) else node_tuple + if isinstance(node, float_aggs): + return "skip" + + return "orderby_only" + + +def _load_query_map(benchmark_types): + """Build a {lowercase_query_name: sql} map from the benchmark query JSON files.""" + query_map = {} + for bench_type in benchmark_types: + try: + queries = get_queries(bench_type) + for key, sql in queries.items(): + query_map[key.lower()] = sql + except (FileNotFoundError, OSError): + pass + return query_map + - expected_results_dir = request.config.getoption("--expected-results-dir") +def validate_benchmark_results(config, benchmark_types): + """Validate benchmark query results against expected parquet files. + + Called from pytest_terminal_summary so output appears after the benchmark summary. + For queries with LIMIT, only ORDER BY columns are compared since other columns + can be non-deterministic at the LIMIT boundary. + """ + expected_results_dir = config.getoption("--expected-results-dir") if expected_results_dir is None: - hostname = request.config.getoption("--hostname") - port = request.config.getoption("--port") - user = request.config.getoption("--user") - schema = request.config.getoption("--schema-name") + hostname = config.getoption("--hostname") + port = config.getoption("--port") + user = config.getoption("--user") + schema = config.getoption("--schema-name") expected_results_dir = _derive_expected_results_dir(hostname, port, user, schema) if expected_results_dir is None: @@ -174,40 +237,72 @@ def validate_benchmark_results(request): print(f"[Validation] Skipping result validation (no parquet files in '{expected_dir}').") return - output_dir = request.config.getoption("--output-dir") + output_dir = config.getoption("--output-dir") actual_results_dir = Path(output_dir) / "query_results" if not actual_results_dir.is_dir(): print(f"[Validation] Skipping result validation (no query results directory at '{actual_results_dir}').") return - passed = 0 - failed = 0 + query_map = _load_query_map(benchmark_types) + + passed_queries = [] + skipped_queries = [] + failures = [] for expected_file in expected_files: - query_name = expected_file.name - actual_file = actual_results_dir / query_name + query_name = expected_file.stem + actual_file = actual_results_dir / expected_file.name if not actual_file.exists(): - print(f"[Validation] SKIPPED: {query_name} - no actual result found.") continue + + query_sql = query_map.get(query_name) + strategy = _classify_limit_query(query_sql) if query_sql else "full" + + # Queries whose ORDER BY involves float aggregates (e.g. SUM, AVG) are + # non-deterministic under distributed execution: the partial-aggregate + # reduction order can change the ranking, so different rows appear in + # the LIMIT result set across runs. + if strategy == "skip": + skipped_queries.append(query_name) + continue + try: expected_rel = duckdb.from_parquet(str(expected_file)) actual_rel = duckdb.from_parquet(str(actual_file)) types = expected_rel.types + columns = expected_rel.columns + + expected_rows = expected_rel.fetchall() + actual_rows = actual_rel.fetchall() - expected_rows = sorted(normalize_rows(expected_rel.fetchall(), types), key=none_safe_sort_key) - actual_rows = sorted(normalize_rows(actual_rel.fetchall(), types), key=none_safe_sort_key) + # For LIMIT queries with deterministic ORDER BY (raw columns, + # COUNT, etc.), only compare the ORDER BY columns — non-ORDER BY + # columns can differ at the boundary when there are ties. + if strategy == "orderby_only": + order_indices = get_orderby_indices(query_sql, columns) + if order_indices: + types = [types[i] for i in order_indices] + expected_rows = [tuple(row[i] for i in order_indices) for row in expected_rows] + actual_rows = [tuple(row[i] for i in order_indices) for row in actual_rows] + + expected_rows = sorted(normalize_rows(expected_rows, types), key=none_safe_sort_key) + actual_rows = sorted(normalize_rows(actual_rows, types), key=none_safe_sort_key) assert len(actual_rows) == len(expected_rows), ( f"Row count mismatch: {len(actual_rows)} vs {len(expected_rows)}" ) assert_rows_equal(actual_rows, expected_rows, types) - print(f"[Validation] PASSED: {query_name}") - passed += 1 + passed_queries.append(query_name) except AssertionError as e: - print(f"[Validation] FAILED: {query_name} - {e}") - failed += 1 + failures.append(f"[Validation] FAILED: {query_name} - {e}") except Exception as e: - print(f"[Validation] ERROR: {query_name} - {e}") - failed += 1 - - total = passed + failed - print(f"[Validation] Result validation complete: {passed}/{total} queries passed.") + failures.append(f"[Validation] ERROR: {query_name} - {e}") + + for line in failures: + print(line) + total = len(passed_queries) + len(failures) + passed_list = ", ".join(passed_queries) + skipped_list = ", ".join(skipped_queries) + parts = [f"{len(passed_queries)}/{total} passed ({passed_list})"] + if skipped_queries: + parts.append(f"{len(skipped_queries)} skipped non-deterministic ({skipped_list})") + print(f"[Validation] {'; '.join(parts)}") diff --git a/presto/testing/performance_benchmarks/conftest.py b/presto/testing/performance_benchmarks/conftest.py index 963a5611..4b643e19 100644 --- a/presto/testing/performance_benchmarks/conftest.py +++ b/presto/testing/performance_benchmarks/conftest.py @@ -15,7 +15,7 @@ from common.testing.performance_benchmarks.conftest import ( DataLocation, pytest_sessionfinish, # noqa: F401 - pytest_terminal_summary, # noqa: F401 + pytest_terminal_summary as _common_pytest_terminal_summary, ) from ..common.fixtures import ( @@ -25,7 +25,7 @@ from .common_fixtures import ( benchmark_query, # noqa: F401 presto_cursor, # noqa: F401 - validate_benchmark_results, # noqa: F401 + validate_benchmark_results, ) @@ -46,5 +46,13 @@ def pytest_addoption(parser): parser.addoption("--expected-results-dir", default=None) +def pytest_terminal_summary(terminalreporter, exitstatus, config): + _common_pytest_terminal_summary(terminalreporter, exitstatus, config) + benchmark_types = [] + if hasattr(terminalreporter._session, "benchmark_results"): + benchmark_types = list(terminalreporter._session.benchmark_results.keys()) + validate_benchmark_results(config, benchmark_types) + + def pytest_configure(config): pytest.data_location = DataLocation("--schema-name", "Schema", BenchmarkKeys.SCHEMA_NAME_KEY)