From ba99de6d2473984cc08b21b82bf0d5af20b57e19 Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Thu, 5 Mar 2026 20:51:47 -0800 Subject: [PATCH 1/6] =?UTF-8?q?[compiler]=20Fix=20Rational=20codegen=20and?= =?UTF-8?q?=20pipeline=20unroll=20bug=20for=20dynamic=20s=E2=80=A6=20(#105?= =?UTF-8?q?7)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: xintin --- tests/kernel/wave_gemm_mxfp_test.py | 7 ++- .../kernel/compiler/wave_codegen/emitter.py | 3 +- wave_lang/kernel/ops/wave_schedule_ops.py | 2 + .../schedules/gemm_mxfp4_double_buffer.py | 1 + wave_lang/kernel/wave/scheduling/schedule.py | 43 ++++++++++++++----- 5 files changed, 44 insertions(+), 12 deletions(-) diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index 892394f249..1eadcfe494 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -907,7 +907,12 @@ def testScaledGemmMXFP4PreshuffleMacrotiles( ) @pytest.mark.parametrize( "block_shape,wave_shape", - [((128, 256, 256), (1, 4)), ((128, 32, 256), (2, 2)), ((256, 224, 256), (2, 2))], + [ + ((128, 256, 256), (1, 4)), + ((32, 64, 256), (1, 4)), + ((128, 32, 256), (2, 2)), + ((256, 224, 256), (2, 2)), + ], ) @pytest.mark.parametrize( "mfma_variant", diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index da11cbafda..c97d311582 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -1139,13 +1139,14 @@ def _get_const(val): # Build nested select operations # Start with the last expression (typically the default/else case) - result = cases[-1][1] + result = _resolve_rational(cases[-1][1]) # Work backwards through earlier cases to build nested selects # Piecewise((expr1, cond1), (expr2, cond2), (expr3, True)) becomes: # select(cond1, expr1, select(cond2, expr2, expr3)) for i in range(len(cases) - 2, -1, -1): cond, expr = cases[i] + expr = _resolve_rational(expr) result = arith_d.select(cond, *_broadcast(expr, result)) stack.append(result) diff --git a/wave_lang/kernel/ops/wave_schedule_ops.py b/wave_lang/kernel/ops/wave_schedule_ops.py index 159a1549e4..da8bfae40b 100755 --- a/wave_lang/kernel/ops/wave_schedule_ops.py +++ b/wave_lang/kernel/ops/wave_schedule_ops.py @@ -810,6 +810,7 @@ def __init__( self.kernel_trace = kernel_trace self.constraints = constraints self.multi_buffer_count = multi_buffer_count + self.unroll_factor = 1 # Access options from the current ScheduleContext from .._support.tracing import ScheduleContext @@ -875,6 +876,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): scheduling_type=SchedulingType.MANUAL, visualize=False, multi_buffer_count=self.multi_buffer_count, + unroll_factor=self.unroll_factor, ) # Store the pipelined iterate node and node mapping, then create proxies for the stages diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index cc422b3540..186f6bcf17 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -1411,6 +1411,7 @@ def mxfp4_dbuf_schedule(): # This forces the pipeline to use double buffering pipeline_loop.multi_buffer_count = 2 + pipeline_loop.unroll_factor = 2 with pipeline_loop as pl: pl.set_stage( diff --git a/wave_lang/kernel/wave/scheduling/schedule.py b/wave_lang/kernel/wave/scheduling/schedule.py index 11cb374c42..3a299344fb 100644 --- a/wave_lang/kernel/wave/scheduling/schedule.py +++ b/wave_lang/kernel/wave/scheduling/schedule.py @@ -214,12 +214,13 @@ def build_guarded_pipeline_with_remainder( visualize: bool = False, use_scheduling_barriers: bool = False, multi_buffer_count: Optional[int] = None, + unroll_factor: int = 1, ): """ Build conditional + pipelined loop + remainder loop for dynamic shapes. Structure: - if (max_induction_variable >= num_stages): + if (max_induction_variable >= num_stages + unroll_factor - 1): pipelined_result = pipelined_loop_with_prologue_epilogue() else: pipelined_result = init_values @@ -232,16 +233,24 @@ def build_guarded_pipeline_with_remainder( original_init_args = reduction.init_args main_graph = reduction.graph - # Create condition: max_induction_variable >= num_stages + if unroll_factor < 1: + raise ValueError(f"Expected unroll_factor >= 1, got {unroll_factor}") + + from math import lcm + + rounding_stride = lcm(num_stages, unroll_factor) + + # Need at least rounding_stride iterations so the rounded-down + # pipelined_iterations is non-zero. with main_graph.inserting_before(reduction.fx_node): - num_stages_scalar = get_graph_node( - NewScalar(num_stages, tkl.i32), main_graph, reduction.location + min_iters_scalar = get_graph_node( + NewScalar(rounding_stride, tkl.i32), main_graph, reduction.location ) num_iters_scalar = get_graph_node( NewScalar(max_induction_variable, tkl.i32), main_graph, reduction.location ) condition = get_graph_node( - Ge(num_iters_scalar, num_stages_scalar), main_graph, reduction.location + Ge(num_iters_scalar, min_iters_scalar), main_graph, reduction.location ) # Prepare conditional subgraph @@ -257,12 +266,22 @@ def build_guarded_pipeline_with_remainder( ) ) - # Compute the number of iterations the pipelined loop should process: - # This ensures the pipelined loop only processes complete pipeline stages + # Round pipelined_iterations to a multiple of lcm(num_stages, unroll_factor). + # This ensures complete pipeline stages (multiple of num_stages) and that + # the kernel count (pipelined_iterations - (num_stages - 1)) is divisible + # by unroll_factor, preventing the unrolled scf.for from executing extra + # iterations that read invalid pipeline state. + from math import lcm + + rounding_stride = lcm(num_stages, unroll_factor) if isinstance(max_induction_variable, (int, float)): - pipelined_iterations = (int(max_induction_variable) // num_stages) * num_stages + pipelined_iterations = ( + int(max_induction_variable) // rounding_stride + ) * rounding_stride else: - pipelined_iterations = (max_induction_variable // num_stages) * num_stages + pipelined_iterations = ( + max_induction_variable // rounding_stride + ) * rounding_stride conditional_body_graph, body_old_to_new = graph_copy(reduction_graph) placeholder_init_args = [placeholders[arg] for arg in reduction.init_args] @@ -420,12 +439,13 @@ def construct_pipelined_loop_adaptive( visualize: bool = False, use_scheduling_barriers: bool = False, multi_buffer_count: Optional[int] = None, + unroll_factor: int = 1, ): """ Constructs a pipelined loop wrapped in a conditional, followed by a remainder loop. Structure: - if (num_iterations >= num_stages): + if (num_iterations >= num_stages + unroll_factor - 1): prologue pipelined_loop epilogue @@ -477,6 +497,7 @@ def construct_pipelined_loop_adaptive( visualize, use_scheduling_barriers, multi_buffer_count, + unroll_factor, ) @@ -491,6 +512,7 @@ def apply_pipelined_schedule( scheduling_type: SchedulingType = SchedulingType.NONE, visualize: bool = False, multi_buffer_count: Optional[int] = None, + unroll_factor: int = 1, ) -> Optional[tuple[fx.Node, dict]]: # After scheduling has completed, we have enough information to decide @@ -525,6 +547,7 @@ def apply_pipelined_schedule( visualize, use_scheduling_barriers, multi_buffer_count, + unroll_factor, ) From e3f601fe0ec0fb2dd080e6ab64a72bd8f147eb4c Mon Sep 17 00:00:00 2001 From: xintin Date: Fri, 6 Mar 2026 03:28:46 +0000 Subject: [PATCH 2/6] fix more dyn shapes Signed-off-by: xintin --- tests/kernel/wave_gemm_mxfp_test.py | 60 ++++++++++++++++++- .../compiler/wave_codegen/read_write.py | 8 +-- wave_lang/kernel/wave/scheduling/schedule.py | 17 +++++- 3 files changed, 77 insertions(+), 8 deletions(-) diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index 1eadcfe494..2f97d8d6b2 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -785,19 +785,77 @@ def testScaledGemmMXFP4AsymmetricScheduleBF16( [ScaledMMAType.F32_16x16x128_F8F6F4], ) @use_water_backend_bool("use_water_backend") +def testScaledGemmMXFP4AsymmetricScheduleBF16( + shape: tuple[int, int, int], + block_shape: tuple[int, int, int], + mfma_variant: ScaledMMAType, + use_water_backend: bool, +): + """End-to-end test for asymmetric MXFP4 GEMM: A through LDS, B direct from global.""" + gemm, options = get_tagged_mxfp4_gemm( + shape, + block_shape, + wave_shape=(1, 4), + mfma_variant=mfma_variant, + b_address_space=GLOBAL_ADDRESS_SPACE, + output_dtype=tkl.bf16, + ) + schedule = get_mxfp4_asymmetric_schedule() + options.minimize_shared_allocs = True + options.linearize_shared_access = True + options.use_buffer_ops = True + options.use_water_backend = use_water_backend + options = set_default_run_config(options) + gemm = wave_compile(options, gemm, schedule) + + x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) + out = device_zeros(x.shape[0], w.shape[1], dtype=torch.bfloat16) + + w_t = w.T.contiguous() + gemm(x, x_scales, w_t, w_scales, out) + torch_out = torchScaledGemmMXFP4(x, w, x_scales, w_scales) + + torch.testing.assert_close(torch_out, out, check_dtype=False) + + +@require_e2e +@require_cdna4 +@pytest.mark.parametrize( + "shape, block_shape, wave_shape", + [ + ((1024, 1024, 8192), (256, 256, 256), (1, 4)), + ((1024, 1024, 8192), (64, 192, 256), (1, 4)), + ((32 * 32, 32 * 64, 16384), (32, 64, 256), (1, 4)), + ((640, 256, 768), (64, 64, 256), (1, 4)), + ((4224, 4096, 768), (64, 64, 256), (1, 4)), + ], +) +@pytest.mark.parametrize( + "mfma_variant", + [ScaledMMAType.F32_16x16x128_F8F6F4], +) +@param_bool("dynamic_shapes", "dyn") +@use_water_backend_bool("use_water_backend") def testScaledGemmMXFP4PreshuffleB( shape: tuple[int, int, int], block_shape: tuple[int, int, int], + wave_shape: tuple[int, int], mfma_variant: ScaledMMAType, + dynamic_shapes: bool, use_water_backend: bool, ): """End-to-end test for MXFP4 GEMM with preshuffled B data and B scales.""" gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( shape, block_shape, - wave_shape=(1, 4), + wave_shape=wave_shape, mfma_variant=mfma_variant, ) + if dynamic_shapes: + for sym in [tkl.sym.M, tkl.sym.N, tkl.sym.K]: + del options.subs[sym] + options.dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + options.wave_runtime = True schedule = get_mxfp4_asymmetric_schedule() options.minimize_shared_allocs = True options.linearize_shared_access = True diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 5caaada520..31179cf39e 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -210,7 +210,7 @@ def _get_max_buffer_size(elem_type: IrType) -> int: Buffer ops offsets are i32, return maximum memref size in elements. """ - return ((1 << 31) - 1) // (elem_type.width // 8) + return ((1 << 32) - 1) // (elem_type.width // 8) def _get_strides_from_memref(mem: Value) -> list[Value]: @@ -241,7 +241,7 @@ def _linearize_memref( memref_type = mem.type offset = None offset_th = None - overflow_flags = arith_d.IntegerOverflowFlags.nsw + overflow_flags = arith_d.IntegerOverflowFlags.none for ind_wg, ind_th, stride in zip(offsets_wg, offsets_th, strides): if isinstance(ind_wg, int): ind_wg = arith_d.constant(IndexType.get(), ind_wg) @@ -342,7 +342,7 @@ def _valid_bytes_buffer(elem_type: IrType) -> int: """ Make valid bytes to be the address of the last byte of the second to last element that can fit in a 32 bit offset to memory address """ - ans = (1 << 31) - 1 - (elem_type.width // 8) + ans = (1 << 32) - 1 - (elem_type.width // 8) assert isinstance(ans, int) return ans @@ -359,7 +359,7 @@ def _get_out_of_bounds_index(element_type: IrType) -> int: assert (oob_index_value * element_width_in_bytes) > _valid_bytes_buffer( element_type ) - assert (oob_index_value * element_width_in_bytes) < (1 << 31) + assert (oob_index_value * element_width_in_bytes) < (1 << 32) return oob_index_value diff --git a/wave_lang/kernel/wave/scheduling/schedule.py b/wave_lang/kernel/wave/scheduling/schedule.py index 3a299344fb..4611123761 100644 --- a/wave_lang/kernel/wave/scheduling/schedule.py +++ b/wave_lang/kernel/wave/scheduling/schedule.py @@ -283,6 +283,17 @@ def build_guarded_pipeline_with_remainder( max_induction_variable // rounding_stride ) * rounding_stride + # When the pipeline guard is false (max_iv < min_pipelined_iterations), + # the conditional returns init values and the remainder loop must start + # from 0 rather than pipelined_iterations. + remainder_start = sympy.Piecewise( + ( + pipelined_iterations, + sympy.Ge(max_induction_variable, min_pipelined_iterations), + ), + (0, True), + ) + conditional_body_graph, body_old_to_new = graph_copy(reduction_graph) placeholder_init_args = [placeholders[arg] for arg in reduction.init_args] placeholder_captures = [placeholders[cap] for cap in reduction.implicit_captures] @@ -395,11 +406,11 @@ def build_guarded_pipeline_with_remainder( main_graph.subgraphs[remainder_subgraph_name] = remainder_graph trace.region_graph.subgraphs[remainder_subgraph_name] = remainder_graph - # Create a scalar node for the starting iteration (where pipelined loop ended) - # This will be pipelined_iterations + # Create a scalar node for the starting iteration. + # When the pipeline ran, this is pipelined_iterations; otherwise 0. with main_graph.inserting_before(reduction.fx_node): start_iter = get_graph_node( - NewScalar(pipelined_iterations, tkl.index), main_graph, reduction.location + NewScalar(remainder_start, tkl.index), main_graph, reduction.location ) remainder_reduction = Iterate( From 503ed39bde1067e4251844a57f420c8e51576bf8 Mon Sep 17 00:00:00 2001 From: xintin Date: Fri, 6 Mar 2026 04:13:51 +0000 Subject: [PATCH 3/6] fix lit tests: buffer size Signed-off-by: xintin --- lit_tests/kernel/wave/codegen.py | 6 +++--- lit_tests/kernel/wave/dynamic_strides.py | 4 ++-- lit_tests/kernel/wave/scaled_gemm.py | 18 +++++++++--------- lit_tests/kernel/wave/scaled_mma.py | 10 +++++----- lit_tests/kernel/wave/topk.py | 8 ++++---- 5 files changed, 23 insertions(+), 23 deletions(-) diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 96271b727a..9fe225383d 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -220,7 +220,7 @@ def read_dynamic_buffer(a: tkl.Memory[B, M, N, ADDRESS_SPACE, tkl.f16]): # Gets offset to tensor's base pointer, then set memref_offset = indexing_offset + base_tensor_offset. # CHECK: %{{.*}}, %[[BASE_TENSOR_OFFSET:.+]], %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[ARG0]] : memref> -> memref, index, index, index, index, index, index, index - # CHECK: %[[MEMREF_OFFSET:.+]] = arith.addi %{{.*}}, %[[BASE_TENSOR_OFFSET]] overflow : index + # CHECK: %[[MEMREF_OFFSET:.+]] = arith.addi %{{.*}}, %[[BASE_TENSOR_OFFSET]] : index # CHECK: %[[MEMREF_CAST:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[MEMREF_OFFSET]]], {{.*}}: memref> to memref> # CHECK: %[[SWIZZLE_CAST:.*]] = arith.index_cast %c16{{.*}} : index to i14 @@ -259,8 +259,8 @@ def read_write( # CHECK: %[[S0:.*]] = memref.reinterpret_cast %[[D0]] to offset: [0], sizes: [16, 16], strides: [16, 1] : memref to memref<16x16xf16, strided<[16, 1]>> # CHECK: %[[I0:.*]] = affine.apply #[[MAP0]]()[%[[thread_id_x]]] # CHECK: %[[V:.*]] = vector.load %[[S0]][%[[I0]], %[[C0]]] : memref<16x16xf16, strided<[16, 1]>>, vector<16xf16> - # CHECK: memref.reinterpret_cast %[[D1]] to offset: [0], sizes: [1073741822], strides: [1] : memref to memref<1073741822xf16, strided<[1]>> - # CHECK: vector.store %[[V]], {{.*}}[{{.*}}] : memref<1073741822xf16, strided<[1]>>, vector<16xf16> + # CHECK: memref.reinterpret_cast %[[D1]] to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xf16, strided<[1]>> + # CHECK: vector.store %[[V]], {{.*}}[{{.*}}] : memref<2147483646xf16, strided<[1]>>, vector<16xf16> # CHECK: return diff --git a/lit_tests/kernel/wave/dynamic_strides.py b/lit_tests/kernel/wave/dynamic_strides.py index fe70fdc5a7..1b9a4b7965 100644 --- a/lit_tests/kernel/wave/dynamic_strides.py +++ b/lit_tests/kernel/wave/dynamic_strides.py @@ -50,5 +50,5 @@ def test_dynamic_strides_gemm(): # Output is linearized using dynamic strides from extract_strided_metadata, then stored to 1D view. # CHECK: memref.extract_strided_metadata %reinterpret_cast_1 : memref<1024x1024xf32, strided<[?, 1]>> - # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [536870910], strides: [1] - # CHECK: vector.store {{.*}} %reinterpret_cast_3{{.*}} : memref<536870910xf32, strided<[1], offset: ?>> + # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [1073741822], strides: [1] + # CHECK: vector.store {{.*}} %reinterpret_cast_3{{.*}} : memref<1073741822xf32, strided<[1], offset: ?>> diff --git a/lit_tests/kernel/wave/scaled_gemm.py b/lit_tests/kernel/wave/scaled_gemm.py index 7bf3b8b178..18dbccfae4 100644 --- a/lit_tests/kernel/wave/scaled_gemm.py +++ b/lit_tests/kernel/wave/scaled_gemm.py @@ -583,10 +583,10 @@ def repeat( # CHECK-DAG: %[[C512_I14:.+]] = arith.constant 512 : i14 # Prologue Global Read - # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1], offset: ?>> + # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [4294967294], strides: [1] : memref to memref<4294967294xi8, strided<[1], offset: ?>> # CHECK: amdgpu.fat_raw_buffer_cast %{{.*}} validBytes(%{{.*}}) cacheSwizzleStride(%[[C512_I14]]) resetOffset : memref> to memref> # CHECK-COUNT-4: vector.load {{.*}} : memref>, vector<16xi8> - # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1], offset: ?>> + # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [4294967294], strides: [1] : memref to memref<4294967294xi8, strided<[1], offset: ?>> # CHECK: amdgpu.fat_raw_buffer_cast %{{.*}} validBytes(%{{.*}}) cacheSwizzleStride(%[[C32_I14]]) resetOffset : memref> to memref> # CHECK: vector.load {{.*}} : memref>, vector<4xi8> # CHECK-COUNT-4: vector.load {{.*}} : memref>, vector<16xi8> @@ -846,10 +846,10 @@ def repeat( # CHECK-DAG: #[[MAP17:.*]] = affine_map<()[s0, s1] -> (s1 * 32 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 192) floordiv 256) * 256 + 192)> # CHECK: func.func @batched_gemm(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding, %arg3: !stream.binding, %arg4: !stream.binding, %arg5: index, %arg6: index) attributes {translation_info = #translation} { # CHECK-DAG: %[[CST1:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> - # CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2147483647> : vector<16xindex> + # CHECK-DAG: %[[CST2:.*]] = arith.constant dense<4294967295> : vector<16xindex> # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index # CHECK-DAG: %[[C8192:.*]] = arith.constant 8192 : index - # CHECK-DAG: %[[C2147483646_I64:.*]] = arith.constant 2147483646 : i64 + # CHECK-DAG: %[[C2147483646_I64:.*]] = arith.constant 4294967294 : i64 # CHECK-DAG: %[[C_NEG_8192_I14:.*]] = arith.constant -8192 : i14 # CHECK-DAG: %[[BLOCK_ID_X:.*]] = gpu.block_id x # CHECK-DAG: %[[BLOCK_ID_Z:.*]] = gpu.block_id z @@ -857,16 +857,16 @@ def repeat( # CHECK-DAG: %[[THREAD_ID_Y:.*]] = gpu.thread_id y upper_bound 2 # CHECK-DAG: %[[AFFINE_APPLY2:.*]] = affine.apply #[[MAP3]]()[%arg6] # CHECK-DAG: %[[AFFINE_APPLY1:.*]] = affine.apply #[[MAP2]]()[%[[THREAD_ID_X]]] - # CHECK: %[[MUL1:.*]] = arith.muli %[[BLOCK_ID_Z]], %[[AFFINE_APPLY2]] overflow : index - # CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1], offset: ?>> - # CHECK: %[[CAST:.*]] = memref.cast %[[REINTERPRET_CAST]] : memref<2147483646xi8, strided<[1], offset: ?>> to memref> + # CHECK: %[[MUL1:.*]] = arith.muli %[[BLOCK_ID_Z]], %[[AFFINE_APPLY2]] : index + # CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [4294967294], strides: [1] : memref to memref<4294967294xi8, strided<[1], offset: ?>> + # CHECK: %[[CAST:.*]] = memref.cast %[[REINTERPRET_CAST]] : memref<4294967294xi8, strided<[1], offset: ?>> to memref> # CHECK: %[[BUFF_CAST:.*]] = amdgpu.fat_raw_buffer_cast %[[CAST]] validBytes(%[[C2147483646_I64]]) cacheSwizzleStride(%[[C_NEG_8192_I14]]) resetOffset : memref> to memref> # CHECK: %[[AFFINE_APPLY3:.*]] = affine.apply #[[MAP6]]()[%[[THREAD_ID_X]], %[[THREAD_ID_Y]], %[[BLOCK_ID_X]]] # CHECK: %[[CMP1:.*]] = arith.cmpi slt, %[[AFFINE_APPLY3]], %arg6 : index # CHECK: %[[BROADCAST1:.*]] = vector.broadcast %[[CMP1]] : i1 to vector<16xi1> # CHECK: %[[AFFINE_APPLY4:.*]] = affine.apply #[[MAP17]]()[%[[THREAD_ID_X]], %[[THREAD_ID_Y]]] - # CHECK: %[[MUL2:.*]] = arith.muli %[[AFFINE_APPLY4]], %[[C8192]] overflow : index - # CHECK: %[[ADD1:.*]] = arith.addi %[[MUL2]], %[[AFFINE_APPLY1]] overflow : index + # CHECK: %[[MUL2:.*]] = arith.muli %[[AFFINE_APPLY4]], %[[C8192]] : index + # CHECK: %[[ADD1:.*]] = arith.addi %[[MUL2]], %[[AFFINE_APPLY1]] : index # CHECK: %[[IDX_CAST1:.*]] = arith.index_cast %[[ADD1]] : index to i32 # CHECK: %[[BROADCAST2:.*]] = vector.broadcast %[[IDX_CAST1]] : i32 to vector<16xi32> # CHECK: %[[ADD3:.*]] = arith.addi %[[BROADCAST2]], %[[CST1]] : vector<16xi32> diff --git a/lit_tests/kernel/wave/scaled_mma.py b/lit_tests/kernel/wave/scaled_mma.py index f37dae9b56..39f02f906a 100644 --- a/lit_tests/kernel/wave/scaled_mma.py +++ b/lit_tests/kernel/wave/scaled_mma.py @@ -121,17 +121,17 @@ def scaled_mma( # CHECK: %[[SCALED_MFMA:.+]] = amdgpu.scaled_mfma 16x16x128 (%[[VECTOR_LOAD_6]][0] * %[[BITCAST_0]]) * (%[[VECTOR_LOAD_4]][0] * %[[BITCAST_1]]) + %[[CST]] : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> # CHECK: %[[EXTRACT_STRIDED_SLICE_0:.+]] = vector.extract_strided_slice %[[SCALED_MFMA]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> # CHECK: %[[AFFINE_APPLY_4:.+]] = affine.apply #[[MAP4]]()[%[[THREAD_ID_X]]] - # CHECK: %[[REINTERPRET_CAST_6:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [536870910], strides: [1] : memref to memref<536870910xf32, strided<[1]>> - # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_0]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<536870910xf32, strided<[1]>>, vector<1xf32> + # CHECK: %[[REINTERPRET_CAST_6:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [1073741822], strides: [1] : memref to memref<1073741822xf32, strided<[1]>> + # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_0]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<1073741822xf32, strided<[1]>>, vector<1xf32> # CHECK: %[[EXTRACT_STRIDED_SLICE_1:.+]] = vector.extract_strided_slice %[[SCALED_MFMA]] {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> # CHECK: %[[AFFINE_APPLY_5:.+]] = affine.apply #[[MAP5]]()[%[[THREAD_ID_X]]] - # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_1]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<536870910xf32, strided<[1]>>, vector<1xf32> + # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_1]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<1073741822xf32, strided<[1]>>, vector<1xf32> # CHECK: %[[EXTRACT_STRIDED_SLICE_2:.+]] = vector.extract_strided_slice %[[SCALED_MFMA]] {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> # CHECK: %[[AFFINE_APPLY_6:.+]] = affine.apply #[[MAP6]]()[%[[THREAD_ID_X]]] - # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_2]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<536870910xf32, strided<[1]>>, vector<1xf32> + # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_2]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<1073741822xf32, strided<[1]>>, vector<1xf32> # CHECK: %[[EXTRACT_STRIDED_SLICE_3:.+]] = vector.extract_strided_slice %[[SCALED_MFMA]] {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> # CHECK: %[[AFFINE_APPLY_7:.+]] = affine.apply #[[MAP7]]()[%[[THREAD_ID_X]]] - # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_3]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<536870910xf32, strided<[1]>>, vector<1xf32> + # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_3]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<1073741822xf32, strided<[1]>>, vector<1xf32> # CHECK: return diff --git a/lit_tests/kernel/wave/topk.py b/lit_tests/kernel/wave/topk.py index 8c6bdac9d4..feb679c1ea 100644 --- a/lit_tests/kernel/wave/topk.py +++ b/lit_tests/kernel/wave/topk.py @@ -102,7 +102,7 @@ def topk( # CHECK: vector.from_elements{{.*}} : vector<2xi32> # Write operations for both values and indices (linearized 1D stores) - # CHECK: memref.reinterpret_cast {{.*}} to offset: [{{.*}}], sizes: [1073741822], strides: [1] : memref - # CHECK: vector.store {{.*}} : memref<1073741822xf16{{.*}}>, vector<2xf16> - # CHECK: memref.reinterpret_cast {{.*}} to offset: [{{.*}}], sizes: [536870910], strides: [1] : memref - # CHECK: vector.store {{.*}} : memref<536870910xi32{{.*}}>, vector<2xi32> + # CHECK: memref.reinterpret_cast {{.*}} to offset: [{{.*}}], sizes: [2147483646], strides: [1] : memref + # CHECK: vector.store {{.*}} : memref<2147483646xf16{{.*}}>, vector<2xf16> + # CHECK: memref.reinterpret_cast {{.*}} to offset: [{{.*}}], sizes: [1073741822], strides: [1] : memref + # CHECK: vector.store {{.*}} : memref<1073741822xi32{{.*}}>, vector<2xi32> From 9d4629f0428f0a86a075548ec6e5a4adcb5361d0 Mon Sep 17 00:00:00 2001 From: xintin Date: Fri, 6 Mar 2026 05:35:51 +0000 Subject: [PATCH 4/6] rebase Signed-off-by: xintin --- tests/kernel/wave_gemm_mxfp_test.py | 69 ++------------------ wave_lang/kernel/wave/scheduling/schedule.py | 4 +- 2 files changed, 9 insertions(+), 64 deletions(-) diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index 2f97d8d6b2..14f06517d0 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -366,10 +366,11 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path): "s_waitcnt vmcnt(0)", "s_waitcnt vmcnt(0) lgkmcnt(0)", "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(14)", "s_waitcnt lgkmcnt(7)", - "s_waitcnt lgkmcnt(5)", - "s_waitcnt lgkmcnt(4)", - "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(6)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(2)", "s_waitcnt lgkmcnt(1)", "s_waitcnt lgkmcnt(0)", ] @@ -785,77 +786,19 @@ def testScaledGemmMXFP4AsymmetricScheduleBF16( [ScaledMMAType.F32_16x16x128_F8F6F4], ) @use_water_backend_bool("use_water_backend") -def testScaledGemmMXFP4AsymmetricScheduleBF16( - shape: tuple[int, int, int], - block_shape: tuple[int, int, int], - mfma_variant: ScaledMMAType, - use_water_backend: bool, -): - """End-to-end test for asymmetric MXFP4 GEMM: A through LDS, B direct from global.""" - gemm, options = get_tagged_mxfp4_gemm( - shape, - block_shape, - wave_shape=(1, 4), - mfma_variant=mfma_variant, - b_address_space=GLOBAL_ADDRESS_SPACE, - output_dtype=tkl.bf16, - ) - schedule = get_mxfp4_asymmetric_schedule() - options.minimize_shared_allocs = True - options.linearize_shared_access = True - options.use_buffer_ops = True - options.use_water_backend = use_water_backend - options = set_default_run_config(options) - gemm = wave_compile(options, gemm, schedule) - - x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) - out = device_zeros(x.shape[0], w.shape[1], dtype=torch.bfloat16) - - w_t = w.T.contiguous() - gemm(x, x_scales, w_t, w_scales, out) - torch_out = torchScaledGemmMXFP4(x, w, x_scales, w_scales) - - torch.testing.assert_close(torch_out, out, check_dtype=False) - - -@require_e2e -@require_cdna4 -@pytest.mark.parametrize( - "shape, block_shape, wave_shape", - [ - ((1024, 1024, 8192), (256, 256, 256), (1, 4)), - ((1024, 1024, 8192), (64, 192, 256), (1, 4)), - ((32 * 32, 32 * 64, 16384), (32, 64, 256), (1, 4)), - ((640, 256, 768), (64, 64, 256), (1, 4)), - ((4224, 4096, 768), (64, 64, 256), (1, 4)), - ], -) -@pytest.mark.parametrize( - "mfma_variant", - [ScaledMMAType.F32_16x16x128_F8F6F4], -) -@param_bool("dynamic_shapes", "dyn") -@use_water_backend_bool("use_water_backend") def testScaledGemmMXFP4PreshuffleB( shape: tuple[int, int, int], block_shape: tuple[int, int, int], - wave_shape: tuple[int, int], mfma_variant: ScaledMMAType, - dynamic_shapes: bool, use_water_backend: bool, ): """End-to-end test for MXFP4 GEMM with preshuffled B data and B scales.""" gemm, options = get_tagged_mxfp4_gemm_preshuffle_b( shape, block_shape, - wave_shape=wave_shape, + wave_shape=(1, 4), mfma_variant=mfma_variant, ) - if dynamic_shapes: - for sym in [tkl.sym.M, tkl.sym.N, tkl.sym.K]: - del options.subs[sym] - options.dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] - options.wave_runtime = True schedule = get_mxfp4_asymmetric_schedule() options.minimize_shared_allocs = True options.linearize_shared_access = True @@ -892,6 +835,7 @@ def testScaledGemmMXFP4PreshuffleB( (64, 128, 256), (64, 128, 128), (64, 64, 128), + (64, 64, 256), (32, 192, 256), (32, 128, 256), (32, 64, 256), @@ -968,6 +912,7 @@ def testScaledGemmMXFP4PreshuffleMacrotiles( [ ((128, 256, 256), (1, 4)), ((32, 64, 256), (1, 4)), + ((64, 64, 256), (1, 4)), ((128, 32, 256), (2, 2)), ((256, 224, 256), (2, 2)), ], diff --git a/wave_lang/kernel/wave/scheduling/schedule.py b/wave_lang/kernel/wave/scheduling/schedule.py index 4611123761..c215a5b90f 100644 --- a/wave_lang/kernel/wave/scheduling/schedule.py +++ b/wave_lang/kernel/wave/scheduling/schedule.py @@ -283,13 +283,13 @@ def build_guarded_pipeline_with_remainder( max_induction_variable // rounding_stride ) * rounding_stride - # When the pipeline guard is false (max_iv < min_pipelined_iterations), + # When the pipeline guard is false (max_iv < rounding_stride), # the conditional returns init values and the remainder loop must start # from 0 rather than pipelined_iterations. remainder_start = sympy.Piecewise( ( pipelined_iterations, - sympy.Ge(max_induction_variable, min_pipelined_iterations), + sympy.Ge(max_induction_variable, rounding_stride), ), (0, True), ) From cae348ac4e6872098f8841f3a17136ab409bea82 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Fri, 6 Mar 2026 11:16:47 +0100 Subject: [PATCH 5/6] [water] promote the duplicate non-unit step verifier (#1050) Until now, we have only been verifying the absence of a second non-unit step in index expressions of read and write operations. Do so for every operation in the trait that attaches the attribute. This is not super-efficient as it requires looking up the attribute on the same parent from all operations, but guarantees the check to happen unlike using the attribute verifier which will not kick in in absence of the hyperparameters attribute even if we can see a problem. A better, longer-term solution is to introduce a top-level wave kernel operation where hyperparameters are mandatory. We can also go for a normal form that will perform a top-down verification collecting the attributes on the way. Closes https://github.com/iree-org/wave/issues/1013. --------- Signed-off-by: Alex Zinenko Signed-off-by: xintin --- .../water/Dialect/Wave/IR/WaveInterfaces.h | 4 ++ .../water/Dialect/Wave/Transforms/Utils.h | 4 -- water/lib/Dialect/Wave/IR/WaveDialect.cpp | 3 +- water/lib/Dialect/Wave/IR/WaveInterfaces.cpp | 45 +++++++++++++++++++ water/lib/Dialect/Wave/IR/WaveOps.cpp | 15 ++----- .../Dialect/Wave/Transforms/InferTypes.cpp | 17 +++---- water/lib/Dialect/Wave/Transforms/Utils.cpp | 9 ---- water/test/Dialect/Wave/ops-invalid.mlir | 22 +++++++-- .../Wave/propagate-elements-per-thread.mlir | 15 ------- 9 files changed, 78 insertions(+), 56 deletions(-) diff --git a/water/include/water/Dialect/Wave/IR/WaveInterfaces.h b/water/include/water/Dialect/Wave/IR/WaveInterfaces.h index 2d2373d12f..256d7b2af2 100644 --- a/water/include/water/Dialect/Wave/IR/WaveInterfaces.h +++ b/water/include/water/Dialect/Wave/IR/WaveInterfaces.h @@ -26,6 +26,10 @@ using EmitErrorFn = llvm::function_ref; class WaveTensorType; +/// Get the hyperparameters from an ancestor operation. +/// Returns nullptr if no hyperparameters are found. +WaveHyperparameterAttr getHyperparameters(mlir::Operation *op); + //----------------------------------------------------------------------------- // HasWaveIndexMapping trait //----------------------------------------------------------------------------- diff --git a/water/include/water/Dialect/Wave/Transforms/Utils.h b/water/include/water/Dialect/Wave/Transforms/Utils.h index 4d3cc3692f..e57195160a 100644 --- a/water/include/water/Dialect/Wave/Transforms/Utils.h +++ b/water/include/water/Dialect/Wave/Transforms/Utils.h @@ -11,10 +11,6 @@ namespace wave { -/// Get the hyperparameters from an ancestor operation. -/// Returns nullptr if no hyperparameters are found. -WaveHyperparameterAttr getHyperparameters(mlir::Operation *op); - // Populates `constraints` with a mapping from an operation with a Wave // constraints attribute attached to that attribute. llvm::LogicalResult collectWaveConstraints( diff --git a/water/lib/Dialect/Wave/IR/WaveDialect.cpp b/water/lib/Dialect/Wave/IR/WaveDialect.cpp index 7a7e7218c0..2b525b222f 100644 --- a/water/lib/Dialect/Wave/IR/WaveDialect.cpp +++ b/water/lib/Dialect/Wave/IR/WaveDialect.cpp @@ -14,14 +14,13 @@ #include "mlir/IR/Dialect.h" #include "water/Dialect/Wave/IR/WaveDialect.cpp.inc" +#include "water/Dialect/Wave/IR/WaveInterfaces.h" #include "water/Dialect/Wave/IR/WaveUtils.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/LogicalResult.h" -#include #include using namespace mlir; diff --git a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp index f96aadeb6f..7125e466fa 100644 --- a/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp +++ b/water/lib/Dialect/Wave/IR/WaveInterfaces.cpp @@ -23,6 +23,19 @@ using namespace mlir; +//----------------------------------------------------------------------------- +// getHyperparameters +//----------------------------------------------------------------------------- + +wave::WaveHyperparameterAttr wave::getHyperparameters(Operation *op) { + for (Operation *current = op; current; current = current->getParentOp()) { + if (auto hyperparams = current->getAttrOfType( + WaveDialect::kHyperparameterAttrName)) + return hyperparams; + } + return nullptr; +} + //----------------------------------------------------------------------------- // Index attribute verification //----------------------------------------------------------------------------- @@ -86,6 +99,38 @@ LogicalResult wave::verifyWaveIndexMappings(Operation *op) { } } + // For ops with the index attribute, verify that each index expression has at + // most one dimension whose step evaluates to a static value different from 1 + // (with hyperparameters substituted). Structural checks stay in op verifiers. + wave::WaveHyperparameterAttr hyperparams = wave::getHyperparameters(op); + for (DictionaryAttr dictAttr : dicts) { + int nonUnitCount = 0; + for (const NamedAttribute &named : dictAttr) { + auto mapping = dyn_cast(named.getValue()); + if (!mapping || !mapping.getStep()) + continue; + + std::optional> stepValues = + wave::evaluateMapWithHyperparams(mapping.getStep(), + mapping.getSymbols(), hyperparams); + if (!stepValues || stepValues->size() != 1) + continue; + + int64_t step = (*stepValues)[0]; + if (step == 1 || step == ShapedType::kDynamic) + continue; + + if (++nonUnitCount > 1) { + InFlightDiagnostic diag = + op->emitOpError() << "'" << WaveDialect::kIndexWaveExprListAttrName + << "' has more than one entry with non-unit step"; + diag.attachNote() << "second non-unit step dimension: " + << named.getName(); + return failure(); + } + } + } + // When the operation implements WaveInferIndexExprsOpInterface, the index // attribute length must match the number of values from // getIndexExprValuesAndDescriptions. Otherwise, default to the number of op diff --git a/water/lib/Dialect/Wave/IR/WaveOps.cpp b/water/lib/Dialect/Wave/IR/WaveOps.cpp index aa5cbb5d70..d87ad5808e 100644 --- a/water/lib/Dialect/Wave/IR/WaveOps.cpp +++ b/water/lib/Dialect/Wave/IR/WaveOps.cpp @@ -1572,9 +1572,8 @@ ReadOp::propagateBackward(MutableArrayRef operandTypes, LogicalResult ReadOp::finalizeTypeInference() { return success(); } -// Check the well-formedness of the index attribute (must have at most one -// non-unit dimension) and its correspondence with the explicit elements per -// thread, if provided, and with the number of elements in the vector type. +// Check the correspondence of the index attribute with the explicit elements +// per thread, if provided, and with the number of elements in the vector type. static LogicalResult verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr, std::optional elementsPerThread, @@ -1605,7 +1604,7 @@ verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr, if (!indexDict) return success(); - wave::WaveHyperparameterAttr hyper = wave::WaveHyperparameterAttr(); + wave::WaveHyperparameterAttr hyper = nullptr; for (Operation *cur = op; cur != nullptr && !hyper; cur = cur->getParentOp()) { hyper = cur->getAttrOfType( @@ -1621,7 +1620,7 @@ verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr, getUncollapsedVectorShape(tensorType.getShape(), indexDict, hyper); int64_t nonUnit = 1; bool hadDynamic = false; - for (auto [i, size] : llvm::enumerate(shape)) { + for (int64_t size : shape) { if (ShapedType::isDynamic(size)) { hadDynamic = true; continue; @@ -1632,13 +1631,7 @@ verifyIndexElementsPerThread(Operation *op, ArrayAttr indexAttr, } if (nonUnit == 1) { nonUnit = size; - continue; } - - InFlightDiagnostic diag = - op->emitError() << "'index' has more than one entry with non-unit step"; - diag.attachNote() << "second non-unit step dimension: " << i; - return diag; } // If there were unevaluated steps, they may end up matching later on. diff --git a/water/lib/Dialect/Wave/Transforms/InferTypes.cpp b/water/lib/Dialect/Wave/Transforms/InferTypes.cpp index cc8126c53d..61362d1a5e 100644 --- a/water/lib/Dialect/Wave/Transforms/InferTypes.cpp +++ b/water/lib/Dialect/Wave/Transforms/InferTypes.cpp @@ -829,18 +829,11 @@ class ElementsPerThreadForwardAnalysis // Elements per thread may be 1 if _all_ dimensions have a unit step, // otherwise it should be the one non-unit step. - // TODO(#1013): this logic can be reused in the verifier. - if (!elementsPerThread.has_value()) { - elementsPerThread = (*stepValues)[0]; - } else if (*elementsPerThread == 1) { - elementsPerThread = (*stepValues)[0]; - } else if (stepValue != 1) { - // TODO(#1013): turn this into an assertion when the verifier is - // implemented. - op->emitError() << "expected only one non-unit index step, found " - << (*stepValues)[0] << " and " << *elementsPerThread - << " (missing verifier)"; - return WalkResult::interrupt(); + assert((!elementsPerThread.has_value() || *elementsPerThread == 1 || + stepValue == 1) && + "expected only one non-unit index step"); + if (!elementsPerThread.has_value() || *elementsPerThread == 1) { + elementsPerThread = stepValue; } } diff --git a/water/lib/Dialect/Wave/Transforms/Utils.cpp b/water/lib/Dialect/Wave/Transforms/Utils.cpp index 25003e33e0..7234ff5405 100644 --- a/water/lib/Dialect/Wave/Transforms/Utils.cpp +++ b/water/lib/Dialect/Wave/Transforms/Utils.cpp @@ -21,15 +21,6 @@ using namespace mlir; -wave::WaveHyperparameterAttr wave::getHyperparameters(Operation *op) { - for (Operation *current = op; current; current = current->getParentOp()) { - if (auto hyperparams = current->getAttrOfType( - WaveDialect::kHyperparameterAttrName)) - return hyperparams; - } - return nullptr; -} - llvm::LogicalResult wave::collectWaveConstraints( Operation *top, llvm::DenseMap &constraints) { auto *waveDialect = top->getContext()->getLoadedDialect(); diff --git a/water/test/Dialect/Wave/ops-invalid.mlir b/water/test/Dialect/Wave/ops-invalid.mlir index da3ba333b1..0a21325099 100644 --- a/water/test/Dialect/Wave/ops-invalid.mlir +++ b/water/test/Dialect/Wave/ops-invalid.mlir @@ -602,9 +602,11 @@ func.func @bounds_wrong_rank(%mem: !wave.tensor<[@N] of f32>) { // ----- -func.func @read_index_multi_step(%mem: !wave.tensor<[@M, @N] of f32>) { +func.func @read_index_multi_step(%mem: !wave.tensor<[@M, @N] of f32>) attributes { + wave.hyperparameters = #wave.hyperparameters<{M = 1, N = 1}> +} { // expected-error @below {{'index' has more than one entry with non-unit step}} - // expected-note @below {{second non-unit step dimension: 1}} + // expected-note @below {{second non-unit step dimension: "N"}} wave.read %mem index [{ M : <[#wave.index_symbol] -> (T0, 2, 1)>, N : <[#wave.index_symbol] -> (T1, 2, 1)> @@ -643,7 +645,7 @@ func.func @read_index_multi_step_eval(%mem: !wave.tensor<[@M, @N] of f32>) attri wave.hyperparameters = #wave.hyperparameters<{X = 1, Y = 1, M = 100, N = 200}> } { // expected-error @below {{'index' has more than one entry with non-unit step}} - // expected-note @below {{second non-unit step dimension: 1}} + // expected-note @below {{second non-unit step dimension: "N"}} wave.read %mem index [{ M : <[#wave.index_symbol, #wave.symbol<"X">] -> (T0, 2 * X, 1)>, N : <[#wave.index_symbol, #wave.symbol<"X">, #wave.symbol<"Y">] -> (T1, X + Y, 1)> @@ -653,6 +655,20 @@ func.func @read_index_multi_step_eval(%mem: !wave.tensor<[@M, @N] of f32>) attri // ----- +func.func @write_index_multi_step_eval(%val: !wave.tensor<[@M, @N] of f32, >, %mem: !wave.tensor<[@M, @N] of f32, >) attributes { + wave.hyperparameters = #wave.hyperparameters<{X = 1, Y = 1, M = 100, N = 200}> +} { + // expected-error @below {{'index' has more than one entry with non-unit step}} + // expected-note @below {{second non-unit step dimension: "N"}} + wave.write %val, %mem index [{ + M : <[#wave.index_symbol, #wave.symbol<"X">] -> (T0, 2 * X, 1)>, + N : <[#wave.index_symbol, #wave.symbol<"X">, #wave.symbol<"Y">] -> (T1, X + Y, 1)> + }] : !wave.tensor<[@M, @N] of f32, >, !wave.tensor<[@M, @N] of f32, > + return +} + +// ----- + func.func @extract_invalid_position_rank(%src: !wave.tensor<[@M, @N] of f32>) { // expected-error @below {{position must contain exactly one expression, but got 2}} wave.extract %src[#wave.expr_list<[] -> (0, 1)>] : (!wave.tensor<[@M, @N] of f32>) -> !wave.tensor<[@M] of f32> diff --git a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir index f3ab0fd78a..cc680e9280 100644 --- a/water/test/Dialect/Wave/propagate-elements-per-thread.mlir +++ b/water/test/Dialect/Wave/propagate-elements-per-thread.mlir @@ -148,21 +148,6 @@ normalform.module [#wave.normal_form] { // ----- -// Two dimensions with non-unit steps; pass must report "expected only one non-unit". -// Use only register (write has its own verifier for multi-step index). -normalform.module [#wave.normal_form] { - func.func @index_multi_non_unit_step(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 64}>, wave.constraints = []} { - %cst = arith.constant 0.0 : f16 - // expected-error @below {{expected only one non-unit index step}} - %reg = wave.register %cst index [{M : <[] -> (, 4, )>, N : <[] -> (, 8, )>}] : !wave.tensor<[@M, @N] of f16, > - wave.write %reg, %mem index [{M : <[] -> (, 4, )>, N : <[] -> (, 1, )>}] - : !wave.tensor<[@M, @N] of f16, >, !wave.tensor<[@M, @N] of f16, > - return - } -} - -// ----- - // Index missing dimension N for result type [M, N]; pass must report missing dimensions. normalform.module [#wave.normal_form] { func.func @index_missing_dimension(%mem: !wave.tensor<[@M, @N] of f16, >) attributes {wave.hyperparameters = #wave.hyperparameters<{M = 128, N = 64}>, wave.constraints = []} { From 7a21b68432c48d542b40295574b52c7cec8f4058 Mon Sep 17 00:00:00 2001 From: xintin Date: Fri, 6 Mar 2026 17:26:52 +0000 Subject: [PATCH 6/6] remove remainder loop start fix (moved to separate PR) The schedule.py changes are now in xintin/fix_dynamic_pipeline_remainder_loop_start. Signed-off-by: xintin Made-with: Cursor Signed-off-by: xintin --- wave_lang/kernel/wave/scheduling/schedule.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/wave_lang/kernel/wave/scheduling/schedule.py b/wave_lang/kernel/wave/scheduling/schedule.py index c215a5b90f..3a299344fb 100644 --- a/wave_lang/kernel/wave/scheduling/schedule.py +++ b/wave_lang/kernel/wave/scheduling/schedule.py @@ -283,17 +283,6 @@ def build_guarded_pipeline_with_remainder( max_induction_variable // rounding_stride ) * rounding_stride - # When the pipeline guard is false (max_iv < rounding_stride), - # the conditional returns init values and the remainder loop must start - # from 0 rather than pipelined_iterations. - remainder_start = sympy.Piecewise( - ( - pipelined_iterations, - sympy.Ge(max_induction_variable, rounding_stride), - ), - (0, True), - ) - conditional_body_graph, body_old_to_new = graph_copy(reduction_graph) placeholder_init_args = [placeholders[arg] for arg in reduction.init_args] placeholder_captures = [placeholders[cap] for cap in reduction.implicit_captures] @@ -406,11 +395,11 @@ def build_guarded_pipeline_with_remainder( main_graph.subgraphs[remainder_subgraph_name] = remainder_graph trace.region_graph.subgraphs[remainder_subgraph_name] = remainder_graph - # Create a scalar node for the starting iteration. - # When the pipeline ran, this is pipelined_iterations; otherwise 0. + # Create a scalar node for the starting iteration (where pipelined loop ended) + # This will be pipelined_iterations with main_graph.inserting_before(reduction.fx_node): start_iter = get_graph_node( - NewScalar(remainder_start, tkl.index), main_graph, reduction.location + NewScalar(pipelined_iterations, tkl.index), main_graph, reduction.location ) remainder_reduction = Iterate(