diff --git a/flydsl/src/flydsl/compiler/compiler.py b/flydsl/src/flydsl/compiler/compiler.py index fc48f064..6e671ff9 100644 --- a/flydsl/src/flydsl/compiler/compiler.py +++ b/flydsl/src/flydsl/compiler/compiler.py @@ -37,6 +37,8 @@ def _pipeline_fragments( use_bare_ptr_memref_call_conv: bool = False, use_bare_pointers_for_host: bool = False, use_bare_pointers_for_kernels: bool = False, + unsafe_fp_math: bool = False, + fast_fp_math: bool = False, ) -> list[str]: """FLIR compilation pipeline fragments as a plain list of strings. @@ -50,6 +52,8 @@ def _pipeline_fragments( rocdl_bare_ptr_opt = b2s(use_bare_ptr_memref_call_conv) llvm_bare_host_opt = b2s(use_bare_pointers_for_host) llvm_bare_kern_opt = b2s(use_bare_pointers_for_kernels) + unsafe_math_opt = b2s(unsafe_fp_math) + fast_opt = b2s(fast_fp_math) return [ "flir-to-standard", "trivial-dce", @@ -63,7 +67,7 @@ def _pipeline_fragments( "gpu.module(reconcile-unrealized-casts)", # Keep this as a formatted string so the chip is visible in dumps and matches # the non-dump compilation pipeline. - f"rocdl-attach-target{{O=2 abi=600 chip={chip} correct-sqrt=true daz=false fast=false features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math=false wave64=true}}", + f"rocdl-attach-target{{O=2 abi=600 chip={chip} correct-sqrt=true daz=false fast={fast_opt} features= finite-only=false module= triple=amdgcn-amd-amdhsa unsafe-math={unsafe_math_opt} wave64=true}}", "gpu-to-llvm{intersperse-sizes-for-kernels=false " + f"use-bare-pointers-for-host={llvm_bare_host_opt} " + f"use-bare-pointers-for-kernels={llvm_bare_kern_opt}" @@ -92,6 +96,8 @@ def _build_pipeline_str( use_bare_ptr_memref_call_conv: bool = False, use_bare_pointers_for_host: bool = False, use_bare_pointers_for_kernels: bool = False, + unsafe_fp_math: bool = False, + fast_fp_math: bool = False, ) -> str: """Build the full PassManager pipeline string from `_pipeline_fragments`.""" frags = _pipeline_fragments( @@ -99,6 +105,8 @@ def _build_pipeline_str( use_bare_ptr_memref_call_conv=use_bare_ptr_memref_call_conv, use_bare_pointers_for_host=use_bare_pointers_for_host, use_bare_pointers_for_kernels=use_bare_pointers_for_kernels, + unsafe_fp_math=unsafe_fp_math, + fast_fp_math=fast_fp_math, ) return f"builtin.module({','.join(frags)})" @@ -192,54 +200,212 @@ def _infer_kernel_names_from_asm(asm: str) -> list[str]: return names -def _apply_waves_per_eu_hint(mlir_module, waves_per_eu: int): - """Apply AMDGPU waves-per-eu occupancy hint to GPU kernel functions. +def _replace_ocml_exp2_with_intrinsic(module: ir.Module) -> ir.Module: + """Replace __ocml_exp2_f32 library calls with llvm.intr.exp2 intrinsics. - This modifies the MLIR module in-place by adding the 'amdgpu-waves-per-eu' - attribute to gpu.func operations marked as kernels. + The convert-gpu-to-rocdl pass lowers math.exp2 to __ocml_exp2_f32 which + generates a safe but slow 6-instruction pattern (range reduction + v_exp_f32 + + v_ldexp_f32). By replacing with llvm.intr.exp2 + fast math flags, we get + bare v_exp_f32 (1 instruction). - Args: - mlir_module: MLIR module containing GPU kernels - waves_per_eu: Number of wavefronts per execution unit (1-4 typical) + Why text replacement instead of using math.exp2 directly: + The MLIR convert-gpu-to-rocdl pass unconditionally lowers math.exp2 to + the __ocml_exp2_f32 library call. There is no pass-level option to emit + the LLVM intrinsic instead, so we do a post-lowering text replacement + on the LLVM IR assembly. + + TODO: Replace this text-based approach with a proper MLIR rewrite pass + when upstream MLIR adds an option to lower math.exp2 to llvm.intr.exp2. + + Returns a new module (or the original if replacement fails). + """ + import re + + try: + asm = module.operation.get_asm(enable_debug_info=True) + + # First replace all call sites, then remove the declaration. + # Use a broad pattern that handles loc() info and whitespace variants. + asm = re.sub( + r'llvm\.call @__ocml_exp2_f32\(([^)]+)\)\s*:\s*\(f32\)\s*->\s*f32', + r'llvm.intr.exp2(\1) {fastmathFlags = #llvm.fastmath} : (f32) -> f32', + asm, + ) + + # Remove the function declaration (it may have loc() info) + asm = re.sub( + r'\s*llvm\.func @__ocml_exp2_f32\(f32\)\s*->\s*f32[^\n]*\n', + '\n', + asm, + ) + + ctx = module.context + new_module = ir.Module.parse(asm, context=ctx) + return new_module + except Exception as e: + import sys + print(f"[flir.compile] WARNING: _replace_ocml_exp2_with_intrinsic failed: {e}", file=sys.stderr) + return module + + +def _apply_unsafe_fp_math_on_llvm_funcs(module: ir.Module) -> None: + """Apply 'unsafe-fp-math'='true' function attribute to GPU kernel llvm.func ops. + + This tells the LLVM AMDGPU backend to use fast/approximate math lowerings, + e.g. bare v_exp_f32 instead of the safe range-reduced exp2 pattern. """ - if waves_per_eu is None: - return + entries = [] + for attr_name in ("unsafe-fp-math", "no-nans-fp-math", "no-infs-fp-math"): + key = ir.StringAttr.get(attr_name) + val = ir.StringAttr.get("true") + entries.append(ir.ArrayAttr.get([key, val])) + # Flush f32 denormals to zero so the AMDGPU backend emits bare v_exp_f32 + # instead of a safe exp2 pattern with range-checking / v_ldexp_f32. + key_denorm = ir.StringAttr.get("denormal-fp-math-f32") + val_denorm = ir.StringAttr.get("preserve-sign,preserve-sign") + entries.append(ir.ArrayAttr.get([key_denorm, val_denorm])) + entries_strs = {f"{n}=true" for n in ("unsafe-fp-math", "no-nans-fp-math", "no-infs-fp-math")} + entries_strs.add("denormal-fp-math-f32=preserve-sign,preserve-sign") + + def _append_passthrough(func_op): + try: + existing = func_op.attributes["passthrough"] + except KeyError: + existing = None + + if existing is None: + func_op.attributes["passthrough"] = ir.ArrayAttr.get(entries) + return - w = int(waves_per_eu) - if w < 1: - raise ValueError(f"waves_per_eu must be >= 1, got {w}") + try: + existing_entries = list(existing) + except TypeError: + func_op.attributes["passthrough"] = ir.ArrayAttr.get(entries) + return + + existing_strs = {str(a).strip('"') for a in existing_entries} + new_entries = list(existing_entries) + for entry, entry_str in zip(entries, entries_strs): + if entry_str not in existing_strs: + new_entries.append(entry) + func_op.attributes["passthrough"] = ir.ArrayAttr.get(new_entries) try: - # Get the context from the module - with mlir_module.context: - # Navigate MLIR module structure: module -> gpu.module -> gpu.func - for op in mlir_module.body.operations: - # Look for gpu.module operations - if getattr(op, "OPERATION_NAME", None) != "gpu.module": + for op in module.body.operations: + if getattr(op, "OPERATION_NAME", None) != "gpu.module": + continue + gpu_module_body = op.regions[0].blocks[0] if hasattr(op, 'regions') else op.body + for inner_op in gpu_module_body.operations: + if getattr(inner_op, "OPERATION_NAME", None) != "llvm.func": continue + if "gpu.kernel" not in inner_op.attributes: + continue + _append_passthrough(inner_op) + except Exception: + pass - # gpu.module has a single region with a single block - gpu_module_region = op.regions[0] - # Within gpu.module, find gpu.func operations with gpu.kernel attribute - for inner_op in gpu_module_region.blocks[0].operations: - if getattr(inner_op, "OPERATION_NAME", None) != "gpu.func": - continue +def _apply_waves_per_eu_on_llvm_funcs(module: ir.Module, waves_per_eu: int) -> None: + """Apply AMDGPU waves-per-eu hint to llvm.func ops via LLVM passthrough. + + This sets the 'amdgpu-waves-per-eu' attribute on GPU kernel functions, + which hints the LLVM backend about the desired occupancy per EU. + + The passthrough attribute format for LLVM attributes with values is: + ["attribute-name", "attribute-value"] + """ + # For attributes with values, passthrough needs an ArrayAttr with [key, value] + attr_key = ir.StringAttr.get("amdgpu-waves-per-eu") + attr_value = ir.StringAttr.get(f"{waves_per_eu},{waves_per_eu}") + new_entry = ir.ArrayAttr.get([attr_key, attr_value]) + new_entry_str = f"amdgpu-waves-per-eu={waves_per_eu},{waves_per_eu}" + + def _append_passthrough(func_op): + try: + existing = func_op.attributes["passthrough"] + except KeyError: + existing = None + + if existing is None: + func_op.attributes["passthrough"] = ir.ArrayAttr.get([new_entry]) + return + + # Best-effort: if it's not an ArrayAttr-like object, just overwrite. + try: + existing_entries = list(existing) + except TypeError: + func_op.attributes["passthrough"] = ir.ArrayAttr.get([new_entry]) + return + + if any(str(a).strip('"') == new_entry_str for a in existing_entries): + return + func_op.attributes["passthrough"] = ir.ArrayAttr.get(existing_entries + [new_entry]) + + try: + for op in module.body.operations: + if getattr(op, "OPERATION_NAME", None) != "gpu.module": + continue + # gpu.module has a single region with a single block + gpu_module_body = op.regions[0].blocks[0] if hasattr(op, 'regions') else op.body + for inner_op in gpu_module_body.operations: + if getattr(inner_op, "OPERATION_NAME", None) != "llvm.func": + continue + # Check for gpu.kernel attribute (it's a unit attribute) + if "gpu.kernel" not in inner_op.attributes: + continue + _append_passthrough(inner_op) + except Exception: + # Best-effort only. + pass + + +def _apply_flat_work_group_size_on_llvm_funcs(module: ir.Module, max_workgroup_size: int) -> None: + """Apply AMDGPU flat-work-group-size hint to GPU kernel llvm.func ops. + + LLVM expects a string value in the form "min,max". We set min=1 and max to + the requested workgroup size. + """ + attr_key = ir.StringAttr.get("amdgpu-flat-work-group-size") + attr_value = ir.StringAttr.get(f"1,{max_workgroup_size}") + new_entry = ir.ArrayAttr.get([attr_key, attr_value]) + new_entry_str = f"amdgpu-flat-work-group-size=1,{max_workgroup_size}" + + def _append_passthrough(func_op): + try: + existing = func_op.attributes["passthrough"] + except KeyError: + existing = None + + if existing is None: + func_op.attributes["passthrough"] = ir.ArrayAttr.get([new_entry]) + return + + try: + existing_entries = list(existing) + except TypeError: + func_op.attributes["passthrough"] = ir.ArrayAttr.get([new_entry]) + return + + if any(str(a).strip('"') == new_entry_str for a in existing_entries): + return + func_op.attributes["passthrough"] = ir.ArrayAttr.get(existing_entries + [new_entry]) + + try: + for op in module.body.operations: + if getattr(op, "OPERATION_NAME", None) != "gpu.module": + continue + gpu_module_body = op.regions[0].blocks[0] if hasattr(op, 'regions') else op.body + for inner_op in gpu_module_body.operations: + if getattr(inner_op, "OPERATION_NAME", None) != "llvm.func": + continue + if "gpu.kernel" not in inner_op.attributes: + continue + _append_passthrough(inner_op) + except Exception: + # Best-effort only. + pass - # Only apply to kernel functions (not device functions) - if "gpu.kernel" not in inner_op.attributes: - continue - # Add or append to the 'rocdl.waves_per_eu' attribute - # This attribute is read by the ROCDL conversion pass - inner_op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get( - ir.IntegerType.get_signless(32), w - ) - except Exception as e: - # Best-effort: if attribute injection fails, log and continue - # This prevents breaking existing functionality - import warnings - warnings.warn(f"Failed to apply waves_per_eu hint: {e}", RuntimeWarning) def compile( flir_module_or_ir: Union[object, ir.Module], @@ -252,6 +418,10 @@ def compile( use_bare_ptr_memref_call_conv: bool = False, use_bare_pointers_for_host: bool = False, use_bare_pointers_for_kernels: bool = False, + waves_per_eu: Optional[int] = None, + flat_work_group_size: Optional[int] = None, + unsafe_fp_math: bool = False, + fast_fp_math: bool = False, ) -> Optional["Executor"]: """Compile a FLIR module to an Executor. @@ -316,6 +486,8 @@ def compile( use_bare_ptr_memref_call_conv=use_bare_ptr_memref_call_conv, use_bare_pointers_for_host=use_bare_pointers_for_host, use_bare_pointers_for_kernels=use_bare_pointers_for_kernels, + unsafe_fp_math=unsafe_fp_math, + fast_fp_math=fast_fp_math, ) with ctx: @@ -377,6 +549,8 @@ def compile( use_bare_ptr_memref_call_conv=use_bare_ptr_memref_call_conv, use_bare_pointers_for_host=use_bare_pointers_for_host, use_bare_pointers_for_kernels=use_bare_pointers_for_kernels, + unsafe_fp_math=unsafe_fp_math, + fast_fp_math=fast_fp_math, ) # Keep dump filenames stable vs the historical numbering scheme: # 00_target_overridden, then 03..14 for pipeline stages, then 15_final_isa. @@ -392,8 +566,24 @@ def compile( # Dump ISA from the *post-LLVM* module (right before fatbin emission). # This mirrors `tests/utils.py:compile_to_hsaco` and yields readable assembly. + # Also apply waves_per_eu here (after LLVM lowering, before binary generation). + # Match only the top-level reconcile-unrealized-casts, not the one inside gpu.module if frag.strip() == "reconcile-unrealized-casts": - asm_for_isa = stage_asm + # Apply waves_per_eu if specified (BEFORE saving asm_for_isa) + if waves_per_eu is not None: + _apply_waves_per_eu_on_llvm_funcs(module, waves_per_eu) + # Apply flat work-group-size hint if specified. + if flat_work_group_size is not None: + _apply_flat_work_group_size_on_llvm_funcs(module, flat_work_group_size) + # Apply unsafe-fp-math function attributes for fast exp2/math + if unsafe_fp_math: + _apply_unsafe_fp_math_on_llvm_funcs(module) + # Replace __ocml_exp2_f32 with llvm.intr.exp2 for fast exp2 + new_mod = _replace_ocml_exp2_with_intrinsic(module) + if new_mod is not module: + module = new_mod + # Get ASM after applying attributes + asm_for_isa = module.operation.get_asm(enable_debug_info=True) if asm_for_isa is not None: isa_out = _dump_isa_from_rocdl_module_asm( @@ -406,9 +596,53 @@ def compile( isa_stage = f"{stage_num_base + len(stage_frags):02d}_final_isa" print(f"[flir.compile] dump {isa_stage} -> {isa_out}") else: - pm = PassManager.parse(pipeline, context=ctx) - pm.enable_verifier(bool(verify)) - pm.run(module.operation) + need_split = ( + (waves_per_eu is not None) + or (flat_work_group_size is not None) + or unsafe_fp_math + ) + if need_split: + # Need to split the pipeline to apply function attributes + # after LLVM lowering but before binary generation. + stage_frags = _pipeline_fragments( + chip=chip, + use_bare_ptr_memref_call_conv=use_bare_ptr_memref_call_conv, + use_bare_pointers_for_host=use_bare_pointers_for_host, + use_bare_pointers_for_kernels=use_bare_pointers_for_kernels, + unsafe_fp_math=unsafe_fp_math, + fast_fp_math=fast_fp_math, + ) + # Run all passes except the last one (gpu-module-to-binary) + pre_binary_frags = stage_frags[:-1] + binary_frag = stage_frags[-1] + + pre_binary_pipeline = f"builtin.module({','.join(pre_binary_frags)})" + pm = PassManager.parse(pre_binary_pipeline, context=ctx) + pm.enable_verifier(bool(verify)) + pm.run(module.operation) + + # Apply waves_per_eu + if waves_per_eu is not None: + _apply_waves_per_eu_on_llvm_funcs(module, waves_per_eu) + # Apply flat work-group-size hint + if flat_work_group_size is not None: + _apply_flat_work_group_size_on_llvm_funcs(module, flat_work_group_size) + # Apply unsafe-fp-math function attributes for fast exp2/math + if unsafe_fp_math: + _apply_unsafe_fp_math_on_llvm_funcs(module) + # Replace __ocml_exp2_f32 with llvm.intr.exp2 for fast exp2 + new_mod = _replace_ocml_exp2_with_intrinsic(module) + if new_mod is not module: + module = new_mod + + # Run the final binary generation pass + pm_binary = PassManager.parse(f"builtin.module({binary_frag})", context=ctx) + pm_binary.enable_verifier(bool(verify)) + pm_binary.run(module.operation) + else: + pm = PassManager.parse(pipeline, context=ctx) + pm.enable_verifier(bool(verify)) + pm.run(module.operation) if print_final_module: print(module) diff --git a/flydsl/src/flydsl/dialects/ext/arith.py b/flydsl/src/flydsl/dialects/ext/arith.py index 2afc3d86..edcf20f4 100644 --- a/flydsl/src/flydsl/dialects/ext/arith.py +++ b/flydsl/src/flydsl/dialects/ext/arith.py @@ -154,12 +154,13 @@ def f64(value: float, *, loc: Location = None, ip: InsertionPoint = None) -> "Ar """Create an f64 constant.""" return constant(value, type=F64Type.get(), loc=loc, ip=ip) -def maximum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *, loc: Location = None) -> "ArithValue": +def maximum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *, fastmath=None, loc: Location = None) -> "ArithValue": """Compute maximum of two values (automatically handles float/int types). Args: lhs: Left operand (ArithValue, Value, or Python number) rhs: Right operand (ArithValue, Value, or Python number) + fastmath: Optional fast-math flags (e.g. arith.FastMathFlags.fast) loc: Optional source location Returns: @@ -171,7 +172,7 @@ def maximum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *, >>> c = arith.maximum(a, b) # Function style >>> d = a.max(b) # Method style (equivalent) """ - return _minmax_op(lhs, rhs, op_type="max", loc=loc) + return _minmax_op(lhs, rhs, op_type="max", fastmath=fastmath, loc=loc) def minimum(lhs: Union["ArithValue", Value], rhs: Union["ArithValue", Value], *, loc: Location = None) -> "ArithValue": """Compute minimum of two values (automatically handles float/int types). @@ -788,6 +789,7 @@ def _minmax_op( rhs: "ArithValue", op_type: str, # "max" or "min" *, + fastmath=None, loc: Location = None, ) -> "ArithValue": """Execute min/max operation based on operand types.""" @@ -809,7 +811,10 @@ def _minmax_op( op_class = _arith.MaximumFOp else: op_class = _arith.MinimumFOp - result = op_class(lhs_val, rhs_val, loc=loc).result + if fastmath is not None: + result = op_class(lhs_val, rhs_val, fastmath=fastmath, loc=loc).result + else: + result = op_class(lhs_val, rhs_val, loc=loc).result elif _is_integer_like_type(lhs_val.type): # Integer min/max (signed/unsigned logic could be tricky, default to signed for now) # TODO: Add unsigned support if needed diff --git a/flydsl/src/flydsl/dialects/ext/buffer_ops.py b/flydsl/src/flydsl/dialects/ext/buffer_ops.py index 84170e9b..c2c5e895 100644 --- a/flydsl/src/flydsl/dialects/ext/buffer_ops.py +++ b/flydsl/src/flydsl/dialects/ext/buffer_ops.py @@ -40,6 +40,25 @@ 'i32_select', ] +# ============================================================================= +# Constants for Hardware OOB (Out-of-Bounds) Handling +# ============================================================================= +# These values are chosen to match Triton's implementation for reliable hardware +# OOB detection in AMD buffer load/store operations. +# +# Reference: triton/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp +# - OOB_OFFSET = static_cast(std::numeric_limits::max() + int64_t(1)) +# - numRecordsByte = std::numeric_limits::max() - 1 +# +# How it works: +# - When mask=False, offset is replaced with OOB_OFFSET (0x80000000) +# - Hardware compares: if (offset >= num_records) -> return 0 (load) or ignore (store) +# - 0x80000000 (as unsigned) = 2147483648 > 0x7FFFFFFE = 2147483646 +# - This guarantees hardware OOB detection triggers for masked-out elements +# ============================================================================= +OOB_OFFSET = 0x80000000 # -2147483648 as signed i32, 2147483648 as unsigned +MAX_NUM_RECORDS = 0x7FFFFFFE # 2147483646 (std::numeric_limits::max() - 1) + def create_llvm_ptr(value, address_space: int = 0) -> ir.Value: """Convert an index value to LLVM pointer. @@ -195,34 +214,41 @@ def _num_records_from_memref_type() -> Optional[int]: if num_records_bytes is not None: # Caller-provided size in BYTES (preferred for exact hardware OOB behavior). + # NOTE: When using masks, num_records should not exceed MAX_NUM_RECORDS + # to ensure OOB_OFFSET always triggers hardware OOB detection. if isinstance(num_records_bytes, int): nbytes = int(num_records_bytes) if nbytes <= 0: nbytes = 0 - # Descriptor uses i32 bytes; clamp to the max representable. - if nbytes > 0xFFFFFFFF: - nbytes = 0xFFFFFFFF + # Clamp to MAX_NUM_RECORDS to ensure OOB_OFFSET works correctly. + if nbytes > MAX_NUM_RECORDS: + nbytes = MAX_NUM_RECORDS num_records = _create_i32_constant(nbytes) else: # Value path: cast to i32 if needed. + # Note: For dynamic values, we trust the caller to provide valid sizes. + # If the buffer is larger than MAX_NUM_RECORDS, OOB detection may not + # work correctly for masked loads/stores. v = _unwrap_value(num_records_bytes) if not isinstance(v.type, ir.IntegerType) or v.type.width != 32: op = std_arith.IndexCastOp(ir.IntegerType.get_signless(32), v) v = _unwrap_value(op.result) num_records = v elif max_size: - # Use max for flexibility (hardware will check actual bounds) - # Note: flir's rocdl.make.buffer.rsrc requires i32, not i64 - num_records = _create_i32_constant(0xFFFFFFFF) # FALLBACK_MAX_SIZE + # Use MAX_NUM_RECORDS for flexibility with proper OOB handling. + # This value (0x7FFFFFFE) ensures that OOB_OFFSET (0x80000000) will + # always trigger hardware OOB detection. + num_records = _create_i32_constant(MAX_NUM_RECORDS) else: # Use the logical memref size (in bytes) for hardware OOB checking. nbytes = _num_records_from_memref_type() if nbytes is None: - # Fall back to max-size if we can't infer statically. - num_records = _create_i32_constant(0xFFFFFFFF) + # Fall back to MAX_NUM_RECORDS if we can't infer statically. + num_records = _create_i32_constant(MAX_NUM_RECORDS) else: - if nbytes > 0xFFFFFFFF: - nbytes = 0xFFFFFFFF + # Clamp to MAX_NUM_RECORDS for proper OOB handling with masks. + if nbytes > MAX_NUM_RECORDS: + nbytes = MAX_NUM_RECORDS num_records = _create_i32_constant(int(nbytes)) # Create resource descriptor (returns !llvm.ptr<8>) @@ -312,11 +338,13 @@ def buffer_load(rsrc: ir.Value, op = std_arith.MulIOp(offset, bytes_const) offset = _unwrap_value(op.result) - # Apply mask by setting invalid offsets to max + # Apply mask by setting invalid offsets to OOB_OFFSET + # When mask=False, offset becomes OOB_OFFSET (0x80000000), which is always + # >= MAX_NUM_RECORDS (0x7FFFFFFE), triggering hardware OOB (returns 0). if mask is not None: mask = _unwrap_value(mask) - max_offset = _create_i32_constant(0x7FFFFFFF) - op = std_arith.SelectOp(mask, offset, max_offset) + oob_offset = _create_i32_constant(OOB_OFFSET) + op = std_arith.SelectOp(mask, offset, oob_offset) offset = _unwrap_value(op.result) # Create vector type @@ -400,11 +428,13 @@ def buffer_store(data: ir.Value, op = std_arith.MulIOp(offset, bytes_const) offset = _unwrap_value(op.result) - # Apply mask by setting invalid offsets to max + # Apply mask by setting invalid offsets to OOB_OFFSET + # When mask=False, offset becomes OOB_OFFSET (0x80000000), which is always + # >= MAX_NUM_RECORDS (0x7FFFFFFE), triggering hardware OOB (store ignored). if mask is not None: mask = _unwrap_value(mask) - max_offset = _create_i32_constant(0x7FFFFFFF) - op = std_arith.SelectOp(mask, offset, max_offset) + oob_offset = _create_i32_constant(OOB_OFFSET) + op = std_arith.SelectOp(mask, offset, oob_offset) offset = _unwrap_value(op.result) # Create instruction offset (soffset) and aux flags diff --git a/flydsl/src/flydsl/dialects/ext/rocdl.py b/flydsl/src/flydsl/dialects/ext/rocdl.py index be82474d..4fad5bb7 100644 --- a/flydsl/src/flydsl/dialects/ext/rocdl.py +++ b/flydsl/src/flydsl/dialects/ext/rocdl.py @@ -17,6 +17,7 @@ from _mlir.dialects.rocdl import * # noqa: F401,F403 # Keep references to ODS-generated builders so we can wrap them without losing access. +_ods_mfma_f32_32x32x8f16 = globals().get("mfma_f32_32x32x8f16", None) _ods_mfma_f32_16x16x16f16 = mfma_f32_16x16x16f16 _ods_mfma_f32_16x16x16bf16_1k = globals().get("mfma_f32_16x16x16bf16_1k", None) _ods_mfma_f32_16x16x32_fp8_fp8 = mfma_f32_16x16x32_fp8_fp8 @@ -28,6 +29,8 @@ _ods_readlane = readlane _ods_readfirstlane = readfirstlane _ods_ds_swizzle = ds_swizzle +_ods_permlane16_swap = permlane16_swap +_ods_permlane32_swap = permlane32_swap _ods_raw_ptr_buffer_atomic_fadd = raw_ptr_buffer_atomic_fadd mask_mfma = 0x008 @@ -45,6 +48,61 @@ def sched_dswr(cnt): sched_group_barrier(mask_dswr, cnt, 0) +def _unwrap_i32_scalar(v, *, loc=None): + from _mlir.ir import IntegerType + from . import arith as _arith_ext + + return _arith_ext.unwrap(v, type=IntegerType.get_signless(32), loc=loc) + + +def async_global_load_to_lds(global_ptr, lds_ptr, size, offset=0, aux=0, *, loc=None, ip=None): + """Global->LDS async-style copy wrapper (closest stable ROCDL primitive).""" + from . import arith as _arith_ext + + return global_load_lds( + _arith_ext.unwrap(global_ptr, loc=loc), + _arith_ext.unwrap(lds_ptr, loc=loc), + _unwrap_i32_scalar(size, loc=loc), + _unwrap_i32_scalar(offset, loc=loc), + _unwrap_i32_scalar(aux, loc=loc), + loc=loc, + ip=ip, + ) + + +def async_load_to_lds(global_ptr, lds_ptr, size, offset=0, aux=0, *, loc=None, ip=None): + """Alias for load_to_lds with scalar auto-unwrapping.""" + from . import arith as _arith_ext + + return load_to_lds( + _arith_ext.unwrap(global_ptr, loc=loc), + _arith_ext.unwrap(lds_ptr, loc=loc), + _unwrap_i32_scalar(size, loc=loc), + _unwrap_i32_scalar(offset, loc=loc), + _unwrap_i32_scalar(aux, loc=loc), + loc=loc, + ip=ip, + ) + + +def async_load_fence(wait_vmem=0, wait_ds=0, *, loc=None, ip=None): + """Waitcnt-style fence helper for staged async copy scheduling.""" + # NOTE: wait_loadcnt/wait_dscnt lowerings are not stable on current toolchain. + # Use conservative full waitcnt fence for now. + _ = (wait_vmem, wait_ds) + return s_waitcnt(0, loc=loc, ip=ip) + + +def phase_barrier(mask=0, *, loc=None, ip=None): + """Scheduling barrier wrapper used as phase fence in pipelined kernels.""" + return sched_barrier(mask, loc=loc, ip=ip) + + +def phase_group_barrier(mask, size, group_id=0, *, loc=None, ip=None): + """Group scheduling barrier wrapper used as phase fence in pipelined kernels.""" + return sched_group_barrier(mask, size, group_id, loc=loc, ip=ip) + + def _unwrap_mfma_operand(v, *, loc=None): """MFMA operands are MLIR Values; some trailing operands are i32 flags. @@ -68,6 +126,20 @@ def mfma_f32_16x16x16f16(result_type, operands, *, loc=None, ip=None): """Return the op result directly (no `.result` needed at call sites).""" return mfma_f32_16x16x16f16_op(result_type, operands, loc=loc, ip=ip).result + +def mfma_f32_32x32x8f16_op(result_type, operands, *, loc=None, ip=None): + """Return the op view (original behavior).""" + if _ods_mfma_f32_32x32x8f16 is None: + raise AttributeError("ROCDL op not found: mfma_f32_32x32x8f16") + ops = [_unwrap_mfma_operand(v, loc=loc) for v in operands] + return _ods_mfma_f32_32x32x8f16(result_type, ops, loc=loc, ip=ip) + + +def mfma_f32_32x32x8f16(result_type, operands, *, loc=None, ip=None): + """Return the op result directly (no `.result` needed at call sites).""" + return mfma_f32_32x32x8f16_op(result_type, operands, loc=loc, ip=ip).result + + # for bf16 version mfma def mfma_f32_16x16x16bf16_1k_op(result_type, operands, *, loc=None, ip=None): """Return the op view (original behavior).""" @@ -138,6 +210,73 @@ def ds_swizzle(result_type, src, offset, *, loc=None, ip=None): return _ods_ds_swizzle(result_type, _arith_ext.unwrap(src), _arith_ext.unwrap(offset), loc=loc, ip=ip) +def _unwrap_i32_lane_operand(v, *, loc=None): + from _mlir.ir import IntegerType + from . import arith as _arith_ext + + return _arith_ext.unwrap(v, type=IntegerType.get_signless(32), loc=loc) + + +def _permlane_i32x2_struct_type(): + from _mlir import ir as _ir + + # Some Python bindings accept optional spaces in LLVM type parser; keep both. + try: + return _ir.Type.parse("!llvm.struct<(i32, i32)>") + except Exception: + return _ir.Type.parse("!llvm.struct<(i32,i32)>") + + +def _extract_permlane_lane_i32(pair_val, *, loc=None, ip=None): + from _mlir.dialects import llvm as _llvm + from _mlir.ir import IntegerType + + i32 = IntegerType.get_signless(32) + return _llvm.extractvalue(i32, pair_val, [0], loc=loc, ip=ip) + + +def permlane16_swap_pair(old, src, fi=False, bound_control=False, *, loc=None, ip=None): + """High-level permlane16 swap wrapper returning the raw i32x2 struct.""" + return _ods_permlane16_swap( + _permlane_i32x2_struct_type(), + _unwrap_i32_lane_operand(old, loc=loc), + _unwrap_i32_lane_operand(src, loc=loc), + fi, + bound_control, + loc=loc, + ip=ip, + ) + + +def permlane16_swap_i32(old, src, fi=False, bound_control=False, *, loc=None, ip=None): + """High-level permlane16 swap wrapper returning the swapped i32 lane value.""" + pair_val = permlane16_swap_pair( + old, src, fi=fi, bound_control=bound_control, loc=loc, ip=ip + ) + return _extract_permlane_lane_i32(pair_val, loc=loc, ip=ip) + + +def permlane32_swap_pair(old, src, fi=False, bound_control=False, *, loc=None, ip=None): + """High-level permlane32 swap wrapper returning the raw i32x2 struct.""" + return _ods_permlane32_swap( + _permlane_i32x2_struct_type(), + _unwrap_i32_lane_operand(old, loc=loc), + _unwrap_i32_lane_operand(src, loc=loc), + fi, + bound_control, + loc=loc, + ip=ip, + ) + + +def permlane32_swap_i32(old, src, fi=False, bound_control=False, *, loc=None, ip=None): + """High-level permlane32 swap wrapper returning the swapped i32 lane value.""" + pair_val = permlane32_swap_pair( + old, src, fi=fi, bound_control=bound_control, loc=loc, ip=ip + ) + return _extract_permlane_lane_i32(pair_val, loc=loc, ip=ip) + + def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None, ip=None): """Atomic fadd that accepts `ArithValue` / wrappers (no explicit `arith.unwrap(...)` needed). @@ -173,6 +312,7 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None, 'barrier', 's_barrier', 's_barrier_signal', 's_barrier_wait', 's_waitcnt', 's_wait_loadcnt', 's_wait_storecnt', 's_wait_dscnt', 's_wait_expcnt', + 'async_load_fence', # Matrix operations - MFMA (Matrix Fused Multiply-Add) 'mfma_f32_32x32x8f16', 'mfma_f32_16x16x16f16', @@ -182,7 +322,7 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None, 'mfma_i32_16x16x32_i8', 'mfma_scale_f32_16x16x128_f8f6f4', # Raw-op constructors (return op view) for the above - 'mfma_f32_16x16x16f16_op', 'mfma_f32_16x16x32_fp8_fp8_op', + 'mfma_f32_32x32x8f16_op', 'mfma_f32_16x16x16f16_op', 'mfma_f32_16x16x32_fp8_fp8_op', 'mfma_f32_16x16x16bf16_1k_op', 'mfma_i32_16x16x32_i8_op', 'mfma_scale_f32_16x16x128_f8f6f4_op', @@ -198,6 +338,8 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None, # Shuffle and permutation 'ds_swizzle', 'ds_bpermute', 'permlanex16', 'permlane16_swap', 'permlane32_swap', + 'permlane16_swap_pair', 'permlane16_swap_i32', + 'permlane32_swap_pair', 'permlane32_swap_i32', 'readlane', 'readfirstlane', 'update_dpp', 'ballot', @@ -206,6 +348,7 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None, 'raw_buffer_load', 'raw_buffer_store', 'raw_ptr_buffer_load', 'raw_ptr_buffer_store', 'load_to_lds', 'global_load_lds', + 'async_load_to_lds', 'async_global_load_to_lds', 'make_buffer_rsrc', # Atomic operations @@ -219,6 +362,8 @@ def raw_ptr_buffer_atomic_fadd(val, rsrc, voffset, soffset, cache, *, loc=None, # Scheduling and optimization 's_setprio', 's_sleep', 'sched_barrier', 'sched_group_barrier', + 'phase_barrier', 'phase_group_barrier', + 'sched_mfma', 'sched_vmem', 'sched_dsrd', 'sched_dswr', 'iglp_opt', # Type conversions diff --git a/flydsl/src/flydsl/dialects/ext/scf.py b/flydsl/src/flydsl/dialects/ext/scf.py index 5e87dfb8..f11c31c3 100644 --- a/flydsl/src/flydsl/dialects/ext/scf.py +++ b/flydsl/src/flydsl/dialects/ext/scf.py @@ -20,6 +20,33 @@ from .arith import constant +def _as_value(v): + """Unwrap various 'Value-like' wrappers (ArithValue, etc.) to a raw MLIR Value. + + This is needed because generated op builders (like ``_scf.YieldOp``) check + ``isinstance(v, Value)`` which fails for ``ArithValue`` wrappers created by + ``register_value_caster``. + """ + seen = set() + while True: + if isinstance(v, Value): + return v + obj_id = id(v) + if obj_id in seen: + return v + seen.add(obj_id) + if hasattr(v, "_value"): + v = v._value + continue + if hasattr(v, "value"): + v = v.value + continue + if hasattr(v, "result"): + v = v.result + continue + return v + + def _normalize_if_condition(condition): """Best-effort normalization for scf.if conditions. @@ -183,34 +210,10 @@ def range_( start, stop, step = canonicalize_range(start, stop, step) - # Unwrap various "Value-like" wrappers down to a real `_mlir.ir.Value`. - # We need this because our arithmetic helpers often return wrapper objects - # (e.g. `ArithValue`) which are not accepted as operands by generated op - # builders (like `_scf.ForOp`). - def _as_value(v): - seen = set() - while True: - if isinstance(v, Value): - return v - obj_id = id(v) - if obj_id in seen: - return v - seen.add(obj_id) - if hasattr(v, "_value"): - v = v._value - continue - if hasattr(v, "value"): - v = v.value - continue - if hasattr(v, "result"): - v = v.result - continue - return v - start = _as_value(start) stop = _as_value(stop) step = _as_value(step) - + iter_args = iter_args or [] iter_args = [_as_value(a) for a in iter_args] for_op = _scf.ForOp(start, stop, step, iter_args, loc=loc, ip=ip) @@ -227,7 +230,8 @@ def _as_value(v): # Ensure scf.for body is terminated. block = for_op.body if (not block.operations) or not isinstance(block.operations[-1], _scf.YieldOp): - _scf.YieldOp(list(for_op.inner_iter_args)) + # Unwrap ArithValue wrappers before passing to YieldOp + _scf.YieldOp([_as_value(a) for a in for_op.inner_iter_args]) @contextmanager @@ -282,26 +286,6 @@ def for_( return start, stop, step = canonicalize_range(start, stop, step) - # Unwrap various "Value-like" wrappers down to a real `_mlir.ir.Value`. - def _as_value(v): - seen = set() - while True: - if isinstance(v, Value): - return v - obj_id = id(v) - if obj_id in seen: - return v - seen.add(obj_id) - if hasattr(v, "_value"): - v = v._value - continue - if hasattr(v, "value"): - v = v.value - continue - if hasattr(v, "result"): - v = v.result - continue - return v start = _as_value(start) stop = _as_value(stop) @@ -316,7 +300,8 @@ def _as_value(v): finally: block = for_op.body if (not block.operations) or not isinstance(block.operations[-1], _scf.YieldOp): - _scf.YieldOp(list(for_op.inner_iter_args)) + # Unwrap ArithValue wrappers before passing to YieldOp + _scf.YieldOp([_as_value(a) for a in for_op.inner_iter_args]) @contextmanager @@ -475,16 +460,20 @@ def yield_( ip: InsertionPoint = None, ): """Create an scf.yield operation. - + + Automatically unwraps ArithValue wrappers so callers don't need + ``arith.as_value()`` on every yielded operand. + Args: - operands: Values to yield - loc: Location for the operation - ip: Insertion point + operands: Values to yield (accepts ArithValue wrappers). + loc: Location for the operation. + ip: Insertion point. """ if loc is None: loc = Location.unknown() - + operands = operands or [] + operands = [_as_value(o) for o in operands] return _scf.YieldOp(operands, loc=loc, ip=ip) diff --git a/flydsl/src/flydsl/lang/ir/module.py b/flydsl/src/flydsl/lang/ir/module.py index a4e01498..a538ab0d 100644 --- a/flydsl/src/flydsl/lang/ir/module.py +++ b/flydsl/src/flydsl/lang/ir/module.py @@ -318,6 +318,21 @@ def wrapper(instance_self, *args, **kwargs): k.qualname = instance_self.GPU_MODULE_NAME except Exception: pass + # Set known_block_size if GPU_BLOCK_SIZE is defined on the module class. + # This tells convert-gpu-to-rocdl the workgroup size so that + # max_flat_workgroup_size is set correctly in the ISA. + block_size = getattr(instance_self, "GPU_BLOCK_SIZE", None) + if block_size is not None: + try: + if isinstance(block_size, int): + block_size = (block_size, 1, 1) + func_op = k._func_op if hasattr(k, "_func_op") else k + op = getattr(func_op, "operation", func_op) + op.attributes["known_block_size"] = ir.DenseI32ArrayAttr.get( + list(block_size) + ) + except Exception: + pass instance_self.kernel_func_op[fn.__name__] = k self._wrapper = wrapper diff --git a/kernels/flash_attn_func.py b/kernels/flash_attn_func.py new file mode 100644 index 00000000..cbe20899 --- /dev/null +++ b/kernels/flash_attn_func.py @@ -0,0 +1,1783 @@ +"""flash_attn_func kernel builder for FlyDSL. + +Aggressive flash_attn_func path: +- True MFMA32 remap: `mfma_f32_32x32x8f16` for both GEMM stages. +- Tile shape: BLOCK_M=128, BLOCK_N=32, 4 waves (256 threads). +- Per-wave Q rows: 32. +- GEMM1 uses `K @ Q^T` so S/P live in MFMA32 register layout. +- Online softmax over KV dimension is done in registers. +- P is kept in registers and fed directly to GEMM2 (`V^T @ P`) without LDS roundtrip. +- K and V^T use separate LDS regions (single-buffered per iteration). + +Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). +Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. +Block: (256,) -- 4 waves of 64 on AMD (wave64). + +Requires: head_dim % 32 == 0, head_dim >= 64, seq_len % 128 == 0. +""" + +import math +import os + +from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl +from flydsl.dialects.ext import vector as vec_ext +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.dialects.ext.scf import yield_ as scf_yield +from _mlir.dialects import memref as _memref +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils import SmemAllocator +from _mlir import ir +import _mlir.extras.types as T + + +KERNEL_NAME = "flash_attn_func_kernel" + + +def select_flash_attn_func_path(num_heads, head_dim, causal=True, dtype_str="f16"): + """Select active flash_attn_func path tag for build-time specialization.""" + override = os.getenv("FLYDSL_FLASH_ATTN_FUNC_PATH", "auto").strip().lower() + if override in ("fallback", "fallback_n32", "n32"): + return "fallback_n32" + if override in ("fastpath", "ck_n128_fastpath", "n128"): + return "ck_n128_fastpath" + # Keep N128 path feature-gated by default due current occupancy/perf risk. + enable_n128 = os.getenv("FLYDSL_FLASH_ATTN_FUNC_ENABLE_N128", "0") == "1" + if ( + enable_n128 + and dtype_str == "f16" + and causal + and num_heads == 64 + and head_dim == 128 + ): + return "ck_n128_fastpath" + return "fallback_n32" + + +def build_flash_attn_func_module_primary( + num_heads, + head_dim, + causal=True, + dtype_str="f16", + sm_scale=None, +): + """Build a FlyDSL flash_attn_func module.""" + gpu_arch = get_hip_arch() + DYN = ir.ShapedType.get_dynamic_size() + + # Aggressive MFMA32 configuration for target B=1, H=64, S=8192, D=128. + BLOCK_M = 128 + BLOCK_N = 32 + NUM_WAVES = 4 + WARP_SIZE = 64 + BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 + ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 32 + PATH_TAG = select_flash_attn_func_path( + num_heads, head_dim, causal=causal, dtype_str=dtype_str + ) + BLOCK_N_OUT = 128 if PATH_TAG == "ck_n128_fastpath" else BLOCK_N + N_SUBTILES = BLOCK_N_OUT // BLOCK_N + ENABLE_PREFETCH_3BUF = ( + os.getenv("FLYDSL_FLASH_ATTN_FUNC_ENABLE_PREFETCH3", "0") == "1" + ) + ENABLE_LDS_VEC16 = os.getenv("FLYDSL_FLASH_ATTN_FUNC_ENABLE_LDS_VEC16", "1") == "1" + REDUCE_MODE = os.getenv("FLYDSL_FLASH_ATTN_FUNC_REDUCE_MODE", "xor").strip().lower() + if REDUCE_MODE not in ("xor", "ds_bpermute"): + REDUCE_MODE = "xor" + NUM_PREFETCH_K = 3 if ENABLE_PREFETCH_3BUF else 1 + NUM_PREFETCH_V = 3 if ENABLE_PREFETCH_3BUF else 1 + CK_LDS_SEQ = (1, 2, 0, 1, 0, 1, 2, 0) if ENABLE_PREFETCH_3BUF else (0,) + + # MFMA32 K-dimension is 8. + K_STEP_QK = 8 + K_STEPS_QK = head_dim // K_STEP_QK + # PV stage computes 32 output columns per accumulator chunk. + D_CHUNK = 32 + D_CHUNKS = head_dim // D_CHUNK + PV_K_STEP = 8 + PV_K_STEPS = BLOCK_N // PV_K_STEP # 4 for BN=32 + + assert BLOCK_M % NUM_WAVES == 0 + assert head_dim % 32 == 0, f"head_dim ({head_dim}) must be divisible by 32" + assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" + assert dtype_str == "f16", "flash_attn_func currently only supports f16" + assert BLOCK_N % 32 == 0 + assert BLOCK_N_OUT % BLOCK_N == 0 + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + NUM_HEADS = num_heads + HEAD_DIM = head_dim + CAUSAL = causal + STRIDE_TOKEN = NUM_HEADS * HEAD_DIM + + # Bank-conflict-friendly LDS strides. + K_STRIDE = HEAD_DIM + 2 + VT_STRIDE = BLOCK_N + 2 + + # Vectorized cooperative load constants. + VEC_WIDTH = 16 if ENABLE_LDS_VEC16 else 8 + assert HEAD_DIM % VEC_WIDTH == 0 + THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH + assert BLOCK_SIZE % THREADS_PER_ROW_LOAD == 0 + ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD + + if ROWS_PER_BATCH_LOAD >= BLOCK_N: + NUM_BATCHES_KV = 1 + KV_NEEDS_GUARD = ROWS_PER_BATCH_LOAD > BLOCK_N + else: + assert BLOCK_N % ROWS_PER_BATCH_LOAD == 0 + NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD + KV_NEEDS_GUARD = False + + # K/VT circular buffers; defaults to 1/1, optional 3/3 with CK-like LDS sequence. + LDS_K_TILE_SIZE = BLOCK_N * K_STRIDE + LDS_VT_TILE_SIZE = HEAD_DIM * VT_STRIDE + LDS_K_TOTAL_SIZE = NUM_PREFETCH_K * LDS_K_TILE_SIZE + LDS_VT_BASE = LDS_K_TOTAL_SIZE + LDS_VT_TOTAL_SIZE = NUM_PREFETCH_V * LDS_VT_TILE_SIZE + LDS_KV_TOTAL_SIZE = LDS_K_TOTAL_SIZE + LDS_VT_TOTAL_SIZE + + allocator = SmemAllocator(None, arch=gpu_arch) + _state = {} + + class _FlashAttnFunc(flir.MlirModule): + GPU_MODULE_NAME = f"flash_attn_func_{dtype_str}_{PATH_TAG}" + KERNEL_VARIANT = PATH_TAG + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + elem_type = T.f16() + _state["elem_type"] = elem_type + _state["lds_kv"] = allocator.allocate_array(elem_type, LDS_KV_TOTAL_SIZE) + allocator.finalize() + + @flir.kernel + def flash_attn_func_kernel( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + seq_len: lambda: T.index(), + ): + compute_type = T.f32() + elem_type = _state["elem_type"] + fm_fast = flir.arith.FastMathFlags.fast + + v4f16_type = ir.VectorType.get([4], elem_type) + vxf16_type = ir.VectorType.get([VEC_WIDTH], elem_type) + v16f32_type = ir.VectorType.get([16], compute_type) + + seq_len_v = arith.as_value(seq_len) + + # ---- LDS view ---- + base_ptr = allocator.get_base() + lds_kv = _state["lds_kv"](base_ptr).get() + + # ---- Thread / block indices ---- + block_id = flir.const_index(flir.block_idx("x")) + tid = flir.const_index(flir.thread_idx("x")) + + # ---- Wave decomposition ---- + c_ws = flir.const_index(WARP_SIZE) + wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) + lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) + + c32 = flir.const_index(32) + lane_mod_32 = arith.as_value(flir.arith.RemUIOp(lane, c32).result) + lane_div_32 = arith.as_value(flir.arith.DivUIOp(lane, c32).result) # 0/1 + + # ---- Wave offsets ---- + wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value + + # ---- Decompose block_id ---- + c_nh = flir.const_index(NUM_HEADS) + head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) + temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) + c_bm = flir.const_index(BLOCK_M) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) + q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) + batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) + q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value + + # ---- Cooperative load decomposition ---- + c_tpr = flir.const_index(THREADS_PER_ROW_LOAD) + load_row_in_batch = arith.as_value(flir.arith.DivUIOp(tid, c_tpr).result) + load_lane_in_row = arith.as_value(flir.arith.RemUIOp(tid, c_tpr).result) + load_col_base = (arith.ArithValue(load_lane_in_row) * VEC_WIDTH).value + + # ---- Helper: global flat index ---- + def global_idx(token_idx, col): + token = ( + arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) + + arith.ArithValue(token_idx) + ) + return ( + token * STRIDE_TOKEN + + arith.ArithValue(head_idx) * HEAD_DIM + + arith.ArithValue(col) + ).value + + def k_buf_base(buf_id): + return flir.const_index(buf_id * LDS_K_TILE_SIZE) + + def vt_buf_base(buf_id): + return flir.const_index(LDS_VT_BASE + buf_id * LDS_VT_TILE_SIZE) + + # ---- Cooperative K load (row-major, padded stride) ---- + def coop_load_k(tile_start, buf_id=0): + k_base = k_buf_base(buf_id) + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + if KV_NEEDS_GUARD: + c_bn = flir.const_index(BLOCK_N) + row_valid = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + arith.ArithValue(load_row_in_batch).value, + c_bn, + ).result + ) + with scf.if_(row_valid): + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(vxf16_type, K, [g_idx])) + lds_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + lds_idx = ( + arith.ArithValue(k_base) + + arith.ArithValue(lds_row) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + else: + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(vxf16_type, K, [g_idx])) + lds_row = (arith.ArithValue(load_row_in_batch) + row_offset).value + lds_idx = ( + arith.ArithValue(k_base) + + arith.ArithValue(lds_row) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + + # ---- Cooperative V load (transposed, padded stride) ---- + def coop_load_v_transposed(tile_start, buf_id=0): + vt_base = vt_buf_base(buf_id) + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + if KV_NEEDS_GUARD: + c_bn = flir.const_index(BLOCK_N) + row_valid = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + arith.ArithValue(load_row_in_batch).value, + c_bn, + ).result + ) + with scf.if_(row_valid): + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(vxf16_type, V, [g_idx])) + load_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + arith.ArithValue(vt_base) + + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + else: + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(vxf16_type, V, [g_idx])) + load_row = (arith.ArithValue(load_row_in_batch) + row_offset).value + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + arith.ArithValue(vt_base) + + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + + # ---- Preload Q^T B-operand packs once (register-resident) ---- + # B operand uses j = lane_mod_32, k-subblock = lane_div_32*4. + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_mod_32) + ).value + q_b_packs = [] + for ks in range_constexpr(K_STEPS_QK): + q_col = ( + flir.const_index(ks * K_STEP_QK) + + arith.ArithValue(lane_div_32) * 4 + ).value + g_idx = global_idx(q_row, q_col) + q_b_packs.append(arith.as_value(vec_ext.load_op(v4f16_type, Q, [g_idx]))) + + # ---- Constants ---- + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_zero_f = arith.constant(0.0, type=compute_type) + c_one_f = arith.constant(1.0, type=compute_type) + c_sm_scale = arith.constant(sm_scale, type=compute_type) + c_log2e = arith.constant(1.4426950408889634, type=compute_type) + c_zero_v16f32 = arith.as_value(arith.constant_vector(0.0, v16f32_type)) + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + shuf_32_i32 = arith.as_value(arith.constant(32, type=T.i32())) + c4_i32 = arith.as_value(arith.constant(4, type=T.i32())) + lane_i32 = arith.as_value(flir.arith.IndexCastOp(T.i32(), lane).result) + lane_xor_32_i32 = arith.as_value(flir.arith.XOrIOp(lane_i32, shuf_32_i32).result) + lane_xor_32_byte = arith.as_value( + flir.arith.MulIOp(lane_xor_32_i32, c4_i32).result + ) + + def reduction_peer(v_f32): + if REDUCE_MODE == "ds_bpermute": + v_i32 = arith.as_value(flir.arith.bitcast(T.i32(), v_f32)) + peer_i32 = arith.as_value( + rocdl.ds_bpermute(T.i32(), lane_xor_32_byte, v_i32) + ) + return arith.as_value(flir.arith.bitcast(compute_type, peer_i32)) + return arith.as_value( + gpu.ShuffleOp(v_f32, shuf_32_i32, width_i32, mode="xor").shuffleResult + ) + + # ---- KV loop upper bound ---- + if CAUSAL: + kv_upper = (arith.ArithValue(q_start) + BLOCK_M).value + else: + kv_upper = seq_len_v + + # Loop-carried: [m_old, l_old, o_acc_chunks...] + init_args = [arith.as_value(c_neg_inf), arith.as_value(c_zero_f)] + for _ in range_constexpr(D_CHUNKS): + init_args.append(c_zero_v16f32) + + with scf.for_(0, kv_upper, BLOCK_N_OUT, iter_args=init_args) as loop: + kv_block_start = arith.as_value(loop.induction_variable) + m_running = arith.as_value(loop.inner_iter_args[0]) + l_running = arith.as_value(loop.inner_iter_args[1]) + o_accs = [arith.as_value(loop.inner_iter_args[2 + i]) for i in range_constexpr(D_CHUNKS)] + preload_k_count = NUM_PREFETCH_K if NUM_PREFETCH_K < N_SUBTILES else N_SUBTILES + + if ENABLE_PREFETCH_3BUF: + for pre_k in range_constexpr(preload_k_count): + pre_k_slot = CK_LDS_SEQ[pre_k % len(CK_LDS_SEQ)] % NUM_PREFETCH_K + pre_k_start = (arith.ArithValue(kv_block_start) + pre_k * BLOCK_N).value + coop_load_k(pre_k_start, pre_k_slot) + rocdl.phase_group_barrier(rocdl.mask_vmem_rd, 1, 0) + gpu.barrier() + + for kv_sub in range_constexpr(N_SUBTILES): + kv_start = (arith.ArithValue(kv_block_start) + kv_sub * BLOCK_N).value + + if ENABLE_PREFETCH_3BUF: + k_slot = CK_LDS_SEQ[kv_sub % len(CK_LDS_SEQ)] % NUM_PREFETCH_K + else: + k_slot = 0 + # ==== Cooperative K load -> LDS_KV ==== + coop_load_k(kv_start, k_slot) + gpu.barrier() + k_base = k_buf_base(k_slot) + + # ==== GEMM1: S = K @ Q^T (MFMA32), S in v16f32 ==== + s_acc = c_zero_v16f32 + for ks in range_constexpr(K_STEPS_QK): + k_idx = ( + arith.ArithValue(k_base) + + arith.ArithValue(lane_mod_32) * K_STRIDE + + ks * K_STEP_QK + + arith.ArithValue(lane_div_32) * 4 + ).value + k_pack = arith.as_value(vec_ext.load_op(v4f16_type, lds_kv, [k_idx])) + q_pack = q_b_packs[ks] + s_acc = arith.as_value( + rocdl.mfma_f32_32x32x8f16( + v16f32_type, [k_pack, q_pack, s_acc, 0, 0, 0] + ) + ) + + # ==== Online softmax over KV dimension (register only) ==== + q_row_i64 = arith.as_value( + flir.arith.IndexCastOp(T.i64(), q_row).result + ) + + s_vals = [] + for r in range_constexpr(16): + s_val = arith.as_value( + vec_ext.extract(s_acc, static_position=[r], dynamic_position=[]) + ) + s_val = arith.as_value( + flir.arith.MulFOp( + s_val, arith.as_value(c_sm_scale), fastmath=fm_fast + ).result + ) + if CAUSAL: + kv_row_rel = ( + arith.ArithValue(lane_div_32) * 4 + + (r // 4) * 8 + + (r % 4) + ).value + kv_col = ( + arith.ArithValue(kv_start) + arith.ArithValue(kv_row_rel) + ).value + kv_col_i64 = arith.as_value( + flir.arith.IndexCastOp(T.i64(), kv_col).result + ) + is_masked = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ugt, kv_col_i64, q_row_i64 + ).result + ) + s_val = arith.as_value( + flir.arith.SelectOp( + is_masked, arith.as_value(c_neg_inf), s_val + ).result + ) + s_vals.append(s_val) + + local_max = s_vals[0] + for r in range_constexpr(15): + local_max = arith.as_value( + flir.arith.MaximumFOp(local_max, s_vals[r + 1]).result + ) + peer_max = reduction_peer(local_max) + row_max = arith.as_value( + flir.arith.MaximumFOp(local_max, peer_max).result + ) + m_new = arith.as_value( + flir.arith.MaximumFOp(m_running, row_max).result + ) + + diff_m = arith.as_value( + flir.arith.SubFOp(m_running, m_new, fastmath=fm_fast).result + ) + diff_m_s = arith.as_value( + flir.arith.MulFOp( + diff_m, arith.as_value(c_log2e), fastmath=fm_fast + ).result + ) + corr = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) + + p_vals = [] + local_sum = arith.as_value(c_zero_f) + for r in range_constexpr(16): + diff = arith.as_value( + flir.arith.SubFOp(s_vals[r], m_new, fastmath=fm_fast).result + ) + diff_s = arith.as_value( + flir.arith.MulFOp( + diff, arith.as_value(c_log2e), fastmath=fm_fast + ).result + ) + p = arith.as_value(flir.math.exp2(diff_s, fastmath=fm_fast)) + p_vals.append(p) + local_sum = arith.as_value( + flir.arith.AddFOp(local_sum, p, fastmath=fm_fast).result + ) + + peer_sum = reduction_peer(local_sum) + tile_sum = arith.as_value( + flir.arith.AddFOp(local_sum, peer_sum, fastmath=fm_fast).result + ) + l_corr = arith.as_value( + flir.arith.MulFOp(corr, l_running, fastmath=fm_fast).result + ) + l_new = arith.as_value( + flir.arith.AddFOp(l_corr, tile_sum, fastmath=fm_fast).result + ) + + # ==== Rescale O accumulators ==== + corr_vec = arith.as_value(vec_ext.broadcast(v16f32_type, corr)) + for dc in range_constexpr(D_CHUNKS): + o_accs[dc] = arith.as_value( + flir.arith.MulFOp(o_accs[dc], corr_vec, fastmath=fm_fast).result + ) + + if ENABLE_PREFETCH_3BUF and (kv_sub + preload_k_count) < N_SUBTILES: + next_k_sub = kv_sub + preload_k_count + next_k_start = ( + arith.ArithValue(kv_block_start) + next_k_sub * BLOCK_N + ).value + next_k_slot = CK_LDS_SEQ[next_k_sub % len(CK_LDS_SEQ)] % NUM_PREFETCH_K + coop_load_k(next_k_start, next_k_slot) + + if ENABLE_PREFETCH_3BUF: + v_slot = CK_LDS_SEQ[kv_sub % len(CK_LDS_SEQ)] % NUM_PREFETCH_V + else: + v_slot = 0 + v_base = vt_buf_base(v_slot) + + # ==== Load V^T for current tile into LDS_KV ==== + coop_load_v_transposed(kv_start, v_slot) + rocdl.phase_group_barrier(rocdl.mask_dswr, 1, 0) + gpu.barrier() + + # ==== Build P packs in MFMA32 B-input format from register S ==== + p_f16 = [] + for r in range_constexpr(16): + p_f16.append( + arith.as_value(flir.arith.TruncFOp(elem_type, p_vals[r]).result) + ) + p_packs = [] + for pks in range_constexpr(PV_K_STEPS): + p_base = pks * 4 + p_packs.append( + arith.as_value( + vec_ext.from_elements( + v4f16_type, + [ + p_f16[p_base + 0], + p_f16[p_base + 1], + p_f16[p_base + 2], + p_f16[p_base + 3], + ], + ) + ) + ) + + # ==== GEMM2: O^T += V^T @ P (MFMA32) ==== + for dc in range_constexpr(D_CHUNKS): + for pks in range_constexpr(PV_K_STEPS): + v_idx = ( + arith.ArithValue(v_base) + + (dc * D_CHUNK + arith.ArithValue(lane_mod_32)) * VT_STRIDE + + pks * PV_K_STEP + + arith.ArithValue(lane_div_32) * 4 + ).value + v_pack = arith.as_value(vec_ext.load_op(v4f16_type, lds_kv, [v_idx])) + o_accs[dc] = arith.as_value( + rocdl.mfma_f32_32x32x8f16( + v16f32_type, [v_pack, p_packs[pks], o_accs[dc], 0, 0, 0] + ) + ) + + m_running = m_new + l_running = l_new + + yield_args = [m_running, l_running] + o_accs + scf_yield(yield_args) + + # ---- Normalize and store O ---- + l_final = arith.as_value(loop.results[1]) + o_finals = [arith.as_value(loop.results[2 + dc]) for dc in range_constexpr(D_CHUNKS)] + + inv_l = arith.as_value( + flir.arith.DivFOp(arith.as_value(c_one_f), l_final, fastmath=fm_fast).result + ) + inv_l_vec = arith.as_value(vec_ext.broadcast(v16f32_type, inv_l)) + + for dc in range_constexpr(D_CHUNKS): + o_norm_vec = arith.as_value( + flir.arith.MulFOp(o_finals[dc], inv_l_vec, fastmath=fm_fast).result + ) + for r in range_constexpr(16): + o_val = arith.as_value( + vec_ext.extract(o_norm_vec, static_position=[r], dynamic_position=[]) + ) + o_f16 = arith.as_value(flir.arith.TruncFOp(elem_type, o_val).result) + + d_row_rel = ( + arith.ArithValue(lane_div_32) * 4 + + (r // 4) * 8 + + (r % 4) + ).value + d_col = (flir.const_index(dc * D_CHUNK) + arith.ArithValue(d_row_rel)).value + o_global = global_idx(q_row, d_col) + _memref.StoreOp(o_f16, O, [o_global]) + + @flir.jit + def __call__( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + batch_size: lambda: T.index(), + seq_len: lambda: T.index(), + ): + c1 = arith.as_value(flir.arith_ext.index(1)) + c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) + c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) + bs_val = arith.as_value(batch_size) + sl_val = arith.as_value(seq_len) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(sl_val, c_bm).result) + bs_qt = arith.as_value(flir.arith.MulIOp(bs_val, num_q_tiles).result) + grid_x = arith.as_value(flir.arith.MulIOp(bs_qt, c_nh).result) + bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) + flir.gpu_ext.LaunchFuncOp( + [self.GPU_MODULE_NAME, KERNEL_NAME], + grid_size=(grid_x, c1, c1), + block_size=(bx, c1, c1), + kernel_operands=[Q, K, V, O, seq_len], + ) + + return _FlashAttnFunc() + + +build_flash_attn_func_module = build_flash_attn_func_module_primary +"""flash_attn_func kernel builder for FlyDSL. + +Aggressive flash_attn_func path: +- True MFMA32 remap: `mfma_f32_32x32x8f16` for both GEMM stages. +- Tile shape: BLOCK_M=128, BLOCK_N=32, 4 waves (256 threads). +- Per-wave Q rows: 32. +- GEMM1 uses `K @ Q^T` so S/P live in MFMA32 register layout. +- Online softmax over KV dimension is done in registers. +- P is kept in registers and fed directly to GEMM2 (`V^T @ P`) without LDS roundtrip. +- K uses LDS ping-pong prefetch between adjacent iterations. + +Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). +Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. +Block: (256,) -- 4 waves of 64 on AMD (wave64). + +Requires: head_dim % 32 == 0, head_dim >= 64, seq_len % 128 == 0. +""" + +import math + +from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl +from flydsl.dialects.ext import vector as vec_ext +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.dialects.ext.scf import yield_ as scf_yield +from _mlir.dialects import memref as _memref +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils import SmemAllocator +from _mlir import ir +import _mlir.extras.types as T + + +KERNEL_NAME = "flash_attn_func_kernel" + + +def _legacy_copy_build_flash_attn_func_module_2( + num_heads, + head_dim, + causal=True, + dtype_str="f16", + sm_scale=None, +): + """Build a FlyDSL flash_attn_func module.""" + gpu_arch = get_hip_arch() + DYN = ir.ShapedType.get_dynamic_size() + + # Aggressive MFMA32 configuration for target B=1, H=64, S=8192, D=128. + BLOCK_M = 256 + BLOCK_N = 32 + NUM_WAVES = 8 + WARP_SIZE = 64 + BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 + ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 32 + + # MFMA32 K-dimension is 8. + K_STEP_QK = 8 + K_STEPS_QK = head_dim // K_STEP_QK + # PV stage computes 32 output columns per accumulator chunk. + D_CHUNK = 32 + D_CHUNKS = head_dim // D_CHUNK + PV_K_STEP = 8 + PV_K_STEPS = BLOCK_N // PV_K_STEP # 4 for BN=32 + + assert BLOCK_M % NUM_WAVES == 0 + assert head_dim % 32 == 0, f"head_dim ({head_dim}) must be divisible by 32" + assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" + assert dtype_str == "f16", "flash_attn_func currently only supports f16" + assert BLOCK_N % 32 == 0 + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + NUM_HEADS = num_heads + HEAD_DIM = head_dim + CAUSAL = causal + STRIDE_TOKEN = NUM_HEADS * HEAD_DIM + + # Bank-conflict-friendly LDS strides. + K_STRIDE = HEAD_DIM + 2 + VT_STRIDE = BLOCK_N + 2 + + # Vectorized cooperative load constants. + VEC_WIDTH = 8 + THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH + assert BLOCK_SIZE % THREADS_PER_ROW_LOAD == 0 + ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD + + if ROWS_PER_BATCH_LOAD >= BLOCK_N: + NUM_BATCHES_KV = 1 + KV_NEEDS_GUARD = ROWS_PER_BATCH_LOAD > BLOCK_N + else: + assert BLOCK_N % ROWS_PER_BATCH_LOAD == 0 + NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD + KV_NEEDS_GUARD = False + + # Two KV buffers for K ping-pong prefetch. + LDS_KV_BUF_SIZE = max(BLOCK_N * K_STRIDE, HEAD_DIM * VT_STRIDE) + LDS_KV_TOTAL_SIZE = 2 * LDS_KV_BUF_SIZE + + allocator = SmemAllocator(None, arch=gpu_arch) + _state = {} + + class _FlashAttnFunc(flir.MlirModule): + GPU_MODULE_NAME = f"flash_attn_func_{dtype_str}" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + elem_type = T.f16() + _state["elem_type"] = elem_type + _state["lds_kv"] = allocator.allocate_array(elem_type, LDS_KV_TOTAL_SIZE) + allocator.finalize() + + @flir.kernel + def flash_attn_func_kernel( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + seq_len: lambda: T.index(), + ): + compute_type = T.f32() + elem_type = _state["elem_type"] + fm_fast = flir.arith.FastMathFlags.fast + + v4f16_type = ir.VectorType.get([4], elem_type) + v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) + v16f32_type = ir.VectorType.get([16], compute_type) + + seq_len_v = arith.as_value(seq_len) + + # ---- LDS view ---- + base_ptr = allocator.get_base() + lds_kv = _state["lds_kv"](base_ptr).get() + + # ---- Thread / block indices ---- + block_id = flir.const_index(flir.block_idx("x")) + tid = flir.const_index(flir.thread_idx("x")) + + # ---- Wave decomposition ---- + c_ws = flir.const_index(WARP_SIZE) + wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) + lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) + + c32 = flir.const_index(32) + lane_mod_32 = arith.as_value(flir.arith.RemUIOp(lane, c32).result) + lane_div_32 = arith.as_value(flir.arith.DivUIOp(lane, c32).result) # 0/1 + + # ---- Wave offsets ---- + wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value + + # ---- Decompose block_id ---- + c_nh = flir.const_index(NUM_HEADS) + head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) + temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) + c_bm = flir.const_index(BLOCK_M) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) + q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) + batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) + q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value + + # ---- Cooperative load decomposition ---- + c_tpr = flir.const_index(THREADS_PER_ROW_LOAD) + load_row_in_batch = arith.as_value(flir.arith.DivUIOp(tid, c_tpr).result) + load_lane_in_row = arith.as_value(flir.arith.RemUIOp(tid, c_tpr).result) + load_col_base = (arith.ArithValue(load_lane_in_row) * VEC_WIDTH).value + + # ---- Helper: global flat index ---- + def global_idx(token_idx, col): + token = ( + arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) + + arith.ArithValue(token_idx) + ) + return ( + token * STRIDE_TOKEN + + arith.ArithValue(head_idx) * HEAD_DIM + + arith.ArithValue(col) + ).value + + def kv_buf_base(buf_idx): + return (arith.ArithValue(buf_idx) * LDS_KV_BUF_SIZE).value + + # ---- Cooperative K load (row-major, padded stride) ---- + def coop_load_k(tile_start, buf_idx): + base = kv_buf_base(buf_idx) + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + if KV_NEEDS_GUARD: + c_bn = flir.const_index(BLOCK_N) + row_valid = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + arith.ArithValue(load_row_in_batch).value, + c_bn, + ).result + ) + with scf.if_(row_valid): + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, K, [g_idx])) + lds_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + lds_idx = ( + arith.ArithValue(base) + + arith.ArithValue(lds_row) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + else: + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, K, [g_idx])) + lds_row = (arith.ArithValue(load_row_in_batch) + row_offset).value + lds_idx = ( + arith.ArithValue(base) + + arith.ArithValue(lds_row) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + + # ---- Cooperative V load (transposed, padded stride) ---- + def coop_load_v_transposed(tile_start, buf_idx): + base = kv_buf_base(buf_idx) + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + if KV_NEEDS_GUARD: + c_bn = flir.const_index(BLOCK_N) + row_valid = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + arith.ArithValue(load_row_in_batch).value, + c_bn, + ).result + ) + with scf.if_(row_valid): + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, V, [g_idx])) + load_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + arith.ArithValue(base) + + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + else: + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, V, [g_idx])) + load_row = (arith.ArithValue(load_row_in_batch) + row_offset).value + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + arith.ArithValue(base) + + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + + # ---- Preload Q^T B-operand packs once (register-resident) ---- + # B operand uses j = lane_mod_32, k-subblock = lane_div_32*4. + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_mod_32) + ).value + q_b_packs = [] + for ks in range_constexpr(K_STEPS_QK): + q_col = ( + flir.const_index(ks * K_STEP_QK) + + arith.ArithValue(lane_div_32) * 4 + ).value + g_idx = global_idx(q_row, q_col) + q_b_packs.append(arith.as_value(vec_ext.load_op(v4f16_type, Q, [g_idx]))) + + # ---- Constants ---- + c0_idx = flir.const_index(0) + c1_idx = flir.const_index(1) + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_zero_f = arith.constant(0.0, type=compute_type) + c_one_f = arith.constant(1.0, type=compute_type) + c_sm_scale = arith.constant(sm_scale, type=compute_type) + c_log2e = arith.constant(1.4426950408889634, type=compute_type) + c_zero_v16f32 = arith.as_value(arith.constant_vector(0.0, v16f32_type)) + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + shuf_32_i32 = arith.as_value(arith.constant(32, type=T.i32())) + + # ---- KV loop upper bound ---- + if CAUSAL: + kv_upper = (arith.ArithValue(q_start) + BLOCK_M).value + else: + kv_upper = seq_len_v + + # ---- K ping-pong preload: K(0) -> buf0 ---- + coop_load_k(c0_idx, c0_idx) + gpu.barrier() + + # Loop-carried: [cur_k_buf, m_old, l_old, o_acc_chunks...] + init_args = [arith.as_value(c0_idx), arith.as_value(c_neg_inf), arith.as_value(c_zero_f)] + for _ in range_constexpr(D_CHUNKS): + init_args.append(c_zero_v16f32) + + with scf.for_(0, kv_upper, BLOCK_N, iter_args=init_args) as loop: + kv_start = arith.as_value(loop.induction_variable) + cur_k_buf = arith.as_value(loop.inner_iter_args[0]) + m_old = arith.as_value(loop.inner_iter_args[1]) + l_old = arith.as_value(loop.inner_iter_args[2]) + o_accs = [arith.as_value(loop.inner_iter_args[3 + i]) for i in range_constexpr(D_CHUNKS)] + + next_k_buf = arith.as_value( + flir.arith.SubIOp(c1_idx, arith.ArithValue(cur_k_buf).value).result + ) + cur_base = kv_buf_base(cur_k_buf) + + # ==== GEMM1: S = K @ Q^T (MFMA32), S in v16f32 ==== + s_acc = c_zero_v16f32 + for ks in range_constexpr(K_STEPS_QK): + k_idx = ( + arith.ArithValue(cur_base) + + arith.ArithValue(lane_mod_32) * K_STRIDE + + ks * K_STEP_QK + + arith.ArithValue(lane_div_32) * 4 + ).value + k_pack = arith.as_value(vec_ext.load_op(v4f16_type, lds_kv, [k_idx])) + q_pack = q_b_packs[ks] + s_acc = arith.as_value( + rocdl.mfma_f32_32x32x8f16( + v16f32_type, [k_pack, q_pack, s_acc, 0, 0, 0] + ) + ) + + # ==== Online softmax over KV dimension (register only) ==== + q_row_i64 = arith.as_value( + flir.arith.IndexCastOp(T.i64(), q_row).result + ) + + s_vals = [] + for r in range_constexpr(16): + s_val = arith.as_value( + vec_ext.extract(s_acc, static_position=[r], dynamic_position=[]) + ) + s_val = arith.as_value( + flir.arith.MulFOp( + s_val, arith.as_value(c_sm_scale), fastmath=fm_fast + ).result + ) + + if CAUSAL: + kv_row_rel = ( + arith.ArithValue(lane_div_32) * 4 + + (r // 4) * 8 + + (r % 4) + ).value + kv_col = (arith.ArithValue(kv_start) + arith.ArithValue(kv_row_rel)).value + kv_col_i64 = arith.as_value( + flir.arith.IndexCastOp(T.i64(), kv_col).result + ) + is_masked = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ugt, kv_col_i64, q_row_i64 + ).result + ) + s_val = arith.as_value( + flir.arith.SelectOp( + is_masked, arith.as_value(c_neg_inf), s_val + ).result + ) + s_vals.append(s_val) + + local_max = s_vals[0] + for r in range_constexpr(15): + local_max = arith.as_value( + flir.arith.MaximumFOp(local_max, s_vals[r + 1]).result + ) + peer_max = arith.as_value( + gpu.ShuffleOp(local_max, shuf_32_i32, width_i32, mode="xor").shuffleResult + ) + row_max = arith.as_value( + flir.arith.MaximumFOp(local_max, peer_max).result + ) + m_new = arith.as_value( + flir.arith.MaximumFOp(m_old, row_max).result + ) + + diff_m = arith.as_value( + flir.arith.SubFOp(m_old, m_new, fastmath=fm_fast).result + ) + diff_m_s = arith.as_value( + flir.arith.MulFOp(diff_m, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + corr = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) + + p_vals = [] + local_sum = arith.as_value(c_zero_f) + for r in range_constexpr(16): + diff = arith.as_value( + flir.arith.SubFOp(s_vals[r], m_new, fastmath=fm_fast).result + ) + diff_s = arith.as_value( + flir.arith.MulFOp(diff, arith.as_value(c_log2e), fastmath=fm_fast).result + ) + p = arith.as_value(flir.math.exp2(diff_s, fastmath=fm_fast)) + p_vals.append(p) + local_sum = arith.as_value( + flir.arith.AddFOp(local_sum, p, fastmath=fm_fast).result + ) + + peer_sum = arith.as_value( + gpu.ShuffleOp(local_sum, shuf_32_i32, width_i32, mode="xor").shuffleResult + ) + tile_sum = arith.as_value( + flir.arith.AddFOp(local_sum, peer_sum, fastmath=fm_fast).result + ) + l_corr = arith.as_value( + flir.arith.MulFOp(corr, l_old, fastmath=fm_fast).result + ) + l_new = arith.as_value( + flir.arith.AddFOp(l_corr, tile_sum, fastmath=fm_fast).result + ) + + # ==== Rescale O accumulators ==== + corr_vec = arith.as_value(vec_ext.broadcast(v16f32_type, corr)) + for dc in range_constexpr(D_CHUNKS): + o_accs[dc] = arith.as_value( + flir.arith.MulFOp(o_accs[dc], corr_vec, fastmath=fm_fast).result + ) + + # All waves must finish K reads before reusing current buffer for V. + gpu.barrier() + + # ==== Load V^T for current tile into current buffer ==== + coop_load_v_transposed(kv_start, cur_k_buf) + + # ==== Prefetch next K tile into next buffer (if exists) ==== + next_kv_start = (arith.ArithValue(kv_start) + BLOCK_N).value + has_next = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + next_kv_start, + kv_upper, + ).result + ) + with scf.if_(has_next): + coop_load_k(next_kv_start, next_k_buf) + + # Synchronize V current + K next visibility. + gpu.barrier() + + # ==== Build P packs in MFMA32 B-input format from register S ==== + p_f16 = [] + for r in range_constexpr(16): + p_f16.append( + arith.as_value(flir.arith.TruncFOp(elem_type, p_vals[r]).result) + ) + p_packs = [] + for pks in range_constexpr(PV_K_STEPS): + p_base = pks * 4 + p_packs.append( + arith.as_value( + vec_ext.from_elements( + v4f16_type, + [ + p_f16[p_base + 0], + p_f16[p_base + 1], + p_f16[p_base + 2], + p_f16[p_base + 3], + ], + ) + ) + ) + + # ==== GEMM2: O^T += V^T @ P (MFMA32) ==== + for dc in range_constexpr(D_CHUNKS): + for pks in range_constexpr(PV_K_STEPS): + v_idx = ( + arith.ArithValue(cur_base) + + (dc * D_CHUNK + arith.ArithValue(lane_mod_32)) * VT_STRIDE + + pks * PV_K_STEP + + arith.ArithValue(lane_div_32) * 4 + ).value + v_pack = arith.as_value(vec_ext.load_op(v4f16_type, lds_kv, [v_idx])) + o_accs[dc] = arith.as_value( + rocdl.mfma_f32_32x32x8f16( + v16f32_type, [v_pack, p_packs[pks], o_accs[dc], 0, 0, 0] + ) + ) + + # No trailing barrier: current buffer is only reused after one full iteration gap. + yield_args = [next_k_buf, m_new, l_new] + o_accs + scf_yield(yield_args) + + # ---- Normalize and store O ---- + l_final = arith.as_value(loop.results[2]) + o_finals = [arith.as_value(loop.results[3 + dc]) for dc in range_constexpr(D_CHUNKS)] + + inv_l = arith.as_value( + flir.arith.DivFOp(arith.as_value(c_one_f), l_final, fastmath=fm_fast).result + ) + inv_l_vec = arith.as_value(vec_ext.broadcast(v16f32_type, inv_l)) + + for dc in range_constexpr(D_CHUNKS): + o_norm_vec = arith.as_value( + flir.arith.MulFOp(o_finals[dc], inv_l_vec, fastmath=fm_fast).result + ) + for r in range_constexpr(16): + o_val = arith.as_value( + vec_ext.extract(o_norm_vec, static_position=[r], dynamic_position=[]) + ) + o_f16 = arith.as_value(flir.arith.TruncFOp(elem_type, o_val).result) + + d_row_rel = ( + arith.ArithValue(lane_div_32) * 4 + + (r // 4) * 8 + + (r % 4) + ).value + d_col = (flir.const_index(dc * D_CHUNK) + arith.ArithValue(d_row_rel)).value + o_global = global_idx(q_row, d_col) + _memref.StoreOp(o_f16, O, [o_global]) + + @flir.jit + def __call__( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + batch_size: lambda: T.index(), + seq_len: lambda: T.index(), + ): + c1 = arith.as_value(flir.arith_ext.index(1)) + c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) + c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) + bs_val = arith.as_value(batch_size) + sl_val = arith.as_value(seq_len) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(sl_val, c_bm).result) + bs_qt = arith.as_value(flir.arith.MulIOp(bs_val, num_q_tiles).result) + grid_x = arith.as_value(flir.arith.MulIOp(bs_qt, c_nh).result) + bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) + flir.gpu_ext.LaunchFuncOp( + [self.GPU_MODULE_NAME, KERNEL_NAME], + grid_size=(grid_x, c1, c1), + block_size=(bx, c1, c1), + kernel_operands=[Q, K, V, O, seq_len], + ) + + return _FlashAttnFunc() +"""flash_attn_func kernel builder for FlyDSL. + +flash_attn_func design (CK-aligned direction, rewritten from V4.3): +- CK-aligned baseline tile family: BLOCK_M=64, BLOCK_N=32. +- Q loaded once from global memory into MFMA A-operand packs (register-resident). +- K/V streamed tile-by-tile through LDS. +- Online softmax in fp32 over 32 positions per iteration (2x 16-column groups). +- Causal early-exit keeps KV upper bound at q_start + BLOCK_M. + +Layout: Q/K/V/O are 1D flattened from BSHD (batch, seq_len, num_heads, head_dim). +Grid: (batch * num_q_tiles * num_heads,) where num_q_tiles = seq_len / BLOCK_M. +Block: (256,) -- 4 waves of 64 on AMD (wave64). + +Requires: head_dim % 16 == 0, head_dim >= 64, seq_len % 64 == 0. +""" + +import math + +from flydsl.dialects.ext import flir, arith, gpu, scf, rocdl +from flydsl.dialects.ext import vector as vec_ext +from flydsl.dialects.ext.python_control_flow import range_constexpr +from flydsl.dialects.ext.scf import yield_ as scf_yield +from _mlir.dialects import memref as _memref +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils import SmemAllocator +from _mlir import ir +import _mlir.extras.types as T + + +KERNEL_NAME = "flash_attn_func_kernel" + + +def _legacy_copy_build_flash_attn_func_module_3( + num_heads, + head_dim, + causal=True, + dtype_str="f16", + sm_scale=None, +): + """Build a FlyDSL flash_attn_func module. + + Args: + num_heads: Number of attention heads. + head_dim: Dimension per head (must be divisible by 16, >= 64). + causal: Whether to apply causal mask. + dtype_str: "f16" (bf16 not yet supported). + sm_scale: Softmax scale (default: 1/sqrt(head_dim)). + + Returns: + MlirModule compilable via ``flydsl.compile(module)``. + """ + gpu_arch = get_hip_arch() + DYN = ir.ShapedType.get_dynamic_size() + + # CK-oriented direction for the target (B=1, H=64, S=8192, D=128). + BLOCK_M = 64 + BLOCK_N = 32 + NUM_WAVES = 4 + WARP_SIZE = 64 + BLOCK_SIZE = NUM_WAVES * WARP_SIZE # 256 + ROWS_PER_WAVE = BLOCK_M // NUM_WAVES # 16 + K_STEPS = head_dim // 16 + N_MFMA = BLOCK_N // 16 # 2 + + assert BLOCK_M % NUM_WAVES == 0 + assert head_dim % 16 == 0, f"head_dim ({head_dim}) must be divisible by 16" + assert head_dim >= 64, f"head_dim ({head_dim}) must be >= 64" + assert dtype_str == "f16", "flash_attn_func currently only supports f16" + assert BLOCK_N % 16 == 0, f"BLOCK_N ({BLOCK_N}) must be divisible by 16" + + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + + NUM_HEADS = num_heads + HEAD_DIM = head_dim + CAUSAL = causal + STRIDE_TOKEN = NUM_HEADS * HEAD_DIM + + # ---- Bank-conflict-friendly LDS strides ---- + K_STRIDE = HEAD_DIM + 2 + VT_STRIDE = BLOCK_N + 2 + + # ---- Vectorized cooperative load constants ---- + VEC_WIDTH = 8 + THREADS_PER_ROW_LOAD = HEAD_DIM // VEC_WIDTH + assert BLOCK_SIZE % THREADS_PER_ROW_LOAD == 0 + ROWS_PER_BATCH_LOAD = BLOCK_SIZE // THREADS_PER_ROW_LOAD + + assert BLOCK_M % ROWS_PER_BATCH_LOAD == 0 + NUM_BATCHES_Q = BLOCK_M // ROWS_PER_BATCH_LOAD + + if ROWS_PER_BATCH_LOAD >= BLOCK_N: + NUM_BATCHES_KV = 1 + KV_NEEDS_GUARD = ROWS_PER_BATCH_LOAD > BLOCK_N + else: + assert BLOCK_N % ROWS_PER_BATCH_LOAD == 0 + NUM_BATCHES_KV = BLOCK_N // ROWS_PER_BATCH_LOAD + KV_NEEDS_GUARD = False + + # LDS sizes (element counts, f16 = 2 bytes each) + # No Q in LDS: Q is read once from global memory to MFMA A packs. + LDS_KV_SIZE = max(BLOCK_N * K_STRIDE, HEAD_DIM * VT_STRIDE) + LDS_P_SIZE = BLOCK_M * BLOCK_N + + allocator = SmemAllocator(None, arch=gpu_arch) + _state = {} + + class _FlashAttnFunc(flir.MlirModule): + GPU_MODULE_NAME = f"flash_attn_func_{dtype_str}" + GPU_MODULE_TARGETS = [f'#rocdl.target'] + + def init_gpu_module(self): + elem_type = T.f16() + _state["elem_type"] = elem_type + _state["lds_kv"] = allocator.allocate_array(elem_type, LDS_KV_SIZE) + _state["lds_p"] = allocator.allocate_array(elem_type, LDS_P_SIZE) + allocator.finalize() + + @flir.kernel + def flash_attn_func_kernel( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + seq_len: lambda: T.index(), + ): + compute_type = T.f32() + elem_type = _state["elem_type"] + fm_fast = flir.arith.FastMathFlags.fast + + v4f16_type = ir.VectorType.get([4], elem_type) + v4f32_type = ir.VectorType.get([4], compute_type) + v8f16_type = ir.VectorType.get([VEC_WIDTH], elem_type) + + seq_len_v = arith.as_value(seq_len) + + # ---- LDS views (KV + P only, no Q in LDS) ---- + base_ptr = allocator.get_base() + lds_kv = _state["lds_kv"](base_ptr).get() + lds_p = _state["lds_p"](base_ptr).get() + + # ---- Thread / block indices ---- + block_id = flir.const_index(flir.block_idx("x")) + tid = flir.const_index(flir.thread_idx("x")) + + # ---- Wave decomposition ---- + c_ws = flir.const_index(WARP_SIZE) + wave_id = arith.as_value(flir.arith.DivUIOp(tid, c_ws).result) + lane = arith.as_value(flir.arith.RemUIOp(tid, c_ws).result) + + # ---- MFMA lane decomposition ---- + c16 = flir.const_index(16) + lane_div_16 = arith.as_value(flir.arith.DivUIOp(lane, c16).result) + lane_mod_16 = arith.as_value(flir.arith.RemUIOp(lane, c16).result) + + # ---- Wave offsets ---- + wave_q_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE).value + wave_p_offset = (arith.ArithValue(wave_id) * ROWS_PER_WAVE * BLOCK_N).value + + # ---- Decompose block_id ---- + c_nh = flir.const_index(NUM_HEADS) + head_idx = arith.as_value(flir.arith.RemUIOp(block_id, c_nh).result) + temp = arith.as_value(flir.arith.DivUIOp(block_id, c_nh).result) + c_bm = flir.const_index(BLOCK_M) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(seq_len_v, c_bm).result) + q_tile_idx = arith.as_value(flir.arith.RemUIOp(temp, num_q_tiles).result) + batch_idx = arith.as_value(flir.arith.DivUIOp(temp, num_q_tiles).result) + q_start = (arith.ArithValue(q_tile_idx) * BLOCK_M).value + + # ---- Cooperative load decomposition ---- + c_tpr = flir.const_index(THREADS_PER_ROW_LOAD) + load_row_in_batch = arith.as_value(flir.arith.DivUIOp(tid, c_tpr).result) + load_lane_in_row = arith.as_value(flir.arith.RemUIOp(tid, c_tpr).result) + load_col_base = (arith.ArithValue(load_lane_in_row) * VEC_WIDTH).value + + # ---- Helper: global flat index ---- + def global_idx(token_idx, col): + token = ( + arith.ArithValue(batch_idx) * arith.ArithValue(seq_len_v) + + arith.ArithValue(token_idx) + ) + return ( + token * STRIDE_TOKEN + + arith.ArithValue(head_idx) * HEAD_DIM + + arith.ArithValue(col) + ).value + + # ---- Cooperative K load (row-major, padded stride) ---- + def coop_load_k(tile_start): + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + if KV_NEEDS_GUARD: + c_bn = flir.const_index(BLOCK_N) + row_valid = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + arith.ArithValue(load_row_in_batch).value, + c_bn, + ).result + ) + with scf.if_(row_valid): + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, K, [g_idx])) + lds_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + lds_idx = ( + arith.ArithValue(lds_row) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + else: + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, K, [g_idx])) + lds_row = (arith.ArithValue(load_row_in_batch) + row_offset).value + lds_idx = ( + arith.ArithValue(lds_row) * K_STRIDE + + arith.ArithValue(load_col_base) + ).value + vec_ext.store(vec, lds_kv, [lds_idx]) + + # ---- Cooperative V load (transposed, padded stride) ---- + def coop_load_v_transposed(tile_start): + for batch in range_constexpr(NUM_BATCHES_KV): + row_offset = batch * ROWS_PER_BATCH_LOAD + row_idx = ( + arith.ArithValue(tile_start) + + arith.ArithValue(load_row_in_batch) + + row_offset + ).value + if KV_NEEDS_GUARD: + c_bn = flir.const_index(BLOCK_N) + row_valid = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ult, + arith.ArithValue(load_row_in_batch).value, + c_bn, + ).result + ) + with scf.if_(row_valid): + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, V, [g_idx])) + load_row = ( + arith.ArithValue(load_row_in_batch) + row_offset + ).value + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + else: + g_idx = global_idx(row_idx, load_col_base) + vec = arith.as_value(vec_ext.load_op(v8f16_type, V, [g_idx])) + load_row = (arith.ArithValue(load_row_in_batch) + row_offset).value + for e in range_constexpr(VEC_WIDTH): + elem = arith.as_value( + vec_ext.extract(vec, static_position=[e], dynamic_position=[]) + ) + col_e = (arith.ArithValue(load_col_base) + e).value + lds_idx = ( + arith.ArithValue(col_e) * VT_STRIDE + + arith.ArithValue(load_row) + ).value + _memref.StoreOp(elem, lds_kv, [lds_idx]) + + # ---- Load Q once from global memory to MFMA A packs ---- + q_row = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_mod_16) + ).value + q_a_packs = [] + for ks in range_constexpr(K_STEPS): + q_col = flir.const_index(ks * 16) + q_col = (arith.ArithValue(q_col) + arith.ArithValue(lane_div_16) * 4).value + g_idx = global_idx(q_row, q_col) + q_a_packs.append(arith.as_value(vec_ext.load_op(v4f16_type, Q, [g_idx]))) + + # ---- Constants ---- + c_neg_inf = arith.constant(float("-inf"), type=compute_type) + c_zero_f = arith.constant(0.0, type=compute_type) + c_sm_scale = arith.constant(sm_scale, type=compute_type) + c_log2e = arith.constant(1.4426950408889634, type=compute_type) + c_zero_v4f32 = arith.as_value(arith.constant_vector(0.0, v4f32_type)) + + # ---- Init loop-carried state ---- + # m[4], l[4], o_accs[K_STEPS] + init_args = [] + for _ in range_constexpr(4): + init_args.append(arith.as_value(c_neg_inf)) + for _ in range_constexpr(4): + init_args.append(arith.as_value(c_zero_f)) + for _ in range_constexpr(K_STEPS): + init_args.append(c_zero_v4f32) + + # ---- KV loop upper bound ---- + if CAUSAL: + kv_upper = (arith.ArithValue(q_start) + BLOCK_M).value + else: + kv_upper = seq_len_v + + # ---- KV loop (step BLOCK_N=64) ---- + with scf.for_(0, kv_upper, BLOCK_N, iter_args=init_args) as loop: + kv_start = arith.as_value(loop.induction_variable) + m_old = [arith.as_value(loop.inner_iter_args[i]) for i in range(4)] + l_old = [arith.as_value(loop.inner_iter_args[4 + i]) for i in range(4)] + o_accs = [arith.as_value(loop.inner_iter_args[8 + ds]) for ds in range(K_STEPS)] + + # ==== Cooperative K load -> LDS_KV ==== + coop_load_k(kv_start) + gpu.barrier() + + # ==== Q @ K^T via MFMA -> S[16, BLOCK_N] ==== + s_accs = [c_zero_v4f32 for _ in range_constexpr(N_MFMA)] + for ks in range_constexpr(K_STEPS): + a_pack = q_a_packs[ks] + for nm in range_constexpr(N_MFMA): + k_row = nm * 16 + k_lds_idx = ( + (arith.ArithValue(lane_mod_16) + k_row) * K_STRIDE + + ks * 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + b_pack = arith.as_value(vec_ext.load_op(v4f16_type, lds_kv, [k_lds_idx])) + s_accs[nm] = arith.as_value( + rocdl.mfma_f32_16x16x16f16( + v4f32_type, [a_pack, b_pack, s_accs[nm], 0, 0, 0] + ) + ) + + # ==== Online softmax over BLOCK_N positions ==== + # s_vals[nm][ii] where nm in [0..3], ii in [0..3] + s_vals = [[None for _ in range_constexpr(4)] for _ in range_constexpr(N_MFMA)] + for ii in range_constexpr(4): + for nm in range_constexpr(N_MFMA): + s_val = arith.as_value( + vec_ext.extract(s_accs[nm], static_position=[ii], dynamic_position=[]) + ) + s_val = arith.as_value( + flir.arith.MulFOp( + s_val, arith.as_value(c_sm_scale), fastmath=fm_fast + ).result + ) + + if CAUSAL: + q_row_i = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_div_16) * 4 + + ii + ).value + kv_col = ( + arith.ArithValue(kv_start) + + nm * 16 + + arith.ArithValue(lane_mod_16) + ).value + q_row_i64 = arith.as_value( + flir.arith.IndexCastOp(T.i64(), q_row_i).result + ) + kv_col_i64 = arith.as_value( + flir.arith.IndexCastOp(T.i64(), kv_col).result + ) + is_masked = arith.as_value( + flir.arith.CmpIOp( + flir.arith.CmpIPredicate.ugt, kv_col_i64, q_row_i64 + ).result + ) + s_val = arith.as_value( + flir.arith.SelectOp( + is_masked, arith.as_value(c_neg_inf), s_val + ).result + ) + s_vals[nm][ii] = s_val + + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + m_new = [None] * 4 + corr = [None] * 4 + p_vals = [[None for _ in range_constexpr(4)] for _ in range_constexpr(N_MFMA)] + l_new = [None] * 4 + + for ii in range_constexpr(4): + row_maxes = [] + for nm in range_constexpr(N_MFMA): + row_max_nm = s_vals[nm][ii] + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp( + row_max_nm, sh_i32, width_i32, mode="xor" + ).shuffleResult + ) + row_max_nm = arith.as_value( + flir.arith.MaximumFOp(row_max_nm, peer).result + ) + row_maxes.append(row_max_nm) + + combined_max = row_maxes[0] + for g in range_constexpr(N_MFMA - 1): + combined_max = arith.as_value( + flir.arith.MaximumFOp(combined_max, row_maxes[g + 1]).result + ) + + m_new[ii] = arith.as_value( + flir.arith.MaximumFOp(m_old[ii], combined_max).result + ) + + diff_m = arith.as_value( + flir.arith.SubFOp(m_old[ii], m_new[ii], fastmath=fm_fast).result + ) + diff_m_s = arith.as_value( + flir.arith.MulFOp( + diff_m, arith.as_value(c_log2e), fastmath=fm_fast + ).result + ) + corr[ii] = arith.as_value(flir.math.exp2(diff_m_s, fastmath=fm_fast)) + + row_sums = [] + for nm in range_constexpr(N_MFMA): + diff = arith.as_value( + flir.arith.SubFOp( + s_vals[nm][ii], m_new[ii], fastmath=fm_fast + ).result + ) + diff_s = arith.as_value( + flir.arith.MulFOp( + diff, arith.as_value(c_log2e), fastmath=fm_fast + ).result + ) + p_vals[nm][ii] = arith.as_value(flir.math.exp2(diff_s, fastmath=fm_fast)) + + row_sum_nm = p_vals[nm][ii] + for sh in [8, 4, 2, 1]: + sh_i32 = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp( + row_sum_nm, sh_i32, width_i32, mode="xor" + ).shuffleResult + ) + row_sum_nm = arith.as_value( + flir.arith.AddFOp(row_sum_nm, peer, fastmath=fm_fast).result + ) + row_sums.append(row_sum_nm) + + combined_sum = row_sums[0] + for g in range_constexpr(N_MFMA - 1): + combined_sum = arith.as_value( + flir.arith.AddFOp(combined_sum, row_sums[g + 1], fastmath=fm_fast).result + ) + + l_corr = arith.as_value( + flir.arith.MulFOp(corr[ii], l_old[ii], fastmath=fm_fast).result + ) + l_new[ii] = arith.as_value( + flir.arith.AddFOp(l_corr, combined_sum, fastmath=fm_fast).result + ) + + # ==== Rescale O accumulators ==== + corr_vec = arith.as_value( + vec_ext.from_elements(v4f32_type, [corr[0], corr[1], corr[2], corr[3]]) + ) + for ds in range_constexpr(K_STEPS): + o_accs[ds] = arith.as_value( + flir.arith.MulFOp(o_accs[ds], corr_vec, fastmath=fm_fast).result + ) + + # ==== P store to LDS_P ==== + for ii in range_constexpr(4): + p_row = (arith.ArithValue(lane_div_16) * 4 + ii).value + for nm in range_constexpr(N_MFMA): + p_f16 = arith.as_value( + flir.arith.TruncFOp(elem_type, p_vals[nm][ii]).result + ) + p_lds_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(p_row) * BLOCK_N + + nm * 16 + + arith.ArithValue(lane_mod_16) + ).value + _memref.StoreOp(p_f16, lds_p, [p_lds_idx]) + + # ==== Barrier: ensure all waves done reading K ==== + gpu.barrier() + + # ==== Cooperative V load (transposed) ==== + coop_load_v_transposed(kv_start) + gpu.barrier() + + # ==== P @ V via MFMA ==== + # P does not depend on ds; load once and reuse across all K_STEPS. + p_packs = [] + for nm in range_constexpr(N_MFMA): + p_a_idx = ( + arith.ArithValue(wave_p_offset) + + arith.ArithValue(lane_mod_16) * BLOCK_N + + nm * 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + p_packs.append(arith.as_value(vec_ext.load_op(v4f16_type, lds_p, [p_a_idx]))) + + for ds in range_constexpr(K_STEPS): + for nm in range_constexpr(N_MFMA): + v_idx = ( + (ds * 16 + arith.ArithValue(lane_mod_16)) * VT_STRIDE + + nm * 16 + + arith.ArithValue(lane_div_16) * 4 + ).value + v_pack = arith.as_value( + vec_ext.load_op(v4f16_type, lds_kv, [v_idx]) + ) + o_accs[ds] = arith.as_value( + rocdl.mfma_f32_16x16x16f16( + v4f32_type, [p_packs[nm], v_pack, o_accs[ds], 0, 0, 0] + ) + ) + + # ==== Barrier: ensure all waves done reading V ==== + gpu.barrier() + + yield_args = m_new + l_new + o_accs + scf_yield(yield_args) + + # ---- Normalize and store O ---- + m_finals = [arith.as_value(loop.results[i]) for i in range(4)] + l_finals = [arith.as_value(loop.results[4 + i]) for i in range(4)] + o_finals = [arith.as_value(loop.results[8 + ds]) for ds in range(K_STEPS)] + + for ds in range_constexpr(K_STEPS): + for ii in range_constexpr(4): + o_val = arith.as_value( + vec_ext.extract(o_finals[ds], static_position=[ii], dynamic_position=[]) + ) + o_norm = arith.as_value( + flir.arith.DivFOp(o_val, l_finals[ii], fastmath=fm_fast).result + ) + o_f16 = arith.as_value(flir.arith.TruncFOp(elem_type, o_norm).result) + q_row_o = ( + arith.ArithValue(q_start) + + arith.ArithValue(wave_q_offset) + + arith.ArithValue(lane_div_16) * 4 + + ii + ).value + d_col = (flir.const_index(ds * 16) + arith.ArithValue(lane_mod_16)).value + o_global = global_idx(q_row_o, d_col) + _memref.StoreOp(o_f16, O, [o_global]) + + @flir.jit + def __call__( + self: flir.T.i64, + Q: lambda: T.memref(DYN, _state["elem_type"]), + K: lambda: T.memref(DYN, _state["elem_type"]), + V: lambda: T.memref(DYN, _state["elem_type"]), + O: lambda: T.memref(DYN, _state["elem_type"]), + batch_size: lambda: T.index(), + seq_len: lambda: T.index(), + ): + c1 = arith.as_value(flir.arith_ext.index(1)) + c_nh = arith.as_value(flir.arith_ext.index(NUM_HEADS)) + c_bm = arith.as_value(flir.arith_ext.index(BLOCK_M)) + bs_val = arith.as_value(batch_size) + sl_val = arith.as_value(seq_len) + num_q_tiles = arith.as_value(flir.arith.DivUIOp(sl_val, c_bm).result) + bs_qt = arith.as_value(flir.arith.MulIOp(bs_val, num_q_tiles).result) + grid_x = arith.as_value(flir.arith.MulIOp(bs_qt, c_nh).result) + bx = arith.as_value(flir.arith_ext.index(BLOCK_SIZE)) + flir.gpu_ext.LaunchFuncOp( + [self.GPU_MODULE_NAME, KERNEL_NAME], + grid_size=(grid_x, c1, c1), + block_size=(bx, c1, c1), + kernel_operands=[Q, K, V, O, seq_len], + ) + + return _FlashAttnFunc() diff --git a/kernels/preshuffle_gemm.py b/kernels/preshuffle_gemm.py index 549020a8..04b8effb 100644 --- a/kernels/preshuffle_gemm.py +++ b/kernels/preshuffle_gemm.py @@ -24,7 +24,6 @@ from flydsl.dialects.ext import arith, gpu, buffer_ops, vector, rocdl from flydsl.lang.ir.types import T, memref from kernels.kernels_common import stream_ptr_to_async_token -from flydsl.compiler.compiler import _apply_waves_per_eu_hint from kernels.mfma_preshuffle_pipeline import ( buffer_copy_gmem16_dwordx4, @@ -1190,15 +1189,12 @@ def __call__( m = _GEMM() - # Apply waves_per_eu hint if specified (before final compilation) - if waves_per_eu is not None: - _apply_waves_per_eu_hint(m.module, waves_per_eu) - return flydsl.compile( m, use_bare_ptr_memref_call_conv=False, use_bare_pointers_for_host=False, use_bare_pointers_for_kernels=False, + waves_per_eu=waves_per_eu, ) diff --git a/kernels/reduce.py b/kernels/reduce.py index d8313200..c784ff40 100644 --- a/kernels/reduce.py +++ b/kernels/reduce.py @@ -8,6 +8,50 @@ from flydsl.dialects.ext.python_control_flow import lower_range_for_loops +# --------------------------------------------------------------------------- +# Single-warp (wave64) shuffle reductions +# --------------------------------------------------------------------------- + +WAVE64_OFFSETS = [32, 16, 8, 4, 2, 1] + + +def warp_reduce_sum(val, *, gpu, arith, flir, T, fm_fast, WARP_SIZE=64): + """Single-warp (wave64) sum reduction via xor shuffle. + + Returns: scalar f32 value holding the warp-wide sum. + All lanes receive the same result after the final shuffle step. + """ + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + w = arith.as_value(val) + for sh in WAVE64_OFFSETS: + off = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(w, off, width_i32, mode="xor").shuffleResult + ) + w = arith.as_value( + flir.arith.AddFOp(w, peer, fastmath=fm_fast).result + ) + return w + + +def warp_reduce_max(val, *, gpu, arith, flir, T, WARP_SIZE=64): + """Single-warp (wave64) max reduction via xor shuffle. + + Returns: scalar f32 value holding the warp-wide max. + """ + width_i32 = arith.as_value(arith.constant(WARP_SIZE, type=T.i32())) + w = arith.as_value(val) + for sh in WAVE64_OFFSETS: + off = arith.as_value(arith.constant(sh, type=T.i32())) + peer = arith.as_value( + gpu.ShuffleOp(w, off, width_i32, mode="xor").shuffleResult + ) + w = arith.as_value( + flir.arith.MaximumFOp(w, peer).result + ) + return w + + def reduce_vec_max(vec_val, *, VEC_WIDTH, compute_type, vector): if VEC_WIDTH == 1: return vector.extract(vec_val, static_position=[0], dynamic_position=[]) diff --git a/tests/kernels/test_flash_attn_func.py b/tests/kernels/test_flash_attn_func.py new file mode 100644 index 00000000..97b52d43 --- /dev/null +++ b/tests/kernels/test_flash_attn_func.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +"""flash_attn_func kernel test and benchmark for FlyDSL. + +Tests flash_attn_func against PyTorch SDPA. +""" + +import sys +import argparse +import hashlib +import random +from pathlib import Path +import logging + +# Configure logging to show INFO level messages (required for kernel name display) +logging.basicConfig(level=logging.INFO) + +_repo = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(_repo)) + +try: + import torch + import torch.nn.functional as F + import numpy as np +except ImportError: + print("PyTorch not available") + sys.exit(1) + +if not torch.cuda.is_available(): + print("CUDA/ROCm not available") + sys.exit(1) + +import flydsl +from kernels.flash_attn_func import ( + KERNEL_NAME, + build_flash_attn_func_module, + select_flash_attn_func_path, +) +from tests.test_common import run_perftest + +# Tensor initialization range (uniform distribution) +UNIFORM_RANGE = (-1, 1) +DEFAULT_SEED = 123 +FLASH_ATTN_FUNC_COMPILE_KWARGS = { + "unsafe_fp_math": True, + "fast_fp_math": True, + "waves_per_eu": 3, + "flat_work_group_size": 256, +} + + +def setup_seed(seed: int) -> None: + """Set random seed for reproducibility across all RNG sources.""" + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + + +def pytorch_ref_attention(q, k, v, causal=True): + q_t = q.transpose(1, 2).float() + k_t = k.transpose(1, 2).float() + v_t = v.transpose(1, 2).float() + out = F.scaled_dot_product_attention(q_t, k_t, v_t, is_causal=causal) + return out.transpose(1, 2) + + +def compute_md5(tensor: torch.Tensor) -> str: + """Compute MD5 hash of a tensor's raw bytes.""" + return hashlib.md5( + tensor.contiguous().view(torch.uint8).detach().cpu().numpy().tobytes() + ).hexdigest() + + +def compare_arrays( + arr1: np.ndarray, + arr2: np.ndarray, + k: int = 5, + thresholds: list = None, +) -> dict: + """Compare two numpy arrays and compute various difference metrics. + + Args: + arr1: First input array (result), will be cast to float32. + arr2: Second input array (reference), will be cast to float32. + k: Number of top differences to report. + thresholds: Difference magnitude buckets for histogram. + + Returns: + Dictionary with top_k_diff, threshold_stats, nan_info, max_diff, max_diff_thr. + """ + if thresholds is None: + thresholds = [0, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0, 1e1] + + if arr1.shape != arr2.shape: + raise ValueError(f"Shape mismatch: arr1 {arr1.shape} vs arr2 {arr2.shape}") + + arr1 = arr1.astype(np.float32) + arr2 = arr2.astype(np.float32) + + result = {"top_k_diff": [], "threshold_stats": [], "nan_info": {}} + + # Check for NaN values + nan_mask1 = np.isnan(arr1) + nan_mask2 = np.isnan(arr2) + if np.any(nan_mask1): + result["nan_info"]["arr1_nan_count"] = int(np.sum(nan_mask1)) + print(f" Warning: result contains {result['nan_info']['arr1_nan_count']} NaN values") + if np.any(nan_mask2): + result["nan_info"]["arr2_nan_count"] = int(np.sum(nan_mask2)) + print(f" Warning: reference contains {result['nan_info']['arr2_nan_count']} NaN values") + + # Compute absolute differences + diff = np.abs(arr1 - arr2) + total_elements = arr1.size + + max_diff_thr = (diff / (1.0 + np.abs(arr2))).max() + result["max_diff"] = float(diff.max()) + result["max_diff_thr"] = float(max_diff_thr) + + print(f" diff.abs.max = {diff.max():.6f}") + print(f" diff.abs.mean = {diff.mean():.6f}") + print(f" max_diff_thr (rel) = {max_diff_thr:.6e}") + + # Find top k differences + flat_diff = diff.flatten() + actual_k = min(k, len(flat_diff)) + top_k_indices = np.argpartition(flat_diff, -actual_k)[-actual_k:] + top_k_indices = top_k_indices[np.argsort(-flat_diff[top_k_indices])] + + orig_indices = np.unravel_index(top_k_indices, diff.shape) + print(f" Top-{actual_k} differences:") + for i in range(actual_k): + idx = tuple(dim[i] for dim in orig_indices) + entry = { + "value": float(diff[idx]), + "position": idx, + "arr1_value": float(arr1[idx]), + "arr2_value": float(arr2[idx]), + } + result["top_k_diff"].append(entry) + print(f" [{idx}] result={arr1[idx]:.6f}, ref={arr2[idx]:.6f}, diff={diff[idx]:.6f}") + + # Compute threshold statistics + print(f" Threshold distribution ({total_elements} elements):") + for i in range(len(thresholds) - 1): + lower, upper = thresholds[i], thresholds[i + 1] + count = int(np.sum((diff >= lower) & (diff < upper))) + pct = 100.0 * count / total_elements + result["threshold_stats"].append( + {"range": f"[{lower:.0e}, {upper:.0e})", "count": count, "percentage": pct} + ) + print(f" [{lower:.0e}, {upper:.0e}): {count:>8d} ({pct:6.2f}%)") + + count = int(np.sum(diff >= thresholds[-1])) + pct = 100.0 * count / total_elements + result["threshold_stats"].append( + {"range": f">={thresholds[-1]:.0e}", "count": count, "percentage": pct} + ) + print(f" >={thresholds[-1]:.0e} : {count:>8d} ({pct:6.2f}%)") + + return result + + +def run_config( + batch, seq_len, num_heads, head_dim, dtype, causal, warmup, iters, seed=DEFAULT_SEED +): + device = "cuda" + results = {} + active_path = select_flash_attn_func_path( + num_heads=num_heads, head_dim=head_dim, causal=causal, dtype_str="f16" + ) + results["active_path"] = active_path + + if seq_len % 128 != 0: + results["err"] = f"seq_len ({seq_len}) must be divisible by 128 for flash_attn_func" + return results + if head_dim % 32 != 0 or head_dim < 64: + results["err"] = f"head_dim ({head_dim}) must be >= 64 and divisible by 32" + return results + + try: + m = build_flash_attn_func_module( + num_heads=num_heads, head_dim=head_dim, causal=causal, dtype_str="f16" + ) + exe = flydsl.compile(m, **FLASH_ATTN_FUNC_COMPILE_KWARGS) + except Exception as e: + results["err"] = f"compile: {e}" + import traceback + + traceback.print_exc() + return results + + B, S, H, D = batch, seq_len, num_heads, head_dim + setup_seed(seed) + q_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) + k_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) + v_4d = torch.empty(B, S, H, D, dtype=dtype, device=device).uniform_(*UNIFORM_RANGE) + + q_flat = q_4d.contiguous().view(-1) + k_flat = k_4d.contiguous().view(-1) + v_flat = v_4d.contiguous().view(-1) + o_flat = torch.zeros_like(q_flat) + + try: + exe(q_flat, k_flat, v_flat, o_flat, B, S) + torch.cuda.synchronize() + except Exception as e: + results["err"] = f"exec: {e}" + import traceback + + traceback.print_exc() + return results + + ref_4d = pytorch_ref_attention(q_4d.float(), k_4d.float(), v_4d.float(), causal=causal).to(dtype) + ref_flat = ref_4d.contiguous().view(-1) + + o_f32 = o_flat.float() + ref_f32 = ref_flat.float() + max_err = (o_f32 - ref_f32).abs().max().item() + mean_err = (o_f32 - ref_f32).abs().mean().item() + cos_sim = F.cosine_similarity(o_f32.view(-1, D), ref_f32.view(-1, D), dim=1) + min_cos = cos_sim.min().item() + results["max_err"] = max_err + results["mean_err"] = mean_err + results["min_cos"] = min_cos + results["passed"] = max_err < 1e-2 and min_cos > 0.99 + + # Compute and print MD5 hashes + tag = f"B={B} S={S} H={H} D={D}" + print(f" [{tag}] active_path = {active_path}") + result_md5 = compute_md5(o_flat) + ref_md5 = compute_md5(ref_flat) + print(f" [{tag}] result_md5 = {result_md5}") + print(f" [{tag}] ref_md5 = {ref_md5}") + if result_md5 == ref_md5: + print(f" [{tag}] MD5 match: EXACT (bit-identical)") + else: + print(f" [{tag}] MD5 match: DIFFER (not bit-identical)") + + print(f" [{tag}] --- compare_arrays ---") + compare_arrays( + o_flat.to(torch.float32).detach().cpu().numpy(), + ref_flat.to(torch.float32).detach().cpu().numpy(), + ) + + try: + def kernel_fn(): + exe(q_flat, k_flat, v_flat, o_flat, B, S) + + _, us = run_perftest(kernel_fn, num_iters=iters, num_warmup=warmup) + s_eff = S / 2.0 if causal else float(S) + flops = 4.0 * S * s_eff * D * H * B + tflops = flops / (us * 1e-6) / 1e12 + results["us"] = us + results["tflops"] = tflops + except Exception as e: + results["bench_err"] = str(e) + + return results + + +def main(): + parser = argparse.ArgumentParser(description="flash_attn_func FlyDSL Test/Benchmark") + parser.add_argument("--batch", type=int, default=None) + parser.add_argument("--seq_len", type=int, default=None) + parser.add_argument("--num_heads", type=int, default=None) + parser.add_argument("--head_dim", type=int, default=None) + parser.add_argument("--no-causal", action="store_true") + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--iters", type=int, default=20) + parser.add_argument( + "--seed", type=int, default=DEFAULT_SEED, help=f"Random seed for reproducibility (default: {DEFAULT_SEED})" + ) + args = parser.parse_args() + + causal = not args.no_causal + dtype = torch.float16 + + print("=" * 130) + print(f"FlyDSL flash_attn_func ({'causal' if causal else 'non-causal'}, fp16)") + print(" Tile: BLOCK_M=128, BLOCK_N=32 fallback (default) + CK-like N=128 fast path (gated)") + print(" Strategy: K@Q^T + register S/P ping-pong + V^T@P") + print(f"GPU: {torch.cuda.get_device_name(0)}") + print(f" Compile opts: {FLASH_ATTN_FUNC_COMPILE_KWARGS}") + print("=" * 130) + + if args.seq_len or args.head_dim or args.batch: + configs = [(args.batch or 1, args.seq_len or 128, args.num_heads or 8, args.head_dim or 128)] + else: + configs = [ + (1, 128, 8, 128), + (1, 128, 64, 128), + (1, 256, 32, 128), + (1, 512, 32, 128), + (2, 128, 8, 128), + (1, 8192, 64, 128), + ] + + hdr = ( + f"{'Config/Path':>56s} | {'Status':>6s} | {'MaxErr':>8s} " + f"{'MinCos':>8s} | {'Time(us)':>10s} {'TFLOPS':>8s}" + ) + print(f"\n{hdr}") + print("-" * len(hdr)) + + all_passed = True + for batch, seq_len, nh, hd in configs: + tag = f"B={batch} S={seq_len} H={nh} D={hd}" + try: + r = run_config( + batch, + seq_len, + nh, + hd, + dtype, + causal, + warmup=args.warmup, + iters=args.iters, + seed=args.seed, + ) + if "err" in r: + cfg_path = f"{tag} / {r.get('active_path', 'unknown')}" + print(f"{cfg_path:>56s} | {'ERROR':>6s} | {r['err'][:60]}") + all_passed = False + continue + + status = "PASS" if r["passed"] else "FAIL" + if not r["passed"]: + all_passed = False + cfg_path = f"{tag} / {r.get('active_path', 'unknown')}" + + us_s = f"{r['us']:>10.1f}" if "us" in r else " N/A" + tf_s = f"{r['tflops']:>9.3f}" if "tflops" in r else " N/A" + print( + f"{cfg_path:>56s} | {status:>6s} | " + f"{r['max_err']:>8.2e} {r['min_cos']:>8.5f} | " + f"{us_s} {tf_s}" + ) + except Exception as e: + print(f"{tag:>56s} | {'ERROR':>6s} | {str(e)[:60]}") + all_passed = False + + print("=" * 130) + if all_passed: + print("All tests PASSED") + else: + print("Some tests FAILED") + sys.exit(1) + + +if __name__ == "__main__": + main()