Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 69 additions & 26 deletions examples/python/7.1_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch

from wave_lang.kernel.wave.scheduling.schedule_enums import SchedulingType
from wave_lang.kernel.wave.compile import wave_compile
from wave_lang.kernel.wave.utils.run_utils import set_default_run_config
from wave_lang.kernel.wave.templates import (
Expand All @@ -31,7 +32,10 @@
b_preshuffle,
e8m0_shuffle,
)
from wave_lang.kernel.lang.global_symbols import GLOBAL_ADDRESS_SPACE
from wave_lang.kernel.lang.global_symbols import (
GLOBAL_ADDRESS_SPACE,
SHARED_ADDRESS_SPACE,
)
from utils import parse_args, list_tests, run_test


Expand Down Expand Up @@ -64,7 +68,7 @@ def _run_mxfp_gemm_preshuffle_b(gemm, shape):
x_scales_ps, w_scales_ps = x_scales_ps.cuda(), w_scales_ps.cuda()
out = torch.zeros(x.shape[0], w_t_ps.shape[0], dtype=torch.float32).cuda()

gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out)
gemm(x, x_scales, w_t_ps, w_scales_ps, out)
torch.testing.assert_close(
torch_out, out.cpu(), check_dtype=False, check_device=False
)
Expand All @@ -74,17 +78,26 @@ def test_dbuf_4wave_mxfp_gemm(
is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256)
):
"""Double-buffered MXFP4 GEMM, 4 waves, no stagger."""
gemm, options = get_tagged_mxfp4_gemm(shape, block, wave_shape=(2, 2))
schedule = get_mxfp4_dbuf_schedule(use_stagger=False)

options.print_ir_after = "all" if is_debug else []
options.print_mlir_file = "gemm_mxfp4_dbuf_4wave.mlir"
options.print_mlir = True
options = set_default_run_config(options)
gemm = wave_compile(options, gemm, schedule)

_run_mxfp_gemm(gemm, shape)
print("MXFP GEMM double-buffer 4-wave test passed!")
for block in [
(256, 256, 256),
(128, 256, 256),
(128, 128, 256),
]:
gemm, options = get_tagged_mxfp4_gemm(
shape, block, wave_shape=(2, 2), dynamic_dims=True
)
# schedule = get_mxfp4_dbuf_schedule(use_stagger=False)

options.print_ir_after = "all" if is_debug else []
options.print_mlir_file = "gemm_mxfp4_dbuf_4wave.mlir"
options.print_mlir = True
options.schedule = SchedulingType.NONE
options.dump_intermediates = f"build/dynamic/wave_gemm_mxfp4_dbuf_4wave_MT{block[0]}x{block[1]}x{block[2]}"
options = set_default_run_config(options)
gemm = wave_compile(options, gemm)

_run_mxfp_gemm(gemm, shape)
print(f"Dynamic MXFP GEMM double-buffer 4-wave test passed for block {block}!")


def test_dbuf_8wave_pingpong_mxfp_gemm(
Expand Down Expand Up @@ -167,19 +180,49 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm(
is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256)
):
"""Asymmetric MXFP4 GEMM with preshuffled B data and B scales."""
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4))
options.minimize_shared_allocs = True
options.linearize_shared_access = True
options.use_buffer_ops = True
options.dump_intermediates = "build/intermediates"
schedule = get_mxfp4_asymmetric_schedule()

options.print_ir_after = "all" if is_debug else []
options = set_default_run_config(options)
gemm = wave_compile(options, gemm, schedule)

_run_mxfp_gemm_preshuffle_b(gemm, shape)
print("MXFP GEMM preshuffle-B 4-wave test passed!")
macrotiles = [
(64, 192, 256),
(256, 256, 128),
]
for block in macrotiles:
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(
shape,
block,
wave_shape=(1, 4),
dynamic_dims=False,
a_address_space=SHARED_ADDRESS_SPACE,
)

block_m, block_n, block_k = block

options.schedule = SchedulingType.MANUAL
options.minimize_shared_allocs = True
options.linearize_shared_access = True
options.use_buffer_ops = True
# options.mlir_print_ir_after_all = True
options.dump_intermediates = f"build/b_preshuffle_dynamic/wave_gemm_mxfp4_preshuffle_MT{block_m}x{block_n}x{block_k}"
schedule = get_mxfp4_asymmetric_schedule()

options.print_ir_after = "all" if is_debug else []
options = set_default_run_config(options)
gemm = wave_compile(options, gemm, schedule)

_run_mxfp_gemm_preshuffle_b(gemm, shape)
print(f"MXFP GEMM preshuffle-B Dynamic 4-wave test passed for shape {shape}!")

# gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4))
# options.minimize_shared_allocs = True
# options.linearize_shared_access = True
# options.use_buffer_ops = True
# options.dump_intermediates = "build/intermediates"
# schedule = get_mxfp4_asymmetric_schedule()

# options.print_ir_after = "all" if is_debug else []
# options = set_default_run_config(options)
# gemm = wave_compile(options, gemm, schedule)

# _run_mxfp_gemm_preshuffle_b(gemm, shape)
# print("MXFP GEMM preshuffle-B 4-wave test passed!")


if __name__ == "__main__":
Expand Down
Loading