Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 217 additions & 3 deletions examples/python/7.1_schedule.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,33 @@
"""
MXFP4 Scaled GEMM Scheduling for GFX950 (MI350)

Double-buffered MXFP4 GEMM with 4-wave and 8-wave configurations.
Double-buffered MXFP4 GEMM with 4-wave and 8-wave configurations, plus split-K.
Uses get_tagged_mxfp4_gemm (templates) + get_mxfp4_dbuf_schedule (schedules).
Split-K kernels use the wave_asm backend with atomic bf16 output.

Usage:
python 7.1_schedule.py --test test_dbuf_4wave_mxfp_gemm
python 7.1_schedule.py --test test_dbuf_8wave_mxfp_gemm
python 7.1_schedule.py --test test_dbuf_8wave_mxfp_gemm --debug
python 7.1_schedule.py --test test_splitk_gemm
python 7.1_schedule.py --test test_splitk_preshuffle_scales_gemm
python 7.1_schedule.py --list_tests
"""

import sys
from dataclasses import dataclass
from pathlib import Path

import torch

from wave_lang.kernel.wave.compile import wave_compile
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
from wave_lang.kernel.wave.constraints import ScaledMMAType
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
from wave_lang.kernel.wave.templates import (
get_tagged_mxfp4_gemm,
get_tagged_mxfp4_gemm_preshuffle_b,
)
from wave_lang.kernel.wave.templates.gemm import get_splitk_mxfp4_gemm_kernel
from wave_lang.kernel.wave.schedules import (
get_mxfp4_dbuf_schedule,
get_mxfp4_dbuf_pingpong_schedule,
Expand All @@ -31,9 +40,22 @@
b_preshuffle,
e8m0_shuffle,
)
from wave_lang.kernel.lang.global_symbols import GLOBAL_ADDRESS_SPACE
import wave_lang.kernel.lang as tkl
from wave_lang.kernel.lang.global_symbols import (
GLOBAL_ADDRESS_SPACE,
SHARED_ADDRESS_SPACE,
)
from utils import parse_args, list_tests, run_test

_EXAMPLES_DIR = Path(__file__).parent
_WAVE_ROOT = _EXAMPLES_DIR.parent.parent
_E2E_DIR = (
_WAVE_ROOT / "wave_lang" / "kernel" / "wave" / "asm" / "wave_asm" / "test" / "e2e"
)
for _p in [str(_EXAMPLES_DIR), str(_WAVE_ROOT), str(_E2E_DIR)]:
if _p not in sys.path:
sys.path.insert(0, _p)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of this could you modify the imports so we can do something like
import WaveASMCompiler, capture_wave_kernel_info ?



def _run_mxfp_gemm(gemm, shape):
"""Run compiled GEMM kernel and verify against reference."""
Expand Down Expand Up @@ -182,6 +204,198 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm(
print("MXFP GEMM preshuffle-B 4-wave test passed!")


@dataclass
class SplitKKernelHandle:
"""Opaque handle for a compiled split-K MXFP4 kernel."""

gpu_func: object
binary_path: Path
kernel_name: str
grid: tuple[int, int, int]
block: tuple[int, int, int]
lds_size: int
num_splits: int


def get_splitk_kernel(
shape: tuple[int, int, int],
block: tuple[int, int, int] = (128, 128, 128),
num_splits: int = 2,
waves_per_block: tuple[int, int] = (2, 2),
preshuffle_scales: bool = False,
compiler=None,
) -> SplitKKernelHandle:
"""Compile a split-K MXFP4 GEMM kernel through the wave_asm backend.

Output tensor must be bf16 and zero-initialised before each call.
w must be in [N, K/2] layout.

Args:
preshuffle_scales: If True, a_scale and b_scale are read from GLOBAL
memory using the e8m0_shuffle IndexMapping. The caller must pass
scales pre-shuffled with e8m0_shuffle().
"""
from waveasm_e2e import WaveASMCompiler, capture_wave_kernel_info
from test_asm_backend_e2e import get_target_arch

if compiler is None:
compiler = WaveASMCompiler(target=get_target_arch())

splitk_fn, hyperparams = get_splitk_mxfp4_gemm_kernel(
shape,
num_splits=num_splits,
mfma_variant=ScaledMMAType.F32_16x16x128_F8F6F4,
block_shape=block,
waves_per_block=waves_per_block,
preshuffle_scales=preshuffle_scales,
)
hyperparams[tkl.sym.ADDRESS_SPACE] = SHARED_ADDRESS_SPACE
hyperparams[tkl.sym.B_ADDRESS_SPACE] = SHARED_ADDRESS_SPACE

options = WaveCompileOptions(
subs=hyperparams,
canonicalize=True,
backend="asm",
wave_runtime=True,
compile_to_mlir=False,
use_global_to_shared=True,
)
options = set_default_run_config(options)

kernel_info = capture_wave_kernel_info(options, splitk_fn)
cpp_result = compiler.compile_full(
kernel_info.mlir_text, kernel_info.workgroup_size
)
if not cpp_result.success:
raise RuntimeError(f"wave_asm compilation failed: {cpp_result.error_message}")

import wave_runtime

wave_runtime.load_hip_functions()
_binary, gpu_func = wave_runtime.load_binary(
str(cpp_result.binary_path),
cpp_result.get_kernel_name() or kernel_info.kernel_name,
)

return SplitKKernelHandle(
gpu_func=gpu_func,
binary_path=cpp_result.binary_path,
kernel_name=cpp_result.get_kernel_name() or kernel_info.kernel_name,
grid=kernel_info.grid_size,
block=kernel_info.workgroup_size,
lds_size=kernel_info.lds_size,
num_splits=num_splits,
)


def run_splitk_kernel(
handle: SplitKKernelHandle,
x: torch.Tensor,
x_scales: torch.Tensor,
w: torch.Tensor,
w_scales: torch.Tensor,
c_out: torch.Tensor,
) -> None:
"""Launch a compiled split-K kernel.

c_out must be zero-initialised (dtype=torch.bfloat16).
w must be in [N, K/2] layout.
"""
import wave_runtime

stream = torch.cuda.current_stream().cuda_stream
kli = wave_runtime.KernelLaunchInfo(
stream,
handle.gpu_func,
handle.lds_size,
handle.grid[0],
handle.grid[1],
handle.grid[2],
handle.block[0],
handle.block[1],
handle.block[2],
1,
1,
1,
)
kern_args = wave_runtime.Int64Vector(
[t.data_ptr() for t in [x, x_scales, w, w_scales, c_out]]
)
wave_runtime.launch(kli, kern_args, [], [])


def test_splitk_gemm(
is_debug: bool = False,
shape: tuple[int, int, int] = (1024, 1024, 8192),
block: tuple[int, int, int] = (128, 128, 256),
):
"""Split-K MXFP4 GEMM (wave_asm backend, bf16 output, unshuffled scales)."""
m, n, k = shape
num_splits = 2
x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape)
# w from generate_gemm_afp4wfp4_inputs is [K/2, N]; split-K kernel wants [N, K/2]
w_nk = w.T.contiguous()
torch_ref = torchScaledGemmMXFP4(x, w, x_scales, w_scales)

handle = get_splitk_kernel(shape, block=block, num_splits=num_splits)

c_out = torch.zeros(m, n, dtype=torch.bfloat16, device="cuda")
run_splitk_kernel(
handle, x.cuda(), x_scales.cuda(), w_nk.cuda(), w_scales.cuda(), c_out
)
torch.cuda.synchronize()

bf16_eps = 2**-7
atol = num_splits * bf16_eps * max(torch_ref.abs().max().item(), 1.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does atol depend on num_splits? Shouldnt it be independent of num splits?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the number of splits increases the accumulation of errors. IE we get error from casting to BF16, then we accumulate in BF16 which means we accumulate error num_splits times.

torch.testing.assert_close(
torch_ref,
c_out.cpu().to(torch.float32),
check_dtype=False,
check_device=False,
atol=atol,
rtol=0.0,
)
print("Split-K MXFP4 GEMM test passed!")


def test_splitk_preshuffle_scales_gemm(
is_debug: bool = False,
shape: tuple[int, int, int] = (1024, 1024, 8192),
block: tuple[int, int, int] = (128, 128, 256),
):
"""Split-K MXFP4 GEMM (wave_asm backend, bf16 output, preshuffled scales)."""
m, n, k = shape
num_splits = 2
x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape)
w_nk = w.T.contiguous()
torch_ref = torchScaledGemmMXFP4(x, w, x_scales, w_scales)

x_scales_sh = e8m0_shuffle(x_scales)
w_scales_sh = e8m0_shuffle(w_scales)

handle = get_splitk_kernel(
shape, block=block, num_splits=num_splits, preshuffle_scales=True
)

c_out = torch.zeros(m, n, dtype=torch.bfloat16, device="cuda")
run_splitk_kernel(
handle, x.cuda(), x_scales_sh.cuda(), w_nk.cuda(), w_scales_sh.cuda(), c_out
)
torch.cuda.synchronize()

bf16_eps = 2**-7
atol = num_splits * bf16_eps * max(torch_ref.abs().max().item(), 1.0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here regarding num_splits

torch.testing.assert_close(
torch_ref,
c_out.cpu().to(torch.float32),
check_dtype=False,
check_device=False,
atol=atol,
rtol=0.0,
)
print("Split-K MXFP4 GEMM (preshuffled scales) test passed!")


if __name__ == "__main__":
args = parse_args()

Expand Down
62 changes: 61 additions & 1 deletion tests/kernel/wave_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,22 @@
require_rdna4,
use_water_backend_bool,
)
from wave_lang.kernel.wave.constraints import MMAType, MMAOperand, GenericDot
from wave_lang.kernel.wave.constraints import (
MMAType,
MMAOperand,
GenericDot,
ScaledMMAType,
)
from wave_lang.kernel.wave.utils.mxfp_utils import (
generate_gemm_afp4wfp4_inputs,
torchScaledGemmMXFP4,
)
from wave_lang.kernel.wave.templates.gemm import (
get_gemm_kernel,
get_gemm_kernel_transpose_a_b,
get_persistent_gemm_kernel,
get_splitk_gemm_kernel,
get_splitk_mxfp4_gemm_kernel,
get_streamk_gemm_kernel,
get_hybrid_streamk_gemm_kernel,
get_persistent_reordering_kernel,
Expand Down Expand Up @@ -3614,3 +3624,53 @@ def testSplitKGemm(

torch_ref = a.cpu().to(torch.float32) @ b.cpu().T.to(torch.float32)
assert_close(c.cpu(), torch_ref, rtol=1e-3, atol=1e-2)


@require_e2e
@require_cdna4
@pytest.mark.parametrize(
"shape, num_splits",
[
((256, 256, 256), 2),
((256, 256, 512), 4),
((256, 256, 1024), 4),
((512, 512, 1024), 2),
],
)
def testSplitKMxfp4Gemm(
shape: tuple[int, int, int],
num_splits: int,
):
splitk_gemm, hyperparams = get_splitk_mxfp4_gemm_kernel(
shape,
num_splits=num_splits,
mfma_variant=ScaledMMAType.F32_16x16x128_F8F6F4,
)

options = WaveCompileOptions(
subs=hyperparams,
canonicalize=True,
)
options = set_default_run_config(options)
splitk_gemm = wave_compile(options, splitk_gemm)

m, n, k = shape
x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(
shape, device=torch.device("cpu")
)
torch_ref = torchScaledGemmMXFP4(x, w, x_scales, w_scales)

x_gpu = x.cuda()
w_t_gpu = w.T.contiguous().cuda()
x_scales_gpu = x_scales.cuda()
w_scales_gpu = w_scales.cuda()
c_gpu = device_zeros(m, n, dtype=torch.bfloat16)

splitk_gemm(x_gpu, x_scales_gpu, w_t_gpu, w_scales_gpu, c_gpu)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we have the bitcast to fp16 controllable through a flag so that for the correctness tests, we disable the bitcast?

# Each split accumulates a partial f32 sum, casts it to bf16, then atomically
# adds it into the bf16 output. At partial-sum magnitudes of ~250-430 the bf16
# ULP is 2-4, so with 2-4 splits the worst-case absolute rounding error is
# num_splits * max_partial_magnitude * bf16_eps ≈ 4 * 430 * 2^-7 ≈ 13.
# Observed worst-case across all parametrized shapes is ~3, so atol=4.0 is
# generous but justified by the bf16 accumulation arithmetic.
assert_close(c_gpu.cpu().to(torch.float32), torch_ref, rtol=1e-1, atol=4.0)
Loading
Loading