diff --git a/wave_lang/kernel/wave/templates/__init__.py b/wave_lang/kernel/wave/templates/__init__.py index d577e326a..4b127d426 100644 --- a/wave_lang/kernel/wave/templates/__init__.py +++ b/wave_lang/kernel/wave/templates/__init__.py @@ -6,10 +6,15 @@ from .attention_common import AttentionShape from .tagged_attention import get_tagged_bshd_attention_kernel -from .tagged_mxfp4_gemm import get_tagged_mxfp4_gemm, get_tagged_mxfp4_gemm_preshuffle_b +from .tagged_mxfp4_gemm import ( + compute_best_group_size_n, + get_tagged_mxfp4_gemm, + get_tagged_mxfp4_gemm_preshuffle_b, +) __all__ = [ "AttentionShape", + "compute_best_group_size_n", "get_tagged_bshd_attention_kernel", "get_tagged_mxfp4_gemm", "get_tagged_mxfp4_gemm_preshuffle_b", diff --git a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py index 6d186c432..59442f539 100755 --- a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py +++ b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py @@ -17,6 +17,7 @@ bitcast_a, bitcast_a_scale, bitcast_b, bitcast_b_scale, scaled_mma. """ +from math import ceil from sympy import Piecewise, ceiling, floor, Max import wave_lang.kernel.lang as tkl @@ -47,10 +48,23 @@ def get_tagged_mxfp4_gemm( block_shape: (BLOCK_M, BLOCK_N, BLOCK_K) tile sizes. mfma_variant: Scaled MMA instruction type. wave_shape: (WAVE_M, WAVE_N) waves per workgroup. + reorder_workgroups: Enable N-dim workgroup reordering. When True, + compute_best_group_size_n() is called to auto-select the optimal + group size and decide whether reordering is actually beneficial. + group_size_n: Number of N-tiles per reordering group. Returns: (kernel_function, WaveCompileOptions) """ + # Auto-select group_size_n and determine whether reordering helps + if reorder_workgroups: + group_size_n, reorder_workgroups = compute_best_group_size_n( + shape[0], shape[1], shape[2], block_shape[0], block_shape[1] + ) + print( + f"[workgroup_reorder] enabled={reorder_workgroups}, group_size_n={group_size_n}" + ) + M = tkl.sym.M N = tkl.sym.N K = tkl.sym.K @@ -154,10 +168,23 @@ def get_tagged_mxfp4_gemm_preshuffle_b( wave_shape: (WAVE_M, WAVE_N) waves per workgroup. mfma_variant: Scaled MMA instruction type. a_address_space: Address space for A and A_scale (typically SHARED). + reorder_workgroups: Enable N-dim workgroup reordering. When True, + compute_best_group_size_n() is called to auto-select the optimal + group size and decide whether reordering is actually beneficial. + group_size_n: Number of N-tiles per reordering group. Returns: (kernel_function, WaveCompileOptions) """ + # Auto-select group_size_n and determine whether reordering helps + if reorder_workgroups: + group_size_n, reorder_workgroups = compute_best_group_size_n( + shape[0], shape[1], shape[2], block_shape[0], block_shape[1] + ) + print( + f"[workgroup_reorder] enabled={reorder_workgroups}, group_size_n={group_size_n}" + ) + M = tkl.sym.M N = tkl.sym.N K = tkl.sym.K @@ -311,6 +338,108 @@ def repeat( return gemm, options +def compute_best_group_size_n( + M: int, + N: int, + K: int, + block_m: int, + block_n: int, + num_xcds: int = 8, + cus_per_xcd: int = 32, +) -> tuple[int, bool]: + """Auto-select group_size_n and decide whether N-dim reordering is beneficial. + + Dispatch model (MI300X / MI350): + Hardware assigns flat workgroup indices round-robin to XCDs. + Each XCD runs cus_per_xcd CUs in parallel, forming a "batch" of + cus_per_xcd concurrent workgroups. + + Each batch covers U_A unique M-tiles × U_B unique N-tiles. + Per K-iteration DRAM fetches = U_A + U_B. + Minimise U_A + U_B subject to U_A × U_B ≈ cus_per_xcd (= 32). + Optimal: (U_A, U_B) = (4, 8) or (8, 4) → sum = 12. + + WITHOUT N-reordering: + U_B_natural ≈ (cus_per_xcd × num_xcds) / num_wg_0 = 256 / num_wg_0 + sum_natural = U_A_natural + U_B_natural + + WITH N-reordering (group_size_n = gsn, multiple of num_xcds): + U_B = gsn / num_xcds + (cost function) sum_gsn = cus_per_xcd × num_xcds / gsn + gsn / num_xcds + = 256 / gsn + gsn / 8 + + Optimal gsn (derivation from cost function set to zero and solved for gsn) + ≈ num_xcds × √cus_per_xcd ≈ 45 → closest power of two: gsn=32 and gsn=64 + + Worked examples (block_m = block_n = 256, MI300X defaults): + + Shape (M, N) num_wg_0 U_B_natural sum_natural best_gsn enable + (4096, 57344) 16 16 18 32 YES ← num_wg_0 < 32 + (8192, 57344) 32 8 12 -- NO ← already optimal + (16384, 16384) 64 4 12 -- NO ← already optimal + (32768, 16384) 128 2 18 64 YES ← num_wg_0 > 64 + + group_size_n selection: + Both gsn=32 (U_A=8, U_B=4) and gsn=64 (U_A=4, U_B=8) achieve sum=12. + Tie-breaking: + • Exact divisors of num_wg_1 are preferred (no tail group). + • B-heavy shapes (num_wg_1 >= num_wg_0): prefer gsn=32 (lower U_B → + more concurrent B sharing per batch). + • A-heavy shapes (num_wg_0 > num_wg_1): prefer gsn=64 (lower U_A → + more concurrent A sharing per batch). + + Args: + M, N, K: Problem dimensions (K is accepted for API consistency + but does not affect the batch balance model). + block_m, block_n: Tile sizes along M and N. + num_xcds: XCD count (MI300X / MI350: 8). + cus_per_xcd: CUs per XCD (MI300X / MI350: 32). + + Returns: + (best_group_size_n, reorder_enabled) + reorder_enabled=False means column-major dispatch already achieves the + optimal batch balance (sum=12); best_group_size_n is still returned + (32) as a safe default. + """ + num_wg_0 = ceil(M / block_m) # M-tiles + num_wg_1 = ceil(N / block_n) # N-tiles + + total_wg = num_wg_0 * num_wg_1 + if total_wg < num_xcds: # fewer wgs than XCDs, model meaningless and reordering too + return 32, False + + candidates = [g for g in (16, 32, 64) if g % num_xcds == 0 and g <= num_wg_1] + if not candidates: + return num_xcds, False + + def ub(g: int) -> int: + return g // num_xcds + + def ua(g: int) -> int: + return cus_per_xcd // max(1, ub(g)) + + def gsn_sum(g: int) -> int: + return ua(g) + ub(g) + + # Natural batch composition (no reordering) + u_b_nat = max(1, min(cus_per_xcd, cus_per_xcd * num_xcds // max(1, num_wg_0))) + u_a_nat = max(1, cus_per_xcd // u_b_nat) + sum_natural = u_a_nat + u_b_nat + + best_sum = min(gsn_sum(g) for g in candidates) + reorder_enabled = best_sum < sum_natural + + if not reorder_enabled: + return 32, False + + optimal = [g for g in candidates if gsn_sum(g) == best_sum] + exact = [g for g in optimal if num_wg_1 % g == 0] + pool = exact if exact else optimal + + # Tie-break: B-heavy → smaller gsn (more B sharing); A-heavy → larger gsn. + return (min(pool) if num_wg_1 >= num_wg_0 else max(pool)), True + + def _reorder_mxfp4_workgroups(m, n, block_m, block_n, group_size_n): """Remap workgroup indices to a new order based on group_size_n along N dimension.