From 817c4196f04f2a503cc9da6b8ab56c4a0f996f42 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 12 Feb 2026 15:53:12 +0100 Subject: [PATCH 01/10] interleaved intrin Signed-off-by: Ivan Butygin --- tests/kernel/wave_gemm_mxfp_test.py | 94 +++++++++++++++++++ .../kernel/compiler/wave_codegen/emitter.py | 4 +- .../kernel/compiler/wave_codegen/handlers.py | 8 +- wave_lang/kernel/wave/constraints.py | 45 +++++++-- 4 files changed, 138 insertions(+), 13 deletions(-) diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index 15925f0dea..93c4f6d87d 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -476,6 +476,100 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: torch.testing.assert_close(torch_out, out, atol=2e-3, rtol=1e-3, check_dtype=False) +@require_e2e +@require_cdna4 +@pytest.mark.parametrize("shape", [(1024, 1024, 1024)]) +@pytest.mark.parametrize( + "mfma_variant", + [ + ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED, + ], +) +@pytest.mark.parametrize( + "enable_scheduling", + [ + SchedulingType.NONE, + ], +) +def testScaledGemmMXFP4ScalesInterleaved( + shape: tuple[int], + mfma_variant: ScaledMMAType, + enable_scheduling: SchedulingType, +): + # Input sizes + M = tkl.sym.M + N = tkl.sym.N + K = tkl.sym.K + # Workgroup tile sizes + BLOCK_M = tkl.sym.BLOCK_M + BLOCK_N = tkl.sym.BLOCK_N + BLOCK_K = tkl.sym.BLOCK_K + # Address space (for GPU, shared(1) or global(0)) + ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE + + # Expose user-constraints + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [tkw.HardwareConstraint(threads_per_wave=64, mma_type=mfma_variant)] + + @tkw.wave(constraints) + def gemm( + a: tkl.Memory[M, K / 2, ADDRESS_SPACE, tkl.i8], + a_scale: tkl.Memory[M, K / 32, ADDRESS_SPACE, tkl.i8], + b: tkl.Memory[N, K / 2, ADDRESS_SPACE, tkl.i8], + b_scale: tkl.Memory[N, K / 32, ADDRESS_SPACE, tkl.i8], + c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.iterate(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a) + a_reg = tkw.bitcast(a_reg, tkl.f4e2m1fn) + a_scale_reg = tkw.read(a_scale) + a_scale_reg = tkw.bitcast(a_scale_reg, tkl.f8e8m0fnu) + b_reg = tkw.read(b) + b_reg = tkw.bitcast(b_reg, tkl.f4e2m1fn) + b_scale_reg = tkw.read(b_scale) + b_scale_reg = tkw.bitcast(b_scale_reg, tkl.f8e8m0fnu) + acc = tkw.scaled_mma(a_reg, a_scale_reg, b_reg, b_scale_reg, acc) + return acc + + tkw.write(repeat, c) + + hyperparams = { + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + BLOCK_M: 32, + BLOCK_N: 32, + BLOCK_K: 256, + M: shape[0], + N: shape[1], + K: shape[2], + } + hyperparams.update(get_default_scheduling_params()) + + options = WaveCompileOptions( + subs=hyperparams, + canonicalize=True, + schedule=enable_scheduling, + ) + options = set_default_run_config(options) + gemm = wave_compile(options, gemm) + + x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) + out = device_zeros(x.shape[0], w.shape[1], dtype=torch.float32) + + w_t = w.T.contiguous() + gemm(x, x_scales, w_t, w_scales, out) + torch_out = torchScaledGemmMXFP4(x, w, x_scales, w_scales) + + torch.testing.assert_close(torch_out, out, check_dtype=False) + + @require_e2e @require_cdna4 @pytest.mark.parametrize("shape", [(1024, 1024, 1024), (8192, 8192, 8192)]) diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index 970d6b52b0..d1ed6b1690 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -1200,14 +1200,14 @@ def cast_vector(emitter: WaveEmitter, value, *, element_type: Optional[IrType] = return vector_d.broadcast(vector_type, value) -def cast_scalar(emitter: WaveEmitter, value): +def cast_scalar(emitter: WaveEmitter, value: Value, position: int = 0) -> Value: proxy_value = cast_py_value(emitter, value) value = proxy_value.ir_value # After scalar promotion, promote to vector. if isinstance(value.type, VectorType): # Vector -> scalar. - return vector_d.extract(value, static_position=[0], dynamic_position=[]) + return vector_d.extract(value, static_position=[position], dynamic_position=[]) else: # Already a scalar. Coerce or return. # No target element_type. diff --git a/wave_lang/kernel/compiler/wave_codegen/handlers.py b/wave_lang/kernel/compiler/wave_codegen/handlers.py index 226a8f9fad..cde78710c1 100644 --- a/wave_lang/kernel/compiler/wave_codegen/handlers.py +++ b/wave_lang/kernel/compiler/wave_codegen/handlers.py @@ -515,7 +515,13 @@ def handle_scaled_mma(emitter: WaveEmitter, node: fx.Node): scales = [cast_vector(emitter, val) for val in [lhs_scale, rhs_scale]] result = emit_wmma_scaled(m, n, k, acc, values, scales) else: - scales = [cast_scalar(emitter, val) for val in [lhs_scale, rhs_scale]] + pos = [0, 0] + if mma_type == ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED: + pos[1] = 1 + scales = [ + cast_scalar(emitter, val, pos) + for pos, val in zip(pos, [lhs_scale, rhs_scale]) + ] result = emit_mfma_scaled(m, n, k, acc, values, scales) emitter.bind_node_proxy(node, IRProxyValue(result)) diff --git a/wave_lang/kernel/wave/constraints.py b/wave_lang/kernel/wave/constraints.py index d49538b8a8..ee712da423 100644 --- a/wave_lang/kernel/wave/constraints.py +++ b/wave_lang/kernel/wave/constraints.py @@ -81,6 +81,8 @@ class ScaledMMAType(Enum): F32_16x16x128_F8F6F4 = 0x1340 F32_32x32x64_F8F6F4 = 0x1341 + F32_16x16x128_F8F6F4_SCALES_INTERLEAVED = 0x1342 + # Intrinsics introduced in GFX1250 GFX1250_F32_16x16x128_F8F6F4 = 0x1940 @@ -290,7 +292,10 @@ def mma_matrix_shapes( | MMAType.I32_32x32x16_I8 ): return (32, 32, 16) - case ScaledMMAType.F32_16x16x128_F8F6F4: + case ( + ScaledMMAType.F32_16x16x128_F8F6F4 + | ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED + ): return (16, 16, 128) case ScaledMMAType.F32_32x32x64_F8F6F4: return (32, 32, 64) @@ -435,19 +440,28 @@ def mma_index_offset(self, mma_type: Optional[MMAType | ScaledMMAType]): ), ), # K ] - case ScaledMMAType.F32_32x32x64_F8F6F4: + case ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED: offset = [ Piecewise( - (lane % 32, ~MMA_ACC), + (lane % 16, ~MMA_ACC), (4 * floor(lane / 16), MMA_ACC) + ), # M + lane % 16, # N + Piecewise( ( - (8 * floor(GPR_NUM / 4) % 32) - + 4 * floor(lane / 32) - + (GPR_NUM % 4), - MMA_ACC, + 64 * floor(GPR_NUM / 16) + + 16 * floor(lane / 16) + + (GPR_NUM % 16), + ~(MMA_LHS_SCALE | MMA_RHS_SCALE | MMA_SCALE_FP4), ), - ), # M - lane % 32, # N - 32 * floor(lane / 32), # K + ( + 32 * floor(lane / 16) - 32, + (MMA_RHS_SCALE), + ), + ( + 32 * floor(lane / 16), + True, + ), + ), # K ] case ScaledMMAType.GFX1250_F32_16x16x128_F8F6F4: offset = [ @@ -639,6 +653,17 @@ def apply_mma_mapping( 1, # N 1, # K ] + case ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED: + size = [ + Piecewise((1, ~MMA_ACC), (4, MMA_ACC)), # M + 1, # N + Piecewise((64, MMA_LHS_SCALE | MMA_RHS_SCALE), (32, True)), # K + ] + stride = [ + Piecewise((1, ~MMA_ACC), (16, MMA_ACC)), # M + 1, # N + 1, # K + ] case ScaledMMAType.F32_32x32x64_F8F6F4: size = [ Piecewise((1, ~MMA_ACC), (16, MMA_ACC)), # M From a376fabc6995728d78e5ead27a667ed9ad0d49df Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 12 Feb 2026 22:32:54 +0100 Subject: [PATCH 02/10] fixes Signed-off-by: Ivan Butygin --- .../kernel/compiler/wave_codegen/handlers.py | 30 ++++++++++++------- wave_lang/kernel/wave/constraints.py | 6 +--- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/handlers.py b/wave_lang/kernel/compiler/wave_codegen/handlers.py index cde78710c1..75382bfcf7 100644 --- a/wave_lang/kernel/compiler/wave_codegen/handlers.py +++ b/wave_lang/kernel/compiler/wave_codegen/handlers.py @@ -437,13 +437,21 @@ def handle_mma(emitter: WaveEmitter, node: fx.Node): def emit_mfma_scaled( - m: int, n: int, k: int, acc: Value, values: list[Value], scales: list[Value] + m: int, + n: int, + k: int, + acc: Value, + values: list[Value], + scales: list[Value], + idx_a: int, + idx_b: int, ) -> Value: - m = get_constant_attr(m, IntegerType.get_signless(32)) - n = get_constant_attr(n, IntegerType.get_signless(32)) - k = get_constant_attr(k, IntegerType.get_signless(32)) - idx_a = get_constant_attr(0, IntegerType.get_signless(32)) - idx_b = get_constant_attr(0, IntegerType.get_signless(32)) + i32 = IntegerType.get_signless(32) + m = get_constant_attr(m, i32) + n = get_constant_attr(n, i32) + k = get_constant_attr(k, i32) + idx_a = get_constant_attr(idx_a, i32) + idx_b = get_constant_attr(idx_b, i32) result = amdgpu_d.scaled_mfma( m=m, @@ -518,11 +526,11 @@ def handle_scaled_mma(emitter: WaveEmitter, node: fx.Node): pos = [0, 0] if mma_type == ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED: pos[1] = 1 - scales = [ - cast_scalar(emitter, val, pos) - for pos, val in zip(pos, [lhs_scale, rhs_scale]) - ] - result = emit_mfma_scaled(m, n, k, acc, values, scales) + scales = [cast_vector(emitter, val) for val in [lhs_scale, rhs_scale]] + else: + scales = [cast_scalar(emitter, val) for val in [lhs_scale, rhs_scale]] + + result = emit_mfma_scaled(m, n, k, acc, values, scales, pos[0], pos[1]) emitter.bind_node_proxy(node, IRProxyValue(result)) diff --git a/wave_lang/kernel/wave/constraints.py b/wave_lang/kernel/wave/constraints.py index ee712da423..45aa55d624 100644 --- a/wave_lang/kernel/wave/constraints.py +++ b/wave_lang/kernel/wave/constraints.py @@ -453,10 +453,6 @@ def mma_index_offset(self, mma_type: Optional[MMAType | ScaledMMAType]): + (GPR_NUM % 16), ~(MMA_LHS_SCALE | MMA_RHS_SCALE | MMA_SCALE_FP4), ), - ( - 32 * floor(lane / 16) - 32, - (MMA_RHS_SCALE), - ), ( 32 * floor(lane / 16), True, @@ -657,7 +653,7 @@ def apply_mma_mapping( size = [ Piecewise((1, ~MMA_ACC), (4, MMA_ACC)), # M 1, # N - Piecewise((64, MMA_LHS_SCALE | MMA_RHS_SCALE), (32, True)), # K + Piecewise((128, MMA_LHS_SCALE | MMA_RHS_SCALE), (32, True)), # K ] stride = [ Piecewise((1, ~MMA_ACC), (16, MMA_ACC)), # M From f2608cfc19eede7c23f742657ac33452e1858635 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 12 Feb 2026 23:07:16 +0100 Subject: [PATCH 03/10] pass stub Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/compile.py | 2 + .../kernel/wave/interleave_scaled_mma.py | 76 +++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 wave_lang/kernel/wave/interleave_scaled_mma.py diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 7d97bc2721..edfac705da 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -64,6 +64,7 @@ from .hardware_transpose import mark_hardware_transpose_candidates from .hoisting import hoist_loop_invariant_ops from .in_thread_transpose import in_thread_transpose +from .interleave_scaled_mma import interleave_scaled_mma from .location_check_pass import location_check_pass from .memory_analysis.minimize_shared_allocs import minimize_shared_allocs from .minimize_global_loads import minimize_global_loads @@ -723,6 +724,7 @@ def _trace_launchable_and_get_kernel_signature( # Optimizations. if options.optimization_level: graph_passes += [ + partial(interleave_scaled_mma, trace, launchable.constraints), partial(hoist_loop_invariant_ops, trace, launchable.constraints), partial(tensor_load_to_shared, trace, launchable.constraints, options), partial(multicast, trace, launchable.constraints, options), diff --git a/wave_lang/kernel/wave/interleave_scaled_mma.py b/wave_lang/kernel/wave/interleave_scaled_mma.py new file mode 100644 index 0000000000..a02add4a07 --- /dev/null +++ b/wave_lang/kernel/wave/interleave_scaled_mma.py @@ -0,0 +1,76 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch.fx as fx + +from .._support.tracing import CapturedTrace +from ..ops.wave_ops import ( + NewScalar, + Reshape, + ScaledMMA, + get_custom, +) +from .constraints import ( + Constraint, + ScaledMMAType, +) +from .utils.general_utils import get_hardware_constraint + + +def interleave_scaled_mma(trace: CapturedTrace, constraints: list[Constraint]): + """ + Transforms ScaledMMA operations using F32_16x16x128_F8F6F4 into + F32_16x16x128_F8F6F4_SCALES_INTERLEAVED by combining separate + lhs_scale and rhs_scale into a single 4-element vector + [a_scale, b_scale, 0, 0] passed as both scale inputs. + + This reduces register pressure by packing both scale values + into a single VGPR, using byte index 0 for a_scale and + byte index 1 for b_scale. + """ + hardware_constraint = get_hardware_constraint(constraints) + + def is_target_scaled_mma(node: fx.Node) -> bool: + custom = get_custom(node) + if not isinstance(custom, ScaledMMA): + return False + mma_type = custom.mma_type or hardware_constraint.mma_type + return mma_type == ScaledMMAType.F32_16x16x128_F8F6F4 + + nodes = trace.walk(is_target_scaled_mma) + if not nodes: + return + + for node in nodes: + mma_op = get_custom(node) + scale_dtype = get_custom(mma_op.lhs_scale).type.dtype + + with mma_op.graph.inserting_before(mma_op.fx_node): + # Create zero padding scalars. + zero0 = NewScalar(0.0, scale_dtype).add_to_graph( + mma_op.graph, loc=mma_op.location + ) + zero1 = NewScalar(0.0, scale_dtype).add_to_graph( + mma_op.graph, loc=mma_op.location + ) + + # Combine scales: [lhs_scale, rhs_scale, 0, 0]. + combined = Reshape( + [mma_op.lhs_scale, mma_op.rhs_scale, zero0, zero1], + {}, + ).add_to_graph(mma_op.graph, loc=mma_op.location) + + # Create new ScaledMMA with interleaved type. + new_mma = ScaledMMA( + mma_op.lhs, + combined, + mma_op.rhs, + combined, + mma_op.acc, + ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED, + ).add_to_graph(mma_op.graph, loc=mma_op.location) + + mma_op.replace_all_uses_with(new_mma) From 249eeb59d97fbf7f3bb73c52fe72c8faf805070d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 12 Feb 2026 23:43:39 +0100 Subject: [PATCH 04/10] fixes Signed-off-by: Ivan Butygin --- tests/kernel/wave_gemm_mxfp_test.py | 2 +- wave_lang/kernel/compiler/wave_codegen/handlers.py | 14 +++++++++----- wave_lang/kernel/wave/interleave_scaled_mma.py | 12 ++++++------ wave_lang/kernel/wave/water.py | 10 +++++++++- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index 93c4f6d87d..305b26dac2 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -482,7 +482,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: @pytest.mark.parametrize( "mfma_variant", [ - ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED, + ScaledMMAType.F32_16x16x128_F8F6F4, ], ) @pytest.mark.parametrize( diff --git a/wave_lang/kernel/compiler/wave_codegen/handlers.py b/wave_lang/kernel/compiler/wave_codegen/handlers.py index 75382bfcf7..03f346c9b2 100644 --- a/wave_lang/kernel/compiler/wave_codegen/handlers.py +++ b/wave_lang/kernel/compiler/wave_codegen/handlers.py @@ -2180,18 +2180,22 @@ def handle_reshape(emitter: WaveEmitter, node: fx.Node): # Determine whether to extract or combine. if len(args) > 1: vectors = [cast_vector(emitter, arg) for arg in args] - shape = vectors[0].type.shape[0] - if shape == 1: + shape = vectors[0].type.shape[0] if vectors[0].type.rank > 0 else 0 + if shape <= 1: # If source is 1-element vector or scalar (which will be casted to - # 1-element vector by `cast_vector`), we can construct the result + # 0-d vector by `cast_vector`), we can construct the result # vector using `extract` and a single `from_elements` op instead of # series of `insert_strided_slice` ops. values = [ - vector_d.extract(vector, static_position=[0], dynamic_position=[]) + vector_d.extract( + vector, + static_position=[] if vector.type.rank == 0 else [0], + dynamic_position=[], + ) for vector in vectors ] element_type = vectors[0].type.element_type - vector_type = VectorType.get([shape * len(args)], element_type) + vector_type = VectorType.get([len(args)], element_type) result = vector_d.from_elements(vector_type, values) emitter.bind_node_proxy(node, IRProxyValue(result)) return diff --git a/wave_lang/kernel/wave/interleave_scaled_mma.py b/wave_lang/kernel/wave/interleave_scaled_mma.py index a02add4a07..2ebba5174d 100644 --- a/wave_lang/kernel/wave/interleave_scaled_mma.py +++ b/wave_lang/kernel/wave/interleave_scaled_mma.py @@ -49,17 +49,14 @@ def is_target_scaled_mma(node: fx.Node) -> bool: scale_dtype = get_custom(mma_op.lhs_scale).type.dtype with mma_op.graph.inserting_before(mma_op.fx_node): - # Create zero padding scalars. - zero0 = NewScalar(0.0, scale_dtype).add_to_graph( - mma_op.graph, loc=mma_op.location - ) - zero1 = NewScalar(0.0, scale_dtype).add_to_graph( + # Create zero padding scalar. + zero = NewScalar(0.0, scale_dtype).add_to_graph( mma_op.graph, loc=mma_op.location ) # Combine scales: [lhs_scale, rhs_scale, 0, 0]. combined = Reshape( - [mma_op.lhs_scale, mma_op.rhs_scale, zero0, zero1], + [mma_op.lhs_scale, mma_op.rhs_scale, zero, zero], {}, ).add_to_graph(mma_op.graph, loc=mma_op.location) @@ -72,5 +69,8 @@ def is_target_scaled_mma(node: fx.Node) -> bool: mma_op.acc, ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED, ).add_to_graph(mma_op.graph, loc=mma_op.location) + new_mma.index = mma_op.index + new_mma.vector_shapes = mma_op.vector_shapes mma_op.replace_all_uses_with(new_mma) + mma_op.erase() diff --git a/wave_lang/kernel/wave/water.py b/wave_lang/kernel/wave/water.py index 62eed0425f..048bfa8ece 100644 --- a/wave_lang/kernel/wave/water.py +++ b/wave_lang/kernel/wave/water.py @@ -383,7 +383,15 @@ def add_opt(pipeline): def add_transform(transform: str, entry_point: str) -> tuple[str, dict[str, Any]]: nonlocal mlir_asm - # Erase the last occurrence of '}' from mlir_asm which closes the module operation + # Add transform.with_named_sequence attribute to the module if missing. + attr_name = "transform.with_named_sequence" + if attr_name not in mlir_asm: + mlir_asm = mlir_asm.replace( + "gpu.container_module", + "gpu.container_module, " + attr_name, + 1, + ) + # Erase the last occurrence of '}' from mlir_asm which closes the module operation. last_close = mlir_asm.rfind("}") if last_close != -1: mlir_asm = mlir_asm[:last_close] From 688e4d25a94936096a095f27073286dbc4cf2b3f Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 12 Feb 2026 23:53:35 +0100 Subject: [PATCH 05/10] move pass and test Signed-off-by: Ivan Butygin --- tests/kernel/wave_gemm_mxfp_test.py | 37 ++++++++++++++++++++++------- wave_lang/kernel/wave/compile.py | 2 +- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index 305b26dac2..26263f307d 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -337,35 +337,52 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path): # We encode the exact registers and wait counts as we want to know if # they suddenly change due to backend or upstream MLIR changes. if use_water_backend: - vgpr_count = 164 + vgpr_count = 154 vgpr_spill_count = 0 - sgpr_count = 57 + sgpr_count = 61 sgpr_spill_count = 0 waitcounts = [ "s_waitcnt lgkmcnt(0)", "s_waitcnt vmcnt(0)", - "s_waitcnt vmcnt(0) lgkmcnt(0)", + "s_waitcnt lgkmcnt(6)", + "s_waitcnt lgkmcnt(5)", + "s_waitcnt lgkmcnt(14)", + "s_waitcnt lgkmcnt(6)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(0)", "s_waitcnt vmcnt(0)", "s_waitcnt lgkmcnt(7)", + "s_waitcnt lgkmcnt(8)", "s_waitcnt lgkmcnt(5)", - "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", "s_waitcnt lgkmcnt(3)", "s_waitcnt lgkmcnt(1)", "s_waitcnt lgkmcnt(0)", ] else: - vgpr_count = 164 + vgpr_count = 142 vgpr_spill_count = 0 - sgpr_count = 59 + sgpr_count = 61 sgpr_spill_count = 0 waitcounts = [ "s_waitcnt lgkmcnt(0)", "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(6)", + "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", "s_waitcnt vmcnt(0) lgkmcnt(0)", "s_waitcnt vmcnt(0)", - "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(7)", + "s_waitcnt lgkmcnt(8)", "s_waitcnt lgkmcnt(5)", - "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(0)", "s_waitcnt lgkmcnt(3)", "s_waitcnt lgkmcnt(1)", "s_waitcnt lgkmcnt(0)", @@ -387,6 +404,10 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path): metadata.waitcnt_ops == waitcounts ), f"Expected {waitcounts} waitcnt operations, got {metadata.waitcnt_ops}" + # Verify interleaved scale instructions are generated (op_sel:[0,1,0] + # means byte 0 for a_scale and byte 1 for b_scale in the same VGPR). + assert "op_sel:[0,1,0]" in text, "Expected interleaved scale mfma instructions" + @require_e2e @require_cdna4 diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index edfac705da..6c5ba059ca 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -724,7 +724,6 @@ def _trace_launchable_and_get_kernel_signature( # Optimizations. if options.optimization_level: graph_passes += [ - partial(interleave_scaled_mma, trace, launchable.constraints), partial(hoist_loop_invariant_ops, trace, launchable.constraints), partial(tensor_load_to_shared, trace, launchable.constraints, options), partial(multicast, trace, launchable.constraints, options), @@ -803,6 +802,7 @@ def _trace_launchable_and_get_kernel_signature( trace, options.minimize_shared_allocs, ), + partial(interleave_scaled_mma, trace, launchable.constraints), ] graph_passes += [ partial( From 150fdc73fff40e39d5f5dde368fa95c90113d8fc Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 13 Feb 2026 00:03:56 +0100 Subject: [PATCH 06/10] remove test Signed-off-by: Ivan Butygin --- tests/kernel/wave_gemm_mxfp_test.py | 94 ----------------------------- 1 file changed, 94 deletions(-) diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index 26263f307d..4c17f11ac8 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -497,100 +497,6 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: torch.testing.assert_close(torch_out, out, atol=2e-3, rtol=1e-3, check_dtype=False) -@require_e2e -@require_cdna4 -@pytest.mark.parametrize("shape", [(1024, 1024, 1024)]) -@pytest.mark.parametrize( - "mfma_variant", - [ - ScaledMMAType.F32_16x16x128_F8F6F4, - ], -) -@pytest.mark.parametrize( - "enable_scheduling", - [ - SchedulingType.NONE, - ], -) -def testScaledGemmMXFP4ScalesInterleaved( - shape: tuple[int], - mfma_variant: ScaledMMAType, - enable_scheduling: SchedulingType, -): - # Input sizes - M = tkl.sym.M - N = tkl.sym.N - K = tkl.sym.K - # Workgroup tile sizes - BLOCK_M = tkl.sym.BLOCK_M - BLOCK_N = tkl.sym.BLOCK_N - BLOCK_K = tkl.sym.BLOCK_K - # Address space (for GPU, shared(1) or global(0)) - ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE - - # Expose user-constraints - constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] - constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] - constraints += [tkw.TilingConstraint(K, BLOCK_K)] - constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] - constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] - - constraints += [tkw.HardwareConstraint(threads_per_wave=64, mma_type=mfma_variant)] - - @tkw.wave(constraints) - def gemm( - a: tkl.Memory[M, K / 2, ADDRESS_SPACE, tkl.i8], - a_scale: tkl.Memory[M, K / 32, ADDRESS_SPACE, tkl.i8], - b: tkl.Memory[N, K / 2, ADDRESS_SPACE, tkl.i8], - b_scale: tkl.Memory[N, K / 32, ADDRESS_SPACE, tkl.i8], - c: tkl.Memory[M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], - ): - c_reg = tkl.Register[M, N, tkl.f32](0.0) - - @tkw.iterate(K, init_args=[c_reg]) - def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: - a_reg = tkw.read(a) - a_reg = tkw.bitcast(a_reg, tkl.f4e2m1fn) - a_scale_reg = tkw.read(a_scale) - a_scale_reg = tkw.bitcast(a_scale_reg, tkl.f8e8m0fnu) - b_reg = tkw.read(b) - b_reg = tkw.bitcast(b_reg, tkl.f4e2m1fn) - b_scale_reg = tkw.read(b_scale) - b_scale_reg = tkw.bitcast(b_scale_reg, tkl.f8e8m0fnu) - acc = tkw.scaled_mma(a_reg, a_scale_reg, b_reg, b_scale_reg, acc) - return acc - - tkw.write(repeat, c) - - hyperparams = { - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - BLOCK_M: 32, - BLOCK_N: 32, - BLOCK_K: 256, - M: shape[0], - N: shape[1], - K: shape[2], - } - hyperparams.update(get_default_scheduling_params()) - - options = WaveCompileOptions( - subs=hyperparams, - canonicalize=True, - schedule=enable_scheduling, - ) - options = set_default_run_config(options) - gemm = wave_compile(options, gemm) - - x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) - out = device_zeros(x.shape[0], w.shape[1], dtype=torch.float32) - - w_t = w.T.contiguous() - gemm(x, x_scales, w_t, w_scales, out) - torch_out = torchScaledGemmMXFP4(x, w, x_scales, w_scales) - - torch.testing.assert_close(torch_out, out, check_dtype=False) - - @require_e2e @require_cdna4 @pytest.mark.parametrize("shape", [(1024, 1024, 1024), (8192, 8192, 8192)]) From 82c6c2b55b6851d85a320fbc266ea31427d21507 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 13 Feb 2026 00:07:13 +0100 Subject: [PATCH 07/10] fix Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/constraints.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/wave_lang/kernel/wave/constraints.py b/wave_lang/kernel/wave/constraints.py index 45aa55d624..836f762e8b 100644 --- a/wave_lang/kernel/wave/constraints.py +++ b/wave_lang/kernel/wave/constraints.py @@ -440,24 +440,22 @@ def mma_index_offset(self, mma_type: Optional[MMAType | ScaledMMAType]): ), ), # K ] - case ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED: + case ( + ScaledMMAType.F32_32x32x64_F8F6F4 + | ScaledMMAType.F32_32x32x64_F8F6F4_SCALES_INTERLEAVED + ): offset = [ Piecewise( - (lane % 16, ~MMA_ACC), (4 * floor(lane / 16), MMA_ACC) - ), # M - lane % 16, # N - Piecewise( - ( - 64 * floor(GPR_NUM / 16) - + 16 * floor(lane / 16) - + (GPR_NUM % 16), - ~(MMA_LHS_SCALE | MMA_RHS_SCALE | MMA_SCALE_FP4), - ), + (lane % 32, ~MMA_ACC), ( - 32 * floor(lane / 16), - True, + (8 * floor(GPR_NUM / 4) % 32) + + 4 * floor(lane / 32) + + (GPR_NUM % 4), + MMA_ACC, ), - ), # K + ), # M + lane % 32, # N + 32 * floor(lane / 32), # K ] case ScaledMMAType.GFX1250_F32_16x16x128_F8F6F4: offset = [ From cbbc59bc4eeb370ca6efc2810e86db1c275a2e73 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 13 Feb 2026 00:26:14 +0100 Subject: [PATCH 08/10] upper intrin Signed-off-by: Ivan Butygin --- wave_lang/kernel/compiler/wave_codegen/handlers.py | 5 ++++- wave_lang/kernel/wave/constraints.py | 13 +++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/handlers.py b/wave_lang/kernel/compiler/wave_codegen/handlers.py index 03f346c9b2..a251fea5a0 100644 --- a/wave_lang/kernel/compiler/wave_codegen/handlers.py +++ b/wave_lang/kernel/compiler/wave_codegen/handlers.py @@ -525,7 +525,10 @@ def handle_scaled_mma(emitter: WaveEmitter, node: fx.Node): else: pos = [0, 0] if mma_type == ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED: - pos[1] = 1 + pos = [0, 1] + scales = [cast_vector(emitter, val) for val in [lhs_scale, rhs_scale]] + elif mma_type == ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED_UPPER: + pos = [2, 3] scales = [cast_vector(emitter, val) for val in [lhs_scale, rhs_scale]] else: scales = [cast_scalar(emitter, val) for val in [lhs_scale, rhs_scale]] diff --git a/wave_lang/kernel/wave/constraints.py b/wave_lang/kernel/wave/constraints.py index 836f762e8b..2a357b9e6f 100644 --- a/wave_lang/kernel/wave/constraints.py +++ b/wave_lang/kernel/wave/constraints.py @@ -82,6 +82,7 @@ class ScaledMMAType(Enum): F32_32x32x64_F8F6F4 = 0x1341 F32_16x16x128_F8F6F4_SCALES_INTERLEAVED = 0x1342 + F32_16x16x128_F8F6F4_SCALES_INTERLEAVED_UPPER = 0x1343 # Intrinsics introduced in GFX1250 GFX1250_F32_16x16x128_F8F6F4 = 0x1940 @@ -295,6 +296,7 @@ def mma_matrix_shapes( case ( ScaledMMAType.F32_16x16x128_F8F6F4 | ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED + | ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED_UPPER ): return (16, 16, 128) case ScaledMMAType.F32_32x32x64_F8F6F4: @@ -421,7 +423,11 @@ def mma_index_offset(self, mma_type: Optional[MMAType | ScaledMMAType]): + 4 * floor(lane / 32) + (GPR_NUM % 4), # K ] - case ScaledMMAType.F32_16x16x128_F8F6F4: + case ( + ScaledMMAType.F32_16x16x128_F8F6F4 + | ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED + | ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED_UPPER + ): offset = [ Piecewise( (lane % 16, ~MMA_ACC), (4 * floor(lane / 16), MMA_ACC) @@ -647,7 +653,10 @@ def apply_mma_mapping( 1, # N 1, # K ] - case ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED: + case ( + ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED + | ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED_UPPER + ): size = [ Piecewise((1, ~MMA_ACC), (4, MMA_ACC)), # M 1, # N From e7b029e718a7e3179d9dbf6e68486dd0c5ca5c6c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 13 Feb 2026 00:36:15 +0100 Subject: [PATCH 09/10] interleave upper Signed-off-by: Ivan Butygin --- tests/kernel/wave_gemm_mxfp_test.py | 32 +++-- .../kernel/wave/interleave_scaled_mma.py | 120 +++++++++++++----- 2 files changed, 103 insertions(+), 49 deletions(-) diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index 4c17f11ac8..28cad65c45 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -337,7 +337,7 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path): # We encode the exact registers and wait counts as we want to know if # they suddenly change due to backend or upstream MLIR changes. if use_water_backend: - vgpr_count = 154 + vgpr_count = 146 vgpr_spill_count = 0 sgpr_count = 61 sgpr_spill_count = 0 @@ -345,15 +345,14 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path): "s_waitcnt lgkmcnt(0)", "s_waitcnt vmcnt(0)", "s_waitcnt lgkmcnt(6)", - "s_waitcnt lgkmcnt(5)", - "s_waitcnt lgkmcnt(14)", - "s_waitcnt lgkmcnt(6)", - "s_waitcnt lgkmcnt(0)", - "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(2)", "s_waitcnt lgkmcnt(0)", + "s_waitcnt vmcnt(0) lgkmcnt(0)", "s_waitcnt vmcnt(0)", "s_waitcnt lgkmcnt(7)", - "s_waitcnt lgkmcnt(8)", + "s_waitcnt lgkmcnt(6)", + "s_waitcnt lgkmcnt(6)", "s_waitcnt lgkmcnt(5)", "s_waitcnt lgkmcnt(3)", "s_waitcnt lgkmcnt(1)", @@ -363,22 +362,26 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path): "s_waitcnt lgkmcnt(0)", ] else: - vgpr_count = 142 + vgpr_count = 140 vgpr_spill_count = 0 - sgpr_count = 61 + sgpr_count = 59 sgpr_spill_count = 0 waitcounts = [ "s_waitcnt lgkmcnt(0)", "s_waitcnt vmcnt(0)", "s_waitcnt lgkmcnt(6)", - "s_waitcnt lgkmcnt(3)", "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(2)", + "s_waitcnt lgkmcnt(0)", + "s_waitcnt lgkmcnt(4)", + "s_waitcnt lgkmcnt(3)", "s_waitcnt lgkmcnt(1)", "s_waitcnt lgkmcnt(0)", "s_waitcnt vmcnt(0) lgkmcnt(0)", "s_waitcnt vmcnt(0)", "s_waitcnt lgkmcnt(7)", - "s_waitcnt lgkmcnt(8)", + "s_waitcnt lgkmcnt(6)", + "s_waitcnt lgkmcnt(6)", "s_waitcnt lgkmcnt(5)", "s_waitcnt lgkmcnt(3)", "s_waitcnt lgkmcnt(1)", @@ -404,9 +407,10 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path): metadata.waitcnt_ops == waitcounts ), f"Expected {waitcounts} waitcnt operations, got {metadata.waitcnt_ops}" - # Verify interleaved scale instructions are generated (op_sel:[0,1,0] - # means byte 0 for a_scale and byte 1 for b_scale in the same VGPR). - assert "op_sel:[0,1,0]" in text, "Expected interleaved scale mfma instructions" + # Verify interleaved scale instructions are generated. + # op_sel_hi:[0,0,0] selects lower bytes (0,1), op_sel_hi:[1,1,0] selects upper bytes (2,3). + assert "op_sel_hi:[0,0,0]" in text, "Expected lower interleaved scale instructions" + assert "op_sel_hi:[1,1,0]" in text, "Expected upper interleaved scale instructions" @require_e2e diff --git a/wave_lang/kernel/wave/interleave_scaled_mma.py b/wave_lang/kernel/wave/interleave_scaled_mma.py index 2ebba5174d..3a073239fa 100644 --- a/wave_lang/kernel/wave/interleave_scaled_mma.py +++ b/wave_lang/kernel/wave/interleave_scaled_mma.py @@ -23,13 +23,14 @@ def interleave_scaled_mma(trace: CapturedTrace, constraints: list[Constraint]): """ Transforms ScaledMMA operations using F32_16x16x128_F8F6F4 into - F32_16x16x128_F8F6F4_SCALES_INTERLEAVED by combining separate - lhs_scale and rhs_scale into a single 4-element vector - [a_scale, b_scale, 0, 0] passed as both scale inputs. + interleaved variants that pack scale values into a single VGPR. - This reduces register pressure by packing both scale values - into a single VGPR, using byte index 0 for a_scale and - byte index 1 for b_scale. + When two ScaledMMA ops exist in the same subgraph, their scales are + packed into one register [a0, b0, a1, b1] and the first op uses + SCALES_INTERLEAVED (bytes 0,1) while the second uses + SCALES_INTERLEAVED_UPPER (bytes 2,3). + + Unpaired ops fall back to [a, b, 0, 0] with SCALES_INTERLEAVED. """ hardware_constraint = get_hardware_constraint(constraints) @@ -44,33 +45,82 @@ def is_target_scaled_mma(node: fx.Node) -> bool: if not nodes: return + # Group nodes by subgraph so we can pair within each one. + graph_groups: dict[fx.Graph, list[fx.Node]] = {} for node in nodes: - mma_op = get_custom(node) - scale_dtype = get_custom(mma_op.lhs_scale).type.dtype - - with mma_op.graph.inserting_before(mma_op.fx_node): - # Create zero padding scalar. - zero = NewScalar(0.0, scale_dtype).add_to_graph( - mma_op.graph, loc=mma_op.location - ) - - # Combine scales: [lhs_scale, rhs_scale, 0, 0]. - combined = Reshape( - [mma_op.lhs_scale, mma_op.rhs_scale, zero, zero], - {}, - ).add_to_graph(mma_op.graph, loc=mma_op.location) - - # Create new ScaledMMA with interleaved type. - new_mma = ScaledMMA( - mma_op.lhs, - combined, - mma_op.rhs, - combined, - mma_op.acc, - ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED, - ).add_to_graph(mma_op.graph, loc=mma_op.location) - new_mma.index = mma_op.index - new_mma.vector_shapes = mma_op.vector_shapes - - mma_op.replace_all_uses_with(new_mma) - mma_op.erase() + graph_groups.setdefault(node.graph, []).append(node) + + for group in graph_groups.values(): + i = 0 + while i < len(group): + mma_a = get_custom(group[i]) + scale_dtype = get_custom(mma_a.lhs_scale).type.dtype + + if i + 1 < len(group): + # Pair: pack all 4 scales into one register. + mma_b = get_custom(group[i + 1]) + + with mma_a.graph.inserting_before(mma_a.fx_node): + combined = Reshape( + [ + mma_a.lhs_scale, + mma_a.rhs_scale, + mma_b.lhs_scale, + mma_b.rhs_scale, + ], + {}, + ).add_to_graph(mma_a.graph, loc=mma_a.location) + + new_a = ScaledMMA( + mma_a.lhs, + combined, + mma_a.rhs, + combined, + mma_a.acc, + ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED, + ).add_to_graph(mma_a.graph, loc=mma_a.location) + new_a.index = mma_a.index + new_a.vector_shapes = mma_a.vector_shapes + + with mma_b.graph.inserting_before(mma_b.fx_node): + new_b = ScaledMMA( + mma_b.lhs, + combined, + mma_b.rhs, + combined, + mma_b.acc, + ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED_UPPER, + ).add_to_graph(mma_b.graph, loc=mma_b.location) + new_b.index = mma_b.index + new_b.vector_shapes = mma_b.vector_shapes + + mma_a.replace_all_uses_with(new_a) + mma_a.erase() + mma_b.replace_all_uses_with(new_b) + mma_b.erase() + i += 2 + else: + # Unpaired: pad upper half with zeros. + with mma_a.graph.inserting_before(mma_a.fx_node): + zero = NewScalar(0.0, scale_dtype).add_to_graph( + mma_a.graph, loc=mma_a.location + ) + combined = Reshape( + [mma_a.lhs_scale, mma_a.rhs_scale, zero, zero], + {}, + ).add_to_graph(mma_a.graph, loc=mma_a.location) + + new_a = ScaledMMA( + mma_a.lhs, + combined, + mma_a.rhs, + combined, + mma_a.acc, + ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED, + ).add_to_graph(mma_a.graph, loc=mma_a.location) + new_a.index = mma_a.index + new_a.vector_shapes = mma_a.vector_shapes + + mma_a.replace_all_uses_with(new_a) + mma_a.erase() + i += 1 From 408afc7e85086c7a28f70859b93ec77319bb8c35 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 13 Feb 2026 20:42:54 +0100 Subject: [PATCH 10/10] refactor scales indices Signed-off-by: Ivan Butygin --- .../kernel/compiler/wave_codegen/handlers.py | 30 +++++---- wave_lang/kernel/ops/wave_ops.py | 2 + wave_lang/kernel/wave/constraints.py | 34 +--------- .../kernel/wave/interleave_scaled_mma.py | 64 ++++++------------- 4 files changed, 43 insertions(+), 87 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/handlers.py b/wave_lang/kernel/compiler/wave_codegen/handlers.py index a251fea5a0..4d4ec6d5f5 100644 --- a/wave_lang/kernel/compiler/wave_codegen/handlers.py +++ b/wave_lang/kernel/compiler/wave_codegen/handlers.py @@ -496,7 +496,9 @@ def emit_wmma_scaled( @handle_op(scaled_mma) def handle_scaled_mma(emitter: WaveEmitter, node: fx.Node): try: - lhs, lhs_scale, rhs, rhs_scale, acc, mma_type = node.args + lhs, lhs_scale, rhs, rhs_scale, acc, mma_type, scale_idx_a, scale_idx_b = ( + node.args + ) acc = cast_vector(emitter, acc) values = [cast_vector(emitter, val) for val in [lhs, rhs]] except ValueError as e: @@ -523,17 +525,21 @@ def handle_scaled_mma(emitter: WaveEmitter, node: fx.Node): scales = [cast_vector(emitter, val) for val in [lhs_scale, rhs_scale]] result = emit_wmma_scaled(m, n, k, acc, values, scales) else: - pos = [0, 0] - if mma_type == ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED: - pos = [0, 1] - scales = [cast_vector(emitter, val) for val in [lhs_scale, rhs_scale]] - elif mma_type == ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED_UPPER: - pos = [2, 3] - scales = [cast_vector(emitter, val) for val in [lhs_scale, rhs_scale]] - else: - scales = [cast_scalar(emitter, val) for val in [lhs_scale, rhs_scale]] - - result = emit_mfma_scaled(m, n, k, acc, values, scales, pos[0], pos[1]) + pos_a = scale_idx_a if scale_idx_a is not None else 0 + pos_b = scale_idx_b if scale_idx_b is not None else 0 + scale_a = ( + cast_vector(emitter, lhs_scale) + if scale_idx_a is not None + else cast_scalar(emitter, lhs_scale) + ) + scale_b = ( + cast_vector(emitter, rhs_scale) + if scale_idx_b is not None + else cast_scalar(emitter, rhs_scale) + ) + result = emit_mfma_scaled( + m, n, k, acc, values, [scale_a, scale_b], pos_a, pos_b + ) emitter.bind_node_proxy(node, IRProxyValue(result)) diff --git a/wave_lang/kernel/ops/wave_ops.py b/wave_lang/kernel/ops/wave_ops.py index 2aef02ed47..617ec500c0 100644 --- a/wave_lang/kernel/ops/wave_ops.py +++ b/wave_lang/kernel/ops/wave_ops.py @@ -1922,6 +1922,8 @@ class ScaledMMA(MMABase): rhs_scale: fx.Node acc: fx.Node mma_type: Optional["ScaledMMAType"] = None + scale_idx_a: Optional[int] = None + scale_idx_b: Optional[int] = None @property def indexing_dims(self) -> list[IndexSymbol]: diff --git a/wave_lang/kernel/wave/constraints.py b/wave_lang/kernel/wave/constraints.py index 2a357b9e6f..d49538b8a8 100644 --- a/wave_lang/kernel/wave/constraints.py +++ b/wave_lang/kernel/wave/constraints.py @@ -81,9 +81,6 @@ class ScaledMMAType(Enum): F32_16x16x128_F8F6F4 = 0x1340 F32_32x32x64_F8F6F4 = 0x1341 - F32_16x16x128_F8F6F4_SCALES_INTERLEAVED = 0x1342 - F32_16x16x128_F8F6F4_SCALES_INTERLEAVED_UPPER = 0x1343 - # Intrinsics introduced in GFX1250 GFX1250_F32_16x16x128_F8F6F4 = 0x1940 @@ -293,11 +290,7 @@ def mma_matrix_shapes( | MMAType.I32_32x32x16_I8 ): return (32, 32, 16) - case ( - ScaledMMAType.F32_16x16x128_F8F6F4 - | ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED - | ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED_UPPER - ): + case ScaledMMAType.F32_16x16x128_F8F6F4: return (16, 16, 128) case ScaledMMAType.F32_32x32x64_F8F6F4: return (32, 32, 64) @@ -423,11 +416,7 @@ def mma_index_offset(self, mma_type: Optional[MMAType | ScaledMMAType]): + 4 * floor(lane / 32) + (GPR_NUM % 4), # K ] - case ( - ScaledMMAType.F32_16x16x128_F8F6F4 - | ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED - | ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED_UPPER - ): + case ScaledMMAType.F32_16x16x128_F8F6F4: offset = [ Piecewise( (lane % 16, ~MMA_ACC), (4 * floor(lane / 16), MMA_ACC) @@ -446,10 +435,7 @@ def mma_index_offset(self, mma_type: Optional[MMAType | ScaledMMAType]): ), ), # K ] - case ( - ScaledMMAType.F32_32x32x64_F8F6F4 - | ScaledMMAType.F32_32x32x64_F8F6F4_SCALES_INTERLEAVED - ): + case ScaledMMAType.F32_32x32x64_F8F6F4: offset = [ Piecewise( (lane % 32, ~MMA_ACC), @@ -653,20 +639,6 @@ def apply_mma_mapping( 1, # N 1, # K ] - case ( - ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED - | ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED_UPPER - ): - size = [ - Piecewise((1, ~MMA_ACC), (4, MMA_ACC)), # M - 1, # N - Piecewise((128, MMA_LHS_SCALE | MMA_RHS_SCALE), (32, True)), # K - ] - stride = [ - Piecewise((1, ~MMA_ACC), (16, MMA_ACC)), # M - 1, # N - 1, # K - ] case ScaledMMAType.F32_32x32x64_F8F6F4: size = [ Piecewise((1, ~MMA_ACC), (16, MMA_ACC)), # M diff --git a/wave_lang/kernel/wave/interleave_scaled_mma.py b/wave_lang/kernel/wave/interleave_scaled_mma.py index 3a073239fa..22de3066ce 100644 --- a/wave_lang/kernel/wave/interleave_scaled_mma.py +++ b/wave_lang/kernel/wave/interleave_scaled_mma.py @@ -22,15 +22,14 @@ def interleave_scaled_mma(trace: CapturedTrace, constraints: list[Constraint]): """ - Transforms ScaledMMA operations using F32_16x16x128_F8F6F4 into - interleaved variants that pack scale values into a single VGPR. + Packs scale values of ScaledMMA operations into shared VGPRs + using byte indexing to reduce register pressure. When two ScaledMMA ops exist in the same subgraph, their scales are packed into one register [a0, b0, a1, b1] and the first op uses - SCALES_INTERLEAVED (bytes 0,1) while the second uses - SCALES_INTERLEAVED_UPPER (bytes 2,3). + scale_idx (0,1) while the second uses (2,3). - Unpaired ops fall back to [a, b, 0, 0] with SCALES_INTERLEAVED. + Unpaired ops fall back to [a, b, 0, 0] with scale_idx (0,1). """ hardware_constraint = get_hardware_constraint(constraints) @@ -38,6 +37,8 @@ def is_target_scaled_mma(node: fx.Node) -> bool: custom = get_custom(node) if not isinstance(custom, ScaledMMA): return False + if custom.scale_idx_a is not None or custom.scale_idx_b is not None: + return False mma_type = custom.mma_type or hardware_constraint.mma_type return mma_type == ScaledMMAType.F32_16x16x128_F8F6F4 @@ -71,33 +72,16 @@ def is_target_scaled_mma(node: fx.Node) -> bool: {}, ).add_to_graph(mma_a.graph, loc=mma_a.location) - new_a = ScaledMMA( - mma_a.lhs, - combined, - mma_a.rhs, - combined, - mma_a.acc, - ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED, - ).add_to_graph(mma_a.graph, loc=mma_a.location) - new_a.index = mma_a.index - new_a.vector_shapes = mma_a.vector_shapes - - with mma_b.graph.inserting_before(mma_b.fx_node): - new_b = ScaledMMA( - mma_b.lhs, - combined, - mma_b.rhs, - combined, - mma_b.acc, - ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED_UPPER, - ).add_to_graph(mma_b.graph, loc=mma_b.location) - new_b.index = mma_b.index - new_b.vector_shapes = mma_b.vector_shapes - - mma_a.replace_all_uses_with(new_a) - mma_a.erase() - mma_b.replace_all_uses_with(new_b) - mma_b.erase() + mma_a.update_arg("lhs_scale", combined) + mma_a.update_arg("rhs_scale", combined) + mma_a.update_arg("scale_idx_a", 0) + mma_a.update_arg("scale_idx_b", 1) + + mma_b.update_arg("lhs_scale", combined) + mma_b.update_arg("rhs_scale", combined) + mma_b.update_arg("scale_idx_a", 2) + mma_b.update_arg("scale_idx_b", 3) + i += 2 else: # Unpaired: pad upper half with zeros. @@ -110,17 +94,9 @@ def is_target_scaled_mma(node: fx.Node) -> bool: {}, ).add_to_graph(mma_a.graph, loc=mma_a.location) - new_a = ScaledMMA( - mma_a.lhs, - combined, - mma_a.rhs, - combined, - mma_a.acc, - ScaledMMAType.F32_16x16x128_F8F6F4_SCALES_INTERLEAVED, - ).add_to_graph(mma_a.graph, loc=mma_a.location) - new_a.index = mma_a.index - new_a.vector_shapes = mma_a.vector_shapes + mma_a.update_arg("lhs_scale", combined) + mma_a.update_arg("rhs_scale", combined) + mma_a.update_arg("scale_idx_a", 0) + mma_a.update_arg("scale_idx_b", 1) - mma_a.replace_all_uses_with(new_a) - mma_a.erase() i += 1