Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions accordo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ validator = Accordo(binary="./app_ref", kernel_name="reduce_sum")
ref = validator.capture_snapshot(binary="./app_ref")
opt = validator.capture_snapshot(binary="./app_opt")

# Compare with specified tolerance
result = validator.compare_snapshots(ref, opt, tolerance=1e-6)
# Compare with allclose-style controls
result = validator.compare_snapshots(ref, opt, atol=1e-6, rtol=1e-5, equal_nan=False)

if result.is_valid:
print(f"✓ PASS: {result.num_arrays_validated} arrays matched")
Expand All @@ -48,7 +48,7 @@ ref = validator.capture_snapshot(binary="./ref")

for opt_binary in ["./opt_v1", "./opt_v2", "./opt_v3"]:
opt = validator.capture_snapshot(binary=opt_binary)
result = validator.compare_snapshots(ref, opt, tolerance=1e-6)
result = validator.compare_snapshots(ref, opt, atol=1e-6, rtol=1e-5, equal_nan=False)
print(f"{opt_binary}: {'✓ PASS' if result.is_valid else '✗ FAIL'}")
```

Expand Down Expand Up @@ -84,12 +84,13 @@ The CLI passes each flag as a **single executable path** (no embedded spaces or

**Methods:**
- `capture_snapshot(binary, timeout_seconds=30)` → `Snapshot`
- `compare_snapshots(reference, optimized, tolerance=1e-6)` → `ValidationResult`
- `compare_snapshots(reference, optimized, *, atol=1e-6, rtol=0.0, equal_nan=False)` → `ValidationResult`

### `Snapshot`

**Attributes:**
- `arrays` (list[np.ndarray]): Captured output arrays
- `dispatch_arrays` (list[list[np.ndarray]] | None): Captured outputs for each dispatch
- `execution_time_ms` (float): Execution time
- `grid_size`, `block_size` (dict | None): Kernel dimensions

Expand Down
6 changes: 3 additions & 3 deletions accordo/accordo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
>>> ref = validator.capture_snapshot(binary="./app_ref")
>>> opt = validator.capture_snapshot(binary="./app_opt")
>>>
>>> # Compare with specified tolerance
>>> result = validator.compare_snapshots(ref, opt, tolerance=1e-6)
>>> # Compare with configurable allclose-style tolerances
>>> result = validator.compare_snapshots(ref, opt, atol=1e-6, rtol=1e-5, equal_nan=False)
>>> print(f"Valid: {result.is_valid}")
Efficient Example (multiple comparisons):
Expand All @@ -30,7 +30,7 @@
>>> # Compare against multiple optimizations
>>> for opt_bin in ["./opt1", "./opt2", "./opt3"]:
... opt = validator.capture_snapshot(binary=opt_bin)
... result = validator.compare_snapshots(ref, opt, tolerance=1e-6)
... result = validator.compare_snapshots(ref, opt, atol=1e-6, rtol=1e-5)
... print(f"{opt_bin}: {'PASS' if result.is_valid else 'FAIL'}")
Multiple Kernels:
Expand Down
160 changes: 108 additions & 52 deletions accordo/accordo/_internal/ipc/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,30 @@ def read_ipc_handles(args, ipc_file_name, sentinel_file=None):
return handles, sizes


def _read_ipc_records(ipc_file_name):
"""Read all IPC records in-order from file as (handle_np, size) tuples."""
if not os.path.exists(ipc_file_name):
return []

with open(ipc_file_name, "rb") as file:
data = file.read()

records = []
messages = data.split(b"BEGIN\n")
for message in messages:
if b"END\n" not in message:
continue
content = message.split(b"END\n")[0]
if len(content) != 72:
continue
handle_data = content[:64]
size_data = content[64:72]
handle_np = np.frombuffer(handle_data, dtype=np.uint8)
size_value = int.from_bytes(size_data, byteorder="little")
records.append((handle_np, size_value))
return records


def send_response(pipe_name):
"""Send completion response through named pipe."""
with open(pipe_name, "w") as fifo:
Expand All @@ -123,7 +147,7 @@ def get_kern_arg_data(
baseline_time_ms: Baseline execution time (for dynamic timeout)

Returns:
List of NumPy arrays with argument data
List of dispatch captures. Each dispatch is a list of NumPy arrays.

Raises:
TimeoutError: If IPC operation times out
Expand Down Expand Up @@ -254,27 +278,10 @@ def _open_pipe():
f"Timeout after {ipc_timeout_seconds} seconds during IPC communication"
)

# Pipe connected and IPC file ready; read it
while True:
if process_pid is not None and not _process_is_alive(process_pid):
raise RuntimeError(
f"Accordo process (PID {process_pid}) crashed or terminated during execution. "
"Check for segfaults or GPU memory access errors."
)
if time.time() - start_time > ipc_timeout_seconds:
raise TimeoutError(
f"Timeout after {ipc_timeout_seconds} seconds during IPC communication"
)
try:
ipc_handles, ptr_sizes = read_ipc_handles(
args, ipc_file_name, sentinel_file=sentinel_file
)
break
except AccordoKernelNeverDispatched:
raise
except Exception:
time.sleep(0.1)
# Pipe connected and IPC file ready; dispatch batches are handled below.
finally:
# Close initial readiness reader. A dedicated keepalive reader is created
# during dispatch processing so subsequent dispatch writes never block.
if pipe_fd is not None:
try:
os.close(pipe_fd)
Expand All @@ -290,37 +297,86 @@ def _open_pipe():
"__hip_bfloat16*": ml_dtypes.bfloat16,
}

results = []
pointer_args = list(filter(lambda arg: "*" in arg and "const" not in arg, args))
logging.debug(f"pointer_args: {pointer_args}")
output_arg_count = len(pointer_args)
if output_arg_count == 0:
return [[]]

processed_records = 0
dispatch_results = []
processing_start = time.time()
keepalive_fd = os.open(pipe_name, os.O_RDONLY | os.O_NONBLOCK)
try:
while True:
if time.time() - processing_start > ipc_timeout_seconds:
raise TimeoutError(
f"Timeout after {ipc_timeout_seconds} seconds during IPC communication"
)

for handle, arg, array_size in zip(ipc_handles, pointer_args, ptr_sizes):
ptr = open_ipc_handle(handle)
logging.debug(f"Opened IPC Ptr: {ptr} (0x{ptr:x})")

# Strip type qualifiers (restrict, const, volatile) and type specifiers (struct, union, class, enum)
words_to_strip = ("restrict", "const", "volatile", "struct", "union", "class", "enum")
arg_type = " ".join(word for word in arg.split() if word not in words_to_strip)
logging.debug(f"arg_type (after stripping qualifiers and specifiers): {arg_type}")

if arg_type in type_map:
dtype = type_map[arg_type]
logging.debug(f"dtype: {dtype}")

# Special handling for FP16 and bfloat16
if arg_type == "__half*":
temp_array = memcpy_d2h(ptr, array_size // 2, ctypes.c_uint16)
host_array = np.frombuffer(temp_array, dtype=np.float16)
elif arg_type == "__hip_bfloat16*":
temp_array = memcpy_d2h(ptr, array_size // 2, ctypes.c_uint16)
host_array = np.frombuffer(temp_array, dtype=ml_dtypes.bfloat16)
else:
num_elements = array_size // ctypes.sizeof(dtype)
host_array = memcpy_d2h(ptr, num_elements, dtype)
else:
raise TypeError(f"Unsupported pointer type: {arg_type}")

logging.debug(f"Received data from IPC ({arg_type}/{len(host_array)}): {host_array}")
results.append(host_array)

return results
records = _read_ipc_records(ipc_file_name)
while len(records) - processed_records >= output_arg_count:
batch = records[processed_records : processed_records + output_arg_count]
processed_records += output_arg_count

dispatch_arrays = []
for (handle, array_size), arg in zip(batch, pointer_args):
ptr = open_ipc_handle(handle)
logging.debug(f"Opened IPC Ptr: {ptr} (0x{ptr:x})")

words_to_strip = (
"restrict",
"const",
"volatile",
"struct",
"union",
"class",
"enum",
)
arg_type = " ".join(word for word in arg.split() if word not in words_to_strip)
logging.debug(
f"arg_type (after stripping qualifiers and specifiers): {arg_type}"
)

if arg_type not in type_map:
raise TypeError(f"Unsupported pointer type: {arg_type}")

dtype = type_map[arg_type]
logging.debug(f"dtype: {dtype}")

if arg_type == "__half*":
temp_array = memcpy_d2h(ptr, array_size // 2, ctypes.c_uint16)
host_array = np.frombuffer(temp_array, dtype=np.float16)
elif arg_type == "__hip_bfloat16*":
temp_array = memcpy_d2h(ptr, array_size // 2, ctypes.c_uint16)
host_array = np.frombuffer(temp_array, dtype=ml_dtypes.bfloat16)
else:
num_elements = array_size // ctypes.sizeof(dtype)
host_array = memcpy_d2h(ptr, num_elements, dtype)

logging.debug(
f"Received data from IPC ({arg_type}/{len(host_array)}): {host_array}"
)
dispatch_arrays.append(host_array)

dispatch_results.append(dispatch_arrays)
send_response(pipe_name)

if process_pid is not None and not _process_is_alive(process_pid):
if len(records) - processed_records > 0:
raise RuntimeError(
"Accordo process exited with incomplete IPC batch. "
"This indicates dispatch/IPC synchronization failure."
)
if dispatch_results:
return dispatch_results
raise RuntimeError(
f"Accordo process (PID {process_pid}) terminated before producing IPC data."
)

time.sleep(0.05)
finally:
try:
os.close(keepalive_fd)
except OSError:
pass
23 changes: 22 additions & 1 deletion accordo/accordo/mcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

"""MCP Server for Accordo - Automated Kernel Validation."""

from typing import Optional

from fastmcp import FastMCP

from accordo import Accordo
Expand All @@ -16,6 +18,9 @@ def run_validate_kernel_correctness(
reference_command: list[str],
optimized_command: list[str],
tolerance: float = 1e-6,
atol: Optional[float] = None,
rtol: float = 0.0,
equal_nan: bool = False,
working_directory: str = ".",
) -> dict:
"""Run kernel correctness validation. Call this from Python; MCP tool wraps it."""
Expand All @@ -28,7 +33,14 @@ def run_validate_kernel_correctness(

ref_snapshot = validator.capture_snapshot(binary=reference_command)
opt_snapshot = validator.capture_snapshot(binary=optimized_command)
result = validator.compare_snapshots(ref_snapshot, opt_snapshot, tolerance=tolerance)
result = validator.compare_snapshots(
ref_snapshot,
opt_snapshot,
tolerance=tolerance,
atol=atol,
rtol=rtol,
equal_nan=equal_nan,
)

return {
"is_valid": result.is_valid,
Expand All @@ -43,6 +55,9 @@ def validate_kernel_correctness(
reference_command: list[str],
optimized_command: list[str],
tolerance: float = 1e-6,
atol: Optional[float] = None,
rtol: float = 0.0,
equal_nan: bool = False,
working_directory: str = ".",
) -> dict:
"""
Expand All @@ -56,6 +71,9 @@ def validate_kernel_correctness(
reference_command: Command for reference version as list (e.g., ['./ref'])
optimized_command: Command for optimized version as list (e.g., ['./opt'])
tolerance: Numerical tolerance for comparisons (default: 1e-6)
atol: Absolute tolerance (overrides tolerance if provided)
rtol: Relative tolerance for comparisons (default: 0.0)
equal_nan: Whether NaN values should compare equal (default: False)
working_directory: Working directory for commands (default: '.')

Returns:
Expand All @@ -66,6 +84,9 @@ def validate_kernel_correctness(
reference_command=reference_command,
optimized_command=optimized_command,
tolerance=tolerance,
atol=atol,
rtol=rtol,
equal_nan=equal_nan,
working_directory=working_directory,
)

Expand Down
6 changes: 5 additions & 1 deletion accordo/accordo/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ class ArrayMismatch:
mean_difference: float
reference_sample: np.ndarray
optimized_sample: np.ndarray
dispatch_index: Optional[int] = None

def __str__(self) -> str:
"""Human-readable string representation."""
dispatch_prefix = (
f"dispatch {self.dispatch_index}: " if self.dispatch_index is not None else ""
)
return (
f"Mismatch in arg '{self.arg_name}' ({self.arg_type}): "
f"Mismatch in {dispatch_prefix}arg '{self.arg_name}' ({self.arg_type}): "
f"max_diff={self.max_difference:.2e}, mean_diff={self.mean_difference:.2e}"
)

Expand Down
8 changes: 7 additions & 1 deletion accordo/accordo/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ class Snapshot:
"""Represents a captured snapshot of kernel argument data.

Attributes:
arrays: List of numpy arrays containing kernel argument data
arrays: Output arrays from the first kernel dispatch (for backward compatibility).
Use dispatch_arrays for per-dispatch access when multiple dispatches are captured.
execution_time_ms: Time taken to execute and capture the snapshot (milliseconds)
binary: The binary command that was executed
working_directory: The directory where the binary was executed
grid_size: Optional grid dimensions dict with x,y,z (if available)
block_size: Optional workgroup dimensions dict with x,y,z (if available)
dispatch_arrays: Optional list of captured dispatch arrays. Each dispatch
is a list of output arrays in kernel argument order.


Example:
Expand All @@ -40,6 +43,7 @@ class Snapshot:
working_directory: str
grid_size: Optional[dict] = None
block_size: Optional[dict] = None
dispatch_arrays: Optional[List[List[np.ndarray]]] = None

def __repr__(self) -> str:
"""Pretty representation of snapshot."""
Expand All @@ -60,6 +64,8 @@ def summary(self) -> str:
f" Execution Time: {self.execution_time_ms:.2f}ms",
f" Number of Arrays: {len(self.arrays)}",
]
if self.dispatch_arrays is not None:
lines.append(f" Number of Dispatches: {len(self.dispatch_arrays)}")

if self.grid_size is not None:
lines.append(
Expand Down
Loading
Loading