From 9ec7fdb9a3dd9a118b5cdb21e6f9f40e0adc2fb4 Mon Sep 17 00:00:00 2001 From: William G Hatch Date: Wed, 18 Feb 2026 17:00:16 -0700 Subject: [PATCH 1/3] Implement mxfp4 split-k gemm The core things added are split-k gemm, and it is tested for (1) generation of the `buffer_atomic_pk_add_bf16` instruction that we wanted to use, and (2) for gemm correctness. Overview of some of the major changes: - `remove_global_indexing` in `general_utils.py`: Zeroes out tiling constraint starts (e.g. `K_SPLIT_OFF`) alongside workgroup IDs before dimension scaling, so that the subtraction of the start offset doesn't mix scaled and unscaled units (K vs K/32 for MXFP4 scales). - Fixing spurious bounds on split-K tiling that prevented scale vector merging: TilingConstraint.get_index_bound was conservatively generating bounds for the split-K case because sympy could not prove that ceiling(Min(K, f(wg)) / tile) * tile <= K. These bounds prevented merge_contiguous_reads from combining scalar scale reads into vector<4xi8> loads (it skips reads that already have bounds). Add _work_may_exceed_dim() to structurally detect the aligned split-k pattern and prove no overshoot, avoiding the spurious bound. (This was necessary to get scale_preshuffle to have 4x vector loads when combined with split-k.) Signed-off-by: William G Hatch --- examples/python/7.1_schedule.py | 220 +++++- tests/kernel/wave_gemm_test.py | 56 +- .../wave_asm/test/e2e/test_asm_backend_e2e.py | 708 ++++++++++++++++++ wave_lang/kernel/wave/constraints.py | 43 +- wave_lang/kernel/wave/templates/gemm.py | 201 +++++ wave_lang/kernel/wave/utils/general_utils.py | 12 +- 6 files changed, 1232 insertions(+), 8 deletions(-) diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index 5110fd3c92..2d0b497cbd 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -1,24 +1,33 @@ """ MXFP4 Scaled GEMM Scheduling for GFX950 (MI350) -Double-buffered MXFP4 GEMM with 4-wave and 8-wave configurations. +Double-buffered MXFP4 GEMM with 4-wave and 8-wave configurations, plus split-K. Uses get_tagged_mxfp4_gemm (templates) + get_mxfp4_dbuf_schedule (schedules). +Split-K kernels use the wave_asm backend with atomic bf16 output. Usage: python 7.1_schedule.py --test test_dbuf_4wave_mxfp_gemm python 7.1_schedule.py --test test_dbuf_8wave_mxfp_gemm python 7.1_schedule.py --test test_dbuf_8wave_mxfp_gemm --debug + python 7.1_schedule.py --test test_splitk_gemm + python 7.1_schedule.py --test test_splitk_preshuffle_scales_gemm python 7.1_schedule.py --list_tests """ +import sys +from dataclasses import dataclass +from pathlib import Path + import torch -from wave_lang.kernel.wave.compile import wave_compile +from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile +from wave_lang.kernel.wave.constraints import ScaledMMAType from wave_lang.kernel.wave.utils.run_utils import set_default_run_config from wave_lang.kernel.wave.templates import ( get_tagged_mxfp4_gemm, get_tagged_mxfp4_gemm_preshuffle_b, ) +from wave_lang.kernel.wave.templates.gemm import get_splitk_mxfp4_gemm_kernel from wave_lang.kernel.wave.schedules import ( get_mxfp4_dbuf_schedule, get_mxfp4_dbuf_pingpong_schedule, @@ -31,9 +40,22 @@ b_preshuffle, e8m0_shuffle, ) -from wave_lang.kernel.lang.global_symbols import GLOBAL_ADDRESS_SPACE +import wave_lang.kernel.lang as tkl +from wave_lang.kernel.lang.global_symbols import ( + GLOBAL_ADDRESS_SPACE, + SHARED_ADDRESS_SPACE, +) from utils import parse_args, list_tests, run_test +_EXAMPLES_DIR = Path(__file__).parent +_WAVE_ROOT = _EXAMPLES_DIR.parent.parent +_E2E_DIR = ( + _WAVE_ROOT / "wave_lang" / "kernel" / "wave" / "asm" / "wave_asm" / "test" / "e2e" +) +for _p in [str(_EXAMPLES_DIR), str(_WAVE_ROOT), str(_E2E_DIR)]: + if _p not in sys.path: + sys.path.insert(0, _p) + def _run_mxfp_gemm(gemm, shape): """Run compiled GEMM kernel and verify against reference.""" @@ -182,6 +204,198 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm( print("MXFP GEMM preshuffle-B 4-wave test passed!") +@dataclass +class SplitKKernelHandle: + """Opaque handle for a compiled split-K MXFP4 kernel.""" + + gpu_func: object + binary_path: Path + kernel_name: str + grid: tuple[int, int, int] + block: tuple[int, int, int] + lds_size: int + num_splits: int + + +def get_splitk_kernel( + shape: tuple[int, int, int], + block: tuple[int, int, int] = (128, 128, 128), + num_splits: int = 2, + waves_per_block: tuple[int, int] = (2, 2), + preshuffle_scales: bool = False, + compiler=None, +) -> SplitKKernelHandle: + """Compile a split-K MXFP4 GEMM kernel through the wave_asm backend. + + Output tensor must be bf16 and zero-initialised before each call. + w must be in [N, K/2] layout. + + Args: + preshuffle_scales: If True, a_scale and b_scale are read from GLOBAL + memory using the e8m0_shuffle IndexMapping. The caller must pass + scales pre-shuffled with e8m0_shuffle(). + """ + from waveasm_e2e import WaveASMCompiler, capture_wave_kernel_info + from test_asm_backend_e2e import get_target_arch + + if compiler is None: + compiler = WaveASMCompiler(target=get_target_arch()) + + splitk_fn, hyperparams = get_splitk_mxfp4_gemm_kernel( + shape, + num_splits=num_splits, + mfma_variant=ScaledMMAType.F32_16x16x128_F8F6F4, + block_shape=block, + waves_per_block=waves_per_block, + preshuffle_scales=preshuffle_scales, + ) + hyperparams[tkl.sym.ADDRESS_SPACE] = SHARED_ADDRESS_SPACE + hyperparams[tkl.sym.B_ADDRESS_SPACE] = SHARED_ADDRESS_SPACE + + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + backend="asm", + wave_runtime=True, + compile_to_mlir=False, + use_global_to_shared=True, + ) + options = set_default_run_config(options) + + kernel_info = capture_wave_kernel_info(options, splitk_fn) + cpp_result = compiler.compile_full( + kernel_info.mlir_text, kernel_info.workgroup_size + ) + if not cpp_result.success: + raise RuntimeError(f"wave_asm compilation failed: {cpp_result.error_message}") + + import wave_runtime + + wave_runtime.load_hip_functions() + _binary, gpu_func = wave_runtime.load_binary( + str(cpp_result.binary_path), + cpp_result.get_kernel_name() or kernel_info.kernel_name, + ) + + return SplitKKernelHandle( + gpu_func=gpu_func, + binary_path=cpp_result.binary_path, + kernel_name=cpp_result.get_kernel_name() or kernel_info.kernel_name, + grid=kernel_info.grid_size, + block=kernel_info.workgroup_size, + lds_size=kernel_info.lds_size, + num_splits=num_splits, + ) + + +def run_splitk_kernel( + handle: SplitKKernelHandle, + x: torch.Tensor, + x_scales: torch.Tensor, + w: torch.Tensor, + w_scales: torch.Tensor, + c_out: torch.Tensor, +) -> None: + """Launch a compiled split-K kernel. + + c_out must be zero-initialised (dtype=torch.bfloat16). + w must be in [N, K/2] layout. + """ + import wave_runtime + + stream = torch.cuda.current_stream().cuda_stream + kli = wave_runtime.KernelLaunchInfo( + stream, + handle.gpu_func, + handle.lds_size, + handle.grid[0], + handle.grid[1], + handle.grid[2], + handle.block[0], + handle.block[1], + handle.block[2], + 1, + 1, + 1, + ) + kern_args = wave_runtime.Int64Vector( + [t.data_ptr() for t in [x, x_scales, w, w_scales, c_out]] + ) + wave_runtime.launch(kli, kern_args, [], []) + + +def test_splitk_gemm( + is_debug: bool = False, + shape: tuple[int, int, int] = (1024, 1024, 8192), + block: tuple[int, int, int] = (128, 128, 256), +): + """Split-K MXFP4 GEMM (wave_asm backend, bf16 output, unshuffled scales).""" + m, n, k = shape + num_splits = 2 + x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) + # w from generate_gemm_afp4wfp4_inputs is [K/2, N]; split-K kernel wants [N, K/2] + w_nk = w.T.contiguous() + torch_ref = torchScaledGemmMXFP4(x, w, x_scales, w_scales) + + handle = get_splitk_kernel(shape, block=block, num_splits=num_splits) + + c_out = torch.zeros(m, n, dtype=torch.bfloat16, device="cuda") + run_splitk_kernel( + handle, x.cuda(), x_scales.cuda(), w_nk.cuda(), w_scales.cuda(), c_out + ) + torch.cuda.synchronize() + + bf16_eps = 2**-7 + atol = num_splits * bf16_eps * max(torch_ref.abs().max().item(), 1.0) + torch.testing.assert_close( + torch_ref, + c_out.cpu().to(torch.float32), + check_dtype=False, + check_device=False, + atol=atol, + rtol=0.0, + ) + print("Split-K MXFP4 GEMM test passed!") + + +def test_splitk_preshuffle_scales_gemm( + is_debug: bool = False, + shape: tuple[int, int, int] = (1024, 1024, 8192), + block: tuple[int, int, int] = (128, 128, 256), +): + """Split-K MXFP4 GEMM (wave_asm backend, bf16 output, preshuffled scales).""" + m, n, k = shape + num_splits = 2 + x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) + w_nk = w.T.contiguous() + torch_ref = torchScaledGemmMXFP4(x, w, x_scales, w_scales) + + x_scales_sh = e8m0_shuffle(x_scales) + w_scales_sh = e8m0_shuffle(w_scales) + + handle = get_splitk_kernel( + shape, block=block, num_splits=num_splits, preshuffle_scales=True + ) + + c_out = torch.zeros(m, n, dtype=torch.bfloat16, device="cuda") + run_splitk_kernel( + handle, x.cuda(), x_scales_sh.cuda(), w_nk.cuda(), w_scales_sh.cuda(), c_out + ) + torch.cuda.synchronize() + + bf16_eps = 2**-7 + atol = num_splits * bf16_eps * max(torch_ref.abs().max().item(), 1.0) + torch.testing.assert_close( + torch_ref, + c_out.cpu().to(torch.float32), + check_dtype=False, + check_device=False, + atol=atol, + rtol=0.0, + ) + print("Split-K MXFP4 GEMM (preshuffled scales) test passed!") + + if __name__ == "__main__": args = parse_args() diff --git a/tests/kernel/wave_gemm_test.py b/tests/kernel/wave_gemm_test.py index bcf96493cf..66a5af6d22 100644 --- a/tests/kernel/wave_gemm_test.py +++ b/tests/kernel/wave_gemm_test.py @@ -41,12 +41,22 @@ require_rdna4, use_water_backend_bool, ) -from wave_lang.kernel.wave.constraints import MMAType, MMAOperand, GenericDot +from wave_lang.kernel.wave.constraints import ( + MMAType, + MMAOperand, + GenericDot, + ScaledMMAType, +) +from wave_lang.kernel.wave.utils.mxfp_utils import ( + generate_gemm_afp4wfp4_inputs, + torchScaledGemmMXFP4, +) from wave_lang.kernel.wave.templates.gemm import ( get_gemm_kernel, get_gemm_kernel_transpose_a_b, get_persistent_gemm_kernel, get_splitk_gemm_kernel, + get_splitk_mxfp4_gemm_kernel, get_streamk_gemm_kernel, get_hybrid_streamk_gemm_kernel, get_persistent_reordering_kernel, @@ -3614,3 +3624,47 @@ def testSplitKGemm( torch_ref = a.cpu().to(torch.float32) @ b.cpu().T.to(torch.float32) assert_close(c.cpu(), torch_ref, rtol=1e-3, atol=1e-2) + + +@require_e2e +@require_cdna4 +@pytest.mark.parametrize( + "shape, num_splits", + [ + ((256, 256, 256), 2), + ((256, 256, 512), 4), + ((256, 256, 1024), 4), + ((512, 512, 1024), 2), + ], +) +def testSplitKMxfp4Gemm( + shape: tuple[int, int, int], + num_splits: int, +): + splitk_gemm, hyperparams = get_splitk_mxfp4_gemm_kernel( + shape, + num_splits=num_splits, + mfma_variant=ScaledMMAType.F32_16x16x128_F8F6F4, + ) + + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + ) + options = set_default_run_config(options) + splitk_gemm = wave_compile(options, splitk_gemm) + + m, n, k = shape + x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs( + shape, device=torch.device("cpu") + ) + torch_ref = torchScaledGemmMXFP4(x, w, x_scales, w_scales) + + x_gpu = x.cuda() + w_t_gpu = w.T.contiguous().cuda() + x_scales_gpu = x_scales.cuda() + w_scales_gpu = w_scales.cuda() + c_gpu = device_zeros(m, n, dtype=torch.bfloat16) + + splitk_gemm(x_gpu, x_scales_gpu, w_t_gpu, w_scales_gpu, c_gpu) + assert_close(c_gpu.cpu().to(torch.float32), torch_ref, rtol=5e-2, atol=5e-2) diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/e2e/test_asm_backend_e2e.py b/wave_lang/kernel/wave/asm/wave_asm/test/e2e/test_asm_backend_e2e.py index 3319b1ac9b..1a776b0b74 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/e2e/test_asm_backend_e2e.py +++ b/wave_lang/kernel/wave/asm/wave_asm/test/e2e/test_asm_backend_e2e.py @@ -1797,6 +1797,714 @@ def count_instructions(asm): assert cpp_stats["buffer_load"] > 0 or python_stats["buffer_load"] > 0 +# ============================================================================= +# Test: Split-K MXFP4 GEMM with bf16 atomic add (buffer_atomic_pk_add_bf16) +# ============================================================================= + + +@pytest.mark.run_e2e +@pytest.mark.parametrize( + "shape,num_splits", + [ + ((256, 256, 256), 2), + ], +) +def test_splitk_mxfp4_bf16_atomic_cpp_backend( + shape, num_splits, compiler, backend, dump_asm +): + """End-to-end test for split-K MXFP4 GEMM with bf16 atomic accumulation. + + Validates that the C++ WaveASM backend correctly emits the + buffer_atomic_pk_add_bf16 instruction for bf16 atomic add operations. + + Requires CDNA4 (gfx950+) for both scaled MFMA and bf16 atomics. + """ + if not is_cdna4(): + pytest.skip("Split-K MXFP4 with bf16 atomics requires gfx950+ (CDNA4)") + + skip_if_no_gpu() + skip_if_no_wave_lang() + + import torch + from torch.testing import assert_close + + import wave_lang.kernel.lang as tkl + from wave_lang.kernel.lang.global_symbols import ( + SHARED_ADDRESS_SPACE, + ) + from wave_lang.kernel.wave.compile import WaveCompileOptions + from wave_lang.kernel.wave.constraints import ScaledMMAType + from wave_lang.kernel.wave.templates.gemm import get_splitk_mxfp4_gemm_kernel + from wave_lang.kernel.wave.utils.mxfp_utils import ( + generate_gemm_afp4wfp4_inputs, + torchScaledGemmMXFP4, + ) + from wave_lang.kernel.wave.utils.run_utils import set_default_run_config + + splitk_gemm, hyperparams = get_splitk_mxfp4_gemm_kernel( + shape, + num_splits=num_splits, + mfma_variant=ScaledMMAType.F32_16x16x128_F8F6F4, + ) + + # Override to use shared memory address space with global-to-shared loads, + # since the C++ backend doesn't yet support vector.maskedload (used by + # the global address space path). + hyperparams[tkl.sym.ADDRESS_SPACE] = SHARED_ADDRESS_SPACE + hyperparams[tkl.sym.B_ADDRESS_SPACE] = SHARED_ADDRESS_SPACE + + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + backend="asm", + wave_runtime=True, + compile_to_mlir=False, + use_global_to_shared=True, + ) + options = set_default_run_config(options) + + kernel_info = capture_wave_kernel_info(options, splitk_gemm) + + test_id = f"splitk_mxfp4_bf16_{shape[0]}x{shape[1]}x{shape[2]}_s{num_splits}" + + cpp_result = compiler.compile_full( + kernel_info.mlir_text, kernel_info.workgroup_size + ) + + if dump_asm: + with open(f"/tmp/{test_id}_mlir.txt", "w") as f: + f.write(kernel_info.mlir_text) + if cpp_result.success: + with open(f"/tmp/{test_id}_cpp.s", "w") as f: + f.write(cpp_result.asm_text) + + if not cpp_result.success: + pytest.fail(f"C++ compilation failed: {cpp_result.error_message}") + + # Verify the assembly contains buffer_atomic_pk_add_bf16 + assert ( + "buffer_atomic_pk_add_bf16" in cpp_result.asm_text + ), "Expected buffer_atomic_pk_add_bf16 in generated assembly" + + m, n, k = shape + x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs( + shape, device=torch.device("cpu") + ) + torch_ref = torchScaledGemmMXFP4(x, w, x_scales, w_scales) + + x_gpu = x.cuda() + w_t_gpu = w.T.contiguous().cuda() + x_scales_gpu = x_scales.cuda() + w_scales_gpu = w_scales.cuda() + c_gpu = torch.zeros(m, n, dtype=torch.bfloat16, device="cuda") + + binary_path = cpp_result.binary_path + kernel_name = cpp_result.get_kernel_name() or kernel_info.kernel_name + block = kernel_info.workgroup_size + lds_size = kernel_info.lds_size + grid = kernel_info.grid_size + + run_with_wave_runtime( + binary_path=binary_path, + inputs=[x_gpu, x_scales_gpu, w_t_gpu, w_scales_gpu], + outputs=[c_gpu], + grid=grid, + block=block, + shared_memory_bytes=lds_size, + func_name=kernel_name, + ) + + # BF16 atomic accumulation loses precision vs f32 reference: + # each split truncates f32→bf16 (up to 0.5 ULP) before the atomic add. + # atol=1.0 covers the worst-case bf16 ULP (0.5) at output magnitudes ~64. + assert_close(c_gpu.cpu().to(torch.float32), torch_ref, rtol=5e-2, atol=1.0) + + +# ============================================================================= +# Test: Split-K MXFP4 Assembly Emission (Lit-style inspection test) +# ============================================================================= + + +@pytest.mark.run_e2e +@pytest.mark.parametrize( + "shape,num_splits", + [ + ((256, 256, 256), 2), + ], +) +def test_splitk_mxfp4_bf16_asm_emission(shape, num_splits, compiler): + """Assembly emission test for split-K MXFP4 GEMM with bf16 atomics. + + Captures MLIR from the splitk mxfp4 bf16 kernel, compiles it to + assembly via waveasm-translate, and validates the emitted instructions. + Both the MLIR and assembly are always saved to /tmp/ for inspection. + + This serves as a lit-style test to see exactly what the C++ backend + emits for the full splitk mxfp4 bf16 atomic pipeline. + + Requires CDNA4 (gfx950+) for both scaled MFMA and bf16 atomics. + """ + if not is_cdna4(): + pytest.skip("Split-K MXFP4 with bf16 atomics requires gfx950+ (CDNA4)") + + skip_if_no_gpu() + skip_if_no_wave_lang() + + import wave_lang.kernel.lang as tkl + from wave_lang.kernel.lang.global_symbols import SHARED_ADDRESS_SPACE + from wave_lang.kernel.wave.compile import WaveCompileOptions + from wave_lang.kernel.wave.constraints import ScaledMMAType + from wave_lang.kernel.wave.templates.gemm import get_splitk_mxfp4_gemm_kernel + from wave_lang.kernel.wave.utils.run_utils import set_default_run_config + + splitk_gemm, hyperparams = get_splitk_mxfp4_gemm_kernel( + shape, + num_splits=num_splits, + mfma_variant=ScaledMMAType.F32_16x16x128_F8F6F4, + ) + + hyperparams[tkl.sym.ADDRESS_SPACE] = SHARED_ADDRESS_SPACE + hyperparams[tkl.sym.B_ADDRESS_SPACE] = SHARED_ADDRESS_SPACE + + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + backend="asm", + wave_runtime=True, + compile_to_mlir=False, + use_global_to_shared=True, + ) + options = set_default_run_config(options) + + kernel_info = capture_wave_kernel_info(options, splitk_gemm) + + test_id = ( + f"splitk_mxfp4_bf16_asm_emission_{shape[0]}x{shape[1]}x{shape[2]}_s{num_splits}" + ) + + cpp_result = compiler.compile_full( + kernel_info.mlir_text, kernel_info.workgroup_size + ) + + # Always save MLIR and assembly for inspection + with open(f"/tmp/{test_id}_mlir.txt", "w") as f: + f.write(kernel_info.mlir_text) + if cpp_result.success: + with open(f"/tmp/{test_id}_cpp.s", "w") as f: + f.write(cpp_result.asm_text) + + if not cpp_result.success: + pytest.fail(f"C++ compilation failed: {cpp_result.error_message}") + + asm = cpp_result.asm_text + + # Validate expected instructions are present + assert ( + "buffer_atomic_pk_add_bf16" in asm + ), "Expected buffer_atomic_pk_add_bf16 in assembly" + assert ( + "v_mfma_scale_f32_16x16x128_f8f6f4" in asm + ), "Expected v_mfma_scale_f32_16x16x128_f8f6f4 in assembly" + assert "v_cvt_pk_bf16_f32" in asm, "Expected v_cvt_pk_bf16_f32 in assembly" + + # Count key instructions for diagnostic output + lines = asm.split("\n") + atomic_count = sum(1 for l in lines if "buffer_atomic_pk_add_bf16" in l) + cvt_count = sum(1 for l in lines if "v_cvt_pk_bf16_f32" in l) + mfma_count = sum(1 for l in lines if "v_mfma_scale_f32_16x16x128_f8f6f4" in l) + total_lines = len(lines) + + print(f"\n=== Assembly Emission Report ===") + print(f" MLIR: /tmp/{test_id}_mlir.txt") + print(f" Assembly: /tmp/{test_id}_cpp.s") + print(f" Total assembly lines: {total_lines}") + print(f" buffer_atomic_pk_add_bf16: {atomic_count}") + print(f" v_cvt_pk_bf16_f32: {cvt_count}") + print(f" v_mfma_scale_f32_16x16x128_f8f6f4: {mfma_count}") + print(f"================================\n") + + +# ============================================================================= +# Test: Split-K MXFP4 GEMM with preshuffled scales +# ============================================================================= + + +@pytest.mark.run_e2e +@pytest.mark.parametrize( + "shape,num_splits", + [ + ((256, 256, 256), 2), + ((512, 512, 512), 4), + ], +) +def test_splitk_mxfp4_preshuffle_scales_cpp_backend( + shape, num_splits, compiler, backend, dump_asm +): + """End-to-end test for split-K MXFP4 GEMM with preshuffled E8M0 scales. + + Validates that the C++ WaveASM backend correctly handles + vector.maskedload operations generated by the e8m0_shuffle IndexMapping, + which reads scales from global memory with a non-trivial address + computation (bypassing LDS). + + Requires CDNA4 (gfx950+) for scaled MFMA and bf16 atomics. + """ + if not is_cdna4(): + pytest.skip("MXFP4 with preshuffled scales requires gfx950+ (CDNA4)") + + skip_if_no_gpu() + skip_if_no_wave_lang() + + import torch + from torch.testing import assert_close + + import wave_lang.kernel.lang as tkl + from wave_lang.kernel.lang.global_symbols import ( + SHARED_ADDRESS_SPACE, + ) + from wave_lang.kernel.wave.compile import WaveCompileOptions + from wave_lang.kernel.wave.constraints import ScaledMMAType + from wave_lang.kernel.wave.templates.gemm import get_splitk_mxfp4_gemm_kernel + from wave_lang.kernel.wave.utils.mxfp_utils import ( + e8m0_shuffle, + generate_gemm_afp4wfp4_inputs, + torchScaledGemmMXFP4, + ) + from wave_lang.kernel.wave.utils.run_utils import set_default_run_config + + splitk_gemm, hyperparams = get_splitk_mxfp4_gemm_kernel( + shape, + num_splits=num_splits, + mfma_variant=ScaledMMAType.F32_16x16x128_F8F6F4, + preshuffle_scales=True, + ) + + hyperparams[tkl.sym.ADDRESS_SPACE] = SHARED_ADDRESS_SPACE + hyperparams[tkl.sym.B_ADDRESS_SPACE] = SHARED_ADDRESS_SPACE + + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + backend="asm", + wave_runtime=True, + compile_to_mlir=False, + use_global_to_shared=True, + ) + options = set_default_run_config(options) + + kernel_info = capture_wave_kernel_info(options, splitk_gemm) + + m, n, k = shape + test_id = f"splitk_mxfp4_preshuffle_scales_{m}x{n}x{k}_s{num_splits}" + + cpp_result = compiler.compile_full( + kernel_info.mlir_text, kernel_info.workgroup_size + ) + + if dump_asm: + with open(f"/tmp/{test_id}_mlir.txt", "w") as f: + f.write(kernel_info.mlir_text) + if cpp_result.success: + with open(f"/tmp/{test_id}_cpp.s", "w") as f: + f.write(cpp_result.asm_text) + + if not cpp_result.success: + pytest.fail(f"C++ compilation failed: {cpp_result.error_message}") + + assert ( + "buffer_atomic_pk_add_bf16" in cpp_result.asm_text + ), "Expected buffer_atomic_pk_add_bf16 in generated assembly" + assert ( + "buffer_load_ubyte" in cpp_result.asm_text + ), "Expected buffer_load_ubyte for preshuffled scale reads" + + x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs( + shape, device=torch.device("cpu") + ) + torch_ref = torchScaledGemmMXFP4(x, w, x_scales, w_scales) + + x_scales_sh = e8m0_shuffle(x_scales) + w_scales_sh = e8m0_shuffle(w_scales) + + x_gpu = x.cuda() + w_t_gpu = w.T.contiguous().cuda() + x_scales_sh_gpu = x_scales_sh.cuda() + w_scales_sh_gpu = w_scales_sh.cuda() + c_gpu = torch.zeros(m, n, dtype=torch.bfloat16, device="cuda") + + binary_path = cpp_result.binary_path + kernel_name = cpp_result.get_kernel_name() or kernel_info.kernel_name + block = kernel_info.workgroup_size + lds_size = kernel_info.lds_size + grid = kernel_info.grid_size + + run_with_wave_runtime( + binary_path=binary_path, + inputs=[x_gpu, x_scales_sh_gpu, w_t_gpu, w_scales_sh_gpu], + outputs=[c_gpu], + grid=grid, + block=block, + shared_memory_bytes=lds_size, + func_name=kernel_name, + ) + + # Each split truncates f32→bf16 before the atomic add, so error + # grows roughly linearly with num_splits (~0.5 ULP per split). + atol = max(1.0, num_splits * 0.5) + assert_close(c_gpu.cpu().to(torch.float32), torch_ref, rtol=5e-2, atol=atol) + + +# ============================================================================= +# Test: Split-K MXFP4 GEMM with f32 output (buffer_atomic_add_f32) +# ============================================================================= + + +@pytest.mark.run_e2e +@pytest.mark.parametrize( + "shape,num_splits", + [ + ((256, 256, 256), 2), + ], +) +def test_splitk_mxfp4_f32_atomic_cpp_backend( + shape, num_splits, compiler, backend, dump_asm +): + """End-to-end test for split-K MXFP4 GEMM with f32 atomic accumulation. + + Like test_splitk_mxfp4_bf16_atomic_cpp_backend but uses f32 output + (buffer_atomic_add_f32) instead of bf16 (buffer_atomic_pk_add_bf16). + This helps isolate whether issues are in the bf16 conversion/packing + path or in the split-K atomic accumulation logic itself. + + Requires CDNA4 (gfx950+) for scaled MFMA. + """ + if not is_cdna4(): + pytest.skip("Split-K MXFP4 requires gfx950+ (CDNA4)") + + skip_if_no_gpu() + skip_if_no_wave_lang() + + import math + + import sympy + import torch + from torch.testing import assert_close + + import wave_lang.kernel.lang as tkl + import wave_lang.kernel.wave as tkw + from wave_lang.kernel.lang.global_symbols import ( + GLOBAL_ADDRESS_SPACE, + SHARED_ADDRESS_SPACE, + WORKGROUP_2, + ) + from wave_lang.kernel.wave.compile import WaveCompileOptions + from wave_lang.kernel.wave.constraints import ScaledMMAType + from wave_lang.kernel.wave.utils.general_utils import get_default_scheduling_params + from wave_lang.kernel.wave.utils.mxfp_utils import ( + generate_gemm_afp4wfp4_inputs, + torchScaledGemmMXFP4, + ) + from wave_lang.kernel.wave.utils.run_utils import set_default_run_config + + m, n, k = shape + block_shape = (128, 128, 128) + waves_per_block = (2, 2) + k_per_split = math.ceil(k / num_splits) + + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + S = tkl.sym.S + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + BLOCK_S = tkl.sym.BLOCK_S + K_SPLIT_OFF = tkl.sym.K_SPLIT_OFF + K_SPLIT_LEN = tkl.sym.K_SPLIT_LEN + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + constraints = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.WorkgroupConstraint(S, BLOCK_S, 2), + tkw.TilingConstraint( + K, + BLOCK_K, + iters=sympy.ceiling(K_SPLIT_LEN / BLOCK_K), + start=K_SPLIT_OFF, + ), + tkw.WaveConstraint(M, sympy.floor(BLOCK_M / waves_per_block[0])), + tkw.WaveConstraint(N, sympy.floor(BLOCK_N / waves_per_block[1])), + tkw.HardwareConstraint( + threads_per_wave=64, + mma_type=ScaledMMAType.F32_16x16x128_F8F6F4, + vector_shapes={S: 0}, + ), + ] + + @tkw.wave(constraints) + def splitk_mxfp4_gemm_f32( + a: tkl.Memory[M, K / 2, ADDRESS_SPACE, tkl.i8], + a_scale: tkl.Memory[M, K / 32, ADDRESS_SPACE, tkl.i8], + b: tkl.Memory[N, K / 2, ADDRESS_SPACE, tkl.i8], + b_scale: tkl.Memory[N, K / 32, ADDRESS_SPACE, tkl.i8], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a) + a_reg = tkw.bitcast(a_reg, tkl.f4e2m1fn) + a_scale_reg = tkw.read(a_scale) + a_scale_reg = tkw.bitcast(a_scale_reg, tkl.f8e8m0fnu) + b_reg = tkw.read(b) + b_reg = tkw.bitcast(b_reg, tkl.f4e2m1fn) + b_scale_reg = tkw.read(b_scale) + b_scale_reg = tkw.bitcast(b_scale_reg, tkl.f8e8m0fnu) + acc = tkw.scaled_mma(a_reg, a_scale_reg, b_reg, b_scale_reg, acc) + return acc + + tkw.atomic_add(repeat, c) + + hyperparams = { + ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE, + BLOCK_M: block_shape[0], + BLOCK_N: block_shape[1], + BLOCK_K: block_shape[2], + BLOCK_S: 1, + M: m, + N: n, + K: k, + S: num_splits, + K_SPLIT_OFF: WORKGROUP_2 * k_per_split, + K_SPLIT_LEN: sympy.Min(K, (WORKGROUP_2 + 1) * k_per_split) - K_SPLIT_OFF, + } + for key, value in hyperparams.items(): + if isinstance(value, sympy.Expr): + hyperparams[key] = value.subs(hyperparams) + hyperparams.update(get_default_scheduling_params()) + + hyperparams[ADDRESS_SPACE] = SHARED_ADDRESS_SPACE + + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + backend="asm", + wave_runtime=True, + compile_to_mlir=False, + use_global_to_shared=True, + ) + options = set_default_run_config(options) + + kernel_info = capture_wave_kernel_info(options, splitk_mxfp4_gemm_f32) + + test_id = f"splitk_mxfp4_f32_{shape[0]}x{shape[1]}x{shape[2]}_s{num_splits}" + + cpp_result = compiler.compile_full( + kernel_info.mlir_text, kernel_info.workgroup_size + ) + + if dump_asm: + with open(f"/tmp/{test_id}_mlir.txt", "w") as f: + f.write(kernel_info.mlir_text) + if cpp_result.success: + with open(f"/tmp/{test_id}_cpp.s", "w") as f: + f.write(cpp_result.asm_text) + + if not cpp_result.success: + pytest.fail(f"C++ compilation failed: {cpp_result.error_message}") + + # Verify the assembly contains buffer_atomic_add_f32 (not bf16) + assert ( + "buffer_atomic_add_f32" in cpp_result.asm_text + ), "Expected buffer_atomic_add_f32 in generated assembly" + + x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs( + shape, device=torch.device("cpu") + ) + torch_ref = torchScaledGemmMXFP4(x, w, x_scales, w_scales) + + x_gpu = x.cuda() + w_t_gpu = w.T.contiguous().cuda() + x_scales_gpu = x_scales.cuda() + w_scales_gpu = w_scales.cuda() + c_gpu = torch.zeros(m, n, dtype=torch.float32, device="cuda") + + binary_path = cpp_result.binary_path + kernel_name = cpp_result.get_kernel_name() or kernel_info.kernel_name + block = kernel_info.workgroup_size + lds_size = kernel_info.lds_size + grid = kernel_info.grid_size + + run_with_wave_runtime( + binary_path=binary_path, + inputs=[x_gpu, x_scales_gpu, w_t_gpu, w_scales_gpu], + outputs=[c_gpu], + grid=grid, + block=block, + shared_memory_bytes=lds_size, + func_name=kernel_name, + ) + + assert_close(c_gpu.cpu(), torch_ref, rtol=5e-2, atol=5e-2) + + +# ============================================================================= +# Test: bf16 output GEMM without split-K +# ============================================================================= + + +@pytest.mark.parametrize( + "shape,block_k,config", + [ + ((64, 64, 64), 16, (16, 16, 16, 16)), + ((256, 256, 128), 64, (32, 32, 16, 16)), + ], +) +@pytest.mark.parametrize("use_global_to_shared", _global_to_shared_params()) +def test_bf16_gemm_cpp_backend( + shape, block_k, config, use_global_to_shared, compiler, backend, dump_asm +): + """End-to-end test for GEMM with bf16 output (no split-K) using C++ ASM backend. + + Computes in f32 via MFMA, casts to bf16, then writes to global memory. + Tests the arith.truncf (f32->bf16) + buffer_store path without atomics. + + This helps isolate whether bf16 conversion and stores work correctly + outside of the atomic accumulation context. + """ + skip_if_no_gpu() + skip_if_no_wave_lang() + + import torch + from torch.testing import assert_close + + import wave_lang.kernel.lang as tkl + import wave_lang.kernel.wave as tkw + from wave_lang.kernel.lang.global_symbols import ( + GLOBAL_ADDRESS_SPACE, + SHARED_ADDRESS_SPACE, + ) + from wave_lang.kernel.wave.compile import WaveCompileOptions + from wave_lang.kernel.wave.utils.run_utils import set_default_run_config + from wave_lang.kernel.wave.utils.torch_utils import device_randn, device_zeros + + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + BLOCK_M_SYM = tkl.sym.BLOCK_M + BLOCK_N_SYM = tkl.sym.BLOCK_N + BLOCK_K_SYM = tkl.sym.BLOCK_K + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 + + block_m, block_n, WAVE_M, WAVE_N = config + wave_size = 64 + + assert block_m % WAVE_M == 0 + assert block_n % WAVE_N == 0 + + mma_type = tkw.MMAType.F32_16x16x16_F16 + + constraints = [ + tkw.WorkgroupConstraint(M, BLOCK_M_SYM, 0), + tkw.WorkgroupConstraint(N, BLOCK_N_SYM, 1), + tkw.TilingConstraint(K, BLOCK_K_SYM), + tkw.WaveConstraint(M, WAVE_M), + tkw.WaveConstraint(N, WAVE_N), + tkw.HardwareConstraint( + threads_per_wave=wave_size, + mma_type=mma_type, + ), + ] + + @tkw.wave(constraints) + def bf16_gemm_kernel( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.bf16], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a) + b_reg = tkw.read(b) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + repeat_bf16 = tkw.cast(repeat, tkl.bf16) + tkw.write(repeat_bf16, c) + + m, n, k = shape + a = device_randn((m, k), dtype=torch.float16) + b = device_randn((n, k), dtype=torch.float16) + c = device_zeros((m, n), dtype=torch.bfloat16) + + options = WaveCompileOptions( + subs={ + M: m, + N: n, + K: k, + BLOCK_M_SYM: block_m, + BLOCK_N_SYM: block_n, + BLOCK_K_SYM: block_k, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + }, + canonicalize=True, + backend="asm", + wave_runtime=True, + compile_to_mlir=False, + use_global_to_shared=use_global_to_shared, + ) + options = set_default_run_config(options) + + kernel_info = capture_wave_kernel_info(options, bf16_gemm_kernel) + + g2s_str = "g2s" if use_global_to_shared else "no_g2s" + test_id = f"bf16_gemm_{m}x{n}x{k}_bk{block_k}_{block_m}x{block_n}_{g2s_str}" + + cpp_result = None + if backend in ("cpp", "both"): + cpp_result = compiler.compile_full( + kernel_info.mlir_text, kernel_info.workgroup_size + ) + if not cpp_result.success: + pytest.fail(f"C++ compilation failed: {cpp_result.error_message}") + + if dump_asm: + with open(f"/tmp/{test_id}_mlir.txt", "w") as f: + f.write(kernel_info.mlir_text) + if cpp_result and cpp_result.asm_text: + with open(f"/tmp/{test_id}_cpp.s", "w") as f: + f.write(cpp_result.asm_text) + + if cpp_result is None: + pytest.fail("No backend compiled successfully") + + binary_path = cpp_result.binary_path + kernel_name = cpp_result.get_kernel_name() or kernel_info.kernel_name + block = kernel_info.workgroup_size + lds_size = kernel_info.lds_size + grid = kernel_info.grid_size + + run_with_wave_runtime( + binary_path=binary_path, + inputs=[a, b], + outputs=[c], + grid=grid, + block=block, + shared_memory_bytes=lds_size, + func_name=kernel_name, + ) + + # Validate: C = bf16(A @ B^T) + expected = torch.matmul(a.float(), b.float().T).to(torch.bfloat16) + assert_close(c, expected, atol=1e-2, rtol=1e-2) + + # ============================================================================= # Main # ============================================================================= diff --git a/wave_lang/kernel/wave/constraints.py b/wave_lang/kernel/wave/constraints.py index d49538b8a8..770dbc785d 100644 --- a/wave_lang/kernel/wave/constraints.py +++ b/wave_lang/kernel/wave/constraints.py @@ -9,7 +9,7 @@ from enum import Enum from typing import Callable, Optional -from sympy import Integer, Piecewise, ceiling, floor +from sympy import Integer, Min, Mul, Piecewise, ceiling, floor from .._support.dtype import DataType from .._support.indexing import IndexExpr, IndexSequence, IndexSymbol @@ -760,6 +760,45 @@ def get_index_bound(self, vector_shape: Optional[int]) -> Optional[IndexExpr]: return bound +def _expr_contains_min_with_bound(expr: IndexExpr, bound) -> bool: + """Return True when *expr* (recursively) contains ``Min(bound, ...)``.""" + if isinstance(expr, Min): + if any(a == bound for a in expr.args): + return True + for arg in getattr(expr, "args", ()): + if _expr_contains_min_with_bound(arg, bound): + return True + return False + + +def _work_may_exceed_dim(work_bound: IndexExpr, dim_bound: IndexExpr) -> bool: + """Conservatively decide whether *work_bound* can exceed *dim_bound*. + + Returns ``False`` (no overshoot) when we can prove that the tiled work + never exceeds the tensor dimension. In particular this handles the + split-K pattern where ``work_bound = tile * ceiling(Min(dim, f(wg)) / tile)`` + and ``dim`` is tile-aligned: ``ceiling(Min(dim, x) / tile) * tile <= dim``. + + Falls back to ``True`` (bounds needed) when the relationship cannot be + determined. + """ + if work_bound == dim_bound: + return False + if isinstance(work_bound, (int, Integer)) and isinstance(dim_bound, (int, Integer)): + return int(work_bound) > int(dim_bound) + if isinstance(dim_bound, (int, Integer)): + dim_int = int(dim_bound) + if _expr_contains_min_with_bound(work_bound, dim_int): + tile = None + if isinstance(work_bound, Mul) and len(work_bound.args) == 2: + for a in work_bound.args: + if isinstance(a, (int, Integer)): + tile = int(a) + if tile is not None and dim_int % tile == 0: + return False + return True + + @dataclass class TilingConstraint(DistributionConstraint): """ @@ -819,7 +858,7 @@ def dim_bound(self) -> IndexExpr: def get_index_bound(self, vector_shape: Optional[int]) -> Optional[IndexExpr]: bound = None - if subs_idxc(self.work_bound) != subs_idxc(self.dim_bound): + if _work_may_exceed_dim(subs_idxc(self.work_bound), subs_idxc(self.dim_bound)): bound = self.dim_bound if ( diff --git a/wave_lang/kernel/wave/templates/gemm.py b/wave_lang/kernel/wave/templates/gemm.py index 181bb51f06..2322b6d799 100644 --- a/wave_lang/kernel/wave/templates/gemm.py +++ b/wave_lang/kernel/wave/templates/gemm.py @@ -243,6 +243,207 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: return splitk_gemm, hyperparams +def get_splitk_mxfp4_gemm_kernel( + shape: tuple[int, int, int], + num_splits: int, + mfma_variant: "ScaledMMAType", + threads_per_wave: int = 64, + block_shape: Optional[tuple[int, int, int]] = None, + waves_per_block: Optional[tuple[int, int]] = None, + preshuffle_B: bool = False, + preshuffle_scales: bool = False, +): + """ + Creates a split-K MXFP4 GEMM kernel that parallelizes the K dimension + across multiple workgroups using atomic_add accumulation. + + Inputs are packed MXFP4 data (2 values per byte stored as i8) with + E8M0 scale tensors (1 scale per 32 data elements). The kernel uses + scaled_mma for the compute and atomic_add for the final accumulation. + + The caller must zero-initialize the output tensor C before launch. + + Args: + shape: (M, N, K) problem dimensions (K is the logical element count). + num_splits: Number of splits along the K dimension. + mfma_variant: Scaled MMA instruction type (e.g. ScaledMMAType.F32_16x16x128_F8F6F4). + threads_per_wave: Threads per wave (64 for CDNA). + block_shape: (BLOCK_M, BLOCK_N, BLOCK_K) tile sizes. + waves_per_block: (waves_M, waves_N) waves per workgroup. + preshuffle_B: If True, read B with the aiter preshuffled IndexMapping. + The caller must pre-shuffle b with preshuffle_b_aiter(). + preshuffle_scales: If True, read a_scale and b_scale from GLOBAL memory + using the e8m0_shuffle IndexMapping. The caller must pre-shuffle + scales with e8m0_shuffle(). + """ + + if not block_shape: + block_shape = (128, 128, 128) + + if not waves_per_block: + waves_per_block = (2, 2) + + m, n, k = shape + k_per_split = math.ceil(k / num_splits) + if k_per_split < block_shape[2]: + raise ValueError( + f"K per split ({k_per_split}) is less than BLOCK_K ({block_shape[2]}). " + f"Reduce num_splits or BLOCK_K so that each split has at least BLOCK_K elements." + ) + + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + S = tkl.sym.S + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + BLOCK_S = tkl.sym.BLOCK_S + K_SPLIT_OFF = tkl.sym.K_SPLIT_OFF + K_SPLIT_LEN = tkl.sym.K_SPLIT_LEN + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + B_ADDRESS_SPACE = tkl.sym.B_ADDRESS_SPACE + K_PACKED = tkl.sym.K_PACKED + K_SCALE_SHUFFLED = tkl.sym.K_SCALE_SHUFFLED + + k_packed_val = k // 2 + k_scale_shuffled_val = (((k // 32) + 7) // 8) * 8 + + b_preshuffle_mapping = None + if preshuffle_B: + n_it = tkw.IndexMapping.iterator(0) + k_it = tkw.IndexMapping.iterator(1) + within_nblk = ( + (k_it // 32) * 512 + ((k_it // 16) % 2) * 256 + (n_it % 16) * 16 + k_it % 16 + ) + b_preshuffle_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={ + N: (n_it // 16) * 16 + within_nblk // K_PACKED, + K: within_nblk % K_PACKED, + }, + outputs={N: n_it, K: k_it}, + ) + + a_scale_mapping = None + b_scale_mapping = None + if preshuffle_scales: + i = tkw.IndexMapping.iterator(0) + j = tkw.IndexMapping.iterator(1) + _flat_a = ( + (j // 32) * ((k_scale_shuffled_val // 8) * 256) + + (i // 8) * 256 + + ((i % 8) % 4) * 64 + + ((j % 32) % 16) * 4 + + (((i % 8) // 4) * 2) + + ((j % 32) // 16) + ) + a_scale_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={ + M: _flat_a // k_scale_shuffled_val, + K: _flat_a % k_scale_shuffled_val, + }, + outputs={K: i, M: j}, + ) + kk = tkw.IndexMapping.iterator(0) + n_s = tkw.IndexMapping.iterator(1) + _flat_b = ( + (n_s // 32) * ((k_scale_shuffled_val // 8) * 256) + + (kk // 8) * 256 + + ((kk % 8) % 4) * 64 + + ((n_s % 32) % 16) * 4 + + (((kk % 8) // 4) * 2) + + ((n_s % 32) // 16) + ) + b_scale_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={ + N: _flat_b // k_scale_shuffled_val, + K: _flat_b % k_scale_shuffled_val, + }, + outputs={K: kk, N: n_s}, + ) + + a_scale_space = GLOBAL_ADDRESS_SPACE if preshuffle_scales else ADDRESS_SPACE + b_scale_space = GLOBAL_ADDRESS_SPACE if preshuffle_scales else ADDRESS_SPACE + + constraints: list[tkw.Constraint] = [ + tkw.WorkgroupConstraint(M, BLOCK_M, 0), + tkw.WorkgroupConstraint(N, BLOCK_N, 1), + tkw.WorkgroupConstraint(S, BLOCK_S, 2), + tkw.TilingConstraint( + K, + BLOCK_K, + iters=sympy.ceiling(K_SPLIT_LEN / BLOCK_K), + start=K_SPLIT_OFF, + ), + tkw.WaveConstraint(M, sympy.floor(BLOCK_M / waves_per_block[0])), + tkw.WaveConstraint(N, sympy.floor(BLOCK_N / waves_per_block[1])), + tkw.HardwareConstraint( + threads_per_wave=threads_per_wave, + mma_type=mfma_variant, + vector_shapes={S: 0}, + ), + ] + + @tkw.wave(constraints) + def splitk_mxfp4_gemm( + a: tkl.Memory[M, K / 2, ADDRESS_SPACE, tkl.i8], + a_scale: tkl.Memory[M, K / 32, a_scale_space, tkl.i8], + b: tkl.Memory[N, K / 2, B_ADDRESS_SPACE, tkl.i8], + b_scale: tkl.Memory[N, K / 32, b_scale_space, tkl.i8], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.bf16], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a) + a_reg = tkw.bitcast(a_reg, tkl.f4e2m1fn) + a_scale_reg = tkw.read(a_scale, mapping=a_scale_mapping) + a_scale_reg = tkw.bitcast(a_scale_reg, tkl.f8e8m0fnu) + b_reg = tkw.read(b, mapping=b_preshuffle_mapping) + b_reg = tkw.bitcast(b_reg, tkl.f4e2m1fn) + b_scale_reg = tkw.read(b_scale, mapping=b_scale_mapping) + b_scale_reg = tkw.bitcast(b_scale_reg, tkl.f8e8m0fnu) + acc = tkw.scaled_mma(a_reg, a_scale_reg, b_reg, b_scale_reg, acc) + return acc + + repeat_bf16 = tkw.cast(repeat, tkl.bf16) + tkw.atomic_add(repeat_bf16, c) + + hyperparams = { + # Use global address space to work around a compiler bug where + # shared-memory write indices are incorrect for fractional K + # dimensions (K/2, K/32) when TilingConstraint has a non-zero start. + ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE, + B_ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE, + BLOCK_M: block_shape[0], + BLOCK_N: block_shape[1], + BLOCK_K: block_shape[2], + BLOCK_S: 1, + M: m, + N: n, + K: k, + S: num_splits, + K_SPLIT_OFF: WORKGROUP_2 * k_per_split, + K_SPLIT_LEN: sympy.Min(K, (WORKGROUP_2 + 1) * k_per_split) - K_SPLIT_OFF, + } + for key, value in hyperparams.items(): + if isinstance(value, sympy.Expr): + hyperparams[key] = value.subs(hyperparams) + + if preshuffle_B: + hyperparams[K_PACKED] = k_packed_val + if preshuffle_scales: + hyperparams[K_SCALE_SHUFFLED] = k_scale_shuffled_val + + hyperparams.update(get_default_scheduling_params()) + + return splitk_mxfp4_gemm, hyperparams + + def get_gemm_kernel_transpose_a_b( shape: tuple[int, int, int], dynamic_dims: bool | tuple[bool, bool, bool], diff --git a/wave_lang/kernel/wave/utils/general_utils.py b/wave_lang/kernel/wave/utils/general_utils.py index eb7107c051..e0af2d29d1 100644 --- a/wave_lang/kernel/wave/utils/general_utils.py +++ b/wave_lang/kernel/wave/utils/general_utils.py @@ -173,6 +173,13 @@ def remove_global_indexing( tiling_constraints = [c for c in constraints if isinstance(c, TilingConstraint)] workgroup_ids = [WORKGROUP_0, WORKGROUP_1, WORKGROUP_2] subs = {w: 0 for w in workgroup_ids} + # Zero out tiling constraint starts (e.g. K_SPLIT_OFF) alongside + # workgroup IDs. These starts are global offsets that must be removed + # *before* any dimension scaling (K → K/32 for MXFP4 scales) so that + # the subtraction later doesn't mix scaled and unscaled units. + for tc in tiling_constraints: + if tc.start != sympy.Integer(0) and isinstance(tc.start, sympy.Symbol): + subs[tc.start] = 0 new_index = {key: safe_subs(index[key], subs) for key in index} for key in new_index: @@ -180,13 +187,14 @@ def remove_global_indexing( new_dim = new_index[key] if new_dim.has(constraint.induction_var): new_dim = new_dim.subs({constraint.induction_var: 0}) + local_start = safe_subs(constraint.start, subs) if isinstance(new_dim, IndexSequence): - new_dim.start = new_dim.start - constraint.start + new_dim.start = new_dim.start - local_start else: assert isinstance( new_dim, sympy.Basic ), f"new_dim is not a sympy expression: {new_dim}" - new_dim = new_dim - constraint.start + new_dim = new_dim - local_start new_index[key] = new_dim return new_index From ca5f8e8e0c1c38bf640820928a632a41e0088925 Mon Sep 17 00:00:00 2001 From: William G Hatch Date: Tue, 24 Feb 2026 15:22:11 -0700 Subject: [PATCH 2/3] bunch of fixes Signed-off-by: William G Hatch --- tests/kernel/wave_gemm_test.py | 8 ++- .../lib/Transforms/TranslateFromMLIR.cpp | 29 ++++++++-- .../Transforms/handlers/VectorHandlers.cpp | 20 +++++++ .../wave_asm/test/e2e/test_asm_backend_e2e.py | 25 +++++--- wave_lang/kernel/wave/constraints.py | 57 +++++++++++-------- wave_lang/kernel/wave/utils/general_utils.py | 7 +-- 6 files changed, 102 insertions(+), 44 deletions(-) diff --git a/tests/kernel/wave_gemm_test.py b/tests/kernel/wave_gemm_test.py index 66a5af6d22..dac0da59c0 100644 --- a/tests/kernel/wave_gemm_test.py +++ b/tests/kernel/wave_gemm_test.py @@ -3667,4 +3667,10 @@ def testSplitKMxfp4Gemm( c_gpu = device_zeros(m, n, dtype=torch.bfloat16) splitk_gemm(x_gpu, x_scales_gpu, w_t_gpu, w_scales_gpu, c_gpu) - assert_close(c_gpu.cpu().to(torch.float32), torch_ref, rtol=5e-2, atol=5e-2) + # Each split accumulates a partial f32 sum, casts it to bf16, then atomically + # adds it into the bf16 output. At partial-sum magnitudes of ~250-430 the bf16 + # ULP is 2-4, so with 2-4 splits the worst-case absolute rounding error is + # num_splits * max_partial_magnitude * bf16_eps ≈ 4 * 430 * 2^-7 ≈ 13. + # Observed worst-case across all parametrized shapes is ~3, so atol=4.0 is + # generous but justified by the bf16 accumulation arithmetic. + assert_close(c_gpu.cpu().to(torch.float32), torch_ref, rtol=1e-1, atol=4.0) diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/TranslateFromMLIR.cpp b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/TranslateFromMLIR.cpp index 345f6b8f99..39e9a02b67 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/TranslateFromMLIR.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/TranslateFromMLIR.cpp @@ -404,12 +404,31 @@ VOffsetResult computeVOffsetFromIndices(MemRefType memrefType, Value lookupSRD(Value memref, TranslationContext &ctx, Location loc) { auto &builder = ctx.getBuilder(); - if (auto srdIdx = ctx.getSRDIndex(memref)) { - auto sregType = ctx.createSRegType(4, 4); - return PrecoloredSRegOp::create(builder, loc, sregType, *srdIdx, 4); + + // Walk up through memref cast/reinterpret operations to find the + // underlying binding that has an SRD registered. This handles the + // pattern: binding.subspan -> reinterpret_cast -> vector.load + // where the SRD is registered on the binding result but the load + // uses the reinterpret_cast result. + Value current = memref; + while (current) { + if (auto srdIdx = ctx.getSRDIndex(current)) { + auto sregType = ctx.createSRegType(4, 4); + return PrecoloredSRegOp::create(builder, loc, sregType, *srdIdx, 4); + } + if (auto mapped = ctx.getMapper().getMapped(current)) + return *mapped; + + if (auto castOp = current.getDefiningOp()) + current = castOp.getSource(); + else if (auto castOp = current.getDefiningOp()) + current = castOp.getSource(); + else if (auto subviewOp = current.getDefiningOp()) + current = subviewOp.getSource(); + else + break; } - if (auto mapped = ctx.getMapper().getMapped(memref)) - return *mapped; + auto sregType = ctx.createSRegType(4, 4); return PrecoloredSRegOp::create(builder, loc, sregType, 8, 4); } diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/handlers/VectorHandlers.cpp b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/handlers/VectorHandlers.cpp index 214f783c2f..384b1b3a0c 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/handlers/VectorHandlers.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/handlers/VectorHandlers.cpp @@ -215,6 +215,26 @@ LogicalResult handleVectorExtractStridedSlice(Operation *op, size = cast(sizes[0]).getInt(); } + // Check for sub-dword element extraction. When elements are smaller than + // 32 bits (e.g. vector<2xi8>), multiple elements are packed in a single + // VGPR. A register-level offset would read the wrong VGPR; instead we + // emit a right-shift to bring the desired byte(s) to bit-position 0. + auto srcVecType = extractOp.getSourceVectorType(); + int64_t elemBits = srcVecType.getElementType().getIntOrFloatBitWidth(); + if (elemBits < 32 && offset != 0) { + int64_t bitShift = offset * elemBits; + auto shiftImm = + ConstantOp::create(builder, loc, ctx.createImmType(bitShift), bitShift); + auto shifted = V_LSHRREV_B32::create(builder, loc, ctx.createVRegType(), + shiftImm, *src); + ctx.getMapper().mapValue(extractOp.getResult(), shifted); + return success(); + } + if (elemBits < 32 && offset == 0) { + ctx.getMapper().mapValue(extractOp.getResult(), *src); + return success(); + } + // Get the source register type to find the base physical register Type srcType = src->getType(); diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/e2e/test_asm_backend_e2e.py b/wave_lang/kernel/wave/asm/wave_asm/test/e2e/test_asm_backend_e2e.py index 1a776b0b74..79dc19ca03 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/e2e/test_asm_backend_e2e.py +++ b/wave_lang/kernel/wave/asm/wave_asm/test/e2e/test_asm_backend_e2e.py @@ -2033,8 +2033,8 @@ def test_splitk_mxfp4_bf16_asm_emission(shape, num_splits, compiler): @pytest.mark.parametrize( "shape,num_splits", [ - ((256, 256, 256), 2), - ((512, 512, 512), 4), + ((512, 512, 2048), 2), + ((512, 512, 2048), 4), ], ) def test_splitk_mxfp4_preshuffle_scales_cpp_backend( @@ -2043,9 +2043,15 @@ def test_splitk_mxfp4_preshuffle_scales_cpp_backend( """End-to-end test for split-K MXFP4 GEMM with preshuffled E8M0 scales. Validates that the C++ WaveASM backend correctly handles - vector.maskedload operations generated by the e8m0_shuffle IndexMapping, - which reads scales from global memory with a non-trivial address - computation (bypassing LDS). + vector<4xi8> dword loads (buffer_load_dword) generated by the + e8m0_shuffle IndexMapping after merge_contiguous_reads combines scalar + scale reads into 4x vector loads, reading scales from global memory with + a non-trivial address computation (bypassing LDS). + + BLOCK_K=256 is required so that each thread reads 8 scale elements + (256/32), which merge into 2 groups of 4 -> vector<4xi8>. The + e8m0_shuffle layout also requires K/32 >= 64 (K >= 2048) for the + groups to land contiguously in the row-major scale tensor. Requires CDNA4 (gfx950+) for scaled MFMA and bf16 atomics. """ @@ -2077,6 +2083,7 @@ def test_splitk_mxfp4_preshuffle_scales_cpp_backend( num_splits=num_splits, mfma_variant=ScaledMMAType.F32_16x16x128_F8F6F4, preshuffle_scales=True, + block_shape=(128, 128, 256), ) hyperparams[tkl.sym.ADDRESS_SPACE] = SHARED_ADDRESS_SPACE @@ -2115,8 +2122,8 @@ def test_splitk_mxfp4_preshuffle_scales_cpp_backend( "buffer_atomic_pk_add_bf16" in cpp_result.asm_text ), "Expected buffer_atomic_pk_add_bf16 in generated assembly" assert ( - "buffer_load_ubyte" in cpp_result.asm_text - ), "Expected buffer_load_ubyte for preshuffled scale reads" + "buffer_load_dword" in cpp_result.asm_text + ), "Expected buffer_load_dword (vector<4xi8>) for preshuffled scale reads" x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs( shape, device=torch.device("cpu") @@ -2149,8 +2156,8 @@ def test_splitk_mxfp4_preshuffle_scales_cpp_backend( ) # Each split truncates f32→bf16 before the atomic add, so error - # grows roughly linearly with num_splits (~0.5 ULP per split). - atol = max(1.0, num_splits * 0.5) + # grows roughly linearly with num_splits. + atol = max(2.0, num_splits * 1.0) assert_close(c_gpu.cpu().to(torch.float32), torch_ref, rtol=5e-2, atol=atol) diff --git a/wave_lang/kernel/wave/constraints.py b/wave_lang/kernel/wave/constraints.py index 770dbc785d..9786f124e3 100644 --- a/wave_lang/kernel/wave/constraints.py +++ b/wave_lang/kernel/wave/constraints.py @@ -760,45 +760,52 @@ def get_index_bound(self, vector_shape: Optional[int]) -> Optional[IndexExpr]: return bound -def _expr_contains_min_with_bound(expr: IndexExpr, bound) -> bool: - """Return True when *expr* (recursively) contains ``Min(bound, ...)``.""" - if isinstance(expr, Min): - if any(a == bound for a in expr.args): - return True - for arg in getattr(expr, "args", ()): - if _expr_contains_min_with_bound(arg, bound): - return True - return False - - def _work_may_exceed_dim(work_bound: IndexExpr, dim_bound: IndexExpr) -> bool: """Conservatively decide whether *work_bound* can exceed *dim_bound*. - Returns ``False`` (no overshoot) when we can prove that the tiled work - never exceeds the tensor dimension. In particular this handles the - split-K pattern where ``work_bound = tile * ceiling(Min(dim, f(wg)) / tile)`` - and ``dim`` is tile-aligned: ``ceiling(Min(dim, x) / tile) * tile <= dim``. - - Falls back to ``True`` (bounds needed) when the relationship cannot be - determined. + Returns ``False`` when we can prove the tiled work never overshoots the + tensor dimension; ``True`` (bounds check needed) otherwise. """ if work_bound == dim_bound: return False + if isinstance(work_bound, (int, Integer)) and isinstance(dim_bound, (int, Integer)): return int(work_bound) > int(dim_bound) + + # The Min caps the numerator at dim, so + # Min(dim, x) <= dim + # ceiling(Min(dim, x) / tile) <= ceiling(dim / tile) = dim / tile + # tile * ceiling(Min(dim, x) / tile) <= dim + # as long as dim is evenly divisible by tile. if isinstance(dim_bound, (int, Integer)): dim_int = int(dim_bound) - if _expr_contains_min_with_bound(work_bound, dim_int): - tile = None - if isinstance(work_bound, Mul) and len(work_bound.args) == 2: - for a in work_bound.args: - if isinstance(a, (int, Integer)): - tile = int(a) - if tile is not None and dim_int % tile == 0: + tile, ceil_expr = _extract_tile_and_ceiling(work_bound) + if tile is not None and dim_int % tile == 0 and ceil_expr is not None: + numerator = (ceil_expr.args[0] * tile).simplify() + if isinstance(numerator, Min) and any(a == dim_int for a in numerator.args): return False + + # Cannot prove safety -- assume bounds check is needed. return True +def _extract_tile_and_ceiling( + expr: IndexExpr, +) -> tuple[int | None, IndexExpr | None]: + """Extract ``(tile, ceiling_expr)`` from ``tile * ceiling(...)``. + + Returns ``(None, None)`` when the expression does not match. + """ + if not isinstance(expr, Mul) or len(expr.args) != 2: + return None, None + a, b = expr.args + if isinstance(a, (int, Integer)) and isinstance(b, ceiling): + return int(a), b + if isinstance(b, (int, Integer)) and isinstance(a, ceiling): + return int(b), a + return None, None + + @dataclass class TilingConstraint(DistributionConstraint): """ diff --git a/wave_lang/kernel/wave/utils/general_utils.py b/wave_lang/kernel/wave/utils/general_utils.py index e0af2d29d1..cf547536b1 100644 --- a/wave_lang/kernel/wave/utils/general_utils.py +++ b/wave_lang/kernel/wave/utils/general_utils.py @@ -173,10 +173,9 @@ def remove_global_indexing( tiling_constraints = [c for c in constraints if isinstance(c, TilingConstraint)] workgroup_ids = [WORKGROUP_0, WORKGROUP_1, WORKGROUP_2] subs = {w: 0 for w in workgroup_ids} - # Zero out tiling constraint starts (e.g. K_SPLIT_OFF) alongside - # workgroup IDs. These starts are global offsets that must be removed - # *before* any dimension scaling (K → K/32 for MXFP4 scales) so that - # the subtraction later doesn't mix scaled and unscaled units. + # Zero out tiling constraint starts alongside workgroup IDs. A non-zero + # start is a global base offset, just like a workgroup-ID term, and must be + # removed to produce a shared-memory-local index. for tc in tiling_constraints: if tc.start != sympy.Integer(0) and isinstance(tc.start, sympy.Symbol): subs[tc.start] = 0 From 74c2c142bf92d7bf7d94b53b11ab555711c6f48b Mon Sep 17 00:00:00 2001 From: William G Hatch Date: Thu, 26 Feb 2026 14:48:35 -0700 Subject: [PATCH 3/3] WIP - fix tests --- .../BufferLoadStrengthReduction.cpp | 31 +++-- .../lib/Transforms/LinearScanPass.cpp | 43 ++++--- .../Transforms/handlers/AMDGPUHandlers.cpp | 15 ++- .../Transforms/handlers/VectorHandlers.cpp | 118 +++++++++++++++--- .../wave_asm/test/e2e/test_asm_backend_e2e.py | 1 + 5 files changed, 158 insertions(+), 50 deletions(-) diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/BufferLoadStrengthReduction.cpp b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/BufferLoadStrengthReduction.cpp index 8094d05c15..8181bd35e3 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/BufferLoadStrengthReduction.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/BufferLoadStrengthReduction.cpp @@ -407,20 +407,27 @@ static void applyStrengthReduction(LoopOp loopOp) { if (candidates.empty()) return; - // Group by (SRD, stride). Loads sharing the same SRD and same constant - // stride share one soffset iter_arg; different strides get separate ones. + // Group by (SRD, stride, original soffset). Loads sharing the same SRD, + // same constant stride, AND same original soffset share one soffset + // iter_arg; different strides or soffsets get separate groups. + // The original soffset is preserved as the initial value of the new + // soffset iter_arg, which is critical for split-K kernels where the + // scale buffer_load has a non-zero soffset (e.g., block_z * 4). struct SRDGroup { Value srd; int64_t stride; Value strideSGPR; + Value originalSoffset; }; SmallVector groups; SmallVector candidateGroupIdx; for (auto [i, info] : llvm::enumerate(candidates)) { + Value origSoff = info.loadOp->getOperand(2); std::optional matchIdx; for (auto [g, group] : llvm::enumerate(groups)) { - if (group.srd == info.srd && group.stride == candidateStrides[i]) { + if (group.srd == info.srd && group.stride == candidateStrides[i] && + group.originalSoffset == origSoff) { matchIdx = g; break; } @@ -433,7 +440,8 @@ static void applyStrengthReduction(LoopOp loopOp) { ConstantOp::create(builder, loc, strideImm, candidateStrides[i]); Value strideSGPR = S_MOV_B32::create(builder, loc, sregType, strideConst); candidateGroupIdx.push_back(groups.size()); - groups.push_back({info.srd, candidateStrides[i], strideSGPR}); + groups.push_back( + {info.srd, candidateStrides[i], strideSGPR, origSoff}); } } @@ -445,13 +453,18 @@ static void applyStrengthReduction(LoopOp loopOp) { initialVoffsets.push_back(voff); } - // Build expanded init args: old args + soffset per SRD group (starts at 0). + // Build expanded init args: old args + soffset per SRD group. + // Each group's initial soffset is the original soffset from its buffer_load. + // This preserves non-zero soffsets (e.g., block_z * scale_stride in split-K). SmallVector expandedInit(initArgs.begin(), initArgs.end()); unsigned soffsetArgBase = expandedInit.size(); - auto zeroImm = builder.getType(0); - auto zeroConst = ConstantOp::create(builder, loc, zeroImm, 0); - auto zeroSoff = S_MOV_B32::create(builder, loc, sregType, zeroConst); - expandedInit.append(groups.size(), zeroSoff); + for (auto &group : groups) { + Value initSoff = group.originalSoffset; + if (isa(initSoff.getType())) { + initSoff = S_MOV_B32::create(builder, loc, sregType, initSoff); + } + expandedInit.push_back(initSoff); + } // Build new loop. auto newLoop = LoopOp::create(builder, loc, expandedInit); diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/LinearScanPass.cpp b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/LinearScanPass.cpp index 3d3d2e86e9..6cd786c310 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/LinearScanPass.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/LinearScanPass.cpp @@ -79,38 +79,36 @@ struct LinearScanPass int64_t maxSGPRs = 104; int64_t maxAGPRs = 256; - /// Create a fresh zero-initialized copy of a duplicate init arg to ensure - /// unique physical registers. This is used when CSE merges identical - /// zero-initialized accumulators, causing multiple loop block args to be - /// tied to the same init value. Each block arg needs its own physical - /// register, so we create a new v_mov_b32/s_mov_b32 from zero. + /// Create a fresh copy of a duplicate init arg to ensure unique physical + /// registers. This is used when CSE or other passes cause multiple loop + /// block args to be tied to the same init value. Each block arg needs its + /// own physical register, so we create a new mov that copies the value. /// - /// PRECONDITION: This should only be called for zero-initialized init args - /// (e.g., v_mov_b32 %vreg, 0). Calling it for non-zero init args will - /// produce incorrect zero values silently. - Value createZeroInitCopy(LoopOp loopOp, Value initArg) { + /// For VGPR/AGPR: always zero-initialize. Duplicate VGPR/AGPR init args + /// are only produced by CSE merging zero-initialized MFMA accumulators; + /// v_mov_b32 can't copy a multi-register source in a single instruction. + /// + /// For SGPR: copy the actual value. Duplicate SGPR init args can carry + /// non-zero values (e.g., soffsets from BufferLoadStrengthReduction). + Value createInitArgCopy(LoopOp loopOp, Value initArg) { OpBuilder copyBuilder(loopOp); auto loc = loopOp.getLoc(); - // Create a zero immediate. We always use 0 because this function is - // only called for duplicate init args produced by CSE merging identical - // zero-initialized values (e.g., v_mov_b32 vN, 0). - auto immType = ImmType::get(loopOp->getContext(), 0); - Value zeroImm = ConstantOp::create(copyBuilder, loc, immType, 0); - if (isAGPRType(initArg.getType())) { - // AGPR zero-init: V_MOV_B32 with ARegType destination. - // The assembly emitter will produce v_accvgpr_write_b32 aN, 0. auto aregType = cast(initArg.getType()); + auto immType = ImmType::get(loopOp->getContext(), 0); + Value zeroImm = ConstantOp::create(copyBuilder, loc, immType, 0); return V_MOV_B32::create(copyBuilder, loc, aregType, zeroImm); } if (isVGPRType(initArg.getType())) { auto vregType = cast(initArg.getType()); + auto immType = ImmType::get(loopOp->getContext(), 0); + Value zeroImm = ConstantOp::create(copyBuilder, loc, immType, 0); return V_MOV_B32::create(copyBuilder, loc, vregType, zeroImm); } if (isSGPRType(initArg.getType())) { auto sregType = cast(initArg.getType()); - return S_MOV_B32::create(copyBuilder, loc, sregType, zeroImm); + return S_MOV_B32::create(copyBuilder, loc, sregType, initArg); } return nullptr; } @@ -194,9 +192,10 @@ struct LinearScanPass if (collectFailed) return failure(); - // Handle duplicate init args: if CSE merged identical zero-initialized - // accumulators, multiple block args may be tied to the same init value. - // Each block arg needs its own physical register, so insert copies. + // Handle duplicate init args: if CSE or other passes cause multiple block + // args to be tied to the same init value, each block arg still needs its + // own physical register. Insert a copy so the allocator sees distinct SSA + // values and doesn't try to coalesce both block args to the same phys reg. // This must run before liveness analysis since it modifies the IR. program.walk([&](LoopOp loopOp) { Block &bodyBlock = loopOp.getBodyBlock(); @@ -206,7 +205,7 @@ struct LinearScanPass if (i < loopOp.getInitArgs().size()) { Value initArg = loopOp.getInitArgs()[i]; if (usedInitArgs.contains(initArg)) { - Value copy = createZeroInitCopy(loopOp, initArg); + Value copy = createInitArgCopy(loopOp, initArg); if (copy) { loopOp.getInitArgsMutable()[i].set(copy); } diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/handlers/AMDGPUHandlers.cpp b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/handlers/AMDGPUHandlers.cpp index 1efe853f40..306980600e 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/handlers/AMDGPUHandlers.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/handlers/AMDGPUHandlers.cpp @@ -412,8 +412,13 @@ LogicalResult handleFatRawBufferCast(Operation *op, TranslationContext &ctx) { std::to_string(srcSrdBase + 1) + ", 0xffff"; RawOp::create(builder, loc, and1); + // SRD word 1 upper bits: cache_swizzle_enable (bit 30) | stride (bits 29:16) + uint32_t srdWord1Upper = + (1u << 30) | (static_cast(swizzleStride & 0x3FFF) << 16); + char hexBuf[16]; + snprintf(hexBuf, sizeof(hexBuf), "0x%X", srdWord1Upper); std::string or1 = "s_or_b32 s" + std::to_string(newSrdBase + 1) + ", s" + - std::to_string(newSrdBase + 1) + ", 0x40400000"; + std::to_string(newSrdBase + 1) + ", " + hexBuf; RawOp::create(builder, loc, or1); std::string mov2 = @@ -1088,7 +1093,13 @@ LogicalResult handleMemRefAtomicRMW(Operation *op, TranslationContext &ctx) { BUFFER_ATOMIC_PK_ADD_BF16::create(builder, loc, packed, srd, alignedVoffset, alignedInstOffset); } else if (elementType.isF32()) { - BUFFER_ATOMIC_ADD_F32::create(builder, loc, *valueMapped, srd, voffset, + Value f32Value = *valueMapped; + if (isAGPRType(f32Value.getType())) { + auto vregType = ctx.createVRegType(); + f32Value = + V_ACCVGPR_READ_B32::create(builder, loc, vregType, f32Value); + } + BUFFER_ATOMIC_ADD_F32::create(builder, loc, f32Value, srd, voffset, instOffset); } else { return op->emitError("unsupported element type for atomic_rmw: ") diff --git a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/handlers/VectorHandlers.cpp b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/handlers/VectorHandlers.cpp index 384b1b3a0c..70970d402d 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/handlers/VectorHandlers.cpp +++ b/wave_lang/kernel/wave/asm/wave_asm/lib/Transforms/handlers/VectorHandlers.cpp @@ -28,6 +28,7 @@ #include "waveasm/Dialect/WaveASMOps.h" #include "waveasm/Dialect/WaveASMTypes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "llvm/Support/Debug.h" @@ -37,6 +38,26 @@ using namespace mlir; namespace waveasm { +/// Return the effective element bit width in the register file. +/// +/// The arith.truncf handler defers vector conversions (e.g. f32->bf16), +/// leaving registers in the source layout while the MLIR type already +/// reflects the narrow destination type. When a subsequent extract +/// indexes into such a vector, the element width that governs register +/// indexing is the *source* width, not the nominal destination width. +static int64_t getEffectiveElemBits(Value source) { + if (auto truncOp = source.getDefiningOp()) { + auto srcType = truncOp.getIn().getType(); + if (auto vecType = dyn_cast(srcType)) { + if (vecType.getNumElements() > 1) + return vecType.getElementType().getIntOrFloatBitWidth(); + } + } + if (auto vecType = dyn_cast(source.getType())) + return vecType.getElementType().getIntOrFloatBitWidth(); + return 32; +} + LogicalResult handleVectorBroadcast(Operation *op, TranslationContext &ctx) { auto broadcastOp = cast(op); @@ -66,6 +87,46 @@ LogicalResult handleVectorExtract(Operation *op, TranslationContext &ctx) { index = staticPos[0]; } + int64_t elemBits = getEffectiveElemBits(extractOp.getSource()); + + // Sub-dword elements: multiple elements are packed per 32-bit VGPR. + // Compute which VGPR holds the element and the bit position within it. + if (elemBits < 32) { + int64_t elemsPerDword = 32 / elemBits; + int64_t dwordOffset = index / elemsPerDword; + int64_t bitOffset = (index % elemsPerDword) * elemBits; + + // Select the correct VGPR via register-level extraction. + Value dwordReg; + Type srcType = src->getType(); + if (auto pvreg = dyn_cast(srcType)) { + int64_t baseIdx = pvreg.getIndex() + dwordOffset; + auto elemType = PVRegType::get(builder.getContext(), baseIdx, 1); + dwordReg = PrecoloredVRegOp::create(builder, loc, elemType, baseIdx, 1); + } else if (dwordOffset == 0) { + dwordReg = *src; + } else { + auto elemType = ctx.createVRegType(1, 1); + auto extractWaveOp = ExtractOp::create( + builder, loc, elemType, *src, + builder.getI64IntegerAttr(dwordOffset)); + dwordReg = extractWaveOp.getResult(); + } + + if (bitOffset != 0) { + auto shiftImm = ConstantOp::create(builder, loc, + ctx.createImmType(bitOffset), + bitOffset); + auto shifted = V_LSHRREV_B32::create(builder, loc, + ctx.createVRegType(), shiftImm, + dwordReg); + ctx.getMapper().mapValue(extractOp.getResult(), shifted); + } else { + ctx.getMapper().mapValue(extractOp.getResult(), dwordReg); + } + return success(); + } + // Get the source register type to find the base physical register Type srcType = src->getType(); int64_t baseIdx = 0; @@ -215,23 +276,46 @@ LogicalResult handleVectorExtractStridedSlice(Operation *op, size = cast(sizes[0]).getInt(); } - // Check for sub-dword element extraction. When elements are smaller than - // 32 bits (e.g. vector<2xi8>), multiple elements are packed in a single - // VGPR. A register-level offset would read the wrong VGPR; instead we - // emit a right-shift to bring the desired byte(s) to bit-position 0. - auto srcVecType = extractOp.getSourceVectorType(); - int64_t elemBits = srcVecType.getElementType().getIntOrFloatBitWidth(); - if (elemBits < 32 && offset != 0) { - int64_t bitShift = offset * elemBits; - auto shiftImm = - ConstantOp::create(builder, loc, ctx.createImmType(bitShift), bitShift); - auto shifted = V_LSHRREV_B32::create(builder, loc, ctx.createVRegType(), - shiftImm, *src); - ctx.getMapper().mapValue(extractOp.getResult(), shifted); - return success(); - } - if (elemBits < 32 && offset == 0) { - ctx.getMapper().mapValue(extractOp.getResult(), *src); + // Sub-dword element extraction. When elements are smaller than 32 bits + // (e.g. bf16, i8), multiple elements are packed per VGPR. We first select + // the correct dword (VGPR), then shift within it if needed. + // Use getEffectiveElemBits to handle deferred arith.truncf where the + // registers are still in the wider source layout. + int64_t elemBits = getEffectiveElemBits(extractOp.getSource()); + if (elemBits < 32) { + int64_t elemsPerDword = 32 / elemBits; + int64_t dwordOffset = offset / elemsPerDword; + int64_t bitOffset = (offset % elemsPerDword) * elemBits; + + // Select the correct VGPR via register-level extraction. + Value dwordReg; + Type srcType = src->getType(); + if (auto pvreg = dyn_cast(srcType)) { + int64_t baseIdx = pvreg.getIndex() + dwordOffset; + auto elemType = PVRegType::get(builder.getContext(), baseIdx, 1); + dwordReg = + PrecoloredVRegOp::create(builder, loc, elemType, baseIdx, 1); + } else if (dwordOffset == 0) { + dwordReg = *src; + } else { + auto elemType = ctx.createVRegType(1, 1); + auto extractWaveOp = ExtractOp::create( + builder, loc, elemType, *src, + builder.getI64IntegerAttr(dwordOffset)); + dwordReg = extractWaveOp.getResult(); + } + + if (bitOffset != 0) { + auto shiftImm = ConstantOp::create(builder, loc, + ctx.createImmType(bitOffset), + bitOffset); + auto shifted = V_LSHRREV_B32::create(builder, loc, + ctx.createVRegType(), shiftImm, + dwordReg); + ctx.getMapper().mapValue(extractOp.getResult(), shifted); + } else { + ctx.getMapper().mapValue(extractOp.getResult(), dwordReg); + } return success(); } diff --git a/wave_lang/kernel/wave/asm/wave_asm/test/e2e/test_asm_backend_e2e.py b/wave_lang/kernel/wave/asm/wave_asm/test/e2e/test_asm_backend_e2e.py index 79dc19ca03..89d897e533 100644 --- a/wave_lang/kernel/wave/asm/wave_asm/test/e2e/test_asm_backend_e2e.py +++ b/wave_lang/kernel/wave/asm/wave_asm/test/e2e/test_asm_backend_e2e.py @@ -2170,6 +2170,7 @@ def test_splitk_mxfp4_preshuffle_scales_cpp_backend( @pytest.mark.parametrize( "shape,num_splits", [ + ((256, 256, 256), 1), ((256, 256, 256), 2), ], )