Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/kernel/wave_gemm_mxfp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,12 @@ def testScaledGemmMXFP4PreshuffleMacrotiles(
)
@pytest.mark.parametrize(
"block_shape,wave_shape",
[((128, 256, 256), (1, 4)), ((128, 32, 256), (2, 2)), ((256, 224, 256), (2, 2))],
[
((128, 256, 256), (1, 4)),
((32, 64, 256), (1, 4)),
((128, 32, 256), (2, 2)),
((256, 224, 256), (2, 2)),
],
)
@pytest.mark.parametrize(
"mfma_variant",
Expand Down
3 changes: 2 additions & 1 deletion wave_lang/kernel/compiler/wave_codegen/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1139,13 +1139,14 @@ def _get_const(val):

# Build nested select operations
# Start with the last expression (typically the default/else case)
result = cases[-1][1]
result = _resolve_rational(cases[-1][1])

# Work backwards through earlier cases to build nested selects
# Piecewise((expr1, cond1), (expr2, cond2), (expr3, True)) becomes:
# select(cond1, expr1, select(cond2, expr2, expr3))
for i in range(len(cases) - 2, -1, -1):
cond, expr = cases[i]
expr = _resolve_rational(expr)
result = arith_d.select(cond, *_broadcast(expr, result))

stack.append(result)
Expand Down
2 changes: 2 additions & 0 deletions wave_lang/kernel/ops/wave_schedule_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@ def __init__(
self.kernel_trace = kernel_trace
self.constraints = constraints
self.multi_buffer_count = multi_buffer_count
self.unroll_factor = 1

# Access options from the current ScheduleContext
from .._support.tracing import ScheduleContext
Expand Down Expand Up @@ -875,6 +876,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
scheduling_type=SchedulingType.MANUAL,
visualize=False,
multi_buffer_count=self.multi_buffer_count,
unroll_factor=self.unroll_factor,
)

# Store the pipelined iterate node and node mapping, then create proxies for the stages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,7 @@ def mxfp4_dbuf_schedule():

# This forces the pipeline to use double buffering
pipeline_loop.multi_buffer_count = 2
pipeline_loop.unroll_factor = 2

with pipeline_loop as pl:
pl.set_stage(
Expand Down
43 changes: 33 additions & 10 deletions wave_lang/kernel/wave/scheduling/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,13 @@ def build_guarded_pipeline_with_remainder(
visualize: bool = False,
use_scheduling_barriers: bool = False,
multi_buffer_count: Optional[int] = None,
unroll_factor: int = 1,
):
"""
Build conditional + pipelined loop + remainder loop for dynamic shapes.

Structure:
if (max_induction_variable >= num_stages):
if (max_induction_variable >= num_stages + unroll_factor - 1):
pipelined_result = pipelined_loop_with_prologue_epilogue()
else:
pipelined_result = init_values
Expand All @@ -232,16 +233,24 @@ def build_guarded_pipeline_with_remainder(
original_init_args = reduction.init_args
main_graph = reduction.graph

# Create condition: max_induction_variable >= num_stages
if unroll_factor < 1:
raise ValueError(f"Expected unroll_factor >= 1, got {unroll_factor}")

from math import lcm

rounding_stride = lcm(num_stages, unroll_factor)

# Need at least rounding_stride iterations so the rounded-down
# pipelined_iterations is non-zero.
with main_graph.inserting_before(reduction.fx_node):
num_stages_scalar = get_graph_node(
NewScalar(num_stages, tkl.i32), main_graph, reduction.location
min_iters_scalar = get_graph_node(
NewScalar(rounding_stride, tkl.i32), main_graph, reduction.location
)
num_iters_scalar = get_graph_node(
NewScalar(max_induction_variable, tkl.i32), main_graph, reduction.location
)
condition = get_graph_node(
Ge(num_iters_scalar, num_stages_scalar), main_graph, reduction.location
Ge(num_iters_scalar, min_iters_scalar), main_graph, reduction.location
)

# Prepare conditional subgraph
Expand All @@ -257,12 +266,22 @@ def build_guarded_pipeline_with_remainder(
)
)

# Compute the number of iterations the pipelined loop should process:
# This ensures the pipelined loop only processes complete pipeline stages
# Round pipelined_iterations to a multiple of lcm(num_stages, unroll_factor).
# This ensures complete pipeline stages (multiple of num_stages) and that
# the kernel count (pipelined_iterations - (num_stages - 1)) is divisible
# by unroll_factor, preventing the unrolled scf.for from executing extra
# iterations that read invalid pipeline state.
from math import lcm

rounding_stride = lcm(num_stages, unroll_factor)
if isinstance(max_induction_variable, (int, float)):
pipelined_iterations = (int(max_induction_variable) // num_stages) * num_stages
pipelined_iterations = (
int(max_induction_variable) // rounding_stride
) * rounding_stride
else:
pipelined_iterations = (max_induction_variable // num_stages) * num_stages
pipelined_iterations = (
max_induction_variable // rounding_stride
) * rounding_stride

conditional_body_graph, body_old_to_new = graph_copy(reduction_graph)
placeholder_init_args = [placeholders[arg] for arg in reduction.init_args]
Expand Down Expand Up @@ -420,12 +439,13 @@ def construct_pipelined_loop_adaptive(
visualize: bool = False,
use_scheduling_barriers: bool = False,
multi_buffer_count: Optional[int] = None,
unroll_factor: int = 1,
):
"""
Constructs a pipelined loop wrapped in a conditional, followed by a remainder loop.

Structure:
if (num_iterations >= num_stages):
if (num_iterations >= num_stages + unroll_factor - 1):
prologue
pipelined_loop
epilogue
Expand Down Expand Up @@ -477,6 +497,7 @@ def construct_pipelined_loop_adaptive(
visualize,
use_scheduling_barriers,
multi_buffer_count,
unroll_factor,
)


Expand All @@ -491,6 +512,7 @@ def apply_pipelined_schedule(
scheduling_type: SchedulingType = SchedulingType.NONE,
visualize: bool = False,
multi_buffer_count: Optional[int] = None,
unroll_factor: int = 1,
) -> Optional[tuple[fx.Node, dict]]:

# After scheduling has completed, we have enough information to decide
Expand Down Expand Up @@ -525,6 +547,7 @@ def apply_pipelined_schedule(
visualize,
use_scheduling_barriers,
multi_buffer_count,
unroll_factor,
)


Expand Down