diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index 892394f249..1eadcfe494 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -907,7 +907,12 @@ def testScaledGemmMXFP4PreshuffleMacrotiles( ) @pytest.mark.parametrize( "block_shape,wave_shape", - [((128, 256, 256), (1, 4)), ((128, 32, 256), (2, 2)), ((256, 224, 256), (2, 2))], + [ + ((128, 256, 256), (1, 4)), + ((32, 64, 256), (1, 4)), + ((128, 32, 256), (2, 2)), + ((256, 224, 256), (2, 2)), + ], ) @pytest.mark.parametrize( "mfma_variant", diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index da11cbafda..c97d311582 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -1139,13 +1139,14 @@ def _get_const(val): # Build nested select operations # Start with the last expression (typically the default/else case) - result = cases[-1][1] + result = _resolve_rational(cases[-1][1]) # Work backwards through earlier cases to build nested selects # Piecewise((expr1, cond1), (expr2, cond2), (expr3, True)) becomes: # select(cond1, expr1, select(cond2, expr2, expr3)) for i in range(len(cases) - 2, -1, -1): cond, expr = cases[i] + expr = _resolve_rational(expr) result = arith_d.select(cond, *_broadcast(expr, result)) stack.append(result) diff --git a/wave_lang/kernel/ops/wave_schedule_ops.py b/wave_lang/kernel/ops/wave_schedule_ops.py index 159a1549e4..da8bfae40b 100755 --- a/wave_lang/kernel/ops/wave_schedule_ops.py +++ b/wave_lang/kernel/ops/wave_schedule_ops.py @@ -810,6 +810,7 @@ def __init__( self.kernel_trace = kernel_trace self.constraints = constraints self.multi_buffer_count = multi_buffer_count + self.unroll_factor = 1 # Access options from the current ScheduleContext from .._support.tracing import ScheduleContext @@ -875,6 +876,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): scheduling_type=SchedulingType.MANUAL, visualize=False, multi_buffer_count=self.multi_buffer_count, + unroll_factor=self.unroll_factor, ) # Store the pipelined iterate node and node mapping, then create proxies for the stages diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index cc422b3540..186f6bcf17 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -1411,6 +1411,7 @@ def mxfp4_dbuf_schedule(): # This forces the pipeline to use double buffering pipeline_loop.multi_buffer_count = 2 + pipeline_loop.unroll_factor = 2 with pipeline_loop as pl: pl.set_stage( diff --git a/wave_lang/kernel/wave/scheduling/schedule.py b/wave_lang/kernel/wave/scheduling/schedule.py index 11cb374c42..3a299344fb 100644 --- a/wave_lang/kernel/wave/scheduling/schedule.py +++ b/wave_lang/kernel/wave/scheduling/schedule.py @@ -214,12 +214,13 @@ def build_guarded_pipeline_with_remainder( visualize: bool = False, use_scheduling_barriers: bool = False, multi_buffer_count: Optional[int] = None, + unroll_factor: int = 1, ): """ Build conditional + pipelined loop + remainder loop for dynamic shapes. Structure: - if (max_induction_variable >= num_stages): + if (max_induction_variable >= num_stages + unroll_factor - 1): pipelined_result = pipelined_loop_with_prologue_epilogue() else: pipelined_result = init_values @@ -232,16 +233,24 @@ def build_guarded_pipeline_with_remainder( original_init_args = reduction.init_args main_graph = reduction.graph - # Create condition: max_induction_variable >= num_stages + if unroll_factor < 1: + raise ValueError(f"Expected unroll_factor >= 1, got {unroll_factor}") + + from math import lcm + + rounding_stride = lcm(num_stages, unroll_factor) + + # Need at least rounding_stride iterations so the rounded-down + # pipelined_iterations is non-zero. with main_graph.inserting_before(reduction.fx_node): - num_stages_scalar = get_graph_node( - NewScalar(num_stages, tkl.i32), main_graph, reduction.location + min_iters_scalar = get_graph_node( + NewScalar(rounding_stride, tkl.i32), main_graph, reduction.location ) num_iters_scalar = get_graph_node( NewScalar(max_induction_variable, tkl.i32), main_graph, reduction.location ) condition = get_graph_node( - Ge(num_iters_scalar, num_stages_scalar), main_graph, reduction.location + Ge(num_iters_scalar, min_iters_scalar), main_graph, reduction.location ) # Prepare conditional subgraph @@ -257,12 +266,22 @@ def build_guarded_pipeline_with_remainder( ) ) - # Compute the number of iterations the pipelined loop should process: - # This ensures the pipelined loop only processes complete pipeline stages + # Round pipelined_iterations to a multiple of lcm(num_stages, unroll_factor). + # This ensures complete pipeline stages (multiple of num_stages) and that + # the kernel count (pipelined_iterations - (num_stages - 1)) is divisible + # by unroll_factor, preventing the unrolled scf.for from executing extra + # iterations that read invalid pipeline state. + from math import lcm + + rounding_stride = lcm(num_stages, unroll_factor) if isinstance(max_induction_variable, (int, float)): - pipelined_iterations = (int(max_induction_variable) // num_stages) * num_stages + pipelined_iterations = ( + int(max_induction_variable) // rounding_stride + ) * rounding_stride else: - pipelined_iterations = (max_induction_variable // num_stages) * num_stages + pipelined_iterations = ( + max_induction_variable // rounding_stride + ) * rounding_stride conditional_body_graph, body_old_to_new = graph_copy(reduction_graph) placeholder_init_args = [placeholders[arg] for arg in reduction.init_args] @@ -420,12 +439,13 @@ def construct_pipelined_loop_adaptive( visualize: bool = False, use_scheduling_barriers: bool = False, multi_buffer_count: Optional[int] = None, + unroll_factor: int = 1, ): """ Constructs a pipelined loop wrapped in a conditional, followed by a remainder loop. Structure: - if (num_iterations >= num_stages): + if (num_iterations >= num_stages + unroll_factor - 1): prologue pipelined_loop epilogue @@ -477,6 +497,7 @@ def construct_pipelined_loop_adaptive( visualize, use_scheduling_barriers, multi_buffer_count, + unroll_factor, ) @@ -491,6 +512,7 @@ def apply_pipelined_schedule( scheduling_type: SchedulingType = SchedulingType.NONE, visualize: bool = False, multi_buffer_count: Optional[int] = None, + unroll_factor: int = 1, ) -> Optional[tuple[fx.Node, dict]]: # After scheduling has completed, we have enough information to decide @@ -525,6 +547,7 @@ def apply_pipelined_schedule( visualize, use_scheduling_barriers, multi_buffer_count, + unroll_factor, )