From 495b17c4f8971ef306bd4a3109440284e031b8b9 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Sun, 1 Mar 2026 19:25:25 -0600 Subject: [PATCH 01/20] working: no valu Signed-off-by: Sanket Pandit --- .../kernel/compiler/wave_codegen/emitter.py | 17 +- .../compiler/wave_codegen/read_write.py | 286 +++++++++++++++++- wave_lang/kernel/lang/global_symbols.py | 7 + .../wave/analysis/compute_iv_strides.py | 237 +++++++++++++++ wave_lang/kernel/wave/compile.py | 4 + wave_lang/kernel/wave/constraints.py | 20 +- .../waveasm/Transforms/TranslateFromMLIR.h | 24 +- waveasm/lib/Transforms/TranslateFromMLIR.cpp | 93 +++++- 8 files changed, 672 insertions(+), 16 deletions(-) create mode 100644 wave_lang/kernel/wave/analysis/compute_iv_strides.py diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index c97d311582..a3783488e2 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -137,6 +137,14 @@ def emit_program_invariants(self): ), ] + threads_per_wave = self.hardware_constraint.threads_per_wave + tpw = arith_d.constant(IndexType.get(), threads_per_wave) + self.wave_ids = [ + arith_d.divui(self.thread_ids[0], tpw), + self.thread_ids[1], + self.thread_ids[2], + ] + def emit_func(self) -> Operation: bindings = self.root_sig.sig.linear_bindings @@ -607,7 +615,11 @@ def add_emitter_subs( arith_d.constant(IndexType.get(), 0), # DEVICE_DIM_2 ] all_symbols = ( - emitter.thread_ids + emitter.workgroup_ids + device_zeros + induction_vars + emitter.thread_ids + + emitter.workgroup_ids + + device_zeros + + emitter.wave_ids + + induction_vars ) dynamics = dict( zip( @@ -621,6 +633,9 @@ def add_emitter_subs( DEVICE_DIM_0, DEVICE_DIM_1, DEVICE_DIM_2, + WAVE_ID_0, + WAVE_ID_1, + WAVE_ID_2, ] + induction_var_syms, all_symbols, diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 147b27704c..a8b51d7a6d 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -102,7 +102,14 @@ def _split_index( """ Split index expr into thread-dependent and thread-independent parts """ - subs_wg = {WORKGROUP_0: 0, WORKGROUP_1: 0, WORKGROUP_2: 0} + subs_wg = { + WORKGROUP_0: 0, + WORKGROUP_1: 0, + WORKGROUP_2: 0, + WAVE_ID_0: 0, + WAVE_ID_1: 0, + WAVE_ID_2: 0, + } # Replace all wg symbols with 0s to get thread-dependent index. # All dynamic values will also be part of thread-index. thread_dependent_index = safe_subs(src, subs_wg) @@ -647,6 +654,257 @@ def extract(vec, ind): return +_WAVEASM_UNIFORM = { + WORKGROUP_0: 0, + WORKGROUP_1: 0, + WORKGROUP_2: 0, + THREAD_1: 0, + THREAD_2: 0, + WAVE_ID_0: 0, + WAVE_ID_1: 0, + WAVE_ID_2: 0, +} +_WAVEASM_UNIFORM_KEYS = set(_WAVEASM_UNIFORM.keys()) + + +_linearize_cache: dict = {} + + +def _linearize_read_waveasm( + emitter: WaveEmitter, + mem: Value, + node_index: Optional[dict], + dynamic_values: dict[IndexExpr, Any], + symbolic_shape: tuple[IndexExpr, ...], +) -> tuple[Value, Value]: + """ + Linearize a global read for the WaveASM backend, treating THREAD_1/2 + as wave-uniform (SRD base) so the per-lane voffset depends only on + THREAD_0. Returns (linearized_mem, th_offset). + """ + kb_type = MemRefType(mem.type) + phys_strides, _ = kb_type.get_strides_and_offset() + overflow_flags = arith_d.IntegerOverflowFlags.nsw + start_exprs = _get_start_indices(node_index) + subs = add_emitter_subs(emitter, dynamic_values) + + wg_offset = None + th_offset = None + for expr, ps in zip(start_exprs, phys_strides): + th_expr = safe_subs(expr, _WAVEASM_UNIFORM) + wg_expr = sympy.expand(expr - th_expr) + if wg_expr.free_symbols - _WAVEASM_UNIFORM_KEYS: + wg_expr = sympy.sympify(0) + th_expr = expr + + wg_val = gen_sympy_index(subs, wg_expr) + th_val = gen_sympy_index(subs, th_expr) + stride_val = arith_d.constant(IndexType.get(), ps) + + wg_term = arith_d.muli(wg_val, stride_val, overflow_flags=overflow_flags) + th_term = arith_d.muli(th_val, stride_val, overflow_flags=overflow_flags) + wg_offset = ( + wg_term + if wg_offset is None + else arith_d.addi(wg_offset, wg_term, overflow_flags=overflow_flags) + ) + th_offset = ( + th_term + if th_offset is None + else arith_d.addi(th_offset, th_term, overflow_flags=overflow_flags) + ) + + if not hasattr(emitter, "_linearize_cache"): + emitter._linearize_cache = {} + cache_key = mem + if cache_key in emitter._linearize_cache: + return emitter._linearize_cache[cache_key], th_offset + + max_buf = _get_max_buffer_size(kb_type.element_type) - 1 + dyn_val = ShapedType.get_dynamic_size() + result_type = MemRefType.get( + [max_buf], + kb_type.element_type, + layout=Attribute.parse("strided<[1], offset: ?>"), + memory_space=kb_type.memory_space, + ) + linearized_mem = memref_d.reinterpret_cast( + result_type, + mem, + offsets=[wg_offset], + sizes=[], + strides=[], + static_offsets=[dyn_val], + static_sizes=[max_buf], + static_strides=[1], + ) + emitter._linearize_cache[cache_key] = linearized_mem + return linearized_mem, th_offset + + +def _get_or_create_flat_memref( + emitter: WaveEmitter, + mem: Value, +) -> Value: + """Return a rank-1 view of *mem* with offset 0 (pure shape change). + + All reads from the same source buffer share one reinterpret_cast, + so the backend maps them all to a single SRD — no per-read SRD copies. + """ + if not hasattr(emitter, "_flat_memref_cache"): + emitter._flat_memref_cache = {} + key = id(mem) + if key in emitter._flat_memref_cache: + return emitter._flat_memref_cache[key] + + kb_type = MemRefType(mem.type) + max_buf = _get_max_buffer_size(kb_type.element_type) - 1 + result_type = MemRefType.get( + [max_buf], + kb_type.element_type, + layout=Attribute.parse("strided<[1], offset: 0>"), + memory_space=kb_type.memory_space, + ) + flat = memref_d.reinterpret_cast( + result_type, + mem, + offsets=[], + sizes=[], + strides=[], + static_offsets=[0], + static_sizes=[max_buf], + static_strides=[1], + ) + emitter._flat_memref_cache[key] = flat + return flat + + +def _emit_iv_split_read( + emitter: WaveEmitter, + node: fx.Node, + index: dict[IndexExpr, IndexSequence | IndexExpr], + kb_src: Value, + input_shape: tuple[IndexExpr, ...], + vector_type: VectorType, + dynamic_vals_map_start: dict[IndexExpr, Any], +) -> Optional[Value]: + """ + Emit a VALU-free global read inside a tiled loop. + + Follows the AITER methodology: + 1. ONE shared rank-1 memref per source buffer (no per-read SRD copies). + 2. Full linearized offset at IV=0 → voffset, hoisted before the loop. + 3. IV * k_stride added inside loop → BufferLoadStrengthReduction + promotes it to soffset, yielding zero in-loop VALU. + + Uses a 3-point linearity check on the post-mapping codegen-time indices + so it works even when the pre-codegen pass couldn't tag the node. + """ + iv_vals, iv_syms = emitter.get_induction_vars_and_syms() + if not iv_syms: + return None + + kb_type = MemRefType(kb_src.type) + if kb_type.rank == 0: + return None + + ip = InsertionPoint.current + owner = ip.block.owner + if isinstance(owner, func_d.FuncOp): + return None + + # --- Determine k_stride_per_iv --- + if getattr(node, "iv_linear", False): + k_stride_per_iv = node.iv_k_stride + else: + phys_strides, _ = kb_type.get_strides_and_offset() + dyn_sentinel = ShapedType.get_dynamic_stride_or_offset() + if any(s == dyn_sentinel for s in phys_strides): + return None + + step_int = _get_constant_value(owner.operands[2]) + if step_int is None or step_int <= 0: + return None + + start_exprs = _get_start_indices(index) + if len(start_exprs) != len(phys_strides): + return None + + all_zero = { + THREAD_0: 0, + THREAD_1: 0, + THREAD_2: 0, + WORKGROUP_0: 0, + WORKGROUP_1: 0, + WORKGROUP_2: 0, + WAVE_ID_0: 0, + WAVE_ID_1: 0, + WAVE_ID_2: 0, + } + iv_sym = iv_syms[0] + try: + d1 = d2 = 0 + for expr, ps in zip(start_exprs, phys_strides): + v0 = int(safe_subs(expr, {**all_zero, iv_sym: 0})) + v1 = int(safe_subs(expr, {**all_zero, iv_sym: step_int})) + v2 = int(safe_subs(expr, {**all_zero, iv_sym: 2 * step_int})) + d1 += (v1 - v0) * ps + d2 += (v2 - v1) * ps + except (TypeError, ValueError, sympy.SympifyError): + return None + + if d1 != d2 or d1 == 0: + return None + k_stride_per_iv, rem = divmod(d1, step_int) + if rem != 0: + return None + + # --- Zero IV in index expressions --- + iv_zero_subs = {sym: 0 for sym in iv_syms} + index_no_iv = {} + for dim, seq in index.items(): + start = _get_start_index(seq) + new_start = safe_subs(start, iv_zero_subs) + if isinstance(seq, IndexSequence): + index_no_iv[dim] = IndexSequence(new_start, seq.size) + else: + index_no_iv[dim] = new_start + + # --- Hoist: compute full linearized voffset at IV=0, create shared flat memref --- + kb_type = MemRefType(kb_src.type) + phys_strides, _ = kb_type.get_strides_and_offset() + hoist_ip = InsertionPoint(owner) + subs_map = add_emitter_subs(emitter, dynamic_vals_map_start) + overflow_flags = arith_d.IntegerOverflowFlags.nsw + + with hoist_ip: + flat_mem = _get_or_create_flat_memref(emitter, kb_src) + + iv0_exprs = _get_start_indices(index_no_iv) + lin_offset = None + for expr, ps in zip(iv0_exprs, phys_strides): + val = gen_sympy_index(subs_map, expr) + stride_c = arith_d.constant(IndexType.get(), ps) + term = arith_d.muli(val, stride_c, overflow_flags=overflow_flags) + lin_offset = ( + term + if lin_offset is None + else arith_d.addi(lin_offset, term, overflow_flags=overflow_flags) + ) + + # --- In-loop: total = hoisted_voffset + IV * k_stride --- + iv_sym = iv_syms[0] + iv_mlir = subs_map.get(iv_sym) + if iv_mlir is None: + return None + + k_stride_val = gen_sympy_index(subs_map, sympy.sympify(k_stride_per_iv)) + iv_offset = arith_d.muli(iv_mlir, k_stride_val, overflow_flags=overflow_flags) + total_offset = arith_d.addi(lin_offset, iv_offset, overflow_flags=overflow_flags) + + return vector_d.load(vector_type, flat_mem, [total_offset]) + + def _build_mask_with_mapping( emitter: WaveEmitter, mapping: IndexMapping, @@ -736,11 +994,35 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): else: mask = _build_mask(emitter, index, elements_per_thread, bounds) + is_global = get_custom(memory).type.address_space != SHARED_ADDRESS_SPACE + use_llvm_load = flags != MemoryAccessFlags.NONE + + if ( + is_global + and mask is None + and not use_llvm_load + and emitter.options.use_wave_asm_backend + and not read_meets_hw_transpose_requirements( + get_custom(node), emitter.constraints, emitter.options.target + ) + ): + result = _emit_iv_split_read( + emitter, + node, + index, + kb_src, + input_shape, + vector_type, + dynamic_vals_map_start, + ) + if result is not None: + emitter.bind_node_proxy(node, IRProxyValue(result)) + return + start_indices, start_indices_wg, start_indices_th = _build_start_indices( emitter, index, dynamic_vals_map_start ) - use_llvm_load = flags != MemoryAccessFlags.NONE if use_llvm_load: result = _create_llvm_read_write( kb_src, kb_ir_type, start_indices, vector_type, flags diff --git a/wave_lang/kernel/lang/global_symbols.py b/wave_lang/kernel/lang/global_symbols.py index 0da42050a2..f0add29b08 100644 --- a/wave_lang/kernel/lang/global_symbols.py +++ b/wave_lang/kernel/lang/global_symbols.py @@ -41,6 +41,13 @@ def get_workgroup_symbol(i: int): THREAD_1 = index_symbol(THREAD_SYMBOL_NAMES[1]) THREAD_2 = index_symbol(THREAD_SYMBOL_NAMES[2]) +# Wave-uniform symbols: same value for all lanes in a wave, SGPR-eligible. +# WAVE_ID_N = floor(linearized_thread_id / threads_per_wave) projected onto +# workgroup dimension N. Expanded to actual MLIR values at codegen time. +WAVE_ID_0 = index_symbol("$WAVE0") +WAVE_ID_1 = index_symbol("$WAVE1") +WAVE_ID_2 = index_symbol("$WAVE2") + # Input selector symbol for selecting input from different tensors. INPUT_SELECTOR = index_symbol("$INPUT_SELECTOR") diff --git a/wave_lang/kernel/wave/analysis/compute_iv_strides.py b/wave_lang/kernel/wave/analysis/compute_iv_strides.py new file mode 100644 index 0000000000..a115567a48 --- /dev/null +++ b/wave_lang/kernel/wave/analysis/compute_iv_strides.py @@ -0,0 +1,237 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Pre-codegen pass that computes the linearized K-stride for global reads +inside tiled loops. Tags eligible read nodes with ``iv_k_stride`` (int) +and ``iv_linear = True`` so that the codegen emitter can produce +VALU-free addressing (voffset + soffset) without any numerical estimation +at emit time. + +The analysis operates on the *pre-mapping* index expressions (single-variable +sympy objects), which are tractable for linearity checks. The post-mapping +composition is intractable but not needed here. +""" + +from __future__ import annotations + +from typing import Optional + +import sympy +import torch.fx as fx + +from ..._support.indexing import ( + IndexExpr, + IndexingContext, + IndexSequence, + IndexSymbol, + subs_idxc, +) +from ..._support.tracing import CapturedTrace +from ...lang.global_symbols import SHARED_ADDRESS_SPACE +from ...lang.wave_types import IndexMapping +from ...ops.wave_ops import Read, get_custom +from ..constraints import Constraint, TilingConstraint +from ..utils.general_utils import infer_dim +from ..utils.symbol_utils import safe_subs +from ...compiler.utils import strides_from_symbolic_shape + +import logging + +logger = logging.getLogger(__name__) + + +def _get_tiling_constraints( + constraints: list[Constraint], +) -> list[TilingConstraint]: + return [c for c in constraints if isinstance(c, TilingConstraint)] + + +def _linearity_check_3pt( + expr: IndexExpr, sym: IndexSymbol, step: int +) -> Optional[IndexExpr]: + """ + 3-point linearity check on *expr* w.r.t. *sym*. + Returns the constant stride per unit of *sym*, or None if non-linear. + All other free symbols are zeroed for the check. + """ + others = {s: 0 for s in expr.free_symbols if s != sym} + try: + v0 = safe_subs(expr, {**others, sym: 0}) + v1 = safe_subs(expr, {**others, sym: step}) + v2 = safe_subs(expr, {**others, sym: 2 * step}) + except (TypeError, ValueError, sympy.SympifyError): + return None + d1 = sympy.simplify(v1 - v0) + d2 = sympy.simplify(v2 - v1) + if sympy.simplify(d2 - d1) != 0: + return None + if step == 0: + return None + return sympy.simplify(d1 / step) + + +def _compute_stride_for_read( + node: fx.Node, + tiling_constraints: list[TilingConstraint], + idxc: IndexingContext, +) -> Optional[IndexExpr]: + """ + For a single Read node, compute the linearized K-stride per IV step. + Returns the integer stride, or None if the read is ineligible. + """ + custom = get_custom(node) + if not isinstance(custom, Read): + return None + + if not hasattr(node, "index") or node.index is None: + return None + + index = node.index # dict[IndexSymbol, IndexSequence] + mapping: Optional[IndexMapping] = custom.mapping + + # Identify which tiling constraint's IV appears in the index expressions. + iv_constraint = None + iv_sym = None + for tc in tiling_constraints: + if tc.induction_var is None: + continue + for dim_sym, seq in index.items(): + start = seq.start if isinstance(seq, IndexSequence) else seq + if isinstance(start, (int, float)): + continue + if tc.induction_var in start.free_symbols: + if iv_constraint is not None and iv_constraint is not tc: + # Multiple IVs — bail + return None + iv_constraint = tc + iv_sym = tc.induction_var + break + + if iv_constraint is None or iv_sym is None: + return None + + tiled_dim = iv_constraint.dim + # Step 1: Find the tiled dimension in the index dict and verify linearity. + # + # The index dict maps logical dimensions to IndexSequence objects. + # The tiled dimension's start should be linear in IV. + tiled_dim_key = None + for dim_sym in index: + base = infer_dim(dim_sym) + if base == infer_dim(tiled_dim): + tiled_dim_key = dim_sym + break + + if tiled_dim_key is None: + return None + + tiled_seq = index[tiled_dim_key] + tiled_start = tiled_seq.start if isinstance(tiled_seq, IndexSequence) else tiled_seq + + stride_into_iter = _linearity_check_3pt(tiled_start, iv_sym, 1) + if stride_into_iter is None: + return None + + # Get the memory node's symbolic shape + memory_node = custom.memory + mem_custom = get_custom(memory_node) + mem_shape = mem_custom.type.symbolic_shape + + # Get physical strides from the memory shape + phys_strides = strides_from_symbolic_shape(idxc, mem_shape, allow_mixed_shapes=True) + if phys_strides is None: + return None + # Substitute concrete values + phys_strides_expr = [subs_idxc(s) for s in phys_strides] + + # Step 2: If mapping present, verify mapping linearity and compute per-dim deltas. + if mapping is not None: + # The output_mapping maps logical dims to iterators. + # Find which iterator carries the tiled axis. + k_iter = None + for dim_sym, iter_expr in mapping.output_mapping.items(): + if infer_dim(dim_sym) == infer_dim(tiled_dim): + k_iter = iter_expr + break + + if k_iter is None: + return None + + # The IV advances the iterator by stride_into_iter per step. + iv_stride_into_iter = stride_into_iter + + # 3-point check on each physical dimension's input_mapping expression. + all_iters = list(mapping.iters.keys()) + other_iters = [it for it in all_iters if it != k_iter] + + per_dim_delta = {} + for dim_sym, expr in mapping.input_mapping.items(): + zero_others = {it: 0 for it in other_iters} + d = _linearity_check_3pt( + safe_subs(expr, zero_others), k_iter, iv_stride_into_iter + ) + if d is None: + return None + per_dim_delta[dim_sym] = d + + # Compute linearized stride + k_stride = 0 + for dim_sym, ps in zip(mem_shape, phys_strides_expr): + base_dim = infer_dim(dim_sym) + delta = 0 + for pd_dim, pd_val in per_dim_delta.items(): + if infer_dim(pd_dim) == base_dim: + delta = pd_val + break + k_stride += delta * ps + + else: + # No mapping: logical dimensions ARE physical dimensions. + # The tiled dimension advances by stride_into_iter, others by 0. + k_stride = 0 + for dim_sym, ps in zip(mem_shape, phys_strides_expr): + if infer_dim(dim_sym) == infer_dim(tiled_dim): + k_stride += stride_into_iter * ps + + k_stride = sympy.simplify(k_stride) + return k_stride if k_stride != 0 else None + + +def compute_iv_strides( + trace: CapturedTrace, + constraints: list[Constraint], +): + """ + Walk the graph and tag each eligible global-memory Read with: + - node.iv_k_stride (int) — linearized bytes-offset per IV step + - node.iv_linear (bool) — True + """ + tiling_constraints = _get_tiling_constraints(constraints) + if not tiling_constraints: + return + + idxc = IndexingContext.current() + + def tag_node(node: fx.Node) -> bool: + custom = get_custom(node) + if not isinstance(custom, Read): + return False + # Only global reads + if hasattr(custom, "memory_type") and hasattr( + custom.memory_type, "address_space" + ): + if custom.memory_type.address_space == SHARED_ADDRESS_SPACE: + return False + + stride = _compute_stride_for_read(node, tiling_constraints, idxc) + if stride is not None: + node.iv_k_stride = stride + node.iv_linear = True + logger.debug(f"Tagged read {node.name} with iv_k_stride={stride}") + return False + + trace.walk(tag_node) diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 8d571b6edf..4a8d32b7a2 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -37,6 +37,7 @@ set_node_indices_water_checked, set_post_expansion_indices, ) +from .analysis.compute_iv_strides import compute_iv_strides from .analysis.partition_strided_operators import ( merge_contiguous_reads, partition_gather_like_ops, @@ -569,6 +570,9 @@ def build_graph_passes( options.minimize_shared_allocs, ), ] + # Run IV-stride analysis after scheduling so loop-carried IV symbols are + # present in read indices; earlier placement cannot see them reliably. + graph_passes += [partial(compute_iv_strides, trace, launchable.constraints)] graph_passes += [ partial( add_shared_memory_barriers, diff --git a/wave_lang/kernel/wave/constraints.py b/wave_lang/kernel/wave/constraints.py index d49538b8a8..37e86810f7 100644 --- a/wave_lang/kernel/wave/constraints.py +++ b/wave_lang/kernel/wave/constraints.py @@ -881,16 +881,22 @@ def set_wave_id_from_hardware_and_workgroup_constraint( The wave_id is the same as the thread_id, with the exception of wave_id[0] = thread_id[0] / threads_per_wave This is a convention that we adopt. + + Uses first-class WAVE_ID_N symbols so that index expressions stay + simple (no floor(THREAD/64)) and can be recognised as wave-uniform + by the read/write lowering. """ old_wave_id = self.wave_id assert self.dim == workgroup_constraint.dim, "Dimension mismatch" - self.wave_id = hardware_constraint.get_thread_id_from_workgroup_dim( - workgroup_constraint.workgroup_dim - ) - # Only handling the wg_dim_0 case because Wave assumes - # all threads in a wave are handled in wg_dim_0. - if workgroup_constraint.workgroup_dim == 0: - self.wave_id = floor(self.wave_id / hardware_constraint.threads_per_wave) + match workgroup_constraint.workgroup_dim: + case 0: + self.wave_id = WAVE_ID_0 + case 1: + self.wave_id = WAVE_ID_1 + case 2: + self.wave_id = WAVE_ID_2 + case _: + raise ValueError("Invalid workgroup dimension. Expected 0, 1 or 2.") assert ( old_wave_id is None or self.wave_id == old_wave_id ), f"Conflicting preset wave_id old: {old_wave_id} new: {self.wave_id}" diff --git a/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h b/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h index ed47d37c79..f67770fc36 100644 --- a/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h +++ b/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h @@ -516,22 +516,38 @@ class TranslationContext { mlir::Value elementOffset; int64_t srcSrdBase; int64_t elementBytes; + mlir::Value computedSoffset; }; void setPendingSRDBaseAdjust(mlir::Value memref, mlir::Value elemOffset, int64_t srcSrdBase, int64_t elementBytes) { - pendingSRDBaseAdjustMap[memref] = {elemOffset, srcSrdBase, elementBytes}; + pendingSRDBaseAdjustMap[memref] = {elemOffset, srcSrdBase, elementBytes, + nullptr}; } - /// Get a pending SRD base adjustment (returns nullptr if none) - const PendingSRDBaseAdjust * - getPendingSRDBaseAdjust(mlir::Value memref) const { + /// Get a pending SRD base adjustment (mutable for caching) + PendingSRDBaseAdjust *getPendingSRDBaseAdjust(mlir::Value memref) { auto it = pendingSRDBaseAdjustMap.find(memref); if (it != pendingSRDBaseAdjustMap.end()) return &it->second; return nullptr; } + /// Reuse an already-computed adjusted SRD with equivalent adjustment params. + std::optional + findComputedAdjustedSRDIndex(mlir::Value elementOffset, int64_t srcSrdBase, + int64_t elementBytes) const { + for (const auto &entry : pendingSRDBaseAdjustMap) { + const auto &adj = entry.second; + if (adj.elementOffset != elementOffset || adj.srcSrdBase != srcSrdBase || + adj.elementBytes != elementBytes || !adj.computedSoffset) + continue; + if (auto psreg = llvm::dyn_cast(adj.computedSoffset.getType())) + return psreg.getIndex(); + } + return std::nullopt; + } + /// Clear a pending SRD base adjustment after it has been applied void clearPendingSRDBaseAdjust(mlir::Value memref) { pendingSRDBaseAdjustMap.erase(memref); diff --git a/waveasm/lib/Transforms/TranslateFromMLIR.cpp b/waveasm/lib/Transforms/TranslateFromMLIR.cpp index 37983f34d4..dda8b7fa98 100644 --- a/waveasm/lib/Transforms/TranslateFromMLIR.cpp +++ b/waveasm/lib/Transforms/TranslateFromMLIR.cpp @@ -437,6 +437,83 @@ Value lookupSRD(Value memref, TranslationContext &ctx, Location loc) { return PrecoloredSRegOp::create(builder, loc, sregType, 8, 4); } +/// Apply a pending SRD base adjustment from a linearized memref +/// reinterpret_cast. Allocates a fresh SRD slot, copies source base, +/// adds byte offset, and sets num_records + stride. +/// Caches the computed SRD in adj->computedSoffset for reuse. +static std::pair +applyPendingSRDBaseAdjust(TranslationContext::PendingSRDBaseAdjust *adj, + MemRefType memrefType, TranslationContext &ctx, + Location loc) { + (void)memrefType; + if (adj->computedSoffset) { + if (auto psreg = dyn_cast(adj->computedSoffset.getType())) { + int64_t idx = psreg.getIndex(); + auto srdType = ctx.createSRegType(4, 4); + auto srd = + PrecoloredSRegOp::create(ctx.getBuilder(), loc, srdType, idx, 4); + return {srd, idx}; + } + } + + if (auto reusedIdx = ctx.findComputedAdjustedSRDIndex( + adj->elementOffset, adj->srcSrdBase, adj->elementBytes)) { + auto &builder = ctx.getBuilder(); + auto srdType = ctx.createSRegType(4, 4); + Value srd = PrecoloredSRegOp::create(builder, loc, srdType, *reusedIdx, 4); + adj->computedSoffset = srd; + return {srd, *reusedIdx}; + } + + auto &builder = ctx.getBuilder(); + int64_t N = ctx.getNextSwizzleSRDIndex(); + auto *mlirCtx = builder.getContext(); + + std::string copyBase = "s_mov_b64 s[" + std::to_string(N) + ":" + + std::to_string(N + 1) + "], s[" + + std::to_string(adj->srcSrdBase) + ":" + + std::to_string(adj->srcSrdBase + 1) + "]"; + RawOp::create(builder, loc, copyBase); + + Value offsetVal = adj->elementOffset; + auto tmpType = PSRegType::get(mlirCtx, N + 3, 1); + if (isVGPRType(offsetVal.getType())) { + offsetVal = V_READFIRSTLANE_B32::create(builder, loc, tmpType, offsetVal); + } else { + offsetVal = S_MOV_B32::create(builder, loc, tmpType, offsetVal); + } + + auto elemSizeImm = ConstantOp::create( + builder, loc, ctx.createImmType(adj->elementBytes), adj->elementBytes); + auto hiType = PSRegType::get(mlirCtx, N + 2, 1); + auto loType = PSRegType::get(mlirCtx, N + 3, 1); + auto byteOffHi = + S_MUL_HI_U32::create(builder, loc, hiType, offsetVal, elemSizeImm); + auto byteOffLo = + S_MUL_I32::create(builder, loc, loType, offsetVal, elemSizeImm); + + auto sccType = ctx.createSRegType(); + auto base0Type = PSRegType::get(mlirCtx, N, 1); + auto base1Type = PSRegType::get(mlirCtx, N + 1, 1); + auto base0 = PrecoloredSRegOp::create(builder, loc, base0Type, N, 1); + auto base1 = PrecoloredSRegOp::create(builder, loc, base1Type, N + 1, 1); + S_ADD_U32::create(builder, loc, base0Type, sccType, base0, byteOffLo); + S_ADDC_U32::create(builder, loc, base1Type, sccType, base1, byteOffHi); + + int64_t srcBase = adj->srcSrdBase; + std::string copySize = "s_mov_b32 s" + std::to_string(N + 2) + ", s" + + std::to_string(srcBase + 2); + RawOp::create(builder, loc, copySize); + std::string copyStride = "s_mov_b32 s" + std::to_string(N + 3) + ", s" + + std::to_string(srcBase + 3); + RawOp::create(builder, loc, copyStride); + + auto srdType = ctx.createSRegType(4, 4); + Value srd = PrecoloredSRegOp::create(builder, loc, srdType, N, 4); + adj->computedSoffset = srd; + return {srd, N}; +} + SmallVector emitBufferLoads(Value srd, Value voffset, int64_t instOffset, int64_t numBytes, TranslationContext &ctx, Location loc) { @@ -948,7 +1025,13 @@ LogicalResult handleVectorLoad(Operation *op, TranslationContext &ctx) { // Global load - buffer_load_dwordx* with splitting for large vectors auto [voffset, instOffset] = computeVOffsetFromIndices(memrefType, loadOp.getIndices(), ctx, loc); - Value srd = lookupSRD(loadOp.getBase(), ctx, loc); + Value srd; + if (auto *adj = ctx.getPendingSRDBaseAdjust(loadOp.getBase())) { + auto [adjSrd, _] = applyPendingSRDBaseAdjust(adj, memrefType, ctx, loc); + srd = adjSrd; + } else { + srd = lookupSRD(loadOp.getBase(), ctx, loc); + } auto loadResults = emitBufferLoads(srd, voffset, instOffset, numBytes, ctx, loc); @@ -988,7 +1071,13 @@ LogicalResult handleVectorMaskedLoad(Operation *op, TranslationContext &ctx) { auto [voffset, instOffset] = computeVOffsetFromIndices( memrefType, maskedLoadOp.getIndices(), ctx, loc); - Value srd = lookupSRD(maskedLoadOp.getBase(), ctx, loc); + Value srd; + if (auto *adj = ctx.getPendingSRDBaseAdjust(maskedLoadOp.getBase())) { + auto [adjSrd, _] = applyPendingSRDBaseAdjust(adj, memrefType, ctx, loc); + srd = adjSrd; + } else { + srd = lookupSRD(maskedLoadOp.getBase(), ctx, loc); + } auto loadResults = emitBufferLoads(srd, voffset, instOffset, numBytes, ctx, loc); From 03ca30b732ede2cc34ad7252b1f3614893abf18d Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Sun, 1 Mar 2026 19:54:04 -0600 Subject: [PATCH 02/20] remove dead code Signed-off-by: Sanket Pandit --- .../compiler/wave_codegen/read_write.py | 175 ++++--------- .../wave/analysis/compute_iv_strides.py | 237 ------------------ wave_lang/kernel/wave/compile.py | 4 - .../waveasm/Transforms/TranslateFromMLIR.h | 24 +- waveasm/lib/Transforms/TranslateFromMLIR.cpp | 93 +------ 5 files changed, 48 insertions(+), 485 deletions(-) delete mode 100644 wave_lang/kernel/wave/analysis/compute_iv_strides.py diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index a8b51d7a6d..79c848cbe5 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -654,94 +654,6 @@ def extract(vec, ind): return -_WAVEASM_UNIFORM = { - WORKGROUP_0: 0, - WORKGROUP_1: 0, - WORKGROUP_2: 0, - THREAD_1: 0, - THREAD_2: 0, - WAVE_ID_0: 0, - WAVE_ID_1: 0, - WAVE_ID_2: 0, -} -_WAVEASM_UNIFORM_KEYS = set(_WAVEASM_UNIFORM.keys()) - - -_linearize_cache: dict = {} - - -def _linearize_read_waveasm( - emitter: WaveEmitter, - mem: Value, - node_index: Optional[dict], - dynamic_values: dict[IndexExpr, Any], - symbolic_shape: tuple[IndexExpr, ...], -) -> tuple[Value, Value]: - """ - Linearize a global read for the WaveASM backend, treating THREAD_1/2 - as wave-uniform (SRD base) so the per-lane voffset depends only on - THREAD_0. Returns (linearized_mem, th_offset). - """ - kb_type = MemRefType(mem.type) - phys_strides, _ = kb_type.get_strides_and_offset() - overflow_flags = arith_d.IntegerOverflowFlags.nsw - start_exprs = _get_start_indices(node_index) - subs = add_emitter_subs(emitter, dynamic_values) - - wg_offset = None - th_offset = None - for expr, ps in zip(start_exprs, phys_strides): - th_expr = safe_subs(expr, _WAVEASM_UNIFORM) - wg_expr = sympy.expand(expr - th_expr) - if wg_expr.free_symbols - _WAVEASM_UNIFORM_KEYS: - wg_expr = sympy.sympify(0) - th_expr = expr - - wg_val = gen_sympy_index(subs, wg_expr) - th_val = gen_sympy_index(subs, th_expr) - stride_val = arith_d.constant(IndexType.get(), ps) - - wg_term = arith_d.muli(wg_val, stride_val, overflow_flags=overflow_flags) - th_term = arith_d.muli(th_val, stride_val, overflow_flags=overflow_flags) - wg_offset = ( - wg_term - if wg_offset is None - else arith_d.addi(wg_offset, wg_term, overflow_flags=overflow_flags) - ) - th_offset = ( - th_term - if th_offset is None - else arith_d.addi(th_offset, th_term, overflow_flags=overflow_flags) - ) - - if not hasattr(emitter, "_linearize_cache"): - emitter._linearize_cache = {} - cache_key = mem - if cache_key in emitter._linearize_cache: - return emitter._linearize_cache[cache_key], th_offset - - max_buf = _get_max_buffer_size(kb_type.element_type) - 1 - dyn_val = ShapedType.get_dynamic_size() - result_type = MemRefType.get( - [max_buf], - kb_type.element_type, - layout=Attribute.parse("strided<[1], offset: ?>"), - memory_space=kb_type.memory_space, - ) - linearized_mem = memref_d.reinterpret_cast( - result_type, - mem, - offsets=[wg_offset], - sizes=[], - strides=[], - static_offsets=[dyn_val], - static_sizes=[max_buf], - static_strides=[1], - ) - emitter._linearize_cache[cache_key] = linearized_mem - return linearized_mem, th_offset - - def _get_or_create_flat_memref( emitter: WaveEmitter, mem: Value, @@ -813,51 +725,48 @@ def _emit_iv_split_read( if isinstance(owner, func_d.FuncOp): return None - # --- Determine k_stride_per_iv --- - if getattr(node, "iv_linear", False): - k_stride_per_iv = node.iv_k_stride - else: - phys_strides, _ = kb_type.get_strides_and_offset() - dyn_sentinel = ShapedType.get_dynamic_stride_or_offset() - if any(s == dyn_sentinel for s in phys_strides): - return None - - step_int = _get_constant_value(owner.operands[2]) - if step_int is None or step_int <= 0: - return None - - start_exprs = _get_start_indices(index) - if len(start_exprs) != len(phys_strides): - return None - - all_zero = { - THREAD_0: 0, - THREAD_1: 0, - THREAD_2: 0, - WORKGROUP_0: 0, - WORKGROUP_1: 0, - WORKGROUP_2: 0, - WAVE_ID_0: 0, - WAVE_ID_1: 0, - WAVE_ID_2: 0, - } - iv_sym = iv_syms[0] - try: - d1 = d2 = 0 - for expr, ps in zip(start_exprs, phys_strides): - v0 = int(safe_subs(expr, {**all_zero, iv_sym: 0})) - v1 = int(safe_subs(expr, {**all_zero, iv_sym: step_int})) - v2 = int(safe_subs(expr, {**all_zero, iv_sym: 2 * step_int})) - d1 += (v1 - v0) * ps - d2 += (v2 - v1) * ps - except (TypeError, ValueError, sympy.SympifyError): - return None - - if d1 != d2 or d1 == 0: - return None - k_stride_per_iv, rem = divmod(d1, step_int) - if rem != 0: - return None + # --- Determine k_stride_per_iv via 3-point linearity check --- + phys_strides, _ = kb_type.get_strides_and_offset() + dyn_sentinel = ShapedType.get_dynamic_stride_or_offset() + if any(s == dyn_sentinel for s in phys_strides): + return None + + step_int = _get_constant_value(owner.operands[2]) + if step_int is None or step_int <= 0: + return None + + start_exprs = _get_start_indices(index) + if len(start_exprs) != len(phys_strides): + return None + + all_zero = { + THREAD_0: 0, + THREAD_1: 0, + THREAD_2: 0, + WORKGROUP_0: 0, + WORKGROUP_1: 0, + WORKGROUP_2: 0, + WAVE_ID_0: 0, + WAVE_ID_1: 0, + WAVE_ID_2: 0, + } + iv_sym = iv_syms[0] + try: + d1 = d2 = 0 + for expr, ps in zip(start_exprs, phys_strides): + v0 = int(safe_subs(expr, {**all_zero, iv_sym: 0})) + v1 = int(safe_subs(expr, {**all_zero, iv_sym: step_int})) + v2 = int(safe_subs(expr, {**all_zero, iv_sym: 2 * step_int})) + d1 += (v1 - v0) * ps + d2 += (v2 - v1) * ps + except (TypeError, ValueError, sympy.SympifyError): + return None + + if d1 != d2 or d1 == 0: + return None + k_stride_per_iv, rem = divmod(d1, step_int) + if rem != 0: + return None # --- Zero IV in index expressions --- iv_zero_subs = {sym: 0 for sym in iv_syms} diff --git a/wave_lang/kernel/wave/analysis/compute_iv_strides.py b/wave_lang/kernel/wave/analysis/compute_iv_strides.py deleted file mode 100644 index a115567a48..0000000000 --- a/wave_lang/kernel/wave/analysis/compute_iv_strides.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright 2025 The IREE Authors -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -""" -Pre-codegen pass that computes the linearized K-stride for global reads -inside tiled loops. Tags eligible read nodes with ``iv_k_stride`` (int) -and ``iv_linear = True`` so that the codegen emitter can produce -VALU-free addressing (voffset + soffset) without any numerical estimation -at emit time. - -The analysis operates on the *pre-mapping* index expressions (single-variable -sympy objects), which are tractable for linearity checks. The post-mapping -composition is intractable but not needed here. -""" - -from __future__ import annotations - -from typing import Optional - -import sympy -import torch.fx as fx - -from ..._support.indexing import ( - IndexExpr, - IndexingContext, - IndexSequence, - IndexSymbol, - subs_idxc, -) -from ..._support.tracing import CapturedTrace -from ...lang.global_symbols import SHARED_ADDRESS_SPACE -from ...lang.wave_types import IndexMapping -from ...ops.wave_ops import Read, get_custom -from ..constraints import Constraint, TilingConstraint -from ..utils.general_utils import infer_dim -from ..utils.symbol_utils import safe_subs -from ...compiler.utils import strides_from_symbolic_shape - -import logging - -logger = logging.getLogger(__name__) - - -def _get_tiling_constraints( - constraints: list[Constraint], -) -> list[TilingConstraint]: - return [c for c in constraints if isinstance(c, TilingConstraint)] - - -def _linearity_check_3pt( - expr: IndexExpr, sym: IndexSymbol, step: int -) -> Optional[IndexExpr]: - """ - 3-point linearity check on *expr* w.r.t. *sym*. - Returns the constant stride per unit of *sym*, or None if non-linear. - All other free symbols are zeroed for the check. - """ - others = {s: 0 for s in expr.free_symbols if s != sym} - try: - v0 = safe_subs(expr, {**others, sym: 0}) - v1 = safe_subs(expr, {**others, sym: step}) - v2 = safe_subs(expr, {**others, sym: 2 * step}) - except (TypeError, ValueError, sympy.SympifyError): - return None - d1 = sympy.simplify(v1 - v0) - d2 = sympy.simplify(v2 - v1) - if sympy.simplify(d2 - d1) != 0: - return None - if step == 0: - return None - return sympy.simplify(d1 / step) - - -def _compute_stride_for_read( - node: fx.Node, - tiling_constraints: list[TilingConstraint], - idxc: IndexingContext, -) -> Optional[IndexExpr]: - """ - For a single Read node, compute the linearized K-stride per IV step. - Returns the integer stride, or None if the read is ineligible. - """ - custom = get_custom(node) - if not isinstance(custom, Read): - return None - - if not hasattr(node, "index") or node.index is None: - return None - - index = node.index # dict[IndexSymbol, IndexSequence] - mapping: Optional[IndexMapping] = custom.mapping - - # Identify which tiling constraint's IV appears in the index expressions. - iv_constraint = None - iv_sym = None - for tc in tiling_constraints: - if tc.induction_var is None: - continue - for dim_sym, seq in index.items(): - start = seq.start if isinstance(seq, IndexSequence) else seq - if isinstance(start, (int, float)): - continue - if tc.induction_var in start.free_symbols: - if iv_constraint is not None and iv_constraint is not tc: - # Multiple IVs — bail - return None - iv_constraint = tc - iv_sym = tc.induction_var - break - - if iv_constraint is None or iv_sym is None: - return None - - tiled_dim = iv_constraint.dim - # Step 1: Find the tiled dimension in the index dict and verify linearity. - # - # The index dict maps logical dimensions to IndexSequence objects. - # The tiled dimension's start should be linear in IV. - tiled_dim_key = None - for dim_sym in index: - base = infer_dim(dim_sym) - if base == infer_dim(tiled_dim): - tiled_dim_key = dim_sym - break - - if tiled_dim_key is None: - return None - - tiled_seq = index[tiled_dim_key] - tiled_start = tiled_seq.start if isinstance(tiled_seq, IndexSequence) else tiled_seq - - stride_into_iter = _linearity_check_3pt(tiled_start, iv_sym, 1) - if stride_into_iter is None: - return None - - # Get the memory node's symbolic shape - memory_node = custom.memory - mem_custom = get_custom(memory_node) - mem_shape = mem_custom.type.symbolic_shape - - # Get physical strides from the memory shape - phys_strides = strides_from_symbolic_shape(idxc, mem_shape, allow_mixed_shapes=True) - if phys_strides is None: - return None - # Substitute concrete values - phys_strides_expr = [subs_idxc(s) for s in phys_strides] - - # Step 2: If mapping present, verify mapping linearity and compute per-dim deltas. - if mapping is not None: - # The output_mapping maps logical dims to iterators. - # Find which iterator carries the tiled axis. - k_iter = None - for dim_sym, iter_expr in mapping.output_mapping.items(): - if infer_dim(dim_sym) == infer_dim(tiled_dim): - k_iter = iter_expr - break - - if k_iter is None: - return None - - # The IV advances the iterator by stride_into_iter per step. - iv_stride_into_iter = stride_into_iter - - # 3-point check on each physical dimension's input_mapping expression. - all_iters = list(mapping.iters.keys()) - other_iters = [it for it in all_iters if it != k_iter] - - per_dim_delta = {} - for dim_sym, expr in mapping.input_mapping.items(): - zero_others = {it: 0 for it in other_iters} - d = _linearity_check_3pt( - safe_subs(expr, zero_others), k_iter, iv_stride_into_iter - ) - if d is None: - return None - per_dim_delta[dim_sym] = d - - # Compute linearized stride - k_stride = 0 - for dim_sym, ps in zip(mem_shape, phys_strides_expr): - base_dim = infer_dim(dim_sym) - delta = 0 - for pd_dim, pd_val in per_dim_delta.items(): - if infer_dim(pd_dim) == base_dim: - delta = pd_val - break - k_stride += delta * ps - - else: - # No mapping: logical dimensions ARE physical dimensions. - # The tiled dimension advances by stride_into_iter, others by 0. - k_stride = 0 - for dim_sym, ps in zip(mem_shape, phys_strides_expr): - if infer_dim(dim_sym) == infer_dim(tiled_dim): - k_stride += stride_into_iter * ps - - k_stride = sympy.simplify(k_stride) - return k_stride if k_stride != 0 else None - - -def compute_iv_strides( - trace: CapturedTrace, - constraints: list[Constraint], -): - """ - Walk the graph and tag each eligible global-memory Read with: - - node.iv_k_stride (int) — linearized bytes-offset per IV step - - node.iv_linear (bool) — True - """ - tiling_constraints = _get_tiling_constraints(constraints) - if not tiling_constraints: - return - - idxc = IndexingContext.current() - - def tag_node(node: fx.Node) -> bool: - custom = get_custom(node) - if not isinstance(custom, Read): - return False - # Only global reads - if hasattr(custom, "memory_type") and hasattr( - custom.memory_type, "address_space" - ): - if custom.memory_type.address_space == SHARED_ADDRESS_SPACE: - return False - - stride = _compute_stride_for_read(node, tiling_constraints, idxc) - if stride is not None: - node.iv_k_stride = stride - node.iv_linear = True - logger.debug(f"Tagged read {node.name} with iv_k_stride={stride}") - return False - - trace.walk(tag_node) diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index 4a8d32b7a2..8d571b6edf 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -37,7 +37,6 @@ set_node_indices_water_checked, set_post_expansion_indices, ) -from .analysis.compute_iv_strides import compute_iv_strides from .analysis.partition_strided_operators import ( merge_contiguous_reads, partition_gather_like_ops, @@ -570,9 +569,6 @@ def build_graph_passes( options.minimize_shared_allocs, ), ] - # Run IV-stride analysis after scheduling so loop-carried IV symbols are - # present in read indices; earlier placement cannot see them reliably. - graph_passes += [partial(compute_iv_strides, trace, launchable.constraints)] graph_passes += [ partial( add_shared_memory_barriers, diff --git a/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h b/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h index f67770fc36..ed47d37c79 100644 --- a/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h +++ b/waveasm/include/waveasm/Transforms/TranslateFromMLIR.h @@ -516,38 +516,22 @@ class TranslationContext { mlir::Value elementOffset; int64_t srcSrdBase; int64_t elementBytes; - mlir::Value computedSoffset; }; void setPendingSRDBaseAdjust(mlir::Value memref, mlir::Value elemOffset, int64_t srcSrdBase, int64_t elementBytes) { - pendingSRDBaseAdjustMap[memref] = {elemOffset, srcSrdBase, elementBytes, - nullptr}; + pendingSRDBaseAdjustMap[memref] = {elemOffset, srcSrdBase, elementBytes}; } - /// Get a pending SRD base adjustment (mutable for caching) - PendingSRDBaseAdjust *getPendingSRDBaseAdjust(mlir::Value memref) { + /// Get a pending SRD base adjustment (returns nullptr if none) + const PendingSRDBaseAdjust * + getPendingSRDBaseAdjust(mlir::Value memref) const { auto it = pendingSRDBaseAdjustMap.find(memref); if (it != pendingSRDBaseAdjustMap.end()) return &it->second; return nullptr; } - /// Reuse an already-computed adjusted SRD with equivalent adjustment params. - std::optional - findComputedAdjustedSRDIndex(mlir::Value elementOffset, int64_t srcSrdBase, - int64_t elementBytes) const { - for (const auto &entry : pendingSRDBaseAdjustMap) { - const auto &adj = entry.second; - if (adj.elementOffset != elementOffset || adj.srcSrdBase != srcSrdBase || - adj.elementBytes != elementBytes || !adj.computedSoffset) - continue; - if (auto psreg = llvm::dyn_cast(adj.computedSoffset.getType())) - return psreg.getIndex(); - } - return std::nullopt; - } - /// Clear a pending SRD base adjustment after it has been applied void clearPendingSRDBaseAdjust(mlir::Value memref) { pendingSRDBaseAdjustMap.erase(memref); diff --git a/waveasm/lib/Transforms/TranslateFromMLIR.cpp b/waveasm/lib/Transforms/TranslateFromMLIR.cpp index dda8b7fa98..37983f34d4 100644 --- a/waveasm/lib/Transforms/TranslateFromMLIR.cpp +++ b/waveasm/lib/Transforms/TranslateFromMLIR.cpp @@ -437,83 +437,6 @@ Value lookupSRD(Value memref, TranslationContext &ctx, Location loc) { return PrecoloredSRegOp::create(builder, loc, sregType, 8, 4); } -/// Apply a pending SRD base adjustment from a linearized memref -/// reinterpret_cast. Allocates a fresh SRD slot, copies source base, -/// adds byte offset, and sets num_records + stride. -/// Caches the computed SRD in adj->computedSoffset for reuse. -static std::pair -applyPendingSRDBaseAdjust(TranslationContext::PendingSRDBaseAdjust *adj, - MemRefType memrefType, TranslationContext &ctx, - Location loc) { - (void)memrefType; - if (adj->computedSoffset) { - if (auto psreg = dyn_cast(adj->computedSoffset.getType())) { - int64_t idx = psreg.getIndex(); - auto srdType = ctx.createSRegType(4, 4); - auto srd = - PrecoloredSRegOp::create(ctx.getBuilder(), loc, srdType, idx, 4); - return {srd, idx}; - } - } - - if (auto reusedIdx = ctx.findComputedAdjustedSRDIndex( - adj->elementOffset, adj->srcSrdBase, adj->elementBytes)) { - auto &builder = ctx.getBuilder(); - auto srdType = ctx.createSRegType(4, 4); - Value srd = PrecoloredSRegOp::create(builder, loc, srdType, *reusedIdx, 4); - adj->computedSoffset = srd; - return {srd, *reusedIdx}; - } - - auto &builder = ctx.getBuilder(); - int64_t N = ctx.getNextSwizzleSRDIndex(); - auto *mlirCtx = builder.getContext(); - - std::string copyBase = "s_mov_b64 s[" + std::to_string(N) + ":" + - std::to_string(N + 1) + "], s[" + - std::to_string(adj->srcSrdBase) + ":" + - std::to_string(adj->srcSrdBase + 1) + "]"; - RawOp::create(builder, loc, copyBase); - - Value offsetVal = adj->elementOffset; - auto tmpType = PSRegType::get(mlirCtx, N + 3, 1); - if (isVGPRType(offsetVal.getType())) { - offsetVal = V_READFIRSTLANE_B32::create(builder, loc, tmpType, offsetVal); - } else { - offsetVal = S_MOV_B32::create(builder, loc, tmpType, offsetVal); - } - - auto elemSizeImm = ConstantOp::create( - builder, loc, ctx.createImmType(adj->elementBytes), adj->elementBytes); - auto hiType = PSRegType::get(mlirCtx, N + 2, 1); - auto loType = PSRegType::get(mlirCtx, N + 3, 1); - auto byteOffHi = - S_MUL_HI_U32::create(builder, loc, hiType, offsetVal, elemSizeImm); - auto byteOffLo = - S_MUL_I32::create(builder, loc, loType, offsetVal, elemSizeImm); - - auto sccType = ctx.createSRegType(); - auto base0Type = PSRegType::get(mlirCtx, N, 1); - auto base1Type = PSRegType::get(mlirCtx, N + 1, 1); - auto base0 = PrecoloredSRegOp::create(builder, loc, base0Type, N, 1); - auto base1 = PrecoloredSRegOp::create(builder, loc, base1Type, N + 1, 1); - S_ADD_U32::create(builder, loc, base0Type, sccType, base0, byteOffLo); - S_ADDC_U32::create(builder, loc, base1Type, sccType, base1, byteOffHi); - - int64_t srcBase = adj->srcSrdBase; - std::string copySize = "s_mov_b32 s" + std::to_string(N + 2) + ", s" + - std::to_string(srcBase + 2); - RawOp::create(builder, loc, copySize); - std::string copyStride = "s_mov_b32 s" + std::to_string(N + 3) + ", s" + - std::to_string(srcBase + 3); - RawOp::create(builder, loc, copyStride); - - auto srdType = ctx.createSRegType(4, 4); - Value srd = PrecoloredSRegOp::create(builder, loc, srdType, N, 4); - adj->computedSoffset = srd; - return {srd, N}; -} - SmallVector emitBufferLoads(Value srd, Value voffset, int64_t instOffset, int64_t numBytes, TranslationContext &ctx, Location loc) { @@ -1025,13 +948,7 @@ LogicalResult handleVectorLoad(Operation *op, TranslationContext &ctx) { // Global load - buffer_load_dwordx* with splitting for large vectors auto [voffset, instOffset] = computeVOffsetFromIndices(memrefType, loadOp.getIndices(), ctx, loc); - Value srd; - if (auto *adj = ctx.getPendingSRDBaseAdjust(loadOp.getBase())) { - auto [adjSrd, _] = applyPendingSRDBaseAdjust(adj, memrefType, ctx, loc); - srd = adjSrd; - } else { - srd = lookupSRD(loadOp.getBase(), ctx, loc); - } + Value srd = lookupSRD(loadOp.getBase(), ctx, loc); auto loadResults = emitBufferLoads(srd, voffset, instOffset, numBytes, ctx, loc); @@ -1071,13 +988,7 @@ LogicalResult handleVectorMaskedLoad(Operation *op, TranslationContext &ctx) { auto [voffset, instOffset] = computeVOffsetFromIndices( memrefType, maskedLoadOp.getIndices(), ctx, loc); - Value srd; - if (auto *adj = ctx.getPendingSRDBaseAdjust(maskedLoadOp.getBase())) { - auto [adjSrd, _] = applyPendingSRDBaseAdjust(adj, memrefType, ctx, loc); - srd = adjSrd; - } else { - srd = lookupSRD(maskedLoadOp.getBase(), ctx, loc); - } + Value srd = lookupSRD(maskedLoadOp.getBase(), ctx, loc); auto loadResults = emitBufferLoads(srd, voffset, instOffset, numBytes, ctx, loc); From 04ac332234d7fc042b4f4c537c4196cbdf8c6a55 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Sun, 1 Mar 2026 23:32:40 -0600 Subject: [PATCH 03/20] gather_to_lds index splitting Signed-off-by: Sanket Pandit --- .../compiler/wave_codegen/read_write.py | 178 ++++++++++++++++-- 1 file changed, 164 insertions(+), 14 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 79c848cbe5..18e4983084 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -1286,30 +1286,184 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): store_type = VectorType.get((elements_per_thread,), element_type) - src_index, src_index_wg, src_index_th = _build_start_indices( - emitter, new_src_idx, src_dynamic_vals_map_start - ) - ip = InsertionPoint.current induction_vars = set(emitter.get_induction_vars_and_syms()[1]) - # Hoist to the function level, if not using induction variables. + # Hoist dst indices to function level if they don't use induction variables. + dst_ip = ip if not any( induction_vars.intersection(set(index.start.free_symbols)) for index in dst_idx.values() ): - while not isinstance(ip.block.owner, func_d.FuncOp): - ip = InsertionPoint(ip.block.owner) + while not isinstance(dst_ip.block.owner, func_d.FuncOp): + dst_ip = InsertionPoint(dst_ip.block.owner) - with ip: + with dst_ip: dst_index, _, _ = _build_start_indices( emitter, dst_idx, dst_dynamic_vals_map_start ) - # We are indexing shared mem so i32 is enough. i32 = IntegerType.get_signless(32) dst_index = [assume_index_subgroup_uniform(idx, i32) for idx in dst_index] + # Try iv-split for the source (global) address — same AITER-style + # approach as _emit_iv_split_read: flat memref, hoisted voffset, + # scalar IV stride. + iv_vals, iv_syms = emitter.get_induction_vars_and_syms() + owner = ip.block.owner + use_iv_split = ( + iv_syms + and not isinstance(owner, func_d.FuncOp) + and emitter.options.use_wave_asm_backend + and MemRefType(src.type).rank > 0 + ) + + if use_iv_split: + # Use symbolic strides (from tensor shape) for the linearized offset, + # not the MLIR memref strides (which may be rank-1 after cast). + sym_stride_vals = strides_from_symbolic_shape( + IndexingContext.current(), src_symbolic_shape, allow_mixed_shapes=True + ) + sym_strides_int = [] + try: + sym_strides_int = [int(subs_idxc(s)) for s in sym_stride_vals] + except (TypeError, ValueError): + sym_strides_int = [] + + step_int = ( + _get_constant_value(owner.operands[2]) + if not isinstance(owner, func_d.FuncOp) + else None + ) + + if sym_strides_int and step_int is not None and step_int > 0: + start_exprs = _get_start_indices(new_src_idx) + if len(start_exprs) == len(sym_strides_int): + all_zero = { + THREAD_0: 0, + THREAD_1: 0, + THREAD_2: 0, + WORKGROUP_0: 0, + WORKGROUP_1: 0, + WORKGROUP_2: 0, + WAVE_ID_0: 0, + WAVE_ID_1: 0, + WAVE_ID_2: 0, + } + iv_sym = iv_syms[0] + try: + d1 = d2 = 0 + for expr, ps in zip(start_exprs, sym_strides_int): + e0 = subs_idxc(safe_subs(expr, {**all_zero, iv_sym: 0})) + e1 = subs_idxc(safe_subs(expr, {**all_zero, iv_sym: step_int})) + e2 = subs_idxc( + safe_subs(expr, {**all_zero, iv_sym: 2 * step_int}) + ) + v0 = int(e0) + v1 = int(e1) + v2 = int(e2) + d1 += (v1 - v0) * ps + d2 += (v2 - v1) * ps + lin_ok = d1 == d2 and d1 != 0 and d1 % step_int == 0 + except (TypeError, ValueError, sympy.SympifyError): + lin_ok = False + + if lin_ok: + k_stride_per_iv = d1 // step_int + iv_zero_subs = {sym: 0 for sym in iv_syms} + idx_no_iv = {} + for dim, seq in new_src_idx.items(): + start = _get_start_index(seq) + new_start = safe_subs(start, iv_zero_subs) + if isinstance(seq, IndexSequence): + idx_no_iv[dim] = IndexSequence(new_start, seq.size) + else: + idx_no_iv[dim] = new_start + + subs_map = add_emitter_subs(emitter, src_dynamic_vals_map_start) + overflow_flags = arith_d.IntegerOverflowFlags.nsw + hoist_ip = InsertionPoint(owner) + + with hoist_ip: + iv0_exprs = _get_start_indices(idx_no_iv) + lin_offset = None + for expr, ps in zip(iv0_exprs, sym_strides_int): + val = gen_sympy_index(subs_map, expr) + stride_c = arith_d.constant(IndexType.get(), ps) + term = arith_d.muli( + val, stride_c, overflow_flags=overflow_flags + ) + lin_offset = ( + term + if lin_offset is None + else arith_d.addi( + lin_offset, term, overflow_flags=overflow_flags + ) + ) + + iv_mlir = subs_map.get(iv_sym) + if iv_mlir is not None: + k_stride_val = arith_d.constant( + IndexType.get(), k_stride_per_iv + ) + iv_offset = arith_d.muli( + iv_mlir, k_stride_val, overflow_flags=overflow_flags + ) + total_offset = arith_d.addi( + lin_offset, iv_offset, overflow_flags=overflow_flags + ) + + mask = _build_mask( + emitter, + src_idx, + elements_per_thread=1, + bounds=src_bounds, + dynamic_values=src_dynamic_vals_map_start, + ) + if mask: + mask = vector_d.extract( + mask, static_position=[0], dynamic_position=[] + ) + oob_index_value = _get_out_of_bounds_index(element_type) + oob_index = arith_d.constant( + IndexType.get(), oob_index_value + ) + total_offset = arith_d.select(mask, total_offset, oob_index) + + # Linearize src to 1D for single-index addressing. + # Use _linearize_memref with zero wg/th offsets since + # the full offset is in total_offset already. + lin_strides = [ + gen_sympy_index( + add_emitter_subs(emitter, src_dynamic_vals_map_start), s + ) + for s in sym_stride_vals + ] + zero_indices = [arith_d.constant(IndexType.get(), 0)] * len( + lin_strides + ) + lin_src, _ = _linearize_memref( + src, zero_indices, zero_indices, lin_strides + ) + lin_src = _cast_buffer_and_encode_stride( + lin_src, lin_strides, element_type, emitter + ) + + amdgpu_d.gather_to_lds( + src=lin_src, + src_indices=[total_offset], + dst=dst, + dst_indices=dst_index, + transfer_type=store_type, + ) + return + use_iv_split = False + + # Fallback: original linearization path + src_index, src_index_wg, src_index_th = _build_start_indices( + emitter, new_src_idx, src_dynamic_vals_map_start + ) + strides = strides_from_symbolic_shape( IndexingContext.current(), src_symbolic_shape, allow_mixed_shapes=True ) @@ -1321,8 +1475,6 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): src, offset_th = _linearize_memref(src, src_index_wg, src_index_th, strides) src = _cast_buffer_and_encode_stride(src, strides, element_type, emitter) - # We previously checked mask is same for all elements, so we can use - # elements_per_thread=1 to build the mask. mask = _build_mask( emitter, src_idx, @@ -1336,11 +1488,9 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): oob_index = arith_d.constant(IndexType.get(), oob_index_value) offset_th = arith_d.select(mask, offset_th, oob_index) - src_index = [offset_th] - amdgpu_d.gather_to_lds( src=src, - src_indices=src_index, + src_indices=[offset_th], dst=dst, dst_indices=dst_index, transfer_type=store_type, From ea79100eb37c71026b69bd03ed205b06e010a298 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Sun, 1 Mar 2026 23:44:49 -0600 Subject: [PATCH 04/20] unify splitting logic Signed-off-by: Sanket Pandit --- .../compiler/wave_codegen/read_write.py | 321 ++++++------------ 1 file changed, 109 insertions(+), 212 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 18e4983084..a7791960e0 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -691,74 +691,69 @@ def _get_or_create_flat_memref( return flat -def _emit_iv_split_read( +_IV_SPLIT_ALL_ZERO = { + THREAD_0: 0, + THREAD_1: 0, + THREAD_2: 0, + WORKGROUP_0: 0, + WORKGROUP_1: 0, + WORKGROUP_2: 0, + WAVE_ID_0: 0, + WAVE_ID_1: 0, + WAVE_ID_2: 0, +} + + +def _try_iv_split_offset( emitter: WaveEmitter, - node: fx.Node, index: dict[IndexExpr, IndexSequence | IndexExpr], - kb_src: Value, - input_shape: tuple[IndexExpr, ...], - vector_type: VectorType, - dynamic_vals_map_start: dict[IndexExpr, Any], + strides: list[int], + dynamic_vals: dict[IndexExpr, Any], + use_subs_idxc: bool = False, ) -> Optional[Value]: - """ - Emit a VALU-free global read inside a tiled loop. + """Compute a hoisted IV-split linearized offset for a loop-carried read. - Follows the AITER methodology: - 1. ONE shared rank-1 memref per source buffer (no per-read SRD copies). - 2. Full linearized offset at IV=0 → voffset, hoisted before the loop. - 3. IV * k_stride added inside loop → BufferLoadStrengthReduction - promotes it to soffset, yielding zero in-loop VALU. + Returns the MLIR Value ``hoisted_voffset + IV * k_stride`` if the index + expressions are provably affine in the loop IV, or ``None`` to fall back + to the default address path. - Uses a 3-point linearity check on the post-mapping codegen-time indices - so it works even when the pre-codegen pass couldn't tag the node. + The caller is responsible for emitting the actual load/gather using the + returned offset. + + Parameters + ---------- + strides : per-dimension integer strides for linearisation. + use_subs_idxc : if True, apply ``subs_idxc`` before ``int()`` in the + 3-point check (needed when expressions contain residual shape symbols). """ iv_vals, iv_syms = emitter.get_induction_vars_and_syms() if not iv_syms: return None - kb_type = MemRefType(kb_src.type) - if kb_type.rank == 0: - return None - ip = InsertionPoint.current owner = ip.block.owner if isinstance(owner, func_d.FuncOp): return None - # --- Determine k_stride_per_iv via 3-point linearity check --- - phys_strides, _ = kb_type.get_strides_and_offset() - dyn_sentinel = ShapedType.get_dynamic_stride_or_offset() - if any(s == dyn_sentinel for s in phys_strides): - return None - step_int = _get_constant_value(owner.operands[2]) if step_int is None or step_int <= 0: return None start_exprs = _get_start_indices(index) - if len(start_exprs) != len(phys_strides): + if len(start_exprs) != len(strides): return None - all_zero = { - THREAD_0: 0, - THREAD_1: 0, - THREAD_2: 0, - WORKGROUP_0: 0, - WORKGROUP_1: 0, - WORKGROUP_2: 0, - WAVE_ID_0: 0, - WAVE_ID_1: 0, - WAVE_ID_2: 0, - } iv_sym = iv_syms[0] try: d1 = d2 = 0 - for expr, ps in zip(start_exprs, phys_strides): - v0 = int(safe_subs(expr, {**all_zero, iv_sym: 0})) - v1 = int(safe_subs(expr, {**all_zero, iv_sym: step_int})) - v2 = int(safe_subs(expr, {**all_zero, iv_sym: 2 * step_int})) - d1 += (v1 - v0) * ps - d2 += (v2 - v1) * ps + for expr, ps in zip(start_exprs, strides): + e0 = safe_subs(expr, {**_IV_SPLIT_ALL_ZERO, iv_sym: 0}) + e1 = safe_subs(expr, {**_IV_SPLIT_ALL_ZERO, iv_sym: step_int}) + e2 = safe_subs(expr, {**_IV_SPLIT_ALL_ZERO, iv_sym: 2 * step_int}) + if use_subs_idxc: + e0, e1, e2 = subs_idxc(e0), subs_idxc(e1), subs_idxc(e2) + d1 += (int(e1) - int(e0)) * ps + d2 += (int(e2) - int(e1)) * ps except (TypeError, ValueError, sympy.SympifyError): return None @@ -768,7 +763,6 @@ def _emit_iv_split_read( if rem != 0: return None - # --- Zero IV in index expressions --- iv_zero_subs = {sym: 0 for sym in iv_syms} index_no_iv = {} for dim, seq in index.items(): @@ -779,19 +773,14 @@ def _emit_iv_split_read( else: index_no_iv[dim] = new_start - # --- Hoist: compute full linearized voffset at IV=0, create shared flat memref --- - kb_type = MemRefType(kb_src.type) - phys_strides, _ = kb_type.get_strides_and_offset() hoist_ip = InsertionPoint(owner) - subs_map = add_emitter_subs(emitter, dynamic_vals_map_start) + subs_map = add_emitter_subs(emitter, dynamic_vals) overflow_flags = arith_d.IntegerOverflowFlags.nsw with hoist_ip: - flat_mem = _get_or_create_flat_memref(emitter, kb_src) - iv0_exprs = _get_start_indices(index_no_iv) lin_offset = None - for expr, ps in zip(iv0_exprs, phys_strides): + for expr, ps in zip(iv0_exprs, strides): val = gen_sympy_index(subs_map, expr) stride_c = arith_d.constant(IndexType.get(), ps) term = arith_d.muli(val, stride_c, overflow_flags=overflow_flags) @@ -801,17 +790,13 @@ def _emit_iv_split_read( else arith_d.addi(lin_offset, term, overflow_flags=overflow_flags) ) - # --- In-loop: total = hoisted_voffset + IV * k_stride --- - iv_sym = iv_syms[0] iv_mlir = subs_map.get(iv_sym) if iv_mlir is None: return None - k_stride_val = gen_sympy_index(subs_map, sympy.sympify(k_stride_per_iv)) + k_stride_val = arith_d.constant(IndexType.get(), k_stride_per_iv) iv_offset = arith_d.muli(iv_mlir, k_stride_val, overflow_flags=overflow_flags) - total_offset = arith_d.addi(lin_offset, iv_offset, overflow_flags=overflow_flags) - - return vector_d.load(vector_type, flat_mem, [total_offset]) + return arith_d.addi(lin_offset, iv_offset, overflow_flags=overflow_flags) def _build_mask_with_mapping( @@ -911,22 +896,30 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): and mask is None and not use_llvm_load and emitter.options.use_wave_asm_backend + and MemRefType(kb_src.type).rank > 0 and not read_meets_hw_transpose_requirements( get_custom(node), emitter.constraints, emitter.options.target ) ): - result = _emit_iv_split_read( - emitter, - node, - index, - kb_src, - input_shape, - vector_type, - dynamic_vals_map_start, - ) - if result is not None: - emitter.bind_node_proxy(node, IRProxyValue(result)) - return + kb_type = MemRefType(kb_src.type) + phys_strides, _ = kb_type.get_strides_and_offset() + dyn_sentinel = ShapedType.get_dynamic_stride_or_offset() + if not any(s == dyn_sentinel for s in phys_strides): + total_offset = _try_iv_split_offset( + emitter, + index, + list(phys_strides), + dynamic_vals_map_start, + ) + if total_offset is not None: + ip = InsertionPoint.current + owner = ip.block.owner + hoist_ip = InsertionPoint(owner) + with hoist_ip: + flat_mem = _get_or_create_flat_memref(emitter, kb_src) + result = vector_d.load(vector_type, flat_mem, [total_offset]) + emitter.bind_node_proxy(node, IRProxyValue(result)) + return start_indices, start_indices_wg, start_indices_th = _build_start_indices( emitter, index, dynamic_vals_map_start @@ -1306,158 +1299,62 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): i32 = IntegerType.get_signless(32) dst_index = [assume_index_subgroup_uniform(idx, i32) for idx in dst_index] - # Try iv-split for the source (global) address — same AITER-style - # approach as _emit_iv_split_read: flat memref, hoisted voffset, - # scalar IV stride. - iv_vals, iv_syms = emitter.get_induction_vars_and_syms() - owner = ip.block.owner - use_iv_split = ( - iv_syms - and not isinstance(owner, func_d.FuncOp) - and emitter.options.use_wave_asm_backend - and MemRefType(src.type).rank > 0 - ) - - if use_iv_split: - # Use symbolic strides (from tensor shape) for the linearized offset, - # not the MLIR memref strides (which may be rank-1 after cast). + # Try iv-split for the source (global) address. + if emitter.options.use_wave_asm_backend and MemRefType(src.type).rank > 0: sym_stride_vals = strides_from_symbolic_shape( IndexingContext.current(), src_symbolic_shape, allow_mixed_shapes=True ) - sym_strides_int = [] try: sym_strides_int = [int(subs_idxc(s)) for s in sym_stride_vals] except (TypeError, ValueError): sym_strides_int = [] - step_int = ( - _get_constant_value(owner.operands[2]) - if not isinstance(owner, func_d.FuncOp) - else None - ) + if sym_strides_int: + total_offset = _try_iv_split_offset( + emitter, + new_src_idx, + sym_strides_int, + src_dynamic_vals_map_start, + use_subs_idxc=True, + ) + if total_offset is not None: + mask = _build_mask( + emitter, + src_idx, + elements_per_thread=1, + bounds=src_bounds, + dynamic_values=src_dynamic_vals_map_start, + ) + if mask: + mask = vector_d.extract( + mask, static_position=[0], dynamic_position=[] + ) + oob_index_value = _get_out_of_bounds_index(element_type) + oob_index = arith_d.constant(IndexType.get(), oob_index_value) + total_offset = arith_d.select(mask, total_offset, oob_index) + + lin_strides = [ + gen_sympy_index( + add_emitter_subs(emitter, src_dynamic_vals_map_start), s + ) + for s in sym_stride_vals + ] + zero_indices = [arith_d.constant(IndexType.get(), 0)] * len(lin_strides) + lin_src, _ = _linearize_memref( + src, zero_indices, zero_indices, lin_strides + ) + lin_src = _cast_buffer_and_encode_stride( + lin_src, lin_strides, element_type, emitter + ) - if sym_strides_int and step_int is not None and step_int > 0: - start_exprs = _get_start_indices(new_src_idx) - if len(start_exprs) == len(sym_strides_int): - all_zero = { - THREAD_0: 0, - THREAD_1: 0, - THREAD_2: 0, - WORKGROUP_0: 0, - WORKGROUP_1: 0, - WORKGROUP_2: 0, - WAVE_ID_0: 0, - WAVE_ID_1: 0, - WAVE_ID_2: 0, - } - iv_sym = iv_syms[0] - try: - d1 = d2 = 0 - for expr, ps in zip(start_exprs, sym_strides_int): - e0 = subs_idxc(safe_subs(expr, {**all_zero, iv_sym: 0})) - e1 = subs_idxc(safe_subs(expr, {**all_zero, iv_sym: step_int})) - e2 = subs_idxc( - safe_subs(expr, {**all_zero, iv_sym: 2 * step_int}) - ) - v0 = int(e0) - v1 = int(e1) - v2 = int(e2) - d1 += (v1 - v0) * ps - d2 += (v2 - v1) * ps - lin_ok = d1 == d2 and d1 != 0 and d1 % step_int == 0 - except (TypeError, ValueError, sympy.SympifyError): - lin_ok = False - - if lin_ok: - k_stride_per_iv = d1 // step_int - iv_zero_subs = {sym: 0 for sym in iv_syms} - idx_no_iv = {} - for dim, seq in new_src_idx.items(): - start = _get_start_index(seq) - new_start = safe_subs(start, iv_zero_subs) - if isinstance(seq, IndexSequence): - idx_no_iv[dim] = IndexSequence(new_start, seq.size) - else: - idx_no_iv[dim] = new_start - - subs_map = add_emitter_subs(emitter, src_dynamic_vals_map_start) - overflow_flags = arith_d.IntegerOverflowFlags.nsw - hoist_ip = InsertionPoint(owner) - - with hoist_ip: - iv0_exprs = _get_start_indices(idx_no_iv) - lin_offset = None - for expr, ps in zip(iv0_exprs, sym_strides_int): - val = gen_sympy_index(subs_map, expr) - stride_c = arith_d.constant(IndexType.get(), ps) - term = arith_d.muli( - val, stride_c, overflow_flags=overflow_flags - ) - lin_offset = ( - term - if lin_offset is None - else arith_d.addi( - lin_offset, term, overflow_flags=overflow_flags - ) - ) - - iv_mlir = subs_map.get(iv_sym) - if iv_mlir is not None: - k_stride_val = arith_d.constant( - IndexType.get(), k_stride_per_iv - ) - iv_offset = arith_d.muli( - iv_mlir, k_stride_val, overflow_flags=overflow_flags - ) - total_offset = arith_d.addi( - lin_offset, iv_offset, overflow_flags=overflow_flags - ) - - mask = _build_mask( - emitter, - src_idx, - elements_per_thread=1, - bounds=src_bounds, - dynamic_values=src_dynamic_vals_map_start, - ) - if mask: - mask = vector_d.extract( - mask, static_position=[0], dynamic_position=[] - ) - oob_index_value = _get_out_of_bounds_index(element_type) - oob_index = arith_d.constant( - IndexType.get(), oob_index_value - ) - total_offset = arith_d.select(mask, total_offset, oob_index) - - # Linearize src to 1D for single-index addressing. - # Use _linearize_memref with zero wg/th offsets since - # the full offset is in total_offset already. - lin_strides = [ - gen_sympy_index( - add_emitter_subs(emitter, src_dynamic_vals_map_start), s - ) - for s in sym_stride_vals - ] - zero_indices = [arith_d.constant(IndexType.get(), 0)] * len( - lin_strides - ) - lin_src, _ = _linearize_memref( - src, zero_indices, zero_indices, lin_strides - ) - lin_src = _cast_buffer_and_encode_stride( - lin_src, lin_strides, element_type, emitter - ) - - amdgpu_d.gather_to_lds( - src=lin_src, - src_indices=[total_offset], - dst=dst, - dst_indices=dst_index, - transfer_type=store_type, - ) - return - use_iv_split = False + amdgpu_d.gather_to_lds( + src=lin_src, + src_indices=[total_offset], + dst=dst, + dst_indices=dst_index, + transfer_type=store_type, + ) + return # Fallback: original linearization path src_index, src_index_wg, src_index_th = _build_start_indices( From a147bec5ddbbbe020cb82f8d929822602ad67358 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Mon, 2 Mar 2026 16:06:20 -0600 Subject: [PATCH 05/20] do it purely symbolically Signed-off-by: Sanket Pandit --- .../compiler/wave_codegen/read_write.py | 52 +++++++++---------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index a7791960e0..b53156308f 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -691,19 +691,6 @@ def _get_or_create_flat_memref( return flat -_IV_SPLIT_ALL_ZERO = { - THREAD_0: 0, - THREAD_1: 0, - THREAD_2: 0, - WORKGROUP_0: 0, - WORKGROUP_1: 0, - WORKGROUP_2: 0, - WAVE_ID_0: 0, - WAVE_ID_1: 0, - WAVE_ID_2: 0, -} - - def _try_iv_split_offset( emitter: WaveEmitter, index: dict[IndexExpr, IndexSequence | IndexExpr], @@ -723,8 +710,8 @@ def _try_iv_split_offset( Parameters ---------- strides : per-dimension integer strides for linearisation. - use_subs_idxc : if True, apply ``subs_idxc`` before ``int()`` in the - 3-point check (needed when expressions contain residual shape symbols). + use_subs_idxc : if True, apply ``subs_idxc`` before simplification + (needed when expressions contain residual shape symbols). """ iv_vals, iv_syms = emitter.get_induction_vars_and_syms() if not iv_syms: @@ -744,21 +731,30 @@ def _try_iv_split_offset( return None iv_sym = iv_syms[0] - try: - d1 = d2 = 0 - for expr, ps in zip(start_exprs, strides): - e0 = safe_subs(expr, {**_IV_SPLIT_ALL_ZERO, iv_sym: 0}) - e1 = safe_subs(expr, {**_IV_SPLIT_ALL_ZERO, iv_sym: step_int}) - e2 = safe_subs(expr, {**_IV_SPLIT_ALL_ZERO, iv_sym: 2 * step_int}) - if use_subs_idxc: - e0, e1, e2 = subs_idxc(e0), subs_idxc(e1), subs_idxc(e2) - d1 += (int(e1) - int(e0)) * ps - d2 += (int(e2) - int(e1)) * ps - except (TypeError, ValueError, sympy.SympifyError): - return None - if d1 != d2 or d1 == 0: + # Symbolic linearity proof: substitute IV = step * j (j is a fresh + # integer symbol) keeping all other symbols (T0, WG, WAVE, ...) live. + # Because step is a concrete power-of-2 that aligns with tile sizes, + # floor/Mod sub-expressions collapse and sympy.simplify reduces the + # linearized offset to c*j + f(T0, WG, ...). The per-step delta + # lin(j+1) - lin(j) then simplifies to a pure integer constant, + # proving the stride is independent of thread/wave/workgroup indices. + _j = sympy.Symbol("_j", integer=True, nonnegative=True) + iv_as_j = step_int * _j + lin_sym = sympy.Integer(0) + for expr, ps in zip(start_exprs, strides): + e = safe_subs(expr, {iv_sym: iv_as_j}) + if use_subs_idxc: + e = subs_idxc(e) + e = sympy.simplify(e) + lin_sym += e * ps + lin_sym = sympy.simplify(lin_sym) + lin_sym_next = sympy.simplify(lin_sym.subs(_j, _j + 1)) + delta = sympy.simplify(lin_sym_next - lin_sym) + + if not delta.is_Integer or delta == 0: return None + d1 = int(delta) k_stride_per_iv, rem = divmod(d1, step_int) if rem != 0: return None From c597b9ae561c895a77f7d766800e57d8be1f14a6 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Mon, 2 Mar 2026 16:17:23 -0600 Subject: [PATCH 06/20] just get the coefficient Signed-off-by: Sanket Pandit --- .../kernel/compiler/wave_codegen/read_write.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index b53156308f..d32cc79ce6 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -736,9 +736,9 @@ def _try_iv_split_offset( # integer symbol) keeping all other symbols (T0, WG, WAVE, ...) live. # Because step is a concrete power-of-2 that aligns with tile sizes, # floor/Mod sub-expressions collapse and sympy.simplify reduces the - # linearized offset to c*j + f(T0, WG, ...). The per-step delta - # lin(j+1) - lin(j) then simplifies to a pure integer constant, - # proving the stride is independent of thread/wave/workgroup indices. + # linearized offset to c*j + f(T0, WG, ...). Extracting the + # coefficient of j and verifying j doesn't appear in the remainder + # proves the stride is constant for all thread/wave/workgroup values. _j = sympy.Symbol("_j", integer=True, nonnegative=True) iv_as_j = step_int * _j lin_sym = sympy.Integer(0) @@ -749,13 +749,12 @@ def _try_iv_split_offset( e = sympy.simplify(e) lin_sym += e * ps lin_sym = sympy.simplify(lin_sym) - lin_sym_next = sympy.simplify(lin_sym.subs(_j, _j + 1)) - delta = sympy.simplify(lin_sym_next - lin_sym) - if not delta.is_Integer or delta == 0: + coeff = lin_sym.coeff(_j) + remainder = sympy.simplify(lin_sym - coeff * _j) + if not coeff.is_Integer or coeff == 0 or _j in remainder.free_symbols: return None - d1 = int(delta) - k_stride_per_iv, rem = divmod(d1, step_int) + k_stride_per_iv, rem = divmod(int(coeff), step_int) if rem != 0: return None From 84e800127dd7259bd72e0ed510fcb57be17ac681 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Mon, 2 Mar 2026 17:41:37 -0600 Subject: [PATCH 07/20] remove wave_id changes Signed-off-by: Sanket Pandit --- .../kernel/compiler/wave_codegen/emitter.py | 17 +--------------- .../compiler/wave_codegen/read_write.py | 9 +-------- wave_lang/kernel/lang/global_symbols.py | 7 ------- wave_lang/kernel/wave/constraints.py | 20 +++++++------------ 4 files changed, 9 insertions(+), 44 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index a3783488e2..c97d311582 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -137,14 +137,6 @@ def emit_program_invariants(self): ), ] - threads_per_wave = self.hardware_constraint.threads_per_wave - tpw = arith_d.constant(IndexType.get(), threads_per_wave) - self.wave_ids = [ - arith_d.divui(self.thread_ids[0], tpw), - self.thread_ids[1], - self.thread_ids[2], - ] - def emit_func(self) -> Operation: bindings = self.root_sig.sig.linear_bindings @@ -615,11 +607,7 @@ def add_emitter_subs( arith_d.constant(IndexType.get(), 0), # DEVICE_DIM_2 ] all_symbols = ( - emitter.thread_ids - + emitter.workgroup_ids - + device_zeros - + emitter.wave_ids - + induction_vars + emitter.thread_ids + emitter.workgroup_ids + device_zeros + induction_vars ) dynamics = dict( zip( @@ -633,9 +621,6 @@ def add_emitter_subs( DEVICE_DIM_0, DEVICE_DIM_1, DEVICE_DIM_2, - WAVE_ID_0, - WAVE_ID_1, - WAVE_ID_2, ] + induction_var_syms, all_symbols, diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index d32cc79ce6..839d3028f6 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -102,14 +102,7 @@ def _split_index( """ Split index expr into thread-dependent and thread-independent parts """ - subs_wg = { - WORKGROUP_0: 0, - WORKGROUP_1: 0, - WORKGROUP_2: 0, - WAVE_ID_0: 0, - WAVE_ID_1: 0, - WAVE_ID_2: 0, - } + subs_wg = {WORKGROUP_0: 0, WORKGROUP_1: 0, WORKGROUP_2: 0} # Replace all wg symbols with 0s to get thread-dependent index. # All dynamic values will also be part of thread-index. thread_dependent_index = safe_subs(src, subs_wg) diff --git a/wave_lang/kernel/lang/global_symbols.py b/wave_lang/kernel/lang/global_symbols.py index f0add29b08..0da42050a2 100644 --- a/wave_lang/kernel/lang/global_symbols.py +++ b/wave_lang/kernel/lang/global_symbols.py @@ -41,13 +41,6 @@ def get_workgroup_symbol(i: int): THREAD_1 = index_symbol(THREAD_SYMBOL_NAMES[1]) THREAD_2 = index_symbol(THREAD_SYMBOL_NAMES[2]) -# Wave-uniform symbols: same value for all lanes in a wave, SGPR-eligible. -# WAVE_ID_N = floor(linearized_thread_id / threads_per_wave) projected onto -# workgroup dimension N. Expanded to actual MLIR values at codegen time. -WAVE_ID_0 = index_symbol("$WAVE0") -WAVE_ID_1 = index_symbol("$WAVE1") -WAVE_ID_2 = index_symbol("$WAVE2") - # Input selector symbol for selecting input from different tensors. INPUT_SELECTOR = index_symbol("$INPUT_SELECTOR") diff --git a/wave_lang/kernel/wave/constraints.py b/wave_lang/kernel/wave/constraints.py index 37e86810f7..d49538b8a8 100644 --- a/wave_lang/kernel/wave/constraints.py +++ b/wave_lang/kernel/wave/constraints.py @@ -881,22 +881,16 @@ def set_wave_id_from_hardware_and_workgroup_constraint( The wave_id is the same as the thread_id, with the exception of wave_id[0] = thread_id[0] / threads_per_wave This is a convention that we adopt. - - Uses first-class WAVE_ID_N symbols so that index expressions stay - simple (no floor(THREAD/64)) and can be recognised as wave-uniform - by the read/write lowering. """ old_wave_id = self.wave_id assert self.dim == workgroup_constraint.dim, "Dimension mismatch" - match workgroup_constraint.workgroup_dim: - case 0: - self.wave_id = WAVE_ID_0 - case 1: - self.wave_id = WAVE_ID_1 - case 2: - self.wave_id = WAVE_ID_2 - case _: - raise ValueError("Invalid workgroup dimension. Expected 0, 1 or 2.") + self.wave_id = hardware_constraint.get_thread_id_from_workgroup_dim( + workgroup_constraint.workgroup_dim + ) + # Only handling the wg_dim_0 case because Wave assumes + # all threads in a wave are handled in wg_dim_0. + if workgroup_constraint.workgroup_dim == 0: + self.wave_id = floor(self.wave_id / hardware_constraint.threads_per_wave) assert ( old_wave_id is None or self.wave_id == old_wave_id ), f"Conflicting preset wave_id old: {old_wave_id} new: {self.wave_id}" From e49763239ecc5c37531173de0da8bcba8b078af2 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Mon, 2 Mar 2026 23:20:54 -0600 Subject: [PATCH 08/20] relax conditions Signed-off-by: Sanket Pandit --- .../compiler/wave_codegen/read_write.py | 99 +++++++++---------- 1 file changed, 47 insertions(+), 52 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 839d3028f6..6e3a69b639 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -714,6 +714,8 @@ def _try_iv_split_offset( owner = ip.block.owner if isinstance(owner, func_d.FuncOp): return None + if owner.name != "scf.for": + return None step_int = _get_constant_value(owner.operands[2]) if step_int is None or step_int <= 0: @@ -883,8 +885,6 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): is_global and mask is None and not use_llvm_load - and emitter.options.use_wave_asm_backend - and MemRefType(kb_src.type).rank > 0 and not read_meets_hw_transpose_requirements( get_custom(node), emitter.constraints, emitter.options.target ) @@ -1288,61 +1288,56 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): dst_index = [assume_index_subgroup_uniform(idx, i32) for idx in dst_index] # Try iv-split for the source (global) address. - if emitter.options.use_wave_asm_backend and MemRefType(src.type).rank > 0: - sym_stride_vals = strides_from_symbolic_shape( - IndexingContext.current(), src_symbolic_shape, allow_mixed_shapes=True - ) - try: - sym_strides_int = [int(subs_idxc(s)) for s in sym_stride_vals] - except (TypeError, ValueError): - sym_strides_int = [] + sym_stride_vals = strides_from_symbolic_shape( + IndexingContext.current(), src_symbolic_shape, allow_mixed_shapes=True + ) + try: + sym_strides_int = [int(subs_idxc(s)) for s in sym_stride_vals] + except (TypeError, ValueError): + sym_strides_int = [] - if sym_strides_int: - total_offset = _try_iv_split_offset( + if sym_strides_int: + total_offset = _try_iv_split_offset( + emitter, + new_src_idx, + sym_strides_int, + src_dynamic_vals_map_start, + use_subs_idxc=True, + ) + if total_offset is not None: + mask = _build_mask( emitter, - new_src_idx, - sym_strides_int, - src_dynamic_vals_map_start, - use_subs_idxc=True, + src_idx, + elements_per_thread=1, + bounds=src_bounds, + dynamic_values=src_dynamic_vals_map_start, ) - if total_offset is not None: - mask = _build_mask( - emitter, - src_idx, - elements_per_thread=1, - bounds=src_bounds, - dynamic_values=src_dynamic_vals_map_start, - ) - if mask: - mask = vector_d.extract( - mask, static_position=[0], dynamic_position=[] - ) - oob_index_value = _get_out_of_bounds_index(element_type) - oob_index = arith_d.constant(IndexType.get(), oob_index_value) - total_offset = arith_d.select(mask, total_offset, oob_index) - - lin_strides = [ - gen_sympy_index( - add_emitter_subs(emitter, src_dynamic_vals_map_start), s - ) - for s in sym_stride_vals - ] - zero_indices = [arith_d.constant(IndexType.get(), 0)] * len(lin_strides) - lin_src, _ = _linearize_memref( - src, zero_indices, zero_indices, lin_strides - ) - lin_src = _cast_buffer_and_encode_stride( - lin_src, lin_strides, element_type, emitter + if mask: + mask = vector_d.extract(mask, static_position=[0], dynamic_position=[]) + oob_index_value = _get_out_of_bounds_index(element_type) + oob_index = arith_d.constant(IndexType.get(), oob_index_value) + total_offset = arith_d.select(mask, total_offset, oob_index) + + lin_strides = [ + gen_sympy_index( + add_emitter_subs(emitter, src_dynamic_vals_map_start), s ) + for s in sym_stride_vals + ] + zero_indices = [arith_d.constant(IndexType.get(), 0)] * len(lin_strides) + lin_src, _ = _linearize_memref(src, zero_indices, zero_indices, lin_strides) + lin_src = _cast_buffer_and_encode_stride( + lin_src, lin_strides, element_type, emitter + ) - amdgpu_d.gather_to_lds( - src=lin_src, - src_indices=[total_offset], - dst=dst, - dst_indices=dst_index, - transfer_type=store_type, - ) - return + amdgpu_d.gather_to_lds( + src=lin_src, + src_indices=[total_offset], + dst=dst, + dst_indices=dst_index, + transfer_type=store_type, + ) + return # Fallback: original linearization path src_index, src_index_wg, src_index_th = _build_start_indices( From 082bf2875fe179119e517d939d76b2e9481de9db Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Mon, 2 Mar 2026 23:21:08 -0600 Subject: [PATCH 09/20] fix lit tests Signed-off-by: Sanket Pandit --- lit_tests/kernel/wave/asm.py | 87 ---------------------- lit_tests/kernel/wave/gather_to_shared.py | 10 +-- lit_tests/kernel/wave/gemm.py | 4 +- lit_tests/kernel/wave/merge_scale_reads.py | 10 +-- lit_tests/kernel/wave/scaled_gemm.py | 24 +++--- 5 files changed, 23 insertions(+), 112 deletions(-) diff --git a/lit_tests/kernel/wave/asm.py b/lit_tests/kernel/wave/asm.py index f2ff99598f..4605fcb8bf 100644 --- a/lit_tests/kernel/wave/asm.py +++ b/lit_tests/kernel/wave/asm.py @@ -483,93 +483,6 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: s_endpgm -@run_test -def test_gemm_gather_to_lds(): - """ - Test GEMM with gather_to_lds (global_to_shared) enabled. - - When use_global_to_shared=True, the compiler generates buffer_load_dword...lds - instructions that load directly from global memory to LDS, bypassing VGPRs. - - Verifies: - - buffer_load_dword ... lds instructions are emitted - - M0 register setup for LDS addressing - - Proper barrier synchronization (vmcnt + lgkmcnt + s_barrier) - """ - constraints: list[tkw.Constraint] = [ - tkw.WorkgroupConstraint(M, BLOCK_M, 0), - tkw.WorkgroupConstraint(N, BLOCK_N, 1), - tkw.TilingConstraint(K, BLOCK_K), - tkw.WaveConstraint(M, BLOCK_M), - tkw.WaveConstraint(N, BLOCK_N), - tkw.HardwareConstraint( - threads_per_wave=64, - mma_type=tkw.MMAType.F32_16x16x16_F16, - ), - ] - - @tkw.wave(constraints) - def gemm_g2s( - a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], - b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], - ): - c_reg = tkl.Register[M, N, tkl.f32](0.0) - - @tkw.iterate(K, init_args=[c_reg]) - def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: - a_reg = tkw.read(a) - b_reg = tkw.read(b) - acc = tkw.mma(a_reg, b_reg, acc) - return acc - - tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) - - compile_options = WaveCompileOptions( - subs={ - M: 32, - N: 32, - K: 32, - BLOCK_M: 16, - BLOCK_N: 16, - BLOCK_K: 16, - LOAD_ELEMS_PER_THREAD: 4, - STORE_ELEMS_PER_THREAD: 4, - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, - }, - canonicalize=True, - compile_to_mlir=True, - use_global_to_shared=True, # Enable gather_to_lds - ) - compile_options.compile_to_asm = True - gemm_g2s = wave_compile(compile_options, gemm_g2s) - print(gemm_g2s.asm) - - # CHECK-LABEL: test_gemm_gather_to_lds - # CHECK: .protected gemm_g2s - # CHECK: .amdhsa_kernel gemm_g2s - # CHECK: gemm_g2s: - - # Verify loop structure for K-loop - # CHECK: loop_0_header: - # CHECK: s_cmp_lt_u32 s{{[0-9]+}}, s{{[0-9]+}} - # CHECK: s_cbranch_scc1 loop_0_body - - # Verify loop body has MFMA instruction - # CHECK: loop_0_body: - # CHECK: v_mfma_f32_16x16x16_f16 - - # Verify loop latch - # CHECK: loop_0_latch: - # CHECK: s_branch loop_0_header - - # Verify loop exit and stores - # CHECK: loop_0_exit: - # CHECK: buffer_store_dword - # CHECK: s_endpgm - - @run_test def test_cse_intermediate_caching(): """ diff --git a/lit_tests/kernel/wave/gather_to_shared.py b/lit_tests/kernel/wave/gather_to_shared.py index b5c69447c6..d292607179 100644 --- a/lit_tests/kernel/wave/gather_to_shared.py +++ b/lit_tests/kernel/wave/gather_to_shared.py @@ -145,17 +145,17 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: print(gemm.asm) # CHECK-LABEL: test_gather_to_shared_wave_tile_aligned_coalescing - # CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 16 + s1 * 2 - (s1 floordiv 8) * 16)> + # CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 8) * 16)> # CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + (s0 floordiv 64) * 8 + (s0 mod 64) floordiv 8 - ((s1 * 16 + (s0 floordiv 64) * 8 + (s0 mod 64) floordiv 8) floordiv 32) * 32)> # CHECK: func.func @gemm - # CHECK: %[[BLOCK_ID_Y:.+]] = gpu.block_id y # CHECK: %[[TIDX:.+]] = gpu.thread_id x # CHECK: %[[TIDY:.+]] = gpu.thread_id y - # CHECK: %[[WAVE_ALIGNED_OFFSET:.+]] = affine.apply #[[MAP2]]()[%[[TIDX]], %[[TIDY]], %[[BLOCK_ID_Y]]] + # CHECK: affine.apply #[[MAP2]]()[%[[TIDX]], %[[TIDY]], %{{.*}}] + # CHECK: %[[TH_OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[TIDX]]] # CHECK: scf.for %[[IND_VAR:.+]] = %c0 # CHECK: amdgpu.lds_barrier - # CHECK: %[[UPDATE_OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[IND_VAR]], %[[TIDX]]] - # CHECK: %[[LHS:.+]] = arith.addi %{{.*}}, %[[UPDATE_OFFSET]] + # CHECK: %[[K_STRIDE:.+]] = arith.muli %[[IND_VAR]], %{{.*}} + # CHECK: %[[LHS:.+]] = arith.addi %{{.*}}, %[[K_STRIDE]] @run_test diff --git a/lit_tests/kernel/wave/gemm.py b/lit_tests/kernel/wave/gemm.py index 7115cc0917..827b7f0949 100644 --- a/lit_tests/kernel/wave/gemm.py +++ b/lit_tests/kernel/wave/gemm.py @@ -2296,11 +2296,11 @@ def repeat( # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} { # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index # CHECK-DAG: %[[WG_ID2:.*]] = gpu.block_id z - # CHECK: %[[LHS_GLOBAL:.*]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [1, 64, 64], strides: [4096, 64, 1] : memref to memref<1x64x64xf16, strided<[4096, 64, 1]>> # CHECK: %[[RHS_GLOBAL:.*]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [6, 128, 64], strides: [8192, 64, 1] : memref to memref<6x128x64xf16, strided<[8192, 64, 1]>> # CHECK: %[[HKV_IDX:.+]] = affine.apply #[[MAP]]()[%[[WG_ID2]]] + # CHECK: %[[LHS_GLOBAL:.*]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{[0-9]+}}], strides: [1] : memref to memref<{{[0-9]+}}xf16, strided<[1]>> # CHECK: scf.for - # CHECK: %[[LHS_READ:.+]] = vector.load %[[LHS_GLOBAL]][%[[HKV_IDX]], %{{.+}}, {{.+}}] : {{.*}}, vector<8xf16> + # CHECK: %[[LHS_READ:.+]] = vector.load %[[LHS_GLOBAL]][%{{.+}}] : {{.*}}, vector<8xf16> # CHECK: %[[RHS_READ:.+]] = vector.load %[[RHS_GLOBAL]][%[[WG_ID2]], %{{.+}}, {{.+}}] : {{.*}}, vector<8xf16> # CHECK-COUNT-2: vector.extract_strided_slice # CHECK-COUNT-1: amdgpu.mfma diff --git a/lit_tests/kernel/wave/merge_scale_reads.py b/lit_tests/kernel/wave/merge_scale_reads.py index 5d56e5b733..3b41fdfc3c 100644 --- a/lit_tests/kernel/wave/merge_scale_reads.py +++ b/lit_tests/kernel/wave/merge_scale_reads.py @@ -180,13 +180,13 @@ def test_preshuffle_scale_merge_block_k_256(): # CHECK-LABEL: test_preshuffle_scale_merge_block_k_256 # Each scale tensor produces 2 merged vector<4xi8> loads from global. - # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8> - # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8> - # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8> - # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[1]>>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[1]>>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[1]>>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[1]>>, vector<4xi8> # No unmerged scalar scale loads from global should remain. - # CHECK-NOT: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<1xi8> + # CHECK-NOT: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[1]>>, vector<1xi8> # Check that amdgpu.scaled_mfma uses opsel (indexed access into scale values) # The key indicator is the [N] indexing syntax on f8E8M0FNU scale operands. Check %REG[1] as a simple check that we are doing a non-zero index diff --git a/lit_tests/kernel/wave/scaled_gemm.py b/lit_tests/kernel/wave/scaled_gemm.py index 7bf3b8b178..50fc8d84fd 100644 --- a/lit_tests/kernel/wave/scaled_gemm.py +++ b/lit_tests/kernel/wave/scaled_gemm.py @@ -196,20 +196,18 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-LABEL: test_scaled_gemm_mxfp8 # CHECK-DAG: #map = affine_map<()[s0, s1] -> (s0 + s1 * 32 - (s0 floordiv 16) * 16 + (s0 floordiv 64) * 16)> - # CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 16)> - # CHECK-DAG: #map2 = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16)> - # CHECK-DAG: #map3 = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16 + 64)> - # CHECK-DAG: #map4 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> - # CHECK-DAG: #map5 = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)> + # CHECK-DAG: #map1 = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16)> + # CHECK-DAG: #map2 = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16 + 64)> + # CHECK-DAG: #map3 = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)> + # CHECK-DAG: #map4 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 16)> + # CHECK-DAG: #map5 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> # CHECK-DAG: #map6 = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)> - # CHECK-DAG: #map7 = affine_map<()[s0, s1] -> (s0 * 128 + ((s1 mod 64) floordiv 16) * 16)> - # CHECK-DAG: #map8 = affine_map<()[s0, s1] -> (s0 * 128 + ((s1 mod 64) floordiv 16) * 16 + 64)> - # CHECK-DAG: #map9 = affine_map<()[s0, s1] -> (s1 * 4 + (s0 mod 64) floordiv 16)> - # CHECK-DAG: #map10 = affine_map<()[s0] -> (s0 * 32)> - # CHECK-DAG: #map11 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)> - # CHECK-DAG: #map12 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)> - # CHECK-DAG: #map13 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 2)> - # CHECK-DAG: #map14 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 3)> + # CHECK-DAG: #map7 = affine_map<()[s0, s1] -> (s1 * 4 + (s0 mod 64) floordiv 16)> + # CHECK-DAG: #map8 = affine_map<()[s0] -> (s0 * 32)> + # CHECK-DAG: #map9 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)> + # CHECK-DAG: #map10 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)> + # CHECK-DAG: #map11 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 2)> + # CHECK-DAG: #map12 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 3)> # CHECK: func.func @scaled_gemm # CHECK-COUNT-1: memref.alloc() # CHECK: scf.for From 108b11fbbcee9f4d445f762a4b765138c290b13d Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Tue, 3 Mar 2026 10:38:00 -0600 Subject: [PATCH 10/20] fixing lit tests Signed-off-by: Sanket Pandit --- lit_tests/kernel/wave/asm.py | 88 -------------------------- lit_tests/kernel/wave/gemm.py | 14 ++-- lit_tests/kernel/wave/scaled_gemm.py | 73 ++++++++++----------- lit_tests/kernel/wave/scaled_mma.py | 8 +-- lit_tests/kernel/wave/wave_schedule.py | 6 +- 5 files changed, 47 insertions(+), 142 deletions(-) diff --git a/lit_tests/kernel/wave/asm.py b/lit_tests/kernel/wave/asm.py index 4605fcb8bf..6ebecc97d3 100644 --- a/lit_tests/kernel/wave/asm.py +++ b/lit_tests/kernel/wave/asm.py @@ -395,94 +395,6 @@ def mma_multi( # CHECK: s_endpgm -@run_test -def test_gemm_multi_wave_k_loop(): - """ - Test multi-wave GEMM with K-loop (BLOCK_K=64). - - Uses 4 waves per workgroup (BLOCK_M=32, BLOCK_N=32, WAVE_M=16, WAVE_N=16) - with BLOCK_K=64 to test loop generation with chained MFMA accumulators. - - Verifies: - - Loop induction variable initialization - - Multiple MFMA instructions with accumulator chaining - - Loop increment and branch back - """ - constraints: list[tkw.Constraint] = [ - tkw.WorkgroupConstraint(M, BLOCK_M, 0), - tkw.WorkgroupConstraint(N, BLOCK_N, 1), - tkw.TilingConstraint(K, BLOCK_K), - tkw.WaveConstraint(M, BLOCK_M // 2), # 2 waves in M dimension - tkw.WaveConstraint(N, BLOCK_N // 2), # 2 waves in N dimension - tkw.HardwareConstraint( - threads_per_wave=64, - mma_type=tkw.MMAType.F32_16x16x16_F16, - ), - ] - - @tkw.wave(constraints) - def gemm_multi_wave( - a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], - b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], - ): - c_reg = tkl.Register[M, N, tkl.f32](0.0) - - @tkw.iterate(K, init_args=[c_reg]) - def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: - a_reg = tkw.read(a) - b_reg = tkw.read(b) - acc = tkw.mma(a_reg, b_reg, acc) - return acc - - tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) - - compile_options = WaveCompileOptions( - subs={ - M: 64, - N: 64, - K: 128, - BLOCK_M: 32, - BLOCK_N: 32, - BLOCK_K: 64, - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, - LOAD_ELEMS_PER_THREAD: 4, - STORE_ELEMS_PER_THREAD: 4, - }, - canonicalize=True, - compile_to_mlir=True, - ) - compile_options.compile_to_asm = True - gemm_multi_wave = wave_compile(compile_options, gemm_multi_wave) - print(gemm_multi_wave.asm) - - # CHECK-LABEL: test_gemm_multi_wave_k_loop - # CHECK: .protected gemm_multi_wave - # CHECK: .amdhsa_kernel gemm_multi_wave - # CHECK: .amdhsa_system_vgpr_workitem_id {{[0-9]+}} - # CHECK: gemm_multi_wave: - - # Verify loop structure - header with comparison and conditional branch - # CHECK: loop_0_header: - # CHECK: s_cmp_lt_u32 s{{[0-9]+}}, s{{[0-9]+}} - # CHECK: s_cbranch_scc1 loop_0_body - - # Verify loop body has MFMA instructions - # CHECK: loop_0_body: - # CHECK: v_mfma_f32_16x16x16_f16 - - # Verify loop latch - increment and branch back - # CHECK: loop_0_latch: - # CHECK: s_add_u32 s{{[0-9]+}}, s{{[0-9]+}}, s{{[0-9]+}} - # CHECK: s_branch loop_0_header - - # Verify loop exit and result stores - # CHECK: loop_0_exit: - # CHECK: buffer_store_dword - # CHECK: s_endpgm - - @run_test def test_cse_intermediate_caching(): """ diff --git a/lit_tests/kernel/wave/gemm.py b/lit_tests/kernel/wave/gemm.py index 827b7f0949..815b003fa1 100644 --- a/lit_tests/kernel/wave/gemm.py +++ b/lit_tests/kernel/wave/gemm.py @@ -111,7 +111,6 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 4)> # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)> # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)> - # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 * 16 + ((s1 mod 64) floordiv 16) * 4)> # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (s0 * 32)> # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)> # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)> @@ -726,7 +725,6 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # %[[IV_K:.+]] = affine.apply #[[MAP_IV_K]]()[%[[IV]], %[[TID_X]]] # CHECK-LABEL: test_packed_gemm - # CHECK-DAG: #[[MAP_IV_K:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + ((s1 mod 64) floordiv 16) * 2)> # CHECK: func.func @packed_gemm # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index # CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index @@ -736,11 +734,10 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: %[[RHS_SHARED:.+]] = memref.view %[[ALLOC]][%c0][] : memref<2560xi8, #gpu.address_space> to memref<32x10xi32, #gpu.address_space> # CHECK: %[[LHS_SHARED:.+]] = memref.view %[[ALLOC]][%c1280][] : memref<2560xi8, #gpu.address_space> to memref<32x10xi32, #gpu.address_space> # CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[C4]] step %[[C1]] - # CHECK: %[[IV_K:.+]] = affine.apply #[[MAP_IV_K]]()[%[[IV]], %[[TID_X]]] - # CHECK: %[[LHS_REG:.+]] = vector.load %{{.*}}[%{{.*}}, %[[IV_K]]] : memref<64x32xi32, strided<[32, 1]>>, vector<2xi32> + # CHECK: %[[LHS_REG:.+]] = vector.load {{.*}} : memref<{{.*}}>, vector<2xi32> # CHECK: amdgpu.lds_barrier # CHECK: vector.store %[[LHS_REG]], %[[LHS_SHARED]] - # CHECK: %[[RHS_REG:.+]] = vector.load %{{.*}}[%{{.*}}, %[[IV_K]]] : memref<128x32xi32, strided<[32, 1]>>, vector<2xi32> + # CHECK: %[[RHS_REG:.+]] = vector.load {{.*}} : memref<{{.*}}>, vector<2xi32> # CHECK: vector.store %[[RHS_REG]], %[[RHS_SHARED]] # CHECK: amdgpu.lds_barrier # CHECK-COUNT-2: vector.load {{.*}} : {{.*}}, vector<2xi32> @@ -810,7 +807,6 @@ def repeat( # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 4)> # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)> # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)> - # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 * 16 + ((s1 mod 64) floordiv 16) * 4)> # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (s0 * 32)> # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)> # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)> @@ -1347,7 +1343,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-COUNT-4: vector.load %[[ALLOC_1]] # Steady State Global Read - # CHECK-COUNT-2: vector.load {{.*}} : memref<128x128xf16, strided<[128, 1]>>, vector<8xf16> + # CHECK-COUNT-2: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> # CHECK-COUNT-2: rocdl.sched.group.barrier # Compute @@ -1738,7 +1734,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: rocdl.sched.barrier # 1st Cluster: Global load LHS - # CHECK-COUNT-2: vector.load {{.*}} : memref<4096x4096xf16, strided<[4096, 1]>>, vector<8xf16> + # CHECK-COUNT-2: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> # CHECK: rocdl.sched.barrier # 1st Cluster: Second slice of Local read lhs and rhs @@ -1749,7 +1745,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: rocdl.sched.barrier # 1st Cluster: Global load RHS - # CHECK-COUNT-4: vector.load {{.*}} : memref<4096x4096xf16, strided<[4096, 1]>>, vector<8xf16> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> # CHECK: rocdl.s.barrier # CHECK: rocdl.sched.barrier diff --git a/lit_tests/kernel/wave/scaled_gemm.py b/lit_tests/kernel/wave/scaled_gemm.py index 50fc8d84fd..261375d0c0 100644 --- a/lit_tests/kernel/wave/scaled_gemm.py +++ b/lit_tests/kernel/wave/scaled_gemm.py @@ -95,19 +95,17 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: print(scaled_gemm.asm) # CHECK-LABEL: test_scaled_gemm_mxfp4 - # CHECK-DAG: #map = affine_map<()[s0, s1] -> (s0 + s1 * 32 - (s0 floordiv 16) * 16 + (s0 floordiv 64) * 16)> - # CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 16)> - # CHECK-DAG: #map2 = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16)> - # CHECK-DAG: #map3 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> - # CHECK-DAG: #map4 = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)> - # CHECK-DAG: #map5 = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)> - # CHECK-DAG: #map6 = affine_map<()[s0, s1] -> (s0 * 64 + ((s1 mod 64) floordiv 16) * 16)> - # CHECK-DAG: #map7 = affine_map<()[s0, s1] -> (s1 * 4 + (s0 mod 64) floordiv 16)> - # CHECK-DAG: #map8 = affine_map<()[s0] -> (s0 * 32)> - # CHECK-DAG: #map9 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)> - # CHECK-DAG: #map10 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)> - # CHECK-DAG: #map11 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 2)> - # CHECK-DAG: #map12 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 3)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 + s1 * 32 - (s0 floordiv 16) * 16 + (s0 floordiv 64) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (s0 * 32)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 2)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 3)> # CHECK: func.func @scaled_gemm # CHECK-COUNT-1: memref.alloc() # CHECK: scf.for @@ -195,19 +193,18 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: print(scaled_gemm.asm) # CHECK-LABEL: test_scaled_gemm_mxfp8 - # CHECK-DAG: #map = affine_map<()[s0, s1] -> (s0 + s1 * 32 - (s0 floordiv 16) * 16 + (s0 floordiv 64) * 16)> - # CHECK-DAG: #map1 = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16)> - # CHECK-DAG: #map2 = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16 + 64)> - # CHECK-DAG: #map3 = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)> - # CHECK-DAG: #map4 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 16)> - # CHECK-DAG: #map5 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> - # CHECK-DAG: #map6 = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)> - # CHECK-DAG: #map7 = affine_map<()[s0, s1] -> (s1 * 4 + (s0 mod 64) floordiv 16)> - # CHECK-DAG: #map8 = affine_map<()[s0] -> (s0 * 32)> - # CHECK-DAG: #map9 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)> - # CHECK-DAG: #map10 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)> - # CHECK-DAG: #map11 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 2)> - # CHECK-DAG: #map12 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 3)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 + s1 * 32 - (s0 floordiv 16) * 16 + (s0 floordiv 64) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16 + 64)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (s0 * 32)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 2)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 3)> # CHECK: func.func @scaled_gemm # CHECK-COUNT-1: memref.alloc() # CHECK: scf.for @@ -315,10 +312,10 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-LABEL: gemm_mxfp4_prefetch # Prologue Global Read - # CHECK-COUNT-4: vector.load {{.*}} : memref<1024x512xi8, strided<[512, 1]>>, vector<16xi8> - # CHECK: vector.load {{.*}} : memref<1024x32xi8, strided<[32, 1]>>, vector<4xi8> - # CHECK-COUNT-4: vector.load {{.*}} : memref<1024x512xi8, strided<[512, 1]>>, vector<16xi8> - # CHECK: vector.load {{.*}} : memref<1024x32xi8, strided<[32, 1]>>, vector<4xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<16xi8> + # CHECK: vector.load {{.*}} : memref<{{.*}}>, vector<4xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<16xi8> + # CHECK: vector.load {{.*}} : memref<{{.*}}>, vector<4xi8> # Prologue Local Write # CHECK-COUNT-4: vector.store {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> @@ -330,22 +327,22 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: scf.for # Steady State global_load_rhs_scale - # CHECK: vector.load %{{.*}} : memref<1024x32xi8, strided<[32, 1]>>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}>, vector<4xi8> # Steady State local_load_rhs_scale # CHECK=COUNT-16: vector.load %{{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> # Steady State global_load_lhs_scale - # CHECK: vector.load %{{.*}} : memref<1024x32xi8, strided<[32, 1]>>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}>, vector<4xi8> # Steady State local_load_lhs_scale # CHECK=COUNT-16: vector.load %{{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> # Steady State global_load_rhs - # CHECK-COUNT-4: vector.load %{{.*}} : memref<1024x512xi8, strided<[512, 1]>>, vector<16xi8> + # CHECK-COUNT-4: vector.load %{{.*}} : memref<{{.*}}>, vector<16xi8> # Steady State local_load_rhs # CHECK=COUNT-16: vector.load %{{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> # Steady State global_load_lhs - # CHECK-COUNT-4: vector.load %{{.*}} : memref<1024x512xi8, strided<[512, 1]>>, vector<16xi8> + # CHECK-COUNT-4: vector.load %{{.*}} : memref<{{.*}}>, vector<16xi8> # Steady State local_load_lhs # CHECK=COUNT-16: vector.load %{{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> @@ -734,12 +731,12 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-DAG: %[[SCALED_LOGITS_BOUND:.+]] = arith.constant dense<96> : vector<16xindex> # CHECK: scf.for %{{.*}} = %[[C0]] to %[[C2]] step %[[C1]] # CHECK: %[[SCALED_LOGITS_MASK:.+]] = arith.cmpi slt, %{{.*}}, %[[SCALED_LOGITS_BOUND]] : vector<16xindex> - # CHECK: vector.maskedload {{.*}}, %[[SCALED_LOGITS_MASK]], %[[CST_0]] : memref<1024x96xi8, strided<[96, 1]>>, vector<16xi1>, vector<16xi8> into vector<16xi8> + # CHECK: vector.maskedload {{.*}}, %[[SCALED_LOGITS_MASK]], %[[CST_0]] : memref<{{.*}}>, vector<16xi1>, vector<16xi8> into vector<16xi8> # CHECK: %[[SCALED_SCALES_MASK_VAL:.+]] = arith.cmpi slt, %{{.*}}, %[[SCALED_SCALES_BOUND]] : index # CHECK: %[[SCALED_SCALES_MASK:.+]] = vector.broadcast %[[SCALED_SCALES_MASK_VAL]] : i1 to vector<1xi1> - # CHECK: vector.maskedload {{.*}}, %[[SCALED_SCALES_MASK]], %[[CST]] : memref<1024x6xi8, strided<[6, 1]>>, vector<1xi1>, vector<1xi8> into vector<1xi8> - # CHECK: vector.maskedload {{.*}}, %[[SCALED_LOGITS_MASK]], %[[CST_0]] : memref<1024x96xi8, strided<[96, 1]>>, vector<16xi1>, vector<16xi8> into vector<16xi8> - # CHECK: vector.maskedload {{.*}}, %[[SCALED_SCALES_MASK]], %[[CST]] : memref<1024x6xi8, strided<[6, 1]>>, vector<1xi1>, vector<1xi8> into vector<1xi8> + # CHECK: vector.maskedload {{.*}}, %[[SCALED_SCALES_MASK]], %[[CST]] : memref<{{.*}}>, vector<1xi1>, vector<1xi8> into vector<1xi8> + # CHECK: vector.maskedload {{.*}}, %[[SCALED_LOGITS_MASK]], %[[CST_0]] : memref<{{.*}}>, vector<16xi1>, vector<16xi8> into vector<16xi8> + # CHECK: vector.maskedload {{.*}}, %[[SCALED_SCALES_MASK]], %[[CST]] : memref<{{.*}}>, vector<1xi1>, vector<1xi8> into vector<1xi8> # CHECK: amdgpu.scaled_mfma # CHECK: scf.yield # CHECK: } diff --git a/lit_tests/kernel/wave/scaled_mma.py b/lit_tests/kernel/wave/scaled_mma.py index f37dae9b56..d681534f64 100644 --- a/lit_tests/kernel/wave/scaled_mma.py +++ b/lit_tests/kernel/wave/scaled_mma.py @@ -347,10 +347,10 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index # CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index # CHECK: scf.for %{{.*}} = %[[C0]] to %[[C64]] step %[[C1]] - # CHECK-COUNT-4: vector.load {{.*}} : memref<16384x8192xi8, strided<[8192, 1]>>, vector<16xi8> - # CHECK-COUNT-1: vector.load {{.*}} : memref<16384x512xi8, strided<[512, 1]>>, vector<4xi8> - # CHECK-COUNT-4: vector.load {{.*}} : memref<16384x8192xi8, strided<[8192, 1]>>, vector<16xi8> - # CHECK-COUNT-1: vector.load {{.*}} : memref<16384x512xi8, strided<[512, 1]>>, vector<4xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<16xi8> + # CHECK-COUNT-1: vector.load {{.*}} : memref<{{.*}}>, vector<4xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<16xi8> + # CHECK-COUNT-1: vector.load {{.*}} : memref<{{.*}}>, vector<4xi8> # CHECK: amdgpu.lds_barrier # CHECK-COUNT-16: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> # CHECK-COUNT-16: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> diff --git a/lit_tests/kernel/wave/wave_schedule.py b/lit_tests/kernel/wave/wave_schedule.py index ee4c21a3c0..40e02ee12a 100644 --- a/lit_tests/kernel/wave/wave_schedule.py +++ b/lit_tests/kernel/wave/wave_schedule.py @@ -46,7 +46,7 @@ def test_gemm_with_wave_schedule(): # CHECK-COUNT-4: vector.load %[[VIEW_1]] # Steady State Global Read - # CHECK-COUNT-2: vector.load {{.*}} : memref<128x128xf16, strided<[128, 1]>>, vector<8xf16> + # CHECK-COUNT-2: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> # CHECK-COUNT-2: rocdl.sched.group.barrier # Compute @@ -111,7 +111,7 @@ def test_gemm_prefetch_reorder_stagger(): # CHECK: rocdl.sched.barrier # 1st Cluster: Global load LHS - # CHECK-COUNT-2: vector.load {{.*}} : memref<4096x4096xf16, strided<[4096, 1]>>, vector<8xf16> + # CHECK-COUNT-2: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> # CHECK: rocdl.sched.barrier # 1st Cluster: Second slice of Local read lhs and rhs @@ -122,7 +122,7 @@ def test_gemm_prefetch_reorder_stagger(): # CHECK: rocdl.sched.barrier # 1st Cluster: Global load RHS - # CHECK-COUNT-4: vector.load {{.*}} : memref<4096x4096xf16, strided<[4096, 1]>>, vector<8xf16> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> # CHECK: rocdl.s.barrier # CHECK: rocdl.sched.barrier From ac96d1985ce15be112879bff0e95ad864995eb67 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Tue, 3 Mar 2026 10:58:06 -0600 Subject: [PATCH 11/20] fixing more lit tests Signed-off-by: Sanket Pandit --- lit_tests/kernel/wave/gemm.py | 4 ++-- lit_tests/kernel/wave/scaled_gemm.py | 16 ++++++++-------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/lit_tests/kernel/wave/gemm.py b/lit_tests/kernel/wave/gemm.py index 815b003fa1..3c5b18c35e 100644 --- a/lit_tests/kernel/wave/gemm.py +++ b/lit_tests/kernel/wave/gemm.py @@ -283,8 +283,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-DAG: #[[MAP_IDX_N:.+]] = affine_map<()[s0, s1, s2, s3] -> (s1 * 32 + s2 * 64 + s0 floordiv 4 - ((s1 * 32 + s0 floordiv 4) floordiv 64) * 64 + ((s2 + s3 * 8) floordiv 32) * 256 - (s2 floordiv 4) * 256)> # CHECK-DAG: %[[IDX_M_READ:.+]] = affine.apply #[[MAP_IDX_M]]()[%thread_id_x, %thread_id_y, %block_id_y, %block_id_x] # CHECK-DAG: %[[IDX_N_READ:.+]] = affine.apply #[[MAP_IDX_N]]()[%thread_id_x, %thread_id_y, %block_id_x, %block_id_y] - # CHECK-DAG: vector.load {{.*}}[%[[IDX_M_READ]], {{.*}}] - # CHECK-DAG: vector.load {{.*}}[%[[IDX_N_READ]], {{.*}}] + # CHECK-DAG: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> + # CHECK-DAG: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> # CHECK: amdgpu.mfma # CHECK: vector.store {{.*}} : memref<{{.*}}xf32{{.*}}>, vector<1xf32> diff --git a/lit_tests/kernel/wave/scaled_gemm.py b/lit_tests/kernel/wave/scaled_gemm.py index 261375d0c0..1fbd47bdd7 100644 --- a/lit_tests/kernel/wave/scaled_gemm.py +++ b/lit_tests/kernel/wave/scaled_gemm.py @@ -580,12 +580,12 @@ def repeat( # Prologue Global Read # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1], offset: ?>> # CHECK: amdgpu.fat_raw_buffer_cast %{{.*}} validBytes(%{{.*}}) cacheSwizzleStride(%[[C512_I14]]) resetOffset : memref> to memref> - # CHECK-COUNT-4: vector.load {{.*}} : memref>, vector<16xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<16xi8> # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1], offset: ?>> # CHECK: amdgpu.fat_raw_buffer_cast %{{.*}} validBytes(%{{.*}}) cacheSwizzleStride(%[[C32_I14]]) resetOffset : memref> to memref> - # CHECK: vector.load {{.*}} : memref>, vector<4xi8> - # CHECK-COUNT-4: vector.load {{.*}} : memref>, vector<16xi8> - # CHECK: vector.load {{.*}} : memref>, vector<4xi8> + # CHECK: vector.load {{.*}} : memref<{{.*}}>, vector<4xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<16xi8> + # CHECK: vector.load {{.*}} : memref<{{.*}}>, vector<4xi8> # Prologue Linearize shared memory + Local Write # CHECK: memref.reinterpret_cast {{.*}} to offset: [0], sizes: [34816], strides: [1] : memref<1x256x136xi8, #gpu.address_space> to memref<34816xi8, #gpu.address_space> @@ -601,22 +601,22 @@ def repeat( # CHECK: scf.for # Steady State global_load_rhs_scale - # CHECK: vector.load %{{.*}} : memref>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}>, vector<4xi8> # Steady State local_load_rhs_scale # CHECK=COUNT-16: vector.load %{{.*}} : memref<4096xi8, #gpu.address_space>, vector<1xi8> # Steady State global_load_lhs_scale - # CHECK: vector.load %{{.*}} : memref>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}>, vector<4xi8> # Steady State local_load_lhs_scale # CHECK=COUNT-16: vector.load %{{.*}} : memref<4096xi8, #gpu.address_space>, vector<1xi8> # Steady State global_load_rhs - # CHECK-COUNT-4: vector.load %{{.*}} : memref>, vector<16xi8> + # CHECK-COUNT-4: vector.load %{{.*}} : memref<{{.*}}>, vector<16xi8> # Steady State local_load_rhs # CHECK=COUNT-16: vector.load %{{.*}} : memref<34816xi8, #gpu.address_space>, vector<16xi8> # Steady State global_load_lhs - # CHECK-COUNT-4: vector.load %{{.*}} : memref>, vector<16xi8> + # CHECK-COUNT-4: vector.load %{{.*}} : memref<{{.*}}>, vector<16xi8> # Steady State local_load_lhs # CHECK=COUNT-16: vector.load %{{.*}} : memref<34816xi8, #gpu.address_space>, vector<16xi8> From ce17492c8417ed78d1f87e546d1e7b28954211b0 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Tue, 3 Mar 2026 13:18:18 -0600 Subject: [PATCH 12/20] fix last test Signed-off-by: Sanket Pandit --- lit_tests/kernel/wave/gemm.py | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/lit_tests/kernel/wave/gemm.py b/lit_tests/kernel/wave/gemm.py index 3c5b18c35e..7169c5947b 100644 --- a/lit_tests/kernel/wave/gemm.py +++ b/lit_tests/kernel/wave/gemm.py @@ -564,11 +564,11 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index # CHECK-DAG: %[[C768:.+]] = arith.constant 768 : index # CHECK-DAG: %[[CST:.+]] = arith.constant dense<0> : vector<4xi32> - # CHECK-DAG: %[[GLOBAL_0:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [64, 64], strides: [64, 1] : memref to memref<64x64xi8, strided<[64, 1]>> - # CHECK-DAG: %[[GLOBAL_1:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [128, 64], strides: [64, 1] : memref to memref<128x64xi8, strided<[64, 1]>> # CHECK: %[[BASE_ALLOC:.+]] = memref.alloc() : memref<1536xi8, #gpu.address_space> # CHECK: %[[ALLOC_0:.+]] = memref.view %[[BASE_ALLOC]][%[[C0]]] # CHECK: %[[ALLOC_1:.+]] = memref.view %[[BASE_ALLOC]][%[[C768]]] + # CHECK: %[[GLOBAL_0:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref to memref<{{.*}}> + # CHECK: %[[GLOBAL_1:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref to memref<{{.*}}> # CHECK: scf.for %[[IVAR:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[CST]]) -> (vector<4xi32>) { # CHECK: %[[REG_0:.+]] = vector.load %[[GLOBAL_0]] # CHECK: vector.store %[[REG_0]], %[[ALLOC_1]] @@ -638,11 +638,11 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index # CHECK-DAG: %[[C1280:.+]] = arith.constant 1280 : index # CHECK-DAG: %[[CST:.+]] = arith.constant dense<0> : vector<4xi32> - # CHECK-DAG: %[[GLOBAL_0:.+]] = memref.reinterpret_cast %0 to offset: [0], sizes: [64, 64], strides: [64, 1] : memref to memref<64x64xi8, strided<[64, 1]>> - # CHECK-DAG: %[[GLOBAL_1:.+]] = memref.reinterpret_cast %1 to offset: [0], sizes: [128, 64], strides: [64, 1] : memref to memref<128x64xi8, strided<[64, 1]>> # CHECK: %[[BASE_ALLOC:.+]] = memref.alloc() : memref<2560xi8, #gpu.address_space> # CHECK: %[[ALLOC_0:.+]] = memref.view %[[BASE_ALLOC]][%[[C0]]] # CHECK: %[[ALLOC_1:.+]] = memref.view %[[BASE_ALLOC]][%[[C1280]]] + # CHECK: %[[GLOBAL_0:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref to memref<{{.*}}> + # CHECK: %[[GLOBAL_1:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref to memref<{{.*}}> # CHECK: scf.for %[[IVAR:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[CST]]) -> (vector<4xi32>) { # CHECK: %[[REG_0:.+]] = vector.load %[[GLOBAL_0]] # CHECK: vector.store %[[REG_0]], %[[ALLOC_1]] @@ -1339,8 +1339,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: scf.for # CHECK-COUNT-1: amdgpu.lds_barrier # Steady State Local Read - # CHECK-COUNT-4: vector.load %[[ALLOC_0]] - # CHECK-COUNT-4: vector.load %[[ALLOC_1]] + # CHECK-COUNT-4: vector.load %[[VIEW_0]] + # CHECK-COUNT-4: vector.load %[[VIEW_1]] # Steady State Global Read # CHECK-COUNT-2: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> @@ -1356,8 +1356,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: scf.yield # Prologue - # CHECK-COUNT-4: vector.load %[[ALLOC_0]] - # CHECK-COUNT-4: vector.load %[[ALLOC_1]] + # CHECK-COUNT-4: vector.load %[[VIEW_0]] + # CHECK-COUNT-4: vector.load %[[VIEW_1]] # CHECK-COUNT-8: amdgpu.mfma @@ -2292,12 +2292,12 @@ def repeat( # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} { # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index # CHECK-DAG: %[[WG_ID2:.*]] = gpu.block_id z - # CHECK: %[[RHS_GLOBAL:.*]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [6, 128, 64], strides: [8192, 64, 1] : memref to memref<6x128x64xf16, strided<[8192, 64, 1]>> # CHECK: %[[HKV_IDX:.+]] = affine.apply #[[MAP]]()[%[[WG_ID2]]] - # CHECK: %[[LHS_GLOBAL:.*]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{[0-9]+}}], strides: [1] : memref to memref<{{[0-9]+}}xf16, strided<[1]>> + # CHECK: %[[GLOBAL_LHS:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref to memref<{{.*}}> + # CHECK: %[[GLOBAL_RHS:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref to memref<{{.*}}> # CHECK: scf.for - # CHECK: %[[LHS_READ:.+]] = vector.load %[[LHS_GLOBAL]][%{{.+}}] : {{.*}}, vector<8xf16> - # CHECK: %[[RHS_READ:.+]] = vector.load %[[RHS_GLOBAL]][%[[WG_ID2]], %{{.+}}, {{.+}}] : {{.*}}, vector<8xf16> + # CHECK: %[[LHS_READ:.+]] = vector.load %[[GLOBAL_LHS]] + # CHECK: %[[RHS_READ:.+]] = vector.load %[[GLOBAL_RHS]] # CHECK-COUNT-2: vector.extract_strided_slice # CHECK-COUNT-1: amdgpu.mfma # CHECK-COUNT-2: vector.extract_strided_slice @@ -2460,19 +2460,17 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-LABEL: test_explicit_shared_gemm # CHECK: func.func @gemm # CHECK-SAME: (%[[ARG0:.*]]: !stream.binding, %[[ARG1:.*]]: !stream.binding, %[[ARG2:.*]]: !stream.binding) - # CHECK-DAG: %[[GLOBAL_A:.+]] = memref.reinterpret_cast %{{.*}} to offset: [{{.*}}], sizes: [64, 64], strides: [64, 1] : memref to memref<64x64xf16{{.*}}> - # CHECK-DAG: %[[GLOBAL_B:.+]] = memref.reinterpret_cast %{{.*}} to offset: [{{.*}}], sizes: [128, 64], strides: [64, 1] : memref to memref<128x64xf16{{.*}}> # Verify explicit shared memory allocations (two separate allocs) # CHECK: %[[ALLOC_A:.+]] = memref.alloc() : memref<{{.*}}xf16, #gpu.address_space> # CHECK: %[[ALLOC_B:.+]] = memref.alloc() : memref<{{.*}}xf16, #gpu.address_space> # CHECK: scf.for # Verify load from global memory (A) - # CHECK: vector.load %[[GLOBAL_A]] + # CHECK: vector.load {{.*}} : memref<{{.*}}>, vector<{{.*}}> # Verify barrier before shared memory writes # CHECK: amdgpu.lds_barrier # Verify write to shared memory # CHECK: vector.store %{{.*}}, %[[ALLOC_A]] - # CHECK: vector.load %[[GLOBAL_B]] + # CHECK: vector.load {{.*}} : memref<{{.*}}>, vector<{{.*}}> # CHECK: vector.store %{{.*}}, %[[ALLOC_B]] # Verify barrier before shared memory reads # CHECK: amdgpu.lds_barrier From 23f7b370ed60b8d75be4e34fff8da7d0c7eb5d59 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Tue, 3 Mar 2026 13:33:18 -0600 Subject: [PATCH 13/20] refactor gather to lds Signed-off-by: Sanket Pandit --- .../compiler/wave_codegen/read_write.py | 74 ++++++------------- 1 file changed, 21 insertions(+), 53 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 6e3a69b639..c45e1eed0b 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -1287,73 +1287,41 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): i32 = IntegerType.get_signless(32) dst_index = [assume_index_subgroup_uniform(idx, i32) for idx in dst_index] - # Try iv-split for the source (global) address. + # Compute symbolic strides (shared by both iv-split and fallback paths). sym_stride_vals = strides_from_symbolic_shape( IndexingContext.current(), src_symbolic_shape, allow_mixed_shapes=True ) + subs_map = add_emitter_subs(emitter, src_dynamic_vals_map_start) + strides = [gen_sympy_index(subs_map, s) for s in sym_stride_vals] + + # Try iv-split: linearize with a single hoisted offset + scalar K stride. try: sym_strides_int = [int(subs_idxc(s)) for s in sym_stride_vals] except (TypeError, ValueError): sym_strides_int = [] + src_offset = None if sym_strides_int: - total_offset = _try_iv_split_offset( + src_offset = _try_iv_split_offset( emitter, new_src_idx, sym_strides_int, src_dynamic_vals_map_start, use_subs_idxc=True, ) - if total_offset is not None: - mask = _build_mask( - emitter, - src_idx, - elements_per_thread=1, - bounds=src_bounds, - dynamic_values=src_dynamic_vals_map_start, - ) - if mask: - mask = vector_d.extract(mask, static_position=[0], dynamic_position=[]) - oob_index_value = _get_out_of_bounds_index(element_type) - oob_index = arith_d.constant(IndexType.get(), oob_index_value) - total_offset = arith_d.select(mask, total_offset, oob_index) - - lin_strides = [ - gen_sympy_index( - add_emitter_subs(emitter, src_dynamic_vals_map_start), s - ) - for s in sym_stride_vals - ] - zero_indices = [arith_d.constant(IndexType.get(), 0)] * len(lin_strides) - lin_src, _ = _linearize_memref(src, zero_indices, zero_indices, lin_strides) - lin_src = _cast_buffer_and_encode_stride( - lin_src, lin_strides, element_type, emitter - ) - - amdgpu_d.gather_to_lds( - src=lin_src, - src_indices=[total_offset], - dst=dst, - dst_indices=dst_index, - transfer_type=store_type, - ) - return - # Fallback: original linearization path - src_index, src_index_wg, src_index_th = _build_start_indices( - emitter, new_src_idx, src_dynamic_vals_map_start - ) - - strides = strides_from_symbolic_shape( - IndexingContext.current(), src_symbolic_shape, allow_mixed_shapes=True - ) - strides = [ - gen_sympy_index(add_emitter_subs(emitter, src_dynamic_vals_map_start), s) - for s in strides - ] + if src_offset is not None: + zero_indices = [arith_d.constant(IndexType.get(), 0)] * len(strides) + lin_src, _ = _linearize_memref(src, zero_indices, zero_indices, strides) + else: + src_index, src_index_wg, src_index_th = _build_start_indices( + emitter, new_src_idx, src_dynamic_vals_map_start + ) + lin_src, src_offset = _linearize_memref( + src, src_index_wg, src_index_th, strides + ) - src, offset_th = _linearize_memref(src, src_index_wg, src_index_th, strides) - src = _cast_buffer_and_encode_stride(src, strides, element_type, emitter) + lin_src = _cast_buffer_and_encode_stride(lin_src, strides, element_type, emitter) mask = _build_mask( emitter, @@ -1366,11 +1334,11 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): mask = vector_d.extract(mask, static_position=[0], dynamic_position=[]) oob_index_value = _get_out_of_bounds_index(element_type) oob_index = arith_d.constant(IndexType.get(), oob_index_value) - offset_th = arith_d.select(mask, offset_th, oob_index) + src_offset = arith_d.select(mask, src_offset, oob_index) amdgpu_d.gather_to_lds( - src=src, - src_indices=[offset_th], + src=lin_src, + src_indices=[src_offset], dst=dst, dst_indices=dst_index, transfer_type=store_type, From d428379d6bf98f6bd5a23406ef9758e1f90bf7a5 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Tue, 3 Mar 2026 15:03:41 -0600 Subject: [PATCH 14/20] refactor gather to lds Signed-off-by: Sanket Pandit --- wave_lang/kernel/compiler/wave_codegen/read_write.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index c45e1eed0b..c12daaac38 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -1271,16 +1271,14 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): induction_vars = set(emitter.get_induction_vars_and_syms()[1]) - # Hoist dst indices to function level if they don't use induction variables. - dst_ip = ip if not any( induction_vars.intersection(set(index.start.free_symbols)) for index in dst_idx.values() ): - while not isinstance(dst_ip.block.owner, func_d.FuncOp): - dst_ip = InsertionPoint(dst_ip.block.owner) + while not isinstance(ip.block.owner, func_d.FuncOp): + ip = InsertionPoint(ip.block.owner) - with dst_ip: + with ip: dst_index, _, _ = _build_start_indices( emitter, dst_idx, dst_dynamic_vals_map_start ) From b1a3ae77dbe0b45741afc9f151478be5e8aa6243 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Tue, 3 Mar 2026 16:26:06 -0600 Subject: [PATCH 15/20] gaurd against buffer_ops, we don't support yet Signed-off-by: Sanket Pandit --- wave_lang/kernel/compiler/wave_codegen/read_write.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index c12daaac38..1c0da16b5f 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -885,6 +885,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): is_global and mask is None and not use_llvm_load + and not emitter.options.use_buffer_ops and not read_meets_hw_transpose_requirements( get_custom(node), emitter.constraints, emitter.options.target ) @@ -1299,7 +1300,7 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): sym_strides_int = [] src_offset = None - if sym_strides_int: + if sym_strides_int and not emitter.options.use_buffer_ops: src_offset = _try_iv_split_offset( emitter, new_src_idx, From 83fad5e0a19c06ef264f06c9ef9b9edc9af405e4 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Wed, 4 Mar 2026 13:05:53 -0600 Subject: [PATCH 16/20] only iv_split when one induction variable Signed-off-by: Sanket Pandit --- wave_lang/kernel/compiler/wave_codegen/read_write.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 1c0da16b5f..d7146ea69a 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -707,7 +707,7 @@ def _try_iv_split_offset( (needed when expressions contain residual shape symbols). """ iv_vals, iv_syms = emitter.get_induction_vars_and_syms() - if not iv_syms: + if len(iv_syms) != 1: return None ip = InsertionPoint.current From 3abd61582da68ac9157efee7f4b2fed61427ced5 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Wed, 4 Mar 2026 13:37:04 -0600 Subject: [PATCH 17/20] some cleanup Signed-off-by: Sanket Pandit --- .../compiler/wave_codegen/read_write.py | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index d7146ea69a..d4f3b1ba10 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -75,6 +75,7 @@ get_type_or_element_type, handle_op, ) +from ...wave.constraints import TilingConstraint def _get_start_index(i: IndexSequence | IndexExpr) -> IndexExpr: @@ -706,10 +707,6 @@ def _try_iv_split_offset( use_subs_idxc : if True, apply ``subs_idxc`` before simplification (needed when expressions contain residual shape symbols). """ - iv_vals, iv_syms = emitter.get_induction_vars_and_syms() - if len(iv_syms) != 1: - return None - ip = InsertionPoint.current owner = ip.block.owner if isinstance(owner, func_d.FuncOp): @@ -717,6 +714,24 @@ def _try_iv_split_offset( if owner.name != "scf.for": return None + # Find the IV symbol for this scf.for directly from its block argument. + current_iv = owner.induction_variable + + # do a reverse lookup of the dimension/symbol that the current IV is associated with + dim = next((d for d, v in emitter.induction_vars.items() if v == current_iv), None) + if dim is None: + return None + iv_sym = next( + ( + c.induction_var + for c in emitter.constraints + if isinstance(c, TilingConstraint) and c.dim == dim + ), + None, + ) + if iv_sym is None: + return None + step_int = _get_constant_value(owner.operands[2]) if step_int is None or step_int <= 0: return None @@ -725,15 +740,7 @@ def _try_iv_split_offset( if len(start_exprs) != len(strides): return None - iv_sym = iv_syms[0] - - # Symbolic linearity proof: substitute IV = step * j (j is a fresh - # integer symbol) keeping all other symbols (T0, WG, WAVE, ...) live. - # Because step is a concrete power-of-2 that aligns with tile sizes, - # floor/Mod sub-expressions collapse and sympy.simplify reduces the - # linearized offset to c*j + f(T0, WG, ...). Extracting the - # coefficient of j and verifying j doesn't appear in the remainder - # proves the stride is constant for all thread/wave/workgroup values. + # Symbolic linearity proof w.r.t. the current loop's IV only. _j = sympy.Symbol("_j", integer=True, nonnegative=True) iv_as_j = step_int * _j lin_sym = sympy.Integer(0) @@ -753,7 +760,7 @@ def _try_iv_split_offset( if rem != 0: return None - iv_zero_subs = {sym: 0 for sym in iv_syms} + iv_zero_subs = {iv_sym: 0} index_no_iv = {} for dim, seq in index.items(): start = _get_start_index(seq) From 18d04ff463613766ac743bc7395a3ff6104ab7a6 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Wed, 4 Mar 2026 22:46:26 -0600 Subject: [PATCH 18/20] schedule fix Signed-off-by: Sanket Pandit --- wave_lang/kernel/wave/schedules/gemm_triple_buffer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py b/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py index 0025725432..eea7e950dd 100644 --- a/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py @@ -150,6 +150,8 @@ def async_two_cluster_three_stage_schedule(): unroll_factor = 2 tkw.unroll(pipeline_loop.KERNEL, unroll_factor) + tkw.insert_after(pipeline_loop.KERNEL, tkw.MemoryCounterWaitBarrier(load=0)) + return async_two_cluster_three_stage_schedule From 6f3c79fefea7107fc52ca04a387268c43631a8d0 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Wed, 4 Mar 2026 23:10:43 -0600 Subject: [PATCH 19/20] enable buffer_ops Signed-off-by: Sanket Pandit --- wave_lang/kernel/compiler/wave_codegen/read_write.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index d4f3b1ba10..522394f3cb 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -892,7 +892,6 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): is_global and mask is None and not use_llvm_load - and not emitter.options.use_buffer_ops and not read_meets_hw_transpose_requirements( get_custom(node), emitter.constraints, emitter.options.target ) @@ -1307,7 +1306,7 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): sym_strides_int = [] src_offset = None - if sym_strides_int and not emitter.options.use_buffer_ops: + if sym_strides_int: src_offset = _try_iv_split_offset( emitter, new_src_idx, From b3a8281cb3263549f2ef96ef2ae914ac3566247e Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Wed, 4 Mar 2026 23:26:00 -0600 Subject: [PATCH 20/20] add comments Signed-off-by: Sanket Pandit --- .../kernel/compiler/wave_codegen/read_write.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 522394f3cb..899d8de21e 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -740,7 +740,9 @@ def _try_iv_split_offset( if len(start_exprs) != len(strides): return None - # Symbolic linearity proof w.r.t. the current loop's IV only. + # Phase 1: Symbolic linearity proof w.r.t. the current loop's IV only. + # substitute IV = step*_j and check + # that the linearized index is c*_j + remainder (no _j in remainder). _j = sympy.Symbol("_j", integer=True, nonnegative=True) iv_as_j = step_int * _j lin_sym = sympy.Integer(0) @@ -760,6 +762,7 @@ def _try_iv_split_offset( if rem != 0: return None + # Phase 2: Substitute IV=0 to get the loop-invariant base offset. iv_zero_subs = {iv_sym: 0} index_no_iv = {} for dim, seq in index.items(): @@ -770,6 +773,7 @@ def _try_iv_split_offset( else: index_no_iv[dim] = new_start + # Emit the hoisted linearized offset BEFORE the scf.for. hoist_ip = InsertionPoint(owner) subs_map = add_emitter_subs(emitter, dynamic_vals) overflow_flags = arith_d.IntegerOverflowFlags.nsw @@ -787,6 +791,7 @@ def _try_iv_split_offset( else arith_d.addi(lin_offset, term, overflow_flags=overflow_flags) ) + # Back inside the loop: total = hoisted_base + IV * k_stride. iv_mlir = subs_map.get(iv_sym) if iv_mlir is None: return None @@ -888,6 +893,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): is_global = get_custom(memory).type.address_space != SHARED_ADDRESS_SPACE use_llvm_load = flags != MemoryAccessFlags.NONE + # IV-split fast path for global reads: hoist address before the loop. if ( is_global and mask is None @@ -907,6 +913,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): dynamic_vals_map_start, ) if total_offset is not None: + # Load from a shared flat rank-1 view (one SRD per buffer). ip = InsertionPoint.current owner = ip.block.owner hoist_ip = InsertionPoint(owner) @@ -1292,14 +1299,14 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): i32 = IntegerType.get_signless(32) dst_index = [assume_index_subgroup_uniform(idx, i32) for idx in dst_index] - # Compute symbolic strides (shared by both iv-split and fallback paths). + # Symbolic strides shared by iv-split and fallback linearization. sym_stride_vals = strides_from_symbolic_shape( IndexingContext.current(), src_symbolic_shape, allow_mixed_shapes=True ) subs_map = add_emitter_subs(emitter, src_dynamic_vals_map_start) strides = [gen_sympy_index(subs_map, s) for s in sym_stride_vals] - # Try iv-split: linearize with a single hoisted offset + scalar K stride. + # IV-split: try hoisting the src offset before the loop. try: sym_strides_int = [int(subs_idxc(s)) for s in sym_stride_vals] except (TypeError, ValueError): @@ -1316,9 +1323,11 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): ) if src_offset is not None: + # IV-split path: offset=0 reinterpret_cast, full address in src_offset. zero_indices = [arith_d.constant(IndexType.get(), 0)] * len(strides) lin_src, _ = _linearize_memref(src, zero_indices, zero_indices, strides) else: + # Fallback: wg offset baked into memref base, th offset as voffset. src_index, src_index_wg, src_index_th = _build_start_indices( emitter, new_src_idx, src_dynamic_vals_map_start )