diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index 5110fd3c9..f968775bb 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -13,6 +13,7 @@ import torch +from wave_lang.kernel.wave.scheduling.schedule_enums import SchedulingType from wave_lang.kernel.wave.compile import wave_compile from wave_lang.kernel.wave.utils.run_utils import set_default_run_config from wave_lang.kernel.wave.templates import ( @@ -31,7 +32,10 @@ b_preshuffle, e8m0_shuffle, ) -from wave_lang.kernel.lang.global_symbols import GLOBAL_ADDRESS_SPACE +from wave_lang.kernel.lang.global_symbols import ( + GLOBAL_ADDRESS_SPACE, + SHARED_ADDRESS_SPACE, +) from utils import parse_args, list_tests, run_test @@ -64,7 +68,7 @@ def _run_mxfp_gemm_preshuffle_b(gemm, shape): x_scales_ps, w_scales_ps = x_scales_ps.cuda(), w_scales_ps.cuda() out = torch.zeros(x.shape[0], w_t_ps.shape[0], dtype=torch.float32).cuda() - gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out) + gemm(x, x_scales, w_t_ps, w_scales_ps, out) torch.testing.assert_close( torch_out, out.cpu(), check_dtype=False, check_device=False ) @@ -74,17 +78,26 @@ def test_dbuf_4wave_mxfp_gemm( is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256) ): """Double-buffered MXFP4 GEMM, 4 waves, no stagger.""" - gemm, options = get_tagged_mxfp4_gemm(shape, block, wave_shape=(2, 2)) - schedule = get_mxfp4_dbuf_schedule(use_stagger=False) - - options.print_ir_after = "all" if is_debug else [] - options.print_mlir_file = "gemm_mxfp4_dbuf_4wave.mlir" - options.print_mlir = True - options = set_default_run_config(options) - gemm = wave_compile(options, gemm, schedule) - - _run_mxfp_gemm(gemm, shape) - print("MXFP GEMM double-buffer 4-wave test passed!") + for block in [ + (256, 256, 256), + (128, 256, 256), + (128, 128, 256), + ]: + gemm, options = get_tagged_mxfp4_gemm( + shape, block, wave_shape=(2, 2), dynamic_dims=True + ) + # schedule = get_mxfp4_dbuf_schedule(use_stagger=False) + + options.print_ir_after = "all" if is_debug else [] + options.print_mlir_file = "gemm_mxfp4_dbuf_4wave.mlir" + options.print_mlir = True + options.schedule = SchedulingType.NONE + options.dump_intermediates = f"build/dynamic/wave_gemm_mxfp4_dbuf_4wave_MT{block[0]}x{block[1]}x{block[2]}" + options = set_default_run_config(options) + gemm = wave_compile(options, gemm) + + _run_mxfp_gemm(gemm, shape) + print(f"Dynamic MXFP GEMM double-buffer 4-wave test passed for block {block}!") def test_dbuf_8wave_pingpong_mxfp_gemm( @@ -167,19 +180,49 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm( is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256) ): """Asymmetric MXFP4 GEMM with preshuffled B data and B scales.""" - gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4)) - options.minimize_shared_allocs = True - options.linearize_shared_access = True - options.use_buffer_ops = True - options.dump_intermediates = "build/intermediates" - schedule = get_mxfp4_asymmetric_schedule() - - options.print_ir_after = "all" if is_debug else [] - options = set_default_run_config(options) - gemm = wave_compile(options, gemm, schedule) - - _run_mxfp_gemm_preshuffle_b(gemm, shape) - print("MXFP GEMM preshuffle-B 4-wave test passed!") + macrotiles = [ + (64, 192, 256), + (256, 256, 128), + ] + for block in macrotiles: + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape, + block, + wave_shape=(1, 4), + dynamic_dims=False, + a_address_space=SHARED_ADDRESS_SPACE, + ) + + block_m, block_n, block_k = block + + options.schedule = SchedulingType.MANUAL + options.minimize_shared_allocs = True + options.linearize_shared_access = True + options.use_buffer_ops = True + # options.mlir_print_ir_after_all = True + options.dump_intermediates = f"build/b_preshuffle_dynamic/wave_gemm_mxfp4_preshuffle_MT{block_m}x{block_n}x{block_k}" + schedule = get_mxfp4_asymmetric_schedule() + + options.print_ir_after = "all" if is_debug else [] + options = set_default_run_config(options) + gemm = wave_compile(options, gemm, schedule) + + _run_mxfp_gemm_preshuffle_b(gemm, shape) + print(f"MXFP GEMM preshuffle-B Dynamic 4-wave test passed for shape {shape}!") + + # gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4)) + # options.minimize_shared_allocs = True + # options.linearize_shared_access = True + # options.use_buffer_ops = True + # options.dump_intermediates = "build/intermediates" + # schedule = get_mxfp4_asymmetric_schedule() + + # options.print_ir_after = "all" if is_debug else [] + # options = set_default_run_config(options) + # gemm = wave_compile(options, gemm, schedule) + + # _run_mxfp_gemm_preshuffle_b(gemm, shape) + # print("MXFP GEMM preshuffle-B 4-wave test passed!") if __name__ == "__main__": diff --git a/wave_lang/kernel/wave/perf/benchmark_mxfp4.py b/wave_lang/kernel/wave/perf/benchmark_mxfp4.py index b56a2db6e..3ca10e245 100644 --- a/wave_lang/kernel/wave/perf/benchmark_mxfp4.py +++ b/wave_lang/kernel/wave/perf/benchmark_mxfp4.py @@ -22,6 +22,7 @@ import shutil import subprocess import sys +import tempfile import traceback from datetime import datetime from pathlib import Path @@ -29,11 +30,18 @@ import torch from wave_lang.kernel.wave.compile import wave_compile -from wave_lang.kernel.wave.schedules import get_mxfp4_dbuf_schedule +from wave_lang.kernel.wave.schedules import ( + get_mxfp4_asymmetric_schedule, + get_mxfp4_dbuf_schedule, +) +from wave_lang.kernel.wave.scheduling.schedule_enums import SchedulingType +from wave_lang.kernel.wave.templates import get_tagged_mxfp4_gemm_preshuffle_b from wave_lang.kernel.wave.templates.tagged_mxfp4_gemm import get_tagged_mxfp4_gemm from wave_lang.kernel.lang.global_symbols import * from wave_lang.kernel.wave.utils.run_utils import set_default_run_config from wave_lang.kernel.wave.utils.mxfp_utils import ( + b_preshuffle, + e8m0_shuffle, generate_gemm_afp4wfp4_inputs, torchScaledGemmMXFP4, ) @@ -51,13 +59,22 @@ def get_mxfp4_gemm_wave( shape: tuple[int, int, int], macrotiles: tuple[int, int, int], + dump_dir: Optional[Path] = None, + gemm_id: Optional[str] = None, ): - gemm, options = get_tagged_mxfp4_gemm(shape, macrotiles) - schedule = get_mxfp4_dbuf_schedule(use_stagger=True) + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape, macrotiles, wave_shape=(1, 4) + ) + options.schedule = SchedulingType.MANUAL + options.minimize_shared_allocs = True + options.linearize_shared_access = True + options.use_buffer_ops = True + if dump_dir is not None and gemm_id is not None: + options.dump_intermediates = str(dump_dir / "inter" / gemm_id) + schedule = get_mxfp4_asymmetric_schedule() options = set_default_run_config(options) - - compiled_gemm = wave_compile(options, gemm, schedule) - return compiled_gemm + gemm = wave_compile(options, gemm, schedule) + return gemm # --------------------------------------------------------------------------- @@ -149,10 +166,16 @@ def run_worker( gemm_rt = get_mxfp4_gemm_wave(shape, macrotiles) device = torch.device("cuda") - x, w, x_scale, w_scale = generate_gemm_afp4wfp4_inputs((m, n, k), device) + x, w, x_scale, w_scale = generate_gemm_afp4wfp4_inputs((m, n, k)) w_t = w.T.contiguous() - wave_out = torch.empty(m, n, device=device, dtype=torch.float32) - inputs = (x, x_scale, w_t, w_scale, wave_out) + w_t_ps = b_preshuffle(w_t) + w_scale_ps = e8m0_shuffle(w_scale) + x, w_t_ps = x.cuda(), w_t_ps.cuda() + x_scale, w_scale_ps = x_scale.cuda(), w_scale_ps.cuda() + wave_out = torch.empty( + x.shape[0], w_t_ps.shape[0], device=device, dtype=torch.float32 + ) + inputs = (x, x_scale, w_t_ps, w_scale_ps, wave_out) mean_us = _run_torch_benchmark( gemm_rt, inputs, warmup_iters=warmup_iters, benchmark_iters=benchmark_iters @@ -165,13 +188,21 @@ def validate_mxfp4_gemm(shape: tuple[int, int, int], compiled_gemm) -> bool: m, n, k = shape try: device = torch.device("cuda") - x, w, x_scale, w_scale = generate_gemm_afp4wfp4_inputs(shape, device) + x, w, x_scale, w_scale = generate_gemm_afp4wfp4_inputs((m, n, k)) w_t = w.T.contiguous() - wave_out = torch.zeros(m, n, device=device, dtype=torch.float32) + w_t_ps = b_preshuffle(w_t) + w_scale_ps = e8m0_shuffle(w_scale) + x, w_t_ps = x.cuda(), w_t_ps.cuda() + x_scale, w_scale_ps = x_scale.cuda(), w_scale_ps.cuda() + wave_out = torch.empty( + x.shape[0], w_t_ps.shape[0], device=device, dtype=torch.float32 + ) - compiled_gemm(x, x_scale, w_t, w_scale, wave_out) + compiled_gemm(x, x_scale, w_t_ps, w_scale_ps, wave_out) torch_ref = torchScaledGemmMXFP4(x, w, x_scale, w_scale) - torch.testing.assert_close(wave_out, torch_ref, check_device=False) + torch.testing.assert_close( + wave_out, torch_ref, check_dtype=False, check_device=False + ) return True except Exception as e: raise RuntimeError(f"Validation failed for shape {shape}: {e}") from e @@ -274,7 +305,9 @@ def run_validate_and_benchmark( # Compile for validation (wave_runtime=True) try: - gemm_rt = get_mxfp4_gemm_wave(shape, macrotiles) + gemm_rt = get_mxfp4_gemm_wave( + shape, macrotiles, dump_dir=dump_dir, gemm_id=gemm_id + ) except Exception as e: raise BenchmarkError( f"Compilation failed for shape {shape}: {e}", stage="compile_failed" @@ -371,8 +404,8 @@ def parse_args(): p.add_argument( "--dump-dir", type=Path, - default=Path("/tmp/bench_mxfp4_dump"), - help="Directory for rocprof output (default: /tmp/bench_mxfp4_dump)", + default=None, + help="Directory for dumps (mlir, inter, rocprof). If not set, a timestamped subdir is created in the temp directory.", ) p.add_argument( "--kernel-regex", @@ -466,10 +499,11 @@ def main(): print("--shapes is required.", file=sys.stderr) sys.exit(1) - dump_dir = Path(args.dump_dir) - dump_dir.mkdir(parents=True, exist_ok=True) run_id = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - run_dump_dir = dump_dir / run_id + if args.dump_dir is not None: + run_dump_dir = Path(args.dump_dir) + else: + run_dump_dir = Path(tempfile.gettempdir()) / "bench_mxfp4_dump" / run_id run_dump_dir.mkdir(parents=True, exist_ok=True) print(f"Dump directory for this run: {run_dump_dir}") kernel_regex = args.kernel_regex @@ -497,35 +531,53 @@ def main(): failed_compilation = [] failed_validation = [] failed_benchmark = [] - for shape, macrotiles in shape_rows: - M, N, K = shape - mt_m, mt_n, mt_k = macrotiles - try: - runtime_us, tflops, status = run_validate_and_benchmark( - shape, - macrotiles, - run_dump_dir, - att_library_path, - kernel_regex=kernel_regex, - warmup_iters=warmup_iters, - benchmark_iters=benchmark_iters, - ) - except BenchmarkError as e: - print(f" {e}", file=sys.stderr) - traceback.print_exc() - status = e.stage - runtime_us, tflops = None, None - ok = status == "ok" - if status == "compile_failed": - failed_compilation.append((M, N, K)) - elif status == "validation_failed": - failed_validation.append((M, N, K)) - elif status == "benchmark_failed": - failed_benchmark.append((M, N, K)) - mean_us = runtime_us if runtime_us is not None else 0.0 - tflops_val = tflops if tflops is not None else 0.0 - results.append( - { + + # Open CSV and write header so we can append results as they come in (survives interrupt/crash) + args.output.parent.mkdir(parents=True, exist_ok=True) + csv_fieldnames = [ + "M", + "N", + "K", + "MT_M", + "MT_N", + "MT_K", + "runtime_us", + "tflops", + "ok", + ] + with open(args.output, "w", newline="") as csv_file: + csv_writer = csv.DictWriter(csv_file, fieldnames=csv_fieldnames) + csv_writer.writeheader() + csv_file.flush() + + for shape, macrotiles in shape_rows: + M, N, K = shape + mt_m, mt_n, mt_k = macrotiles + try: + runtime_us, tflops, status = run_validate_and_benchmark( + shape, + macrotiles, + run_dump_dir, + att_library_path, + kernel_regex=kernel_regex, + warmup_iters=warmup_iters, + benchmark_iters=benchmark_iters, + ) + except BenchmarkError as e: + print(f" {e}", file=sys.stderr) + traceback.print_exc() + status = e.stage + runtime_us, tflops = None, None + ok = status == "ok" + if status == "compile_failed": + failed_compilation.append((M, N, K, mt_m, mt_n, mt_k)) + elif status == "validation_failed": + failed_validation.append((M, N, K, mt_m, mt_n, mt_k)) + elif status == "benchmark_failed": + failed_benchmark.append((M, N, K, mt_m, mt_n, mt_k)) + mean_us = runtime_us if runtime_us is not None else 0.0 + tflops_val = tflops if tflops is not None else 0.0 + row = { "M": M, "N": N, "K": K, @@ -536,11 +588,13 @@ def main(): "tflops": tflops_val, "ok": ok, } - ) - status_str = "ok" if ok else status - print( - f" ({M}, {N}, {K}) MT({mt_m},{mt_n},{mt_k}): {mean_us:.2f} us, {tflops_val:.4f} TFLOPs [{status_str}]" - ) + results.append(row) + csv_writer.writerow(row) + csv_file.flush() + status_str = "ok" if ok else status + print( + f" ({M}, {N}, {K}) MT({mt_m},{mt_n},{mt_k}): {mean_us:.2f} us, {tflops_val:.4f} TFLOPs [{status_str}]" + ) if failed_compilation: print( @@ -573,37 +627,6 @@ def main(): f"{best['runtime_us']:.2f} us, {best['tflops']:.4f} TFLOPs" ) - args.output.parent.mkdir(parents=True, exist_ok=True) - with open(args.output, "w", newline="") as f: - w = csv.DictWriter( - f, - fieldnames=[ - "M", - "N", - "K", - "MT_M", - "MT_N", - "MT_K", - "runtime_us", - "tflops", - "ok", - ], - ) - w.writeheader() - for r in results: - w.writerow( - { - "M": r["M"], - "N": r["N"], - "K": r["K"], - "MT_M": r["MT_M"], - "MT_N": r["MT_N"], - "MT_K": r["MT_K"], - "runtime_us": r["runtime_us"], - "tflops": r["tflops"], - "ok": r["ok"], - } - ) print(f"Results written to {args.output}") diff --git a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py index 6d186c432..f532be6b8 100755 --- a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py +++ b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py @@ -35,6 +35,7 @@ def get_tagged_mxfp4_gemm( mfma_variant: ScaledMMAType = ScaledMMAType.F32_16x16x128_F8F6F4, a_address_space: tkl.AddressSpace = SHARED_ADDRESS_SPACE, b_address_space: tkl.AddressSpace = SHARED_ADDRESS_SPACE, + dynamic_dims: bool = False, reorder_workgroups=True, group_size_n=32, ): @@ -117,6 +118,16 @@ def repeat( N: shape[1], K: shape[2], } + + dynamic_symbols = [] + if dynamic_dims: + dynamic_symbols.append(M) + dynamic_symbols.append(N) + dynamic_symbols.append(K) + del hyperparams[M] + del hyperparams[N] + del hyperparams[K] + hyperparams.update(get_default_scheduling_params()) options = WaveCompileOptions( @@ -125,6 +136,7 @@ def repeat( schedule=SchedulingType.MANUAL, use_global_to_shared=True, minimize_shared_allocs=False, + dynamic_symbols=dynamic_symbols, ) return gemm, options @@ -136,6 +148,7 @@ def get_tagged_mxfp4_gemm_preshuffle_b( wave_shape: tuple[int, int] = (2, 2), mfma_variant: ScaledMMAType = ScaledMMAType.F32_16x16x128_F8F6F4, a_address_space: tkl.AddressSpace = SHARED_ADDRESS_SPACE, + dynamic_dims: bool = False, reorder_workgroups=True, group_size_n=32, ): @@ -186,6 +199,9 @@ def get_tagged_mxfp4_gemm_preshuffle_b( constraints += [tkw.ReorderingConstraint(new_wg0, 0)] constraints += [tkw.ReorderingConstraint(new_wg1, 1)] + if dynamic_dims: + constraints += [tkw.Assumption(K > BLOCK_K * 4)] + # --- B data preshuffle mapping (aiter shuffle_weight) --- # Each 16-row x 32-byte tile is reordered from [n, k_sub, k_elem] to # [k_sub, n, k_elem] so a contiguous 256-byte read fetches one K-chunk @@ -272,7 +288,7 @@ def repeat( ) -> tkl.Register[M, N, tkl.f32]: a_reg = tkw.read(a, tag="read_a") a_reg = tkw.bitcast(a_reg, tkl.f4e2m1fn, tag="bitcast_a") - a_scale_reg = tkw.read(a_scale, mapping=a_scale_mapping, tag="read_a_scale") + a_scale_reg = tkw.read(a_scale, tag="read_a_scale") a_scale_reg = tkw.bitcast(a_scale_reg, tkl.f8e8m0fnu, tag="bitcast_a_scale") b_reg = tkw.read(b, mapping=b_preshuffle_mapping, tag="read_b") b_reg = tkw.bitcast(b_reg, tkl.f4e2m1fn, tag="bitcast_b") @@ -300,12 +316,24 @@ def repeat( } hyperparams.update(get_default_scheduling_params()) + dynamic_symbols = [] + if dynamic_dims: + dynamic_symbols.append(M) + dynamic_symbols.append(N) + dynamic_symbols.append(K) + del hyperparams[M] + del hyperparams[N] + del hyperparams[K] + hyperparams[K_PACKED] = K // 2 + hyperparams[K_SCALE_SHUFFLED] = (((K // 32) + 7) // 8) * 8 + options = WaveCompileOptions( subs=hyperparams, canonicalize=True, schedule=SchedulingType.MANUAL, use_global_to_shared=True, minimize_shared_allocs=False, + dynamic_symbols=dynamic_symbols, ) return gemm, options