From d71bb8131bbba83a038762f237edf1d7dcb19917 Mon Sep 17 00:00:00 2001 From: Aurore De Spirlet Date: Fri, 6 Mar 2026 02:03:09 +0000 Subject: [PATCH 1/4] add barrier for theoretically correct kernel functoning Signed-off-by: Aurore De Spirlet --- examples/python/7.1_schedule.py | 44 ++++++++++++++++--- .../schedules/gemm_mxfp4_double_buffer.py | 12 ++--- 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index 519b38f9c..239ed7552 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -88,6 +88,23 @@ def _run_mxfp_gemm_preshuffle(gemm, shape, all=False, only_scale=False, only_b=F ) +def _get_8wave_shape_from_block(block): + """Choose an 8-wave shape (4x2 or 2x4) from block M/N dims. + + If either block M or N is 32, force that corresponding wave dimension to 2. + """ + m_blk, n_blk = block[0], block[1] + if m_blk == 32 and n_blk == 32: + raise ValueError( + "Cannot satisfy both M and N=32 with an 8-wave shape constrained to (4, 2) or (2, 4)." + ) + if m_blk == 32: + return (2, 4) + if n_blk == 32: + return (4, 2) + return (4, 2) + + def test_dbuf_4wave_mxfp_gemm( is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256) ): @@ -106,19 +123,29 @@ def test_dbuf_4wave_mxfp_gemm( def test_dbuf_8wave_pingpong_mxfp_gemm( - is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256) + is_debug=False, shape=(1024, 1024, 8192), block=(128, 256, 256), dynamic=False ): """Double-buffered MXFP4 GEMM, 8 waves, ping-pong with stagger. A&B scales are preshuffled and read from global memory directly to VGPRs. A and B are read from global memory directly to LDS. + + Note: for dynamic mode, keep block MxN at or below 128x256 or 256x128 + to avoid exceeding shared-memory limits. """ + wave_shape = _get_8wave_shape_from_block(block) gemm, options = get_tagged_mxfp4_gemm_preshuffle_scales( - shape, block, wave_shape=(4, 2) + shape, block, wave_shape=wave_shape ) options.specialize = True options.use_buffer_ops = True options.minimize_shared_allocs = True options.linearize_shared_access = True + + if dynamic: + options.dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in options.dynamic_symbols: + del options.subs[sym] + schedule = get_mxfp4_dbuf_pingpong_schedule(use_stagger=True, shape=shape) options.print_ir_after = "all" if is_debug else [] @@ -126,7 +153,10 @@ def test_dbuf_8wave_pingpong_mxfp_gemm( gemm = wave_compile(options, gemm, schedule) _run_mxfp_gemm_preshuffle(gemm, shape, only_scale=True) - print("MXFP GEMM double-buffer 8-wave ping pong with scale shuffling test passed!") + mode = "dynamic" if dynamic else "static" + print( + f"MXFP GEMM double-buffer 8-wave ping pong with scale shuffling ({mode}) test passed!" + ) def test_dbuf_8wave_pingpong_mxfp_gemm_Bshuffle( @@ -137,8 +167,9 @@ def test_dbuf_8wave_pingpong_mxfp_gemm_Bshuffle( Same for B data. However, loading B directly to VGPR consumes too many VGPRs and causes spilling. A is read from global memory directly to LDS. """ + wave_shape = _get_8wave_shape_from_block(block) gemm, options = get_tagged_mxfp4_gemm_preshuffle_scales_and_B( - shape, block, wave_shape=(4, 2) + shape, block, wave_shape=wave_shape ) options.specialize = True options.use_buffer_ops = True @@ -156,7 +187,10 @@ def test_dbuf_8wave_pingpong_mxfp_gemm_Bshuffle( gemm = wave_compile(options, gemm, schedule) _run_mxfp_gemm_preshuffle(gemm, shape, all=True) - print("MXFP GEMM double-buffer 8-wave ping pong with scale shuffling test passed!") + mode = "dynamic" if dynamic else "static" + print( + f"MXFP GEMM double-buffer 8-wave ping pong with scale shuffling and B shuffled ({mode}) test passed!" + ) def test_dbuf_8wave_mixed_pingpong_mxfp_gemm( diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index 186f6bcf1..fb58205f4 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -378,17 +378,17 @@ def mxfp4_dbuf_schedule(): # If the bus gets congested and cluster memory dependency are affected, we must add a second barrier to fix the timing and prevent incorrect output results. # In case a second a second workgroup barrier is needed, another schedule is created to hide the latency of that second barrier, by scheduling safe ds_read ops before the second barrier (see get_mxfp4_dbuf_mixed_pingpong_schedule). - use_extra_barrier = False + use_extra_barrier = True # Build cluster 0: first K-partition loads + bitcasts + GatherToLDS cluster_0_ops = [ tkw.SchedulingBarrier([]), + tkw.MemoryCounterWait(load=0), + tkw.WorkgroupBarrier(), ] if use_extra_barrier: cluster_0_ops.append(tkw.WorkgroupBarrier()) cluster_0_ops.extend( [ - tkw.MemoryCounterWait(load=0), - tkw.WorkgroupBarrier(), loop_global_to_shared, tkw.SchedulingBarrier([]), loop_shared_load_a_0, @@ -583,17 +583,17 @@ def mxfp4_dbuf_schedule(): # If the bus gets congested and cluster memory dependency are affected, we must add a second barrier to fix the timing and prevent incorrect output results. # In case a second a second workgroup barrier is needed, another schedule is created to hide the latency of that second barrier, by scheduling safe ds_read ops before the second barrier (see get_mxfp4_dbuf_mixed_pingpong_schedule). - use_extra_barrier = False + use_extra_barrier = True # Build cluster 0: first K-partition loads + bitcasts + GatherToLDS cluster_0_ops = [ tkw.SchedulingBarrier([]), + tkw.MemoryCounterWait(load=0), + tkw.WorkgroupBarrier(), ] if use_extra_barrier: cluster_0_ops.append(tkw.WorkgroupBarrier()) cluster_0_ops.extend( [ - tkw.MemoryCounterWait(load=0), - tkw.WorkgroupBarrier(), loop_global_to_shared, tkw.SchedulingBarrier([]), loop_shared_load_a_0, From 126df230acade0a4676f3dc0232823dcbe8ee3bc Mon Sep 17 00:00:00 2001 From: Aurore De Spirlet Date: Fri, 6 Mar 2026 16:28:40 +0000 Subject: [PATCH 2/4] add all tests from 7.1 to ci Signed-off-by: Aurore De Spirlet --- tests/kernel/wave_gemm_mxfp_test.py | 63 +++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index 88107fa82..c48dd7933 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -30,10 +30,12 @@ get_tagged_mxfp4_gemm, get_tagged_mxfp4_gemm_preshuffle_b, get_tagged_mxfp4_gemm_preshuffle_scales, + get_tagged_mxfp4_gemm_preshuffle_scales_and_B, ) from wave_lang.kernel.wave.schedules import ( get_mxfp4_dbuf_schedule, get_mxfp4_dbuf_pingpong_schedule, + get_mxfp4_dbuf_pingpong_schedule_Bshuffled, get_mxfp4_asymmetric_schedule, ) from wave_lang.kernel.wave.utils.mxfp_utils import ( @@ -960,6 +962,8 @@ def testScaledGemmMXFP4PreshuffleBDynamic( (256, 160, 256), (256, 224, 256), (128, 128, 256), + (128, 256, 256), + (256, 32, 256), (64, 192, 256), (64, 128, 256), ] @@ -976,10 +980,12 @@ def testScaledGemmMXFP4PreshuffleBDynamic( "mfma_variant", [ScaledMMAType.F32_16x16x128_F8F6F4], ) +@pytest.mark.parametrize("dynamic", [False, True], ids=["static", "dynamic"]) def testScaledGemmMXFP4PreshuffleMacrotiles8WavePingpong( shape: tuple[int, int, int], block_shape: tuple[int, int, int], mfma_variant: ScaledMMAType, + dynamic: bool, ): """8-wave double-buffered MXFP4 GEMM with ping-pong schedule and scale preshuffling. (A&B scales preshuffled, A and B global-to-LDS). @@ -993,6 +999,10 @@ def testScaledGemmMXFP4PreshuffleMacrotiles8WavePingpong( options.specialize = True options.use_buffer_ops = True options.minimize_shared_allocs = True + if dynamic: + options.dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in options.dynamic_symbols: + del options.subs[sym] schedule = get_mxfp4_dbuf_pingpong_schedule(use_stagger=True, shape=shape) options = set_default_run_config(options) gemm = wave_compile(options, gemm, schedule) @@ -1016,6 +1026,59 @@ def testScaledGemmMXFP4PreshuffleMacrotiles8WavePingpong( torch.testing.assert_close(torch_out, out, check_dtype=False) +@require_e2e +@require_cdna4 +@pytest.mark.parametrize( + "shape", + [(1024, 1024, 8192)], +) +@pytest.mark.parametrize("block_shape", MACROTILES_PRESHUFFLE_8WAVE_PINGPONG) +@pytest.mark.parametrize( + "mfma_variant", + [ScaledMMAType.F32_16x16x128_F8F6F4], +) +@pytest.mark.parametrize("dynamic", [False, True], ids=["static", "dynamic"]) +def testScaledGemmMXFP4PreshuffleScalesAndBMacrotiles8WavePingpong( + shape: tuple[int, int, int], + block_shape: tuple[int, int, int], + mfma_variant: ScaledMMAType, + dynamic: bool, +): + """8-wave double-buffered MXFP4 GEMM with ping-pong schedule, scale and B preshuffling. + (A&B scales preshuffled and B preshuffled in K-pack order). + """ + gemm, options = get_tagged_mxfp4_gemm_preshuffle_scales_and_B( + shape, + block_shape, + wave_shape=(4, 2), + mfma_variant=mfma_variant, + ) + options.specialize = True + options.use_buffer_ops = True + options.minimize_shared_allocs = True + options.linearize_shared_access = True + if dynamic: + options.dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in options.dynamic_symbols: + del options.subs[sym] + schedule = get_mxfp4_dbuf_pingpong_schedule_Bshuffled(use_stagger=True, shape=shape) + options = set_default_run_config(options) + gemm = wave_compile(options, gemm, schedule) + + x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) + torch_out = torchScaledGemmMXFP4(x, w, x_scales, w_scales) + + w_t = w.T.contiguous() + w_t_ps = b_preshuffle(w_t) + x_scales_ps = e8m0_shuffle(x_scales) + w_scales_ps = e8m0_shuffle(w_scales) + + out = device_zeros(x.shape[0], w_t_ps.shape[0], dtype=torch.float32) + gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out) + + torch.testing.assert_close(torch_out, out, check_dtype=False) + + def get_gfx1250_scaled_gemm_mxfp4_template( shape: tuple[int], mfma_variant: ScaledMMAType, enable_scheduling: SchedulingType ) -> tuple[WaveCompileOptions, "LaunchableWave"]: From 8934969461ac9273488ed8d4119eae3a19167d36 Mon Sep 17 00:00:00 2001 From: Aurore De Spirlet Date: Fri, 6 Mar 2026 18:33:22 +0000 Subject: [PATCH 3/4] skip dynamic on xpecific tile size to prevent going out of LDS space Signed-off-by: Aurore De Spirlet --- tests/kernel/wave_gemm_mxfp_test.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index c48dd7933..30a77c7de 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -969,6 +969,14 @@ def testScaledGemmMXFP4PreshuffleBDynamic( ] +_DYNAMIC_ALLOWED_PRESHUFFLE_8WAVE_BLOCKS = { + (128, 128, 256), + (128, 256, 256), + (64, 192, 256), + (64, 128, 256), +} + + @require_e2e @require_cdna4 @pytest.mark.parametrize( @@ -990,6 +998,9 @@ def testScaledGemmMXFP4PreshuffleMacrotiles8WavePingpong( """8-wave double-buffered MXFP4 GEMM with ping-pong schedule and scale preshuffling. (A&B scales preshuffled, A and B global-to-LDS). """ + if dynamic and block_shape not in _DYNAMIC_ALLOWED_PRESHUFFLE_8WAVE_BLOCKS: + pytest.skip("Dynamic mode is only covered for selected block shapes.") + gemm, options = get_tagged_mxfp4_gemm_preshuffle_scales( shape, block_shape, @@ -1011,12 +1022,6 @@ def testScaledGemmMXFP4PreshuffleMacrotiles8WavePingpong( torch_out = torchScaledGemmMXFP4(x, w, x_scales, w_scales) w_t = w.T.contiguous() - w_t_ps = b_preshuffle(w_t) - x_scales_ps = e8m0_shuffle(x_scales) - w_scales_ps = e8m0_shuffle(w_scales) - - out = device_zeros(x.shape[0], w_t_ps.shape[0], dtype=torch.float32) - gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out) x_scales_ps = e8m0_shuffle(x_scales) w_scales_ps = e8m0_shuffle(w_scales) @@ -1047,6 +1052,9 @@ def testScaledGemmMXFP4PreshuffleScalesAndBMacrotiles8WavePingpong( """8-wave double-buffered MXFP4 GEMM with ping-pong schedule, scale and B preshuffling. (A&B scales preshuffled and B preshuffled in K-pack order). """ + if dynamic and block_shape not in _DYNAMIC_ALLOWED_PRESHUFFLE_8WAVE_BLOCKS: + pytest.skip("Dynamic mode is only covered for selected block shapes.") + gemm, options = get_tagged_mxfp4_gemm_preshuffle_scales_and_B( shape, block_shape, From eb6383034b68ce5fb299eac43c3cae02053450e2 Mon Sep 17 00:00:00 2001 From: Aurore De Spirlet Date: Fri, 6 Mar 2026 18:40:09 +0000 Subject: [PATCH 4/4] cleaning Signed-off-by: Aurore De Spirlet --- tests/kernel/wave_gemm_mxfp_test.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index 30a77c7de..d7e36abd2 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -997,6 +997,7 @@ def testScaledGemmMXFP4PreshuffleMacrotiles8WavePingpong( ): """8-wave double-buffered MXFP4 GEMM with ping-pong schedule and scale preshuffling. (A&B scales preshuffled, A and B global-to-LDS). + Note: In dynamic mode, this test only covers selected block shapes to avoid exceeding LDS memory limits. """ if dynamic and block_shape not in _DYNAMIC_ALLOWED_PRESHUFFLE_8WAVE_BLOCKS: pytest.skip("Dynamic mode is only covered for selected block shapes.")