From ba04383786d8b536a974c34e30a49a664f41fd30 Mon Sep 17 00:00:00 2001 From: "Chuan(Richard) Li" Date: Thu, 26 Feb 2026 16:53:26 -0800 Subject: [PATCH 1/4] Add decode-phase Flash Attention kernel for LLM inference Implement a single-query-token Flash Attention kernel targeting the decode phase of autoregressive LLM inference. Uses online softmax with warp-level xor-shuffle reductions on AMD wave64. Includes correctness and performance tests against PyTorch SDPA reference. Made-with: Cursor --- kernels/flash_decode_attention.py | 284 +++++++++++++++++++ tests/kernels/test_flash_decode_attention.py | 173 +++++++++++ 2 files changed, 457 insertions(+) create mode 100644 kernels/flash_decode_attention.py create mode 100644 tests/kernels/test_flash_decode_attention.py diff --git a/kernels/flash_decode_attention.py b/kernels/flash_decode_attention.py new file mode 100644 index 00000000..8f9cc283 --- /dev/null +++ b/kernels/flash_decode_attention.py @@ -0,0 +1,284 @@ +"""Flash Decode Attention kernel builder. + +Single-query (decode-phase) attention using online softmax: + O[h,d] = sum_j( softmax(Q[h,:] . K[h,j,:] / sqrt(d_k))_j * V[h,j,d] ) + +Architecture: + Grid: (total_heads, 1, 1) -- one wavefront per (batch, head) + Block: (WARP_SIZE, 1, 1) -- AMD wave64, barrier-free dot product reduction + +Each thread owns ELEMS_PER_THREAD = head_dim / WARP_SIZE output elements. +Dot products Q.K[j] use intra-warp xor-shuffle sum reduction so all lanes +see the same score without shared-memory barriers. +Online softmax avoids materializing the full attention score matrix. + +Memory layout (row-major, batch and heads flattened into dim-0): + Q: [total_heads, head_dim] + K: [total_heads, seq_len, head_dim] + V: [total_heads, seq_len, head_dim] + O: [total_heads, head_dim] + +where total_heads = batch_size * num_heads. +""" + +from _mlir import ir + +from flydsl.dialects.ext import flir, arith, gpu +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +import _mlir.extras.types as T + + +KERNEL_NAME = "flash_decode_attention" + +WARP_SIZE = 64 + + +def dtype_to_elem_type(dtype_str: str): + if dtype_str == "f32": + return T.f32() + if dtype_str == "f16": + return T.f16() + if dtype_str == "bf16": + return T.bf16() + raise ValueError(f"unsupported dtype: {dtype_str}") + + +def build_flash_decode_attention_module( + seq_len: int, + head_dim: int, + dtype_str: str = "f16", +): + """Build MLIR module for decode-phase flash attention. + + Args: + seq_len: KV cache sequence length (compile-time constant). + head_dim: per-head dimension, must be divisible by WARP_SIZE (64). + dtype_str: element type for Q/K/V/O ("f32", "f16", or "bf16"). + + Returns: + An MlirModule instance whose ``__call__`` launches the kernel. + """ + if head_dim % WARP_SIZE != 0: + raise ValueError( + f"head_dim ({head_dim}) must be divisible by WARP_SIZE ({WARP_SIZE})" + ) + + arch = get_hip_arch() + DYN = ir.ShapedType.get_dynamic_size() + BLOCK_THREADS = WARP_SIZE + ELEMS_PER_THREAD = head_dim // WARP_SIZE + + _state = {} + + class _FlashDecodeAttn(flir.MlirModule): + GPU_MODULE_NAME = "flash_decode_attn" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + _state["elem_type"] = dtype_to_elem_type(dtype_str) + _state["compute_type"] = T.f32() + + @flir.kernel + def flash_decode_attention_kernel( + self: flir.T.i64, + Q: lambda: T.memref(DYN, head_dim, _state["elem_type"]), + K: lambda: T.memref(DYN, seq_len, head_dim, _state["elem_type"]), + V: lambda: T.memref(DYN, seq_len, head_dim, _state["elem_type"]), + O: lambda: T.memref(DYN, head_dim, _state["elem_type"]), + total_heads: lambda: T.index(), + ): + elem_type = _state["elem_type"] + compute_type = _state["compute_type"] + fm_fast = flir.arith.FastMathFlags.fast + + h = flir.const_index(flir.block_idx("x")) + tid = flir.const_index(flir.thread_idx("x")) + + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_zero_f = arith.constant(0.0, type=compute_type) + c_log2e = arith.constant(1.4426950408889634, type=compute_type) + rsqrt_d = arith.constant(1.0 / (head_dim ** 0.5), type=compute_type) + + # Thread t owns output elements [d_base .. d_base + ELEMS_PER_THREAD). + c_ept = flir.const_index(ELEMS_PER_THREAD) + d_base = flir.arith.MulIOp( + arith.as_value(tid), arith.as_value(c_ept) + ).result + + # Pre-compute per-element head-dim indices. + d_indices = [] + for e in range_constexpr(ELEMS_PER_THREAD): + d_off = flir.const_index(e) + d_indices.append( + flir.arith.AddIOp( + arith.as_value(d_base), arith.as_value(d_off) + ).result + ) + + # Load this thread's Q elements into registers. + q_local = [] + for e in range_constexpr(ELEMS_PER_THREAD): + q_e = flir.memref.load(Q, [arith.as_value(h), d_indices[e]]) + q_f = ( + q_e + if dtype_str == "f32" + else flir.arith.extf(compute_type, arith.as_value(q_e)) + ) + q_local.append(q_f) + + # ---- online softmax state ---- + m = c_neg_inf # running max + l = c_zero_f # running denominator (sum of exp) + acc = [c_zero_f] * ELEMS_PER_THREAD # weighted V accumulator + + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + + # ---- main loop over KV-cache positions (compile-time unrolled) ---- + for j_py in range_constexpr(seq_len): + j = flir.const_index(j_py) + + # Partial dot product: Q[d_base:d_base+EPT] . K[h, j, d_base:d_base+EPT] + partial = c_zero_f + for e in range_constexpr(ELEMS_PER_THREAD): + k_e = flir.memref.load( + K, [arith.as_value(h), arith.as_value(j), d_indices[e]] + ) + k_f = ( + k_e + if dtype_str == "f32" + else flir.arith.extf( + compute_type, arith.as_value(k_e) + ) + ) + qk = flir.arith.MulFOp( + arith.as_value(q_local[e]), + arith.as_value(k_f), + fastmath=fm_fast, + ).result + partial = flir.arith.AddFOp( + arith.as_value(partial), qk, fastmath=fm_fast + ).result + + # Warp-wide sum reduction (xor-shuffle, wave64). + w = arith.as_value(partial) + for sh in [32, 16, 8, 4, 2, 1]: + off = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp( + arith.as_value(w), off, width_i32, mode="xor" + ).shuffleResult + ) + w = flir.arith.AddFOp( + arith.as_value(w), peer, fastmath=fm_fast + ).result + + # score = dot(Q, K_j) / sqrt(head_dim) + score = flir.arith.MulFOp( + arith.as_value(w), + arith.as_value(rsqrt_d), + fastmath=fm_fast, + ).result + + # Online softmax update: + # m_new = max(m, score) + # correction = exp2((m_old - m_new) * log2e) + # p = exp2((score - m_new) * log2e) + # l = l * correction + p + # acc[e] = acc[e] * correction + p * V[h, j, e] + m_new = flir.arith.MaximumFOp( + arith.as_value(m), arith.as_value(score) + ).result + + diff_m = flir.arith.SubFOp( + arith.as_value(m), m_new, fastmath=fm_fast + ).result + corr_arg = flir.arith.MulFOp( + diff_m, arith.as_value(c_log2e), fastmath=fm_fast + ).result + correction = flir.math.exp2(corr_arg, fastmath=fm_fast) + + diff_s = flir.arith.SubFOp( + arith.as_value(score), m_new, fastmath=fm_fast + ).result + p_arg = flir.arith.MulFOp( + diff_s, arith.as_value(c_log2e), fastmath=fm_fast + ).result + p = flir.math.exp2(p_arg, fastmath=fm_fast) + + l_corr = flir.arith.MulFOp( + arith.as_value(l), + arith.as_value(correction), + fastmath=fm_fast, + ).result + l = flir.arith.AddFOp( + l_corr, arith.as_value(p), fastmath=fm_fast + ).result + + # Update accumulator with weighted V. + new_acc = [] + for e in range_constexpr(ELEMS_PER_THREAD): + v_e = flir.memref.load( + V, [arith.as_value(h), arith.as_value(j), d_indices[e]] + ) + v_f = ( + v_e + if dtype_str == "f32" + else flir.arith.extf( + compute_type, arith.as_value(v_e) + ) + ) + a_corr = flir.arith.MulFOp( + arith.as_value(acc[e]), + arith.as_value(correction), + fastmath=fm_fast, + ).result + pv = flir.arith.MulFOp( + arith.as_value(p), + arith.as_value(v_f), + fastmath=fm_fast, + ).result + new_acc.append( + flir.arith.AddFOp(a_corr, pv, fastmath=fm_fast).result + ) + + acc = new_acc + m = m_new + + # ---- store output: O[h, d] = acc[d] / l ---- + for e in range_constexpr(ELEMS_PER_THREAD): + out_f32 = flir.arith.DivFOp( + arith.as_value(acc[e]), + arith.as_value(l), + fastmath=fm_fast, + ).result + if dtype_str != "f32": + out_e = flir.arith.truncf(elem_type, out_f32) + else: + out_e = out_f32 + flir.memref.store( + arith.as_value(out_e), + O, + [arith.as_value(h), d_indices[e]], + ) + + @flir.jit + def __call__( + self: flir.T.i64, + Q: lambda: T.memref(DYN, head_dim, _state["elem_type"]), + K: lambda: T.memref(DYN, seq_len, head_dim, _state["elem_type"]), + V: lambda: T.memref(DYN, seq_len, head_dim, _state["elem_type"]), + O: lambda: T.memref(DYN, head_dim, _state["elem_type"]), + total_heads: lambda: T.index(), + ): + c1 = flir.arith_ext.index(1) + gx = total_heads + bx = flir.arith_ext.index(BLOCK_THREADS) + flir.gpu_ext.LaunchFuncOp( + ["flash_decode_attn", "flash_decode_attention_kernel"], + grid_size=(gx, c1, c1), + block_size=(bx, c1, c1), + kernel_operands=[Q, K, V, O, total_heads], + ) + + return _FlashDecodeAttn() diff --git a/tests/kernels/test_flash_decode_attention.py b/tests/kernels/test_flash_decode_attention.py new file mode 100644 index 00000000..08a04e19 --- /dev/null +++ b/tests/kernels/test_flash_decode_attention.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 +""" +Flash Decode Attention Test + +Verifies the single-query (decode-phase) attention kernel: + O = softmax(Q @ K^T / sqrt(d)) @ V + +where Q has a single token per (batch, head). + +Grid: (batch_size * num_heads, 1, 1) +Block: (64, 1, 1) -- single AMD wave64 +""" + +import sys +import os +from pathlib import Path + +_repo = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(_repo)) + +import pytest + +try: + import torch + import torch.nn.functional as F +except ImportError: + torch = None +if torch is None or not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + +from tests.test_common import run_perftest + +import flydsl +from kernels.flash_decode_attention import ( + build_flash_decode_attention_module, + KERNEL_NAME, +) + +WARMUP_ITERS = 5 +BENCH_ITERS = 20 + +DTYPE_MAP = { + "f32": torch.float32, + "f16": torch.float16, + "bf16": torch.bfloat16, +} + +ATOL_MAP = { + "f32": 1e-4, + "f16": 2e-2, + "bf16": 3e-2, +} + + +def run_test( + batch_size: int, + num_heads: int, + seq_len: int, + head_dim: int, + dtype_str: str = "f16", +): + total_heads = batch_size * num_heads + torch_dtype = DTYPE_MAP[dtype_str] + atol = ATOL_MAP[dtype_str] + + print( + f"\nTesting Flash Decode Attention " + f"(B={batch_size}, H={num_heads}, S={seq_len}, D={head_dim}, dtype={dtype_str})" + ) + + try: + m = build_flash_decode_attention_module(seq_len, head_dim, dtype_str) + exe = flydsl.compile(m) + except Exception as e: + print(f"[FAIL] Compile failed: {e}") + import traceback + traceback.print_exc() + return False + + torch.manual_seed(42) + + Q_ref = torch.randn(batch_size, num_heads, 1, head_dim, device="cuda", dtype=torch.float32) + K_ref = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float32) + V_ref = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float32) + + # PyTorch reference (always in fp32 for stability). + expected = F.scaled_dot_product_attention( + Q_ref, K_ref, V_ref, is_causal=False + ) # [B, H, 1, D] + expected = expected.squeeze(2).reshape(total_heads, head_dim).to(torch.float32) + + # Prepare device tensors in target dtype, flattened to [total_heads, ...]. + Q_dev = Q_ref.squeeze(2).reshape(total_heads, head_dim).to(torch_dtype).contiguous() + K_dev = K_ref.reshape(total_heads, seq_len, head_dim).to(torch_dtype).contiguous() + V_dev = V_ref.reshape(total_heads, seq_len, head_dim).to(torch_dtype).contiguous() + O_dev = torch.empty(total_heads, head_dim, device="cuda", dtype=torch_dtype) + + print(" Launching kernel...") + + def kernel_launch(): + exe(Q_dev, K_dev, V_dev, O_dev, total_heads) + + kernel_launch() + torch.cuda.synchronize() + + _, avg_us = run_perftest( + lambda: (kernel_launch(), torch.cuda.synchronize()), + num_iters=BENCH_ITERS, + num_warmup=WARMUP_ITERS, + ) + torch.cuda.synchronize() + avg_ms = avg_us / 1000.0 + + elem_bytes = 4 if dtype_str == "f32" else 2 + kv_bytes = 2 * total_heads * seq_len * head_dim * elem_bytes + bandwidth_gbs = kv_bytes / (avg_us / 1e6) / 1e9 + print(f" Kernel avg time: {avg_ms:.4f} ms (warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") + print(f" Bandwidth (KV read): {bandwidth_gbs:.2f} GB/s") + + output_f32 = O_dev.to(torch.float32) + error = (output_f32 - expected).abs().max().item() + print(f" Max absolute error: {error:.2e} (atol={atol})") + + if error < atol: + print(" PASSED") + return True + else: + print(" FAILED") + print(" Expected (first 8):", expected[0, :8]) + print(" Got (first 8):", output_f32[0, :8]) + return False + + +def test_flash_decode_attention(): + """Pytest entry point -- small configs for CI.""" + configs = [ + # (batch, heads, seq_len, head_dim, dtype) + (1, 1, 64, 128, "f32"), + (2, 4, 128, 128, "f16"), + (1, 2, 64, 128, "bf16"), + ] + + shapes_env = os.environ.get("FLYDSL_FLASH_ATTN_SHAPES", "").strip() + if shapes_env: + configs = [] + for part in shapes_env.split(";"): + p = part.strip() + if not p: + continue + b, h, s, d, dt = [x.strip() for x in p.split(",")] + configs.append((int(b), int(h), int(s), int(d), dt)) + + print("=" * 80) + print("Running Flash Decode Attention Tests") + print("=" * 80) + + failures = 0 + for batch, heads, seq_len, head_dim, dtype in configs: + if not run_test(batch, heads, seq_len, head_dim, dtype): + failures += 1 + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("=" * 80) + + assert failures == 0, f"{failures} test(s) failed" + + +if __name__ == "__main__": + test_flash_decode_attention() From 980371c0eadcc6acd408a53b56c8fe65fcf8b1bc Mon Sep 17 00:00:00 2001 From: "Chuan(Richard) Li" Date: Mon, 2 Mar 2026 10:33:06 -0800 Subject: [PATCH 2/4] Fix memref.load/store indices: wrap d_indices[e] with arith.as_value() All memref index operands must be MLIR Values. The d_indices list elements were raw OpResult objects not wrapped in arith.as_value(), causing 'Operand 2 of operation memref.load must be a Value' at compile time. Made-with: Cursor --- kernels/flash_decode_attention.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/kernels/flash_decode_attention.py b/kernels/flash_decode_attention.py index 8f9cc283..6ec7b1f9 100644 --- a/kernels/flash_decode_attention.py +++ b/kernels/flash_decode_attention.py @@ -119,7 +119,7 @@ def flash_decode_attention_kernel( # Load this thread's Q elements into registers. q_local = [] for e in range_constexpr(ELEMS_PER_THREAD): - q_e = flir.memref.load(Q, [arith.as_value(h), d_indices[e]]) + q_e = flir.memref.load(Q, [arith.as_value(h), arith.as_value(d_indices[e])]) q_f = ( q_e if dtype_str == "f32" @@ -142,7 +142,7 @@ def flash_decode_attention_kernel( partial = c_zero_f for e in range_constexpr(ELEMS_PER_THREAD): k_e = flir.memref.load( - K, [arith.as_value(h), arith.as_value(j), d_indices[e]] + K, [arith.as_value(h), arith.as_value(j), arith.as_value(d_indices[e])] ) k_f = ( k_e @@ -219,7 +219,7 @@ def flash_decode_attention_kernel( new_acc = [] for e in range_constexpr(ELEMS_PER_THREAD): v_e = flir.memref.load( - V, [arith.as_value(h), arith.as_value(j), d_indices[e]] + V, [arith.as_value(h), arith.as_value(j), arith.as_value(d_indices[e])] ) v_f = ( v_e @@ -259,7 +259,7 @@ def flash_decode_attention_kernel( flir.memref.store( arith.as_value(out_e), O, - [arith.as_value(h), d_indices[e]], + [arith.as_value(h), arith.as_value(d_indices[e])], ) @flir.jit From a4e9589319a2e3748229d2b3e4bb456c50c9e9b3 Mon Sep 17 00:00:00 2001 From: "Chuan(Richard) Li" Date: Mon, 2 Mar 2026 16:44:55 -0800 Subject: [PATCH 3/4] Fix MLIR Value wrapping in flash decode attention kernel Wrap all raw OpResult values with arith.as_value() before passing them as operands to flir.arith and flir.math ops. Without wrapping, the MLIR verifier rejects them with: 'Operand N of operation arith.addf must be a Value (is not a Value)'. Fixes compile failures on MI325 and MI355 CI runners. Made-with: Cursor --- kernels/flash_decode_attention.py | 33 +++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/kernels/flash_decode_attention.py b/kernels/flash_decode_attention.py index 6ec7b1f9..bf692620 100644 --- a/kernels/flash_decode_attention.py +++ b/kernels/flash_decode_attention.py @@ -157,7 +157,8 @@ def flash_decode_attention_kernel( fastmath=fm_fast, ).result partial = flir.arith.AddFOp( - arith.as_value(partial), qk, fastmath=fm_fast + arith.as_value(partial), arith.as_value(qk), + fastmath=fm_fast, ).result # Warp-wide sum reduction (xor-shuffle, wave64). @@ -191,20 +192,28 @@ def flash_decode_attention_kernel( ).result diff_m = flir.arith.SubFOp( - arith.as_value(m), m_new, fastmath=fm_fast + arith.as_value(m), arith.as_value(m_new), + fastmath=fm_fast, ).result corr_arg = flir.arith.MulFOp( - diff_m, arith.as_value(c_log2e), fastmath=fm_fast + arith.as_value(diff_m), arith.as_value(c_log2e), + fastmath=fm_fast, ).result - correction = flir.math.exp2(corr_arg, fastmath=fm_fast) + correction = flir.math.exp2( + arith.as_value(corr_arg), fastmath=fm_fast + ) diff_s = flir.arith.SubFOp( - arith.as_value(score), m_new, fastmath=fm_fast + arith.as_value(score), arith.as_value(m_new), + fastmath=fm_fast, ).result p_arg = flir.arith.MulFOp( - diff_s, arith.as_value(c_log2e), fastmath=fm_fast + arith.as_value(diff_s), arith.as_value(c_log2e), + fastmath=fm_fast, ).result - p = flir.math.exp2(p_arg, fastmath=fm_fast) + p = flir.math.exp2( + arith.as_value(p_arg), fastmath=fm_fast + ) l_corr = flir.arith.MulFOp( arith.as_value(l), @@ -212,7 +221,8 @@ def flash_decode_attention_kernel( fastmath=fm_fast, ).result l = flir.arith.AddFOp( - l_corr, arith.as_value(p), fastmath=fm_fast + arith.as_value(l_corr), arith.as_value(p), + fastmath=fm_fast, ).result # Update accumulator with weighted V. @@ -239,7 +249,10 @@ def flash_decode_attention_kernel( fastmath=fm_fast, ).result new_acc.append( - flir.arith.AddFOp(a_corr, pv, fastmath=fm_fast).result + flir.arith.AddFOp( + arith.as_value(a_corr), arith.as_value(pv), + fastmath=fm_fast, + ).result ) acc = new_acc @@ -253,7 +266,7 @@ def flash_decode_attention_kernel( fastmath=fm_fast, ).result if dtype_str != "f32": - out_e = flir.arith.truncf(elem_type, out_f32) + out_e = flir.arith.truncf(elem_type, arith.as_value(out_f32)) else: out_e = out_f32 flir.memref.store( From 279b8335d377987ed77988dfc4da4f921a1e5988 Mon Sep 17 00:00:00 2001 From: "Chuan(Richard) Li" Date: Mon, 2 Mar 2026 22:32:55 -0800 Subject: [PATCH 4/4] Add real LLM model shapes and perf comparison to flash decode test Add MODEL_CONFIGS with LLaMA 3.1 8B/70B and Mixtral/DeepSeek-like decode shapes. Use benchmark_common (PerfRow, bench_gpu_us_torch, print_perf_table) for consistent perf reporting. Optional PyTorch SDPA comparison via FLYDSL_COMPARE_SDPA=1. Model shapes enabled via FLYDSL_FLASH_ATTN_MODELS=1. Made-with: Cursor --- tests/kernels/test_flash_decode_attention.py | 114 +++++++++++++++---- 1 file changed, 89 insertions(+), 25 deletions(-) diff --git a/tests/kernels/test_flash_decode_attention.py b/tests/kernels/test_flash_decode_attention.py index 08a04e19..b03e6268 100644 --- a/tests/kernels/test_flash_decode_attention.py +++ b/tests/kernels/test_flash_decode_attention.py @@ -9,6 +9,12 @@ Grid: (batch_size * num_heads, 1, 1) Block: (64, 1, 1) -- single AMD wave64 + +Model shape references (decode phase, KV-cache heads): + LLaMA 3.1 8B : 8 KV heads, head_dim=128 + LLaMA 3.1 70B : 8 KV heads, head_dim=128 + DeepSeek V3 : compressed KV via MLA, head_dim=128 + Mixtral 8x7B : 8 KV heads, head_dim=128 """ import sys @@ -29,6 +35,11 @@ pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) from tests.test_common import run_perftest +from tests.kernels.benchmark_common import ( + PerfRow, + bench_gpu_us_torch, + print_perf_table, +) import flydsl from kernels.flash_decode_attention import ( @@ -36,8 +47,8 @@ KERNEL_NAME, ) -WARMUP_ITERS = 5 -BENCH_ITERS = 20 +WARMUP_ITERS = 10 +BENCH_ITERS = 100 DTYPE_MAP = { "f32": torch.float32, @@ -52,17 +63,26 @@ } +def _sdpa_ref_us(Q_ref, K_ref, V_ref, warmup=10, iters=100): + """Benchmark PyTorch scaled_dot_product_attention as baseline.""" + def fn(): + F.scaled_dot_product_attention(Q_ref, K_ref, V_ref, is_causal=False) + return bench_gpu_us_torch(fn, warmup=warmup, iters=iters) + + def run_test( batch_size: int, num_heads: int, seq_len: int, head_dim: int, dtype_str: str = "f16", + do_compare: bool = False, ): total_heads = batch_size * num_heads torch_dtype = DTYPE_MAP[dtype_str] atol = ATOL_MAP[dtype_str] + shape_tag = f"B{batch_size}_H{num_heads}_S{seq_len}_D{head_dim}" print( f"\nTesting Flash Decode Attention " f"(B={batch_size}, H={num_heads}, S={seq_len}, D={head_dim}, dtype={dtype_str})" @@ -75,7 +95,7 @@ def run_test( print(f"[FAIL] Compile failed: {e}") import traceback traceback.print_exc() - return False + return False, None torch.manual_seed(42) @@ -83,13 +103,11 @@ def run_test( K_ref = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float32) V_ref = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float32) - # PyTorch reference (always in fp32 for stability). expected = F.scaled_dot_product_attention( Q_ref, K_ref, V_ref, is_causal=False - ) # [B, H, 1, D] + ) expected = expected.squeeze(2).reshape(total_heads, head_dim).to(torch.float32) - # Prepare device tensors in target dtype, flattened to [total_heads, ...]. Q_dev = Q_ref.squeeze(2).reshape(total_heads, head_dim).to(torch_dtype).contiguous() K_dev = K_ref.reshape(total_heads, seq_len, head_dim).to(torch_dtype).contiguous() V_dev = V_ref.reshape(total_heads, seq_len, head_dim).to(torch_dtype).contiguous() @@ -103,43 +121,76 @@ def kernel_launch(): kernel_launch() torch.cuda.synchronize() - _, avg_us = run_perftest( - lambda: (kernel_launch(), torch.cuda.synchronize()), - num_iters=BENCH_ITERS, - num_warmup=WARMUP_ITERS, - ) - torch.cuda.synchronize() - avg_ms = avg_us / 1000.0 + flir_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + avg_ms = flir_gpu_us / 1000.0 elem_bytes = 4 if dtype_str == "f32" else 2 kv_bytes = 2 * total_heads * seq_len * head_dim * elem_bytes - bandwidth_gbs = kv_bytes / (avg_us / 1e6) / 1e9 + bandwidth_gbs = kv_bytes / (flir_gpu_us / 1e6) / 1e9 print(f" Kernel avg time: {avg_ms:.4f} ms (warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") print(f" Bandwidth (KV read): {bandwidth_gbs:.2f} GB/s") + sdpa_us = None + if do_compare: + Q_sdpa = Q_ref.to(torch_dtype) + K_sdpa = K_ref.to(torch_dtype) + V_sdpa = V_ref.to(torch_dtype) + sdpa_us = _sdpa_ref_us(Q_sdpa, K_sdpa, V_sdpa, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f" PyTorch SDPA avg time: {sdpa_us / 1000.0:.4f} ms") + output_f32 = O_dev.to(torch.float32) error = (output_f32 - expected).abs().max().item() print(f" Max absolute error: {error:.2e} (atol={atol})") if error < atol: print(" PASSED") - return True + ok = True else: print(" FAILED") print(" Expected (first 8):", expected[0, :8]) print(" Got (first 8):", output_f32[0, :8]) - return False + ok = False + + perf_row = PerfRow( + op="flash_decode_attn", + shape=shape_tag, + dtype=dtype_str, + flir_gpu_us=flir_gpu_us, + aiter_gpu_us=sdpa_us, + ) + return ok, perf_row + + +# --------------------------------------------------------------------------- +# CI test configs (kept small for fast CI turnaround) +# --------------------------------------------------------------------------- +CI_CONFIGS = [ + # (batch, heads, seq_len, head_dim, dtype) + (1, 1, 64, 128, "f32"), + (2, 4, 128, 128, "f16"), + (1, 2, 64, 128, "bf16"), +] + +# --------------------------------------------------------------------------- +# Real model shapes (decode phase with moderate seq_len for compile-time +# unrolled loop; production seq_len would need scf.for dynamic loop) +# --------------------------------------------------------------------------- +MODEL_CONFIGS = [ + # LLaMA 3.1 8B : 8 KV heads, head_dim=128 + (1, 8, 256, 128, "f16"), + (8, 8, 256, 128, "f16"), + (32, 8, 128, 128, "bf16"), + # LLaMA 3.1 70B : 8 KV heads, head_dim=128 + (1, 8, 256, 128, "bf16"), + (4, 8, 256, 128, "f16"), + # Mixtral / DeepSeek-like : larger head count + (1, 32, 128, 128, "f16"), + (1, 32, 256, 128, "bf16"), +] def test_flash_decode_attention(): - """Pytest entry point -- small configs for CI.""" - configs = [ - # (batch, heads, seq_len, head_dim, dtype) - (1, 1, 64, 128, "f32"), - (2, 4, 128, 128, "f16"), - (1, 2, 64, 128, "bf16"), - ] - + """Pytest entry point -- CI correctness + optional model shapes.""" shapes_env = os.environ.get("FLYDSL_FLASH_ATTN_SHAPES", "").strip() if shapes_env: configs = [] @@ -149,15 +200,25 @@ def test_flash_decode_attention(): continue b, h, s, d, dt = [x.strip() for x in p.split(",")] configs.append((int(b), int(h), int(s), int(d), dt)) + else: + run_models = os.environ.get("FLYDSL_FLASH_ATTN_MODELS", "0") == "1" + configs = CI_CONFIGS + (MODEL_CONFIGS if run_models else []) + + do_compare = os.environ.get("FLYDSL_COMPARE_SDPA", "0") == "1" print("=" * 80) print("Running Flash Decode Attention Tests") + print(f" Configs: {len(configs)} compare_sdpa={do_compare}") print("=" * 80) failures = 0 + perf_rows = [] for batch, heads, seq_len, head_dim, dtype in configs: - if not run_test(batch, heads, seq_len, head_dim, dtype): + ok, row = run_test(batch, heads, seq_len, head_dim, dtype, do_compare=do_compare) + if not ok: failures += 1 + if row is not None: + perf_rows.append(row) print("\n" + "=" * 80) if failures == 0: @@ -166,6 +227,9 @@ def test_flash_decode_attention(): print(f"{failures} TESTS FAILED") print("=" * 80) + if do_compare and perf_rows: + print_perf_table(perf_rows) + assert failures == 0, f"{failures} test(s) failed"