diff --git a/kernels/flash_decode_attention.py b/kernels/flash_decode_attention.py new file mode 100644 index 00000000..bf692620 --- /dev/null +++ b/kernels/flash_decode_attention.py @@ -0,0 +1,297 @@ +"""Flash Decode Attention kernel builder. + +Single-query (decode-phase) attention using online softmax: + O[h,d] = sum_j( softmax(Q[h,:] . K[h,j,:] / sqrt(d_k))_j * V[h,j,d] ) + +Architecture: + Grid: (total_heads, 1, 1) -- one wavefront per (batch, head) + Block: (WARP_SIZE, 1, 1) -- AMD wave64, barrier-free dot product reduction + +Each thread owns ELEMS_PER_THREAD = head_dim / WARP_SIZE output elements. +Dot products Q.K[j] use intra-warp xor-shuffle sum reduction so all lanes +see the same score without shared-memory barriers. +Online softmax avoids materializing the full attention score matrix. + +Memory layout (row-major, batch and heads flattened into dim-0): + Q: [total_heads, head_dim] + K: [total_heads, seq_len, head_dim] + V: [total_heads, seq_len, head_dim] + O: [total_heads, head_dim] + +where total_heads = batch_size * num_heads. +""" + +from _mlir import ir + +from flydsl.dialects.ext import flir, arith, gpu +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +import _mlir.extras.types as T + + +KERNEL_NAME = "flash_decode_attention" + +WARP_SIZE = 64 + + +def dtype_to_elem_type(dtype_str: str): + if dtype_str == "f32": + return T.f32() + if dtype_str == "f16": + return T.f16() + if dtype_str == "bf16": + return T.bf16() + raise ValueError(f"unsupported dtype: {dtype_str}") + + +def build_flash_decode_attention_module( + seq_len: int, + head_dim: int, + dtype_str: str = "f16", +): + """Build MLIR module for decode-phase flash attention. + + Args: + seq_len: KV cache sequence length (compile-time constant). + head_dim: per-head dimension, must be divisible by WARP_SIZE (64). + dtype_str: element type for Q/K/V/O ("f32", "f16", or "bf16"). + + Returns: + An MlirModule instance whose ``__call__`` launches the kernel. + """ + if head_dim % WARP_SIZE != 0: + raise ValueError( + f"head_dim ({head_dim}) must be divisible by WARP_SIZE ({WARP_SIZE})" + ) + + arch = get_hip_arch() + DYN = ir.ShapedType.get_dynamic_size() + BLOCK_THREADS = WARP_SIZE + ELEMS_PER_THREAD = head_dim // WARP_SIZE + + _state = {} + + class _FlashDecodeAttn(flir.MlirModule): + GPU_MODULE_NAME = "flash_decode_attn" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + _state["elem_type"] = dtype_to_elem_type(dtype_str) + _state["compute_type"] = T.f32() + + @flir.kernel + def flash_decode_attention_kernel( + self: flir.T.i64, + Q: lambda: T.memref(DYN, head_dim, _state["elem_type"]), + K: lambda: T.memref(DYN, seq_len, head_dim, _state["elem_type"]), + V: lambda: T.memref(DYN, seq_len, head_dim, _state["elem_type"]), + O: lambda: T.memref(DYN, head_dim, _state["elem_type"]), + total_heads: lambda: T.index(), + ): + elem_type = _state["elem_type"] + compute_type = _state["compute_type"] + fm_fast = flir.arith.FastMathFlags.fast + + h = flir.const_index(flir.block_idx("x")) + tid = flir.const_index(flir.thread_idx("x")) + + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_zero_f = arith.constant(0.0, type=compute_type) + c_log2e = arith.constant(1.4426950408889634, type=compute_type) + rsqrt_d = arith.constant(1.0 / (head_dim ** 0.5), type=compute_type) + + # Thread t owns output elements [d_base .. d_base + ELEMS_PER_THREAD). + c_ept = flir.const_index(ELEMS_PER_THREAD) + d_base = flir.arith.MulIOp( + arith.as_value(tid), arith.as_value(c_ept) + ).result + + # Pre-compute per-element head-dim indices. + d_indices = [] + for e in range_constexpr(ELEMS_PER_THREAD): + d_off = flir.const_index(e) + d_indices.append( + flir.arith.AddIOp( + arith.as_value(d_base), arith.as_value(d_off) + ).result + ) + + # Load this thread's Q elements into registers. + q_local = [] + for e in range_constexpr(ELEMS_PER_THREAD): + q_e = flir.memref.load(Q, [arith.as_value(h), arith.as_value(d_indices[e])]) + q_f = ( + q_e + if dtype_str == "f32" + else flir.arith.extf(compute_type, arith.as_value(q_e)) + ) + q_local.append(q_f) + + # ---- online softmax state ---- + m = c_neg_inf # running max + l = c_zero_f # running denominator (sum of exp) + acc = [c_zero_f] * ELEMS_PER_THREAD # weighted V accumulator + + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + + # ---- main loop over KV-cache positions (compile-time unrolled) ---- + for j_py in range_constexpr(seq_len): + j = flir.const_index(j_py) + + # Partial dot product: Q[d_base:d_base+EPT] . K[h, j, d_base:d_base+EPT] + partial = c_zero_f + for e in range_constexpr(ELEMS_PER_THREAD): + k_e = flir.memref.load( + K, [arith.as_value(h), arith.as_value(j), arith.as_value(d_indices[e])] + ) + k_f = ( + k_e + if dtype_str == "f32" + else flir.arith.extf( + compute_type, arith.as_value(k_e) + ) + ) + qk = flir.arith.MulFOp( + arith.as_value(q_local[e]), + arith.as_value(k_f), + fastmath=fm_fast, + ).result + partial = flir.arith.AddFOp( + arith.as_value(partial), arith.as_value(qk), + fastmath=fm_fast, + ).result + + # Warp-wide sum reduction (xor-shuffle, wave64). + w = arith.as_value(partial) + for sh in [32, 16, 8, 4, 2, 1]: + off = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp( + arith.as_value(w), off, width_i32, mode="xor" + ).shuffleResult + ) + w = flir.arith.AddFOp( + arith.as_value(w), peer, fastmath=fm_fast + ).result + + # score = dot(Q, K_j) / sqrt(head_dim) + score = flir.arith.MulFOp( + arith.as_value(w), + arith.as_value(rsqrt_d), + fastmath=fm_fast, + ).result + + # Online softmax update: + # m_new = max(m, score) + # correction = exp2((m_old - m_new) * log2e) + # p = exp2((score - m_new) * log2e) + # l = l * correction + p + # acc[e] = acc[e] * correction + p * V[h, j, e] + m_new = flir.arith.MaximumFOp( + arith.as_value(m), arith.as_value(score) + ).result + + diff_m = flir.arith.SubFOp( + arith.as_value(m), arith.as_value(m_new), + fastmath=fm_fast, + ).result + corr_arg = flir.arith.MulFOp( + arith.as_value(diff_m), arith.as_value(c_log2e), + fastmath=fm_fast, + ).result + correction = flir.math.exp2( + arith.as_value(corr_arg), fastmath=fm_fast + ) + + diff_s = flir.arith.SubFOp( + arith.as_value(score), arith.as_value(m_new), + fastmath=fm_fast, + ).result + p_arg = flir.arith.MulFOp( + arith.as_value(diff_s), arith.as_value(c_log2e), + fastmath=fm_fast, + ).result + p = flir.math.exp2( + arith.as_value(p_arg), fastmath=fm_fast + ) + + l_corr = flir.arith.MulFOp( + arith.as_value(l), + arith.as_value(correction), + fastmath=fm_fast, + ).result + l = flir.arith.AddFOp( + arith.as_value(l_corr), arith.as_value(p), + fastmath=fm_fast, + ).result + + # Update accumulator with weighted V. + new_acc = [] + for e in range_constexpr(ELEMS_PER_THREAD): + v_e = flir.memref.load( + V, [arith.as_value(h), arith.as_value(j), arith.as_value(d_indices[e])] + ) + v_f = ( + v_e + if dtype_str == "f32" + else flir.arith.extf( + compute_type, arith.as_value(v_e) + ) + ) + a_corr = flir.arith.MulFOp( + arith.as_value(acc[e]), + arith.as_value(correction), + fastmath=fm_fast, + ).result + pv = flir.arith.MulFOp( + arith.as_value(p), + arith.as_value(v_f), + fastmath=fm_fast, + ).result + new_acc.append( + flir.arith.AddFOp( + arith.as_value(a_corr), arith.as_value(pv), + fastmath=fm_fast, + ).result + ) + + acc = new_acc + m = m_new + + # ---- store output: O[h, d] = acc[d] / l ---- + for e in range_constexpr(ELEMS_PER_THREAD): + out_f32 = flir.arith.DivFOp( + arith.as_value(acc[e]), + arith.as_value(l), + fastmath=fm_fast, + ).result + if dtype_str != "f32": + out_e = flir.arith.truncf(elem_type, arith.as_value(out_f32)) + else: + out_e = out_f32 + flir.memref.store( + arith.as_value(out_e), + O, + [arith.as_value(h), arith.as_value(d_indices[e])], + ) + + @flir.jit + def __call__( + self: flir.T.i64, + Q: lambda: T.memref(DYN, head_dim, _state["elem_type"]), + K: lambda: T.memref(DYN, seq_len, head_dim, _state["elem_type"]), + V: lambda: T.memref(DYN, seq_len, head_dim, _state["elem_type"]), + O: lambda: T.memref(DYN, head_dim, _state["elem_type"]), + total_heads: lambda: T.index(), + ): + c1 = flir.arith_ext.index(1) + gx = total_heads + bx = flir.arith_ext.index(BLOCK_THREADS) + flir.gpu_ext.LaunchFuncOp( + ["flash_decode_attn", "flash_decode_attention_kernel"], + grid_size=(gx, c1, c1), + block_size=(bx, c1, c1), + kernel_operands=[Q, K, V, O, total_heads], + ) + + return _FlashDecodeAttn() diff --git a/tests/kernels/test_flash_decode_attention.py b/tests/kernels/test_flash_decode_attention.py new file mode 100644 index 00000000..b03e6268 --- /dev/null +++ b/tests/kernels/test_flash_decode_attention.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +""" +Flash Decode Attention Test + +Verifies the single-query (decode-phase) attention kernel: + O = softmax(Q @ K^T / sqrt(d)) @ V + +where Q has a single token per (batch, head). + +Grid: (batch_size * num_heads, 1, 1) +Block: (64, 1, 1) -- single AMD wave64 + +Model shape references (decode phase, KV-cache heads): + LLaMA 3.1 8B : 8 KV heads, head_dim=128 + LLaMA 3.1 70B : 8 KV heads, head_dim=128 + DeepSeek V3 : compressed KV via MLA, head_dim=128 + Mixtral 8x7B : 8 KV heads, head_dim=128 +""" + +import sys +import os +from pathlib import Path + +_repo = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(_repo)) + +import pytest + +try: + import torch + import torch.nn.functional as F +except ImportError: + torch = None +if torch is None or not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + +from tests.test_common import run_perftest +from tests.kernels.benchmark_common import ( + PerfRow, + bench_gpu_us_torch, + print_perf_table, +) + +import flydsl +from kernels.flash_decode_attention import ( + build_flash_decode_attention_module, + KERNEL_NAME, +) + +WARMUP_ITERS = 10 +BENCH_ITERS = 100 + +DTYPE_MAP = { + "f32": torch.float32, + "f16": torch.float16, + "bf16": torch.bfloat16, +} + +ATOL_MAP = { + "f32": 1e-4, + "f16": 2e-2, + "bf16": 3e-2, +} + + +def _sdpa_ref_us(Q_ref, K_ref, V_ref, warmup=10, iters=100): + """Benchmark PyTorch scaled_dot_product_attention as baseline.""" + def fn(): + F.scaled_dot_product_attention(Q_ref, K_ref, V_ref, is_causal=False) + return bench_gpu_us_torch(fn, warmup=warmup, iters=iters) + + +def run_test( + batch_size: int, + num_heads: int, + seq_len: int, + head_dim: int, + dtype_str: str = "f16", + do_compare: bool = False, +): + total_heads = batch_size * num_heads + torch_dtype = DTYPE_MAP[dtype_str] + atol = ATOL_MAP[dtype_str] + + shape_tag = f"B{batch_size}_H{num_heads}_S{seq_len}_D{head_dim}" + print( + f"\nTesting Flash Decode Attention " + f"(B={batch_size}, H={num_heads}, S={seq_len}, D={head_dim}, dtype={dtype_str})" + ) + + try: + m = build_flash_decode_attention_module(seq_len, head_dim, dtype_str) + exe = flydsl.compile(m) + except Exception as e: + print(f"[FAIL] Compile failed: {e}") + import traceback + traceback.print_exc() + return False, None + + torch.manual_seed(42) + + Q_ref = torch.randn(batch_size, num_heads, 1, head_dim, device="cuda", dtype=torch.float32) + K_ref = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float32) + V_ref = torch.randn(batch_size, num_heads, seq_len, head_dim, device="cuda", dtype=torch.float32) + + expected = F.scaled_dot_product_attention( + Q_ref, K_ref, V_ref, is_causal=False + ) + expected = expected.squeeze(2).reshape(total_heads, head_dim).to(torch.float32) + + Q_dev = Q_ref.squeeze(2).reshape(total_heads, head_dim).to(torch_dtype).contiguous() + K_dev = K_ref.reshape(total_heads, seq_len, head_dim).to(torch_dtype).contiguous() + V_dev = V_ref.reshape(total_heads, seq_len, head_dim).to(torch_dtype).contiguous() + O_dev = torch.empty(total_heads, head_dim, device="cuda", dtype=torch_dtype) + + print(" Launching kernel...") + + def kernel_launch(): + exe(Q_dev, K_dev, V_dev, O_dev, total_heads) + + kernel_launch() + torch.cuda.synchronize() + + flir_gpu_us = bench_gpu_us_torch(kernel_launch, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + avg_ms = flir_gpu_us / 1000.0 + + elem_bytes = 4 if dtype_str == "f32" else 2 + kv_bytes = 2 * total_heads * seq_len * head_dim * elem_bytes + bandwidth_gbs = kv_bytes / (flir_gpu_us / 1e6) / 1e9 + print(f" Kernel avg time: {avg_ms:.4f} ms (warmup={WARMUP_ITERS}, iters={BENCH_ITERS})") + print(f" Bandwidth (KV read): {bandwidth_gbs:.2f} GB/s") + + sdpa_us = None + if do_compare: + Q_sdpa = Q_ref.to(torch_dtype) + K_sdpa = K_ref.to(torch_dtype) + V_sdpa = V_ref.to(torch_dtype) + sdpa_us = _sdpa_ref_us(Q_sdpa, K_sdpa, V_sdpa, warmup=WARMUP_ITERS, iters=BENCH_ITERS) + print(f" PyTorch SDPA avg time: {sdpa_us / 1000.0:.4f} ms") + + output_f32 = O_dev.to(torch.float32) + error = (output_f32 - expected).abs().max().item() + print(f" Max absolute error: {error:.2e} (atol={atol})") + + if error < atol: + print(" PASSED") + ok = True + else: + print(" FAILED") + print(" Expected (first 8):", expected[0, :8]) + print(" Got (first 8):", output_f32[0, :8]) + ok = False + + perf_row = PerfRow( + op="flash_decode_attn", + shape=shape_tag, + dtype=dtype_str, + flir_gpu_us=flir_gpu_us, + aiter_gpu_us=sdpa_us, + ) + return ok, perf_row + + +# --------------------------------------------------------------------------- +# CI test configs (kept small for fast CI turnaround) +# --------------------------------------------------------------------------- +CI_CONFIGS = [ + # (batch, heads, seq_len, head_dim, dtype) + (1, 1, 64, 128, "f32"), + (2, 4, 128, 128, "f16"), + (1, 2, 64, 128, "bf16"), +] + +# --------------------------------------------------------------------------- +# Real model shapes (decode phase with moderate seq_len for compile-time +# unrolled loop; production seq_len would need scf.for dynamic loop) +# --------------------------------------------------------------------------- +MODEL_CONFIGS = [ + # LLaMA 3.1 8B : 8 KV heads, head_dim=128 + (1, 8, 256, 128, "f16"), + (8, 8, 256, 128, "f16"), + (32, 8, 128, 128, "bf16"), + # LLaMA 3.1 70B : 8 KV heads, head_dim=128 + (1, 8, 256, 128, "bf16"), + (4, 8, 256, 128, "f16"), + # Mixtral / DeepSeek-like : larger head count + (1, 32, 128, 128, "f16"), + (1, 32, 256, 128, "bf16"), +] + + +def test_flash_decode_attention(): + """Pytest entry point -- CI correctness + optional model shapes.""" + shapes_env = os.environ.get("FLYDSL_FLASH_ATTN_SHAPES", "").strip() + if shapes_env: + configs = [] + for part in shapes_env.split(";"): + p = part.strip() + if not p: + continue + b, h, s, d, dt = [x.strip() for x in p.split(",")] + configs.append((int(b), int(h), int(s), int(d), dt)) + else: + run_models = os.environ.get("FLYDSL_FLASH_ATTN_MODELS", "0") == "1" + configs = CI_CONFIGS + (MODEL_CONFIGS if run_models else []) + + do_compare = os.environ.get("FLYDSL_COMPARE_SDPA", "0") == "1" + + print("=" * 80) + print("Running Flash Decode Attention Tests") + print(f" Configs: {len(configs)} compare_sdpa={do_compare}") + print("=" * 80) + + failures = 0 + perf_rows = [] + for batch, heads, seq_len, head_dim, dtype in configs: + ok, row = run_test(batch, heads, seq_len, head_dim, dtype, do_compare=do_compare) + if not ok: + failures += 1 + if row is not None: + perf_rows.append(row) + + print("\n" + "=" * 80) + if failures == 0: + print("ALL TESTS PASSED") + else: + print(f"{failures} TESTS FAILED") + print("=" * 80) + + if do_compare and perf_rows: + print_perf_table(perf_rows) + + assert failures == 0, f"{failures} test(s) failed" + + +if __name__ == "__main__": + test_flash_decode_attention()