Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion wave_lang/kernel/wave/templates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@

from .attention_common import AttentionShape
from .tagged_attention import get_tagged_bshd_attention_kernel
from .tagged_mxfp4_gemm import get_tagged_mxfp4_gemm, get_tagged_mxfp4_gemm_preshuffle_b
from .tagged_mxfp4_gemm import (
compute_best_group_size_n,
get_tagged_mxfp4_gemm,
get_tagged_mxfp4_gemm_preshuffle_b,
)

__all__ = [
"AttentionShape",
"compute_best_group_size_n",
"get_tagged_bshd_attention_kernel",
"get_tagged_mxfp4_gemm",
"get_tagged_mxfp4_gemm_preshuffle_b",
Expand Down
129 changes: 129 additions & 0 deletions wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
bitcast_a, bitcast_a_scale, bitcast_b, bitcast_b_scale, scaled_mma.
"""

from math import ceil
from sympy import Piecewise, ceiling, floor, Max

import wave_lang.kernel.lang as tkl
Expand Down Expand Up @@ -47,10 +48,23 @@ def get_tagged_mxfp4_gemm(
block_shape: (BLOCK_M, BLOCK_N, BLOCK_K) tile sizes.
mfma_variant: Scaled MMA instruction type.
wave_shape: (WAVE_M, WAVE_N) waves per workgroup.
reorder_workgroups: Enable N-dim workgroup reordering. When True,
compute_best_group_size_n() is called to auto-select the optimal
group size and decide whether reordering is actually beneficial.
group_size_n: Number of N-tiles per reordering group.

Returns:
(kernel_function, WaveCompileOptions)
"""
# Auto-select group_size_n and determine whether reordering helps
if reorder_workgroups:
group_size_n, reorder_workgroups = compute_best_group_size_n(
shape[0], shape[1], shape[2], block_shape[0], block_shape[1]
)
print(
f"[workgroup_reorder] enabled={reorder_workgroups}, group_size_n={group_size_n}"
)

M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K
Expand Down Expand Up @@ -154,10 +168,23 @@ def get_tagged_mxfp4_gemm_preshuffle_b(
wave_shape: (WAVE_M, WAVE_N) waves per workgroup.
mfma_variant: Scaled MMA instruction type.
a_address_space: Address space for A and A_scale (typically SHARED).
reorder_workgroups: Enable N-dim workgroup reordering. When True,
compute_best_group_size_n() is called to auto-select the optimal
group size and decide whether reordering is actually beneficial.
group_size_n: Number of N-tiles per reordering group.

Returns:
(kernel_function, WaveCompileOptions)
"""
# Auto-select group_size_n and determine whether reordering helps
if reorder_workgroups:
group_size_n, reorder_workgroups = compute_best_group_size_n(
shape[0], shape[1], shape[2], block_shape[0], block_shape[1]
)
print(
f"[workgroup_reorder] enabled={reorder_workgroups}, group_size_n={group_size_n}"
)

M = tkl.sym.M
N = tkl.sym.N
K = tkl.sym.K
Expand Down Expand Up @@ -311,6 +338,108 @@ def repeat(
return gemm, options


def compute_best_group_size_n(
M: int,
N: int,
K: int,
block_m: int,
block_n: int,
num_xcds: int = 8,
cus_per_xcd: int = 32,
) -> tuple[int, bool]:
"""Auto-select group_size_n and decide whether N-dim reordering is beneficial.

Dispatch model (MI300X / MI350):
Hardware assigns flat workgroup indices round-robin to XCDs.
Each XCD runs cus_per_xcd CUs in parallel, forming a "batch" of
cus_per_xcd concurrent workgroups.

Each batch covers U_A unique M-tiles × U_B unique N-tiles.
Per K-iteration DRAM fetches = U_A + U_B.
Minimise U_A + U_B subject to U_A × U_B ≈ cus_per_xcd (= 32).
Optimal: (U_A, U_B) = (4, 8) or (8, 4) → sum = 12.

WITHOUT N-reordering:
U_B_natural ≈ (cus_per_xcd × num_xcds) / num_wg_0 = 256 / num_wg_0
sum_natural = U_A_natural + U_B_natural

WITH N-reordering (group_size_n = gsn, multiple of num_xcds):
U_B = gsn / num_xcds
(cost function) sum_gsn = cus_per_xcd × num_xcds / gsn + gsn / num_xcds
= 256 / gsn + gsn / 8

Optimal gsn (derivation from cost function set to zero and solved for gsn)
≈ num_xcds × √cus_per_xcd ≈ 45 → closest power of two: gsn=32 and gsn=64

Worked examples (block_m = block_n = 256, MI300X defaults):

Shape (M, N) num_wg_0 U_B_natural sum_natural best_gsn enable
(4096, 57344) 16 16 18 32 YES ← num_wg_0 < 32
(8192, 57344) 32 8 12 -- NO ← already optimal
(16384, 16384) 64 4 12 -- NO ← already optimal
(32768, 16384) 128 2 18 64 YES ← num_wg_0 > 64

group_size_n selection:
Both gsn=32 (U_A=8, U_B=4) and gsn=64 (U_A=4, U_B=8) achieve sum=12.
Tie-breaking:
• Exact divisors of num_wg_1 are preferred (no tail group).
• B-heavy shapes (num_wg_1 >= num_wg_0): prefer gsn=32 (lower U_B →
more concurrent B sharing per batch).
• A-heavy shapes (num_wg_0 > num_wg_1): prefer gsn=64 (lower U_A →
more concurrent A sharing per batch).

Args:
M, N, K: Problem dimensions (K is accepted for API consistency
but does not affect the batch balance model).
block_m, block_n: Tile sizes along M and N.
num_xcds: XCD count (MI300X / MI350: 8).
cus_per_xcd: CUs per XCD (MI300X / MI350: 32).

Returns:
(best_group_size_n, reorder_enabled)
reorder_enabled=False means column-major dispatch already achieves the
optimal batch balance (sum=12); best_group_size_n is still returned
(32) as a safe default.
"""
num_wg_0 = ceil(M / block_m) # M-tiles
num_wg_1 = ceil(N / block_n) # N-tiles

total_wg = num_wg_0 * num_wg_1
if total_wg < num_xcds: # fewer wgs than XCDs, model meaningless and reordering too
return 32, False

candidates = [g for g in (16, 32, 64) if g % num_xcds == 0 and g <= num_wg_1]
if not candidates:
return num_xcds, False

def ub(g: int) -> int:
return g // num_xcds

def ua(g: int) -> int:
return cus_per_xcd // max(1, ub(g))

def gsn_sum(g: int) -> int:
return ua(g) + ub(g)

# Natural batch composition (no reordering)
u_b_nat = max(1, min(cus_per_xcd, cus_per_xcd * num_xcds // max(1, num_wg_0)))
u_a_nat = max(1, cus_per_xcd // u_b_nat)
sum_natural = u_a_nat + u_b_nat

best_sum = min(gsn_sum(g) for g in candidates)
reorder_enabled = best_sum < sum_natural

if not reorder_enabled:
return 32, False

optimal = [g for g in candidates if gsn_sum(g) == best_sum]
exact = [g for g in optimal if num_wg_1 % g == 0]
pool = exact if exact else optimal

# Tie-break: B-heavy → smaller gsn (more B sharing); A-heavy → larger gsn.
return (min(pool) if num_wg_1 >= num_wg_0 else max(pool)), True


def _reorder_mxfp4_workgroups(m, n, block_m, block_n, group_size_n):
"""Remap workgroup indices to a new order based on group_size_n along N dimension.

Expand Down