Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions examples/python/7.1_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ def _run_mxfp_gemm_preshuffle(gemm, shape, all=False, only_scale=False, only_b=F
)


def _get_8wave_shape_from_block(block):
"""Choose an 8-wave shape (4x2 or 2x4) from block M/N dims.

If either block M or N is 32, force that corresponding wave dimension to 2.
"""
m_blk, n_blk = block[0], block[1]
if m_blk == 32 and n_blk == 32:
raise ValueError(
"Cannot satisfy both M and N=32 with an 8-wave shape constrained to (4, 2) or (2, 4)."
)
if m_blk == 32:
return (2, 4)
if n_blk == 32:
return (4, 2)
return (4, 2)


def test_dbuf_4wave_mxfp_gemm(
is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256)
):
Expand All @@ -106,27 +123,40 @@ def test_dbuf_4wave_mxfp_gemm(


def test_dbuf_8wave_pingpong_mxfp_gemm(
is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256)
is_debug=False, shape=(1024, 1024, 8192), block=(128, 256, 256), dynamic=False
):
"""Double-buffered MXFP4 GEMM, 8 waves, ping-pong with stagger.
A&B scales are preshuffled and read from global memory directly to VGPRs.
A and B are read from global memory directly to LDS.

Note: for dynamic mode, keep block MxN at or below 128x256 or 256x128
to avoid exceeding shared-memory limits.
"""
wave_shape = _get_8wave_shape_from_block(block)
gemm, options = get_tagged_mxfp4_gemm_preshuffle_scales(
shape, block, wave_shape=(4, 2)
shape, block, wave_shape=wave_shape
)
options.specialize = True
options.use_buffer_ops = True
options.minimize_shared_allocs = True
options.linearize_shared_access = True

if dynamic:
options.dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K]
for sym in options.dynamic_symbols:
del options.subs[sym]

schedule = get_mxfp4_dbuf_pingpong_schedule(use_stagger=True, shape=shape)

options.print_ir_after = "all" if is_debug else []
options = set_default_run_config(options)
gemm = wave_compile(options, gemm, schedule)

_run_mxfp_gemm_preshuffle(gemm, shape, only_scale=True)
print("MXFP GEMM double-buffer 8-wave ping pong with scale shuffling test passed!")
mode = "dynamic" if dynamic else "static"
print(
f"MXFP GEMM double-buffer 8-wave ping pong with scale shuffling ({mode}) test passed!"
)


def test_dbuf_8wave_pingpong_mxfp_gemm_Bshuffle(
Expand All @@ -137,8 +167,9 @@ def test_dbuf_8wave_pingpong_mxfp_gemm_Bshuffle(
Same for B data. However, loading B directly to VGPR consumes too many VGPRs and causes spilling.
A is read from global memory directly to LDS.
"""
wave_shape = _get_8wave_shape_from_block(block)
gemm, options = get_tagged_mxfp4_gemm_preshuffle_scales_and_B(
shape, block, wave_shape=(4, 2)
shape, block, wave_shape=wave_shape
)
options.specialize = True
options.use_buffer_ops = True
Expand All @@ -156,7 +187,10 @@ def test_dbuf_8wave_pingpong_mxfp_gemm_Bshuffle(
gemm = wave_compile(options, gemm, schedule)

_run_mxfp_gemm_preshuffle(gemm, shape, all=True)
print("MXFP GEMM double-buffer 8-wave ping pong with scale shuffling test passed!")
mode = "dynamic" if dynamic else "static"
print(
f"MXFP GEMM double-buffer 8-wave ping pong with scale shuffling and B shuffled ({mode}) test passed!"
)


def test_dbuf_8wave_mixed_pingpong_mxfp_gemm(
Expand Down
12 changes: 6 additions & 6 deletions wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,17 +378,17 @@ def mxfp4_dbuf_schedule():

# If the bus gets congested and cluster memory dependency are affected, we must add a second barrier to fix the timing and prevent incorrect output results.
# In case a second a second workgroup barrier is needed, another schedule is created to hide the latency of that second barrier, by scheduling safe ds_read ops before the second barrier (see get_mxfp4_dbuf_mixed_pingpong_schedule).
use_extra_barrier = False
use_extra_barrier = True
# Build cluster 0: first K-partition loads + bitcasts + GatherToLDS
cluster_0_ops = [
tkw.SchedulingBarrier([]),
tkw.MemoryCounterWait(load=0),
tkw.WorkgroupBarrier(),
]
if use_extra_barrier:
cluster_0_ops.append(tkw.WorkgroupBarrier())
cluster_0_ops.extend(
[
tkw.MemoryCounterWait(load=0),
tkw.WorkgroupBarrier(),
loop_global_to_shared,
tkw.SchedulingBarrier([]),
loop_shared_load_a_0,
Expand Down Expand Up @@ -583,17 +583,17 @@ def mxfp4_dbuf_schedule():

# If the bus gets congested and cluster memory dependency are affected, we must add a second barrier to fix the timing and prevent incorrect output results.
# In case a second a second workgroup barrier is needed, another schedule is created to hide the latency of that second barrier, by scheduling safe ds_read ops before the second barrier (see get_mxfp4_dbuf_mixed_pingpong_schedule).
use_extra_barrier = False
use_extra_barrier = True
# Build cluster 0: first K-partition loads + bitcasts + GatherToLDS
cluster_0_ops = [
tkw.SchedulingBarrier([]),
tkw.MemoryCounterWait(load=0),
tkw.WorkgroupBarrier(),
]
if use_extra_barrier:
cluster_0_ops.append(tkw.WorkgroupBarrier())
cluster_0_ops.extend(
[
tkw.MemoryCounterWait(load=0),
tkw.WorkgroupBarrier(),
loop_global_to_shared,
tkw.SchedulingBarrier([]),
loop_shared_load_a_0,
Expand Down