diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index fb87425234..bf9f318fc8 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -214,15 +214,21 @@ def test_dbuf_4wave_mxfp_asymmetric_gemm( def test_dbuf_4wave_mxfp_preshuffle_b_gemm( - is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256) + is_debug=False, + shape=(1024, 1024, 8192), + block=(128, 256, 256), + eliminate_epilogue=False, ): """Asymmetric MXFP4 GEMM with preshuffled B data and B scales.""" gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4)) options.minimize_shared_allocs = True options.linearize_shared_access = True options.use_buffer_ops = True + options.eliminate_epilogue = eliminate_epilogue options.dump_intermediates = "build/intermediates" - schedule = get_mxfp4_asymmetric_schedule(is_bscale_shuffled=True) + schedule = get_mxfp4_asymmetric_schedule( + eliminate_epilogue=eliminate_epilogue, is_bscale_shuffled=True + ) options.print_ir_after = "all" if is_debug else [] options = set_default_run_config(options) @@ -232,6 +238,18 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm( print("MXFP GEMM preshuffle-B 4-wave test passed!") +def test_dbuf_4wave_mxfp_preshuffle_b_no_epilogue_gemm( + is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256) +): + """Asymmetric MXFP4 GEMM with preshuffled B, epilogue eliminated via OOB=0.""" + test_dbuf_4wave_mxfp_preshuffle_b_gemm( + is_debug=is_debug, + shape=shape, + block=block, + eliminate_epilogue=True, + ) + + def test_dbuf_4wave_mxfp_asymmetric_gemm_cpp( is_debug=False, shape=(1024, 1024, 8192), block=(128, 256, 256) ): diff --git a/lit_tests/kernel/wave/mlir_roundtrip_pipeline.py b/lit_tests/kernel/wave/mlir_roundtrip_pipeline.py index 8c69acebb0..74811a78a6 100644 --- a/lit_tests/kernel/wave/mlir_roundtrip_pipeline.py +++ b/lit_tests/kernel/wave/mlir_roundtrip_pipeline.py @@ -273,6 +273,7 @@ def attention_progressive_roundtrip(): "compute_shared_memory_usage", "partition_gather_like_ops", "generate_bounds_exprs", + "guard_g2s_with_bounds_check", "merge_contiguous_reads", "location_check_pass", "simplify_indices", diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 5c871e8604..2d82813594 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -336,6 +336,29 @@ def _valid_bytes_buffer(elem_type: IrType) -> int: return ans +def _compute_total_valid_bytes( + elem_type: IrType, + symbolic_shape: tuple, + use_real_bounds: bool, +) -> int: + """Return the total valid byte count for a buffer SRD. + + When *use_real_bounds* is True and *symbolic_shape* resolves to a concrete + number, the result is clamped to the actual tensor size (still bounded by + the hardware maximum). Otherwise falls back to the hardware maximum + returned by ``_valid_bytes_buffer``. + """ + if use_real_bounds and symbolic_shape is not None: + elem_bytes = elem_type.width // 8 + total_elements = subs_idxc(sympy.prod(s for s in symbolic_shape)) + if isinstance(total_elements, (int, float)) or ( + hasattr(total_elements, "is_number") and total_elements.is_number + ): + total_bytes = int(total_elements) * elem_bytes + return min(total_bytes, _valid_bytes_buffer(elem_type)) + return _valid_bytes_buffer(elem_type) + + def _get_out_of_bounds_index(element_type: IrType) -> int: """ returns the first index that's out of bounds of a buffer based on the element type and maximum bytes @@ -362,25 +385,99 @@ def _get_constant_value(candidate: Value): return candidate.owner.opview.value.value +def _compute_branchless_valid_bytes( + emitter: WaveEmitter, + symbolic_shape: tuple, + elem_type: IrType, + guard_condition: sympy.Basic, +) -> Value: + """Compute a dynamic validBytes that becomes 0 when OOB. + + The guard_condition is a sympy expression like: + iv + prefetch_offset < max_iv + + We emit: + cond = gen_sympy_index(guard_condition) # index type, nonzero=true + real_valid = compute_static_validBytes() + validBytes = select(cond != 0, real_valid, 0) + + When the condition is false (last iterations), validBytes=0 makes the + SRD's NUM_RECORDS=0 so gather_to_lds DMA is a hardware no-op. + """ + uint64 = IntegerType.get_signless(64) + total_bytes = _compute_total_valid_bytes( + elem_type, symbolic_shape, use_real_bounds=True + ) + + real_valid = arith_d.constant(uint64, get_constant_attr(total_bytes, uint64)) + zero_valid = arith_d.constant(uint64, get_constant_attr(0, uint64)) + + cond_val = gen_sympy_index(add_emitter_subs(emitter), guard_condition) + i1 = IntegerType.get_signless(1) + if cond_val.type != i1: + zero_idx = arith_d.constant(cond_val.type, 0) + cond_val = arith_d.cmpi(arith_d.CmpIPredicate.ne, cond_val, zero_idx) + + return arith_d.select(cond_val, real_valid, zero_valid) + + def _cast_buffer_and_encode_stride( - ptr: Value, strides: tuple[Value], elem_type: IrType, emitter: WaveEmitter + ptr: Value, + strides: tuple[Value], + elem_type: IrType, + emitter: WaveEmitter, + symbolic_shape: tuple = None, + is_read: bool = True, + valid_bytes_override: Value = None, ) -> Value: uint64 = IntegerType.get_signless(64) uint14 = IntegerType.get_signless(14) + elem_bytes = elem_type.width // 8 + + if valid_bytes_override is not None: + valid_bytes_val = valid_bytes_override + else: + use_real_bounds = ( + emitter.options.eliminate_epilogue and symbolic_shape is not None + ) + total_bytes = _compute_total_valid_bytes( + elem_type, symbolic_shape, use_real_bounds + ) + + # With resetOffset, the SRD base is adjusted forward by the memref's + # offset, so validBytes must be total_bytes - offset_bytes to avoid + # over-reporting the valid range past the actual allocation. + # + # For writes, skip the offset subtraction: real buffer bounds are only + # useful for reads (OOB loads return 0), and subtracting the offset + # from a clamped total_bytes can produce values that overflow the + # 32-bit SRD NUM_RECORDS field for output buffers larger than 2 GB. + if use_real_bounds and is_read: + metadata = memref_d.extract_strided_metadata(ptr) + offset_elements = metadata[1] + offset_bytes = arith_d.index_cast(uint64, offset_elements) + elem_bytes_val = arith_d.constant( + uint64, get_constant_attr(elem_bytes, uint64) + ) + offset_bytes = arith_d.muli(offset_bytes, elem_bytes_val) + total_bytes_val = arith_d.constant( + uint64, get_constant_attr(total_bytes, uint64) + ) + valid_bytes_val = arith_d.subi(total_bytes_val, offset_bytes) + else: + valid_bytes_val = arith_d.constant( + uint64, get_constant_attr(total_bytes, uint64) + ) - valid_bytes = _valid_bytes_buffer( - elem_type - ) # max bytes that are in range to be addressed from a buffer - valid_bytes_constant = get_constant_attr(valid_bytes, uint64) - valid_bytes_constant = arith_d.constant(uint64, valid_bytes_constant) stride_rank = len(strides) swizzle_stride = None if stride_rank >= 2: - # fastest_dim_bound == second to last stride. stride_candidate = strides[-2] stride_int = _get_constant_value(stride_candidate) - # Only swizzle if stride is static and <= 8192(the useful case). + # Only swizzle if stride is static and fits in signed i14 + # (max representable positive value is 8191, but 8192 wraps to + # -8192 which the hardware accepts). if stride_int and stride_int <= 8192: swizzle_stride = arith_d.index_cast(uint14, stride_candidate) @@ -390,14 +487,14 @@ def _cast_buffer_and_encode_stride( cache_swizzle_stride=swizzle_stride, bounds_check=True, reset_offset=True, - valid_bytes=valid_bytes_constant, + valid_bytes=valid_bytes_val, ) else: ptr = amdgpu_d.fat_raw_buffer_cast( ptr, bounds_check=True, reset_offset=True, - valid_bytes=valid_bytes_constant, + valid_bytes=valid_bytes_val, ) return ptr @@ -508,7 +605,9 @@ def extract(vec, ind): mem, offset_th = _linearize_memref( mem, start_indices_wg, start_indices_th, strides ) - mem = _cast_buffer_and_encode_stride(mem, strides, element_type, emitter) + mem = _cast_buffer_and_encode_stride( + mem, strides, element_type, emitter, symbolic_shape, is_read + ) elif is_global_mem and not is_read: mem, offset_th = _linearize_memref( mem, start_indices_wg, start_indices_th, strides @@ -553,7 +652,9 @@ def extract(vec, ind): mem, offset_th = _linearize_memref( mem, start_indices_wg, start_indices_th, strides ) - mem = _cast_buffer_and_encode_stride(mem, strides, element_type, emitter) + mem = _cast_buffer_and_encode_stride( + mem, strides, element_type, emitter, symbolic_shape, is_read + ) indices = [offset_th] if buffer_ops_enabled else start_indices @@ -1109,7 +1210,22 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): ] src, offset_th = _linearize_memref(src, src_index_wg, src_index_th, strides) - src = _cast_buffer_and_encode_stride(src, strides, element_type, emitter) + + valid_bytes_override = None + guard_condition = node.meta.get("g2s_guard", None) + if guard_condition is not None: + valid_bytes_override = _compute_branchless_valid_bytes( + emitter, src_symbolic_shape, element_type, guard_condition + ) + + src = _cast_buffer_and_encode_stride( + src, + strides, + element_type, + emitter, + src_symbolic_shape, + valid_bytes_override=valid_bytes_override, + ) # We previously checked mask is same for all elements, so we can use # elements_per_thread=1 to build the mask. diff --git a/wave_lang/kernel/ops/wave_schedule_ops.py b/wave_lang/kernel/ops/wave_schedule_ops.py index 159a1549e4..ec680ebc78 100755 --- a/wave_lang/kernel/ops/wave_schedule_ops.py +++ b/wave_lang/kernel/ops/wave_schedule_ops.py @@ -183,7 +183,11 @@ def reorder_graph(loop: Any, clusters: Any): ... @define_schedule_op -def pipeline(iterate: Sequence[fx.Node], multi_buffer_count: Optional[int] = None): ... +def pipeline( + iterate: Sequence[fx.Node], + eliminate_epilogue: bool = False, + multi_buffer_count: Optional[int] = None, +): ... @define_schedule_op @@ -804,11 +808,13 @@ def __init__( iterate: Sequence[fx.Node], kernel_trace: "CapturedTrace", constraints: list[Constraint], + eliminate_epilogue: bool = False, multi_buffer_count: Optional[int] = None, ): self.iterate = iterate self.kernel_trace = kernel_trace self.constraints = constraints + self.eliminate_epilogue = eliminate_epilogue self.multi_buffer_count = multi_buffer_count # Access options from the current ScheduleContext @@ -875,6 +881,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): scheduling_type=SchedulingType.MANUAL, visualize=False, multi_buffer_count=self.multi_buffer_count, + eliminate_epilogue=self.eliminate_epilogue, ) # Store the pipelined iterate node and node mapping, then create proxies for the stages @@ -937,7 +944,11 @@ def PROLOGUE(self): @property def EPILOGUE(self): - """Get a reference to the EPILOGUE stage (nodes after pipelined iterate).""" + """Get a reference to the EPILOGUE stage (nodes after pipelined iterate). + Returns None when eliminate_epilogue=True (epilogue was not generated). + """ + if self.eliminate_epilogue: + return None return self._EPILOGUE def _update_kernel_node_mapping(self): @@ -1110,8 +1121,11 @@ def handle( kernel_trace, constraints: list[Constraint], iterate: Sequence[fx.Node], + eliminate_epilogue: bool = False, ): - real_pipelined_loop = PipelinedLoop(iterate, kernel_trace, constraints) + real_pipelined_loop = PipelinedLoop( + iterate, kernel_trace, constraints, eliminate_epilogue=eliminate_epilogue + ) # Return the real object directly (no proxy needed) return real_pipelined_loop diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 2e74a70b48..2f92bfeff0 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -73,6 +73,7 @@ from .multicast import multicast from .promotion import compute_shared_memory_usage, promote_placeholders from .schedule_reordering import schedule_reordering +from .scheduling.loop_reconstruction import guard_g2s_with_bounds_check from .scheduling.schedule import schedule_graph from .scheduling.schedule_enums import SchedulingType from .shared_memory_indexing import apply_shared_memory_indexing_corrections @@ -554,6 +555,10 @@ def build_graph_passes( ) ) + graph_passes.append( + partial(guard_g2s_with_bounds_check, trace, launchable.constraints) + ) + if options.optimization_level: graph_passes += [ partial( diff --git a/wave_lang/kernel/wave/compile_options.py b/wave_lang/kernel/wave/compile_options.py index bc123aa2c6..f126742c05 100644 --- a/wave_lang/kernel/wave/compile_options.py +++ b/wave_lang/kernel/wave/compile_options.py @@ -104,6 +104,7 @@ class WaveCompileOptions: dump_schedule: Optional[str] = None use_bound_check: bool = False specialize: bool = False + eliminate_epilogue: bool = False # Cluster barrier signal/wait delay in number of loop iterations # None - no barriers inside the loop diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index dca51dd0eb..da35d8f3a5 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -1142,7 +1142,9 @@ def mxfp4_dbuf_schedule(): return mxfp4_dbuf_schedule -def get_mxfp4_asymmetric_schedule(is_bscale_shuffled: bool = False): +def get_mxfp4_asymmetric_schedule( + eliminate_epilogue: bool = False, is_bscale_shuffled: bool = False +): """Return an asymmetric-prefetch MXFP4 schedule for wave_compile(). Asymmetric data paths: @@ -1162,6 +1164,11 @@ def get_mxfp4_asymmetric_schedule(is_bscale_shuffled: bool = False): Second MMA half: interleaved with B_scale loads and next-iteration first-partition A reads (plus G2S for the iteration after next). + + When eliminate_epilogue=True the loop runs for the full K trip count + and relies on OOB buffer loads returning zero (GFX9+ hardware guarantee) + so that extra iterations contribute nothing to the accumulators. This + removes all epilogue code, reducing total code size. """ M = tkl.sym.M @@ -1204,7 +1211,7 @@ def mxfp4_dbuf_schedule(): # ===================================================================== # Create 2-stage pipeline (double buffering) # ===================================================================== - pipeline_loop = tkw.pipeline(k_loop) + pipeline_loop = tkw.pipeline(k_loop, eliminate_epilogue=eliminate_epilogue) # This forces the pipeline to use double buffering pipeline_loop.multi_buffer_count = 2 @@ -1403,147 +1410,154 @@ def mxfp4_dbuf_schedule(): ), ] - #################### EPILOGUE #################### - - # Filter nodes for EPILOGUE stage - - epilogue_g2v_b = tkw.filter_nodes(g2v_b, subgraph=pipeline_loop.EPILOGUE) - epilogue_g2v_b_scale = tkw.filter_nodes( - g2v_b_scale, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_s2v_a_0 = tkw.filter_nodes(s2v_a_0, subgraph=pipeline_loop.EPILOGUE) - epilogue_s2v_a_scale_0 = tkw.filter_nodes( - s2v_a_scale_0, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_s2v_a_1 = tkw.filter_nodes(s2v_a_1, subgraph=pipeline_loop.EPILOGUE) - epilogue_s2v_a_scale_1 = tkw.filter_nodes( - s2v_a_scale_1, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_bitcast_a = tkw.filter_nodes( - bitcast_a, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_bitcast_a_scale = tkw.filter_nodes( - bitcast_a_scale, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_bitcast_b = tkw.filter_nodes( - bitcast_b, subgraph=pipeline_loop.EPILOGUE - ) - epilogue_bitcast_b_scale = tkw.filter_nodes( - bitcast_b_scale, subgraph=pipeline_loop.EPILOGUE - ) - - epilogue_mma = tkw.filter_nodes(scaled_mma, subgraph=pipeline_loop.EPILOGUE) - - def split_by_iteration(nodes, key="name"): - # TODO: Replace name-based splitting with a pipeline_drain_iteration - # attribute (analogous to unroll_iteration). expanded_dims can't be - # used here because loop_reconstruction copies them verbatim for - # both drain iterations. - itr0 = [] - itr1 = [] - for node in nodes: - value = getattr(node, key) - if "1_2" in value: - itr0.append(node) - elif "2_2" in value: - itr1.append(node) - else: - raise ValueError(f"Unknown {key} for node: {value}") - return itr0, itr1 + if eliminate_epilogue: + clusters += prologue_clusters + tkw.reorder_graph(pipeline_loop.KERNEL, clusters) + else: + epilogue_g2v_b = tkw.filter_nodes(g2v_b, subgraph=pipeline_loop.EPILOGUE) + epilogue_g2v_b_scale = tkw.filter_nodes( + g2v_b_scale, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_s2v_a_0 = tkw.filter_nodes( + s2v_a_0, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_s2v_a_scale_0 = tkw.filter_nodes( + s2v_a_scale_0, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_s2v_a_1 = tkw.filter_nodes( + s2v_a_1, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_s2v_a_scale_1 = tkw.filter_nodes( + s2v_a_scale_1, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_bitcast_a = tkw.filter_nodes( + bitcast_a, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_bitcast_a_scale = tkw.filter_nodes( + bitcast_a_scale, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_bitcast_b = tkw.filter_nodes( + bitcast_b, subgraph=pipeline_loop.EPILOGUE + ) + epilogue_bitcast_b_scale = tkw.filter_nodes( + bitcast_b_scale, subgraph=pipeline_loop.EPILOGUE + ) - epilogue_mma_itr0, epilogue_mma_itr1 = split_by_iteration(epilogue_mma) - epilogue_s2v_a_1_itr0, epilogue_s2v_a_1_itr1 = split_by_iteration( - epilogue_s2v_a_1 - ) - epilogue_s2v_a_scale_1_itr0, epilogue_s2v_a_scale_1_itr1 = split_by_iteration( - epilogue_s2v_a_scale_1 - ) - epilogue_bitcast_a_itr0, epilogue_bitcast_a_itr1 = split_by_iteration( - epilogue_bitcast_a - ) - epilogue_bitcast_a_scale_itr0, epilogue_bitcast_a_scale_itr1 = ( - split_by_iteration(epilogue_bitcast_a_scale) - ) - epilogue_bitcast_b_itr0, epilogue_bitcast_b_itr1 = split_by_iteration( - epilogue_bitcast_b - ) - epilogue_bitcast_b_scale_itr0, epilogue_bitcast_b_scale_itr1 = ( - split_by_iteration(epilogue_bitcast_b_scale) - ) + epilogue_mma = tkw.filter_nodes(scaled_mma, subgraph=pipeline_loop.EPILOGUE) + + def split_by_iteration(nodes, key="name"): + # TODO: Replace name-based splitting with a + # pipeline_drain_iteration attribute (analogous to + # unroll_iteration). expanded_dims can't be used here because + # loop_reconstruction copies them verbatim for both drain + # iterations. + itr0 = [] + itr1 = [] + for node in nodes: + value = getattr(node, key) + if "1_2" in value: + itr0.append(node) + elif "2_2" in value: + itr1.append(node) + else: + raise ValueError(f"Unknown {key} for node: {value}") + return itr0, itr1 + + epilogue_mma_itr0, epilogue_mma_itr1 = split_by_iteration(epilogue_mma) + epilogue_s2v_a_1_itr0, epilogue_s2v_a_1_itr1 = split_by_iteration( + epilogue_s2v_a_1 + ) + ( + epilogue_s2v_a_scale_1_itr0, + epilogue_s2v_a_scale_1_itr1, + ) = split_by_iteration(epilogue_s2v_a_scale_1) + epilogue_bitcast_a_itr0, epilogue_bitcast_a_itr1 = split_by_iteration( + epilogue_bitcast_a + ) + epilogue_bitcast_a_scale_itr0, epilogue_bitcast_a_scale_itr1 = ( + split_by_iteration(epilogue_bitcast_a_scale) + ) + epilogue_bitcast_b_itr0, epilogue_bitcast_b_itr1 = split_by_iteration( + epilogue_bitcast_b + ) + epilogue_bitcast_b_scale_itr0, epilogue_bitcast_b_scale_itr1 = ( + split_by_iteration(epilogue_bitcast_b_scale) + ) - epilogue_mma_itr0_0, epilogue_mma_itr0_1 = tkw.partition_by_dim( - epilogue_mma_itr0, dim=M, num_partitions=2 - ) - epilogue_bitcast_a_itr0_0, epilogue_bitcast_a_itr0_1 = tkw.partition_by_dim( - epilogue_bitcast_a_itr0, dim=M, num_partitions=2 - ) - epilogue_bitcast_a_scale_itr0_0, epilogue_bitcast_a_scale_itr0_1 = ( - tkw.partition_by_dim(epilogue_bitcast_a_scale_itr0, dim=M, num_partitions=2) - ) + epilogue_mma_itr0_0, epilogue_mma_itr0_1 = tkw.partition_by_dim( + epilogue_mma_itr0, dim=M, num_partitions=2 + ) + epilogue_bitcast_a_itr0_0, epilogue_bitcast_a_itr0_1 = tkw.partition_by_dim( + epilogue_bitcast_a_itr0, dim=M, num_partitions=2 + ) + epilogue_bitcast_a_scale_itr0_0, epilogue_bitcast_a_scale_itr0_1 = ( + tkw.partition_by_dim( + epilogue_bitcast_a_scale_itr0, dim=M, num_partitions=2 + ) + ) - epilogue_mma_itr1_0, epilogue_mma_itr1_1 = tkw.partition_by_dim( - epilogue_mma_itr1, dim=M, num_partitions=2 - ) - epilogue_bitcast_a_itr1_0, epilogue_bitcast_a_itr1_1 = tkw.partition_by_dim( - epilogue_bitcast_a_itr1, dim=M, num_partitions=2 - ) - epilogue_bitcast_a_scale_itr1_0, epilogue_bitcast_a_scale_itr1_1 = ( - tkw.partition_by_dim(epilogue_bitcast_a_scale_itr1, dim=M, num_partitions=2) - ) + epilogue_mma_itr1_0, epilogue_mma_itr1_1 = tkw.partition_by_dim( + epilogue_mma_itr1, dim=M, num_partitions=2 + ) + epilogue_bitcast_a_itr1_0, epilogue_bitcast_a_itr1_1 = tkw.partition_by_dim( + epilogue_bitcast_a_itr1, dim=M, num_partitions=2 + ) + epilogue_bitcast_a_scale_itr1_0, epilogue_bitcast_a_scale_itr1_1 = ( + tkw.partition_by_dim( + epilogue_bitcast_a_scale_itr1, dim=M, num_partitions=2 + ) + ) - epilogue_clusters_itr0 = [ - tkw.cluster( - [ - tkw.MemoryCounterWaitBarrier(load=0, ds=0), - epilogue_bitcast_a_itr0_0, - epilogue_bitcast_a_scale_itr0_0, - epilogue_bitcast_b_itr0, - epilogue_bitcast_b_scale_itr0, - tkw.SchedulingBarrier([]), - epilogue_mma_itr0_0, - epilogue_g2v_b, - epilogue_s2v_a_1_itr0, - epilogue_g2v_b_scale, - epilogue_s2v_a_scale_1_itr0, - epilogue_bitcast_a_itr0_1, - epilogue_bitcast_a_scale_itr0_1, - ], - ), - tkw.cluster( - [ - epilogue_mma_itr0_1, - tkw.MemoryCounterWaitBarrier(load=0, ds=0), - epilogue_s2v_a_0, - epilogue_s2v_a_scale_0, - ], - ), - tkw.cluster( - [ - tkw.MemoryCounterWaitBarrier(load=0, ds=0), - epilogue_bitcast_a_itr1_0, - epilogue_bitcast_a_scale_itr1_0, - epilogue_bitcast_b_itr1, - epilogue_bitcast_b_scale_itr1, - tkw.SchedulingBarrier([]), - epilogue_mma_itr1_0, - epilogue_s2v_a_1_itr1, - epilogue_s2v_a_scale_1_itr1, - ], - ), - tkw.cluster( - [ - epilogue_bitcast_a_itr1_1, - epilogue_bitcast_a_scale_itr1_1, - epilogue_mma_itr1_1, - ], - ), - ] + epilogue_clusters_itr0 = [ + tkw.cluster( + [ + epilogue_bitcast_a_itr0_0, + epilogue_bitcast_a_scale_itr0_0, + epilogue_bitcast_b_itr0, + epilogue_bitcast_b_scale_itr0, + tkw.SchedulingBarrier([]), + epilogue_mma_itr0_0, + epilogue_g2v_b, + epilogue_s2v_a_1_itr0, + epilogue_g2v_b_scale, + epilogue_s2v_a_scale_1_itr0, + epilogue_bitcast_a_itr0_1, + epilogue_bitcast_a_scale_itr0_1, + ], + ), + tkw.cluster( + [ + epilogue_mma_itr0_1, + tkw.SchedulingBarrier([]), + epilogue_s2v_a_0, + epilogue_s2v_a_scale_0, + ], + ), + tkw.cluster( + [ + epilogue_bitcast_a_itr1_0, + epilogue_bitcast_a_scale_itr1_0, + epilogue_bitcast_b_itr1, + epilogue_bitcast_b_scale_itr1, + tkw.SchedulingBarrier([]), + epilogue_mma_itr1_0, + epilogue_s2v_a_1_itr1, + epilogue_s2v_a_scale_1_itr1, + ], + ), + tkw.cluster( + [ + epilogue_bitcast_a_itr1_1, + epilogue_bitcast_a_scale_itr1_1, + epilogue_mma_itr1_1, + ], + ), + ] - tkw.reorder_graph(pipeline_loop.PROLOGUE, prologue_clusters) - tkw.reorder_graph(pipeline_loop.KERNEL, clusters) - # tkw.reorder_graph(pipeline_loop.EPILOGUE, epilogue_clusters_itr0) - unroll_factor = 2 - tkw.unroll(pipeline_loop.KERNEL, unroll_factor) + tkw.reorder_graph(pipeline_loop.PROLOGUE, prologue_clusters) + tkw.reorder_graph(pipeline_loop.KERNEL, clusters) + unroll_factor = 2 + tkw.unroll(pipeline_loop.KERNEL, unroll_factor) tkw.insert_at_start( pipeline_loop.KERNEL, diff --git a/wave_lang/kernel/wave/scheduling/loop_reconstruction.py b/wave_lang/kernel/wave/scheduling/loop_reconstruction.py index b757c8aecf..70e7c7b7b7 100644 --- a/wave_lang/kernel/wave/scheduling/loop_reconstruction.py +++ b/wave_lang/kernel/wave/scheduling/loop_reconstruction.py @@ -1,6 +1,7 @@ from collections import defaultdict, deque from enum import Enum +import sympy import torch.fx as fx from wave_lang.support.logging import get_logger @@ -9,6 +10,7 @@ from ..._support.tracing import CapturedTrace from ...ops.wave_ops import ( CustomOp, + GatherToLDS, GetResult, IterArg, Iterate, @@ -812,6 +814,52 @@ def erase_allocs(allocs: list[fx.Node]): get_custom(alloc).erase() +def guard_g2s_with_bounds_check( + trace: CapturedTrace, + constraints: list[Constraint], +): + """ + Post-scheduling pass: annotate GatherToLDS nodes inside pipelined loops + that have eliminate_epilogue=True with branchless SRD guard metadata. + + When eliminate_epilogue=True, the loop runs for the full trip count. Stage 0 + (the prefetch stage) uses iv + (num_stages-1)*step, which goes OOB in the + last (num_stages-1) iterations. Instead of wrapping gather_to_lds in an + scf.if branch, we annotate each gather_to_lds with a guard condition so + that codegen can emit a dynamic validBytes that becomes 0 when OOB: + + validBytes = select(iv + prefetch_offset < max_iv, real_validBytes, 0) + + This makes the hardware DMA a no-op (reads nothing) without any branch. + """ + for node in trace.walk(lambda n: isinstance(get_custom(n), Iterate)): + if not node.meta.get("eliminate_epilogue", False): + continue + + iterate = get_custom(node) + subgraph_name = iterate.subgraph_name + pipelined_graph = trace.get_subgraph(subgraph_name) + + g2s_nodes = [ + n for n in pipelined_graph.nodes if isinstance(get_custom(n), GatherToLDS) + ] + if not g2s_nodes: + continue + + num_stages = node.meta["num_pipeline_stages"] + step = iterate.step + max_iv = iterate.count + induction_variable = get_induction_variable(iterate, constraints) + + prefetch_offset = (num_stages - 1) * step + guard_condition = sympy.StrictLessThan( + induction_variable + prefetch_offset, max_iv + ) + + for g2s in g2s_nodes: + g2s.meta["g2s_guard"] = guard_condition + + def construct_pipelined_loop( trace: CapturedTrace, reduction: Iterate, @@ -823,11 +871,16 @@ def construct_pipelined_loop( visualize: bool = False, use_scheduling_barriers: bool = False, multi_buffer_count: Optional[int] = None, + eliminate_epilogue: bool = False, ) -> tuple[fx.Node, dict[fx.Node, list[fx.Node]], list[fx.Node]]: """ Given a graph annotated with scheduling parameters, construct a pipelined loop with a prologue, kernel and epilogue. + When eliminate_epilogue=True, the epilogue is not generated. The loop runs + for the full trip count and relies on out-of-bounds loads returning zero + to make the extra iterations contribute nothing to the accumulators. + Returns: pipelined_reduction: The pipelined loop node node_mapping: Dictionary tracking mappings from original nodes to pipelined copies @@ -893,28 +946,61 @@ def construct_pipelined_loop( trace.add_subgraph( get_custom(pipelined_reduction).subgraph_name, pipelined_reduction_graph ) - # Construct epilogue. - # The epilogue induction variables must account for the step size. - # With step > 1, each "iteration" covers step original iterations. - final_results = construct_epilogue( - graph, - reduction, - get_custom(pipelined_reduction), - partitioned_graph, - num_stages, - initiation_interval, - rotating_registers, - induction_variable, - [ - max_induction_variable - num_stages * step + i * step - for i in range(num_stages) - ], - create_drain_stage_schedule(num_stages), - num_rotating_registers, - visualize, - outer_vars=outer_vars, - node_mapping=node_mapping, - ) + + if eliminate_epilogue: + pipelined_reduction.meta["eliminate_epilogue"] = True + pipelined_reduction.meta["num_pipeline_stages"] = num_stages + + # No epilogue: the loop runs for the full trip count. + # The final results are simply the GetResult nodes from the pipelined loop. + pipelined_custom = get_custom(pipelined_reduction) + existing_get_results: list[GetResult] = sorted( + [x for x in pipelined_custom.users if isinstance(x, GetResult)], + key=lambda x: x.res_idx, + ) + iter_args = reduction.iter_args(graph) + # Ensure we have GetResult for every iter arg + last_get_result = ( + existing_get_results[-1].fx_node + if existing_get_results + else pipelined_reduction + ) + existing_indices = {x.res_idx for x in existing_get_results} + for i in range(len(iter_args)): + if i not in existing_indices: + with pipelined_custom.graph.inserting_after(last_get_result): + result = GetResult(pipelined_reduction, i).add_to_graph( + pipelined_custom.graph, + type=iter_args[i].type, + loc=pipelined_custom.location, + ) + existing_get_results.append(get_custom(result)) + last_get_result = result + existing_get_results = sorted(existing_get_results, key=lambda x: x.res_idx) + final_results = [gr.fx_node for gr in existing_get_results] + else: + # Construct epilogue. + # The epilogue induction variables must account for the step size. + # With step > 1, each "iteration" covers step original iterations. + final_results = construct_epilogue( + graph, + reduction, + get_custom(pipelined_reduction), + partitioned_graph, + num_stages, + initiation_interval, + rotating_registers, + induction_variable, + [ + max_induction_variable - num_stages * step + i * step + for i in range(num_stages) + ], + create_drain_stage_schedule(num_stages), + num_rotating_registers, + visualize, + outer_vars=outer_vars, + node_mapping=node_mapping, + ) # Remove the unpipelined reduction and the corresponding subgraph reduction.erase() diff --git a/wave_lang/kernel/wave/scheduling/schedule.py b/wave_lang/kernel/wave/scheduling/schedule.py index 11cb374c42..7d123d0cce 100644 --- a/wave_lang/kernel/wave/scheduling/schedule.py +++ b/wave_lang/kernel/wave/scheduling/schedule.py @@ -214,6 +214,7 @@ def build_guarded_pipeline_with_remainder( visualize: bool = False, use_scheduling_barriers: bool = False, multi_buffer_count: Optional[int] = None, + eliminate_epilogue: bool = False, ): """ Build conditional + pipelined loop + remainder loop for dynamic shapes. @@ -312,6 +313,7 @@ def build_guarded_pipeline_with_remainder( visualize, use_scheduling_barriers, multi_buffer_count, + eliminate_epilogue=eliminate_epilogue, ) # node_mapping keys are from the copied body graph. Translate them back @@ -321,12 +323,17 @@ def build_guarded_pipeline_with_remainder( node_mapping = {new_to_old.get(k, k): v for k, v in node_mapping.items()} # Set the count for the pipelined loop - # With step > 1 (e.g., from unrolling), we need to reduce the count by more - # to prevent out-of-bounds access. The last kernel iteration's stage 0 loads - # data for the "next" iteration (offset by step), so we need to ensure - # that stays within bounds. step = get_custom(pipelined_node).step - get_custom(pipelined_node).count = pipelined_iterations - (num_stages - 1) * step + if eliminate_epilogue: + get_custom(pipelined_node).count = pipelined_iterations + else: + # With step > 1 (e.g., from unrolling), we need to reduce the count by more + # to prevent out-of-bounds access. The last kernel iteration's stage 0 loads + # data for the "next" iteration (offset by step), so we need to ensure + # that stays within bounds. + get_custom(pipelined_node).count = ( + pipelined_iterations - (num_stages - 1) * step + ) # Verify we have the right number of results assert len(final_results) == len( @@ -420,6 +427,7 @@ def construct_pipelined_loop_adaptive( visualize: bool = False, use_scheduling_barriers: bool = False, multi_buffer_count: Optional[int] = None, + eliminate_epilogue: bool = False, ): """ Constructs a pipelined loop wrapped in a conditional, followed by a remainder loop. @@ -433,6 +441,12 @@ def construct_pipelined_loop_adaptive( else: return (0, init_values...) remainder_loop(start=iterations_done, end=total_iterations) + + When eliminate_epilogue=True, the epilogue is not generated and the loop + runs for the full trip count. Out-of-bounds loads in the extra iterations + must return zero (guaranteed by buffer_load on GFX950). This trades wasted + prefetch work in the last (num_stages-1) iterations for eliminating all + epilogue code (MFMAs, loads, bitcasts). """ # Check if we have a dynamic shape (max_induction_variable is symbolic) is_dynamic = not ( @@ -456,12 +470,16 @@ def construct_pipelined_loop_adaptive( visualize, use_scheduling_barriers, multi_buffer_count, + eliminate_epilogue=eliminate_epilogue, ) if new_reduction: step = get_custom(new_reduction).step - get_custom(new_reduction).count = ( - max_induction_variable - (num_stages - 1) * step - ) + if eliminate_epilogue: + get_custom(new_reduction).count = max_induction_variable + else: + get_custom(new_reduction).count = ( + max_induction_variable - (num_stages - 1) * step + ) return new_reduction, node_mapping # For dynamic shapes, emit conditional + pipelined loop + remainder loop @@ -477,6 +495,7 @@ def construct_pipelined_loop_adaptive( visualize, use_scheduling_barriers, multi_buffer_count, + eliminate_epilogue=eliminate_epilogue, ) @@ -491,6 +510,7 @@ def apply_pipelined_schedule( scheduling_type: SchedulingType = SchedulingType.NONE, visualize: bool = False, multi_buffer_count: Optional[int] = None, + eliminate_epilogue: bool = False, ) -> Optional[tuple[fx.Node, dict]]: # After scheduling has completed, we have enough information to decide @@ -525,6 +545,7 @@ def apply_pipelined_schedule( visualize, use_scheduling_barriers, multi_buffer_count, + eliminate_epilogue=eliminate_epilogue, )