diff --git a/lit_tests/kernel/wave/scaled_gemm.py b/lit_tests/kernel/wave/scaled_gemm.py index 7bf3b8b17..7fb39936c 100644 --- a/lit_tests/kernel/wave/scaled_gemm.py +++ b/lit_tests/kernel/wave/scaled_gemm.py @@ -10,6 +10,7 @@ ScaledMMAType, ) from wave_lang.kernel.wave.scheduling.schedule_enums import SchedulingType +from wave_lang.kernel.wave.templates import get_tagged_mxfp4_gemm_preshuffle_b from wave_lang.kernel.wave.templates.test_kernels import ( get_broadcasted_scale_gemm_mxfp4, ) @@ -361,9 +362,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: } # Epilogue Local Read - # CHECK-COUNT-16: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-16: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> - # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-8: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> # Epilogue MFMA @@ -471,8 +472,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: rocdl.s.waitcnt # CHECK: amdgpu.lds_barrier - # Steady state local loads - # CHECK-COUNT-48: vector.load{{.*}} memref<{{.*}}, #gpu.address_space> + # Steady state local loads (8+4 scale loads as vector<8xi8> + 16+8 data loads as vector<16xi8>) + # CHECK-COUNT-36: vector.load{{.*}} memref<{{.*}}, #gpu.address_space> # Steady State global load to lds # CHECK-COUNT-34: amdgpu.gather_to_lds @@ -637,9 +638,9 @@ def repeat( # CHECK: } # Epilogue Local Read - # CHECK-COUNT-16: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-8: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-16: vector.load {{.*}} : memref<34816xi8, #gpu.address_space>, vector<16xi8> - # CHECK-COUNT-8: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-8: vector.load {{.*}} : memref<34816xi8, #gpu.address_space>, vector<16xi8> # Epilogue MFMA @@ -997,3 +998,57 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # Unmasked vector stores for output. # CHECK: vector.store # CHECK: return + + +@run_test +def test_dynamic_preshuffle_b_scale_coalescing(): + """Verify B-scale reads coalesce into clean vector<16xi8> with dynamic dims. + + Uses the preshuffle-B MXFP4 template with dynamic M, N, K and small + block sizes. The K % 256 divisibility assumption lets the coalescer + apply divisibility substitutions during numeric probing, so the 2D + decomposition (row = offset floordiv K/2, col = offset mod K/2) gives + consistent per-dim diffs across probe sets. Without this fix, probes + like K=137 make K/2=68, causing inconsistent row/col diffs and + fragmenting 16-byte scale reads into {2, 16, 8, 4} loads glued by + vector.from_elements. + """ + shape = (256, 256, 256) + block = (128, 128, 256) + kernel, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape, + block, + wave_shape=(2, 2), + mfma_variant=ScaledMMAType.F32_16x16x128_F8F6F4, + reorder_workgroups=False, + ) + dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in dynamic_symbols: + del options.subs[sym] + options.dynamic_symbols = dynamic_symbols + options.schedule = SchedulingType.NONE + options.use_buffer_ops = True + options.compile_to_mlir = True + options.device = "hip" + options.target = "gfx950" + result = wave_compile(options, kernel) + print(result.asm) + + # CHECK-LABEL: test_dynamic_preshuffle_b_scale_coalescing + + # Dynamic index arguments for M, N, K. + # CHECK: func.func @gemm(%arg0: {{.*}}, %arg1: {{.*}}, %arg2: {{.*}}, %arg3: {{.*}}, %arg4: {{.*}}, %arg5: index, %arg6: index, %arg7: index) + + # Buffer ops: fat_raw_buffer_cast for global buffers. + # CHECK: amdgpu.fat_raw_buffer_cast + + # B-scale reads are clean vector<16xi8> from fat_raw_buffer — no + # fragmentation into mixed-width loads glued by from_elements. + # A-scale reads are vector<4xi8>. + # CHECK: scf.for + # CHECK-COUNT-8: vector.load %{{.*}} : memref>, vector<16xi8> + # CHECK-COUNT-2: vector.load %{{.*}} : memref>, vector<4xi8> + # CHECK: amdgpu.scaled_mfma + + # No byte-level reassembly — coalescing succeeded. + # CHECK-NOT: vector.from_elements diff --git a/lit_tests/kernel/wave/scaled_mma.py b/lit_tests/kernel/wave/scaled_mma.py index f37dae9b5..a762fb4a2 100644 --- a/lit_tests/kernel/wave/scaled_mma.py +++ b/lit_tests/kernel/wave/scaled_mma.py @@ -352,9 +352,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-COUNT-4: vector.load {{.*}} : memref<16384x8192xi8, strided<[8192, 1]>>, vector<16xi8> # CHECK-COUNT-1: vector.load {{.*}} : memref<16384x512xi8, strided<[512, 1]>>, vector<4xi8> # CHECK: amdgpu.lds_barrier - # CHECK-COUNT-16: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-16: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> - # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-8: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> # CHECK-COUNT-8: vector.bitcast {{.*}} : vector<16xi8> to vector<32xf4E2M1FN> # CHECK-COUNT-8: vector.bitcast {{.*}} : vector<1xi8> to vector<1xf8E8M0FNU> diff --git a/tests/unittests/test_numeric_probing.py b/tests/unittests/test_numeric_probing.py new file mode 100644 index 000000000..d30d5c5ee --- /dev/null +++ b/tests/unittests/test_numeric_probing.py @@ -0,0 +1,287 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Tests for numeric probing in read coalescing. + +The pairwise merge uses numeric probing to verify that adjacent reads +have consistent per-dim offset diffs. With symbolic K, the 2D +decomposition (row = offset floordiv K/2, col = offset mod K/2) +can give inconsistent diffs when probe values don't respect divisibility +constraints. Applying divisibility forward subs (K -> 256*K') fixes +this. + +The expressions here are taken from the MXFP4 preshuffle B-scale +codegen (test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm on gfx950). +""" + +import pytest +import sympy + +from wave_lang.kernel.wave.analysis.partition_strided_operators import ( + _MERGE_PROBES, + _eval_expr, + _find_merge_dim_from_diffs, +) +from wave_lang.kernel.wave.utils.symbol_utils import ( + _numeric_eval_constant, + safe_subs, +) + + +# --- Symbols matching the MXFP4 B-scale path --- +# t = thread_id_x (0..63), K = GEMM K dimension (multiple of 256), +# wg_m/wg_n = workgroup tile offsets. +t = sympy.Symbol("t", integer=True, nonneg=True) +K = sympy.Symbol("K", integer=True, positive=True) +wg_m = sympy.Symbol("wg_m", integer=True, nonneg=True) +wg_n = sympy.Symbol("wg_n", integer=True, nonneg=True) +half_K = K / 2 + +# Preshuffle base offset for thread t (from MLIR #map27/#map28 pattern): +# base(t) = t*16 + ((t%64)//16)*256 - (t//16)*256 +# For t in 0..63 this simplifies to t*16, but sympy doesn't know that. +_base = t * 16 + ((t % 64) // 16) * 256 - (t // 16) * 256 + + +def _make_row_col(offset): + """Build (row, col) from a linearized preshuffle offset.""" + row = wg_m * 64 + wg_n * 256 + sympy.floor(offset / half_K) + col = offset % half_K + return row, col + + +# The MXFP4 preshuffle MMA layout scatters 16 scale bytes across +# offsets with large gaps. From the MLIR maps (#map27-#map36), the +# actual offsets within a single thread's 16-byte read are: +# base, base+1024, base+2, base+3, base+11, base+15, +# base+16*N, base+16*N+1024, ... +# Here we use the first 4 groups from the MLIR maps: offsets +# base+0, base+1024, base+2, base+3 — which is enough to +# demonstrate the inconsistency. +_PRESHUFFLE_DELTAS = [ + 0, + 1, + 1024, + 1025, + 2, + 3, + 11, + 15, + 16, + 17, + 1040, + 1041, + 18, + 19, + 27, + 31, +] + +BYTE_READS = [] +for delta in _PRESHUFFLE_DELTAS: + row_i, col_i = _make_row_col(_base + delta) + flat_i = row_i * half_K + col_i + BYTE_READS.append({"row": row_i, "col": col_i, "flat": flat_i, "delta": delta}) + +# Divisibility forward sub: K -> 256 * K'. +K_prime = sympy.Symbol("_K_div_256", integer=True, positive=True) +DIV_FWD = [(K, 256 * K_prime)] + + +def _build_merge_probe_maps(expressions): + """Build _MERGE_PROBES probe maps over free symbols in expressions.""" + all_free = set() + for expr in expressions: + if hasattr(expr, "free_symbols"): + all_free |= expr.free_symbols + free_list = sorted(all_free, key=str) + return [{s: gen(i) for i, s in enumerate(free_list)} for gen in _MERGE_PROBES] + + +def _find_reads_by_delta(d0, d1): + """Return the two BYTE_READS entries with the given deltas.""" + r0 = next(br for br in BYTE_READS if br["delta"] == d0) + r1 = next(br for br in BYTE_READS if br["delta"] == d1) + return r0, r1 + + +class TestMergeProbeConsistency: + """Verify that numeric probing gives consistent per-dim diffs + for adjacent B-scale reads, with and without divisibility subs.""" + + def test_flat_offset_diff_always_correct(self): + """Flat offset diff between bytes equals their delta difference.""" + all_flats = [br["flat"] for br in BYTE_READS] + probes = _build_merge_probe_maps(all_flats) + for probe in probes: + for i in range(len(BYTE_READS) - 1): + flat_i = _eval_expr(BYTE_READS[i]["flat"], probe) + flat_j = _eval_expr(BYTE_READS[i + 1]["flat"], probe) + expected = BYTE_READS[i + 1]["delta"] - BYTE_READS[i]["delta"] + assert flat_j - flat_i == expected, ( + f"flat diff mismatch for deltas " + f"{BYTE_READS[i]['delta']},{BYTE_READS[i+1]['delta']}" + ) + + def test_per_dim_diffs_inconsistent_without_div_subs(self): + """Without divisibility subs, per-dim diffs are inconsistent + across probe sets for reads with large preshuffle gaps. + + Reads at offsets base+0 and base+1024 (from MXFP4 maps #map27 + and #map29) have row_diff and col_diff that depend on K/2. + At K/2=68: row_diff=15, col_diff=-3. + At K/2=125: row_diff=8, col_diff=20. Etc. + """ + r0, r1 = _find_reads_by_delta(0, 1024) + all_exprs = list(r0.values()) + list(r1.values()) + probes = _build_merge_probe_maps(all_exprs) + + diffs_per_probe = [] + for probe in probes: + row_diff = _eval_expr(r1["row"], probe) - _eval_expr(r0["row"], probe) + col_diff = _eval_expr(r1["col"], probe) - _eval_expr(r0["col"], probe) + diffs_per_probe.append((row_diff, col_diff)) + + all_agree = all(d == diffs_per_probe[0] for d in diffs_per_probe[1:]) + assert not all_agree, ( + "Expected inconsistent per-dim diffs across probes, " + f"but all agreed: {diffs_per_probe[0]}" + ) + + def test_adjacent_diffs_consistent_with_div_subs(self): + """With divisibility subs (K -> 256*K'), adjacent byte pairs + have consistent diffs: row_diff=0, col_diff=delta_diff.""" + div_reads = [ + {k: safe_subs(v, DIV_FWD) for k, v in br.items()} for br in BYTE_READS + ] + all_exprs = [] + for br in div_reads: + all_exprs.extend(v for k, v in br.items() if k != "delta") + probes = _build_merge_probe_maps(all_exprs) + + # Test all consecutive pairs sorted by delta. + sorted_reads = sorted(zip(BYTE_READS, div_reads), key=lambda x: x[0]["delta"]) + for probe in probes: + for idx in range(len(sorted_reads) - 1): + orig_a, div_a = sorted_reads[idx] + orig_b, div_b = sorted_reads[idx + 1] + row_diff = _eval_expr(div_b["row"], probe) - _eval_expr( + div_a["row"], probe + ) + col_diff = _eval_expr(div_b["col"], probe) - _eval_expr( + div_a["col"], probe + ) + expected_flat = orig_b["delta"] - orig_a["delta"] + assert row_diff == 0, ( + f"row diff != 0 for deltas " f"{orig_a['delta']},{orig_b['delta']}" + ) + assert col_diff == expected_flat, ( + f"col diff {col_diff} != {expected_flat} for deltas " + f"{orig_a['delta']},{orig_b['delta']}" + ) + + def test_large_gap_consistent_with_div_subs(self): + """Reads 1024 apart have consistent diffs after div subs.""" + r0, r1 = _find_reads_by_delta(0, 1024) + div_r0 = {k: safe_subs(v, DIV_FWD) for k, v in r0.items()} + div_r1 = {k: safe_subs(v, DIV_FWD) for k, v in r1.items()} + + all_exprs = [v for d in [div_r0, div_r1] for k, v in d.items() if k != "delta"] + probes = _build_merge_probe_maps(all_exprs) + + for probe in probes: + row_diff = _eval_expr(div_r1["row"], probe) - _eval_expr( + div_r0["row"], probe + ) + col_diff = _eval_expr(div_r1["col"], probe) - _eval_expr( + div_r0["col"], probe + ) + assert row_diff == 0, f"row diff {row_diff} != 0" + assert col_diff == 1024, f"col diff {col_diff} != 1024" + + def test_merge_dim_found_with_div_subs(self): + """After div subs, _find_merge_dim_from_diffs finds the col dim.""" + r0, r1 = _find_reads_by_delta(0, 1) + div_r0 = {k: safe_subs(v, DIV_FWD) for k, v in r0.items()} + div_r1 = {k: safe_subs(v, DIV_FWD) for k, v in r1.items()} + all_exprs = [v for d in [div_r0, div_r1] for k, v in d.items() if k != "delta"] + probes = _build_merge_probe_maps(all_exprs) + + dim_row = sympy.Symbol("dim_row") + dim_col = sympy.Symbol("dim_col") + dims = [dim_row, dim_col] + for probe in probes: + row_diff = _eval_expr(div_r1["row"], probe) - _eval_expr( + div_r0["row"], probe + ) + col_diff = _eval_expr(div_r1["col"], probe) - _eval_expr( + div_r0["col"], probe + ) + result = _find_merge_dim_from_diffs( + {dim_row: row_diff, dim_col: col_diff}, 1, dims + ) + assert result == dim_col + + +class TestNumericEvalConstant: + """Test _numeric_eval_constant with preshuffle-like expressions. + + Note: _numeric_eval_constant uses _PROBE_POOL which includes 0, + causing ZeroDivisionError on floor/Mod expressions with symbolic + divisors. It returns None conservatively on any error, so these + complex expressions are beyond its reach. This is documented + behavior, not a bug — the merge probing (_MERGE_PROBES) handles + these cases instead. + """ + + def test_simple_constant(self): + """Sanity: detects that a simple constant expression is constant.""" + x = sympy.Symbol("x", integer=True, nonneg=True) + assert _numeric_eval_constant(3 * x - 3 * x) == 0 + assert _numeric_eval_constant(x - x + 7) == 7 + + def test_floor_mod_identity_returns_none(self): + """floor(x/d)*d + x%d - x is always 0, but _numeric_eval_constant + can't prove it because _PROBE_POOL includes 0 → ZeroDivisionError.""" + x = sympy.Symbol("x", integer=True, nonneg=True) + d = sympy.Symbol("d", integer=True, positive=True) + expr = sympy.floor(x / d) * d + sympy.Mod(x, d) - x + # Returns None because probing hits d=0. + assert _numeric_eval_constant(expr) is None + + def test_row_diff_not_constant_without_div_subs(self): + """Row diff for adjacent bytes is not constant when K is symbolic.""" + row_diff = BYTE_READS[1]["row"] - BYTE_READS[0]["row"] + assert _numeric_eval_constant(row_diff) is None + + +class TestFindMergeDim: + """Test _find_merge_dim_from_diffs helper.""" + + def test_single_dim_matches(self): + dim_a = sympy.Symbol("a") + dim_b = sympy.Symbol("b") + dims = [dim_a, dim_b] + assert _find_merge_dim_from_diffs({dim_a: 0, dim_b: 1}, 1, dims) == dim_b + + def test_both_change_returns_none(self): + dim_a = sympy.Symbol("a") + dim_b = sympy.Symbol("b") + dims = [dim_a, dim_b] + assert _find_merge_dim_from_diffs({dim_a: 1, dim_b: 1}, 1, dims) is None + + def test_wrong_diff_returns_none(self): + dim_a = sympy.Symbol("a") + dim_b = sympy.Symbol("b") + dims = [dim_a, dim_b] + assert _find_merge_dim_from_diffs({dim_a: 0, dim_b: 2}, 1, dims) is None + + def test_nonzero_non_ept_returns_none(self): + """If a dim changes by something other than 0 or ept, reject.""" + dim_a = sympy.Symbol("a") + dim_b = sympy.Symbol("b") + dims = [dim_a, dim_b] + assert _find_merge_dim_from_diffs({dim_a: 3, dim_b: 0}, 1, dims) is None diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 5caaada52..aedf6b20a 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -711,7 +711,18 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): ) dynamic_vals_map_start = _build_dyn_vals_map(mapping, dyn_vals) - if mapping: + is_global_mem = kb_src.type.memory_space is None + buffer_ops_enabled = emitter.options.use_buffer_ops and is_global_mem + + precomputed_mask_expr = getattr(node, "precomputed_mask_expr", None) + if precomputed_mask_expr is not None and not buffer_ops_enabled: + mask = gen_sympy_index(add_emitter_subs(emitter), precomputed_mask_expr) + mask_vec_type = VectorType.get( + [elements_per_thread], IntegerType.get_signless(1) + ) + if mask.type != mask_vec_type: + mask = vector_d.broadcast(mask_vec_type, mask) + elif mapping: transformed_index = transform_index_on_mapping( mapping, input_shape, index, is_read=True ) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 18b92c99f..fe8ef040e 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import functools from collections.abc import Sequence from copy import deepcopy from itertools import groupby @@ -30,6 +31,7 @@ ) from ..assumptions import get_divisibility_subs from ..constraints import Constraint +from ..utils.mapping_utils import transform_index_on_mapping from ..utils.tag_utils import propagate_tag from ..utils.general_utils import ( all_equal, @@ -418,7 +420,8 @@ def merge_contiguous_reads( physical flat offset starts differ by exactly ept are merged. """ hw_constraint = get_hardware_constraint(constraints) - while _merge_contiguous_reads_once(trace, hw_constraint): + fwd, _ = get_divisibility_subs(constraints) + while _merge_contiguous_reads_once(trace, hw_constraint, divisibility_fwd=fwd): pass @@ -433,8 +436,6 @@ def _get_physical_start( coordinates. For identity-mapped reads (mapping=None), reads the start offsets directly from the index. """ - from ..utils.mapping_utils import transform_index_on_mapping - if custom.mapping is not None and not custom.has_identity_mapping(): physical = transform_index_on_mapping( custom.mapping, symbolic_shape, custom.index, is_read=True @@ -447,21 +448,156 @@ def _get_physical_start( return {dim: custom.index[dim].start for dim in symbolic_dims} -def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: - """Single merge pass: merge adjacent pairs of same-ept reads. +def _flatten_bounds_to_mask_expr( + custom: Read, + symbolic_shape: tuple, +): + """Pre-compute a read's bounds check as a flat sympy boolean. + + Replicates the _build_mask_with_mapping decision logic: when the + mapping's transformed index contains all bounded dims (and has no + dynamic_val_indices), the transformed index is used; otherwise the + original logical index is used. Returns a sympy boolean expression + (eg. And(idx < bound, ...)) or None if the read has no bounds. - Groups reads by (memory operand, ept) and merges pairs whose physical - flat offset starts differ by exactly ept. Returns True if any merges - happened. """ - from collections import defaultdict - from ...compiler.utils import strides_from_symbolic_shape + if not custom.bounds: + return None + + index = custom.index + + if custom.mapping is not None and not custom.has_identity_mapping(): + transformed = transform_index_on_mapping( + custom.mapping, symbolic_shape, index, is_read=True + ) + use_transformed = ( + all(dim in transformed for dim in custom.bounds) + and not custom.mapping.dynamic_val_indices + ) + if use_transformed: + index = transformed + + conditions = [] + for dim, bound in custom.bounds.items(): + if dim not in index: + continue + start = ( + index[dim].start if isinstance(index[dim], IndexSequence) else index[dim] + ) + if isinstance(start, int): + start = sympy.Integer(start) + if isinstance(bound, int): + bound = sympy.Integer(bound) + conditions.append(sympy.StrictLessThan(start, bound)) + + if not conditions: + return None + + return functools.reduce(sympy.And, conditions) + + +def _build_wide_mask_expr(sub_reads, symbolic_shape, wide_ept): + """Build a concatenated sympy mask for a wide read from its sub-reads. + + Each entry in *sub_reads* is ``(offset, size, orig_custom)`` where + *offset* is the lane offset within the wide vector and *size* is the + number of lanes that sub-read occupies. *wide_ept* is the total + number of elements in the wide read (may exceed ``sum(sizes)`` when + there are gaps, e.g. multiway coalesce with non-contiguous offsets). + + Builds ``Or(And(lane_in_range_0, mask_0), And(lane_in_range_1, mask_1), ...)`` + using pure boolean ops so that ``gen_sympy_index`` can lower it without + the nested-Piecewise ``select_stack`` ordering issue. + + Returns a sympy boolean expression or ``None`` when no sub-read has + bounds. + """ from ..._support.indexing import IndexingContext - # Group reads by (memory, ept, region). A new region starts at each - # subgraph boundary and whenever a side-effecting op (write, barrier, ...) - # is encountered, so we never merge reads across such ops. Reads with - # dynamic mapping values are skipped to keep the merge logic simple. + from ..._support.indexing import index_symbol + + masks = [] + for offset, size, custom in sub_reads: + existing = getattr(custom.fx_node, "precomputed_mask_expr", None) + if existing is not None: + masks.append((offset, size, existing)) + else: + masks.append( + (offset, size, _flatten_bounds_to_mask_expr(custom, symbolic_shape)) + ) + + if not any(m is not None for _, _, m in masks): + return None + + idxc = IndexingContext.current() + iota = idxc.iota(wide_ept) + + terms = [] + for offset, size, mask in masks: + upper = offset + size + lane_cond = sympy.And( + sympy.GreaterThan(iota, offset) if offset > 0 else sympy.true, + sympy.StrictLessThan(iota, upper), + ) + if mask is not None: + # Remap any iota from a previous merge level to the new wide iota. + old_iota_sym = index_symbol(f"$IOTA{size}") + if mask.has(old_iota_sym): + mask = mask.subs(old_iota_sym, iota - offset) + bound_cond = mask + else: + bound_cond = sympy.true + terms.append(sympy.And(lane_cond, bound_cond)) + + return functools.reduce(sympy.Or, terms) + + +def _emit_wide_read(anchor_custom, wide_index, wide_ept, tag_source, mask_expr=None): + """Create a merged Read node covering ``wide_ept`` elements.""" + wide_read = Read( + anchor_custom.memory, + elements_per_thread=wide_ept, + mapping=None, + _write_dependency=anchor_custom._write_dependency, + flags=anchor_custom.flags, + ).add_to_graph(anchor_custom.graph, loc=anchor_custom.location) + wide_custom = get_custom(wide_read) + wide_custom.index = wide_index + if hasattr(tag_source, "vector_shapes"): + wide_read.vector_shapes = deepcopy(tag_source.vector_shapes) + if mask_expr is not None: + wide_read.precomputed_mask_expr = mask_expr + propagate_tag(tag_source, wide_read) + return wide_read + + +def _emit_extract_slice( + wide_read, offset, size, orig_custom, orig_node, symbolic_shape +): + """Create an ExtractSlice from a wide read and propagate metadata.""" + extract = ExtractSlice(wide_read, [offset], [size], [1]).add_to_graph( + orig_custom.graph, loc=orig_custom.location + ) + extract_custom = get_custom(extract) + extract_custom.index = deepcopy(orig_custom.index) + if hasattr(orig_node, "vector_shapes"): + extract_custom.vector_shapes = deepcopy(orig_node.vector_shapes) + propagate_tag(orig_node, extract) + return extract + + +def _group_reads_by_memory( + trace: CapturedTrace, +) -> dict[tuple, list[fx.Node]]: + """Group reads by (memory, ept, region). + + A new region starts at each subgraph boundary and whenever a + side-effecting op (write, barrier, ...) is encountered, so we never + merge reads across such ops. Reads with dynamic mapping values are + skipped to keep the merge logic simple. + """ + from collections import defaultdict + groups: dict[tuple, list[fx.Node]] = defaultdict(list) region_id = 0 for subgraph in trace.region_graph.subgraphs.values(): @@ -477,13 +613,402 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: continue if custom.mapping_dynamic_vals: continue - # Skip reads that have bounds: the merged read would lose the - # mapping and source→target index, making mask generation incorrect. - if custom.bounds is not None: - continue key = (custom.memory, custom.elements_per_thread, region_id) groups[key].append(node) + return groups + + +def _resolve_symbolic_diff(raw_diff, has_complex_mapping, expected_vals=None): + """Resolve a raw sympy offset difference to a value, or None. + + Strategy: + 1. If already a plain int / sympy.Integer, return it directly. + 2. If the mapping is complex (non-identity), use numeric probing. + 3. Otherwise try sym_simplify. If ``expected_vals`` is given and + the result isn't among them, fall back to numeric probing; + return None when neither approach succeeds. + """ + if isinstance(raw_diff, (int, sympy.Integer)): + return int(raw_diff) + if has_complex_mapping: + return _numeric_eval_constant(raw_diff) + simplified = sym_simplify(raw_diff) + if expected_vals is None or simplified in expected_vals: + return simplified + nv = _numeric_eval_constant(raw_diff) + return nv + + +def _do_merge( + lo_i, hi_i, merge_dim, read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint +): + """Emit a wide read merging reads at lo_i and hi_i. Returns True on success.""" + _, lo_phys, lo_custom, lo_node = read_infos[lo_i] + _, _, hi_custom, hi_node = read_infos[hi_i] + new_ept = 2 * ept + element_type = lo_custom.type.dtype + if new_ept > hw_constraint.max_elems_per_load(element_type): + return False + wide_mask = _build_wide_mask_expr( + [(0, ept, lo_custom), (ept, ept, hi_custom)], + symbolic_shape, + new_ept, + ) + with lo_custom.graph.inserting_before(lo_node): + new_index = { + dim: IndexSequence( + lo_phys[dim], + new_ept if dim == merge_dim else 1, + 1, + ) + for dim in symbolic_dims + } + merged_read = _emit_wide_read( + lo_custom, new_index, new_ept, lo_node, mask_expr=wide_mask + ) + lo_extract = _emit_extract_slice( + merged_read, 0, ept, lo_custom, lo_node, symbolic_shape + ) + hi_extract = _emit_extract_slice( + merged_read, ept, ept, hi_custom, hi_node, symbolic_shape + ) + lo_custom.replace_all_uses_with(lo_extract) + hi_custom.replace_all_uses_with(hi_extract) + lo_custom.graph.erase_node(lo_node) + hi_custom.graph.erase_node(hi_node) + return True + + +def _find_merge_dim_from_diffs(dim_diffs, ept, symbolic_dims): + """Return the single dimension whose diff equals ept, or None.""" + merge_dim = None + for dim in symbolic_dims: + d = dim_diffs[dim] + if d == ept: + if merge_dim is not None: + return None + merge_dim = dim + elif d != 0: + return None + return merge_dim + + +# Probe value sets for numeric offset evaluation. Diverse primes avoid +# floor/Mod aliasing; all positive (symbols are nonneg). +_MERGE_PROBES = [ + lambda i: 137 + i * 31, + lambda i: 251 + i * 47, + lambda i: 503 + i * 17, +] + + +def _eval_expr(expr, probe_map): + """Evaluate a sympy expression with concrete symbol values.""" + if isinstance(expr, (int, float)): + return int(expr) + return int(expr.xreplace(probe_map)) + + +def _pairwise_merge( + read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint, divisibility_fwd=None +): + """Merge pairs of reads whose flat offsets differ by exactly ``ept``. + + Returns ``(merged_indices, did_merge)`` where *merged_indices* is + the set of read_infos indices that were consumed. + + Evaluates each offset independently with concrete probe values and + uses dict lookup for O(n) candidate matching, avoiding O(n²) + symbolic diff resolution. + """ + n = len(read_infos) + if n < 2: + return set(), False + + merged = set() + did_merge = False + + # Resolve all flat offsets and per-dim physical starts once. + resolved_flat = [subs_idxc(info[0]) for info in read_infos] + resolved_phys = [ + {dim: subs_idxc(info[1][dim]) for dim in symbolic_dims} for info in read_infos + ] + + # Apply divisibility forward subs so that floordiv/Mod with symbolic + # divisors evaluate consistently across probe points. + if divisibility_fwd: + resolved_flat = [safe_subs(e, divisibility_fwd) for e in resolved_flat] + resolved_phys = [ + {dim: safe_subs(e, divisibility_fwd) for dim, e in phys.items()} + for phys in resolved_phys + ] + + # Collect free symbols across all expressions. + all_free = set() + for expr in resolved_flat: + if hasattr(expr, "free_symbols"): + all_free.update(expr.free_symbols) + for phys in resolved_phys: + for expr in phys.values(): + if hasattr(expr, "free_symbols"): + all_free.update(expr.free_symbols) + free_list = sorted(all_free, key=str) + + # Build probe maps. + probe_maps = [{s: gen(i) for i, s in enumerate(free_list)} for gen in _MERGE_PROBES] + + # Evaluate flat offsets with first probe for candidate matching. + num_flat_0 = [None] * n + for i in range(n): + try: + num_flat_0[i] = _eval_expr(resolved_flat[i], probe_maps[0]) + except (TypeError, ValueError, ZeroDivisionError): + pass + + # Evaluate per-dim physical starts with first probe. + num_phys_0 = [None] * n + for i in range(n): + if num_flat_0[i] is None: + continue + try: + num_phys_0[i] = { + dim: _eval_expr(resolved_phys[i][dim], probe_maps[0]) + for dim in symbolic_dims + } + except (TypeError, ValueError, ZeroDivisionError): + pass + + # Build dict: numeric_flat_offset -> [indices] for O(1) partner lookup. + from collections import defaultdict + + offset_map = defaultdict(list) + for i in range(n): + if num_flat_0[i] is not None: + offset_map[num_flat_0[i]].append(i) + + def _verify_with_extra_probes(lo_i, hi_i, expected_flat_diff, expected_dim_diffs): + """Confirm diffs are consistent across additional probe sets.""" + for probe in probe_maps[1:]: + try: + flat_lo = _eval_expr(resolved_flat[lo_i], probe) + flat_hi = _eval_expr(resolved_flat[hi_i], probe) + if flat_hi - flat_lo != expected_flat_diff: + return False + for dim in symbolic_dims: + d_lo = _eval_expr(resolved_phys[lo_i][dim], probe) + d_hi = _eval_expr(resolved_phys[hi_i][dim], probe) + if d_hi - d_lo != expected_dim_diffs[dim]: + return False + except (TypeError, ValueError, ZeroDivisionError): + return False + return True + + for i in range(n): + if i in merged or num_flat_0[i] is None: + continue + vi = num_flat_0[i] + found = False + for target, i_is_lo in ((vi + ept, True), (vi - ept, False)): + for j in offset_map.get(target, []): + if j in merged or j == i: + continue + if num_phys_0[j] is None: + continue + lo_i, hi_i = (i, j) if i_is_lo else (j, i) + # Per-dim check with first probe. + dim_diffs = { + dim: num_phys_0[hi_i][dim] - num_phys_0[lo_i][dim] + for dim in symbolic_dims + } + merge_dim = _find_merge_dim_from_diffs(dim_diffs, ept, symbolic_dims) + if merge_dim is None: + continue + # Verify with additional probes. + flat_diff = num_flat_0[hi_i] - num_flat_0[lo_i] + if not _verify_with_extra_probes(lo_i, hi_i, flat_diff, dim_diffs): + continue + if _do_merge( + lo_i, + hi_i, + merge_dim, + read_infos, + ept, + symbolic_dims, + symbolic_shape, + hw_constraint, + ): + merged.update({i, j}) + did_merge = True + found = True + break + if found: + break + + return merged, did_merge + + +def _multiway_coalesce( + read_infos, + merged, + reads, + symbolic_dims, + symbolic_shape, + hw_constraint, + divisibility_fwd=None, +): + """Coalesce unmerged ept==1 reads whose flat offsets fall in an aligned window. + + Groups reads whose numerically-probed flat offsets fall within a + power-of-2 aligned window (up to ``max_elems_per_load``), then emits + a single wide read with per-byte ExtractSlice ops. + """ + element_type = get_custom(reads[0]).type.dtype + max_load_width = hw_constraint.max_elems_per_load(element_type) + + unmerged_infos = [read_infos[k] for k in range(len(read_infos)) if k not in merged] + if len(unmerged_infos) < 2: + return False + + # Pre-evaluate flat offsets with probe values to avoid symbolic diffs. + resolved_offs = [subs_idxc(info[0]) for info in unmerged_infos] + if divisibility_fwd: + resolved_offs = [safe_subs(e, divisibility_fwd) for e in resolved_offs] + all_free = set() + for expr in resolved_offs: + if hasattr(expr, "free_symbols"): + all_free.update(expr.free_symbols) + free_list = sorted(all_free, key=str) + probe0 = {s: _MERGE_PROBES[0](i) for i, s in enumerate(free_list)} + extra_probes = [ + {s: gen(i) for i, s in enumerate(free_list)} for gen in _MERGE_PROBES[1:] + ] + num_offs = [None] * len(unmerged_infos) + for i, expr in enumerate(resolved_offs): + try: + num_offs[i] = _eval_expr(expr, probe0) + except (TypeError, ValueError, ZeroDivisionError): + pass + + coalesced_any = False + coalesced_set: set[int] = set() + for anchor_idx in range(len(unmerged_infos)): + if anchor_idx in coalesced_set: + continue + if num_offs[anchor_idx] is None: + continue + _, phys_a, custom_a, node_a = unmerged_infos[anchor_idx] + + group = [(anchor_idx, node_a, custom_a, 0, phys_a)] + for probe_idx in range(len(unmerged_infos)): + if probe_idx == anchor_idx or probe_idx in coalesced_set: + continue + if num_offs[probe_idx] is None: + continue + diff_val = num_offs[probe_idx] - num_offs[anchor_idx] + if not (0 < diff_val < max_load_width): + continue + # Verify diff is constant across extra probes. + consistent = True + for ep in extra_probes: + try: + va = _eval_expr(resolved_offs[anchor_idx], ep) + vp = _eval_expr(resolved_offs[probe_idx], ep) + if vp - va != diff_val: + consistent = False + break + except (TypeError, ValueError, ZeroDivisionError): + consistent = False + break + if not consistent: + continue + _, custom_p, node_p = ( + unmerged_infos[probe_idx][1], + unmerged_infos[probe_idx][2], + unmerged_infos[probe_idx][3], + ) + _, _, _, phys_p = unmerged_infos[probe_idx] + group.append((probe_idx, node_p, custom_p, diff_val, phys_p)) + + if len(group) < 2: + continue + + group.sort(key=lambda x: x[3]) + max_off = group[-1][3] + wide_ept = 1 + while wide_ept <= max_off: + wide_ept *= 2 + if wide_ept > max_load_width: + continue + + base_phys = group[0][4] + + earliest_node = group[0][1] + for g in group[1:]: + candidate = g[1] + for n in custom_a.graph.nodes: + if n is candidate: + earliest_node = candidate + break + if n is earliest_node: + break + + wide_mask = _build_wide_mask_expr( + [(byte_pos, 1, g_custom) for _, _, g_custom, byte_pos, _ in group], + symbolic_shape, + wide_ept, + ) + with get_custom(earliest_node).graph.inserting_before(earliest_node): + wide_index = {} + for dim_idx, dim in enumerate(symbolic_dims): + if dim_idx == len(symbolic_dims) - 1: + wide_index[dim] = IndexSequence(base_phys[dim], wide_ept, 1) + else: + wide_index[dim] = IndexSequence(base_phys[dim], 1, 1) + + wide_read = _emit_wide_read( + custom_a, wide_index, wide_ept, earliest_node, mask_expr=wide_mask + ) + + extracts = [] + for g_idx, g_node, g_custom, byte_pos, _ in group: + with g_custom.graph.inserting_before(g_node): + ext = _emit_extract_slice( + wide_read, byte_pos, 1, g_custom, g_node, symbolic_shape + ) + extracts.append((ext, g_custom, g_node)) + + for ext, g_custom, g_node in extracts: + g_custom.replace_all_uses_with(ext) + g_custom.graph.erase_node(g_node) + + coalesced_set.update(g[0] for g in group) + coalesced_any = True + + return coalesced_any + +def _merge_contiguous_reads_once( + trace: CapturedTrace, hw_constraint, divisibility_fwd=None +) -> bool: + """Single merge pass: merge reads that access nearby physical memory. + + Two strategies are applied per (memory, ept) group: + + 1. **Pairwise contiguous merge** (``_pairwise_merge``): pairs whose + physical flat offset starts differ by exactly ``ept`` are merged + into a ``2*ept`` read with two ExtractSlice outputs. + + 2. **Multi-way coalescing** (``_multiway_coalesce``, ``ept==1`` only): + unmerged byte reads whose flat offsets fall within a power-of-2 + aligned window are replaced by a single wide read with per-byte + ExtractSlice outputs. + + Returns True if any merges or coalescing happened. + """ + from ...compiler.utils import strides_from_symbolic_shape + from ..._support.indexing import IndexingContext + + groups = _group_reads_by_memory(trace) idxc = IndexingContext.current() merged_any = False @@ -509,129 +1034,43 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: ) read_infos.append((flat_offset, phys_start, custom, node)) - merged = set() - for i in range(len(read_infos)): - if i in merged: - continue - for j in range(i + 1, len(read_infos)): - if j in merged: - continue - off1, phys1, custom1, node1 = read_infos[i] - off2, phys2, custom2, node2 = read_infos[j] - - raw_diff = subs_idxc(off2 - off1) - - # For reads with non-identity mappings (e.g. preshuffle - # scales), the flat-offset diff contains complex floor/Mod - # expressions that sympy.simplify cannot reduce. Use fast - # numeric probing instead. - has_complex_mapping = ( - custom1.mapping is not None and not custom1.has_identity_mapping() - ) - - # subs_idxc may fully resolve to a plain int. - if isinstance(raw_diff, (int, sympy.Integer)): - diff = int(raw_diff) - elif has_complex_mapping: - diff = _numeric_eval_constant(raw_diff) - if diff is None: - continue - else: - diff = sym_simplify(raw_diff) - if diff != ept and diff != -ept: - nv = _numeric_eval_constant(raw_diff) - if nv is not None: - diff = nv - - if diff == ept: - lo_phys, hi_phys = phys1, phys2 - lo_custom, hi_custom = custom1, custom2 - lo_node, hi_node = node1, node2 - elif diff == -ept: - lo_phys, hi_phys = phys2, phys1 - lo_custom, hi_custom = custom2, custom1 - lo_node, hi_node = node2, node1 - else: - continue + import time as _time - # Find dimension that advances by ept. - merge_dim = None - for dim in symbolic_dims: - raw_d = subs_idxc(hi_phys[dim] - lo_phys[dim]) - if isinstance(raw_d, (int, sympy.Integer)): - d = int(raw_d) - elif has_complex_mapping: - d = _numeric_eval_constant(raw_d) - if d is None: - merge_dim = None - break - else: - d = sym_simplify(raw_d) - if d != ept and d != 0: - nv = _numeric_eval_constant(raw_d) - if nv is not None: - d = nv - if d == ept: - merge_dim = dim - elif not (d == 0): - merge_dim = None - break - if merge_dim is None: - continue - - # Respect hardware vector width limit. - new_ept = 2 * ept - element_type = lo_custom.type.dtype - if new_ept > hw_constraint.max_elems_per_load(element_type): - continue - with lo_custom.graph.inserting_before(lo_node): - new_index = { - dim: IndexSequence( - lo_phys[dim], - new_ept if dim == merge_dim else 1, - 1, - ) - for dim in symbolic_dims - } - - merged_read = Read( - lo_custom.memory, - elements_per_thread=new_ept, - mapping=None, - _write_dependency=lo_custom._write_dependency, - flags=lo_custom.flags, - ).add_to_graph(lo_custom.graph, loc=lo_custom.location) - merged_custom = get_custom(merged_read) - merged_custom.index = new_index - merged_custom.vector_shapes = deepcopy(lo_custom.vector_shapes) - propagate_tag(lo_node, merged_read) - - extract0 = ExtractSlice(merged_read, [0], [ept], [1]).add_to_graph( - lo_custom.graph, loc=lo_custom.location - ) - get_custom(extract0).index = deepcopy(lo_custom.index) - get_custom(extract0).vector_shapes = deepcopy( - lo_custom.vector_shapes - ) - propagate_tag(lo_node, extract0) - - extract1 = ExtractSlice( - merged_read, [ept], [ept], [1] - ).add_to_graph(lo_custom.graph, loc=lo_custom.location) - get_custom(extract1).index = deepcopy(hi_custom.index) - get_custom(extract1).vector_shapes = deepcopy( - hi_custom.vector_shapes - ) - propagate_tag(hi_node, extract1) - - lo_custom.replace_all_uses_with(extract0) - hi_custom.replace_all_uses_with(extract1) - lo_custom.graph.erase_node(lo_node) - hi_custom.graph.erase_node(hi_node) - - merged.update({i, j}) - merged_any = True - break + print( + f"[DEBUG merge] mem={memory.fx_node.name} ept={ept} region={_region} n_reads={len(read_infos)}", + flush=True, + ) + _t0 = _time.time() + merged, did_merge = _pairwise_merge( + read_infos, + ept, + symbolic_dims, + symbolic_shape, + hw_constraint, + divisibility_fwd=divisibility_fwd, + ) + print( + f"[DEBUG merge] _pairwise_merge {_time.time()-_t0:.3f}s merged={len(merged)} did_merge={did_merge}", + flush=True, + ) + merged_any |= did_merge + + # Only ept==1 (byte) reads need multi-way coalescing; wider reads + # are already handled by the pairwise merge above. + if ept == 1 and len(read_infos) >= 2: + _t0 = _time.time() + merged_any |= _multiway_coalesce( + read_infos, + merged, + reads, + symbolic_dims, + symbolic_shape, + hw_constraint, + divisibility_fwd=divisibility_fwd, + ) + print( + f"[DEBUG merge] _multiway_coalesce {_time.time()-_t0:.3f}s", flush=True + ) return merged_any diff --git a/wave_lang/kernel/wave/asm/handlers_memory.py b/wave_lang/kernel/wave/asm/handlers_memory.py index a32394028..e0ee25d3d 100644 --- a/wave_lang/kernel/wave/asm/handlers_memory.py +++ b/wave_lang/kernel/wave/asm/handlers_memory.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from wave_lang.support.ir_imports import ( + VectorType, amdgpu_d, gpu_d, memref_d, @@ -79,14 +80,24 @@ def handle_vector_extract_strided_slice_op( offset_val = int(str(offsets).split("[")[1].split("]")[0]) size_val = int(str(sizes).split("[")[1].split("]")[0]) + # Convert MLIR element offsets to physical register indices. + # Each 32-bit VGPR holds (32 // elem_bits) elements, e.g. 4 for i8. + # An extract at element offset 4 on vector<8xi8> targets register 1, + # not register 4. + source_vec_type = VectorType(operation.operands[0].type) + elem_bits = source_vec_type.element_type.width + elems_per_reg = 32 // elem_bits + + reg_offset = offset_val // elems_per_reg + reg_count = max(1, (size_val * elem_bits + 31) // 32) + # Extract the appropriate subset of registers - if size_val == 1: - # Single scalar extract - return just the one register as a tuple - extracted_reg = source_regs[offset_val] - result_regs = (extracted_reg,) + if reg_count == 1: + # Single register extract - return just the one register as a tuple + result_regs = (source_regs[reg_offset],) else: - # Multi-element extract - return a slice - result_regs = source_regs[offset_val : offset_val + size_val] + # Multi-register extract - return a slice + result_regs = source_regs[reg_offset : reg_offset + reg_count] result_ssa = str(operation.result) self.walker.kernel_ctx.ssa_to_reg[result_ssa] = result_regs diff --git a/wave_lang/kernel/wave/opsel_scaled_mfma.py b/wave_lang/kernel/wave/opsel_scaled_mfma.py index 2a4910b93..3880dfba8 100644 --- a/wave_lang/kernel/wave/opsel_scaled_mfma.py +++ b/wave_lang/kernel/wave/opsel_scaled_mfma.py @@ -98,6 +98,8 @@ def _trace_scale_chain(scale_value): return None slice_op = bitcast_source.owner + if _is_op_named(slice_op, "arith.select"): + slice_op = slice_op.operands[1].owner if not _is_op_named(slice_op, "vector.extract_strided_slice"): return None @@ -139,9 +141,14 @@ def _trace_extract_strided_slice( ) -> Optional[tuple[Value, int]]: """Check if *value* is produced by extract_strided_slice of a vector<4xi8>. + Looks through ``arith.select`` (inserted by flatten-bounds masking) + to find the underlying extract_strided_slice. + Returns ``(source_vec4xi8, byte_offset)`` or ``None``. """ op = value.owner + if _is_op_named(op, "arith.select"): + op = op.operands[1].owner if not _is_op_named(op, "vector.extract_strided_slice"): return None source = op.operands[0] @@ -165,17 +172,19 @@ def _find_yield_op(for_view) -> Optional[Operation]: def _find_mergeable_groups( for_view, yield_op: Operation ) -> list[tuple[Value, Value, dict[int, int]]]: - """Find groups of 4 ``vector<1xi8>`` iter_args that can be coalesced. + """Find groups of ``vector<1xi8>`` iter_args that can be coalesced. A group is valid when: - * All 4 init values are ``extract_strided_slice`` at offsets - {0, 1, ..., SCALE_VECTOR_WIDTH-1} from the same ``vector<4xi8>`` - source. - * All 4 yield values follow the same pattern from a (possibly - different) ``vector<4xi8>`` source. + * At least 2 init values are ``extract_strided_slice`` at distinct + offsets from the same ``vector<4xi8>`` source. + * The corresponding yield values follow the same pattern from a + (possibly different) ``vector<4xi8>`` source. * For each member, init_offset == yield_offset (byte identity is preserved across iterations). + Partial groups (e.g. only offsets {0, 2}) are accepted — the + coalesced ``vector<4xi8>`` iter_arg simply carries unused bytes. + Returns a list of ``(init_source, yield_source, {offset: iter_index})``. """ i8 = IntegerType.get_signless(8) @@ -199,9 +208,6 @@ def _find_mergeable_groups( continue eligible.append((i, init_off, init_src, yield_src)) - # Group by init source. Multiple args can share the same source - # (e.g. two MFMAs using the same scale load), so partition by offset - # to form distinct groups of exactly 4. by_init_src = defaultdict(list) for entry in eligible: _, _, init_src, _ = entry @@ -214,12 +220,16 @@ def _find_mergeable_groups( _, off, _, _ = entry by_offset[off].append(entry) - while all(len(by_offset[o]) > 0 for o in range(SCALE_VECTOR_WIDTH)): + # Greedily form groups from available offsets. Accept any group + # with >= 2 distinct offsets (full groups of 4 are the common + # case; partial groups like {0, 2} arise from preshuffle scales). + present_offsets = [o for o in range(SCALE_VECTOR_WIDTH) if by_offset[o]] + while len(present_offsets) >= 2: members = {} init_source = None yield_owners = set() yield_source = None - for o in range(SCALE_VECTOR_WIDTH): + for o in present_offsets: idx, _, isrc, ysrc = by_offset[o].pop(0) members[o] = idx init_source = isrc @@ -227,6 +237,7 @@ def _find_mergeable_groups( yield_owners.add(id(ysrc.owner)) if len(yield_owners) == 1: result.append((init_source, yield_source, members)) + present_offsets = [o for o in range(SCALE_VECTOR_WIDTH) if by_offset[o]] return result @@ -314,15 +325,13 @@ def _rewire_for_results( members = plan.groups[g_idx][2] new_idx = plan.group_new_iter_idx[g_idx] has_users = any( - any(True for _ in old_results[members[o]].uses) - for o in range(SCALE_VECTOR_WIDTH) + any(True for _ in old_results[members[o]].uses) for o in members ) if not has_users: continue with InsertionPoint(for_op): - for o in range(SCALE_VECTOR_WIDTH): - old_i = members[o] + for o, old_i in members.items(): if not any(True for _ in old_results[old_i].uses): continue extract_slice = make_extract_slice(new_results[new_idx], o) @@ -330,12 +339,15 @@ def _rewire_for_results( def _coalesce_vector_iter_args(module: Module) -> None: - """Merge groups of 4 ``vector<1xi8>`` scf.for iter_args into ``vector<4xi8>``. + """Merge groups of ``vector<1xi8>`` scf.for iter_args into ``vector<4xi8>``. - Pipeline double-buffering splits a ``vector<4xi8>`` scale load into 4 + Pipeline double-buffering splits a ``vector<4xi8>`` scale load into individual bytes for loop-carry. This pass merges them back so that ``_trace_scale_chain`` sees the full ``extract_strided_slice`` pattern inside the loop body and the opsel optimisation fires. + + Handles both full groups (all 4 offsets present) and partial groups + (e.g. only offsets {0, 2} from preshuffle scales). """ i8 = IntegerType.get_signless(8) i64 = IntegerType.get_signless(64) @@ -364,7 +376,7 @@ def make_extract_slice(source: Value, offset: int): if not groups: continue - logger.debug(f"Coalescing {len(groups)} group(s) of 4 vector<1xi8> iter_args") + logger.debug(f"Coalescing {len(groups)} group(s) of vector<1xi8> iter_args") plan = _build_coalesce_plan(groups, for_view, yield_op) old_iter_args = list(for_view.inner_iter_args) @@ -396,8 +408,7 @@ def make_extract_slice(source: Value, offset: int): with InsertionPoint(first_op): for g_idx, (_, _, members) in enumerate(groups): merged_arg = new_for.inner_iter_args[plan.group_new_iter_idx[g_idx]] - for offset in range(SCALE_VECTOR_WIDTH): - iter_idx = members[offset] + for offset, iter_idx in members.items(): extract_slice = make_extract_slice(merged_arg, offset) extract_results[iter_idx] = extract_slice.result diff --git a/wave_lang/kernel/wave/utils/symbol_utils.py b/wave_lang/kernel/wave/utils/symbol_utils.py index 9c1013378..1cb79e1bd 100644 --- a/wave_lang/kernel/wave/utils/symbol_utils.py +++ b/wave_lang/kernel/wave/utils/symbol_utils.py @@ -181,7 +181,7 @@ def transform_mod(expr): return None mult = m if (mult is None) or (m < mult) else mult terms.append(arg) - if c >= mult: + if c is None or mult is None or c >= mult: return None return (sum(terms) % q) + c @@ -409,6 +409,8 @@ def _numeric_eval_constant(expr, num_samples: int = 48): free, evaluator = (), None if not free: + if isinstance(expr, int): + return expr if expr.has(*_BAD_ATOMS): return None if expr.is_integer is not True: