-
Notifications
You must be signed in to change notification settings - Fork 28
Implement mxfp4 split-k gemm #958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,24 +1,33 @@ | ||
| """ | ||
| MXFP4 Scaled GEMM Scheduling for GFX950 (MI350) | ||
|
|
||
| Double-buffered MXFP4 GEMM with 4-wave and 8-wave configurations. | ||
| Double-buffered MXFP4 GEMM with 4-wave and 8-wave configurations, plus split-K. | ||
| Uses get_tagged_mxfp4_gemm (templates) + get_mxfp4_dbuf_schedule (schedules). | ||
| Split-K kernels use the wave_asm backend with atomic bf16 output. | ||
|
|
||
| Usage: | ||
| python 7.1_schedule.py --test test_dbuf_4wave_mxfp_gemm | ||
| python 7.1_schedule.py --test test_dbuf_8wave_mxfp_gemm | ||
| python 7.1_schedule.py --test test_dbuf_8wave_mxfp_gemm --debug | ||
| python 7.1_schedule.py --test test_splitk_gemm | ||
| python 7.1_schedule.py --test test_splitk_preshuffle_scales_gemm | ||
| python 7.1_schedule.py --list_tests | ||
| """ | ||
|
|
||
| import sys | ||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
|
|
||
| import torch | ||
|
|
||
| from wave_lang.kernel.wave.compile import wave_compile | ||
| from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile | ||
| from wave_lang.kernel.wave.constraints import ScaledMMAType | ||
| from wave_lang.kernel.wave.utils.run_utils import set_default_run_config | ||
| from wave_lang.kernel.wave.templates import ( | ||
| get_tagged_mxfp4_gemm, | ||
| get_tagged_mxfp4_gemm_preshuffle_b, | ||
| ) | ||
| from wave_lang.kernel.wave.templates.gemm import get_splitk_mxfp4_gemm_kernel | ||
| from wave_lang.kernel.wave.schedules import ( | ||
| get_mxfp4_dbuf_schedule, | ||
| get_mxfp4_dbuf_pingpong_schedule, | ||
|
|
@@ -31,9 +40,22 @@ | |
| b_preshuffle, | ||
| e8m0_shuffle, | ||
| ) | ||
| from wave_lang.kernel.lang.global_symbols import GLOBAL_ADDRESS_SPACE | ||
| import wave_lang.kernel.lang as tkl | ||
| from wave_lang.kernel.lang.global_symbols import ( | ||
| GLOBAL_ADDRESS_SPACE, | ||
| SHARED_ADDRESS_SPACE, | ||
| ) | ||
| from utils import parse_args, list_tests, run_test | ||
|
|
||
| _EXAMPLES_DIR = Path(__file__).parent | ||
| _WAVE_ROOT = _EXAMPLES_DIR.parent.parent | ||
| _E2E_DIR = ( | ||
| _WAVE_ROOT / "wave_lang" / "kernel" / "wave" / "asm" / "wave_asm" / "test" / "e2e" | ||
| ) | ||
| for _p in [str(_EXAMPLES_DIR), str(_WAVE_ROOT), str(_E2E_DIR)]: | ||
| if _p not in sys.path: | ||
| sys.path.insert(0, _p) | ||
|
|
||
|
|
||
| def _run_mxfp_gemm(gemm, shape): | ||
| """Run compiled GEMM kernel and verify against reference.""" | ||
|
|
@@ -182,6 +204,198 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm( | |
| print("MXFP GEMM preshuffle-B 4-wave test passed!") | ||
|
|
||
|
|
||
| @dataclass | ||
| class SplitKKernelHandle: | ||
| """Opaque handle for a compiled split-K MXFP4 kernel.""" | ||
|
|
||
| gpu_func: object | ||
| binary_path: Path | ||
| kernel_name: str | ||
| grid: tuple[int, int, int] | ||
| block: tuple[int, int, int] | ||
| lds_size: int | ||
| num_splits: int | ||
|
|
||
|
|
||
| def get_splitk_kernel( | ||
| shape: tuple[int, int, int], | ||
| block: tuple[int, int, int] = (128, 128, 128), | ||
| num_splits: int = 2, | ||
| waves_per_block: tuple[int, int] = (2, 2), | ||
| preshuffle_scales: bool = False, | ||
| compiler=None, | ||
| ) -> SplitKKernelHandle: | ||
| """Compile a split-K MXFP4 GEMM kernel through the wave_asm backend. | ||
|
|
||
| Output tensor must be bf16 and zero-initialised before each call. | ||
| w must be in [N, K/2] layout. | ||
|
|
||
| Args: | ||
| preshuffle_scales: If True, a_scale and b_scale are read from GLOBAL | ||
| memory using the e8m0_shuffle IndexMapping. The caller must pass | ||
| scales pre-shuffled with e8m0_shuffle(). | ||
| """ | ||
| from waveasm_e2e import WaveASMCompiler, capture_wave_kernel_info | ||
| from test_asm_backend_e2e import get_target_arch | ||
|
|
||
| if compiler is None: | ||
| compiler = WaveASMCompiler(target=get_target_arch()) | ||
|
|
||
| splitk_fn, hyperparams = get_splitk_mxfp4_gemm_kernel( | ||
| shape, | ||
| num_splits=num_splits, | ||
| mfma_variant=ScaledMMAType.F32_16x16x128_F8F6F4, | ||
| block_shape=block, | ||
| waves_per_block=waves_per_block, | ||
| preshuffle_scales=preshuffle_scales, | ||
| ) | ||
| hyperparams[tkl.sym.ADDRESS_SPACE] = SHARED_ADDRESS_SPACE | ||
| hyperparams[tkl.sym.B_ADDRESS_SPACE] = SHARED_ADDRESS_SPACE | ||
|
|
||
| options = WaveCompileOptions( | ||
| subs=hyperparams, | ||
| canonicalize=True, | ||
| backend="asm", | ||
| wave_runtime=True, | ||
| compile_to_mlir=False, | ||
| use_global_to_shared=True, | ||
| ) | ||
| options = set_default_run_config(options) | ||
|
|
||
| kernel_info = capture_wave_kernel_info(options, splitk_fn) | ||
| cpp_result = compiler.compile_full( | ||
| kernel_info.mlir_text, kernel_info.workgroup_size | ||
| ) | ||
| if not cpp_result.success: | ||
| raise RuntimeError(f"wave_asm compilation failed: {cpp_result.error_message}") | ||
|
|
||
| import wave_runtime | ||
|
|
||
| wave_runtime.load_hip_functions() | ||
| _binary, gpu_func = wave_runtime.load_binary( | ||
| str(cpp_result.binary_path), | ||
| cpp_result.get_kernel_name() or kernel_info.kernel_name, | ||
| ) | ||
|
|
||
| return SplitKKernelHandle( | ||
| gpu_func=gpu_func, | ||
| binary_path=cpp_result.binary_path, | ||
| kernel_name=cpp_result.get_kernel_name() or kernel_info.kernel_name, | ||
| grid=kernel_info.grid_size, | ||
| block=kernel_info.workgroup_size, | ||
| lds_size=kernel_info.lds_size, | ||
| num_splits=num_splits, | ||
| ) | ||
|
|
||
|
|
||
| def run_splitk_kernel( | ||
| handle: SplitKKernelHandle, | ||
| x: torch.Tensor, | ||
| x_scales: torch.Tensor, | ||
| w: torch.Tensor, | ||
| w_scales: torch.Tensor, | ||
| c_out: torch.Tensor, | ||
| ) -> None: | ||
| """Launch a compiled split-K kernel. | ||
|
|
||
| c_out must be zero-initialised (dtype=torch.bfloat16). | ||
| w must be in [N, K/2] layout. | ||
| """ | ||
| import wave_runtime | ||
|
|
||
| stream = torch.cuda.current_stream().cuda_stream | ||
| kli = wave_runtime.KernelLaunchInfo( | ||
| stream, | ||
| handle.gpu_func, | ||
| handle.lds_size, | ||
| handle.grid[0], | ||
| handle.grid[1], | ||
| handle.grid[2], | ||
| handle.block[0], | ||
| handle.block[1], | ||
| handle.block[2], | ||
| 1, | ||
| 1, | ||
| 1, | ||
| ) | ||
| kern_args = wave_runtime.Int64Vector( | ||
| [t.data_ptr() for t in [x, x_scales, w, w_scales, c_out]] | ||
| ) | ||
| wave_runtime.launch(kli, kern_args, [], []) | ||
|
|
||
|
|
||
| def test_splitk_gemm( | ||
| is_debug: bool = False, | ||
| shape: tuple[int, int, int] = (1024, 1024, 8192), | ||
| block: tuple[int, int, int] = (128, 128, 256), | ||
| ): | ||
| """Split-K MXFP4 GEMM (wave_asm backend, bf16 output, unshuffled scales).""" | ||
| m, n, k = shape | ||
| num_splits = 2 | ||
| x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) | ||
| # w from generate_gemm_afp4wfp4_inputs is [K/2, N]; split-K kernel wants [N, K/2] | ||
| w_nk = w.T.contiguous() | ||
| torch_ref = torchScaledGemmMXFP4(x, w, x_scales, w_scales) | ||
|
|
||
| handle = get_splitk_kernel(shape, block=block, num_splits=num_splits) | ||
|
|
||
| c_out = torch.zeros(m, n, dtype=torch.bfloat16, device="cuda") | ||
| run_splitk_kernel( | ||
| handle, x.cuda(), x_scales.cuda(), w_nk.cuda(), w_scales.cuda(), c_out | ||
| ) | ||
| torch.cuda.synchronize() | ||
|
|
||
| bf16_eps = 2**-7 | ||
| atol = num_splits * bf16_eps * max(torch_ref.abs().max().item(), 1.0) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why does atol depend on num_splits? Shouldnt it be independent of num splits?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, the number of splits increases the accumulation of errors. IE we get error from casting to BF16, then we accumulate in BF16 which means we accumulate error |
||
| torch.testing.assert_close( | ||
| torch_ref, | ||
| c_out.cpu().to(torch.float32), | ||
| check_dtype=False, | ||
| check_device=False, | ||
| atol=atol, | ||
| rtol=0.0, | ||
| ) | ||
| print("Split-K MXFP4 GEMM test passed!") | ||
|
|
||
|
|
||
| def test_splitk_preshuffle_scales_gemm( | ||
| is_debug: bool = False, | ||
| shape: tuple[int, int, int] = (1024, 1024, 8192), | ||
| block: tuple[int, int, int] = (128, 128, 256), | ||
| ): | ||
| """Split-K MXFP4 GEMM (wave_asm backend, bf16 output, preshuffled scales).""" | ||
| m, n, k = shape | ||
| num_splits = 2 | ||
| x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) | ||
| w_nk = w.T.contiguous() | ||
| torch_ref = torchScaledGemmMXFP4(x, w, x_scales, w_scales) | ||
|
|
||
| x_scales_sh = e8m0_shuffle(x_scales) | ||
| w_scales_sh = e8m0_shuffle(w_scales) | ||
|
|
||
| handle = get_splitk_kernel( | ||
| shape, block=block, num_splits=num_splits, preshuffle_scales=True | ||
| ) | ||
|
|
||
| c_out = torch.zeros(m, n, dtype=torch.bfloat16, device="cuda") | ||
| run_splitk_kernel( | ||
| handle, x.cuda(), x_scales_sh.cuda(), w_nk.cuda(), w_scales_sh.cuda(), c_out | ||
| ) | ||
| torch.cuda.synchronize() | ||
|
|
||
| bf16_eps = 2**-7 | ||
| atol = num_splits * bf16_eps * max(torch_ref.abs().max().item(), 1.0) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here regarding num_splits |
||
| torch.testing.assert_close( | ||
| torch_ref, | ||
| c_out.cpu().to(torch.float32), | ||
| check_dtype=False, | ||
| check_device=False, | ||
| atol=atol, | ||
| rtol=0.0, | ||
| ) | ||
| print("Split-K MXFP4 GEMM (preshuffled scales) test passed!") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| args = parse_args() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,12 +41,22 @@ | |
| require_rdna4, | ||
| use_water_backend_bool, | ||
| ) | ||
| from wave_lang.kernel.wave.constraints import MMAType, MMAOperand, GenericDot | ||
| from wave_lang.kernel.wave.constraints import ( | ||
| MMAType, | ||
| MMAOperand, | ||
| GenericDot, | ||
| ScaledMMAType, | ||
| ) | ||
| from wave_lang.kernel.wave.utils.mxfp_utils import ( | ||
| generate_gemm_afp4wfp4_inputs, | ||
| torchScaledGemmMXFP4, | ||
| ) | ||
| from wave_lang.kernel.wave.templates.gemm import ( | ||
| get_gemm_kernel, | ||
| get_gemm_kernel_transpose_a_b, | ||
| get_persistent_gemm_kernel, | ||
| get_splitk_gemm_kernel, | ||
| get_splitk_mxfp4_gemm_kernel, | ||
| get_streamk_gemm_kernel, | ||
| get_hybrid_streamk_gemm_kernel, | ||
| get_persistent_reordering_kernel, | ||
|
|
@@ -3614,3 +3624,53 @@ def testSplitKGemm( | |
|
|
||
| torch_ref = a.cpu().to(torch.float32) @ b.cpu().T.to(torch.float32) | ||
| assert_close(c.cpu(), torch_ref, rtol=1e-3, atol=1e-2) | ||
|
|
||
|
|
||
| @require_e2e | ||
| @require_cdna4 | ||
| @pytest.mark.parametrize( | ||
| "shape, num_splits", | ||
| [ | ||
| ((256, 256, 256), 2), | ||
| ((256, 256, 512), 4), | ||
| ((256, 256, 1024), 4), | ||
| ((512, 512, 1024), 2), | ||
| ], | ||
| ) | ||
| def testSplitKMxfp4Gemm( | ||
| shape: tuple[int, int, int], | ||
| num_splits: int, | ||
| ): | ||
| splitk_gemm, hyperparams = get_splitk_mxfp4_gemm_kernel( | ||
| shape, | ||
| num_splits=num_splits, | ||
| mfma_variant=ScaledMMAType.F32_16x16x128_F8F6F4, | ||
| ) | ||
|
|
||
| options = WaveCompileOptions( | ||
| subs=hyperparams, | ||
| canonicalize=True, | ||
| ) | ||
| options = set_default_run_config(options) | ||
| splitk_gemm = wave_compile(options, splitk_gemm) | ||
|
|
||
| m, n, k = shape | ||
| x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs( | ||
| shape, device=torch.device("cpu") | ||
| ) | ||
| torch_ref = torchScaledGemmMXFP4(x, w, x_scales, w_scales) | ||
|
|
||
| x_gpu = x.cuda() | ||
| w_t_gpu = w.T.contiguous().cuda() | ||
| x_scales_gpu = x_scales.cuda() | ||
| w_scales_gpu = w_scales.cuda() | ||
| c_gpu = device_zeros(m, n, dtype=torch.bfloat16) | ||
|
|
||
| splitk_gemm(x_gpu, x_scales_gpu, w_t_gpu, w_scales_gpu, c_gpu) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we have the bitcast to fp16 controllable through a flag so that for the correctness tests, we disable the bitcast? |
||
| # Each split accumulates a partial f32 sum, casts it to bf16, then atomically | ||
| # adds it into the bf16 output. At partial-sum magnitudes of ~250-430 the bf16 | ||
| # ULP is 2-4, so with 2-4 splits the worst-case absolute rounding error is | ||
| # num_splits * max_partial_magnitude * bf16_eps ≈ 4 * 430 * 2^-7 ≈ 13. | ||
| # Observed worst-case across all parametrized shapes is ~3, so atol=4.0 is | ||
| # generous but justified by the bf16 accumulation arithmetic. | ||
| assert_close(c_gpu.cpu().to(torch.float32), torch_ref, rtol=1e-1, atol=4.0) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of this could you modify the imports so we can do something like
import WaveASMCompiler, capture_wave_kernel_info ?