diff --git a/presto/scripts/run_benchmark.sh b/presto/scripts/run_benchmark.sh index 2a4d7462..79ab19b7 100755 --- a/presto/scripts/run_benchmark.sh +++ b/presto/scripts/run_benchmark.sh @@ -38,6 +38,11 @@ OPTIONS: --skip-drop-cache Skip dropping system caches before each benchmark query (dropped by default). -m, --metrics Collect detailed metrics from Presto REST API after each query. Metrics are stored in query-specific directories. + -e, --expected-results-dir + Directory containing expected query result parquet files for validation. + If not specified, the default is derived from the table schema by appending + "_expected" to the data directory (e.g., \$PRESTO_DATA_DIR/tpchsf100_expected). + If the directory does not exist, validation is skipped. EXAMPLES: $0 -b tpch -s bench_sf100 @@ -160,6 +165,15 @@ parse_args() { METRICS=true shift ;; + -e|--expected-results-dir) + if [[ -n $2 ]]; then + EXPECTED_RESULTS_DIR=$2 + shift 2 + else + echo "Error: --expected-results-dir requires a value" + exit 1 + fi + ;; *) echo "Error: Unknown argument $1" print_help @@ -236,6 +250,10 @@ if [[ "${SKIP_DROP_CACHE}" == "true" ]]; then PYTEST_ARGS+=("--skip-drop-cache") fi +if [[ -n ${EXPECTED_RESULTS_DIR} ]]; then + PYTEST_ARGS+=("--expected-results-dir ${EXPECTED_RESULTS_DIR}") +fi + source "${SCRIPT_DIR}/../../scripts/py_env_functions.sh" trap delete_python_virtual_env EXIT diff --git a/presto/testing/performance_benchmarks/common_fixtures.py b/presto/testing/performance_benchmarks/common_fixtures.py index 8ea3b6db..e9a3ea8c 100644 --- a/presto/testing/performance_benchmarks/common_fixtures.py +++ b/presto/testing/performance_benchmarks/common_fixtures.py @@ -1,14 +1,25 @@ # SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 +import os from pathlib import Path +import duckdb import pandas as pd import prestodb import pytest +import sqlglot +from common.testing.integration_tests.test_utils import ( + assert_rows_equal, + get_orderby_indices, + none_safe_sort_key, + normalize_rows, +) from common.testing.performance_benchmarks.benchmark_keys import BenchmarkKeys +from common.testing.test_utils import get_queries +from ..common.test_utils import get_table_external_location from ..integration_tests.analyze_tables import check_tables_analyzed from .metrics_collector import collect_metrics from .profiler_utils import start_profiler, stop_profiler @@ -107,7 +118,9 @@ def benchmark_query_function(query_id): ) raw_times_dict[query_id] = result except Exception as e: - failed_queries_dict[query_id] = f"{e.error_type}: {e.error_name}" + error_desc = getattr(e, "error_type", type(e).__name__) + error_name = getattr(e, "error_name", str(e)) + failed_queries_dict[query_id] = f"{error_desc}: {error_name}" raw_times_dict[query_id] = None raise finally: @@ -115,3 +128,181 @@ def benchmark_query_function(query_id): stop_profiler(profile_script_path, profile_output_file_path) return benchmark_query_function + + +def _derive_expected_results_dir(hostname, port, user, schema): + """Derive the expected results directory from the table schema. + + Queries the schema to find a table's external location on the host + (e.g. $PRESTO_DATA_DIR/tpchsf100/lineitem), goes up one level to + get the data root, and appends '_expected'. + """ + conn = prestodb.dbapi.connect(host=hostname, port=port, user=user, catalog="hive", schema=schema) + cursor = conn.cursor() + try: + table = cursor.execute(f"SHOW TABLES IN {schema}").fetchone()[0] + table_location = get_table_external_location(schema, table, cursor) + data_root = os.path.dirname(table_location) + return f"{data_root}_expected" + except Exception as e: + print(f"[Validation] Could not derive expected results directory from schema: {e}") + return None + finally: + cursor.close() + conn.close() + + +def _classify_limit_query(query_sql): + """Classify a query's validation strategy based on its LIMIT and ORDER BY. + + Returns: + "full" - no LIMIT, compare all columns + "orderby_only" - LIMIT with deterministic ORDER BY (raw columns, COUNT, etc.), + compare only ORDER BY columns + "skip" - LIMIT with non-deterministic ORDER BY (SUM/AVG float aggregates), + skip validation because distributed floating-point aggregation + can change the ranking and thus which rows appear in the result set + """ + try: + expr = sqlglot.parse_one(query_sql) + except sqlglot.errors.ParseError: + return "full" + + has_limit = any(isinstance(e, sqlglot.exp.Limit) for e in expr.iter_expressions()) + if not has_limit: + return "full" + + order = next((e for e in expr.find_all(sqlglot.exp.Order)), None) + if not order: + return "full" + + order_names = set() + for ordered in order.expressions: + key = ordered.this + if isinstance(key, sqlglot.exp.Column): + order_names.add(key.name) + + select = expr.find(sqlglot.exp.Select) + float_aggs = (sqlglot.exp.Sum, sqlglot.exp.Avg) + for s in select.expressions: + alias = s.alias if hasattr(s, "alias") else None + if alias and alias in order_names: + for node_tuple in s.walk(): + node = node_tuple[0] if isinstance(node_tuple, tuple) else node_tuple + if isinstance(node, float_aggs): + return "skip" + + return "orderby_only" + + +def _load_query_map(benchmark_types): + """Build a {lowercase_query_name: sql} map from the benchmark query JSON files.""" + query_map = {} + for bench_type in benchmark_types: + try: + queries = get_queries(bench_type) + for key, sql in queries.items(): + query_map[key.lower()] = sql + except (FileNotFoundError, OSError): + pass + return query_map + + +def validate_benchmark_results(config, benchmark_types): + """Validate benchmark query results against expected parquet files. + + Called from pytest_terminal_summary so output appears after the benchmark summary. + For queries with LIMIT, only ORDER BY columns are compared since other columns + can be non-deterministic at the LIMIT boundary. + """ + expected_results_dir = config.getoption("--expected-results-dir") + if expected_results_dir is None: + hostname = config.getoption("--hostname") + port = config.getoption("--port") + user = config.getoption("--user") + schema = config.getoption("--schema-name") + expected_results_dir = _derive_expected_results_dir(hostname, port, user, schema) + + if expected_results_dir is None: + print("[Validation] Skipping result validation (could not determine expected results directory).") + return + + expected_dir = Path(expected_results_dir) + if not expected_dir.is_dir(): + print(f"[Validation] Skipping result validation (expected results directory '{expected_dir}' not found).") + return + + expected_files = sorted(expected_dir.glob("*.parquet")) + if not expected_files: + print(f"[Validation] Skipping result validation (no parquet files in '{expected_dir}').") + return + + output_dir = config.getoption("--output-dir") + actual_results_dir = Path(output_dir) / "query_results" + if not actual_results_dir.is_dir(): + print(f"[Validation] Skipping result validation (no query results directory at '{actual_results_dir}').") + return + + query_map = _load_query_map(benchmark_types) + + passed_queries = [] + skipped_queries = [] + failures = [] + for expected_file in expected_files: + query_name = expected_file.stem + actual_file = actual_results_dir / expected_file.name + if not actual_file.exists(): + continue + + query_sql = query_map.get(query_name) + strategy = _classify_limit_query(query_sql) if query_sql else "full" + + # Queries whose ORDER BY involves float aggregates (e.g. SUM, AVG) are + # non-deterministic under distributed execution: the partial-aggregate + # reduction order can change the ranking, so different rows appear in + # the LIMIT result set across runs. + if strategy == "skip": + skipped_queries.append(query_name) + continue + + try: + expected_rel = duckdb.from_parquet(str(expected_file)) + actual_rel = duckdb.from_parquet(str(actual_file)) + types = expected_rel.types + columns = expected_rel.columns + + expected_rows = expected_rel.fetchall() + actual_rows = actual_rel.fetchall() + + # For LIMIT queries with deterministic ORDER BY (raw columns, + # COUNT, etc.), only compare the ORDER BY columns — non-ORDER BY + # columns can differ at the boundary when there are ties. + if strategy == "orderby_only": + order_indices = get_orderby_indices(query_sql, columns) + if order_indices: + types = [types[i] for i in order_indices] + expected_rows = [tuple(row[i] for i in order_indices) for row in expected_rows] + actual_rows = [tuple(row[i] for i in order_indices) for row in actual_rows] + + expected_rows = sorted(normalize_rows(expected_rows, types), key=none_safe_sort_key) + actual_rows = sorted(normalize_rows(actual_rows, types), key=none_safe_sort_key) + + assert len(actual_rows) == len(expected_rows), ( + f"Row count mismatch: {len(actual_rows)} vs {len(expected_rows)}" + ) + assert_rows_equal(actual_rows, expected_rows, types) + passed_queries.append(query_name) + except AssertionError as e: + failures.append(f"[Validation] FAILED: {query_name} - {e}") + except Exception as e: + failures.append(f"[Validation] ERROR: {query_name} - {e}") + + for line in failures: + print(line) + total = len(passed_queries) + len(failures) + passed_list = ", ".join(passed_queries) + skipped_list = ", ".join(skipped_queries) + parts = [f"{len(passed_queries)}/{total} passed ({passed_list})"] + if skipped_queries: + parts.append(f"{len(skipped_queries)} skipped non-deterministic ({skipped_list})") + print(f"[Validation] {'; '.join(parts)}") diff --git a/presto/testing/performance_benchmarks/conftest.py b/presto/testing/performance_benchmarks/conftest.py index 2636fc99..4b643e19 100644 --- a/presto/testing/performance_benchmarks/conftest.py +++ b/presto/testing/performance_benchmarks/conftest.py @@ -15,7 +15,7 @@ from common.testing.performance_benchmarks.conftest import ( DataLocation, pytest_sessionfinish, # noqa: F401 - pytest_terminal_summary, # noqa: F401 + pytest_terminal_summary as _common_pytest_terminal_summary, ) from ..common.fixtures import ( @@ -25,6 +25,7 @@ from .common_fixtures import ( benchmark_query, # noqa: F401 presto_cursor, # noqa: F401 + validate_benchmark_results, ) @@ -42,6 +43,15 @@ def pytest_addoption(parser): parser.addoption("--profile-script-path") parser.addoption("--metrics", action="store_true", default=False) parser.addoption("--skip-drop-cache", action="store_true", default=False) + parser.addoption("--expected-results-dir", default=None) + + +def pytest_terminal_summary(terminalreporter, exitstatus, config): + _common_pytest_terminal_summary(terminalreporter, exitstatus, config) + benchmark_types = [] + if hasattr(terminalreporter._session, "benchmark_results"): + benchmark_types = list(terminalreporter._session.benchmark_results.keys()) + validate_benchmark_results(config, benchmark_types) def pytest_configure(config):