From a8c5a21649e51d33a67b0f29e892d66d74fde3db Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Wed, 25 Feb 2026 12:36:13 -0600 Subject: [PATCH 01/12] index splitting - another PR Signed-off-by: Sanket Pandit --- .../compiler/wave_codegen/read_write.py | 100 ++++++++++++++++-- 1 file changed, 92 insertions(+), 8 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 9941136769..f144fdcfa5 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -127,6 +127,48 @@ def _split_index(src: IndexExpr | int) -> tuple[IndexExpr, IndexExpr]: return thread_independent_index, thread_dependent_index +def _split_index_three_way( + src: IndexExpr | int, uniform_syms: set +) -> tuple[IndexExpr, IndexExpr, IndexExpr]: + """ + Split index expr into workgroup, uniform (e.g. loop induction), and + thread-dependent parts. Keeping the uniform part separate from the + thread part allows the backend to place the uniform contribution in + the buffer-load soffset field instead of a VALU add. + """ + if isinstance(src, int): + return sympy.sympify(0), sympy.sympify(0), sympy.sympify(src) + + subs_wg = {WORKGROUP_0: 0, WORKGROUP_1: 0, WORKGROUP_2: 0} + subs_all_uniform = {**subs_wg, **{s: 0 for s in uniform_syms}} + + # Thread-only: zero out all uniform symbols (WG + induction vars). + thread_index = safe_subs(src, subs_all_uniform) + + # No-WG: zero out only WG symbols (keeps induction vars + thread). + no_wg = safe_subs(src, subs_wg) + + # Uniform part = no_wg - thread_only (induction-var-dependent terms). + uniform_index = _simplify(no_wg - thread_index) + + # WG part = src - no_wg. + wg_index = _simplify(src - no_wg) + + # Validate WG part contains only WG symbols. + if wg_index.free_symbols - set(subs_wg.keys()): + wg_index = sympy.sympify(0) + no_wg = src + uniform_index = _simplify(no_wg - thread_index) + + # Validate uniform part has no thread-dependent symbols. + thread_syms = {THREAD_0, THREAD_1, THREAD_2} + if uniform_index.free_symbols & thread_syms: + thread_index = no_wg + uniform_index = sympy.sympify(0) + + return wg_index, uniform_index, thread_index + + def _extract0(src): static_pos = [0] * src.type.rank return vector_d.extract(src, static_position=static_pos, dynamic_position=[]) @@ -151,14 +193,25 @@ def _build_start_indices( emitter: WaveEmitter, src_indices: dict[IndexExpr, IndexSequence | IndexExpr], dynamic_values: dict[IndexExpr, Any] = {}, -) -> tuple[list[OpResult], list[OpResult], list[OpResult]]: + uniform_syms: set | None = None, +) -> ( + tuple[list[OpResult], list[OpResult], list[OpResult]] + | tuple[list[OpResult], list[OpResult], list[OpResult], list[OpResult]] +): start_indices = _get_start_indices(src_indices) - split_indices = [_split_index(i) for i in start_indices] subs = add_emitter_subs(emitter, dynamic_values) indices = [gen_sympy_index(subs, i) for i in start_indices] + + if uniform_syms: + split_indices = [_split_index_three_way(i, uniform_syms) for i in start_indices] + indices_wg = [gen_sympy_index(subs, i[0]) for i in split_indices] + indices_unif = [gen_sympy_index(subs, i[1]) for i in split_indices] + indices_th = [gen_sympy_index(subs, i[2]) for i in split_indices] + return indices, indices_wg, indices_unif, indices_th + + split_indices = [_split_index(i) for i in start_indices] indices_wg = [gen_sympy_index(subs, i[0]) for i in split_indices] indices_th = [gen_sympy_index(subs, i[1]) for i in split_indices] - return indices, indices_wg, indices_th @@ -1079,13 +1132,25 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): store_type = VectorType.get((elements_per_thread,), element_type) - src_index, src_index_wg, src_index_th = _build_start_indices( - emitter, new_src_idx, src_dynamic_vals_map_start - ) + induction_vars = set(emitter.get_induction_vars_and_syms()[1]) - ip = InsertionPoint.current + # Three-way split: separate induction-variable (uniform) offsets from + # per-lane thread offsets so the backend can place the uniform part in + # the buffer-load soffset field instead of emitting a VALU add. + src_index_iv = None + if induction_vars: + src_index, src_index_wg, src_index_iv, src_index_th = _build_start_indices( + emitter, + new_src_idx, + src_dynamic_vals_map_start, + uniform_syms=induction_vars, + ) + else: + src_index, src_index_wg, src_index_th = _build_start_indices( + emitter, new_src_idx, src_dynamic_vals_map_start + ) - induction_vars = set(emitter.get_induction_vars_and_syms()[1]) + ip = InsertionPoint.current # Hoist to the function level, if not using induction variables. if not any( @@ -1114,6 +1179,25 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): src, offset_th = _linearize_memref(src, src_index_wg, src_index_th, strides) src = _cast_buffer_and_encode_stride(src, strides, element_type, emitter) + # Add uniform (induction-variable) contribution as a separate SGPR offset. + # Keeping it as a distinct arith.addi (VGPR + SGPR) lets the AMDGPU + # backend's SIFoldOperands fold the SGPR into the buffer_load soffset. + if src_index_iv is not None: + overflow_flags = arith_d.IntegerOverflowFlags.nsw + offset_iv = None + for iv_idx, stride in zip(src_index_iv, strides): + if isinstance(iv_idx, int): + iv_idx = arith_d.constant(IndexType.get(), iv_idx) + off = arith_d.muli(iv_idx, stride, overflow_flags=overflow_flags) + if offset_iv is None: + offset_iv = off + else: + offset_iv = arith_d.addi(offset_iv, off, overflow_flags=overflow_flags) + if offset_iv is not None and _get_constant_value(offset_iv) != 0: + offset_th = arith_d.addi( + offset_th, offset_iv, overflow_flags=overflow_flags + ) + # We previously checked mask is same for all elements, so we can use # elements_per_thread=1 to build the mask. mask = _build_mask( From c0ff14fedccc9e3d1f9d3259770139651f48736b Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Wed, 25 Feb 2026 12:35:56 -0600 Subject: [PATCH 02/12] schedule fixes for asm backend Signed-off-by: Sanket Pandit --- wave_lang/kernel/wave/.gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/wave_lang/kernel/wave/.gitignore b/wave_lang/kernel/wave/.gitignore index fd68a7a039..30a78986a1 100644 --- a/wave_lang/kernel/wave/.gitignore +++ b/wave_lang/kernel/wave/.gitignore @@ -12,6 +12,7 @@ downloads/ eggs/ .eggs/ lib/ +!asm/wave_asm/ lib64/ parts/ sdist/ From 591a8e68045bce70364bc589b2ba9051eef9f096 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Tue, 24 Feb 2026 23:19:09 -0600 Subject: [PATCH 03/12] wave asm backend infra Signed-off-by: Sanket Pandit --- wave_lang/kernel/wave/compile_options.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/wave_lang/kernel/wave/compile_options.py b/wave_lang/kernel/wave/compile_options.py index bc123aa2c6..0e52441973 100644 --- a/wave_lang/kernel/wave/compile_options.py +++ b/wave_lang/kernel/wave/compile_options.py @@ -132,4 +132,7 @@ class WaveCompileOptions: use_wave_asm_backend: bool = ( False # Use WaveASM (waveasm-translate) instead of Python backend ) + use_wave_asm_backend: bool = ( + False # Use WaveASM (waveasm-translate) instead of Python backend + ) mma_type: Optional["MMAType"] = None # MMA type for ASM backend dispatch From 3726fde497fb4dfbae4ff2f880129a4e56aa4732 Mon Sep 17 00:00:00 2001 From: Sanket Pandit Date: Sun, 22 Feb 2026 15:25:30 -0600 Subject: [PATCH 04/12] fix address >4GB Signed-off-by: Sanket Pandit --- docs/per-workgroup-srd-base-adjustment.md | 68 +++++++++++++++++++ .../compiler/wave_codegen/read_write.py | 8 ++- 2 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 docs/per-workgroup-srd-base-adjustment.md diff --git a/docs/per-workgroup-srd-base-adjustment.md b/docs/per-workgroup-srd-base-adjustment.md new file mode 100644 index 0000000000..4e26e3adcf --- /dev/null +++ b/docs/per-workgroup-srd-base-adjustment.md @@ -0,0 +1,68 @@ +# Per-Workgroup SRD Base Adjustment for >4GB Output Buffers + +## Problem + +For large GEMM shapes (e.g., `M=32768, N=57344, K=16384`), the C output matrix `memref<32768x57344xf32>` is ~7GB. This caused two failures: + +1. **Assembly error**: `s_mov_b32 srd[2], 0x1C0000000` — the SRD `num_records` field is 32 bits, but the computed buffer size (7,516,192,768 bytes) exceeds 2^32. +2. **Address overflow**: even if `num_records` were clamped, the store voffset `row * 229376 + col * 4` overflows a 32-bit VGPR for workgroups targeting the upper portion of the output matrix. + +## Fix + +Split the store byte offset into a **workgroup base** (folded into the SRD base address via 64-bit SALU) and a **thread offset** (small, used as `voffset`). This matches AITER's per-workgroup SRD pattern. + +### Layer 1: MLIR codegen (Python) + +**File**: `wave_lang/kernel/compiler/wave_codegen/read_write.py` + +The existing `_linearize_memref` function already separates workgroup offsets (from `block_id * tile_size`) into the memref base pointer and returns thread-only offsets for indexing. It was previously gated on `buffer_ops_enabled`. + +Change: for global writes without `buffer_ops`, also call `_linearize_memref` (skipping `_cast_buffer_and_encode_stride`). This produces a `memref.reinterpret_cast` with a dynamic per-workgroup element offset and 1D thread-only store indices: + +```mlir +%wg_offset = arith.addi(arith.muli(%block_id_x_times_128, 57344), %block_id_y_times_256) +%tile_mem = memref.reinterpret_cast(%c_raw) offset: [%wg_offset], sizes: [536870910], strides: [1] +vector.store(%val, %tile_mem, [%thread_offset]) +``` + +Thread offsets stay within ~28MB (the 128x256 tile), fitting comfortably in 32 bits. + +### Layer 2: C++ backend + +Three changes in the WaveASM C++ backend: + +**1. Clamp buffer size** (`TranslateFromMLIR.cpp`): In `emitSRDPrologue`, clamp `pending.bufferSize` to `0xFFFFFFFF` before emitting `s_mov_b32` for `num_records`. This is a safety net — the original full-sized `reinterpret_cast` still exists in the MLIR but is unused by stores after linearization. + +**2. Track pending SRD adjustments** (`MemRefHandlers.cpp`): In `handleMemRefReinterpretCast`, detect dynamic offsets (from `_linearize_memref`) and store the element offset Value, source SRD index, and element byte width in a `PendingSRDBaseAdjust` map. The actual SALU ops are deferred to the store handler to survive DCE. + +**3. Emit SRD adjustment inline** (`TranslateFromMLIR.cpp`): In `handleVectorStore`, when a pending adjustment exists for the store target, emit: + +```asm +s_mov_b64 s[N:N+1], s[src:src+1] ; copy source SRD base +v_readfirstlane_b32 s[N+3], vOffset ; element offset → SGPR +s_mul_hi_u32 s[N+2], s[N+3], 4 ; byte offset high (for >4GB) +s_mul_i32 s[N+3], s[N+3], 4 ; byte offset low +s_add_u32 s[N], s[N], s[N+3] ; base_lo += byteOffLo (sets SCC) +s_addc_u32 s[N+1], s[N+1], s[N+2] ; base_hi += byteOffHi + carry +s_mov_b32 s[N+2], 0x7FFFFFF8 ; num_records (tile-sized) +s_mov_b32 s[N+3], 0x20000 ; stride descriptor +``` + +The adjustment uses `PSRegType` (precolored physical SGPRs) for all intermediates, with `s[N+2]` and `s[N+3]` serving as temporaries before being overwritten by `num_records` and `stride`. After the first store emits the adjustment, subsequent stores reuse the adjusted SRD via `setSRDIndex`. + +### Layer 3: Dialect changes + +**File**: `WaveASMOps.td` + +- Added `S_ADDC_U32` (carry-dependent add, reads SCC from preceding `s_add_u32`). +- Made `S_ADD_U32` and `S_ADDC_U32` non-`Pure`. These ops set SCC as a side effect; removing `Pure` prevents the canonicalizer from DCE'ing the SRD adjustment chain (whose PSRegType results have no explicit SSA users — they communicate through physical register aliasing with the later `PrecoloredSRegOp`). + +## Files modified + +| File | Change | +|------|--------| +| `wave_codegen/read_write.py` | Call `_linearize_memref` for global writes without `buffer_ops` | +| `TranslateFromMLIR.cpp` | Clamp `bufferSize` in `emitSRDPrologue`; emit SRD adjustment in `handleVectorStore` | +| `TranslateFromMLIR.h` | Add `PendingSRDBaseAdjust` struct and tracking methods | +| `handlers/MemRefHandlers.cpp` | Detect dynamic offset in `handleMemRefReinterpretCast`, track for deferred emission | +| `WaveASMOps.td` | Add `S_ADDC_U32`; make `S_ADD_U32`/`S_ADDC_U32` non-Pure | diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index f144fdcfa5..6a535a0b1e 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -610,8 +610,14 @@ def extract(vec, ind): mem, start_indices_wg, start_indices_th, strides ) mem = _cast_buffer_and_encode_stride(mem, strides, element_type, emitter) + elif is_global_mem and not is_read: + mem, offset_th = _linearize_memref( + mem, start_indices_wg, start_indices_th, strides + ) - indices = [offset_th] if buffer_ops_enabled else start_indices + indices = ( + [offset_th] if (buffer_ops_enabled or offset_th is not None) else start_indices + ) if no_masked_load_store_ops: # find the index at which memory out of bounds of buffer From 35391d33b41be8f8f3c4d716bfe0b9f3cb6ba133 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 26 Feb 2026 22:51:42 +0100 Subject: [PATCH 05/12] Support dynamic M/N/K in MXFP4 preshuffle-B kernels MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Six infrastructure fixes needed for dynamic dims with the wave_asm backend: 1. scheduling/schedule.py: translate node_mapping keys back to original graph nodes after graph_copy in the dynamic pipelining path, fixing identity mismatch in _update_kernel_node_mapping (0/165 → 165/165). 2. wave_schedule_ops.py: use iterate's owning graph for subgraph reordering instead of hardcoding get_root_graph() (pipelined iterate lives inside a conditional subgraph with dynamic shapes). 3. unrolling.py: guard unroll count validation with is_literal() so symbolic counts don't raise TypeError. 4. emitter.py: handle Rational operands in Mod via the identity Mod(a/b,c)=Mod(a,b*c)/b; resolve terminal rationals with divsi. 5. read_write.py: initialize offset_th=None before masked code path (pre-existing bug only triggered by dynamic shapes). 6. host_codegen.py: resolve derived buffer shapes (K/2, K/32) for dynamic symbol recovery using infer_dim + sympy.solve, and evaluate dimension expressions via gen_sympy_index. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ivan Butygin --- examples/python/7.1_schedule.py | 8 +++- wave_lang/kernel/compiler/host_codegen.py | 37 +++++++++++++++++-- .../kernel/compiler/wave_codegen/emitter.py | 1 - .../compiler/wave_codegen/read_write.py | 1 + wave_lang/kernel/ops/wave_schedule_ops.py | 11 ++++-- wave_lang/kernel/wave/scheduling/schedule.py | 10 ++++- wave_lang/kernel/wave/unrolling.py | 10 +++-- 7 files changed, 63 insertions(+), 15 deletions(-) diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index 693f3af92a..1b740ba372 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -33,6 +33,7 @@ b_preshuffle, e8m0_shuffle, ) +import wave_lang.kernel.lang as tkl from wave_lang.kernel.lang.global_symbols import GLOBAL_ADDRESS_SPACE from utils import parse_args, list_tests, run_test @@ -254,8 +255,13 @@ def test_dbuf_4wave_mxfp_asymmetric_gemm_cpp( def test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp( is_debug=False, shape=(1024, 1024, 8192), block=(128, 256, 256) ): - """Preshuffle-B MXFP4 GEMM using C++ WaveASM backend.""" + """Preshuffle-B MXFP4 GEMM using C++ WaveASM backend with dynamic M/N/K.""" gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4)) + # Make M, N, K dynamic so the compiler does not specialize on problem size. + dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in dynamic_symbols: + del options.subs[sym] + options.dynamic_symbols = dynamic_symbols options.backend = "asm" options.wave_runtime = True options.use_wave_asm_backend = True diff --git a/wave_lang/kernel/compiler/host_codegen.py b/wave_lang/kernel/compiler/host_codegen.py index 34fd8b5e7d..e742549d9c 100644 --- a/wave_lang/kernel/compiler/host_codegen.py +++ b/wave_lang/kernel/compiler/host_codegen.py @@ -25,7 +25,10 @@ tensor_d, ) +import sympy + from .._support.indexing import IndexSymbol +from ..wave.utils.general_utils import infer_dim from ...support.location_config import LocationCaptureConfig from .builder import ( ModuleBuilder, @@ -150,14 +153,30 @@ def isolated_test_call( argument_dims = get_dynamic_dims(host_sig.buffer_bindings, dynamic_symbols) # Map dynamic symbols to buffer argument indices and dimensions. + # For derived shapes like K/2, also store the inverse expression + # so we can recover K from the buffer dimension at runtime. arg_dim_mapping: dict[IndexSymbol, tuple[int, int]] = {} + # Maps symbol -> sympy expression to recover it from the dim value. + # For direct matches (M in shape[M, ...]) this is just a dummy d. + # For derived (K/2 in shape[M, K/2]) this is e.g. 2*d. + _dim_val = sympy.Symbol("_dim_val") + arg_dim_inverse: dict[IndexSymbol, sympy.Expr] = {} for arg_idx, b in enumerate(host_sig.buffer_bindings): shape = b.kernel_buffer_type.symbolic_shape - for dim_idx, dim_symbol in enumerate(shape): - if dim_symbol in arg_dim_mapping: + for dim_idx, dim_expr in enumerate(shape): + base_sym = infer_dim(dim_expr) + if base_sym in arg_dim_mapping: continue - - arg_dim_mapping[dim_symbol] = (arg_idx, dim_idx) + arg_dim_mapping[base_sym] = (arg_idx, dim_idx) + if dim_expr == base_sym: + arg_dim_inverse[base_sym] = _dim_val + else: + # Solve shape_expr = d for the base symbol. + solutions = sympy.solve(dim_expr - _dim_val, base_sym) + assert len(solutions) == 1, ( + f"Cannot solve {dim_expr} = _dim_val for {base_sym}" + ) + arg_dim_inverse[base_sym] = solutions[0] if async_dispatch: fence_type = IrType.parse("!hal.fence") @@ -217,6 +236,8 @@ def isolated_test_call( ] # Get the dynamic symbols values from the buffer dimensions. + # For derived shapes (K/2), apply the inverse expression to + # recover the original symbol value. dynamic_argument_map: dict[IndexSymbol, Value] = {} for symbol in dynamic_symbols: if symbol in arg_dim_mapping: @@ -338,6 +359,14 @@ def isolated_test_call( else: # If no device constraints, just dispatch the kernel directly # with the provided host signature arguments. + from .wave_codegen.emitter import gen_sympy_index as _gen + + def _resolve_dim(expr): + """Resolve a shape expression to an IR value.""" + if expr in dynamic_argument_map: + return dynamic_argument_map[expr] + return _gen(dynamic_argument_map, expr) + out = flow_d.DispatchOp( memref_to_tensor(output_types), [dynamic_argument_map[dim] for dim in dynamic_symbols] diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index 30b494800c..397de51a6c 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -836,7 +836,6 @@ def _rem(lhs, rhs): rem_expr(muli_expr(lhs, rhs.denominator), rhs.numerator), rhs.denominator, ) - return rem_expr(lhs, rhs) def _floordiv(lhs, rhs): diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 6a535a0b1e..5d3755d5cd 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -590,6 +590,7 @@ def extract(vec, ind): zero = get_constant_attr(0, element_type) zero = arith_d.constant(element_type, zero) + offset_th = None if mask is None: mask_vec_type = VectorType.get( [elements_per_thread], IntegerType.get_signless(1) diff --git a/wave_lang/kernel/ops/wave_schedule_ops.py b/wave_lang/kernel/ops/wave_schedule_ops.py index ea6a42bc0b..159a1549e4 100755 --- a/wave_lang/kernel/ops/wave_schedule_ops.py +++ b/wave_lang/kernel/ops/wave_schedule_ops.py @@ -648,14 +648,17 @@ def _reorder_subgraph( original_subgraph_name = custom_iterate.subgraph_name reordered_subgraph_name = f"reordered_{original_subgraph_name}" + # The iterate's owning graph holds the subgraph registration. + # For static shapes this is the root graph; for dynamic shapes + # it may be a conditional subgraph. + parent_graph = custom_iterate.graph + kernel_trace.add_subgraph(reordered_subgraph_name, reordered_subgraph) - kernel_trace.get_root_graph().subgraphs[ - reordered_subgraph_name - ] = reordered_subgraph + parent_graph.subgraphs[reordered_subgraph_name] = reordered_subgraph custom_iterate.update_arg("subgraph_name", reordered_subgraph_name) del kernel_trace.region_graph.subgraphs[original_subgraph_name] - del kernel_trace.get_root_graph().subgraphs[original_subgraph_name] + del parent_graph.subgraphs[original_subgraph_name] logger.info( f"Successfully reordered graph: {original_subgraph_name} -> {reordered_subgraph_name}" diff --git a/wave_lang/kernel/wave/scheduling/schedule.py b/wave_lang/kernel/wave/scheduling/schedule.py index 61e641ec1f..d97210e030 100644 --- a/wave_lang/kernel/wave/scheduling/schedule.py +++ b/wave_lang/kernel/wave/scheduling/schedule.py @@ -264,7 +264,7 @@ def build_guarded_pipeline_with_remainder( else: pipelined_iterations = (max_induction_variable // num_stages) * num_stages - conditional_body_graph, _ = graph_copy(reduction_graph) + conditional_body_graph, body_old_to_new = graph_copy(reduction_graph) placeholder_init_args = [placeholders[arg] for arg in reduction.init_args] placeholder_captures = [placeholders[cap] for cap in reduction.implicit_captures] @@ -314,6 +314,14 @@ def build_guarded_pipeline_with_remainder( multi_buffer_count, ) + # node_mapping keys are from the copied body graph. Translate them back + # to the original reduction_graph nodes so that + # _update_kernel_node_mapping can match tracked lists by identity. + new_to_old = {v: k for k, v in body_old_to_new.items()} + node_mapping = { + new_to_old.get(k, k): v for k, v in node_mapping.items() + } + # Set the count for the pipelined loop # With step > 1 (e.g., from unrolling), we need to reduce the count by more # to prevent out-of-bounds access. The last kernel iteration's stage 0 loads diff --git a/wave_lang/kernel/wave/unrolling.py b/wave_lang/kernel/wave/unrolling.py index 061f6b67e8..3d9977f9a3 100644 --- a/wave_lang/kernel/wave/unrolling.py +++ b/wave_lang/kernel/wave/unrolling.py @@ -7,6 +7,7 @@ from sympy import Integer from torch import fx +from wave_lang.kernel._support.indexing import is_literal from wave_lang.kernel._support.tracing import CapturedTrace from ..ops.wave_ops import ( @@ -70,10 +71,11 @@ def unroll( assert isinstance( iterate.count, int | Integer ), "Iteration count must be a statically determinable integer" - if iterate.count / unroll_factor < 1: - raise ValueError("Unroll factor is too large for the iteration count.") - if iterate.count % unroll_factor != 0: - raise ValueError("Unroll factor must divide the iteration count evenly.") + if is_literal(iterate.count): + if int(iterate.count) / unroll_factor < 1: + raise ValueError("Unroll factor is too large for the iteration count.") + if int(iterate.count) % unroll_factor != 0: + raise ValueError("Unroll factor must divide the iteration count evenly.") if iterate.condition is not None: raise ValueError("Unrolling is not supported for iterates with conditions.") From c6143773f6e42ca0bf730a47047c21560daaa7f9 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Thu, 26 Feb 2026 23:26:48 +0100 Subject: [PATCH 06/12] [asm] Add arith.divsi handler for power-of-2 divisors Uses V_ASHRREV_I32 (arithmetic right shift) for signed division by power-of-2 constants, mirroring the existing arith.divui handler which uses V_LSHRREV_B32 (logical right shift). Required for dynamic M/N/K support where K/2 and K/32 appear as arith.divsi operations. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ivan Butygin --- waveasm/lib/Transforms/TranslateFromMLIR.cpp | 2 ++ .../lib/Transforms/handlers/ArithHandlers.cpp | 34 +++++++++++++++++-- waveasm/lib/Transforms/handlers/Handlers.h | 2 ++ 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/waveasm/lib/Transforms/TranslateFromMLIR.cpp b/waveasm/lib/Transforms/TranslateFromMLIR.cpp index 37983f34d4..c7f36f4847 100644 --- a/waveasm/lib/Transforms/TranslateFromMLIR.cpp +++ b/waveasm/lib/Transforms/TranslateFromMLIR.cpp @@ -522,6 +522,7 @@ LogicalResult handleArithAddI(Operation *op, TranslationContext &ctx); LogicalResult handleArithSubI(Operation *op, TranslationContext &ctx); LogicalResult handleArithMulI(Operation *op, TranslationContext &ctx); LogicalResult handleArithDivUI(Operation *op, TranslationContext &ctx); +LogicalResult handleArithDivSI(Operation *op, TranslationContext &ctx); LogicalResult handleArithRemUI(Operation *op, TranslationContext &ctx); LogicalResult handleArithIndexCast(Operation *op, TranslationContext &ctx); LogicalResult handleArithAndI(Operation *op, TranslationContext &ctx); @@ -1593,6 +1594,7 @@ void OpHandlerRegistry::registerDefaultHandlers(mlir::MLIRContext *ctx) { REGISTER_HANDLER(arith::SubIOp, handleArithSubI); REGISTER_HANDLER(arith::MulIOp, handleArithMulI); REGISTER_HANDLER(arith::DivUIOp, handleArithDivUI); + REGISTER_HANDLER(arith::DivSIOp, handleArithDivSI); REGISTER_HANDLER(arith::RemUIOp, handleArithRemUI); REGISTER_HANDLER(arith::IndexCastOp, handleArithIndexCast); diff --git a/waveasm/lib/Transforms/handlers/ArithHandlers.cpp b/waveasm/lib/Transforms/handlers/ArithHandlers.cpp index 82b8e9e41d..ae3b47a78a 100644 --- a/waveasm/lib/Transforms/handlers/ArithHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/ArithHandlers.cpp @@ -170,7 +170,7 @@ LogicalResult handleArithDivUI(Operation *op, TranslationContext &ctx) { return failure(); } - // Check if RHS is a power of 2 constant - use shift instead + // Check if RHS is a power of 2 constant - use shift instead. if (auto constOp = rhs->getDefiningOp()) { int64_t divisor = constOp.getValue(); if (isPowerOf2(divisor)) { @@ -184,12 +184,40 @@ LogicalResult handleArithDivUI(Operation *op, TranslationContext &ctx) { } } - // General case: non-power-of-2 division requires complex reciprocal sequence - // Emit an error rather than silently producing incorrect code + // General case: non-power-of-2 division requires complex reciprocal sequence. + // Emit an error rather than silently producing incorrect code. return op->emitError("unsigned integer division by non-power-of-2 is not " "yet implemented; divisor must be a power of 2"); } +LogicalResult handleArithDivSI(Operation *op, TranslationContext &ctx) { + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + auto vregType = ctx.createVRegType(); + + auto divOp = cast(op); + std::optional lhs, rhs; + if (failed(validateBinaryOperands(divOp, ctx, lhs, rhs))) + return failure(); + + // Power-of-2 constant divisor → arithmetic right shift. + if (auto constOp = rhs->getDefiningOp()) { + int64_t divisor = constOp.getValue(); + if (divisor > 0 && isPowerOf2(divisor)) { + int64_t shiftAmt = log2(divisor); + auto immShift = ctx.createImmType(shiftAmt); + auto shiftConst = ConstantOp::create(builder, loc, immShift, shiftAmt); + auto result = + V_ASHRREV_I32::create(builder, loc, vregType, shiftConst, *lhs); + ctx.getMapper().mapValue(divOp.getResult(), result); + return success(); + } + } + + return op->emitError("signed integer division by non-power-of-2 is not " + "yet implemented; divisor must be a positive power of 2"); +} + LogicalResult handleArithRemUI(Operation *op, TranslationContext &ctx) { auto &builder = ctx.getBuilder(); auto loc = op->getLoc(); diff --git a/waveasm/lib/Transforms/handlers/Handlers.h b/waveasm/lib/Transforms/handlers/Handlers.h index 8d6481c1dc..5cf8464997 100644 --- a/waveasm/lib/Transforms/handlers/Handlers.h +++ b/waveasm/lib/Transforms/handlers/Handlers.h @@ -71,6 +71,8 @@ mlir::LogicalResult handleArithMulI(mlir::Operation *op, TranslationContext &ctx); mlir::LogicalResult handleArithDivUI(mlir::Operation *op, TranslationContext &ctx); +mlir::LogicalResult handleArithDivSI(mlir::Operation *op, + TranslationContext &ctx); mlir::LogicalResult handleArithRemUI(mlir::Operation *op, TranslationContext &ctx); mlir::LogicalResult handleArithIndexCast(mlir::Operation *op, From c3b2254fa76328896d6dd2ba89562b78d0d78ce9 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 27 Feb 2026 01:07:32 +0100 Subject: [PATCH 07/12] Scalarize mask computation and index select for buffer ops path When use_buffer_ops is enabled, avoid emitting vector arith ops for bounds-check masks and OOB index selection. Instead: - _build_mask: new scalarize option builds per-element scalar cmpi and assembles the mask with vector.from_elements. - _create_vec_read_write: replace vector broadcast/addi/select with a scalar loop computing offset_th+i per element. - Enable use_buffer_ops in the dynamic-dims preshuffle-B test. This eliminates all vector<16xindex> ops from the dynamic-dims MLIR, which the WaveASM backend cannot translate. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ivan Butygin --- examples/python/7.1_schedule.py | 1 + .../compiler/wave_codegen/read_write.py | 96 ++++++++++--------- 2 files changed, 51 insertions(+), 46 deletions(-) diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index 1b740ba372..62b9408e52 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -262,6 +262,7 @@ def test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp( for sym in dynamic_symbols: del options.subs[sym] options.dynamic_symbols = dynamic_symbols + options.use_buffer_ops = True options.backend = "asm" options.wave_runtime = True options.use_wave_asm_backend = True diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 5d3755d5cd..490c6f34c4 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -225,6 +225,7 @@ def _build_mask( elements_per_thread: int, bounds: Optional[dict[IndexSymbol, IndexExpr]], dynamic_values: dict[IndexExpr, Any] = {}, + scalarize: bool = False, ) -> Optional[OpResult]: if not bounds: return None @@ -234,6 +235,25 @@ def _build_mask( last_dim = list(index)[fastest_dim] new_index = {k: _get_start_index(v) for k, v in index.items()} + if scalarize: + # Build mask from per-element scalar comparisons to avoid vector ops. + subs = add_emitter_subs(emitter, dynamic_values) + base = new_index[last_dim] + i1 = IntegerType.get_signless(1) + bits = [] + for i in range(elements_per_thread): + new_index[last_dim] = base + i + elem_expr = functools.reduce( + lambda a, b: sympy.And(a, b), + (new_index[dim] < bound for dim, bound in bounds.items()), + ) + bits.append(gen_sympy_index(subs, elem_expr)) + new_index[last_dim] = base + mask_vec_type = VectorType.get( + [elements_per_thread], IntegerType.get_signless(1) + ) + return vector_d.from_elements(mask_vec_type, bits) + new_index[last_dim] = new_index[last_dim] + idxc.iota(elements_per_thread) mask_expr = functools.reduce( @@ -597,15 +617,6 @@ def extract(vec, ind): ) mask = _constant_mask(mask_vec_type) - # make offsets 0, 1, 2 ... - offsets_vec_type = VectorType.get(vector_type.shape, IndexType.get()) - vals = [IntegerAttr.get(IndexType.get(), v) for v in range(elements_per_thread)] - - offsets_vec = arith_d.constant( - offsets_vec_type, DenseElementsAttr.get(vals, offsets_vec_type) - ) - - offset_th = None if buffer_ops_enabled: mem, offset_th = _linearize_memref( mem, start_indices_wg, start_indices_th, strides @@ -621,49 +632,28 @@ def extract(vec, ind): ) if no_masked_load_store_ops: - # find the index at which memory out of bounds of buffer + # Out-of-bounds index causes hardware to return zero. oob_index_value = _get_out_of_bounds_index(element_type) oob_index = arith_d.constant(IndexType.get(), oob_index_value) - oob_index = vector_d.broadcast( - VectorType.get(vector_type.shape, IndexType.get()), oob_index - ) - - offset_th = vector_d.broadcast( - VectorType.get(vector_type.shape, IndexType.get()), offset_th - ) - - uint32_vec_type = VectorType.get([elements_per_thread], uint32) - indexvec_type = VectorType.get([elements_per_thread], IndexType.get()) - - offsets_vec = arith_d.index_cast(uint32_vec_type, offsets_vec) - offset_th = arith_d.index_cast(uint32_vec_type, offset_th) - - # add the thread offset and the vec offsets - offsets_vec = arith_d.addi(offsets_vec, offset_th) - offsets_vec = arith_d.index_cast(indexvec_type, offsets_vec) - - # based on mask, select between the offsets_vec and out of bounds. In this case all 3 operands can be vectors - selected_index = arith_d.select(mask, offsets_vec, oob_index) - elems = list() - if splatted_mask: - # mask is same for all of them, can just pick the first index - selected_index = extract(selected_index, 0) - + # Mask is uniform — select once and do a single vector load/store. + selected_index = arith_d.select(mask_splat, offset_th, oob_index) if is_read: return vector_d.load(vector_type, mem, indices=[selected_index]) - else: vector_d.store(value, mem, indices=[selected_index]) return + # Per-element scalar index computation avoids vector broadcasts. + elems = list() + singlenumvec_type = VectorType.get([1], vector_type.element_type) for i in range(elements_per_thread): - # mask is not same for all elements, need to unroll - this_index = extract(selected_index, i) # this element + i_const = arith_d.constant(IndexType.get(), i) + elem_offset = arith_d.addi(offset_th, i_const) + mask_bit = extract(mask, i) + this_index = arith_d.select(mask_bit, elem_offset, oob_index) - # Unmasked load, using selected_index - singlenumvec_type = VectorType.get([1], vector_type.element_type) if is_read: elem = vector_d.load(singlenumvec_type, mem, indices=[this_index]) elem = extract(elem, 0) @@ -674,10 +664,8 @@ def extract(vec, ind): vector_d.store(single_num_vector, mem, indices=[this_index]) if is_read: - # now make a vector from all the elements loaded return vector_d.from_elements(vector_type, elems) - - else: # it was a store, return + else: return else: @@ -700,6 +688,7 @@ def _build_mask_with_mapping( elements_per_thread: int, bounds: Optional[tuple[IndexSymbol, ...]], dynamic_vals_map: dict[IndexExpr, Value], + scalarize: bool = False, ) -> Optional[Value]: """ Build a mask for read/write operations, when a mapping is used @@ -731,9 +720,12 @@ def _build_mask_with_mapping( elements_per_thread, bounds, dynamic_vals_map, + scalarize=scalarize, ) else: - return _build_mask(emitter, index, elements_per_thread, bounds) + return _build_mask( + emitter, index, elements_per_thread, bounds, scalarize=scalarize + ) @handle_op(read) @@ -762,6 +754,9 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): ) dynamic_vals_map_start = _build_dyn_vals_map(mapping, dyn_vals) + is_global_mem = kb_ir_type.memory_space is None + scalarize_mask = emitter.options.use_buffer_ops and is_global_mem + if mapping: transformed_index = transform_index_on_mapping( mapping, input_shape, index, is_read=True @@ -775,10 +770,13 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): elements_per_thread, bounds, dynamic_vals_map_start, + scalarize=scalarize_mask, ) index = transformed_index else: - mask = _build_mask(emitter, index, elements_per_thread, bounds) + mask = _build_mask( + emitter, index, elements_per_thread, bounds, scalarize=scalarize_mask + ) start_indices, start_indices_wg, start_indices_th = _build_start_indices( emitter, index, dynamic_vals_map_start @@ -853,6 +851,9 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): dynamic_vals_map_start = _build_dyn_vals_map(mapping, dyn_vals) element_type = kb_ir_type.element_type + is_global_mem = kb_ir_type.memory_space is None + scalarize_mask = emitter.options.use_buffer_ops and is_global_mem + if mapping: transformed_index = transform_index_on_mapping( mapping, output_shape, index, is_read=False @@ -866,10 +867,13 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): elements_per_thread, bounds, dynamic_vals_map_start, + scalarize=scalarize_mask, ) index = transformed_index else: - mask = _build_mask(emitter, index, elements_per_thread, bounds) + mask = _build_mask( + emitter, index, elements_per_thread, bounds, scalarize=scalarize_mask + ) start_indices, start_indices_wg, start_indices_th = _build_start_indices( emitter, index, dynamic_vals_map_start From 9ea75add0864dcd8efb7c871c01c207780967948 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 27 Feb 2026 01:41:53 +0100 Subject: [PATCH 08/12] [asm] Add vector.from_elements handler and fix sub-dword memref.load - Add handleVectorFromElements: packs sub-dword scalars into VGPRs using V_LSHL_OR_B32 chains, combines DWORDs with PackOp. - Fix handleMemRefLoad to emit buffer_load_ubyte / ds_read_u8 for i8 element types instead of unconditionally emitting dword loads. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ivan Butygin --- waveasm/lib/Transforms/TranslateFromMLIR.cpp | 2 + waveasm/lib/Transforms/handlers/Handlers.h | 2 + .../Transforms/handlers/MemRefHandlers.cpp | 39 +++++++++++---- .../Transforms/handlers/VectorHandlers.cpp | 50 +++++++++++++++++++ 4 files changed, 82 insertions(+), 11 deletions(-) diff --git a/waveasm/lib/Transforms/TranslateFromMLIR.cpp b/waveasm/lib/Transforms/TranslateFromMLIR.cpp index c7f36f4847..08b5197935 100644 --- a/waveasm/lib/Transforms/TranslateFromMLIR.cpp +++ b/waveasm/lib/Transforms/TranslateFromMLIR.cpp @@ -562,6 +562,7 @@ LogicalResult handleVectorShapeCast(Operation *op, TranslationContext &ctx); LogicalResult handleVectorBitCast(Operation *op, TranslationContext &ctx); LogicalResult handleVectorFma(Operation *op, TranslationContext &ctx); LogicalResult handleVectorReduction(Operation *op, TranslationContext &ctx); +LogicalResult handleVectorFromElements(Operation *op, TranslationContext &ctx); LogicalResult handleVectorExtractStridedSlice(Operation *op, TranslationContext &ctx); @@ -1652,6 +1653,7 @@ void OpHandlerRegistry::registerDefaultHandlers(mlir::MLIRContext *ctx) { REGISTER_HANDLER(vector::TransferWriteOp, handleVectorTransferWrite); REGISTER_HANDLER(vector::FMAOp, handleVectorFma); REGISTER_HANDLER(vector::ReductionOp, handleVectorReduction); + REGISTER_HANDLER(vector::FromElementsOp, handleVectorFromElements); // AMDGPU dialect REGISTER_HANDLER(amdgpu::LDSBarrierOp, handleAMDGPULdsBarrier); diff --git a/waveasm/lib/Transforms/handlers/Handlers.h b/waveasm/lib/Transforms/handlers/Handlers.h index 5cf8464997..70c09d869a 100644 --- a/waveasm/lib/Transforms/handlers/Handlers.h +++ b/waveasm/lib/Transforms/handlers/Handlers.h @@ -157,6 +157,8 @@ mlir::LogicalResult handleVectorFma(mlir::Operation *op, TranslationContext &ctx); mlir::LogicalResult handleVectorReduction(mlir::Operation *op, TranslationContext &ctx); +mlir::LogicalResult handleVectorFromElements(mlir::Operation *op, + TranslationContext &ctx); //===----------------------------------------------------------------------===// // AMDGPU Dialect Handlers diff --git a/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp b/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp index 3243f449f1..026c98a3ca 100644 --- a/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp @@ -196,31 +196,40 @@ LogicalResult handleMemRefLoad(Operation *op, TranslationContext &ctx) { auto memrefType = loadOp.getMemRefType(); auto vregType = ctx.createVRegType(); + Type elemType = memrefType.getElementType(); + int64_t elemBits = elemType.getIntOrFloatBitWidth(); + if (isLDSMemRef(memrefType)) { - // LDS load + // LDS load. Value vaddr; if (!loadOp.getIndices().empty()) { - if (auto mapped = ctx.getMapper().getMapped(loadOp.getIndices()[0])) { + if (auto mapped = ctx.getMapper().getMapped(loadOp.getIndices()[0])) vaddr = *mapped; - } } if (!vaddr) { auto immType = ctx.createImmType(0); vaddr = ConstantOp::create(builder, loc, immType, 0); } - auto readOp = DS_READ_B32::create(builder, loc, TypeRange{vregType}, vaddr); - ctx.getMapper().mapValue(loadOp.getResult(), readOp.getResult(0)); + Value result; + if (elemBits <= 8) + result = + DS_READ_U8::create(builder, loc, TypeRange{vregType}, vaddr) + .getResult(0); + else + result = + DS_READ_B32::create(builder, loc, TypeRange{vregType}, vaddr) + .getResult(0); + ctx.getMapper().mapValue(loadOp.getResult(), result); } else { - // Global load + // Global load. auto sregType = ctx.createSRegType(4, 4); auto srd = PrecoloredSRegOp::create(builder, loc, sregType, 8, 4); Value voffset; if (!loadOp.getIndices().empty()) { - if (auto mapped = ctx.getMapper().getMapped(loadOp.getIndices()[0])) { + if (auto mapped = ctx.getMapper().getMapped(loadOp.getIndices()[0])) voffset = *mapped; - } } if (!voffset) { auto immType = ctx.createImmType(0); @@ -229,9 +238,17 @@ LogicalResult handleMemRefLoad(Operation *op, TranslationContext &ctx) { auto zeroImm = builder.getType(0); auto zeroConst = ConstantOp::create(builder, loc, zeroImm, 0); - auto loadInstr = BUFFER_LOAD_DWORD::create( - builder, loc, TypeRange{vregType}, srd, voffset, zeroConst); - ctx.getMapper().mapValue(loadOp.getResult(), loadInstr.getResult(0)); + + Value result; + if (elemBits <= 8) + result = BUFFER_LOAD_UBYTE::create(builder, loc, TypeRange{vregType}, + srd, voffset, zeroConst) + .getResult(0); + else + result = BUFFER_LOAD_DWORD::create(builder, loc, TypeRange{vregType}, + srd, voffset, zeroConst) + .getResult(0); + ctx.getMapper().mapValue(loadOp.getResult(), result); } return success(); diff --git a/waveasm/lib/Transforms/handlers/VectorHandlers.cpp b/waveasm/lib/Transforms/handlers/VectorHandlers.cpp index 03f74f122a..1bf834440c 100644 --- a/waveasm/lib/Transforms/handlers/VectorHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/VectorHandlers.cpp @@ -271,4 +271,54 @@ LogicalResult handleVectorExtractStridedSlice(Operation *op, return success(); } +LogicalResult handleVectorFromElements(Operation *op, TranslationContext &ctx) { + auto fromElemOp = cast(op); + auto &builder = ctx.getBuilder(); + auto loc = op->getLoc(); + + auto resultType = fromElemOp.getType(); + int64_t numElems = resultType.getNumElements(); + int64_t elemBitWidth = resultType.getElementType().getIntOrFloatBitWidth(); + + // Number of sub-elements that pack into a single 32-bit VGPR. + int64_t elemsPerDword = 32 / elemBitWidth; + int64_t numDwords = (numElems + elemsPerDword - 1) / elemsPerDword; + + SmallVector dwords; + for (int64_t d = 0; d < numDwords; ++d) { + Value packed; + for (int64_t e = 0; e < elemsPerDword; ++e) { + int64_t idx = d * elemsPerDword + e; + if (idx >= numElems) + break; + + auto elem = ctx.getMapper().getMapped(fromElemOp.getElements()[idx]); + if (!elem) + return op->emitError("element not mapped at index ") << idx; + + if (e == 0) { + packed = *elem; + } else { + // Pack into the next byte/halfword lane via shift-or. + int64_t shiftAmt = e * elemBitWidth; + auto shiftImm = ConstantOp::create(builder, loc, + ctx.createImmType(shiftAmt), + shiftAmt); + packed = V_LSHL_OR_B32::create(builder, loc, ctx.createVRegType(), + *elem, shiftImm, packed); + } + } + dwords.push_back(packed); + } + + if (numDwords == 1) { + ctx.getMapper().mapValue(fromElemOp.getResult(), dwords[0]); + } else { + auto resultVRegType = ctx.createVRegType(numDwords, numDwords); + auto pack = PackOp::create(builder, loc, resultVRegType, dwords); + ctx.getMapper().mapValue(fromElemOp.getResult(), pack); + } + return success(); +} + } // namespace waveasm From c23546f69784701305353db6c22fd4dcb0e62ae5 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 2 Mar 2026 16:35:48 +0100 Subject: [PATCH 09/12] c++ fixes Signed-off-by: Ivan Butygin --- .../lib/Transforms/handlers/ArithHandlers.cpp | 5 +++-- .../lib/Transforms/handlers/MemRefHandlers.cpp | 18 ++++++++---------- .../lib/Transforms/handlers/VectorHandlers.cpp | 5 ++--- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/waveasm/lib/Transforms/handlers/ArithHandlers.cpp b/waveasm/lib/Transforms/handlers/ArithHandlers.cpp index ae3b47a78a..2909ad805a 100644 --- a/waveasm/lib/Transforms/handlers/ArithHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/ArithHandlers.cpp @@ -214,8 +214,9 @@ LogicalResult handleArithDivSI(Operation *op, TranslationContext &ctx) { } } - return op->emitError("signed integer division by non-power-of-2 is not " - "yet implemented; divisor must be a positive power of 2"); + return op->emitError( + "signed integer division by non-power-of-2 is not " + "yet implemented; divisor must be a positive power of 2"); } LogicalResult handleArithRemUI(Operation *op, TranslationContext &ctx) { diff --git a/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp b/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp index 026c98a3ca..3c6d3fa939 100644 --- a/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/MemRefHandlers.cpp @@ -213,13 +213,11 @@ LogicalResult handleMemRefLoad(Operation *op, TranslationContext &ctx) { Value result; if (elemBits <= 8) - result = - DS_READ_U8::create(builder, loc, TypeRange{vregType}, vaddr) - .getResult(0); + result = DS_READ_U8::create(builder, loc, TypeRange{vregType}, vaddr) + .getResult(0); else - result = - DS_READ_B32::create(builder, loc, TypeRange{vregType}, vaddr) - .getResult(0); + result = DS_READ_B32::create(builder, loc, TypeRange{vregType}, vaddr) + .getResult(0); ctx.getMapper().mapValue(loadOp.getResult(), result); } else { // Global load. @@ -241,12 +239,12 @@ LogicalResult handleMemRefLoad(Operation *op, TranslationContext &ctx) { Value result; if (elemBits <= 8) - result = BUFFER_LOAD_UBYTE::create(builder, loc, TypeRange{vregType}, - srd, voffset, zeroConst) + result = BUFFER_LOAD_UBYTE::create(builder, loc, TypeRange{vregType}, srd, + voffset, zeroConst) .getResult(0); else - result = BUFFER_LOAD_DWORD::create(builder, loc, TypeRange{vregType}, - srd, voffset, zeroConst) + result = BUFFER_LOAD_DWORD::create(builder, loc, TypeRange{vregType}, srd, + voffset, zeroConst) .getResult(0); ctx.getMapper().mapValue(loadOp.getResult(), result); } diff --git a/waveasm/lib/Transforms/handlers/VectorHandlers.cpp b/waveasm/lib/Transforms/handlers/VectorHandlers.cpp index 1bf834440c..8d9c979770 100644 --- a/waveasm/lib/Transforms/handlers/VectorHandlers.cpp +++ b/waveasm/lib/Transforms/handlers/VectorHandlers.cpp @@ -301,9 +301,8 @@ LogicalResult handleVectorFromElements(Operation *op, TranslationContext &ctx) { } else { // Pack into the next byte/halfword lane via shift-or. int64_t shiftAmt = e * elemBitWidth; - auto shiftImm = ConstantOp::create(builder, loc, - ctx.createImmType(shiftAmt), - shiftAmt); + auto shiftImm = ConstantOp::create( + builder, loc, ctx.createImmType(shiftAmt), shiftAmt); packed = V_LSHL_OR_B32::create(builder, loc, ctx.createVRegType(), *elem, shiftImm, packed); } From f6b09d0ea6180e641ad10676cbb02b89b5c239d6 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 2 Mar 2026 16:36:25 +0100 Subject: [PATCH 10/12] simplify ranges Signed-off-by: Ivan Butygin --- tests/kernel/wave_utils_test.py | 14 +- tests/unittests/symbol_utils_test.py | 31 +++ .../compiler/wave_codegen/read_write.py | 20 +- .../analysis/partition_strided_operators.py | 28 ++- wave_lang/kernel/wave/utils/mapping_utils.py | 139 +------------ wave_lang/kernel/wave/utils/symbol_utils.py | 193 ++++++++++++++++-- 6 files changed, 247 insertions(+), 178 deletions(-) diff --git a/tests/kernel/wave_utils_test.py b/tests/kernel/wave_utils_test.py index af9dba71e4..031b34e01f 100644 --- a/tests/kernel/wave_utils_test.py +++ b/tests/kernel/wave_utils_test.py @@ -11,9 +11,7 @@ delinearize_index, divide_shape_into_chunks, ) -from wave_lang.kernel.wave.utils.mapping_utils import ( - _simplify_sympy_expr, -) +from wave_lang.kernel.wave.utils.symbol_utils import simplify from wave_lang.kernel.wave.constraints import MMAType from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile from wave_lang.kernel.wave.templates.gemm import get_gemm_kernel @@ -111,10 +109,10 @@ def test_divide_shape_into_chunks(): def test_custom_sympy_simplifications(): a = sympy.Symbol("a", integer=True, nonnegative=True) mod_expr = (sympy.floor(a) * 4 + 3) % 16 - assert str(_simplify_sympy_expr(mod_expr)) == "4*(Mod(a, 4)) + 3" + assert str(simplify(mod_expr)) == "4*(Mod(a, 4)) + 3" floor_expr = sympy.floor(sympy.floor(a) / 3 + sympy.sympify(1) / 6) - assert str(_simplify_sympy_expr(floor_expr)) == "floor(a/3)" + assert str(simplify(floor_expr)) == "floor(a/3)" @pytest.mark.skip("Too slow") @@ -139,7 +137,7 @@ def test_fuzz_custom_sympy_simplifications_mod(): expr = expr.subs({a: vals[0], b: vals[1], c: vals[2]}) expr = sympy.simplify(expr) - expr2 = _simplify_sympy_expr(expr) + expr2 = simplify(expr) if i % 50 == 0 and i > 0: print(f"{100*i/outer_num_iters}%") @@ -453,7 +451,7 @@ def check_specific(*vals): expr1 = orig_expr.subs({a: vals[0], b: vals[1], c: vals[2], d: vals[3]}) expr1 = sympy.simplify(expr1) - expr2 = _simplify_sympy_expr(expr1) + expr2 = simplify(expr1) assert expr1.subs({x: vals[4]}) == expr2.subs({x: vals[4]}) check_specific(10, 11, 6, 10, 6) @@ -477,7 +475,7 @@ def check_specific(*vals): expr = orig_expr.subs({a: vals[0], b: vals[1], c: vals[2], d: vals[3]}) expr = sympy.simplify(expr) - expr2 = _simplify_sympy_expr(expr) + expr2 = simplify(expr) if expr != expr2: break diff --git a/tests/unittests/symbol_utils_test.py b/tests/unittests/symbol_utils_test.py index dc2d80b553..72a05909e4 100644 --- a/tests/unittests/symbol_utils_test.py +++ b/tests/unittests/symbol_utils_test.py @@ -108,6 +108,37 @@ def test_bounds_unsupported_returns_none(): assert expr_bounds(x**2) is None +def test_bounds_ceiling(): + x = _sym("x") + inner = sympy.Mod(x, 16, evaluate=False) / 16 + # ceiling([0, 15/16]) = (0, 1). + assert expr_bounds(sympy.ceiling(inner)) == (0, 1) + + +def test_bounds_piecewise(): + x = _sym("x") + pw = sympy.Piecewise( + (sympy.Mod(x, 4, evaluate=False), x > 10), + (sympy.Integer(5), True), + ) + # Branch 0: [0, 3], branch 1: [5, 5] → envelope [0, 5]. + assert expr_bounds(pw) == (0, 5) + + +def test_bounds_max(): + x = _sym("x") + a = sympy.Mod(x, 4, evaluate=False) # [0, 3] + b = sympy.Mod(x, 8, evaluate=False) # [0, 7] + assert expr_bounds(sympy.Max(a, b)) == (0, 7) + + +def test_bounds_min(): + x = _sym("x") + a = sympy.Mod(x, 4, evaluate=False) # [0, 3] + b = sympy.Mod(x, 8, evaluate=False) # [0, 7] + assert expr_bounds(sympy.Min(a, b)) == (0, 3) + + # ---- simplify tests ---- diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 490c6f34c4..94c1bc5d2a 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -62,7 +62,7 @@ ) from ...wave.utils.general_utils import get_fastest_index, infer_dim, linearize_index from ...wave.utils.mapping_utils import transform_index_on_mapping -from ...wave.utils.symbol_utils import safe_subs +from ...wave.utils.symbol_utils import safe_subs, simplify from .emitter import ( WaveEmitter, add_emitter_subs, @@ -95,16 +95,6 @@ def _get_start_indices( return start_indices -@functools.lru_cache -def _simplify(expr): - """ - Simple wrapper around simplify in order to utilize LRU Cache. - This is important to minimize compile time caused by re-simplifying - expressions. - """ - return sympy.simplify(expr) - - def _split_index(src: IndexExpr | int) -> tuple[IndexExpr, IndexExpr]: """ Split index expr into thread-dependent and thread-independent parts @@ -116,7 +106,7 @@ def _split_index(src: IndexExpr | int) -> tuple[IndexExpr, IndexExpr]: # Compute thread-independent index as `orig_index - thread_dependent_index` # All thread symbols and dynamic should cancel-out in the result. - thread_independent_index = _simplify(src - thread_dependent_index) + thread_independent_index = simplify(src - thread_dependent_index) if thread_independent_index.free_symbols - set(subs_wg.keys()): # If we have any symbols besides wg symbols, means some thread or # dynamic symbols were not canceled out, use the entire index as @@ -149,16 +139,16 @@ def _split_index_three_way( no_wg = safe_subs(src, subs_wg) # Uniform part = no_wg - thread_only (induction-var-dependent terms). - uniform_index = _simplify(no_wg - thread_index) + uniform_index = simplify(no_wg - thread_index) # WG part = src - no_wg. - wg_index = _simplify(src - no_wg) + wg_index = simplify(src - no_wg) # Validate WG part contains only WG symbols. if wg_index.free_symbols - set(subs_wg.keys()): wg_index = sympy.sympify(0) no_wg = src - uniform_index = _simplify(no_wg - thread_index) + uniform_index = simplify(no_wg - thread_index) # Validate uniform part has no thread-dependent symbols. thread_syms = {THREAD_0, THREAD_1, THREAD_2} diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index bda0a39083..9b3b5b2170 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -400,6 +400,24 @@ def has_gpr_offsets(node: fx.Node) -> bool: custom.graph.erase_node(custom.fx_node) +def _symbol_ranges_from_constraints( + constraints: list[Constraint], +) -> dict[sympy.Symbol, tuple[int, int]]: + """Build symbol→(lo, hi) ranges for thread IDs from constraints.""" + hw = get_hardware_constraint(constraints) + tpw = hw.threads_per_wave + wpb = hw.waves_per_block + ranges: dict[sympy.Symbol, tuple[int, int]] = {} + if wpb is not None: + ranges[THREAD_0] = (sympy.Integer(0), sympy.Integer(tpw * wpb[0] - 1)) + ranges[THREAD_1] = (sympy.Integer(0), sympy.Integer(wpb[1] - 1)) + ranges[THREAD_2] = (sympy.Integer(0), sympy.Integer(wpb[2] - 1)) + else: + # Still know THREAD_0 ∈ [0, tpw - 1] at minimum. + ranges[THREAD_0] = (sympy.Integer(0), sympy.Integer(tpw - 1)) + return ranges + + def merge_contiguous_reads( trace: CapturedTrace, constraints: list[Constraint], target: str ): @@ -414,7 +432,8 @@ def merge_contiguous_reads( physical flat offset starts differ by exactly ept are merged. """ hw_constraint = get_hardware_constraint(constraints) - while _merge_contiguous_reads_once(trace, hw_constraint): + ranges = _symbol_ranges_from_constraints(constraints) + while _merge_contiguous_reads_once(trace, hw_constraint, ranges): pass @@ -443,7 +462,7 @@ def _get_physical_start( return {dim: custom.index[dim].start for dim in symbolic_dims} -def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: +def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint, ranges) -> bool: """Single merge pass: merge adjacent pairs of same-ept reads. Groups reads by (memory operand, ept) and merges pairs whose physical @@ -512,6 +531,7 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: for j in range(i + 1, len(read_infos)): if j in merged: continue + off1, phys1, custom1, node1 = read_infos[i] off2, phys2, custom2, node2 = read_infos[j] @@ -533,7 +553,7 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: if diff is None: continue else: - diff = sym_simplify(raw_diff) + diff = sym_simplify(raw_diff, ranges) if diff != ept and diff != -ept: nv = _numeric_eval_constant(raw_diff) if nv is not None: @@ -562,7 +582,7 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: merge_dim = None break else: - d = sym_simplify(raw_d) + d = sym_simplify(raw_d, ranges) if d != ept and d != 0: nv = _numeric_eval_constant(raw_d) if nv is not None: diff --git a/wave_lang/kernel/wave/utils/mapping_utils.py b/wave_lang/kernel/wave/utils/mapping_utils.py index 24ed506463..eaba428935 100644 --- a/wave_lang/kernel/wave/utils/mapping_utils.py +++ b/wave_lang/kernel/wave/utils/mapping_utils.py @@ -11,7 +11,7 @@ from ..._support.indexing import IndexingContext from ...lang.wave_types import IndexMapping from .general_utils import infer_dim, get_fastest_index -from .symbol_utils import IndexExpr, IndexSequence, IndexSymbol, subs_idxc +from .symbol_utils import IndexExpr, IndexSequence, IndexSymbol, simplify, subs_idxc from ...compiler.utils import strides_from_symbolic_shape K = TypeVar("K") # Key type @@ -43,141 +43,6 @@ def get_dict_with_updated_key( return new_dict -def _simplify_sympy_expr(expr: IndexExpr) -> IndexExpr: - """Apply custom sympy simplifications""" - - def check_mul(mul): - ret = None - for arg in mul.args: - if arg.is_number: - if arg < 0: - return None - - if ret is not None: - return None - - ret = arg - continue - - if not (isinstance(arg, (sympy.floor, sympy.Mod)) or arg.is_integer): - return None - - if not arg.is_nonnegative: - return None - - return ret - - def transform_mod(expr): - """Move constant outside of Mod expr - - Example: - (floor(a) * 4 + 3) % 16 -> (floor(a) * 4) % 16 + 3 - """ - if not isinstance(expr, sympy.Mod): - return None - - p, q = expr.args - if not q.is_number or q < 0: - return None - - if not isinstance(p, sympy.Add): - return None - - c = None - terms = [] - mult = None - for arg in p.args: - if arg.is_number: - if c is not None: - return None - - c = arg - continue - - if not isinstance(arg, sympy.Mul): - return None - - m = check_mul(arg) - if (m is None) or (q % m != 0): - return None - - mult = m if (mult is None) or (m < mult) else mult - terms.append(arg) - - if c >= mult: - return None - - return (sum(terms) % q) + c - - def check_mul_rational(mul): - ret = None - for arg in mul.args: - if isinstance(arg, sympy.Rational): - if ret is not None: - return None - - if arg.p < 0 or arg.q < 0: - return None - - ret = arg - continue - - if not (isinstance(arg, (sympy.floor, sympy.Mod)) or arg.is_integer): - return None - - if not arg.is_nonnegative: - return None - - return ret - - def transform_floor(expr): - """Simplify rational addition inside floor expr - - Example: - floor(floor(a)/3 + 1/6) -> floor(floor(a)/3) - """ - if not isinstance(expr, sympy.floor): - return None - - expr = expr.args[0] - if not isinstance(expr, sympy.Add): - return None - - c = None - for arg in expr.args: - if isinstance(arg, sympy.Rational): - if c is not None: - return None - - c = arg - - if c is None: - return None - - terms = [] - for arg in expr.args: - if isinstance(arg, sympy.Rational): - continue - - if not isinstance(arg, sympy.Mul): - return None - - r = check_mul_rational(arg) - if r is None or r.p != 1: - return None - - if r <= c: - return None - - terms.append(arg) - - return sympy.floor(sum(terms)) - - expr = expr.replace(lambda e: transform_mod(e) is not None, transform_mod) - expr = expr.replace(lambda e: transform_floor(e) is not None, transform_floor) - return sympy.simplify(expr) - - def approximate_difference( expr: IndexExpr, vars: list[IndexSymbol], elements_per_thread: int ) -> bool: @@ -379,7 +244,7 @@ def _make_aligned_index(offset_val: int) -> dict[IndexExpr, IndexSequence]: [new_index[infer_dim(d)] for d in symbolic_shape], strides, ) - diff_expr = _simplify_sympy_expr(offset - prev_offset) + diff_expr = simplify(offset - prev_offset) if diff_expr != 1: return False prev_offset = offset diff --git a/wave_lang/kernel/wave/utils/symbol_utils.py b/wave_lang/kernel/wave/utils/symbol_utils.py index 01e6d777f1..b045c33f79 100644 --- a/wave_lang/kernel/wave/utils/symbol_utils.py +++ b/wave_lang/kernel/wave/utils/symbol_utils.py @@ -63,39 +63,81 @@ #################################################################### +# Ranges type: tuple of (symbol, (lo, hi)) pairs. Hashable for lru_cache. +SymbolRanges = tuple[tuple[sympy.Symbol, tuple[sympy.Expr, sympy.Expr]], ...] + + +def _lookup_range( + sym: sympy.Symbol, ranges: SymbolRanges = () +) -> tuple[sympy.Expr, sympy.Expr] | None: + """Find bounds for *sym* in *ranges*, or return None.""" + for s, bounds in ranges: + if s == sym: + return bounds + return None + + @lru_cache(maxsize=1024) -def expr_bounds(expr: sympy.Expr) -> tuple[sympy.Expr, sympy.Expr] | None: +def expr_bounds( + expr: sympy.Expr, + ranges: SymbolRanges = (), +) -> tuple[sympy.Expr, sympy.Expr] | None: """Compute (lo, hi) bounds for a sympy expression via interval arithmetic. - Free symbols are assumed to be non-negative integers (hardware indices). - Returns (lo, hi) or None if bounds cannot be determined. + Free symbols default to [0, ∞) (hardware indices). Pass *ranges* as a + tuple of ``(symbol, (lo, hi))`` pairs to supply tighter bounds. + Returns ``(lo, hi)`` or ``None`` if bounds cannot be determined. """ if expr.is_Integer or expr.is_Rational: return (expr, expr) if expr.is_Symbol: + r = _lookup_range(expr, ranges) + if r is not None: + return r return (sympy.Integer(0), sympy.oo) if expr.is_nonnegative else None if isinstance(expr, sympy.Mod): p, q = expr.args if q.is_positive and q.is_number: - p_bounds = expr_bounds(p) + p_bounds = expr_bounds(p, ranges) if p_bounds and p_bounds[0] >= 0 and p_bounds[1] < q: return p_bounds return (sympy.Integer(0), q - 1) return None if isinstance(expr, sympy.floor): - inner_bounds = expr_bounds(expr.args[0]) + inner_bounds = expr_bounds(expr.args[0], ranges) if inner_bounds: return (sympy.floor(inner_bounds[0]), sympy.floor(inner_bounds[1])) return None + if isinstance(expr, sympy.ceiling): + inner_bounds = expr_bounds(expr.args[0], ranges) + if inner_bounds: + return (sympy.ceiling(inner_bounds[0]), sympy.ceiling(inner_bounds[1])) + return None + if isinstance(expr, sympy.Piecewise): + # Envelope of all branches — any branch could be active. + branch_bounds = [expr_bounds(val, ranges) for val, _ in expr.args] + if all(b is not None for b in branch_bounds): + return (min(b[0] for b in branch_bounds), max(b[1] for b in branch_bounds)) + return None + if isinstance(expr, sympy.Max): + bounds = [expr_bounds(a, ranges) for a in expr.args] + if all(b is not None for b in bounds): + return (max(b[0] for b in bounds), max(b[1] for b in bounds)) + return None + if isinstance(expr, sympy.Min): + bounds = [expr_bounds(a, ranges) for a in expr.args] + if all(b is not None for b in bounds): + return (min(b[0] for b in bounds), min(b[1] for b in bounds)) + return None if isinstance(expr, sympy.Add): - bounds = [expr_bounds(a) for a in expr.args] + bounds = [expr_bounds(a, ranges) for a in expr.args] if all(b is not None for b in bounds): return (sum(b[0] for b in bounds), sum(b[1] for b in bounds)) return None if isinstance(expr, sympy.Mul): if not expr.args: return (sympy.Integer(1), sympy.Integer(1)) - bounds = [expr_bounds(a) for a in expr.args] + bounds = [expr_bounds(a, ranges) for a in expr.args] if all(b is not None for b in bounds): # Bail out if any bound is infinite (0 * oo = NaN). if any(sympy.oo in b or -sympy.oo in b for b in bounds): @@ -109,18 +151,31 @@ def expr_bounds(expr: sympy.Expr) -> tuple[sympy.Expr, sympy.Expr] | None: return None -@lru_cache(maxsize=1024) -def simplify(expr: sympy.Expr) -> sympy.Expr: +def simplify( + expr: sympy.Expr, + ranges: dict[sympy.Symbol, tuple[sympy.Expr, sympy.Expr]] | None = None, +) -> sympy.Expr: """Simplify a sympy expression using interval arithmetic and sympy.simplify. Extends sympy.simplify with bounds-based reasoning that can resolve floor/Mod sub-expressions (e.g. floor(Mod(x,16)/16) -> 0) that standard sympy cannot handle. Iterates to a fixed point. + + Pass *ranges* as ``{symbol: (lo, hi)}`` to supply tighter bounds than + the default [0, ∞). """ + print(f"simplify: {expr}") + frozen: SymbolRanges = tuple(ranges.items()) if ranges else () + return _simplify_impl(subs_idxc(expr), frozen) + + +@lru_cache(maxsize=1024) +def _simplify_impl(expr: sympy.Expr, ranges: SymbolRanges) -> sympy.Expr: if not isinstance(expr, sympy.Basic): return expr for _ in range(5): - new_expr = _bounds_simplify_once(expr) + new_expr = _algebraic_simplify(expr) + new_expr = _bounds_simplify_once(new_expr, ranges) new_expr = sympy.simplify(new_expr) if new_expr == expr: break @@ -128,7 +183,108 @@ def simplify(expr: sympy.Expr) -> sympy.Expr: return expr -def _bounds_simplify_once(expr: sympy.Expr) -> sympy.Expr: +@lru_cache(maxsize=1024) +def _algebraic_simplify(expr: sympy.Expr) -> sympy.Expr: + """Algebraic rewrites for Mod and floor that sympy misses. + + - ``(floor(a)*k + c) % m -> (floor(a)*k) % m + c`` + when ``k | m`` and ``0 <= c < k`` (pulls constant out of Mod). + - ``floor(floor(a)/q + c) -> floor(floor(a)/q)`` + when ``c < 1/q`` (drops negligible rational offset from floor). + """ + + def _check_mul_nonneg_int(mul): + """Return the numeric factor of *mul* if all factors are nonneg integer-ish.""" + ret = None + for arg in mul.args: + if arg.is_number: + if arg < 0: + return None + if ret is not None: + return None + ret = arg + continue + if not (isinstance(arg, (sympy.floor, sympy.Mod)) or arg.is_integer): + return None + if not arg.is_nonnegative: + return None + return ret + + def _transform_mod(e): + if not isinstance(e, sympy.Mod): + return None + p, q = e.args + if not q.is_number or q < 0 or not isinstance(p, sympy.Add): + return None + c = None + terms = [] + mult = None + for arg in p.args: + if arg.is_number: + if c is not None: + return None + c = arg + continue + if not isinstance(arg, sympy.Mul): + return None + m = _check_mul_nonneg_int(arg) + if m is None or q % m != 0: + return None + mult = m if mult is None or m < mult else mult + terms.append(arg) + if c is None or c >= mult: + return None + return (sum(terms) % q) + c + + def _check_mul_rational(mul): + ret = None + for arg in mul.args: + if isinstance(arg, sympy.Rational): + if ret is not None: + return None + if arg.p < 0 or arg.q < 0: + return None + ret = arg + continue + if not (isinstance(arg, (sympy.floor, sympy.Mod)) or arg.is_integer): + return None + if not arg.is_nonnegative: + return None + return ret + + def _transform_floor(e): + if not isinstance(e, sympy.floor): + return None + inner = e.args[0] + if not isinstance(inner, sympy.Add): + return None + c = None + for arg in inner.args: + if isinstance(arg, sympy.Rational): + if c is not None: + return None + c = arg + if c is None: + return None + terms = [] + for arg in inner.args: + if isinstance(arg, sympy.Rational): + continue + if not isinstance(arg, sympy.Mul): + return None + r = _check_mul_rational(arg) + if r is None or r.p != 1 or r <= c: + return None + terms.append(arg) + return sympy.floor(sum(terms)) + + expr = expr.replace(lambda e: _transform_mod(e) is not None, _transform_mod) + expr = expr.replace(lambda e: _transform_floor(e) is not None, _transform_floor) + return expr + + +@lru_cache(maxsize=1024) +def _bounds_simplify_once(expr: sympy.Expr, ranges: SymbolRanges) -> sympy.Expr: """Single bottom-up pass of bounds-based simplification. Mod nodes are handled specially to avoid a sympy auto-evaluation bug @@ -138,13 +294,13 @@ def _bounds_simplify_once(expr: sympy.Expr) -> sympy.Expr: if not isinstance(expr, sympy.Basic) or expr.is_Atom: return expr - simplified_args = [_bounds_simplify_once(a) for a in expr.args] + simplified_args = [_bounds_simplify_once(a, ranges) for a in expr.args] # Handle Mod before reconstruction to avoid triggering the sympy bug. if isinstance(expr, sympy.Mod): p, q = simplified_args if q.is_positive and q.is_number: - p_bounds = expr_bounds(p) + p_bounds = expr_bounds(p, ranges) if p_bounds and p_bounds[0] >= 0 and p_bounds[1] < q: return p # Keep Mod but prevent buggy auto-evaluation. @@ -154,7 +310,7 @@ def _bounds_simplify_once(expr: sympy.Expr) -> sympy.Expr: expr = expr.func(*simplified_args) if isinstance(expr, sympy.floor): - bounds = expr_bounds(expr.args[0]) + bounds = expr_bounds(expr.args[0], ranges) if ( bounds and bounds[0] != sympy.oo @@ -162,6 +318,15 @@ def _bounds_simplify_once(expr: sympy.Expr) -> sympy.Expr: and sympy.floor(bounds[0]) == sympy.floor(bounds[1]) ): return sympy.Integer(int(sympy.floor(bounds[0]))) + if isinstance(expr, sympy.ceiling): + bounds = expr_bounds(expr.args[0], ranges) + if ( + bounds + and bounds[0] != sympy.oo + and bounds[1] != sympy.oo + and sympy.ceiling(bounds[0]) == sympy.ceiling(bounds[1]) + ): + return sympy.Integer(int(sympy.ceiling(bounds[0]))) return expr From ef4a0fc60a1041e915930a11b372b5b38dc9c1ee Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 2 Mar 2026 17:23:04 +0100 Subject: [PATCH 11/12] format Signed-off-by: Ivan Butygin --- wave_lang/kernel/compiler/host_codegen.py | 6 +++--- wave_lang/kernel/compiler/wave_codegen/read_write.py | 1 - wave_lang/kernel/wave/scheduling/schedule.py | 4 +--- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/wave_lang/kernel/compiler/host_codegen.py b/wave_lang/kernel/compiler/host_codegen.py index e742549d9c..7747c8b464 100644 --- a/wave_lang/kernel/compiler/host_codegen.py +++ b/wave_lang/kernel/compiler/host_codegen.py @@ -173,9 +173,9 @@ def isolated_test_call( else: # Solve shape_expr = d for the base symbol. solutions = sympy.solve(dim_expr - _dim_val, base_sym) - assert len(solutions) == 1, ( - f"Cannot solve {dim_expr} = _dim_val for {base_sym}" - ) + assert ( + len(solutions) == 1 + ), f"Cannot solve {dim_expr} = _dim_val for {base_sym}" arg_dim_inverse[base_sym] = solutions[0] if async_dispatch: diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 94c1bc5d2a..e06d11cffb 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -17,7 +17,6 @@ DenseElementsAttr, IndexType, InsertionPoint, - IntegerAttr, IntegerType, IrType, MemRefType, diff --git a/wave_lang/kernel/wave/scheduling/schedule.py b/wave_lang/kernel/wave/scheduling/schedule.py index d97210e030..11cb374c42 100644 --- a/wave_lang/kernel/wave/scheduling/schedule.py +++ b/wave_lang/kernel/wave/scheduling/schedule.py @@ -318,9 +318,7 @@ def build_guarded_pipeline_with_remainder( # to the original reduction_graph nodes so that # _update_kernel_node_mapping can match tracked lists by identity. new_to_old = {v: k for k, v in body_old_to_new.items()} - node_mapping = { - new_to_old.get(k, k): v for k, v in node_mapping.items() - } + node_mapping = {new_to_old.get(k, k): v for k, v in node_mapping.items()} # Set the count for the pipelined loop # With step > 1 (e.g., from unrolling), we need to reduce the count by more From 28b02f46aa85ac82e2a05d6466c7de097a27c573 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Mon, 2 Mar 2026 17:23:28 +0100 Subject: [PATCH 12/12] dynamic size Signed-off-by: Ivan Butygin --- wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py index 56d9fec944..0490acbbe3 100755 --- a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py +++ b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py @@ -452,12 +452,17 @@ def repeat( } hyperparams.update(get_default_scheduling_params()) + dynamic_symbols = [M, N, K] + for symbol in dynamic_symbols: + del hyperparams[symbol] + options = WaveCompileOptions( subs=hyperparams, canonicalize=True, schedule=SchedulingType.MANUAL, use_global_to_shared=True, minimize_shared_allocs=False, + dynamic_symbols=dynamic_symbols, ) return gemm, options