From ca7adb3709aef21d0032e120163405ca7e48d489 Mon Sep 17 00:00:00 2001 From: xintin Date: Wed, 25 Feb 2026 14:17:15 +0000 Subject: [PATCH 1/9] epilogue elimination Signed-off-by: xintin --- examples/python/7.1_schedule.py | 24 +- .../compiler/wave_codegen/read_write.py | 71 ++++- wave_lang/kernel/ops/wave_schedule_ops.py | 16 +- wave_lang/kernel/wave/compile.py | 5 + wave_lang/kernel/wave/compile_options.py | 1 + .../schedules/gemm_mxfp4_double_buffer.py | 289 +++++++++--------- .../wave/scheduling/loop_reconstruction.py | 153 ++++++++-- wave_lang/kernel/wave/scheduling/schedule.py | 37 ++- 8 files changed, 408 insertions(+), 188 deletions(-) diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index 693f3af92a..88f1ab217a 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -213,10 +213,11 @@ def test_dbuf_4wave_mxfp_asymmetric_gemm( def test_dbuf_4wave_mxfp_preshuffle_b_gemm( - is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256) + is_debug=False, shape=(1024, 1024, 8192), block=(128, 256, 256) ): """Asymmetric MXFP4 GEMM with preshuffled B data and B scales.""" gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4)) + print(block) options.minimize_shared_allocs = True options.linearize_shared_access = True options.use_buffer_ops = True @@ -231,6 +232,27 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm( print("MXFP GEMM preshuffle-B 4-wave test passed!") +def test_dbuf_4wave_mxfp_preshuffle_b_no_epilogue_gemm( + is_debug=False, shape=(1024, 2048, 8192), block=(128, 256, 256) +): + """Asymmetric MXFP4 GEMM with preshuffled B, epilogue eliminated via OOB=0.""" + gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4)) + options.minimize_shared_allocs = True + options.linearize_shared_access = True + options.use_buffer_ops = True + options.eliminate_epilogue = True + options.dump_intermediates = "build/intermediates" + schedule = get_mxfp4_asymmetric_schedule(eliminate_epilogue=True, is_bscale_shuffled=True) + + options.print_ir_after = "all" if is_debug else [] + options = set_default_run_config(options) + print(block) + gemm = wave_compile(options, gemm, schedule) + + _run_mxfp_gemm_preshuffle(gemm, shape, all=True) + print("MXFP GEMM preshuffle-B no-epilogue 4-wave test passed!") + + def test_dbuf_4wave_mxfp_asymmetric_gemm_cpp( is_debug=False, shape=(1024, 1024, 8192), block=(128, 256, 256) ): diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 110c5ffee5..c64830396c 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -356,16 +356,53 @@ def _get_constant_value(candidate: Value): def _cast_buffer_and_encode_stride( - ptr: Value, strides: tuple[Value], elem_type: IrType, emitter: WaveEmitter + ptr: Value, + strides: tuple[Value], + elem_type: IrType, + emitter: WaveEmitter, + symbolic_shape: tuple = None, + is_read: bool = True, ) -> Value: uint64 = IntegerType.get_signless(64) uint14 = IntegerType.get_signless(14) + elem_bytes = elem_type.width // 8 + + if emitter.options.use_real_buffer_bounds and symbolic_shape is not None: + total_elements = subs_idxc(sympy.prod(s for s in symbolic_shape)) + if isinstance(total_elements, (int, float)) or ( + hasattr(total_elements, "is_number") and total_elements.is_number + ): + total_bytes = int(total_elements) * elem_bytes + max_valid = _valid_bytes_buffer(elem_type) + total_bytes = min(total_bytes, max_valid) + else: + total_bytes = _valid_bytes_buffer(elem_type) + else: + total_bytes = _valid_bytes_buffer(elem_type) + + # With resetOffset, the SRD base is adjusted forward by the memref's + # offset, so validBytes must be total_bytes - offset_bytes to avoid + # over-reporting the valid range past the actual allocation. + # + # For writes, skip the offset subtraction: real buffer bounds are only + # useful for reads (OOB loads return 0), and subtracting the offset + # from a clamped total_bytes can produce values that overflow the + # 32-bit SRD NUM_RECORDS field for output buffers larger than 2 GB. + if is_read: + metadata = memref_d.extract_strided_metadata(ptr) + offset_elements = metadata[1] + offset_bytes = arith_d.index_cast(uint64, offset_elements) + elem_bytes_val = arith_d.constant(uint64, get_constant_attr(elem_bytes, uint64)) + offset_bytes = arith_d.muli(offset_bytes, elem_bytes_val) + total_bytes_val = arith_d.constant( + uint64, get_constant_attr(total_bytes, uint64) + ) + valid_bytes_val = arith_d.subi(total_bytes_val, offset_bytes) + else: + valid_bytes_val = arith_d.constant( + uint64, get_constant_attr(total_bytes, uint64) + ) - valid_bytes = _valid_bytes_buffer( - elem_type - ) # max bytes that are in range to be addressed from a buffer - valid_bytes_constant = get_constant_attr(valid_bytes, uint64) - valid_bytes_constant = arith_d.constant(uint64, valid_bytes_constant) stride_rank = len(strides) swizzle_stride = None @@ -373,8 +410,8 @@ def _cast_buffer_and_encode_stride( # fastest_dim_bound == second to last stride. stride_candidate = strides[-2] stride_int = _get_constant_value(stride_candidate) - # Only swizzle if stride is static and <= 8192(the useful case). - if stride_int and stride_int <= 8192: + # Only swizzle if stride is static and fits in signed i14. + if stride_int and stride_int < 8192: swizzle_stride = arith_d.index_cast(uint14, stride_candidate) if swizzle_stride: @@ -383,14 +420,14 @@ def _cast_buffer_and_encode_stride( cache_swizzle_stride=swizzle_stride, bounds_check=True, reset_offset=True, - valid_bytes=valid_bytes_constant, + valid_bytes=valid_bytes_val, ) else: ptr = amdgpu_d.fat_raw_buffer_cast( ptr, bounds_check=True, reset_offset=True, - valid_bytes=valid_bytes_constant, + valid_bytes=valid_bytes_val, ) return ptr @@ -501,10 +538,8 @@ def extract(vec, ind): mem, offset_th = _linearize_memref( mem, start_indices_wg, start_indices_th, strides ) - mem = _cast_buffer_and_encode_stride(mem, strides, element_type, emitter) - elif is_global_mem and not is_read: - mem, offset_th = _linearize_memref( - mem, start_indices_wg, start_indices_th, strides + mem = _cast_buffer_and_encode_stride( + mem, strides, element_type, emitter, symbolic_shape, is_read ) if linearize_shared_mem: mem = _linearize_shared_mem(mem) @@ -546,7 +581,9 @@ def extract(vec, ind): mem, offset_th = _linearize_memref( mem, start_indices_wg, start_indices_th, strides ) - mem = _cast_buffer_and_encode_stride(mem, strides, element_type, emitter) + mem = _cast_buffer_and_encode_stride( + mem, strides, element_type, emitter, symbolic_shape, is_read + ) indices = [offset_th] if buffer_ops_enabled else start_indices @@ -1102,7 +1139,9 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): ] src, offset_th = _linearize_memref(src, src_index_wg, src_index_th, strides) - src = _cast_buffer_and_encode_stride(src, strides, element_type, emitter) + src = _cast_buffer_and_encode_stride( + src, strides, element_type, emitter, src_symbolic_shape + ) # We previously checked mask is same for all elements, so we can use # elements_per_thread=1 to build the mask. diff --git a/wave_lang/kernel/ops/wave_schedule_ops.py b/wave_lang/kernel/ops/wave_schedule_ops.py index ea6a42bc0b..dceb592eb9 100755 --- a/wave_lang/kernel/ops/wave_schedule_ops.py +++ b/wave_lang/kernel/ops/wave_schedule_ops.py @@ -183,7 +183,7 @@ def reorder_graph(loop: Any, clusters: Any): ... @define_schedule_op -def pipeline(iterate: Sequence[fx.Node], multi_buffer_count: Optional[int] = None): ... +def pipeline(iterate: Sequence[fx.Node], eliminate_epilogue: bool = False, multi_buffer_count: Optional[int] = None): ... @define_schedule_op @@ -801,11 +801,13 @@ def __init__( iterate: Sequence[fx.Node], kernel_trace: "CapturedTrace", constraints: list[Constraint], + eliminate_epilogue: bool = False, multi_buffer_count: Optional[int] = None, ): self.iterate = iterate self.kernel_trace = kernel_trace self.constraints = constraints + self.eliminate_epilogue = eliminate_epilogue self.multi_buffer_count = multi_buffer_count # Access options from the current ScheduleContext @@ -872,6 +874,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): scheduling_type=SchedulingType.MANUAL, visualize=False, multi_buffer_count=self.multi_buffer_count, + eliminate_epilogue=self.eliminate_epilogue, ) # Store the pipelined iterate node and node mapping, then create proxies for the stages @@ -934,7 +937,11 @@ def PROLOGUE(self): @property def EPILOGUE(self): - """Get a reference to the EPILOGUE stage (nodes after pipelined iterate).""" + """Get a reference to the EPILOGUE stage (nodes after pipelined iterate). + Returns None when eliminate_epilogue=True (epilogue was not generated). + """ + if self.eliminate_epilogue: + return None return self._EPILOGUE def _update_kernel_node_mapping(self): @@ -1107,8 +1114,11 @@ def handle( kernel_trace, constraints: list[Constraint], iterate: Sequence[fx.Node], + eliminate_epilogue: bool = False, ): - real_pipelined_loop = PipelinedLoop(iterate, kernel_trace, constraints) + real_pipelined_loop = PipelinedLoop( + iterate, kernel_trace, constraints, eliminate_epilogue=eliminate_epilogue + ) # Return the real object directly (no proxy needed) return real_pipelined_loop diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 2e74a70b48..2f92bfeff0 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -73,6 +73,7 @@ from .multicast import multicast from .promotion import compute_shared_memory_usage, promote_placeholders from .schedule_reordering import schedule_reordering +from .scheduling.loop_reconstruction import guard_g2s_with_bounds_check from .scheduling.schedule import schedule_graph from .scheduling.schedule_enums import SchedulingType from .shared_memory_indexing import apply_shared_memory_indexing_corrections @@ -554,6 +555,10 @@ def build_graph_passes( ) ) + graph_passes.append( + partial(guard_g2s_with_bounds_check, trace, launchable.constraints) + ) + if options.optimization_level: graph_passes += [ partial( diff --git a/wave_lang/kernel/wave/compile_options.py b/wave_lang/kernel/wave/compile_options.py index bc123aa2c6..8b525bed5d 100644 --- a/wave_lang/kernel/wave/compile_options.py +++ b/wave_lang/kernel/wave/compile_options.py @@ -89,6 +89,7 @@ class WaveCompileOptions: wave_runtime: bool = False iree_launch_async: bool = True use_buffer_ops: bool = False + use_real_buffer_bounds: bool = False use_fast_math: bool = False use_global_to_shared: bool = False linearize_shared_access: bool = False diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index dca51dd0eb..def59fc1f3 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -1142,7 +1142,7 @@ def mxfp4_dbuf_schedule(): return mxfp4_dbuf_schedule -def get_mxfp4_asymmetric_schedule(is_bscale_shuffled: bool = False): +def get_mxfp4_asymmetric_schedule(eliminate_epilogue: bool = False, is_bscale_shuffled: bool = False): """Return an asymmetric-prefetch MXFP4 schedule for wave_compile(). Asymmetric data paths: @@ -1162,6 +1162,12 @@ def get_mxfp4_asymmetric_schedule(is_bscale_shuffled: bool = False): Second MMA half: interleaved with B_scale loads and next-iteration first-partition A reads (plus G2S for the iteration after next). + + When eliminate_epilogue=True the loop runs for the full K trip count + and relies on OOB buffer loads returning zero (GFX9+ hardware guarantee) + so that extra iterations contribute nothing to the accumulators. This + removes all epilogue code, reducing icache pressure and total code size. + Requires options.use_buffer_ops=True and options.use_real_buffer_bounds=True. """ M = tkl.sym.M @@ -1204,7 +1210,7 @@ def mxfp4_dbuf_schedule(): # ===================================================================== # Create 2-stage pipeline (double buffering) # ===================================================================== - pipeline_loop = tkw.pipeline(k_loop) + pipeline_loop = tkw.pipeline(k_loop, eliminate_epilogue=eliminate_epilogue) # This forces the pipeline to use double buffering pipeline_loop.multi_buffer_count = 2 @@ -1403,147 +1409,154 @@ def mxfp4_dbuf_schedule(): ), ] - #################### EPILOGUE #################### - - # Filter nodes for EPILOGUE stage - - epilogue_g2v_b = tkw.filter_nodes(g2v_b, subgraph=pipeline_loop.EPILOGUE) - epilogue_g2v_b_scale = tkw.filter_nodes( - g2v_b_scale, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_s2v_a_0 = tkw.filter_nodes(s2v_a_0, subgraph=pipeline_loop.EPILOGUE) - epilogue_s2v_a_scale_0 = tkw.filter_nodes( - s2v_a_scale_0, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_s2v_a_1 = tkw.filter_nodes(s2v_a_1, subgraph=pipeline_loop.EPILOGUE) - epilogue_s2v_a_scale_1 = tkw.filter_nodes( - s2v_a_scale_1, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_bitcast_a = tkw.filter_nodes( - bitcast_a, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_bitcast_a_scale = tkw.filter_nodes( - bitcast_a_scale, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_bitcast_b = tkw.filter_nodes( - bitcast_b, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_bitcast_b_scale = tkw.filter_nodes( - bitcast_b_scale, subgraph=pipeline_loop.EPILOGUE - ) - - epilogue_mma = tkw.filter_nodes(scaled_mma, subgraph=pipeline_loop.EPILOGUE) - - def split_by_iteration(nodes, key="name"): - # TODO: Replace name-based splitting with a pipeline_drain_iteration - # attribute (analogous to unroll_iteration). expanded_dims can't be - # used here because loop_reconstruction copies them verbatim for - # both drain iterations. - itr0 = [] - itr1 = [] - for node in nodes: - value = getattr(node, key) - if "1_2" in value: - itr0.append(node) - elif "2_2" in value: - itr1.append(node) - else: - raise ValueError(f"Unknown {key} for node: {value}") - return itr0, itr1 + if eliminate_epilogue: + clusters += prologue_clusters + tkw.reorder_graph(pipeline_loop.KERNEL, clusters) + else: + epilogue_g2v_b = tkw.filter_nodes(g2v_b, subgraph=pipeline_loop.EPILOGUE) + epilogue_g2v_b_scale = tkw.filter_nodes( + g2v_b_scale, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_s2v_a_0 = tkw.filter_nodes( + s2v_a_0, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_s2v_a_scale_0 = tkw.filter_nodes( + s2v_a_scale_0, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_s2v_a_1 = tkw.filter_nodes( + s2v_a_1, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_s2v_a_scale_1 = tkw.filter_nodes( + s2v_a_scale_1, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_bitcast_a = tkw.filter_nodes( + bitcast_a, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_bitcast_a_scale = tkw.filter_nodes( + bitcast_a_scale, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_bitcast_b = tkw.filter_nodes( + bitcast_b, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_bitcast_b_scale = tkw.filter_nodes( + bitcast_b_scale, subgraph=pipeline_loop.EPILOGUE + ) - epilogue_mma_itr0, epilogue_mma_itr1 = split_by_iteration(epilogue_mma) - epilogue_s2v_a_1_itr0, epilogue_s2v_a_1_itr1 = split_by_iteration( - epilogue_s2v_a_1 - ) - epilogue_s2v_a_scale_1_itr0, epilogue_s2v_a_scale_1_itr1 = split_by_iteration( - epilogue_s2v_a_scale_1 - ) - epilogue_bitcast_a_itr0, epilogue_bitcast_a_itr1 = split_by_iteration( - epilogue_bitcast_a - ) - epilogue_bitcast_a_scale_itr0, epilogue_bitcast_a_scale_itr1 = ( - split_by_iteration(epilogue_bitcast_a_scale) - ) - epilogue_bitcast_b_itr0, epilogue_bitcast_b_itr1 = split_by_iteration( - epilogue_bitcast_b - ) - epilogue_bitcast_b_scale_itr0, epilogue_bitcast_b_scale_itr1 = ( - split_by_iteration(epilogue_bitcast_b_scale) - ) + epilogue_mma = tkw.filter_nodes(scaled_mma, subgraph=pipeline_loop.EPILOGUE) + + def split_by_iteration(nodes, key="name"): + # TODO: Replace name-based splitting with a + # pipeline_drain_iteration attribute (analogous to + # unroll_iteration). expanded_dims can't be used here because + # loop_reconstruction copies them verbatim for both drain + # iterations. + itr0 = [] + itr1 = [] + for node in nodes: + value = getattr(node, key) + if "1_2" in value: + itr0.append(node) + elif "2_2" in value: + itr1.append(node) + else: + raise ValueError(f"Unknown {key} for node: {value}") + return itr0, itr1 + + epilogue_mma_itr0, epilogue_mma_itr1 = split_by_iteration(epilogue_mma) + epilogue_s2v_a_1_itr0, epilogue_s2v_a_1_itr1 = split_by_iteration( + epilogue_s2v_a_1 + ) + ( + epilogue_s2v_a_scale_1_itr0, + epilogue_s2v_a_scale_1_itr1, + ) = split_by_iteration(epilogue_s2v_a_scale_1) + epilogue_bitcast_a_itr0, epilogue_bitcast_a_itr1 = split_by_iteration( + epilogue_bitcast_a + ) + epilogue_bitcast_a_scale_itr0, epilogue_bitcast_a_scale_itr1 = ( + split_by_iteration(epilogue_bitcast_a_scale) + ) + epilogue_bitcast_b_itr0, epilogue_bitcast_b_itr1 = split_by_iteration( + epilogue_bitcast_b + ) + epilogue_bitcast_b_scale_itr0, epilogue_bitcast_b_scale_itr1 = ( + split_by_iteration(epilogue_bitcast_b_scale) + ) - epilogue_mma_itr0_0, epilogue_mma_itr0_1 = tkw.partition_by_dim( - epilogue_mma_itr0, dim=M, num_partitions=2 - ) - epilogue_bitcast_a_itr0_0, epilogue_bitcast_a_itr0_1 = tkw.partition_by_dim( - epilogue_bitcast_a_itr0, dim=M, num_partitions=2 - ) - epilogue_bitcast_a_scale_itr0_0, epilogue_bitcast_a_scale_itr0_1 = ( - tkw.partition_by_dim(epilogue_bitcast_a_scale_itr0, dim=M, num_partitions=2) - ) + epilogue_mma_itr0_0, epilogue_mma_itr0_1 = tkw.partition_by_dim( + epilogue_mma_itr0, dim=M, num_partitions=2 + ) + epilogue_bitcast_a_itr0_0, epilogue_bitcast_a_itr0_1 = tkw.partition_by_dim( + epilogue_bitcast_a_itr0, dim=M, num_partitions=2 + ) + epilogue_bitcast_a_scale_itr0_0, epilogue_bitcast_a_scale_itr0_1 = ( + tkw.partition_by_dim( + epilogue_bitcast_a_scale_itr0, dim=M, num_partitions=2 + ) + ) - epilogue_mma_itr1_0, epilogue_mma_itr1_1 = tkw.partition_by_dim( - epilogue_mma_itr1, dim=M, num_partitions=2 - ) - epilogue_bitcast_a_itr1_0, epilogue_bitcast_a_itr1_1 = tkw.partition_by_dim( - epilogue_bitcast_a_itr1, dim=M, num_partitions=2 - ) - epilogue_bitcast_a_scale_itr1_0, epilogue_bitcast_a_scale_itr1_1 = ( - tkw.partition_by_dim(epilogue_bitcast_a_scale_itr1, dim=M, num_partitions=2) - ) + epilogue_mma_itr1_0, epilogue_mma_itr1_1 = tkw.partition_by_dim( + epilogue_mma_itr1, dim=M, num_partitions=2 + ) + epilogue_bitcast_a_itr1_0, epilogue_bitcast_a_itr1_1 = tkw.partition_by_dim( + epilogue_bitcast_a_itr1, dim=M, num_partitions=2 + ) + epilogue_bitcast_a_scale_itr1_0, epilogue_bitcast_a_scale_itr1_1 = ( + tkw.partition_by_dim( + epilogue_bitcast_a_scale_itr1, dim=M, num_partitions=2 + ) + ) - epilogue_clusters_itr0 = [ - tkw.cluster( - [ - tkw.MemoryCounterWaitBarrier(load=0, ds=0), - epilogue_bitcast_a_itr0_0, - epilogue_bitcast_a_scale_itr0_0, - epilogue_bitcast_b_itr0, - epilogue_bitcast_b_scale_itr0, - tkw.SchedulingBarrier([]), - epilogue_mma_itr0_0, - epilogue_g2v_b, - epilogue_s2v_a_1_itr0, - epilogue_g2v_b_scale, - epilogue_s2v_a_scale_1_itr0, - epilogue_bitcast_a_itr0_1, - epilogue_bitcast_a_scale_itr0_1, - ], - ), - tkw.cluster( - [ - epilogue_mma_itr0_1, - tkw.MemoryCounterWaitBarrier(load=0, ds=0), - epilogue_s2v_a_0, - epilogue_s2v_a_scale_0, - ], - ), - tkw.cluster( - [ - tkw.MemoryCounterWaitBarrier(load=0, ds=0), - epilogue_bitcast_a_itr1_0, - epilogue_bitcast_a_scale_itr1_0, - epilogue_bitcast_b_itr1, - epilogue_bitcast_b_scale_itr1, - tkw.SchedulingBarrier([]), - epilogue_mma_itr1_0, - epilogue_s2v_a_1_itr1, - epilogue_s2v_a_scale_1_itr1, - ], - ), - tkw.cluster( - [ - epilogue_bitcast_a_itr1_1, - epilogue_bitcast_a_scale_itr1_1, - epilogue_mma_itr1_1, - ], - ), - ] + epilogue_clusters_itr0 = [ + tkw.cluster( + [ + epilogue_bitcast_a_itr0_0, + epilogue_bitcast_a_scale_itr0_0, + epilogue_bitcast_b_itr0, + epilogue_bitcast_b_scale_itr0, + tkw.SchedulingBarrier([]), + epilogue_mma_itr0_0, + epilogue_g2v_b, + epilogue_s2v_a_1_itr0, + epilogue_g2v_b_scale, + epilogue_s2v_a_scale_1_itr0, + epilogue_bitcast_a_itr0_1, + epilogue_bitcast_a_scale_itr0_1, + ], + ), + tkw.cluster( + [ + epilogue_mma_itr0_1, + tkw.SchedulingBarrier([]), + epilogue_s2v_a_0, + epilogue_s2v_a_scale_0, + ], + ), + tkw.cluster( + [ + epilogue_bitcast_a_itr1_0, + epilogue_bitcast_a_scale_itr1_0, + epilogue_bitcast_b_itr1, + epilogue_bitcast_b_scale_itr1, + tkw.SchedulingBarrier([]), + epilogue_mma_itr1_0, + epilogue_s2v_a_1_itr1, + epilogue_s2v_a_scale_1_itr1, + ], + ), + tkw.cluster( + [ + epilogue_bitcast_a_itr1_1, + epilogue_bitcast_a_scale_itr1_1, + epilogue_mma_itr1_1, + ], + ), + ] - tkw.reorder_graph(pipeline_loop.PROLOGUE, prologue_clusters) - tkw.reorder_graph(pipeline_loop.KERNEL, clusters) - # tkw.reorder_graph(pipeline_loop.EPILOGUE, epilogue_clusters_itr0) - unroll_factor = 2 - tkw.unroll(pipeline_loop.KERNEL, unroll_factor) + tkw.reorder_graph(pipeline_loop.PROLOGUE, prologue_clusters) + tkw.reorder_graph(pipeline_loop.KERNEL, clusters) + unroll_factor = 2 + tkw.unroll(pipeline_loop.KERNEL, unroll_factor) tkw.insert_at_start( pipeline_loop.KERNEL, diff --git a/wave_lang/kernel/wave/scheduling/loop_reconstruction.py b/wave_lang/kernel/wave/scheduling/loop_reconstruction.py index b757c8aecf..70055c2a3f 100644 --- a/wave_lang/kernel/wave/scheduling/loop_reconstruction.py +++ b/wave_lang/kernel/wave/scheduling/loop_reconstruction.py @@ -1,6 +1,7 @@ from collections import defaultdict, deque from enum import Enum +import sympy import torch.fx as fx from wave_lang.support.logging import get_logger @@ -8,7 +9,9 @@ from ..._support.indexing import IndexSymbol, IndexSequence, IndexExpr from ..._support.tracing import CapturedTrace from ...ops.wave_ops import ( + Conditional, CustomOp, + GatherToLDS, GetResult, IterArg, Iterate, @@ -812,6 +815,74 @@ def erase_allocs(allocs: list[fx.Node]): get_custom(alloc).erase() +def guard_g2s_with_bounds_check( + trace: CapturedTrace, + constraints: list[Constraint], +): + """ + Post-scheduling pass: wrap GatherToLDS nodes inside pipelined loops that + have eliminate_epilogue=True with an scf.if bounds guard. + + When eliminate_epilogue=True, the loop runs for the full trip count. Stage 0 + (the prefetch stage) uses iv + (num_stages-1)*step, which goes OOB in the + last (num_stages-1) iterations. Since gather_to_lds faults on OOB rather + than returning zero, we guard these ops with: + + if iv + (num_stages-1)*step < max_induction_variable: + gather_to_lds(...) + """ + for node in trace.walk(lambda n: isinstance(get_custom(n), Iterate)): + if not node.meta.get("eliminate_epilogue", False): + continue + + iterate = get_custom(node) + subgraph_name = iterate.subgraph_name + pipelined_graph = trace.get_subgraph(subgraph_name) + + g2s_nodes = [ + n for n in pipelined_graph.nodes if isinstance(get_custom(n), GatherToLDS) + ] + if not g2s_nodes: + continue + + num_stages = node.meta["num_pipeline_stages"] + step = iterate.step + max_iv = iterate.count + induction_variable = get_induction_variable(iterate, constraints) + + prefetch_offset = (num_stages - 1) * step + guard_condition = sympy.StrictLessThan( + induction_variable + prefetch_offset, max_iv + ) + + guard_subgraph = fx.Graph() + guard_subgraph_name = f"g2s_guard_{subgraph_name}" + + first_g2s = g2s_nodes[0] + location = get_custom(first_g2s).location + + with pipelined_graph.inserting_before(first_g2s): + cond_node = Conditional( + guard_condition, + subgraph_name=guard_subgraph_name, + implicit_captures=[], + ).add_to_graph(pipelined_graph, loc=location) + + for g2s in g2s_nodes: + custom = get_custom(g2s) + custom.copy(new_graph=guard_subgraph) + custom.erase() + + guard_subgraph.output(None) + + trace.add_subgraph(guard_subgraph_name, guard_subgraph) + + guard_subgraph.parent_op = cond_node + + root_graph = get_custom(cond_node).get_root_graph() + root_graph.subgraphs[guard_subgraph_name] = guard_subgraph + + def construct_pipelined_loop( trace: CapturedTrace, reduction: Iterate, @@ -823,11 +894,16 @@ def construct_pipelined_loop( visualize: bool = False, use_scheduling_barriers: bool = False, multi_buffer_count: Optional[int] = None, + eliminate_epilogue: bool = False, ) -> tuple[fx.Node, dict[fx.Node, list[fx.Node]], list[fx.Node]]: """ Given a graph annotated with scheduling parameters, construct a pipelined loop with a prologue, kernel and epilogue. + When eliminate_epilogue=True, the epilogue is not generated. The loop runs + for the full trip count and relies on out-of-bounds loads returning zero + to make the extra iterations contribute nothing to the accumulators. + Returns: pipelined_reduction: The pipelined loop node node_mapping: Dictionary tracking mappings from original nodes to pipelined copies @@ -893,28 +969,61 @@ def construct_pipelined_loop( trace.add_subgraph( get_custom(pipelined_reduction).subgraph_name, pipelined_reduction_graph ) - # Construct epilogue. - # The epilogue induction variables must account for the step size. - # With step > 1, each "iteration" covers step original iterations. - final_results = construct_epilogue( - graph, - reduction, - get_custom(pipelined_reduction), - partitioned_graph, - num_stages, - initiation_interval, - rotating_registers, - induction_variable, - [ - max_induction_variable - num_stages * step + i * step - for i in range(num_stages) - ], - create_drain_stage_schedule(num_stages), - num_rotating_registers, - visualize, - outer_vars=outer_vars, - node_mapping=node_mapping, - ) + + if eliminate_epilogue: + pipelined_reduction.meta["eliminate_epilogue"] = True + pipelined_reduction.meta["num_pipeline_stages"] = num_stages + + # No epilogue: the loop runs for the full trip count. + # The final results are simply the GetResult nodes from the pipelined loop. + pipelined_custom = get_custom(pipelined_reduction) + existing_get_results: list[GetResult] = sorted( + [x for x in pipelined_custom.users if isinstance(x, GetResult)], + key=lambda x: x.res_idx, + ) + iter_args = reduction.iter_args(graph) + # Ensure we have GetResult for every iter arg + last_get_result = ( + existing_get_results[-1].fx_node + if existing_get_results + else pipelined_reduction + ) + existing_indices = {x.res_idx for x in existing_get_results} + for i in range(len(iter_args)): + if i not in existing_indices: + with pipelined_custom.graph.inserting_after(last_get_result): + result = GetResult(pipelined_reduction, i).add_to_graph( + pipelined_custom.graph, + type=iter_args[i].type, + loc=pipelined_custom.location, + ) + existing_get_results.append(get_custom(result)) + last_get_result = result + existing_get_results = sorted(existing_get_results, key=lambda x: x.res_idx) + final_results = [gr.fx_node for gr in existing_get_results] + else: + # Construct epilogue. + # The epilogue induction variables must account for the step size. + # With step > 1, each "iteration" covers step original iterations. + final_results = construct_epilogue( + graph, + reduction, + get_custom(pipelined_reduction), + partitioned_graph, + num_stages, + initiation_interval, + rotating_registers, + induction_variable, + [ + max_induction_variable - num_stages * step + i * step + for i in range(num_stages) + ], + create_drain_stage_schedule(num_stages), + num_rotating_registers, + visualize, + outer_vars=outer_vars, + node_mapping=node_mapping, + ) # Remove the unpipelined reduction and the corresponding subgraph reduction.erase() diff --git a/wave_lang/kernel/wave/scheduling/schedule.py b/wave_lang/kernel/wave/scheduling/schedule.py index 61e641ec1f..7118615857 100644 --- a/wave_lang/kernel/wave/scheduling/schedule.py +++ b/wave_lang/kernel/wave/scheduling/schedule.py @@ -214,6 +214,7 @@ def build_guarded_pipeline_with_remainder( visualize: bool = False, use_scheduling_barriers: bool = False, multi_buffer_count: Optional[int] = None, + eliminate_epilogue: bool = False, ): """ Build conditional + pipelined loop + remainder loop for dynamic shapes. @@ -312,15 +313,21 @@ def build_guarded_pipeline_with_remainder( visualize, use_scheduling_barriers, multi_buffer_count, + eliminate_epilogue=eliminate_epilogue, ) # Set the count for the pipelined loop - # With step > 1 (e.g., from unrolling), we need to reduce the count by more - # to prevent out-of-bounds access. The last kernel iteration's stage 0 loads - # data for the "next" iteration (offset by step), so we need to ensure - # that stays within bounds. step = get_custom(pipelined_node).step - get_custom(pipelined_node).count = pipelined_iterations - (num_stages - 1) * step + if eliminate_epilogue: + get_custom(pipelined_node).count = pipelined_iterations + else: + # With step > 1 (e.g., from unrolling), we need to reduce the count by more + # to prevent out-of-bounds access. The last kernel iteration's stage 0 loads + # data for the "next" iteration (offset by step), so we need to ensure + # that stays within bounds. + get_custom(pipelined_node).count = ( + pipelined_iterations - (num_stages - 1) * step + ) # Verify we have the right number of results assert len(final_results) == len( @@ -414,6 +421,7 @@ def construct_pipelined_loop_adaptive( visualize: bool = False, use_scheduling_barriers: bool = False, multi_buffer_count: Optional[int] = None, + eliminate_epilogue: bool = False, ): """ Constructs a pipelined loop wrapped in a conditional, followed by a remainder loop. @@ -427,6 +435,12 @@ def construct_pipelined_loop_adaptive( else: return (0, init_values...) remainder_loop(start=iterations_done, end=total_iterations) + + When eliminate_epilogue=True, the epilogue is not generated and the loop + runs for the full trip count. Out-of-bounds loads in the extra iterations + must return zero (guaranteed by buffer_load on GFX9+). This trades wasted + prefetch work in the last (num_stages-1) iterations for eliminating all + epilogue code (MFMAs, loads, bitcasts). """ # Check if we have a dynamic shape (max_induction_variable is symbolic) is_dynamic = not ( @@ -450,12 +464,16 @@ def construct_pipelined_loop_adaptive( visualize, use_scheduling_barriers, multi_buffer_count, + eliminate_epilogue=eliminate_epilogue, ) if new_reduction: step = get_custom(new_reduction).step - get_custom(new_reduction).count = ( - max_induction_variable - (num_stages - 1) * step - ) + if eliminate_epilogue: + get_custom(new_reduction).count = max_induction_variable + else: + get_custom(new_reduction).count = ( + max_induction_variable - (num_stages - 1) * step + ) return new_reduction, node_mapping # For dynamic shapes, emit conditional + pipelined loop + remainder loop @@ -471,6 +489,7 @@ def construct_pipelined_loop_adaptive( visualize, use_scheduling_barriers, multi_buffer_count, + eliminate_epilogue=eliminate_epilogue, ) @@ -485,6 +504,7 @@ def apply_pipelined_schedule( scheduling_type: SchedulingType = SchedulingType.NONE, visualize: bool = False, multi_buffer_count: Optional[int] = None, + eliminate_epilogue: bool = False, ) -> Optional[tuple[fx.Node, dict]]: # After scheduling has completed, we have enough information to decide @@ -519,6 +539,7 @@ def apply_pipelined_schedule( visualize, use_scheduling_barriers, multi_buffer_count, + eliminate_epilogue=eliminate_epilogue, ) From 388b3a8fa1119e8b1eaeb861b14a1365926a461d Mon Sep 17 00:00:00 2001 From: xintin Date: Wed, 25 Feb 2026 14:17:15 +0000 Subject: [PATCH 2/9] epilogue elimination Signed-off-by: xintin --- wave_lang/kernel/compiler/wave_codegen/read_write.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index c64830396c..c754e59442 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -410,8 +410,10 @@ def _cast_buffer_and_encode_stride( # fastest_dim_bound == second to last stride. stride_candidate = strides[-2] stride_int = _get_constant_value(stride_candidate) - # Only swizzle if stride is static and fits in signed i14. - if stride_int and stride_int < 8192: + # Only swizzle if stride is static and fits in signed i14 + # (max representable positive value is 8191, but 8192 wraps to + # -8192 which the hardware accepts). + if stride_int and stride_int <= 8192: swizzle_stride = arith_d.index_cast(uint14, stride_candidate) if swizzle_stride: From 499700856031f036a5a9968d8fe803203547aabe Mon Sep 17 00:00:00 2001 From: xintin Date: Wed, 25 Feb 2026 18:20:34 +0000 Subject: [PATCH 3/9] updated branchless guard Signed-off-by: xintin --- .../compiler/wave_codegen/read_write.py | 125 +++++++++++++----- .../wave/scheduling/loop_reconstruction.py | 41 ++---- 2 files changed, 104 insertions(+), 62 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index c754e59442..0f959080aa 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -355,16 +355,26 @@ def _get_constant_value(candidate: Value): return candidate.owner.opview.value.value -def _cast_buffer_and_encode_stride( - ptr: Value, - strides: tuple[Value], - elem_type: IrType, +def _compute_branchless_valid_bytes( emitter: WaveEmitter, - symbolic_shape: tuple = None, - is_read: bool = True, + symbolic_shape: tuple, + elem_type: IrType, + guard_condition: sympy.Basic, ) -> Value: + """Compute a dynamic validBytes that becomes 0 when OOB. + + The guard_condition is a sympy expression like: + iv + prefetch_offset < max_iv + + We emit: + cond = gen_sympy_index(guard_condition) # index type, nonzero=true + real_valid = compute_static_validBytes() + validBytes = select(cond != 0, real_valid, 0) + + When the condition is false (last iterations), validBytes=0 makes the + SRD's NUM_RECORDS=0 so gather_to_lds DMA is a hardware no-op. + """ uint64 = IntegerType.get_signless(64) - uint14 = IntegerType.get_signless(14) elem_bytes = elem_type.width // 8 if emitter.options.use_real_buffer_bounds and symbolic_shape is not None: @@ -380,34 +390,76 @@ def _cast_buffer_and_encode_stride( else: total_bytes = _valid_bytes_buffer(elem_type) - # With resetOffset, the SRD base is adjusted forward by the memref's - # offset, so validBytes must be total_bytes - offset_bytes to avoid - # over-reporting the valid range past the actual allocation. - # - # For writes, skip the offset subtraction: real buffer bounds are only - # useful for reads (OOB loads return 0), and subtracting the offset - # from a clamped total_bytes can produce values that overflow the - # 32-bit SRD NUM_RECORDS field for output buffers larger than 2 GB. - if is_read: - metadata = memref_d.extract_strided_metadata(ptr) - offset_elements = metadata[1] - offset_bytes = arith_d.index_cast(uint64, offset_elements) - elem_bytes_val = arith_d.constant(uint64, get_constant_attr(elem_bytes, uint64)) - offset_bytes = arith_d.muli(offset_bytes, elem_bytes_val) - total_bytes_val = arith_d.constant( - uint64, get_constant_attr(total_bytes, uint64) - ) - valid_bytes_val = arith_d.subi(total_bytes_val, offset_bytes) + real_valid = arith_d.constant(uint64, get_constant_attr(total_bytes, uint64)) + zero_valid = arith_d.constant(uint64, get_constant_attr(0, uint64)) + + cond_val = gen_sympy_index(add_emitter_subs(emitter), guard_condition) + i1 = IntegerType.get_signless(1) + if cond_val.type != i1: + zero_idx = arith_d.constant(cond_val.type, 0) + cond_val = arith_d.cmpi(arith_d.CmpIPredicate.ne, cond_val, zero_idx) + + return arith_d.select(cond_val, real_valid, zero_valid) + + +def _cast_buffer_and_encode_stride( + ptr: Value, + strides: tuple[Value], + elem_type: IrType, + emitter: WaveEmitter, + symbolic_shape: tuple = None, + is_read: bool = True, + valid_bytes_override: Value = None, +) -> Value: + uint64 = IntegerType.get_signless(64) + uint14 = IntegerType.get_signless(14) + elem_bytes = elem_type.width // 8 + + if valid_bytes_override is not None: + valid_bytes_val = valid_bytes_override else: - valid_bytes_val = arith_d.constant( - uint64, get_constant_attr(total_bytes, uint64) - ) + if emitter.options.use_real_buffer_bounds and symbolic_shape is not None: + total_elements = subs_idxc(sympy.prod(s for s in symbolic_shape)) + if isinstance(total_elements, (int, float)) or ( + hasattr(total_elements, "is_number") and total_elements.is_number + ): + total_bytes = int(total_elements) * elem_bytes + max_valid = _valid_bytes_buffer(elem_type) + total_bytes = min(total_bytes, max_valid) + else: + total_bytes = _valid_bytes_buffer(elem_type) + else: + total_bytes = _valid_bytes_buffer(elem_type) + + # With resetOffset, the SRD base is adjusted forward by the memref's + # offset, so validBytes must be total_bytes - offset_bytes to avoid + # over-reporting the valid range past the actual allocation. + # + # For writes, skip the offset subtraction: real buffer bounds are only + # useful for reads (OOB loads return 0), and subtracting the offset + # from a clamped total_bytes can produce values that overflow the + # 32-bit SRD NUM_RECORDS field for output buffers larger than 2 GB. + if is_read: + metadata = memref_d.extract_strided_metadata(ptr) + offset_elements = metadata[1] + offset_bytes = arith_d.index_cast(uint64, offset_elements) + elem_bytes_val = arith_d.constant( + uint64, get_constant_attr(elem_bytes, uint64) + ) + offset_bytes = arith_d.muli(offset_bytes, elem_bytes_val) + total_bytes_val = arith_d.constant( + uint64, get_constant_attr(total_bytes, uint64) + ) + valid_bytes_val = arith_d.subi(total_bytes_val, offset_bytes) + else: + valid_bytes_val = arith_d.constant( + uint64, get_constant_attr(total_bytes, uint64) + ) stride_rank = len(strides) swizzle_stride = None if stride_rank >= 2: - # fastest_dim_bound == second to last stride. stride_candidate = strides[-2] stride_int = _get_constant_value(stride_candidate) # Only swizzle if stride is static and fits in signed i14 @@ -1141,8 +1193,21 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): ] src, offset_th = _linearize_memref(src, src_index_wg, src_index_th, strides) + + valid_bytes_override = None + guard_condition = node.meta.get("g2s_branchless_guard", None) + if guard_condition is not None: + valid_bytes_override = _compute_branchless_valid_bytes( + emitter, src_symbolic_shape, element_type, guard_condition + ) + src = _cast_buffer_and_encode_stride( - src, strides, element_type, emitter, src_symbolic_shape + src, + strides, + element_type, + emitter, + src_symbolic_shape, + valid_bytes_override=valid_bytes_override, ) # We previously checked mask is same for all elements, so we can use diff --git a/wave_lang/kernel/wave/scheduling/loop_reconstruction.py b/wave_lang/kernel/wave/scheduling/loop_reconstruction.py index 70055c2a3f..538dace07f 100644 --- a/wave_lang/kernel/wave/scheduling/loop_reconstruction.py +++ b/wave_lang/kernel/wave/scheduling/loop_reconstruction.py @@ -9,7 +9,6 @@ from ..._support.indexing import IndexSymbol, IndexSequence, IndexExpr from ..._support.tracing import CapturedTrace from ...ops.wave_ops import ( - Conditional, CustomOp, GatherToLDS, GetResult, @@ -820,16 +819,18 @@ def guard_g2s_with_bounds_check( constraints: list[Constraint], ): """ - Post-scheduling pass: wrap GatherToLDS nodes inside pipelined loops that - have eliminate_epilogue=True with an scf.if bounds guard. + Post-scheduling pass: annotate GatherToLDS nodes inside pipelined loops + that have eliminate_epilogue=True with branchless SRD guard metadata. When eliminate_epilogue=True, the loop runs for the full trip count. Stage 0 (the prefetch stage) uses iv + (num_stages-1)*step, which goes OOB in the - last (num_stages-1) iterations. Since gather_to_lds faults on OOB rather - than returning zero, we guard these ops with: + last (num_stages-1) iterations. Instead of wrapping gather_to_lds in an + scf.if branch, we annotate each gather_to_lds with a guard condition so + that codegen can emit a dynamic validBytes that becomes 0 when OOB: - if iv + (num_stages-1)*step < max_induction_variable: - gather_to_lds(...) + validBytes = select(iv + prefetch_offset < max_iv, real_validBytes, 0) + + This makes the hardware DMA a no-op (reads nothing) without any branch. """ for node in trace.walk(lambda n: isinstance(get_custom(n), Iterate)): if not node.meta.get("eliminate_epilogue", False): @@ -855,32 +856,8 @@ def guard_g2s_with_bounds_check( induction_variable + prefetch_offset, max_iv ) - guard_subgraph = fx.Graph() - guard_subgraph_name = f"g2s_guard_{subgraph_name}" - - first_g2s = g2s_nodes[0] - location = get_custom(first_g2s).location - - with pipelined_graph.inserting_before(first_g2s): - cond_node = Conditional( - guard_condition, - subgraph_name=guard_subgraph_name, - implicit_captures=[], - ).add_to_graph(pipelined_graph, loc=location) - for g2s in g2s_nodes: - custom = get_custom(g2s) - custom.copy(new_graph=guard_subgraph) - custom.erase() - - guard_subgraph.output(None) - - trace.add_subgraph(guard_subgraph_name, guard_subgraph) - - guard_subgraph.parent_op = cond_node - - root_graph = get_custom(cond_node).get_root_graph() - root_graph.subgraphs[guard_subgraph_name] = guard_subgraph + g2s.meta["g2s_branchless_guard"] = guard_condition def construct_pipelined_loop( From ac90421ed26a5515ba04c21b22a4eab7c3448e87 Mon Sep 17 00:00:00 2001 From: xintin Date: Wed, 25 Feb 2026 23:34:27 +0000 Subject: [PATCH 4/9] updated example Signed-off-by: xintin --- examples/python/7.1_schedule.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index 88f1ab217a..d6ab2abd0a 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -233,7 +233,10 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm( def test_dbuf_4wave_mxfp_preshuffle_b_no_epilogue_gemm( - is_debug=False, shape=(1024, 2048, 8192), block=(128, 256, 256) + is_debug=False, + shape=(1024, 1024, 8192), + block=(128, 256, 256), + eliminate_epilogue=True, ): """Asymmetric MXFP4 GEMM with preshuffled B, epilogue eliminated via OOB=0.""" gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4)) From 856a9016ed9ed392f2de719a2b1d777abb133ecb Mon Sep 17 00:00:00 2001 From: xintin Date: Thu, 26 Feb 2026 00:36:52 +0000 Subject: [PATCH 5/9] removed debug code Signed-off-by: xintin --- examples/python/7.1_schedule.py | 32 ++++++------------- .../compiler/wave_codegen/read_write.py | 6 ++-- wave_lang/kernel/wave/compile_options.py | 4 ++- .../schedules/gemm_mxfp4_double_buffer.py | 3 +- .../wave/scheduling/loop_reconstruction.py | 2 +- wave_lang/kernel/wave/scheduling/schedule.py | 2 +- 6 files changed, 19 insertions(+), 30 deletions(-) diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index d6ab2abd0a..62e747577b 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -213,16 +213,19 @@ def test_dbuf_4wave_mxfp_asymmetric_gemm( def test_dbuf_4wave_mxfp_preshuffle_b_gemm( - is_debug=False, shape=(1024, 1024, 8192), block=(128, 256, 256) + is_debug=False, + shape=(1024, 1024, 8192), + block=(128, 256, 256), + eliminate_epilogue=False, ): """Asymmetric MXFP4 GEMM with preshuffled B data and B scales.""" gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4)) - print(block) options.minimize_shared_allocs = True options.linearize_shared_access = True options.use_buffer_ops = True + options.eliminate_epilogue = eliminate_epilogue options.dump_intermediates = "build/intermediates" - schedule = get_mxfp4_asymmetric_schedule(is_bscale_shuffled=True) + schedule = get_mxfp4_asymmetric_schedule(eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True) options.print_ir_after = "all" if is_debug else [] options = set_default_run_config(options) @@ -233,27 +236,12 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm( def test_dbuf_4wave_mxfp_preshuffle_b_no_epilogue_gemm( - is_debug=False, - shape=(1024, 1024, 8192), - block=(128, 256, 256), - eliminate_epilogue=True, + is_debug=False, shape=(1024, 2048, 8192), block=(128, 256, 256) ): """Asymmetric MXFP4 GEMM with preshuffled B, epilogue eliminated via OOB=0.""" - gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4)) - options.minimize_shared_allocs = True - options.linearize_shared_access = True - options.use_buffer_ops = True - options.eliminate_epilogue = True - options.dump_intermediates = "build/intermediates" - schedule = get_mxfp4_asymmetric_schedule(eliminate_epilogue=True, is_bscale_shuffled=True) - - options.print_ir_after = "all" if is_debug else [] - options = set_default_run_config(options) - print(block) - gemm = wave_compile(options, gemm, schedule) - - _run_mxfp_gemm_preshuffle(gemm, shape, all=True) - print("MXFP GEMM preshuffle-B no-epilogue 4-wave test passed!") + test_dbuf_4wave_mxfp_preshuffle_b_gemm( + is_debug=is_debug, shape=shape, block=block, eliminate_epilogue=True, + ) def test_dbuf_4wave_mxfp_asymmetric_gemm_cpp( diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 0f959080aa..a0088e31f2 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -377,7 +377,7 @@ def _compute_branchless_valid_bytes( uint64 = IntegerType.get_signless(64) elem_bytes = elem_type.width // 8 - if emitter.options.use_real_buffer_bounds and symbolic_shape is not None: + if emitter.options.eliminate_epilogue and symbolic_shape is not None: total_elements = subs_idxc(sympy.prod(s for s in symbolic_shape)) if isinstance(total_elements, (int, float)) or ( hasattr(total_elements, "is_number") and total_elements.is_number @@ -418,7 +418,7 @@ def _cast_buffer_and_encode_stride( if valid_bytes_override is not None: valid_bytes_val = valid_bytes_override else: - if emitter.options.use_real_buffer_bounds and symbolic_shape is not None: + if emitter.options.eliminate_epilogue and symbolic_shape is not None: total_elements = subs_idxc(sympy.prod(s for s in symbolic_shape)) if isinstance(total_elements, (int, float)) or ( hasattr(total_elements, "is_number") and total_elements.is_number @@ -1195,7 +1195,7 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): src, offset_th = _linearize_memref(src, src_index_wg, src_index_th, strides) valid_bytes_override = None - guard_condition = node.meta.get("g2s_branchless_guard", None) + guard_condition = node.meta.get("g2s_guard", None) if guard_condition is not None: valid_bytes_override = _compute_branchless_valid_bytes( emitter, src_symbolic_shape, element_type, guard_condition diff --git a/wave_lang/kernel/wave/compile_options.py b/wave_lang/kernel/wave/compile_options.py index 8b525bed5d..2da22f0684 100644 --- a/wave_lang/kernel/wave/compile_options.py +++ b/wave_lang/kernel/wave/compile_options.py @@ -89,7 +89,9 @@ class WaveCompileOptions: wave_runtime: bool = False iree_launch_async: bool = True use_buffer_ops: bool = False - use_real_buffer_bounds: bool = False + + eliminate_epilogue: bool = False + use_fast_math: bool = False use_global_to_shared: bool = False linearize_shared_access: bool = False diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index def59fc1f3..fd3e9bdfde 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -1166,8 +1166,7 @@ def get_mxfp4_asymmetric_schedule(eliminate_epilogue: bool = False, is_bscale_sh When eliminate_epilogue=True the loop runs for the full K trip count and relies on OOB buffer loads returning zero (GFX9+ hardware guarantee) so that extra iterations contribute nothing to the accumulators. This - removes all epilogue code, reducing icache pressure and total code size. - Requires options.use_buffer_ops=True and options.use_real_buffer_bounds=True. + removes all epilogue code, reducing total code size. """ M = tkl.sym.M diff --git a/wave_lang/kernel/wave/scheduling/loop_reconstruction.py b/wave_lang/kernel/wave/scheduling/loop_reconstruction.py index 538dace07f..70e7c7b7b7 100644 --- a/wave_lang/kernel/wave/scheduling/loop_reconstruction.py +++ b/wave_lang/kernel/wave/scheduling/loop_reconstruction.py @@ -857,7 +857,7 @@ def guard_g2s_with_bounds_check( ) for g2s in g2s_nodes: - g2s.meta["g2s_branchless_guard"] = guard_condition + g2s.meta["g2s_guard"] = guard_condition def construct_pipelined_loop( diff --git a/wave_lang/kernel/wave/scheduling/schedule.py b/wave_lang/kernel/wave/scheduling/schedule.py index 7118615857..192ef4b499 100644 --- a/wave_lang/kernel/wave/scheduling/schedule.py +++ b/wave_lang/kernel/wave/scheduling/schedule.py @@ -438,7 +438,7 @@ def construct_pipelined_loop_adaptive( When eliminate_epilogue=True, the epilogue is not generated and the loop runs for the full trip count. Out-of-bounds loads in the extra iterations - must return zero (guaranteed by buffer_load on GFX9+). This trades wasted + must return zero (guaranteed by buffer_load on GFX950). This trades wasted prefetch work in the last (num_stages-1) iterations for eliminating all epilogue code (MFMAs, loads, bitcasts). """ From 9a6d8f58eeb1d9fcefc80eb8088cafe616648d50 Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 18:35:50 +0000 Subject: [PATCH 6/9] rebased with main Signed-off-by: xintin --- examples/python/7.1_schedule.py | 11 +++- .../compiler/wave_codegen/read_write.py | 58 ++++++++++--------- 2 files changed, 40 insertions(+), 29 deletions(-) diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index 62e747577b..c3b40f9dd6 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -225,7 +225,9 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm( options.use_buffer_ops = True options.eliminate_epilogue = eliminate_epilogue options.dump_intermediates = "build/intermediates" - schedule = get_mxfp4_asymmetric_schedule(eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True) + schedule = get_mxfp4_asymmetric_schedule( + eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True + ) options.print_ir_after = "all" if is_debug else [] options = set_default_run_config(options) @@ -236,11 +238,14 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm( def test_dbuf_4wave_mxfp_preshuffle_b_no_epilogue_gemm( - is_debug=False, shape=(1024, 2048, 8192), block=(128, 256, 256) + is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256) ): """Asymmetric MXFP4 GEMM with preshuffled B, epilogue eliminated via OOB=0.""" test_dbuf_4wave_mxfp_preshuffle_b_gemm( - is_debug=is_debug, shape=shape, block=block, eliminate_epilogue=True, + is_debug=is_debug, + shape=shape, + block=block, + eliminate_epilogue=True, ) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index a0088e31f2..8391c1a103 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -329,6 +329,29 @@ def _valid_bytes_buffer(elem_type: IrType) -> int: return ans +def _compute_total_valid_bytes( + elem_type: IrType, + symbolic_shape: tuple, + use_real_bounds: bool, +) -> int: + """Return the total valid byte count for a buffer SRD. + + When *use_real_bounds* is True and *symbolic_shape* resolves to a concrete + number, the result is clamped to the actual tensor size (still bounded by + the hardware maximum). Otherwise falls back to the hardware maximum + returned by ``_valid_bytes_buffer``. + """ + if use_real_bounds and symbolic_shape is not None: + elem_bytes = elem_type.width // 8 + total_elements = subs_idxc(sympy.prod(s for s in symbolic_shape)) + if isinstance(total_elements, (int, float)) or ( + hasattr(total_elements, "is_number") and total_elements.is_number + ): + total_bytes = int(total_elements) * elem_bytes + return min(total_bytes, _valid_bytes_buffer(elem_type)) + return _valid_bytes_buffer(elem_type) + + def _get_out_of_bounds_index(element_type: IrType) -> int: """ returns the first index that's out of bounds of a buffer based on the element type and maximum bytes @@ -375,20 +398,9 @@ def _compute_branchless_valid_bytes( SRD's NUM_RECORDS=0 so gather_to_lds DMA is a hardware no-op. """ uint64 = IntegerType.get_signless(64) - elem_bytes = elem_type.width // 8 - - if emitter.options.eliminate_epilogue and symbolic_shape is not None: - total_elements = subs_idxc(sympy.prod(s for s in symbolic_shape)) - if isinstance(total_elements, (int, float)) or ( - hasattr(total_elements, "is_number") and total_elements.is_number - ): - total_bytes = int(total_elements) * elem_bytes - max_valid = _valid_bytes_buffer(elem_type) - total_bytes = min(total_bytes, max_valid) - else: - total_bytes = _valid_bytes_buffer(elem_type) - else: - total_bytes = _valid_bytes_buffer(elem_type) + total_bytes = _compute_total_valid_bytes( + elem_type, symbolic_shape, use_real_bounds=True + ) real_valid = arith_d.constant(uint64, get_constant_attr(total_bytes, uint64)) zero_valid = arith_d.constant(uint64, get_constant_attr(0, uint64)) @@ -418,18 +430,12 @@ def _cast_buffer_and_encode_stride( if valid_bytes_override is not None: valid_bytes_val = valid_bytes_override else: - if emitter.options.eliminate_epilogue and symbolic_shape is not None: - total_elements = subs_idxc(sympy.prod(s for s in symbolic_shape)) - if isinstance(total_elements, (int, float)) or ( - hasattr(total_elements, "is_number") and total_elements.is_number - ): - total_bytes = int(total_elements) * elem_bytes - max_valid = _valid_bytes_buffer(elem_type) - total_bytes = min(total_bytes, max_valid) - else: - total_bytes = _valid_bytes_buffer(elem_type) - else: - total_bytes = _valid_bytes_buffer(elem_type) + use_real_bounds = ( + emitter.options.eliminate_epilogue and symbolic_shape is not None + ) + total_bytes = _compute_total_valid_bytes( + elem_type, symbolic_shape, use_real_bounds + ) # With resetOffset, the SRD base is adjusted forward by the memref's # offset, so validBytes must be total_bytes - offset_bytes to avoid From 951d3914706e506d0a14a52eff7f4ea413304eb6 Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 18:39:28 +0000 Subject: [PATCH 7/9] reorder flag Signed-off-by: xintin --- wave_lang/kernel/wave/compile_options.py | 4 +--- wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py | 4 +++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/wave_lang/kernel/wave/compile_options.py b/wave_lang/kernel/wave/compile_options.py index 2da22f0684..f126742c05 100644 --- a/wave_lang/kernel/wave/compile_options.py +++ b/wave_lang/kernel/wave/compile_options.py @@ -89,9 +89,6 @@ class WaveCompileOptions: wave_runtime: bool = False iree_launch_async: bool = True use_buffer_ops: bool = False - - eliminate_epilogue: bool = False - use_fast_math: bool = False use_global_to_shared: bool = False linearize_shared_access: bool = False @@ -107,6 +104,7 @@ class WaveCompileOptions: dump_schedule: Optional[str] = None use_bound_check: bool = False specialize: bool = False + eliminate_epilogue: bool = False # Cluster barrier signal/wait delay in number of loop iterations # None - no barriers inside the loop diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index fd3e9bdfde..da35d8f3a5 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -1142,7 +1142,9 @@ def mxfp4_dbuf_schedule(): return mxfp4_dbuf_schedule -def get_mxfp4_asymmetric_schedule(eliminate_epilogue: bool = False, is_bscale_shuffled: bool = False): +def get_mxfp4_asymmetric_schedule( + eliminate_epilogue: bool = False, is_bscale_shuffled: bool = False +): """Return an asymmetric-prefetch MXFP4 schedule for wave_compile(). Asymmetric data paths: From 0f53afe9676cbc5c4431361477fbdf63fb7b81e9 Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 19:04:37 +0000 Subject: [PATCH 8/9] lint fix Signed-off-by: xintin --- wave_lang/kernel/compiler/wave_codegen/read_write.py | 6 +++++- wave_lang/kernel/ops/wave_schedule_ops.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 8391c1a103..529e781da7 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -445,7 +445,7 @@ def _cast_buffer_and_encode_stride( # useful for reads (OOB loads return 0), and subtracting the offset # from a clamped total_bytes can produce values that overflow the # 32-bit SRD NUM_RECORDS field for output buffers larger than 2 GB. - if is_read: + if use_real_bounds and is_read: metadata = memref_d.extract_strided_metadata(ptr) offset_elements = metadata[1] offset_bytes = arith_d.index_cast(uint64, offset_elements) @@ -601,6 +601,10 @@ def extract(vec, ind): mem = _cast_buffer_and_encode_stride( mem, strides, element_type, emitter, symbolic_shape, is_read ) + elif is_global_mem and not is_read: + mem, offset_th = _linearize_memref( + mem, start_indices_wg, start_indices_th, strides + ) if linearize_shared_mem: mem = _linearize_shared_mem(mem) linearized_index = { diff --git a/wave_lang/kernel/ops/wave_schedule_ops.py b/wave_lang/kernel/ops/wave_schedule_ops.py index dceb592eb9..15bdc7104b 100755 --- a/wave_lang/kernel/ops/wave_schedule_ops.py +++ b/wave_lang/kernel/ops/wave_schedule_ops.py @@ -183,7 +183,11 @@ def reorder_graph(loop: Any, clusters: Any): ... @define_schedule_op -def pipeline(iterate: Sequence[fx.Node], eliminate_epilogue: bool = False, multi_buffer_count: Optional[int] = None): ... +def pipeline( + iterate: Sequence[fx.Node], + eliminate_epilogue: bool = False, + multi_buffer_count: Optional[int] = None, +): ... @define_schedule_op From 2b4f9e7e15fa075df51d747071de6124591bbf45 Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 21:11:35 +0000 Subject: [PATCH 9/9] pass mlir_roundtrip_pipeline Signed-off-by: xintin --- lit_tests/kernel/wave/mlir_roundtrip_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lit_tests/kernel/wave/mlir_roundtrip_pipeline.py b/lit_tests/kernel/wave/mlir_roundtrip_pipeline.py index 8c69acebb0..74811a78a6 100644 --- a/lit_tests/kernel/wave/mlir_roundtrip_pipeline.py +++ b/lit_tests/kernel/wave/mlir_roundtrip_pipeline.py @@ -273,6 +273,7 @@ def attention_progressive_roundtrip(): "compute_shared_memory_usage", "partition_gather_like_ops", "generate_bounds_exprs", + "guard_g2s_with_bounds_check", "merge_contiguous_reads", "location_check_pass", "simplify_indices",