From 2e311ee0482d100252203724e3b1110e576ef442 Mon Sep 17 00:00:00 2001 From: harsh-nod Date: Thu, 5 Mar 2026 23:40:14 +0000 Subject: [PATCH] [compiler] Fix Piecewise Rational codegen and pipeline unroll bug for dynamic shapes Handle _Rational values in Piecewise/select expressions by resolving them via floor division before passing to arith.select, which cannot operate on deferred _Rational values. Fix pipeline unroll interaction for dynamic shapes where the kernel loop count was not divisible by the unroll factor, causing the scf.for loop to execute extra iterations that read invalid pipeline state. Thread unroll_factor through PipelinedLoop to build_guarded_pipeline_with_remainder and use lcm(num_stages, unroll_factor) as the rounding stride for pipelined_iterations. Add (32,64,256) block shape to testScaledGemmMXFP4PreshuffleBDynamic. Signed-off-by: Harsh Menon --- tests/kernel/wave_gemm_mxfp_test.py | 7 ++- .../kernel/compiler/wave_codegen/emitter.py | 3 +- wave_lang/kernel/ops/wave_schedule_ops.py | 2 + .../schedules/gemm_mxfp4_double_buffer.py | 1 + wave_lang/kernel/wave/scheduling/schedule.py | 43 ++++++++++++++----- 5 files changed, 44 insertions(+), 12 deletions(-) diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index 892394f24..1eadcfe49 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -907,7 +907,12 @@ def testScaledGemmMXFP4PreshuffleMacrotiles( ) @pytest.mark.parametrize( "block_shape,wave_shape", - [((128, 256, 256), (1, 4)), ((128, 32, 256), (2, 2)), ((256, 224, 256), (2, 2))], + [ + ((128, 256, 256), (1, 4)), + ((32, 64, 256), (1, 4)), + ((128, 32, 256), (2, 2)), + ((256, 224, 256), (2, 2)), + ], ) @pytest.mark.parametrize( "mfma_variant", diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index da11cbafd..c97d31158 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -1139,13 +1139,14 @@ def _get_const(val): # Build nested select operations # Start with the last expression (typically the default/else case) - result = cases[-1][1] + result = _resolve_rational(cases[-1][1]) # Work backwards through earlier cases to build nested selects # Piecewise((expr1, cond1), (expr2, cond2), (expr3, True)) becomes: # select(cond1, expr1, select(cond2, expr2, expr3)) for i in range(len(cases) - 2, -1, -1): cond, expr = cases[i] + expr = _resolve_rational(expr) result = arith_d.select(cond, *_broadcast(expr, result)) stack.append(result) diff --git a/wave_lang/kernel/ops/wave_schedule_ops.py b/wave_lang/kernel/ops/wave_schedule_ops.py index 159a1549e..da8bfae40 100755 --- a/wave_lang/kernel/ops/wave_schedule_ops.py +++ b/wave_lang/kernel/ops/wave_schedule_ops.py @@ -810,6 +810,7 @@ def __init__( self.kernel_trace = kernel_trace self.constraints = constraints self.multi_buffer_count = multi_buffer_count + self.unroll_factor = 1 # Access options from the current ScheduleContext from .._support.tracing import ScheduleContext @@ -875,6 +876,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): scheduling_type=SchedulingType.MANUAL, visualize=False, multi_buffer_count=self.multi_buffer_count, + unroll_factor=self.unroll_factor, ) # Store the pipelined iterate node and node mapping, then create proxies for the stages diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index cc422b354..186f6bcf1 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -1411,6 +1411,7 @@ def mxfp4_dbuf_schedule(): # This forces the pipeline to use double buffering pipeline_loop.multi_buffer_count = 2 + pipeline_loop.unroll_factor = 2 with pipeline_loop as pl: pl.set_stage( diff --git a/wave_lang/kernel/wave/scheduling/schedule.py b/wave_lang/kernel/wave/scheduling/schedule.py index 11cb374c4..3a299344f 100644 --- a/wave_lang/kernel/wave/scheduling/schedule.py +++ b/wave_lang/kernel/wave/scheduling/schedule.py @@ -214,12 +214,13 @@ def build_guarded_pipeline_with_remainder( visualize: bool = False, use_scheduling_barriers: bool = False, multi_buffer_count: Optional[int] = None, + unroll_factor: int = 1, ): """ Build conditional + pipelined loop + remainder loop for dynamic shapes. Structure: - if (max_induction_variable >= num_stages): + if (max_induction_variable >= num_stages + unroll_factor - 1): pipelined_result = pipelined_loop_with_prologue_epilogue() else: pipelined_result = init_values @@ -232,16 +233,24 @@ def build_guarded_pipeline_with_remainder( original_init_args = reduction.init_args main_graph = reduction.graph - # Create condition: max_induction_variable >= num_stages + if unroll_factor < 1: + raise ValueError(f"Expected unroll_factor >= 1, got {unroll_factor}") + + from math import lcm + + rounding_stride = lcm(num_stages, unroll_factor) + + # Need at least rounding_stride iterations so the rounded-down + # pipelined_iterations is non-zero. with main_graph.inserting_before(reduction.fx_node): - num_stages_scalar = get_graph_node( - NewScalar(num_stages, tkl.i32), main_graph, reduction.location + min_iters_scalar = get_graph_node( + NewScalar(rounding_stride, tkl.i32), main_graph, reduction.location ) num_iters_scalar = get_graph_node( NewScalar(max_induction_variable, tkl.i32), main_graph, reduction.location ) condition = get_graph_node( - Ge(num_iters_scalar, num_stages_scalar), main_graph, reduction.location + Ge(num_iters_scalar, min_iters_scalar), main_graph, reduction.location ) # Prepare conditional subgraph @@ -257,12 +266,22 @@ def build_guarded_pipeline_with_remainder( ) ) - # Compute the number of iterations the pipelined loop should process: - # This ensures the pipelined loop only processes complete pipeline stages + # Round pipelined_iterations to a multiple of lcm(num_stages, unroll_factor). + # This ensures complete pipeline stages (multiple of num_stages) and that + # the kernel count (pipelined_iterations - (num_stages - 1)) is divisible + # by unroll_factor, preventing the unrolled scf.for from executing extra + # iterations that read invalid pipeline state. + from math import lcm + + rounding_stride = lcm(num_stages, unroll_factor) if isinstance(max_induction_variable, (int, float)): - pipelined_iterations = (int(max_induction_variable) // num_stages) * num_stages + pipelined_iterations = ( + int(max_induction_variable) // rounding_stride + ) * rounding_stride else: - pipelined_iterations = (max_induction_variable // num_stages) * num_stages + pipelined_iterations = ( + max_induction_variable // rounding_stride + ) * rounding_stride conditional_body_graph, body_old_to_new = graph_copy(reduction_graph) placeholder_init_args = [placeholders[arg] for arg in reduction.init_args] @@ -420,12 +439,13 @@ def construct_pipelined_loop_adaptive( visualize: bool = False, use_scheduling_barriers: bool = False, multi_buffer_count: Optional[int] = None, + unroll_factor: int = 1, ): """ Constructs a pipelined loop wrapped in a conditional, followed by a remainder loop. Structure: - if (num_iterations >= num_stages): + if (num_iterations >= num_stages + unroll_factor - 1): prologue pipelined_loop epilogue @@ -477,6 +497,7 @@ def construct_pipelined_loop_adaptive( visualize, use_scheduling_barriers, multi_buffer_count, + unroll_factor, ) @@ -491,6 +512,7 @@ def apply_pipelined_schedule( scheduling_type: SchedulingType = SchedulingType.NONE, visualize: bool = False, multi_buffer_count: Optional[int] = None, + unroll_factor: int = 1, ) -> Optional[tuple[fx.Node, dict]]: # After scheduling has completed, we have enough information to decide @@ -525,6 +547,7 @@ def apply_pipelined_schedule( visualize, use_scheduling_barriers, multi_buffer_count, + unroll_factor, )