-
Notifications
You must be signed in to change notification settings - Fork 28
Epilogue elimination in MXFP4 #975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ca7adb3
388b3a8
4997008
ac90421
856a901
9a6d8f5
951d391
0f53afe
2b4f9e7
097d84f
d0b3775
fb34507
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -336,6 +336,29 @@ def _valid_bytes_buffer(elem_type: IrType) -> int: | |
| return ans | ||
|
|
||
|
|
||
| def _compute_total_valid_bytes( | ||
| elem_type: IrType, | ||
| symbolic_shape: tuple, | ||
| use_real_bounds: bool, | ||
| ) -> int: | ||
| """Return the total valid byte count for a buffer SRD. | ||
|
|
||
| When *use_real_bounds* is True and *symbolic_shape* resolves to a concrete | ||
| number, the result is clamped to the actual tensor size (still bounded by | ||
| the hardware maximum). Otherwise falls back to the hardware maximum | ||
| returned by ``_valid_bytes_buffer``. | ||
| """ | ||
| if use_real_bounds and symbolic_shape is not None: | ||
| elem_bytes = elem_type.width // 8 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For mxfp4, this returns a 0. Would that be a problem? |
||
| total_elements = subs_idxc(sympy.prod(s for s in symbolic_shape)) | ||
| if isinstance(total_elements, (int, float)) or ( | ||
| hasattr(total_elements, "is_number") and total_elements.is_number | ||
| ): | ||
| total_bytes = int(total_elements) * elem_bytes | ||
| return min(total_bytes, _valid_bytes_buffer(elem_type)) | ||
| return _valid_bytes_buffer(elem_type) | ||
|
|
||
|
|
||
| def _get_out_of_bounds_index(element_type: IrType) -> int: | ||
| """ | ||
| returns the first index that's out of bounds of a buffer based on the element type and maximum bytes | ||
|
|
@@ -362,25 +385,99 @@ def _get_constant_value(candidate: Value): | |
| return candidate.owner.opview.value.value | ||
|
|
||
|
|
||
| def _compute_branchless_valid_bytes( | ||
| emitter: WaveEmitter, | ||
| symbolic_shape: tuple, | ||
| elem_type: IrType, | ||
| guard_condition: sympy.Basic, | ||
| ) -> Value: | ||
| """Compute a dynamic validBytes that becomes 0 when OOB. | ||
|
|
||
| The guard_condition is a sympy expression like: | ||
| iv + prefetch_offset < max_iv | ||
|
|
||
| We emit: | ||
| cond = gen_sympy_index(guard_condition) # index type, nonzero=true | ||
| real_valid = compute_static_validBytes() | ||
| validBytes = select(cond != 0, real_valid, 0) | ||
|
|
||
| When the condition is false (last iterations), validBytes=0 makes the | ||
| SRD's NUM_RECORDS=0 so gather_to_lds DMA is a hardware no-op. | ||
harsh-nod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| uint64 = IntegerType.get_signless(64) | ||
| total_bytes = _compute_total_valid_bytes( | ||
| elem_type, symbolic_shape, use_real_bounds=True | ||
| ) | ||
|
|
||
| real_valid = arith_d.constant(uint64, get_constant_attr(total_bytes, uint64)) | ||
| zero_valid = arith_d.constant(uint64, get_constant_attr(0, uint64)) | ||
|
|
||
| cond_val = gen_sympy_index(add_emitter_subs(emitter), guard_condition) | ||
| i1 = IntegerType.get_signless(1) | ||
| if cond_val.type != i1: | ||
| zero_idx = arith_d.constant(cond_val.type, 0) | ||
| cond_val = arith_d.cmpi(arith_d.CmpIPredicate.ne, cond_val, zero_idx) | ||
|
|
||
| return arith_d.select(cond_val, real_valid, zero_valid) | ||
|
|
||
|
|
||
| def _cast_buffer_and_encode_stride( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you remove is_read from this as the functionality of this function is going beyond its original definition to something like and then compute valid bytes becomes |
||
| ptr: Value, strides: tuple[Value], elem_type: IrType, emitter: WaveEmitter | ||
| ptr: Value, | ||
| strides: tuple[Value], | ||
| elem_type: IrType, | ||
| emitter: WaveEmitter, | ||
| symbolic_shape: tuple = None, | ||
| is_read: bool = True, | ||
| valid_bytes_override: Value = None, | ||
| ) -> Value: | ||
| uint64 = IntegerType.get_signless(64) | ||
| uint14 = IntegerType.get_signless(14) | ||
| elem_bytes = elem_type.width // 8 | ||
|
|
||
| if valid_bytes_override is not None: | ||
| valid_bytes_val = valid_bytes_override | ||
| else: | ||
| use_real_bounds = ( | ||
| emitter.options.eliminate_epilogue and symbolic_shape is not None | ||
| ) | ||
| total_bytes = _compute_total_valid_bytes( | ||
| elem_type, symbolic_shape, use_real_bounds | ||
| ) | ||
|
|
||
| # With resetOffset, the SRD base is adjusted forward by the memref's | ||
| # offset, so validBytes must be total_bytes - offset_bytes to avoid | ||
| # over-reporting the valid range past the actual allocation. | ||
| # | ||
| # For writes, skip the offset subtraction: real buffer bounds are only | ||
| # useful for reads (OOB loads return 0), and subtracting the offset | ||
| # from a clamped total_bytes can produce values that overflow the | ||
| # 32-bit SRD NUM_RECORDS field for output buffers larger than 2 GB. | ||
xintin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if use_real_bounds and is_read: | ||
| metadata = memref_d.extract_strided_metadata(ptr) | ||
| offset_elements = metadata[1] | ||
| offset_bytes = arith_d.index_cast(uint64, offset_elements) | ||
| elem_bytes_val = arith_d.constant( | ||
| uint64, get_constant_attr(elem_bytes, uint64) | ||
| ) | ||
| offset_bytes = arith_d.muli(offset_bytes, elem_bytes_val) | ||
| total_bytes_val = arith_d.constant( | ||
| uint64, get_constant_attr(total_bytes, uint64) | ||
| ) | ||
| valid_bytes_val = arith_d.subi(total_bytes_val, offset_bytes) | ||
| else: | ||
| valid_bytes_val = arith_d.constant( | ||
| uint64, get_constant_attr(total_bytes, uint64) | ||
| ) | ||
|
|
||
| valid_bytes = _valid_bytes_buffer( | ||
| elem_type | ||
| ) # max bytes that are in range to be addressed from a buffer | ||
| valid_bytes_constant = get_constant_attr(valid_bytes, uint64) | ||
| valid_bytes_constant = arith_d.constant(uint64, valid_bytes_constant) | ||
| stride_rank = len(strides) | ||
| swizzle_stride = None | ||
|
|
||
| if stride_rank >= 2: | ||
| # fastest_dim_bound == second to last stride. | ||
| stride_candidate = strides[-2] | ||
| stride_int = _get_constant_value(stride_candidate) | ||
| # Only swizzle if stride is static and <= 8192(the useful case). | ||
| # Only swizzle if stride is static and fits in signed i14 | ||
| # (max representable positive value is 8191, but 8192 wraps to | ||
| # -8192 which the hardware accepts). | ||
| if stride_int and stride_int <= 8192: | ||
| swizzle_stride = arith_d.index_cast(uint14, stride_candidate) | ||
|
|
||
|
|
@@ -390,14 +487,14 @@ def _cast_buffer_and_encode_stride( | |
| cache_swizzle_stride=swizzle_stride, | ||
| bounds_check=True, | ||
| reset_offset=True, | ||
| valid_bytes=valid_bytes_constant, | ||
| valid_bytes=valid_bytes_val, | ||
| ) | ||
| else: | ||
| ptr = amdgpu_d.fat_raw_buffer_cast( | ||
| ptr, | ||
| bounds_check=True, | ||
| reset_offset=True, | ||
| valid_bytes=valid_bytes_constant, | ||
| valid_bytes=valid_bytes_val, | ||
| ) | ||
|
|
||
| return ptr | ||
|
|
@@ -508,7 +605,9 @@ def extract(vec, ind): | |
| mem, offset_th = _linearize_memref( | ||
| mem, start_indices_wg, start_indices_th, strides | ||
| ) | ||
| mem = _cast_buffer_and_encode_stride(mem, strides, element_type, emitter) | ||
| mem = _cast_buffer_and_encode_stride( | ||
| mem, strides, element_type, emitter, symbolic_shape, is_read | ||
| ) | ||
| elif is_global_mem and not is_read: | ||
| mem, offset_th = _linearize_memref( | ||
| mem, start_indices_wg, start_indices_th, strides | ||
|
|
@@ -553,7 +652,9 @@ def extract(vec, ind): | |
| mem, offset_th = _linearize_memref( | ||
| mem, start_indices_wg, start_indices_th, strides | ||
| ) | ||
| mem = _cast_buffer_and_encode_stride(mem, strides, element_type, emitter) | ||
| mem = _cast_buffer_and_encode_stride( | ||
| mem, strides, element_type, emitter, symbolic_shape, is_read | ||
| ) | ||
|
|
||
| indices = [offset_th] if buffer_ops_enabled else start_indices | ||
|
|
||
|
|
@@ -1109,7 +1210,22 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): | |
| ] | ||
|
|
||
| src, offset_th = _linearize_memref(src, src_index_wg, src_index_th, strides) | ||
| src = _cast_buffer_and_encode_stride(src, strides, element_type, emitter) | ||
|
|
||
| valid_bytes_override = None | ||
| guard_condition = node.meta.get("g2s_guard", None) | ||
| if guard_condition is not None: | ||
| valid_bytes_override = _compute_branchless_valid_bytes( | ||
| emitter, src_symbolic_shape, element_type, guard_condition | ||
| ) | ||
|
|
||
| src = _cast_buffer_and_encode_stride( | ||
| src, | ||
| strides, | ||
| element_type, | ||
| emitter, | ||
| src_symbolic_shape, | ||
| valid_bytes_override=valid_bytes_override, | ||
| ) | ||
|
|
||
| # We previously checked mask is same for all elements, so we can use | ||
| # elements_per_thread=1 to build the mask. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -183,7 +183,11 @@ def reorder_graph(loop: Any, clusters: Any): ... | |
|
|
||
|
|
||
| @define_schedule_op | ||
| def pipeline(iterate: Sequence[fx.Node], multi_buffer_count: Optional[int] = None): ... | ||
| def pipeline( | ||
| iterate: Sequence[fx.Node], | ||
| eliminate_epilogue: bool = False, | ||
| multi_buffer_count: Optional[int] = None, | ||
| ): ... | ||
|
|
||
|
|
||
| @define_schedule_op | ||
|
|
@@ -804,11 +808,13 @@ def __init__( | |
| iterate: Sequence[fx.Node], | ||
| kernel_trace: "CapturedTrace", | ||
| constraints: list[Constraint], | ||
| eliminate_epilogue: bool = False, | ||
| multi_buffer_count: Optional[int] = None, | ||
| ): | ||
| self.iterate = iterate | ||
| self.kernel_trace = kernel_trace | ||
| self.constraints = constraints | ||
| self.eliminate_epilogue = eliminate_epilogue | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not directly related but does the handle() method of Pipeline need to support taking in multi-buffer count? |
||
| self.multi_buffer_count = multi_buffer_count | ||
|
|
||
| # Access options from the current ScheduleContext | ||
|
|
@@ -875,6 +881,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): | |
| scheduling_type=SchedulingType.MANUAL, | ||
| visualize=False, | ||
| multi_buffer_count=self.multi_buffer_count, | ||
| eliminate_epilogue=self.eliminate_epilogue, | ||
| ) | ||
|
|
||
| # Store the pipelined iterate node and node mapping, then create proxies for the stages | ||
|
|
@@ -937,7 +944,11 @@ def PROLOGUE(self): | |
|
|
||
| @property | ||
| def EPILOGUE(self): | ||
| """Get a reference to the EPILOGUE stage (nodes after pipelined iterate).""" | ||
| """Get a reference to the EPILOGUE stage (nodes after pipelined iterate). | ||
| Returns None when eliminate_epilogue=True (epilogue was not generated). | ||
| """ | ||
| if self.eliminate_epilogue: | ||
| return None | ||
| return self._EPILOGUE | ||
|
|
||
| def _update_kernel_node_mapping(self): | ||
|
|
@@ -1110,8 +1121,11 @@ def handle( | |
| kernel_trace, | ||
| constraints: list[Constraint], | ||
| iterate: Sequence[fx.Node], | ||
| eliminate_epilogue: bool = False, | ||
| ): | ||
| real_pipelined_loop = PipelinedLoop(iterate, kernel_trace, constraints) | ||
| real_pipelined_loop = PipelinedLoop( | ||
| iterate, kernel_trace, constraints, eliminate_epilogue=eliminate_epilogue | ||
| ) | ||
|
|
||
| # Return the real object directly (no proxy needed) | ||
| return real_pipelined_loop | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
once, reviewed, I will just add flag to the main test function (whichever it might be at that point).
This test will be removed.