From 11b396790bcdacb759d2bc4f61c0d3f8d50a359d Mon Sep 17 00:00:00 2001 From: yanguahe Date: Tue, 3 Feb 2026 23:13:30 +0800 Subject: [PATCH 01/17] Add simple GEMM kernel with MFMA 16x16x16 and test scripts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add kernels/simple_gemm.py: Simple GEMM kernel (C = A × B^T) for AMD GPUs using MFMA instructions with XOR16 LDS swizzle and boundary checks for non-aligned M, N, K dimensions - Add tests/kernels/test_simple_gemm.py: Test script with aligned and non-aligned dimension test cases - Add tests/kernels/test_moe_stage1_simple.py: Standalone test script for MoE Stage1 kernel - Add run.sh: Shell script for running tests and collecting ROCm thread traces - Add input.yaml: ROCm profiler configuration for thread trace collection Co-authored-by: Cursor --- input.yaml | 17 + kernels/simple_gemm.py | 616 ++++++++++++++++++++++++ run.sh | 87 ++++ tests/kernels/test_moe_stage1_simple.py | 193 ++++++++ tests/kernels/test_simple_gemm.py | 364 ++++++++++++++ thread_trace/.gitkeep | 0 6 files changed, 1277 insertions(+) create mode 100644 input.yaml create mode 100644 kernels/simple_gemm.py create mode 100755 run.sh create mode 100644 tests/kernels/test_moe_stage1_simple.py create mode 100644 tests/kernels/test_simple_gemm.py create mode 100644 thread_trace/.gitkeep diff --git a/input.yaml b/input.yaml new file mode 100644 index 00000000..8477d1e0 --- /dev/null +++ b/input.yaml @@ -0,0 +1,17 @@ +jobs: + - + kernel_include_regex: (kernel_gemm) + kernel_exclude_regex: + kernel_iteration_range: "[1]" + output_file: out + output_directory: thread_trace/rpf_v3 + output_format: [csv] + truncate_kernels: false + sys_trace: false # enable for pftrace and otf2 + advanced_thread_trace: true # enable for att and ui folder + att_target_cu: 1 + att_shader_engine_mask: "0xf" # collect one CU from 4 SEs + att_simd_select: "0xf" # collect 4 SIMDs on single CU + att_buffer_size: "0x6000000" + - + pmc: [SQ_WAVES, FETCH_SIZE] diff --git a/kernels/simple_gemm.py b/kernels/simple_gemm.py new file mode 100644 index 00000000..b9c6caa5 --- /dev/null +++ b/kernels/simple_gemm.py @@ -0,0 +1,616 @@ +"""Simple GEMM kernel implementation using FlyDSL (MFMA 16x16x16). + +This module provides a simple GEMM kernel (C = A × B^T) for AMD GPUs using MFMA instructions. + +Configuration: +- Block: 256 threads = 4 waves × 64 lanes +- Tile: M=16 × N=64 × K=128 (configurable) +- Currently supports bf16/fp16 input, f32 accumulator, bf16/fp16 output +- Supports non-aligned M, N, K dimensions: + - M and N: kernel boundary checks on output stores + - K: caller pads to multiple of 8 (vector load granularity) + +A matrix loading: +- GM → GPR → LDS: 256 threads cooperatively load the A tile with XOR16 swizzle +- LDS → GPR: Each wave reads 16×32 sub-block for MFMA + +B matrix loading: +- Direct load: Each wave handles 16 columns of N +- Wave 0: N[0:16], Wave 1: N[16:32], Wave 2: N[32:48], Wave 3: N[48:64] + +Output C matrix (16×64): +- Wave 0 → C[0:16, 0:16] +- Wave 1 → C[0:16, 16:32] +- Wave 2 → C[0:16, 32:48] +- Wave 3 → C[0:16, 48:64] +- Boundary check: skip stores for out-of-bounds elements +""" + +import functools + +import flydsl +from flydsl.dialects.ext import flir +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils import SmemAllocator + +from _mlir import ir +from _mlir.dialects import scf as _scf + +from flydsl.dialects.ext import arith, gpu, buffer_ops, vector, rocdl, scf +from flydsl.lang.ir.types import T, memref + + +@functools.lru_cache(maxsize=1024) +def compile_simple_gemm( + *, + tile_m: int = 16, + tile_n: int = 64, + tile_k: int = 128, + in_dtype: str = "bf16", +): + """Compile a simple GEMM kernel and return the compiled executable. + + This kernel supports non-aligned M and N dimensions via output boundary checks. + The caller is responsible for: + - Passing M_pad, N_pad, K_pad as the actual tensor dimensions (padded to tile sizes) + - Passing M_orig, N_orig as the original dimensions for boundary checking on output + + Args: + tile_m, tile_n, tile_k: Block tile sizes. + in_dtype: Input data type ("bf16" or "fp16"). + """ + if in_dtype not in ("bf16", "fp16"): + raise ValueError(f"in_dtype must be 'bf16' or 'fp16', got {in_dtype!r}") + + is_bf16 = in_dtype == "bf16" + elem_bytes = 2 # bf16 and fp16 are both 2 bytes + + # Validate tile configuration + tile_k_bytes = tile_k * elem_bytes + if tile_k_bytes % 64 != 0: + raise ValueError( + f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes}" + ) + + gpu_arch = get_hip_arch() + allocator = SmemAllocator(None, arch=gpu_arch) + _state = {} + + DYN = ir.ShapedType.get_dynamic_size() + total_threads = 256 + + # LDS configuration: tile_m × tile_k elements + lds_stride = tile_k # No padding for simplicity + + # Type helpers + def _elem_type(): + return T.bf16 if is_bf16 else T.f16 + + def _vec8_type(): + """16B vector (8 bf16/fp16 elements).""" + return T.bf16x8 if is_bf16 else T.f16x8 + + def _out_type(): + """Output element type.""" + return T.bf16 if is_bf16 else T.f16 + + module_name = f"simple_gemm_{in_dtype}_t{tile_m}x{tile_n}x{tile_k}".replace("-", "_") + + class _GEMM(flir.MlirModule): + GPU_MODULE_NAME = module_name + GPU_MODULE_TARGETS = [ + f'#rocdl.target' + ] + + def init_gpu_module(self): + # Allocate LDS for A tile: tile_m × tile_k elements + lds_a_elems = tile_m * lds_stride + _state["lds_a_decl"] = allocator.allocate_array(_elem_type(), lds_a_elems) + allocator.finalize() + + @flir.kernel + def kernel_gemm( + self: flir.T.i64, + arg_c: lambda: memref(DYN, _out_type()), + arg_a: lambda: memref(DYN, _elem_type()), + arg_b: lambda: memref(DYN, _elem_type()), + c_m_pad: lambda: T.index, # Padded M dimension (tensor size) + c_n_pad: lambda: T.index, # Padded N dimension (tensor size) + c_k_pad: lambda: T.index, # Padded K dimension (tensor size) + c_m_orig: lambda: T.index, # Original M dimension (for boundary check) + c_n_orig: lambda: T.index, # Original N dimension (for boundary check) + ): + # ================= Types ================= + f32 = T.f32 + i32 = T.i32 + i64 = T.i64 + vec4_f32 = T.f32x4 + vec4_i16 = T.i16x4 + vec4_f16 = T.f16x4 + vec8_elem = _vec8_type() + vec1_i64 = T.vec(1, i64) + vec2_i64 = T.vec(2, i64) + + # Accumulator initialization + acc_init = arith.constant_vector(0.0, vec4_f32) + + # ================= Layouts ================= + # A layout: [M_pad, K_pad] row-major + layout_a = flir.make_layout((c_m_pad, c_k_pad), stride=(c_k_pad, 1)) + + # B layout: [N_pad, K_pad] row-major (B^T in standard GEMM) + layout_b = flir.make_layout((c_n_pad, c_k_pad), stride=(c_k_pad, 1)) + + # C layout: [M_pad, N_pad] row-major + layout_c = flir.make_layout((c_m_pad, c_n_pad), stride=(c_n_pad, 1)) + + # LDS layout: [tile_m, tile_k] + shape_lds = flir.make_shape(tile_m, tile_k) + stride_lds = flir.make_stride(lds_stride, 1) + layout_lds = flir.make_layout(shape_lds, stride_lds) + + # XOR16 swizzle parameter (in 16-byte blocks) + k_blocks16 = arith.constant(tile_k_bytes // 16, index=True) + + # ================= Thread/Block IDs ================= + tx = gpu.thread_id("x") + bx = gpu.block_id("x") # M dimension + by = gpu.block_id("y") # N dimension + + # Base addresses for this block + bx_m = bx * arith.constant(tile_m, index=True) + by_n = by * arith.constant(tile_n, index=True) + + # ================= Thread Decomposition ================= + # tx -> (wave_id, lane_id) + layout_wave_lane = flir.make_layout((4, 64), stride=(64, 1)) + coord_wave_lane = flir.idx2crd(tx, layout_wave_lane) + wave_id = flir.get(coord_wave_lane, 0) + lane_id = flir.get(coord_wave_lane, 1) + + # lane_id -> (lane_div_16, lane_mod_16) + layout_lane16 = flir.make_layout((4, 16), stride=(16, 1)) + coord_lane16 = flir.idx2crd(lane_id, layout_lane16) + lane_div_16 = flir.get(coord_lane16, 0) + lane_mod_16 = flir.get(coord_lane16, 1) + + # ================= LDS Setup ================= + base_ptr = allocator.get_base() + lds_a_ptr = _state["lds_a_decl"](base_ptr) + lds_a = lds_a_ptr.get() + + # ================= Buffer Resources ================= + a_rsrc = buffer_ops.create_buffer_resource(arg_a, max_size=False) + b_rsrc = buffer_ops.create_buffer_resource(arg_b, max_size=False) + c_rsrc = buffer_ops.create_buffer_resource(arg_c, max_size=False) + + # ================= Wave/Lane Mappings ================= + # For MFMA 16x16x16: + # - A row index: lane_mod_16 (0..15) + # - K pack offset: lane_div_16 * 4 (each lane group handles 4 elements) + row_a_lds = lane_mod_16 + + # K element offset for LDS reads (16 elements per pack, 4 packs per K64) + kpack_elems = 8 # 8 bf16 = 16 bytes + col_offset_base = lane_div_16 * arith.constant(kpack_elems, index=True) + # Convert to bytes for swizzle + col_offset_base_bytes = col_offset_base * arith.constant(elem_bytes, index=True) + + # ================= Tile Configuration ================= + m_repeat = tile_m // 16 # Number of M-dimension repeats + k_unroll = tile_k_bytes // 64 # K64-byte micro-steps + num_waves = 4 + n_per_wave = tile_n // num_waves # Columns per wave + num_acc_n = n_per_wave // 16 # Accumulators per wave along N + + # Wave's N tile base + c_n_per_wave = arith.constant(n_per_wave, index=True) + n_tile_base = wave_id * c_n_per_wave + + # ================= A Tile Loading (GM -> LDS) ================= + # 256 threads load tile_m × tile_k elements (16B per thread) + bytes_a_per_tile = tile_m * tile_k * elem_bytes + bytes_per_thread_a = bytes_a_per_tile // total_threads + num_a_loads = bytes_per_thread_a // 16 # 16B loads + + # A tile layout in dwords for addressing + tile_k_dwords = (tile_k * elem_bytes) // 4 + layout_a_tile_div4 = flir.make_layout( + (tile_m, tile_k_dwords), stride=(tile_k_dwords, 1) + ) + + # Convert K to dwords for A addressing + c_k_div4bytes = (c_k_pad * arith.constant(elem_bytes, index=True)) / arith.constant(4, index=True) + layout_a_div4 = flir.make_layout((c_m_pad, c_k_div4bytes), stride=(c_k_div4bytes, 1)) + + c4 = arith.constant(4, index=True) + tx_i32_base = tx * c4 + + atom_a_g2r = flir.make_copy_atom(_elem_type(), vector_size=8) + atom_a_lds = flir.make_copy_atom(_elem_type(), vector_size=8) + + def a_tile_chunk_coord(i: int): + """Map (thread, chunk_id) -> (row_local, col_local_i32) for A loads.""" + chunk_off = arith.constant(i * total_threads * 4, index=True) + tile_idx = tx_i32_base + chunk_off + coord_local = flir.idx2crd(tile_idx, layout_a_tile_div4) + row_local = flir.get(coord_local, 0) + col_local_i32 = flir.get(coord_local, 1) + return row_local, col_local_i32 + + def load_a_tile(base_k): + """Load A tile from global memory (tile_m × tile_k).""" + base_k_bytes = base_k * arith.constant(elem_bytes, index=True) + base_k_div4 = base_k_bytes / arith.constant(4, index=True) + parts = [] + for i in range_constexpr(num_a_loads): + row_local, col_local_i32 = a_tile_chunk_coord(i) + row_global = bx_m + row_local + coord_a_g = flir.make_coord(row_global, base_k_div4 + col_local_i32) + idx_i32 = flir.crd2idx(coord_a_g, layout_a_div4) + # Convert dword index to element index + idx_elem = idx_i32 * arith.constant(2, index=True) # 2 bf16 per dword + + a_view = flir.TensorView( + arg_a, + (8,), # 8 bf16 elements = 16 bytes + strides=(1,), + base_indices=(idx_elem,), + element_type=_elem_type(), + ) + a_vec = flir.copy( + atom_a_g2r, + a_view, + None, + alignment=16, + return_vector=True, + src_buffer_resource=None, # Use memref directly + ) + parts.append(vector.bitcast(T.i32x4, a_vec)) + return parts + + def store_a_tile_to_lds(a_parts): + """Store A tile to LDS with XOR16 swizzle.""" + for i in range_constexpr(num_a_loads): + row_local, col_local_i32 = a_tile_chunk_coord(i) + # Apply XOR16 swizzle + col_local_bytes = col_local_i32 * c4 + col_swz_bytes = flir.swizzle_xor16(row_local, col_local_bytes, k_blocks16) + col_swz = col_swz_bytes / arith.constant(elem_bytes, index=True) + coord_store = flir.make_coord(row_local, col_swz) + idx0 = flir.crd2idx(coord_store, layout_lds) + v8 = vector.bitcast(vec8_elem, a_parts[i]) + s_view = flir.TensorView( + lds_a, + (8,), + strides=(1,), + base_indices=(idx0,), + element_type=_elem_type(), + ) + flir.copy(atom_a_lds, v8, s_view, alignment=16) + + # ================= B Tile Loading (Direct to GPR) ================= + def load_b_packs_k64(base_k, ku: int, ni: int): + """Load B pack for MFMA (16B -> 2 × i64 for K64-byte step).""" + # Global N index for this wave/lane + n_offset = arith.constant(ni * 16, index=True) + n_global = by_n + n_tile_base + n_offset + lane_mod_16 + + # K index within the K64 block + ki64 = arith.constant(ku * 32, index=True) # 64 bytes = 32 bf16 + k_base = base_k + ki64 + + # lane_div_16 determines which 8 elements to load (0-3 -> 0, 8, 16, 24 offset) + k_lane_offset = lane_div_16 * arith.constant(8, index=True) + k_global = k_base + k_lane_offset + + # Calculate linear index + coord_b = flir.make_coord(n_global, k_global) + idx_b = flir.crd2idx(coord_b, layout_b) + + # Load 8 elements (16 bytes) + b_view = flir.TensorView( + arg_b, + (8,), + strides=(1,), + base_indices=(idx_b,), + element_type=_elem_type(), + ) + b_vec = flir.copy( + atom_a_g2r, # Same atom type + b_view, + None, + alignment=16, + return_vector=True, + src_buffer_resource=None, + ) + + # Split into two i64 halves for two MFMA K16 steps + b_i64x2 = vector.bitcast(vec2_i64, b_vec) + b0_i64 = vector.extract(b_i64x2, static_position=[0], dynamic_position=[]) + b1_i64 = vector.extract(b_i64x2, static_position=[1], dynamic_position=[]) + + # Convert to MFMA operand type + if is_bf16: + # bf16 uses i16 bit patterns + b0_v1 = vector.from_elements(vec1_i64, [b0_i64]) + b1_v1 = vector.from_elements(vec1_i64, [b1_i64]) + return vector.bitcast(vec4_i16, b0_v1), vector.bitcast(vec4_i16, b1_v1) + else: + # fp16 uses f16 directly + b0_v1 = vector.from_elements(vec1_i64, [b0_i64]) + b1_v1 = vector.from_elements(vec1_i64, [b1_i64]) + return vector.bitcast(vec4_f16, b0_v1), vector.bitcast(vec4_f16, b1_v1) + + def load_b_tile(base_k): + """Load entire B tile for K loop.""" + b_tile = [] + for ku in range_constexpr(k_unroll): + packs0 = [] + packs1 = [] + for ni in range_constexpr(num_acc_n): + b0, b1 = load_b_packs_k64(base_k, ku, ni) + packs0.append(b0) + packs1.append(b1) + b_tile.append((packs0, packs1)) + return b_tile + + # ================= A LDS Load ================= + def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): + """Load A pack from LDS for MFMA (16B -> 2 × i64).""" + # Apply XOR16 swizzle + col_swz_bytes = flir.swizzle_xor16(curr_row_a_lds, col_base_bytes, k_blocks16) + col_swz = col_swz_bytes / arith.constant(elem_bytes, index=True) + coord_a = flir.make_coord(curr_row_a_lds, col_swz) + idx_a = flir.crd2idx(coord_a, layout_lds) + idx_a = idx_a + lds_base + + # Load 8 elements + loaded_a = vector.load_op(vec8_elem, lds_a, [idx_a]) + a_i64x2 = vector.bitcast(vec2_i64, loaded_a) + a0_i64 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) + a1_i64 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) + + # Convert to MFMA operand type + if is_bf16: + a0_v1 = vector.from_elements(vec1_i64, [a0_i64]) + a1_v1 = vector.from_elements(vec1_i64, [a1_i64]) + return vector.bitcast(vec4_i16, a0_v1), vector.bitcast(vec4_i16, a1_v1) + else: + a0_v1 = vector.from_elements(vec1_i64, [a0_i64]) + a1_v1 = vector.from_elements(vec1_i64, [a1_i64]) + return vector.bitcast(vec4_f16, a0_v1), vector.bitcast(vec4_f16, a1_v1) + + # ================= MFMA Computation ================= + mfma_res_ty = vec4_f32 + if is_bf16: + mfma_fn = rocdl.mfma_f32_16x16x16bf16_1k + else: + mfma_fn = rocdl.mfma_f32_16x16x16f16 + + def mfma_step(acc_in, a, b): + return mfma_fn(mfma_res_ty, [a, b, acc_in, 0, 0, 0]) + + def mfma_k64_bytes(acc_in, a0, a1, b0, b1): + """K64-byte wrapper: two MFMA K16 ops.""" + acc_mid = mfma_step(acc_in, a0, b0) + return mfma_step(acc_mid, a1, b1) + + def compute_tile(accs_in, b_tile_in, lds_base): + """Compute one tile of MFMA operations.""" + current_accs = list(accs_in) + + for ku in range_constexpr(k_unroll): + b_packs0, b_packs1 = b_tile_in[ku] + # K byte offset for this ku + ki64 = ku * 64 # 64 bytes per ku + col_base = col_offset_base_bytes + arith.constant(ki64, index=True) + + for mi in range_constexpr(m_repeat): + mi_val = arith.constant(mi * 16, index=True) + curr_row_a_lds = row_a_lds + mi_val + + # Load A pack from LDS + a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + current_accs[acc_idx] = mfma_k64_bytes( + current_accs[acc_idx], + a0, a1, + b_packs0[ni], b_packs1[ni], + ) + + return current_accs + + # ================= Epilogue (Store C) with boundary checks ================= + def store_output(final_accs): + """Store accumulated results to C with boundary checks.""" + lane_div_16_mul4 = lane_div_16 * arith.constant(4, index=True) + + for mi in range_constexpr(m_repeat): + mi_base = arith.constant(mi * 16, index=True) + for ii in range_constexpr(4): # 4 rows per lane group + ii_idx = arith.constant(ii, index=True) + row_off = lane_div_16_mul4 + ii_idx + row_in_tile = mi_base + row_off + row = bx_m + row_in_tile + + col_base = by_n + n_tile_base + lane_mod_16 + + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + acc = final_accs[acc_idx] + val = vector.extract(acc, static_position=[ii], dynamic_position=[]) + + # Convert f32 to output type + val_out = arith.trunc_f(_out_type(), val) + + col = col_base + arith.constant(ni * 16, index=True) + + # Boundary check: row < M_orig and col < N_orig + row_valid = arith.cmpu(row, c_m_orig, "ult") + col_valid = arith.cmpu(col, c_n_orig, "ult") + in_bounds = arith.andi(row_valid, col_valid) + + # Conditional store using IfOp + if_op = scf.IfOp(in_bounds) + with if_op.then(): + coord_c = flir.make_coord(row, col) + idx_c = flir.crd2idx(coord_c, layout_c) + buffer_ops.buffer_store(val_out, c_rsrc, idx_c) + + # ================= Main Pipeline ================= + # Single LDS buffer, simple pipeline + lds_base = arith.constant(0, index=True) + + # Initialize accumulators + accs = [acc_init] * (num_acc_n * m_repeat) + + # K loop + c_tile_k = arith.constant(tile_k, index=True) + for k_base in range(arith.constant(0, index=True), c_k_pad, c_tile_k): + # Load A tile to LDS + a_parts = load_a_tile(k_base) + store_a_tile_to_lds(a_parts) + gpu.barrier() + + # Load B tile directly to GPR + b_tile = load_b_tile(k_base) + + # Compute MFMA + accs = compute_tile(accs, b_tile, lds_base) + + # Barrier before next iteration (if any) + gpu.barrier() + + # Store output (with boundary checks) + store_output(accs) + + @flir.jit + def __call__( + self: flir.T.i64, + arg_c: lambda: memref(DYN, _out_type()), + arg_a: lambda: memref(DYN, _elem_type()), + arg_b: lambda: memref(DYN, _elem_type()), + c_m_pad: lambda: T.index, + c_n_pad: lambda: T.index, + c_k_pad: lambda: T.index, + c_m_orig: lambda: T.index, + c_n_orig: lambda: T.index, + ): + c1 = arith.constant(1, index=True) + bdx = arith.constant(256, index=True) + tm = arith.constant(tile_m, index=True) + tn = arith.constant(tile_n, index=True) + one = arith.constant(1, index=True) + # Use padded dimensions for grid size + gx = (c_m_pad + tm - one) / tm + gy = (c_n_pad + tn - one) / tn + flir.gpu_ext.LaunchFuncOp( + [module_name, "kernel_gemm"], + grid_size=(gx, gy, c1), + block_size=(bdx, c1, c1), + kernel_operands=[ + arg_c, + arg_a, + arg_b, + c_m_pad, + c_n_pad, + c_k_pad, + c_m_orig, + c_n_orig, + ], + ) + + m = _GEMM() + return flydsl.compile(m) + + +def _align_up(val: int, align: int) -> int: + """Round up val to the next multiple of align.""" + return ((val + align - 1) // align) * align + + +def run_simple_gemm( + *, + M: int, + N: int, + K: int, + tile_m: int = 16, + tile_n: int = 64, + tile_k: int = 128, + in_dtype: str = "bf16", + A=None, + B=None, +): + """Run simple GEMM: C = A @ B^T. + + This function supports non-aligned M, N, K dimensions. + - M and N: handled by kernel boundary checks on output stores + - K: padded to multiple of tile_k to ensure correct computation + + Args: + M, N, K: Matrix dimensions (A[M,K], B[N,K], C[M,N]). + tile_m, tile_n, tile_k: Tile sizes. + in_dtype: Input data type ("bf16" or "fp16"). + A: Optional input tensor A[M,K]. If None, creates random tensor. + B: Optional input tensor B[N,K]. If None, creates random tensor. + + Returns: + C: Output tensor C[M,N]. + """ + import torch + + # Determine torch dtype + if in_dtype == "bf16": + torch_dtype = torch.bfloat16 + else: + torch_dtype = torch.float16 + + device = "cuda" + + # Create input tensors if not provided + if A is None: + A = torch.randn(M, K, dtype=torch_dtype, device=device) + if B is None: + B = torch.randn(N, K, dtype=torch_dtype, device=device) + + # Ensure inputs are contiguous and on correct device + A = A.contiguous().to(device=device, dtype=torch_dtype) + B = B.contiguous().to(device=device, dtype=torch_dtype) + + # Pad all dimensions to tile sizes for kernel execution + M_pad = _align_up(M, tile_m) + N_pad = _align_up(N, tile_n) + K_pad = _align_up(K, tile_k) + + # Create padded tensors + A_pad = torch.zeros(M_pad, K_pad, dtype=torch_dtype, device=device) + B_pad = torch.zeros(N_pad, K_pad, dtype=torch_dtype, device=device) + C_pad = torch.zeros(M_pad, N_pad, dtype=torch_dtype, device=device) + + # Copy original data to padded tensors + A_pad[:M, :K] = A + B_pad[:N, :K] = B + + # Compile and run kernel + exe = compile_simple_gemm( + tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, + in_dtype=in_dtype, + ) + + # Flatten tensors for kernel interface + A_flat = A_pad.view(-1) + B_flat = B_pad.view(-1) + C_flat = C_pad.view(-1) + + # Pass both padded and original dimensions + exe(C_flat, A_flat, B_flat, M_pad, N_pad, K_pad, M, N) + torch.cuda.synchronize() + + # Extract the actual output (only the M×N portion) + C = C_pad[:M, :N].contiguous() + + return C diff --git a/run.sh b/run.sh new file mode 100755 index 00000000..8e701d6a --- /dev/null +++ b/run.sh @@ -0,0 +1,87 @@ +set -x + +shopt -s expand_aliases + +alias l.='ls -d .* --color=auto' +alias ll='ls -l --color=auto' +alias ls='ls --color=auto' +alias python='python3' + +# export HIP_VISIBLE_DEVICES=0 +# export HIP_VISIBLE_DEVICES=1 +# export HIP_VISIBLE_DEVICES=3 +# export HIP_VISIBLE_DEVICES=5 +export HIP_VISIBLE_DEVICES=6 +# export HIP_VISIBLE_DEVICES=7 + + +# export LD_LIBRARY_PATH=/mnt/raid0/heyanguang/code/poc_kl/scripts/common:$LD_LIBRARY_PATH +# export LD_LIBRARY_PATH=/usr/local/lib/python3.12/dist-packages/torch/lib:$LD_LIBRARY_PATH +# export PATH=/mnt/raid0/heyanguang/code/poc_kl/scripts/common:$PATH + +rocm-smi | egrep "$HIP_VISIBLE_DEVICES |Device" +pip show triton +rocprofv3 --version + + +function run_flydsl_op { + export FLIR_LOG_MORE=1 + export FLIR_DUMP_IR=1 + export FLIR_REBUILD=1 + export FLIR_DUMP_DIR=./flydsl_dump + + # python tests/kernels/test_moe_stage1_simple.py --size M + + python tests/kernels/test_simple_gemm.py --size XL + # python tests/kernels/test_simple_gemm.py --size NA4 +} + + +function get_flydsl_op_thread_trace { + pushd $PWD + export KERNEL_NAME=kernel_gemm + KERNEL_VERSION="${KERNEL_NAME}_v0" + + + DUMP_TRACE=1 + # DUMP_TRACE=0 + if [ $DUMP_TRACE = 1 ]; then + rm -rf ./pass_2 + cd ./thread_trace + trace_dir=./${KERNEL_VERSION} + rm -rf ./rpf_v3 + rm -rf ./${trace_dir} ./${trace_dir}.tar.gz + mkdir -p ${trace_dir} + cd - + + rocprofv3 -i ./input.yaml -- \ + python tests/kernels/test_simple_gemm.py --size XL + + cd ./thread_trace + cp -r ./rpf_v3/pass_1/*.att ${trace_dir} + cp -r ./rpf_v3/pass_1/ui_* ${trace_dir} + cp -r ./rpf_v3/pass_1/*_agent_info.csv ${trace_dir} + cp -r ./rpf_v3/pass_1/stats_ui_*.csv ${trace_dir} + tar -zcf ./${trace_dir}.tar.gz ./${trace_dir} + ls -lah ./${trace_dir} ./${trace_dir}.tar.gz + cd - + fi + + popd +} + + +# # Press y then n while install +# ./rocprof-trace-decoder-manylinux-2.28-0.1.6-Linux.sh --prefix=/opt/rocm/ +# cd /opt/rocm/ +# ll -ah ./opt/rocm/lib/librocprof-trace-decoder.so +# ll -ah ./lib/librocprof-trace-decoder.so +# cp opt/rocm/lib/librocprof-trace-decoder.so ./lib/ +# ll -ah ./lib/librocprof-trace-decoder.so + + +run_flydsl_op +get_flydsl_op_thread_trace + + +set +x diff --git a/tests/kernels/test_moe_stage1_simple.py b/tests/kernels/test_moe_stage1_simple.py new file mode 100644 index 00000000..2cbbfa47 --- /dev/null +++ b/tests/kernels/test_moe_stage1_simple.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +""" +Simple test script for run_moe_stage1 (no pytest required). + +Usage: + python tests/kernels/test_moe_stage1_simple.py [--size S|M|L] [--dtype fp8|fp16|int8|int4|all] + +Examples: + python tests/kernels/test_moe_stage1_simple.py # Run Small with fp8 + python tests/kernels/test_moe_stage1_simple.py --size M # Run Medium with fp8 + python tests/kernels/test_moe_stage1_simple.py --dtype all # Run Small with all dtypes + python tests/kernels/test_moe_stage1_simple.py --size L --dtype fp8 # Run Large with fp8 +""" + +import argparse +import os +import sys + +# Ensure repo-local flydsl is used +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +import torch + +# Import run_moe_stage1 from the test file +from tests.kernels.test_moe_gemm import run_moe_stage1 + +# Test configurations (from pytest.param) +TEST_CONFIGS = { + "S": { + "tokens": 64, + "model_dim": 256, + "inter_dim": 128, + "experts": 4, + "topk": 2, + "tile_m": 32, + "tile_n": 64, + "tile_k": 128, + "doweight_stage1": False, + "description": "Small smoke test", + }, + "M": { + "tokens": 128, + "model_dim": 1024, + "inter_dim": 256, + "experts": 8, + "topk": 2, + "tile_m": 64, + "tile_n": 128, + "tile_k": 128, + "doweight_stage1": False, + "description": "Medium realistic test", + }, + "L": { + "tokens": 256, + "model_dim": 4096, + "inter_dim": 2048, + "experts": 17, + "topk": 9, + "tile_m": 64, + "tile_n": 128, + "tile_k": 128, + "doweight_stage1": False, + "description": "Large aiter-style test", + }, +} + +DTYPES = ["fp8", "fp16", "int8", "int4"] + + +def run_test(size: str, in_dtype: str, num_iters: int = 5, num_warmup: int = 2, skip_ref: bool = False): + """Run a single stage1 test.""" + config = TEST_CONFIGS[size] + + print("=" * 70) + print(f"Running MoE Stage1 Test: size={size} ({config['description']}), dtype={in_dtype}") + print(f" tokens={config['tokens']}, model_dim={config['model_dim']}, inter_dim={config['inter_dim']}") + print(f" experts={config['experts']}, topk={config['topk']}") + print(f" tile_m={config['tile_m']}, tile_n={config['tile_n']}, tile_k={config['tile_k']}") + print("=" * 70) + + try: + run_moe_stage1( + tokens=config["tokens"], + model_dim=config["model_dim"], + inter_dim=config["inter_dim"], + experts=config["experts"], + topk=config["topk"], + tile_m=config["tile_m"], + tile_n=config["tile_n"], + tile_k=config["tile_k"], + doweight_stage1=config["doweight_stage1"], + in_dtype=in_dtype, + seed=0, + num_iters=num_iters, + num_warmup=num_warmup, + compare_aiter_ck=False, # Skip aiter comparison by default + moe_sort_mode="torch", # Use torch sorting for portability + skip_ref=skip_ref, + ) + print(f"[PASS] size={size}, dtype={in_dtype}\n") + return True + except Exception as e: + print(f"[FAIL] size={size}, dtype={in_dtype}") + print(f" Error: {e}\n") + import traceback + traceback.print_exc() + return False + + +def main(): + parser = argparse.ArgumentParser( + description="Simple MoE Stage1 test (no pytest)", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--size", "-s", + type=str, + choices=["S", "M", "L", "all"], + default="S", + help="Test size: S (small), M (medium), L (large), or all", + ) + parser.add_argument( + "--dtype", "-d", + type=str, + choices=["fp8", "fp16", "int8", "int4", "all"], + default="fp8", + help="Input data type (default: fp8)", + ) + parser.add_argument( + "--num_iters", "-n", + type=int, + default=100, + help="Number of benchmark iterations (default: 5)", + ) + parser.add_argument( + "--num_warmup", "-w", + type=int, + default=2, + help="Number of warmup iterations (default: 2)", + ) + parser.add_argument( + "--skip_ref", + action="store_true", + help="Skip reference correctness check (benchmark only)", + ) + + args = parser.parse_args() + + # Check CUDA availability + if not torch.cuda.is_available(): + print("ERROR: CUDA/ROCm not available. Cannot run GPU tests.") + sys.exit(1) + + torch.set_default_device("cuda") + + # Determine sizes and dtypes to run + sizes = list(TEST_CONFIGS.keys()) if args.size == "all" else [args.size] + dtypes = DTYPES if args.dtype == "all" else [args.dtype] + + print(f"\nRunning MoE Stage1 tests: sizes={sizes}, dtypes={dtypes}") + print(f"GPU: {torch.cuda.get_device_name(0)}\n") + + results = [] + for size in sizes: + for dtype in dtypes: + passed = run_test( + size=size, + in_dtype=dtype, + num_iters=args.num_iters, + num_warmup=args.num_warmup, + skip_ref=args.skip_ref, + ) + results.append((size, dtype, passed)) + + # Summary + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + passed = sum(1 for _, _, p in results if p) + total = len(results) + for size, dtype, p in results: + status = "PASS" if p else "FAIL" + print(f" [{status}] size={size}, dtype={dtype}") + print(f"\nTotal: {passed}/{total} passed") + + sys.exit(0 if passed == total else 1) + + +if __name__ == "__main__": + main() diff --git a/tests/kernels/test_simple_gemm.py b/tests/kernels/test_simple_gemm.py new file mode 100644 index 00000000..2c5595e4 --- /dev/null +++ b/tests/kernels/test_simple_gemm.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python3 +""" +Simple test script for the simple GEMM kernel. + +Usage: + python tests/kernels/test_simple_gemm.py [--size S|M|L|XL|NA1|NA2|all] [--dtype bf16|fp16|all] + +Examples: + python tests/kernels/test_simple_gemm.py # Run Small with bf16 + python tests/kernels/test_simple_gemm.py --size M # Run Medium with bf16 + python tests/kernels/test_simple_gemm.py --dtype all # Run Small with all dtypes + python tests/kernels/test_simple_gemm.py --size all # Run all sizes with bf16 + python tests/kernels/test_simple_gemm.py --size NA1 # Non-aligned test 1 +""" + +import argparse +import logging +import os +import sys + +# Ensure repo-local flydsl is used +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +import torch + +from kernels.simple_gemm import compile_simple_gemm, run_simple_gemm +from tests.test_common import run_perftest, verify_output + +# Configure logging to show INFO level messages (required for kernel name display) +logging.basicConfig(level=logging.INFO) + +# Test configurations +# Aligned tests: M, N, K are multiples of tile sizes +TEST_CONFIGS = { + "S": { + "M": 16, + "N": 64, + "K": 128, + "tile_m": 16, + "tile_n": 64, + "tile_k": 128, + "description": "Small smoke test (single tile)", + }, + "M": { + "M": 64, + "N": 128, + "K": 256, + "tile_m": 16, + "tile_n": 64, + "tile_k": 128, + "description": "Medium test (multi-tile)", + }, + "L": { + "M": 256, + "N": 512, + "K": 512, + "tile_m": 16, + "tile_n": 64, + "tile_k": 128, + "description": "Large test", + }, + "XL": { + "M": 1280, + "N": 2048, + "K": 128, + "tile_m": 16, + "tile_n": 64, + "tile_k": 128, + "description": "Extra large test", + }, + # Non-aligned tests: M, N, K are NOT multiples of 16 + "NA1": { + "M": 33, # Not aligned to 16 + "N": 87, # Not aligned to 64 + "K": 145, # Not aligned to 128 + "tile_m": 16, + "tile_n": 64, + "tile_k": 128, + "description": "Non-aligned test 1 (M=33, N=87, K=145)", + }, + "NA2": { + "M": 57, # Not aligned to 16 + "N": 123, # Not aligned to 64 + "K": 259, # Not aligned to 128 + "tile_m": 16, + "tile_n": 64, + "tile_k": 128, + "description": "Non-aligned test 2 (M=57, N=123, K=259)", + }, + "NA3": { + "M": 100, # Not aligned to 16 + "N": 200, # Not aligned to 64 + "K": 300, # Not aligned to 128 + "tile_m": 16, + "tile_n": 64, + "tile_k": 128, + "description": "Non-aligned test 3 (M=100, N=200, K=300)", + }, + "NA4": { + "M": 171, # Not aligned to 16 + "N": 333, # Not aligned to 64 + "K": 517, # Not aligned to 128 + "tile_m": 16, + "tile_n": 64, + "tile_k": 128, + "description": "Non-aligned test 4 (M=171, N=333, K=517)", + }, +} + +DTYPES = ["bf16", "fp16"] + + +def get_torch_dtype(in_dtype: str): + """Convert string dtype to torch dtype.""" + if in_dtype == "bf16": + return torch.bfloat16 + elif in_dtype == "fp16": + return torch.float16 + else: + raise ValueError(f"Unknown dtype: {in_dtype}") + + +def _align_up(val: int, align: int) -> int: + """Round up val to the next multiple of align.""" + return ((val + align - 1) // align) * align + + +def run_test( + size: str, + in_dtype: str, + num_iters: int = 100, + num_warmup: int = 5, + skip_ref: bool = False, + rtol: float = 1e-2, + atol: float = 1e-2, +): + """Run a single GEMM test.""" + config = TEST_CONFIGS[size] + M = config["M"] + N = config["N"] + K = config["K"] + tile_m = config["tile_m"] + tile_n = config["tile_n"] + tile_k = config["tile_k"] + + # Pad all dimensions to tile sizes + M_pad = _align_up(M, tile_m) + N_pad = _align_up(N, tile_n) + K_pad = _align_up(K, tile_k) + + print("=" * 70) + print(f"Running Simple GEMM Test: size={size} ({config['description']}), dtype={in_dtype}") + print(f" M={M}, N={N}, K={K}") + print(f" M_pad={M_pad}, N_pad={N_pad}, K_pad={K_pad}") + print(f" tile_m={tile_m}, tile_n={tile_n}, tile_k={tile_k}") + print("=" * 70) + + torch_dtype = get_torch_dtype(in_dtype) + device = "cuda" + + try: + # Create random inputs (original size) + torch.manual_seed(42) + A_orig = torch.randn(M, K, dtype=torch_dtype, device=device) + B_orig = torch.randn(N, K, dtype=torch_dtype, device=device) + + # Run reference computation (using float32 for accuracy) with original dimensions + if not skip_ref: + A_f32 = A_orig.to(torch.float32) + B_f32 = B_orig.to(torch.float32) + C_ref = torch.mm(A_f32, B_f32.T).to(torch_dtype) + + # Create padded tensors + A_pad = torch.zeros(M_pad, K_pad, dtype=torch_dtype, device=device) + B_pad = torch.zeros(N_pad, K_pad, dtype=torch_dtype, device=device) + C_pad = torch.zeros(M_pad, N_pad, dtype=torch_dtype, device=device) + + # Copy original data + A_pad[:M, :K] = A_orig + B_pad[:N, :K] = B_orig + + # Compile kernel + print("Compiling kernel...") + exe = compile_simple_gemm( + tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, + in_dtype=in_dtype, + ) + print("Kernel compiled successfully.") + + # Flatten tensors for kernel interface + A_flat = A_pad.view(-1) + B_flat = B_pad.view(-1) + C_flat = C_pad.view(-1) + C_flat.zero_() + + # Define launch function for run_perftest + def launch(): + exe(C_flat, A_flat, B_flat, M_pad, N_pad, K_pad, M, N) + + # Warmup and benchmark using run_perftest + print(f"Running {num_warmup} warmup + {num_iters} benchmark iterations...") + _, us = run_perftest( + launch, + num_iters=num_iters, + num_warmup=num_warmup, + ) + torch.cuda.synchronize() + + # Calculate TFLOPS + flops = 2 * M * N * K # 2 ops per element (multiply + add) + tflops = flops / (us / 1e6) / 1e12 + + print(f" Time per iteration: {us:.3f} us ({us/1000:.3f} ms)") + print(f" Throughput: {tflops:.2f} TFLOPS") + + # Verify correctness + if not skip_ref: + # Run one more time for correctness check + C_flat.zero_() + exe(C_flat, A_flat, B_flat, M_pad, N_pad, K_pad, M, N) + torch.cuda.synchronize() + # Extract only the M×N portion from the padded output + C_result = C_pad[:M, :N] + + # Check correctness using verify_output + passed = verify_output( + C_result.to(torch.float32), + C_ref.to(torch.float32), + rtol=rtol, + atol=atol, + msg=f"size={size}, dtype={in_dtype}" + ) + + if not passed: + # Print more details for debugging + max_diff = (C_result - C_ref).abs().max().item() + mean_diff = (C_result - C_ref).abs().mean().item() + print(f" Max diff: {max_diff:.6f}") + print(f" Mean diff: {mean_diff:.6f}") + print("\n Sample values (first 4x4):") + print(f" Result:\n{C_result[:4, :4]}") + print(f" Reference:\n{C_ref[:4, :4]}") + return False + + print(f"[PASS] size={size}, dtype={in_dtype}\n") + return True + + except Exception as e: + print(f"[FAIL] size={size}, dtype={in_dtype}") + print(f" Error: {e}\n") + import traceback + traceback.print_exc() + return False + + +def main(): + parser = argparse.ArgumentParser( + description="Simple GEMM kernel test", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + parser.add_argument( + "--size", "-s", + type=str, + choices=list(TEST_CONFIGS.keys()) + ["all", "aligned", "nonaligned"], + default="S", + help="Test size: S/M/L/XL (aligned), NA1/NA2/NA3/NA4 (non-aligned), all, aligned, or nonaligned", + ) + parser.add_argument( + "--dtype", "-d", + type=str, + choices=["bf16", "fp16", "all"], + default="bf16", + help="Input data type (default: bf16)", + ) + parser.add_argument( + "--num_iters", "-n", + type=int, + default=100, + help="Number of benchmark iterations (default: 100)", + ) + parser.add_argument( + "--num_warmup", "-w", + type=int, + default=5, + help="Number of warmup iterations (default: 5)", + ) + parser.add_argument( + "--skip_ref", + action="store_true", + help="Skip reference correctness check (benchmark only)", + ) + parser.add_argument( + "--rtol", + type=float, + default=1e-2, + help="Relative tolerance for correctness check (default: 1e-2)", + ) + parser.add_argument( + "--atol", + type=float, + default=1e-2, + help="Absolute tolerance for correctness check (default: 1e-2)", + ) + + args = parser.parse_args() + + # Check CUDA availability + if not torch.cuda.is_available(): + print("ERROR: CUDA/ROCm not available. Cannot run GPU tests.") + sys.exit(1) + + torch.set_default_device("cuda") + + # Determine sizes and dtypes to run + aligned_sizes = ["S", "M", "L", "XL"] + nonaligned_sizes = ["NA1", "NA2", "NA3", "NA4"] + + if args.size == "all": + sizes = list(TEST_CONFIGS.keys()) + elif args.size == "aligned": + sizes = aligned_sizes + elif args.size == "nonaligned": + sizes = nonaligned_sizes + else: + sizes = [args.size] + + dtypes = DTYPES if args.dtype == "all" else [args.dtype] + + print(f"\nRunning Simple GEMM tests: sizes={sizes}, dtypes={dtypes}") + print(f"GPU: {torch.cuda.get_device_name(0)}\n") + + results = [] + for size in sizes: + for dtype in dtypes: + passed = run_test( + size=size, + in_dtype=dtype, + num_iters=args.num_iters, + num_warmup=args.num_warmup, + skip_ref=args.skip_ref, + rtol=args.rtol, + atol=args.atol, + ) + results.append((size, dtype, passed)) + + # Summary + print("\n" + "=" * 70) + print("SUMMARY") + print("=" * 70) + passed = sum(1 for _, _, p in results if p) + total = len(results) + for size, dtype, p in results: + status = "PASS" if p else "FAIL" + print(f" [{status}] size={size}, dtype={dtype}") + print(f"\nTotal: {passed}/{total} passed") + + sys.exit(0 if passed == total else 1) + + +if __name__ == "__main__": + main() diff --git a/thread_trace/.gitkeep b/thread_trace/.gitkeep new file mode 100644 index 00000000..e69de29b From 7a3d05b9fa9e49a9b39ff45d9a1ca2432fa6a1be Mon Sep 17 00:00:00 2001 From: yanguahe Date: Wed, 4 Feb 2026 18:49:24 +0800 Subject: [PATCH 02/17] Add waves_per_eu support and switch to mask-based boundary handling in simple GEMM - Add waves_per_eu parameter to compiler.compile() for AMDGPU occupancy hints - Implement _apply_waves_per_eu_on_llvm_funcs() to set amdgpu-waves-per-eu attribute on GPU kernel functions via LLVM passthrough - Refactor simple_gemm to use mask-based loads/stores for M/N boundaries instead of host-side padding (Triton-like approach) - Only K dimension is padded on host (required for MFMA vector loads) - Add --waves_per_eu CLI argument to test_simple_gemm.py Co-authored-by: Cursor --- flydsl/src/flydsl/compiler/compiler.py | 95 +++++++- kernels/simple_gemm.py | 288 +++++++++++++------------ run.sh | 8 +- tests/kernels/test_simple_gemm.py | 58 +++-- 4 files changed, 279 insertions(+), 170 deletions(-) diff --git a/flydsl/src/flydsl/compiler/compiler.py b/flydsl/src/flydsl/compiler/compiler.py index b283a682..7e2ec98c 100644 --- a/flydsl/src/flydsl/compiler/compiler.py +++ b/flydsl/src/flydsl/compiler/compiler.py @@ -192,6 +192,60 @@ def _infer_kernel_names_from_asm(asm: str) -> list[str]: return names +def _apply_waves_per_eu_on_llvm_funcs(module: ir.Module, waves_per_eu: int) -> None: + """Apply AMDGPU waves-per-eu hint to llvm.func ops via LLVM passthrough. + + This sets the 'amdgpu-waves-per-eu' attribute on GPU kernel functions, + which hints the LLVM backend about the desired occupancy per EU. + + The passthrough attribute format for LLVM attributes with values is: + ["attribute-name", "attribute-value"] + """ + # For attributes with values, passthrough needs an ArrayAttr with [key, value] + attr_key = ir.StringAttr.get("amdgpu-waves-per-eu") + attr_value = ir.StringAttr.get(f"{waves_per_eu},{waves_per_eu}") + new_entry = ir.ArrayAttr.get([attr_key, attr_value]) + new_entry_str = f"amdgpu-waves-per-eu={waves_per_eu},{waves_per_eu}" + + def _append_passthrough(func_op): + try: + existing = func_op.attributes["passthrough"] + except KeyError: + existing = None + + if existing is None: + func_op.attributes["passthrough"] = ir.ArrayAttr.get([new_entry]) + return + + # Best-effort: if it's not an ArrayAttr-like object, just overwrite. + try: + existing_entries = list(existing) + except TypeError: + func_op.attributes["passthrough"] = ir.ArrayAttr.get([new_entry]) + return + + if any(str(a).strip('"') == new_entry_str for a in existing_entries): + return + func_op.attributes["passthrough"] = ir.ArrayAttr.get(existing_entries + [new_entry]) + + try: + for op in module.body.operations: + if getattr(op, "OPERATION_NAME", None) != "gpu.module": + continue + # gpu.module has a single region with a single block + gpu_module_body = op.regions[0].blocks[0] if hasattr(op, 'regions') else op.body + for inner_op in gpu_module_body.operations: + if getattr(inner_op, "OPERATION_NAME", None) != "llvm.func": + continue + # Check for gpu.kernel attribute (it's a unit attribute) + if "gpu.kernel" not in inner_op.attributes: + continue + _append_passthrough(inner_op) + except Exception: + # Best-effort only. + pass + + def compile( flir_module_or_ir: Union[object, ir.Module], *, @@ -203,6 +257,7 @@ def compile( use_bare_ptr_memref_call_conv: bool = False, use_bare_pointers_for_host: bool = False, use_bare_pointers_for_kernels: bool = False, + waves_per_eu: Optional[int] = None, ) -> Executor: """Compile a FLIR module to an Executor. @@ -330,8 +385,14 @@ def compile( # Dump ISA from the *post-LLVM* module (right before fatbin emission). # This mirrors `tests/utils.py:compile_to_hsaco` and yields readable assembly. + # Also apply waves_per_eu here (after LLVM lowering, before binary generation). + # Match only the top-level reconcile-unrealized-casts, not the one inside gpu.module if frag.strip() == "reconcile-unrealized-casts": - asm_for_isa = stage_asm + # Apply waves_per_eu if specified (BEFORE saving asm_for_isa) + if waves_per_eu is not None: + _apply_waves_per_eu_on_llvm_funcs(module, waves_per_eu) + # Get ASM after applying waves_per_eu + asm_for_isa = module.operation.get_asm(enable_debug_info=True) if asm_for_isa is not None: isa_out = _dump_isa_from_rocdl_module_asm( @@ -344,9 +405,35 @@ def compile( isa_stage = f"{stage_num_base + len(stage_frags):02d}_final_isa" print(f"[flir.compile] dump {isa_stage} -> {isa_out}") else: - pm = PassManager.parse(pipeline, context=ctx) - pm.enable_verifier(bool(verify)) - pm.run(module.operation) + if waves_per_eu is not None: + # When waves_per_eu is specified, we need to split the pipeline + # to apply the attribute after LLVM lowering but before binary generation. + stage_frags = _pipeline_fragments( + chip=chip, + use_bare_ptr_memref_call_conv=use_bare_ptr_memref_call_conv, + use_bare_pointers_for_host=use_bare_pointers_for_host, + use_bare_pointers_for_kernels=use_bare_pointers_for_kernels, + ) + # Run all passes except the last one (gpu-module-to-binary) + pre_binary_frags = stage_frags[:-1] + binary_frag = stage_frags[-1] + + pre_binary_pipeline = f"builtin.module({','.join(pre_binary_frags)})" + pm = PassManager.parse(pre_binary_pipeline, context=ctx) + pm.enable_verifier(bool(verify)) + pm.run(module.operation) + + # Apply waves_per_eu + _apply_waves_per_eu_on_llvm_funcs(module, waves_per_eu) + + # Run the final binary generation pass + pm_binary = PassManager.parse(f"builtin.module({binary_frag})", context=ctx) + pm_binary.enable_verifier(bool(verify)) + pm_binary.run(module.operation) + else: + pm = PassManager.parse(pipeline, context=ctx) + pm.enable_verifier(bool(verify)) + pm.run(module.operation) if print_final_module: print(module) diff --git a/kernels/simple_gemm.py b/kernels/simple_gemm.py index b9c6caa5..83064b60 100644 --- a/kernels/simple_gemm.py +++ b/kernels/simple_gemm.py @@ -6,24 +6,26 @@ - Block: 256 threads = 4 waves × 64 lanes - Tile: M=16 × N=64 × K=128 (configurable) - Currently supports bf16/fp16 input, f32 accumulator, bf16/fp16 output -- Supports non-aligned M, N, K dimensions: - - M and N: kernel boundary checks on output stores - - K: caller pads to multiple of 8 (vector load granularity) + +Non-aligned shape handling (Triton-like approach): +- M and N: mask-based loads/stores in kernel (no host padding needed) +- K: padded to tile_k on host (required for MFMA vector loads) +- num_records_bytes: explicitly set in buffer resource descriptor for hardware OOB A matrix loading: - GM → GPR → LDS: 256 threads cooperatively load the A tile with XOR16 swizzle -- LDS → GPR: Each wave reads 16×32 sub-block for MFMA +- Mask-based: OOB elements load zeros via buffer descriptor bounds checking B matrix loading: - Direct load: Each wave handles 16 columns of N -- Wave 0: N[0:16], Wave 1: N[16:32], Wave 2: N[32:48], Wave 3: N[48:64] +- Mask-based: OOB elements load zeros via buffer descriptor bounds checking Output C matrix (16×64): - Wave 0 → C[0:16, 0:16] - Wave 1 → C[0:16, 16:32] - Wave 2 → C[0:16, 32:48] - Wave 3 → C[0:16, 48:64] -- Boundary check: skip stores for out-of-bounds elements +- Mask-based stores: OOB stores are skipped via select(mask, offset, MAX_OFFSET) """ import functools @@ -35,12 +37,16 @@ from flydsl.utils import SmemAllocator from _mlir import ir -from _mlir.dialects import scf as _scf -from flydsl.dialects.ext import arith, gpu, buffer_ops, vector, rocdl, scf +from flydsl.dialects.ext import arith, gpu, buffer_ops, vector, rocdl from flydsl.lang.ir.types import T, memref +def _align_up(val: int, align: int) -> int: + """Round up val to the next multiple of align.""" + return ((val + align - 1) // align) * align + + @functools.lru_cache(maxsize=1024) def compile_simple_gemm( *, @@ -48,23 +54,25 @@ def compile_simple_gemm( tile_n: int = 64, tile_k: int = 128, in_dtype: str = "bf16", + waves_per_eu: int = None, ): """Compile a simple GEMM kernel and return the compiled executable. - This kernel supports non-aligned M and N dimensions via output boundary checks. - The caller is responsible for: - - Passing M_pad, N_pad, K_pad as the actual tensor dimensions (padded to tile sizes) - - Passing M_orig, N_orig as the original dimensions for boundary checking on output + This kernel supports non-aligned M, N, K dimensions via mask-based loads/stores. + No host-side padding required. Args: tile_m, tile_n, tile_k: Block tile sizes. in_dtype: Input data type ("bf16" or "fp16"). + waves_per_eu: Optional hint for AMDGPU backend about the desired number of waves + per execution unit. This affects occupancy optimization. """ if in_dtype not in ("bf16", "fp16"): raise ValueError(f"in_dtype must be 'bf16' or 'fp16', got {in_dtype!r}") is_bf16 = in_dtype == "bf16" elem_bytes = 2 # bf16 and fp16 are both 2 bytes + out_elem_bytes = 2 # output is also bf16/fp16 # Validate tile configuration tile_k_bytes = tile_k * elem_bytes @@ -115,11 +123,9 @@ def kernel_gemm( arg_c: lambda: memref(DYN, _out_type()), arg_a: lambda: memref(DYN, _elem_type()), arg_b: lambda: memref(DYN, _elem_type()), - c_m_pad: lambda: T.index, # Padded M dimension (tensor size) - c_n_pad: lambda: T.index, # Padded N dimension (tensor size) - c_k_pad: lambda: T.index, # Padded K dimension (tensor size) - c_m_orig: lambda: T.index, # Original M dimension (for boundary check) - c_n_orig: lambda: T.index, # Original N dimension (for boundary check) + c_m: lambda: T.index, # Original M dimension + c_n: lambda: T.index, # Original N dimension + c_k: lambda: T.index, # Original K dimension ): # ================= Types ================= f32 = T.f32 @@ -135,15 +141,33 @@ def kernel_gemm( # Accumulator initialization acc_init = arith.constant_vector(0.0, vec4_f32) + # ================= Buffer sizes in bytes for OOB handling ================= + # A: [M, K] -> M * K * elem_bytes + a_nbytes_idx = c_m * c_k * arith.constant(elem_bytes, index=True) + a_nbytes_i32 = arith.index_cast(i32, a_nbytes_idx) + + # B: [N, K] -> N * K * elem_bytes + b_nbytes_idx = c_n * c_k * arith.constant(elem_bytes, index=True) + b_nbytes_i32 = arith.index_cast(i32, b_nbytes_idx) + + # C: [M, N] -> M * N * out_elem_bytes + c_nbytes_idx = c_m * c_n * arith.constant(out_elem_bytes, index=True) + c_nbytes_i32 = arith.index_cast(i32, c_nbytes_idx) + + # ================= Buffer Resources with explicit sizes ================= + a_rsrc = buffer_ops.create_buffer_resource(arg_a, max_size=False, num_records_bytes=a_nbytes_i32) + b_rsrc = buffer_ops.create_buffer_resource(arg_b, max_size=False, num_records_bytes=b_nbytes_i32) + c_rsrc = buffer_ops.create_buffer_resource(arg_c, max_size=False, num_records_bytes=c_nbytes_i32) + # ================= Layouts ================= - # A layout: [M_pad, K_pad] row-major - layout_a = flir.make_layout((c_m_pad, c_k_pad), stride=(c_k_pad, 1)) + # A layout: [M, K] row-major + layout_a = flir.make_layout((c_m, c_k), stride=(c_k, 1)) - # B layout: [N_pad, K_pad] row-major (B^T in standard GEMM) - layout_b = flir.make_layout((c_n_pad, c_k_pad), stride=(c_k_pad, 1)) + # B layout: [N, K] row-major (B^T in standard GEMM) + layout_b = flir.make_layout((c_n, c_k), stride=(c_k, 1)) - # C layout: [M_pad, N_pad] row-major - layout_c = flir.make_layout((c_m_pad, c_n_pad), stride=(c_n_pad, 1)) + # C layout: [M, N] row-major + layout_c = flir.make_layout((c_m, c_n), stride=(c_n, 1)) # LDS layout: [tile_m, tile_k] shape_lds = flir.make_shape(tile_m, tile_k) @@ -180,11 +204,6 @@ def kernel_gemm( lds_a_ptr = _state["lds_a_decl"](base_ptr) lds_a = lds_a_ptr.get() - # ================= Buffer Resources ================= - a_rsrc = buffer_ops.create_buffer_resource(arg_a, max_size=False) - b_rsrc = buffer_ops.create_buffer_resource(arg_b, max_size=False) - c_rsrc = buffer_ops.create_buffer_resource(arg_c, max_size=False) - # ================= Wave/Lane Mappings ================= # For MFMA 16x16x16: # - A row index: lane_mod_16 (0..15) @@ -208,7 +227,7 @@ def kernel_gemm( c_n_per_wave = arith.constant(n_per_wave, index=True) n_tile_base = wave_id * c_n_per_wave - # ================= A Tile Loading (GM -> LDS) ================= + # ================= A Tile Loading (GM -> LDS) with mask ================= # 256 threads load tile_m × tile_k elements (16B per thread) bytes_a_per_tile = tile_m * tile_k * elem_bytes bytes_per_thread_a = bytes_a_per_tile // total_threads @@ -220,14 +239,10 @@ def kernel_gemm( (tile_m, tile_k_dwords), stride=(tile_k_dwords, 1) ) - # Convert K to dwords for A addressing - c_k_div4bytes = (c_k_pad * arith.constant(elem_bytes, index=True)) / arith.constant(4, index=True) - layout_a_div4 = flir.make_layout((c_m_pad, c_k_div4bytes), stride=(c_k_div4bytes, 1)) - c4 = arith.constant(4, index=True) + c8 = arith.constant(8, index=True) tx_i32_base = tx * c4 - atom_a_g2r = flir.make_copy_atom(_elem_type(), vector_size=8) atom_a_lds = flir.make_copy_atom(_elem_type(), vector_size=8) def a_tile_chunk_coord(i: int): @@ -240,34 +255,32 @@ def a_tile_chunk_coord(i: int): return row_local, col_local_i32 def load_a_tile(base_k): - """Load A tile from global memory (tile_m × tile_k).""" - base_k_bytes = base_k * arith.constant(elem_bytes, index=True) - base_k_div4 = base_k_bytes / arith.constant(4, index=True) + """Load A tile from global memory (tile_m × tile_k) with mask.""" parts = [] for i in range_constexpr(num_a_loads): row_local, col_local_i32 = a_tile_chunk_coord(i) row_global = bx_m + row_local - coord_a_g = flir.make_coord(row_global, base_k_div4 + col_local_i32) - idx_i32 = flir.crd2idx(coord_a_g, layout_a_div4) - # Convert dword index to element index - idx_elem = idx_i32 * arith.constant(2, index=True) # 2 bf16 per dword - - a_view = flir.TensorView( - arg_a, - (8,), # 8 bf16 elements = 16 bytes - strides=(1,), - base_indices=(idx_elem,), - element_type=_elem_type(), - ) - a_vec = flir.copy( - atom_a_g2r, - a_view, - None, - alignment=16, - return_vector=True, - src_buffer_resource=None, # Use memref directly - ) - parts.append(vector.bitcast(T.i32x4, a_vec)) + # col_local_i32 is in dwords (4 bytes), convert to elements + col_local_elem = col_local_i32 * arith.constant(2, index=True) # 2 bf16 per dword + k_global = base_k + col_local_elem + + # Calculate linear element offset for buffer_load + # buffer_load expects offset in elements (i32 unit), it will scale to bytes internally + # offset = row_global * K + k_global (in dword units for vec4 i32 load) + offset_elem = row_global * c_k + k_global + # Convert to dword offset (divide by 2 since 2 bf16 per dword) + offset_dword = offset_elem / arith.constant(2, index=True) + offset_i32 = arith.index_cast(i32, offset_dword) + + # Mask: row_global < M and (k_global + 7) < K + row_valid = arith.cmpu(row_global, c_m, "ult") + k_end = k_global + c8 + k_valid = arith.cmpu(k_end, c_k + arith.constant(1, index=True), "ult") + mask = arith.andi(row_valid, k_valid) + + # Load 4 dwords (16 bytes = 8 bf16 elements) with mask + a_i32x4 = buffer_ops.buffer_load(a_rsrc, offset_i32, vec_width=4, dtype=i32, mask=mask) + parts.append(a_i32x4) return parts def store_a_tile_to_lds(a_parts): @@ -290,9 +303,9 @@ def store_a_tile_to_lds(a_parts): ) flir.copy(atom_a_lds, v8, s_view, alignment=16) - # ================= B Tile Loading (Direct to GPR) ================= + # ================= B Tile Loading (Direct to GPR) with mask ================= def load_b_packs_k64(base_k, ku: int, ni: int): - """Load B pack for MFMA (16B -> 2 × i64 for K64-byte step).""" + """Load B pack for MFMA (16B -> 2 × i64 for K64-byte step) with mask.""" # Global N index for this wave/lane n_offset = arith.constant(ni * 16, index=True) n_global = by_n + n_tile_base + n_offset + lane_mod_16 @@ -305,28 +318,25 @@ def load_b_packs_k64(base_k, ku: int, ni: int): k_lane_offset = lane_div_16 * arith.constant(8, index=True) k_global = k_base + k_lane_offset - # Calculate linear index - coord_b = flir.make_coord(n_global, k_global) - idx_b = flir.crd2idx(coord_b, layout_b) - - # Load 8 elements (16 bytes) - b_view = flir.TensorView( - arg_b, - (8,), - strides=(1,), - base_indices=(idx_b,), - element_type=_elem_type(), - ) - b_vec = flir.copy( - atom_a_g2r, # Same atom type - b_view, - None, - alignment=16, - return_vector=True, - src_buffer_resource=None, - ) - - # Split into two i64 halves for two MFMA K16 steps + # Calculate linear element offset for buffer_load + # buffer_load with dtype=i32 scales offset by 4, so we need dword offset + # offset_elem = n_global * K + k_global (in bf16 elements) + # offset_dword = offset_elem / 2 (in i32 dwords) + offset_elem = n_global * c_k + k_global + offset_dword = offset_elem / arith.constant(2, index=True) + offset_i32 = arith.index_cast(i32, offset_dword) + + # Mask: n_global < N and (k_global + 7) < K + n_valid = arith.cmpu(n_global, c_n, "ult") + k_end = k_global + c8 + k_valid = arith.cmpu(k_end, c_k + arith.constant(1, index=True), "ult") + mask = arith.andi(n_valid, k_valid) + + # Load 4 dwords (16 bytes = 8 bf16 elements) with mask + b_i32x4 = buffer_ops.buffer_load(b_rsrc, offset_i32, vec_width=4, dtype=i32, mask=mask) + + # Convert to vec8 bf16/fp16, then split into two i64 halves + b_vec = vector.bitcast(vec8_elem, b_i32x4) b_i64x2 = vector.bitcast(vec2_i64, b_vec) b0_i64 = vector.extract(b_i64x2, static_position=[0], dynamic_position=[]) b1_i64 = vector.extract(b_i64x2, static_position=[1], dynamic_position=[]) @@ -424,9 +434,9 @@ def compute_tile(accs_in, b_tile_in, lds_base): return current_accs - # ================= Epilogue (Store C) with boundary checks ================= + # ================= Epilogue (Store C) with mask ================= def store_output(final_accs): - """Store accumulated results to C with boundary checks.""" + """Store accumulated results to C with mask-based boundary check.""" lane_div_16_mul4 = lane_div_16 * arith.constant(4, index=True) for mi in range_constexpr(m_repeat): @@ -449,17 +459,19 @@ def store_output(final_accs): col = col_base + arith.constant(ni * 16, index=True) - # Boundary check: row < M_orig and col < N_orig - row_valid = arith.cmpu(row, c_m_orig, "ult") - col_valid = arith.cmpu(col, c_n_orig, "ult") - in_bounds = arith.andi(row_valid, col_valid) + # Calculate linear element offset for buffer_store + # buffer_store expects offset in elements (it will scale by element size) + # offset = row * N + col (in bf16/fp16 elements) + offset_elem = row * c_n + col + offset_i32 = arith.index_cast(i32, offset_elem) - # Conditional store using IfOp - if_op = scf.IfOp(in_bounds) - with if_op.then(): - coord_c = flir.make_coord(row, col) - idx_c = flir.crd2idx(coord_c, layout_c) - buffer_ops.buffer_store(val_out, c_rsrc, idx_c) + # Mask: row < M and col < N + row_valid = arith.cmpu(row, c_m, "ult") + col_valid = arith.cmpu(col, c_n, "ult") + mask = arith.andi(row_valid, col_valid) + + # Store with mask (OOB stores are skipped) + buffer_ops.buffer_store(val_out, c_rsrc, offset_i32, mask=mask) # ================= Main Pipeline ================= # Single LDS buffer, simple pipeline @@ -468,15 +480,17 @@ def store_output(final_accs): # Initialize accumulators accs = [acc_init] * (num_acc_n * m_repeat) - # K loop + # K loop - iterate over K in tile_k steps c_tile_k = arith.constant(tile_k, index=True) - for k_base in range(arith.constant(0, index=True), c_k_pad, c_tile_k): - # Load A tile to LDS + # Calculate number of K iterations needed (ceiling division) + # We iterate through all K blocks, mask handles the boundary + for k_base in range(arith.constant(0, index=True), c_k, c_tile_k): + # Load A tile to LDS (with mask for boundary) a_parts = load_a_tile(k_base) store_a_tile_to_lds(a_parts) gpu.barrier() - # Load B tile directly to GPR + # Load B tile directly to GPR (with mask for boundary) b_tile = load_b_tile(k_base) # Compute MFMA @@ -485,7 +499,7 @@ def store_output(final_accs): # Barrier before next iteration (if any) gpu.barrier() - # Store output (with boundary checks) + # Store output (with mask for boundary) store_output(accs) @flir.jit @@ -494,20 +508,18 @@ def __call__( arg_c: lambda: memref(DYN, _out_type()), arg_a: lambda: memref(DYN, _elem_type()), arg_b: lambda: memref(DYN, _elem_type()), - c_m_pad: lambda: T.index, - c_n_pad: lambda: T.index, - c_k_pad: lambda: T.index, - c_m_orig: lambda: T.index, - c_n_orig: lambda: T.index, + c_m: lambda: T.index, + c_n: lambda: T.index, + c_k: lambda: T.index, ): c1 = arith.constant(1, index=True) bdx = arith.constant(256, index=True) tm = arith.constant(tile_m, index=True) tn = arith.constant(tile_n, index=True) one = arith.constant(1, index=True) - # Use padded dimensions for grid size - gx = (c_m_pad + tm - one) / tm - gy = (c_n_pad + tn - one) / tn + # Grid size: ceiling division for non-aligned M and N + gx = (c_m + tm - one) / tm + gy = (c_n + tn - one) / tn flir.gpu_ext.LaunchFuncOp( [module_name, "kernel_gemm"], grid_size=(gx, gy, c1), @@ -516,21 +528,14 @@ def __call__( arg_c, arg_a, arg_b, - c_m_pad, - c_n_pad, - c_k_pad, - c_m_orig, - c_n_orig, + c_m, + c_n, + c_k, ], ) m = _GEMM() - return flydsl.compile(m) - - -def _align_up(val: int, align: int) -> int: - """Round up val to the next multiple of align.""" - return ((val + align - 1) // align) * align + return flydsl.compile(m, waves_per_eu=waves_per_eu) def run_simple_gemm( @@ -544,12 +549,13 @@ def run_simple_gemm( in_dtype: str = "bf16", A=None, B=None, + waves_per_eu: int = None, ): """Run simple GEMM: C = A @ B^T. - This function supports non-aligned M, N, K dimensions. - - M and N: handled by kernel boundary checks on output stores - - K: padded to multiple of tile_k to ensure correct computation + This function supports non-aligned M, N, K dimensions: + - M and N: handled by kernel mask-based loads/stores (Triton-like approach) + - K: padded to tile_k on host (required for MFMA vector loads) Args: M, N, K: Matrix dimensions (A[M,K], B[N,K], C[M,N]). @@ -557,6 +563,7 @@ def run_simple_gemm( in_dtype: Input data type ("bf16" or "fp16"). A: Optional input tensor A[M,K]. If None, creates random tensor. B: Optional input tensor B[N,K]. If None, creates random tensor. + waves_per_eu: Optional hint for AMDGPU backend about the desired number of waves. Returns: C: Output tensor C[M,N]. @@ -581,36 +588,35 @@ def run_simple_gemm( A = A.contiguous().to(device=device, dtype=torch_dtype) B = B.contiguous().to(device=device, dtype=torch_dtype) - # Pad all dimensions to tile sizes for kernel execution - M_pad = _align_up(M, tile_m) - N_pad = _align_up(N, tile_n) + # Pad K to tile_k (required for MFMA vector loads) + # M and N are handled by kernel mask-based boundary checks K_pad = _align_up(K, tile_k) - - # Create padded tensors - A_pad = torch.zeros(M_pad, K_pad, dtype=torch_dtype, device=device) - B_pad = torch.zeros(N_pad, K_pad, dtype=torch_dtype, device=device) - C_pad = torch.zeros(M_pad, N_pad, dtype=torch_dtype, device=device) - - # Copy original data to padded tensors - A_pad[:M, :K] = A - B_pad[:N, :K] = B + if K_pad != K: + A_pad = torch.zeros(M, K_pad, dtype=torch_dtype, device=device) + B_pad = torch.zeros(N, K_pad, dtype=torch_dtype, device=device) + A_pad[:, :K] = A + B_pad[:, :K] = B + A = A_pad + B = B_pad + K = K_pad + + # Create output tensor (original size, no padding needed for M and N) + C = torch.zeros(M, N, dtype=torch_dtype, device=device) # Compile and run kernel exe = compile_simple_gemm( tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, in_dtype=in_dtype, + waves_per_eu=waves_per_eu, ) # Flatten tensors for kernel interface - A_flat = A_pad.view(-1) - B_flat = B_pad.view(-1) - C_flat = C_pad.view(-1) + A_flat = A.view(-1) + B_flat = B.view(-1) + C_flat = C.view(-1) - # Pass both padded and original dimensions - exe(C_flat, A_flat, B_flat, M_pad, N_pad, K_pad, M, N) + # Pass dimensions (K is now padded, M and N are original) + exe(C_flat, A_flat, B_flat, M, N, K) torch.cuda.synchronize() - # Extract the actual output (only the M×N portion) - C = C_pad[:M, :N].contiguous() - return C diff --git a/run.sh b/run.sh index 8e701d6a..b547cbcd 100755 --- a/run.sh +++ b/run.sh @@ -25,6 +25,7 @@ rocprofv3 --version function run_flydsl_op { + export MLIR_ASM_VERBOSE=1 export FLIR_LOG_MORE=1 export FLIR_DUMP_IR=1 export FLIR_REBUILD=1 @@ -32,7 +33,7 @@ function run_flydsl_op { # python tests/kernels/test_moe_stage1_simple.py --size M - python tests/kernels/test_simple_gemm.py --size XL + python tests/kernels/test_simple_gemm.py --size XL --waves_per_eu 1 # python tests/kernels/test_simple_gemm.py --size NA4 } @@ -55,7 +56,8 @@ function get_flydsl_op_thread_trace { cd - rocprofv3 -i ./input.yaml -- \ - python tests/kernels/test_simple_gemm.py --size XL + python tests/kernels/test_simple_gemm.py --size XL --waves_per_eu 1 + # python tests/kernels/test_simple_gemm.py --size XL cd ./thread_trace cp -r ./rpf_v3/pass_1/*.att ${trace_dir} @@ -81,7 +83,7 @@ function get_flydsl_op_thread_trace { run_flydsl_op -get_flydsl_op_thread_trace +# get_flydsl_op_thread_trace set +x diff --git a/tests/kernels/test_simple_gemm.py b/tests/kernels/test_simple_gemm.py index 2c5595e4..e17127c7 100644 --- a/tests/kernels/test_simple_gemm.py +++ b/tests/kernels/test_simple_gemm.py @@ -3,7 +3,7 @@ Simple test script for the simple GEMM kernel. Usage: - python tests/kernels/test_simple_gemm.py [--size S|M|L|XL|NA1|NA2|all] [--dtype bf16|fp16|all] + python tests/kernels/test_simple_gemm.py [--size S|M|L|XL|NA1|NA2|all] [--dtype bf16|fp16|all] [--waves_per_eu N] Examples: python tests/kernels/test_simple_gemm.py # Run Small with bf16 @@ -11,6 +11,7 @@ python tests/kernels/test_simple_gemm.py --dtype all # Run Small with all dtypes python tests/kernels/test_simple_gemm.py --size all # Run all sizes with bf16 python tests/kernels/test_simple_gemm.py --size NA1 # Non-aligned test 1 + python tests/kernels/test_simple_gemm.py --waves_per_eu 2 # Set waves per EU hint to 2 """ import argparse @@ -135,6 +136,7 @@ def run_test( skip_ref: bool = False, rtol: float = 1e-2, atol: float = 1e-2, + waves_per_eu: int = None, ): """Run a single GEMM test.""" config = TEST_CONFIGS[size] @@ -145,15 +147,12 @@ def run_test( tile_n = config["tile_n"] tile_k = config["tile_k"] - # Pad all dimensions to tile sizes - M_pad = _align_up(M, tile_m) - N_pad = _align_up(N, tile_n) + # K must be padded to tile_k for MFMA vector loads K_pad = _align_up(K, tile_k) print("=" * 70) print(f"Running Simple GEMM Test: size={size} ({config['description']}), dtype={in_dtype}") - print(f" M={M}, N={N}, K={K}") - print(f" M_pad={M_pad}, N_pad={N_pad}, K_pad={K_pad}") + print(f" M={M}, N={N}, K={K} (K_pad={K_pad})") print(f" tile_m={tile_m}, tile_n={tile_n}, tile_k={tile_k}") print("=" * 70) @@ -166,38 +165,45 @@ def run_test( A_orig = torch.randn(M, K, dtype=torch_dtype, device=device) B_orig = torch.randn(N, K, dtype=torch_dtype, device=device) - # Run reference computation (using float32 for accuracy) with original dimensions + # Run reference computation (using float32 for accuracy) with original K if not skip_ref: A_f32 = A_orig.to(torch.float32) B_f32 = B_orig.to(torch.float32) C_ref = torch.mm(A_f32, B_f32.T).to(torch_dtype) - # Create padded tensors - A_pad = torch.zeros(M_pad, K_pad, dtype=torch_dtype, device=device) - B_pad = torch.zeros(N_pad, K_pad, dtype=torch_dtype, device=device) - C_pad = torch.zeros(M_pad, N_pad, dtype=torch_dtype, device=device) - - # Copy original data - A_pad[:M, :K] = A_orig - B_pad[:N, :K] = B_orig + # Pad K for kernel (M and N are handled by kernel mask-based boundary checks) + if K_pad != K: + A = torch.zeros(M, K_pad, dtype=torch_dtype, device=device) + B = torch.zeros(N, K_pad, dtype=torch_dtype, device=device) + A[:, :K] = A_orig + B[:, :K] = B_orig + else: + A = A_orig + B = B_orig + + # Create output tensor (original size, no padding needed for M and N) + C = torch.zeros(M, N, dtype=torch_dtype, device=device) # Compile kernel print("Compiling kernel...") + if waves_per_eu is not None: + print(f" waves_per_eu={waves_per_eu}") exe = compile_simple_gemm( tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, in_dtype=in_dtype, + waves_per_eu=waves_per_eu, ) print("Kernel compiled successfully.") # Flatten tensors for kernel interface - A_flat = A_pad.view(-1) - B_flat = B_pad.view(-1) - C_flat = C_pad.view(-1) + A_flat = A.view(-1) + B_flat = B.view(-1) + C_flat = C.view(-1) C_flat.zero_() # Define launch function for run_perftest def launch(): - exe(C_flat, A_flat, B_flat, M_pad, N_pad, K_pad, M, N) + exe(C_flat, A_flat, B_flat, M, N, K_pad) # Warmup and benchmark using run_perftest print(f"Running {num_warmup} warmup + {num_iters} benchmark iterations...") @@ -219,10 +225,9 @@ def launch(): if not skip_ref: # Run one more time for correctness check C_flat.zero_() - exe(C_flat, A_flat, B_flat, M_pad, N_pad, K_pad, M, N) + exe(C_flat, A_flat, B_flat, M, N, K_pad) torch.cuda.synchronize() - # Extract only the M×N portion from the padded output - C_result = C_pad[:M, :N] + C_result = C # Check correctness using verify_output passed = verify_output( @@ -304,6 +309,12 @@ def main(): default=1e-2, help="Absolute tolerance for correctness check (default: 1e-2)", ) + parser.add_argument( + "--waves_per_eu", + type=int, + default=None, + help="AMDGPU waves-per-eu hint for occupancy optimization (e.g., 1, 2, 4)", + ) args = parser.parse_args() @@ -330,6 +341,8 @@ def main(): dtypes = DTYPES if args.dtype == "all" else [args.dtype] print(f"\nRunning Simple GEMM tests: sizes={sizes}, dtypes={dtypes}") + if args.waves_per_eu is not None: + print(f"waves_per_eu: {args.waves_per_eu}") print(f"GPU: {torch.cuda.get_device_name(0)}\n") results = [] @@ -343,6 +356,7 @@ def main(): skip_ref=args.skip_ref, rtol=args.rtol, atol=args.atol, + waves_per_eu=args.waves_per_eu, ) results.append((size, dtype, passed)) From 5e1134296fd8b1fdef0bdab201b4837fac29e888 Mon Sep 17 00:00:00 2001 From: yanguahe Date: Wed, 4 Feb 2026 22:42:21 +0800 Subject: [PATCH 03/17] Fix hardware OOB handling in buffer ops to match Triton implementation - Add OOB_OFFSET (0x80000000) and MAX_NUM_RECORDS (0x7FFFFFFE) constants that match Triton's BufferOpsEmitter for reliable hardware OOB detection - Update buffer load/store to use OOB_OFFSET for masked-out elements, ensuring hardware always detects OOB when mask=False - Simplify GEMM kernel masking by removing redundant K boundary checks since K dimension is guaranteed to be padded to tile_k - Enable additional test cases in run.sh Co-authored-by: Cursor --- flydsl/src/flydsl/dialects/ext/buffer_ops.py | 62 +++++++++++++++----- kernels/simple_gemm.py | 15 ++--- run.sh | 3 +- 3 files changed, 52 insertions(+), 28 deletions(-) diff --git a/flydsl/src/flydsl/dialects/ext/buffer_ops.py b/flydsl/src/flydsl/dialects/ext/buffer_ops.py index 84170e9b..c2c5e895 100644 --- a/flydsl/src/flydsl/dialects/ext/buffer_ops.py +++ b/flydsl/src/flydsl/dialects/ext/buffer_ops.py @@ -40,6 +40,25 @@ 'i32_select', ] +# ============================================================================= +# Constants for Hardware OOB (Out-of-Bounds) Handling +# ============================================================================= +# These values are chosen to match Triton's implementation for reliable hardware +# OOB detection in AMD buffer load/store operations. +# +# Reference: triton/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp +# - OOB_OFFSET = static_cast(std::numeric_limits::max() + int64_t(1)) +# - numRecordsByte = std::numeric_limits::max() - 1 +# +# How it works: +# - When mask=False, offset is replaced with OOB_OFFSET (0x80000000) +# - Hardware compares: if (offset >= num_records) -> return 0 (load) or ignore (store) +# - 0x80000000 (as unsigned) = 2147483648 > 0x7FFFFFFE = 2147483646 +# - This guarantees hardware OOB detection triggers for masked-out elements +# ============================================================================= +OOB_OFFSET = 0x80000000 # -2147483648 as signed i32, 2147483648 as unsigned +MAX_NUM_RECORDS = 0x7FFFFFFE # 2147483646 (std::numeric_limits::max() - 1) + def create_llvm_ptr(value, address_space: int = 0) -> ir.Value: """Convert an index value to LLVM pointer. @@ -195,34 +214,41 @@ def _num_records_from_memref_type() -> Optional[int]: if num_records_bytes is not None: # Caller-provided size in BYTES (preferred for exact hardware OOB behavior). + # NOTE: When using masks, num_records should not exceed MAX_NUM_RECORDS + # to ensure OOB_OFFSET always triggers hardware OOB detection. if isinstance(num_records_bytes, int): nbytes = int(num_records_bytes) if nbytes <= 0: nbytes = 0 - # Descriptor uses i32 bytes; clamp to the max representable. - if nbytes > 0xFFFFFFFF: - nbytes = 0xFFFFFFFF + # Clamp to MAX_NUM_RECORDS to ensure OOB_OFFSET works correctly. + if nbytes > MAX_NUM_RECORDS: + nbytes = MAX_NUM_RECORDS num_records = _create_i32_constant(nbytes) else: # Value path: cast to i32 if needed. + # Note: For dynamic values, we trust the caller to provide valid sizes. + # If the buffer is larger than MAX_NUM_RECORDS, OOB detection may not + # work correctly for masked loads/stores. v = _unwrap_value(num_records_bytes) if not isinstance(v.type, ir.IntegerType) or v.type.width != 32: op = std_arith.IndexCastOp(ir.IntegerType.get_signless(32), v) v = _unwrap_value(op.result) num_records = v elif max_size: - # Use max for flexibility (hardware will check actual bounds) - # Note: flir's rocdl.make.buffer.rsrc requires i32, not i64 - num_records = _create_i32_constant(0xFFFFFFFF) # FALLBACK_MAX_SIZE + # Use MAX_NUM_RECORDS for flexibility with proper OOB handling. + # This value (0x7FFFFFFE) ensures that OOB_OFFSET (0x80000000) will + # always trigger hardware OOB detection. + num_records = _create_i32_constant(MAX_NUM_RECORDS) else: # Use the logical memref size (in bytes) for hardware OOB checking. nbytes = _num_records_from_memref_type() if nbytes is None: - # Fall back to max-size if we can't infer statically. - num_records = _create_i32_constant(0xFFFFFFFF) + # Fall back to MAX_NUM_RECORDS if we can't infer statically. + num_records = _create_i32_constant(MAX_NUM_RECORDS) else: - if nbytes > 0xFFFFFFFF: - nbytes = 0xFFFFFFFF + # Clamp to MAX_NUM_RECORDS for proper OOB handling with masks. + if nbytes > MAX_NUM_RECORDS: + nbytes = MAX_NUM_RECORDS num_records = _create_i32_constant(int(nbytes)) # Create resource descriptor (returns !llvm.ptr<8>) @@ -312,11 +338,13 @@ def buffer_load(rsrc: ir.Value, op = std_arith.MulIOp(offset, bytes_const) offset = _unwrap_value(op.result) - # Apply mask by setting invalid offsets to max + # Apply mask by setting invalid offsets to OOB_OFFSET + # When mask=False, offset becomes OOB_OFFSET (0x80000000), which is always + # >= MAX_NUM_RECORDS (0x7FFFFFFE), triggering hardware OOB (returns 0). if mask is not None: mask = _unwrap_value(mask) - max_offset = _create_i32_constant(0x7FFFFFFF) - op = std_arith.SelectOp(mask, offset, max_offset) + oob_offset = _create_i32_constant(OOB_OFFSET) + op = std_arith.SelectOp(mask, offset, oob_offset) offset = _unwrap_value(op.result) # Create vector type @@ -400,11 +428,13 @@ def buffer_store(data: ir.Value, op = std_arith.MulIOp(offset, bytes_const) offset = _unwrap_value(op.result) - # Apply mask by setting invalid offsets to max + # Apply mask by setting invalid offsets to OOB_OFFSET + # When mask=False, offset becomes OOB_OFFSET (0x80000000), which is always + # >= MAX_NUM_RECORDS (0x7FFFFFFE), triggering hardware OOB (store ignored). if mask is not None: mask = _unwrap_value(mask) - max_offset = _create_i32_constant(0x7FFFFFFF) - op = std_arith.SelectOp(mask, offset, max_offset) + oob_offset = _create_i32_constant(OOB_OFFSET) + op = std_arith.SelectOp(mask, offset, oob_offset) offset = _unwrap_value(op.result) # Create instruction offset (soffset) and aux flags diff --git a/kernels/simple_gemm.py b/kernels/simple_gemm.py index 83064b60..0adde3f4 100644 --- a/kernels/simple_gemm.py +++ b/kernels/simple_gemm.py @@ -240,7 +240,6 @@ def kernel_gemm( ) c4 = arith.constant(4, index=True) - c8 = arith.constant(8, index=True) tx_i32_base = tx * c4 atom_a_lds = flir.make_copy_atom(_elem_type(), vector_size=8) @@ -272,14 +271,11 @@ def load_a_tile(base_k): offset_dword = offset_elem / arith.constant(2, index=True) offset_i32 = arith.index_cast(i32, offset_dword) - # Mask: row_global < M and (k_global + 7) < K + # Mask: row_global < M (K is guaranteed to be padded to tile_k) row_valid = arith.cmpu(row_global, c_m, "ult") - k_end = k_global + c8 - k_valid = arith.cmpu(k_end, c_k + arith.constant(1, index=True), "ult") - mask = arith.andi(row_valid, k_valid) # Load 4 dwords (16 bytes = 8 bf16 elements) with mask - a_i32x4 = buffer_ops.buffer_load(a_rsrc, offset_i32, vec_width=4, dtype=i32, mask=mask) + a_i32x4 = buffer_ops.buffer_load(a_rsrc, offset_i32, vec_width=4, dtype=i32, mask=row_valid) parts.append(a_i32x4) return parts @@ -326,14 +322,11 @@ def load_b_packs_k64(base_k, ku: int, ni: int): offset_dword = offset_elem / arith.constant(2, index=True) offset_i32 = arith.index_cast(i32, offset_dword) - # Mask: n_global < N and (k_global + 7) < K + # Mask: n_global < N (K is guaranteed to be padded to tile_k) n_valid = arith.cmpu(n_global, c_n, "ult") - k_end = k_global + c8 - k_valid = arith.cmpu(k_end, c_k + arith.constant(1, index=True), "ult") - mask = arith.andi(n_valid, k_valid) # Load 4 dwords (16 bytes = 8 bf16 elements) with mask - b_i32x4 = buffer_ops.buffer_load(b_rsrc, offset_i32, vec_width=4, dtype=i32, mask=mask) + b_i32x4 = buffer_ops.buffer_load(b_rsrc, offset_i32, vec_width=4, dtype=i32, mask=n_valid) # Convert to vec8 bf16/fp16, then split into two i64 halves b_vec = vector.bitcast(vec8_elem, b_i32x4) diff --git a/run.sh b/run.sh index b547cbcd..34eb7ba8 100755 --- a/run.sh +++ b/run.sh @@ -34,7 +34,8 @@ function run_flydsl_op { # python tests/kernels/test_moe_stage1_simple.py --size M python tests/kernels/test_simple_gemm.py --size XL --waves_per_eu 1 - # python tests/kernels/test_simple_gemm.py --size NA4 + python tests/kernels/test_simple_gemm.py --size NA4 + python tests/kernels/test_simple_gemm.py --size all --dtype all } From 516fe7d4d68af5f31d062245c2f53b349f237cad Mon Sep 17 00:00:00 2001 From: yanguahe Date: Fri, 6 Feb 2026 20:59:53 +0800 Subject: [PATCH 04/17] Add unsafe_fp_math and fast_fp_math compiler options for faster GPU math - Add unsafe_fp_math and fast_fp_math parameters to compiler pipeline - Replace __ocml_exp2_f32 library calls with llvm.intr.exp2 intrinsics - Apply unsafe-fp-math function attributes to GPU kernel llvm.func ops - Add fastmath parameter support to arith.maximum operation - Improve test reproducibility with seed control and MD5 hash comparison - Add detailed array comparison utility for debugging numerical differences Co-authored-by: Cursor --- flydsl/src/flydsl/compiler/compiler.py | 139 +++++++++++++++++++++- flydsl/src/flydsl/dialects/ext/arith.py | 11 +- tests/kernels/test_simple_gemm.py | 148 +++++++++++++++++++++++- 3 files changed, 285 insertions(+), 13 deletions(-) diff --git a/flydsl/src/flydsl/compiler/compiler.py b/flydsl/src/flydsl/compiler/compiler.py index 7e2ec98c..e3eb320c 100644 --- a/flydsl/src/flydsl/compiler/compiler.py +++ b/flydsl/src/flydsl/compiler/compiler.py @@ -37,6 +37,8 @@ def _pipeline_fragments( use_bare_ptr_memref_call_conv: bool = False, use_bare_pointers_for_host: bool = False, use_bare_pointers_for_kernels: bool = False, + unsafe_fp_math: bool = False, + fast_fp_math: bool = False, ) -> list[str]: """FLIR compilation pipeline fragments as a plain list of strings. @@ -50,6 +52,8 @@ def _pipeline_fragments( rocdl_bare_ptr_opt = b2s(use_bare_ptr_memref_call_conv) llvm_bare_host_opt = b2s(use_bare_pointers_for_host) llvm_bare_kern_opt = b2s(use_bare_pointers_for_kernels) + unsafe_math_opt = b2s(unsafe_fp_math) + fast_opt = b2s(fast_fp_math) return [ "flir-to-standard", "trivial-dce", @@ -63,7 +67,7 @@ def _pipeline_fragments( "gpu.module(reconcile-unrealized-casts)", # Keep this as a formatted string so the chip is visible in dumps and matches # the non-dump compilation pipeline. - f"rocdl-attach-target{{O=2 abi=600 chip={chip} correct-sqrt=true daz=false fast=false features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64=true}}", + f"rocdl-attach-target{{O=2 abi=600 chip={chip} correct-sqrt=true daz=false fast={fast_opt} features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math={unsafe_math_opt} wave64=true}}", "gpu-to-llvm{intersperse-sizes-for-kernels=false " + f"use-bare-pointers-for-host={llvm_bare_host_opt} " + f"use-bare-pointers-for-kernels={llvm_bare_kern_opt}" @@ -92,6 +96,8 @@ def _build_pipeline_str( use_bare_ptr_memref_call_conv: bool = False, use_bare_pointers_for_host: bool = False, use_bare_pointers_for_kernels: bool = False, + unsafe_fp_math: bool = False, + fast_fp_math: bool = False, ) -> str: """Build the full PassManager pipeline string from `_pipeline_fragments`.""" frags = _pipeline_fragments( @@ -99,6 +105,8 @@ def _build_pipeline_str( use_bare_ptr_memref_call_conv=use_bare_ptr_memref_call_conv, use_bare_pointers_for_host=use_bare_pointers_for_host, use_bare_pointers_for_kernels=use_bare_pointers_for_kernels, + unsafe_fp_math=unsafe_fp_math, + fast_fp_math=fast_fp_math, ) return f"builtin.module({','.join(frags)})" @@ -192,6 +200,101 @@ def _infer_kernel_names_from_asm(asm: str) -> list[str]: return names +def _replace_ocml_exp2_with_intrinsic(module: ir.Module) -> ir.Module: + """Replace __ocml_exp2_f32 library calls with llvm.intr.exp2 intrinsics. + + The convert-gpu-to-rocdl pass lowers math.exp2 to __ocml_exp2_f32 which + generates a safe but slow 6-instruction pattern. By replacing with + llvm.intr.exp2 + fast math flags, we get bare v_exp_f32 (1 instruction). + + Returns a new module (or the original if replacement fails). + """ + import re + + try: + asm = module.operation.get_asm(enable_debug_info=True) + + # First replace all call sites, then remove the declaration. + # Use a broad pattern that handles loc() info and whitespace variants. + asm = re.sub( + r'llvm\.call @__ocml_exp2_f32\(([^)]+)\)\s*:\s*\(f32\)\s*->\s*f32', + r'llvm.intr.exp2(\1) {fastmathFlags = #llvm.fastmath} : (f32) -> f32', + asm, + ) + + # Remove the function declaration (it may have loc() info) + asm = re.sub( + r'\s*llvm\.func @__ocml_exp2_f32\(f32\)\s*->\s*f32[^\n]*\n', + '\n', + asm, + ) + + ctx = module.context + new_module = ir.Module.parse(asm, context=ctx) + return new_module + except Exception as e: + import sys + print(f"[flir.compile] WARNING: _replace_ocml_exp2_with_intrinsic failed: {e}", file=sys.stderr) + return module + + +def _apply_unsafe_fp_math_on_llvm_funcs(module: ir.Module) -> None: + """Apply 'unsafe-fp-math'='true' function attribute to GPU kernel llvm.func ops. + + This tells the LLVM AMDGPU backend to use fast/approximate math lowerings, + e.g. bare v_exp_f32 instead of the safe range-reduced exp2 pattern. + """ + entries = [] + for attr_name in ("unsafe-fp-math", "no-nans-fp-math", "no-infs-fp-math"): + key = ir.StringAttr.get(attr_name) + val = ir.StringAttr.get("true") + entries.append(ir.ArrayAttr.get([key, val])) + # Flush f32 denormals to zero so the AMDGPU backend emits bare v_exp_f32 + # instead of a safe exp2 pattern with range-checking / v_ldexp_f32. + key_denorm = ir.StringAttr.get("denormal-fp-math-f32") + val_denorm = ir.StringAttr.get("preserve-sign,preserve-sign") + entries.append(ir.ArrayAttr.get([key_denorm, val_denorm])) + entries_strs = {f"{n}=true" for n in ("unsafe-fp-math", "no-nans-fp-math", "no-infs-fp-math")} + entries_strs.add("denormal-fp-math-f32=preserve-sign,preserve-sign") + + def _append_passthrough(func_op): + try: + existing = func_op.attributes["passthrough"] + except KeyError: + existing = None + + if existing is None: + func_op.attributes["passthrough"] = ir.ArrayAttr.get(entries) + return + + try: + existing_entries = list(existing) + except TypeError: + func_op.attributes["passthrough"] = ir.ArrayAttr.get(entries) + return + + existing_strs = {str(a).strip('"') for a in existing_entries} + new_entries = list(existing_entries) + for entry, entry_str in zip(entries, entries_strs): + if entry_str not in existing_strs: + new_entries.append(entry) + func_op.attributes["passthrough"] = ir.ArrayAttr.get(new_entries) + + try: + for op in module.body.operations: + if getattr(op, "OPERATION_NAME", None) != "gpu.module": + continue + gpu_module_body = op.regions[0].blocks[0] if hasattr(op, 'regions') else op.body + for inner_op in gpu_module_body.operations: + if getattr(inner_op, "OPERATION_NAME", None) != "llvm.func": + continue + if "gpu.kernel" not in inner_op.attributes: + continue + _append_passthrough(inner_op) + except Exception: + pass + + def _apply_waves_per_eu_on_llvm_funcs(module: ir.Module, waves_per_eu: int) -> None: """Apply AMDGPU waves-per-eu hint to llvm.func ops via LLVM passthrough. @@ -258,6 +361,8 @@ def compile( use_bare_pointers_for_host: bool = False, use_bare_pointers_for_kernels: bool = False, waves_per_eu: Optional[int] = None, + unsafe_fp_math: bool = False, + fast_fp_math: bool = False, ) -> Executor: """Compile a FLIR module to an Executor. @@ -313,6 +418,8 @@ def compile( use_bare_ptr_memref_call_conv=use_bare_ptr_memref_call_conv, use_bare_pointers_for_host=use_bare_pointers_for_host, use_bare_pointers_for_kernels=use_bare_pointers_for_kernels, + unsafe_fp_math=unsafe_fp_math, + fast_fp_math=fast_fp_math, ) with ctx: @@ -370,6 +477,8 @@ def compile( use_bare_ptr_memref_call_conv=use_bare_ptr_memref_call_conv, use_bare_pointers_for_host=use_bare_pointers_for_host, use_bare_pointers_for_kernels=use_bare_pointers_for_kernels, + unsafe_fp_math=unsafe_fp_math, + fast_fp_math=fast_fp_math, ) # Keep dump filenames stable vs the historical numbering scheme: # 00_target_overridden, then 03..14 for pipeline stages, then 15_final_isa. @@ -391,7 +500,14 @@ def compile( # Apply waves_per_eu if specified (BEFORE saving asm_for_isa) if waves_per_eu is not None: _apply_waves_per_eu_on_llvm_funcs(module, waves_per_eu) - # Get ASM after applying waves_per_eu + # Apply unsafe-fp-math function attributes for fast exp2/math + if unsafe_fp_math: + _apply_unsafe_fp_math_on_llvm_funcs(module) + # Replace __ocml_exp2_f32 with llvm.intr.exp2 for fast exp2 + new_mod = _replace_ocml_exp2_with_intrinsic(module) + if new_mod is not module: + module = new_mod + # Get ASM after applying attributes asm_for_isa = module.operation.get_asm(enable_debug_info=True) if asm_for_isa is not None: @@ -405,14 +521,17 @@ def compile( isa_stage = f"{stage_num_base + len(stage_frags):02d}_final_isa" print(f"[flir.compile] dump {isa_stage} -> {isa_out}") else: - if waves_per_eu is not None: - # When waves_per_eu is specified, we need to split the pipeline - # to apply the attribute after LLVM lowering but before binary generation. + need_split = (waves_per_eu is not None) or unsafe_fp_math + if need_split: + # Need to split the pipeline to apply function attributes + # after LLVM lowering but before binary generation. stage_frags = _pipeline_fragments( chip=chip, use_bare_ptr_memref_call_conv=use_bare_ptr_memref_call_conv, use_bare_pointers_for_host=use_bare_pointers_for_host, use_bare_pointers_for_kernels=use_bare_pointers_for_kernels, + unsafe_fp_math=unsafe_fp_math, + fast_fp_math=fast_fp_math, ) # Run all passes except the last one (gpu-module-to-binary) pre_binary_frags = stage_frags[:-1] @@ -424,7 +543,15 @@ def compile( pm.run(module.operation) # Apply waves_per_eu - _apply_waves_per_eu_on_llvm_funcs(module, waves_per_eu) + if waves_per_eu is not None: + _apply_waves_per_eu_on_llvm_funcs(module, waves_per_eu) + # Apply unsafe-fp-math function attributes for fast exp2/math + if unsafe_fp_math: + _apply_unsafe_fp_math_on_llvm_funcs(module) + # Replace __ocml_exp2_f32 with llvm.intr.exp2 for fast exp2 + new_mod = _replace_ocml_exp2_with_intrinsic(module) + if new_mod is not module: + module = new_mod # Run the final binary generation pass pm_binary = PassManager.parse(f"builtin.module({binary_frag})", context=ctx) diff --git a/flydsl/src/flydsl/dialects/ext/arith.py b/flydsl/src/flydsl/dialects/ext/arith.py index 2afc3d86..edcf20f4 100644 --- a/flydsl/src/flydsl/dialects/ext/arith.py +++ b/flydsl/src/flydsl/dialects/ext/arith.py @@ -154,12 +154,13 @@ def f64(value: float, *, loc: Location = None, ip: InsertionPoint = None) -> "Ar """Create an f64 constant.""" return constant(value, type=F64Type.get(), loc=loc, ip=ip) -def maximum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *, loc: Location = None) -> "ArithValue": +def maximum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *, fastmath=None, loc: Location = None) -> "ArithValue": """Compute maximum of two values (automatically handles float/int types). Args: lhs: Left operand (ArithValue, Value, or Python number) rhs: Right operand (ArithValue, Value, or Python number) + fastmath: Optional fast-math flags (e.g. arith.FastMathFlags.fast) loc: Optional source location Returns: @@ -171,7 +172,7 @@ def maximum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *, >>> c = arith.maximum(a, b) # Function style >>> d = a.max(b) # Method style (equivalent) """ - return _minmax_op(lhs, rhs, op_type="max", loc=loc) + return _minmax_op(lhs, rhs, op_type="max", fastmath=fastmath, loc=loc) def minimum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *, loc: Location = None) -> "ArithValue": """Compute minimum of two values (automatically handles float/int types). @@ -788,6 +789,7 @@ def _minmax_op( rhs: "ArithValue", op_type: str, # "max" or "min" *, + fastmath=None, loc: Location = None, ) -> "ArithValue": """Execute min/max operation based on operand types.""" @@ -809,7 +811,10 @@ def _minmax_op( op_class = _arith.MaximumFOp else: op_class = _arith.MinimumFOp - result = op_class(lhs_val, rhs_val, loc=loc).result + if fastmath is not None: + result = op_class(lhs_val, rhs_val, fastmath=fastmath, loc=loc).result + else: + result = op_class(lhs_val, rhs_val, loc=loc).result elif _is_integer_like_type(lhs_val.type): # Integer min/max (signed/unsigned logic could be tricky, default to signed for now) # TODO: Add unsigned support if needed diff --git a/tests/kernels/test_simple_gemm.py b/tests/kernels/test_simple_gemm.py index e17127c7..255148af 100644 --- a/tests/kernels/test_simple_gemm.py +++ b/tests/kernels/test_simple_gemm.py @@ -15,8 +15,10 @@ """ import argparse +import hashlib import logging import os +import random import sys # Ensure repo-local flydsl is used @@ -24,6 +26,7 @@ if _REPO_ROOT not in sys.path: sys.path.insert(0, _REPO_ROOT) +import numpy as np import torch from kernels.simple_gemm import compile_simple_gemm, run_simple_gemm @@ -112,6 +115,18 @@ DTYPES = ["bf16", "fp16"] +# Tensor initialization range (uniform distribution) +UNIFORM_RANGE = (-1, 1) +DEFAULT_SEED = 123 + + +def setup_seed(seed: int) -> None: + """Set random seed for reproducibility across all RNG sources.""" + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + def get_torch_dtype(in_dtype: str): """Convert string dtype to torch dtype.""" @@ -128,6 +143,105 @@ def _align_up(val: int, align: int) -> int: return ((val + align - 1) // align) * align +def compute_md5(tensor: torch.Tensor) -> str: + """Compute MD5 hash of a tensor's raw bytes.""" + return hashlib.md5( + tensor.contiguous().view(torch.uint8).detach().cpu().numpy().tobytes() + ).hexdigest() + + +def compare_arrays( + arr1: np.ndarray, + arr2: np.ndarray, + k: int = 5, + thresholds: list = None, +) -> dict: + """Compare two numpy arrays and compute various difference metrics. + + Args: + arr1: First input array (result), will be cast to float32. + arr2: Second input array (reference), will be cast to float32. + k: Number of top differences to report. + thresholds: Difference magnitude buckets for histogram. + + Returns: + Dictionary with top_k_diff, threshold_stats, nan_info, max_diff, max_diff_thr. + """ + if thresholds is None: + thresholds = [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1] + + if arr1.shape != arr2.shape: + raise ValueError( + f"Shape mismatch: arr1 {arr1.shape} vs arr2 {arr2.shape}" + ) + + arr1 = arr1.astype(np.float32) + arr2 = arr2.astype(np.float32) + + result = {"top_k_diff": [], "threshold_stats": [], "nan_info": {}} + + # Check for NaN values + nan_mask1 = np.isnan(arr1) + nan_mask2 = np.isnan(arr2) + if np.any(nan_mask1): + result["nan_info"]["arr1_nan_count"] = int(np.sum(nan_mask1)) + print(f" Warning: result contains {result['nan_info']['arr1_nan_count']} NaN values") + if np.any(nan_mask2): + result["nan_info"]["arr2_nan_count"] = int(np.sum(nan_mask2)) + print(f" Warning: reference contains {result['nan_info']['arr2_nan_count']} NaN values") + + # Compute absolute differences + diff = np.abs(arr1 - arr2) + total_elements = arr1.size + + max_diff_thr = (diff / (1.0 + np.abs(arr2))).max() + result["max_diff"] = float(diff.max()) + result["max_diff_thr"] = float(max_diff_thr) + + print(f" diff.abs.max = {diff.max():.6f}") + print(f" diff.abs.mean = {diff.mean():.6f}") + print(f" max_diff_thr (rel) = {max_diff_thr:.6e}") + + # Find top k differences + flat_diff = diff.flatten() + actual_k = min(k, len(flat_diff)) + top_k_indices = np.argpartition(flat_diff, -actual_k)[-actual_k:] + top_k_indices = top_k_indices[np.argsort(-flat_diff[top_k_indices])] + + orig_indices = np.unravel_index(top_k_indices, diff.shape) + print(f" Top-{actual_k} differences:") + for i in range(actual_k): + idx = tuple(dim[i] for dim in orig_indices) + entry = { + "value": float(diff[idx]), + "position": idx, + "arr1_value": float(arr1[idx]), + "arr2_value": float(arr2[idx]), + } + result["top_k_diff"].append(entry) + print(f" [{idx}] result={arr1[idx]:.6f}, ref={arr2[idx]:.6f}, diff={diff[idx]:.6f}") + + # Compute threshold statistics + print(f" Threshold distribution ({total_elements} elements):") + for i in range(len(thresholds) - 1): + lower, upper = thresholds[i], thresholds[i + 1] + count = int(np.sum((diff >= lower) & (diff < upper))) + pct = 100.0 * count / total_elements + result["threshold_stats"].append( + {"range": f"[{lower:.0e}, {upper:.0e})", "count": count, "percentage": pct} + ) + print(f" [{lower:.0e}, {upper:.0e}): {count:>8d} ({pct:6.2f}%)") + + count = int(np.sum(diff >= thresholds[-1])) + pct = 100.0 * count / total_elements + result["threshold_stats"].append( + {"range": f">={thresholds[-1]:.0e}", "count": count, "percentage": pct} + ) + print(f" >={thresholds[-1]:.0e} : {count:>8d} ({pct:6.2f}%)") + + return result + + def run_test( size: str, in_dtype: str, @@ -137,6 +251,7 @@ def run_test( rtol: float = 1e-2, atol: float = 1e-2, waves_per_eu: int = None, + seed: int = DEFAULT_SEED, ): """Run a single GEMM test.""" config = TEST_CONFIGS[size] @@ -160,10 +275,10 @@ def run_test( device = "cuda" try: - # Create random inputs (original size) - torch.manual_seed(42) - A_orig = torch.randn(M, K, dtype=torch_dtype, device=device) - B_orig = torch.randn(N, K, dtype=torch_dtype, device=device) + # Create random inputs (uniform distribution in UNIFORM_RANGE) + setup_seed(seed) + A_orig = torch.empty(M, K, dtype=torch_dtype, device=device).uniform_(*UNIFORM_RANGE) + B_orig = torch.empty(N, K, dtype=torch_dtype, device=device).uniform_(*UNIFORM_RANGE) # Run reference computation (using float32 for accuracy) with original K if not skip_ref: @@ -229,6 +344,23 @@ def launch(): torch.cuda.synchronize() C_result = C + # Compute and print MD5 hashes + result_md5 = compute_md5(C_result) + ref_md5 = compute_md5(C_ref) + print(f" result_md5 = {result_md5}") + print(f" ref_md5 = {ref_md5}") + if result_md5 == ref_md5: + print(" MD5 match: EXACT (bit-identical)") + else: + print(" MD5 match: DIFFER (not bit-identical)") + + # Detailed comparison using compare_arrays + print(" --- compare_arrays ---") + compare_arrays( + C_result.to(torch.float32).detach().cpu().numpy(), + C_ref.to(torch.float32).detach().cpu().numpy(), + ) + # Check correctness using verify_output passed = verify_output( C_result.to(torch.float32), @@ -315,6 +447,12 @@ def main(): default=None, help="AMDGPU waves-per-eu hint for occupancy optimization (e.g., 1, 2, 4)", ) + parser.add_argument( + "--seed", + type=int, + default=DEFAULT_SEED, + help=f"Random seed for reproducibility (default: {DEFAULT_SEED})", + ) args = parser.parse_args() @@ -341,6 +479,7 @@ def main(): dtypes = DTYPES if args.dtype == "all" else [args.dtype] print(f"\nRunning Simple GEMM tests: sizes={sizes}, dtypes={dtypes}") + print(f"seed: {args.seed}") if args.waves_per_eu is not None: print(f"waves_per_eu: {args.waves_per_eu}") print(f"GPU: {torch.cuda.get_device_name(0)}\n") @@ -357,6 +496,7 @@ def main(): rtol=args.rtol, atol=args.atol, waves_per_eu=args.waves_per_eu, + seed=args.seed, ) results.append((size, dtype, passed)) From ab58591abbc96c120aecb51dd91ca99bd62b810c Mon Sep 17 00:00:00 2001 From: yanguahe Date: Thu, 12 Feb 2026 18:01:03 +0800 Subject: [PATCH 05/17] Add peng's Flash Attention PR: https://github.com/sunway513/FlyDSL/pull/2 --- flydsl/src/flydsl/dialects/ext/scf.py | 93 ++- flydsl/src/flydsl/lang/ir/module.py | 15 + kernels/flash_attention_v4.py | 539 +++++++++++++++++ kernels/flash_attention_v4_1.py | 612 +++++++++++++++++++ kernels/flash_attention_v4_2.py | 667 +++++++++++++++++++++ kernels/flash_attention_v4_3.py | 650 ++++++++++++++++++++ kernels/reduce.py | 44 ++ run.sh | 10 +- tests/kernels/test_flash_attention_v4.py | 288 +++++++++ tests/kernels/test_flash_attention_v4_1.py | 288 +++++++++ tests/kernels/test_flash_attention_v4_2.py | 268 +++++++++ tests/kernels/test_flash_attention_v4_3.py | 268 +++++++++ 12 files changed, 3687 insertions(+), 55 deletions(-) create mode 100644 kernels/flash_attention_v4.py create mode 100644 kernels/flash_attention_v4_1.py create mode 100644 kernels/flash_attention_v4_2.py create mode 100644 kernels/flash_attention_v4_3.py create mode 100644 tests/kernels/test_flash_attention_v4.py create mode 100644 tests/kernels/test_flash_attention_v4_1.py create mode 100644 tests/kernels/test_flash_attention_v4_2.py create mode 100644 tests/kernels/test_flash_attention_v4_3.py diff --git a/flydsl/src/flydsl/dialects/ext/scf.py b/flydsl/src/flydsl/dialects/ext/scf.py index 5e87dfb8..f11c31c3 100644 --- a/flydsl/src/flydsl/dialects/ext/scf.py +++ b/flydsl/src/flydsl/dialects/ext/scf.py @@ -20,6 +20,33 @@ from .arith import constant +def _as_value(v): + """Unwrap various 'Value-like' wrappers (ArithValue, etc.) to a raw MLIR Value. + + This is needed because generated op builders (like ``_scf.YieldOp``) check + ``isinstance(v, Value)`` which fails for ``ArithValue`` wrappers created by + ``register_value_caster``. + """ + seen = set() + while True: + if isinstance(v, Value): + return v + obj_id = id(v) + if obj_id in seen: + return v + seen.add(obj_id) + if hasattr(v, "_value"): + v = v._value + continue + if hasattr(v, "value"): + v = v.value + continue + if hasattr(v, "result"): + v = v.result + continue + return v + + def _normalize_if_condition(condition): """Best-effort normalization for scf.if conditions. @@ -183,34 +210,10 @@ def range_( start, stop, step = canonicalize_range(start, stop, step) - # Unwrap various "Value-like" wrappers down to a real `_mlir.ir.Value`. - # We need this because our arithmetic helpers often return wrapper objects - # (e.g. `ArithValue`) which are not accepted as operands by generated op - # builders (like `_scf.ForOp`). - def _as_value(v): - seen = set() - while True: - if isinstance(v, Value): - return v - obj_id = id(v) - if obj_id in seen: - return v - seen.add(obj_id) - if hasattr(v, "_value"): - v = v._value - continue - if hasattr(v, "value"): - v = v.value - continue - if hasattr(v, "result"): - v = v.result - continue - return v - start = _as_value(start) stop = _as_value(stop) step = _as_value(step) - + iter_args = iter_args or [] iter_args = [_as_value(a) for a in iter_args] for_op = _scf.ForOp(start, stop, step, iter_args, loc=loc, ip=ip) @@ -227,7 +230,8 @@ def _as_value(v): # Ensure scf.for body is terminated. block = for_op.body if (not block.operations) or not isinstance(block.operations[-1], _scf.YieldOp): - _scf.YieldOp(list(for_op.inner_iter_args)) + # Unwrap ArithValue wrappers before passing to YieldOp + _scf.YieldOp([_as_value(a) for a in for_op.inner_iter_args]) @contextmanager @@ -282,26 +286,6 @@ def for_( return start, stop, step = canonicalize_range(start, stop, step) - # Unwrap various "Value-like" wrappers down to a real `_mlir.ir.Value`. - def _as_value(v): - seen = set() - while True: - if isinstance(v, Value): - return v - obj_id = id(v) - if obj_id in seen: - return v - seen.add(obj_id) - if hasattr(v, "_value"): - v = v._value - continue - if hasattr(v, "value"): - v = v.value - continue - if hasattr(v, "result"): - v = v.result - continue - return v start = _as_value(start) stop = _as_value(stop) @@ -316,7 +300,8 @@ def _as_value(v): finally: block = for_op.body if (not block.operations) or not isinstance(block.operations[-1], _scf.YieldOp): - _scf.YieldOp(list(for_op.inner_iter_args)) + # Unwrap ArithValue wrappers before passing to YieldOp + _scf.YieldOp([_as_value(a) for a in for_op.inner_iter_args]) @contextmanager @@ -475,16 +460,20 @@ def yield_( ip: InsertionPoint = None, ): """Create an scf.yield operation. - + + Automatically unwraps ArithValue wrappers so callers don't need + ``arith.as_value()`` on every yielded operand. + Args: - operands: Values to yield - loc: Location for the operation - ip: Insertion point + operands: Values to yield (accepts ArithValue wrappers). + loc: Location for the operation. + ip: Insertion point. """ if loc is None: loc = Location.unknown() - + operands = operands or [] + operands = [_as_value(o) for o in operands] return _scf.YieldOp(operands, loc=loc, ip=ip) diff --git a/flydsl/src/flydsl/lang/ir/module.py b/flydsl/src/flydsl/lang/ir/module.py index a4e01498..a538ab0d 100644 --- a/flydsl/src/flydsl/lang/ir/module.py +++ b/flydsl/src/flydsl/lang/ir/module.py @@ -318,6 +318,21 @@ def wrapper(instance_self, *args, **kwargs): k.qualname = instance_self.GPU_MODULE_NAME except Exception: pass + # Set known_block_size if GPU_BLOCK_SIZE is defined on the module class. + # This tells convert-gpu-to-rocdl the workgroup size so that + # max_flat_workgroup_size is set correctly in the ISA. + block_size = getattr(instance_self, "GPU_BLOCK_SIZE", None) + if block_size is not None: + try: + if isinstance(block_size, int): + block_size = (block_size, 1, 1) + func_op = k._func_op if hasattr(k, "_func_op") else k + op = getattr(func_op, "operation", func_op) + op.attributes["known_block_size"] = ir.DenseI32ArrayAttr.get( + list(block_size) + ) + except Exception: + pass instance_self.kernel_func_op[fn.__name__] = k self._wrapper = wrapper diff --git a/kernels/flash_attention_v4.py b/kernels/flash_attention_v4.py new file mode 100644 index 00000000..1eb3c7b6 --- /dev/null +++ b/kernels/flash_attention_v4.py @@ -0,0 +1,539 @@ +"""Flash Attention V4 kernel builder for FlyDSL. + +Multi-wave MFMA implementation: 4 waves (256 threads), BLOCK_M=64, BLOCK_N=16. +Each wave owns 16 Q-rows and performs independent MFMA Q@K^T and P@V. +All 256 threads cooperate on K/V tile loads. + +Inspired by the poc_kl ASM kernel (8w, 256x32, mfma_32x32x8, ping-pong LDS). +This V4.0 uses mfma_f32_16x16x16f16 with single-buffered K/V as a first step. + +V4.0 vs V3.1: +- 4 waves (256 threads) vs 1 wave (64 threads) → 4x more MFMA throughput per CU +- BLOCK_M=64 vs 16 → 4x fewer grid workgroups, better scheduling +- Cooperative load: 256 threads → K/V tile (16×HD) loaded in 1 batch (vs 4) +- LDS: ~30KB (Q=16KB + KV=4KB + P=2KB) + +Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). +Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. +Block: (256,) — 4 waves of 64 on AMD (wave64). + +Requires: head_dim % 16 == 0, seq_len % 64 == 0, head_dim >= 64. +""" + +import math + +from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl +from flydsl.dialects.ext import vector as vec_ext +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.dialects.ext.scf import yield_ as scf_yield +from _mlir.dialects import memref as _memref +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils import SmemAllocator +from _mlir import ir +import _mlir.extras.types as T + + +KERNEL_NAME = "flash_attention_v4_kernel" + + +def build_flash_attention_v4_module( + num_heads, + head_dim, + causal=True, + dtype_str="f16", + sm_scale=None, +): + """Build a FlyDSL Flash Attention V4 module with multi-wave MFMA tiling. + + Args: + num_heads: Number of attention heads. + head_dim: Dimension per head (must be divisible by 16, >= 64). + causal: Whether to apply causal mask. + dtype_str: "f16" (bf16 not yet supported). + sm_scale: Softmax scale (default: 1/sqrt(head_dim)). + + Returns: + MlirModule compilable via ``flydsl.compile(module)``. + """ + gpu_arch = get_hip_arch() + DYN = ir.ShapedType.get_dynamic_size() + + BLOCK_M = 64 + BLOCK_N = 16 + NUM_WAVES = 4 + WARP_SIZE = 64 + BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 + ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 16 + K_STEPS = head_dim // 16 + + assert head_dim % 16 == 0, f"head_dim ({head_dim}) must be divisible by 16" + assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" + assert dtype_str == "f16", "V4 currently only supports f16" + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + NUM_HEADS = num_heads + HEAD_DIM = head_dim + CAUSAL = causal + STRIDE_TOKEN = NUM_HEADS * HEAD_DIM + + # ---- Vectorized cooperative load constants ---- + VEC_WIDTH = 8 # v8f16 = 16 bytes + THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH + assert BLOCK_SIZE % THREADS_PER_ROW_LOAD == 0 + ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD + + # For Q tile (64 rows): NUM_BATCHES_Q = 64 / ROWS_PER_BATCH + assert BLOCK_M % ROWS_PER_BATCH_LOAD == 0 + NUM_BATCHES_Q = BLOCK_M // ROWS_PER_BATCH_LOAD + + # For KV tile (16 rows): NUM_BATCHES_KV = 16 / ROWS_PER_BATCH + assert BLOCK_N % ROWS_PER_BATCH_LOAD == 0 or ROWS_PER_BATCH_LOAD >= BLOCK_N + if ROWS_PER_BATCH_LOAD >= BLOCK_N: + NUM_BATCHES_KV = 1 + # Some threads will be idle (load_row >= BLOCK_N). Need guard. + KV_NEEDS_GUARD = ROWS_PER_BATCH_LOAD > BLOCK_N + else: + NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD + KV_NEEDS_GUARD = False + + allocator = SmemAllocator(None, arch=gpu_arch) + _state = {} + + class _FlashAttentionV4(flir.MlirModule): + GPU_MODULE_NAME = f"flash_attn_v4_{dtype_str}" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + elem_type = T.f16() + _state["elem_type"] = elem_type + _state["lds_q"] = allocator.allocate_array(elem_type, BLOCK_M * HEAD_DIM) + _state["lds_kv"] = allocator.allocate_array(elem_type, BLOCK_N * HEAD_DIM) + _state["lds_p"] = allocator.allocate_array(elem_type, BLOCK_M * BLOCK_N) + allocator.finalize() + + @flir.kernel + def flash_attention_v4_kernel( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + seq_len: lambda: T.index(), + ): + compute_type = T.f32() + elem_type = _state["elem_type"] + fm_fast = flir.arith.FastMathFlags.fast + + v4f16_type = ir.VectorType.get([4], elem_type) + v4f32_type = ir.VectorType.get([4], compute_type) + v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) + + seq_len_v = arith.as_value(seq_len) + + # ---- LDS views ---- + base_ptr = allocator.get_base() + lds_q = _state["lds_q"](base_ptr).get() + lds_kv = _state["lds_kv"](base_ptr).get() + lds_p = _state["lds_p"](base_ptr).get() + + # ---- Thread / block indices ---- + block_id = flir.const_index(flir.block_idx("x")) + tid = flir.const_index(flir.thread_idx("x")) + + # ---- Wave decomposition ---- + c_ws = flir.const_index(WARP_SIZE) + wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) + lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) + + # ---- MFMA lane decomposition (within each wave) ---- + c16 = flir.const_index(16) + lane_div_16 = arith.as_value(flir.arith.DivUIOp(lane, c16).result) + lane_mod_16 = arith.as_value(flir.arith.RemUIOp(lane, c16).result) + + # ---- Wave's Q-row offset in the Q tile ---- + wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value + # Wave's P offset in lds_p + wave_p_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE * BLOCK_N).value + + # ---- Decompose block_id -> (batch_idx, q_tile_idx, head_idx) ---- + c_nh = flir.const_index(NUM_HEADS) + head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) + temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) + c_bm = flir.const_index(BLOCK_M) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) + q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) + batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) + q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value + + # ---- Vectorized load thread decomposition ---- + c_tpr = flir.const_index(THREADS_PER_ROW_LOAD) + load_row_in_batch = arith.as_value( + flir.arith.DivUIOp(tid, c_tpr).result + ) + load_lane_in_row = arith.as_value( + flir.arith.RemUIOp(tid, c_tpr).result + ) + load_col_base = ( + arith.ArithValue(load_lane_in_row) * VEC_WIDTH + ).value + + # ---- Helper: global flat index ---- + def global_idx(token_idx, col): + token = ( + arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) + + arith.ArithValue(token_idx) + ) + return ( + token * STRIDE_TOKEN + + arith.ArithValue(head_idx) * HEAD_DIM + + arith.ArithValue(col) + ).value + + # ---- Cooperative Q load (64 rows, all 256 threads) ---- + def coop_load_q(): + for batch in range_constexpr(NUM_BATCHES_Q): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(q_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value( + vec_ext.load_op(v8f16_type, Q, [g_idx]) + ) + lds_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + lds_idx = ( + arith.ArithValue(lds_row) * HEAD_DIM + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_q, [lds_idx]) + + # ---- Cooperative KV load (16 rows, 256 threads — may need guard) ---- + def coop_load_kv(src_memref, lds_memref, tile_start): + if KV_NEEDS_GUARD: + # With 256 threads and THREADS_PER_ROW=16, ROWS_PER_BATCH=16 + # which equals BLOCK_N=16. No guard needed in this config. + # But for safety, handle the case where ROWS_PER_BATCH > BLOCK_N. + c_bn = flir.const_index(BLOCK_N) + row_ok = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + load_row_in_batch, c_bn, + ).result + ) + # Only threads with row < BLOCK_N participate + # Use scf.if_ for conditional store + from flydsl.dialects.ext.scf import IfOp + if_op = IfOp(row_ok) + with if_op: + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + ).value + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value( + vec_ext.load_op(v8f16_type, src_memref, [g_idx]) + ) + lds_idx = ( + arith.ArithValue(load_row_in_batch) * HEAD_DIM + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_memref, [lds_idx]) + else: + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value( + vec_ext.load_op(v8f16_type, src_memref, [g_idx]) + ) + lds_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + lds_idx = ( + arith.ArithValue(lds_row) * HEAD_DIM + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_memref, [lds_idx]) + + # ---- Load Q tile to LDS ---- + coop_load_q() + gpu.barrier() + + # ---- Constants ---- + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_zero_f = arith.constant(0.0, type=compute_type) + c_sm_scale = arith.constant(sm_scale, type=compute_type) + c_log2e = arith.constant(1.4426950408889634, type=compute_type) + c_zero_v4f32 = arith.as_value( + arith.constant_vector(0.0, v4f32_type) + ) + + # ---- Init loop-carried state ---- + # [m_0..m_3, l_0..l_3, o_acc_0..o_acc_{K_STEPS-1}] + init_args = [] + for _ in range_constexpr(4): + init_args.append(arith.as_value(c_neg_inf)) + for _ in range_constexpr(4): + init_args.append(arith.as_value(c_zero_f)) + for _ in range_constexpr(K_STEPS): + init_args.append(c_zero_v4f32) + + # ---- KV loop ---- + with scf.for_(0, seq_len_v, BLOCK_N, iter_args=init_args) as loop: + kv_start = arith.as_value(loop.induction_variable) + m_old = [arith.as_value(loop.inner_iter_args[i]) for i in range(4)] + l_old = [arith.as_value(loop.inner_iter_args[4 + i]) for i in range(4)] + o_accs = [arith.as_value(loop.inner_iter_args[8 + ds]) for ds in range(K_STEPS)] + + # ==== Cooperative K load -> LDS_KV ==== + coop_load_kv(K, lds_kv, kv_start) + gpu.barrier() + + # ==== Q @ K^T via MFMA (each wave uses its Q rows) ==== + s_acc = c_zero_v4f32 + for ks in range_constexpr(K_STEPS): + # A operand (Q): lane's Q row within this wave's 16 rows + q_lds_idx = ( + (arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_mod_16)) * HEAD_DIM + + ks * 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + a_pack = arith.as_value( + vec_ext.load_op(v4f16_type, lds_q, [q_lds_idx]) + ) + # B operand (K^T): same for all waves (shared K tile) + k_lds_idx = ( + arith.ArithValue(lane_mod_16) * HEAD_DIM + + ks * 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + b_pack = arith.as_value( + vec_ext.load_op(v4f16_type, lds_kv, [k_lds_idx]) + ) + s_acc = arith.as_value( + rocdl.mfma_f32_16x16x16f16(v4f32_type, [a_pack, b_pack, s_acc, 0, 0, 0]) + ) + + # ==== Online softmax (per-wave, per-row) ==== + s_vals = [] + for ii in range_constexpr(4): + s_ii = arith.as_value( + vec_ext.extract(s_acc, static_position=[ii], dynamic_position=[]) + ) + s_ii = arith.as_value( + flir.arith.MulFOp(s_ii, arith.as_value(c_sm_scale), fastmath=fm_fast).result + ) + if CAUSAL: + # Global Q row for this lane's ii-th element + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_div_16) * 4 + + ii + ).value + kv_col = (arith.ArithValue(kv_start) + arith.ArithValue(lane_mod_16)).value + q_row_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), q_row).result) + kv_col_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), kv_col).result) + is_masked = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ugt, kv_col_i64, q_row_i64, + ).result + ) + s_ii = arith.as_value( + flir.arith.SelectOp(is_masked, arith.as_value(c_neg_inf), s_ii).result + ) + s_vals.append(s_ii) + + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + m_new = [None] * 4 + corr = [None] * 4 + p_vals = [None] * 4 + l_new = [None] * 4 + + for ii in range_constexpr(4): + row_max = s_vals[ii] + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(row_max, sh_i32, width_i32, mode="xor").shuffleResult + ) + row_max = arith.as_value( + flir.arith.MaximumFOp(row_max, peer).result + ) + + m_new[ii] = arith.as_value( + flir.arith.MaximumFOp(m_old[ii], row_max).result + ) + + diff_m = arith.as_value( + flir.arith.SubFOp(m_old[ii], m_new[ii], fastmath=fm_fast).result + ) + diff_m_s = arith.as_value( + flir.arith.MulFOp(diff_m, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + corr[ii] = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) + + diff_s = arith.as_value( + flir.arith.SubFOp(s_vals[ii], m_new[ii], fastmath=fm_fast).result + ) + diff_s_s = arith.as_value( + flir.arith.MulFOp(diff_s, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + p_vals[ii] = arith.as_value(flir.math.exp2(diff_s_s, fastmath=fm_fast)) + + row_sum = p_vals[ii] + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(row_sum, sh_i32, width_i32, mode="xor").shuffleResult + ) + row_sum = arith.as_value( + flir.arith.AddFOp(row_sum, peer, fastmath=fm_fast).result + ) + + l_corr = arith.as_value( + flir.arith.MulFOp(corr[ii], l_old[ii], fastmath=fm_fast).result + ) + l_new[ii] = arith.as_value( + flir.arith.AddFOp(l_corr, row_sum, fastmath=fm_fast).result + ) + + # ==== Rescale O accumulators ==== + corr_vec = arith.as_value( + vec_ext.from_elements(v4f32_type, [corr[0], corr[1], corr[2], corr[3]]) + ) + for ds in range_constexpr(K_STEPS): + o_accs[ds] = arith.as_value( + flir.arith.MulFOp(o_accs[ds], corr_vec, fastmath=fm_fast).result + ) + + # ==== P store to LDS_P (each wave writes its 16×16 section) ==== + for ii in range_constexpr(4): + p_f16 = arith.as_value( + flir.arith.TruncFOp(elem_type, p_vals[ii]).result + ) + p_row = (arith.ArithValue(lane_div_16) * 4 + ii).value + p_lds_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(p_row) * BLOCK_N + + arith.ArithValue(lane_mod_16) + ).value + _memref.StoreOp(p_f16, lds_p, [p_lds_idx]) + + # ==== Barrier: ensure all waves finished reading K from lds_kv ==== + gpu.barrier() + + # ==== Cooperative V load -> LDS_KV (overwrites K) ==== + coop_load_kv(V, lds_kv, kv_start) + gpu.barrier() + + # ==== P load (A-operand, wave-local) ==== + p_a_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(lane_mod_16) * BLOCK_N + + arith.ArithValue(lane_div_16) * 4 + ).value + p_pack = arith.as_value( + vec_ext.load_op(v4f16_type, lds_p, [p_a_idx]) + ) + + # ==== P @ V via MFMA ==== + for ds in range_constexpr(K_STEPS): + v_elems = [] + for e in range_constexpr(4): + v_row = (arith.ArithValue(lane_div_16) * 4 + e).value + v_lds_idx = ( + arith.ArithValue(v_row) * HEAD_DIM + + ds * 16 + + arith.ArithValue(lane_mod_16) + ).value + v_val = _memref.LoadOp(lds_kv, [v_lds_idx]).result + v_elems.append(arith.as_value(v_val)) + v_pack = arith.as_value( + vec_ext.from_elements(v4f16_type, v_elems) + ) + o_accs[ds] = arith.as_value( + rocdl.mfma_f32_16x16x16f16( + v4f32_type, [p_pack, v_pack, o_accs[ds], 0, 0, 0] + ) + ) + + # ==== Barrier: ensure all waves finished P@V (reading lds_kv) + # before next iteration overwrites lds_kv with K ==== + gpu.barrier() + + # ==== Yield ==== + yield_args = m_new + l_new + o_accs + scf_yield(yield_args) + + # ---- Normalize and store O ---- + m_finals = [arith.as_value(loop.results[i]) for i in range(4)] + l_finals = [arith.as_value(loop.results[4 + i]) for i in range(4)] + o_finals = [arith.as_value(loop.results[8 + ds]) for ds in range(K_STEPS)] + + for ds in range_constexpr(K_STEPS): + for ii in range_constexpr(4): + o_val = arith.as_value( + vec_ext.extract(o_finals[ds], static_position=[ii], dynamic_position=[]) + ) + o_norm = arith.as_value( + flir.arith.DivFOp(o_val, l_finals[ii], fastmath=fm_fast).result + ) + o_f16 = arith.as_value( + flir.arith.TruncFOp(elem_type, o_norm).result + ) + # Global store: each wave writes its Q rows + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_div_16) * 4 + + ii + ).value + d_col = (flir.const_index(ds * 16) + arith.ArithValue(lane_mod_16)).value + o_global = global_idx(q_row, d_col) + _memref.StoreOp(o_f16, O, [o_global]) + + @flir.jit + def __call__( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + batch_size: lambda: T.index(), + seq_len: lambda: T.index(), + ): + c1 = arith.as_value(flir.arith_ext.index(1)) + c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) + c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) + bs_val = arith.as_value(batch_size) + sl_val = arith.as_value(seq_len) + num_q_tiles = arith.as_value( + flir.arith.DivUIOp(sl_val, c_bm).result + ) + bs_qt = arith.as_value( + flir.arith.MulIOp(bs_val, num_q_tiles).result + ) + grid_x = arith.as_value( + flir.arith.MulIOp(bs_qt, c_nh).result + ) + bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) + flir.gpu_ext.LaunchFuncOp( + [self.GPU_MODULE_NAME, KERNEL_NAME], + grid_size=(grid_x, c1, c1), + block_size=(bx, c1, c1), + kernel_operands=[Q, K, V, O, seq_len], + ) + + return _FlashAttentionV4() diff --git a/kernels/flash_attention_v4_1.py b/kernels/flash_attention_v4_1.py new file mode 100644 index 00000000..9c407a28 --- /dev/null +++ b/kernels/flash_attention_v4_1.py @@ -0,0 +1,612 @@ +"""Flash Attention V4.1 kernel builder for FlyDSL. + +V4.1 optimizations over V4.0: +- Q preloaded to registers (eliminates Q LDS reads from KV loop) +- V stored transposed in LDS (vectorized v4f16 B-operand loads) +- Bank-conflict-free LDS padding (K stride=HD+2, V transposed stride=BLOCK_N+2) + +Tile config: BLOCK_M=64, BLOCK_N=16, 4 waves (256 threads), mfma_f32_16x16x16f16. + +Expected improvements from V4.0: +- ~32% fewer LDS instructions (Q reads eliminated, V loads vectorized) +- Reduced LDS bank conflicts + +Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). +Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. +Block: (256,) -- 4 waves of 64 on AMD (wave64). + +Requires: head_dim % 16 == 0, seq_len % 64 == 0, head_dim >= 64. +""" + +import math + +from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl +from flydsl.dialects.ext import vector as vec_ext +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.dialects.ext.scf import yield_ as scf_yield +from _mlir.dialects import memref as _memref +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils import SmemAllocator +from _mlir import ir +import _mlir.extras.types as T + + +KERNEL_NAME = "flash_attention_v4_1_kernel" + + +def build_flash_attention_v4_1_module( + num_heads, + head_dim, + causal=True, + dtype_str="f16", + sm_scale=None, +): + """Build a FlyDSL Flash Attention V4.1 module. + + Args: + num_heads: Number of attention heads. + head_dim: Dimension per head (must be divisible by 16, >= 64). + causal: Whether to apply causal mask. + dtype_str: "f16" (bf16 not yet supported). + sm_scale: Softmax scale (default: 1/sqrt(head_dim)). + + Returns: + MlirModule compilable via ``flydsl.compile(module)``. + """ + gpu_arch = get_hip_arch() + DYN = ir.ShapedType.get_dynamic_size() + + BLOCK_M = 64 + BLOCK_N = 16 + NUM_WAVES = 4 + WARP_SIZE = 64 + BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 + ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 16 + K_STEPS = head_dim // 16 + + assert head_dim % 16 == 0, f"head_dim ({head_dim}) must be divisible by 16" + assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" + assert dtype_str == "f16", "V4.1 currently only supports f16" + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + NUM_HEADS = num_heads + HEAD_DIM = head_dim + CAUSAL = causal + STRIDE_TOKEN = NUM_HEADS * HEAD_DIM + + # ---- Bank-conflict-free LDS strides ---- + # K row-major: stride = HD + 2 (makes row stride odd in bank units) + # V transposed: stride = BLOCK_N + 2 (same reasoning) + K_STRIDE = HEAD_DIM + 2 # 130 for HD=128 + VT_STRIDE = BLOCK_N + 2 # 18 for BLOCK_N=16 + + # ---- Vectorized cooperative load constants ---- + VEC_WIDTH = 8 # v8f16 = 16 bytes + THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH + assert BLOCK_SIZE % THREADS_PER_ROW_LOAD == 0 + ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD + + # For Q tile (64 rows) + assert BLOCK_M % ROWS_PER_BATCH_LOAD == 0 + NUM_BATCHES_Q = BLOCK_M // ROWS_PER_BATCH_LOAD + + # For KV tile (16 rows) + assert BLOCK_N % ROWS_PER_BATCH_LOAD == 0 or ROWS_PER_BATCH_LOAD >= BLOCK_N + if ROWS_PER_BATCH_LOAD >= BLOCK_N: + NUM_BATCHES_KV = 1 + KV_NEEDS_GUARD = ROWS_PER_BATCH_LOAD > BLOCK_N + else: + NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD + KV_NEEDS_GUARD = False + + # LDS sizes + LDS_Q_SIZE = BLOCK_M * HEAD_DIM # Q unpadded (only read once for register preload) + LDS_KV_SIZE = max(BLOCK_N * K_STRIDE, HEAD_DIM * VT_STRIDE) # max(K padded, Vt padded) + LDS_P_SIZE = BLOCK_M * BLOCK_N + + allocator = SmemAllocator(None, arch=gpu_arch) + _state = {} + + class _FlashAttentionV4_1(flir.MlirModule): + GPU_MODULE_NAME = f"flash_attn_v4_1_{dtype_str}" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + elem_type = T.f16() + _state["elem_type"] = elem_type + _state["lds_q"] = allocator.allocate_array(elem_type, LDS_Q_SIZE) + _state["lds_kv"] = allocator.allocate_array(elem_type, LDS_KV_SIZE) + _state["lds_p"] = allocator.allocate_array(elem_type, LDS_P_SIZE) + allocator.finalize() + + @flir.kernel + def flash_attention_v4_1_kernel( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + seq_len: lambda: T.index(), + ): + compute_type = T.f32() + elem_type = _state["elem_type"] + fm_fast = flir.arith.FastMathFlags.fast + + v4f16_type = ir.VectorType.get([4], elem_type) + v4f32_type = ir.VectorType.get([4], compute_type) + v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) + + seq_len_v = arith.as_value(seq_len) + + # ---- LDS views ---- + base_ptr = allocator.get_base() + lds_q = _state["lds_q"](base_ptr).get() + lds_kv = _state["lds_kv"](base_ptr).get() + lds_p = _state["lds_p"](base_ptr).get() + + # ---- Thread / block indices ---- + block_id = flir.const_index(flir.block_idx("x")) + tid = flir.const_index(flir.thread_idx("x")) + + # ---- Wave decomposition ---- + c_ws = flir.const_index(WARP_SIZE) + wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) + lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) + + # ---- MFMA lane decomposition (within each wave) ---- + c16 = flir.const_index(16) + lane_div_16 = arith.as_value(flir.arith.DivUIOp(lane, c16).result) + lane_mod_16 = arith.as_value(flir.arith.RemUIOp(lane, c16).result) + + # ---- Wave's Q-row offset in the Q tile ---- + wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value + # Wave's P offset in lds_p + wave_p_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE * BLOCK_N).value + + # ---- Decompose block_id -> (batch_idx, q_tile_idx, head_idx) ---- + c_nh = flir.const_index(NUM_HEADS) + head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) + temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) + c_bm = flir.const_index(BLOCK_M) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) + q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) + batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) + q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value + + # ---- Vectorized load thread decomposition ---- + c_tpr = flir.const_index(THREADS_PER_ROW_LOAD) + load_row_in_batch = arith.as_value( + flir.arith.DivUIOp(tid, c_tpr).result + ) + load_lane_in_row = arith.as_value( + flir.arith.RemUIOp(tid, c_tpr).result + ) + load_col_base = ( + arith.ArithValue(load_lane_in_row) * VEC_WIDTH + ).value + + # ---- Helper: global flat index ---- + def global_idx(token_idx, col): + token = ( + arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) + + arith.ArithValue(token_idx) + ) + return ( + token * STRIDE_TOKEN + + arith.ArithValue(head_idx) * HEAD_DIM + + arith.ArithValue(col) + ).value + + # ---- Cooperative Q load (64 rows, all 256 threads, unpadded) ---- + def coop_load_q(): + for batch in range_constexpr(NUM_BATCHES_Q): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(q_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value( + vec_ext.load_op(v8f16_type, Q, [g_idx]) + ) + lds_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + lds_idx = ( + arith.ArithValue(lds_row) * HEAD_DIM + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_q, [lds_idx]) + + # ---- Cooperative K load (row-major with padded stride) ---- + def coop_load_k(tile_start): + if KV_NEEDS_GUARD: + c_bn = flir.const_index(BLOCK_N) + row_ok = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + load_row_in_batch, c_bn, + ).result + ) + from flydsl.dialects.ext.scf import IfOp + if_op = IfOp(row_ok) + with if_op: + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + ).value + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value( + vec_ext.load_op(v8f16_type, K, [g_idx]) + ) + lds_idx = ( + arith.ArithValue(load_row_in_batch) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + else: + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value( + vec_ext.load_op(v8f16_type, K, [g_idx]) + ) + lds_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + lds_idx = ( + arith.ArithValue(lds_row) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + + # ---- Cooperative V load (transposed with padded stride) ---- + # Global V[row, col] -> LDS Vt[col, row] at lds_kv[col * VT_STRIDE + row] + def coop_load_v_transposed(tile_start): + if KV_NEEDS_GUARD: + c_bn = flir.const_index(BLOCK_N) + row_ok = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + load_row_in_batch, c_bn, + ).result + ) + from flydsl.dialects.ext.scf import IfOp + if_op = IfOp(row_ok) + with if_op: + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + ).value + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value( + vec_ext.load_op(v8f16_type, V, [g_idx]) + ) + # Scatter-store transposed: V[row, col+e] -> lds[col_e * VT_STRIDE + row] + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row_in_batch) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + else: + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value( + vec_ext.load_op(v8f16_type, V, [g_idx]) + ) + load_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + # Scatter-store transposed + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + + # ---- Load Q tile to LDS ---- + coop_load_q() + gpu.barrier() + + # ---- Preload Q A-operand packs into registers ---- + # Each lane loads K_STEPS v4f16 packs from LDS_Q (one-time cost). + # At step ks, thread (b,n) needs Q[wave_row + n, ks*16 + b*4 : ks*16+b*4+4] + q_a_packs = [] + for ks in range_constexpr(K_STEPS): + q_lds_idx = ( + (arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_mod_16)) * HEAD_DIM + + ks * 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + q_a_packs.append(arith.as_value( + vec_ext.load_op(v4f16_type, lds_q, [q_lds_idx]) + )) + + # ---- Constants ---- + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_zero_f = arith.constant(0.0, type=compute_type) + c_sm_scale = arith.constant(sm_scale, type=compute_type) + c_log2e = arith.constant(1.4426950408889634, type=compute_type) + c_zero_v4f32 = arith.as_value( + arith.constant_vector(0.0, v4f32_type) + ) + + # ---- Init loop-carried state ---- + # [m_0..m_3, l_0..l_3, o_acc_0..o_acc_{K_STEPS-1}] + init_args = [] + for _ in range_constexpr(4): + init_args.append(arith.as_value(c_neg_inf)) + for _ in range_constexpr(4): + init_args.append(arith.as_value(c_zero_f)) + for _ in range_constexpr(K_STEPS): + init_args.append(c_zero_v4f32) + + # ---- KV loop upper bound ---- + # Causal early-exit: last Q row = q_start + BLOCK_M - 1, + # so only need KV positions 0 .. q_start + BLOCK_M - 1. + if CAUSAL: + kv_upper = (arith.ArithValue(q_start) + BLOCK_M).value + else: + kv_upper = seq_len_v + + # ---- KV loop ---- + with scf.for_(0, kv_upper, BLOCK_N, iter_args=init_args) as loop: + kv_start = arith.as_value(loop.induction_variable) + m_old = [arith.as_value(loop.inner_iter_args[i]) for i in range(4)] + l_old = [arith.as_value(loop.inner_iter_args[4 + i]) for i in range(4)] + o_accs = [arith.as_value(loop.inner_iter_args[8 + ds]) for ds in range(K_STEPS)] + + # ==== Cooperative K load -> LDS_KV (row-major, padded stride) ==== + coop_load_k(kv_start) + gpu.barrier() + + # ==== Q @ K^T via MFMA (Q from registers, K from LDS) ==== + s_acc = c_zero_v4f32 + for ks in range_constexpr(K_STEPS): + # A operand (Q): from preloaded registers + a_pack = q_a_packs[ks] + # B operand (K^T): from LDS with padded stride + k_lds_idx = ( + arith.ArithValue(lane_mod_16) * K_STRIDE + + ks * 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + b_pack = arith.as_value( + vec_ext.load_op(v4f16_type, lds_kv, [k_lds_idx]) + ) + s_acc = arith.as_value( + rocdl.mfma_f32_16x16x16f16(v4f32_type, [a_pack, b_pack, s_acc, 0, 0, 0]) + ) + + # ==== Online softmax (per-wave, per-row) ==== + s_vals = [] + for ii in range_constexpr(4): + s_ii = arith.as_value( + vec_ext.extract(s_acc, static_position=[ii], dynamic_position=[]) + ) + s_ii = arith.as_value( + flir.arith.MulFOp(s_ii, arith.as_value(c_sm_scale), fastmath=fm_fast).result + ) + if CAUSAL: + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_div_16) * 4 + + ii + ).value + kv_col = (arith.ArithValue(kv_start) + arith.ArithValue(lane_mod_16)).value + q_row_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), q_row).result) + kv_col_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), kv_col).result) + is_masked = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ugt, kv_col_i64, q_row_i64, + ).result + ) + s_ii = arith.as_value( + flir.arith.SelectOp(is_masked, arith.as_value(c_neg_inf), s_ii).result + ) + s_vals.append(s_ii) + + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + m_new = [None] * 4 + corr = [None] * 4 + p_vals = [None] * 4 + l_new = [None] * 4 + + for ii in range_constexpr(4): + row_max = s_vals[ii] + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(row_max, sh_i32, width_i32, mode="xor").shuffleResult + ) + row_max = arith.as_value( + flir.arith.MaximumFOp(row_max, peer).result + ) + + m_new[ii] = arith.as_value( + flir.arith.MaximumFOp(m_old[ii], row_max).result + ) + + diff_m = arith.as_value( + flir.arith.SubFOp(m_old[ii], m_new[ii], fastmath=fm_fast).result + ) + diff_m_s = arith.as_value( + flir.arith.MulFOp(diff_m, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + corr[ii] = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) + + diff_s = arith.as_value( + flir.arith.SubFOp(s_vals[ii], m_new[ii], fastmath=fm_fast).result + ) + diff_s_s = arith.as_value( + flir.arith.MulFOp(diff_s, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + p_vals[ii] = arith.as_value(flir.math.exp2(diff_s_s, fastmath=fm_fast)) + + row_sum = p_vals[ii] + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(row_sum, sh_i32, width_i32, mode="xor").shuffleResult + ) + row_sum = arith.as_value( + flir.arith.AddFOp(row_sum, peer, fastmath=fm_fast).result + ) + + l_corr = arith.as_value( + flir.arith.MulFOp(corr[ii], l_old[ii], fastmath=fm_fast).result + ) + l_new[ii] = arith.as_value( + flir.arith.AddFOp(l_corr, row_sum, fastmath=fm_fast).result + ) + + # ==== Rescale O accumulators ==== + corr_vec = arith.as_value( + vec_ext.from_elements(v4f32_type, [corr[0], corr[1], corr[2], corr[3]]) + ) + for ds in range_constexpr(K_STEPS): + o_accs[ds] = arith.as_value( + flir.arith.MulFOp(o_accs[ds], corr_vec, fastmath=fm_fast).result + ) + + # ==== P store to LDS_P (each wave writes its 16x16 section) ==== + for ii in range_constexpr(4): + p_f16 = arith.as_value( + flir.arith.TruncFOp(elem_type, p_vals[ii]).result + ) + p_row = (arith.ArithValue(lane_div_16) * 4 + ii).value + p_lds_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(p_row) * BLOCK_N + + arith.ArithValue(lane_mod_16) + ).value + _memref.StoreOp(p_f16, lds_p, [p_lds_idx]) + + # ==== Barrier: ensure all waves finished reading K from lds_kv ==== + gpu.barrier() + + # ==== Cooperative V load -> LDS_KV (transposed, padded stride) ==== + coop_load_v_transposed(kv_start) + gpu.barrier() + + # ==== P load (A-operand, wave-local) ==== + p_a_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(lane_mod_16) * BLOCK_N + + arith.ArithValue(lane_div_16) * 4 + ).value + p_pack = arith.as_value( + vec_ext.load_op(v4f16_type, lds_p, [p_a_idx]) + ) + + # ==== P @ V via MFMA (V from transposed LDS, vectorized v4f16 loads) ==== + # V transposed: V[row, col] at lds_kv[col * VT_STRIDE + row] + # B-operand: V[b*4:b*4+4, ds*16+n] = lds_kv[(ds*16+n) * VT_STRIDE + b*4] + # -> v4f16 at base (ds*16 + lane_mod_16) * VT_STRIDE + lane_div_16 * 4 + for ds in range_constexpr(K_STEPS): + v_lds_idx = ( + (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE + + arith.ArithValue(lane_div_16) * 4 + ).value + v_pack = arith.as_value( + vec_ext.load_op(v4f16_type, lds_kv, [v_lds_idx]) + ) + o_accs[ds] = arith.as_value( + rocdl.mfma_f32_16x16x16f16( + v4f32_type, [p_pack, v_pack, o_accs[ds], 0, 0, 0] + ) + ) + + # ==== Barrier: ensure all waves finished P@V (reading lds_kv) + # before next iteration overwrites lds_kv with K ==== + gpu.barrier() + + # ==== Yield ==== + yield_args = m_new + l_new + o_accs + scf_yield(yield_args) + + # ---- Normalize and store O ---- + m_finals = [arith.as_value(loop.results[i]) for i in range(4)] + l_finals = [arith.as_value(loop.results[4 + i]) for i in range(4)] + o_finals = [arith.as_value(loop.results[8 + ds]) for ds in range(K_STEPS)] + + for ds in range_constexpr(K_STEPS): + for ii in range_constexpr(4): + o_val = arith.as_value( + vec_ext.extract(o_finals[ds], static_position=[ii], dynamic_position=[]) + ) + o_norm = arith.as_value( + flir.arith.DivFOp(o_val, l_finals[ii], fastmath=fm_fast).result + ) + o_f16 = arith.as_value( + flir.arith.TruncFOp(elem_type, o_norm).result + ) + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_div_16) * 4 + + ii + ).value + d_col = (flir.const_index(ds * 16) + arith.ArithValue(lane_mod_16)).value + o_global = global_idx(q_row, d_col) + _memref.StoreOp(o_f16, O, [o_global]) + + @flir.jit + def __call__( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + batch_size: lambda: T.index(), + seq_len: lambda: T.index(), + ): + c1 = arith.as_value(flir.arith_ext.index(1)) + c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) + c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) + bs_val = arith.as_value(batch_size) + sl_val = arith.as_value(seq_len) + num_q_tiles = arith.as_value( + flir.arith.DivUIOp(sl_val, c_bm).result + ) + bs_qt = arith.as_value( + flir.arith.MulIOp(bs_val, num_q_tiles).result + ) + grid_x = arith.as_value( + flir.arith.MulIOp(bs_qt, c_nh).result + ) + bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) + flir.gpu_ext.LaunchFuncOp( + [self.GPU_MODULE_NAME, KERNEL_NAME], + grid_size=(grid_x, c1, c1), + block_size=(bx, c1, c1), + kernel_operands=[Q, K, V, O, seq_len], + ) + + return _FlashAttentionV4_1() diff --git a/kernels/flash_attention_v4_2.py b/kernels/flash_attention_v4_2.py new file mode 100644 index 00000000..de6f28ca --- /dev/null +++ b/kernels/flash_attention_v4_2.py @@ -0,0 +1,667 @@ +"""Flash Attention V4.2 kernel builder for FlyDSL. + +V4.2 optimizations over V4.1: +- BLOCK_N=32 (vs 16): halves KV iterations and barriers +- Q@K^T produces [16,32] via two MFMA 16x16x16 in N dimension +- P@V uses K=32 via two MFMA 16x16x16 in K dimension +- Softmax over 32 positions per row (two 16-wide groups) +- V stored transposed in LDS with bank-conflict-free padding (from V4.1) +- Q preloaded to registers (from V4.1) + +Tile config: BLOCK_M=64, BLOCK_N=32, 4 waves (256 threads), mfma_f32_16x16x16f16. + +Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). +Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. +Block: (256,) -- 4 waves of 64 on AMD (wave64). + +Requires: head_dim % 16 == 0, seq_len % 64 == 0, head_dim >= 64. +""" + +import math + +from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl +from flydsl.dialects.ext import vector as vec_ext +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.dialects.ext.scf import yield_ as scf_yield +from _mlir.dialects import memref as _memref +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils import SmemAllocator +from _mlir import ir +import _mlir.extras.types as T + + +KERNEL_NAME = "flash_attention_v4_2_kernel" + + +def build_flash_attention_v4_2_module( + num_heads, + head_dim, + causal=True, + dtype_str="f16", + sm_scale=None, +): + """Build a FlyDSL Flash Attention V4.2 module. + + Args: + num_heads: Number of attention heads. + head_dim: Dimension per head (must be divisible by 16, >= 64). + causal: Whether to apply causal mask. + dtype_str: "f16" (bf16 not yet supported). + sm_scale: Softmax scale (default: 1/sqrt(head_dim)). + + Returns: + MlirModule compilable via ``flydsl.compile(module)``. + """ + gpu_arch = get_hip_arch() + DYN = ir.ShapedType.get_dynamic_size() + + BLOCK_M = 64 + BLOCK_N = 32 # *** doubled from V4.1 *** + NUM_WAVES = 4 + WARP_SIZE = 64 + BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 + ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 16 + K_STEPS = head_dim // 16 + # Number of 16-wide MFMA columns in Q@K^T N-dimension + N_MFMA = BLOCK_N // 16 # 2 + + assert head_dim % 16 == 0, f"head_dim ({head_dim}) must be divisible by 16" + assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" + assert dtype_str == "f16", "V4.2 currently only supports f16" + assert BLOCK_N % 16 == 0, f"BLOCK_N ({BLOCK_N}) must be divisible by 16" + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + NUM_HEADS = num_heads + HEAD_DIM = head_dim + CAUSAL = causal + STRIDE_TOKEN = NUM_HEADS * HEAD_DIM + + # ---- Bank-conflict-free LDS strides ---- + K_STRIDE = HEAD_DIM + 2 # 130 for HD=128 + VT_STRIDE = BLOCK_N + 2 # 34 for BLOCK_N=32 + + # ---- Vectorized cooperative load constants ---- + VEC_WIDTH = 8 + THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH + assert BLOCK_SIZE % THREADS_PER_ROW_LOAD == 0 + ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD + + assert BLOCK_M % ROWS_PER_BATCH_LOAD == 0 + NUM_BATCHES_Q = BLOCK_M // ROWS_PER_BATCH_LOAD + + # For KV tile (32 rows with 256 threads) + if ROWS_PER_BATCH_LOAD >= BLOCK_N: + NUM_BATCHES_KV = 1 + KV_NEEDS_GUARD = ROWS_PER_BATCH_LOAD > BLOCK_N + else: + assert BLOCK_N % ROWS_PER_BATCH_LOAD == 0 + NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD + KV_NEEDS_GUARD = False + + # LDS sizes + LDS_Q_SIZE = BLOCK_M * HEAD_DIM + LDS_KV_SIZE = max(BLOCK_N * K_STRIDE, HEAD_DIM * VT_STRIDE) + LDS_P_SIZE = BLOCK_M * BLOCK_N # 64*32 = 2048 + + allocator = SmemAllocator(None, arch=gpu_arch) + _state = {} + + class _FlashAttentionV4_2(flir.MlirModule): + GPU_MODULE_NAME = f"flash_attn_v4_2_{dtype_str}" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + elem_type = T.f16() + _state["elem_type"] = elem_type + _state["lds_q"] = allocator.allocate_array(elem_type, LDS_Q_SIZE) + _state["lds_kv"] = allocator.allocate_array(elem_type, LDS_KV_SIZE) + _state["lds_p"] = allocator.allocate_array(elem_type, LDS_P_SIZE) + allocator.finalize() + + @flir.kernel + def flash_attention_v4_2_kernel( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + seq_len: lambda: T.index(), + ): + compute_type = T.f32() + elem_type = _state["elem_type"] + fm_fast = flir.arith.FastMathFlags.fast + + v4f16_type = ir.VectorType.get([4], elem_type) + v4f32_type = ir.VectorType.get([4], compute_type) + v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) + + seq_len_v = arith.as_value(seq_len) + + # ---- LDS views ---- + base_ptr = allocator.get_base() + lds_q = _state["lds_q"](base_ptr).get() + lds_kv = _state["lds_kv"](base_ptr).get() + lds_p = _state["lds_p"](base_ptr).get() + + # ---- Thread / block indices ---- + block_id = flir.const_index(flir.block_idx("x")) + tid = flir.const_index(flir.thread_idx("x")) + + # ---- Wave decomposition ---- + c_ws = flir.const_index(WARP_SIZE) + wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) + lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) + + # ---- MFMA lane decomposition ---- + c16 = flir.const_index(16) + lane_div_16 = arith.as_value(flir.arith.DivUIOp(lane, c16).result) + lane_mod_16 = arith.as_value(flir.arith.RemUIOp(lane, c16).result) + + # ---- Wave offsets ---- + wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value + wave_p_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE * BLOCK_N).value + + # ---- Decompose block_id ---- + c_nh = flir.const_index(NUM_HEADS) + head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) + temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) + c_bm = flir.const_index(BLOCK_M) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) + q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) + batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) + q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value + + # ---- Load thread decomposition ---- + c_tpr = flir.const_index(THREADS_PER_ROW_LOAD) + load_row_in_batch = arith.as_value( + flir.arith.DivUIOp(tid, c_tpr).result + ) + load_lane_in_row = arith.as_value( + flir.arith.RemUIOp(tid, c_tpr).result + ) + load_col_base = ( + arith.ArithValue(load_lane_in_row) * VEC_WIDTH + ).value + + # ---- Helper: global flat index ---- + def global_idx(token_idx, col): + token = ( + arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) + + arith.ArithValue(token_idx) + ) + return ( + token * STRIDE_TOKEN + + arith.ArithValue(head_idx) * HEAD_DIM + + arith.ArithValue(col) + ).value + + # ---- Cooperative Q load (64 rows, unpadded) ---- + def coop_load_q(): + for batch in range_constexpr(NUM_BATCHES_Q): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(q_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value( + vec_ext.load_op(v8f16_type, Q, [g_idx]) + ) + lds_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + lds_idx = ( + arith.ArithValue(lds_row) * HEAD_DIM + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_q, [lds_idx]) + + # ---- Cooperative K load (row-major, padded stride) ---- + def coop_load_k(tile_start): + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value( + vec_ext.load_op(v8f16_type, K, [g_idx]) + ) + lds_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + lds_idx = ( + arith.ArithValue(lds_row) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + + # ---- Cooperative V load (transposed, padded stride) ---- + def coop_load_v_transposed(tile_start): + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value( + vec_ext.load_op(v8f16_type, V, [g_idx]) + ) + load_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + # Scatter-store transposed: V[row, col+e] -> lds[(col+e)*VT_STRIDE + row] + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + + # ---- Load Q tile to LDS ---- + coop_load_q() + gpu.barrier() + + # ---- Preload Q A-operand packs into registers ---- + q_a_packs = [] + for ks in range_constexpr(K_STEPS): + q_lds_idx = ( + (arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_mod_16)) * HEAD_DIM + + ks * 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + q_a_packs.append(arith.as_value( + vec_ext.load_op(v4f16_type, lds_q, [q_lds_idx]) + )) + + # ---- Constants ---- + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_zero_f = arith.constant(0.0, type=compute_type) + c_sm_scale = arith.constant(sm_scale, type=compute_type) + c_log2e = arith.constant(1.4426950408889634, type=compute_type) + c_zero_v4f32 = arith.as_value( + arith.constant_vector(0.0, v4f32_type) + ) + + # ---- Init loop-carried state ---- + # m[4], l[4], o_accs[K_STEPS] + init_args = [] + for _ in range_constexpr(4): + init_args.append(arith.as_value(c_neg_inf)) + for _ in range_constexpr(4): + init_args.append(arith.as_value(c_zero_f)) + for _ in range_constexpr(K_STEPS): + init_args.append(c_zero_v4f32) + + # ---- KV loop upper bound ---- + # Causal early-exit: last Q row = q_start + BLOCK_M - 1, + # so only need KV positions 0 .. q_start + BLOCK_M - 1. + # q_start + BLOCK_M is always a multiple of BLOCK_N (64 % 32 == 0). + if CAUSAL: + kv_upper = (arith.ArithValue(q_start) + BLOCK_M).value + else: + kv_upper = seq_len_v + + # ---- KV loop (step BLOCK_N=32) ---- + with scf.for_(0, kv_upper, BLOCK_N, iter_args=init_args) as loop: + kv_start = arith.as_value(loop.induction_variable) + m_old = [arith.as_value(loop.inner_iter_args[i]) for i in range(4)] + l_old = [arith.as_value(loop.inner_iter_args[4 + i]) for i in range(4)] + o_accs = [arith.as_value(loop.inner_iter_args[8 + ds]) for ds in range(K_STEPS)] + + # ==== Cooperative K load -> LDS_KV (32 rows, padded stride) ==== + coop_load_k(kv_start) + gpu.barrier() + + # ==== Q @ K^T via MFMA -> S[16, 32] ==== + # Two MFMA outputs: s_acc[0] for KV cols 0..15, s_acc[1] for KV cols 16..31 + s_accs = [c_zero_v4f32, c_zero_v4f32] + for ks in range_constexpr(K_STEPS): + a_pack = q_a_packs[ks] + for nm in range_constexpr(N_MFMA): + # B operand (K^T): K row = nm*16 + lane_mod_16 + k_row = nm * 16 + k_lds_idx = ( + (arith.ArithValue(lane_mod_16) + k_row) * K_STRIDE + + ks * 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + b_pack = arith.as_value( + vec_ext.load_op(v4f16_type, lds_kv, [k_lds_idx]) + ) + s_accs[nm] = arith.as_value( + rocdl.mfma_f32_16x16x16f16( + v4f32_type, [a_pack, b_pack, s_accs[nm], 0, 0, 0] + ) + ) + + # ==== Online softmax over 32 positions ==== + # For each row ii (0..3): have values at lane_mod_16 in s_accs[0] and s_accs[1] + # Need max and sum over all 32 positions + s_vals_lo = [] # from s_accs[0], KV cols 0..15 + s_vals_hi = [] # from s_accs[1], KV cols 16..31 + for ii in range_constexpr(4): + s_lo = arith.as_value( + vec_ext.extract(s_accs[0], static_position=[ii], dynamic_position=[]) + ) + s_lo = arith.as_value( + flir.arith.MulFOp(s_lo, arith.as_value(c_sm_scale), fastmath=fm_fast).result + ) + s_hi = arith.as_value( + vec_ext.extract(s_accs[1], static_position=[ii], dynamic_position=[]) + ) + s_hi = arith.as_value( + flir.arith.MulFOp(s_hi, arith.as_value(c_sm_scale), fastmath=fm_fast).result + ) + + if CAUSAL: + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_div_16) * 4 + + ii + ).value + # Low half: KV col = kv_start + lane_mod_16 + kv_col_lo = (arith.ArithValue(kv_start) + arith.ArithValue(lane_mod_16)).value + # High half: KV col = kv_start + 16 + lane_mod_16 + kv_col_hi = (arith.ArithValue(kv_start) + 16 + arith.ArithValue(lane_mod_16)).value + + q_row_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), q_row).result) + kv_lo_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), kv_col_lo).result) + kv_hi_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), kv_col_hi).result) + + is_masked_lo = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ugt, kv_lo_i64, q_row_i64, + ).result + ) + is_masked_hi = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ugt, kv_hi_i64, q_row_i64, + ).result + ) + s_lo = arith.as_value( + flir.arith.SelectOp(is_masked_lo, arith.as_value(c_neg_inf), s_lo).result + ) + s_hi = arith.as_value( + flir.arith.SelectOp(is_masked_hi, arith.as_value(c_neg_inf), s_hi).result + ) + + s_vals_lo.append(s_lo) + s_vals_hi.append(s_hi) + + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + m_new = [None] * 4 + corr = [None] * 4 + p_vals_lo = [None] * 4 + p_vals_hi = [None] * 4 + l_new = [None] * 4 + + for ii in range_constexpr(4): + # Max over 32 positions: max of lo-half and hi-half + row_max_lo = s_vals_lo[ii] + row_max_hi = s_vals_hi[ii] + + # Reduce lo-half within 16 lanes + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(row_max_lo, sh_i32, width_i32, mode="xor").shuffleResult + ) + row_max_lo = arith.as_value( + flir.arith.MaximumFOp(row_max_lo, peer).result + ) + + # Reduce hi-half within 16 lanes + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(row_max_hi, sh_i32, width_i32, mode="xor").shuffleResult + ) + row_max_hi = arith.as_value( + flir.arith.MaximumFOp(row_max_hi, peer).result + ) + + # Combine lo and hi maxes + row_max = arith.as_value( + flir.arith.MaximumFOp(row_max_lo, row_max_hi).result + ) + + m_new[ii] = arith.as_value( + flir.arith.MaximumFOp(m_old[ii], row_max).result + ) + + diff_m = arith.as_value( + flir.arith.SubFOp(m_old[ii], m_new[ii], fastmath=fm_fast).result + ) + diff_m_s = arith.as_value( + flir.arith.MulFOp(diff_m, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + corr[ii] = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) + + # exp2 for both halves + diff_lo = arith.as_value( + flir.arith.SubFOp(s_vals_lo[ii], m_new[ii], fastmath=fm_fast).result + ) + diff_lo_s = arith.as_value( + flir.arith.MulFOp(diff_lo, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + p_vals_lo[ii] = arith.as_value(flir.math.exp2(diff_lo_s, fastmath=fm_fast)) + + diff_hi = arith.as_value( + flir.arith.SubFOp(s_vals_hi[ii], m_new[ii], fastmath=fm_fast).result + ) + diff_hi_s = arith.as_value( + flir.arith.MulFOp(diff_hi, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + p_vals_hi[ii] = arith.as_value(flir.math.exp2(diff_hi_s, fastmath=fm_fast)) + + # Sum over 32 positions + row_sum_lo = p_vals_lo[ii] + row_sum_hi = p_vals_hi[ii] + + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(row_sum_lo, sh_i32, width_i32, mode="xor").shuffleResult + ) + row_sum_lo = arith.as_value( + flir.arith.AddFOp(row_sum_lo, peer, fastmath=fm_fast).result + ) + + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(row_sum_hi, sh_i32, width_i32, mode="xor").shuffleResult + ) + row_sum_hi = arith.as_value( + flir.arith.AddFOp(row_sum_hi, peer, fastmath=fm_fast).result + ) + + row_sum = arith.as_value( + flir.arith.AddFOp(row_sum_lo, row_sum_hi, fastmath=fm_fast).result + ) + + l_corr = arith.as_value( + flir.arith.MulFOp(corr[ii], l_old[ii], fastmath=fm_fast).result + ) + l_new[ii] = arith.as_value( + flir.arith.AddFOp(l_corr, row_sum, fastmath=fm_fast).result + ) + + # ==== Rescale O accumulators ==== + corr_vec = arith.as_value( + vec_ext.from_elements(v4f32_type, [corr[0], corr[1], corr[2], corr[3]]) + ) + for ds in range_constexpr(K_STEPS): + o_accs[ds] = arith.as_value( + flir.arith.MulFOp(o_accs[ds], corr_vec, fastmath=fm_fast).result + ) + + # ==== P store to LDS_P ==== + # P is [16, 32] per wave. Two 16x16 blocks: lo (cols 0..15) and hi (cols 16..31) + for ii in range_constexpr(4): + p_lo_f16 = arith.as_value( + flir.arith.TruncFOp(elem_type, p_vals_lo[ii]).result + ) + p_hi_f16 = arith.as_value( + flir.arith.TruncFOp(elem_type, p_vals_hi[ii]).result + ) + p_row = (arith.ArithValue(lane_div_16) * 4 + ii).value + # Lo: cols 0..15 + p_lds_lo = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(p_row) * BLOCK_N + + arith.ArithValue(lane_mod_16) + ).value + _memref.StoreOp(p_lo_f16, lds_p, [p_lds_lo]) + # Hi: cols 16..31 + p_lds_hi = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(p_row) * BLOCK_N + + 16 + + arith.ArithValue(lane_mod_16) + ).value + _memref.StoreOp(p_hi_f16, lds_p, [p_lds_hi]) + + # ==== Barrier: ensure all waves done reading K ==== + gpu.barrier() + + # ==== Cooperative V load (transposed) ==== + coop_load_v_transposed(kv_start) + gpu.barrier() + + # ==== P @ V via MFMA ==== + # P[16, 32] @ V[32, 16chunk] = O[16, 16chunk] + # Split K=32 into two halves: P_lo[16,16] @ V_top[16,16] + P_hi[16,16] @ V_bot[16,16] + + # Load P A-operand packs: P_lo and P_hi + p_a_lo_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(lane_mod_16) * BLOCK_N + + arith.ArithValue(lane_div_16) * 4 + ).value + p_pack_lo = arith.as_value( + vec_ext.load_op(v4f16_type, lds_p, [p_a_lo_idx]) + ) + + p_a_hi_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(lane_mod_16) * BLOCK_N + + 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + p_pack_hi = arith.as_value( + vec_ext.load_op(v4f16_type, lds_p, [p_a_hi_idx]) + ) + + for ds in range_constexpr(K_STEPS): + # V_top: V rows 0..15, B-operand from transposed LDS + v_top_idx = ( + (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE + + arith.ArithValue(lane_div_16) * 4 + ).value + v_top = arith.as_value( + vec_ext.load_op(v4f16_type, lds_kv, [v_top_idx]) + ) + # Accumulate P_lo @ V_top + o_accs[ds] = arith.as_value( + rocdl.mfma_f32_16x16x16f16( + v4f32_type, [p_pack_lo, v_top, o_accs[ds], 0, 0, 0] + ) + ) + + # V_bot: V rows 16..31, B-operand from transposed LDS + v_bot_idx = ( + (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE + + 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + v_bot = arith.as_value( + vec_ext.load_op(v4f16_type, lds_kv, [v_bot_idx]) + ) + # Accumulate P_hi @ V_bot + o_accs[ds] = arith.as_value( + rocdl.mfma_f32_16x16x16f16( + v4f32_type, [p_pack_hi, v_bot, o_accs[ds], 0, 0, 0] + ) + ) + + # ==== Barrier: ensure all waves done reading V ==== + gpu.barrier() + + # ==== Yield ==== + yield_args = m_new + l_new + o_accs + scf_yield(yield_args) + + # ---- Normalize and store O ---- + m_finals = [arith.as_value(loop.results[i]) for i in range(4)] + l_finals = [arith.as_value(loop.results[4 + i]) for i in range(4)] + o_finals = [arith.as_value(loop.results[8 + ds]) for ds in range(K_STEPS)] + + for ds in range_constexpr(K_STEPS): + for ii in range_constexpr(4): + o_val = arith.as_value( + vec_ext.extract(o_finals[ds], static_position=[ii], dynamic_position=[]) + ) + o_norm = arith.as_value( + flir.arith.DivFOp(o_val, l_finals[ii], fastmath=fm_fast).result + ) + o_f16 = arith.as_value( + flir.arith.TruncFOp(elem_type, o_norm).result + ) + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_div_16) * 4 + + ii + ).value + d_col = (flir.const_index(ds * 16) + arith.ArithValue(lane_mod_16)).value + o_global = global_idx(q_row, d_col) + _memref.StoreOp(o_f16, O, [o_global]) + + @flir.jit + def __call__( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + batch_size: lambda: T.index(), + seq_len: lambda: T.index(), + ): + c1 = arith.as_value(flir.arith_ext.index(1)) + c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) + c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) + bs_val = arith.as_value(batch_size) + sl_val = arith.as_value(seq_len) + num_q_tiles = arith.as_value( + flir.arith.DivUIOp(sl_val, c_bm).result + ) + bs_qt = arith.as_value( + flir.arith.MulIOp(bs_val, num_q_tiles).result + ) + grid_x = arith.as_value( + flir.arith.MulIOp(bs_qt, c_nh).result + ) + bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) + flir.gpu_ext.LaunchFuncOp( + [self.GPU_MODULE_NAME, KERNEL_NAME], + grid_size=(grid_x, c1, c1), + block_size=(bx, c1, c1), + kernel_operands=[Q, K, V, O, seq_len], + ) + + return _FlashAttentionV4_2() diff --git a/kernels/flash_attention_v4_3.py b/kernels/flash_attention_v4_3.py new file mode 100644 index 00000000..bdfe26f9 --- /dev/null +++ b/kernels/flash_attention_v4_3.py @@ -0,0 +1,650 @@ +"""Flash Attention V4.3 kernel builder for FlyDSL. + +V4.3 optimization over V4.2: +- Q loaded directly from global memory to MFMA registers (no Q in LDS). + LDS = KV(8.5KB) + P(4KB) = 12.5KB (was 29KB in V4.2). + This enables 4 workgroups/CU -> 4 waves/SIMD (was 2 waves/SIMD). +- Eliminates 2 barriers (Q store + Q preload sync). + +All other optimizations from V4.2: +- BLOCK_N=32 (vs 16): halves KV iterations and barriers +- Q@K^T produces [16,32] via two MFMA 16x16x16 in N dimension +- P@V uses K=32 via two MFMA 16x16x16 in K dimension +- Softmax over 32 positions per row (two 16-wide groups) +- V stored transposed in LDS with bank-conflict-free padding (from V4.1) +- Causal early-exit + +Tile config: BLOCK_M=64, BLOCK_N=32, 4 waves (256 threads), mfma_f32_16x16x16f16. + +Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). +Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. +Block: (256,) -- 4 waves of 64 on AMD (wave64). + +Requires: head_dim % 16 == 0, seq_len % 64 == 0, head_dim >= 64. +""" + +import math + +from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl +from flydsl.dialects.ext import vector as vec_ext +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.dialects.ext.scf import yield_ as scf_yield +from _mlir.dialects import memref as _memref +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils import SmemAllocator +from _mlir import ir +import _mlir.extras.types as T + + +KERNEL_NAME = "flash_attention_v4_3_kernel" + + +def build_flash_attention_v4_3_module( + num_heads, + head_dim, + causal=True, + dtype_str="f16", + sm_scale=None, +): + """Build a FlyDSL Flash Attention V4.3 module (LDS overlay). + + Args: + num_heads: Number of attention heads. + head_dim: Dimension per head (must be divisible by 16, >= 64). + causal: Whether to apply causal mask. + dtype_str: "f16" (bf16 not yet supported). + sm_scale: Softmax scale (default: 1/sqrt(head_dim)). + + Returns: + MlirModule compilable via ``flydsl.compile(module)``. + """ + gpu_arch = get_hip_arch() + DYN = ir.ShapedType.get_dynamic_size() + + BLOCK_M = 64 + BLOCK_N = 32 # *** doubled from V4.1 *** + NUM_WAVES = 4 + WARP_SIZE = 64 + BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 + ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 16 + K_STEPS = head_dim // 16 + # Number of 16-wide MFMA columns in Q@K^T N-dimension + N_MFMA = BLOCK_N // 16 # 2 + + assert head_dim % 16 == 0, f"head_dim ({head_dim}) must be divisible by 16" + assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" + assert dtype_str == "f16", "V4.3 currently only supports f16" + assert BLOCK_N % 16 == 0, f"BLOCK_N ({BLOCK_N}) must be divisible by 16" + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + NUM_HEADS = num_heads + HEAD_DIM = head_dim + CAUSAL = causal + STRIDE_TOKEN = NUM_HEADS * HEAD_DIM + + # ---- Bank-conflict-free LDS strides ---- + K_STRIDE = HEAD_DIM + 2 # 130 for HD=128 + VT_STRIDE = BLOCK_N + 2 # 34 for BLOCK_N=32 + + # ---- Vectorized cooperative load constants ---- + VEC_WIDTH = 8 + THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH + assert BLOCK_SIZE % THREADS_PER_ROW_LOAD == 0 + ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD + + assert BLOCK_M % ROWS_PER_BATCH_LOAD == 0 + NUM_BATCHES_Q = BLOCK_M // ROWS_PER_BATCH_LOAD + + # For KV tile (32 rows with 256 threads) + if ROWS_PER_BATCH_LOAD >= BLOCK_N: + NUM_BATCHES_KV = 1 + KV_NEEDS_GUARD = ROWS_PER_BATCH_LOAD > BLOCK_N + else: + assert BLOCK_N % ROWS_PER_BATCH_LOAD == 0 + NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD + KV_NEEDS_GUARD = False + + # LDS sizes (element counts, f16 = 2 bytes each) + # No Q in LDS — loaded directly from global memory to MFMA registers + LDS_KV_SIZE = max(BLOCK_N * K_STRIDE, HEAD_DIM * VT_STRIDE) # 4352 elements = 8704 bytes + LDS_P_SIZE = BLOCK_M * BLOCK_N # 2048 elements = 4096 bytes + + allocator = SmemAllocator(None, arch=gpu_arch) + _state = {} + + class _FlashAttentionV4_3(flir.MlirModule): + GPU_MODULE_NAME = f"flash_attn_v4_3_{dtype_str}" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + elem_type = T.f16() + _state["elem_type"] = elem_type + _state["lds_kv"] = allocator.allocate_array(elem_type, LDS_KV_SIZE) + _state["lds_p"] = allocator.allocate_array(elem_type, LDS_P_SIZE) + allocator.finalize() + + @flir.kernel + def flash_attention_v4_3_kernel( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + seq_len: lambda: T.index(), + ): + compute_type = T.f32() + elem_type = _state["elem_type"] + fm_fast = flir.arith.FastMathFlags.fast + + v4f16_type = ir.VectorType.get([4], elem_type) + v4f32_type = ir.VectorType.get([4], compute_type) + v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) + + seq_len_v = arith.as_value(seq_len) + + # ---- LDS views (KV + P only, no Q in LDS) ---- + base_ptr = allocator.get_base() + lds_kv = _state["lds_kv"](base_ptr).get() + lds_p = _state["lds_p"](base_ptr).get() + + # ---- Thread / block indices ---- + block_id = flir.const_index(flir.block_idx("x")) + tid = flir.const_index(flir.thread_idx("x")) + + # ---- Wave decomposition ---- + c_ws = flir.const_index(WARP_SIZE) + wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) + lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) + + # ---- MFMA lane decomposition ---- + c16 = flir.const_index(16) + lane_div_16 = arith.as_value(flir.arith.DivUIOp(lane, c16).result) + lane_mod_16 = arith.as_value(flir.arith.RemUIOp(lane, c16).result) + + # ---- Wave offsets ---- + wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value + wave_p_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE * BLOCK_N).value + + # ---- Decompose block_id ---- + c_nh = flir.const_index(NUM_HEADS) + head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) + temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) + c_bm = flir.const_index(BLOCK_M) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) + q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) + batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) + q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value + + # ---- Load thread decomposition ---- + c_tpr = flir.const_index(THREADS_PER_ROW_LOAD) + load_row_in_batch = arith.as_value( + flir.arith.DivUIOp(tid, c_tpr).result + ) + load_lane_in_row = arith.as_value( + flir.arith.RemUIOp(tid, c_tpr).result + ) + load_col_base = ( + arith.ArithValue(load_lane_in_row) * VEC_WIDTH + ).value + + # ---- Helper: global flat index ---- + def global_idx(token_idx, col): + token = ( + arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) + + arith.ArithValue(token_idx) + ) + return ( + token * STRIDE_TOKEN + + arith.ArithValue(head_idx) * HEAD_DIM + + arith.ArithValue(col) + ).value + + # ---- Cooperative K load (row-major, padded stride) ---- + def coop_load_k(tile_start): + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value( + vec_ext.load_op(v8f16_type, K, [g_idx]) + ) + lds_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + lds_idx = ( + arith.ArithValue(lds_row) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + + # ---- Cooperative V load (transposed, padded stride) ---- + def coop_load_v_transposed(tile_start): + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value( + vec_ext.load_op(v8f16_type, V, [g_idx]) + ) + load_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + # Scatter-store transposed: V[row, col+e] -> lds_kv[(col+e)*VT_STRIDE + row] + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + + # ---- Load Q directly from global memory to MFMA registers ---- + # Each MFMA lane (b=lane_div_16, n=lane_mod_16) loads v4f16 from + # Q[q_start + wave_offset + n, ks*16 + b*4 : ks*16 + b*4 + 4]. + # No LDS needed for Q — eliminates overlay race condition. + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_mod_16) + ).value + q_a_packs = [] + for ks in range_constexpr(K_STEPS): + q_col = flir.const_index(ks * 16 + 0) + q_col = (arith.ArithValue(q_col) + arith.ArithValue(lane_div_16) * 4).value + g_idx = global_idx(q_row, q_col) + q_a_packs.append(arith.as_value( + vec_ext.load_op(v4f16_type, Q, [g_idx]) + )) + + # ---- Constants ---- + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_zero_f = arith.constant(0.0, type=compute_type) + c_sm_scale = arith.constant(sm_scale, type=compute_type) + c_log2e = arith.constant(1.4426950408889634, type=compute_type) + c_zero_v4f32 = arith.as_value( + arith.constant_vector(0.0, v4f32_type) + ) + + # ---- Init loop-carried state ---- + # m[4], l[4], o_accs[K_STEPS] + init_args = [] + for _ in range_constexpr(4): + init_args.append(arith.as_value(c_neg_inf)) + for _ in range_constexpr(4): + init_args.append(arith.as_value(c_zero_f)) + for _ in range_constexpr(K_STEPS): + init_args.append(c_zero_v4f32) + + # ---- KV loop upper bound ---- + # Causal early-exit: last Q row = q_start + BLOCK_M - 1, + # so only need KV positions 0 .. q_start + BLOCK_M - 1. + # q_start + BLOCK_M is always a multiple of BLOCK_N (64 % 32 == 0). + if CAUSAL: + kv_upper = (arith.ArithValue(q_start) + BLOCK_M).value + else: + kv_upper = seq_len_v + + # ---- KV loop (step BLOCK_N=32) ---- + with scf.for_(0, kv_upper, BLOCK_N, iter_args=init_args) as loop: + kv_start = arith.as_value(loop.induction_variable) + m_old = [arith.as_value(loop.inner_iter_args[i]) for i in range(4)] + l_old = [arith.as_value(loop.inner_iter_args[4 + i]) for i in range(4)] + o_accs = [arith.as_value(loop.inner_iter_args[8 + ds]) for ds in range(K_STEPS)] + + # ==== Cooperative K load -> LDS_KV (32 rows, padded stride) ==== + coop_load_k(kv_start) + gpu.barrier() + + # ==== Q @ K^T via MFMA -> S[16, 32] ==== + # Two MFMA outputs: s_acc[0] for KV cols 0..15, s_acc[1] for KV cols 16..31 + s_accs = [c_zero_v4f32, c_zero_v4f32] + for ks in range_constexpr(K_STEPS): + a_pack = q_a_packs[ks] + for nm in range_constexpr(N_MFMA): + # B operand (K^T): K row = nm*16 + lane_mod_16 + k_row = nm * 16 + k_lds_idx = ( + (arith.ArithValue(lane_mod_16) + k_row) * K_STRIDE + + ks * 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + b_pack = arith.as_value( + vec_ext.load_op(v4f16_type, lds_kv, [k_lds_idx]) + ) + s_accs[nm] = arith.as_value( + rocdl.mfma_f32_16x16x16f16( + v4f32_type, [a_pack, b_pack, s_accs[nm], 0, 0, 0] + ) + ) + + # ==== Online softmax over 32 positions ==== + # For each row ii (0..3): have values at lane_mod_16 in s_accs[0] and s_accs[1] + # Need max and sum over all 32 positions + s_vals_lo = [] # from s_accs[0], KV cols 0..15 + s_vals_hi = [] # from s_accs[1], KV cols 16..31 + for ii in range_constexpr(4): + s_lo = arith.as_value( + vec_ext.extract(s_accs[0], static_position=[ii], dynamic_position=[]) + ) + s_lo = arith.as_value( + flir.arith.MulFOp(s_lo, arith.as_value(c_sm_scale), fastmath=fm_fast).result + ) + s_hi = arith.as_value( + vec_ext.extract(s_accs[1], static_position=[ii], dynamic_position=[]) + ) + s_hi = arith.as_value( + flir.arith.MulFOp(s_hi, arith.as_value(c_sm_scale), fastmath=fm_fast).result + ) + + if CAUSAL: + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_div_16) * 4 + + ii + ).value + # Low half: KV col = kv_start + lane_mod_16 + kv_col_lo = (arith.ArithValue(kv_start) + arith.ArithValue(lane_mod_16)).value + # High half: KV col = kv_start + 16 + lane_mod_16 + kv_col_hi = (arith.ArithValue(kv_start) + 16 + arith.ArithValue(lane_mod_16)).value + + q_row_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), q_row).result) + kv_lo_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), kv_col_lo).result) + kv_hi_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), kv_col_hi).result) + + is_masked_lo = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ugt, kv_lo_i64, q_row_i64, + ).result + ) + is_masked_hi = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ugt, kv_hi_i64, q_row_i64, + ).result + ) + s_lo = arith.as_value( + flir.arith.SelectOp(is_masked_lo, arith.as_value(c_neg_inf), s_lo).result + ) + s_hi = arith.as_value( + flir.arith.SelectOp(is_masked_hi, arith.as_value(c_neg_inf), s_hi).result + ) + + s_vals_lo.append(s_lo) + s_vals_hi.append(s_hi) + + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + m_new = [None] * 4 + corr = [None] * 4 + p_vals_lo = [None] * 4 + p_vals_hi = [None] * 4 + l_new = [None] * 4 + + for ii in range_constexpr(4): + # Max over 32 positions: max of lo-half and hi-half + row_max_lo = s_vals_lo[ii] + row_max_hi = s_vals_hi[ii] + + # Reduce lo-half within 16 lanes + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(row_max_lo, sh_i32, width_i32, mode="xor").shuffleResult + ) + row_max_lo = arith.as_value( + flir.arith.MaximumFOp(row_max_lo, peer).result + ) + + # Reduce hi-half within 16 lanes + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(row_max_hi, sh_i32, width_i32, mode="xor").shuffleResult + ) + row_max_hi = arith.as_value( + flir.arith.MaximumFOp(row_max_hi, peer).result + ) + + # Combine lo and hi maxes + row_max = arith.as_value( + flir.arith.MaximumFOp(row_max_lo, row_max_hi).result + ) + + m_new[ii] = arith.as_value( + flir.arith.MaximumFOp(m_old[ii], row_max).result + ) + + diff_m = arith.as_value( + flir.arith.SubFOp(m_old[ii], m_new[ii], fastmath=fm_fast).result + ) + diff_m_s = arith.as_value( + flir.arith.MulFOp(diff_m, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + corr[ii] = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) + + # exp2 for both halves + diff_lo = arith.as_value( + flir.arith.SubFOp(s_vals_lo[ii], m_new[ii], fastmath=fm_fast).result + ) + diff_lo_s = arith.as_value( + flir.arith.MulFOp(diff_lo, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + p_vals_lo[ii] = arith.as_value(flir.math.exp2(diff_lo_s, fastmath=fm_fast)) + + diff_hi = arith.as_value( + flir.arith.SubFOp(s_vals_hi[ii], m_new[ii], fastmath=fm_fast).result + ) + diff_hi_s = arith.as_value( + flir.arith.MulFOp(diff_hi, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + p_vals_hi[ii] = arith.as_value(flir.math.exp2(diff_hi_s, fastmath=fm_fast)) + + # Sum over 32 positions + row_sum_lo = p_vals_lo[ii] + row_sum_hi = p_vals_hi[ii] + + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(row_sum_lo, sh_i32, width_i32, mode="xor").shuffleResult + ) + row_sum_lo = arith.as_value( + flir.arith.AddFOp(row_sum_lo, peer, fastmath=fm_fast).result + ) + + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(row_sum_hi, sh_i32, width_i32, mode="xor").shuffleResult + ) + row_sum_hi = arith.as_value( + flir.arith.AddFOp(row_sum_hi, peer, fastmath=fm_fast).result + ) + + row_sum = arith.as_value( + flir.arith.AddFOp(row_sum_lo, row_sum_hi, fastmath=fm_fast).result + ) + + l_corr = arith.as_value( + flir.arith.MulFOp(corr[ii], l_old[ii], fastmath=fm_fast).result + ) + l_new[ii] = arith.as_value( + flir.arith.AddFOp(l_corr, row_sum, fastmath=fm_fast).result + ) + + # ==== Rescale O accumulators ==== + corr_vec = arith.as_value( + vec_ext.from_elements(v4f32_type, [corr[0], corr[1], corr[2], corr[3]]) + ) + for ds in range_constexpr(K_STEPS): + o_accs[ds] = arith.as_value( + flir.arith.MulFOp(o_accs[ds], corr_vec, fastmath=fm_fast).result + ) + + # ==== P store to LDS_P ==== + # P is [16, 32] per wave. Two 16x16 blocks: lo (cols 0..15) and hi (cols 16..31) + for ii in range_constexpr(4): + p_lo_f16 = arith.as_value( + flir.arith.TruncFOp(elem_type, p_vals_lo[ii]).result + ) + p_hi_f16 = arith.as_value( + flir.arith.TruncFOp(elem_type, p_vals_hi[ii]).result + ) + p_row = (arith.ArithValue(lane_div_16) * 4 + ii).value + # Lo: cols 0..15 + p_lds_lo = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(p_row) * BLOCK_N + + arith.ArithValue(lane_mod_16) + ).value + _memref.StoreOp(p_lo_f16, lds_p, [p_lds_lo]) + # Hi: cols 16..31 + p_lds_hi = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(p_row) * BLOCK_N + + 16 + + arith.ArithValue(lane_mod_16) + ).value + _memref.StoreOp(p_hi_f16, lds_p, [p_lds_hi]) + + # ==== Barrier: ensure all waves done reading K ==== + gpu.barrier() + + # ==== Cooperative V load (transposed) ==== + coop_load_v_transposed(kv_start) + gpu.barrier() + + # ==== P @ V via MFMA ==== + # P[16, 32] @ V[32, 16chunk] = O[16, 16chunk] + # Split K=32 into two halves: P_lo[16,16] @ V_top[16,16] + P_hi[16,16] @ V_bot[16,16] + + # Load P A-operand packs: P_lo and P_hi + p_a_lo_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(lane_mod_16) * BLOCK_N + + arith.ArithValue(lane_div_16) * 4 + ).value + p_pack_lo = arith.as_value( + vec_ext.load_op(v4f16_type, lds_p, [p_a_lo_idx]) + ) + + p_a_hi_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(lane_mod_16) * BLOCK_N + + 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + p_pack_hi = arith.as_value( + vec_ext.load_op(v4f16_type, lds_p, [p_a_hi_idx]) + ) + + for ds in range_constexpr(K_STEPS): + # V_top: V rows 0..15, B-operand from transposed LDS + v_top_idx = ( + (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE + + arith.ArithValue(lane_div_16) * 4 + ).value + v_top = arith.as_value( + vec_ext.load_op(v4f16_type, lds_kv, [v_top_idx]) + ) + # Accumulate P_lo @ V_top + o_accs[ds] = arith.as_value( + rocdl.mfma_f32_16x16x16f16( + v4f32_type, [p_pack_lo, v_top, o_accs[ds], 0, 0, 0] + ) + ) + + # V_bot: V rows 16..31, B-operand from transposed LDS + v_bot_idx = ( + (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE + + 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + v_bot = arith.as_value( + vec_ext.load_op(v4f16_type, lds_kv, [v_bot_idx]) + ) + # Accumulate P_hi @ V_bot + o_accs[ds] = arith.as_value( + rocdl.mfma_f32_16x16x16f16( + v4f32_type, [p_pack_hi, v_bot, o_accs[ds], 0, 0, 0] + ) + ) + + # ==== Barrier: ensure all waves done reading V ==== + gpu.barrier() + + # ==== Yield ==== + yield_args = m_new + l_new + o_accs + scf_yield(yield_args) + + # ---- Normalize and store O ---- + m_finals = [arith.as_value(loop.results[i]) for i in range(4)] + l_finals = [arith.as_value(loop.results[4 + i]) for i in range(4)] + o_finals = [arith.as_value(loop.results[8 + ds]) for ds in range(K_STEPS)] + + for ds in range_constexpr(K_STEPS): + for ii in range_constexpr(4): + o_val = arith.as_value( + vec_ext.extract(o_finals[ds], static_position=[ii], dynamic_position=[]) + ) + o_norm = arith.as_value( + flir.arith.DivFOp(o_val, l_finals[ii], fastmath=fm_fast).result + ) + o_f16 = arith.as_value( + flir.arith.TruncFOp(elem_type, o_norm).result + ) + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_div_16) * 4 + + ii + ).value + d_col = (flir.const_index(ds * 16) + arith.ArithValue(lane_mod_16)).value + o_global = global_idx(q_row, d_col) + _memref.StoreOp(o_f16, O, [o_global]) + + @flir.jit + def __call__( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + batch_size: lambda: T.index(), + seq_len: lambda: T.index(), + ): + c1 = arith.as_value(flir.arith_ext.index(1)) + c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) + c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) + bs_val = arith.as_value(batch_size) + sl_val = arith.as_value(seq_len) + num_q_tiles = arith.as_value( + flir.arith.DivUIOp(sl_val, c_bm).result + ) + bs_qt = arith.as_value( + flir.arith.MulIOp(bs_val, num_q_tiles).result + ) + grid_x = arith.as_value( + flir.arith.MulIOp(bs_qt, c_nh).result + ) + bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) + flir.gpu_ext.LaunchFuncOp( + [self.GPU_MODULE_NAME, KERNEL_NAME], + grid_size=(grid_x, c1, c1), + block_size=(bx, c1, c1), + kernel_operands=[Q, K, V, O, seq_len], + ) + + return _FlashAttentionV4_3() diff --git a/kernels/reduce.py b/kernels/reduce.py index d8313200..c784ff40 100644 --- a/kernels/reduce.py +++ b/kernels/reduce.py @@ -8,6 +8,50 @@ from flydsl.dialects.ext.python_control_flow import lower_range_for_loops +# --------------------------------------------------------------------------- +# Single-warp (wave64) shuffle reductions +# --------------------------------------------------------------------------- + +WAVE64_OFFSETS = [32, 16, 8, 4, 2, 1] + + +def warp_reduce_sum(val, *, gpu, arith, flir, T, fm_fast, WARP_SIZE=64): + """Single-warp (wave64) sum reduction via xor shuffle. + + Returns: scalar f32 value holding the warp-wide sum. + All lanes receive the same result after the final shuffle step. + """ + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + w = arith.as_value(val) + for sh in WAVE64_OFFSETS: + off = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(w, off, width_i32, mode="xor").shuffleResult + ) + w = arith.as_value( + flir.arith.AddFOp(w, peer, fastmath=fm_fast).result + ) + return w + + +def warp_reduce_max(val, *, gpu, arith, flir, T, WARP_SIZE=64): + """Single-warp (wave64) max reduction via xor shuffle. + + Returns: scalar f32 value holding the warp-wide max. + """ + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + w = arith.as_value(val) + for sh in WAVE64_OFFSETS: + off = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(w, off, width_i32, mode="xor").shuffleResult + ) + w = arith.as_value( + flir.arith.MaximumFOp(w, peer).result + ) + return w + + def reduce_vec_max(vec_val, *, VEC_WIDTH, compute_type, vector): if VEC_WIDTH == 1: return vector.extract(vec_val, static_position=[0], dynamic_position=[]) diff --git a/run.sh b/run.sh index 34eb7ba8..24ca1dfe 100755 --- a/run.sh +++ b/run.sh @@ -33,9 +33,13 @@ function run_flydsl_op { # python tests/kernels/test_moe_stage1_simple.py --size M - python tests/kernels/test_simple_gemm.py --size XL --waves_per_eu 1 - python tests/kernels/test_simple_gemm.py --size NA4 - python tests/kernels/test_simple_gemm.py --size all --dtype all + # python tests/kernels/test_simple_gemm.py --size XL --waves_per_eu 1 + # python tests/kernels/test_simple_gemm.py --size NA4 + # python tests/kernels/test_simple_gemm.py --size all --dtype all + + # python tests/kernels/test_flash_attention_v4_2.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 + python tests/kernels/test_flash_attention_v4_3.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v42 + } diff --git a/tests/kernels/test_flash_attention_v4.py b/tests/kernels/test_flash_attention_v4.py new file mode 100644 index 00000000..9f1d52d8 --- /dev/null +++ b/tests/kernels/test_flash_attention_v4.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 +"""Flash Attention V4 (Multi-Wave MFMA) kernel test and benchmark for FlyDSL. + +Tests the V4 multi-wave Flash Attention kernel against PyTorch SDPA reference. +Optionally compares performance with V3 kernels. + +Usage: + python tests/kernels/test_flash_attention_v4.py + python tests/kernels/test_flash_attention_v4.py --seq_len 512 --head_dim 128 + python tests/kernels/test_flash_attention_v4.py --compare-v3 +""" + +import sys +import os +import argparse +from pathlib import Path + +_repo = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(_repo)) + +try: + import torch + import torch.nn.functional as F +except ImportError: + print("PyTorch not available") + sys.exit(1) + +if not torch.cuda.is_available(): + print("CUDA/ROCm not available") + sys.exit(1) + +import flydsl +from kernels.flash_attention_v4 import build_flash_attention_v4_module, KERNEL_NAME + + +def pytorch_ref_attention(q, k, v, causal=True): + """PyTorch SDPA reference. q/k/v: (B, S, H, D) float32 -> (B, S, H, D).""" + q_t = q.transpose(1, 2).float() + k_t = k.transpose(1, 2).float() + v_t = v.transpose(1, 2).float() + out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=causal) + return out.transpose(1, 2) + + +def bench_gpu_us(fn, warmup=10, iters=50): + """Benchmark a GPU function, return average microseconds.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + fn() + end.record() + torch.cuda.synchronize() + return (start.elapsed_time(end) / iters) * 1000 + + +def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, + warmup, iters, v3_exe=None): + """Run one configuration. Returns dict with results.""" + device = "cuda" + results = {} + + # V4 requires seq_len divisible by BLOCK_M=64, head_dim by 16, head_dim >= 64 + if seq_len % 64 != 0: + results["err"] = f"seq_len ({seq_len}) must be divisible by 64" + return results + if head_dim % 16 != 0 or head_dim < 64: + results["err"] = f"head_dim ({head_dim}) must be >= 64 and divisible by 16" + return results + + try: + m = build_flash_attention_v4_module( + num_heads=num_heads, + head_dim=head_dim, + causal=causal, + dtype_str="f16", + ) + exe = flydsl.compile(m) + except Exception as e: + results["err"] = f"compile: {e}" + import traceback + traceback.print_exc() + return results + + B, S, H, D = batch, seq_len, num_heads, head_dim + q_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) + k_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) + v_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) + + q_flat = q_4d.contiguous().view(-1) + k_flat = k_4d.contiguous().view(-1) + v_flat = v_4d.contiguous().view(-1) + o_flat = torch.zeros_like(q_flat) + + try: + exe(q_flat, k_flat, v_flat, o_flat, B, S) + torch.cuda.synchronize() + except Exception as e: + results["err"] = f"exec: {e}" + import traceback + traceback.print_exc() + return results + + # PyTorch reference + ref_4d = pytorch_ref_attention( + q_4d.float(), k_4d.float(), v_4d.float(), causal=causal + ).to(dtype) + ref_flat = ref_4d.contiguous().view(-1) + + # Correctness + o_f32 = o_flat.float() + ref_f32 = ref_flat.float() + max_err = (o_f32 - ref_f32).abs().max().item() + mean_err = (o_f32 - ref_f32).abs().mean().item() + cos_sim = F.cosine_similarity( + o_f32.view(-1, D), ref_f32.view(-1, D), dim=1 + ) + min_cos = cos_sim.min().item() + results["max_err"] = max_err + results["mean_err"] = mean_err + results["min_cos"] = min_cos + + atol = 1e-2 + results["passed"] = max_err < atol and min_cos > 0.99 + + # Benchmark V4 + try: + def kernel_fn(): + o_flat.zero_() + exe(q_flat, k_flat, v_flat, o_flat, B, S) + + us = bench_gpu_us(kernel_fn, warmup=warmup, iters=iters) + s_eff = S / 2.0 if causal else float(S) + flops = 4.0 * S * s_eff * D * H * B + tflops = flops / (us * 1e-6) / 1e12 + results["us"] = us + results["tflops"] = tflops + except Exception as e: + results["bench_err"] = str(e) + + # Benchmark V3 for comparison + if v3_exe is not None: + try: + o_v3 = torch.zeros_like(q_flat) + + def v3_fn(): + o_v3.zero_() + v3_exe(q_flat, k_flat, v_flat, o_v3, B, S) + + v3_us = bench_gpu_us(v3_fn, warmup=warmup, iters=iters) + v3_tflops = flops / (v3_us * 1e-6) / 1e12 + results["v3_us"] = v3_us + results["v3_tflops"] = v3_tflops + except Exception as e: + results["v3_bench_err"] = str(e) + + return results + + +def main(): + parser = argparse.ArgumentParser( + description="Flash Attention V4 (Multi-Wave MFMA) FlyDSL Test/Benchmark" + ) + parser.add_argument("--batch", type=int, default=None) + parser.add_argument("--seq_len", type=int, default=None) + parser.add_argument("--num_heads", type=int, default=None) + parser.add_argument("--head_dim", type=int, default=None) + parser.add_argument( + "--dtype", type=str, default="fp16", choices=["fp16"] + ) + parser.add_argument("--no-causal", action="store_true") + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--iters", type=int, default=20) + parser.add_argument("--compare-v3", action="store_true", + help="Also benchmark V3 for comparison") + args = parser.parse_args() + + causal = not args.no_causal + dtype = torch.float16 + causal_str = "causal" if causal else "non-causal" + + print("=" * 120) + print(f"FlyDSL Flash Attention V4 Multi-Wave MFMA ({causal_str}, fp16)") + print(f" BLOCK_M=64, BLOCK_N=16, 4 waves (256 threads), mfma_f32_16x16x16f16") + print(f"GPU: {torch.cuda.get_device_name(0)}") + print("=" * 120) + + if args.seq_len or args.head_dim or args.batch: + configs = [( + args.batch or 1, + args.seq_len or 128, + args.num_heads or 8, + args.head_dim or 128, + )] + else: + configs = [ + (1, 64, 8, 128), + (1, 128, 8, 128), + (1, 256, 32, 128), + (1, 512, 32, 128), + (2, 128, 8, 128), + ] + + # Pre-compile V3 if comparing + v3_exes = {} + if args.compare_v3: + from kernels.flash_attention_v3 import build_flash_attention_v3_module + for _, _, nh, hd in configs: + key = (nh, hd) + if key not in v3_exes: + try: + m = build_flash_attention_v3_module( + num_heads=nh, head_dim=hd, + causal=causal, dtype_str="f16", + ) + v3_exes[key] = flydsl.compile(m) + except Exception: + v3_exes[key] = None + + if args.compare_v3: + hdr = ( + f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'MinCos':>8s} | {'V4(us)':>10s} {'V4 TFLOPS':>9s} | " + f"{'V3(us)':>10s} {'V3 TFLOPS':>9s} | {'Speedup':>7s}" + ) + else: + hdr = ( + f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'MinCos':>8s} | {'Time(us)':>10s} {'TFLOPS':>8s}" + ) + print(f"\n{hdr}") + print("-" * len(hdr)) + + all_passed = True + for batch, seq_len, nh, hd in configs: + tag = f"B={batch} S={seq_len} H={nh} D={hd}" + try: + v3_exe = v3_exes.get((nh, hd)) if args.compare_v3 else None + r = run_config( + batch, seq_len, nh, hd, dtype, causal, + warmup=args.warmup, iters=args.iters, + v3_exe=v3_exe, + ) + if "err" in r: + print(f"{tag:>38s} | {'ERROR':>6s} | {r['err'][:60]}") + all_passed = False + continue + + status = "PASS" if r["passed"] else "FAIL" + if not r["passed"]: + all_passed = False + + v4_us = f"{r['us']:>10.1f}" if "us" in r else " N/A" + v4_tf = f"{r['tflops']:>9.3f}" if "tflops" in r else " N/A" + + if args.compare_v3 and "v3_us" in r: + v3_us = f"{r['v3_us']:>10.1f}" + v3_tf = f"{r['v3_tflops']:>9.3f}" + speedup = r["v3_us"] / r["us"] if r.get("us") else 0 + sp_s = f"{speedup:>6.2f}x" + print( + f"{tag:>38s} | {status:>6s} | " + f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " + f"{v4_us} {v4_tf} | {v3_us} {v3_tf} | {sp_s}" + ) + else: + print( + f"{tag:>38s} | {status:>6s} | " + f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " + f"{v4_us} {v4_tf}" + ) + except Exception as e: + print(f"{tag:>38s} | {'ERROR':>6s} | {str(e)[:60]}") + all_passed = False + + print("=" * 120) + if all_passed: + print("All tests PASSED") + else: + print("Some tests FAILED") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/kernels/test_flash_attention_v4_1.py b/tests/kernels/test_flash_attention_v4_1.py new file mode 100644 index 00000000..5763d60d --- /dev/null +++ b/tests/kernels/test_flash_attention_v4_1.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python3 +"""Flash Attention V4.1 kernel test and benchmark for FlyDSL. + +Tests V4.1 (Q-in-registers, transposed V, bank-conflict-free padding) against +PyTorch SDPA reference. Optionally compares with V4.0. + +Usage: + python tests/kernels/test_flash_attention_v4_1.py + python tests/kernels/test_flash_attention_v4_1.py --seq_len 512 --head_dim 128 + python tests/kernels/test_flash_attention_v4_1.py --compare-v4 +""" + +import sys +import os +import argparse +from pathlib import Path + +_repo = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(_repo)) + +try: + import torch + import torch.nn.functional as F +except ImportError: + print("PyTorch not available") + sys.exit(1) + +if not torch.cuda.is_available(): + print("CUDA/ROCm not available") + sys.exit(1) + +import flydsl +from kernels.flash_attention_v4_1 import build_flash_attention_v4_1_module, KERNEL_NAME + + +def pytorch_ref_attention(q, k, v, causal=True): + """PyTorch SDPA reference. q/k/v: (B, S, H, D) float32 -> (B, S, H, D).""" + q_t = q.transpose(1, 2).float() + k_t = k.transpose(1, 2).float() + v_t = v.transpose(1, 2).float() + out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=causal) + return out.transpose(1, 2) + + +def bench_gpu_us(fn, warmup=10, iters=50): + """Benchmark a GPU function, return average microseconds.""" + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + fn() + end.record() + torch.cuda.synchronize() + return (start.elapsed_time(end) / iters) * 1000 + + +def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, + warmup, iters, v4_exe=None): + """Run one configuration. Returns dict with results.""" + device = "cuda" + results = {} + + if seq_len % 64 != 0: + results["err"] = f"seq_len ({seq_len}) must be divisible by 64" + return results + if head_dim % 16 != 0 or head_dim < 64: + results["err"] = f"head_dim ({head_dim}) must be >= 64 and divisible by 16" + return results + + try: + m = build_flash_attention_v4_1_module( + num_heads=num_heads, + head_dim=head_dim, + causal=causal, + dtype_str="f16", + ) + exe = flydsl.compile(m) + except Exception as e: + results["err"] = f"compile: {e}" + import traceback + traceback.print_exc() + return results + + B, S, H, D = batch, seq_len, num_heads, head_dim + q_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) + k_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) + v_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) + + q_flat = q_4d.contiguous().view(-1) + k_flat = k_4d.contiguous().view(-1) + v_flat = v_4d.contiguous().view(-1) + o_flat = torch.zeros_like(q_flat) + + try: + exe(q_flat, k_flat, v_flat, o_flat, B, S) + torch.cuda.synchronize() + except Exception as e: + results["err"] = f"exec: {e}" + import traceback + traceback.print_exc() + return results + + # PyTorch reference + ref_4d = pytorch_ref_attention( + q_4d.float(), k_4d.float(), v_4d.float(), causal=causal + ).to(dtype) + ref_flat = ref_4d.contiguous().view(-1) + + # Correctness + o_f32 = o_flat.float() + ref_f32 = ref_flat.float() + max_err = (o_f32 - ref_f32).abs().max().item() + mean_err = (o_f32 - ref_f32).abs().mean().item() + cos_sim = F.cosine_similarity( + o_f32.view(-1, D), ref_f32.view(-1, D), dim=1 + ) + min_cos = cos_sim.min().item() + results["max_err"] = max_err + results["mean_err"] = mean_err + results["min_cos"] = min_cos + + atol = 1e-2 + results["passed"] = max_err < atol and min_cos > 0.99 + + # Benchmark V4.1 + try: + def kernel_fn(): + o_flat.zero_() + exe(q_flat, k_flat, v_flat, o_flat, B, S) + + us = bench_gpu_us(kernel_fn, warmup=warmup, iters=iters) + s_eff = S / 2.0 if causal else float(S) + flops = 4.0 * S * s_eff * D * H * B + tflops = flops / (us * 1e-6) / 1e12 + results["us"] = us + results["tflops"] = tflops + except Exception as e: + results["bench_err"] = str(e) + + # Benchmark V4.0 for comparison + if v4_exe is not None: + try: + o_v4 = torch.zeros_like(q_flat) + + def v4_fn(): + o_v4.zero_() + v4_exe(q_flat, k_flat, v_flat, o_v4, B, S) + + v4_us = bench_gpu_us(v4_fn, warmup=warmup, iters=iters) + v4_tflops = flops / (v4_us * 1e-6) / 1e12 + results["v4_us"] = v4_us + results["v4_tflops"] = v4_tflops + except Exception as e: + results["v4_bench_err"] = str(e) + + return results + + +def main(): + parser = argparse.ArgumentParser( + description="Flash Attention V4.1 FlyDSL Test/Benchmark" + ) + parser.add_argument("--batch", type=int, default=None) + parser.add_argument("--seq_len", type=int, default=None) + parser.add_argument("--num_heads", type=int, default=None) + parser.add_argument("--head_dim", type=int, default=None) + parser.add_argument( + "--dtype", type=str, default="fp16", choices=["fp16"] + ) + parser.add_argument("--no-causal", action="store_true") + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--iters", type=int, default=20) + parser.add_argument("--compare-v4", action="store_true", + help="Also benchmark V4.0 for comparison") + args = parser.parse_args() + + causal = not args.no_causal + dtype = torch.float16 + causal_str = "causal" if causal else "non-causal" + + print("=" * 130) + print(f"FlyDSL Flash Attention V4.1 ({causal_str}, fp16)") + print(f" Q-in-registers, transposed V (vectorized), bank-conflict-free LDS padding") + print(f" BLOCK_M=64, BLOCK_N=16, 4 waves (256 threads), mfma_f32_16x16x16f16") + print(f"GPU: {torch.cuda.get_device_name(0)}") + print("=" * 130) + + if args.seq_len or args.head_dim or args.batch: + configs = [( + args.batch or 1, + args.seq_len or 128, + args.num_heads or 8, + args.head_dim or 128, + )] + else: + configs = [ + (1, 64, 8, 128), + (1, 128, 8, 128), + (1, 256, 32, 128), + (1, 512, 32, 128), + (2, 128, 8, 128), + ] + + # Pre-compile V4.0 if comparing + v4_exes = {} + if args.compare_v4: + from kernels.flash_attention_v4 import build_flash_attention_v4_module + for _, _, nh, hd in configs: + key = (nh, hd) + if key not in v4_exes: + try: + m = build_flash_attention_v4_module( + num_heads=nh, head_dim=hd, + causal=causal, dtype_str="f16", + ) + v4_exes[key] = flydsl.compile(m) + except Exception: + v4_exes[key] = None + + if args.compare_v4: + hdr = ( + f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'MinCos':>8s} | {'V4.1(us)':>10s} {'V4.1 TF':>9s} | " + f"{'V4.0(us)':>10s} {'V4.0 TF':>9s} | {'Speedup':>7s}" + ) + else: + hdr = ( + f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'MinCos':>8s} | {'Time(us)':>10s} {'TFLOPS':>8s}" + ) + print(f"\n{hdr}") + print("-" * len(hdr)) + + all_passed = True + for batch, seq_len, nh, hd in configs: + tag = f"B={batch} S={seq_len} H={nh} D={hd}" + try: + v4_exe = v4_exes.get((nh, hd)) if args.compare_v4 else None + r = run_config( + batch, seq_len, nh, hd, dtype, causal, + warmup=args.warmup, iters=args.iters, + v4_exe=v4_exe, + ) + if "err" in r: + print(f"{tag:>38s} | {'ERROR':>6s} | {r['err'][:60]}") + all_passed = False + continue + + status = "PASS" if r["passed"] else "FAIL" + if not r["passed"]: + all_passed = False + + v41_us = f"{r['us']:>10.1f}" if "us" in r else " N/A" + v41_tf = f"{r['tflops']:>9.3f}" if "tflops" in r else " N/A" + + if args.compare_v4 and "v4_us" in r: + v4_us = f"{r['v4_us']:>10.1f}" + v4_tf = f"{r['v4_tflops']:>9.3f}" + speedup = r["v4_us"] / r["us"] if r.get("us") else 0 + sp_s = f"{speedup:>6.2f}x" + print( + f"{tag:>38s} | {status:>6s} | " + f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " + f"{v41_us} {v41_tf} | {v4_us} {v4_tf} | {sp_s}" + ) + else: + print( + f"{tag:>38s} | {status:>6s} | " + f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " + f"{v41_us} {v41_tf}" + ) + except Exception as e: + print(f"{tag:>38s} | {'ERROR':>6s} | {str(e)[:60]}") + all_passed = False + + print("=" * 130) + if all_passed: + print("All tests PASSED") + else: + print("Some tests FAILED") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/kernels/test_flash_attention_v4_2.py b/tests/kernels/test_flash_attention_v4_2.py new file mode 100644 index 00000000..debba7a0 --- /dev/null +++ b/tests/kernels/test_flash_attention_v4_2.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +"""Flash Attention V4.2 kernel test and benchmark for FlyDSL. + +Tests V4.2 (BLOCK_N=32, transposed V, Q-in-registers) against PyTorch SDPA. +Optionally compares with V4.1. + +Usage: + python tests/kernels/test_flash_attention_v4_2.py + python tests/kernels/test_flash_attention_v4_2.py --seq_len 512 --head_dim 128 + python tests/kernels/test_flash_attention_v4_2.py --compare-v41 +""" + +import sys +import argparse +from pathlib import Path + +_repo = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(_repo)) + +try: + import torch + import torch.nn.functional as F +except ImportError: + print("PyTorch not available") + sys.exit(1) + +if not torch.cuda.is_available(): + print("CUDA/ROCm not available") + sys.exit(1) + +import flydsl +from kernels.flash_attention_v4_2 import build_flash_attention_v4_2_module, KERNEL_NAME + + +def pytorch_ref_attention(q, k, v, causal=True): + q_t = q.transpose(1, 2).float() + k_t = k.transpose(1, 2).float() + v_t = v.transpose(1, 2).float() + out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=causal) + return out.transpose(1, 2) + + +def bench_gpu_us(fn, warmup=10, iters=50): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + fn() + end.record() + torch.cuda.synchronize() + return (start.elapsed_time(end) / iters) * 1000 + + +def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, + warmup, iters, prev_exe=None): + device = "cuda" + results = {} + + if seq_len % 64 != 0: + results["err"] = f"seq_len ({seq_len}) must be divisible by 64" + return results + if head_dim % 16 != 0 or head_dim < 64: + results["err"] = f"head_dim ({head_dim}) must be >= 64 and divisible by 16" + return results + + try: + m = build_flash_attention_v4_2_module( + num_heads=num_heads, head_dim=head_dim, + causal=causal, dtype_str="f16", + ) + exe = flydsl.compile(m) + except Exception as e: + results["err"] = f"compile: {e}" + import traceback + traceback.print_exc() + return results + + B, S, H, D = batch, seq_len, num_heads, head_dim + q_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) + k_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) + v_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) + + q_flat = q_4d.contiguous().view(-1) + k_flat = k_4d.contiguous().view(-1) + v_flat = v_4d.contiguous().view(-1) + o_flat = torch.zeros_like(q_flat) + + try: + exe(q_flat, k_flat, v_flat, o_flat, B, S) + torch.cuda.synchronize() + except Exception as e: + results["err"] = f"exec: {e}" + import traceback + traceback.print_exc() + return results + + ref_4d = pytorch_ref_attention( + q_4d.float(), k_4d.float(), v_4d.float(), causal=causal + ).to(dtype) + ref_flat = ref_4d.contiguous().view(-1) + + o_f32 = o_flat.float() + ref_f32 = ref_flat.float() + max_err = (o_f32 - ref_f32).abs().max().item() + mean_err = (o_f32 - ref_f32).abs().mean().item() + cos_sim = F.cosine_similarity( + o_f32.view(-1, D), ref_f32.view(-1, D), dim=1 + ) + min_cos = cos_sim.min().item() + results["max_err"] = max_err + results["mean_err"] = mean_err + results["min_cos"] = min_cos + results["passed"] = max_err < 1e-2 and min_cos > 0.99 + + try: + def kernel_fn(): + o_flat.zero_() + exe(q_flat, k_flat, v_flat, o_flat, B, S) + + us = bench_gpu_us(kernel_fn, warmup=warmup, iters=iters) + s_eff = S / 2.0 if causal else float(S) + flops = 4.0 * S * s_eff * D * H * B + tflops = flops / (us * 1e-6) / 1e12 + results["us"] = us + results["tflops"] = tflops + except Exception as e: + results["bench_err"] = str(e) + + if prev_exe is not None: + try: + o_prev = torch.zeros_like(q_flat) + def prev_fn(): + o_prev.zero_() + prev_exe(q_flat, k_flat, v_flat, o_prev, B, S) + prev_us = bench_gpu_us(prev_fn, warmup=warmup, iters=iters) + prev_tflops = flops / (prev_us * 1e-6) / 1e12 + results["prev_us"] = prev_us + results["prev_tflops"] = prev_tflops + except Exception as e: + results["prev_bench_err"] = str(e) + + return results + + +def main(): + parser = argparse.ArgumentParser( + description="Flash Attention V4.2 FlyDSL Test/Benchmark" + ) + parser.add_argument("--batch", type=int, default=None) + parser.add_argument("--seq_len", type=int, default=None) + parser.add_argument("--num_heads", type=int, default=None) + parser.add_argument("--head_dim", type=int, default=None) + parser.add_argument("--no-causal", action="store_true") + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--iters", type=int, default=20) + parser.add_argument("--compare-v41", action="store_true", + help="Also benchmark V4.1 for comparison") + args = parser.parse_args() + + causal = not args.no_causal + dtype = torch.float16 + + print("=" * 130) + print(f"FlyDSL Flash Attention V4.2 ({'causal' if causal else 'non-causal'}, fp16)") + print(f" BLOCK_N=32, Q-in-registers, transposed V, bank-conflict-free LDS") + print(f" BLOCK_M=64, 4 waves (256 threads), mfma_f32_16x16x16f16") + print(f"GPU: {torch.cuda.get_device_name(0)}") + print("=" * 130) + + if args.seq_len or args.head_dim or args.batch: + configs = [( + args.batch or 1, + args.seq_len or 128, + args.num_heads or 8, + args.head_dim or 128, + )] + else: + configs = [ + (1, 64, 8, 128), + (1, 128, 8, 128), + (1, 256, 32, 128), + (1, 512, 32, 128), + (2, 128, 8, 128), + ] + + prev_exes = {} + if args.compare_v41: + from kernels.flash_attention_v4_1 import build_flash_attention_v4_1_module + for _, _, nh, hd in configs: + key = (nh, hd) + if key not in prev_exes: + try: + m = build_flash_attention_v4_1_module( + num_heads=nh, head_dim=hd, + causal=causal, dtype_str="f16", + ) + prev_exes[key] = flydsl.compile(m) + except Exception: + prev_exes[key] = None + + if args.compare_v41: + hdr = ( + f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'MinCos':>8s} | {'V4.2(us)':>10s} {'V4.2 TF':>9s} | " + f"{'V4.1(us)':>10s} {'V4.1 TF':>9s} | {'Speedup':>7s}" + ) + else: + hdr = ( + f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'MinCos':>8s} | {'Time(us)':>10s} {'TFLOPS':>8s}" + ) + print(f"\n{hdr}") + print("-" * len(hdr)) + + all_passed = True + for batch, seq_len, nh, hd in configs: + tag = f"B={batch} S={seq_len} H={nh} D={hd}" + try: + prev_exe = prev_exes.get((nh, hd)) if args.compare_v41 else None + r = run_config( + batch, seq_len, nh, hd, dtype, causal, + warmup=args.warmup, iters=args.iters, + prev_exe=prev_exe, + ) + if "err" in r: + print(f"{tag:>38s} | {'ERROR':>6s} | {r['err'][:60]}") + all_passed = False + continue + + status = "PASS" if r["passed"] else "FAIL" + if not r["passed"]: + all_passed = False + + us_s = f"{r['us']:>10.1f}" if "us" in r else " N/A" + tf_s = f"{r['tflops']:>9.3f}" if "tflops" in r else " N/A" + + if args.compare_v41 and "prev_us" in r: + p_us = f"{r['prev_us']:>10.1f}" + p_tf = f"{r['prev_tflops']:>9.3f}" + speedup = r["prev_us"] / r["us"] if r.get("us") else 0 + print( + f"{tag:>38s} | {status:>6s} | " + f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " + f"{us_s} {tf_s} | {p_us} {p_tf} | {speedup:>6.2f}x" + ) + else: + print( + f"{tag:>38s} | {status:>6s} | " + f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " + f"{us_s} {tf_s}" + ) + except Exception as e: + print(f"{tag:>38s} | {'ERROR':>6s} | {str(e)[:60]}") + all_passed = False + + print("=" * 130) + if all_passed: + print("All tests PASSED") + else: + print("Some tests FAILED") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/kernels/test_flash_attention_v4_3.py b/tests/kernels/test_flash_attention_v4_3.py new file mode 100644 index 00000000..e317b802 --- /dev/null +++ b/tests/kernels/test_flash_attention_v4_3.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python3 +"""Flash Attention V4.3 kernel test and benchmark for FlyDSL. + +Tests V4.3 (LDS overlay) against PyTorch SDPA. +Optionally compares with V4.2. + +Usage: + python tests/kernels/test_flash_attention_v4_3.py + python tests/kernels/test_flash_attention_v4_3.py --seq_len 512 --head_dim 128 + python tests/kernels/test_flash_attention_v4_3.py --compare-v42 +""" + +import sys +import argparse +from pathlib import Path + +_repo = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(_repo)) + +try: + import torch + import torch.nn.functional as F +except ImportError: + print("PyTorch not available") + sys.exit(1) + +if not torch.cuda.is_available(): + print("CUDA/ROCm not available") + sys.exit(1) + +import flydsl +from kernels.flash_attention_v4_3 import build_flash_attention_v4_3_module, KERNEL_NAME + + +def pytorch_ref_attention(q, k, v, causal=True): + q_t = q.transpose(1, 2).float() + k_t = k.transpose(1, 2).float() + v_t = v.transpose(1, 2).float() + out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=causal) + return out.transpose(1, 2) + + +def bench_gpu_us(fn, warmup=10, iters=50): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(iters): + fn() + end.record() + torch.cuda.synchronize() + return (start.elapsed_time(end) / iters) * 1000 + + +def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, + warmup, iters, prev_exe=None): + device = "cuda" + results = {} + + if seq_len % 64 != 0: + results["err"] = f"seq_len ({seq_len}) must be divisible by 64" + return results + if head_dim % 16 != 0 or head_dim < 64: + results["err"] = f"head_dim ({head_dim}) must be >= 64 and divisible by 16" + return results + + try: + m = build_flash_attention_v4_3_module( + num_heads=num_heads, head_dim=head_dim, + causal=causal, dtype_str="f16", + ) + exe = flydsl.compile(m) + except Exception as e: + results["err"] = f"compile: {e}" + import traceback + traceback.print_exc() + return results + + B, S, H, D = batch, seq_len, num_heads, head_dim + q_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) + k_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) + v_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) + + q_flat = q_4d.contiguous().view(-1) + k_flat = k_4d.contiguous().view(-1) + v_flat = v_4d.contiguous().view(-1) + o_flat = torch.zeros_like(q_flat) + + try: + exe(q_flat, k_flat, v_flat, o_flat, B, S) + torch.cuda.synchronize() + except Exception as e: + results["err"] = f"exec: {e}" + import traceback + traceback.print_exc() + return results + + ref_4d = pytorch_ref_attention( + q_4d.float(), k_4d.float(), v_4d.float(), causal=causal + ).to(dtype) + ref_flat = ref_4d.contiguous().view(-1) + + o_f32 = o_flat.float() + ref_f32 = ref_flat.float() + max_err = (o_f32 - ref_f32).abs().max().item() + mean_err = (o_f32 - ref_f32).abs().mean().item() + cos_sim = F.cosine_similarity( + o_f32.view(-1, D), ref_f32.view(-1, D), dim=1 + ) + min_cos = cos_sim.min().item() + results["max_err"] = max_err + results["mean_err"] = mean_err + results["min_cos"] = min_cos + results["passed"] = max_err < 1e-2 and min_cos > 0.99 + + try: + def kernel_fn(): + o_flat.zero_() + exe(q_flat, k_flat, v_flat, o_flat, B, S) + + us = bench_gpu_us(kernel_fn, warmup=warmup, iters=iters) + s_eff = S / 2.0 if causal else float(S) + flops = 4.0 * S * s_eff * D * H * B + tflops = flops / (us * 1e-6) / 1e12 + results["us"] = us + results["tflops"] = tflops + except Exception as e: + results["bench_err"] = str(e) + + if prev_exe is not None: + try: + o_prev = torch.zeros_like(q_flat) + def prev_fn(): + o_prev.zero_() + prev_exe(q_flat, k_flat, v_flat, o_prev, B, S) + prev_us = bench_gpu_us(prev_fn, warmup=warmup, iters=iters) + prev_tflops = flops / (prev_us * 1e-6) / 1e12 + results["prev_us"] = prev_us + results["prev_tflops"] = prev_tflops + except Exception as e: + results["prev_bench_err"] = str(e) + + return results + + +def main(): + parser = argparse.ArgumentParser( + description="Flash Attention V4.3 FlyDSL Test/Benchmark" + ) + parser.add_argument("--batch", type=int, default=None) + parser.add_argument("--seq_len", type=int, default=None) + parser.add_argument("--num_heads", type=int, default=None) + parser.add_argument("--head_dim", type=int, default=None) + parser.add_argument("--no-causal", action="store_true") + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--iters", type=int, default=20) + parser.add_argument("--compare-v42", action="store_true", + help="Also benchmark V4.2 for comparison") + args = parser.parse_args() + + causal = not args.no_causal + dtype = torch.float16 + + print("=" * 130) + print(f"FlyDSL Flash Attention V4.3 ({'causal' if causal else 'non-causal'}, fp16)") + print(f" LDS overlay: Q space reused for KV+P (16KB vs 29KB)") + print(f" BLOCK_M=64, BLOCK_N=32, 4 waves (256 threads), mfma_f32_16x16x16f16") + print(f"GPU: {torch.cuda.get_device_name(0)}") + print("=" * 130) + + if args.seq_len or args.head_dim or args.batch: + configs = [( + args.batch or 1, + args.seq_len or 128, + args.num_heads or 8, + args.head_dim or 128, + )] + else: + configs = [ + (1, 64, 8, 128), + (1, 128, 8, 128), + (1, 256, 32, 128), + (1, 512, 32, 128), + (2, 128, 8, 128), + ] + + prev_exes = {} + if args.compare_v42: + from kernels.flash_attention_v4_2 import build_flash_attention_v4_2_module + for _, _, nh, hd in configs: + key = (nh, hd) + if key not in prev_exes: + try: + m = build_flash_attention_v4_2_module( + num_heads=nh, head_dim=hd, + causal=causal, dtype_str="f16", + ) + prev_exes[key] = flydsl.compile(m) + except Exception: + prev_exes[key] = None + + if args.compare_v42: + hdr = ( + f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'MinCos':>8s} | {'V4.3(us)':>10s} {'V4.3 TF':>9s} | " + f"{'V4.2(us)':>10s} {'V4.2 TF':>9s} | {'Speedup':>7s}" + ) + else: + hdr = ( + f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'MinCos':>8s} | {'Time(us)':>10s} {'TFLOPS':>8s}" + ) + print(f"\n{hdr}") + print("-" * len(hdr)) + + all_passed = True + for batch, seq_len, nh, hd in configs: + tag = f"B={batch} S={seq_len} H={nh} D={hd}" + try: + prev_exe = prev_exes.get((nh, hd)) if args.compare_v42 else None + r = run_config( + batch, seq_len, nh, hd, dtype, causal, + warmup=args.warmup, iters=args.iters, + prev_exe=prev_exe, + ) + if "err" in r: + print(f"{tag:>38s} | {'ERROR':>6s} | {r['err'][:60]}") + all_passed = False + continue + + status = "PASS" if r["passed"] else "FAIL" + if not r["passed"]: + all_passed = False + + us_s = f"{r['us']:>10.1f}" if "us" in r else " N/A" + tf_s = f"{r['tflops']:>9.3f}" if "tflops" in r else " N/A" + + if args.compare_v42 and "prev_us" in r: + p_us = f"{r['prev_us']:>10.1f}" + p_tf = f"{r['prev_tflops']:>9.3f}" + speedup = r["prev_us"] / r["us"] if r.get("us") else 0 + print( + f"{tag:>38s} | {status:>6s} | " + f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " + f"{us_s} {tf_s} | {p_us} {p_tf} | {speedup:>6.2f}x" + ) + else: + print( + f"{tag:>38s} | {status:>6s} | " + f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " + f"{us_s} {tf_s}" + ) + except Exception as e: + print(f"{tag:>38s} | {'ERROR':>6s} | {str(e)[:60]}") + all_passed = False + + print("=" * 130) + if all_passed: + print("All tests PASSED") + else: + print("Some tests FAILED") + sys.exit(1) + + +if __name__ == "__main__": + main() From 9390ffa99d067dfc95f76f795b7796ad2b71bae0 Mon Sep 17 00:00:00 2001 From: yanguahe Date: Thu, 12 Feb 2026 20:41:07 +0800 Subject: [PATCH 06/17] Add flash_attention_v4_4_kernel and it's test --- kernels/flash_attention_v4_4.py | 583 +++++++++++++++++++++ run.sh | 3 +- tests/kernels/test_flash_attention_v4_2.py | 164 +++++- tests/kernels/test_flash_attention_v4_3.py | 168 +++++- tests/kernels/test_flash_attention_v4_4.py | 392 ++++++++++++++ 5 files changed, 1269 insertions(+), 41 deletions(-) create mode 100644 kernels/flash_attention_v4_4.py create mode 100644 tests/kernels/test_flash_attention_v4_4.py diff --git a/kernels/flash_attention_v4_4.py b/kernels/flash_attention_v4_4.py new file mode 100644 index 00000000..7f76d8ae --- /dev/null +++ b/kernels/flash_attention_v4_4.py @@ -0,0 +1,583 @@ +"""Flash Attention V4.4 kernel builder for FlyDSL. + +V4.4 optimization over V4.3 (CK-aligned design): +- BLOCK_N=64 (vs 32): halves KV loop iterations. +- K loaded in chunks of kK0=32 along head_dim (K0_LOOPS inner iterations). + K_STRIDE = kK0 + 2 = 34 (was HEAD_DIM + 2 = 130). +- V loaded in chunks of kK1=32 (K1_LOOPS inner iterations). + VT_STRIDE = kK1 + 2 = 34. +- K/V prefetching: overlaps global loads with MFMA computation. + K[k0+1] is fetched while computing with K[k0]. + V[0] is fetched during last K computation. + V[k1+1] is fetched while computing with V[k1]. +- LDS reduced: max(64*34, 128*34) = 4352 elem = 8.5KB + 8KB(P) = 16.5KB. +- Softmax over 64 positions (four 16-wide groups). +- Causal early-exit retained. + +Tile config: BLOCK_M=64, BLOCK_N=64, kK0=32, kK1=32, + 4 waves (256 threads), mfma_f32_16x16x16f16. + +Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). +Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. +Block: (256,) -- 4 waves of 64 on AMD (wave64). + +Requires: head_dim % 32 == 0, seq_len % 64 == 0, head_dim >= 64. +""" + +import math + +from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl +from flydsl.dialects.ext import vector as vec_ext +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.dialects.ext.scf import yield_ as scf_yield +from _mlir.dialects import memref as _memref +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils import SmemAllocator +from _mlir import ir +import _mlir.extras.types as T + + +KERNEL_NAME = "flash_attention_v4_4_kernel" + + +def build_flash_attention_v4_4_module( + num_heads, + head_dim, + causal=True, + dtype_str="f16", + sm_scale=None, +): + gpu_arch = get_hip_arch() + DYN = ir.ShapedType.get_dynamic_size() + + BLOCK_M = 64 + BLOCK_N = 64 + NUM_WAVES = 4 + WARP_SIZE = 64 + BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 + ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 16 + K_STEPS = head_dim // 16 + + N_MFMA = BLOCK_N // 16 # 4 + + kK0 = 32 + kK1 = 32 + K0_LOOPS = head_dim // kK0 + K1_LOOPS = BLOCK_N // kK1 # 2 + K_STEPS_PER_CHUNK = kK0 // 16 # 2 + + assert head_dim % kK0 == 0 + assert BLOCK_N % kK1 == 0 + assert head_dim % 16 == 0 + assert head_dim >= 64 + assert dtype_str == "f16" + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + NUM_HEADS = num_heads + HEAD_DIM = head_dim + CAUSAL = causal + STRIDE_TOKEN = NUM_HEADS * HEAD_DIM + + K_STRIDE = kK0 + 2 # 34 + VT_STRIDE = kK1 + 2 # 34 + + VEC_WIDTH = 8 + K_THREADS_PER_ROW = kK0 // VEC_WIDTH # 4 + V_THREADS_PER_ROW = HEAD_DIM // VEC_WIDTH # 16 + V_ROWS_PER_BATCH = BLOCK_SIZE // V_THREADS_PER_ROW # 16 + NUM_BATCHES_V = kK1 // V_ROWS_PER_BATCH # 2 + + LDS_KV_SIZE = max(BLOCK_N * K_STRIDE, HEAD_DIM * VT_STRIDE) + LDS_P_SIZE = BLOCK_M * BLOCK_N + + allocator = SmemAllocator(None, arch=gpu_arch) + _state = {} + + class _FlashAttentionV4_4(flir.MlirModule): + GPU_MODULE_NAME = f"flash_attn_v4_4_{dtype_str}" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + elem_type = T.f16() + _state["elem_type"] = elem_type + _state["lds_kv"] = allocator.allocate_array(elem_type, LDS_KV_SIZE) + _state["lds_p"] = allocator.allocate_array(elem_type, LDS_P_SIZE) + allocator.finalize() + + @flir.kernel + def flash_attention_v4_4_kernel( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + seq_len: lambda: T.index(), + ): + compute_type = T.f32() + elem_type = _state["elem_type"] + fm_fast = flir.arith.FastMathFlags.fast + + v4f16_type = ir.VectorType.get([4], elem_type) + v4f32_type = ir.VectorType.get([4], compute_type) + v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) + + seq_len_v = arith.as_value(seq_len) + + base_ptr = allocator.get_base() + lds_kv = _state["lds_kv"](base_ptr).get() + lds_p = _state["lds_p"](base_ptr).get() + + block_id = flir.const_index(flir.block_idx("x")) + tid = flir.const_index(flir.thread_idx("x")) + + c_ws = flir.const_index(WARP_SIZE) + wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) + lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) + + c16 = flir.const_index(16) + lane_div_16 = arith.as_value(flir.arith.DivUIOp(lane, c16).result) + lane_mod_16 = arith.as_value(flir.arith.RemUIOp(lane, c16).result) + + wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value + wave_p_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE * BLOCK_N).value + + c_nh = flir.const_index(NUM_HEADS) + head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) + temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) + c_bm = flir.const_index(BLOCK_M) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) + q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) + batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) + q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value + + # K load decomposition (4 threads/row for kK0=32) + c_ktpr = flir.const_index(K_THREADS_PER_ROW) + k_load_row = arith.as_value(flir.arith.DivUIOp(tid, c_ktpr).result) + k_load_col_lane = arith.as_value(flir.arith.RemUIOp(tid, c_ktpr).result) + k_load_col_base = (arith.ArithValue(k_load_col_lane) * VEC_WIDTH).value + + # V load decomposition (16 threads/row for HEAD_DIM) + c_vtpr = flir.const_index(V_THREADS_PER_ROW) + v_load_row_in_batch = arith.as_value(flir.arith.DivUIOp(tid, c_vtpr).result) + v_load_col_base = ( + arith.ArithValue(flir.arith.RemUIOp(tid, c_vtpr).result) * VEC_WIDTH + ).value + + def global_idx(token_idx, col): + token = ( + arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) + + arith.ArithValue(token_idx) + ) + return ( + token * STRIDE_TOKEN + + arith.ArithValue(head_idx) * HEAD_DIM + + arith.ArithValue(col) + ).value + + # ---- Prefetch helpers: separate load-to-regs from store-to-lds ---- + + def load_k_to_regs(tile_start, k0_col_offset): + """Issue global load for K chunk → returns v8f16 register.""" + row_idx = (arith.ArithValue(tile_start) + arith.ArithValue(k_load_row)).value + col_idx = (flir.const_index(k0_col_offset) + arith.ArithValue(k_load_col_base)).value + g_idx = global_idx(row_idx, col_idx) + return arith.as_value(vec_ext.load_op(v8f16_type, K, [g_idx])) + + def store_k_regs_to_lds(k_reg): + """Store K register data to LDS_KV.""" + lds_idx = ( + arith.ArithValue(k_load_row) * K_STRIDE + + arith.ArithValue(k_load_col_base) + ).value + vec_ext.store(k_reg, lds_kv, [lds_idx]) + + def load_v_to_regs(tile_start, k1_row_offset): + """Issue global loads for V chunk → returns list of v8f16.""" + regs = [] + for batch in range_constexpr(NUM_BATCHES_V): + row_offset = batch * V_ROWS_PER_BATCH + row_idx = ( + arith.ArithValue(tile_start) + k1_row_offset + + arith.ArithValue(v_load_row_in_batch) + row_offset + ).value + g_idx = global_idx(row_idx, v_load_col_base) + regs.append(arith.as_value(vec_ext.load_op(v8f16_type, V, [g_idx]))) + return regs + + def store_v_regs_to_lds(v_regs): + """Scatter-store V registers transposed to LDS_KV.""" + for batch_idx in range_constexpr(NUM_BATCHES_V): + vec = v_regs[batch_idx] + load_row = ( + arith.ArithValue(v_load_row_in_batch) + batch_idx * V_ROWS_PER_BATCH + ).value + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(v_load_col_base) + e).value + lds_idx = ( + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + + # ---- Load Q to registers (once) ---- + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_mod_16) + ).value + q_a_packs = [] + for ks in range_constexpr(K_STEPS): + q_col = flir.const_index(ks * 16) + q_col = (arith.ArithValue(q_col) + arith.ArithValue(lane_div_16) * 4).value + g_idx = global_idx(q_row, q_col) + q_a_packs.append(arith.as_value( + vec_ext.load_op(v4f16_type, Q, [g_idx]) + )) + + # ---- Constants ---- + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_zero_f = arith.constant(0.0, type=compute_type) + c_sm_scale = arith.constant(sm_scale, type=compute_type) + c_log2e = arith.constant(1.4426950408889634, type=compute_type) + c_zero_v4f32 = arith.as_value(arith.constant_vector(0.0, v4f32_type)) + + init_args = [] + for _ in range_constexpr(4): + init_args.append(arith.as_value(c_neg_inf)) + for _ in range_constexpr(4): + init_args.append(arith.as_value(c_zero_f)) + for _ in range_constexpr(K_STEPS): + init_args.append(c_zero_v4f32) + + if CAUSAL: + kv_upper = (arith.ArithValue(q_start) + BLOCK_M).value + else: + kv_upper = seq_len_v + + # ================================================================ + # KV loop (step BLOCK_N=64) + # ================================================================ + with scf.for_(0, kv_upper, BLOCK_N, iter_args=init_args) as loop: + kv_start = arith.as_value(loop.induction_variable) + m_old = [arith.as_value(loop.inner_iter_args[i]) for i in range(4)] + l_old = [arith.as_value(loop.inner_iter_args[4 + i]) for i in range(4)] + o_accs = [arith.as_value(loop.inner_iter_args[8 + ds]) for ds in range(K_STEPS)] + + s_accs = [c_zero_v4f32] * N_MFMA + + # ======================================================== + # QK GEMM Phase with K prefetching + # ======================================================== + # K[0]: load → store → prefetch K[1] + k_reg = load_k_to_regs(kv_start, 0) + store_k_regs_to_lds(k_reg) + k_reg = load_k_to_regs(kv_start, kK0) # Prefetch K[1] + + for k0 in range_constexpr(K0_LOOPS): + gpu.barrier() # K[k0] in LDS + + # Compute QK with K[k0] + for local_ks in range_constexpr(K_STEPS_PER_CHUNK): + global_ks = k0 * K_STEPS_PER_CHUNK + local_ks + a_pack = q_a_packs[global_ks] + for nm in range_constexpr(N_MFMA): + k_lds_idx = ( + (arith.ArithValue(lane_mod_16) + nm * 16) * K_STRIDE + + local_ks * 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + b_pack = arith.as_value( + vec_ext.load_op(v4f16_type, lds_kv, [k_lds_idx]) + ) + s_accs[nm] = arith.as_value( + rocdl.mfma_f32_16x16x16f16( + v4f32_type, [a_pack, b_pack, s_accs[nm], 0, 0, 0] + ) + ) + + if k0 < K0_LOOPS - 1: + gpu.barrier() # All reads of K[k0] done + + # Store prefetched K[k0+1] to LDS + store_k_regs_to_lds(k_reg) + + # Prefetch K[k0+2] or V[0] + if k0 + 2 < K0_LOOPS: + k_reg = load_k_to_regs(kv_start, (k0 + 2) * kK0) + if k0 == K0_LOOPS - 2: + # Last K store: also prefetch V[0] + v_regs = load_v_to_regs(kv_start, 0) + + # After last QK compute: no barrier here yet + # (softmax is register-only, no LDS_KV conflict) + + # ======================================================== + # Online Softmax over 64 positions + # ======================================================== + s_vals = [[] for _ in range(N_MFMA)] + for ii in range_constexpr(4): + for nm in range_constexpr(N_MFMA): + s_val = arith.as_value( + vec_ext.extract(s_accs[nm], static_position=[ii], dynamic_position=[]) + ) + s_val = arith.as_value( + flir.arith.MulFOp(s_val, arith.as_value(c_sm_scale), fastmath=fm_fast).result + ) + if CAUSAL: + q_row_c = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_div_16) * 4 + + ii + ).value + kv_col = ( + arith.ArithValue(kv_start) + nm * 16 + + arith.ArithValue(lane_mod_16) + ).value + q_row_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), q_row_c).result) + kv_col_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), kv_col).result) + is_masked = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ugt, kv_col_i64, q_row_i64, + ).result + ) + s_val = arith.as_value( + flir.arith.SelectOp(is_masked, arith.as_value(c_neg_inf), s_val).result + ) + s_vals[nm].append(s_val) + + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + m_new = [None] * 4 + corr = [None] * 4 + p_vals = [[None] * 4 for _ in range(N_MFMA)] + l_new = [None] * 4 + + for ii in range_constexpr(4): + row_maxes = [] + for nm in range_constexpr(N_MFMA): + row_max_nm = s_vals[nm][ii] + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(row_max_nm, sh_i32, width_i32, mode="xor").shuffleResult + ) + row_max_nm = arith.as_value( + flir.arith.MaximumFOp(row_max_nm, peer).result + ) + row_maxes.append(row_max_nm) + + combined_max = row_maxes[0] + for _g in range_constexpr(N_MFMA - 1): + combined_max = arith.as_value( + flir.arith.MaximumFOp(combined_max, row_maxes[_g + 1]).result + ) + + m_new[ii] = arith.as_value( + flir.arith.MaximumFOp(m_old[ii], combined_max).result + ) + + diff_m = arith.as_value( + flir.arith.SubFOp(m_old[ii], m_new[ii], fastmath=fm_fast).result + ) + diff_m_s = arith.as_value( + flir.arith.MulFOp(diff_m, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + corr[ii] = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) + + row_sums = [] + for nm in range_constexpr(N_MFMA): + diff = arith.as_value( + flir.arith.SubFOp(s_vals[nm][ii], m_new[ii], fastmath=fm_fast).result + ) + diff_s = arith.as_value( + flir.arith.MulFOp(diff, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + p_vals[nm][ii] = arith.as_value(flir.math.exp2(diff_s, fastmath=fm_fast)) + + row_sum_nm = p_vals[nm][ii] + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(row_sum_nm, sh_i32, width_i32, mode="xor").shuffleResult + ) + row_sum_nm = arith.as_value( + flir.arith.AddFOp(row_sum_nm, peer, fastmath=fm_fast).result + ) + row_sums.append(row_sum_nm) + + combined_sum = row_sums[0] + for _g in range_constexpr(N_MFMA - 1): + combined_sum = arith.as_value( + flir.arith.AddFOp(combined_sum, row_sums[_g + 1], fastmath=fm_fast).result + ) + + l_corr = arith.as_value( + flir.arith.MulFOp(corr[ii], l_old[ii], fastmath=fm_fast).result + ) + l_new[ii] = arith.as_value( + flir.arith.AddFOp(l_corr, combined_sum, fastmath=fm_fast).result + ) + + # Rescale O + corr_vec = arith.as_value( + vec_ext.from_elements(v4f32_type, [corr[0], corr[1], corr[2], corr[3]]) + ) + for ds in range_constexpr(K_STEPS): + o_accs[ds] = arith.as_value( + flir.arith.MulFOp(o_accs[ds], corr_vec, fastmath=fm_fast).result + ) + + # ======================================================== + # P store to LDS_P + # ======================================================== + for ii in range_constexpr(4): + p_row = (arith.ArithValue(lane_div_16) * 4 + ii).value + for nm in range_constexpr(N_MFMA): + p_f16 = arith.as_value( + flir.arith.TruncFOp(elem_type, p_vals[nm][ii]).result + ) + p_lds_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(p_row) * BLOCK_N + + nm * 16 + + arith.ArithValue(lane_mod_16) + ).value + _memref.StoreOp(p_f16, lds_p, [p_lds_idx]) + + # ======================================================== + # PV GEMM Phase with V prefetching + # ======================================================== + # Barrier: all K reads done + P stores visible + gpu.barrier() + + # V[0] was prefetched during K phase → store to LDS + store_v_regs_to_lds(v_regs) + # Prefetch V[1] + v_regs = load_v_to_regs(kv_start, kK1) + + for k1 in range_constexpr(K1_LOOPS): + gpu.barrier() # V[k1] in LDS + + # P A-operand + p_a_lo_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(lane_mod_16) * BLOCK_N + + k1 * kK1 + + arith.ArithValue(lane_div_16) * 4 + ).value + p_pack_lo = arith.as_value( + vec_ext.load_op(v4f16_type, lds_p, [p_a_lo_idx]) + ) + + p_a_hi_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(lane_mod_16) * BLOCK_N + + k1 * kK1 + 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + p_pack_hi = arith.as_value( + vec_ext.load_op(v4f16_type, lds_p, [p_a_hi_idx]) + ) + + for ds in range_constexpr(K_STEPS): + v_top_idx = ( + (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE + + arith.ArithValue(lane_div_16) * 4 + ).value + v_top = arith.as_value( + vec_ext.load_op(v4f16_type, lds_kv, [v_top_idx]) + ) + o_accs[ds] = arith.as_value( + rocdl.mfma_f32_16x16x16f16( + v4f32_type, [p_pack_lo, v_top, o_accs[ds], 0, 0, 0] + ) + ) + + v_bot_idx = ( + (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE + + 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + v_bot = arith.as_value( + vec_ext.load_op(v4f16_type, lds_kv, [v_bot_idx]) + ) + o_accs[ds] = arith.as_value( + rocdl.mfma_f32_16x16x16f16( + v4f32_type, [p_pack_hi, v_bot, o_accs[ds], 0, 0, 0] + ) + ) + + if k1 < K1_LOOPS - 1: + gpu.barrier() # All V[k1] reads done + store_v_regs_to_lds(v_regs) # Store prefetched V[k1+1] + + # Final barrier + gpu.barrier() + + yield_args = m_new + l_new + o_accs + scf_yield(yield_args) + + # ---- Normalize and store O ---- + m_finals = [arith.as_value(loop.results[i]) for i in range(4)] + l_finals = [arith.as_value(loop.results[4 + i]) for i in range(4)] + o_finals = [arith.as_value(loop.results[8 + ds]) for ds in range(K_STEPS)] + + for ds in range_constexpr(K_STEPS): + for ii in range_constexpr(4): + o_val = arith.as_value( + vec_ext.extract(o_finals[ds], static_position=[ii], dynamic_position=[]) + ) + o_norm = arith.as_value( + flir.arith.DivFOp(o_val, l_finals[ii], fastmath=fm_fast).result + ) + o_f16 = arith.as_value( + flir.arith.TruncFOp(elem_type, o_norm).result + ) + q_row_o = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_div_16) * 4 + + ii + ).value + d_col = (flir.const_index(ds * 16) + arith.ArithValue(lane_mod_16)).value + o_global = global_idx(q_row_o, d_col) + _memref.StoreOp(o_f16, O, [o_global]) + + @flir.jit + def __call__( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + batch_size: lambda: T.index(), + seq_len: lambda: T.index(), + ): + c1 = arith.as_value(flir.arith_ext.index(1)) + c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) + c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) + bs_val = arith.as_value(batch_size) + sl_val = arith.as_value(seq_len) + num_q_tiles = arith.as_value( + flir.arith.DivUIOp(sl_val, c_bm).result + ) + bs_qt = arith.as_value( + flir.arith.MulIOp(bs_val, num_q_tiles).result + ) + grid_x = arith.as_value( + flir.arith.MulIOp(bs_qt, c_nh).result + ) + bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) + flir.gpu_ext.LaunchFuncOp( + [self.GPU_MODULE_NAME, KERNEL_NAME], + grid_size=(grid_x, c1, c1), + block_size=(bx, c1, c1), + kernel_operands=[Q, K, V, O, seq_len], + ) + + return _FlashAttentionV4_4() diff --git a/run.sh b/run.sh index 24ca1dfe..f7fe9529 100755 --- a/run.sh +++ b/run.sh @@ -38,7 +38,8 @@ function run_flydsl_op { # python tests/kernels/test_simple_gemm.py --size all --dtype all # python tests/kernels/test_flash_attention_v4_2.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 - python tests/kernels/test_flash_attention_v4_3.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v42 + # python tests/kernels/test_flash_attention_v4_3.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v42 + python tests/kernels/test_flash_attention_v4_4.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v43 } diff --git a/tests/kernels/test_flash_attention_v4_2.py b/tests/kernels/test_flash_attention_v4_2.py index debba7a0..5a69a7a4 100644 --- a/tests/kernels/test_flash_attention_v4_2.py +++ b/tests/kernels/test_flash_attention_v4_2.py @@ -12,7 +12,13 @@ import sys import argparse +import hashlib +import random from pathlib import Path +import logging + +# Configure logging to show INFO level messages (required for kernel name display) +logging.basicConfig(level=logging.INFO) _repo = Path(__file__).resolve().parents[2] sys.path.insert(0, str(_repo)) @@ -20,6 +26,7 @@ try: import torch import torch.nn.functional as F + import numpy as np except ImportError: print("PyTorch not available") sys.exit(1) @@ -30,6 +37,19 @@ import flydsl from kernels.flash_attention_v4_2 import build_flash_attention_v4_2_module, KERNEL_NAME +from tests.test_common import run_perftest + +# Tensor initialization range (uniform distribution) +UNIFORM_RANGE = (-1, 1) +DEFAULT_SEED = 123 + + +def setup_seed(seed: int) -> None: + """Set random seed for reproducibility across all RNG sources.""" + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True def pytorch_ref_attention(q, k, v, causal=True): @@ -40,22 +60,107 @@ def pytorch_ref_attention(q, k, v, causal=True): return out.transpose(1, 2) -def bench_gpu_us(fn, warmup=10, iters=50): - for _ in range(warmup): - fn() - torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(iters): - fn() - end.record() - torch.cuda.synchronize() - return (start.elapsed_time(end) / iters) * 1000 +def compute_md5(tensor: torch.Tensor) -> str: + """Compute MD5 hash of a tensor's raw bytes.""" + return hashlib.md5( + tensor.contiguous().view(torch.uint8).detach().cpu().numpy().tobytes() + ).hexdigest() + + +def compare_arrays( + arr1: np.ndarray, + arr2: np.ndarray, + k: int = 5, + thresholds: list = None, +) -> dict: + """Compare two numpy arrays and compute various difference metrics. + + Args: + arr1: First input array (result), will be cast to float32. + arr2: Second input array (reference), will be cast to float32. + k: Number of top differences to report. + thresholds: Difference magnitude buckets for histogram. + + Returns: + Dictionary with top_k_diff, threshold_stats, nan_info, max_diff, max_diff_thr. + """ + if thresholds is None: + thresholds = [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1] + + if arr1.shape != arr2.shape: + raise ValueError( + f"Shape mismatch: arr1 {arr1.shape} vs arr2 {arr2.shape}" + ) + + arr1 = arr1.astype(np.float32) + arr2 = arr2.astype(np.float32) + + result = {"top_k_diff": [], "threshold_stats": [], "nan_info": {}} + + # Check for NaN values + nan_mask1 = np.isnan(arr1) + nan_mask2 = np.isnan(arr2) + if np.any(nan_mask1): + result["nan_info"]["arr1_nan_count"] = int(np.sum(nan_mask1)) + print(f" Warning: result contains {result['nan_info']['arr1_nan_count']} NaN values") + if np.any(nan_mask2): + result["nan_info"]["arr2_nan_count"] = int(np.sum(nan_mask2)) + print(f" Warning: reference contains {result['nan_info']['arr2_nan_count']} NaN values") + + # Compute absolute differences + diff = np.abs(arr1 - arr2) + total_elements = arr1.size + + max_diff_thr = (diff / (1.0 + np.abs(arr2))).max() + result["max_diff"] = float(diff.max()) + result["max_diff_thr"] = float(max_diff_thr) + + print(f" diff.abs.max = {diff.max():.6f}") + print(f" diff.abs.mean = {diff.mean():.6f}") + print(f" max_diff_thr (rel) = {max_diff_thr:.6e}") + + # Find top k differences + flat_diff = diff.flatten() + actual_k = min(k, len(flat_diff)) + top_k_indices = np.argpartition(flat_diff, -actual_k)[-actual_k:] + top_k_indices = top_k_indices[np.argsort(-flat_diff[top_k_indices])] + + orig_indices = np.unravel_index(top_k_indices, diff.shape) + print(f" Top-{actual_k} differences:") + for i in range(actual_k): + idx = tuple(dim[i] for dim in orig_indices) + entry = { + "value": float(diff[idx]), + "position": idx, + "arr1_value": float(arr1[idx]), + "arr2_value": float(arr2[idx]), + } + result["top_k_diff"].append(entry) + print(f" [{idx}] result={arr1[idx]:.6f}, ref={arr2[idx]:.6f}, diff={diff[idx]:.6f}") + + # Compute threshold statistics + print(f" Threshold distribution ({total_elements} elements):") + for i in range(len(thresholds) - 1): + lower, upper = thresholds[i], thresholds[i + 1] + count = int(np.sum((diff >= lower) & (diff < upper))) + pct = 100.0 * count / total_elements + result["threshold_stats"].append( + {"range": f"[{lower:.0e}, {upper:.0e})", "count": count, "percentage": pct} + ) + print(f" [{lower:.0e}, {upper:.0e}): {count:>8d} ({pct:6.2f}%)") + + count = int(np.sum(diff >= thresholds[-1])) + pct = 100.0 * count / total_elements + result["threshold_stats"].append( + {"range": f">={thresholds[-1]:.0e}", "count": count, "percentage": pct} + ) + print(f" >={thresholds[-1]:.0e} : {count:>8d} ({pct:6.2f}%)") + + return result def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, - warmup, iters, prev_exe=None): + warmup, iters, prev_exe=None, seed=DEFAULT_SEED): device = "cuda" results = {} @@ -79,9 +184,10 @@ def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, return results B, S, H, D = batch, seq_len, num_heads, head_dim - q_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) - k_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) - v_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) + setup_seed(seed) + q_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) + k_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) + v_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) q_flat = q_4d.contiguous().view(-1) k_flat = k_4d.contiguous().view(-1) @@ -115,12 +221,30 @@ def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, results["min_cos"] = min_cos results["passed"] = max_err < 1e-2 and min_cos > 0.99 + # Compute and print MD5 hashes + tag = f"B={B} S={S} H={H} D={D}" + result_md5 = compute_md5(o_flat) + ref_md5 = compute_md5(ref_flat) + print(f" [{tag}] result_md5 = {result_md5}") + print(f" [{tag}] ref_md5 = {ref_md5}") + if result_md5 == ref_md5: + print(f" [{tag}] MD5 match: EXACT (bit-identical)") + else: + print(f" [{tag}] MD5 match: DIFFER (not bit-identical)") + + # Detailed comparison using compare_arrays + print(f" [{tag}] --- compare_arrays ---") + compare_arrays( + o_flat.to(torch.float32).detach().cpu().numpy(), + ref_flat.to(torch.float32).detach().cpu().numpy(), + ) + try: def kernel_fn(): o_flat.zero_() exe(q_flat, k_flat, v_flat, o_flat, B, S) - us = bench_gpu_us(kernel_fn, warmup=warmup, iters=iters) + _, us = run_perftest(kernel_fn, num_iters=iters, num_warmup=warmup) s_eff = S / 2.0 if causal else float(S) flops = 4.0 * S * s_eff * D * H * B tflops = flops / (us * 1e-6) / 1e12 @@ -135,7 +259,7 @@ def kernel_fn(): def prev_fn(): o_prev.zero_() prev_exe(q_flat, k_flat, v_flat, o_prev, B, S) - prev_us = bench_gpu_us(prev_fn, warmup=warmup, iters=iters) + _, prev_us = run_perftest(prev_fn, num_iters=iters, num_warmup=warmup) prev_tflops = flops / (prev_us * 1e-6) / 1e12 results["prev_us"] = prev_us results["prev_tflops"] = prev_tflops @@ -158,6 +282,8 @@ def main(): parser.add_argument("--iters", type=int, default=20) parser.add_argument("--compare-v41", action="store_true", help="Also benchmark V4.1 for comparison") + parser.add_argument("--seed", type=int, default=DEFAULT_SEED, + help=f"Random seed for reproducibility (default: {DEFAULT_SEED})") args = parser.parse_args() causal = not args.no_causal @@ -223,7 +349,7 @@ def main(): r = run_config( batch, seq_len, nh, hd, dtype, causal, warmup=args.warmup, iters=args.iters, - prev_exe=prev_exe, + prev_exe=prev_exe, seed=args.seed, ) if "err" in r: print(f"{tag:>38s} | {'ERROR':>6s} | {r['err'][:60]}") diff --git a/tests/kernels/test_flash_attention_v4_3.py b/tests/kernels/test_flash_attention_v4_3.py index e317b802..28cd1e97 100644 --- a/tests/kernels/test_flash_attention_v4_3.py +++ b/tests/kernels/test_flash_attention_v4_3.py @@ -12,7 +12,13 @@ import sys import argparse +import hashlib +import random from pathlib import Path +import logging + +# Configure logging to show INFO level messages (required for kernel name display) +logging.basicConfig(level=logging.INFO) _repo = Path(__file__).resolve().parents[2] sys.path.insert(0, str(_repo)) @@ -20,6 +26,7 @@ try: import torch import torch.nn.functional as F + import numpy as np except ImportError: print("PyTorch not available") sys.exit(1) @@ -30,6 +37,19 @@ import flydsl from kernels.flash_attention_v4_3 import build_flash_attention_v4_3_module, KERNEL_NAME +from tests.test_common import run_perftest + +# Tensor initialization range (uniform distribution) +UNIFORM_RANGE = (-1, 1) +DEFAULT_SEED = 123 + + +def setup_seed(seed: int) -> None: + """Set random seed for reproducibility across all RNG sources.""" + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True def pytorch_ref_attention(q, k, v, causal=True): @@ -40,22 +60,107 @@ def pytorch_ref_attention(q, k, v, causal=True): return out.transpose(1, 2) -def bench_gpu_us(fn, warmup=10, iters=50): - for _ in range(warmup): - fn() - torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(iters): - fn() - end.record() - torch.cuda.synchronize() - return (start.elapsed_time(end) / iters) * 1000 +def compute_md5(tensor: torch.Tensor) -> str: + """Compute MD5 hash of a tensor's raw bytes.""" + return hashlib.md5( + tensor.contiguous().view(torch.uint8).detach().cpu().numpy().tobytes() + ).hexdigest() + + +def compare_arrays( + arr1: np.ndarray, + arr2: np.ndarray, + k: int = 5, + thresholds: list = None, +) -> dict: + """Compare two numpy arrays and compute various difference metrics. + + Args: + arr1: First input array (result), will be cast to float32. + arr2: Second input array (reference), will be cast to float32. + k: Number of top differences to report. + thresholds: Difference magnitude buckets for histogram. + + Returns: + Dictionary with top_k_diff, threshold_stats, nan_info, max_diff, max_diff_thr. + """ + if thresholds is None: + thresholds = [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1] + + if arr1.shape != arr2.shape: + raise ValueError( + f"Shape mismatch: arr1 {arr1.shape} vs arr2 {arr2.shape}" + ) + + arr1 = arr1.astype(np.float32) + arr2 = arr2.astype(np.float32) + + result = {"top_k_diff": [], "threshold_stats": [], "nan_info": {}} + + # Check for NaN values + nan_mask1 = np.isnan(arr1) + nan_mask2 = np.isnan(arr2) + if np.any(nan_mask1): + result["nan_info"]["arr1_nan_count"] = int(np.sum(nan_mask1)) + print(f" Warning: result contains {result['nan_info']['arr1_nan_count']} NaN values") + if np.any(nan_mask2): + result["nan_info"]["arr2_nan_count"] = int(np.sum(nan_mask2)) + print(f" Warning: reference contains {result['nan_info']['arr2_nan_count']} NaN values") + + # Compute absolute differences + diff = np.abs(arr1 - arr2) + total_elements = arr1.size + + max_diff_thr = (diff / (1.0 + np.abs(arr2))).max() + result["max_diff"] = float(diff.max()) + result["max_diff_thr"] = float(max_diff_thr) + + print(f" diff.abs.max = {diff.max():.6f}") + print(f" diff.abs.mean = {diff.mean():.6f}") + print(f" max_diff_thr (rel) = {max_diff_thr:.6e}") + + # Find top k differences + flat_diff = diff.flatten() + actual_k = min(k, len(flat_diff)) + top_k_indices = np.argpartition(flat_diff, -actual_k)[-actual_k:] + top_k_indices = top_k_indices[np.argsort(-flat_diff[top_k_indices])] + + orig_indices = np.unravel_index(top_k_indices, diff.shape) + print(f" Top-{actual_k} differences:") + for i in range(actual_k): + idx = tuple(dim[i] for dim in orig_indices) + entry = { + "value": float(diff[idx]), + "position": idx, + "arr1_value": float(arr1[idx]), + "arr2_value": float(arr2[idx]), + } + result["top_k_diff"].append(entry) + print(f" [{idx}] result={arr1[idx]:.6f}, ref={arr2[idx]:.6f}, diff={diff[idx]:.6f}") + + # Compute threshold statistics + print(f" Threshold distribution ({total_elements} elements):") + for i in range(len(thresholds) - 1): + lower, upper = thresholds[i], thresholds[i + 1] + count = int(np.sum((diff >= lower) & (diff < upper))) + pct = 100.0 * count / total_elements + result["threshold_stats"].append( + {"range": f"[{lower:.0e}, {upper:.0e})", "count": count, "percentage": pct} + ) + print(f" [{lower:.0e}, {upper:.0e}): {count:>8d} ({pct:6.2f}%)") + + count = int(np.sum(diff >= thresholds[-1])) + pct = 100.0 * count / total_elements + result["threshold_stats"].append( + {"range": f">={thresholds[-1]:.0e}", "count": count, "percentage": pct} + ) + print(f" >={thresholds[-1]:.0e} : {count:>8d} ({pct:6.2f}%)") + + return result def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, - warmup, iters, prev_exe=None): + warmup, iters, prev_exe=None, seed=DEFAULT_SEED): device = "cuda" results = {} @@ -79,9 +184,10 @@ def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, return results B, S, H, D = batch, seq_len, num_heads, head_dim - q_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) - k_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) - v_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) + setup_seed(seed) + q_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) + k_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) + v_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) q_flat = q_4d.contiguous().view(-1) k_flat = k_4d.contiguous().view(-1) @@ -115,12 +221,30 @@ def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, results["min_cos"] = min_cos results["passed"] = max_err < 1e-2 and min_cos > 0.99 + # Compute and print MD5 hashes + tag = f"B={B} S={S} H={H} D={D}" + result_md5 = compute_md5(o_flat) + ref_md5 = compute_md5(ref_flat) + print(f" [{tag}] result_md5 = {result_md5}") + print(f" [{tag}] ref_md5 = {ref_md5}") + if result_md5 == ref_md5: + print(f" [{tag}] MD5 match: EXACT (bit-identical)") + else: + print(f" [{tag}] MD5 match: DIFFER (not bit-identical)") + + # Detailed comparison using compare_arrays + print(f" [{tag}] --- compare_arrays ---") + compare_arrays( + o_flat.to(torch.float32).detach().cpu().numpy(), + ref_flat.to(torch.float32).detach().cpu().numpy(), + ) + try: def kernel_fn(): - o_flat.zero_() + # o_flat.zero_() exe(q_flat, k_flat, v_flat, o_flat, B, S) - us = bench_gpu_us(kernel_fn, warmup=warmup, iters=iters) + _, us = run_perftest(kernel_fn, num_iters=iters, num_warmup=warmup) s_eff = S / 2.0 if causal else float(S) flops = 4.0 * S * s_eff * D * H * B tflops = flops / (us * 1e-6) / 1e12 @@ -133,9 +257,9 @@ def kernel_fn(): try: o_prev = torch.zeros_like(q_flat) def prev_fn(): - o_prev.zero_() + # o_prev.zero_() prev_exe(q_flat, k_flat, v_flat, o_prev, B, S) - prev_us = bench_gpu_us(prev_fn, warmup=warmup, iters=iters) + _, prev_us = run_perftest(prev_fn, num_iters=iters, num_warmup=warmup) prev_tflops = flops / (prev_us * 1e-6) / 1e12 results["prev_us"] = prev_us results["prev_tflops"] = prev_tflops @@ -158,6 +282,8 @@ def main(): parser.add_argument("--iters", type=int, default=20) parser.add_argument("--compare-v42", action="store_true", help="Also benchmark V4.2 for comparison") + parser.add_argument("--seed", type=int, default=DEFAULT_SEED, + help=f"Random seed for reproducibility (default: {DEFAULT_SEED})") args = parser.parse_args() causal = not args.no_causal @@ -223,7 +349,7 @@ def main(): r = run_config( batch, seq_len, nh, hd, dtype, causal, warmup=args.warmup, iters=args.iters, - prev_exe=prev_exe, + prev_exe=prev_exe, seed=args.seed, ) if "err" in r: print(f"{tag:>38s} | {'ERROR':>6s} | {r['err'][:60]}") diff --git a/tests/kernels/test_flash_attention_v4_4.py b/tests/kernels/test_flash_attention_v4_4.py new file mode 100644 index 00000000..90588ffb --- /dev/null +++ b/tests/kernels/test_flash_attention_v4_4.py @@ -0,0 +1,392 @@ +#!/usr/bin/env python3 +"""Flash Attention V4.4 kernel test and benchmark for FlyDSL. + +Tests V4.4 (CK-aligned, BLOCK_N=64) against PyTorch SDPA. +Optionally compares with V4.3. + +Usage: + python tests/kernels/test_flash_attention_v4_4.py + python tests/kernels/test_flash_attention_v4_4.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 + python tests/kernels/test_flash_attention_v4_4.py --compare-v43 +""" + +import sys +import argparse +import hashlib +import random +from pathlib import Path +import logging + +# Configure logging to show INFO level messages (required for kernel name display) +logging.basicConfig(level=logging.INFO) + +_repo = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(_repo)) + +try: + import torch + import torch.nn.functional as F + import numpy as np +except ImportError: + print("PyTorch not available") + sys.exit(1) + +if not torch.cuda.is_available(): + print("CUDA/ROCm not available") + sys.exit(1) + +import flydsl +from kernels.flash_attention_v4_4 import build_flash_attention_v4_4_module, KERNEL_NAME +from tests.test_common import run_perftest + +# Tensor initialization range (uniform distribution) +UNIFORM_RANGE = (-1, 1) +DEFAULT_SEED = 123 + + +def setup_seed(seed: int) -> None: + """Set random seed for reproducibility across all RNG sources.""" + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + + +def pytorch_ref_attention(q, k, v, causal=True): + q_t = q.transpose(1, 2).float() + k_t = k.transpose(1, 2).float() + v_t = v.transpose(1, 2).float() + out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=causal) + return out.transpose(1, 2) + + +def compute_md5(tensor: torch.Tensor) -> str: + """Compute MD5 hash of a tensor's raw bytes.""" + return hashlib.md5( + tensor.contiguous().view(torch.uint8).detach().cpu().numpy().tobytes() + ).hexdigest() + + +def compare_arrays( + arr1: np.ndarray, + arr2: np.ndarray, + k: int = 5, + thresholds: list = None, +) -> dict: + """Compare two numpy arrays and compute various difference metrics. + + Args: + arr1: First input array (result), will be cast to float32. + arr2: Second input array (reference), will be cast to float32. + k: Number of top differences to report. + thresholds: Difference magnitude buckets for histogram. + + Returns: + Dictionary with top_k_diff, threshold_stats, nan_info, max_diff, max_diff_thr. + """ + if thresholds is None: + thresholds = [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1] + + if arr1.shape != arr2.shape: + raise ValueError( + f"Shape mismatch: arr1 {arr1.shape} vs arr2 {arr2.shape}" + ) + + arr1 = arr1.astype(np.float32) + arr2 = arr2.astype(np.float32) + + result = {"top_k_diff": [], "threshold_stats": [], "nan_info": {}} + + # Check for NaN values + nan_mask1 = np.isnan(arr1) + nan_mask2 = np.isnan(arr2) + if np.any(nan_mask1): + result["nan_info"]["arr1_nan_count"] = int(np.sum(nan_mask1)) + print(f" Warning: result contains {result['nan_info']['arr1_nan_count']} NaN values") + if np.any(nan_mask2): + result["nan_info"]["arr2_nan_count"] = int(np.sum(nan_mask2)) + print(f" Warning: reference contains {result['nan_info']['arr2_nan_count']} NaN values") + + # Compute absolute differences + diff = np.abs(arr1 - arr2) + total_elements = arr1.size + + max_diff_thr = (diff / (1.0 + np.abs(arr2))).max() + result["max_diff"] = float(diff.max()) + result["max_diff_thr"] = float(max_diff_thr) + + print(f" diff.abs.max = {diff.max():.6f}") + print(f" diff.abs.mean = {diff.mean():.6f}") + print(f" max_diff_thr (rel) = {max_diff_thr:.6e}") + + # Find top k differences + flat_diff = diff.flatten() + actual_k = min(k, len(flat_diff)) + top_k_indices = np.argpartition(flat_diff, -actual_k)[-actual_k:] + top_k_indices = top_k_indices[np.argsort(-flat_diff[top_k_indices])] + + orig_indices = np.unravel_index(top_k_indices, diff.shape) + print(f" Top-{actual_k} differences:") + for i in range(actual_k): + idx = tuple(dim[i] for dim in orig_indices) + entry = { + "value": float(diff[idx]), + "position": idx, + "arr1_value": float(arr1[idx]), + "arr2_value": float(arr2[idx]), + } + result["top_k_diff"].append(entry) + print(f" [{idx}] result={arr1[idx]:.6f}, ref={arr2[idx]:.6f}, diff={diff[idx]:.6f}") + + # Compute threshold statistics + print(f" Threshold distribution ({total_elements} elements):") + for i in range(len(thresholds) - 1): + lower, upper = thresholds[i], thresholds[i + 1] + count = int(np.sum((diff >= lower) & (diff < upper))) + pct = 100.0 * count / total_elements + result["threshold_stats"].append( + {"range": f"[{lower:.0e}, {upper:.0e})", "count": count, "percentage": pct} + ) + print(f" [{lower:.0e}, {upper:.0e}): {count:>8d} ({pct:6.2f}%)") + + count = int(np.sum(diff >= thresholds[-1])) + pct = 100.0 * count / total_elements + result["threshold_stats"].append( + {"range": f">={thresholds[-1]:.0e}", "count": count, "percentage": pct} + ) + print(f" >={thresholds[-1]:.0e} : {count:>8d} ({pct:6.2f}%)") + + return result + + +def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, + warmup, iters, prev_exe=None, seed=DEFAULT_SEED): + device = "cuda" + results = {} + + if seq_len % 64 != 0: + results["err"] = f"seq_len ({seq_len}) must be divisible by 64" + return results + if head_dim % 16 != 0 or head_dim < 64: + results["err"] = f"head_dim ({head_dim}) must be >= 64 and divisible by 16" + return results + + try: + m = build_flash_attention_v4_4_module( + num_heads=num_heads, head_dim=head_dim, + causal=causal, dtype_str="f16", + ) + exe = flydsl.compile(m) + except Exception as e: + results["err"] = f"compile: {e}" + import traceback + traceback.print_exc() + return results + + B, S, H, D = batch, seq_len, num_heads, head_dim + setup_seed(seed) + q_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) + k_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) + v_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) + + q_flat = q_4d.contiguous().view(-1) + k_flat = k_4d.contiguous().view(-1) + v_flat = v_4d.contiguous().view(-1) + o_flat = torch.zeros_like(q_flat) + + try: + exe(q_flat, k_flat, v_flat, o_flat, B, S) + torch.cuda.synchronize() + except Exception as e: + results["err"] = f"exec: {e}" + import traceback + traceback.print_exc() + return results + + ref_4d = pytorch_ref_attention( + q_4d.float(), k_4d.float(), v_4d.float(), causal=causal + ).to(dtype) + ref_flat = ref_4d.contiguous().view(-1) + + o_f32 = o_flat.float() + ref_f32 = ref_flat.float() + max_err = (o_f32 - ref_f32).abs().max().item() + mean_err = (o_f32 - ref_f32).abs().mean().item() + cos_sim = F.cosine_similarity( + o_f32.view(-1, D), ref_f32.view(-1, D), dim=1 + ) + min_cos = cos_sim.min().item() + results["max_err"] = max_err + results["mean_err"] = mean_err + results["min_cos"] = min_cos + results["passed"] = max_err < 1e-2 and min_cos > 0.99 + + # Compute and print MD5 hashes + tag = f"B={B} S={S} H={H} D={D}" + result_md5 = compute_md5(o_flat) + ref_md5 = compute_md5(ref_flat) + print(f" [{tag}] result_md5 = {result_md5}") + print(f" [{tag}] ref_md5 = {ref_md5}") + if result_md5 == ref_md5: + print(f" [{tag}] MD5 match: EXACT (bit-identical)") + else: + print(f" [{tag}] MD5 match: DIFFER (not bit-identical)") + + # Detailed comparison using compare_arrays + print(f" [{tag}] --- compare_arrays ---") + compare_arrays( + o_flat.to(torch.float32).detach().cpu().numpy(), + ref_flat.to(torch.float32).detach().cpu().numpy(), + ) + + try: + def kernel_fn(): + exe(q_flat, k_flat, v_flat, o_flat, B, S) + + _, us = run_perftest(kernel_fn, num_iters=iters, num_warmup=warmup) + s_eff = S / 2.0 if causal else float(S) + flops = 4.0 * S * s_eff * D * H * B + tflops = flops / (us * 1e-6) / 1e12 + results["us"] = us + results["tflops"] = tflops + except Exception as e: + results["bench_err"] = str(e) + + if prev_exe is not None: + try: + o_prev = torch.zeros_like(q_flat) + def prev_fn(): + prev_exe(q_flat, k_flat, v_flat, o_prev, B, S) + _, prev_us = run_perftest(prev_fn, num_iters=iters, num_warmup=warmup) + prev_tflops = flops / (prev_us * 1e-6) / 1e12 + results["prev_us"] = prev_us + results["prev_tflops"] = prev_tflops + except Exception as e: + results["prev_bench_err"] = str(e) + + return results + + +def main(): + parser = argparse.ArgumentParser( + description="Flash Attention V4.4 FlyDSL Test/Benchmark" + ) + parser.add_argument("--batch", type=int, default=None) + parser.add_argument("--seq_len", type=int, default=None) + parser.add_argument("--num_heads", type=int, default=None) + parser.add_argument("--head_dim", type=int, default=None) + parser.add_argument("--no-causal", action="store_true") + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--iters", type=int, default=20) + parser.add_argument("--compare-v43", action="store_true", + help="Also benchmark V4.3 for comparison") + parser.add_argument("--seed", type=int, default=DEFAULT_SEED, + help=f"Random seed for reproducibility (default: {DEFAULT_SEED})") + args = parser.parse_args() + + causal = not args.no_causal + dtype = torch.float16 + + print("=" * 130) + print(f"FlyDSL Flash Attention V4.4 ({'causal' if causal else 'non-causal'}, fp16)") + print(f" CK-aligned: BLOCK_M=64, BLOCK_N=64, 4 waves (256 threads), mfma_f32_16x16x16f16") + print(f" Vectorized output via LDS, softmax over 64 positions") + print(f"GPU: {torch.cuda.get_device_name(0)}") + print("=" * 130) + + if args.seq_len or args.head_dim or args.batch: + configs = [( + args.batch or 1, + args.seq_len or 128, + args.num_heads or 8, + args.head_dim or 128, + )] + else: + configs = [ + (1, 64, 8, 128), + (1, 128, 8, 128), + (1, 256, 32, 128), + (1, 512, 32, 128), + (2, 128, 8, 128), + ] + + prev_exes = {} + if args.compare_v43: + from kernels.flash_attention_v4_3 import build_flash_attention_v4_3_module + for _, _, nh, hd in configs: + key = (nh, hd) + if key not in prev_exes: + try: + m = build_flash_attention_v4_3_module( + num_heads=nh, head_dim=hd, + causal=causal, dtype_str="f16", + ) + prev_exes[key] = flydsl.compile(m) + except Exception: + prev_exes[key] = None + + if args.compare_v43: + hdr = ( + f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'MinCos':>8s} | {'V4.4(us)':>10s} {'V4.4 TF':>9s} | " + f"{'V4.3(us)':>10s} {'V4.3 TF':>9s} | {'Speedup':>7s}" + ) + else: + hdr = ( + f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'MinCos':>8s} | {'Time(us)':>10s} {'TFLOPS':>8s}" + ) + print(f"\n{hdr}") + print("-" * len(hdr)) + + all_passed = True + for batch, seq_len, nh, hd in configs: + tag = f"B={batch} S={seq_len} H={nh} D={hd}" + try: + prev_exe = prev_exes.get((nh, hd)) if args.compare_v43 else None + r = run_config( + batch, seq_len, nh, hd, dtype, causal, + warmup=args.warmup, iters=args.iters, + prev_exe=prev_exe, seed=args.seed, + ) + if "err" in r: + print(f"{tag:>38s} | {'ERROR':>6s} | {r['err'][:60]}") + all_passed = False + continue + + status = "PASS" if r["passed"] else "FAIL" + if not r["passed"]: + all_passed = False + + us_s = f"{r['us']:>10.1f}" if "us" in r else " N/A" + tf_s = f"{r['tflops']:>9.3f}" if "tflops" in r else " N/A" + + if args.compare_v43 and "prev_us" in r: + p_us = f"{r['prev_us']:>10.1f}" + p_tf = f"{r['prev_tflops']:>9.3f}" + speedup = r["prev_us"] / r["us"] if r.get("us") else 0 + print( + f"{tag:>38s} | {status:>6s} | " + f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " + f"{us_s} {tf_s} | {p_us} {p_tf} | {speedup:>6.2f}x" + ) + else: + print( + f"{tag:>38s} | {status:>6s} | " + f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " + f"{us_s} {tf_s}" + ) + except Exception as e: + print(f"{tag:>38s} | {'ERROR':>6s} | {str(e)[:60]}") + all_passed = False + + print("=" * 130) + if all_passed: + print("All tests PASSED") + else: + print("Some tests FAILED") + sys.exit(1) + + +if __name__ == "__main__": + main() From 9ee814c4280035249432680445202602240e65f8 Mon Sep 17 00:00:00 2001 From: yanguahe Date: Fri, 13 Feb 2026 10:32:22 +0800 Subject: [PATCH 07/17] Romve flash_attention_v4_4_kernel --- kernels/flash_attention_v4_4.py | 583 --------------------- run.sh | 4 +- tests/kernels/test_flash_attention_v4_4.py | 392 -------------- 3 files changed, 2 insertions(+), 977 deletions(-) delete mode 100644 kernels/flash_attention_v4_4.py delete mode 100644 tests/kernels/test_flash_attention_v4_4.py diff --git a/kernels/flash_attention_v4_4.py b/kernels/flash_attention_v4_4.py deleted file mode 100644 index 7f76d8ae..00000000 --- a/kernels/flash_attention_v4_4.py +++ /dev/null @@ -1,583 +0,0 @@ -"""Flash Attention V4.4 kernel builder for FlyDSL. - -V4.4 optimization over V4.3 (CK-aligned design): -- BLOCK_N=64 (vs 32): halves KV loop iterations. -- K loaded in chunks of kK0=32 along head_dim (K0_LOOPS inner iterations). - K_STRIDE = kK0 + 2 = 34 (was HEAD_DIM + 2 = 130). -- V loaded in chunks of kK1=32 (K1_LOOPS inner iterations). - VT_STRIDE = kK1 + 2 = 34. -- K/V prefetching: overlaps global loads with MFMA computation. - K[k0+1] is fetched while computing with K[k0]. - V[0] is fetched during last K computation. - V[k1+1] is fetched while computing with V[k1]. -- LDS reduced: max(64*34, 128*34) = 4352 elem = 8.5KB + 8KB(P) = 16.5KB. -- Softmax over 64 positions (four 16-wide groups). -- Causal early-exit retained. - -Tile config: BLOCK_M=64, BLOCK_N=64, kK0=32, kK1=32, - 4 waves (256 threads), mfma_f32_16x16x16f16. - -Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). -Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. -Block: (256,) -- 4 waves of 64 on AMD (wave64). - -Requires: head_dim % 32 == 0, seq_len % 64 == 0, head_dim >= 64. -""" - -import math - -from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl -from flydsl.dialects.ext import vector as vec_ext -from flydsl.dialects.ext.python_control_flow import range_constexpr -from flydsl.dialects.ext.scf import yield_ as scf_yield -from _mlir.dialects import memref as _memref -from flydsl.runtime.device import get_rocm_arch as get_hip_arch -from flydsl.utils import SmemAllocator -from _mlir import ir -import _mlir.extras.types as T - - -KERNEL_NAME = "flash_attention_v4_4_kernel" - - -def build_flash_attention_v4_4_module( - num_heads, - head_dim, - causal=True, - dtype_str="f16", - sm_scale=None, -): - gpu_arch = get_hip_arch() - DYN = ir.ShapedType.get_dynamic_size() - - BLOCK_M = 64 - BLOCK_N = 64 - NUM_WAVES = 4 - WARP_SIZE = 64 - BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 - ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 16 - K_STEPS = head_dim // 16 - - N_MFMA = BLOCK_N // 16 # 4 - - kK0 = 32 - kK1 = 32 - K0_LOOPS = head_dim // kK0 - K1_LOOPS = BLOCK_N // kK1 # 2 - K_STEPS_PER_CHUNK = kK0 // 16 # 2 - - assert head_dim % kK0 == 0 - assert BLOCK_N % kK1 == 0 - assert head_dim % 16 == 0 - assert head_dim >= 64 - assert dtype_str == "f16" - - if sm_scale is None: - sm_scale = 1.0 / math.sqrt(head_dim) - - NUM_HEADS = num_heads - HEAD_DIM = head_dim - CAUSAL = causal - STRIDE_TOKEN = NUM_HEADS * HEAD_DIM - - K_STRIDE = kK0 + 2 # 34 - VT_STRIDE = kK1 + 2 # 34 - - VEC_WIDTH = 8 - K_THREADS_PER_ROW = kK0 // VEC_WIDTH # 4 - V_THREADS_PER_ROW = HEAD_DIM // VEC_WIDTH # 16 - V_ROWS_PER_BATCH = BLOCK_SIZE // V_THREADS_PER_ROW # 16 - NUM_BATCHES_V = kK1 // V_ROWS_PER_BATCH # 2 - - LDS_KV_SIZE = max(BLOCK_N * K_STRIDE, HEAD_DIM * VT_STRIDE) - LDS_P_SIZE = BLOCK_M * BLOCK_N - - allocator = SmemAllocator(None, arch=gpu_arch) - _state = {} - - class _FlashAttentionV4_4(flir.MlirModule): - GPU_MODULE_NAME = f"flash_attn_v4_4_{dtype_str}" - GPU_MODULE_TARGETS = [f'#rocdl.target'] - - def init_gpu_module(self): - elem_type = T.f16() - _state["elem_type"] = elem_type - _state["lds_kv"] = allocator.allocate_array(elem_type, LDS_KV_SIZE) - _state["lds_p"] = allocator.allocate_array(elem_type, LDS_P_SIZE) - allocator.finalize() - - @flir.kernel - def flash_attention_v4_4_kernel( - self: flir.T.i64, - Q: lambda: T.memref(DYN, _state["elem_type"]), - K: lambda: T.memref(DYN, _state["elem_type"]), - V: lambda: T.memref(DYN, _state["elem_type"]), - O: lambda: T.memref(DYN, _state["elem_type"]), - seq_len: lambda: T.index(), - ): - compute_type = T.f32() - elem_type = _state["elem_type"] - fm_fast = flir.arith.FastMathFlags.fast - - v4f16_type = ir.VectorType.get([4], elem_type) - v4f32_type = ir.VectorType.get([4], compute_type) - v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) - - seq_len_v = arith.as_value(seq_len) - - base_ptr = allocator.get_base() - lds_kv = _state["lds_kv"](base_ptr).get() - lds_p = _state["lds_p"](base_ptr).get() - - block_id = flir.const_index(flir.block_idx("x")) - tid = flir.const_index(flir.thread_idx("x")) - - c_ws = flir.const_index(WARP_SIZE) - wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) - lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) - - c16 = flir.const_index(16) - lane_div_16 = arith.as_value(flir.arith.DivUIOp(lane, c16).result) - lane_mod_16 = arith.as_value(flir.arith.RemUIOp(lane, c16).result) - - wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value - wave_p_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE * BLOCK_N).value - - c_nh = flir.const_index(NUM_HEADS) - head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) - temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) - c_bm = flir.const_index(BLOCK_M) - num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) - q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) - batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) - q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value - - # K load decomposition (4 threads/row for kK0=32) - c_ktpr = flir.const_index(K_THREADS_PER_ROW) - k_load_row = arith.as_value(flir.arith.DivUIOp(tid, c_ktpr).result) - k_load_col_lane = arith.as_value(flir.arith.RemUIOp(tid, c_ktpr).result) - k_load_col_base = (arith.ArithValue(k_load_col_lane) * VEC_WIDTH).value - - # V load decomposition (16 threads/row for HEAD_DIM) - c_vtpr = flir.const_index(V_THREADS_PER_ROW) - v_load_row_in_batch = arith.as_value(flir.arith.DivUIOp(tid, c_vtpr).result) - v_load_col_base = ( - arith.ArithValue(flir.arith.RemUIOp(tid, c_vtpr).result) * VEC_WIDTH - ).value - - def global_idx(token_idx, col): - token = ( - arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) - + arith.ArithValue(token_idx) - ) - return ( - token * STRIDE_TOKEN - + arith.ArithValue(head_idx) * HEAD_DIM - + arith.ArithValue(col) - ).value - - # ---- Prefetch helpers: separate load-to-regs from store-to-lds ---- - - def load_k_to_regs(tile_start, k0_col_offset): - """Issue global load for K chunk → returns v8f16 register.""" - row_idx = (arith.ArithValue(tile_start) + arith.ArithValue(k_load_row)).value - col_idx = (flir.const_index(k0_col_offset) + arith.ArithValue(k_load_col_base)).value - g_idx = global_idx(row_idx, col_idx) - return arith.as_value(vec_ext.load_op(v8f16_type, K, [g_idx])) - - def store_k_regs_to_lds(k_reg): - """Store K register data to LDS_KV.""" - lds_idx = ( - arith.ArithValue(k_load_row) * K_STRIDE - + arith.ArithValue(k_load_col_base) - ).value - vec_ext.store(k_reg, lds_kv, [lds_idx]) - - def load_v_to_regs(tile_start, k1_row_offset): - """Issue global loads for V chunk → returns list of v8f16.""" - regs = [] - for batch in range_constexpr(NUM_BATCHES_V): - row_offset = batch * V_ROWS_PER_BATCH - row_idx = ( - arith.ArithValue(tile_start) + k1_row_offset - + arith.ArithValue(v_load_row_in_batch) + row_offset - ).value - g_idx = global_idx(row_idx, v_load_col_base) - regs.append(arith.as_value(vec_ext.load_op(v8f16_type, V, [g_idx]))) - return regs - - def store_v_regs_to_lds(v_regs): - """Scatter-store V registers transposed to LDS_KV.""" - for batch_idx in range_constexpr(NUM_BATCHES_V): - vec = v_regs[batch_idx] - load_row = ( - arith.ArithValue(v_load_row_in_batch) + batch_idx * V_ROWS_PER_BATCH - ).value - for e in range_constexpr(VEC_WIDTH): - elem = arith.as_value( - vec_ext.extract(vec, static_position=[e], dynamic_position=[]) - ) - col_e = (arith.ArithValue(v_load_col_base) + e).value - lds_idx = ( - arith.ArithValue(col_e) * VT_STRIDE - + arith.ArithValue(load_row) - ).value - _memref.StoreOp(elem, lds_kv, [lds_idx]) - - # ---- Load Q to registers (once) ---- - q_row = ( - arith.ArithValue(q_start) - + arith.ArithValue(wave_q_offset) - + arith.ArithValue(lane_mod_16) - ).value - q_a_packs = [] - for ks in range_constexpr(K_STEPS): - q_col = flir.const_index(ks * 16) - q_col = (arith.ArithValue(q_col) + arith.ArithValue(lane_div_16) * 4).value - g_idx = global_idx(q_row, q_col) - q_a_packs.append(arith.as_value( - vec_ext.load_op(v4f16_type, Q, [g_idx]) - )) - - # ---- Constants ---- - c_neg_inf = arith.constant(float("-inf"), type=compute_type) - c_zero_f = arith.constant(0.0, type=compute_type) - c_sm_scale = arith.constant(sm_scale, type=compute_type) - c_log2e = arith.constant(1.4426950408889634, type=compute_type) - c_zero_v4f32 = arith.as_value(arith.constant_vector(0.0, v4f32_type)) - - init_args = [] - for _ in range_constexpr(4): - init_args.append(arith.as_value(c_neg_inf)) - for _ in range_constexpr(4): - init_args.append(arith.as_value(c_zero_f)) - for _ in range_constexpr(K_STEPS): - init_args.append(c_zero_v4f32) - - if CAUSAL: - kv_upper = (arith.ArithValue(q_start) + BLOCK_M).value - else: - kv_upper = seq_len_v - - # ================================================================ - # KV loop (step BLOCK_N=64) - # ================================================================ - with scf.for_(0, kv_upper, BLOCK_N, iter_args=init_args) as loop: - kv_start = arith.as_value(loop.induction_variable) - m_old = [arith.as_value(loop.inner_iter_args[i]) for i in range(4)] - l_old = [arith.as_value(loop.inner_iter_args[4 + i]) for i in range(4)] - o_accs = [arith.as_value(loop.inner_iter_args[8 + ds]) for ds in range(K_STEPS)] - - s_accs = [c_zero_v4f32] * N_MFMA - - # ======================================================== - # QK GEMM Phase with K prefetching - # ======================================================== - # K[0]: load → store → prefetch K[1] - k_reg = load_k_to_regs(kv_start, 0) - store_k_regs_to_lds(k_reg) - k_reg = load_k_to_regs(kv_start, kK0) # Prefetch K[1] - - for k0 in range_constexpr(K0_LOOPS): - gpu.barrier() # K[k0] in LDS - - # Compute QK with K[k0] - for local_ks in range_constexpr(K_STEPS_PER_CHUNK): - global_ks = k0 * K_STEPS_PER_CHUNK + local_ks - a_pack = q_a_packs[global_ks] - for nm in range_constexpr(N_MFMA): - k_lds_idx = ( - (arith.ArithValue(lane_mod_16) + nm * 16) * K_STRIDE - + local_ks * 16 - + arith.ArithValue(lane_div_16) * 4 - ).value - b_pack = arith.as_value( - vec_ext.load_op(v4f16_type, lds_kv, [k_lds_idx]) - ) - s_accs[nm] = arith.as_value( - rocdl.mfma_f32_16x16x16f16( - v4f32_type, [a_pack, b_pack, s_accs[nm], 0, 0, 0] - ) - ) - - if k0 < K0_LOOPS - 1: - gpu.barrier() # All reads of K[k0] done - - # Store prefetched K[k0+1] to LDS - store_k_regs_to_lds(k_reg) - - # Prefetch K[k0+2] or V[0] - if k0 + 2 < K0_LOOPS: - k_reg = load_k_to_regs(kv_start, (k0 + 2) * kK0) - if k0 == K0_LOOPS - 2: - # Last K store: also prefetch V[0] - v_regs = load_v_to_regs(kv_start, 0) - - # After last QK compute: no barrier here yet - # (softmax is register-only, no LDS_KV conflict) - - # ======================================================== - # Online Softmax over 64 positions - # ======================================================== - s_vals = [[] for _ in range(N_MFMA)] - for ii in range_constexpr(4): - for nm in range_constexpr(N_MFMA): - s_val = arith.as_value( - vec_ext.extract(s_accs[nm], static_position=[ii], dynamic_position=[]) - ) - s_val = arith.as_value( - flir.arith.MulFOp(s_val, arith.as_value(c_sm_scale), fastmath=fm_fast).result - ) - if CAUSAL: - q_row_c = ( - arith.ArithValue(q_start) - + arith.ArithValue(wave_q_offset) - + arith.ArithValue(lane_div_16) * 4 - + ii - ).value - kv_col = ( - arith.ArithValue(kv_start) + nm * 16 - + arith.ArithValue(lane_mod_16) - ).value - q_row_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), q_row_c).result) - kv_col_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), kv_col).result) - is_masked = arith.as_value( - flir.arith.CmpIOp( - flir.arith.CmpIPredicate.ugt, kv_col_i64, q_row_i64, - ).result - ) - s_val = arith.as_value( - flir.arith.SelectOp(is_masked, arith.as_value(c_neg_inf), s_val).result - ) - s_vals[nm].append(s_val) - - width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) - m_new = [None] * 4 - corr = [None] * 4 - p_vals = [[None] * 4 for _ in range(N_MFMA)] - l_new = [None] * 4 - - for ii in range_constexpr(4): - row_maxes = [] - for nm in range_constexpr(N_MFMA): - row_max_nm = s_vals[nm][ii] - for sh in [8, 4, 2, 1]: - sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value( - gpu.ShuffleOp(row_max_nm, sh_i32, width_i32, mode="xor").shuffleResult - ) - row_max_nm = arith.as_value( - flir.arith.MaximumFOp(row_max_nm, peer).result - ) - row_maxes.append(row_max_nm) - - combined_max = row_maxes[0] - for _g in range_constexpr(N_MFMA - 1): - combined_max = arith.as_value( - flir.arith.MaximumFOp(combined_max, row_maxes[_g + 1]).result - ) - - m_new[ii] = arith.as_value( - flir.arith.MaximumFOp(m_old[ii], combined_max).result - ) - - diff_m = arith.as_value( - flir.arith.SubFOp(m_old[ii], m_new[ii], fastmath=fm_fast).result - ) - diff_m_s = arith.as_value( - flir.arith.MulFOp(diff_m, arith.as_value(c_log2e), fastmath=fm_fast).result - ) - corr[ii] = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) - - row_sums = [] - for nm in range_constexpr(N_MFMA): - diff = arith.as_value( - flir.arith.SubFOp(s_vals[nm][ii], m_new[ii], fastmath=fm_fast).result - ) - diff_s = arith.as_value( - flir.arith.MulFOp(diff, arith.as_value(c_log2e), fastmath=fm_fast).result - ) - p_vals[nm][ii] = arith.as_value(flir.math.exp2(diff_s, fastmath=fm_fast)) - - row_sum_nm = p_vals[nm][ii] - for sh in [8, 4, 2, 1]: - sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value( - gpu.ShuffleOp(row_sum_nm, sh_i32, width_i32, mode="xor").shuffleResult - ) - row_sum_nm = arith.as_value( - flir.arith.AddFOp(row_sum_nm, peer, fastmath=fm_fast).result - ) - row_sums.append(row_sum_nm) - - combined_sum = row_sums[0] - for _g in range_constexpr(N_MFMA - 1): - combined_sum = arith.as_value( - flir.arith.AddFOp(combined_sum, row_sums[_g + 1], fastmath=fm_fast).result - ) - - l_corr = arith.as_value( - flir.arith.MulFOp(corr[ii], l_old[ii], fastmath=fm_fast).result - ) - l_new[ii] = arith.as_value( - flir.arith.AddFOp(l_corr, combined_sum, fastmath=fm_fast).result - ) - - # Rescale O - corr_vec = arith.as_value( - vec_ext.from_elements(v4f32_type, [corr[0], corr[1], corr[2], corr[3]]) - ) - for ds in range_constexpr(K_STEPS): - o_accs[ds] = arith.as_value( - flir.arith.MulFOp(o_accs[ds], corr_vec, fastmath=fm_fast).result - ) - - # ======================================================== - # P store to LDS_P - # ======================================================== - for ii in range_constexpr(4): - p_row = (arith.ArithValue(lane_div_16) * 4 + ii).value - for nm in range_constexpr(N_MFMA): - p_f16 = arith.as_value( - flir.arith.TruncFOp(elem_type, p_vals[nm][ii]).result - ) - p_lds_idx = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(p_row) * BLOCK_N - + nm * 16 - + arith.ArithValue(lane_mod_16) - ).value - _memref.StoreOp(p_f16, lds_p, [p_lds_idx]) - - # ======================================================== - # PV GEMM Phase with V prefetching - # ======================================================== - # Barrier: all K reads done + P stores visible - gpu.barrier() - - # V[0] was prefetched during K phase → store to LDS - store_v_regs_to_lds(v_regs) - # Prefetch V[1] - v_regs = load_v_to_regs(kv_start, kK1) - - for k1 in range_constexpr(K1_LOOPS): - gpu.barrier() # V[k1] in LDS - - # P A-operand - p_a_lo_idx = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(lane_mod_16) * BLOCK_N - + k1 * kK1 - + arith.ArithValue(lane_div_16) * 4 - ).value - p_pack_lo = arith.as_value( - vec_ext.load_op(v4f16_type, lds_p, [p_a_lo_idx]) - ) - - p_a_hi_idx = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(lane_mod_16) * BLOCK_N - + k1 * kK1 + 16 - + arith.ArithValue(lane_div_16) * 4 - ).value - p_pack_hi = arith.as_value( - vec_ext.load_op(v4f16_type, lds_p, [p_a_hi_idx]) - ) - - for ds in range_constexpr(K_STEPS): - v_top_idx = ( - (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE - + arith.ArithValue(lane_div_16) * 4 - ).value - v_top = arith.as_value( - vec_ext.load_op(v4f16_type, lds_kv, [v_top_idx]) - ) - o_accs[ds] = arith.as_value( - rocdl.mfma_f32_16x16x16f16( - v4f32_type, [p_pack_lo, v_top, o_accs[ds], 0, 0, 0] - ) - ) - - v_bot_idx = ( - (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE - + 16 - + arith.ArithValue(lane_div_16) * 4 - ).value - v_bot = arith.as_value( - vec_ext.load_op(v4f16_type, lds_kv, [v_bot_idx]) - ) - o_accs[ds] = arith.as_value( - rocdl.mfma_f32_16x16x16f16( - v4f32_type, [p_pack_hi, v_bot, o_accs[ds], 0, 0, 0] - ) - ) - - if k1 < K1_LOOPS - 1: - gpu.barrier() # All V[k1] reads done - store_v_regs_to_lds(v_regs) # Store prefetched V[k1+1] - - # Final barrier - gpu.barrier() - - yield_args = m_new + l_new + o_accs - scf_yield(yield_args) - - # ---- Normalize and store O ---- - m_finals = [arith.as_value(loop.results[i]) for i in range(4)] - l_finals = [arith.as_value(loop.results[4 + i]) for i in range(4)] - o_finals = [arith.as_value(loop.results[8 + ds]) for ds in range(K_STEPS)] - - for ds in range_constexpr(K_STEPS): - for ii in range_constexpr(4): - o_val = arith.as_value( - vec_ext.extract(o_finals[ds], static_position=[ii], dynamic_position=[]) - ) - o_norm = arith.as_value( - flir.arith.DivFOp(o_val, l_finals[ii], fastmath=fm_fast).result - ) - o_f16 = arith.as_value( - flir.arith.TruncFOp(elem_type, o_norm).result - ) - q_row_o = ( - arith.ArithValue(q_start) - + arith.ArithValue(wave_q_offset) - + arith.ArithValue(lane_div_16) * 4 - + ii - ).value - d_col = (flir.const_index(ds * 16) + arith.ArithValue(lane_mod_16)).value - o_global = global_idx(q_row_o, d_col) - _memref.StoreOp(o_f16, O, [o_global]) - - @flir.jit - def __call__( - self: flir.T.i64, - Q: lambda: T.memref(DYN, _state["elem_type"]), - K: lambda: T.memref(DYN, _state["elem_type"]), - V: lambda: T.memref(DYN, _state["elem_type"]), - O: lambda: T.memref(DYN, _state["elem_type"]), - batch_size: lambda: T.index(), - seq_len: lambda: T.index(), - ): - c1 = arith.as_value(flir.arith_ext.index(1)) - c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) - c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) - bs_val = arith.as_value(batch_size) - sl_val = arith.as_value(seq_len) - num_q_tiles = arith.as_value( - flir.arith.DivUIOp(sl_val, c_bm).result - ) - bs_qt = arith.as_value( - flir.arith.MulIOp(bs_val, num_q_tiles).result - ) - grid_x = arith.as_value( - flir.arith.MulIOp(bs_qt, c_nh).result - ) - bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) - flir.gpu_ext.LaunchFuncOp( - [self.GPU_MODULE_NAME, KERNEL_NAME], - grid_size=(grid_x, c1, c1), - block_size=(bx, c1, c1), - kernel_operands=[Q, K, V, O, seq_len], - ) - - return _FlashAttentionV4_4() diff --git a/run.sh b/run.sh index f7fe9529..c6c032cc 100755 --- a/run.sh +++ b/run.sh @@ -38,8 +38,8 @@ function run_flydsl_op { # python tests/kernels/test_simple_gemm.py --size all --dtype all # python tests/kernels/test_flash_attention_v4_2.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 - # python tests/kernels/test_flash_attention_v4_3.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v42 - python tests/kernels/test_flash_attention_v4_4.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v43 + python tests/kernels/test_flash_attention_v4_3.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v42 + # python tests/kernels/test_flash_attention_v4_4.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v43 } diff --git a/tests/kernels/test_flash_attention_v4_4.py b/tests/kernels/test_flash_attention_v4_4.py deleted file mode 100644 index 90588ffb..00000000 --- a/tests/kernels/test_flash_attention_v4_4.py +++ /dev/null @@ -1,392 +0,0 @@ -#!/usr/bin/env python3 -"""Flash Attention V4.4 kernel test and benchmark for FlyDSL. - -Tests V4.4 (CK-aligned, BLOCK_N=64) against PyTorch SDPA. -Optionally compares with V4.3. - -Usage: - python tests/kernels/test_flash_attention_v4_4.py - python tests/kernels/test_flash_attention_v4_4.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 - python tests/kernels/test_flash_attention_v4_4.py --compare-v43 -""" - -import sys -import argparse -import hashlib -import random -from pathlib import Path -import logging - -# Configure logging to show INFO level messages (required for kernel name display) -logging.basicConfig(level=logging.INFO) - -_repo = Path(__file__).resolve().parents[2] -sys.path.insert(0, str(_repo)) - -try: - import torch - import torch.nn.functional as F - import numpy as np -except ImportError: - print("PyTorch not available") - sys.exit(1) - -if not torch.cuda.is_available(): - print("CUDA/ROCm not available") - sys.exit(1) - -import flydsl -from kernels.flash_attention_v4_4 import build_flash_attention_v4_4_module, KERNEL_NAME -from tests.test_common import run_perftest - -# Tensor initialization range (uniform distribution) -UNIFORM_RANGE = (-1, 1) -DEFAULT_SEED = 123 - - -def setup_seed(seed: int) -> None: - """Set random seed for reproducibility across all RNG sources.""" - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - - -def pytorch_ref_attention(q, k, v, causal=True): - q_t = q.transpose(1, 2).float() - k_t = k.transpose(1, 2).float() - v_t = v.transpose(1, 2).float() - out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=causal) - return out.transpose(1, 2) - - -def compute_md5(tensor: torch.Tensor) -> str: - """Compute MD5 hash of a tensor's raw bytes.""" - return hashlib.md5( - tensor.contiguous().view(torch.uint8).detach().cpu().numpy().tobytes() - ).hexdigest() - - -def compare_arrays( - arr1: np.ndarray, - arr2: np.ndarray, - k: int = 5, - thresholds: list = None, -) -> dict: - """Compare two numpy arrays and compute various difference metrics. - - Args: - arr1: First input array (result), will be cast to float32. - arr2: Second input array (reference), will be cast to float32. - k: Number of top differences to report. - thresholds: Difference magnitude buckets for histogram. - - Returns: - Dictionary with top_k_diff, threshold_stats, nan_info, max_diff, max_diff_thr. - """ - if thresholds is None: - thresholds = [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1] - - if arr1.shape != arr2.shape: - raise ValueError( - f"Shape mismatch: arr1 {arr1.shape} vs arr2 {arr2.shape}" - ) - - arr1 = arr1.astype(np.float32) - arr2 = arr2.astype(np.float32) - - result = {"top_k_diff": [], "threshold_stats": [], "nan_info": {}} - - # Check for NaN values - nan_mask1 = np.isnan(arr1) - nan_mask2 = np.isnan(arr2) - if np.any(nan_mask1): - result["nan_info"]["arr1_nan_count"] = int(np.sum(nan_mask1)) - print(f" Warning: result contains {result['nan_info']['arr1_nan_count']} NaN values") - if np.any(nan_mask2): - result["nan_info"]["arr2_nan_count"] = int(np.sum(nan_mask2)) - print(f" Warning: reference contains {result['nan_info']['arr2_nan_count']} NaN values") - - # Compute absolute differences - diff = np.abs(arr1 - arr2) - total_elements = arr1.size - - max_diff_thr = (diff / (1.0 + np.abs(arr2))).max() - result["max_diff"] = float(diff.max()) - result["max_diff_thr"] = float(max_diff_thr) - - print(f" diff.abs.max = {diff.max():.6f}") - print(f" diff.abs.mean = {diff.mean():.6f}") - print(f" max_diff_thr (rel) = {max_diff_thr:.6e}") - - # Find top k differences - flat_diff = diff.flatten() - actual_k = min(k, len(flat_diff)) - top_k_indices = np.argpartition(flat_diff, -actual_k)[-actual_k:] - top_k_indices = top_k_indices[np.argsort(-flat_diff[top_k_indices])] - - orig_indices = np.unravel_index(top_k_indices, diff.shape) - print(f" Top-{actual_k} differences:") - for i in range(actual_k): - idx = tuple(dim[i] for dim in orig_indices) - entry = { - "value": float(diff[idx]), - "position": idx, - "arr1_value": float(arr1[idx]), - "arr2_value": float(arr2[idx]), - } - result["top_k_diff"].append(entry) - print(f" [{idx}] result={arr1[idx]:.6f}, ref={arr2[idx]:.6f}, diff={diff[idx]:.6f}") - - # Compute threshold statistics - print(f" Threshold distribution ({total_elements} elements):") - for i in range(len(thresholds) - 1): - lower, upper = thresholds[i], thresholds[i + 1] - count = int(np.sum((diff >= lower) & (diff < upper))) - pct = 100.0 * count / total_elements - result["threshold_stats"].append( - {"range": f"[{lower:.0e}, {upper:.0e})", "count": count, "percentage": pct} - ) - print(f" [{lower:.0e}, {upper:.0e}): {count:>8d} ({pct:6.2f}%)") - - count = int(np.sum(diff >= thresholds[-1])) - pct = 100.0 * count / total_elements - result["threshold_stats"].append( - {"range": f">={thresholds[-1]:.0e}", "count": count, "percentage": pct} - ) - print(f" >={thresholds[-1]:.0e} : {count:>8d} ({pct:6.2f}%)") - - return result - - -def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, - warmup, iters, prev_exe=None, seed=DEFAULT_SEED): - device = "cuda" - results = {} - - if seq_len % 64 != 0: - results["err"] = f"seq_len ({seq_len}) must be divisible by 64" - return results - if head_dim % 16 != 0 or head_dim < 64: - results["err"] = f"head_dim ({head_dim}) must be >= 64 and divisible by 16" - return results - - try: - m = build_flash_attention_v4_4_module( - num_heads=num_heads, head_dim=head_dim, - causal=causal, dtype_str="f16", - ) - exe = flydsl.compile(m) - except Exception as e: - results["err"] = f"compile: {e}" - import traceback - traceback.print_exc() - return results - - B, S, H, D = batch, seq_len, num_heads, head_dim - setup_seed(seed) - q_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) - k_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) - v_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) - - q_flat = q_4d.contiguous().view(-1) - k_flat = k_4d.contiguous().view(-1) - v_flat = v_4d.contiguous().view(-1) - o_flat = torch.zeros_like(q_flat) - - try: - exe(q_flat, k_flat, v_flat, o_flat, B, S) - torch.cuda.synchronize() - except Exception as e: - results["err"] = f"exec: {e}" - import traceback - traceback.print_exc() - return results - - ref_4d = pytorch_ref_attention( - q_4d.float(), k_4d.float(), v_4d.float(), causal=causal - ).to(dtype) - ref_flat = ref_4d.contiguous().view(-1) - - o_f32 = o_flat.float() - ref_f32 = ref_flat.float() - max_err = (o_f32 - ref_f32).abs().max().item() - mean_err = (o_f32 - ref_f32).abs().mean().item() - cos_sim = F.cosine_similarity( - o_f32.view(-1, D), ref_f32.view(-1, D), dim=1 - ) - min_cos = cos_sim.min().item() - results["max_err"] = max_err - results["mean_err"] = mean_err - results["min_cos"] = min_cos - results["passed"] = max_err < 1e-2 and min_cos > 0.99 - - # Compute and print MD5 hashes - tag = f"B={B} S={S} H={H} D={D}" - result_md5 = compute_md5(o_flat) - ref_md5 = compute_md5(ref_flat) - print(f" [{tag}] result_md5 = {result_md5}") - print(f" [{tag}] ref_md5 = {ref_md5}") - if result_md5 == ref_md5: - print(f" [{tag}] MD5 match: EXACT (bit-identical)") - else: - print(f" [{tag}] MD5 match: DIFFER (not bit-identical)") - - # Detailed comparison using compare_arrays - print(f" [{tag}] --- compare_arrays ---") - compare_arrays( - o_flat.to(torch.float32).detach().cpu().numpy(), - ref_flat.to(torch.float32).detach().cpu().numpy(), - ) - - try: - def kernel_fn(): - exe(q_flat, k_flat, v_flat, o_flat, B, S) - - _, us = run_perftest(kernel_fn, num_iters=iters, num_warmup=warmup) - s_eff = S / 2.0 if causal else float(S) - flops = 4.0 * S * s_eff * D * H * B - tflops = flops / (us * 1e-6) / 1e12 - results["us"] = us - results["tflops"] = tflops - except Exception as e: - results["bench_err"] = str(e) - - if prev_exe is not None: - try: - o_prev = torch.zeros_like(q_flat) - def prev_fn(): - prev_exe(q_flat, k_flat, v_flat, o_prev, B, S) - _, prev_us = run_perftest(prev_fn, num_iters=iters, num_warmup=warmup) - prev_tflops = flops / (prev_us * 1e-6) / 1e12 - results["prev_us"] = prev_us - results["prev_tflops"] = prev_tflops - except Exception as e: - results["prev_bench_err"] = str(e) - - return results - - -def main(): - parser = argparse.ArgumentParser( - description="Flash Attention V4.4 FlyDSL Test/Benchmark" - ) - parser.add_argument("--batch", type=int, default=None) - parser.add_argument("--seq_len", type=int, default=None) - parser.add_argument("--num_heads", type=int, default=None) - parser.add_argument("--head_dim", type=int, default=None) - parser.add_argument("--no-causal", action="store_true") - parser.add_argument("--warmup", type=int, default=5) - parser.add_argument("--iters", type=int, default=20) - parser.add_argument("--compare-v43", action="store_true", - help="Also benchmark V4.3 for comparison") - parser.add_argument("--seed", type=int, default=DEFAULT_SEED, - help=f"Random seed for reproducibility (default: {DEFAULT_SEED})") - args = parser.parse_args() - - causal = not args.no_causal - dtype = torch.float16 - - print("=" * 130) - print(f"FlyDSL Flash Attention V4.4 ({'causal' if causal else 'non-causal'}, fp16)") - print(f" CK-aligned: BLOCK_M=64, BLOCK_N=64, 4 waves (256 threads), mfma_f32_16x16x16f16") - print(f" Vectorized output via LDS, softmax over 64 positions") - print(f"GPU: {torch.cuda.get_device_name(0)}") - print("=" * 130) - - if args.seq_len or args.head_dim or args.batch: - configs = [( - args.batch or 1, - args.seq_len or 128, - args.num_heads or 8, - args.head_dim or 128, - )] - else: - configs = [ - (1, 64, 8, 128), - (1, 128, 8, 128), - (1, 256, 32, 128), - (1, 512, 32, 128), - (2, 128, 8, 128), - ] - - prev_exes = {} - if args.compare_v43: - from kernels.flash_attention_v4_3 import build_flash_attention_v4_3_module - for _, _, nh, hd in configs: - key = (nh, hd) - if key not in prev_exes: - try: - m = build_flash_attention_v4_3_module( - num_heads=nh, head_dim=hd, - causal=causal, dtype_str="f16", - ) - prev_exes[key] = flydsl.compile(m) - except Exception: - prev_exes[key] = None - - if args.compare_v43: - hdr = ( - f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " - f"{'MinCos':>8s} | {'V4.4(us)':>10s} {'V4.4 TF':>9s} | " - f"{'V4.3(us)':>10s} {'V4.3 TF':>9s} | {'Speedup':>7s}" - ) - else: - hdr = ( - f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " - f"{'MinCos':>8s} | {'Time(us)':>10s} {'TFLOPS':>8s}" - ) - print(f"\n{hdr}") - print("-" * len(hdr)) - - all_passed = True - for batch, seq_len, nh, hd in configs: - tag = f"B={batch} S={seq_len} H={nh} D={hd}" - try: - prev_exe = prev_exes.get((nh, hd)) if args.compare_v43 else None - r = run_config( - batch, seq_len, nh, hd, dtype, causal, - warmup=args.warmup, iters=args.iters, - prev_exe=prev_exe, seed=args.seed, - ) - if "err" in r: - print(f"{tag:>38s} | {'ERROR':>6s} | {r['err'][:60]}") - all_passed = False - continue - - status = "PASS" if r["passed"] else "FAIL" - if not r["passed"]: - all_passed = False - - us_s = f"{r['us']:>10.1f}" if "us" in r else " N/A" - tf_s = f"{r['tflops']:>9.3f}" if "tflops" in r else " N/A" - - if args.compare_v43 and "prev_us" in r: - p_us = f"{r['prev_us']:>10.1f}" - p_tf = f"{r['prev_tflops']:>9.3f}" - speedup = r["prev_us"] / r["us"] if r.get("us") else 0 - print( - f"{tag:>38s} | {status:>6s} | " - f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " - f"{us_s} {tf_s} | {p_us} {p_tf} | {speedup:>6.2f}x" - ) - else: - print( - f"{tag:>38s} | {status:>6s} | " - f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " - f"{us_s} {tf_s}" - ) - except Exception as e: - print(f"{tag:>38s} | {'ERROR':>6s} | {str(e)[:60]}") - all_passed = False - - print("=" * 130) - if all_passed: - print("All tests PASSED") - else: - print("Some tests FAILED") - sys.exit(1) - - -if __name__ == "__main__": - main() From d41beae43b6111cecd24133e8ca3371954860045 Mon Sep 17 00:00:00 2001 From: yanguahe Date: Fri, 13 Feb 2026 12:52:28 +0800 Subject: [PATCH 08/17] Add flash_attention_v4_4_kernel and it's test --- flydsl/src/flydsl/compiler/compiler.py | 60 ++- flydsl/src/flydsl/dialects/ext/rocdl.py | 17 +- kernels/flash_attention_v4_4.py | 582 +++++++++++++++++++++ run.sh | 7 +- tests/kernels/test_flash_attention_v4_4.py | 388 ++++++++++++++ 5 files changed, 1050 insertions(+), 4 deletions(-) create mode 100644 kernels/flash_attention_v4_4.py create mode 100644 tests/kernels/test_flash_attention_v4_4.py diff --git a/flydsl/src/flydsl/compiler/compiler.py b/flydsl/src/flydsl/compiler/compiler.py index e3eb320c..6a700298 100644 --- a/flydsl/src/flydsl/compiler/compiler.py +++ b/flydsl/src/flydsl/compiler/compiler.py @@ -349,6 +349,53 @@ def _append_passthrough(func_op): pass +def _apply_flat_work_group_size_on_llvm_funcs(module: ir.Module, max_workgroup_size: int) -> None: + """Apply AMDGPU flat-work-group-size hint to GPU kernel llvm.func ops. + + LLVM expects a string value in the form "min,max". We set min=1 and max to + the requested workgroup size. + """ + attr_key = ir.StringAttr.get("amdgpu-flat-work-group-size") + attr_value = ir.StringAttr.get(f"1,{max_workgroup_size}") + new_entry = ir.ArrayAttr.get([attr_key, attr_value]) + new_entry_str = f"amdgpu-flat-work-group-size=1,{max_workgroup_size}" + + def _append_passthrough(func_op): + try: + existing = func_op.attributes["passthrough"] + except KeyError: + existing = None + + if existing is None: + func_op.attributes["passthrough"] = ir.ArrayAttr.get([new_entry]) + return + + try: + existing_entries = list(existing) + except TypeError: + func_op.attributes["passthrough"] = ir.ArrayAttr.get([new_entry]) + return + + if any(str(a).strip('"') == new_entry_str for a in existing_entries): + return + func_op.attributes["passthrough"] = ir.ArrayAttr.get(existing_entries + [new_entry]) + + try: + for op in module.body.operations: + if getattr(op, "OPERATION_NAME", None) != "gpu.module": + continue + gpu_module_body = op.regions[0].blocks[0] if hasattr(op, 'regions') else op.body + for inner_op in gpu_module_body.operations: + if getattr(inner_op, "OPERATION_NAME", None) != "llvm.func": + continue + if "gpu.kernel" not in inner_op.attributes: + continue + _append_passthrough(inner_op) + except Exception: + # Best-effort only. + pass + + def compile( flir_module_or_ir: Union[object, ir.Module], *, @@ -361,6 +408,7 @@ def compile( use_bare_pointers_for_host: bool = False, use_bare_pointers_for_kernels: bool = False, waves_per_eu: Optional[int] = None, + flat_work_group_size: Optional[int] = None, unsafe_fp_math: bool = False, fast_fp_math: bool = False, ) -> Executor: @@ -500,6 +548,9 @@ def compile( # Apply waves_per_eu if specified (BEFORE saving asm_for_isa) if waves_per_eu is not None: _apply_waves_per_eu_on_llvm_funcs(module, waves_per_eu) + # Apply flat work-group-size hint if specified. + if flat_work_group_size is not None: + _apply_flat_work_group_size_on_llvm_funcs(module, flat_work_group_size) # Apply unsafe-fp-math function attributes for fast exp2/math if unsafe_fp_math: _apply_unsafe_fp_math_on_llvm_funcs(module) @@ -521,7 +572,11 @@ def compile( isa_stage = f"{stage_num_base + len(stage_frags):02d}_final_isa" print(f"[flir.compile] dump {isa_stage} -> {isa_out}") else: - need_split = (waves_per_eu is not None) or unsafe_fp_math + need_split = ( + (waves_per_eu is not None) + or (flat_work_group_size is not None) + or unsafe_fp_math + ) if need_split: # Need to split the pipeline to apply function attributes # after LLVM lowering but before binary generation. @@ -545,6 +600,9 @@ def compile( # Apply waves_per_eu if waves_per_eu is not None: _apply_waves_per_eu_on_llvm_funcs(module, waves_per_eu) + # Apply flat work-group-size hint + if flat_work_group_size is not None: + _apply_flat_work_group_size_on_llvm_funcs(module, flat_work_group_size) # Apply unsafe-fp-math function attributes for fast exp2/math if unsafe_fp_math: _apply_unsafe_fp_math_on_llvm_funcs(module) diff --git a/flydsl/src/flydsl/dialects/ext/rocdl.py b/flydsl/src/flydsl/dialects/ext/rocdl.py index be82474d..a9356d0e 100644 --- a/flydsl/src/flydsl/dialects/ext/rocdl.py +++ b/flydsl/src/flydsl/dialects/ext/rocdl.py @@ -17,6 +17,7 @@ from _mlir.dialects.rocdl import * # noqa: F401,F403 # Keep references to ODS-generated builders so we can wrap them without losing access. +_ods_mfma_f32_32x32x8f16 = globals().get("mfma_f32_32x32x8f16", None) _ods_mfma_f32_16x16x16f16 = mfma_f32_16x16x16f16 _ods_mfma_f32_16x16x16bf16_1k = globals().get("mfma_f32_16x16x16bf16_1k", None) _ods_mfma_f32_16x16x32_fp8_fp8 = mfma_f32_16x16x32_fp8_fp8 @@ -68,6 +69,20 @@ def mfma_f32_16x16x16f16(result_type, operands, *, loc=None, ip=None): """Return the op result directly (no `.result` needed at call sites).""" return mfma_f32_16x16x16f16_op(result_type, operands, loc=loc, ip=ip).result + +def mfma_f32_32x32x8f16_op(result_type, operands, *, loc=None, ip=None): + """Return the op view (original behavior).""" + if _ods_mfma_f32_32x32x8f16 is None: + raise AttributeError("ROCDL op not found: mfma_f32_32x32x8f16") + ops = [_unwrap_mfma_operand(v, loc=loc) for v in operands] + return _ods_mfma_f32_32x32x8f16(result_type, ops, loc=loc, ip=ip) + + +def mfma_f32_32x32x8f16(result_type, operands, *, loc=None, ip=None): + """Return the op result directly (no `.result` needed at call sites).""" + return mfma_f32_32x32x8f16_op(result_type, operands, loc=loc, ip=ip).result + + # for bf16 version mfma def mfma_f32_16x16x16bf16_1k_op(result_type, operands, *, loc=None, ip=None): """Return the op view (original behavior).""" @@ -182,7 +197,7 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None, 'mfma_i32_16x16x32_i8', 'mfma_scale_f32_16x16x128_f8f6f4', # Raw-op constructors (return op view) for the above - 'mfma_f32_16x16x16f16_op', 'mfma_f32_16x16x32_fp8_fp8_op', + 'mfma_f32_32x32x8f16_op', 'mfma_f32_16x16x16f16_op', 'mfma_f32_16x16x32_fp8_fp8_op', 'mfma_f32_16x16x16bf16_1k_op', 'mfma_i32_16x16x32_i8_op', 'mfma_scale_f32_16x16x128_f8f6f4_op', diff --git a/kernels/flash_attention_v4_4.py b/kernels/flash_attention_v4_4.py new file mode 100644 index 00000000..5298682a --- /dev/null +++ b/kernels/flash_attention_v4_4.py @@ -0,0 +1,582 @@ +"""Flash Attention V4.4 kernel builder for FlyDSL. + +V4.4 design (CK-aligned direction, rewritten from V4.3): +- CK-aligned baseline tile family: BLOCK_M=64, BLOCK_N=32. +- Q loaded once from global memory into MFMA A-operand packs (register-resident). +- K/V streamed tile-by-tile through LDS. +- Online softmax in fp32 over 32 positions per iteration (2x 16-column groups). +- Causal early-exit keeps KV upper bound at q_start + BLOCK_M. + +Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). +Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. +Block: (256,) -- 4 waves of 64 on AMD (wave64). + +Requires: head_dim % 16 == 0, head_dim >= 64, seq_len % 64 == 0. +""" + +import math + +from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl +from flydsl.dialects.ext import vector as vec_ext +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.dialects.ext.scf import yield_ as scf_yield +from _mlir.dialects import memref as _memref +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils import SmemAllocator +from _mlir import ir +import _mlir.extras.types as T + + +KERNEL_NAME = "flash_attention_v4_4_kernel" + + +def build_flash_attention_v4_4_module( + num_heads, + head_dim, + causal=True, + dtype_str="f16", + sm_scale=None, +): + """Build a FlyDSL Flash Attention V4.4 module. + + Args: + num_heads: Number of attention heads. + head_dim: Dimension per head (must be divisible by 16, >= 64). + causal: Whether to apply causal mask. + dtype_str: "f16" (bf16 not yet supported). + sm_scale: Softmax scale (default: 1/sqrt(head_dim)). + + Returns: + MlirModule compilable via ``flydsl.compile(module)``. + """ + gpu_arch = get_hip_arch() + DYN = ir.ShapedType.get_dynamic_size() + + # CK-oriented direction for the target (B=1, H=64, S=8192, D=128). + BLOCK_M = 64 + BLOCK_N = 32 + NUM_WAVES = 4 + WARP_SIZE = 64 + BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 + ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 16 + K_STEPS = head_dim // 16 + N_MFMA = BLOCK_N // 16 # 2 + + assert BLOCK_M % NUM_WAVES == 0 + assert head_dim % 16 == 0, f"head_dim ({head_dim}) must be divisible by 16" + assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" + assert dtype_str == "f16", "V4.4 currently only supports f16" + assert BLOCK_N % 16 == 0, f"BLOCK_N ({BLOCK_N}) must be divisible by 16" + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + NUM_HEADS = num_heads + HEAD_DIM = head_dim + CAUSAL = causal + STRIDE_TOKEN = NUM_HEADS * HEAD_DIM + + # ---- Bank-conflict-friendly LDS strides ---- + K_STRIDE = HEAD_DIM + 2 + VT_STRIDE = BLOCK_N + 2 + + # ---- Vectorized cooperative load constants ---- + VEC_WIDTH = 8 + THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH + assert BLOCK_SIZE % THREADS_PER_ROW_LOAD == 0 + ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD + + assert BLOCK_M % ROWS_PER_BATCH_LOAD == 0 + NUM_BATCHES_Q = BLOCK_M // ROWS_PER_BATCH_LOAD + + if ROWS_PER_BATCH_LOAD >= BLOCK_N: + NUM_BATCHES_KV = 1 + KV_NEEDS_GUARD = ROWS_PER_BATCH_LOAD > BLOCK_N + else: + assert BLOCK_N % ROWS_PER_BATCH_LOAD == 0 + NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD + KV_NEEDS_GUARD = False + + # LDS sizes (element counts, f16 = 2 bytes each) + # No Q in LDS: Q is read once from global memory to MFMA A packs. + LDS_KV_SIZE = max(BLOCK_N * K_STRIDE, HEAD_DIM * VT_STRIDE) + LDS_P_SIZE = BLOCK_M * BLOCK_N + + allocator = SmemAllocator(None, arch=gpu_arch) + _state = {} + + class _FlashAttentionV4_4(flir.MlirModule): + GPU_MODULE_NAME = f"flash_attn_v4_4_{dtype_str}" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + elem_type = T.f16() + _state["elem_type"] = elem_type + _state["lds_kv"] = allocator.allocate_array(elem_type, LDS_KV_SIZE) + _state["lds_p"] = allocator.allocate_array(elem_type, LDS_P_SIZE) + allocator.finalize() + + @flir.kernel + def flash_attention_v4_4_kernel( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + seq_len: lambda: T.index(), + ): + compute_type = T.f32() + elem_type = _state["elem_type"] + fm_fast = flir.arith.FastMathFlags.fast + + v4f16_type = ir.VectorType.get([4], elem_type) + v4f32_type = ir.VectorType.get([4], compute_type) + v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) + + seq_len_v = arith.as_value(seq_len) + + # ---- LDS views (KV + P only, no Q in LDS) ---- + base_ptr = allocator.get_base() + lds_kv = _state["lds_kv"](base_ptr).get() + lds_p = _state["lds_p"](base_ptr).get() + + # ---- Thread / block indices ---- + block_id = flir.const_index(flir.block_idx("x")) + tid = flir.const_index(flir.thread_idx("x")) + + # ---- Wave decomposition ---- + c_ws = flir.const_index(WARP_SIZE) + wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) + lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) + + # ---- MFMA lane decomposition ---- + c16 = flir.const_index(16) + lane_div_16 = arith.as_value(flir.arith.DivUIOp(lane, c16).result) + lane_mod_16 = arith.as_value(flir.arith.RemUIOp(lane, c16).result) + + # ---- Wave offsets ---- + wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value + wave_p_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE * BLOCK_N).value + + # ---- Decompose block_id ---- + c_nh = flir.const_index(NUM_HEADS) + head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) + temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) + c_bm = flir.const_index(BLOCK_M) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) + q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) + batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) + q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value + + # ---- Cooperative load decomposition ---- + c_tpr = flir.const_index(THREADS_PER_ROW_LOAD) + load_row_in_batch = arith.as_value(flir.arith.DivUIOp(tid, c_tpr).result) + load_lane_in_row = arith.as_value(flir.arith.RemUIOp(tid, c_tpr).result) + load_col_base = (arith.ArithValue(load_lane_in_row) * VEC_WIDTH).value + + # ---- Helper: global flat index ---- + def global_idx(token_idx, col): + token = ( + arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) + + arith.ArithValue(token_idx) + ) + return ( + token * STRIDE_TOKEN + + arith.ArithValue(head_idx) * HEAD_DIM + + arith.ArithValue(col) + ).value + + # ---- Cooperative K load (row-major, padded stride) ---- + def coop_load_k(tile_start): + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + if KV_NEEDS_GUARD: + c_bn = flir.const_index(BLOCK_N) + row_valid = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + arith.ArithValue(load_row_in_batch).value, + c_bn, + ).result + ) + with scf.if_(row_valid): + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, K, [g_idx])) + lds_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + lds_idx = ( + arith.ArithValue(lds_row) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + else: + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, K, [g_idx])) + lds_row = (arith.ArithValue(load_row_in_batch) + row_offset).value + lds_idx = ( + arith.ArithValue(lds_row) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + + # ---- Cooperative V load (transposed, padded stride) ---- + def coop_load_v_transposed(tile_start): + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + if KV_NEEDS_GUARD: + c_bn = flir.const_index(BLOCK_N) + row_valid = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + arith.ArithValue(load_row_in_batch).value, + c_bn, + ).result + ) + with scf.if_(row_valid): + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, V, [g_idx])) + load_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + else: + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, V, [g_idx])) + load_row = (arith.ArithValue(load_row_in_batch) + row_offset).value + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + + # ---- Load Q once from global memory to MFMA A packs ---- + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_mod_16) + ).value + q_a_packs = [] + for ks in range_constexpr(K_STEPS): + q_col = flir.const_index(ks * 16) + q_col = (arith.ArithValue(q_col) + arith.ArithValue(lane_div_16) * 4).value + g_idx = global_idx(q_row, q_col) + q_a_packs.append(arith.as_value(vec_ext.load_op(v4f16_type, Q, [g_idx]))) + + # ---- Constants ---- + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_zero_f = arith.constant(0.0, type=compute_type) + c_sm_scale = arith.constant(sm_scale, type=compute_type) + c_log2e = arith.constant(1.4426950408889634, type=compute_type) + c_zero_v4f32 = arith.as_value(arith.constant_vector(0.0, v4f32_type)) + + # ---- Init loop-carried state ---- + # m[4], l[4], o_accs[K_STEPS] + init_args = [] + for _ in range_constexpr(4): + init_args.append(arith.as_value(c_neg_inf)) + for _ in range_constexpr(4): + init_args.append(arith.as_value(c_zero_f)) + for _ in range_constexpr(K_STEPS): + init_args.append(c_zero_v4f32) + + # ---- KV loop upper bound ---- + if CAUSAL: + kv_upper = (arith.ArithValue(q_start) + BLOCK_M).value + else: + kv_upper = seq_len_v + + # ---- KV loop (step BLOCK_N=64) ---- + with scf.for_(0, kv_upper, BLOCK_N, iter_args=init_args) as loop: + kv_start = arith.as_value(loop.induction_variable) + m_old = [arith.as_value(loop.inner_iter_args[i]) for i in range(4)] + l_old = [arith.as_value(loop.inner_iter_args[4 + i]) for i in range(4)] + o_accs = [arith.as_value(loop.inner_iter_args[8 + ds]) for ds in range(K_STEPS)] + + # ==== Cooperative K load -> LDS_KV ==== + coop_load_k(kv_start) + gpu.barrier() + + # ==== Q @ K^T via MFMA -> S[16, BLOCK_N] ==== + s_accs = [c_zero_v4f32 for _ in range_constexpr(N_MFMA)] + for ks in range_constexpr(K_STEPS): + a_pack = q_a_packs[ks] + for nm in range_constexpr(N_MFMA): + k_row = nm * 16 + k_lds_idx = ( + (arith.ArithValue(lane_mod_16) + k_row) * K_STRIDE + + ks * 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + b_pack = arith.as_value(vec_ext.load_op(v4f16_type, lds_kv, [k_lds_idx])) + s_accs[nm] = arith.as_value( + rocdl.mfma_f32_16x16x16f16( + v4f32_type, [a_pack, b_pack, s_accs[nm], 0, 0, 0] + ) + ) + + # ==== Online softmax over BLOCK_N positions ==== + # s_vals[nm][ii] where nm in [0..3], ii in [0..3] + s_vals = [[None for _ in range_constexpr(4)] for _ in range_constexpr(N_MFMA)] + for ii in range_constexpr(4): + for nm in range_constexpr(N_MFMA): + s_val = arith.as_value( + vec_ext.extract(s_accs[nm], static_position=[ii], dynamic_position=[]) + ) + s_val = arith.as_value( + flir.arith.MulFOp( + s_val, arith.as_value(c_sm_scale), fastmath=fm_fast + ).result + ) + + if CAUSAL: + q_row_i = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_div_16) * 4 + + ii + ).value + kv_col = ( + arith.ArithValue(kv_start) + + nm * 16 + + arith.ArithValue(lane_mod_16) + ).value + q_row_i64 = arith.as_value( + flir.arith.IndexCastOp(T.i64(), q_row_i).result + ) + kv_col_i64 = arith.as_value( + flir.arith.IndexCastOp(T.i64(), kv_col).result + ) + is_masked = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ugt, kv_col_i64, q_row_i64 + ).result + ) + s_val = arith.as_value( + flir.arith.SelectOp( + is_masked, arith.as_value(c_neg_inf), s_val + ).result + ) + s_vals[nm][ii] = s_val + + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + m_new = [None] * 4 + corr = [None] * 4 + p_vals = [[None for _ in range_constexpr(4)] for _ in range_constexpr(N_MFMA)] + l_new = [None] * 4 + + for ii in range_constexpr(4): + row_maxes = [] + for nm in range_constexpr(N_MFMA): + row_max_nm = s_vals[nm][ii] + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp( + row_max_nm, sh_i32, width_i32, mode="xor" + ).shuffleResult + ) + row_max_nm = arith.as_value( + flir.arith.MaximumFOp(row_max_nm, peer).result + ) + row_maxes.append(row_max_nm) + + combined_max = row_maxes[0] + for g in range_constexpr(N_MFMA - 1): + combined_max = arith.as_value( + flir.arith.MaximumFOp(combined_max, row_maxes[g + 1]).result + ) + + m_new[ii] = arith.as_value( + flir.arith.MaximumFOp(m_old[ii], combined_max).result + ) + + diff_m = arith.as_value( + flir.arith.SubFOp(m_old[ii], m_new[ii], fastmath=fm_fast).result + ) + diff_m_s = arith.as_value( + flir.arith.MulFOp( + diff_m, arith.as_value(c_log2e), fastmath=fm_fast + ).result + ) + corr[ii] = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) + + row_sums = [] + for nm in range_constexpr(N_MFMA): + diff = arith.as_value( + flir.arith.SubFOp( + s_vals[nm][ii], m_new[ii], fastmath=fm_fast + ).result + ) + diff_s = arith.as_value( + flir.arith.MulFOp( + diff, arith.as_value(c_log2e), fastmath=fm_fast + ).result + ) + p_vals[nm][ii] = arith.as_value(flir.math.exp2(diff_s, fastmath=fm_fast)) + + row_sum_nm = p_vals[nm][ii] + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp( + row_sum_nm, sh_i32, width_i32, mode="xor" + ).shuffleResult + ) + row_sum_nm = arith.as_value( + flir.arith.AddFOp(row_sum_nm, peer, fastmath=fm_fast).result + ) + row_sums.append(row_sum_nm) + + combined_sum = row_sums[0] + for g in range_constexpr(N_MFMA - 1): + combined_sum = arith.as_value( + flir.arith.AddFOp(combined_sum, row_sums[g + 1], fastmath=fm_fast).result + ) + + l_corr = arith.as_value( + flir.arith.MulFOp(corr[ii], l_old[ii], fastmath=fm_fast).result + ) + l_new[ii] = arith.as_value( + flir.arith.AddFOp(l_corr, combined_sum, fastmath=fm_fast).result + ) + + # ==== Rescale O accumulators ==== + corr_vec = arith.as_value( + vec_ext.from_elements(v4f32_type, [corr[0], corr[1], corr[2], corr[3]]) + ) + for ds in range_constexpr(K_STEPS): + o_accs[ds] = arith.as_value( + flir.arith.MulFOp(o_accs[ds], corr_vec, fastmath=fm_fast).result + ) + + # ==== P store to LDS_P ==== + for ii in range_constexpr(4): + p_row = (arith.ArithValue(lane_div_16) * 4 + ii).value + for nm in range_constexpr(N_MFMA): + p_f16 = arith.as_value( + flir.arith.TruncFOp(elem_type, p_vals[nm][ii]).result + ) + p_lds_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(p_row) * BLOCK_N + + nm * 16 + + arith.ArithValue(lane_mod_16) + ).value + _memref.StoreOp(p_f16, lds_p, [p_lds_idx]) + + # ==== Barrier: ensure all waves done reading K ==== + gpu.barrier() + + # ==== Cooperative V load (transposed) ==== + coop_load_v_transposed(kv_start) + gpu.barrier() + + # ==== P @ V via MFMA ==== + for ds in range_constexpr(K_STEPS): + for nm in range_constexpr(N_MFMA): + p_a_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(lane_mod_16) * BLOCK_N + + nm * 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + p_pack = arith.as_value( + vec_ext.load_op(v4f16_type, lds_p, [p_a_idx]) + ) + + v_idx = ( + (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE + + nm * 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + v_pack = arith.as_value( + vec_ext.load_op(v4f16_type, lds_kv, [v_idx]) + ) + o_accs[ds] = arith.as_value( + rocdl.mfma_f32_16x16x16f16( + v4f32_type, [p_pack, v_pack, o_accs[ds], 0, 0, 0] + ) + ) + + # ==== Barrier: ensure all waves done reading V ==== + gpu.barrier() + + yield_args = m_new + l_new + o_accs + scf_yield(yield_args) + + # ---- Normalize and store O ---- + m_finals = [arith.as_value(loop.results[i]) for i in range(4)] + l_finals = [arith.as_value(loop.results[4 + i]) for i in range(4)] + o_finals = [arith.as_value(loop.results[8 + ds]) for ds in range(K_STEPS)] + + for ds in range_constexpr(K_STEPS): + for ii in range_constexpr(4): + o_val = arith.as_value( + vec_ext.extract(o_finals[ds], static_position=[ii], dynamic_position=[]) + ) + o_norm = arith.as_value( + flir.arith.DivFOp(o_val, l_finals[ii], fastmath=fm_fast).result + ) + o_f16 = arith.as_value(flir.arith.TruncFOp(elem_type, o_norm).result) + q_row_o = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_div_16) * 4 + + ii + ).value + d_col = (flir.const_index(ds * 16) + arith.ArithValue(lane_mod_16)).value + o_global = global_idx(q_row_o, d_col) + _memref.StoreOp(o_f16, O, [o_global]) + + @flir.jit + def __call__( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + batch_size: lambda: T.index(), + seq_len: lambda: T.index(), + ): + c1 = arith.as_value(flir.arith_ext.index(1)) + c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) + c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) + bs_val = arith.as_value(batch_size) + sl_val = arith.as_value(seq_len) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(sl_val, c_bm).result) + bs_qt = arith.as_value(flir.arith.MulIOp(bs_val, num_q_tiles).result) + grid_x = arith.as_value(flir.arith.MulIOp(bs_qt, c_nh).result) + bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) + flir.gpu_ext.LaunchFuncOp( + [self.GPU_MODULE_NAME, KERNEL_NAME], + grid_size=(grid_x, c1, c1), + block_size=(bx, c1, c1), + kernel_operands=[Q, K, V, O, seq_len], + ) + + return _FlashAttentionV4_4() diff --git a/run.sh b/run.sh index c6c032cc..7a99674d 100755 --- a/run.sh +++ b/run.sh @@ -38,8 +38,11 @@ function run_flydsl_op { # python tests/kernels/test_simple_gemm.py --size all --dtype all # python tests/kernels/test_flash_attention_v4_2.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 - python tests/kernels/test_flash_attention_v4_3.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v42 - # python tests/kernels/test_flash_attention_v4_4.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v43 + # python tests/kernels/test_flash_attention_v4_3.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v42 + python tests/kernels/test_flash_attention_v4_4.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v43 + + # rocprof -i perf_counters1.txt -o prof_v44_p1.csv python tests/kernels/test_flash_attention_v4_4.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 5 --warmup 2 + # rocprof -i perf_counters2.txt -o prof_v44_p2.csv python tests/kernels/test_flash_attention_v4_4.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 5 --warmup 2 } diff --git a/tests/kernels/test_flash_attention_v4_4.py b/tests/kernels/test_flash_attention_v4_4.py new file mode 100644 index 00000000..ceddcace --- /dev/null +++ b/tests/kernels/test_flash_attention_v4_4.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +"""Flash Attention V4.4 kernel test and benchmark for FlyDSL. + +Tests V4.4 against PyTorch SDPA. +Optionally compares with V4.3. +""" + +import sys +import argparse +import hashlib +import random +from pathlib import Path +import logging + +# Configure logging to show INFO level messages (required for kernel name display) +logging.basicConfig(level=logging.INFO) + +_repo = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(_repo)) + +try: + import torch + import torch.nn.functional as F + import numpy as np +except ImportError: + print("PyTorch not available") + sys.exit(1) + +if not torch.cuda.is_available(): + print("CUDA/ROCm not available") + sys.exit(1) + +import flydsl +from kernels.flash_attention_v4_4 import build_flash_attention_v4_4_module, KERNEL_NAME +from tests.test_common import run_perftest + +# Tensor initialization range (uniform distribution) +UNIFORM_RANGE = (-1, 1) +DEFAULT_SEED = 123 +V4_4_COMPILE_KWARGS = { + "unsafe_fp_math": True, + "fast_fp_math": True, + "waves_per_eu": 4, + "flat_work_group_size": 256, +} + + +def setup_seed(seed: int) -> None: + """Set random seed for reproducibility across all RNG sources.""" + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + + +def pytorch_ref_attention(q, k, v, causal=True): + q_t = q.transpose(1, 2).float() + k_t = k.transpose(1, 2).float() + v_t = v.transpose(1, 2).float() + out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=causal) + return out.transpose(1, 2) + + +def compute_md5(tensor: torch.Tensor) -> str: + """Compute MD5 hash of a tensor's raw bytes.""" + return hashlib.md5( + tensor.contiguous().view(torch.uint8).detach().cpu().numpy().tobytes() + ).hexdigest() + + +def compare_arrays( + arr1: np.ndarray, + arr2: np.ndarray, + k: int = 5, + thresholds: list = None, +) -> dict: + """Compare two numpy arrays and compute various difference metrics. + + Args: + arr1: First input array (result), will be cast to float32. + arr2: Second input array (reference), will be cast to float32. + k: Number of top differences to report. + thresholds: Difference magnitude buckets for histogram. + + Returns: + Dictionary with top_k_diff, threshold_stats, nan_info, max_diff, max_diff_thr. + """ + if thresholds is None: + thresholds = [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1] + + if arr1.shape != arr2.shape: + raise ValueError(f"Shape mismatch: arr1 {arr1.shape} vs arr2 {arr2.shape}") + + arr1 = arr1.astype(np.float32) + arr2 = arr2.astype(np.float32) + + result = {"top_k_diff": [], "threshold_stats": [], "nan_info": {}} + + # Check for NaN values + nan_mask1 = np.isnan(arr1) + nan_mask2 = np.isnan(arr2) + if np.any(nan_mask1): + result["nan_info"]["arr1_nan_count"] = int(np.sum(nan_mask1)) + print(f" Warning: result contains {result['nan_info']['arr1_nan_count']} NaN values") + if np.any(nan_mask2): + result["nan_info"]["arr2_nan_count"] = int(np.sum(nan_mask2)) + print(f" Warning: reference contains {result['nan_info']['arr2_nan_count']} NaN values") + + # Compute absolute differences + diff = np.abs(arr1 - arr2) + total_elements = arr1.size + + max_diff_thr = (diff / (1.0 + np.abs(arr2))).max() + result["max_diff"] = float(diff.max()) + result["max_diff_thr"] = float(max_diff_thr) + + print(f" diff.abs.max = {diff.max():.6f}") + print(f" diff.abs.mean = {diff.mean():.6f}") + print(f" max_diff_thr (rel) = {max_diff_thr:.6e}") + + # Find top k differences + flat_diff = diff.flatten() + actual_k = min(k, len(flat_diff)) + top_k_indices = np.argpartition(flat_diff, -actual_k)[-actual_k:] + top_k_indices = top_k_indices[np.argsort(-flat_diff[top_k_indices])] + + orig_indices = np.unravel_index(top_k_indices, diff.shape) + print(f" Top-{actual_k} differences:") + for i in range(actual_k): + idx = tuple(dim[i] for dim in orig_indices) + entry = { + "value": float(diff[idx]), + "position": idx, + "arr1_value": float(arr1[idx]), + "arr2_value": float(arr2[idx]), + } + result["top_k_diff"].append(entry) + print(f" [{idx}] result={arr1[idx]:.6f}, ref={arr2[idx]:.6f}, diff={diff[idx]:.6f}") + + # Compute threshold statistics + print(f" Threshold distribution ({total_elements} elements):") + for i in range(len(thresholds) - 1): + lower, upper = thresholds[i], thresholds[i + 1] + count = int(np.sum((diff >= lower) & (diff < upper))) + pct = 100.0 * count / total_elements + result["threshold_stats"].append( + {"range": f"[{lower:.0e}, {upper:.0e})", "count": count, "percentage": pct} + ) + print(f" [{lower:.0e}, {upper:.0e}): {count:>8d} ({pct:6.2f}%)") + + count = int(np.sum(diff >= thresholds[-1])) + pct = 100.0 * count / total_elements + result["threshold_stats"].append( + {"range": f">={thresholds[-1]:.0e}", "count": count, "percentage": pct} + ) + print(f" >={thresholds[-1]:.0e} : {count:>8d} ({pct:6.2f}%)") + + return result + + +def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, warmup, iters, prev_exe=None, seed=DEFAULT_SEED): + device = "cuda" + results = {} + + if seq_len % 64 != 0: + results["err"] = f"seq_len ({seq_len}) must be divisible by 64 for V4.4" + return results + if head_dim % 16 != 0 or head_dim < 64: + results["err"] = f"head_dim ({head_dim}) must be >= 64 and divisible by 16" + return results + + try: + m = build_flash_attention_v4_4_module( + num_heads=num_heads, head_dim=head_dim, causal=causal, dtype_str="f16" + ) + exe = flydsl.compile(m, **V4_4_COMPILE_KWARGS) + except Exception as e: + results["err"] = f"compile: {e}" + import traceback + + traceback.print_exc() + return results + + B, S, H, D = batch, seq_len, num_heads, head_dim + setup_seed(seed) + q_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) + k_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) + v_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) + + q_flat = q_4d.contiguous().view(-1) + k_flat = k_4d.contiguous().view(-1) + v_flat = v_4d.contiguous().view(-1) + o_flat = torch.zeros_like(q_flat) + + try: + exe(q_flat, k_flat, v_flat, o_flat, B, S) + torch.cuda.synchronize() + except Exception as e: + results["err"] = f"exec: {e}" + import traceback + + traceback.print_exc() + return results + + ref_4d = pytorch_ref_attention(q_4d.float(), k_4d.float(), v_4d.float(), causal=causal).to(dtype) + ref_flat = ref_4d.contiguous().view(-1) + + o_f32 = o_flat.float() + ref_f32 = ref_flat.float() + max_err = (o_f32 - ref_f32).abs().max().item() + mean_err = (o_f32 - ref_f32).abs().mean().item() + cos_sim = F.cosine_similarity(o_f32.view(-1, D), ref_f32.view(-1, D), dim=1) + min_cos = cos_sim.min().item() + results["max_err"] = max_err + results["mean_err"] = mean_err + results["min_cos"] = min_cos + results["passed"] = max_err < 1e-2 and min_cos > 0.99 + + # Compute and print MD5 hashes + tag = f"B={B} S={S} H={H} D={D}" + result_md5 = compute_md5(o_flat) + ref_md5 = compute_md5(ref_flat) + print(f" [{tag}] result_md5 = {result_md5}") + print(f" [{tag}] ref_md5 = {ref_md5}") + if result_md5 == ref_md5: + print(f" [{tag}] MD5 match: EXACT (bit-identical)") + else: + print(f" [{tag}] MD5 match: DIFFER (not bit-identical)") + + print(f" [{tag}] --- compare_arrays ---") + compare_arrays( + o_flat.to(torch.float32).detach().cpu().numpy(), + ref_flat.to(torch.float32).detach().cpu().numpy(), + ) + + try: + def kernel_fn(): + exe(q_flat, k_flat, v_flat, o_flat, B, S) + + _, us = run_perftest(kernel_fn, num_iters=iters, num_warmup=warmup) + s_eff = S / 2.0 if causal else float(S) + flops = 4.0 * S * s_eff * D * H * B + tflops = flops / (us * 1e-6) / 1e12 + results["us"] = us + results["tflops"] = tflops + except Exception as e: + results["bench_err"] = str(e) + + if prev_exe is not None: + try: + o_prev = torch.zeros_like(q_flat) + + def prev_fn(): + prev_exe(q_flat, k_flat, v_flat, o_prev, B, S) + + _, prev_us = run_perftest(prev_fn, num_iters=iters, num_warmup=warmup) + prev_tflops = flops / (prev_us * 1e-6) / 1e12 + results["prev_us"] = prev_us + results["prev_tflops"] = prev_tflops + except Exception as e: + results["prev_bench_err"] = str(e) + + return results + + +def main(): + parser = argparse.ArgumentParser(description="Flash Attention V4.4 FlyDSL Test/Benchmark") + parser.add_argument("--batch", type=int, default=None) + parser.add_argument("--seq_len", type=int, default=None) + parser.add_argument("--num_heads", type=int, default=None) + parser.add_argument("--head_dim", type=int, default=None) + parser.add_argument("--no-causal", action="store_true") + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--iters", type=int, default=20) + parser.add_argument("--compare-v43", action="store_true", help="Also benchmark V4.3 for comparison") + parser.add_argument( + "--seed", type=int, default=DEFAULT_SEED, help=f"Random seed for reproducibility (default: {DEFAULT_SEED})" + ) + args = parser.parse_args() + + causal = not args.no_causal + dtype = torch.float16 + + print("=" * 130) + print(f"FlyDSL Flash Attention V4.4 ({'causal' if causal else 'non-causal'}, fp16)") + print(" Tile: BLOCK_M=64, BLOCK_N=32, 4 waves (256 threads), mfma_f32_16x16x16f16") + print(" Strategy: Q-load-once + KV streaming + online softmax over 32 positions") + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f" Compile opts: {V4_4_COMPILE_KWARGS}") + print("=" * 130) + + if args.seq_len or args.head_dim or args.batch: + configs = [(args.batch or 1, args.seq_len or 128, args.num_heads or 8, args.head_dim or 128)] + else: + configs = [ + (1, 128, 8, 128), + (1, 256, 32, 128), + (1, 512, 32, 128), + (2, 128, 8, 128), + ] + + prev_exes = {} + if args.compare_v43: + from kernels.flash_attention_v4_3 import build_flash_attention_v4_3_module + + for _, _, nh, hd in configs: + key = (nh, hd) + if key not in prev_exes: + try: + m = build_flash_attention_v4_3_module( + num_heads=nh, head_dim=hd, causal=causal, dtype_str="f16" + ) + prev_exes[key] = flydsl.compile(m) + except Exception: + prev_exes[key] = None + + if args.compare_v43: + hdr = ( + f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'MinCos':>8s} | {'V4.4(us)':>10s} {'V4.4 TF':>9s} | " + f"{'V4.3(us)':>10s} {'V4.3 TF':>9s} | {'Speedup':>7s}" + ) + else: + hdr = ( + f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'MinCos':>8s} | {'Time(us)':>10s} {'TFLOPS':>8s}" + ) + print(f"\n{hdr}") + print("-" * len(hdr)) + + all_passed = True + for batch, seq_len, nh, hd in configs: + tag = f"B={batch} S={seq_len} H={nh} D={hd}" + try: + prev_exe = prev_exes.get((nh, hd)) if args.compare_v43 else None + r = run_config( + batch, + seq_len, + nh, + hd, + dtype, + causal, + warmup=args.warmup, + iters=args.iters, + prev_exe=prev_exe, + seed=args.seed, + ) + if "err" in r: + print(f"{tag:>38s} | {'ERROR':>6s} | {r['err'][:60]}") + all_passed = False + continue + + status = "PASS" if r["passed"] else "FAIL" + if not r["passed"]: + all_passed = False + + us_s = f"{r['us']:>10.1f}" if "us" in r else " N/A" + tf_s = f"{r['tflops']:>9.3f}" if "tflops" in r else " N/A" + + if args.compare_v43 and "prev_us" in r: + p_us = f"{r['prev_us']:>10.1f}" + p_tf = f"{r['prev_tflops']:>9.3f}" + speedup = r["prev_us"] / r["us"] if r.get("us") else 0 + print( + f"{tag:>38s} | {status:>6s} | " + f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " + f"{us_s} {tf_s} | {p_us} {p_tf} | {speedup:>6.2f}x" + ) + else: + print( + f"{tag:>38s} | {status:>6s} | " + f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " + f"{us_s} {tf_s}" + ) + except Exception as e: + print(f"{tag:>38s} | {'ERROR':>6s} | {str(e)[:60]}") + all_passed = False + + print("=" * 130) + if all_passed: + print("All tests PASSED") + else: + print("Some tests FAILED") + sys.exit(1) + + +if __name__ == "__main__": + main() From 7705b5076f7455c450c28054b788f804d68ee87d Mon Sep 17 00:00:00 2001 From: yanguahe Date: Fri, 13 Feb 2026 13:42:06 +0800 Subject: [PATCH 09/17] Optimize flash_attention_v4_4 with MFMA32 register pipeline - Switched the **active** `v4_4` kernel path to a true **MFMA32** pipeline (`mfma_f32_32x32x8f16`) with `BLOCK_M=128`, `BLOCK_N=32`, `NUM_WAVES=4`. - Remapped compute flow to **`K @ Q^T -> online softmax -> V^T @ P`**. - Kept intermediate **S/P in registers** (removed the previous `P -> LDS -> VGPR` roundtrip). - Split LDS staging for K and `V^T` into separate regions and removed an inner-loop barrier to cut synchronization overhead. - Updated test constraints and compile options in `test_flash_attention_v4_4.py` (`seq_len % 128`, `head_dim % 32`, `waves_per_eu=3`). - Final measured result at target shape: **12350.8 us/iter**, with accuracy preserved (`diff.abs.max=4.88e-4`, `max_diff_thr=3.255208e-04`), about **2.17x faster** than the previous 26751.5 us. --- kernels/flash_attention_v4_4.py | 1118 +++++++++++++++++++- tests/kernels/test_flash_attention_v4_4.py | 14 +- 2 files changed, 1113 insertions(+), 19 deletions(-) diff --git a/kernels/flash_attention_v4_4.py b/kernels/flash_attention_v4_4.py index 5298682a..15313315 100644 --- a/kernels/flash_attention_v4_4.py +++ b/kernels/flash_attention_v4_4.py @@ -1,5 +1,1098 @@ """Flash Attention V4.4 kernel builder for FlyDSL. +Aggressive V4.4 path: +- True MFMA32 remap: `mfma_f32_32x32x8f16` for both GEMM stages. +- Tile shape: BLOCK_M=128, BLOCK_N=32, 4 waves (256 threads). +- Per-wave Q rows: 32. +- GEMM1 uses `K @ Q^T` so S/P live in MFMA32 register layout. +- Online softmax over KV dimension is done in registers. +- P is kept in registers and fed directly to GEMM2 (`V^T @ P`) without LDS roundtrip. +- K uses LDS ping-pong prefetch between adjacent iterations. + +Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). +Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. +Block: (256,) -- 4 waves of 64 on AMD (wave64). + +Requires: head_dim % 32 == 0, head_dim >= 64, seq_len % 128 == 0. +""" + +import math + +from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl +from flydsl.dialects.ext import vector as vec_ext +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.dialects.ext.scf import yield_ as scf_yield +from _mlir.dialects import memref as _memref +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils import SmemAllocator +from _mlir import ir +import _mlir.extras.types as T + + +KERNEL_NAME = "flash_attention_v4_4_kernel" + + +def build_flash_attention_v4_4_module_primary( + num_heads, + head_dim, + causal=True, + dtype_str="f16", + sm_scale=None, +): + """Build a FlyDSL Flash Attention V4.4 module.""" + gpu_arch = get_hip_arch() + DYN = ir.ShapedType.get_dynamic_size() + + # Aggressive MFMA32 configuration for target B=1, H=64, S=8192, D=128. + BLOCK_M = 128 + BLOCK_N = 32 + NUM_WAVES = 4 + WARP_SIZE = 64 + BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 + ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 32 + + # MFMA32 K-dimension is 8. + K_STEP_QK = 8 + K_STEPS_QK = head_dim // K_STEP_QK + # PV stage computes 32 output columns per accumulator chunk. + D_CHUNK = 32 + D_CHUNKS = head_dim // D_CHUNK + PV_K_STEP = 8 + PV_K_STEPS = BLOCK_N // PV_K_STEP # 4 for BN=32 + + assert BLOCK_M % NUM_WAVES == 0 + assert head_dim % 32 == 0, f"head_dim ({head_dim}) must be divisible by 32" + assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" + assert dtype_str == "f16", "V4.4 currently only supports f16" + assert BLOCK_N % 32 == 0 + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + NUM_HEADS = num_heads + HEAD_DIM = head_dim + CAUSAL = causal + STRIDE_TOKEN = NUM_HEADS * HEAD_DIM + + # Bank-conflict-friendly LDS strides. + K_STRIDE = HEAD_DIM + 2 + VT_STRIDE = BLOCK_N + 2 + + # Vectorized cooperative load constants. + VEC_WIDTH = 8 + THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH + assert BLOCK_SIZE % THREADS_PER_ROW_LOAD == 0 + ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD + + if ROWS_PER_BATCH_LOAD >= BLOCK_N: + NUM_BATCHES_KV = 1 + KV_NEEDS_GUARD = ROWS_PER_BATCH_LOAD > BLOCK_N + else: + assert BLOCK_N % ROWS_PER_BATCH_LOAD == 0 + NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD + KV_NEEDS_GUARD = False + + # Separate K and V^T regions to remove one intra-iteration barrier. + LDS_K_SIZE = BLOCK_N * K_STRIDE + LDS_VT_SIZE = HEAD_DIM * VT_STRIDE + LDS_VT_BASE = LDS_K_SIZE + LDS_KV_TOTAL_SIZE = LDS_K_SIZE + LDS_VT_SIZE + + allocator = SmemAllocator(None, arch=gpu_arch) + _state = {} + + class _FlashAttentionV4_4(flir.MlirModule): + GPU_MODULE_NAME = f"flash_attn_v4_4_{dtype_str}" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + elem_type = T.f16() + _state["elem_type"] = elem_type + _state["lds_kv"] = allocator.allocate_array(elem_type, LDS_KV_TOTAL_SIZE) + allocator.finalize() + + @flir.kernel + def flash_attention_v4_4_kernel( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + seq_len: lambda: T.index(), + ): + compute_type = T.f32() + elem_type = _state["elem_type"] + fm_fast = flir.arith.FastMathFlags.fast + + v4f16_type = ir.VectorType.get([4], elem_type) + v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) + v16f32_type = ir.VectorType.get([16], compute_type) + + seq_len_v = arith.as_value(seq_len) + + # ---- LDS view ---- + base_ptr = allocator.get_base() + lds_kv = _state["lds_kv"](base_ptr).get() + + # ---- Thread / block indices ---- + block_id = flir.const_index(flir.block_idx("x")) + tid = flir.const_index(flir.thread_idx("x")) + + # ---- Wave decomposition ---- + c_ws = flir.const_index(WARP_SIZE) + wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) + lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) + + c32 = flir.const_index(32) + lane_mod_32 = arith.as_value(flir.arith.RemUIOp(lane, c32).result) + lane_div_32 = arith.as_value(flir.arith.DivUIOp(lane, c32).result) # 0/1 + + # ---- Wave offsets ---- + wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value + + # ---- Decompose block_id ---- + c_nh = flir.const_index(NUM_HEADS) + head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) + temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) + c_bm = flir.const_index(BLOCK_M) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) + q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) + batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) + q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value + + # ---- Cooperative load decomposition ---- + c_tpr = flir.const_index(THREADS_PER_ROW_LOAD) + load_row_in_batch = arith.as_value(flir.arith.DivUIOp(tid, c_tpr).result) + load_lane_in_row = arith.as_value(flir.arith.RemUIOp(tid, c_tpr).result) + load_col_base = (arith.ArithValue(load_lane_in_row) * VEC_WIDTH).value + + # ---- Helper: global flat index ---- + def global_idx(token_idx, col): + token = ( + arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) + + arith.ArithValue(token_idx) + ) + return ( + token * STRIDE_TOKEN + + arith.ArithValue(head_idx) * HEAD_DIM + + arith.ArithValue(col) + ).value + + # ---- Cooperative K load (row-major, padded stride) ---- + def coop_load_k(tile_start): + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + if KV_NEEDS_GUARD: + c_bn = flir.const_index(BLOCK_N) + row_valid = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + arith.ArithValue(load_row_in_batch).value, + c_bn, + ).result + ) + with scf.if_(row_valid): + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, K, [g_idx])) + lds_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + lds_idx = ( + arith.ArithValue(lds_row) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + else: + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, K, [g_idx])) + lds_row = (arith.ArithValue(load_row_in_batch) + row_offset).value + lds_idx = ( + arith.ArithValue(lds_row) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + + # ---- Cooperative V load (transposed, padded stride) ---- + def coop_load_v_transposed(tile_start): + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + if KV_NEEDS_GUARD: + c_bn = flir.const_index(BLOCK_N) + row_valid = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + arith.ArithValue(load_row_in_batch).value, + c_bn, + ).result + ) + with scf.if_(row_valid): + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, V, [g_idx])) + load_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + LDS_VT_BASE + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + else: + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, V, [g_idx])) + load_row = (arith.ArithValue(load_row_in_batch) + row_offset).value + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + LDS_VT_BASE + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + + # ---- Preload Q^T B-operand packs once (register-resident) ---- + # B operand uses j = lane_mod_32, k-subblock = lane_div_32*4. + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_mod_32) + ).value + q_b_packs = [] + for ks in range_constexpr(K_STEPS_QK): + q_col = ( + flir.const_index(ks * K_STEP_QK) + + arith.ArithValue(lane_div_32) * 4 + ).value + g_idx = global_idx(q_row, q_col) + q_b_packs.append(arith.as_value(vec_ext.load_op(v4f16_type, Q, [g_idx]))) + + # ---- Constants ---- + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_zero_f = arith.constant(0.0, type=compute_type) + c_one_f = arith.constant(1.0, type=compute_type) + c_sm_scale = arith.constant(sm_scale, type=compute_type) + c_log2e = arith.constant(1.4426950408889634, type=compute_type) + c_zero_v16f32 = arith.as_value(arith.constant_vector(0.0, v16f32_type)) + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + shuf_32_i32 = arith.as_value(arith.constant(32, type=T.i32())) + + # ---- KV loop upper bound ---- + if CAUSAL: + kv_upper = (arith.ArithValue(q_start) + BLOCK_M).value + else: + kv_upper = seq_len_v + + # Loop-carried: [m_old, l_old, o_acc_chunks...] + init_args = [arith.as_value(c_neg_inf), arith.as_value(c_zero_f)] + for _ in range_constexpr(D_CHUNKS): + init_args.append(c_zero_v16f32) + + with scf.for_(0, kv_upper, BLOCK_N, iter_args=init_args) as loop: + kv_start = arith.as_value(loop.induction_variable) + m_old = arith.as_value(loop.inner_iter_args[0]) + l_old = arith.as_value(loop.inner_iter_args[1]) + o_accs = [arith.as_value(loop.inner_iter_args[2 + i]) for i in range_constexpr(D_CHUNKS)] + + # ==== Cooperative K load -> LDS_KV ==== + coop_load_k(kv_start) + gpu.barrier() + + # ==== GEMM1: S = K @ Q^T (MFMA32), S in v16f32 ==== + s_acc = c_zero_v16f32 + for ks in range_constexpr(K_STEPS_QK): + k_idx = ( + arith.ArithValue(lane_mod_32) * K_STRIDE + + ks * K_STEP_QK + + arith.ArithValue(lane_div_32) * 4 + ).value + k_pack = arith.as_value(vec_ext.load_op(v4f16_type, lds_kv, [k_idx])) + q_pack = q_b_packs[ks] + s_acc = arith.as_value( + rocdl.mfma_f32_32x32x8f16( + v16f32_type, [k_pack, q_pack, s_acc, 0, 0, 0] + ) + ) + + # ==== Online softmax over KV dimension (register only) ==== + q_row_i64 = arith.as_value( + flir.arith.IndexCastOp(T.i64(), q_row).result + ) + + s_vals = [] + for r in range_constexpr(16): + s_val = arith.as_value( + vec_ext.extract(s_acc, static_position=[r], dynamic_position=[]) + ) + s_val = arith.as_value( + flir.arith.MulFOp( + s_val, arith.as_value(c_sm_scale), fastmath=fm_fast + ).result + ) + + if CAUSAL: + kv_row_rel = ( + arith.ArithValue(lane_div_32) * 4 + + (r // 4) * 8 + + (r % 4) + ).value + kv_col = (arith.ArithValue(kv_start) + arith.ArithValue(kv_row_rel)).value + kv_col_i64 = arith.as_value( + flir.arith.IndexCastOp(T.i64(), kv_col).result + ) + is_masked = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ugt, kv_col_i64, q_row_i64 + ).result + ) + s_val = arith.as_value( + flir.arith.SelectOp( + is_masked, arith.as_value(c_neg_inf), s_val + ).result + ) + s_vals.append(s_val) + + local_max = s_vals[0] + for r in range_constexpr(15): + local_max = arith.as_value( + flir.arith.MaximumFOp(local_max, s_vals[r + 1]).result + ) + peer_max = arith.as_value( + gpu.ShuffleOp(local_max, shuf_32_i32, width_i32, mode="xor").shuffleResult + ) + row_max = arith.as_value( + flir.arith.MaximumFOp(local_max, peer_max).result + ) + m_new = arith.as_value( + flir.arith.MaximumFOp(m_old, row_max).result + ) + + diff_m = arith.as_value( + flir.arith.SubFOp(m_old, m_new, fastmath=fm_fast).result + ) + diff_m_s = arith.as_value( + flir.arith.MulFOp(diff_m, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + corr = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) + + p_vals = [] + local_sum = arith.as_value(c_zero_f) + for r in range_constexpr(16): + diff = arith.as_value( + flir.arith.SubFOp(s_vals[r], m_new, fastmath=fm_fast).result + ) + diff_s = arith.as_value( + flir.arith.MulFOp(diff, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + p = arith.as_value(flir.math.exp2(diff_s, fastmath=fm_fast)) + p_vals.append(p) + local_sum = arith.as_value( + flir.arith.AddFOp(local_sum, p, fastmath=fm_fast).result + ) + + peer_sum = arith.as_value( + gpu.ShuffleOp(local_sum, shuf_32_i32, width_i32, mode="xor").shuffleResult + ) + tile_sum = arith.as_value( + flir.arith.AddFOp(local_sum, peer_sum, fastmath=fm_fast).result + ) + l_corr = arith.as_value( + flir.arith.MulFOp(corr, l_old, fastmath=fm_fast).result + ) + l_new = arith.as_value( + flir.arith.AddFOp(l_corr, tile_sum, fastmath=fm_fast).result + ) + + # ==== Rescale O accumulators ==== + corr_vec = arith.as_value(vec_ext.broadcast(v16f32_type, corr)) + for dc in range_constexpr(D_CHUNKS): + o_accs[dc] = arith.as_value( + flir.arith.MulFOp(o_accs[dc], corr_vec, fastmath=fm_fast).result + ) + + # ==== Load V^T for current tile into LDS_KV ==== + coop_load_v_transposed(kv_start) + gpu.barrier() + + # ==== Build P packs in MFMA32 B-input format from register S ==== + p_f16 = [] + for r in range_constexpr(16): + p_f16.append( + arith.as_value(flir.arith.TruncFOp(elem_type, p_vals[r]).result) + ) + p_packs = [] + for pks in range_constexpr(PV_K_STEPS): + p_base = pks * 4 + p_packs.append( + arith.as_value( + vec_ext.from_elements( + v4f16_type, + [ + p_f16[p_base + 0], + p_f16[p_base + 1], + p_f16[p_base + 2], + p_f16[p_base + 3], + ], + ) + ) + ) + + # ==== GEMM2: O^T += V^T @ P (MFMA32) ==== + for dc in range_constexpr(D_CHUNKS): + for pks in range_constexpr(PV_K_STEPS): + v_idx = ( + LDS_VT_BASE + + (dc * D_CHUNK + arith.ArithValue(lane_mod_32)) * VT_STRIDE + + pks * PV_K_STEP + + arith.ArithValue(lane_div_32) * 4 + ).value + v_pack = arith.as_value(vec_ext.load_op(v4f16_type, lds_kv, [v_idx])) + o_accs[dc] = arith.as_value( + rocdl.mfma_f32_32x32x8f16( + v16f32_type, [v_pack, p_packs[pks], o_accs[dc], 0, 0, 0] + ) + ) + + yield_args = [m_new, l_new] + o_accs + scf_yield(yield_args) + + # ---- Normalize and store O ---- + l_final = arith.as_value(loop.results[1]) + o_finals = [arith.as_value(loop.results[2 + dc]) for dc in range_constexpr(D_CHUNKS)] + + inv_l = arith.as_value( + flir.arith.DivFOp(arith.as_value(c_one_f), l_final, fastmath=fm_fast).result + ) + inv_l_vec = arith.as_value(vec_ext.broadcast(v16f32_type, inv_l)) + + for dc in range_constexpr(D_CHUNKS): + o_norm_vec = arith.as_value( + flir.arith.MulFOp(o_finals[dc], inv_l_vec, fastmath=fm_fast).result + ) + for r in range_constexpr(16): + o_val = arith.as_value( + vec_ext.extract(o_norm_vec, static_position=[r], dynamic_position=[]) + ) + o_f16 = arith.as_value(flir.arith.TruncFOp(elem_type, o_val).result) + + d_row_rel = ( + arith.ArithValue(lane_div_32) * 4 + + (r // 4) * 8 + + (r % 4) + ).value + d_col = (flir.const_index(dc * D_CHUNK) + arith.ArithValue(d_row_rel)).value + o_global = global_idx(q_row, d_col) + _memref.StoreOp(o_f16, O, [o_global]) + + @flir.jit + def __call__( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + batch_size: lambda: T.index(), + seq_len: lambda: T.index(), + ): + c1 = arith.as_value(flir.arith_ext.index(1)) + c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) + c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) + bs_val = arith.as_value(batch_size) + sl_val = arith.as_value(seq_len) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(sl_val, c_bm).result) + bs_qt = arith.as_value(flir.arith.MulIOp(bs_val, num_q_tiles).result) + grid_x = arith.as_value(flir.arith.MulIOp(bs_qt, c_nh).result) + bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) + flir.gpu_ext.LaunchFuncOp( + [self.GPU_MODULE_NAME, KERNEL_NAME], + grid_size=(grid_x, c1, c1), + block_size=(bx, c1, c1), + kernel_operands=[Q, K, V, O, seq_len], + ) + + return _FlashAttentionV4_4() + + +build_flash_attention_v4_4_module = build_flash_attention_v4_4_module_primary +"""Flash Attention V4.4 kernel builder for FlyDSL. + +Aggressive V4.4 path: +- True MFMA32 remap: `mfma_f32_32x32x8f16` for both GEMM stages. +- Tile shape: BLOCK_M=128, BLOCK_N=32, 4 waves (256 threads). +- Per-wave Q rows: 32. +- GEMM1 uses `K @ Q^T` so S/P live in MFMA32 register layout. +- Online softmax over KV dimension is done in registers. +- P is kept in registers and fed directly to GEMM2 (`V^T @ P`) without LDS roundtrip. +- K uses LDS ping-pong prefetch between adjacent iterations. + +Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). +Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. +Block: (256,) -- 4 waves of 64 on AMD (wave64). + +Requires: head_dim % 32 == 0, head_dim >= 64, seq_len % 128 == 0. +""" + +import math + +from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl +from flydsl.dialects.ext import vector as vec_ext +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.dialects.ext.scf import yield_ as scf_yield +from _mlir.dialects import memref as _memref +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils import SmemAllocator +from _mlir import ir +import _mlir.extras.types as T + + +KERNEL_NAME = "flash_attention_v4_4_kernel" + + +def _legacy_copy_build_flash_attention_v4_4_module_2( + num_heads, + head_dim, + causal=True, + dtype_str="f16", + sm_scale=None, +): + """Build a FlyDSL Flash Attention V4.4 module.""" + gpu_arch = get_hip_arch() + DYN = ir.ShapedType.get_dynamic_size() + + # Aggressive MFMA32 configuration for target B=1, H=64, S=8192, D=128. + BLOCK_M = 256 + BLOCK_N = 32 + NUM_WAVES = 8 + WARP_SIZE = 64 + BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 + ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 32 + + # MFMA32 K-dimension is 8. + K_STEP_QK = 8 + K_STEPS_QK = head_dim // K_STEP_QK + # PV stage computes 32 output columns per accumulator chunk. + D_CHUNK = 32 + D_CHUNKS = head_dim // D_CHUNK + PV_K_STEP = 8 + PV_K_STEPS = BLOCK_N // PV_K_STEP # 4 for BN=32 + + assert BLOCK_M % NUM_WAVES == 0 + assert head_dim % 32 == 0, f"head_dim ({head_dim}) must be divisible by 32" + assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" + assert dtype_str == "f16", "V4.4 currently only supports f16" + assert BLOCK_N % 32 == 0 + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + NUM_HEADS = num_heads + HEAD_DIM = head_dim + CAUSAL = causal + STRIDE_TOKEN = NUM_HEADS * HEAD_DIM + + # Bank-conflict-friendly LDS strides. + K_STRIDE = HEAD_DIM + 2 + VT_STRIDE = BLOCK_N + 2 + + # Vectorized cooperative load constants. + VEC_WIDTH = 8 + THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH + assert BLOCK_SIZE % THREADS_PER_ROW_LOAD == 0 + ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD + + if ROWS_PER_BATCH_LOAD >= BLOCK_N: + NUM_BATCHES_KV = 1 + KV_NEEDS_GUARD = ROWS_PER_BATCH_LOAD > BLOCK_N + else: + assert BLOCK_N % ROWS_PER_BATCH_LOAD == 0 + NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD + KV_NEEDS_GUARD = False + + # Two KV buffers for K ping-pong prefetch. + LDS_KV_BUF_SIZE = max(BLOCK_N * K_STRIDE, HEAD_DIM * VT_STRIDE) + LDS_KV_TOTAL_SIZE = 2 * LDS_KV_BUF_SIZE + + allocator = SmemAllocator(None, arch=gpu_arch) + _state = {} + + class _FlashAttentionV4_4(flir.MlirModule): + GPU_MODULE_NAME = f"flash_attn_v4_4_{dtype_str}" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + elem_type = T.f16() + _state["elem_type"] = elem_type + _state["lds_kv"] = allocator.allocate_array(elem_type, LDS_KV_TOTAL_SIZE) + allocator.finalize() + + @flir.kernel + def flash_attention_v4_4_kernel( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + seq_len: lambda: T.index(), + ): + compute_type = T.f32() + elem_type = _state["elem_type"] + fm_fast = flir.arith.FastMathFlags.fast + + v4f16_type = ir.VectorType.get([4], elem_type) + v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) + v16f32_type = ir.VectorType.get([16], compute_type) + + seq_len_v = arith.as_value(seq_len) + + # ---- LDS view ---- + base_ptr = allocator.get_base() + lds_kv = _state["lds_kv"](base_ptr).get() + + # ---- Thread / block indices ---- + block_id = flir.const_index(flir.block_idx("x")) + tid = flir.const_index(flir.thread_idx("x")) + + # ---- Wave decomposition ---- + c_ws = flir.const_index(WARP_SIZE) + wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) + lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) + + c32 = flir.const_index(32) + lane_mod_32 = arith.as_value(flir.arith.RemUIOp(lane, c32).result) + lane_div_32 = arith.as_value(flir.arith.DivUIOp(lane, c32).result) # 0/1 + + # ---- Wave offsets ---- + wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value + + # ---- Decompose block_id ---- + c_nh = flir.const_index(NUM_HEADS) + head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) + temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) + c_bm = flir.const_index(BLOCK_M) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) + q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) + batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) + q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value + + # ---- Cooperative load decomposition ---- + c_tpr = flir.const_index(THREADS_PER_ROW_LOAD) + load_row_in_batch = arith.as_value(flir.arith.DivUIOp(tid, c_tpr).result) + load_lane_in_row = arith.as_value(flir.arith.RemUIOp(tid, c_tpr).result) + load_col_base = (arith.ArithValue(load_lane_in_row) * VEC_WIDTH).value + + # ---- Helper: global flat index ---- + def global_idx(token_idx, col): + token = ( + arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) + + arith.ArithValue(token_idx) + ) + return ( + token * STRIDE_TOKEN + + arith.ArithValue(head_idx) * HEAD_DIM + + arith.ArithValue(col) + ).value + + def kv_buf_base(buf_idx): + return (arith.ArithValue(buf_idx) * LDS_KV_BUF_SIZE).value + + # ---- Cooperative K load (row-major, padded stride) ---- + def coop_load_k(tile_start, buf_idx): + base = kv_buf_base(buf_idx) + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + if KV_NEEDS_GUARD: + c_bn = flir.const_index(BLOCK_N) + row_valid = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + arith.ArithValue(load_row_in_batch).value, + c_bn, + ).result + ) + with scf.if_(row_valid): + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, K, [g_idx])) + lds_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + lds_idx = ( + arith.ArithValue(base) + + arith.ArithValue(lds_row) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + else: + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, K, [g_idx])) + lds_row = (arith.ArithValue(load_row_in_batch) + row_offset).value + lds_idx = ( + arith.ArithValue(base) + + arith.ArithValue(lds_row) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + + # ---- Cooperative V load (transposed, padded stride) ---- + def coop_load_v_transposed(tile_start, buf_idx): + base = kv_buf_base(buf_idx) + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + if KV_NEEDS_GUARD: + c_bn = flir.const_index(BLOCK_N) + row_valid = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + arith.ArithValue(load_row_in_batch).value, + c_bn, + ).result + ) + with scf.if_(row_valid): + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, V, [g_idx])) + load_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + arith.ArithValue(base) + + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + else: + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, V, [g_idx])) + load_row = (arith.ArithValue(load_row_in_batch) + row_offset).value + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + arith.ArithValue(base) + + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + + # ---- Preload Q^T B-operand packs once (register-resident) ---- + # B operand uses j = lane_mod_32, k-subblock = lane_div_32*4. + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_mod_32) + ).value + q_b_packs = [] + for ks in range_constexpr(K_STEPS_QK): + q_col = ( + flir.const_index(ks * K_STEP_QK) + + arith.ArithValue(lane_div_32) * 4 + ).value + g_idx = global_idx(q_row, q_col) + q_b_packs.append(arith.as_value(vec_ext.load_op(v4f16_type, Q, [g_idx]))) + + # ---- Constants ---- + c0_idx = flir.const_index(0) + c1_idx = flir.const_index(1) + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_zero_f = arith.constant(0.0, type=compute_type) + c_one_f = arith.constant(1.0, type=compute_type) + c_sm_scale = arith.constant(sm_scale, type=compute_type) + c_log2e = arith.constant(1.4426950408889634, type=compute_type) + c_zero_v16f32 = arith.as_value(arith.constant_vector(0.0, v16f32_type)) + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + shuf_32_i32 = arith.as_value(arith.constant(32, type=T.i32())) + + # ---- KV loop upper bound ---- + if CAUSAL: + kv_upper = (arith.ArithValue(q_start) + BLOCK_M).value + else: + kv_upper = seq_len_v + + # ---- K ping-pong preload: K(0) -> buf0 ---- + coop_load_k(c0_idx, c0_idx) + gpu.barrier() + + # Loop-carried: [cur_k_buf, m_old, l_old, o_acc_chunks...] + init_args = [arith.as_value(c0_idx), arith.as_value(c_neg_inf), arith.as_value(c_zero_f)] + for _ in range_constexpr(D_CHUNKS): + init_args.append(c_zero_v16f32) + + with scf.for_(0, kv_upper, BLOCK_N, iter_args=init_args) as loop: + kv_start = arith.as_value(loop.induction_variable) + cur_k_buf = arith.as_value(loop.inner_iter_args[0]) + m_old = arith.as_value(loop.inner_iter_args[1]) + l_old = arith.as_value(loop.inner_iter_args[2]) + o_accs = [arith.as_value(loop.inner_iter_args[3 + i]) for i in range_constexpr(D_CHUNKS)] + + next_k_buf = arith.as_value( + flir.arith.SubIOp(c1_idx, arith.ArithValue(cur_k_buf).value).result + ) + cur_base = kv_buf_base(cur_k_buf) + + # ==== GEMM1: S = K @ Q^T (MFMA32), S in v16f32 ==== + s_acc = c_zero_v16f32 + for ks in range_constexpr(K_STEPS_QK): + k_idx = ( + arith.ArithValue(cur_base) + + arith.ArithValue(lane_mod_32) * K_STRIDE + + ks * K_STEP_QK + + arith.ArithValue(lane_div_32) * 4 + ).value + k_pack = arith.as_value(vec_ext.load_op(v4f16_type, lds_kv, [k_idx])) + q_pack = q_b_packs[ks] + s_acc = arith.as_value( + rocdl.mfma_f32_32x32x8f16( + v16f32_type, [k_pack, q_pack, s_acc, 0, 0, 0] + ) + ) + + # ==== Online softmax over KV dimension (register only) ==== + q_row_i64 = arith.as_value( + flir.arith.IndexCastOp(T.i64(), q_row).result + ) + + s_vals = [] + for r in range_constexpr(16): + s_val = arith.as_value( + vec_ext.extract(s_acc, static_position=[r], dynamic_position=[]) + ) + s_val = arith.as_value( + flir.arith.MulFOp( + s_val, arith.as_value(c_sm_scale), fastmath=fm_fast + ).result + ) + + if CAUSAL: + kv_row_rel = ( + arith.ArithValue(lane_div_32) * 4 + + (r // 4) * 8 + + (r % 4) + ).value + kv_col = (arith.ArithValue(kv_start) + arith.ArithValue(kv_row_rel)).value + kv_col_i64 = arith.as_value( + flir.arith.IndexCastOp(T.i64(), kv_col).result + ) + is_masked = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ugt, kv_col_i64, q_row_i64 + ).result + ) + s_val = arith.as_value( + flir.arith.SelectOp( + is_masked, arith.as_value(c_neg_inf), s_val + ).result + ) + s_vals.append(s_val) + + local_max = s_vals[0] + for r in range_constexpr(15): + local_max = arith.as_value( + flir.arith.MaximumFOp(local_max, s_vals[r + 1]).result + ) + peer_max = arith.as_value( + gpu.ShuffleOp(local_max, shuf_32_i32, width_i32, mode="xor").shuffleResult + ) + row_max = arith.as_value( + flir.arith.MaximumFOp(local_max, peer_max).result + ) + m_new = arith.as_value( + flir.arith.MaximumFOp(m_old, row_max).result + ) + + diff_m = arith.as_value( + flir.arith.SubFOp(m_old, m_new, fastmath=fm_fast).result + ) + diff_m_s = arith.as_value( + flir.arith.MulFOp(diff_m, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + corr = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) + + p_vals = [] + local_sum = arith.as_value(c_zero_f) + for r in range_constexpr(16): + diff = arith.as_value( + flir.arith.SubFOp(s_vals[r], m_new, fastmath=fm_fast).result + ) + diff_s = arith.as_value( + flir.arith.MulFOp(diff, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + p = arith.as_value(flir.math.exp2(diff_s, fastmath=fm_fast)) + p_vals.append(p) + local_sum = arith.as_value( + flir.arith.AddFOp(local_sum, p, fastmath=fm_fast).result + ) + + peer_sum = arith.as_value( + gpu.ShuffleOp(local_sum, shuf_32_i32, width_i32, mode="xor").shuffleResult + ) + tile_sum = arith.as_value( + flir.arith.AddFOp(local_sum, peer_sum, fastmath=fm_fast).result + ) + l_corr = arith.as_value( + flir.arith.MulFOp(corr, l_old, fastmath=fm_fast).result + ) + l_new = arith.as_value( + flir.arith.AddFOp(l_corr, tile_sum, fastmath=fm_fast).result + ) + + # ==== Rescale O accumulators ==== + corr_vec = arith.as_value(vec_ext.broadcast(v16f32_type, corr)) + for dc in range_constexpr(D_CHUNKS): + o_accs[dc] = arith.as_value( + flir.arith.MulFOp(o_accs[dc], corr_vec, fastmath=fm_fast).result + ) + + # All waves must finish K reads before reusing current buffer for V. + gpu.barrier() + + # ==== Load V^T for current tile into current buffer ==== + coop_load_v_transposed(kv_start, cur_k_buf) + + # ==== Prefetch next K tile into next buffer (if exists) ==== + next_kv_start = (arith.ArithValue(kv_start) + BLOCK_N).value + has_next = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + next_kv_start, + kv_upper, + ).result + ) + with scf.if_(has_next): + coop_load_k(next_kv_start, next_k_buf) + + # Synchronize V current + K next visibility. + gpu.barrier() + + # ==== Build P packs in MFMA32 B-input format from register S ==== + p_f16 = [] + for r in range_constexpr(16): + p_f16.append( + arith.as_value(flir.arith.TruncFOp(elem_type, p_vals[r]).result) + ) + p_packs = [] + for pks in range_constexpr(PV_K_STEPS): + p_base = pks * 4 + p_packs.append( + arith.as_value( + vec_ext.from_elements( + v4f16_type, + [ + p_f16[p_base + 0], + p_f16[p_base + 1], + p_f16[p_base + 2], + p_f16[p_base + 3], + ], + ) + ) + ) + + # ==== GEMM2: O^T += V^T @ P (MFMA32) ==== + for dc in range_constexpr(D_CHUNKS): + for pks in range_constexpr(PV_K_STEPS): + v_idx = ( + arith.ArithValue(cur_base) + + (dc * D_CHUNK + arith.ArithValue(lane_mod_32)) * VT_STRIDE + + pks * PV_K_STEP + + arith.ArithValue(lane_div_32) * 4 + ).value + v_pack = arith.as_value(vec_ext.load_op(v4f16_type, lds_kv, [v_idx])) + o_accs[dc] = arith.as_value( + rocdl.mfma_f32_32x32x8f16( + v16f32_type, [v_pack, p_packs[pks], o_accs[dc], 0, 0, 0] + ) + ) + + # No trailing barrier: current buffer is only reused after one full iteration gap. + yield_args = [next_k_buf, m_new, l_new] + o_accs + scf_yield(yield_args) + + # ---- Normalize and store O ---- + l_final = arith.as_value(loop.results[2]) + o_finals = [arith.as_value(loop.results[3 + dc]) for dc in range_constexpr(D_CHUNKS)] + + inv_l = arith.as_value( + flir.arith.DivFOp(arith.as_value(c_one_f), l_final, fastmath=fm_fast).result + ) + inv_l_vec = arith.as_value(vec_ext.broadcast(v16f32_type, inv_l)) + + for dc in range_constexpr(D_CHUNKS): + o_norm_vec = arith.as_value( + flir.arith.MulFOp(o_finals[dc], inv_l_vec, fastmath=fm_fast).result + ) + for r in range_constexpr(16): + o_val = arith.as_value( + vec_ext.extract(o_norm_vec, static_position=[r], dynamic_position=[]) + ) + o_f16 = arith.as_value(flir.arith.TruncFOp(elem_type, o_val).result) + + d_row_rel = ( + arith.ArithValue(lane_div_32) * 4 + + (r // 4) * 8 + + (r % 4) + ).value + d_col = (flir.const_index(dc * D_CHUNK) + arith.ArithValue(d_row_rel)).value + o_global = global_idx(q_row, d_col) + _memref.StoreOp(o_f16, O, [o_global]) + + @flir.jit + def __call__( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + batch_size: lambda: T.index(), + seq_len: lambda: T.index(), + ): + c1 = arith.as_value(flir.arith_ext.index(1)) + c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) + c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) + bs_val = arith.as_value(batch_size) + sl_val = arith.as_value(seq_len) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(sl_val, c_bm).result) + bs_qt = arith.as_value(flir.arith.MulIOp(bs_val, num_q_tiles).result) + grid_x = arith.as_value(flir.arith.MulIOp(bs_qt, c_nh).result) + bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) + flir.gpu_ext.LaunchFuncOp( + [self.GPU_MODULE_NAME, KERNEL_NAME], + grid_size=(grid_x, c1, c1), + block_size=(bx, c1, c1), + kernel_operands=[Q, K, V, O, seq_len], + ) + + return _FlashAttentionV4_4() +"""Flash Attention V4.4 kernel builder for FlyDSL. + V4.4 design (CK-aligned direction, rewritten from V4.3): - CK-aligned baseline tile family: BLOCK_M=64, BLOCK_N=32. - Q loaded once from global memory into MFMA A-operand packs (register-resident). @@ -30,7 +1123,7 @@ KERNEL_NAME = "flash_attention_v4_4_kernel" -def build_flash_attention_v4_4_module( +def _legacy_copy_build_flash_attention_v4_4_module_3( num_heads, head_dim, causal=True, @@ -497,18 +1590,19 @@ def coop_load_v_transposed(tile_start): gpu.barrier() # ==== P @ V via MFMA ==== + # P does not depend on ds; load once and reuse across all K_STEPS. + p_packs = [] + for nm in range_constexpr(N_MFMA): + p_a_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(lane_mod_16) * BLOCK_N + + nm * 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + p_packs.append(arith.as_value(vec_ext.load_op(v4f16_type, lds_p, [p_a_idx]))) + for ds in range_constexpr(K_STEPS): for nm in range_constexpr(N_MFMA): - p_a_idx = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(lane_mod_16) * BLOCK_N - + nm * 16 - + arith.ArithValue(lane_div_16) * 4 - ).value - p_pack = arith.as_value( - vec_ext.load_op(v4f16_type, lds_p, [p_a_idx]) - ) - v_idx = ( (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE + nm * 16 @@ -519,7 +1613,7 @@ def coop_load_v_transposed(tile_start): ) o_accs[ds] = arith.as_value( rocdl.mfma_f32_16x16x16f16( - v4f32_type, [p_pack, v_pack, o_accs[ds], 0, 0, 0] + v4f32_type, [p_packs[nm], v_pack, o_accs[ds], 0, 0, 0] ) ) diff --git a/tests/kernels/test_flash_attention_v4_4.py b/tests/kernels/test_flash_attention_v4_4.py index ceddcace..d60ac3b1 100644 --- a/tests/kernels/test_flash_attention_v4_4.py +++ b/tests/kernels/test_flash_attention_v4_4.py @@ -40,7 +40,7 @@ V4_4_COMPILE_KWARGS = { "unsafe_fp_math": True, "fast_fp_math": True, - "waves_per_eu": 4, + "waves_per_eu": 3, "flat_work_group_size": 256, } @@ -162,11 +162,11 @@ def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, warmup, iters device = "cuda" results = {} - if seq_len % 64 != 0: - results["err"] = f"seq_len ({seq_len}) must be divisible by 64 for V4.4" + if seq_len % 128 != 0: + results["err"] = f"seq_len ({seq_len}) must be divisible by 128 for V4.4" return results - if head_dim % 16 != 0 or head_dim < 64: - results["err"] = f"head_dim ({head_dim}) must be >= 64 and divisible by 16" + if head_dim % 32 != 0 or head_dim < 64: + results["err"] = f"head_dim ({head_dim}) must be >= 64 and divisible by 32" return results try: @@ -283,8 +283,8 @@ def main(): print("=" * 130) print(f"FlyDSL Flash Attention V4.4 ({'causal' if causal else 'non-causal'}, fp16)") - print(" Tile: BLOCK_M=64, BLOCK_N=32, 4 waves (256 threads), mfma_f32_16x16x16f16") - print(" Strategy: Q-load-once + KV streaming + online softmax over 32 positions") + print(" Tile: BLOCK_M=128, BLOCK_N=32, 4 waves (256 threads), mfma_f32_32x32x8f16") + print(" Strategy: K@Q^T + register S/P ping-pong + V^T@P") print(f"GPU: {torch.cuda.get_device_name(0)}") print(f" Compile opts: {V4_4_COMPILE_KWARGS}") print("=" * 130) From b8fc9691a2690091cbcf45d9fce3bc6f9a012539 Mon Sep 17 00:00:00 2001 From: yanguahe Date: Fri, 13 Feb 2026 14:29:59 +0800 Subject: [PATCH 10/17] [WIP] Opt flash_attention_v4_4_kernel --- flydsl/src/flydsl/dialects/ext/rocdl.py | 71 +++++++++++++++++++++++++ kernels/flash_attention_v4_4.py | 2 +- 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/flydsl/src/flydsl/dialects/ext/rocdl.py b/flydsl/src/flydsl/dialects/ext/rocdl.py index a9356d0e..57355954 100644 --- a/flydsl/src/flydsl/dialects/ext/rocdl.py +++ b/flydsl/src/flydsl/dialects/ext/rocdl.py @@ -29,6 +29,8 @@ _ods_readlane = readlane _ods_readfirstlane = readfirstlane _ods_ds_swizzle = ds_swizzle +_ods_permlane16_swap = permlane16_swap +_ods_permlane32_swap = permlane32_swap _ods_raw_ptr_buffer_atomic_fadd = raw_ptr_buffer_atomic_fadd mask_mfma = 0x008 @@ -153,6 +155,73 @@ def ds_swizzle(result_type, src, offset, *, loc=None, ip=None): return _ods_ds_swizzle(result_type, _arith_ext.unwrap(src), _arith_ext.unwrap(offset), loc=loc, ip=ip) +def _unwrap_i32_lane_operand(v, *, loc=None): + from _mlir.ir import IntegerType + from . import arith as _arith_ext + + return _arith_ext.unwrap(v, type=IntegerType.get_signless(32), loc=loc) + + +def _permlane_i32x2_struct_type(): + from _mlir import ir as _ir + + # Some Python bindings accept optional spaces in LLVM type parser; keep both. + try: + return _ir.Type.parse("!llvm.struct<(i32, i32)>") + except Exception: + return _ir.Type.parse("!llvm.struct<(i32,i32)>") + + +def _extract_permlane_lane_i32(pair_val, *, loc=None, ip=None): + from _mlir.dialects import llvm as _llvm + from _mlir.ir import IntegerType + + i32 = IntegerType.get_signless(32) + return _llvm.extractvalue(i32, pair_val, [0], loc=loc, ip=ip) + + +def permlane16_swap_pair(old, src, fi=False, bound_control=False, *, loc=None, ip=None): + """High-level permlane16 swap wrapper returning the raw i32x2 struct.""" + return _ods_permlane16_swap( + _permlane_i32x2_struct_type(), + _unwrap_i32_lane_operand(old, loc=loc), + _unwrap_i32_lane_operand(src, loc=loc), + fi, + bound_control, + loc=loc, + ip=ip, + ) + + +def permlane16_swap_i32(old, src, fi=False, bound_control=False, *, loc=None, ip=None): + """High-level permlane16 swap wrapper returning the swapped i32 lane value.""" + pair_val = permlane16_swap_pair( + old, src, fi=fi, bound_control=bound_control, loc=loc, ip=ip + ) + return _extract_permlane_lane_i32(pair_val, loc=loc, ip=ip) + + +def permlane32_swap_pair(old, src, fi=False, bound_control=False, *, loc=None, ip=None): + """High-level permlane32 swap wrapper returning the raw i32x2 struct.""" + return _ods_permlane32_swap( + _permlane_i32x2_struct_type(), + _unwrap_i32_lane_operand(old, loc=loc), + _unwrap_i32_lane_operand(src, loc=loc), + fi, + bound_control, + loc=loc, + ip=ip, + ) + + +def permlane32_swap_i32(old, src, fi=False, bound_control=False, *, loc=None, ip=None): + """High-level permlane32 swap wrapper returning the swapped i32 lane value.""" + pair_val = permlane32_swap_pair( + old, src, fi=fi, bound_control=bound_control, loc=loc, ip=ip + ) + return _extract_permlane_lane_i32(pair_val, loc=loc, ip=ip) + + def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None, ip=None): """Atomic fadd that accepts `ArithValue` / wrappers (no explicit `arith.unwrap(...)` needed). @@ -213,6 +282,8 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None, # Shuffle and permutation 'ds_swizzle', 'ds_bpermute', 'permlanex16', 'permlane16_swap', 'permlane32_swap', + 'permlane16_swap_pair', 'permlane16_swap_i32', + 'permlane32_swap_pair', 'permlane32_swap_i32', 'readlane', 'readfirstlane', 'update_dpp', 'ballot', diff --git a/kernels/flash_attention_v4_4.py b/kernels/flash_attention_v4_4.py index 15313315..18fc6236 100644 --- a/kernels/flash_attention_v4_4.py +++ b/kernels/flash_attention_v4_4.py @@ -7,7 +7,7 @@ - GEMM1 uses `K @ Q^T` so S/P live in MFMA32 register layout. - Online softmax over KV dimension is done in registers. - P is kept in registers and fed directly to GEMM2 (`V^T @ P`) without LDS roundtrip. -- K uses LDS ping-pong prefetch between adjacent iterations. +- K and V^T use separate LDS regions (single-buffered per iteration). Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. From 73f96c7671745aacae046b3ce6f28e700c534bcb Mon Sep 17 00:00:00 2001 From: yanguahe Date: Fri, 13 Feb 2026 16:30:41 +0800 Subject: [PATCH 11/17] Refine flash_attention_v4_4 convergence paths with safe defaults. Add gated CK-style N128/prefetch/reduction experiments plus ROCDL phase-fence wrappers so performance tuning can be A/B tested without regressing the stable target-shape path. Co-authored-by: Cursor --- flydsl/src/flydsl/dialects/ext/rocdl.py | 59 +++ kernels/flash_attention_v4_4.py | 415 +++++++++++++-------- tests/kernels/test_flash_attention_v4_4.py | 27 +- 3 files changed, 337 insertions(+), 164 deletions(-) diff --git a/flydsl/src/flydsl/dialects/ext/rocdl.py b/flydsl/src/flydsl/dialects/ext/rocdl.py index 57355954..4fad5bb7 100644 --- a/flydsl/src/flydsl/dialects/ext/rocdl.py +++ b/flydsl/src/flydsl/dialects/ext/rocdl.py @@ -48,6 +48,61 @@ def sched_dswr(cnt): sched_group_barrier(mask_dswr, cnt, 0) +def _unwrap_i32_scalar(v, *, loc=None): + from _mlir.ir import IntegerType + from . import arith as _arith_ext + + return _arith_ext.unwrap(v, type=IntegerType.get_signless(32), loc=loc) + + +def async_global_load_to_lds(global_ptr, lds_ptr, size, offset=0, aux=0, *, loc=None, ip=None): + """Global->LDS async-style copy wrapper (closest stable ROCDL primitive).""" + from . import arith as _arith_ext + + return global_load_lds( + _arith_ext.unwrap(global_ptr, loc=loc), + _arith_ext.unwrap(lds_ptr, loc=loc), + _unwrap_i32_scalar(size, loc=loc), + _unwrap_i32_scalar(offset, loc=loc), + _unwrap_i32_scalar(aux, loc=loc), + loc=loc, + ip=ip, + ) + + +def async_load_to_lds(global_ptr, lds_ptr, size, offset=0, aux=0, *, loc=None, ip=None): + """Alias for load_to_lds with scalar auto-unwrapping.""" + from . import arith as _arith_ext + + return load_to_lds( + _arith_ext.unwrap(global_ptr, loc=loc), + _arith_ext.unwrap(lds_ptr, loc=loc), + _unwrap_i32_scalar(size, loc=loc), + _unwrap_i32_scalar(offset, loc=loc), + _unwrap_i32_scalar(aux, loc=loc), + loc=loc, + ip=ip, + ) + + +def async_load_fence(wait_vmem=0, wait_ds=0, *, loc=None, ip=None): + """Waitcnt-style fence helper for staged async copy scheduling.""" + # NOTE: wait_loadcnt/wait_dscnt lowerings are not stable on current toolchain. + # Use conservative full waitcnt fence for now. + _ = (wait_vmem, wait_ds) + return s_waitcnt(0, loc=loc, ip=ip) + + +def phase_barrier(mask=0, *, loc=None, ip=None): + """Scheduling barrier wrapper used as phase fence in pipelined kernels.""" + return sched_barrier(mask, loc=loc, ip=ip) + + +def phase_group_barrier(mask, size, group_id=0, *, loc=None, ip=None): + """Group scheduling barrier wrapper used as phase fence in pipelined kernels.""" + return sched_group_barrier(mask, size, group_id, loc=loc, ip=ip) + + def _unwrap_mfma_operand(v, *, loc=None): """MFMA operands are MLIR Values; some trailing operands are i32 flags. @@ -257,6 +312,7 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None, 'barrier', 's_barrier', 's_barrier_signal', 's_barrier_wait', 's_waitcnt', 's_wait_loadcnt', 's_wait_storecnt', 's_wait_dscnt', 's_wait_expcnt', + 'async_load_fence', # Matrix operations - MFMA (Matrix Fused Multiply-Add) 'mfma_f32_32x32x8f16', 'mfma_f32_16x16x16f16', @@ -292,6 +348,7 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None, 'raw_buffer_load', 'raw_buffer_store', 'raw_ptr_buffer_load', 'raw_ptr_buffer_store', 'load_to_lds', 'global_load_lds', + 'async_load_to_lds', 'async_global_load_to_lds', 'make_buffer_rsrc', # Atomic operations @@ -305,6 +362,8 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None, # Scheduling and optimization 's_setprio', 's_sleep', 'sched_barrier', 'sched_group_barrier', + 'phase_barrier', 'phase_group_barrier', + 'sched_mfma', 'sched_vmem', 'sched_dsrd', 'sched_dswr', 'iglp_opt', # Type conversions diff --git a/kernels/flash_attention_v4_4.py b/kernels/flash_attention_v4_4.py index 18fc6236..30271b45 100644 --- a/kernels/flash_attention_v4_4.py +++ b/kernels/flash_attention_v4_4.py @@ -17,6 +17,7 @@ """ import math +import os from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl from flydsl.dialects.ext import vector as vec_ext @@ -32,6 +33,26 @@ KERNEL_NAME = "flash_attention_v4_4_kernel" +def select_v4_4_path(num_heads, head_dim, causal=True, dtype_str="f16"): + """Select active V4.4 path tag for build-time specialization.""" + override = os.getenv("FLYDSL_FA_V44_PATH", "auto").strip().lower() + if override in ("fallback", "fallback_n32", "n32"): + return "fallback_n32" + if override in ("fastpath", "ck_n128_fastpath", "n128"): + return "ck_n128_fastpath" + # Keep N128 path feature-gated by default due current occupancy/perf risk. + enable_n128 = os.getenv("FLYDSL_FA_V44_ENABLE_N128", "0") == "1" + if ( + enable_n128 + and dtype_str == "f16" + and causal + and num_heads == 64 + and head_dim == 128 + ): + return "ck_n128_fastpath" + return "fallback_n32" + + def build_flash_attention_v4_4_module_primary( num_heads, head_dim, @@ -50,6 +71,17 @@ def build_flash_attention_v4_4_module_primary( WARP_SIZE = 64 BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 32 + PATH_TAG = select_v4_4_path(num_heads, head_dim, causal=causal, dtype_str=dtype_str) + BLOCK_N_OUT = 128 if PATH_TAG == "ck_n128_fastpath" else BLOCK_N + N_SUBTILES = BLOCK_N_OUT // BLOCK_N + ENABLE_PREFETCH_3BUF = os.getenv("FLYDSL_FA_V44_ENABLE_PREFETCH3", "0") == "1" + ENABLE_LDS_VEC16 = os.getenv("FLYDSL_FA_V44_ENABLE_LDS_VEC16", "1") == "1" + REDUCE_MODE = os.getenv("FLYDSL_FA_V44_REDUCE_MODE", "xor").strip().lower() + if REDUCE_MODE not in ("xor", "ds_bpermute"): + REDUCE_MODE = "xor" + NUM_PREFETCH_K = 3 if ENABLE_PREFETCH_3BUF else 1 + NUM_PREFETCH_V = 3 if ENABLE_PREFETCH_3BUF else 1 + CK_LDS_SEQ = (1, 2, 0, 1, 0, 1, 2, 0) if ENABLE_PREFETCH_3BUF else (0,) # MFMA32 K-dimension is 8. K_STEP_QK = 8 @@ -65,6 +97,7 @@ def build_flash_attention_v4_4_module_primary( assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" assert dtype_str == "f16", "V4.4 currently only supports f16" assert BLOCK_N % 32 == 0 + assert BLOCK_N_OUT % BLOCK_N == 0 if sm_scale is None: sm_scale = 1.0 / math.sqrt(head_dim) @@ -79,7 +112,8 @@ def build_flash_attention_v4_4_module_primary( VT_STRIDE = BLOCK_N + 2 # Vectorized cooperative load constants. - VEC_WIDTH = 8 + VEC_WIDTH = 16 if ENABLE_LDS_VEC16 else 8 + assert HEAD_DIM % VEC_WIDTH == 0 THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH assert BLOCK_SIZE % THREADS_PER_ROW_LOAD == 0 ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD @@ -92,17 +126,20 @@ def build_flash_attention_v4_4_module_primary( NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD KV_NEEDS_GUARD = False - # Separate K and V^T regions to remove one intra-iteration barrier. - LDS_K_SIZE = BLOCK_N * K_STRIDE - LDS_VT_SIZE = HEAD_DIM * VT_STRIDE - LDS_VT_BASE = LDS_K_SIZE - LDS_KV_TOTAL_SIZE = LDS_K_SIZE + LDS_VT_SIZE + # K/VT circular buffers; defaults to 1/1, optional 3/3 with CK-like LDS sequence. + LDS_K_TILE_SIZE = BLOCK_N * K_STRIDE + LDS_VT_TILE_SIZE = HEAD_DIM * VT_STRIDE + LDS_K_TOTAL_SIZE = NUM_PREFETCH_K * LDS_K_TILE_SIZE + LDS_VT_BASE = LDS_K_TOTAL_SIZE + LDS_VT_TOTAL_SIZE = NUM_PREFETCH_V * LDS_VT_TILE_SIZE + LDS_KV_TOTAL_SIZE = LDS_K_TOTAL_SIZE + LDS_VT_TOTAL_SIZE allocator = SmemAllocator(None, arch=gpu_arch) _state = {} class _FlashAttentionV4_4(flir.MlirModule): - GPU_MODULE_NAME = f"flash_attn_v4_4_{dtype_str}" + GPU_MODULE_NAME = f"flash_attn_v4_4_{dtype_str}_{PATH_TAG}" + KERNEL_VARIANT = PATH_TAG GPU_MODULE_TARGETS = [f'#rocdl.target'] def init_gpu_module(self): @@ -125,7 +162,7 @@ def flash_attention_v4_4_kernel( fm_fast = flir.arith.FastMathFlags.fast v4f16_type = ir.VectorType.get([4], elem_type) - v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) + vxf16_type = ir.VectorType.get([VEC_WIDTH], elem_type) v16f32_type = ir.VectorType.get([16], compute_type) seq_len_v = arith.as_value(seq_len) @@ -178,8 +215,15 @@ def global_idx(token_idx, col): + arith.ArithValue(col) ).value + def k_buf_base(buf_id): + return flir.const_index(buf_id * LDS_K_TILE_SIZE) + + def vt_buf_base(buf_id): + return flir.const_index(LDS_VT_BASE + buf_id * LDS_VT_TILE_SIZE) + # ---- Cooperative K load (row-major, padded stride) ---- - def coop_load_k(tile_start): + def coop_load_k(tile_start, buf_id=0): + k_base = k_buf_base(buf_id) for batch in range_constexpr(NUM_BATCHES_KV): row_offset = batch * ROWS_PER_BATCH_LOAD row_idx = ( @@ -198,27 +242,30 @@ def coop_load_k(tile_start): ) with scf.if_(row_valid): g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value(vec_ext.load_op(v8f16_type, K, [g_idx])) + vec = arith.as_value(vec_ext.load_op(vxf16_type, K, [g_idx])) lds_row = ( arith.ArithValue(load_row_in_batch) + row_offset ).value lds_idx = ( - arith.ArithValue(lds_row) * K_STRIDE + arith.ArithValue(k_base) + + arith.ArithValue(lds_row) * K_STRIDE + arith.ArithValue(load_col_base) ).value vec_ext.store(vec, lds_kv, [lds_idx]) else: g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value(vec_ext.load_op(v8f16_type, K, [g_idx])) + vec = arith.as_value(vec_ext.load_op(vxf16_type, K, [g_idx])) lds_row = (arith.ArithValue(load_row_in_batch) + row_offset).value lds_idx = ( - arith.ArithValue(lds_row) * K_STRIDE + arith.ArithValue(k_base) + + arith.ArithValue(lds_row) * K_STRIDE + arith.ArithValue(load_col_base) ).value vec_ext.store(vec, lds_kv, [lds_idx]) # ---- Cooperative V load (transposed, padded stride) ---- - def coop_load_v_transposed(tile_start): + def coop_load_v_transposed(tile_start, buf_id=0): + vt_base = vt_buf_base(buf_id) for batch in range_constexpr(NUM_BATCHES_KV): row_offset = batch * ROWS_PER_BATCH_LOAD row_idx = ( @@ -237,7 +284,7 @@ def coop_load_v_transposed(tile_start): ) with scf.if_(row_valid): g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value(vec_ext.load_op(v8f16_type, V, [g_idx])) + vec = arith.as_value(vec_ext.load_op(vxf16_type, V, [g_idx])) load_row = ( arith.ArithValue(load_row_in_batch) + row_offset ).value @@ -247,13 +294,14 @@ def coop_load_v_transposed(tile_start): ) col_e = (arith.ArithValue(load_col_base) + e).value lds_idx = ( - LDS_VT_BASE + arith.ArithValue(col_e) * VT_STRIDE + arith.ArithValue(vt_base) + + arith.ArithValue(col_e) * VT_STRIDE + arith.ArithValue(load_row) ).value _memref.StoreOp(elem, lds_kv, [lds_idx]) else: g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value(vec_ext.load_op(v8f16_type, V, [g_idx])) + vec = arith.as_value(vec_ext.load_op(vxf16_type, V, [g_idx])) load_row = (arith.ArithValue(load_row_in_batch) + row_offset).value for e in range_constexpr(VEC_WIDTH): elem = arith.as_value( @@ -261,7 +309,8 @@ def coop_load_v_transposed(tile_start): ) col_e = (arith.ArithValue(load_col_base) + e).value lds_idx = ( - LDS_VT_BASE + arith.ArithValue(col_e) * VT_STRIDE + arith.ArithValue(vt_base) + + arith.ArithValue(col_e) * VT_STRIDE + arith.ArithValue(load_row) ).value _memref.StoreOp(elem, lds_kv, [lds_idx]) @@ -291,6 +340,23 @@ def coop_load_v_transposed(tile_start): c_zero_v16f32 = arith.as_value(arith.constant_vector(0.0, v16f32_type)) width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) shuf_32_i32 = arith.as_value(arith.constant(32, type=T.i32())) + c4_i32 = arith.as_value(arith.constant(4, type=T.i32())) + lane_i32 = arith.as_value(flir.arith.IndexCastOp(T.i32(), lane).result) + lane_xor_32_i32 = arith.as_value(flir.arith.XOrIOp(lane_i32, shuf_32_i32).result) + lane_xor_32_byte = arith.as_value( + flir.arith.MulIOp(lane_xor_32_i32, c4_i32).result + ) + + def reduction_peer(v_f32): + if REDUCE_MODE == "ds_bpermute": + v_i32 = arith.as_value(flir.arith.bitcast(T.i32(), v_f32)) + peer_i32 = arith.as_value( + rocdl.ds_bpermute(T.i32(), lane_xor_32_byte, v_i32) + ) + return arith.as_value(flir.arith.bitcast(compute_type, peer_i32)) + return arith.as_value( + gpu.ShuffleOp(v_f32, shuf_32_i32, width_i32, mode="xor").shuffleResult + ) # ---- KV loop upper bound ---- if CAUSAL: @@ -303,172 +369,209 @@ def coop_load_v_transposed(tile_start): for _ in range_constexpr(D_CHUNKS): init_args.append(c_zero_v16f32) - with scf.for_(0, kv_upper, BLOCK_N, iter_args=init_args) as loop: - kv_start = arith.as_value(loop.induction_variable) - m_old = arith.as_value(loop.inner_iter_args[0]) - l_old = arith.as_value(loop.inner_iter_args[1]) + with scf.for_(0, kv_upper, BLOCK_N_OUT, iter_args=init_args) as loop: + kv_block_start = arith.as_value(loop.induction_variable) + m_running = arith.as_value(loop.inner_iter_args[0]) + l_running = arith.as_value(loop.inner_iter_args[1]) o_accs = [arith.as_value(loop.inner_iter_args[2 + i]) for i in range_constexpr(D_CHUNKS)] + preload_k_count = NUM_PREFETCH_K if NUM_PREFETCH_K < N_SUBTILES else N_SUBTILES - # ==== Cooperative K load -> LDS_KV ==== - coop_load_k(kv_start) - gpu.barrier() + if ENABLE_PREFETCH_3BUF: + for pre_k in range_constexpr(preload_k_count): + pre_k_slot = CK_LDS_SEQ[pre_k % len(CK_LDS_SEQ)] % NUM_PREFETCH_K + pre_k_start = (arith.ArithValue(kv_block_start) + pre_k * BLOCK_N).value + coop_load_k(pre_k_start, pre_k_slot) + rocdl.phase_group_barrier(rocdl.mask_vmem_rd, 1, 0) + gpu.barrier() - # ==== GEMM1: S = K @ Q^T (MFMA32), S in v16f32 ==== - s_acc = c_zero_v16f32 - for ks in range_constexpr(K_STEPS_QK): - k_idx = ( - arith.ArithValue(lane_mod_32) * K_STRIDE - + ks * K_STEP_QK - + arith.ArithValue(lane_div_32) * 4 - ).value - k_pack = arith.as_value(vec_ext.load_op(v4f16_type, lds_kv, [k_idx])) - q_pack = q_b_packs[ks] - s_acc = arith.as_value( - rocdl.mfma_f32_32x32x8f16( - v16f32_type, [k_pack, q_pack, s_acc, 0, 0, 0] + for kv_sub in range_constexpr(N_SUBTILES): + kv_start = (arith.ArithValue(kv_block_start) + kv_sub * BLOCK_N).value + + if ENABLE_PREFETCH_3BUF: + k_slot = CK_LDS_SEQ[kv_sub % len(CK_LDS_SEQ)] % NUM_PREFETCH_K + else: + k_slot = 0 + # ==== Cooperative K load -> LDS_KV ==== + coop_load_k(kv_start, k_slot) + gpu.barrier() + k_base = k_buf_base(k_slot) + + # ==== GEMM1: S = K @ Q^T (MFMA32), S in v16f32 ==== + s_acc = c_zero_v16f32 + for ks in range_constexpr(K_STEPS_QK): + k_idx = ( + arith.ArithValue(k_base) + + arith.ArithValue(lane_mod_32) * K_STRIDE + + ks * K_STEP_QK + + arith.ArithValue(lane_div_32) * 4 + ).value + k_pack = arith.as_value(vec_ext.load_op(v4f16_type, lds_kv, [k_idx])) + q_pack = q_b_packs[ks] + s_acc = arith.as_value( + rocdl.mfma_f32_32x32x8f16( + v16f32_type, [k_pack, q_pack, s_acc, 0, 0, 0] + ) ) + + # ==== Online softmax over KV dimension (register only) ==== + q_row_i64 = arith.as_value( + flir.arith.IndexCastOp(T.i64(), q_row).result ) - # ==== Online softmax over KV dimension (register only) ==== - q_row_i64 = arith.as_value( - flir.arith.IndexCastOp(T.i64(), q_row).result - ) + s_vals = [] + for r in range_constexpr(16): + s_val = arith.as_value( + vec_ext.extract(s_acc, static_position=[r], dynamic_position=[]) + ) + s_val = arith.as_value( + flir.arith.MulFOp( + s_val, arith.as_value(c_sm_scale), fastmath=fm_fast + ).result + ) + if CAUSAL: + kv_row_rel = ( + arith.ArithValue(lane_div_32) * 4 + + (r // 4) * 8 + + (r % 4) + ).value + kv_col = ( + arith.ArithValue(kv_start) + arith.ArithValue(kv_row_rel) + ).value + kv_col_i64 = arith.as_value( + flir.arith.IndexCastOp(T.i64(), kv_col).result + ) + is_masked = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ugt, kv_col_i64, q_row_i64 + ).result + ) + s_val = arith.as_value( + flir.arith.SelectOp( + is_masked, arith.as_value(c_neg_inf), s_val + ).result + ) + s_vals.append(s_val) - s_vals = [] - for r in range_constexpr(16): - s_val = arith.as_value( - vec_ext.extract(s_acc, static_position=[r], dynamic_position=[]) + local_max = s_vals[0] + for r in range_constexpr(15): + local_max = arith.as_value( + flir.arith.MaximumFOp(local_max, s_vals[r + 1]).result + ) + peer_max = reduction_peer(local_max) + row_max = arith.as_value( + flir.arith.MaximumFOp(local_max, peer_max).result ) - s_val = arith.as_value( + m_new = arith.as_value( + flir.arith.MaximumFOp(m_running, row_max).result + ) + + diff_m = arith.as_value( + flir.arith.SubFOp(m_running, m_new, fastmath=fm_fast).result + ) + diff_m_s = arith.as_value( flir.arith.MulFOp( - s_val, arith.as_value(c_sm_scale), fastmath=fm_fast + diff_m, arith.as_value(c_log2e), fastmath=fm_fast ).result ) + corr = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) - if CAUSAL: - kv_row_rel = ( - arith.ArithValue(lane_div_32) * 4 - + (r // 4) * 8 - + (r % 4) - ).value - kv_col = (arith.ArithValue(kv_start) + arith.ArithValue(kv_row_rel)).value - kv_col_i64 = arith.as_value( - flir.arith.IndexCastOp(T.i64(), kv_col).result + p_vals = [] + local_sum = arith.as_value(c_zero_f) + for r in range_constexpr(16): + diff = arith.as_value( + flir.arith.SubFOp(s_vals[r], m_new, fastmath=fm_fast).result ) - is_masked = arith.as_value( - flir.arith.CmpIOp( - flir.arith.CmpIPredicate.ugt, kv_col_i64, q_row_i64 + diff_s = arith.as_value( + flir.arith.MulFOp( + diff, arith.as_value(c_log2e), fastmath=fm_fast ).result ) - s_val = arith.as_value( - flir.arith.SelectOp( - is_masked, arith.as_value(c_neg_inf), s_val - ).result + p = arith.as_value(flir.math.exp2(diff_s, fastmath=fm_fast)) + p_vals.append(p) + local_sum = arith.as_value( + flir.arith.AddFOp(local_sum, p, fastmath=fm_fast).result ) - s_vals.append(s_val) - - local_max = s_vals[0] - for r in range_constexpr(15): - local_max = arith.as_value( - flir.arith.MaximumFOp(local_max, s_vals[r + 1]).result - ) - peer_max = arith.as_value( - gpu.ShuffleOp(local_max, shuf_32_i32, width_i32, mode="xor").shuffleResult - ) - row_max = arith.as_value( - flir.arith.MaximumFOp(local_max, peer_max).result - ) - m_new = arith.as_value( - flir.arith.MaximumFOp(m_old, row_max).result - ) - - diff_m = arith.as_value( - flir.arith.SubFOp(m_old, m_new, fastmath=fm_fast).result - ) - diff_m_s = arith.as_value( - flir.arith.MulFOp(diff_m, arith.as_value(c_log2e), fastmath=fm_fast).result - ) - corr = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) - p_vals = [] - local_sum = arith.as_value(c_zero_f) - for r in range_constexpr(16): - diff = arith.as_value( - flir.arith.SubFOp(s_vals[r], m_new, fastmath=fm_fast).result + peer_sum = reduction_peer(local_sum) + tile_sum = arith.as_value( + flir.arith.AddFOp(local_sum, peer_sum, fastmath=fm_fast).result ) - diff_s = arith.as_value( - flir.arith.MulFOp(diff, arith.as_value(c_log2e), fastmath=fm_fast).result + l_corr = arith.as_value( + flir.arith.MulFOp(corr, l_running, fastmath=fm_fast).result ) - p = arith.as_value(flir.math.exp2(diff_s, fastmath=fm_fast)) - p_vals.append(p) - local_sum = arith.as_value( - flir.arith.AddFOp(local_sum, p, fastmath=fm_fast).result + l_new = arith.as_value( + flir.arith.AddFOp(l_corr, tile_sum, fastmath=fm_fast).result ) - peer_sum = arith.as_value( - gpu.ShuffleOp(local_sum, shuf_32_i32, width_i32, mode="xor").shuffleResult - ) - tile_sum = arith.as_value( - flir.arith.AddFOp(local_sum, peer_sum, fastmath=fm_fast).result - ) - l_corr = arith.as_value( - flir.arith.MulFOp(corr, l_old, fastmath=fm_fast).result - ) - l_new = arith.as_value( - flir.arith.AddFOp(l_corr, tile_sum, fastmath=fm_fast).result - ) - - # ==== Rescale O accumulators ==== - corr_vec = arith.as_value(vec_ext.broadcast(v16f32_type, corr)) - for dc in range_constexpr(D_CHUNKS): - o_accs[dc] = arith.as_value( - flir.arith.MulFOp(o_accs[dc], corr_vec, fastmath=fm_fast).result - ) + # ==== Rescale O accumulators ==== + corr_vec = arith.as_value(vec_ext.broadcast(v16f32_type, corr)) + for dc in range_constexpr(D_CHUNKS): + o_accs[dc] = arith.as_value( + flir.arith.MulFOp(o_accs[dc], corr_vec, fastmath=fm_fast).result + ) - # ==== Load V^T for current tile into LDS_KV ==== - coop_load_v_transposed(kv_start) - gpu.barrier() + if ENABLE_PREFETCH_3BUF and (kv_sub + preload_k_count) < N_SUBTILES: + next_k_sub = kv_sub + preload_k_count + next_k_start = ( + arith.ArithValue(kv_block_start) + next_k_sub * BLOCK_N + ).value + next_k_slot = CK_LDS_SEQ[next_k_sub % len(CK_LDS_SEQ)] % NUM_PREFETCH_K + coop_load_k(next_k_start, next_k_slot) - # ==== Build P packs in MFMA32 B-input format from register S ==== - p_f16 = [] - for r in range_constexpr(16): - p_f16.append( - arith.as_value(flir.arith.TruncFOp(elem_type, p_vals[r]).result) - ) - p_packs = [] - for pks in range_constexpr(PV_K_STEPS): - p_base = pks * 4 - p_packs.append( - arith.as_value( - vec_ext.from_elements( - v4f16_type, - [ - p_f16[p_base + 0], - p_f16[p_base + 1], - p_f16[p_base + 2], - p_f16[p_base + 3], - ], - ) + if ENABLE_PREFETCH_3BUF: + v_slot = CK_LDS_SEQ[kv_sub % len(CK_LDS_SEQ)] % NUM_PREFETCH_V + else: + v_slot = 0 + v_base = vt_buf_base(v_slot) + + # ==== Load V^T for current tile into LDS_KV ==== + coop_load_v_transposed(kv_start, v_slot) + rocdl.phase_group_barrier(rocdl.mask_dswr, 1, 0) + gpu.barrier() + + # ==== Build P packs in MFMA32 B-input format from register S ==== + p_f16 = [] + for r in range_constexpr(16): + p_f16.append( + arith.as_value(flir.arith.TruncFOp(elem_type, p_vals[r]).result) ) - ) - - # ==== GEMM2: O^T += V^T @ P (MFMA32) ==== - for dc in range_constexpr(D_CHUNKS): + p_packs = [] for pks in range_constexpr(PV_K_STEPS): - v_idx = ( - LDS_VT_BASE - + (dc * D_CHUNK + arith.ArithValue(lane_mod_32)) * VT_STRIDE - + pks * PV_K_STEP - + arith.ArithValue(lane_div_32) * 4 - ).value - v_pack = arith.as_value(vec_ext.load_op(v4f16_type, lds_kv, [v_idx])) - o_accs[dc] = arith.as_value( - rocdl.mfma_f32_32x32x8f16( - v16f32_type, [v_pack, p_packs[pks], o_accs[dc], 0, 0, 0] + p_base = pks * 4 + p_packs.append( + arith.as_value( + vec_ext.from_elements( + v4f16_type, + [ + p_f16[p_base + 0], + p_f16[p_base + 1], + p_f16[p_base + 2], + p_f16[p_base + 3], + ], + ) ) ) - yield_args = [m_new, l_new] + o_accs + # ==== GEMM2: O^T += V^T @ P (MFMA32) ==== + for dc in range_constexpr(D_CHUNKS): + for pks in range_constexpr(PV_K_STEPS): + v_idx = ( + arith.ArithValue(v_base) + + (dc * D_CHUNK + arith.ArithValue(lane_mod_32)) * VT_STRIDE + + pks * PV_K_STEP + + arith.ArithValue(lane_div_32) * 4 + ).value + v_pack = arith.as_value(vec_ext.load_op(v4f16_type, lds_kv, [v_idx])) + o_accs[dc] = arith.as_value( + rocdl.mfma_f32_32x32x8f16( + v16f32_type, [v_pack, p_packs[pks], o_accs[dc], 0, 0, 0] + ) + ) + + m_running = m_new + l_running = l_new + + yield_args = [m_running, l_running] + o_accs scf_yield(yield_args) # ---- Normalize and store O ---- diff --git a/tests/kernels/test_flash_attention_v4_4.py b/tests/kernels/test_flash_attention_v4_4.py index d60ac3b1..66d086d1 100644 --- a/tests/kernels/test_flash_attention_v4_4.py +++ b/tests/kernels/test_flash_attention_v4_4.py @@ -31,7 +31,11 @@ sys.exit(1) import flydsl -from kernels.flash_attention_v4_4 import build_flash_attention_v4_4_module, KERNEL_NAME +from kernels.flash_attention_v4_4 import ( + KERNEL_NAME, + build_flash_attention_v4_4_module, + select_v4_4_path, +) from tests.test_common import run_perftest # Tensor initialization range (uniform distribution) @@ -161,6 +165,10 @@ def compare_arrays( def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, warmup, iters, prev_exe=None, seed=DEFAULT_SEED): device = "cuda" results = {} + active_path = select_v4_4_path( + num_heads=num_heads, head_dim=head_dim, causal=causal, dtype_str="f16" + ) + results["active_path"] = active_path if seq_len % 128 != 0: results["err"] = f"seq_len ({seq_len}) must be divisible by 128 for V4.4" @@ -218,6 +226,7 @@ def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, warmup, iters # Compute and print MD5 hashes tag = f"B={B} S={S} H={H} D={D}" + print(f" [{tag}] active_path = {active_path}") result_md5 = compute_md5(o_flat) ref_md5 = compute_md5(ref_flat) print(f" [{tag}] result_md5 = {result_md5}") @@ -283,7 +292,7 @@ def main(): print("=" * 130) print(f"FlyDSL Flash Attention V4.4 ({'causal' if causal else 'non-causal'}, fp16)") - print(" Tile: BLOCK_M=128, BLOCK_N=32, 4 waves (256 threads), mfma_f32_32x32x8f16") + print(" Tile: BLOCK_M=128, BLOCK_N=32 fallback (default) + CK-like N=128 fast path (gated)") print(" Strategy: K@Q^T + register S/P ping-pong + V^T@P") print(f"GPU: {torch.cuda.get_device_name(0)}") print(f" Compile opts: {V4_4_COMPILE_KWARGS}") @@ -316,13 +325,13 @@ def main(): if args.compare_v43: hdr = ( - f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'Config/Path':>56s} | {'Status':>6s} | {'MaxErr':>8s} " f"{'MinCos':>8s} | {'V4.4(us)':>10s} {'V4.4 TF':>9s} | " f"{'V4.3(us)':>10s} {'V4.3 TF':>9s} | {'Speedup':>7s}" ) else: hdr = ( - f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'Config/Path':>56s} | {'Status':>6s} | {'MaxErr':>8s} " f"{'MinCos':>8s} | {'Time(us)':>10s} {'TFLOPS':>8s}" ) print(f"\n{hdr}") @@ -346,13 +355,15 @@ def main(): seed=args.seed, ) if "err" in r: - print(f"{tag:>38s} | {'ERROR':>6s} | {r['err'][:60]}") + cfg_path = f"{tag} / {r.get('active_path', 'unknown')}" + print(f"{cfg_path:>56s} | {'ERROR':>6s} | {r['err'][:60]}") all_passed = False continue status = "PASS" if r["passed"] else "FAIL" if not r["passed"]: all_passed = False + cfg_path = f"{tag} / {r.get('active_path', 'unknown')}" us_s = f"{r['us']:>10.1f}" if "us" in r else " N/A" tf_s = f"{r['tflops']:>9.3f}" if "tflops" in r else " N/A" @@ -362,18 +373,18 @@ def main(): p_tf = f"{r['prev_tflops']:>9.3f}" speedup = r["prev_us"] / r["us"] if r.get("us") else 0 print( - f"{tag:>38s} | {status:>6s} | " + f"{cfg_path:>56s} | {status:>6s} | " f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " f"{us_s} {tf_s} | {p_us} {p_tf} | {speedup:>6.2f}x" ) else: print( - f"{tag:>38s} | {status:>6s} | " + f"{cfg_path:>56s} | {status:>6s} | " f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " f"{us_s} {tf_s}" ) except Exception as e: - print(f"{tag:>38s} | {'ERROR':>6s} | {str(e)[:60]}") + print(f"{tag:>56s} | {'ERROR':>6s} | {str(e)[:60]}") all_passed = False print("=" * 130) From 4e0fff028e1cc2d34c11a40edfb14b0699ce8cae Mon Sep 17 00:00:00 2001 From: yanguahe Date: Fri, 13 Feb 2026 17:59:37 +0800 Subject: [PATCH 12/17] Rename flash_attention_v4_4 artifacts to flash_attn_func. Align kernel, test, and run-script references so the renamed entrypoint is used consistently across build and benchmark workflows. Co-authored-by: Cursor --- ...h_attention_v4_4.py => flash_attn_func.py} | 82 ++++++++++--------- run.sh | 7 +- ...ention_v4_4.py => test_flash_attn_func.py} | 29 +++---- 3 files changed, 62 insertions(+), 56 deletions(-) rename kernels/{flash_attention_v4_4.py => flash_attn_func.py} (97%) rename tests/kernels/{test_flash_attention_v4_4.py => test_flash_attn_func.py} (94%) diff --git a/kernels/flash_attention_v4_4.py b/kernels/flash_attn_func.py similarity index 97% rename from kernels/flash_attention_v4_4.py rename to kernels/flash_attn_func.py index 30271b45..cbe20899 100644 --- a/kernels/flash_attention_v4_4.py +++ b/kernels/flash_attn_func.py @@ -1,6 +1,6 @@ -"""Flash Attention V4.4 kernel builder for FlyDSL. +"""flash_attn_func kernel builder for FlyDSL. -Aggressive V4.4 path: +Aggressive flash_attn_func path: - True MFMA32 remap: `mfma_f32_32x32x8f16` for both GEMM stages. - Tile shape: BLOCK_M=128, BLOCK_N=32, 4 waves (256 threads). - Per-wave Q rows: 32. @@ -30,18 +30,18 @@ import _mlir.extras.types as T -KERNEL_NAME = "flash_attention_v4_4_kernel" +KERNEL_NAME = "flash_attn_func_kernel" -def select_v4_4_path(num_heads, head_dim, causal=True, dtype_str="f16"): - """Select active V4.4 path tag for build-time specialization.""" - override = os.getenv("FLYDSL_FA_V44_PATH", "auto").strip().lower() +def select_flash_attn_func_path(num_heads, head_dim, causal=True, dtype_str="f16"): + """Select active flash_attn_func path tag for build-time specialization.""" + override = os.getenv("FLYDSL_FLASH_ATTN_FUNC_PATH", "auto").strip().lower() if override in ("fallback", "fallback_n32", "n32"): return "fallback_n32" if override in ("fastpath", "ck_n128_fastpath", "n128"): return "ck_n128_fastpath" # Keep N128 path feature-gated by default due current occupancy/perf risk. - enable_n128 = os.getenv("FLYDSL_FA_V44_ENABLE_N128", "0") == "1" + enable_n128 = os.getenv("FLYDSL_FLASH_ATTN_FUNC_ENABLE_N128", "0") == "1" if ( enable_n128 and dtype_str == "f16" @@ -53,14 +53,14 @@ def select_v4_4_path(num_heads, head_dim, causal=True, dtype_str="f16"): return "fallback_n32" -def build_flash_attention_v4_4_module_primary( +def build_flash_attn_func_module_primary( num_heads, head_dim, causal=True, dtype_str="f16", sm_scale=None, ): - """Build a FlyDSL Flash Attention V4.4 module.""" + """Build a FlyDSL flash_attn_func module.""" gpu_arch = get_hip_arch() DYN = ir.ShapedType.get_dynamic_size() @@ -71,12 +71,16 @@ def build_flash_attention_v4_4_module_primary( WARP_SIZE = 64 BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 32 - PATH_TAG = select_v4_4_path(num_heads, head_dim, causal=causal, dtype_str=dtype_str) + PATH_TAG = select_flash_attn_func_path( + num_heads, head_dim, causal=causal, dtype_str=dtype_str + ) BLOCK_N_OUT = 128 if PATH_TAG == "ck_n128_fastpath" else BLOCK_N N_SUBTILES = BLOCK_N_OUT // BLOCK_N - ENABLE_PREFETCH_3BUF = os.getenv("FLYDSL_FA_V44_ENABLE_PREFETCH3", "0") == "1" - ENABLE_LDS_VEC16 = os.getenv("FLYDSL_FA_V44_ENABLE_LDS_VEC16", "1") == "1" - REDUCE_MODE = os.getenv("FLYDSL_FA_V44_REDUCE_MODE", "xor").strip().lower() + ENABLE_PREFETCH_3BUF = ( + os.getenv("FLYDSL_FLASH_ATTN_FUNC_ENABLE_PREFETCH3", "0") == "1" + ) + ENABLE_LDS_VEC16 = os.getenv("FLYDSL_FLASH_ATTN_FUNC_ENABLE_LDS_VEC16", "1") == "1" + REDUCE_MODE = os.getenv("FLYDSL_FLASH_ATTN_FUNC_REDUCE_MODE", "xor").strip().lower() if REDUCE_MODE not in ("xor", "ds_bpermute"): REDUCE_MODE = "xor" NUM_PREFETCH_K = 3 if ENABLE_PREFETCH_3BUF else 1 @@ -95,7 +99,7 @@ def build_flash_attention_v4_4_module_primary( assert BLOCK_M % NUM_WAVES == 0 assert head_dim % 32 == 0, f"head_dim ({head_dim}) must be divisible by 32" assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" - assert dtype_str == "f16", "V4.4 currently only supports f16" + assert dtype_str == "f16", "flash_attn_func currently only supports f16" assert BLOCK_N % 32 == 0 assert BLOCK_N_OUT % BLOCK_N == 0 @@ -137,8 +141,8 @@ def build_flash_attention_v4_4_module_primary( allocator = SmemAllocator(None, arch=gpu_arch) _state = {} - class _FlashAttentionV4_4(flir.MlirModule): - GPU_MODULE_NAME = f"flash_attn_v4_4_{dtype_str}_{PATH_TAG}" + class _FlashAttnFunc(flir.MlirModule): + GPU_MODULE_NAME = f"flash_attn_func_{dtype_str}_{PATH_TAG}" KERNEL_VARIANT = PATH_TAG GPU_MODULE_TARGETS = [f'#rocdl.target'] @@ -149,7 +153,7 @@ def init_gpu_module(self): allocator.finalize() @flir.kernel - def flash_attention_v4_4_kernel( + def flash_attn_func_kernel( self: flir.T.i64, Q: lambda: T.memref(DYN, _state["elem_type"]), K: lambda: T.memref(DYN, _state["elem_type"]), @@ -628,13 +632,13 @@ def __call__( kernel_operands=[Q, K, V, O, seq_len], ) - return _FlashAttentionV4_4() + return _FlashAttnFunc() -build_flash_attention_v4_4_module = build_flash_attention_v4_4_module_primary -"""Flash Attention V4.4 kernel builder for FlyDSL. +build_flash_attn_func_module = build_flash_attn_func_module_primary +"""flash_attn_func kernel builder for FlyDSL. -Aggressive V4.4 path: +Aggressive flash_attn_func path: - True MFMA32 remap: `mfma_f32_32x32x8f16` for both GEMM stages. - Tile shape: BLOCK_M=128, BLOCK_N=32, 4 waves (256 threads). - Per-wave Q rows: 32. @@ -663,17 +667,17 @@ def __call__( import _mlir.extras.types as T -KERNEL_NAME = "flash_attention_v4_4_kernel" +KERNEL_NAME = "flash_attn_func_kernel" -def _legacy_copy_build_flash_attention_v4_4_module_2( +def _legacy_copy_build_flash_attn_func_module_2( num_heads, head_dim, causal=True, dtype_str="f16", sm_scale=None, ): - """Build a FlyDSL Flash Attention V4.4 module.""" + """Build a FlyDSL flash_attn_func module.""" gpu_arch = get_hip_arch() DYN = ir.ShapedType.get_dynamic_size() @@ -697,7 +701,7 @@ def _legacy_copy_build_flash_attention_v4_4_module_2( assert BLOCK_M % NUM_WAVES == 0 assert head_dim % 32 == 0, f"head_dim ({head_dim}) must be divisible by 32" assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" - assert dtype_str == "f16", "V4.4 currently only supports f16" + assert dtype_str == "f16", "flash_attn_func currently only supports f16" assert BLOCK_N % 32 == 0 if sm_scale is None: @@ -733,8 +737,8 @@ def _legacy_copy_build_flash_attention_v4_4_module_2( allocator = SmemAllocator(None, arch=gpu_arch) _state = {} - class _FlashAttentionV4_4(flir.MlirModule): - GPU_MODULE_NAME = f"flash_attn_v4_4_{dtype_str}" + class _FlashAttnFunc(flir.MlirModule): + GPU_MODULE_NAME = f"flash_attn_func_{dtype_str}" GPU_MODULE_TARGETS = [f'#rocdl.target'] def init_gpu_module(self): @@ -744,7 +748,7 @@ def init_gpu_module(self): allocator.finalize() @flir.kernel - def flash_attention_v4_4_kernel( + def flash_attn_func_kernel( self: flir.T.i64, Q: lambda: T.memref(DYN, _state["elem_type"]), K: lambda: T.memref(DYN, _state["elem_type"]), @@ -1193,10 +1197,10 @@ def __call__( kernel_operands=[Q, K, V, O, seq_len], ) - return _FlashAttentionV4_4() -"""Flash Attention V4.4 kernel builder for FlyDSL. + return _FlashAttnFunc() +"""flash_attn_func kernel builder for FlyDSL. -V4.4 design (CK-aligned direction, rewritten from V4.3): +flash_attn_func design (CK-aligned direction, rewritten from V4.3): - CK-aligned baseline tile family: BLOCK_M=64, BLOCK_N=32. - Q loaded once from global memory into MFMA A-operand packs (register-resident). - K/V streamed tile-by-tile through LDS. @@ -1223,17 +1227,17 @@ def __call__( import _mlir.extras.types as T -KERNEL_NAME = "flash_attention_v4_4_kernel" +KERNEL_NAME = "flash_attn_func_kernel" -def _legacy_copy_build_flash_attention_v4_4_module_3( +def _legacy_copy_build_flash_attn_func_module_3( num_heads, head_dim, causal=True, dtype_str="f16", sm_scale=None, ): - """Build a FlyDSL Flash Attention V4.4 module. + """Build a FlyDSL flash_attn_func module. Args: num_heads: Number of attention heads. @@ -1261,7 +1265,7 @@ def _legacy_copy_build_flash_attention_v4_4_module_3( assert BLOCK_M % NUM_WAVES == 0 assert head_dim % 16 == 0, f"head_dim ({head_dim}) must be divisible by 16" assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" - assert dtype_str == "f16", "V4.4 currently only supports f16" + assert dtype_str == "f16", "flash_attn_func currently only supports f16" assert BLOCK_N % 16 == 0, f"BLOCK_N ({BLOCK_N}) must be divisible by 16" if sm_scale is None: @@ -1301,8 +1305,8 @@ def _legacy_copy_build_flash_attention_v4_4_module_3( allocator = SmemAllocator(None, arch=gpu_arch) _state = {} - class _FlashAttentionV4_4(flir.MlirModule): - GPU_MODULE_NAME = f"flash_attn_v4_4_{dtype_str}" + class _FlashAttnFunc(flir.MlirModule): + GPU_MODULE_NAME = f"flash_attn_func_{dtype_str}" GPU_MODULE_TARGETS = [f'#rocdl.target'] def init_gpu_module(self): @@ -1313,7 +1317,7 @@ def init_gpu_module(self): allocator.finalize() @flir.kernel - def flash_attention_v4_4_kernel( + def flash_attn_func_kernel( self: flir.T.i64, Q: lambda: T.memref(DYN, _state["elem_type"]), K: lambda: T.memref(DYN, _state["elem_type"]), @@ -1776,4 +1780,4 @@ def __call__( kernel_operands=[Q, K, V, O, seq_len], ) - return _FlashAttentionV4_4() + return _FlashAttnFunc() diff --git a/run.sh b/run.sh index 7a99674d..ec4ff6cc 100755 --- a/run.sh +++ b/run.sh @@ -39,10 +39,11 @@ function run_flydsl_op { # python tests/kernels/test_flash_attention_v4_2.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 # python tests/kernels/test_flash_attention_v4_3.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v42 - python tests/kernels/test_flash_attention_v4_4.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v43 + # python tests/kernels/test_flash_attn_func.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v43 + python tests/kernels/test_flash_attn_func.py --iters 100 --compare-v43 - # rocprof -i perf_counters1.txt -o prof_v44_p1.csv python tests/kernels/test_flash_attention_v4_4.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 5 --warmup 2 - # rocprof -i perf_counters2.txt -o prof_v44_p2.csv python tests/kernels/test_flash_attention_v4_4.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 5 --warmup 2 + # rocprof -i perf_counters1.txt -o prof_v44_p1.csv python tests/kernels/test_flash_attn_func.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 5 --warmup 2 + # rocprof -i perf_counters2.txt -o prof_v44_p2.csv python tests/kernels/test_flash_attn_func.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 5 --warmup 2 } diff --git a/tests/kernels/test_flash_attention_v4_4.py b/tests/kernels/test_flash_attn_func.py similarity index 94% rename from tests/kernels/test_flash_attention_v4_4.py rename to tests/kernels/test_flash_attn_func.py index 66d086d1..6748a250 100644 --- a/tests/kernels/test_flash_attention_v4_4.py +++ b/tests/kernels/test_flash_attn_func.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 -"""Flash Attention V4.4 kernel test and benchmark for FlyDSL. +"""flash_attn_func kernel test and benchmark for FlyDSL. -Tests V4.4 against PyTorch SDPA. +Tests flash_attn_func against PyTorch SDPA. Optionally compares with V4.3. """ @@ -31,17 +31,17 @@ sys.exit(1) import flydsl -from kernels.flash_attention_v4_4 import ( +from kernels.flash_attn_func import ( KERNEL_NAME, - build_flash_attention_v4_4_module, - select_v4_4_path, + build_flash_attn_func_module, + select_flash_attn_func_path, ) from tests.test_common import run_perftest # Tensor initialization range (uniform distribution) UNIFORM_RANGE = (-1, 1) DEFAULT_SEED = 123 -V4_4_COMPILE_KWARGS = { +FLASH_ATTN_FUNC_COMPILE_KWARGS = { "unsafe_fp_math": True, "fast_fp_math": True, "waves_per_eu": 3, @@ -165,23 +165,23 @@ def compare_arrays( def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, warmup, iters, prev_exe=None, seed=DEFAULT_SEED): device = "cuda" results = {} - active_path = select_v4_4_path( + active_path = select_flash_attn_func_path( num_heads=num_heads, head_dim=head_dim, causal=causal, dtype_str="f16" ) results["active_path"] = active_path if seq_len % 128 != 0: - results["err"] = f"seq_len ({seq_len}) must be divisible by 128 for V4.4" + results["err"] = f"seq_len ({seq_len}) must be divisible by 128 for flash_attn_func" return results if head_dim % 32 != 0 or head_dim < 64: results["err"] = f"head_dim ({head_dim}) must be >= 64 and divisible by 32" return results try: - m = build_flash_attention_v4_4_module( + m = build_flash_attn_func_module( num_heads=num_heads, head_dim=head_dim, causal=causal, dtype_str="f16" ) - exe = flydsl.compile(m, **V4_4_COMPILE_KWARGS) + exe = flydsl.compile(m, **FLASH_ATTN_FUNC_COMPILE_KWARGS) except Exception as e: results["err"] = f"compile: {e}" import traceback @@ -273,7 +273,7 @@ def prev_fn(): def main(): - parser = argparse.ArgumentParser(description="Flash Attention V4.4 FlyDSL Test/Benchmark") + parser = argparse.ArgumentParser(description="flash_attn_func FlyDSL Test/Benchmark") parser.add_argument("--batch", type=int, default=None) parser.add_argument("--seq_len", type=int, default=None) parser.add_argument("--num_heads", type=int, default=None) @@ -291,11 +291,11 @@ def main(): dtype = torch.float16 print("=" * 130) - print(f"FlyDSL Flash Attention V4.4 ({'causal' if causal else 'non-causal'}, fp16)") + print(f"FlyDSL flash_attn_func ({'causal' if causal else 'non-causal'}, fp16)") print(" Tile: BLOCK_M=128, BLOCK_N=32 fallback (default) + CK-like N=128 fast path (gated)") print(" Strategy: K@Q^T + register S/P ping-pong + V^T@P") print(f"GPU: {torch.cuda.get_device_name(0)}") - print(f" Compile opts: {V4_4_COMPILE_KWARGS}") + print(f" Compile opts: {FLASH_ATTN_FUNC_COMPILE_KWARGS}") print("=" * 130) if args.seq_len or args.head_dim or args.batch: @@ -306,6 +306,7 @@ def main(): (1, 256, 32, 128), (1, 512, 32, 128), (2, 128, 8, 128), + (1, 8192, 64, 128), ] prev_exes = {} @@ -326,7 +327,7 @@ def main(): if args.compare_v43: hdr = ( f"{'Config/Path':>56s} | {'Status':>6s} | {'MaxErr':>8s} " - f"{'MinCos':>8s} | {'V4.4(us)':>10s} {'V4.4 TF':>9s} | " + f"{'MinCos':>8s} | {'Func(us)':>10s} {'Func TF':>9s} | " f"{'V4.3(us)':>10s} {'V4.3 TF':>9s} | {'Speedup':>7s}" ) else: From 5767f759faa06447b8db31fe2f1b2f43d1035e4e Mon Sep 17 00:00:00 2001 From: yanguahe Date: Fri, 13 Feb 2026 18:10:18 +0800 Subject: [PATCH 13/17] Remove legacy flash_attention_v4 variants and simplify flash_attn_func benchmarking. Drop the v4.3 comparison path to keep tests focused on flash_attn_func and align run.sh defaults with the updated benchmark flow. Co-authored-by: Cursor --- kernels/flash_attention_v4.py | 539 ----------------- kernels/flash_attention_v4_1.py | 612 ------------------- kernels/flash_attention_v4_2.py | 667 --------------------- kernels/flash_attention_v4_3.py | 650 -------------------- run.sh | 5 +- tests/kernels/test_flash_attention_v4.py | 288 --------- tests/kernels/test_flash_attention_v4_1.py | 288 --------- tests/kernels/test_flash_attention_v4_2.py | 394 ------------ tests/kernels/test_flash_attention_v4_3.py | 394 ------------ tests/kernels/test_flash_attn_func.py | 73 +-- 10 files changed, 16 insertions(+), 3894 deletions(-) delete mode 100644 kernels/flash_attention_v4.py delete mode 100644 kernels/flash_attention_v4_1.py delete mode 100644 kernels/flash_attention_v4_2.py delete mode 100644 kernels/flash_attention_v4_3.py delete mode 100644 tests/kernels/test_flash_attention_v4.py delete mode 100644 tests/kernels/test_flash_attention_v4_1.py delete mode 100644 tests/kernels/test_flash_attention_v4_2.py delete mode 100644 tests/kernels/test_flash_attention_v4_3.py diff --git a/kernels/flash_attention_v4.py b/kernels/flash_attention_v4.py deleted file mode 100644 index 1eb3c7b6..00000000 --- a/kernels/flash_attention_v4.py +++ /dev/null @@ -1,539 +0,0 @@ -"""Flash Attention V4 kernel builder for FlyDSL. - -Multi-wave MFMA implementation: 4 waves (256 threads), BLOCK_M=64, BLOCK_N=16. -Each wave owns 16 Q-rows and performs independent MFMA Q@K^T and P@V. -All 256 threads cooperate on K/V tile loads. - -Inspired by the poc_kl ASM kernel (8w, 256x32, mfma_32x32x8, ping-pong LDS). -This V4.0 uses mfma_f32_16x16x16f16 with single-buffered K/V as a first step. - -V4.0 vs V3.1: -- 4 waves (256 threads) vs 1 wave (64 threads) → 4x more MFMA throughput per CU -- BLOCK_M=64 vs 16 → 4x fewer grid workgroups, better scheduling -- Cooperative load: 256 threads → K/V tile (16×HD) loaded in 1 batch (vs 4) -- LDS: ~30KB (Q=16KB + KV=4KB + P=2KB) - -Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). -Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. -Block: (256,) — 4 waves of 64 on AMD (wave64). - -Requires: head_dim % 16 == 0, seq_len % 64 == 0, head_dim >= 64. -""" - -import math - -from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl -from flydsl.dialects.ext import vector as vec_ext -from flydsl.dialects.ext.python_control_flow import range_constexpr -from flydsl.dialects.ext.scf import yield_ as scf_yield -from _mlir.dialects import memref as _memref -from flydsl.runtime.device import get_rocm_arch as get_hip_arch -from flydsl.utils import SmemAllocator -from _mlir import ir -import _mlir.extras.types as T - - -KERNEL_NAME = "flash_attention_v4_kernel" - - -def build_flash_attention_v4_module( - num_heads, - head_dim, - causal=True, - dtype_str="f16", - sm_scale=None, -): - """Build a FlyDSL Flash Attention V4 module with multi-wave MFMA tiling. - - Args: - num_heads: Number of attention heads. - head_dim: Dimension per head (must be divisible by 16, >= 64). - causal: Whether to apply causal mask. - dtype_str: "f16" (bf16 not yet supported). - sm_scale: Softmax scale (default: 1/sqrt(head_dim)). - - Returns: - MlirModule compilable via ``flydsl.compile(module)``. - """ - gpu_arch = get_hip_arch() - DYN = ir.ShapedType.get_dynamic_size() - - BLOCK_M = 64 - BLOCK_N = 16 - NUM_WAVES = 4 - WARP_SIZE = 64 - BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 - ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 16 - K_STEPS = head_dim // 16 - - assert head_dim % 16 == 0, f"head_dim ({head_dim}) must be divisible by 16" - assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" - assert dtype_str == "f16", "V4 currently only supports f16" - - if sm_scale is None: - sm_scale = 1.0 / math.sqrt(head_dim) - - NUM_HEADS = num_heads - HEAD_DIM = head_dim - CAUSAL = causal - STRIDE_TOKEN = NUM_HEADS * HEAD_DIM - - # ---- Vectorized cooperative load constants ---- - VEC_WIDTH = 8 # v8f16 = 16 bytes - THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH - assert BLOCK_SIZE % THREADS_PER_ROW_LOAD == 0 - ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD - - # For Q tile (64 rows): NUM_BATCHES_Q = 64 / ROWS_PER_BATCH - assert BLOCK_M % ROWS_PER_BATCH_LOAD == 0 - NUM_BATCHES_Q = BLOCK_M // ROWS_PER_BATCH_LOAD - - # For KV tile (16 rows): NUM_BATCHES_KV = 16 / ROWS_PER_BATCH - assert BLOCK_N % ROWS_PER_BATCH_LOAD == 0 or ROWS_PER_BATCH_LOAD >= BLOCK_N - if ROWS_PER_BATCH_LOAD >= BLOCK_N: - NUM_BATCHES_KV = 1 - # Some threads will be idle (load_row >= BLOCK_N). Need guard. - KV_NEEDS_GUARD = ROWS_PER_BATCH_LOAD > BLOCK_N - else: - NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD - KV_NEEDS_GUARD = False - - allocator = SmemAllocator(None, arch=gpu_arch) - _state = {} - - class _FlashAttentionV4(flir.MlirModule): - GPU_MODULE_NAME = f"flash_attn_v4_{dtype_str}" - GPU_MODULE_TARGETS = [f'#rocdl.target'] - - def init_gpu_module(self): - elem_type = T.f16() - _state["elem_type"] = elem_type - _state["lds_q"] = allocator.allocate_array(elem_type, BLOCK_M * HEAD_DIM) - _state["lds_kv"] = allocator.allocate_array(elem_type, BLOCK_N * HEAD_DIM) - _state["lds_p"] = allocator.allocate_array(elem_type, BLOCK_M * BLOCK_N) - allocator.finalize() - - @flir.kernel - def flash_attention_v4_kernel( - self: flir.T.i64, - Q: lambda: T.memref(DYN, _state["elem_type"]), - K: lambda: T.memref(DYN, _state["elem_type"]), - V: lambda: T.memref(DYN, _state["elem_type"]), - O: lambda: T.memref(DYN, _state["elem_type"]), - seq_len: lambda: T.index(), - ): - compute_type = T.f32() - elem_type = _state["elem_type"] - fm_fast = flir.arith.FastMathFlags.fast - - v4f16_type = ir.VectorType.get([4], elem_type) - v4f32_type = ir.VectorType.get([4], compute_type) - v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) - - seq_len_v = arith.as_value(seq_len) - - # ---- LDS views ---- - base_ptr = allocator.get_base() - lds_q = _state["lds_q"](base_ptr).get() - lds_kv = _state["lds_kv"](base_ptr).get() - lds_p = _state["lds_p"](base_ptr).get() - - # ---- Thread / block indices ---- - block_id = flir.const_index(flir.block_idx("x")) - tid = flir.const_index(flir.thread_idx("x")) - - # ---- Wave decomposition ---- - c_ws = flir.const_index(WARP_SIZE) - wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) - lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) - - # ---- MFMA lane decomposition (within each wave) ---- - c16 = flir.const_index(16) - lane_div_16 = arith.as_value(flir.arith.DivUIOp(lane, c16).result) - lane_mod_16 = arith.as_value(flir.arith.RemUIOp(lane, c16).result) - - # ---- Wave's Q-row offset in the Q tile ---- - wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value - # Wave's P offset in lds_p - wave_p_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE * BLOCK_N).value - - # ---- Decompose block_id -> (batch_idx, q_tile_idx, head_idx) ---- - c_nh = flir.const_index(NUM_HEADS) - head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) - temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) - c_bm = flir.const_index(BLOCK_M) - num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) - q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) - batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) - q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value - - # ---- Vectorized load thread decomposition ---- - c_tpr = flir.const_index(THREADS_PER_ROW_LOAD) - load_row_in_batch = arith.as_value( - flir.arith.DivUIOp(tid, c_tpr).result - ) - load_lane_in_row = arith.as_value( - flir.arith.RemUIOp(tid, c_tpr).result - ) - load_col_base = ( - arith.ArithValue(load_lane_in_row) * VEC_WIDTH - ).value - - # ---- Helper: global flat index ---- - def global_idx(token_idx, col): - token = ( - arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) - + arith.ArithValue(token_idx) - ) - return ( - token * STRIDE_TOKEN - + arith.ArithValue(head_idx) * HEAD_DIM - + arith.ArithValue(col) - ).value - - # ---- Cooperative Q load (64 rows, all 256 threads) ---- - def coop_load_q(): - for batch in range_constexpr(NUM_BATCHES_Q): - row_offset = batch * ROWS_PER_BATCH_LOAD - row_idx = ( - arith.ArithValue(q_start) - + arith.ArithValue(load_row_in_batch) - + row_offset - ).value - g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value( - vec_ext.load_op(v8f16_type, Q, [g_idx]) - ) - lds_row = ( - arith.ArithValue(load_row_in_batch) + row_offset - ).value - lds_idx = ( - arith.ArithValue(lds_row) * HEAD_DIM - + arith.ArithValue(load_col_base) - ).value - vec_ext.store(vec, lds_q, [lds_idx]) - - # ---- Cooperative KV load (16 rows, 256 threads — may need guard) ---- - def coop_load_kv(src_memref, lds_memref, tile_start): - if KV_NEEDS_GUARD: - # With 256 threads and THREADS_PER_ROW=16, ROWS_PER_BATCH=16 - # which equals BLOCK_N=16. No guard needed in this config. - # But for safety, handle the case where ROWS_PER_BATCH > BLOCK_N. - c_bn = flir.const_index(BLOCK_N) - row_ok = arith.as_value( - flir.arith.CmpIOp( - flir.arith.CmpIPredicate.ult, - load_row_in_batch, c_bn, - ).result - ) - # Only threads with row < BLOCK_N participate - # Use scf.if_ for conditional store - from flydsl.dialects.ext.scf import IfOp - if_op = IfOp(row_ok) - with if_op: - row_idx = ( - arith.ArithValue(tile_start) - + arith.ArithValue(load_row_in_batch) - ).value - g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value( - vec_ext.load_op(v8f16_type, src_memref, [g_idx]) - ) - lds_idx = ( - arith.ArithValue(load_row_in_batch) * HEAD_DIM - + arith.ArithValue(load_col_base) - ).value - vec_ext.store(vec, lds_memref, [lds_idx]) - else: - for batch in range_constexpr(NUM_BATCHES_KV): - row_offset = batch * ROWS_PER_BATCH_LOAD - row_idx = ( - arith.ArithValue(tile_start) - + arith.ArithValue(load_row_in_batch) - + row_offset - ).value - g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value( - vec_ext.load_op(v8f16_type, src_memref, [g_idx]) - ) - lds_row = ( - arith.ArithValue(load_row_in_batch) + row_offset - ).value - lds_idx = ( - arith.ArithValue(lds_row) * HEAD_DIM - + arith.ArithValue(load_col_base) - ).value - vec_ext.store(vec, lds_memref, [lds_idx]) - - # ---- Load Q tile to LDS ---- - coop_load_q() - gpu.barrier() - - # ---- Constants ---- - c_neg_inf = arith.constant(float("-inf"), type=compute_type) - c_zero_f = arith.constant(0.0, type=compute_type) - c_sm_scale = arith.constant(sm_scale, type=compute_type) - c_log2e = arith.constant(1.4426950408889634, type=compute_type) - c_zero_v4f32 = arith.as_value( - arith.constant_vector(0.0, v4f32_type) - ) - - # ---- Init loop-carried state ---- - # [m_0..m_3, l_0..l_3, o_acc_0..o_acc_{K_STEPS-1}] - init_args = [] - for _ in range_constexpr(4): - init_args.append(arith.as_value(c_neg_inf)) - for _ in range_constexpr(4): - init_args.append(arith.as_value(c_zero_f)) - for _ in range_constexpr(K_STEPS): - init_args.append(c_zero_v4f32) - - # ---- KV loop ---- - with scf.for_(0, seq_len_v, BLOCK_N, iter_args=init_args) as loop: - kv_start = arith.as_value(loop.induction_variable) - m_old = [arith.as_value(loop.inner_iter_args[i]) for i in range(4)] - l_old = [arith.as_value(loop.inner_iter_args[4 + i]) for i in range(4)] - o_accs = [arith.as_value(loop.inner_iter_args[8 + ds]) for ds in range(K_STEPS)] - - # ==== Cooperative K load -> LDS_KV ==== - coop_load_kv(K, lds_kv, kv_start) - gpu.barrier() - - # ==== Q @ K^T via MFMA (each wave uses its Q rows) ==== - s_acc = c_zero_v4f32 - for ks in range_constexpr(K_STEPS): - # A operand (Q): lane's Q row within this wave's 16 rows - q_lds_idx = ( - (arith.ArithValue(wave_q_offset) - + arith.ArithValue(lane_mod_16)) * HEAD_DIM - + ks * 16 - + arith.ArithValue(lane_div_16) * 4 - ).value - a_pack = arith.as_value( - vec_ext.load_op(v4f16_type, lds_q, [q_lds_idx]) - ) - # B operand (K^T): same for all waves (shared K tile) - k_lds_idx = ( - arith.ArithValue(lane_mod_16) * HEAD_DIM - + ks * 16 - + arith.ArithValue(lane_div_16) * 4 - ).value - b_pack = arith.as_value( - vec_ext.load_op(v4f16_type, lds_kv, [k_lds_idx]) - ) - s_acc = arith.as_value( - rocdl.mfma_f32_16x16x16f16(v4f32_type, [a_pack, b_pack, s_acc, 0, 0, 0]) - ) - - # ==== Online softmax (per-wave, per-row) ==== - s_vals = [] - for ii in range_constexpr(4): - s_ii = arith.as_value( - vec_ext.extract(s_acc, static_position=[ii], dynamic_position=[]) - ) - s_ii = arith.as_value( - flir.arith.MulFOp(s_ii, arith.as_value(c_sm_scale), fastmath=fm_fast).result - ) - if CAUSAL: - # Global Q row for this lane's ii-th element - q_row = ( - arith.ArithValue(q_start) - + arith.ArithValue(wave_q_offset) - + arith.ArithValue(lane_div_16) * 4 - + ii - ).value - kv_col = (arith.ArithValue(kv_start) + arith.ArithValue(lane_mod_16)).value - q_row_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), q_row).result) - kv_col_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), kv_col).result) - is_masked = arith.as_value( - flir.arith.CmpIOp( - flir.arith.CmpIPredicate.ugt, kv_col_i64, q_row_i64, - ).result - ) - s_ii = arith.as_value( - flir.arith.SelectOp(is_masked, arith.as_value(c_neg_inf), s_ii).result - ) - s_vals.append(s_ii) - - width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) - m_new = [None] * 4 - corr = [None] * 4 - p_vals = [None] * 4 - l_new = [None] * 4 - - for ii in range_constexpr(4): - row_max = s_vals[ii] - for sh in [8, 4, 2, 1]: - sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value( - gpu.ShuffleOp(row_max, sh_i32, width_i32, mode="xor").shuffleResult - ) - row_max = arith.as_value( - flir.arith.MaximumFOp(row_max, peer).result - ) - - m_new[ii] = arith.as_value( - flir.arith.MaximumFOp(m_old[ii], row_max).result - ) - - diff_m = arith.as_value( - flir.arith.SubFOp(m_old[ii], m_new[ii], fastmath=fm_fast).result - ) - diff_m_s = arith.as_value( - flir.arith.MulFOp(diff_m, arith.as_value(c_log2e), fastmath=fm_fast).result - ) - corr[ii] = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) - - diff_s = arith.as_value( - flir.arith.SubFOp(s_vals[ii], m_new[ii], fastmath=fm_fast).result - ) - diff_s_s = arith.as_value( - flir.arith.MulFOp(diff_s, arith.as_value(c_log2e), fastmath=fm_fast).result - ) - p_vals[ii] = arith.as_value(flir.math.exp2(diff_s_s, fastmath=fm_fast)) - - row_sum = p_vals[ii] - for sh in [8, 4, 2, 1]: - sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value( - gpu.ShuffleOp(row_sum, sh_i32, width_i32, mode="xor").shuffleResult - ) - row_sum = arith.as_value( - flir.arith.AddFOp(row_sum, peer, fastmath=fm_fast).result - ) - - l_corr = arith.as_value( - flir.arith.MulFOp(corr[ii], l_old[ii], fastmath=fm_fast).result - ) - l_new[ii] = arith.as_value( - flir.arith.AddFOp(l_corr, row_sum, fastmath=fm_fast).result - ) - - # ==== Rescale O accumulators ==== - corr_vec = arith.as_value( - vec_ext.from_elements(v4f32_type, [corr[0], corr[1], corr[2], corr[3]]) - ) - for ds in range_constexpr(K_STEPS): - o_accs[ds] = arith.as_value( - flir.arith.MulFOp(o_accs[ds], corr_vec, fastmath=fm_fast).result - ) - - # ==== P store to LDS_P (each wave writes its 16×16 section) ==== - for ii in range_constexpr(4): - p_f16 = arith.as_value( - flir.arith.TruncFOp(elem_type, p_vals[ii]).result - ) - p_row = (arith.ArithValue(lane_div_16) * 4 + ii).value - p_lds_idx = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(p_row) * BLOCK_N - + arith.ArithValue(lane_mod_16) - ).value - _memref.StoreOp(p_f16, lds_p, [p_lds_idx]) - - # ==== Barrier: ensure all waves finished reading K from lds_kv ==== - gpu.barrier() - - # ==== Cooperative V load -> LDS_KV (overwrites K) ==== - coop_load_kv(V, lds_kv, kv_start) - gpu.barrier() - - # ==== P load (A-operand, wave-local) ==== - p_a_idx = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(lane_mod_16) * BLOCK_N - + arith.ArithValue(lane_div_16) * 4 - ).value - p_pack = arith.as_value( - vec_ext.load_op(v4f16_type, lds_p, [p_a_idx]) - ) - - # ==== P @ V via MFMA ==== - for ds in range_constexpr(K_STEPS): - v_elems = [] - for e in range_constexpr(4): - v_row = (arith.ArithValue(lane_div_16) * 4 + e).value - v_lds_idx = ( - arith.ArithValue(v_row) * HEAD_DIM - + ds * 16 - + arith.ArithValue(lane_mod_16) - ).value - v_val = _memref.LoadOp(lds_kv, [v_lds_idx]).result - v_elems.append(arith.as_value(v_val)) - v_pack = arith.as_value( - vec_ext.from_elements(v4f16_type, v_elems) - ) - o_accs[ds] = arith.as_value( - rocdl.mfma_f32_16x16x16f16( - v4f32_type, [p_pack, v_pack, o_accs[ds], 0, 0, 0] - ) - ) - - # ==== Barrier: ensure all waves finished P@V (reading lds_kv) - # before next iteration overwrites lds_kv with K ==== - gpu.barrier() - - # ==== Yield ==== - yield_args = m_new + l_new + o_accs - scf_yield(yield_args) - - # ---- Normalize and store O ---- - m_finals = [arith.as_value(loop.results[i]) for i in range(4)] - l_finals = [arith.as_value(loop.results[4 + i]) for i in range(4)] - o_finals = [arith.as_value(loop.results[8 + ds]) for ds in range(K_STEPS)] - - for ds in range_constexpr(K_STEPS): - for ii in range_constexpr(4): - o_val = arith.as_value( - vec_ext.extract(o_finals[ds], static_position=[ii], dynamic_position=[]) - ) - o_norm = arith.as_value( - flir.arith.DivFOp(o_val, l_finals[ii], fastmath=fm_fast).result - ) - o_f16 = arith.as_value( - flir.arith.TruncFOp(elem_type, o_norm).result - ) - # Global store: each wave writes its Q rows - q_row = ( - arith.ArithValue(q_start) - + arith.ArithValue(wave_q_offset) - + arith.ArithValue(lane_div_16) * 4 - + ii - ).value - d_col = (flir.const_index(ds * 16) + arith.ArithValue(lane_mod_16)).value - o_global = global_idx(q_row, d_col) - _memref.StoreOp(o_f16, O, [o_global]) - - @flir.jit - def __call__( - self: flir.T.i64, - Q: lambda: T.memref(DYN, _state["elem_type"]), - K: lambda: T.memref(DYN, _state["elem_type"]), - V: lambda: T.memref(DYN, _state["elem_type"]), - O: lambda: T.memref(DYN, _state["elem_type"]), - batch_size: lambda: T.index(), - seq_len: lambda: T.index(), - ): - c1 = arith.as_value(flir.arith_ext.index(1)) - c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) - c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) - bs_val = arith.as_value(batch_size) - sl_val = arith.as_value(seq_len) - num_q_tiles = arith.as_value( - flir.arith.DivUIOp(sl_val, c_bm).result - ) - bs_qt = arith.as_value( - flir.arith.MulIOp(bs_val, num_q_tiles).result - ) - grid_x = arith.as_value( - flir.arith.MulIOp(bs_qt, c_nh).result - ) - bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) - flir.gpu_ext.LaunchFuncOp( - [self.GPU_MODULE_NAME, KERNEL_NAME], - grid_size=(grid_x, c1, c1), - block_size=(bx, c1, c1), - kernel_operands=[Q, K, V, O, seq_len], - ) - - return _FlashAttentionV4() diff --git a/kernels/flash_attention_v4_1.py b/kernels/flash_attention_v4_1.py deleted file mode 100644 index 9c407a28..00000000 --- a/kernels/flash_attention_v4_1.py +++ /dev/null @@ -1,612 +0,0 @@ -"""Flash Attention V4.1 kernel builder for FlyDSL. - -V4.1 optimizations over V4.0: -- Q preloaded to registers (eliminates Q LDS reads from KV loop) -- V stored transposed in LDS (vectorized v4f16 B-operand loads) -- Bank-conflict-free LDS padding (K stride=HD+2, V transposed stride=BLOCK_N+2) - -Tile config: BLOCK_M=64, BLOCK_N=16, 4 waves (256 threads), mfma_f32_16x16x16f16. - -Expected improvements from V4.0: -- ~32% fewer LDS instructions (Q reads eliminated, V loads vectorized) -- Reduced LDS bank conflicts - -Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). -Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. -Block: (256,) -- 4 waves of 64 on AMD (wave64). - -Requires: head_dim % 16 == 0, seq_len % 64 == 0, head_dim >= 64. -""" - -import math - -from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl -from flydsl.dialects.ext import vector as vec_ext -from flydsl.dialects.ext.python_control_flow import range_constexpr -from flydsl.dialects.ext.scf import yield_ as scf_yield -from _mlir.dialects import memref as _memref -from flydsl.runtime.device import get_rocm_arch as get_hip_arch -from flydsl.utils import SmemAllocator -from _mlir import ir -import _mlir.extras.types as T - - -KERNEL_NAME = "flash_attention_v4_1_kernel" - - -def build_flash_attention_v4_1_module( - num_heads, - head_dim, - causal=True, - dtype_str="f16", - sm_scale=None, -): - """Build a FlyDSL Flash Attention V4.1 module. - - Args: - num_heads: Number of attention heads. - head_dim: Dimension per head (must be divisible by 16, >= 64). - causal: Whether to apply causal mask. - dtype_str: "f16" (bf16 not yet supported). - sm_scale: Softmax scale (default: 1/sqrt(head_dim)). - - Returns: - MlirModule compilable via ``flydsl.compile(module)``. - """ - gpu_arch = get_hip_arch() - DYN = ir.ShapedType.get_dynamic_size() - - BLOCK_M = 64 - BLOCK_N = 16 - NUM_WAVES = 4 - WARP_SIZE = 64 - BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 - ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 16 - K_STEPS = head_dim // 16 - - assert head_dim % 16 == 0, f"head_dim ({head_dim}) must be divisible by 16" - assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" - assert dtype_str == "f16", "V4.1 currently only supports f16" - - if sm_scale is None: - sm_scale = 1.0 / math.sqrt(head_dim) - - NUM_HEADS = num_heads - HEAD_DIM = head_dim - CAUSAL = causal - STRIDE_TOKEN = NUM_HEADS * HEAD_DIM - - # ---- Bank-conflict-free LDS strides ---- - # K row-major: stride = HD + 2 (makes row stride odd in bank units) - # V transposed: stride = BLOCK_N + 2 (same reasoning) - K_STRIDE = HEAD_DIM + 2 # 130 for HD=128 - VT_STRIDE = BLOCK_N + 2 # 18 for BLOCK_N=16 - - # ---- Vectorized cooperative load constants ---- - VEC_WIDTH = 8 # v8f16 = 16 bytes - THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH - assert BLOCK_SIZE % THREADS_PER_ROW_LOAD == 0 - ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD - - # For Q tile (64 rows) - assert BLOCK_M % ROWS_PER_BATCH_LOAD == 0 - NUM_BATCHES_Q = BLOCK_M // ROWS_PER_BATCH_LOAD - - # For KV tile (16 rows) - assert BLOCK_N % ROWS_PER_BATCH_LOAD == 0 or ROWS_PER_BATCH_LOAD >= BLOCK_N - if ROWS_PER_BATCH_LOAD >= BLOCK_N: - NUM_BATCHES_KV = 1 - KV_NEEDS_GUARD = ROWS_PER_BATCH_LOAD > BLOCK_N - else: - NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD - KV_NEEDS_GUARD = False - - # LDS sizes - LDS_Q_SIZE = BLOCK_M * HEAD_DIM # Q unpadded (only read once for register preload) - LDS_KV_SIZE = max(BLOCK_N * K_STRIDE, HEAD_DIM * VT_STRIDE) # max(K padded, Vt padded) - LDS_P_SIZE = BLOCK_M * BLOCK_N - - allocator = SmemAllocator(None, arch=gpu_arch) - _state = {} - - class _FlashAttentionV4_1(flir.MlirModule): - GPU_MODULE_NAME = f"flash_attn_v4_1_{dtype_str}" - GPU_MODULE_TARGETS = [f'#rocdl.target'] - - def init_gpu_module(self): - elem_type = T.f16() - _state["elem_type"] = elem_type - _state["lds_q"] = allocator.allocate_array(elem_type, LDS_Q_SIZE) - _state["lds_kv"] = allocator.allocate_array(elem_type, LDS_KV_SIZE) - _state["lds_p"] = allocator.allocate_array(elem_type, LDS_P_SIZE) - allocator.finalize() - - @flir.kernel - def flash_attention_v4_1_kernel( - self: flir.T.i64, - Q: lambda: T.memref(DYN, _state["elem_type"]), - K: lambda: T.memref(DYN, _state["elem_type"]), - V: lambda: T.memref(DYN, _state["elem_type"]), - O: lambda: T.memref(DYN, _state["elem_type"]), - seq_len: lambda: T.index(), - ): - compute_type = T.f32() - elem_type = _state["elem_type"] - fm_fast = flir.arith.FastMathFlags.fast - - v4f16_type = ir.VectorType.get([4], elem_type) - v4f32_type = ir.VectorType.get([4], compute_type) - v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) - - seq_len_v = arith.as_value(seq_len) - - # ---- LDS views ---- - base_ptr = allocator.get_base() - lds_q = _state["lds_q"](base_ptr).get() - lds_kv = _state["lds_kv"](base_ptr).get() - lds_p = _state["lds_p"](base_ptr).get() - - # ---- Thread / block indices ---- - block_id = flir.const_index(flir.block_idx("x")) - tid = flir.const_index(flir.thread_idx("x")) - - # ---- Wave decomposition ---- - c_ws = flir.const_index(WARP_SIZE) - wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) - lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) - - # ---- MFMA lane decomposition (within each wave) ---- - c16 = flir.const_index(16) - lane_div_16 = arith.as_value(flir.arith.DivUIOp(lane, c16).result) - lane_mod_16 = arith.as_value(flir.arith.RemUIOp(lane, c16).result) - - # ---- Wave's Q-row offset in the Q tile ---- - wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value - # Wave's P offset in lds_p - wave_p_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE * BLOCK_N).value - - # ---- Decompose block_id -> (batch_idx, q_tile_idx, head_idx) ---- - c_nh = flir.const_index(NUM_HEADS) - head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) - temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) - c_bm = flir.const_index(BLOCK_M) - num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) - q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) - batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) - q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value - - # ---- Vectorized load thread decomposition ---- - c_tpr = flir.const_index(THREADS_PER_ROW_LOAD) - load_row_in_batch = arith.as_value( - flir.arith.DivUIOp(tid, c_tpr).result - ) - load_lane_in_row = arith.as_value( - flir.arith.RemUIOp(tid, c_tpr).result - ) - load_col_base = ( - arith.ArithValue(load_lane_in_row) * VEC_WIDTH - ).value - - # ---- Helper: global flat index ---- - def global_idx(token_idx, col): - token = ( - arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) - + arith.ArithValue(token_idx) - ) - return ( - token * STRIDE_TOKEN - + arith.ArithValue(head_idx) * HEAD_DIM - + arith.ArithValue(col) - ).value - - # ---- Cooperative Q load (64 rows, all 256 threads, unpadded) ---- - def coop_load_q(): - for batch in range_constexpr(NUM_BATCHES_Q): - row_offset = batch * ROWS_PER_BATCH_LOAD - row_idx = ( - arith.ArithValue(q_start) - + arith.ArithValue(load_row_in_batch) - + row_offset - ).value - g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value( - vec_ext.load_op(v8f16_type, Q, [g_idx]) - ) - lds_row = ( - arith.ArithValue(load_row_in_batch) + row_offset - ).value - lds_idx = ( - arith.ArithValue(lds_row) * HEAD_DIM - + arith.ArithValue(load_col_base) - ).value - vec_ext.store(vec, lds_q, [lds_idx]) - - # ---- Cooperative K load (row-major with padded stride) ---- - def coop_load_k(tile_start): - if KV_NEEDS_GUARD: - c_bn = flir.const_index(BLOCK_N) - row_ok = arith.as_value( - flir.arith.CmpIOp( - flir.arith.CmpIPredicate.ult, - load_row_in_batch, c_bn, - ).result - ) - from flydsl.dialects.ext.scf import IfOp - if_op = IfOp(row_ok) - with if_op: - row_idx = ( - arith.ArithValue(tile_start) - + arith.ArithValue(load_row_in_batch) - ).value - g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value( - vec_ext.load_op(v8f16_type, K, [g_idx]) - ) - lds_idx = ( - arith.ArithValue(load_row_in_batch) * K_STRIDE - + arith.ArithValue(load_col_base) - ).value - vec_ext.store(vec, lds_kv, [lds_idx]) - else: - for batch in range_constexpr(NUM_BATCHES_KV): - row_offset = batch * ROWS_PER_BATCH_LOAD - row_idx = ( - arith.ArithValue(tile_start) - + arith.ArithValue(load_row_in_batch) - + row_offset - ).value - g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value( - vec_ext.load_op(v8f16_type, K, [g_idx]) - ) - lds_row = ( - arith.ArithValue(load_row_in_batch) + row_offset - ).value - lds_idx = ( - arith.ArithValue(lds_row) * K_STRIDE - + arith.ArithValue(load_col_base) - ).value - vec_ext.store(vec, lds_kv, [lds_idx]) - - # ---- Cooperative V load (transposed with padded stride) ---- - # Global V[row, col] -> LDS Vt[col, row] at lds_kv[col * VT_STRIDE + row] - def coop_load_v_transposed(tile_start): - if KV_NEEDS_GUARD: - c_bn = flir.const_index(BLOCK_N) - row_ok = arith.as_value( - flir.arith.CmpIOp( - flir.arith.CmpIPredicate.ult, - load_row_in_batch, c_bn, - ).result - ) - from flydsl.dialects.ext.scf import IfOp - if_op = IfOp(row_ok) - with if_op: - row_idx = ( - arith.ArithValue(tile_start) - + arith.ArithValue(load_row_in_batch) - ).value - g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value( - vec_ext.load_op(v8f16_type, V, [g_idx]) - ) - # Scatter-store transposed: V[row, col+e] -> lds[col_e * VT_STRIDE + row] - for e in range_constexpr(VEC_WIDTH): - elem = arith.as_value( - vec_ext.extract(vec, static_position=[e], dynamic_position=[]) - ) - col_e = (arith.ArithValue(load_col_base) + e).value - lds_idx = ( - arith.ArithValue(col_e) * VT_STRIDE - + arith.ArithValue(load_row_in_batch) - ).value - _memref.StoreOp(elem, lds_kv, [lds_idx]) - else: - for batch in range_constexpr(NUM_BATCHES_KV): - row_offset = batch * ROWS_PER_BATCH_LOAD - row_idx = ( - arith.ArithValue(tile_start) - + arith.ArithValue(load_row_in_batch) - + row_offset - ).value - g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value( - vec_ext.load_op(v8f16_type, V, [g_idx]) - ) - load_row = ( - arith.ArithValue(load_row_in_batch) + row_offset - ).value - # Scatter-store transposed - for e in range_constexpr(VEC_WIDTH): - elem = arith.as_value( - vec_ext.extract(vec, static_position=[e], dynamic_position=[]) - ) - col_e = (arith.ArithValue(load_col_base) + e).value - lds_idx = ( - arith.ArithValue(col_e) * VT_STRIDE - + arith.ArithValue(load_row) - ).value - _memref.StoreOp(elem, lds_kv, [lds_idx]) - - # ---- Load Q tile to LDS ---- - coop_load_q() - gpu.barrier() - - # ---- Preload Q A-operand packs into registers ---- - # Each lane loads K_STEPS v4f16 packs from LDS_Q (one-time cost). - # At step ks, thread (b,n) needs Q[wave_row + n, ks*16 + b*4 : ks*16+b*4+4] - q_a_packs = [] - for ks in range_constexpr(K_STEPS): - q_lds_idx = ( - (arith.ArithValue(wave_q_offset) - + arith.ArithValue(lane_mod_16)) * HEAD_DIM - + ks * 16 - + arith.ArithValue(lane_div_16) * 4 - ).value - q_a_packs.append(arith.as_value( - vec_ext.load_op(v4f16_type, lds_q, [q_lds_idx]) - )) - - # ---- Constants ---- - c_neg_inf = arith.constant(float("-inf"), type=compute_type) - c_zero_f = arith.constant(0.0, type=compute_type) - c_sm_scale = arith.constant(sm_scale, type=compute_type) - c_log2e = arith.constant(1.4426950408889634, type=compute_type) - c_zero_v4f32 = arith.as_value( - arith.constant_vector(0.0, v4f32_type) - ) - - # ---- Init loop-carried state ---- - # [m_0..m_3, l_0..l_3, o_acc_0..o_acc_{K_STEPS-1}] - init_args = [] - for _ in range_constexpr(4): - init_args.append(arith.as_value(c_neg_inf)) - for _ in range_constexpr(4): - init_args.append(arith.as_value(c_zero_f)) - for _ in range_constexpr(K_STEPS): - init_args.append(c_zero_v4f32) - - # ---- KV loop upper bound ---- - # Causal early-exit: last Q row = q_start + BLOCK_M - 1, - # so only need KV positions 0 .. q_start + BLOCK_M - 1. - if CAUSAL: - kv_upper = (arith.ArithValue(q_start) + BLOCK_M).value - else: - kv_upper = seq_len_v - - # ---- KV loop ---- - with scf.for_(0, kv_upper, BLOCK_N, iter_args=init_args) as loop: - kv_start = arith.as_value(loop.induction_variable) - m_old = [arith.as_value(loop.inner_iter_args[i]) for i in range(4)] - l_old = [arith.as_value(loop.inner_iter_args[4 + i]) for i in range(4)] - o_accs = [arith.as_value(loop.inner_iter_args[8 + ds]) for ds in range(K_STEPS)] - - # ==== Cooperative K load -> LDS_KV (row-major, padded stride) ==== - coop_load_k(kv_start) - gpu.barrier() - - # ==== Q @ K^T via MFMA (Q from registers, K from LDS) ==== - s_acc = c_zero_v4f32 - for ks in range_constexpr(K_STEPS): - # A operand (Q): from preloaded registers - a_pack = q_a_packs[ks] - # B operand (K^T): from LDS with padded stride - k_lds_idx = ( - arith.ArithValue(lane_mod_16) * K_STRIDE - + ks * 16 - + arith.ArithValue(lane_div_16) * 4 - ).value - b_pack = arith.as_value( - vec_ext.load_op(v4f16_type, lds_kv, [k_lds_idx]) - ) - s_acc = arith.as_value( - rocdl.mfma_f32_16x16x16f16(v4f32_type, [a_pack, b_pack, s_acc, 0, 0, 0]) - ) - - # ==== Online softmax (per-wave, per-row) ==== - s_vals = [] - for ii in range_constexpr(4): - s_ii = arith.as_value( - vec_ext.extract(s_acc, static_position=[ii], dynamic_position=[]) - ) - s_ii = arith.as_value( - flir.arith.MulFOp(s_ii, arith.as_value(c_sm_scale), fastmath=fm_fast).result - ) - if CAUSAL: - q_row = ( - arith.ArithValue(q_start) - + arith.ArithValue(wave_q_offset) - + arith.ArithValue(lane_div_16) * 4 - + ii - ).value - kv_col = (arith.ArithValue(kv_start) + arith.ArithValue(lane_mod_16)).value - q_row_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), q_row).result) - kv_col_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), kv_col).result) - is_masked = arith.as_value( - flir.arith.CmpIOp( - flir.arith.CmpIPredicate.ugt, kv_col_i64, q_row_i64, - ).result - ) - s_ii = arith.as_value( - flir.arith.SelectOp(is_masked, arith.as_value(c_neg_inf), s_ii).result - ) - s_vals.append(s_ii) - - width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) - m_new = [None] * 4 - corr = [None] * 4 - p_vals = [None] * 4 - l_new = [None] * 4 - - for ii in range_constexpr(4): - row_max = s_vals[ii] - for sh in [8, 4, 2, 1]: - sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value( - gpu.ShuffleOp(row_max, sh_i32, width_i32, mode="xor").shuffleResult - ) - row_max = arith.as_value( - flir.arith.MaximumFOp(row_max, peer).result - ) - - m_new[ii] = arith.as_value( - flir.arith.MaximumFOp(m_old[ii], row_max).result - ) - - diff_m = arith.as_value( - flir.arith.SubFOp(m_old[ii], m_new[ii], fastmath=fm_fast).result - ) - diff_m_s = arith.as_value( - flir.arith.MulFOp(diff_m, arith.as_value(c_log2e), fastmath=fm_fast).result - ) - corr[ii] = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) - - diff_s = arith.as_value( - flir.arith.SubFOp(s_vals[ii], m_new[ii], fastmath=fm_fast).result - ) - diff_s_s = arith.as_value( - flir.arith.MulFOp(diff_s, arith.as_value(c_log2e), fastmath=fm_fast).result - ) - p_vals[ii] = arith.as_value(flir.math.exp2(diff_s_s, fastmath=fm_fast)) - - row_sum = p_vals[ii] - for sh in [8, 4, 2, 1]: - sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value( - gpu.ShuffleOp(row_sum, sh_i32, width_i32, mode="xor").shuffleResult - ) - row_sum = arith.as_value( - flir.arith.AddFOp(row_sum, peer, fastmath=fm_fast).result - ) - - l_corr = arith.as_value( - flir.arith.MulFOp(corr[ii], l_old[ii], fastmath=fm_fast).result - ) - l_new[ii] = arith.as_value( - flir.arith.AddFOp(l_corr, row_sum, fastmath=fm_fast).result - ) - - # ==== Rescale O accumulators ==== - corr_vec = arith.as_value( - vec_ext.from_elements(v4f32_type, [corr[0], corr[1], corr[2], corr[3]]) - ) - for ds in range_constexpr(K_STEPS): - o_accs[ds] = arith.as_value( - flir.arith.MulFOp(o_accs[ds], corr_vec, fastmath=fm_fast).result - ) - - # ==== P store to LDS_P (each wave writes its 16x16 section) ==== - for ii in range_constexpr(4): - p_f16 = arith.as_value( - flir.arith.TruncFOp(elem_type, p_vals[ii]).result - ) - p_row = (arith.ArithValue(lane_div_16) * 4 + ii).value - p_lds_idx = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(p_row) * BLOCK_N - + arith.ArithValue(lane_mod_16) - ).value - _memref.StoreOp(p_f16, lds_p, [p_lds_idx]) - - # ==== Barrier: ensure all waves finished reading K from lds_kv ==== - gpu.barrier() - - # ==== Cooperative V load -> LDS_KV (transposed, padded stride) ==== - coop_load_v_transposed(kv_start) - gpu.barrier() - - # ==== P load (A-operand, wave-local) ==== - p_a_idx = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(lane_mod_16) * BLOCK_N - + arith.ArithValue(lane_div_16) * 4 - ).value - p_pack = arith.as_value( - vec_ext.load_op(v4f16_type, lds_p, [p_a_idx]) - ) - - # ==== P @ V via MFMA (V from transposed LDS, vectorized v4f16 loads) ==== - # V transposed: V[row, col] at lds_kv[col * VT_STRIDE + row] - # B-operand: V[b*4:b*4+4, ds*16+n] = lds_kv[(ds*16+n) * VT_STRIDE + b*4] - # -> v4f16 at base (ds*16 + lane_mod_16) * VT_STRIDE + lane_div_16 * 4 - for ds in range_constexpr(K_STEPS): - v_lds_idx = ( - (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE - + arith.ArithValue(lane_div_16) * 4 - ).value - v_pack = arith.as_value( - vec_ext.load_op(v4f16_type, lds_kv, [v_lds_idx]) - ) - o_accs[ds] = arith.as_value( - rocdl.mfma_f32_16x16x16f16( - v4f32_type, [p_pack, v_pack, o_accs[ds], 0, 0, 0] - ) - ) - - # ==== Barrier: ensure all waves finished P@V (reading lds_kv) - # before next iteration overwrites lds_kv with K ==== - gpu.barrier() - - # ==== Yield ==== - yield_args = m_new + l_new + o_accs - scf_yield(yield_args) - - # ---- Normalize and store O ---- - m_finals = [arith.as_value(loop.results[i]) for i in range(4)] - l_finals = [arith.as_value(loop.results[4 + i]) for i in range(4)] - o_finals = [arith.as_value(loop.results[8 + ds]) for ds in range(K_STEPS)] - - for ds in range_constexpr(K_STEPS): - for ii in range_constexpr(4): - o_val = arith.as_value( - vec_ext.extract(o_finals[ds], static_position=[ii], dynamic_position=[]) - ) - o_norm = arith.as_value( - flir.arith.DivFOp(o_val, l_finals[ii], fastmath=fm_fast).result - ) - o_f16 = arith.as_value( - flir.arith.TruncFOp(elem_type, o_norm).result - ) - q_row = ( - arith.ArithValue(q_start) - + arith.ArithValue(wave_q_offset) - + arith.ArithValue(lane_div_16) * 4 - + ii - ).value - d_col = (flir.const_index(ds * 16) + arith.ArithValue(lane_mod_16)).value - o_global = global_idx(q_row, d_col) - _memref.StoreOp(o_f16, O, [o_global]) - - @flir.jit - def __call__( - self: flir.T.i64, - Q: lambda: T.memref(DYN, _state["elem_type"]), - K: lambda: T.memref(DYN, _state["elem_type"]), - V: lambda: T.memref(DYN, _state["elem_type"]), - O: lambda: T.memref(DYN, _state["elem_type"]), - batch_size: lambda: T.index(), - seq_len: lambda: T.index(), - ): - c1 = arith.as_value(flir.arith_ext.index(1)) - c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) - c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) - bs_val = arith.as_value(batch_size) - sl_val = arith.as_value(seq_len) - num_q_tiles = arith.as_value( - flir.arith.DivUIOp(sl_val, c_bm).result - ) - bs_qt = arith.as_value( - flir.arith.MulIOp(bs_val, num_q_tiles).result - ) - grid_x = arith.as_value( - flir.arith.MulIOp(bs_qt, c_nh).result - ) - bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) - flir.gpu_ext.LaunchFuncOp( - [self.GPU_MODULE_NAME, KERNEL_NAME], - grid_size=(grid_x, c1, c1), - block_size=(bx, c1, c1), - kernel_operands=[Q, K, V, O, seq_len], - ) - - return _FlashAttentionV4_1() diff --git a/kernels/flash_attention_v4_2.py b/kernels/flash_attention_v4_2.py deleted file mode 100644 index de6f28ca..00000000 --- a/kernels/flash_attention_v4_2.py +++ /dev/null @@ -1,667 +0,0 @@ -"""Flash Attention V4.2 kernel builder for FlyDSL. - -V4.2 optimizations over V4.1: -- BLOCK_N=32 (vs 16): halves KV iterations and barriers -- Q@K^T produces [16,32] via two MFMA 16x16x16 in N dimension -- P@V uses K=32 via two MFMA 16x16x16 in K dimension -- Softmax over 32 positions per row (two 16-wide groups) -- V stored transposed in LDS with bank-conflict-free padding (from V4.1) -- Q preloaded to registers (from V4.1) - -Tile config: BLOCK_M=64, BLOCK_N=32, 4 waves (256 threads), mfma_f32_16x16x16f16. - -Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). -Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. -Block: (256,) -- 4 waves of 64 on AMD (wave64). - -Requires: head_dim % 16 == 0, seq_len % 64 == 0, head_dim >= 64. -""" - -import math - -from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl -from flydsl.dialects.ext import vector as vec_ext -from flydsl.dialects.ext.python_control_flow import range_constexpr -from flydsl.dialects.ext.scf import yield_ as scf_yield -from _mlir.dialects import memref as _memref -from flydsl.runtime.device import get_rocm_arch as get_hip_arch -from flydsl.utils import SmemAllocator -from _mlir import ir -import _mlir.extras.types as T - - -KERNEL_NAME = "flash_attention_v4_2_kernel" - - -def build_flash_attention_v4_2_module( - num_heads, - head_dim, - causal=True, - dtype_str="f16", - sm_scale=None, -): - """Build a FlyDSL Flash Attention V4.2 module. - - Args: - num_heads: Number of attention heads. - head_dim: Dimension per head (must be divisible by 16, >= 64). - causal: Whether to apply causal mask. - dtype_str: "f16" (bf16 not yet supported). - sm_scale: Softmax scale (default: 1/sqrt(head_dim)). - - Returns: - MlirModule compilable via ``flydsl.compile(module)``. - """ - gpu_arch = get_hip_arch() - DYN = ir.ShapedType.get_dynamic_size() - - BLOCK_M = 64 - BLOCK_N = 32 # *** doubled from V4.1 *** - NUM_WAVES = 4 - WARP_SIZE = 64 - BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 - ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 16 - K_STEPS = head_dim // 16 - # Number of 16-wide MFMA columns in Q@K^T N-dimension - N_MFMA = BLOCK_N // 16 # 2 - - assert head_dim % 16 == 0, f"head_dim ({head_dim}) must be divisible by 16" - assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" - assert dtype_str == "f16", "V4.2 currently only supports f16" - assert BLOCK_N % 16 == 0, f"BLOCK_N ({BLOCK_N}) must be divisible by 16" - - if sm_scale is None: - sm_scale = 1.0 / math.sqrt(head_dim) - - NUM_HEADS = num_heads - HEAD_DIM = head_dim - CAUSAL = causal - STRIDE_TOKEN = NUM_HEADS * HEAD_DIM - - # ---- Bank-conflict-free LDS strides ---- - K_STRIDE = HEAD_DIM + 2 # 130 for HD=128 - VT_STRIDE = BLOCK_N + 2 # 34 for BLOCK_N=32 - - # ---- Vectorized cooperative load constants ---- - VEC_WIDTH = 8 - THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH - assert BLOCK_SIZE % THREADS_PER_ROW_LOAD == 0 - ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD - - assert BLOCK_M % ROWS_PER_BATCH_LOAD == 0 - NUM_BATCHES_Q = BLOCK_M // ROWS_PER_BATCH_LOAD - - # For KV tile (32 rows with 256 threads) - if ROWS_PER_BATCH_LOAD >= BLOCK_N: - NUM_BATCHES_KV = 1 - KV_NEEDS_GUARD = ROWS_PER_BATCH_LOAD > BLOCK_N - else: - assert BLOCK_N % ROWS_PER_BATCH_LOAD == 0 - NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD - KV_NEEDS_GUARD = False - - # LDS sizes - LDS_Q_SIZE = BLOCK_M * HEAD_DIM - LDS_KV_SIZE = max(BLOCK_N * K_STRIDE, HEAD_DIM * VT_STRIDE) - LDS_P_SIZE = BLOCK_M * BLOCK_N # 64*32 = 2048 - - allocator = SmemAllocator(None, arch=gpu_arch) - _state = {} - - class _FlashAttentionV4_2(flir.MlirModule): - GPU_MODULE_NAME = f"flash_attn_v4_2_{dtype_str}" - GPU_MODULE_TARGETS = [f'#rocdl.target'] - - def init_gpu_module(self): - elem_type = T.f16() - _state["elem_type"] = elem_type - _state["lds_q"] = allocator.allocate_array(elem_type, LDS_Q_SIZE) - _state["lds_kv"] = allocator.allocate_array(elem_type, LDS_KV_SIZE) - _state["lds_p"] = allocator.allocate_array(elem_type, LDS_P_SIZE) - allocator.finalize() - - @flir.kernel - def flash_attention_v4_2_kernel( - self: flir.T.i64, - Q: lambda: T.memref(DYN, _state["elem_type"]), - K: lambda: T.memref(DYN, _state["elem_type"]), - V: lambda: T.memref(DYN, _state["elem_type"]), - O: lambda: T.memref(DYN, _state["elem_type"]), - seq_len: lambda: T.index(), - ): - compute_type = T.f32() - elem_type = _state["elem_type"] - fm_fast = flir.arith.FastMathFlags.fast - - v4f16_type = ir.VectorType.get([4], elem_type) - v4f32_type = ir.VectorType.get([4], compute_type) - v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) - - seq_len_v = arith.as_value(seq_len) - - # ---- LDS views ---- - base_ptr = allocator.get_base() - lds_q = _state["lds_q"](base_ptr).get() - lds_kv = _state["lds_kv"](base_ptr).get() - lds_p = _state["lds_p"](base_ptr).get() - - # ---- Thread / block indices ---- - block_id = flir.const_index(flir.block_idx("x")) - tid = flir.const_index(flir.thread_idx("x")) - - # ---- Wave decomposition ---- - c_ws = flir.const_index(WARP_SIZE) - wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) - lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) - - # ---- MFMA lane decomposition ---- - c16 = flir.const_index(16) - lane_div_16 = arith.as_value(flir.arith.DivUIOp(lane, c16).result) - lane_mod_16 = arith.as_value(flir.arith.RemUIOp(lane, c16).result) - - # ---- Wave offsets ---- - wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value - wave_p_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE * BLOCK_N).value - - # ---- Decompose block_id ---- - c_nh = flir.const_index(NUM_HEADS) - head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) - temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) - c_bm = flir.const_index(BLOCK_M) - num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) - q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) - batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) - q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value - - # ---- Load thread decomposition ---- - c_tpr = flir.const_index(THREADS_PER_ROW_LOAD) - load_row_in_batch = arith.as_value( - flir.arith.DivUIOp(tid, c_tpr).result - ) - load_lane_in_row = arith.as_value( - flir.arith.RemUIOp(tid, c_tpr).result - ) - load_col_base = ( - arith.ArithValue(load_lane_in_row) * VEC_WIDTH - ).value - - # ---- Helper: global flat index ---- - def global_idx(token_idx, col): - token = ( - arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) - + arith.ArithValue(token_idx) - ) - return ( - token * STRIDE_TOKEN - + arith.ArithValue(head_idx) * HEAD_DIM - + arith.ArithValue(col) - ).value - - # ---- Cooperative Q load (64 rows, unpadded) ---- - def coop_load_q(): - for batch in range_constexpr(NUM_BATCHES_Q): - row_offset = batch * ROWS_PER_BATCH_LOAD - row_idx = ( - arith.ArithValue(q_start) - + arith.ArithValue(load_row_in_batch) - + row_offset - ).value - g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value( - vec_ext.load_op(v8f16_type, Q, [g_idx]) - ) - lds_row = ( - arith.ArithValue(load_row_in_batch) + row_offset - ).value - lds_idx = ( - arith.ArithValue(lds_row) * HEAD_DIM - + arith.ArithValue(load_col_base) - ).value - vec_ext.store(vec, lds_q, [lds_idx]) - - # ---- Cooperative K load (row-major, padded stride) ---- - def coop_load_k(tile_start): - for batch in range_constexpr(NUM_BATCHES_KV): - row_offset = batch * ROWS_PER_BATCH_LOAD - row_idx = ( - arith.ArithValue(tile_start) - + arith.ArithValue(load_row_in_batch) - + row_offset - ).value - g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value( - vec_ext.load_op(v8f16_type, K, [g_idx]) - ) - lds_row = ( - arith.ArithValue(load_row_in_batch) + row_offset - ).value - lds_idx = ( - arith.ArithValue(lds_row) * K_STRIDE - + arith.ArithValue(load_col_base) - ).value - vec_ext.store(vec, lds_kv, [lds_idx]) - - # ---- Cooperative V load (transposed, padded stride) ---- - def coop_load_v_transposed(tile_start): - for batch in range_constexpr(NUM_BATCHES_KV): - row_offset = batch * ROWS_PER_BATCH_LOAD - row_idx = ( - arith.ArithValue(tile_start) - + arith.ArithValue(load_row_in_batch) - + row_offset - ).value - g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value( - vec_ext.load_op(v8f16_type, V, [g_idx]) - ) - load_row = ( - arith.ArithValue(load_row_in_batch) + row_offset - ).value - # Scatter-store transposed: V[row, col+e] -> lds[(col+e)*VT_STRIDE + row] - for e in range_constexpr(VEC_WIDTH): - elem = arith.as_value( - vec_ext.extract(vec, static_position=[e], dynamic_position=[]) - ) - col_e = (arith.ArithValue(load_col_base) + e).value - lds_idx = ( - arith.ArithValue(col_e) * VT_STRIDE - + arith.ArithValue(load_row) - ).value - _memref.StoreOp(elem, lds_kv, [lds_idx]) - - # ---- Load Q tile to LDS ---- - coop_load_q() - gpu.barrier() - - # ---- Preload Q A-operand packs into registers ---- - q_a_packs = [] - for ks in range_constexpr(K_STEPS): - q_lds_idx = ( - (arith.ArithValue(wave_q_offset) - + arith.ArithValue(lane_mod_16)) * HEAD_DIM - + ks * 16 - + arith.ArithValue(lane_div_16) * 4 - ).value - q_a_packs.append(arith.as_value( - vec_ext.load_op(v4f16_type, lds_q, [q_lds_idx]) - )) - - # ---- Constants ---- - c_neg_inf = arith.constant(float("-inf"), type=compute_type) - c_zero_f = arith.constant(0.0, type=compute_type) - c_sm_scale = arith.constant(sm_scale, type=compute_type) - c_log2e = arith.constant(1.4426950408889634, type=compute_type) - c_zero_v4f32 = arith.as_value( - arith.constant_vector(0.0, v4f32_type) - ) - - # ---- Init loop-carried state ---- - # m[4], l[4], o_accs[K_STEPS] - init_args = [] - for _ in range_constexpr(4): - init_args.append(arith.as_value(c_neg_inf)) - for _ in range_constexpr(4): - init_args.append(arith.as_value(c_zero_f)) - for _ in range_constexpr(K_STEPS): - init_args.append(c_zero_v4f32) - - # ---- KV loop upper bound ---- - # Causal early-exit: last Q row = q_start + BLOCK_M - 1, - # so only need KV positions 0 .. q_start + BLOCK_M - 1. - # q_start + BLOCK_M is always a multiple of BLOCK_N (64 % 32 == 0). - if CAUSAL: - kv_upper = (arith.ArithValue(q_start) + BLOCK_M).value - else: - kv_upper = seq_len_v - - # ---- KV loop (step BLOCK_N=32) ---- - with scf.for_(0, kv_upper, BLOCK_N, iter_args=init_args) as loop: - kv_start = arith.as_value(loop.induction_variable) - m_old = [arith.as_value(loop.inner_iter_args[i]) for i in range(4)] - l_old = [arith.as_value(loop.inner_iter_args[4 + i]) for i in range(4)] - o_accs = [arith.as_value(loop.inner_iter_args[8 + ds]) for ds in range(K_STEPS)] - - # ==== Cooperative K load -> LDS_KV (32 rows, padded stride) ==== - coop_load_k(kv_start) - gpu.barrier() - - # ==== Q @ K^T via MFMA -> S[16, 32] ==== - # Two MFMA outputs: s_acc[0] for KV cols 0..15, s_acc[1] for KV cols 16..31 - s_accs = [c_zero_v4f32, c_zero_v4f32] - for ks in range_constexpr(K_STEPS): - a_pack = q_a_packs[ks] - for nm in range_constexpr(N_MFMA): - # B operand (K^T): K row = nm*16 + lane_mod_16 - k_row = nm * 16 - k_lds_idx = ( - (arith.ArithValue(lane_mod_16) + k_row) * K_STRIDE - + ks * 16 - + arith.ArithValue(lane_div_16) * 4 - ).value - b_pack = arith.as_value( - vec_ext.load_op(v4f16_type, lds_kv, [k_lds_idx]) - ) - s_accs[nm] = arith.as_value( - rocdl.mfma_f32_16x16x16f16( - v4f32_type, [a_pack, b_pack, s_accs[nm], 0, 0, 0] - ) - ) - - # ==== Online softmax over 32 positions ==== - # For each row ii (0..3): have values at lane_mod_16 in s_accs[0] and s_accs[1] - # Need max and sum over all 32 positions - s_vals_lo = [] # from s_accs[0], KV cols 0..15 - s_vals_hi = [] # from s_accs[1], KV cols 16..31 - for ii in range_constexpr(4): - s_lo = arith.as_value( - vec_ext.extract(s_accs[0], static_position=[ii], dynamic_position=[]) - ) - s_lo = arith.as_value( - flir.arith.MulFOp(s_lo, arith.as_value(c_sm_scale), fastmath=fm_fast).result - ) - s_hi = arith.as_value( - vec_ext.extract(s_accs[1], static_position=[ii], dynamic_position=[]) - ) - s_hi = arith.as_value( - flir.arith.MulFOp(s_hi, arith.as_value(c_sm_scale), fastmath=fm_fast).result - ) - - if CAUSAL: - q_row = ( - arith.ArithValue(q_start) - + arith.ArithValue(wave_q_offset) - + arith.ArithValue(lane_div_16) * 4 - + ii - ).value - # Low half: KV col = kv_start + lane_mod_16 - kv_col_lo = (arith.ArithValue(kv_start) + arith.ArithValue(lane_mod_16)).value - # High half: KV col = kv_start + 16 + lane_mod_16 - kv_col_hi = (arith.ArithValue(kv_start) + 16 + arith.ArithValue(lane_mod_16)).value - - q_row_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), q_row).result) - kv_lo_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), kv_col_lo).result) - kv_hi_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), kv_col_hi).result) - - is_masked_lo = arith.as_value( - flir.arith.CmpIOp( - flir.arith.CmpIPredicate.ugt, kv_lo_i64, q_row_i64, - ).result - ) - is_masked_hi = arith.as_value( - flir.arith.CmpIOp( - flir.arith.CmpIPredicate.ugt, kv_hi_i64, q_row_i64, - ).result - ) - s_lo = arith.as_value( - flir.arith.SelectOp(is_masked_lo, arith.as_value(c_neg_inf), s_lo).result - ) - s_hi = arith.as_value( - flir.arith.SelectOp(is_masked_hi, arith.as_value(c_neg_inf), s_hi).result - ) - - s_vals_lo.append(s_lo) - s_vals_hi.append(s_hi) - - width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) - m_new = [None] * 4 - corr = [None] * 4 - p_vals_lo = [None] * 4 - p_vals_hi = [None] * 4 - l_new = [None] * 4 - - for ii in range_constexpr(4): - # Max over 32 positions: max of lo-half and hi-half - row_max_lo = s_vals_lo[ii] - row_max_hi = s_vals_hi[ii] - - # Reduce lo-half within 16 lanes - for sh in [8, 4, 2, 1]: - sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value( - gpu.ShuffleOp(row_max_lo, sh_i32, width_i32, mode="xor").shuffleResult - ) - row_max_lo = arith.as_value( - flir.arith.MaximumFOp(row_max_lo, peer).result - ) - - # Reduce hi-half within 16 lanes - for sh in [8, 4, 2, 1]: - sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value( - gpu.ShuffleOp(row_max_hi, sh_i32, width_i32, mode="xor").shuffleResult - ) - row_max_hi = arith.as_value( - flir.arith.MaximumFOp(row_max_hi, peer).result - ) - - # Combine lo and hi maxes - row_max = arith.as_value( - flir.arith.MaximumFOp(row_max_lo, row_max_hi).result - ) - - m_new[ii] = arith.as_value( - flir.arith.MaximumFOp(m_old[ii], row_max).result - ) - - diff_m = arith.as_value( - flir.arith.SubFOp(m_old[ii], m_new[ii], fastmath=fm_fast).result - ) - diff_m_s = arith.as_value( - flir.arith.MulFOp(diff_m, arith.as_value(c_log2e), fastmath=fm_fast).result - ) - corr[ii] = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) - - # exp2 for both halves - diff_lo = arith.as_value( - flir.arith.SubFOp(s_vals_lo[ii], m_new[ii], fastmath=fm_fast).result - ) - diff_lo_s = arith.as_value( - flir.arith.MulFOp(diff_lo, arith.as_value(c_log2e), fastmath=fm_fast).result - ) - p_vals_lo[ii] = arith.as_value(flir.math.exp2(diff_lo_s, fastmath=fm_fast)) - - diff_hi = arith.as_value( - flir.arith.SubFOp(s_vals_hi[ii], m_new[ii], fastmath=fm_fast).result - ) - diff_hi_s = arith.as_value( - flir.arith.MulFOp(diff_hi, arith.as_value(c_log2e), fastmath=fm_fast).result - ) - p_vals_hi[ii] = arith.as_value(flir.math.exp2(diff_hi_s, fastmath=fm_fast)) - - # Sum over 32 positions - row_sum_lo = p_vals_lo[ii] - row_sum_hi = p_vals_hi[ii] - - for sh in [8, 4, 2, 1]: - sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value( - gpu.ShuffleOp(row_sum_lo, sh_i32, width_i32, mode="xor").shuffleResult - ) - row_sum_lo = arith.as_value( - flir.arith.AddFOp(row_sum_lo, peer, fastmath=fm_fast).result - ) - - for sh in [8, 4, 2, 1]: - sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value( - gpu.ShuffleOp(row_sum_hi, sh_i32, width_i32, mode="xor").shuffleResult - ) - row_sum_hi = arith.as_value( - flir.arith.AddFOp(row_sum_hi, peer, fastmath=fm_fast).result - ) - - row_sum = arith.as_value( - flir.arith.AddFOp(row_sum_lo, row_sum_hi, fastmath=fm_fast).result - ) - - l_corr = arith.as_value( - flir.arith.MulFOp(corr[ii], l_old[ii], fastmath=fm_fast).result - ) - l_new[ii] = arith.as_value( - flir.arith.AddFOp(l_corr, row_sum, fastmath=fm_fast).result - ) - - # ==== Rescale O accumulators ==== - corr_vec = arith.as_value( - vec_ext.from_elements(v4f32_type, [corr[0], corr[1], corr[2], corr[3]]) - ) - for ds in range_constexpr(K_STEPS): - o_accs[ds] = arith.as_value( - flir.arith.MulFOp(o_accs[ds], corr_vec, fastmath=fm_fast).result - ) - - # ==== P store to LDS_P ==== - # P is [16, 32] per wave. Two 16x16 blocks: lo (cols 0..15) and hi (cols 16..31) - for ii in range_constexpr(4): - p_lo_f16 = arith.as_value( - flir.arith.TruncFOp(elem_type, p_vals_lo[ii]).result - ) - p_hi_f16 = arith.as_value( - flir.arith.TruncFOp(elem_type, p_vals_hi[ii]).result - ) - p_row = (arith.ArithValue(lane_div_16) * 4 + ii).value - # Lo: cols 0..15 - p_lds_lo = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(p_row) * BLOCK_N - + arith.ArithValue(lane_mod_16) - ).value - _memref.StoreOp(p_lo_f16, lds_p, [p_lds_lo]) - # Hi: cols 16..31 - p_lds_hi = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(p_row) * BLOCK_N - + 16 - + arith.ArithValue(lane_mod_16) - ).value - _memref.StoreOp(p_hi_f16, lds_p, [p_lds_hi]) - - # ==== Barrier: ensure all waves done reading K ==== - gpu.barrier() - - # ==== Cooperative V load (transposed) ==== - coop_load_v_transposed(kv_start) - gpu.barrier() - - # ==== P @ V via MFMA ==== - # P[16, 32] @ V[32, 16chunk] = O[16, 16chunk] - # Split K=32 into two halves: P_lo[16,16] @ V_top[16,16] + P_hi[16,16] @ V_bot[16,16] - - # Load P A-operand packs: P_lo and P_hi - p_a_lo_idx = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(lane_mod_16) * BLOCK_N - + arith.ArithValue(lane_div_16) * 4 - ).value - p_pack_lo = arith.as_value( - vec_ext.load_op(v4f16_type, lds_p, [p_a_lo_idx]) - ) - - p_a_hi_idx = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(lane_mod_16) * BLOCK_N - + 16 - + arith.ArithValue(lane_div_16) * 4 - ).value - p_pack_hi = arith.as_value( - vec_ext.load_op(v4f16_type, lds_p, [p_a_hi_idx]) - ) - - for ds in range_constexpr(K_STEPS): - # V_top: V rows 0..15, B-operand from transposed LDS - v_top_idx = ( - (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE - + arith.ArithValue(lane_div_16) * 4 - ).value - v_top = arith.as_value( - vec_ext.load_op(v4f16_type, lds_kv, [v_top_idx]) - ) - # Accumulate P_lo @ V_top - o_accs[ds] = arith.as_value( - rocdl.mfma_f32_16x16x16f16( - v4f32_type, [p_pack_lo, v_top, o_accs[ds], 0, 0, 0] - ) - ) - - # V_bot: V rows 16..31, B-operand from transposed LDS - v_bot_idx = ( - (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE - + 16 - + arith.ArithValue(lane_div_16) * 4 - ).value - v_bot = arith.as_value( - vec_ext.load_op(v4f16_type, lds_kv, [v_bot_idx]) - ) - # Accumulate P_hi @ V_bot - o_accs[ds] = arith.as_value( - rocdl.mfma_f32_16x16x16f16( - v4f32_type, [p_pack_hi, v_bot, o_accs[ds], 0, 0, 0] - ) - ) - - # ==== Barrier: ensure all waves done reading V ==== - gpu.barrier() - - # ==== Yield ==== - yield_args = m_new + l_new + o_accs - scf_yield(yield_args) - - # ---- Normalize and store O ---- - m_finals = [arith.as_value(loop.results[i]) for i in range(4)] - l_finals = [arith.as_value(loop.results[4 + i]) for i in range(4)] - o_finals = [arith.as_value(loop.results[8 + ds]) for ds in range(K_STEPS)] - - for ds in range_constexpr(K_STEPS): - for ii in range_constexpr(4): - o_val = arith.as_value( - vec_ext.extract(o_finals[ds], static_position=[ii], dynamic_position=[]) - ) - o_norm = arith.as_value( - flir.arith.DivFOp(o_val, l_finals[ii], fastmath=fm_fast).result - ) - o_f16 = arith.as_value( - flir.arith.TruncFOp(elem_type, o_norm).result - ) - q_row = ( - arith.ArithValue(q_start) - + arith.ArithValue(wave_q_offset) - + arith.ArithValue(lane_div_16) * 4 - + ii - ).value - d_col = (flir.const_index(ds * 16) + arith.ArithValue(lane_mod_16)).value - o_global = global_idx(q_row, d_col) - _memref.StoreOp(o_f16, O, [o_global]) - - @flir.jit - def __call__( - self: flir.T.i64, - Q: lambda: T.memref(DYN, _state["elem_type"]), - K: lambda: T.memref(DYN, _state["elem_type"]), - V: lambda: T.memref(DYN, _state["elem_type"]), - O: lambda: T.memref(DYN, _state["elem_type"]), - batch_size: lambda: T.index(), - seq_len: lambda: T.index(), - ): - c1 = arith.as_value(flir.arith_ext.index(1)) - c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) - c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) - bs_val = arith.as_value(batch_size) - sl_val = arith.as_value(seq_len) - num_q_tiles = arith.as_value( - flir.arith.DivUIOp(sl_val, c_bm).result - ) - bs_qt = arith.as_value( - flir.arith.MulIOp(bs_val, num_q_tiles).result - ) - grid_x = arith.as_value( - flir.arith.MulIOp(bs_qt, c_nh).result - ) - bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) - flir.gpu_ext.LaunchFuncOp( - [self.GPU_MODULE_NAME, KERNEL_NAME], - grid_size=(grid_x, c1, c1), - block_size=(bx, c1, c1), - kernel_operands=[Q, K, V, O, seq_len], - ) - - return _FlashAttentionV4_2() diff --git a/kernels/flash_attention_v4_3.py b/kernels/flash_attention_v4_3.py deleted file mode 100644 index bdfe26f9..00000000 --- a/kernels/flash_attention_v4_3.py +++ /dev/null @@ -1,650 +0,0 @@ -"""Flash Attention V4.3 kernel builder for FlyDSL. - -V4.3 optimization over V4.2: -- Q loaded directly from global memory to MFMA registers (no Q in LDS). - LDS = KV(8.5KB) + P(4KB) = 12.5KB (was 29KB in V4.2). - This enables 4 workgroups/CU -> 4 waves/SIMD (was 2 waves/SIMD). -- Eliminates 2 barriers (Q store + Q preload sync). - -All other optimizations from V4.2: -- BLOCK_N=32 (vs 16): halves KV iterations and barriers -- Q@K^T produces [16,32] via two MFMA 16x16x16 in N dimension -- P@V uses K=32 via two MFMA 16x16x16 in K dimension -- Softmax over 32 positions per row (two 16-wide groups) -- V stored transposed in LDS with bank-conflict-free padding (from V4.1) -- Causal early-exit - -Tile config: BLOCK_M=64, BLOCK_N=32, 4 waves (256 threads), mfma_f32_16x16x16f16. - -Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). -Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. -Block: (256,) -- 4 waves of 64 on AMD (wave64). - -Requires: head_dim % 16 == 0, seq_len % 64 == 0, head_dim >= 64. -""" - -import math - -from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl -from flydsl.dialects.ext import vector as vec_ext -from flydsl.dialects.ext.python_control_flow import range_constexpr -from flydsl.dialects.ext.scf import yield_ as scf_yield -from _mlir.dialects import memref as _memref -from flydsl.runtime.device import get_rocm_arch as get_hip_arch -from flydsl.utils import SmemAllocator -from _mlir import ir -import _mlir.extras.types as T - - -KERNEL_NAME = "flash_attention_v4_3_kernel" - - -def build_flash_attention_v4_3_module( - num_heads, - head_dim, - causal=True, - dtype_str="f16", - sm_scale=None, -): - """Build a FlyDSL Flash Attention V4.3 module (LDS overlay). - - Args: - num_heads: Number of attention heads. - head_dim: Dimension per head (must be divisible by 16, >= 64). - causal: Whether to apply causal mask. - dtype_str: "f16" (bf16 not yet supported). - sm_scale: Softmax scale (default: 1/sqrt(head_dim)). - - Returns: - MlirModule compilable via ``flydsl.compile(module)``. - """ - gpu_arch = get_hip_arch() - DYN = ir.ShapedType.get_dynamic_size() - - BLOCK_M = 64 - BLOCK_N = 32 # *** doubled from V4.1 *** - NUM_WAVES = 4 - WARP_SIZE = 64 - BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 - ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 16 - K_STEPS = head_dim // 16 - # Number of 16-wide MFMA columns in Q@K^T N-dimension - N_MFMA = BLOCK_N // 16 # 2 - - assert head_dim % 16 == 0, f"head_dim ({head_dim}) must be divisible by 16" - assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" - assert dtype_str == "f16", "V4.3 currently only supports f16" - assert BLOCK_N % 16 == 0, f"BLOCK_N ({BLOCK_N}) must be divisible by 16" - - if sm_scale is None: - sm_scale = 1.0 / math.sqrt(head_dim) - - NUM_HEADS = num_heads - HEAD_DIM = head_dim - CAUSAL = causal - STRIDE_TOKEN = NUM_HEADS * HEAD_DIM - - # ---- Bank-conflict-free LDS strides ---- - K_STRIDE = HEAD_DIM + 2 # 130 for HD=128 - VT_STRIDE = BLOCK_N + 2 # 34 for BLOCK_N=32 - - # ---- Vectorized cooperative load constants ---- - VEC_WIDTH = 8 - THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH - assert BLOCK_SIZE % THREADS_PER_ROW_LOAD == 0 - ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD - - assert BLOCK_M % ROWS_PER_BATCH_LOAD == 0 - NUM_BATCHES_Q = BLOCK_M // ROWS_PER_BATCH_LOAD - - # For KV tile (32 rows with 256 threads) - if ROWS_PER_BATCH_LOAD >= BLOCK_N: - NUM_BATCHES_KV = 1 - KV_NEEDS_GUARD = ROWS_PER_BATCH_LOAD > BLOCK_N - else: - assert BLOCK_N % ROWS_PER_BATCH_LOAD == 0 - NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD - KV_NEEDS_GUARD = False - - # LDS sizes (element counts, f16 = 2 bytes each) - # No Q in LDS — loaded directly from global memory to MFMA registers - LDS_KV_SIZE = max(BLOCK_N * K_STRIDE, HEAD_DIM * VT_STRIDE) # 4352 elements = 8704 bytes - LDS_P_SIZE = BLOCK_M * BLOCK_N # 2048 elements = 4096 bytes - - allocator = SmemAllocator(None, arch=gpu_arch) - _state = {} - - class _FlashAttentionV4_3(flir.MlirModule): - GPU_MODULE_NAME = f"flash_attn_v4_3_{dtype_str}" - GPU_MODULE_TARGETS = [f'#rocdl.target'] - - def init_gpu_module(self): - elem_type = T.f16() - _state["elem_type"] = elem_type - _state["lds_kv"] = allocator.allocate_array(elem_type, LDS_KV_SIZE) - _state["lds_p"] = allocator.allocate_array(elem_type, LDS_P_SIZE) - allocator.finalize() - - @flir.kernel - def flash_attention_v4_3_kernel( - self: flir.T.i64, - Q: lambda: T.memref(DYN, _state["elem_type"]), - K: lambda: T.memref(DYN, _state["elem_type"]), - V: lambda: T.memref(DYN, _state["elem_type"]), - O: lambda: T.memref(DYN, _state["elem_type"]), - seq_len: lambda: T.index(), - ): - compute_type = T.f32() - elem_type = _state["elem_type"] - fm_fast = flir.arith.FastMathFlags.fast - - v4f16_type = ir.VectorType.get([4], elem_type) - v4f32_type = ir.VectorType.get([4], compute_type) - v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) - - seq_len_v = arith.as_value(seq_len) - - # ---- LDS views (KV + P only, no Q in LDS) ---- - base_ptr = allocator.get_base() - lds_kv = _state["lds_kv"](base_ptr).get() - lds_p = _state["lds_p"](base_ptr).get() - - # ---- Thread / block indices ---- - block_id = flir.const_index(flir.block_idx("x")) - tid = flir.const_index(flir.thread_idx("x")) - - # ---- Wave decomposition ---- - c_ws = flir.const_index(WARP_SIZE) - wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) - lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) - - # ---- MFMA lane decomposition ---- - c16 = flir.const_index(16) - lane_div_16 = arith.as_value(flir.arith.DivUIOp(lane, c16).result) - lane_mod_16 = arith.as_value(flir.arith.RemUIOp(lane, c16).result) - - # ---- Wave offsets ---- - wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value - wave_p_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE * BLOCK_N).value - - # ---- Decompose block_id ---- - c_nh = flir.const_index(NUM_HEADS) - head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) - temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) - c_bm = flir.const_index(BLOCK_M) - num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) - q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) - batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) - q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value - - # ---- Load thread decomposition ---- - c_tpr = flir.const_index(THREADS_PER_ROW_LOAD) - load_row_in_batch = arith.as_value( - flir.arith.DivUIOp(tid, c_tpr).result - ) - load_lane_in_row = arith.as_value( - flir.arith.RemUIOp(tid, c_tpr).result - ) - load_col_base = ( - arith.ArithValue(load_lane_in_row) * VEC_WIDTH - ).value - - # ---- Helper: global flat index ---- - def global_idx(token_idx, col): - token = ( - arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) - + arith.ArithValue(token_idx) - ) - return ( - token * STRIDE_TOKEN - + arith.ArithValue(head_idx) * HEAD_DIM - + arith.ArithValue(col) - ).value - - # ---- Cooperative K load (row-major, padded stride) ---- - def coop_load_k(tile_start): - for batch in range_constexpr(NUM_BATCHES_KV): - row_offset = batch * ROWS_PER_BATCH_LOAD - row_idx = ( - arith.ArithValue(tile_start) - + arith.ArithValue(load_row_in_batch) - + row_offset - ).value - g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value( - vec_ext.load_op(v8f16_type, K, [g_idx]) - ) - lds_row = ( - arith.ArithValue(load_row_in_batch) + row_offset - ).value - lds_idx = ( - arith.ArithValue(lds_row) * K_STRIDE - + arith.ArithValue(load_col_base) - ).value - vec_ext.store(vec, lds_kv, [lds_idx]) - - # ---- Cooperative V load (transposed, padded stride) ---- - def coop_load_v_transposed(tile_start): - for batch in range_constexpr(NUM_BATCHES_KV): - row_offset = batch * ROWS_PER_BATCH_LOAD - row_idx = ( - arith.ArithValue(tile_start) - + arith.ArithValue(load_row_in_batch) - + row_offset - ).value - g_idx = global_idx(row_idx, load_col_base) - vec = arith.as_value( - vec_ext.load_op(v8f16_type, V, [g_idx]) - ) - load_row = ( - arith.ArithValue(load_row_in_batch) + row_offset - ).value - # Scatter-store transposed: V[row, col+e] -> lds_kv[(col+e)*VT_STRIDE + row] - for e in range_constexpr(VEC_WIDTH): - elem = arith.as_value( - vec_ext.extract(vec, static_position=[e], dynamic_position=[]) - ) - col_e = (arith.ArithValue(load_col_base) + e).value - lds_idx = ( - arith.ArithValue(col_e) * VT_STRIDE - + arith.ArithValue(load_row) - ).value - _memref.StoreOp(elem, lds_kv, [lds_idx]) - - # ---- Load Q directly from global memory to MFMA registers ---- - # Each MFMA lane (b=lane_div_16, n=lane_mod_16) loads v4f16 from - # Q[q_start + wave_offset + n, ks*16 + b*4 : ks*16 + b*4 + 4]. - # No LDS needed for Q — eliminates overlay race condition. - q_row = ( - arith.ArithValue(q_start) - + arith.ArithValue(wave_q_offset) - + arith.ArithValue(lane_mod_16) - ).value - q_a_packs = [] - for ks in range_constexpr(K_STEPS): - q_col = flir.const_index(ks * 16 + 0) - q_col = (arith.ArithValue(q_col) + arith.ArithValue(lane_div_16) * 4).value - g_idx = global_idx(q_row, q_col) - q_a_packs.append(arith.as_value( - vec_ext.load_op(v4f16_type, Q, [g_idx]) - )) - - # ---- Constants ---- - c_neg_inf = arith.constant(float("-inf"), type=compute_type) - c_zero_f = arith.constant(0.0, type=compute_type) - c_sm_scale = arith.constant(sm_scale, type=compute_type) - c_log2e = arith.constant(1.4426950408889634, type=compute_type) - c_zero_v4f32 = arith.as_value( - arith.constant_vector(0.0, v4f32_type) - ) - - # ---- Init loop-carried state ---- - # m[4], l[4], o_accs[K_STEPS] - init_args = [] - for _ in range_constexpr(4): - init_args.append(arith.as_value(c_neg_inf)) - for _ in range_constexpr(4): - init_args.append(arith.as_value(c_zero_f)) - for _ in range_constexpr(K_STEPS): - init_args.append(c_zero_v4f32) - - # ---- KV loop upper bound ---- - # Causal early-exit: last Q row = q_start + BLOCK_M - 1, - # so only need KV positions 0 .. q_start + BLOCK_M - 1. - # q_start + BLOCK_M is always a multiple of BLOCK_N (64 % 32 == 0). - if CAUSAL: - kv_upper = (arith.ArithValue(q_start) + BLOCK_M).value - else: - kv_upper = seq_len_v - - # ---- KV loop (step BLOCK_N=32) ---- - with scf.for_(0, kv_upper, BLOCK_N, iter_args=init_args) as loop: - kv_start = arith.as_value(loop.induction_variable) - m_old = [arith.as_value(loop.inner_iter_args[i]) for i in range(4)] - l_old = [arith.as_value(loop.inner_iter_args[4 + i]) for i in range(4)] - o_accs = [arith.as_value(loop.inner_iter_args[8 + ds]) for ds in range(K_STEPS)] - - # ==== Cooperative K load -> LDS_KV (32 rows, padded stride) ==== - coop_load_k(kv_start) - gpu.barrier() - - # ==== Q @ K^T via MFMA -> S[16, 32] ==== - # Two MFMA outputs: s_acc[0] for KV cols 0..15, s_acc[1] for KV cols 16..31 - s_accs = [c_zero_v4f32, c_zero_v4f32] - for ks in range_constexpr(K_STEPS): - a_pack = q_a_packs[ks] - for nm in range_constexpr(N_MFMA): - # B operand (K^T): K row = nm*16 + lane_mod_16 - k_row = nm * 16 - k_lds_idx = ( - (arith.ArithValue(lane_mod_16) + k_row) * K_STRIDE - + ks * 16 - + arith.ArithValue(lane_div_16) * 4 - ).value - b_pack = arith.as_value( - vec_ext.load_op(v4f16_type, lds_kv, [k_lds_idx]) - ) - s_accs[nm] = arith.as_value( - rocdl.mfma_f32_16x16x16f16( - v4f32_type, [a_pack, b_pack, s_accs[nm], 0, 0, 0] - ) - ) - - # ==== Online softmax over 32 positions ==== - # For each row ii (0..3): have values at lane_mod_16 in s_accs[0] and s_accs[1] - # Need max and sum over all 32 positions - s_vals_lo = [] # from s_accs[0], KV cols 0..15 - s_vals_hi = [] # from s_accs[1], KV cols 16..31 - for ii in range_constexpr(4): - s_lo = arith.as_value( - vec_ext.extract(s_accs[0], static_position=[ii], dynamic_position=[]) - ) - s_lo = arith.as_value( - flir.arith.MulFOp(s_lo, arith.as_value(c_sm_scale), fastmath=fm_fast).result - ) - s_hi = arith.as_value( - vec_ext.extract(s_accs[1], static_position=[ii], dynamic_position=[]) - ) - s_hi = arith.as_value( - flir.arith.MulFOp(s_hi, arith.as_value(c_sm_scale), fastmath=fm_fast).result - ) - - if CAUSAL: - q_row = ( - arith.ArithValue(q_start) - + arith.ArithValue(wave_q_offset) - + arith.ArithValue(lane_div_16) * 4 - + ii - ).value - # Low half: KV col = kv_start + lane_mod_16 - kv_col_lo = (arith.ArithValue(kv_start) + arith.ArithValue(lane_mod_16)).value - # High half: KV col = kv_start + 16 + lane_mod_16 - kv_col_hi = (arith.ArithValue(kv_start) + 16 + arith.ArithValue(lane_mod_16)).value - - q_row_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), q_row).result) - kv_lo_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), kv_col_lo).result) - kv_hi_i64 = arith.as_value(flir.arith.IndexCastOp(T.i64(), kv_col_hi).result) - - is_masked_lo = arith.as_value( - flir.arith.CmpIOp( - flir.arith.CmpIPredicate.ugt, kv_lo_i64, q_row_i64, - ).result - ) - is_masked_hi = arith.as_value( - flir.arith.CmpIOp( - flir.arith.CmpIPredicate.ugt, kv_hi_i64, q_row_i64, - ).result - ) - s_lo = arith.as_value( - flir.arith.SelectOp(is_masked_lo, arith.as_value(c_neg_inf), s_lo).result - ) - s_hi = arith.as_value( - flir.arith.SelectOp(is_masked_hi, arith.as_value(c_neg_inf), s_hi).result - ) - - s_vals_lo.append(s_lo) - s_vals_hi.append(s_hi) - - width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) - m_new = [None] * 4 - corr = [None] * 4 - p_vals_lo = [None] * 4 - p_vals_hi = [None] * 4 - l_new = [None] * 4 - - for ii in range_constexpr(4): - # Max over 32 positions: max of lo-half and hi-half - row_max_lo = s_vals_lo[ii] - row_max_hi = s_vals_hi[ii] - - # Reduce lo-half within 16 lanes - for sh in [8, 4, 2, 1]: - sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value( - gpu.ShuffleOp(row_max_lo, sh_i32, width_i32, mode="xor").shuffleResult - ) - row_max_lo = arith.as_value( - flir.arith.MaximumFOp(row_max_lo, peer).result - ) - - # Reduce hi-half within 16 lanes - for sh in [8, 4, 2, 1]: - sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value( - gpu.ShuffleOp(row_max_hi, sh_i32, width_i32, mode="xor").shuffleResult - ) - row_max_hi = arith.as_value( - flir.arith.MaximumFOp(row_max_hi, peer).result - ) - - # Combine lo and hi maxes - row_max = arith.as_value( - flir.arith.MaximumFOp(row_max_lo, row_max_hi).result - ) - - m_new[ii] = arith.as_value( - flir.arith.MaximumFOp(m_old[ii], row_max).result - ) - - diff_m = arith.as_value( - flir.arith.SubFOp(m_old[ii], m_new[ii], fastmath=fm_fast).result - ) - diff_m_s = arith.as_value( - flir.arith.MulFOp(diff_m, arith.as_value(c_log2e), fastmath=fm_fast).result - ) - corr[ii] = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) - - # exp2 for both halves - diff_lo = arith.as_value( - flir.arith.SubFOp(s_vals_lo[ii], m_new[ii], fastmath=fm_fast).result - ) - diff_lo_s = arith.as_value( - flir.arith.MulFOp(diff_lo, arith.as_value(c_log2e), fastmath=fm_fast).result - ) - p_vals_lo[ii] = arith.as_value(flir.math.exp2(diff_lo_s, fastmath=fm_fast)) - - diff_hi = arith.as_value( - flir.arith.SubFOp(s_vals_hi[ii], m_new[ii], fastmath=fm_fast).result - ) - diff_hi_s = arith.as_value( - flir.arith.MulFOp(diff_hi, arith.as_value(c_log2e), fastmath=fm_fast).result - ) - p_vals_hi[ii] = arith.as_value(flir.math.exp2(diff_hi_s, fastmath=fm_fast)) - - # Sum over 32 positions - row_sum_lo = p_vals_lo[ii] - row_sum_hi = p_vals_hi[ii] - - for sh in [8, 4, 2, 1]: - sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value( - gpu.ShuffleOp(row_sum_lo, sh_i32, width_i32, mode="xor").shuffleResult - ) - row_sum_lo = arith.as_value( - flir.arith.AddFOp(row_sum_lo, peer, fastmath=fm_fast).result - ) - - for sh in [8, 4, 2, 1]: - sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) - peer = arith.as_value( - gpu.ShuffleOp(row_sum_hi, sh_i32, width_i32, mode="xor").shuffleResult - ) - row_sum_hi = arith.as_value( - flir.arith.AddFOp(row_sum_hi, peer, fastmath=fm_fast).result - ) - - row_sum = arith.as_value( - flir.arith.AddFOp(row_sum_lo, row_sum_hi, fastmath=fm_fast).result - ) - - l_corr = arith.as_value( - flir.arith.MulFOp(corr[ii], l_old[ii], fastmath=fm_fast).result - ) - l_new[ii] = arith.as_value( - flir.arith.AddFOp(l_corr, row_sum, fastmath=fm_fast).result - ) - - # ==== Rescale O accumulators ==== - corr_vec = arith.as_value( - vec_ext.from_elements(v4f32_type, [corr[0], corr[1], corr[2], corr[3]]) - ) - for ds in range_constexpr(K_STEPS): - o_accs[ds] = arith.as_value( - flir.arith.MulFOp(o_accs[ds], corr_vec, fastmath=fm_fast).result - ) - - # ==== P store to LDS_P ==== - # P is [16, 32] per wave. Two 16x16 blocks: lo (cols 0..15) and hi (cols 16..31) - for ii in range_constexpr(4): - p_lo_f16 = arith.as_value( - flir.arith.TruncFOp(elem_type, p_vals_lo[ii]).result - ) - p_hi_f16 = arith.as_value( - flir.arith.TruncFOp(elem_type, p_vals_hi[ii]).result - ) - p_row = (arith.ArithValue(lane_div_16) * 4 + ii).value - # Lo: cols 0..15 - p_lds_lo = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(p_row) * BLOCK_N - + arith.ArithValue(lane_mod_16) - ).value - _memref.StoreOp(p_lo_f16, lds_p, [p_lds_lo]) - # Hi: cols 16..31 - p_lds_hi = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(p_row) * BLOCK_N - + 16 - + arith.ArithValue(lane_mod_16) - ).value - _memref.StoreOp(p_hi_f16, lds_p, [p_lds_hi]) - - # ==== Barrier: ensure all waves done reading K ==== - gpu.barrier() - - # ==== Cooperative V load (transposed) ==== - coop_load_v_transposed(kv_start) - gpu.barrier() - - # ==== P @ V via MFMA ==== - # P[16, 32] @ V[32, 16chunk] = O[16, 16chunk] - # Split K=32 into two halves: P_lo[16,16] @ V_top[16,16] + P_hi[16,16] @ V_bot[16,16] - - # Load P A-operand packs: P_lo and P_hi - p_a_lo_idx = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(lane_mod_16) * BLOCK_N - + arith.ArithValue(lane_div_16) * 4 - ).value - p_pack_lo = arith.as_value( - vec_ext.load_op(v4f16_type, lds_p, [p_a_lo_idx]) - ) - - p_a_hi_idx = ( - arith.ArithValue(wave_p_offset) - + arith.ArithValue(lane_mod_16) * BLOCK_N - + 16 - + arith.ArithValue(lane_div_16) * 4 - ).value - p_pack_hi = arith.as_value( - vec_ext.load_op(v4f16_type, lds_p, [p_a_hi_idx]) - ) - - for ds in range_constexpr(K_STEPS): - # V_top: V rows 0..15, B-operand from transposed LDS - v_top_idx = ( - (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE - + arith.ArithValue(lane_div_16) * 4 - ).value - v_top = arith.as_value( - vec_ext.load_op(v4f16_type, lds_kv, [v_top_idx]) - ) - # Accumulate P_lo @ V_top - o_accs[ds] = arith.as_value( - rocdl.mfma_f32_16x16x16f16( - v4f32_type, [p_pack_lo, v_top, o_accs[ds], 0, 0, 0] - ) - ) - - # V_bot: V rows 16..31, B-operand from transposed LDS - v_bot_idx = ( - (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE - + 16 - + arith.ArithValue(lane_div_16) * 4 - ).value - v_bot = arith.as_value( - vec_ext.load_op(v4f16_type, lds_kv, [v_bot_idx]) - ) - # Accumulate P_hi @ V_bot - o_accs[ds] = arith.as_value( - rocdl.mfma_f32_16x16x16f16( - v4f32_type, [p_pack_hi, v_bot, o_accs[ds], 0, 0, 0] - ) - ) - - # ==== Barrier: ensure all waves done reading V ==== - gpu.barrier() - - # ==== Yield ==== - yield_args = m_new + l_new + o_accs - scf_yield(yield_args) - - # ---- Normalize and store O ---- - m_finals = [arith.as_value(loop.results[i]) for i in range(4)] - l_finals = [arith.as_value(loop.results[4 + i]) for i in range(4)] - o_finals = [arith.as_value(loop.results[8 + ds]) for ds in range(K_STEPS)] - - for ds in range_constexpr(K_STEPS): - for ii in range_constexpr(4): - o_val = arith.as_value( - vec_ext.extract(o_finals[ds], static_position=[ii], dynamic_position=[]) - ) - o_norm = arith.as_value( - flir.arith.DivFOp(o_val, l_finals[ii], fastmath=fm_fast).result - ) - o_f16 = arith.as_value( - flir.arith.TruncFOp(elem_type, o_norm).result - ) - q_row = ( - arith.ArithValue(q_start) - + arith.ArithValue(wave_q_offset) - + arith.ArithValue(lane_div_16) * 4 - + ii - ).value - d_col = (flir.const_index(ds * 16) + arith.ArithValue(lane_mod_16)).value - o_global = global_idx(q_row, d_col) - _memref.StoreOp(o_f16, O, [o_global]) - - @flir.jit - def __call__( - self: flir.T.i64, - Q: lambda: T.memref(DYN, _state["elem_type"]), - K: lambda: T.memref(DYN, _state["elem_type"]), - V: lambda: T.memref(DYN, _state["elem_type"]), - O: lambda: T.memref(DYN, _state["elem_type"]), - batch_size: lambda: T.index(), - seq_len: lambda: T.index(), - ): - c1 = arith.as_value(flir.arith_ext.index(1)) - c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) - c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) - bs_val = arith.as_value(batch_size) - sl_val = arith.as_value(seq_len) - num_q_tiles = arith.as_value( - flir.arith.DivUIOp(sl_val, c_bm).result - ) - bs_qt = arith.as_value( - flir.arith.MulIOp(bs_val, num_q_tiles).result - ) - grid_x = arith.as_value( - flir.arith.MulIOp(bs_qt, c_nh).result - ) - bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) - flir.gpu_ext.LaunchFuncOp( - [self.GPU_MODULE_NAME, KERNEL_NAME], - grid_size=(grid_x, c1, c1), - block_size=(bx, c1, c1), - kernel_operands=[Q, K, V, O, seq_len], - ) - - return _FlashAttentionV4_3() diff --git a/run.sh b/run.sh index ec4ff6cc..86959e1c 100755 --- a/run.sh +++ b/run.sh @@ -40,7 +40,10 @@ function run_flydsl_op { # python tests/kernels/test_flash_attention_v4_2.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 # python tests/kernels/test_flash_attention_v4_3.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v42 # python tests/kernels/test_flash_attn_func.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v43 - python tests/kernels/test_flash_attn_func.py --iters 100 --compare-v43 + # python tests/kernels/test_flash_attn_func.py --iters 100 --compare-v43 + + python tests/kernels/test_flash_attn_func.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 + python tests/kernels/test_flash_attn_func.py --iters 100 # rocprof -i perf_counters1.txt -o prof_v44_p1.csv python tests/kernels/test_flash_attn_func.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 5 --warmup 2 # rocprof -i perf_counters2.txt -o prof_v44_p2.csv python tests/kernels/test_flash_attn_func.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 5 --warmup 2 diff --git a/tests/kernels/test_flash_attention_v4.py b/tests/kernels/test_flash_attention_v4.py deleted file mode 100644 index 9f1d52d8..00000000 --- a/tests/kernels/test_flash_attention_v4.py +++ /dev/null @@ -1,288 +0,0 @@ -#!/usr/bin/env python3 -"""Flash Attention V4 (Multi-Wave MFMA) kernel test and benchmark for FlyDSL. - -Tests the V4 multi-wave Flash Attention kernel against PyTorch SDPA reference. -Optionally compares performance with V3 kernels. - -Usage: - python tests/kernels/test_flash_attention_v4.py - python tests/kernels/test_flash_attention_v4.py --seq_len 512 --head_dim 128 - python tests/kernels/test_flash_attention_v4.py --compare-v3 -""" - -import sys -import os -import argparse -from pathlib import Path - -_repo = Path(__file__).resolve().parents[2] -sys.path.insert(0, str(_repo)) - -try: - import torch - import torch.nn.functional as F -except ImportError: - print("PyTorch not available") - sys.exit(1) - -if not torch.cuda.is_available(): - print("CUDA/ROCm not available") - sys.exit(1) - -import flydsl -from kernels.flash_attention_v4 import build_flash_attention_v4_module, KERNEL_NAME - - -def pytorch_ref_attention(q, k, v, causal=True): - """PyTorch SDPA reference. q/k/v: (B, S, H, D) float32 -> (B, S, H, D).""" - q_t = q.transpose(1, 2).float() - k_t = k.transpose(1, 2).float() - v_t = v.transpose(1, 2).float() - out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=causal) - return out.transpose(1, 2) - - -def bench_gpu_us(fn, warmup=10, iters=50): - """Benchmark a GPU function, return average microseconds.""" - for _ in range(warmup): - fn() - torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(iters): - fn() - end.record() - torch.cuda.synchronize() - return (start.elapsed_time(end) / iters) * 1000 - - -def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, - warmup, iters, v3_exe=None): - """Run one configuration. Returns dict with results.""" - device = "cuda" - results = {} - - # V4 requires seq_len divisible by BLOCK_M=64, head_dim by 16, head_dim >= 64 - if seq_len % 64 != 0: - results["err"] = f"seq_len ({seq_len}) must be divisible by 64" - return results - if head_dim % 16 != 0 or head_dim < 64: - results["err"] = f"head_dim ({head_dim}) must be >= 64 and divisible by 16" - return results - - try: - m = build_flash_attention_v4_module( - num_heads=num_heads, - head_dim=head_dim, - causal=causal, - dtype_str="f16", - ) - exe = flydsl.compile(m) - except Exception as e: - results["err"] = f"compile: {e}" - import traceback - traceback.print_exc() - return results - - B, S, H, D = batch, seq_len, num_heads, head_dim - q_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) - k_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) - v_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) - - q_flat = q_4d.contiguous().view(-1) - k_flat = k_4d.contiguous().view(-1) - v_flat = v_4d.contiguous().view(-1) - o_flat = torch.zeros_like(q_flat) - - try: - exe(q_flat, k_flat, v_flat, o_flat, B, S) - torch.cuda.synchronize() - except Exception as e: - results["err"] = f"exec: {e}" - import traceback - traceback.print_exc() - return results - - # PyTorch reference - ref_4d = pytorch_ref_attention( - q_4d.float(), k_4d.float(), v_4d.float(), causal=causal - ).to(dtype) - ref_flat = ref_4d.contiguous().view(-1) - - # Correctness - o_f32 = o_flat.float() - ref_f32 = ref_flat.float() - max_err = (o_f32 - ref_f32).abs().max().item() - mean_err = (o_f32 - ref_f32).abs().mean().item() - cos_sim = F.cosine_similarity( - o_f32.view(-1, D), ref_f32.view(-1, D), dim=1 - ) - min_cos = cos_sim.min().item() - results["max_err"] = max_err - results["mean_err"] = mean_err - results["min_cos"] = min_cos - - atol = 1e-2 - results["passed"] = max_err < atol and min_cos > 0.99 - - # Benchmark V4 - try: - def kernel_fn(): - o_flat.zero_() - exe(q_flat, k_flat, v_flat, o_flat, B, S) - - us = bench_gpu_us(kernel_fn, warmup=warmup, iters=iters) - s_eff = S / 2.0 if causal else float(S) - flops = 4.0 * S * s_eff * D * H * B - tflops = flops / (us * 1e-6) / 1e12 - results["us"] = us - results["tflops"] = tflops - except Exception as e: - results["bench_err"] = str(e) - - # Benchmark V3 for comparison - if v3_exe is not None: - try: - o_v3 = torch.zeros_like(q_flat) - - def v3_fn(): - o_v3.zero_() - v3_exe(q_flat, k_flat, v_flat, o_v3, B, S) - - v3_us = bench_gpu_us(v3_fn, warmup=warmup, iters=iters) - v3_tflops = flops / (v3_us * 1e-6) / 1e12 - results["v3_us"] = v3_us - results["v3_tflops"] = v3_tflops - except Exception as e: - results["v3_bench_err"] = str(e) - - return results - - -def main(): - parser = argparse.ArgumentParser( - description="Flash Attention V4 (Multi-Wave MFMA) FlyDSL Test/Benchmark" - ) - parser.add_argument("--batch", type=int, default=None) - parser.add_argument("--seq_len", type=int, default=None) - parser.add_argument("--num_heads", type=int, default=None) - parser.add_argument("--head_dim", type=int, default=None) - parser.add_argument( - "--dtype", type=str, default="fp16", choices=["fp16"] - ) - parser.add_argument("--no-causal", action="store_true") - parser.add_argument("--warmup", type=int, default=5) - parser.add_argument("--iters", type=int, default=20) - parser.add_argument("--compare-v3", action="store_true", - help="Also benchmark V3 for comparison") - args = parser.parse_args() - - causal = not args.no_causal - dtype = torch.float16 - causal_str = "causal" if causal else "non-causal" - - print("=" * 120) - print(f"FlyDSL Flash Attention V4 Multi-Wave MFMA ({causal_str}, fp16)") - print(f" BLOCK_M=64, BLOCK_N=16, 4 waves (256 threads), mfma_f32_16x16x16f16") - print(f"GPU: {torch.cuda.get_device_name(0)}") - print("=" * 120) - - if args.seq_len or args.head_dim or args.batch: - configs = [( - args.batch or 1, - args.seq_len or 128, - args.num_heads or 8, - args.head_dim or 128, - )] - else: - configs = [ - (1, 64, 8, 128), - (1, 128, 8, 128), - (1, 256, 32, 128), - (1, 512, 32, 128), - (2, 128, 8, 128), - ] - - # Pre-compile V3 if comparing - v3_exes = {} - if args.compare_v3: - from kernels.flash_attention_v3 import build_flash_attention_v3_module - for _, _, nh, hd in configs: - key = (nh, hd) - if key not in v3_exes: - try: - m = build_flash_attention_v3_module( - num_heads=nh, head_dim=hd, - causal=causal, dtype_str="f16", - ) - v3_exes[key] = flydsl.compile(m) - except Exception: - v3_exes[key] = None - - if args.compare_v3: - hdr = ( - f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " - f"{'MinCos':>8s} | {'V4(us)':>10s} {'V4 TFLOPS':>9s} | " - f"{'V3(us)':>10s} {'V3 TFLOPS':>9s} | {'Speedup':>7s}" - ) - else: - hdr = ( - f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " - f"{'MinCos':>8s} | {'Time(us)':>10s} {'TFLOPS':>8s}" - ) - print(f"\n{hdr}") - print("-" * len(hdr)) - - all_passed = True - for batch, seq_len, nh, hd in configs: - tag = f"B={batch} S={seq_len} H={nh} D={hd}" - try: - v3_exe = v3_exes.get((nh, hd)) if args.compare_v3 else None - r = run_config( - batch, seq_len, nh, hd, dtype, causal, - warmup=args.warmup, iters=args.iters, - v3_exe=v3_exe, - ) - if "err" in r: - print(f"{tag:>38s} | {'ERROR':>6s} | {r['err'][:60]}") - all_passed = False - continue - - status = "PASS" if r["passed"] else "FAIL" - if not r["passed"]: - all_passed = False - - v4_us = f"{r['us']:>10.1f}" if "us" in r else " N/A" - v4_tf = f"{r['tflops']:>9.3f}" if "tflops" in r else " N/A" - - if args.compare_v3 and "v3_us" in r: - v3_us = f"{r['v3_us']:>10.1f}" - v3_tf = f"{r['v3_tflops']:>9.3f}" - speedup = r["v3_us"] / r["us"] if r.get("us") else 0 - sp_s = f"{speedup:>6.2f}x" - print( - f"{tag:>38s} | {status:>6s} | " - f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " - f"{v4_us} {v4_tf} | {v3_us} {v3_tf} | {sp_s}" - ) - else: - print( - f"{tag:>38s} | {status:>6s} | " - f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " - f"{v4_us} {v4_tf}" - ) - except Exception as e: - print(f"{tag:>38s} | {'ERROR':>6s} | {str(e)[:60]}") - all_passed = False - - print("=" * 120) - if all_passed: - print("All tests PASSED") - else: - print("Some tests FAILED") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/tests/kernels/test_flash_attention_v4_1.py b/tests/kernels/test_flash_attention_v4_1.py deleted file mode 100644 index 5763d60d..00000000 --- a/tests/kernels/test_flash_attention_v4_1.py +++ /dev/null @@ -1,288 +0,0 @@ -#!/usr/bin/env python3 -"""Flash Attention V4.1 kernel test and benchmark for FlyDSL. - -Tests V4.1 (Q-in-registers, transposed V, bank-conflict-free padding) against -PyTorch SDPA reference. Optionally compares with V4.0. - -Usage: - python tests/kernels/test_flash_attention_v4_1.py - python tests/kernels/test_flash_attention_v4_1.py --seq_len 512 --head_dim 128 - python tests/kernels/test_flash_attention_v4_1.py --compare-v4 -""" - -import sys -import os -import argparse -from pathlib import Path - -_repo = Path(__file__).resolve().parents[2] -sys.path.insert(0, str(_repo)) - -try: - import torch - import torch.nn.functional as F -except ImportError: - print("PyTorch not available") - sys.exit(1) - -if not torch.cuda.is_available(): - print("CUDA/ROCm not available") - sys.exit(1) - -import flydsl -from kernels.flash_attention_v4_1 import build_flash_attention_v4_1_module, KERNEL_NAME - - -def pytorch_ref_attention(q, k, v, causal=True): - """PyTorch SDPA reference. q/k/v: (B, S, H, D) float32 -> (B, S, H, D).""" - q_t = q.transpose(1, 2).float() - k_t = k.transpose(1, 2).float() - v_t = v.transpose(1, 2).float() - out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=causal) - return out.transpose(1, 2) - - -def bench_gpu_us(fn, warmup=10, iters=50): - """Benchmark a GPU function, return average microseconds.""" - for _ in range(warmup): - fn() - torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(iters): - fn() - end.record() - torch.cuda.synchronize() - return (start.elapsed_time(end) / iters) * 1000 - - -def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, - warmup, iters, v4_exe=None): - """Run one configuration. Returns dict with results.""" - device = "cuda" - results = {} - - if seq_len % 64 != 0: - results["err"] = f"seq_len ({seq_len}) must be divisible by 64" - return results - if head_dim % 16 != 0 or head_dim < 64: - results["err"] = f"head_dim ({head_dim}) must be >= 64 and divisible by 16" - return results - - try: - m = build_flash_attention_v4_1_module( - num_heads=num_heads, - head_dim=head_dim, - causal=causal, - dtype_str="f16", - ) - exe = flydsl.compile(m) - except Exception as e: - results["err"] = f"compile: {e}" - import traceback - traceback.print_exc() - return results - - B, S, H, D = batch, seq_len, num_heads, head_dim - q_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) - k_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) - v_4d = torch.randn(B, S, H, D, dtype=dtype, device=device) - - q_flat = q_4d.contiguous().view(-1) - k_flat = k_4d.contiguous().view(-1) - v_flat = v_4d.contiguous().view(-1) - o_flat = torch.zeros_like(q_flat) - - try: - exe(q_flat, k_flat, v_flat, o_flat, B, S) - torch.cuda.synchronize() - except Exception as e: - results["err"] = f"exec: {e}" - import traceback - traceback.print_exc() - return results - - # PyTorch reference - ref_4d = pytorch_ref_attention( - q_4d.float(), k_4d.float(), v_4d.float(), causal=causal - ).to(dtype) - ref_flat = ref_4d.contiguous().view(-1) - - # Correctness - o_f32 = o_flat.float() - ref_f32 = ref_flat.float() - max_err = (o_f32 - ref_f32).abs().max().item() - mean_err = (o_f32 - ref_f32).abs().mean().item() - cos_sim = F.cosine_similarity( - o_f32.view(-1, D), ref_f32.view(-1, D), dim=1 - ) - min_cos = cos_sim.min().item() - results["max_err"] = max_err - results["mean_err"] = mean_err - results["min_cos"] = min_cos - - atol = 1e-2 - results["passed"] = max_err < atol and min_cos > 0.99 - - # Benchmark V4.1 - try: - def kernel_fn(): - o_flat.zero_() - exe(q_flat, k_flat, v_flat, o_flat, B, S) - - us = bench_gpu_us(kernel_fn, warmup=warmup, iters=iters) - s_eff = S / 2.0 if causal else float(S) - flops = 4.0 * S * s_eff * D * H * B - tflops = flops / (us * 1e-6) / 1e12 - results["us"] = us - results["tflops"] = tflops - except Exception as e: - results["bench_err"] = str(e) - - # Benchmark V4.0 for comparison - if v4_exe is not None: - try: - o_v4 = torch.zeros_like(q_flat) - - def v4_fn(): - o_v4.zero_() - v4_exe(q_flat, k_flat, v_flat, o_v4, B, S) - - v4_us = bench_gpu_us(v4_fn, warmup=warmup, iters=iters) - v4_tflops = flops / (v4_us * 1e-6) / 1e12 - results["v4_us"] = v4_us - results["v4_tflops"] = v4_tflops - except Exception as e: - results["v4_bench_err"] = str(e) - - return results - - -def main(): - parser = argparse.ArgumentParser( - description="Flash Attention V4.1 FlyDSL Test/Benchmark" - ) - parser.add_argument("--batch", type=int, default=None) - parser.add_argument("--seq_len", type=int, default=None) - parser.add_argument("--num_heads", type=int, default=None) - parser.add_argument("--head_dim", type=int, default=None) - parser.add_argument( - "--dtype", type=str, default="fp16", choices=["fp16"] - ) - parser.add_argument("--no-causal", action="store_true") - parser.add_argument("--warmup", type=int, default=5) - parser.add_argument("--iters", type=int, default=20) - parser.add_argument("--compare-v4", action="store_true", - help="Also benchmark V4.0 for comparison") - args = parser.parse_args() - - causal = not args.no_causal - dtype = torch.float16 - causal_str = "causal" if causal else "non-causal" - - print("=" * 130) - print(f"FlyDSL Flash Attention V4.1 ({causal_str}, fp16)") - print(f" Q-in-registers, transposed V (vectorized), bank-conflict-free LDS padding") - print(f" BLOCK_M=64, BLOCK_N=16, 4 waves (256 threads), mfma_f32_16x16x16f16") - print(f"GPU: {torch.cuda.get_device_name(0)}") - print("=" * 130) - - if args.seq_len or args.head_dim or args.batch: - configs = [( - args.batch or 1, - args.seq_len or 128, - args.num_heads or 8, - args.head_dim or 128, - )] - else: - configs = [ - (1, 64, 8, 128), - (1, 128, 8, 128), - (1, 256, 32, 128), - (1, 512, 32, 128), - (2, 128, 8, 128), - ] - - # Pre-compile V4.0 if comparing - v4_exes = {} - if args.compare_v4: - from kernels.flash_attention_v4 import build_flash_attention_v4_module - for _, _, nh, hd in configs: - key = (nh, hd) - if key not in v4_exes: - try: - m = build_flash_attention_v4_module( - num_heads=nh, head_dim=hd, - causal=causal, dtype_str="f16", - ) - v4_exes[key] = flydsl.compile(m) - except Exception: - v4_exes[key] = None - - if args.compare_v4: - hdr = ( - f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " - f"{'MinCos':>8s} | {'V4.1(us)':>10s} {'V4.1 TF':>9s} | " - f"{'V4.0(us)':>10s} {'V4.0 TF':>9s} | {'Speedup':>7s}" - ) - else: - hdr = ( - f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " - f"{'MinCos':>8s} | {'Time(us)':>10s} {'TFLOPS':>8s}" - ) - print(f"\n{hdr}") - print("-" * len(hdr)) - - all_passed = True - for batch, seq_len, nh, hd in configs: - tag = f"B={batch} S={seq_len} H={nh} D={hd}" - try: - v4_exe = v4_exes.get((nh, hd)) if args.compare_v4 else None - r = run_config( - batch, seq_len, nh, hd, dtype, causal, - warmup=args.warmup, iters=args.iters, - v4_exe=v4_exe, - ) - if "err" in r: - print(f"{tag:>38s} | {'ERROR':>6s} | {r['err'][:60]}") - all_passed = False - continue - - status = "PASS" if r["passed"] else "FAIL" - if not r["passed"]: - all_passed = False - - v41_us = f"{r['us']:>10.1f}" if "us" in r else " N/A" - v41_tf = f"{r['tflops']:>9.3f}" if "tflops" in r else " N/A" - - if args.compare_v4 and "v4_us" in r: - v4_us = f"{r['v4_us']:>10.1f}" - v4_tf = f"{r['v4_tflops']:>9.3f}" - speedup = r["v4_us"] / r["us"] if r.get("us") else 0 - sp_s = f"{speedup:>6.2f}x" - print( - f"{tag:>38s} | {status:>6s} | " - f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " - f"{v41_us} {v41_tf} | {v4_us} {v4_tf} | {sp_s}" - ) - else: - print( - f"{tag:>38s} | {status:>6s} | " - f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " - f"{v41_us} {v41_tf}" - ) - except Exception as e: - print(f"{tag:>38s} | {'ERROR':>6s} | {str(e)[:60]}") - all_passed = False - - print("=" * 130) - if all_passed: - print("All tests PASSED") - else: - print("Some tests FAILED") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/tests/kernels/test_flash_attention_v4_2.py b/tests/kernels/test_flash_attention_v4_2.py deleted file mode 100644 index 5a69a7a4..00000000 --- a/tests/kernels/test_flash_attention_v4_2.py +++ /dev/null @@ -1,394 +0,0 @@ -#!/usr/bin/env python3 -"""Flash Attention V4.2 kernel test and benchmark for FlyDSL. - -Tests V4.2 (BLOCK_N=32, transposed V, Q-in-registers) against PyTorch SDPA. -Optionally compares with V4.1. - -Usage: - python tests/kernels/test_flash_attention_v4_2.py - python tests/kernels/test_flash_attention_v4_2.py --seq_len 512 --head_dim 128 - python tests/kernels/test_flash_attention_v4_2.py --compare-v41 -""" - -import sys -import argparse -import hashlib -import random -from pathlib import Path -import logging - -# Configure logging to show INFO level messages (required for kernel name display) -logging.basicConfig(level=logging.INFO) - -_repo = Path(__file__).resolve().parents[2] -sys.path.insert(0, str(_repo)) - -try: - import torch - import torch.nn.functional as F - import numpy as np -except ImportError: - print("PyTorch not available") - sys.exit(1) - -if not torch.cuda.is_available(): - print("CUDA/ROCm not available") - sys.exit(1) - -import flydsl -from kernels.flash_attention_v4_2 import build_flash_attention_v4_2_module, KERNEL_NAME -from tests.test_common import run_perftest - -# Tensor initialization range (uniform distribution) -UNIFORM_RANGE = (-1, 1) -DEFAULT_SEED = 123 - - -def setup_seed(seed: int) -> None: - """Set random seed for reproducibility across all RNG sources.""" - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - - -def pytorch_ref_attention(q, k, v, causal=True): - q_t = q.transpose(1, 2).float() - k_t = k.transpose(1, 2).float() - v_t = v.transpose(1, 2).float() - out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=causal) - return out.transpose(1, 2) - - -def compute_md5(tensor: torch.Tensor) -> str: - """Compute MD5 hash of a tensor's raw bytes.""" - return hashlib.md5( - tensor.contiguous().view(torch.uint8).detach().cpu().numpy().tobytes() - ).hexdigest() - - -def compare_arrays( - arr1: np.ndarray, - arr2: np.ndarray, - k: int = 5, - thresholds: list = None, -) -> dict: - """Compare two numpy arrays and compute various difference metrics. - - Args: - arr1: First input array (result), will be cast to float32. - arr2: Second input array (reference), will be cast to float32. - k: Number of top differences to report. - thresholds: Difference magnitude buckets for histogram. - - Returns: - Dictionary with top_k_diff, threshold_stats, nan_info, max_diff, max_diff_thr. - """ - if thresholds is None: - thresholds = [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1] - - if arr1.shape != arr2.shape: - raise ValueError( - f"Shape mismatch: arr1 {arr1.shape} vs arr2 {arr2.shape}" - ) - - arr1 = arr1.astype(np.float32) - arr2 = arr2.astype(np.float32) - - result = {"top_k_diff": [], "threshold_stats": [], "nan_info": {}} - - # Check for NaN values - nan_mask1 = np.isnan(arr1) - nan_mask2 = np.isnan(arr2) - if np.any(nan_mask1): - result["nan_info"]["arr1_nan_count"] = int(np.sum(nan_mask1)) - print(f" Warning: result contains {result['nan_info']['arr1_nan_count']} NaN values") - if np.any(nan_mask2): - result["nan_info"]["arr2_nan_count"] = int(np.sum(nan_mask2)) - print(f" Warning: reference contains {result['nan_info']['arr2_nan_count']} NaN values") - - # Compute absolute differences - diff = np.abs(arr1 - arr2) - total_elements = arr1.size - - max_diff_thr = (diff / (1.0 + np.abs(arr2))).max() - result["max_diff"] = float(diff.max()) - result["max_diff_thr"] = float(max_diff_thr) - - print(f" diff.abs.max = {diff.max():.6f}") - print(f" diff.abs.mean = {diff.mean():.6f}") - print(f" max_diff_thr (rel) = {max_diff_thr:.6e}") - - # Find top k differences - flat_diff = diff.flatten() - actual_k = min(k, len(flat_diff)) - top_k_indices = np.argpartition(flat_diff, -actual_k)[-actual_k:] - top_k_indices = top_k_indices[np.argsort(-flat_diff[top_k_indices])] - - orig_indices = np.unravel_index(top_k_indices, diff.shape) - print(f" Top-{actual_k} differences:") - for i in range(actual_k): - idx = tuple(dim[i] for dim in orig_indices) - entry = { - "value": float(diff[idx]), - "position": idx, - "arr1_value": float(arr1[idx]), - "arr2_value": float(arr2[idx]), - } - result["top_k_diff"].append(entry) - print(f" [{idx}] result={arr1[idx]:.6f}, ref={arr2[idx]:.6f}, diff={diff[idx]:.6f}") - - # Compute threshold statistics - print(f" Threshold distribution ({total_elements} elements):") - for i in range(len(thresholds) - 1): - lower, upper = thresholds[i], thresholds[i + 1] - count = int(np.sum((diff >= lower) & (diff < upper))) - pct = 100.0 * count / total_elements - result["threshold_stats"].append( - {"range": f"[{lower:.0e}, {upper:.0e})", "count": count, "percentage": pct} - ) - print(f" [{lower:.0e}, {upper:.0e}): {count:>8d} ({pct:6.2f}%)") - - count = int(np.sum(diff >= thresholds[-1])) - pct = 100.0 * count / total_elements - result["threshold_stats"].append( - {"range": f">={thresholds[-1]:.0e}", "count": count, "percentage": pct} - ) - print(f" >={thresholds[-1]:.0e} : {count:>8d} ({pct:6.2f}%)") - - return result - - -def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, - warmup, iters, prev_exe=None, seed=DEFAULT_SEED): - device = "cuda" - results = {} - - if seq_len % 64 != 0: - results["err"] = f"seq_len ({seq_len}) must be divisible by 64" - return results - if head_dim % 16 != 0 or head_dim < 64: - results["err"] = f"head_dim ({head_dim}) must be >= 64 and divisible by 16" - return results - - try: - m = build_flash_attention_v4_2_module( - num_heads=num_heads, head_dim=head_dim, - causal=causal, dtype_str="f16", - ) - exe = flydsl.compile(m) - except Exception as e: - results["err"] = f"compile: {e}" - import traceback - traceback.print_exc() - return results - - B, S, H, D = batch, seq_len, num_heads, head_dim - setup_seed(seed) - q_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) - k_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) - v_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) - - q_flat = q_4d.contiguous().view(-1) - k_flat = k_4d.contiguous().view(-1) - v_flat = v_4d.contiguous().view(-1) - o_flat = torch.zeros_like(q_flat) - - try: - exe(q_flat, k_flat, v_flat, o_flat, B, S) - torch.cuda.synchronize() - except Exception as e: - results["err"] = f"exec: {e}" - import traceback - traceback.print_exc() - return results - - ref_4d = pytorch_ref_attention( - q_4d.float(), k_4d.float(), v_4d.float(), causal=causal - ).to(dtype) - ref_flat = ref_4d.contiguous().view(-1) - - o_f32 = o_flat.float() - ref_f32 = ref_flat.float() - max_err = (o_f32 - ref_f32).abs().max().item() - mean_err = (o_f32 - ref_f32).abs().mean().item() - cos_sim = F.cosine_similarity( - o_f32.view(-1, D), ref_f32.view(-1, D), dim=1 - ) - min_cos = cos_sim.min().item() - results["max_err"] = max_err - results["mean_err"] = mean_err - results["min_cos"] = min_cos - results["passed"] = max_err < 1e-2 and min_cos > 0.99 - - # Compute and print MD5 hashes - tag = f"B={B} S={S} H={H} D={D}" - result_md5 = compute_md5(o_flat) - ref_md5 = compute_md5(ref_flat) - print(f" [{tag}] result_md5 = {result_md5}") - print(f" [{tag}] ref_md5 = {ref_md5}") - if result_md5 == ref_md5: - print(f" [{tag}] MD5 match: EXACT (bit-identical)") - else: - print(f" [{tag}] MD5 match: DIFFER (not bit-identical)") - - # Detailed comparison using compare_arrays - print(f" [{tag}] --- compare_arrays ---") - compare_arrays( - o_flat.to(torch.float32).detach().cpu().numpy(), - ref_flat.to(torch.float32).detach().cpu().numpy(), - ) - - try: - def kernel_fn(): - o_flat.zero_() - exe(q_flat, k_flat, v_flat, o_flat, B, S) - - _, us = run_perftest(kernel_fn, num_iters=iters, num_warmup=warmup) - s_eff = S / 2.0 if causal else float(S) - flops = 4.0 * S * s_eff * D * H * B - tflops = flops / (us * 1e-6) / 1e12 - results["us"] = us - results["tflops"] = tflops - except Exception as e: - results["bench_err"] = str(e) - - if prev_exe is not None: - try: - o_prev = torch.zeros_like(q_flat) - def prev_fn(): - o_prev.zero_() - prev_exe(q_flat, k_flat, v_flat, o_prev, B, S) - _, prev_us = run_perftest(prev_fn, num_iters=iters, num_warmup=warmup) - prev_tflops = flops / (prev_us * 1e-6) / 1e12 - results["prev_us"] = prev_us - results["prev_tflops"] = prev_tflops - except Exception as e: - results["prev_bench_err"] = str(e) - - return results - - -def main(): - parser = argparse.ArgumentParser( - description="Flash Attention V4.2 FlyDSL Test/Benchmark" - ) - parser.add_argument("--batch", type=int, default=None) - parser.add_argument("--seq_len", type=int, default=None) - parser.add_argument("--num_heads", type=int, default=None) - parser.add_argument("--head_dim", type=int, default=None) - parser.add_argument("--no-causal", action="store_true") - parser.add_argument("--warmup", type=int, default=5) - parser.add_argument("--iters", type=int, default=20) - parser.add_argument("--compare-v41", action="store_true", - help="Also benchmark V4.1 for comparison") - parser.add_argument("--seed", type=int, default=DEFAULT_SEED, - help=f"Random seed for reproducibility (default: {DEFAULT_SEED})") - args = parser.parse_args() - - causal = not args.no_causal - dtype = torch.float16 - - print("=" * 130) - print(f"FlyDSL Flash Attention V4.2 ({'causal' if causal else 'non-causal'}, fp16)") - print(f" BLOCK_N=32, Q-in-registers, transposed V, bank-conflict-free LDS") - print(f" BLOCK_M=64, 4 waves (256 threads), mfma_f32_16x16x16f16") - print(f"GPU: {torch.cuda.get_device_name(0)}") - print("=" * 130) - - if args.seq_len or args.head_dim or args.batch: - configs = [( - args.batch or 1, - args.seq_len or 128, - args.num_heads or 8, - args.head_dim or 128, - )] - else: - configs = [ - (1, 64, 8, 128), - (1, 128, 8, 128), - (1, 256, 32, 128), - (1, 512, 32, 128), - (2, 128, 8, 128), - ] - - prev_exes = {} - if args.compare_v41: - from kernels.flash_attention_v4_1 import build_flash_attention_v4_1_module - for _, _, nh, hd in configs: - key = (nh, hd) - if key not in prev_exes: - try: - m = build_flash_attention_v4_1_module( - num_heads=nh, head_dim=hd, - causal=causal, dtype_str="f16", - ) - prev_exes[key] = flydsl.compile(m) - except Exception: - prev_exes[key] = None - - if args.compare_v41: - hdr = ( - f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " - f"{'MinCos':>8s} | {'V4.2(us)':>10s} {'V4.2 TF':>9s} | " - f"{'V4.1(us)':>10s} {'V4.1 TF':>9s} | {'Speedup':>7s}" - ) - else: - hdr = ( - f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " - f"{'MinCos':>8s} | {'Time(us)':>10s} {'TFLOPS':>8s}" - ) - print(f"\n{hdr}") - print("-" * len(hdr)) - - all_passed = True - for batch, seq_len, nh, hd in configs: - tag = f"B={batch} S={seq_len} H={nh} D={hd}" - try: - prev_exe = prev_exes.get((nh, hd)) if args.compare_v41 else None - r = run_config( - batch, seq_len, nh, hd, dtype, causal, - warmup=args.warmup, iters=args.iters, - prev_exe=prev_exe, seed=args.seed, - ) - if "err" in r: - print(f"{tag:>38s} | {'ERROR':>6s} | {r['err'][:60]}") - all_passed = False - continue - - status = "PASS" if r["passed"] else "FAIL" - if not r["passed"]: - all_passed = False - - us_s = f"{r['us']:>10.1f}" if "us" in r else " N/A" - tf_s = f"{r['tflops']:>9.3f}" if "tflops" in r else " N/A" - - if args.compare_v41 and "prev_us" in r: - p_us = f"{r['prev_us']:>10.1f}" - p_tf = f"{r['prev_tflops']:>9.3f}" - speedup = r["prev_us"] / r["us"] if r.get("us") else 0 - print( - f"{tag:>38s} | {status:>6s} | " - f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " - f"{us_s} {tf_s} | {p_us} {p_tf} | {speedup:>6.2f}x" - ) - else: - print( - f"{tag:>38s} | {status:>6s} | " - f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " - f"{us_s} {tf_s}" - ) - except Exception as e: - print(f"{tag:>38s} | {'ERROR':>6s} | {str(e)[:60]}") - all_passed = False - - print("=" * 130) - if all_passed: - print("All tests PASSED") - else: - print("Some tests FAILED") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/tests/kernels/test_flash_attention_v4_3.py b/tests/kernels/test_flash_attention_v4_3.py deleted file mode 100644 index 28cd1e97..00000000 --- a/tests/kernels/test_flash_attention_v4_3.py +++ /dev/null @@ -1,394 +0,0 @@ -#!/usr/bin/env python3 -"""Flash Attention V4.3 kernel test and benchmark for FlyDSL. - -Tests V4.3 (LDS overlay) against PyTorch SDPA. -Optionally compares with V4.2. - -Usage: - python tests/kernels/test_flash_attention_v4_3.py - python tests/kernels/test_flash_attention_v4_3.py --seq_len 512 --head_dim 128 - python tests/kernels/test_flash_attention_v4_3.py --compare-v42 -""" - -import sys -import argparse -import hashlib -import random -from pathlib import Path -import logging - -# Configure logging to show INFO level messages (required for kernel name display) -logging.basicConfig(level=logging.INFO) - -_repo = Path(__file__).resolve().parents[2] -sys.path.insert(0, str(_repo)) - -try: - import torch - import torch.nn.functional as F - import numpy as np -except ImportError: - print("PyTorch not available") - sys.exit(1) - -if not torch.cuda.is_available(): - print("CUDA/ROCm not available") - sys.exit(1) - -import flydsl -from kernels.flash_attention_v4_3 import build_flash_attention_v4_3_module, KERNEL_NAME -from tests.test_common import run_perftest - -# Tensor initialization range (uniform distribution) -UNIFORM_RANGE = (-1, 1) -DEFAULT_SEED = 123 - - -def setup_seed(seed: int) -> None: - """Set random seed for reproducibility across all RNG sources.""" - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - - -def pytorch_ref_attention(q, k, v, causal=True): - q_t = q.transpose(1, 2).float() - k_t = k.transpose(1, 2).float() - v_t = v.transpose(1, 2).float() - out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=causal) - return out.transpose(1, 2) - - -def compute_md5(tensor: torch.Tensor) -> str: - """Compute MD5 hash of a tensor's raw bytes.""" - return hashlib.md5( - tensor.contiguous().view(torch.uint8).detach().cpu().numpy().tobytes() - ).hexdigest() - - -def compare_arrays( - arr1: np.ndarray, - arr2: np.ndarray, - k: int = 5, - thresholds: list = None, -) -> dict: - """Compare two numpy arrays and compute various difference metrics. - - Args: - arr1: First input array (result), will be cast to float32. - arr2: Second input array (reference), will be cast to float32. - k: Number of top differences to report. - thresholds: Difference magnitude buckets for histogram. - - Returns: - Dictionary with top_k_diff, threshold_stats, nan_info, max_diff, max_diff_thr. - """ - if thresholds is None: - thresholds = [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1] - - if arr1.shape != arr2.shape: - raise ValueError( - f"Shape mismatch: arr1 {arr1.shape} vs arr2 {arr2.shape}" - ) - - arr1 = arr1.astype(np.float32) - arr2 = arr2.astype(np.float32) - - result = {"top_k_diff": [], "threshold_stats": [], "nan_info": {}} - - # Check for NaN values - nan_mask1 = np.isnan(arr1) - nan_mask2 = np.isnan(arr2) - if np.any(nan_mask1): - result["nan_info"]["arr1_nan_count"] = int(np.sum(nan_mask1)) - print(f" Warning: result contains {result['nan_info']['arr1_nan_count']} NaN values") - if np.any(nan_mask2): - result["nan_info"]["arr2_nan_count"] = int(np.sum(nan_mask2)) - print(f" Warning: reference contains {result['nan_info']['arr2_nan_count']} NaN values") - - # Compute absolute differences - diff = np.abs(arr1 - arr2) - total_elements = arr1.size - - max_diff_thr = (diff / (1.0 + np.abs(arr2))).max() - result["max_diff"] = float(diff.max()) - result["max_diff_thr"] = float(max_diff_thr) - - print(f" diff.abs.max = {diff.max():.6f}") - print(f" diff.abs.mean = {diff.mean():.6f}") - print(f" max_diff_thr (rel) = {max_diff_thr:.6e}") - - # Find top k differences - flat_diff = diff.flatten() - actual_k = min(k, len(flat_diff)) - top_k_indices = np.argpartition(flat_diff, -actual_k)[-actual_k:] - top_k_indices = top_k_indices[np.argsort(-flat_diff[top_k_indices])] - - orig_indices = np.unravel_index(top_k_indices, diff.shape) - print(f" Top-{actual_k} differences:") - for i in range(actual_k): - idx = tuple(dim[i] for dim in orig_indices) - entry = { - "value": float(diff[idx]), - "position": idx, - "arr1_value": float(arr1[idx]), - "arr2_value": float(arr2[idx]), - } - result["top_k_diff"].append(entry) - print(f" [{idx}] result={arr1[idx]:.6f}, ref={arr2[idx]:.6f}, diff={diff[idx]:.6f}") - - # Compute threshold statistics - print(f" Threshold distribution ({total_elements} elements):") - for i in range(len(thresholds) - 1): - lower, upper = thresholds[i], thresholds[i + 1] - count = int(np.sum((diff >= lower) & (diff < upper))) - pct = 100.0 * count / total_elements - result["threshold_stats"].append( - {"range": f"[{lower:.0e}, {upper:.0e})", "count": count, "percentage": pct} - ) - print(f" [{lower:.0e}, {upper:.0e}): {count:>8d} ({pct:6.2f}%)") - - count = int(np.sum(diff >= thresholds[-1])) - pct = 100.0 * count / total_elements - result["threshold_stats"].append( - {"range": f">={thresholds[-1]:.0e}", "count": count, "percentage": pct} - ) - print(f" >={thresholds[-1]:.0e} : {count:>8d} ({pct:6.2f}%)") - - return result - - -def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, - warmup, iters, prev_exe=None, seed=DEFAULT_SEED): - device = "cuda" - results = {} - - if seq_len % 64 != 0: - results["err"] = f"seq_len ({seq_len}) must be divisible by 64" - return results - if head_dim % 16 != 0 or head_dim < 64: - results["err"] = f"head_dim ({head_dim}) must be >= 64 and divisible by 16" - return results - - try: - m = build_flash_attention_v4_3_module( - num_heads=num_heads, head_dim=head_dim, - causal=causal, dtype_str="f16", - ) - exe = flydsl.compile(m) - except Exception as e: - results["err"] = f"compile: {e}" - import traceback - traceback.print_exc() - return results - - B, S, H, D = batch, seq_len, num_heads, head_dim - setup_seed(seed) - q_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) - k_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) - v_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) - - q_flat = q_4d.contiguous().view(-1) - k_flat = k_4d.contiguous().view(-1) - v_flat = v_4d.contiguous().view(-1) - o_flat = torch.zeros_like(q_flat) - - try: - exe(q_flat, k_flat, v_flat, o_flat, B, S) - torch.cuda.synchronize() - except Exception as e: - results["err"] = f"exec: {e}" - import traceback - traceback.print_exc() - return results - - ref_4d = pytorch_ref_attention( - q_4d.float(), k_4d.float(), v_4d.float(), causal=causal - ).to(dtype) - ref_flat = ref_4d.contiguous().view(-1) - - o_f32 = o_flat.float() - ref_f32 = ref_flat.float() - max_err = (o_f32 - ref_f32).abs().max().item() - mean_err = (o_f32 - ref_f32).abs().mean().item() - cos_sim = F.cosine_similarity( - o_f32.view(-1, D), ref_f32.view(-1, D), dim=1 - ) - min_cos = cos_sim.min().item() - results["max_err"] = max_err - results["mean_err"] = mean_err - results["min_cos"] = min_cos - results["passed"] = max_err < 1e-2 and min_cos > 0.99 - - # Compute and print MD5 hashes - tag = f"B={B} S={S} H={H} D={D}" - result_md5 = compute_md5(o_flat) - ref_md5 = compute_md5(ref_flat) - print(f" [{tag}] result_md5 = {result_md5}") - print(f" [{tag}] ref_md5 = {ref_md5}") - if result_md5 == ref_md5: - print(f" [{tag}] MD5 match: EXACT (bit-identical)") - else: - print(f" [{tag}] MD5 match: DIFFER (not bit-identical)") - - # Detailed comparison using compare_arrays - print(f" [{tag}] --- compare_arrays ---") - compare_arrays( - o_flat.to(torch.float32).detach().cpu().numpy(), - ref_flat.to(torch.float32).detach().cpu().numpy(), - ) - - try: - def kernel_fn(): - # o_flat.zero_() - exe(q_flat, k_flat, v_flat, o_flat, B, S) - - _, us = run_perftest(kernel_fn, num_iters=iters, num_warmup=warmup) - s_eff = S / 2.0 if causal else float(S) - flops = 4.0 * S * s_eff * D * H * B - tflops = flops / (us * 1e-6) / 1e12 - results["us"] = us - results["tflops"] = tflops - except Exception as e: - results["bench_err"] = str(e) - - if prev_exe is not None: - try: - o_prev = torch.zeros_like(q_flat) - def prev_fn(): - # o_prev.zero_() - prev_exe(q_flat, k_flat, v_flat, o_prev, B, S) - _, prev_us = run_perftest(prev_fn, num_iters=iters, num_warmup=warmup) - prev_tflops = flops / (prev_us * 1e-6) / 1e12 - results["prev_us"] = prev_us - results["prev_tflops"] = prev_tflops - except Exception as e: - results["prev_bench_err"] = str(e) - - return results - - -def main(): - parser = argparse.ArgumentParser( - description="Flash Attention V4.3 FlyDSL Test/Benchmark" - ) - parser.add_argument("--batch", type=int, default=None) - parser.add_argument("--seq_len", type=int, default=None) - parser.add_argument("--num_heads", type=int, default=None) - parser.add_argument("--head_dim", type=int, default=None) - parser.add_argument("--no-causal", action="store_true") - parser.add_argument("--warmup", type=int, default=5) - parser.add_argument("--iters", type=int, default=20) - parser.add_argument("--compare-v42", action="store_true", - help="Also benchmark V4.2 for comparison") - parser.add_argument("--seed", type=int, default=DEFAULT_SEED, - help=f"Random seed for reproducibility (default: {DEFAULT_SEED})") - args = parser.parse_args() - - causal = not args.no_causal - dtype = torch.float16 - - print("=" * 130) - print(f"FlyDSL Flash Attention V4.3 ({'causal' if causal else 'non-causal'}, fp16)") - print(f" LDS overlay: Q space reused for KV+P (16KB vs 29KB)") - print(f" BLOCK_M=64, BLOCK_N=32, 4 waves (256 threads), mfma_f32_16x16x16f16") - print(f"GPU: {torch.cuda.get_device_name(0)}") - print("=" * 130) - - if args.seq_len or args.head_dim or args.batch: - configs = [( - args.batch or 1, - args.seq_len or 128, - args.num_heads or 8, - args.head_dim or 128, - )] - else: - configs = [ - (1, 64, 8, 128), - (1, 128, 8, 128), - (1, 256, 32, 128), - (1, 512, 32, 128), - (2, 128, 8, 128), - ] - - prev_exes = {} - if args.compare_v42: - from kernels.flash_attention_v4_2 import build_flash_attention_v4_2_module - for _, _, nh, hd in configs: - key = (nh, hd) - if key not in prev_exes: - try: - m = build_flash_attention_v4_2_module( - num_heads=nh, head_dim=hd, - causal=causal, dtype_str="f16", - ) - prev_exes[key] = flydsl.compile(m) - except Exception: - prev_exes[key] = None - - if args.compare_v42: - hdr = ( - f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " - f"{'MinCos':>8s} | {'V4.3(us)':>10s} {'V4.3 TF':>9s} | " - f"{'V4.2(us)':>10s} {'V4.2 TF':>9s} | {'Speedup':>7s}" - ) - else: - hdr = ( - f"{'Config':>38s} | {'Status':>6s} | {'MaxErr':>8s} " - f"{'MinCos':>8s} | {'Time(us)':>10s} {'TFLOPS':>8s}" - ) - print(f"\n{hdr}") - print("-" * len(hdr)) - - all_passed = True - for batch, seq_len, nh, hd in configs: - tag = f"B={batch} S={seq_len} H={nh} D={hd}" - try: - prev_exe = prev_exes.get((nh, hd)) if args.compare_v42 else None - r = run_config( - batch, seq_len, nh, hd, dtype, causal, - warmup=args.warmup, iters=args.iters, - prev_exe=prev_exe, seed=args.seed, - ) - if "err" in r: - print(f"{tag:>38s} | {'ERROR':>6s} | {r['err'][:60]}") - all_passed = False - continue - - status = "PASS" if r["passed"] else "FAIL" - if not r["passed"]: - all_passed = False - - us_s = f"{r['us']:>10.1f}" if "us" in r else " N/A" - tf_s = f"{r['tflops']:>9.3f}" if "tflops" in r else " N/A" - - if args.compare_v42 and "prev_us" in r: - p_us = f"{r['prev_us']:>10.1f}" - p_tf = f"{r['prev_tflops']:>9.3f}" - speedup = r["prev_us"] / r["us"] if r.get("us") else 0 - print( - f"{tag:>38s} | {status:>6s} | " - f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " - f"{us_s} {tf_s} | {p_us} {p_tf} | {speedup:>6.2f}x" - ) - else: - print( - f"{tag:>38s} | {status:>6s} | " - f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " - f"{us_s} {tf_s}" - ) - except Exception as e: - print(f"{tag:>38s} | {'ERROR':>6s} | {str(e)[:60]}") - all_passed = False - - print("=" * 130) - if all_passed: - print("All tests PASSED") - else: - print("Some tests FAILED") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/tests/kernels/test_flash_attn_func.py b/tests/kernels/test_flash_attn_func.py index 6748a250..897b06cd 100644 --- a/tests/kernels/test_flash_attn_func.py +++ b/tests/kernels/test_flash_attn_func.py @@ -2,7 +2,6 @@ """flash_attn_func kernel test and benchmark for FlyDSL. Tests flash_attn_func against PyTorch SDPA. -Optionally compares with V4.3. """ import sys @@ -162,7 +161,9 @@ def compare_arrays( return result -def run_config(batch, seq_len, num_heads, head_dim, dtype, causal, warmup, iters, prev_exe=None, seed=DEFAULT_SEED): +def run_config( + batch, seq_len, num_heads, head_dim, dtype, causal, warmup, iters, seed=DEFAULT_SEED +): device = "cuda" results = {} active_path = select_flash_attn_func_path( @@ -255,20 +256,6 @@ def kernel_fn(): except Exception as e: results["bench_err"] = str(e) - if prev_exe is not None: - try: - o_prev = torch.zeros_like(q_flat) - - def prev_fn(): - prev_exe(q_flat, k_flat, v_flat, o_prev, B, S) - - _, prev_us = run_perftest(prev_fn, num_iters=iters, num_warmup=warmup) - prev_tflops = flops / (prev_us * 1e-6) / 1e12 - results["prev_us"] = prev_us - results["prev_tflops"] = prev_tflops - except Exception as e: - results["prev_bench_err"] = str(e) - return results @@ -281,7 +268,6 @@ def main(): parser.add_argument("--no-causal", action="store_true") parser.add_argument("--warmup", type=int, default=5) parser.add_argument("--iters", type=int, default=20) - parser.add_argument("--compare-v43", action="store_true", help="Also benchmark V4.3 for comparison") parser.add_argument( "--seed", type=int, default=DEFAULT_SEED, help=f"Random seed for reproducibility (default: {DEFAULT_SEED})" ) @@ -309,32 +295,10 @@ def main(): (1, 8192, 64, 128), ] - prev_exes = {} - if args.compare_v43: - from kernels.flash_attention_v4_3 import build_flash_attention_v4_3_module - - for _, _, nh, hd in configs: - key = (nh, hd) - if key not in prev_exes: - try: - m = build_flash_attention_v4_3_module( - num_heads=nh, head_dim=hd, causal=causal, dtype_str="f16" - ) - prev_exes[key] = flydsl.compile(m) - except Exception: - prev_exes[key] = None - - if args.compare_v43: - hdr = ( - f"{'Config/Path':>56s} | {'Status':>6s} | {'MaxErr':>8s} " - f"{'MinCos':>8s} | {'Func(us)':>10s} {'Func TF':>9s} | " - f"{'V4.3(us)':>10s} {'V4.3 TF':>9s} | {'Speedup':>7s}" - ) - else: - hdr = ( - f"{'Config/Path':>56s} | {'Status':>6s} | {'MaxErr':>8s} " - f"{'MinCos':>8s} | {'Time(us)':>10s} {'TFLOPS':>8s}" - ) + hdr = ( + f"{'Config/Path':>56s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'MinCos':>8s} | {'Time(us)':>10s} {'TFLOPS':>8s}" + ) print(f"\n{hdr}") print("-" * len(hdr)) @@ -342,7 +306,6 @@ def main(): for batch, seq_len, nh, hd in configs: tag = f"B={batch} S={seq_len} H={nh} D={hd}" try: - prev_exe = prev_exes.get((nh, hd)) if args.compare_v43 else None r = run_config( batch, seq_len, @@ -352,7 +315,6 @@ def main(): causal, warmup=args.warmup, iters=args.iters, - prev_exe=prev_exe, seed=args.seed, ) if "err" in r: @@ -368,22 +330,11 @@ def main(): us_s = f"{r['us']:>10.1f}" if "us" in r else " N/A" tf_s = f"{r['tflops']:>9.3f}" if "tflops" in r else " N/A" - - if args.compare_v43 and "prev_us" in r: - p_us = f"{r['prev_us']:>10.1f}" - p_tf = f"{r['prev_tflops']:>9.3f}" - speedup = r["prev_us"] / r["us"] if r.get("us") else 0 - print( - f"{cfg_path:>56s} | {status:>6s} | " - f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " - f"{us_s} {tf_s} | {p_us} {p_tf} | {speedup:>6.2f}x" - ) - else: - print( - f"{cfg_path:>56s} | {status:>6s} | " - f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " - f"{us_s} {tf_s}" - ) + print( + f"{cfg_path:>56s} | {status:>6s} | " + f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " + f"{us_s} {tf_s}" + ) except Exception as e: print(f"{tag:>56s} | {'ERROR':>6s} | {str(e)[:60]}") all_passed = False From 1c49a1c1c9f5cba29da6da37ad90beef80e65af3 Mon Sep 17 00:00:00 2001 From: yanguahe Date: Fri, 13 Feb 2026 19:38:30 +0800 Subject: [PATCH 14/17] Remove temp file --- input.yaml | 17 - kernels/simple_gemm.py | 615 ------------------------ run.sh | 102 ---- tests/kernels/test_moe_stage1_simple.py | 193 -------- tests/kernels/test_simple_gemm.py | 518 -------------------- thread_trace/.gitkeep | 0 6 files changed, 1445 deletions(-) delete mode 100644 input.yaml delete mode 100644 kernels/simple_gemm.py delete mode 100755 run.sh delete mode 100644 tests/kernels/test_moe_stage1_simple.py delete mode 100644 tests/kernels/test_simple_gemm.py delete mode 100644 thread_trace/.gitkeep diff --git a/input.yaml b/input.yaml deleted file mode 100644 index 8477d1e0..00000000 --- a/input.yaml +++ /dev/null @@ -1,17 +0,0 @@ -jobs: - - - kernel_include_regex: (kernel_gemm) - kernel_exclude_regex: - kernel_iteration_range: "[1]" - output_file: out - output_directory: thread_trace/rpf_v3 - output_format: [csv] - truncate_kernels: false - sys_trace: false # enable for pftrace and otf2 - advanced_thread_trace: true # enable for att and ui folder - att_target_cu: 1 - att_shader_engine_mask: "0xf" # collect one CU from 4 SEs - att_simd_select: "0xf" # collect 4 SIMDs on single CU - att_buffer_size: "0x6000000" - - - pmc: [SQ_WAVES, FETCH_SIZE] diff --git a/kernels/simple_gemm.py b/kernels/simple_gemm.py deleted file mode 100644 index 0adde3f4..00000000 --- a/kernels/simple_gemm.py +++ /dev/null @@ -1,615 +0,0 @@ -"""Simple GEMM kernel implementation using FlyDSL (MFMA 16x16x16). - -This module provides a simple GEMM kernel (C = A × B^T) for AMD GPUs using MFMA instructions. - -Configuration: -- Block: 256 threads = 4 waves × 64 lanes -- Tile: M=16 × N=64 × K=128 (configurable) -- Currently supports bf16/fp16 input, f32 accumulator, bf16/fp16 output - -Non-aligned shape handling (Triton-like approach): -- M and N: mask-based loads/stores in kernel (no host padding needed) -- K: padded to tile_k on host (required for MFMA vector loads) -- num_records_bytes: explicitly set in buffer resource descriptor for hardware OOB - -A matrix loading: -- GM → GPR → LDS: 256 threads cooperatively load the A tile with XOR16 swizzle -- Mask-based: OOB elements load zeros via buffer descriptor bounds checking - -B matrix loading: -- Direct load: Each wave handles 16 columns of N -- Mask-based: OOB elements load zeros via buffer descriptor bounds checking - -Output C matrix (16×64): -- Wave 0 → C[0:16, 0:16] -- Wave 1 → C[0:16, 16:32] -- Wave 2 → C[0:16, 32:48] -- Wave 3 → C[0:16, 48:64] -- Mask-based stores: OOB stores are skipped via select(mask, offset, MAX_OFFSET) -""" - -import functools - -import flydsl -from flydsl.dialects.ext import flir -from flydsl.dialects.ext.python_control_flow import range_constexpr -from flydsl.runtime.device import get_rocm_arch as get_hip_arch -from flydsl.utils import SmemAllocator - -from _mlir import ir - -from flydsl.dialects.ext import arith, gpu, buffer_ops, vector, rocdl -from flydsl.lang.ir.types import T, memref - - -def _align_up(val: int, align: int) -> int: - """Round up val to the next multiple of align.""" - return ((val + align - 1) // align) * align - - -@functools.lru_cache(maxsize=1024) -def compile_simple_gemm( - *, - tile_m: int = 16, - tile_n: int = 64, - tile_k: int = 128, - in_dtype: str = "bf16", - waves_per_eu: int = None, -): - """Compile a simple GEMM kernel and return the compiled executable. - - This kernel supports non-aligned M, N, K dimensions via mask-based loads/stores. - No host-side padding required. - - Args: - tile_m, tile_n, tile_k: Block tile sizes. - in_dtype: Input data type ("bf16" or "fp16"). - waves_per_eu: Optional hint for AMDGPU backend about the desired number of waves - per execution unit. This affects occupancy optimization. - """ - if in_dtype not in ("bf16", "fp16"): - raise ValueError(f"in_dtype must be 'bf16' or 'fp16', got {in_dtype!r}") - - is_bf16 = in_dtype == "bf16" - elem_bytes = 2 # bf16 and fp16 are both 2 bytes - out_elem_bytes = 2 # output is also bf16/fp16 - - # Validate tile configuration - tile_k_bytes = tile_k * elem_bytes - if tile_k_bytes % 64 != 0: - raise ValueError( - f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes}" - ) - - gpu_arch = get_hip_arch() - allocator = SmemAllocator(None, arch=gpu_arch) - _state = {} - - DYN = ir.ShapedType.get_dynamic_size() - total_threads = 256 - - # LDS configuration: tile_m × tile_k elements - lds_stride = tile_k # No padding for simplicity - - # Type helpers - def _elem_type(): - return T.bf16 if is_bf16 else T.f16 - - def _vec8_type(): - """16B vector (8 bf16/fp16 elements).""" - return T.bf16x8 if is_bf16 else T.f16x8 - - def _out_type(): - """Output element type.""" - return T.bf16 if is_bf16 else T.f16 - - module_name = f"simple_gemm_{in_dtype}_t{tile_m}x{tile_n}x{tile_k}".replace("-", "_") - - class _GEMM(flir.MlirModule): - GPU_MODULE_NAME = module_name - GPU_MODULE_TARGETS = [ - f'#rocdl.target' - ] - - def init_gpu_module(self): - # Allocate LDS for A tile: tile_m × tile_k elements - lds_a_elems = tile_m * lds_stride - _state["lds_a_decl"] = allocator.allocate_array(_elem_type(), lds_a_elems) - allocator.finalize() - - @flir.kernel - def kernel_gemm( - self: flir.T.i64, - arg_c: lambda: memref(DYN, _out_type()), - arg_a: lambda: memref(DYN, _elem_type()), - arg_b: lambda: memref(DYN, _elem_type()), - c_m: lambda: T.index, # Original M dimension - c_n: lambda: T.index, # Original N dimension - c_k: lambda: T.index, # Original K dimension - ): - # ================= Types ================= - f32 = T.f32 - i32 = T.i32 - i64 = T.i64 - vec4_f32 = T.f32x4 - vec4_i16 = T.i16x4 - vec4_f16 = T.f16x4 - vec8_elem = _vec8_type() - vec1_i64 = T.vec(1, i64) - vec2_i64 = T.vec(2, i64) - - # Accumulator initialization - acc_init = arith.constant_vector(0.0, vec4_f32) - - # ================= Buffer sizes in bytes for OOB handling ================= - # A: [M, K] -> M * K * elem_bytes - a_nbytes_idx = c_m * c_k * arith.constant(elem_bytes, index=True) - a_nbytes_i32 = arith.index_cast(i32, a_nbytes_idx) - - # B: [N, K] -> N * K * elem_bytes - b_nbytes_idx = c_n * c_k * arith.constant(elem_bytes, index=True) - b_nbytes_i32 = arith.index_cast(i32, b_nbytes_idx) - - # C: [M, N] -> M * N * out_elem_bytes - c_nbytes_idx = c_m * c_n * arith.constant(out_elem_bytes, index=True) - c_nbytes_i32 = arith.index_cast(i32, c_nbytes_idx) - - # ================= Buffer Resources with explicit sizes ================= - a_rsrc = buffer_ops.create_buffer_resource(arg_a, max_size=False, num_records_bytes=a_nbytes_i32) - b_rsrc = buffer_ops.create_buffer_resource(arg_b, max_size=False, num_records_bytes=b_nbytes_i32) - c_rsrc = buffer_ops.create_buffer_resource(arg_c, max_size=False, num_records_bytes=c_nbytes_i32) - - # ================= Layouts ================= - # A layout: [M, K] row-major - layout_a = flir.make_layout((c_m, c_k), stride=(c_k, 1)) - - # B layout: [N, K] row-major (B^T in standard GEMM) - layout_b = flir.make_layout((c_n, c_k), stride=(c_k, 1)) - - # C layout: [M, N] row-major - layout_c = flir.make_layout((c_m, c_n), stride=(c_n, 1)) - - # LDS layout: [tile_m, tile_k] - shape_lds = flir.make_shape(tile_m, tile_k) - stride_lds = flir.make_stride(lds_stride, 1) - layout_lds = flir.make_layout(shape_lds, stride_lds) - - # XOR16 swizzle parameter (in 16-byte blocks) - k_blocks16 = arith.constant(tile_k_bytes // 16, index=True) - - # ================= Thread/Block IDs ================= - tx = gpu.thread_id("x") - bx = gpu.block_id("x") # M dimension - by = gpu.block_id("y") # N dimension - - # Base addresses for this block - bx_m = bx * arith.constant(tile_m, index=True) - by_n = by * arith.constant(tile_n, index=True) - - # ================= Thread Decomposition ================= - # tx -> (wave_id, lane_id) - layout_wave_lane = flir.make_layout((4, 64), stride=(64, 1)) - coord_wave_lane = flir.idx2crd(tx, layout_wave_lane) - wave_id = flir.get(coord_wave_lane, 0) - lane_id = flir.get(coord_wave_lane, 1) - - # lane_id -> (lane_div_16, lane_mod_16) - layout_lane16 = flir.make_layout((4, 16), stride=(16, 1)) - coord_lane16 = flir.idx2crd(lane_id, layout_lane16) - lane_div_16 = flir.get(coord_lane16, 0) - lane_mod_16 = flir.get(coord_lane16, 1) - - # ================= LDS Setup ================= - base_ptr = allocator.get_base() - lds_a_ptr = _state["lds_a_decl"](base_ptr) - lds_a = lds_a_ptr.get() - - # ================= Wave/Lane Mappings ================= - # For MFMA 16x16x16: - # - A row index: lane_mod_16 (0..15) - # - K pack offset: lane_div_16 * 4 (each lane group handles 4 elements) - row_a_lds = lane_mod_16 - - # K element offset for LDS reads (16 elements per pack, 4 packs per K64) - kpack_elems = 8 # 8 bf16 = 16 bytes - col_offset_base = lane_div_16 * arith.constant(kpack_elems, index=True) - # Convert to bytes for swizzle - col_offset_base_bytes = col_offset_base * arith.constant(elem_bytes, index=True) - - # ================= Tile Configuration ================= - m_repeat = tile_m // 16 # Number of M-dimension repeats - k_unroll = tile_k_bytes // 64 # K64-byte micro-steps - num_waves = 4 - n_per_wave = tile_n // num_waves # Columns per wave - num_acc_n = n_per_wave // 16 # Accumulators per wave along N - - # Wave's N tile base - c_n_per_wave = arith.constant(n_per_wave, index=True) - n_tile_base = wave_id * c_n_per_wave - - # ================= A Tile Loading (GM -> LDS) with mask ================= - # 256 threads load tile_m × tile_k elements (16B per thread) - bytes_a_per_tile = tile_m * tile_k * elem_bytes - bytes_per_thread_a = bytes_a_per_tile // total_threads - num_a_loads = bytes_per_thread_a // 16 # 16B loads - - # A tile layout in dwords for addressing - tile_k_dwords = (tile_k * elem_bytes) // 4 - layout_a_tile_div4 = flir.make_layout( - (tile_m, tile_k_dwords), stride=(tile_k_dwords, 1) - ) - - c4 = arith.constant(4, index=True) - tx_i32_base = tx * c4 - - atom_a_lds = flir.make_copy_atom(_elem_type(), vector_size=8) - - def a_tile_chunk_coord(i: int): - """Map (thread, chunk_id) -> (row_local, col_local_i32) for A loads.""" - chunk_off = arith.constant(i * total_threads * 4, index=True) - tile_idx = tx_i32_base + chunk_off - coord_local = flir.idx2crd(tile_idx, layout_a_tile_div4) - row_local = flir.get(coord_local, 0) - col_local_i32 = flir.get(coord_local, 1) - return row_local, col_local_i32 - - def load_a_tile(base_k): - """Load A tile from global memory (tile_m × tile_k) with mask.""" - parts = [] - for i in range_constexpr(num_a_loads): - row_local, col_local_i32 = a_tile_chunk_coord(i) - row_global = bx_m + row_local - # col_local_i32 is in dwords (4 bytes), convert to elements - col_local_elem = col_local_i32 * arith.constant(2, index=True) # 2 bf16 per dword - k_global = base_k + col_local_elem - - # Calculate linear element offset for buffer_load - # buffer_load expects offset in elements (i32 unit), it will scale to bytes internally - # offset = row_global * K + k_global (in dword units for vec4 i32 load) - offset_elem = row_global * c_k + k_global - # Convert to dword offset (divide by 2 since 2 bf16 per dword) - offset_dword = offset_elem / arith.constant(2, index=True) - offset_i32 = arith.index_cast(i32, offset_dword) - - # Mask: row_global < M (K is guaranteed to be padded to tile_k) - row_valid = arith.cmpu(row_global, c_m, "ult") - - # Load 4 dwords (16 bytes = 8 bf16 elements) with mask - a_i32x4 = buffer_ops.buffer_load(a_rsrc, offset_i32, vec_width=4, dtype=i32, mask=row_valid) - parts.append(a_i32x4) - return parts - - def store_a_tile_to_lds(a_parts): - """Store A tile to LDS with XOR16 swizzle.""" - for i in range_constexpr(num_a_loads): - row_local, col_local_i32 = a_tile_chunk_coord(i) - # Apply XOR16 swizzle - col_local_bytes = col_local_i32 * c4 - col_swz_bytes = flir.swizzle_xor16(row_local, col_local_bytes, k_blocks16) - col_swz = col_swz_bytes / arith.constant(elem_bytes, index=True) - coord_store = flir.make_coord(row_local, col_swz) - idx0 = flir.crd2idx(coord_store, layout_lds) - v8 = vector.bitcast(vec8_elem, a_parts[i]) - s_view = flir.TensorView( - lds_a, - (8,), - strides=(1,), - base_indices=(idx0,), - element_type=_elem_type(), - ) - flir.copy(atom_a_lds, v8, s_view, alignment=16) - - # ================= B Tile Loading (Direct to GPR) with mask ================= - def load_b_packs_k64(base_k, ku: int, ni: int): - """Load B pack for MFMA (16B -> 2 × i64 for K64-byte step) with mask.""" - # Global N index for this wave/lane - n_offset = arith.constant(ni * 16, index=True) - n_global = by_n + n_tile_base + n_offset + lane_mod_16 - - # K index within the K64 block - ki64 = arith.constant(ku * 32, index=True) # 64 bytes = 32 bf16 - k_base = base_k + ki64 - - # lane_div_16 determines which 8 elements to load (0-3 -> 0, 8, 16, 24 offset) - k_lane_offset = lane_div_16 * arith.constant(8, index=True) - k_global = k_base + k_lane_offset - - # Calculate linear element offset for buffer_load - # buffer_load with dtype=i32 scales offset by 4, so we need dword offset - # offset_elem = n_global * K + k_global (in bf16 elements) - # offset_dword = offset_elem / 2 (in i32 dwords) - offset_elem = n_global * c_k + k_global - offset_dword = offset_elem / arith.constant(2, index=True) - offset_i32 = arith.index_cast(i32, offset_dword) - - # Mask: n_global < N (K is guaranteed to be padded to tile_k) - n_valid = arith.cmpu(n_global, c_n, "ult") - - # Load 4 dwords (16 bytes = 8 bf16 elements) with mask - b_i32x4 = buffer_ops.buffer_load(b_rsrc, offset_i32, vec_width=4, dtype=i32, mask=n_valid) - - # Convert to vec8 bf16/fp16, then split into two i64 halves - b_vec = vector.bitcast(vec8_elem, b_i32x4) - b_i64x2 = vector.bitcast(vec2_i64, b_vec) - b0_i64 = vector.extract(b_i64x2, static_position=[0], dynamic_position=[]) - b1_i64 = vector.extract(b_i64x2, static_position=[1], dynamic_position=[]) - - # Convert to MFMA operand type - if is_bf16: - # bf16 uses i16 bit patterns - b0_v1 = vector.from_elements(vec1_i64, [b0_i64]) - b1_v1 = vector.from_elements(vec1_i64, [b1_i64]) - return vector.bitcast(vec4_i16, b0_v1), vector.bitcast(vec4_i16, b1_v1) - else: - # fp16 uses f16 directly - b0_v1 = vector.from_elements(vec1_i64, [b0_i64]) - b1_v1 = vector.from_elements(vec1_i64, [b1_i64]) - return vector.bitcast(vec4_f16, b0_v1), vector.bitcast(vec4_f16, b1_v1) - - def load_b_tile(base_k): - """Load entire B tile for K loop.""" - b_tile = [] - for ku in range_constexpr(k_unroll): - packs0 = [] - packs1 = [] - for ni in range_constexpr(num_acc_n): - b0, b1 = load_b_packs_k64(base_k, ku, ni) - packs0.append(b0) - packs1.append(b1) - b_tile.append((packs0, packs1)) - return b_tile - - # ================= A LDS Load ================= - def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_base): - """Load A pack from LDS for MFMA (16B -> 2 × i64).""" - # Apply XOR16 swizzle - col_swz_bytes = flir.swizzle_xor16(curr_row_a_lds, col_base_bytes, k_blocks16) - col_swz = col_swz_bytes / arith.constant(elem_bytes, index=True) - coord_a = flir.make_coord(curr_row_a_lds, col_swz) - idx_a = flir.crd2idx(coord_a, layout_lds) - idx_a = idx_a + lds_base - - # Load 8 elements - loaded_a = vector.load_op(vec8_elem, lds_a, [idx_a]) - a_i64x2 = vector.bitcast(vec2_i64, loaded_a) - a0_i64 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) - a1_i64 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) - - # Convert to MFMA operand type - if is_bf16: - a0_v1 = vector.from_elements(vec1_i64, [a0_i64]) - a1_v1 = vector.from_elements(vec1_i64, [a1_i64]) - return vector.bitcast(vec4_i16, a0_v1), vector.bitcast(vec4_i16, a1_v1) - else: - a0_v1 = vector.from_elements(vec1_i64, [a0_i64]) - a1_v1 = vector.from_elements(vec1_i64, [a1_i64]) - return vector.bitcast(vec4_f16, a0_v1), vector.bitcast(vec4_f16, a1_v1) - - # ================= MFMA Computation ================= - mfma_res_ty = vec4_f32 - if is_bf16: - mfma_fn = rocdl.mfma_f32_16x16x16bf16_1k - else: - mfma_fn = rocdl.mfma_f32_16x16x16f16 - - def mfma_step(acc_in, a, b): - return mfma_fn(mfma_res_ty, [a, b, acc_in, 0, 0, 0]) - - def mfma_k64_bytes(acc_in, a0, a1, b0, b1): - """K64-byte wrapper: two MFMA K16 ops.""" - acc_mid = mfma_step(acc_in, a0, b0) - return mfma_step(acc_mid, a1, b1) - - def compute_tile(accs_in, b_tile_in, lds_base): - """Compute one tile of MFMA operations.""" - current_accs = list(accs_in) - - for ku in range_constexpr(k_unroll): - b_packs0, b_packs1 = b_tile_in[ku] - # K byte offset for this ku - ki64 = ku * 64 # 64 bytes per ku - col_base = col_offset_base_bytes + arith.constant(ki64, index=True) - - for mi in range_constexpr(m_repeat): - mi_val = arith.constant(mi * 16, index=True) - curr_row_a_lds = row_a_lds + mi_val - - # Load A pack from LDS - a0, a1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) - - for ni in range_constexpr(num_acc_n): - acc_idx = mi * num_acc_n + ni - current_accs[acc_idx] = mfma_k64_bytes( - current_accs[acc_idx], - a0, a1, - b_packs0[ni], b_packs1[ni], - ) - - return current_accs - - # ================= Epilogue (Store C) with mask ================= - def store_output(final_accs): - """Store accumulated results to C with mask-based boundary check.""" - lane_div_16_mul4 = lane_div_16 * arith.constant(4, index=True) - - for mi in range_constexpr(m_repeat): - mi_base = arith.constant(mi * 16, index=True) - for ii in range_constexpr(4): # 4 rows per lane group - ii_idx = arith.constant(ii, index=True) - row_off = lane_div_16_mul4 + ii_idx - row_in_tile = mi_base + row_off - row = bx_m + row_in_tile - - col_base = by_n + n_tile_base + lane_mod_16 - - for ni in range_constexpr(num_acc_n): - acc_idx = mi * num_acc_n + ni - acc = final_accs[acc_idx] - val = vector.extract(acc, static_position=[ii], dynamic_position=[]) - - # Convert f32 to output type - val_out = arith.trunc_f(_out_type(), val) - - col = col_base + arith.constant(ni * 16, index=True) - - # Calculate linear element offset for buffer_store - # buffer_store expects offset in elements (it will scale by element size) - # offset = row * N + col (in bf16/fp16 elements) - offset_elem = row * c_n + col - offset_i32 = arith.index_cast(i32, offset_elem) - - # Mask: row < M and col < N - row_valid = arith.cmpu(row, c_m, "ult") - col_valid = arith.cmpu(col, c_n, "ult") - mask = arith.andi(row_valid, col_valid) - - # Store with mask (OOB stores are skipped) - buffer_ops.buffer_store(val_out, c_rsrc, offset_i32, mask=mask) - - # ================= Main Pipeline ================= - # Single LDS buffer, simple pipeline - lds_base = arith.constant(0, index=True) - - # Initialize accumulators - accs = [acc_init] * (num_acc_n * m_repeat) - - # K loop - iterate over K in tile_k steps - c_tile_k = arith.constant(tile_k, index=True) - # Calculate number of K iterations needed (ceiling division) - # We iterate through all K blocks, mask handles the boundary - for k_base in range(arith.constant(0, index=True), c_k, c_tile_k): - # Load A tile to LDS (with mask for boundary) - a_parts = load_a_tile(k_base) - store_a_tile_to_lds(a_parts) - gpu.barrier() - - # Load B tile directly to GPR (with mask for boundary) - b_tile = load_b_tile(k_base) - - # Compute MFMA - accs = compute_tile(accs, b_tile, lds_base) - - # Barrier before next iteration (if any) - gpu.barrier() - - # Store output (with mask for boundary) - store_output(accs) - - @flir.jit - def __call__( - self: flir.T.i64, - arg_c: lambda: memref(DYN, _out_type()), - arg_a: lambda: memref(DYN, _elem_type()), - arg_b: lambda: memref(DYN, _elem_type()), - c_m: lambda: T.index, - c_n: lambda: T.index, - c_k: lambda: T.index, - ): - c1 = arith.constant(1, index=True) - bdx = arith.constant(256, index=True) - tm = arith.constant(tile_m, index=True) - tn = arith.constant(tile_n, index=True) - one = arith.constant(1, index=True) - # Grid size: ceiling division for non-aligned M and N - gx = (c_m + tm - one) / tm - gy = (c_n + tn - one) / tn - flir.gpu_ext.LaunchFuncOp( - [module_name, "kernel_gemm"], - grid_size=(gx, gy, c1), - block_size=(bdx, c1, c1), - kernel_operands=[ - arg_c, - arg_a, - arg_b, - c_m, - c_n, - c_k, - ], - ) - - m = _GEMM() - return flydsl.compile(m, waves_per_eu=waves_per_eu) - - -def run_simple_gemm( - *, - M: int, - N: int, - K: int, - tile_m: int = 16, - tile_n: int = 64, - tile_k: int = 128, - in_dtype: str = "bf16", - A=None, - B=None, - waves_per_eu: int = None, -): - """Run simple GEMM: C = A @ B^T. - - This function supports non-aligned M, N, K dimensions: - - M and N: handled by kernel mask-based loads/stores (Triton-like approach) - - K: padded to tile_k on host (required for MFMA vector loads) - - Args: - M, N, K: Matrix dimensions (A[M,K], B[N,K], C[M,N]). - tile_m, tile_n, tile_k: Tile sizes. - in_dtype: Input data type ("bf16" or "fp16"). - A: Optional input tensor A[M,K]. If None, creates random tensor. - B: Optional input tensor B[N,K]. If None, creates random tensor. - waves_per_eu: Optional hint for AMDGPU backend about the desired number of waves. - - Returns: - C: Output tensor C[M,N]. - """ - import torch - - # Determine torch dtype - if in_dtype == "bf16": - torch_dtype = torch.bfloat16 - else: - torch_dtype = torch.float16 - - device = "cuda" - - # Create input tensors if not provided - if A is None: - A = torch.randn(M, K, dtype=torch_dtype, device=device) - if B is None: - B = torch.randn(N, K, dtype=torch_dtype, device=device) - - # Ensure inputs are contiguous and on correct device - A = A.contiguous().to(device=device, dtype=torch_dtype) - B = B.contiguous().to(device=device, dtype=torch_dtype) - - # Pad K to tile_k (required for MFMA vector loads) - # M and N are handled by kernel mask-based boundary checks - K_pad = _align_up(K, tile_k) - if K_pad != K: - A_pad = torch.zeros(M, K_pad, dtype=torch_dtype, device=device) - B_pad = torch.zeros(N, K_pad, dtype=torch_dtype, device=device) - A_pad[:, :K] = A - B_pad[:, :K] = B - A = A_pad - B = B_pad - K = K_pad - - # Create output tensor (original size, no padding needed for M and N) - C = torch.zeros(M, N, dtype=torch_dtype, device=device) - - # Compile and run kernel - exe = compile_simple_gemm( - tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, - in_dtype=in_dtype, - waves_per_eu=waves_per_eu, - ) - - # Flatten tensors for kernel interface - A_flat = A.view(-1) - B_flat = B.view(-1) - C_flat = C.view(-1) - - # Pass dimensions (K is now padded, M and N are original) - exe(C_flat, A_flat, B_flat, M, N, K) - torch.cuda.synchronize() - - return C diff --git a/run.sh b/run.sh deleted file mode 100755 index 86959e1c..00000000 --- a/run.sh +++ /dev/null @@ -1,102 +0,0 @@ -set -x - -shopt -s expand_aliases - -alias l.='ls -d .* --color=auto' -alias ll='ls -l --color=auto' -alias ls='ls --color=auto' -alias python='python3' - -# export HIP_VISIBLE_DEVICES=0 -# export HIP_VISIBLE_DEVICES=1 -# export HIP_VISIBLE_DEVICES=3 -# export HIP_VISIBLE_DEVICES=5 -export HIP_VISIBLE_DEVICES=6 -# export HIP_VISIBLE_DEVICES=7 - - -# export LD_LIBRARY_PATH=/mnt/raid0/heyanguang/code/poc_kl/scripts/common:$LD_LIBRARY_PATH -# export LD_LIBRARY_PATH=/usr/local/lib/python3.12/dist-packages/torch/lib:$LD_LIBRARY_PATH -# export PATH=/mnt/raid0/heyanguang/code/poc_kl/scripts/common:$PATH - -rocm-smi | egrep "$HIP_VISIBLE_DEVICES |Device" -pip show triton -rocprofv3 --version - - -function run_flydsl_op { - export MLIR_ASM_VERBOSE=1 - export FLIR_LOG_MORE=1 - export FLIR_DUMP_IR=1 - export FLIR_REBUILD=1 - export FLIR_DUMP_DIR=./flydsl_dump - - # python tests/kernels/test_moe_stage1_simple.py --size M - - # python tests/kernels/test_simple_gemm.py --size XL --waves_per_eu 1 - # python tests/kernels/test_simple_gemm.py --size NA4 - # python tests/kernels/test_simple_gemm.py --size all --dtype all - - # python tests/kernels/test_flash_attention_v4_2.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 - # python tests/kernels/test_flash_attention_v4_3.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v42 - # python tests/kernels/test_flash_attn_func.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 --compare-v43 - # python tests/kernels/test_flash_attn_func.py --iters 100 --compare-v43 - - python tests/kernels/test_flash_attn_func.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 100 - python tests/kernels/test_flash_attn_func.py --iters 100 - - # rocprof -i perf_counters1.txt -o prof_v44_p1.csv python tests/kernels/test_flash_attn_func.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 5 --warmup 2 - # rocprof -i perf_counters2.txt -o prof_v44_p2.csv python tests/kernels/test_flash_attn_func.py --batch 1 --num_heads 64 --seq_len 8192 --head_dim 128 --iters 5 --warmup 2 - -} - - -function get_flydsl_op_thread_trace { - pushd $PWD - export KERNEL_NAME=kernel_gemm - KERNEL_VERSION="${KERNEL_NAME}_v0" - - - DUMP_TRACE=1 - # DUMP_TRACE=0 - if [ $DUMP_TRACE = 1 ]; then - rm -rf ./pass_2 - cd ./thread_trace - trace_dir=./${KERNEL_VERSION} - rm -rf ./rpf_v3 - rm -rf ./${trace_dir} ./${trace_dir}.tar.gz - mkdir -p ${trace_dir} - cd - - - rocprofv3 -i ./input.yaml -- \ - python tests/kernels/test_simple_gemm.py --size XL --waves_per_eu 1 - # python tests/kernels/test_simple_gemm.py --size XL - - cd ./thread_trace - cp -r ./rpf_v3/pass_1/*.att ${trace_dir} - cp -r ./rpf_v3/pass_1/ui_* ${trace_dir} - cp -r ./rpf_v3/pass_1/*_agent_info.csv ${trace_dir} - cp -r ./rpf_v3/pass_1/stats_ui_*.csv ${trace_dir} - tar -zcf ./${trace_dir}.tar.gz ./${trace_dir} - ls -lah ./${trace_dir} ./${trace_dir}.tar.gz - cd - - fi - - popd -} - - -# # Press y then n while install -# ./rocprof-trace-decoder-manylinux-2.28-0.1.6-Linux.sh --prefix=/opt/rocm/ -# cd /opt/rocm/ -# ll -ah ./opt/rocm/lib/librocprof-trace-decoder.so -# ll -ah ./lib/librocprof-trace-decoder.so -# cp opt/rocm/lib/librocprof-trace-decoder.so ./lib/ -# ll -ah ./lib/librocprof-trace-decoder.so - - -run_flydsl_op -# get_flydsl_op_thread_trace - - -set +x diff --git a/tests/kernels/test_moe_stage1_simple.py b/tests/kernels/test_moe_stage1_simple.py deleted file mode 100644 index 2cbbfa47..00000000 --- a/tests/kernels/test_moe_stage1_simple.py +++ /dev/null @@ -1,193 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple test script for run_moe_stage1 (no pytest required). - -Usage: - python tests/kernels/test_moe_stage1_simple.py [--size S|M|L] [--dtype fp8|fp16|int8|int4|all] - -Examples: - python tests/kernels/test_moe_stage1_simple.py # Run Small with fp8 - python tests/kernels/test_moe_stage1_simple.py --size M # Run Medium with fp8 - python tests/kernels/test_moe_stage1_simple.py --dtype all # Run Small with all dtypes - python tests/kernels/test_moe_stage1_simple.py --size L --dtype fp8 # Run Large with fp8 -""" - -import argparse -import os -import sys - -# Ensure repo-local flydsl is used -_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) -if _REPO_ROOT not in sys.path: - sys.path.insert(0, _REPO_ROOT) - -import torch - -# Import run_moe_stage1 from the test file -from tests.kernels.test_moe_gemm import run_moe_stage1 - -# Test configurations (from pytest.param) -TEST_CONFIGS = { - "S": { - "tokens": 64, - "model_dim": 256, - "inter_dim": 128, - "experts": 4, - "topk": 2, - "tile_m": 32, - "tile_n": 64, - "tile_k": 128, - "doweight_stage1": False, - "description": "Small smoke test", - }, - "M": { - "tokens": 128, - "model_dim": 1024, - "inter_dim": 256, - "experts": 8, - "topk": 2, - "tile_m": 64, - "tile_n": 128, - "tile_k": 128, - "doweight_stage1": False, - "description": "Medium realistic test", - }, - "L": { - "tokens": 256, - "model_dim": 4096, - "inter_dim": 2048, - "experts": 17, - "topk": 9, - "tile_m": 64, - "tile_n": 128, - "tile_k": 128, - "doweight_stage1": False, - "description": "Large aiter-style test", - }, -} - -DTYPES = ["fp8", "fp16", "int8", "int4"] - - -def run_test(size: str, in_dtype: str, num_iters: int = 5, num_warmup: int = 2, skip_ref: bool = False): - """Run a single stage1 test.""" - config = TEST_CONFIGS[size] - - print("=" * 70) - print(f"Running MoE Stage1 Test: size={size} ({config['description']}), dtype={in_dtype}") - print(f" tokens={config['tokens']}, model_dim={config['model_dim']}, inter_dim={config['inter_dim']}") - print(f" experts={config['experts']}, topk={config['topk']}") - print(f" tile_m={config['tile_m']}, tile_n={config['tile_n']}, tile_k={config['tile_k']}") - print("=" * 70) - - try: - run_moe_stage1( - tokens=config["tokens"], - model_dim=config["model_dim"], - inter_dim=config["inter_dim"], - experts=config["experts"], - topk=config["topk"], - tile_m=config["tile_m"], - tile_n=config["tile_n"], - tile_k=config["tile_k"], - doweight_stage1=config["doweight_stage1"], - in_dtype=in_dtype, - seed=0, - num_iters=num_iters, - num_warmup=num_warmup, - compare_aiter_ck=False, # Skip aiter comparison by default - moe_sort_mode="torch", # Use torch sorting for portability - skip_ref=skip_ref, - ) - print(f"[PASS] size={size}, dtype={in_dtype}\n") - return True - except Exception as e: - print(f"[FAIL] size={size}, dtype={in_dtype}") - print(f" Error: {e}\n") - import traceback - traceback.print_exc() - return False - - -def main(): - parser = argparse.ArgumentParser( - description="Simple MoE Stage1 test (no pytest)", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=__doc__, - ) - parser.add_argument( - "--size", "-s", - type=str, - choices=["S", "M", "L", "all"], - default="S", - help="Test size: S (small), M (medium), L (large), or all", - ) - parser.add_argument( - "--dtype", "-d", - type=str, - choices=["fp8", "fp16", "int8", "int4", "all"], - default="fp8", - help="Input data type (default: fp8)", - ) - parser.add_argument( - "--num_iters", "-n", - type=int, - default=100, - help="Number of benchmark iterations (default: 5)", - ) - parser.add_argument( - "--num_warmup", "-w", - type=int, - default=2, - help="Number of warmup iterations (default: 2)", - ) - parser.add_argument( - "--skip_ref", - action="store_true", - help="Skip reference correctness check (benchmark only)", - ) - - args = parser.parse_args() - - # Check CUDA availability - if not torch.cuda.is_available(): - print("ERROR: CUDA/ROCm not available. Cannot run GPU tests.") - sys.exit(1) - - torch.set_default_device("cuda") - - # Determine sizes and dtypes to run - sizes = list(TEST_CONFIGS.keys()) if args.size == "all" else [args.size] - dtypes = DTYPES if args.dtype == "all" else [args.dtype] - - print(f"\nRunning MoE Stage1 tests: sizes={sizes}, dtypes={dtypes}") - print(f"GPU: {torch.cuda.get_device_name(0)}\n") - - results = [] - for size in sizes: - for dtype in dtypes: - passed = run_test( - size=size, - in_dtype=dtype, - num_iters=args.num_iters, - num_warmup=args.num_warmup, - skip_ref=args.skip_ref, - ) - results.append((size, dtype, passed)) - - # Summary - print("\n" + "=" * 70) - print("SUMMARY") - print("=" * 70) - passed = sum(1 for _, _, p in results if p) - total = len(results) - for size, dtype, p in results: - status = "PASS" if p else "FAIL" - print(f" [{status}] size={size}, dtype={dtype}") - print(f"\nTotal: {passed}/{total} passed") - - sys.exit(0 if passed == total else 1) - - -if __name__ == "__main__": - main() diff --git a/tests/kernels/test_simple_gemm.py b/tests/kernels/test_simple_gemm.py deleted file mode 100644 index 255148af..00000000 --- a/tests/kernels/test_simple_gemm.py +++ /dev/null @@ -1,518 +0,0 @@ -#!/usr/bin/env python3 -""" -Simple test script for the simple GEMM kernel. - -Usage: - python tests/kernels/test_simple_gemm.py [--size S|M|L|XL|NA1|NA2|all] [--dtype bf16|fp16|all] [--waves_per_eu N] - -Examples: - python tests/kernels/test_simple_gemm.py # Run Small with bf16 - python tests/kernels/test_simple_gemm.py --size M # Run Medium with bf16 - python tests/kernels/test_simple_gemm.py --dtype all # Run Small with all dtypes - python tests/kernels/test_simple_gemm.py --size all # Run all sizes with bf16 - python tests/kernels/test_simple_gemm.py --size NA1 # Non-aligned test 1 - python tests/kernels/test_simple_gemm.py --waves_per_eu 2 # Set waves per EU hint to 2 -""" - -import argparse -import hashlib -import logging -import os -import random -import sys - -# Ensure repo-local flydsl is used -_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) -if _REPO_ROOT not in sys.path: - sys.path.insert(0, _REPO_ROOT) - -import numpy as np -import torch - -from kernels.simple_gemm import compile_simple_gemm, run_simple_gemm -from tests.test_common import run_perftest, verify_output - -# Configure logging to show INFO level messages (required for kernel name display) -logging.basicConfig(level=logging.INFO) - -# Test configurations -# Aligned tests: M, N, K are multiples of tile sizes -TEST_CONFIGS = { - "S": { - "M": 16, - "N": 64, - "K": 128, - "tile_m": 16, - "tile_n": 64, - "tile_k": 128, - "description": "Small smoke test (single tile)", - }, - "M": { - "M": 64, - "N": 128, - "K": 256, - "tile_m": 16, - "tile_n": 64, - "tile_k": 128, - "description": "Medium test (multi-tile)", - }, - "L": { - "M": 256, - "N": 512, - "K": 512, - "tile_m": 16, - "tile_n": 64, - "tile_k": 128, - "description": "Large test", - }, - "XL": { - "M": 1280, - "N": 2048, - "K": 128, - "tile_m": 16, - "tile_n": 64, - "tile_k": 128, - "description": "Extra large test", - }, - # Non-aligned tests: M, N, K are NOT multiples of 16 - "NA1": { - "M": 33, # Not aligned to 16 - "N": 87, # Not aligned to 64 - "K": 145, # Not aligned to 128 - "tile_m": 16, - "tile_n": 64, - "tile_k": 128, - "description": "Non-aligned test 1 (M=33, N=87, K=145)", - }, - "NA2": { - "M": 57, # Not aligned to 16 - "N": 123, # Not aligned to 64 - "K": 259, # Not aligned to 128 - "tile_m": 16, - "tile_n": 64, - "tile_k": 128, - "description": "Non-aligned test 2 (M=57, N=123, K=259)", - }, - "NA3": { - "M": 100, # Not aligned to 16 - "N": 200, # Not aligned to 64 - "K": 300, # Not aligned to 128 - "tile_m": 16, - "tile_n": 64, - "tile_k": 128, - "description": "Non-aligned test 3 (M=100, N=200, K=300)", - }, - "NA4": { - "M": 171, # Not aligned to 16 - "N": 333, # Not aligned to 64 - "K": 517, # Not aligned to 128 - "tile_m": 16, - "tile_n": 64, - "tile_k": 128, - "description": "Non-aligned test 4 (M=171, N=333, K=517)", - }, -} - -DTYPES = ["bf16", "fp16"] - -# Tensor initialization range (uniform distribution) -UNIFORM_RANGE = (-1, 1) -DEFAULT_SEED = 123 - - -def setup_seed(seed: int) -> None: - """Set random seed for reproducibility across all RNG sources.""" - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - torch.backends.cudnn.deterministic = True - - -def get_torch_dtype(in_dtype: str): - """Convert string dtype to torch dtype.""" - if in_dtype == "bf16": - return torch.bfloat16 - elif in_dtype == "fp16": - return torch.float16 - else: - raise ValueError(f"Unknown dtype: {in_dtype}") - - -def _align_up(val: int, align: int) -> int: - """Round up val to the next multiple of align.""" - return ((val + align - 1) // align) * align - - -def compute_md5(tensor: torch.Tensor) -> str: - """Compute MD5 hash of a tensor's raw bytes.""" - return hashlib.md5( - tensor.contiguous().view(torch.uint8).detach().cpu().numpy().tobytes() - ).hexdigest() - - -def compare_arrays( - arr1: np.ndarray, - arr2: np.ndarray, - k: int = 5, - thresholds: list = None, -) -> dict: - """Compare two numpy arrays and compute various difference metrics. - - Args: - arr1: First input array (result), will be cast to float32. - arr2: Second input array (reference), will be cast to float32. - k: Number of top differences to report. - thresholds: Difference magnitude buckets for histogram. - - Returns: - Dictionary with top_k_diff, threshold_stats, nan_info, max_diff, max_diff_thr. - """ - if thresholds is None: - thresholds = [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1] - - if arr1.shape != arr2.shape: - raise ValueError( - f"Shape mismatch: arr1 {arr1.shape} vs arr2 {arr2.shape}" - ) - - arr1 = arr1.astype(np.float32) - arr2 = arr2.astype(np.float32) - - result = {"top_k_diff": [], "threshold_stats": [], "nan_info": {}} - - # Check for NaN values - nan_mask1 = np.isnan(arr1) - nan_mask2 = np.isnan(arr2) - if np.any(nan_mask1): - result["nan_info"]["arr1_nan_count"] = int(np.sum(nan_mask1)) - print(f" Warning: result contains {result['nan_info']['arr1_nan_count']} NaN values") - if np.any(nan_mask2): - result["nan_info"]["arr2_nan_count"] = int(np.sum(nan_mask2)) - print(f" Warning: reference contains {result['nan_info']['arr2_nan_count']} NaN values") - - # Compute absolute differences - diff = np.abs(arr1 - arr2) - total_elements = arr1.size - - max_diff_thr = (diff / (1.0 + np.abs(arr2))).max() - result["max_diff"] = float(diff.max()) - result["max_diff_thr"] = float(max_diff_thr) - - print(f" diff.abs.max = {diff.max():.6f}") - print(f" diff.abs.mean = {diff.mean():.6f}") - print(f" max_diff_thr (rel) = {max_diff_thr:.6e}") - - # Find top k differences - flat_diff = diff.flatten() - actual_k = min(k, len(flat_diff)) - top_k_indices = np.argpartition(flat_diff, -actual_k)[-actual_k:] - top_k_indices = top_k_indices[np.argsort(-flat_diff[top_k_indices])] - - orig_indices = np.unravel_index(top_k_indices, diff.shape) - print(f" Top-{actual_k} differences:") - for i in range(actual_k): - idx = tuple(dim[i] for dim in orig_indices) - entry = { - "value": float(diff[idx]), - "position": idx, - "arr1_value": float(arr1[idx]), - "arr2_value": float(arr2[idx]), - } - result["top_k_diff"].append(entry) - print(f" [{idx}] result={arr1[idx]:.6f}, ref={arr2[idx]:.6f}, diff={diff[idx]:.6f}") - - # Compute threshold statistics - print(f" Threshold distribution ({total_elements} elements):") - for i in range(len(thresholds) - 1): - lower, upper = thresholds[i], thresholds[i + 1] - count = int(np.sum((diff >= lower) & (diff < upper))) - pct = 100.0 * count / total_elements - result["threshold_stats"].append( - {"range": f"[{lower:.0e}, {upper:.0e})", "count": count, "percentage": pct} - ) - print(f" [{lower:.0e}, {upper:.0e}): {count:>8d} ({pct:6.2f}%)") - - count = int(np.sum(diff >= thresholds[-1])) - pct = 100.0 * count / total_elements - result["threshold_stats"].append( - {"range": f">={thresholds[-1]:.0e}", "count": count, "percentage": pct} - ) - print(f" >={thresholds[-1]:.0e} : {count:>8d} ({pct:6.2f}%)") - - return result - - -def run_test( - size: str, - in_dtype: str, - num_iters: int = 100, - num_warmup: int = 5, - skip_ref: bool = False, - rtol: float = 1e-2, - atol: float = 1e-2, - waves_per_eu: int = None, - seed: int = DEFAULT_SEED, -): - """Run a single GEMM test.""" - config = TEST_CONFIGS[size] - M = config["M"] - N = config["N"] - K = config["K"] - tile_m = config["tile_m"] - tile_n = config["tile_n"] - tile_k = config["tile_k"] - - # K must be padded to tile_k for MFMA vector loads - K_pad = _align_up(K, tile_k) - - print("=" * 70) - print(f"Running Simple GEMM Test: size={size} ({config['description']}), dtype={in_dtype}") - print(f" M={M}, N={N}, K={K} (K_pad={K_pad})") - print(f" tile_m={tile_m}, tile_n={tile_n}, tile_k={tile_k}") - print("=" * 70) - - torch_dtype = get_torch_dtype(in_dtype) - device = "cuda" - - try: - # Create random inputs (uniform distribution in UNIFORM_RANGE) - setup_seed(seed) - A_orig = torch.empty(M, K, dtype=torch_dtype, device=device).uniform_(*UNIFORM_RANGE) - B_orig = torch.empty(N, K, dtype=torch_dtype, device=device).uniform_(*UNIFORM_RANGE) - - # Run reference computation (using float32 for accuracy) with original K - if not skip_ref: - A_f32 = A_orig.to(torch.float32) - B_f32 = B_orig.to(torch.float32) - C_ref = torch.mm(A_f32, B_f32.T).to(torch_dtype) - - # Pad K for kernel (M and N are handled by kernel mask-based boundary checks) - if K_pad != K: - A = torch.zeros(M, K_pad, dtype=torch_dtype, device=device) - B = torch.zeros(N, K_pad, dtype=torch_dtype, device=device) - A[:, :K] = A_orig - B[:, :K] = B_orig - else: - A = A_orig - B = B_orig - - # Create output tensor (original size, no padding needed for M and N) - C = torch.zeros(M, N, dtype=torch_dtype, device=device) - - # Compile kernel - print("Compiling kernel...") - if waves_per_eu is not None: - print(f" waves_per_eu={waves_per_eu}") - exe = compile_simple_gemm( - tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, - in_dtype=in_dtype, - waves_per_eu=waves_per_eu, - ) - print("Kernel compiled successfully.") - - # Flatten tensors for kernel interface - A_flat = A.view(-1) - B_flat = B.view(-1) - C_flat = C.view(-1) - C_flat.zero_() - - # Define launch function for run_perftest - def launch(): - exe(C_flat, A_flat, B_flat, M, N, K_pad) - - # Warmup and benchmark using run_perftest - print(f"Running {num_warmup} warmup + {num_iters} benchmark iterations...") - _, us = run_perftest( - launch, - num_iters=num_iters, - num_warmup=num_warmup, - ) - torch.cuda.synchronize() - - # Calculate TFLOPS - flops = 2 * M * N * K # 2 ops per element (multiply + add) - tflops = flops / (us / 1e6) / 1e12 - - print(f" Time per iteration: {us:.3f} us ({us/1000:.3f} ms)") - print(f" Throughput: {tflops:.2f} TFLOPS") - - # Verify correctness - if not skip_ref: - # Run one more time for correctness check - C_flat.zero_() - exe(C_flat, A_flat, B_flat, M, N, K_pad) - torch.cuda.synchronize() - C_result = C - - # Compute and print MD5 hashes - result_md5 = compute_md5(C_result) - ref_md5 = compute_md5(C_ref) - print(f" result_md5 = {result_md5}") - print(f" ref_md5 = {ref_md5}") - if result_md5 == ref_md5: - print(" MD5 match: EXACT (bit-identical)") - else: - print(" MD5 match: DIFFER (not bit-identical)") - - # Detailed comparison using compare_arrays - print(" --- compare_arrays ---") - compare_arrays( - C_result.to(torch.float32).detach().cpu().numpy(), - C_ref.to(torch.float32).detach().cpu().numpy(), - ) - - # Check correctness using verify_output - passed = verify_output( - C_result.to(torch.float32), - C_ref.to(torch.float32), - rtol=rtol, - atol=atol, - msg=f"size={size}, dtype={in_dtype}" - ) - - if not passed: - # Print more details for debugging - max_diff = (C_result - C_ref).abs().max().item() - mean_diff = (C_result - C_ref).abs().mean().item() - print(f" Max diff: {max_diff:.6f}") - print(f" Mean diff: {mean_diff:.6f}") - print("\n Sample values (first 4x4):") - print(f" Result:\n{C_result[:4, :4]}") - print(f" Reference:\n{C_ref[:4, :4]}") - return False - - print(f"[PASS] size={size}, dtype={in_dtype}\n") - return True - - except Exception as e: - print(f"[FAIL] size={size}, dtype={in_dtype}") - print(f" Error: {e}\n") - import traceback - traceback.print_exc() - return False - - -def main(): - parser = argparse.ArgumentParser( - description="Simple GEMM kernel test", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=__doc__, - ) - parser.add_argument( - "--size", "-s", - type=str, - choices=list(TEST_CONFIGS.keys()) + ["all", "aligned", "nonaligned"], - default="S", - help="Test size: S/M/L/XL (aligned), NA1/NA2/NA3/NA4 (non-aligned), all, aligned, or nonaligned", - ) - parser.add_argument( - "--dtype", "-d", - type=str, - choices=["bf16", "fp16", "all"], - default="bf16", - help="Input data type (default: bf16)", - ) - parser.add_argument( - "--num_iters", "-n", - type=int, - default=100, - help="Number of benchmark iterations (default: 100)", - ) - parser.add_argument( - "--num_warmup", "-w", - type=int, - default=5, - help="Number of warmup iterations (default: 5)", - ) - parser.add_argument( - "--skip_ref", - action="store_true", - help="Skip reference correctness check (benchmark only)", - ) - parser.add_argument( - "--rtol", - type=float, - default=1e-2, - help="Relative tolerance for correctness check (default: 1e-2)", - ) - parser.add_argument( - "--atol", - type=float, - default=1e-2, - help="Absolute tolerance for correctness check (default: 1e-2)", - ) - parser.add_argument( - "--waves_per_eu", - type=int, - default=None, - help="AMDGPU waves-per-eu hint for occupancy optimization (e.g., 1, 2, 4)", - ) - parser.add_argument( - "--seed", - type=int, - default=DEFAULT_SEED, - help=f"Random seed for reproducibility (default: {DEFAULT_SEED})", - ) - - args = parser.parse_args() - - # Check CUDA availability - if not torch.cuda.is_available(): - print("ERROR: CUDA/ROCm not available. Cannot run GPU tests.") - sys.exit(1) - - torch.set_default_device("cuda") - - # Determine sizes and dtypes to run - aligned_sizes = ["S", "M", "L", "XL"] - nonaligned_sizes = ["NA1", "NA2", "NA3", "NA4"] - - if args.size == "all": - sizes = list(TEST_CONFIGS.keys()) - elif args.size == "aligned": - sizes = aligned_sizes - elif args.size == "nonaligned": - sizes = nonaligned_sizes - else: - sizes = [args.size] - - dtypes = DTYPES if args.dtype == "all" else [args.dtype] - - print(f"\nRunning Simple GEMM tests: sizes={sizes}, dtypes={dtypes}") - print(f"seed: {args.seed}") - if args.waves_per_eu is not None: - print(f"waves_per_eu: {args.waves_per_eu}") - print(f"GPU: {torch.cuda.get_device_name(0)}\n") - - results = [] - for size in sizes: - for dtype in dtypes: - passed = run_test( - size=size, - in_dtype=dtype, - num_iters=args.num_iters, - num_warmup=args.num_warmup, - skip_ref=args.skip_ref, - rtol=args.rtol, - atol=args.atol, - waves_per_eu=args.waves_per_eu, - seed=args.seed, - ) - results.append((size, dtype, passed)) - - # Summary - print("\n" + "=" * 70) - print("SUMMARY") - print("=" * 70) - passed = sum(1 for _, _, p in results if p) - total = len(results) - for size, dtype, p in results: - status = "PASS" if p else "FAIL" - print(f" [{status}] size={size}, dtype={dtype}") - print(f"\nTotal: {passed}/{total} passed") - - sys.exit(0 if passed == total else 1) - - -if __name__ == "__main__": - main() diff --git a/thread_trace/.gitkeep b/thread_trace/.gitkeep deleted file mode 100644 index e69de29b..00000000 From c274a877f62cc5d306c04552e008b63625557d86 Mon Sep 17 00:00:00 2001 From: yanguahe Date: Sat, 14 Feb 2026 11:29:31 +0800 Subject: [PATCH 15/17] Update test config --- tests/kernels/test_flash_attn_func.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/kernels/test_flash_attn_func.py b/tests/kernels/test_flash_attn_func.py index 897b06cd..97b52d43 100644 --- a/tests/kernels/test_flash_attn_func.py +++ b/tests/kernels/test_flash_attn_func.py @@ -289,6 +289,7 @@ def main(): else: configs = [ (1, 128, 8, 128), + (1, 128, 64, 128), (1, 256, 32, 128), (1, 512, 32, 128), (2, 128, 8, 128), From ac1d47709aa0b81de50024a796cb867a766ef3f3 Mon Sep 17 00:00:00 2001 From: yanguahe Date: Sun, 15 Feb 2026 20:09:42 +0800 Subject: [PATCH 16/17] Address review: remove unused _apply_waves_per_eu_hint, clarify exp2 replacement --- flydsl/src/flydsl/compiler/compiler.py | 62 +++++--------------------- 1 file changed, 12 insertions(+), 50 deletions(-) diff --git a/flydsl/src/flydsl/compiler/compiler.py b/flydsl/src/flydsl/compiler/compiler.py index c05137e4..6e671ff9 100644 --- a/flydsl/src/flydsl/compiler/compiler.py +++ b/flydsl/src/flydsl/compiler/compiler.py @@ -204,8 +204,18 @@ def _replace_ocml_exp2_with_intrinsic(module: ir.Module) -> ir.Module: """Replace __ocml_exp2_f32 library calls with llvm.intr.exp2 intrinsics. The convert-gpu-to-rocdl pass lowers math.exp2 to __ocml_exp2_f32 which - generates a safe but slow 6-instruction pattern. By replacing with - llvm.intr.exp2 + fast math flags, we get bare v_exp_f32 (1 instruction). + generates a safe but slow 6-instruction pattern (range reduction + v_exp_f32 + + v_ldexp_f32). By replacing with llvm.intr.exp2 + fast math flags, we get + bare v_exp_f32 (1 instruction). + + Why text replacement instead of using math.exp2 directly: + The MLIR convert-gpu-to-rocdl pass unconditionally lowers math.exp2 to + the __ocml_exp2_f32 library call. There is no pass-level option to emit + the LLVM intrinsic instead, so we do a post-lowering text replacement + on the LLVM IR assembly. + + TODO: Replace this text-based approach with a proper MLIR rewrite pass + when upstream MLIR adds an option to lower math.exp2 to llvm.intr.exp2. Returns a new module (or the original if replacement fails). """ @@ -396,54 +406,6 @@ def _append_passthrough(func_op): pass -def _apply_waves_per_eu_hint(mlir_module, waves_per_eu: int): - """Apply AMDGPU waves-per-eu occupancy hint to GPU kernel functions. - - This modifies the MLIR module in-place by adding the 'amdgpu-waves-per-eu' - attribute to gpu.func operations marked as kernels. - - Args: - mlir_module: MLIR module containing GPU kernels - waves_per_eu: Number of wavefronts per execution unit (1-4 typical) - """ - if waves_per_eu is None: - return - - w = int(waves_per_eu) - if w < 1: - raise ValueError(f"waves_per_eu must be >= 1, got {w}") - - try: - # Get the context from the module - with mlir_module.context: - # Navigate MLIR module structure: module -> gpu.module -> gpu.func - for op in mlir_module.body.operations: - # Look for gpu.module operations - if getattr(op, "OPERATION_NAME", None) != "gpu.module": - continue - - # gpu.module has a single region with a single block - gpu_module_region = op.regions[0] - - # Within gpu.module, find gpu.func operations with gpu.kernel attribute - for inner_op in gpu_module_region.blocks[0].operations: - if getattr(inner_op, "OPERATION_NAME", None) != "gpu.func": - continue - - # Only apply to kernel functions (not device functions) - if "gpu.kernel" not in inner_op.attributes: - continue - - # Add or append to the 'rocdl.waves_per_eu' attribute - # This attribute is read by the ROCDL conversion pass - inner_op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get( - ir.IntegerType.get_signless(32), w - ) - except Exception as e: - # Best-effort: if attribute injection fails, log and continue - # This prevents breaking existing functionality - import warnings - warnings.warn(f"Failed to apply waves_per_eu hint: {e}", RuntimeWarning) def compile( flir_module_or_ir: Union[object, ir.Module], From 5401b3653dd25e96db4a0c9d3e31e081cd537932 Mon Sep 17 00:00:00 2001 From: yanguahe Date: Sun, 15 Feb 2026 22:55:02 +0800 Subject: [PATCH 17/17] Fix CI: update preshuffle_gemm to use compile(waves_per_eu=) instead of removed _apply_waves_per_eu_hint --- kernels/preshuffle_gemm.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index 549020a8..04b8effb 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -24,7 +24,6 @@ from flydsl.dialects.ext import arith, gpu, buffer_ops, vector, rocdl from flydsl.lang.ir.types import T, memref from kernels.kernels_common import stream_ptr_to_async_token -from flydsl.compiler.compiler import _apply_waves_per_eu_hint from kernels.mfma_preshuffle_pipeline import ( buffer_copy_gmem16_dwordx4, @@ -1190,15 +1189,12 @@ def __call__( m = _GEMM() - # Apply waves_per_eu hint if specified (before final compilation) - if waves_per_eu is not None: - _apply_waves_per_eu_hint(m.module, waves_per_eu) - return flydsl.compile( m, use_bare_ptr_memref_call_conv=False, use_bare_pointers_for_host=False, use_bare_pointers_for_kernels=False, + waves_per_eu=waves_per_eu, )