diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index 519b38f9c..239ed7552 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -88,6 +88,23 @@ def _run_mxfp_gemm_preshuffle(gemm, shape, all=False, only_scale=False, only_b=F ) +def _get_8wave_shape_from_block(block): + """Choose an 8-wave shape (4x2 or 2x4) from block M/N dims. + + If either block M or N is 32, force that corresponding wave dimension to 2. + """ + m_blk, n_blk = block[0], block[1] + if m_blk == 32 and n_blk == 32: + raise ValueError( + "Cannot satisfy both M and N=32 with an 8-wave shape constrained to (4, 2) or (2, 4)." + ) + if m_blk == 32: + return (2, 4) + if n_blk == 32: + return (4, 2) + return (4, 2) + + def test_dbuf_4wave_mxfp_gemm( is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256) ): @@ -106,19 +123,29 @@ def test_dbuf_4wave_mxfp_gemm( def test_dbuf_8wave_pingpong_mxfp_gemm( - is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256) + is_debug=False, shape=(1024, 1024, 8192), block=(128, 256, 256), dynamic=False ): """Double-buffered MXFP4 GEMM, 8 waves, ping-pong with stagger. A&B scales are preshuffled and read from global memory directly to VGPRs. A and B are read from global memory directly to LDS. + + Note: for dynamic mode, keep block MxN at or below 128x256 or 256x128 + to avoid exceeding shared-memory limits. """ + wave_shape = _get_8wave_shape_from_block(block) gemm, options = get_tagged_mxfp4_gemm_preshuffle_scales( - shape, block, wave_shape=(4, 2) + shape, block, wave_shape=wave_shape ) options.specialize = True options.use_buffer_ops = True options.minimize_shared_allocs = True options.linearize_shared_access = True + + if dynamic: + options.dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in options.dynamic_symbols: + del options.subs[sym] + schedule = get_mxfp4_dbuf_pingpong_schedule(use_stagger=True, shape=shape) options.print_ir_after = "all" if is_debug else [] @@ -126,7 +153,10 @@ def test_dbuf_8wave_pingpong_mxfp_gemm( gemm = wave_compile(options, gemm, schedule) _run_mxfp_gemm_preshuffle(gemm, shape, only_scale=True) - print("MXFP GEMM double-buffer 8-wave ping pong with scale shuffling test passed!") + mode = "dynamic" if dynamic else "static" + print( + f"MXFP GEMM double-buffer 8-wave ping pong with scale shuffling ({mode}) test passed!" + ) def test_dbuf_8wave_pingpong_mxfp_gemm_Bshuffle( @@ -137,8 +167,9 @@ def test_dbuf_8wave_pingpong_mxfp_gemm_Bshuffle( Same for B data. However, loading B directly to VGPR consumes too many VGPRs and causes spilling. A is read from global memory directly to LDS. """ + wave_shape = _get_8wave_shape_from_block(block) gemm, options = get_tagged_mxfp4_gemm_preshuffle_scales_and_B( - shape, block, wave_shape=(4, 2) + shape, block, wave_shape=wave_shape ) options.specialize = True options.use_buffer_ops = True @@ -156,7 +187,10 @@ def test_dbuf_8wave_pingpong_mxfp_gemm_Bshuffle( gemm = wave_compile(options, gemm, schedule) _run_mxfp_gemm_preshuffle(gemm, shape, all=True) - print("MXFP GEMM double-buffer 8-wave ping pong with scale shuffling test passed!") + mode = "dynamic" if dynamic else "static" + print( + f"MXFP GEMM double-buffer 8-wave ping pong with scale shuffling and B shuffled ({mode}) test passed!" + ) def test_dbuf_8wave_mixed_pingpong_mxfp_gemm( diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index 1eadcfe49..d627183e0 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -30,10 +30,12 @@ get_tagged_mxfp4_gemm, get_tagged_mxfp4_gemm_preshuffle_b, get_tagged_mxfp4_gemm_preshuffle_scales, + get_tagged_mxfp4_gemm_preshuffle_scales_and_B, ) from wave_lang.kernel.wave.schedules import ( get_mxfp4_dbuf_schedule, get_mxfp4_dbuf_pingpong_schedule, + get_mxfp4_dbuf_pingpong_schedule_Bshuffled, get_mxfp4_asymmetric_schedule, ) from wave_lang.kernel.wave.utils.mxfp_utils import ( @@ -958,6 +960,8 @@ def testScaledGemmMXFP4PreshuffleBDynamic( (256, 160, 256), (256, 224, 256), (128, 128, 256), + (128, 256, 256), + (256, 32, 256), (64, 192, 256), (64, 128, 256), ] @@ -974,10 +978,12 @@ def testScaledGemmMXFP4PreshuffleBDynamic( "mfma_variant", [ScaledMMAType.F32_16x16x128_F8F6F4], ) +@pytest.mark.parametrize("dynamic", [False, True], ids=["static", "dynamic"]) def testScaledGemmMXFP4PreshuffleMacrotiles8WavePingpong( shape: tuple[int, int, int], block_shape: tuple[int, int, int], mfma_variant: ScaledMMAType, + dynamic: bool, ): """8-wave double-buffered MXFP4 GEMM with ping-pong schedule and scale preshuffling. (A&B scales preshuffled, A and B global-to-LDS). @@ -991,6 +997,10 @@ def testScaledGemmMXFP4PreshuffleMacrotiles8WavePingpong( options.specialize = True options.use_buffer_ops = True options.minimize_shared_allocs = True + if dynamic: + options.dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in options.dynamic_symbols: + del options.subs[sym] schedule = get_mxfp4_dbuf_pingpong_schedule(use_stagger=True, shape=shape) options = set_default_run_config(options) gemm = wave_compile(options, gemm, schedule) @@ -1014,6 +1024,59 @@ def testScaledGemmMXFP4PreshuffleMacrotiles8WavePingpong( torch.testing.assert_close(torch_out, out, check_dtype=False) +@require_e2e +@require_cdna4 +@pytest.mark.parametrize( + "shape", + [(1024, 1024, 8192)], +) +@pytest.mark.parametrize("block_shape", MACROTILES_PRESHUFFLE_8WAVE_PINGPONG) +@pytest.mark.parametrize( + "mfma_variant", + [ScaledMMAType.F32_16x16x128_F8F6F4], +) +@pytest.mark.parametrize("dynamic", [False, True], ids=["static", "dynamic"]) +def testScaledGemmMXFP4PreshuffleScalesAndBMacrotiles8WavePingpong( + shape: tuple[int, int, int], + block_shape: tuple[int, int, int], + mfma_variant: ScaledMMAType, + dynamic: bool, +): + """8-wave double-buffered MXFP4 GEMM with ping-pong schedule, scale and B preshuffling. + (A&B scales preshuffled and B preshuffled in K-pack order). + """ + gemm, options = get_tagged_mxfp4_gemm_preshuffle_scales_and_B( + shape, + block_shape, + wave_shape=(4, 2), + mfma_variant=mfma_variant, + ) + options.specialize = True + options.use_buffer_ops = True + options.minimize_shared_allocs = True + options.linearize_shared_access = True + if dynamic: + options.dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in options.dynamic_symbols: + del options.subs[sym] + schedule = get_mxfp4_dbuf_pingpong_schedule_Bshuffled(use_stagger=True, shape=shape) + options = set_default_run_config(options) + gemm = wave_compile(options, gemm, schedule) + + x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) + torch_out = torchScaledGemmMXFP4(x, w, x_scales, w_scales) + + w_t = w.T.contiguous() + w_t_ps = b_preshuffle(w_t) + x_scales_ps = e8m0_shuffle(x_scales) + w_scales_ps = e8m0_shuffle(w_scales) + + out = device_zeros(x.shape[0], w_t_ps.shape[0], dtype=torch.float32) + gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out) + + torch.testing.assert_close(torch_out, out, check_dtype=False) + + def get_gfx1250_scaled_gemm_mxfp4_template( shape: tuple[int], mfma_variant: ScaledMMAType, enable_scheduling: SchedulingType ) -> tuple[WaveCompileOptions, "LaunchableWave"]: diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index 186f6bcf1..fb58205f4 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -378,17 +378,17 @@ def mxfp4_dbuf_schedule(): # If the bus gets congested and cluster memory dependency are affected, we must add a second barrier to fix the timing and prevent incorrect output results. # In case a second a second workgroup barrier is needed, another schedule is created to hide the latency of that second barrier, by scheduling safe ds_read ops before the second barrier (see get_mxfp4_dbuf_mixed_pingpong_schedule). - use_extra_barrier = False + use_extra_barrier = True # Build cluster 0: first K-partition loads + bitcasts + GatherToLDS cluster_0_ops = [ tkw.SchedulingBarrier([]), + tkw.MemoryCounterWait(load=0), + tkw.WorkgroupBarrier(), ] if use_extra_barrier: cluster_0_ops.append(tkw.WorkgroupBarrier()) cluster_0_ops.extend( [ - tkw.MemoryCounterWait(load=0), - tkw.WorkgroupBarrier(), loop_global_to_shared, tkw.SchedulingBarrier([]), loop_shared_load_a_0, @@ -583,17 +583,17 @@ def mxfp4_dbuf_schedule(): # If the bus gets congested and cluster memory dependency are affected, we must add a second barrier to fix the timing and prevent incorrect output results. # In case a second a second workgroup barrier is needed, another schedule is created to hide the latency of that second barrier, by scheduling safe ds_read ops before the second barrier (see get_mxfp4_dbuf_mixed_pingpong_schedule). - use_extra_barrier = False + use_extra_barrier = True # Build cluster 0: first K-partition loads + bitcasts + GatherToLDS cluster_0_ops = [ tkw.SchedulingBarrier([]), + tkw.MemoryCounterWait(load=0), + tkw.WorkgroupBarrier(), ] if use_extra_barrier: cluster_0_ops.append(tkw.WorkgroupBarrier()) cluster_0_ops.extend( [ - tkw.MemoryCounterWait(load=0), - tkw.WorkgroupBarrier(), loop_global_to_shared, tkw.SchedulingBarrier([]), loop_shared_load_a_0,