diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index 970d6b52b..cb51c4e98 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -75,7 +75,11 @@ from ...wave.compile_options import WaveCompileOptions from ...wave.constraints import Constraint, HardwareConstraint, TilingConstraint from ...wave.utils.general_utils import get_hardware_constraint -from ...wave.utils.symbol_utils import subs_idxc, is_literal +from ...wave.utils.symbol_utils import ( + decompose_affine_by_uniformity, + subs_idxc, + is_literal, +) logger = get_logger("wave.ops_location_check") @@ -598,6 +602,103 @@ def add_emitter_subs( return dynamics +def get_uniformity_classes( + emitter: WaveEmitter, +) -> list[set]: + """Return symbol sets ordered from most-uniform to least-uniform. + + The returned list is suitable for passing to + :func:`decompose_affine_by_uniformity` or + :func:`gen_sympy_index_decomposed`. + + Currently returns up to two classes: + * ``{WORKGROUP_0, WORKGROUP_1, WORKGROUP_2}`` + * induction-variable symbols (if the emitter is inside a loop) + + Thread-ID symbols and dynamic values are the implicit *remainder* + (most divergent) and need not be listed. + """ + wg = {WORKGROUP_0, WORKGROUP_1, WORKGROUP_2} + classes: list[set] = [wg] + iv_syms = set(emitter.get_induction_vars_and_syms()[1]) + if iv_syms: + classes.append(iv_syms) + return classes + + +def gen_sympy_index_decomposed( + emitter: WaveEmitter, + expr: "sympy.Expr", + dynamic_values: dict = {}, + uniform_sym_classes: list[set] | None = None, +) -> tuple[Value, list[Value]]: + """Lower a sympy expression to MLIR with automatic uniformity decomposition. + + Decomposes *expr* into additive components by uniformity class, emits + each component via :func:`gen_sympy_index`, and combines them so that + uniform (SGPR-eligible) contributions are separate ``arith.addi`` ops. + This lets the AMDGPU backend keep uniform parts in SGPRs and + potentially fold them into hardware instruction fields (e.g. soffset). + + Args: + emitter: The current wave emitter (provides symbol bindings). + expr: Sympy expression to lower. + dynamic_values: Extra symbol-to-Value mappings. + uniform_sym_classes: Override for uniformity classes. When + ``None``, uses :func:`get_uniformity_classes`. + + Returns: + ``(combined, components)`` where *combined* is the final MLIR + Value (sum of all components) and *components* is a list of + per-class Values (one per class + remainder). + """ + import sympy as _sympy + + subs = add_emitter_subs(emitter, dynamic_values) + classes = uniform_sym_classes or get_uniformity_classes(emitter) + parts = decompose_affine_by_uniformity(expr, classes) + + zero = _sympy.sympify(0) + component_values = [gen_sympy_index(subs, p) for p in parts] + + # Combine: sum uniform components first (SGPR + SGPR stays SGPR), + # then add the divergent remainder last (VGPR + SGPR). + overflow_flags = ( + arith_d.IntegerOverflowFlags.nsw | arith_d.IntegerOverflowFlags.nuw + ) + uniform_sum = None + for cv in component_values[:-1]: + if _is_zero(cv): + continue + if uniform_sum is None: + uniform_sum = cv + else: + uniform_sum = arith_d.addi(uniform_sum, cv, overflow_flags=overflow_flags) + + remainder = component_values[-1] + + if uniform_sum is None: + combined = remainder + elif _is_zero(remainder): + combined = uniform_sum + else: + combined = arith_d.addi(remainder, uniform_sum, overflow_flags=overflow_flags) + + return combined, component_values + + +def _is_zero(val: Value) -> bool: + """Return True if *val* is a constant-zero index.""" + if not hasattr(val, "owner") or not hasattr(val.owner, "opview"): + return False + op = val.owner.opview + if isinstance(op, arith_d.ConstantOp): + v = op.attributes["value"] + if isinstance(v, IntegerAttr) and int(v) == 0: + return True + return False + + _emulate_ceildiv = bool(int(environ.get("WAVE_EMULATE_CEILDIV", 0))) _use_affine_expr = bool(int(environ.get("WAVE_USE_AFFINE_EXPR", 1))) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index a6fd84eb5..2c9cb0b33 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -62,7 +62,7 @@ ) from ...wave.utils.general_utils import get_fastest_index, infer_dim, linearize_index from ...wave.utils.mapping_utils import transform_index_on_mapping -from ...wave.utils.symbol_utils import safe_subs +from ...wave.utils.symbol_utils import decompose_affine_by_uniformity, safe_subs from .emitter import ( WaveEmitter, add_emitter_subs, @@ -71,8 +71,10 @@ cast_py_value, cast_vector, gen_sympy_index, + gen_sympy_index_decomposed, get_constant_attr, get_type_or_element_type, + get_uniformity_classes, handle_op, ) @@ -105,26 +107,17 @@ def _simplify(expr): return sympy.simplify(expr) -def _split_index(src: IndexExpr | int) -> tuple[IndexExpr, IndexExpr]: - """ - Split index expr into thread-dependent and thread-independent parts - """ - subs_wg = {WORKGROUP_0: 0, WORKGROUP_1: 0, WORKGROUP_2: 0} - # Replace all wg symbols with 0s to get thread-dependent index. - # All dynamic values will also be part of thread-index. - thread_dependent_index = safe_subs(src, subs_wg) +_WG_SYMS = {WORKGROUP_0, WORKGROUP_1, WORKGROUP_2} - # Compute thread-independent index as `orig_index - thread_dependent_index` - # All thread symbols and dynamic should cancel-out in the result. - thread_independent_index = _simplify(src - thread_dependent_index) - if thread_independent_index.free_symbols - set(subs_wg.keys()): - # If we have any symbols besides wg symbols, means some thread or - # dynamic symbols were not canceled out, use the entire index as - # thread dependent index. - thread_independent_index = sympy.sympify(0) - thread_dependent_index = src - return thread_independent_index, thread_dependent_index +def _split_index(src: IndexExpr | int) -> tuple[IndexExpr, IndexExpr]: + """Split index expr into workgroup and thread-dependent parts. + + Thin wrapper around :func:`decompose_affine_by_uniformity` with a + single symbol class (workgroup symbols). + """ + parts = decompose_affine_by_uniformity(src, [_WG_SYMS]) + return parts[0], parts[1] def _extract0(src): @@ -151,15 +144,82 @@ def _build_start_indices( emitter: WaveEmitter, src_indices: dict[IndexExpr, IndexSequence | IndexExpr], dynamic_values: dict[IndexExpr, Any] = {}, -) -> tuple[list[OpResult], list[OpResult], list[OpResult]]: + uniform_sym_classes: list[set] | None = None, +) -> tuple: + """Build MLIR index values with N-way uniformity decomposition. + + When *uniform_sym_classes* is ``None`` (default), performs the legacy + two-way split (workgroup / thread) and returns a 3-tuple:: + + (full_indices, wg_indices, thread_indices) + + When *uniform_sym_classes* is a list of symbol sets (ordered + most-uniform first, e.g. ``[induction_var_syms]``), prepends the + workgroup class automatically and returns an ``(n + 2)``-tuple:: + + (full_indices, wg_indices, class_0_indices, ..., thread_indices) + """ start_indices = _get_start_indices(src_indices) - split_indices = [_split_index(i) for i in start_indices] subs = add_emitter_subs(emitter, dynamic_values) indices = [gen_sympy_index(subs, i) for i in start_indices] - indices_wg = [gen_sympy_index(subs, i[0]) for i in split_indices] - indices_th = [gen_sympy_index(subs, i[1]) for i in split_indices] - return indices, indices_wg, indices_th + classes = [_WG_SYMS] + (uniform_sym_classes or []) + decomposed = [ + decompose_affine_by_uniformity(i, classes) for i in start_indices + ] + + n_parts = len(classes) + 1 # one per class + remainder + parts: list[list[OpResult]] = [] + for k in range(n_parts): + parts.append([gen_sympy_index(subs, d[k]) for d in decomposed]) + + if not uniform_sym_classes: + return indices, parts[0], parts[1] + + return (indices,) + tuple(parts) + + +def _compute_linear_offset( + indices: list[Value | int], + strides: list[Value], +) -> Value | None: + """Linearize per-dimension index values with strides into a scalar offset. + + Returns *None* when ``indices`` is empty. + """ + overflow_flags = arith_d.IntegerOverflowFlags.nsw + offset = None + for idx, stride in zip(indices, strides): + if isinstance(idx, int): + idx = arith_d.constant(IndexType.get(), idx) + off = arith_d.muli(idx, stride, overflow_flags=overflow_flags) + if offset is None: + offset = off + else: + offset = arith_d.addi(offset, off, overflow_flags=overflow_flags) + return offset + + +def _apply_uniform_offsets( + offset_th: Value, + uniform_parts: list[list[Value]], + strides: list[Value], +) -> Value: + """Add uniform (SGPR-eligible) contributions to the thread offset. + + Each entry in *uniform_parts* is a list of per-dimension index values + for one uniformity class. They are linearized with *strides* and + added to *offset_th* as separate ``arith.addi`` ops so the backend + can keep them in SGPRs (e.g. fold into ``buffer_load`` soffset). + """ + overflow_flags = arith_d.IntegerOverflowFlags.nsw + for unif_indices in uniform_parts: + unif_offset = _compute_linear_offset(unif_indices, strides) + if unif_offset is not None and _get_constant_value(unif_offset) != 0: + offset_th = arith_d.addi( + offset_th, unif_offset, overflow_flags=overflow_flags + ) + return offset_th def _get_symbolic_shape(node: fx.Node) -> tuple[IndexExpr]: @@ -469,6 +529,7 @@ def _create_vec_read_write( memory: CustomOp, mask: Optional[Value], node_index: Optional[IndexSequence] = None, + uniform_parts: list[list[Value]] | None = None, ) -> Optional[Value]: is_read = value is None uint32 = IntegerType.get_signless(32) @@ -512,6 +573,8 @@ def extract(vec, ind): mem, start_indices_wg, start_indices_th, strides ) mem = _cast_buffer_and_encode_stride(mem, strides, element_type, emitter) + if uniform_parts: + offset_th = _apply_uniform_offsets(offset_th, uniform_parts, strides) if linearize_shared_mem: mem = _linearize_shared_mem(mem) linearized_index = { @@ -548,6 +611,8 @@ def extract(vec, ind): mem, start_indices_wg, start_indices_th, strides ) mem = _cast_buffer_and_encode_stride(mem, strides, element_type, emitter) + if uniform_parts: + offset_th = _apply_uniform_offsets(offset_th, uniform_parts, strides) indices = [offset_th] if buffer_ops_enabled else start_indices @@ -711,9 +776,21 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): else: mask = _build_mask(emitter, index, elements_per_thread, bounds) - start_indices, start_indices_wg, start_indices_th = _build_start_indices( - emitter, index, dynamic_vals_map_start - ) + induction_vars = set(emitter.get_induction_vars_and_syms()[1]) + uniform_parts: list[list[Value]] = [] + if induction_vars: + start_indices, start_indices_wg, *uniform_parts, start_indices_th = ( + _build_start_indices( + emitter, + index, + dynamic_vals_map_start, + uniform_sym_classes=[induction_vars], + ) + ) + else: + start_indices, start_indices_wg, start_indices_th = _build_start_indices( + emitter, index, dynamic_vals_map_start + ) use_llvm_load = flags != MemoryAccessFlags.NONE if use_llvm_load: @@ -738,6 +815,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): get_custom(memory), mask, node_index=index, + uniform_parts=uniform_parts or None, ) emitter.bind_node_proxy(node, IRProxyValue(result)) @@ -802,9 +880,21 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): else: mask = _build_mask(emitter, index, elements_per_thread, bounds) - start_indices, start_indices_wg, start_indices_th = _build_start_indices( - emitter, index, dynamic_vals_map_start - ) + induction_vars = set(emitter.get_induction_vars_and_syms()[1]) + uniform_parts: list[list[Value]] = [] + if induction_vars: + start_indices, start_indices_wg, *uniform_parts, start_indices_th = ( + _build_start_indices( + emitter, + index, + dynamic_vals_map_start, + uniform_sym_classes=[induction_vars], + ) + ) + else: + start_indices, start_indices_wg, start_indices_th = _build_start_indices( + emitter, index, dynamic_vals_map_start + ) use_llvm_store = flags != MemoryAccessFlags.NONE if use_llvm_store: @@ -825,6 +915,7 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): get_custom(memory), mask, node_index=index, + uniform_parts=uniform_parts or None, ) @@ -1070,13 +1161,24 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): store_type = VectorType.get((elements_per_thread,), element_type) - src_index, src_index_wg, src_index_th = _build_start_indices( - emitter, new_src_idx, src_dynamic_vals_map_start - ) + induction_vars = set(emitter.get_induction_vars_and_syms()[1]) - ip = InsertionPoint.current + uniform_parts: list[list[Value]] = [] + if induction_vars: + src_index, src_index_wg, *uniform_parts, src_index_th = ( + _build_start_indices( + emitter, + new_src_idx, + src_dynamic_vals_map_start, + uniform_sym_classes=[induction_vars], + ) + ) + else: + src_index, src_index_wg, src_index_th = _build_start_indices( + emitter, new_src_idx, src_dynamic_vals_map_start + ) - induction_vars = set(emitter.get_induction_vars_and_syms()[1]) + ip = InsertionPoint.current # Hoist to the function level, if not using induction variables. if not any( @@ -1105,6 +1207,9 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): src, offset_th = _linearize_memref(src, src_index_wg, src_index_th, strides) src = _cast_buffer_and_encode_stride(src, strides, element_type, emitter) + if uniform_parts: + offset_th = _apply_uniform_offsets(offset_th, uniform_parts, strides) + # We previously checked mask is same for all elements, so we can use # elements_per_thread=1 to build the mask. mask = _build_mask( diff --git a/wave_lang/kernel/wave/utils/symbol_utils.py b/wave_lang/kernel/wave/utils/symbol_utils.py index 66b7f6c5a..4fd73ef10 100644 --- a/wave_lang/kernel/wave/utils/symbol_utils.py +++ b/wave_lang/kernel/wave/utils/symbol_utils.py @@ -21,6 +21,80 @@ ) +#################################################################### +# Affine-expression uniformity decomposition +#################################################################### + + +def decompose_affine_by_uniformity( + expr: IndexExpr | int, + symbol_classes: list[set], +) -> list[IndexExpr]: + """Decompose a polynomial expression into additive components by symbol uniformity. + + Given symbol classes ``[C0, C1, ..., Cn]`` ordered from most-uniform + (e.g. workgroup IDs) to least-uniform (e.g. induction variables), + produces ``n + 1`` components ``[part_0, ..., part_n, remainder]`` + such that ``sum(parts) == expr``. + + * ``part_k`` ideally depends only on symbols from ``C_k``. + * ``remainder`` contains thread-dependent terms and constants. + * Cross-class terms (e.g. ``WG * IV``) are cascaded to the next + less-uniform class during validation. + + This generalises two-way (workgroup / thread) and three-way + (workgroup / induction-var / thread) index splits into a single + N-way routine. + + Args: + expr: Sympy expression (or plain int) to decompose. + symbol_classes: Ordered list of symbol sets, most-uniform first. + + Returns: + List of ``n + 1`` IndexExprs (one per class + remainder). + """ + zero = sympy.sympify(0) + n = len(symbol_classes) + + if isinstance(expr, (int, float)): + return [zero] * n + [sympy.sympify(expr)] + + if not isinstance(expr, sympy.Basic): + expr = sympy.sympify(expr) + + # Progressive remainders: r[k] = expr with classes 0..k-1 zeroed out. + zero_subs: dict = {} + remainders = [expr] + for cls in symbol_classes: + zero_subs.update({s: 0 for s in cls}) + remainders.append(safe_subs(expr, zero_subs)) + + # component[k] = r[k] - r[k+1]: the part attributable to class k. + components: list[IndexExpr] = [] + for k in range(n): + components.append(sympy.simplify(remainders[k] - remainders[k + 1])) + components.append(remainders[n]) # remainder + + # Cascade validation: component[k] must only contain symbols from + # classes 0..k. If it has symbols from "below", merge into the next + # less-uniform component. + allowed: set = set() + for k in range(n): + allowed = allowed | symbol_classes[k] + if components[k] == zero: + continue + actual = ( + components[k].free_symbols + if isinstance(components[k], sympy.Basic) + else set() + ) + if actual - allowed: + components[k + 1] = sympy.simplify(components[k + 1] + components[k]) + components[k] = zero + + return components + + #################################################################### # Interval-arithmetic simplification for floor/Mod expressions. ####################################################################