diff --git a/accordo/README.md b/accordo/README.md index 58ab1f1..82eca4e 100644 --- a/accordo/README.md +++ b/accordo/README.md @@ -31,8 +31,8 @@ validator = Accordo(binary="./app_ref", kernel_name="reduce_sum") ref = validator.capture_snapshot(binary="./app_ref") opt = validator.capture_snapshot(binary="./app_opt") -# Compare with specified tolerance -result = validator.compare_snapshots(ref, opt, tolerance=1e-6) +# Compare with allclose-style controls +result = validator.compare_snapshots(ref, opt, atol=1e-6, rtol=1e-5, equal_nan=False) if result.is_valid: print(f"✓ PASS: {result.num_arrays_validated} arrays matched") @@ -48,7 +48,7 @@ ref = validator.capture_snapshot(binary="./ref") for opt_binary in ["./opt_v1", "./opt_v2", "./opt_v3"]: opt = validator.capture_snapshot(binary=opt_binary) - result = validator.compare_snapshots(ref, opt, tolerance=1e-6) + result = validator.compare_snapshots(ref, opt, atol=1e-6, rtol=1e-5, equal_nan=False) print(f"{opt_binary}: {'✓ PASS' if result.is_valid else '✗ FAIL'}") ``` @@ -84,12 +84,13 @@ The CLI passes each flag as a **single executable path** (no embedded spaces or **Methods:** - `capture_snapshot(binary, timeout_seconds=30)` → `Snapshot` -- `compare_snapshots(reference, optimized, tolerance=1e-6)` → `ValidationResult` +- `compare_snapshots(reference, optimized, *, atol=1e-6, rtol=0.0, equal_nan=False)` → `ValidationResult` ### `Snapshot` **Attributes:** - `arrays` (list[np.ndarray]): Captured output arrays +- `dispatch_arrays` (list[list[np.ndarray]] | None): Captured outputs for each dispatch - `execution_time_ms` (float): Execution time - `grid_size`, `block_size` (dict | None): Kernel dimensions diff --git a/accordo/accordo/__init__.py b/accordo/accordo/__init__.py index 829655d..f700062 100644 --- a/accordo/accordo/__init__.py +++ b/accordo/accordo/__init__.py @@ -16,8 +16,8 @@ >>> ref = validator.capture_snapshot(binary="./app_ref") >>> opt = validator.capture_snapshot(binary="./app_opt") >>> - >>> # Compare with specified tolerance - >>> result = validator.compare_snapshots(ref, opt, tolerance=1e-6) + >>> # Compare with configurable allclose-style tolerances + >>> result = validator.compare_snapshots(ref, opt, atol=1e-6, rtol=1e-5, equal_nan=False) >>> print(f"Valid: {result.is_valid}") Efficient Example (multiple comparisons): @@ -30,7 +30,7 @@ >>> # Compare against multiple optimizations >>> for opt_bin in ["./opt1", "./opt2", "./opt3"]: ... opt = validator.capture_snapshot(binary=opt_bin) - ... result = validator.compare_snapshots(ref, opt, tolerance=1e-6) + ... result = validator.compare_snapshots(ref, opt, atol=1e-6, rtol=1e-5) ... print(f"{opt_bin}: {'PASS' if result.is_valid else 'FAIL'}") Multiple Kernels: diff --git a/accordo/accordo/_internal/ipc/communication.py b/accordo/accordo/_internal/ipc/communication.py index 9bc5fd7..93d8fd0 100644 --- a/accordo/accordo/_internal/ipc/communication.py +++ b/accordo/accordo/_internal/ipc/communication.py @@ -103,6 +103,30 @@ def read_ipc_handles(args, ipc_file_name, sentinel_file=None): return handles, sizes +def _read_ipc_records(ipc_file_name): + """Read all IPC records in-order from file as (handle_np, size) tuples.""" + if not os.path.exists(ipc_file_name): + return [] + + with open(ipc_file_name, "rb") as file: + data = file.read() + + records = [] + messages = data.split(b"BEGIN\n") + for message in messages: + if b"END\n" not in message: + continue + content = message.split(b"END\n")[0] + if len(content) != 72: + continue + handle_data = content[:64] + size_data = content[64:72] + handle_np = np.frombuffer(handle_data, dtype=np.uint8) + size_value = int.from_bytes(size_data, byteorder="little") + records.append((handle_np, size_value)) + return records + + def send_response(pipe_name): """Send completion response through named pipe.""" with open(pipe_name, "w") as fifo: @@ -123,7 +147,7 @@ def get_kern_arg_data( baseline_time_ms: Baseline execution time (for dynamic timeout) Returns: - List of NumPy arrays with argument data + List of dispatch captures. Each dispatch is a list of NumPy arrays. Raises: TimeoutError: If IPC operation times out @@ -254,27 +278,10 @@ def _open_pipe(): f"Timeout after {ipc_timeout_seconds} seconds during IPC communication" ) - # Pipe connected and IPC file ready; read it - while True: - if process_pid is not None and not _process_is_alive(process_pid): - raise RuntimeError( - f"Accordo process (PID {process_pid}) crashed or terminated during execution. " - "Check for segfaults or GPU memory access errors." - ) - if time.time() - start_time > ipc_timeout_seconds: - raise TimeoutError( - f"Timeout after {ipc_timeout_seconds} seconds during IPC communication" - ) - try: - ipc_handles, ptr_sizes = read_ipc_handles( - args, ipc_file_name, sentinel_file=sentinel_file - ) - break - except AccordoKernelNeverDispatched: - raise - except Exception: - time.sleep(0.1) + # Pipe connected and IPC file ready; dispatch batches are handled below. finally: + # Close initial readiness reader. A dedicated keepalive reader is created + # during dispatch processing so subsequent dispatch writes never block. if pipe_fd is not None: try: os.close(pipe_fd) @@ -290,37 +297,86 @@ def _open_pipe(): "__hip_bfloat16*": ml_dtypes.bfloat16, } - results = [] pointer_args = list(filter(lambda arg: "*" in arg and "const" not in arg, args)) logging.debug(f"pointer_args: {pointer_args}") + output_arg_count = len(pointer_args) + if output_arg_count == 0: + return [[]] + + processed_records = 0 + dispatch_results = [] + processing_start = time.time() + keepalive_fd = os.open(pipe_name, os.O_RDONLY | os.O_NONBLOCK) + try: + while True: + if time.time() - processing_start > ipc_timeout_seconds: + raise TimeoutError( + f"Timeout after {ipc_timeout_seconds} seconds during IPC communication" + ) - for handle, arg, array_size in zip(ipc_handles, pointer_args, ptr_sizes): - ptr = open_ipc_handle(handle) - logging.debug(f"Opened IPC Ptr: {ptr} (0x{ptr:x})") - - # Strip type qualifiers (restrict, const, volatile) and type specifiers (struct, union, class, enum) - words_to_strip = ("restrict", "const", "volatile", "struct", "union", "class", "enum") - arg_type = " ".join(word for word in arg.split() if word not in words_to_strip) - logging.debug(f"arg_type (after stripping qualifiers and specifiers): {arg_type}") - - if arg_type in type_map: - dtype = type_map[arg_type] - logging.debug(f"dtype: {dtype}") - - # Special handling for FP16 and bfloat16 - if arg_type == "__half*": - temp_array = memcpy_d2h(ptr, array_size // 2, ctypes.c_uint16) - host_array = np.frombuffer(temp_array, dtype=np.float16) - elif arg_type == "__hip_bfloat16*": - temp_array = memcpy_d2h(ptr, array_size // 2, ctypes.c_uint16) - host_array = np.frombuffer(temp_array, dtype=ml_dtypes.bfloat16) - else: - num_elements = array_size // ctypes.sizeof(dtype) - host_array = memcpy_d2h(ptr, num_elements, dtype) - else: - raise TypeError(f"Unsupported pointer type: {arg_type}") - - logging.debug(f"Received data from IPC ({arg_type}/{len(host_array)}): {host_array}") - results.append(host_array) - - return results + records = _read_ipc_records(ipc_file_name) + while len(records) - processed_records >= output_arg_count: + batch = records[processed_records : processed_records + output_arg_count] + processed_records += output_arg_count + + dispatch_arrays = [] + for (handle, array_size), arg in zip(batch, pointer_args): + ptr = open_ipc_handle(handle) + logging.debug(f"Opened IPC Ptr: {ptr} (0x{ptr:x})") + + words_to_strip = ( + "restrict", + "const", + "volatile", + "struct", + "union", + "class", + "enum", + ) + arg_type = " ".join(word for word in arg.split() if word not in words_to_strip) + logging.debug( + f"arg_type (after stripping qualifiers and specifiers): {arg_type}" + ) + + if arg_type not in type_map: + raise TypeError(f"Unsupported pointer type: {arg_type}") + + dtype = type_map[arg_type] + logging.debug(f"dtype: {dtype}") + + if arg_type == "__half*": + temp_array = memcpy_d2h(ptr, array_size // 2, ctypes.c_uint16) + host_array = np.frombuffer(temp_array, dtype=np.float16) + elif arg_type == "__hip_bfloat16*": + temp_array = memcpy_d2h(ptr, array_size // 2, ctypes.c_uint16) + host_array = np.frombuffer(temp_array, dtype=ml_dtypes.bfloat16) + else: + num_elements = array_size // ctypes.sizeof(dtype) + host_array = memcpy_d2h(ptr, num_elements, dtype) + + logging.debug( + f"Received data from IPC ({arg_type}/{len(host_array)}): {host_array}" + ) + dispatch_arrays.append(host_array) + + dispatch_results.append(dispatch_arrays) + send_response(pipe_name) + + if process_pid is not None and not _process_is_alive(process_pid): + if len(records) - processed_records > 0: + raise RuntimeError( + "Accordo process exited with incomplete IPC batch. " + "This indicates dispatch/IPC synchronization failure." + ) + if dispatch_results: + return dispatch_results + raise RuntimeError( + f"Accordo process (PID {process_pid}) terminated before producing IPC data." + ) + + time.sleep(0.05) + finally: + try: + os.close(keepalive_fd) + except OSError: + pass diff --git a/accordo/accordo/mcp/server.py b/accordo/accordo/mcp/server.py index 278a170..5d3fd1d 100644 --- a/accordo/accordo/mcp/server.py +++ b/accordo/accordo/mcp/server.py @@ -4,6 +4,8 @@ """MCP Server for Accordo - Automated Kernel Validation.""" +from typing import Optional + from fastmcp import FastMCP from accordo import Accordo @@ -16,6 +18,9 @@ def run_validate_kernel_correctness( reference_command: list[str], optimized_command: list[str], tolerance: float = 1e-6, + atol: Optional[float] = None, + rtol: float = 0.0, + equal_nan: bool = False, working_directory: str = ".", ) -> dict: """Run kernel correctness validation. Call this from Python; MCP tool wraps it.""" @@ -28,7 +33,14 @@ def run_validate_kernel_correctness( ref_snapshot = validator.capture_snapshot(binary=reference_command) opt_snapshot = validator.capture_snapshot(binary=optimized_command) - result = validator.compare_snapshots(ref_snapshot, opt_snapshot, tolerance=tolerance) + result = validator.compare_snapshots( + ref_snapshot, + opt_snapshot, + tolerance=tolerance, + atol=atol, + rtol=rtol, + equal_nan=equal_nan, + ) return { "is_valid": result.is_valid, @@ -43,6 +55,9 @@ def validate_kernel_correctness( reference_command: list[str], optimized_command: list[str], tolerance: float = 1e-6, + atol: Optional[float] = None, + rtol: float = 0.0, + equal_nan: bool = False, working_directory: str = ".", ) -> dict: """ @@ -56,6 +71,9 @@ def validate_kernel_correctness( reference_command: Command for reference version as list (e.g., ['./ref']) optimized_command: Command for optimized version as list (e.g., ['./opt']) tolerance: Numerical tolerance for comparisons (default: 1e-6) + atol: Absolute tolerance (overrides tolerance if provided) + rtol: Relative tolerance for comparisons (default: 0.0) + equal_nan: Whether NaN values should compare equal (default: False) working_directory: Working directory for commands (default: '.') Returns: @@ -66,6 +84,9 @@ def validate_kernel_correctness( reference_command=reference_command, optimized_command=optimized_command, tolerance=tolerance, + atol=atol, + rtol=rtol, + equal_nan=equal_nan, working_directory=working_directory, ) diff --git a/accordo/accordo/result.py b/accordo/accordo/result.py index 9fbf6d2..83729bb 100644 --- a/accordo/accordo/result.py +++ b/accordo/accordo/result.py @@ -30,11 +30,15 @@ class ArrayMismatch: mean_difference: float reference_sample: np.ndarray optimized_sample: np.ndarray + dispatch_index: Optional[int] = None def __str__(self) -> str: """Human-readable string representation.""" + dispatch_prefix = ( + f"dispatch {self.dispatch_index}: " if self.dispatch_index is not None else "" + ) return ( - f"Mismatch in arg '{self.arg_name}' ({self.arg_type}): " + f"Mismatch in {dispatch_prefix}arg '{self.arg_name}' ({self.arg_type}): " f"max_diff={self.max_difference:.2e}, mean_diff={self.mean_difference:.2e}" ) diff --git a/accordo/accordo/snapshot.py b/accordo/accordo/snapshot.py index 59ef1f9..530238d 100644 --- a/accordo/accordo/snapshot.py +++ b/accordo/accordo/snapshot.py @@ -14,12 +14,15 @@ class Snapshot: """Represents a captured snapshot of kernel argument data. Attributes: - arrays: List of numpy arrays containing kernel argument data + arrays: Output arrays from the first kernel dispatch (for backward compatibility). + Use dispatch_arrays for per-dispatch access when multiple dispatches are captured. execution_time_ms: Time taken to execute and capture the snapshot (milliseconds) binary: The binary command that was executed working_directory: The directory where the binary was executed grid_size: Optional grid dimensions dict with x,y,z (if available) block_size: Optional workgroup dimensions dict with x,y,z (if available) + dispatch_arrays: Optional list of captured dispatch arrays. Each dispatch + is a list of output arrays in kernel argument order. Example: @@ -40,6 +43,7 @@ class Snapshot: working_directory: str grid_size: Optional[dict] = None block_size: Optional[dict] = None + dispatch_arrays: Optional[List[List[np.ndarray]]] = None def __repr__(self) -> str: """Pretty representation of snapshot.""" @@ -60,6 +64,8 @@ def summary(self) -> str: f" Execution Time: {self.execution_time_ms:.2f}ms", f" Number of Arrays: {len(self.arrays)}", ] + if self.dispatch_arrays is not None: + lines.append(f" Number of Dispatches: {len(self.dispatch_arrays)}") if self.grid_size is not None: lines.append( diff --git a/accordo/accordo/validator.py b/accordo/accordo/validator.py index 09172ff..c37c7a3 100644 --- a/accordo/accordo/validator.py +++ b/accordo/accordo/validator.py @@ -15,7 +15,7 @@ import numpy as np from ._internal.codegen import generate_kernel_metadata -from ._internal.ipc.communication import get_kern_arg_data, send_response +from ._internal.ipc.communication import get_kern_arg_data from .exceptions import AccordoBuildError, AccordoProcessError, AccordoTimeoutError from .kernel_args import extract_kernel_arguments from .result import ArrayMismatch, ValidationResult @@ -79,18 +79,22 @@ def _build_accordo(accordo_path: Path, parallel_jobs: int = 16) -> Path: raise AccordoBuildError(f"Accordo build failed: {str(e)}") -def _validate_arrays(arr1: np.ndarray, arr2: np.ndarray, tolerance: float) -> bool: +def _validate_arrays( + arr1: np.ndarray, arr2: np.ndarray, atol: float, rtol: float, equal_nan: bool +) -> bool: """Validate two arrays are close within tolerance. Args: arr1: First array arr2: Second array - tolerance: Absolute tolerance + atol: Absolute tolerance + rtol: Relative tolerance + equal_nan: Whether NaN values compare equal Returns: - True if arrays match within tolerance + True if arrays match within tolerances """ - return np.allclose(arr1, arr2, atol=tolerance, rtol=0) + return np.allclose(arr1, arr2, atol=atol, rtol=rtol, equal_nan=equal_nan) class Accordo: @@ -280,12 +284,13 @@ def capture_snapshot( pass return Snapshot( - arrays=result_arrays, + arrays=result_arrays[0] if result_arrays else [], execution_time_ms=execution_time_ms, binary=binary, working_directory=self.working_directory, grid_size=grid, block_size=block, + dispatch_arrays=result_arrays, ) except _TimeoutException: signal.alarm(0) @@ -308,26 +313,44 @@ def compare_snapshots( self, reference_snapshot: Snapshot, optimized_snapshot: Snapshot, - tolerance: float = 1e-6, + tolerance: Optional[float] = 1e-6, + *, + atol: Optional[float] = None, + rtol: float = 0.0, + equal_nan: bool = False, ) -> ValidationResult: """Compare two snapshots and validate their arrays. Args: reference_snapshot: Snapshot from reference binary optimized_snapshot: Snapshot from optimized binary - tolerance: Absolute tolerance for array comparison + tolerance: Legacy absolute tolerance (backward-compatible alias for atol) + atol: Absolute tolerance for array comparison + rtol: Relative tolerance for array comparison + equal_nan: If True, NaN values compare equal (torch.isclose-like behavior) Returns: ValidationResult with validation status and details Example: - >>> result = validator.compare_snapshots(ref, opt, tolerance=1e-4) + >>> result = validator.compare_snapshots(ref, opt, tolerance=1e-4, rtol=1e-5, equal_nan=False) >>> if result.is_valid: ... print(f"✓ PASS: {result.num_arrays_validated} arrays matched") """ + effective_atol = atol if atol is not None else (1e-6 if tolerance is None else tolerance) + reference_dispatches = ( + reference_snapshot.dispatch_arrays + if reference_snapshot.dispatch_arrays is not None + else [reference_snapshot.arrays] + ) + optimized_dispatches = ( + optimized_snapshot.dispatch_arrays + if optimized_snapshot.dispatch_arrays is not None + else [optimized_snapshot.arrays] + ) results = { - "reference": reference_snapshot.arrays, - "optimized": optimized_snapshot.arrays, + "reference": reference_dispatches, + "optimized": optimized_dispatches, } execution_times = { "reference": reference_snapshot.execution_time_ms, @@ -337,7 +360,9 @@ def compare_snapshots( return self._validate_results( results=results, execution_times=execution_times, - tolerance=tolerance, + atol=effective_atol, + rtol=rtol, + equal_nan=equal_nan, ) def _run_instrumented_app( @@ -346,7 +371,7 @@ def _run_instrumented_app( label: str, extra_env: Optional[dict] = None, timeout_seconds: int = 30, - ) -> List[np.ndarray]: + ) -> List[List[np.ndarray]]: """Run an instrumented application and collect kernel argument data. Args: @@ -356,7 +381,7 @@ def _run_instrumented_app( timeout_seconds: Timeout for IPC wait (used by get_kern_arg_data) Returns: - List of numpy arrays with kernel argument data + List of dispatch captures. Each dispatch is a list of output arrays. """ timestamp = int(time.time() * 1000) pipe_name = f"/tmp/kernel_pipe_{timestamp}_{label}" @@ -422,34 +447,38 @@ def _run_instrumented_app( pass raise - # Send completion response - send_response(pipe_name) - return result_arrays def _validate_results( self, results: dict, execution_times: dict, - tolerance: float, + atol: float, + rtol: float, + equal_nan: bool, ) -> ValidationResult: """Validate results from reference and optimized runs. Args: - results: Dictionary with "reference" and "optimized" array lists + results: Dictionary with "reference" and "optimized" dispatch lists execution_times: Execution times for each run - tolerance: Absolute tolerance for comparison + atol: Absolute tolerance for comparison + rtol: Relative tolerance for comparison + equal_nan: Whether NaN values compare equal Returns: ValidationResult with validation status """ - reference_arrays = results["reference"] - optimized_arrays = results["optimized"] + reference_dispatches = results["reference"] + optimized_dispatches = results["optimized"] - if len(reference_arrays) != len(optimized_arrays): + if len(reference_dispatches) != len(optimized_dispatches): return ValidationResult( is_valid=False, - error_message=f"Array count mismatch: {len(reference_arrays)} vs {len(optimized_arrays)}", + error_message=( + "Dispatch count mismatch: " + f"{len(reference_dispatches)} vs {len(optimized_dispatches)}" + ), execution_time_ms=execution_times, ) @@ -464,40 +493,65 @@ def _validate_results( if "*" in arg_type and "const" not in arg_type ] - for array_idx, (ref_arr, opt_arr) in enumerate(zip(reference_arrays, optimized_arrays)): - # Map array index to the correct kernel argument index - kernel_arg_idx = output_arg_indices[array_idx] - arg_name, arg_type = self.kernel_args[kernel_arg_idx] - - if not _validate_arrays(ref_arr, opt_arr, tolerance): - # Array mismatch - diff = np.abs(ref_arr - opt_arr) - mismatch = ArrayMismatch( - arg_index=kernel_arg_idx, # Use kernel arg index, not array index - arg_name=arg_name, - arg_type=arg_type, - max_difference=float(np.max(diff)), - mean_difference=float(np.mean(diff)), - reference_sample=ref_arr[:10] if len(ref_arr) > 10 else ref_arr, - optimized_sample=opt_arr[:10] if len(opt_arr) > 10 else opt_arr, - ) - mismatches.append(mismatch) - - logging.debug( - f"Output array {array_idx} (kernel arg {kernel_arg_idx} '{arg_name}' {arg_type}) - NOT close" - ) - logging.debug(f" Max difference: {mismatch.max_difference}") - logging.debug(f" Mean difference: {mismatch.mean_difference}") - else: - # Array matched - matched_arrays[arg_name] = { - "index": kernel_arg_idx, # Use kernel arg index - "type": arg_type, - "size": len(ref_arr), - } - logging.debug( - f"Output array {array_idx} (kernel arg {kernel_arg_idx} '{arg_name}' {arg_type}) - MATCH" + for dispatch_idx, (reference_arrays, optimized_arrays) in enumerate( + zip(reference_dispatches, optimized_dispatches) + ): + if len(reference_arrays) != len(optimized_arrays): + return ValidationResult( + is_valid=False, + error_message=( + f"Array count mismatch at dispatch {dispatch_idx}: " + f"{len(reference_arrays)} vs {len(optimized_arrays)}" + ), + execution_time_ms=execution_times, ) + for array_idx, (ref_arr, opt_arr) in enumerate(zip(reference_arrays, optimized_arrays)): + # Map array index to the correct kernel argument index + kernel_arg_idx = output_arg_indices[array_idx] + arg_name, arg_type = self.kernel_args[kernel_arg_idx] + matched_key = f"dispatch_{dispatch_idx}:{arg_name}" + + if not _validate_arrays(ref_arr, opt_arr, atol, rtol, equal_nan): + # Array mismatch + diff = np.abs(ref_arr - opt_arr) + finite_diff = diff[~np.isnan(diff)] + if finite_diff.size > 0: + max_diff = float(np.max(finite_diff)) + mean_diff = float(np.mean(finite_diff)) + else: + max_diff = 0.0 + mean_diff = 0.0 + mismatch = ArrayMismatch( + arg_index=kernel_arg_idx, # Use kernel arg index, not array index + arg_name=arg_name, + arg_type=arg_type, + max_difference=max_diff, + mean_difference=mean_diff, + reference_sample=ref_arr[:10] if len(ref_arr) > 10 else ref_arr, + optimized_sample=opt_arr[:10] if len(opt_arr) > 10 else opt_arr, + dispatch_index=dispatch_idx, + ) + mismatches.append(mismatch) + + logging.debug( + f"Dispatch {dispatch_idx} output array {array_idx} " + f"(kernel arg {kernel_arg_idx} '{arg_name}' {arg_type}) - NOT close" + ) + logging.debug(f" Max difference: {mismatch.max_difference}") + logging.debug(f" Mean difference: {mismatch.mean_difference}") + else: + # Array matched + matched_arrays[matched_key] = { + "index": kernel_arg_idx, # Use kernel arg index + "type": arg_type, + "size": len(ref_arr), + "dispatch": dispatch_idx, + "arg_name": arg_name, + } + logging.debug( + f"Dispatch {dispatch_idx} output array {array_idx} " + f"(kernel arg {kernel_arg_idx} '{arg_name}' {arg_type}) - MATCH" + ) # Determine overall success is_valid = len(mismatches) == 0 diff --git a/accordo/tests/test_reduction_validation.py b/accordo/tests/test_reduction_validation.py index d935015..6d1cfc8 100644 --- a/accordo/tests/test_reduction_validation.py +++ b/accordo/tests/test_reduction_validation.py @@ -18,9 +18,10 @@ from textwrap import dedent from pathlib import Path +import numpy as np import pytest -from accordo import Accordo +from accordo import Accordo, Snapshot from accordo.exceptions import AccordoKernelNeverDispatched # ----------------------------------------------------------------------------- @@ -189,6 +190,49 @@ def _scale_values_source(dtype: str) -> str: } """ +MULTI_DISPATCH_KERNEL = """ +__global__ void scale_values_dispatch(const float* input, float* output, float factor, int N) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < N) output[idx] = input[idx] * factor; +} +""" + + +def _multi_dispatch_source(second_factor: str, second_launch: bool = True) -> str: + """Build a source that launches the same kernel one or two times.""" + second_launch_code = ( + f"hipLaunchKernelGGL(scale_values_dispatch, dim3((N + 255) / 256), dim3(256), 0, 0, d_in, d_out, {second_factor}, N);\n" + " hipDeviceSynchronize();\n" + if second_launch + else "" + ) + return f""" +#include +#include +#include +{MULTI_DISPATCH_KERNEL} + +int main() {{ + const int N = 256; + size_t bytes = N * sizeof(float); + float *d_in, *d_out; + hipMalloc(&d_in, bytes); + hipMalloc(&d_out, bytes); + float* h_in = (float*)malloc(bytes); + for (int i = 0; i < N; i++) h_in[i] = (float)(i + 1); + hipMemcpy(d_in, h_in, bytes, hipMemcpyHostToDevice); + + hipLaunchKernelGGL(scale_values_dispatch, dim3((N + 255) / 256), dim3(256), 0, 0, d_in, d_out, 2.0f, N); + hipDeviceSynchronize(); + {second_launch_code} + + free(h_in); + hipFree(d_in); + hipFree(d_out); + return 0; +}} +""" + def _compile_hip(kernel_code: str, name: str, tmp_dir: Path) -> Path: """Write HIP source, compile with hipcc, return path to binary.""" @@ -357,6 +401,156 @@ def test_validation_result_summary(): assert "pass" in pass_result.summary().lower() or "match" in pass_result.summary().lower() +def test_compare_snapshots_supports_rtol(): + """Relative tolerance should allow proportional differences when atol is zero.""" + with tempfile.TemporaryDirectory(prefix="accordo_test_") as tmp_dir: + tmp_path = Path(tmp_dir) + bin_path = _compile_hip(_reduce_source(REDUCE_KERNEL_BASELINE), "baseline", tmp_path) + validator = Accordo( + binary=str(bin_path), + kernel_name="reduce_sum", + working_directory=str(tmp_path), + ) + + ref_arr = np.array([1000.0], dtype=np.float32) + opt_arr = np.array([1000.5], dtype=np.float32) + ref = Snapshot( + arrays=[ref_arr], + dispatch_arrays=[[ref_arr]], + execution_time_ms=1.0, + binary=[str(bin_path)], + working_directory=str(tmp_path), + ) + opt = Snapshot( + arrays=[opt_arr], + dispatch_arrays=[[opt_arr]], + execution_time_ms=1.0, + binary=[str(bin_path)], + working_directory=str(tmp_path), + ) + + strict = validator.compare_snapshots(ref, opt, atol=0.0, rtol=0.0) + relaxed = validator.compare_snapshots(ref, opt, atol=0.0, rtol=1e-3) + + assert not strict.is_valid + assert relaxed.is_valid + + +def test_compare_snapshots_equal_nan_toggle(): + """NaN equality should honor equal_nan flag.""" + with tempfile.TemporaryDirectory(prefix="accordo_test_") as tmp_dir: + tmp_path = Path(tmp_dir) + bin_path = _compile_hip(_reduce_source(REDUCE_KERNEL_BASELINE), "baseline", tmp_path) + validator = Accordo( + binary=str(bin_path), + kernel_name="reduce_sum", + working_directory=str(tmp_path), + ) + + ref_arr = np.array([1.0, np.nan, 3.0], dtype=np.float32) + opt_arr = np.array([1.0, np.nan, 3.0], dtype=np.float32) + ref = Snapshot( + arrays=[ref_arr], + dispatch_arrays=[[ref_arr]], + execution_time_ms=1.0, + binary=[str(bin_path)], + working_directory=str(tmp_path), + ) + opt = Snapshot( + arrays=[opt_arr], + dispatch_arrays=[[opt_arr]], + execution_time_ms=1.0, + binary=[str(bin_path)], + working_directory=str(tmp_path), + ) + + strict = validator.compare_snapshots(ref, opt, equal_nan=False) + nan_equal = validator.compare_snapshots(ref, opt, equal_nan=True) + + assert not strict.is_valid + assert nan_equal.is_valid + + +def test_compare_snapshots_tolerance_backward_compatibility(): + """Legacy tolerance argument should continue to work as absolute tolerance.""" + with tempfile.TemporaryDirectory(prefix="accordo_test_") as tmp_dir: + tmp_path = Path(tmp_dir) + bin_path = _compile_hip(_reduce_source(REDUCE_KERNEL_BASELINE), "baseline", tmp_path) + validator = Accordo( + binary=str(bin_path), + kernel_name="reduce_sum", + working_directory=str(tmp_path), + ) + ref_arr = np.array([1.0], dtype=np.float32) + opt_arr = np.array([1.00009], dtype=np.float32) + ref = Snapshot( + arrays=[ref_arr], + dispatch_arrays=[[ref_arr]], + execution_time_ms=1.0, + binary=[str(bin_path)], + working_directory=str(tmp_path), + ) + opt = Snapshot( + arrays=[opt_arr], + dispatch_arrays=[[opt_arr]], + execution_time_ms=1.0, + binary=[str(bin_path)], + working_directory=str(tmp_path), + ) + result = validator.compare_snapshots(ref, opt, tolerance=1e-4) + assert result.is_valid + + +def test_multi_dispatch_second_dispatch_mismatch_detected(): + """Validation should compare all dispatches, not only the first one.""" + with tempfile.TemporaryDirectory(prefix="accordo_test_") as tmp_dir: + tmp_path = Path(tmp_dir) + ref_bin = _compile_hip( + _multi_dispatch_source("3.0f", second_launch=True), "ref_multi", tmp_path + ) + bad_bin = _compile_hip( + _multi_dispatch_source("4.0f", second_launch=True), "bad_multi", tmp_path + ) + validator = Accordo( + binary=str(ref_bin), + kernel_name="scale_values_dispatch", + working_directory=str(tmp_path), + ) + ref_snap = validator.capture_snapshot(binary=str(ref_bin), timeout_seconds=30) + bad_snap = validator.capture_snapshot(binary=str(bad_bin), timeout_seconds=30) + result = validator.compare_snapshots(ref_snap, bad_snap, tolerance=1e-6) + + assert len(ref_snap.dispatch_arrays or []) == 2 + assert len(bad_snap.dispatch_arrays or []) == 2 + assert not result.is_valid + assert any(m.dispatch_index == 1 for m in result.mismatches) + + +def test_multi_dispatch_count_mismatch_fails(): + """Validation should fail clearly when dispatch counts differ.""" + with tempfile.TemporaryDirectory(prefix="accordo_test_") as tmp_dir: + tmp_path = Path(tmp_dir) + ref_bin = _compile_hip( + _multi_dispatch_source("3.0f", second_launch=True), "ref_multi", tmp_path + ) + one_dispatch_bin = _compile_hip( + _multi_dispatch_source("3.0f", second_launch=False), "one_dispatch", tmp_path + ) + validator = Accordo( + binary=str(ref_bin), + kernel_name="scale_values_dispatch", + working_directory=str(tmp_path), + ) + ref_snap = validator.capture_snapshot(binary=str(ref_bin), timeout_seconds=30) + one_snap = validator.capture_snapshot(binary=str(one_dispatch_bin), timeout_seconds=30) + result = validator.compare_snapshots(ref_snap, one_snap, tolerance=1e-6) + + assert len(ref_snap.dispatch_arrays or []) == 2 + assert len(one_snap.dispatch_arrays or []) == 1 + assert not result.is_valid + assert result.error_message and "Dispatch count mismatch" in result.error_message + + # ----------------------------------------------------------------------------- # IPC robustness: kernel never dispatched, Python dies before "done" # -----------------------------------------------------------------------------