From a08ceee9152de1e16d3a6861ea628ed22c0c69fe Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Fri, 20 Feb 2026 14:43:49 -0600 Subject: [PATCH] Generalized affine expression decomposition for SGPR promotion Add an N-way uniformity decomposition pass that splits any affine expression into additive components by symbol class (workgroup, induction variable, thread). This allows the LLVM backend to keep uniform contributions in SGPRs and fold them into hardware instruction fields (e.g. buffer_load soffset), reducing VGPR pressure. Key changes: - symbol_utils.py: add decompose_affine_by_uniformity(), a general N-way decomposition with cascade validation for cross-class terms - emitter.py: add get_uniformity_classes() and gen_sympy_index_decomposed() so any handler can decompose arbitrary affine maps at the emitter level - read_write.py: replace _split_index_three_way with the general utility, add _compute_linear_offset/_apply_uniform_offsets helpers, and apply the decomposition to handle_read, handle_write, and handle_gather_to_lds Co-authored-by: Cursor --- .../kernel/compiler/wave_codegen/emitter.py | 103 ++++++++++- .../compiler/wave_codegen/read_write.py | 175 ++++++++++++++---- wave_lang/kernel/wave/utils/symbol_utils.py | 74 ++++++++ 3 files changed, 316 insertions(+), 36 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index 970d6b52b0..cb51c4e984 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -75,7 +75,11 @@ from ...wave.compile_options import WaveCompileOptions from ...wave.constraints import Constraint, HardwareConstraint, TilingConstraint from ...wave.utils.general_utils import get_hardware_constraint -from ...wave.utils.symbol_utils import subs_idxc, is_literal +from ...wave.utils.symbol_utils import ( + decompose_affine_by_uniformity, + subs_idxc, + is_literal, +) logger = get_logger("wave.ops_location_check") @@ -598,6 +602,103 @@ def add_emitter_subs( return dynamics +def get_uniformity_classes( + emitter: WaveEmitter, +) -> list[set]: + """Return symbol sets ordered from most-uniform to least-uniform. + + The returned list is suitable for passing to + :func:`decompose_affine_by_uniformity` or + :func:`gen_sympy_index_decomposed`. + + Currently returns up to two classes: + * ``{WORKGROUP_0, WORKGROUP_1, WORKGROUP_2}`` + * induction-variable symbols (if the emitter is inside a loop) + + Thread-ID symbols and dynamic values are the implicit *remainder* + (most divergent) and need not be listed. + """ + wg = {WORKGROUP_0, WORKGROUP_1, WORKGROUP_2} + classes: list[set] = [wg] + iv_syms = set(emitter.get_induction_vars_and_syms()[1]) + if iv_syms: + classes.append(iv_syms) + return classes + + +def gen_sympy_index_decomposed( + emitter: WaveEmitter, + expr: "sympy.Expr", + dynamic_values: dict = {}, + uniform_sym_classes: list[set] | None = None, +) -> tuple[Value, list[Value]]: + """Lower a sympy expression to MLIR with automatic uniformity decomposition. + + Decomposes *expr* into additive components by uniformity class, emits + each component via :func:`gen_sympy_index`, and combines them so that + uniform (SGPR-eligible) contributions are separate ``arith.addi`` ops. + This lets the AMDGPU backend keep uniform parts in SGPRs and + potentially fold them into hardware instruction fields (e.g. soffset). + + Args: + emitter: The current wave emitter (provides symbol bindings). + expr: Sympy expression to lower. + dynamic_values: Extra symbol-to-Value mappings. + uniform_sym_classes: Override for uniformity classes. When + ``None``, uses :func:`get_uniformity_classes`. + + Returns: + ``(combined, components)`` where *combined* is the final MLIR + Value (sum of all components) and *components* is a list of + per-class Values (one per class + remainder). + """ + import sympy as _sympy + + subs = add_emitter_subs(emitter, dynamic_values) + classes = uniform_sym_classes or get_uniformity_classes(emitter) + parts = decompose_affine_by_uniformity(expr, classes) + + zero = _sympy.sympify(0) + component_values = [gen_sympy_index(subs, p) for p in parts] + + # Combine: sum uniform components first (SGPR + SGPR stays SGPR), + # then add the divergent remainder last (VGPR + SGPR). + overflow_flags = ( + arith_d.IntegerOverflowFlags.nsw | arith_d.IntegerOverflowFlags.nuw + ) + uniform_sum = None + for cv in component_values[:-1]: + if _is_zero(cv): + continue + if uniform_sum is None: + uniform_sum = cv + else: + uniform_sum = arith_d.addi(uniform_sum, cv, overflow_flags=overflow_flags) + + remainder = component_values[-1] + + if uniform_sum is None: + combined = remainder + elif _is_zero(remainder): + combined = uniform_sum + else: + combined = arith_d.addi(remainder, uniform_sum, overflow_flags=overflow_flags) + + return combined, component_values + + +def _is_zero(val: Value) -> bool: + """Return True if *val* is a constant-zero index.""" + if not hasattr(val, "owner") or not hasattr(val.owner, "opview"): + return False + op = val.owner.opview + if isinstance(op, arith_d.ConstantOp): + v = op.attributes["value"] + if isinstance(v, IntegerAttr) and int(v) == 0: + return True + return False + + _emulate_ceildiv = bool(int(environ.get("WAVE_EMULATE_CEILDIV", 0))) _use_affine_expr = bool(int(environ.get("WAVE_USE_AFFINE_EXPR", 1))) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index a6fd84eb54..2c9cb0b331 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -62,7 +62,7 @@ ) from ...wave.utils.general_utils import get_fastest_index, infer_dim, linearize_index from ...wave.utils.mapping_utils import transform_index_on_mapping -from ...wave.utils.symbol_utils import safe_subs +from ...wave.utils.symbol_utils import decompose_affine_by_uniformity, safe_subs from .emitter import ( WaveEmitter, add_emitter_subs, @@ -71,8 +71,10 @@ cast_py_value, cast_vector, gen_sympy_index, + gen_sympy_index_decomposed, get_constant_attr, get_type_or_element_type, + get_uniformity_classes, handle_op, ) @@ -105,26 +107,17 @@ def _simplify(expr): return sympy.simplify(expr) -def _split_index(src: IndexExpr | int) -> tuple[IndexExpr, IndexExpr]: - """ - Split index expr into thread-dependent and thread-independent parts - """ - subs_wg = {WORKGROUP_0: 0, WORKGROUP_1: 0, WORKGROUP_2: 0} - # Replace all wg symbols with 0s to get thread-dependent index. - # All dynamic values will also be part of thread-index. - thread_dependent_index = safe_subs(src, subs_wg) +_WG_SYMS = {WORKGROUP_0, WORKGROUP_1, WORKGROUP_2} - # Compute thread-independent index as `orig_index - thread_dependent_index` - # All thread symbols and dynamic should cancel-out in the result. - thread_independent_index = _simplify(src - thread_dependent_index) - if thread_independent_index.free_symbols - set(subs_wg.keys()): - # If we have any symbols besides wg symbols, means some thread or - # dynamic symbols were not canceled out, use the entire index as - # thread dependent index. - thread_independent_index = sympy.sympify(0) - thread_dependent_index = src - return thread_independent_index, thread_dependent_index +def _split_index(src: IndexExpr | int) -> tuple[IndexExpr, IndexExpr]: + """Split index expr into workgroup and thread-dependent parts. + + Thin wrapper around :func:`decompose_affine_by_uniformity` with a + single symbol class (workgroup symbols). + """ + parts = decompose_affine_by_uniformity(src, [_WG_SYMS]) + return parts[0], parts[1] def _extract0(src): @@ -151,15 +144,82 @@ def _build_start_indices( emitter: WaveEmitter, src_indices: dict[IndexExpr, IndexSequence | IndexExpr], dynamic_values: dict[IndexExpr, Any] = {}, -) -> tuple[list[OpResult], list[OpResult], list[OpResult]]: + uniform_sym_classes: list[set] | None = None, +) -> tuple: + """Build MLIR index values with N-way uniformity decomposition. + + When *uniform_sym_classes* is ``None`` (default), performs the legacy + two-way split (workgroup / thread) and returns a 3-tuple:: + + (full_indices, wg_indices, thread_indices) + + When *uniform_sym_classes* is a list of symbol sets (ordered + most-uniform first, e.g. ``[induction_var_syms]``), prepends the + workgroup class automatically and returns an ``(n + 2)``-tuple:: + + (full_indices, wg_indices, class_0_indices, ..., thread_indices) + """ start_indices = _get_start_indices(src_indices) - split_indices = [_split_index(i) for i in start_indices] subs = add_emitter_subs(emitter, dynamic_values) indices = [gen_sympy_index(subs, i) for i in start_indices] - indices_wg = [gen_sympy_index(subs, i[0]) for i in split_indices] - indices_th = [gen_sympy_index(subs, i[1]) for i in split_indices] - return indices, indices_wg, indices_th + classes = [_WG_SYMS] + (uniform_sym_classes or []) + decomposed = [ + decompose_affine_by_uniformity(i, classes) for i in start_indices + ] + + n_parts = len(classes) + 1 # one per class + remainder + parts: list[list[OpResult]] = [] + for k in range(n_parts): + parts.append([gen_sympy_index(subs, d[k]) for d in decomposed]) + + if not uniform_sym_classes: + return indices, parts[0], parts[1] + + return (indices,) + tuple(parts) + + +def _compute_linear_offset( + indices: list[Value | int], + strides: list[Value], +) -> Value | None: + """Linearize per-dimension index values with strides into a scalar offset. + + Returns *None* when ``indices`` is empty. + """ + overflow_flags = arith_d.IntegerOverflowFlags.nsw + offset = None + for idx, stride in zip(indices, strides): + if isinstance(idx, int): + idx = arith_d.constant(IndexType.get(), idx) + off = arith_d.muli(idx, stride, overflow_flags=overflow_flags) + if offset is None: + offset = off + else: + offset = arith_d.addi(offset, off, overflow_flags=overflow_flags) + return offset + + +def _apply_uniform_offsets( + offset_th: Value, + uniform_parts: list[list[Value]], + strides: list[Value], +) -> Value: + """Add uniform (SGPR-eligible) contributions to the thread offset. + + Each entry in *uniform_parts* is a list of per-dimension index values + for one uniformity class. They are linearized with *strides* and + added to *offset_th* as separate ``arith.addi`` ops so the backend + can keep them in SGPRs (e.g. fold into ``buffer_load`` soffset). + """ + overflow_flags = arith_d.IntegerOverflowFlags.nsw + for unif_indices in uniform_parts: + unif_offset = _compute_linear_offset(unif_indices, strides) + if unif_offset is not None and _get_constant_value(unif_offset) != 0: + offset_th = arith_d.addi( + offset_th, unif_offset, overflow_flags=overflow_flags + ) + return offset_th def _get_symbolic_shape(node: fx.Node) -> tuple[IndexExpr]: @@ -469,6 +529,7 @@ def _create_vec_read_write( memory: CustomOp, mask: Optional[Value], node_index: Optional[IndexSequence] = None, + uniform_parts: list[list[Value]] | None = None, ) -> Optional[Value]: is_read = value is None uint32 = IntegerType.get_signless(32) @@ -512,6 +573,8 @@ def extract(vec, ind): mem, start_indices_wg, start_indices_th, strides ) mem = _cast_buffer_and_encode_stride(mem, strides, element_type, emitter) + if uniform_parts: + offset_th = _apply_uniform_offsets(offset_th, uniform_parts, strides) if linearize_shared_mem: mem = _linearize_shared_mem(mem) linearized_index = { @@ -548,6 +611,8 @@ def extract(vec, ind): mem, start_indices_wg, start_indices_th, strides ) mem = _cast_buffer_and_encode_stride(mem, strides, element_type, emitter) + if uniform_parts: + offset_th = _apply_uniform_offsets(offset_th, uniform_parts, strides) indices = [offset_th] if buffer_ops_enabled else start_indices @@ -711,9 +776,21 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): else: mask = _build_mask(emitter, index, elements_per_thread, bounds) - start_indices, start_indices_wg, start_indices_th = _build_start_indices( - emitter, index, dynamic_vals_map_start - ) + induction_vars = set(emitter.get_induction_vars_and_syms()[1]) + uniform_parts: list[list[Value]] = [] + if induction_vars: + start_indices, start_indices_wg, *uniform_parts, start_indices_th = ( + _build_start_indices( + emitter, + index, + dynamic_vals_map_start, + uniform_sym_classes=[induction_vars], + ) + ) + else: + start_indices, start_indices_wg, start_indices_th = _build_start_indices( + emitter, index, dynamic_vals_map_start + ) use_llvm_load = flags != MemoryAccessFlags.NONE if use_llvm_load: @@ -738,6 +815,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): get_custom(memory), mask, node_index=index, + uniform_parts=uniform_parts or None, ) emitter.bind_node_proxy(node, IRProxyValue(result)) @@ -802,9 +880,21 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): else: mask = _build_mask(emitter, index, elements_per_thread, bounds) - start_indices, start_indices_wg, start_indices_th = _build_start_indices( - emitter, index, dynamic_vals_map_start - ) + induction_vars = set(emitter.get_induction_vars_and_syms()[1]) + uniform_parts: list[list[Value]] = [] + if induction_vars: + start_indices, start_indices_wg, *uniform_parts, start_indices_th = ( + _build_start_indices( + emitter, + index, + dynamic_vals_map_start, + uniform_sym_classes=[induction_vars], + ) + ) + else: + start_indices, start_indices_wg, start_indices_th = _build_start_indices( + emitter, index, dynamic_vals_map_start + ) use_llvm_store = flags != MemoryAccessFlags.NONE if use_llvm_store: @@ -825,6 +915,7 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): get_custom(memory), mask, node_index=index, + uniform_parts=uniform_parts or None, ) @@ -1070,13 +1161,24 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): store_type = VectorType.get((elements_per_thread,), element_type) - src_index, src_index_wg, src_index_th = _build_start_indices( - emitter, new_src_idx, src_dynamic_vals_map_start - ) + induction_vars = set(emitter.get_induction_vars_and_syms()[1]) - ip = InsertionPoint.current + uniform_parts: list[list[Value]] = [] + if induction_vars: + src_index, src_index_wg, *uniform_parts, src_index_th = ( + _build_start_indices( + emitter, + new_src_idx, + src_dynamic_vals_map_start, + uniform_sym_classes=[induction_vars], + ) + ) + else: + src_index, src_index_wg, src_index_th = _build_start_indices( + emitter, new_src_idx, src_dynamic_vals_map_start + ) - induction_vars = set(emitter.get_induction_vars_and_syms()[1]) + ip = InsertionPoint.current # Hoist to the function level, if not using induction variables. if not any( @@ -1105,6 +1207,9 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): src, offset_th = _linearize_memref(src, src_index_wg, src_index_th, strides) src = _cast_buffer_and_encode_stride(src, strides, element_type, emitter) + if uniform_parts: + offset_th = _apply_uniform_offsets(offset_th, uniform_parts, strides) + # We previously checked mask is same for all elements, so we can use # elements_per_thread=1 to build the mask. mask = _build_mask( diff --git a/wave_lang/kernel/wave/utils/symbol_utils.py b/wave_lang/kernel/wave/utils/symbol_utils.py index 66b7f6c5a6..4fd73ef108 100644 --- a/wave_lang/kernel/wave/utils/symbol_utils.py +++ b/wave_lang/kernel/wave/utils/symbol_utils.py @@ -21,6 +21,80 @@ ) +#################################################################### +# Affine-expression uniformity decomposition +#################################################################### + + +def decompose_affine_by_uniformity( + expr: IndexExpr | int, + symbol_classes: list[set], +) -> list[IndexExpr]: + """Decompose a polynomial expression into additive components by symbol uniformity. + + Given symbol classes ``[C0, C1, ..., Cn]`` ordered from most-uniform + (e.g. workgroup IDs) to least-uniform (e.g. induction variables), + produces ``n + 1`` components ``[part_0, ..., part_n, remainder]`` + such that ``sum(parts) == expr``. + + * ``part_k`` ideally depends only on symbols from ``C_k``. + * ``remainder`` contains thread-dependent terms and constants. + * Cross-class terms (e.g. ``WG * IV``) are cascaded to the next + less-uniform class during validation. + + This generalises two-way (workgroup / thread) and three-way + (workgroup / induction-var / thread) index splits into a single + N-way routine. + + Args: + expr: Sympy expression (or plain int) to decompose. + symbol_classes: Ordered list of symbol sets, most-uniform first. + + Returns: + List of ``n + 1`` IndexExprs (one per class + remainder). + """ + zero = sympy.sympify(0) + n = len(symbol_classes) + + if isinstance(expr, (int, float)): + return [zero] * n + [sympy.sympify(expr)] + + if not isinstance(expr, sympy.Basic): + expr = sympy.sympify(expr) + + # Progressive remainders: r[k] = expr with classes 0..k-1 zeroed out. + zero_subs: dict = {} + remainders = [expr] + for cls in symbol_classes: + zero_subs.update({s: 0 for s in cls}) + remainders.append(safe_subs(expr, zero_subs)) + + # component[k] = r[k] - r[k+1]: the part attributable to class k. + components: list[IndexExpr] = [] + for k in range(n): + components.append(sympy.simplify(remainders[k] - remainders[k + 1])) + components.append(remainders[n]) # remainder + + # Cascade validation: component[k] must only contain symbols from + # classes 0..k. If it has symbols from "below", merge into the next + # less-uniform component. + allowed: set = set() + for k in range(n): + allowed = allowed | symbol_classes[k] + if components[k] == zero: + continue + actual = ( + components[k].free_symbols + if isinstance(components[k], sympy.Basic) + else set() + ) + if actual - allowed: + components[k + 1] = sympy.simplify(components[k + 1] + components[k]) + components[k] = zero + + return components + + #################################################################### # Interval-arithmetic simplification for floor/Mod expressions. ####################################################################