From 519c70a117cb45f6c6a45fb09b109ba20b517c98 Mon Sep 17 00:00:00 2001 From: xintin Date: Mon, 2 Mar 2026 20:57:46 +0000 Subject: [PATCH 01/15] initial commit Signed-off-by: xintin --- .../analysis/partition_strided_operators.py | 187 ++++++++++++++++++ 1 file changed, 187 insertions(+) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 18b92c99f..9a66449a8 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -404,6 +404,192 @@ def has_gpr_offsets(node: fx.Node) -> bool: custom.graph.erase_node(custom.fx_node) +# TODO: I will add a field in IndexMapping to check if it is a preshuffle mapping. +# `kind`: `preshuffle` or `shuffleb` or `linear`. +def _is_preshuffle_mapping(mapping) -> bool: + """Check if a mapping uses the preshuffle e8m0_shuffle pattern.""" + if mapping is None: + return False + # Scale buffers are always 2D: (N or M, K/32). + if len(mapping.output_mapping) != 2 or len(mapping.input_mapping) != 2: + return False + # Output must be identity (logical coords), input must be shuffled. + if not (mapping.is_output_identity() and not mapping.is_input_identity()): + return False + # The preshuffle formula uses floor (m//32, k//4) and Mod (k%4, (m//16)%2) + # to interleave scale bytes within each 256-byte region. + input_atoms = set() + for expr in mapping.input_mapping.values(): + input_atoms.update(type(a) for a in sympy.preorder_traversal(expr)) + return sympy.floor in input_atoms and sympy.Mod in input_atoms + + +def _coalesce_preshuffle_global_reads( + trace: CapturedTrace, constraints: list[Constraint] +): + """Coalesce scattered preshuffle-mapped global byte reads into dword loads. + + After merge_contiguous_reads, preshuffle-mapped global reads may remain as + individual byte loads when the per-wave tile doesn't divide evenly into + groups of 4 contiguous bytes (e.g. BLOCK_N=192 with 4 waves gives 48 + N-rows/wave, whose second 32-row e8m0 block only yields 2 reads instead + of 4). + + Uses pairwise symbolic differences to find reads within the same 4-byte + aligned region, then replaces them with a single 4-byte read plus + ExtractSlice ops for the needed bytes. + """ + from collections import defaultdict + from ...compiler.utils import strides_from_symbolic_shape + from ..._support.indexing import IndexingContext + from ..utils.mapping_utils import transform_index_on_mapping + + idxc = IndexingContext.current() + + # Collect single-byte preshuffle scale reads from global memory, grouped + # by their source buffer. These are the reads that merge_contiguous_reads + # couldn't handle (e.g. because they have bounds, or their preshuffle + # offsets aren't strictly adjacent). + mem_groups: dict[int, list] = defaultdict(list) + for node in trace.walk(lambda n: isinstance(get_custom(n), Read)): + custom = get_custom(node) + # Already widened by merge_contiguous_reads + if custom.elements_per_thread != 1: + continue + # Only target reads with the e8m0 preshuffle mapping. + if not _is_preshuffle_mapping(custom.mapping): + continue + # Shared-memory reads are handled by preshuffle_scale_to_shared. + if subs_idxc(custom.memory_type.address_space) != GLOBAL_ADDRESS_SPACE: + continue + mem_groups[id(custom.memory)].append(node) + + coalesced = 0 + for _, nodes in mem_groups.items(): + # Need at least 2 reads to coalesce. + if len(nodes) < 2: + continue + + sample_custom = get_custom(nodes[0]) + memory_node = sample_custom.memory + memory = get_custom(memory_node) + symbolic_shape = memory.type.symbolic_shape + symbolic_dims = [infer_dim(d) for d in symbolic_shape] + strides = strides_from_symbolic_shape( + idxc, symbolic_shape, allow_mixed_shapes=True + ) + if strides is None: + continue + + # For each read, compute its physical flat offset (in bytes) and store + # it along with the read node, custom op, and physical coordinates. + read_infos = [] + for node in nodes: + custom = get_custom(node) + # Transform the logical index to physical coordinates using the mapping. + physical = transform_index_on_mapping( + custom.mapping, symbolic_shape, custom.index, is_read=True + ) + if not all(dim in physical for dim in symbolic_dims): + continue + # Compute the physical flat offset (in bytes) by multiplying the + # physical coordinates by the corresponding stride and summing the results. + flat = sum( + physical[dim] * stride for dim, stride in zip(symbolic_dims, strides) + ) + read_infos.append((node, custom, subs_idxc(flat), physical)) + + if len(read_infos) < 2: + continue + + # Try all pairs to find contiguous ones (diff == ept). + merged = set() + for i in range(len(read_infos)): + if i in merged: + continue + node_i, custom_i, flat_i, phys_i = read_infos[i] + + group = [(i, node_i, custom_i, 0, phys_i)] + for j in range(len(read_infos)): + if j == i or j in merged: + continue + node_j, custom_j, flat_j, phys_j = read_infos[j] + raw_diff = subs_idxc(flat_j - flat_i) + diff_val = _numeric_eval_constant(raw_diff) + if diff_val is None: + continue + if 0 < diff_val < 4: + group.append((j, node_j, custom_j, diff_val, phys_j)) + + if len(group) < 2: + continue + + group.sort(key=lambda x: x[3]) + max_off = group[-1][3] + if max_off >= 4: + continue + + base_phys = group[0][4] + + earliest_node = group[0][1] + for g in group[1:]: + candidate = g[1] + for n in custom_i.graph.nodes: + if n is candidate: + earliest_node = candidate + break + if n is earliest_node: + break + + with get_custom(earliest_node).graph.inserting_before(earliest_node): + # Create a new wide read with 4-byte elements, using the earliest + # read's physical coordinates as the base. + wide_index = {} + for dim_idx, dim in enumerate(symbolic_dims): + if dim_idx == len(symbolic_dims) - 1: + wide_index[dim] = IndexSequence(base_phys[dim], 4, 1) + else: + wide_index[dim] = IndexSequence(base_phys[dim], 1, 1) + + wide_read = Read( + memory_node, + elements_per_thread=4, + mapping=None, + _write_dependency=custom_i._write_dependency, + flags=custom_i.flags, + ).add_to_graph(custom_i.graph, loc=custom_i.location) + wide_custom = get_custom(wide_read) + wide_custom.index = wide_index + if hasattr(earliest_node, "vector_shapes"): + wide_read.vector_shapes = deepcopy(earliest_node.vector_shapes) + propagate_tag(group[0][1], wide_read) + + # Create ExtractSlice ops for each byte in the contiguous region. + # because the wide read is 4-byte aligned, we need to extract the + # correct byte from the wide read. + for g_idx, g_node, g_custom, byte_pos, _ in group: + with g_custom.graph.inserting_before(g_node): + extract = ExtractSlice( + wide_read, [byte_pos], [1], [1] + ).add_to_graph(g_custom.graph, loc=g_custom.location) + if hasattr(g_node, "vector_shapes"): + get_custom(extract).vector_shapes = deepcopy( + g_node.vector_shapes + ) + propagate_tag(g_node, extract) + + g_custom.replace_all_uses_with(extract) + g_custom.graph.erase_node(g_node) + + merged.update(g[0] for g in group) + coalesced += 1 + + if coalesced > 0: + logger.info( + f"Coalesced {coalesced} preshuffle global read group(s) into dword loads" + ) + + def merge_contiguous_reads( trace: CapturedTrace, constraints: list[Constraint], target: str ): @@ -420,6 +606,7 @@ def merge_contiguous_reads( hw_constraint = get_hardware_constraint(constraints) while _merge_contiguous_reads_once(trace, hw_constraint): pass + _coalesce_preshuffle_global_reads(trace, constraints) def _get_physical_start( From a855992c7191423841a5e90066dd445ad799f874 Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 01:16:17 +0000 Subject: [PATCH 02/15] remove bound check Signed-off-by: xintin --- .../wave/analysis/partition_strided_operators.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 9a66449a8..718506423 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -606,7 +606,7 @@ def merge_contiguous_reads( hw_constraint = get_hardware_constraint(constraints) while _merge_contiguous_reads_once(trace, hw_constraint): pass - _coalesce_preshuffle_global_reads(trace, constraints) + # _coalesce_preshuffle_global_reads(trace, constraints) def _get_physical_start( @@ -664,10 +664,8 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: continue if custom.mapping_dynamic_vals: continue - # Skip reads that have bounds: the merged read would lose the - # mapping and source→target index, making mask generation incorrect. - if custom.bounds is not None: - continue + # Bounded reads are allowed into the group; the merge loop + # below only merges pairs with identical bounds. key = (custom.memory, custom.elements_per_thread, region_id) groups[key].append(node) @@ -741,6 +739,14 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: else: continue + # Only merge bounded reads when both have the exact same + # bounds (same OOB mask). This lets preshuffle reads + # (which all share identical bounds) pass through even if + # they can't merge here — _coalesce_preshuffle_global_reads + # handles them downstream. + if lo_custom.bounds != hi_custom.bounds: + continue + # Find dimension that advances by ept. merge_dim = None for dim in symbolic_dims: From 840ff09c01eb8d0433da47bb152f557585004952 Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 05:22:52 +0000 Subject: [PATCH 03/15] updated Signed-off-by: xintin --- .../analysis/partition_strided_operators.py | 325 +++++++----------- wave_lang/kernel/wave/opsel_scaled_mfma.py | 44 +-- 2 files changed, 150 insertions(+), 219 deletions(-) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 718506423..1202ee930 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -404,192 +404,6 @@ def has_gpr_offsets(node: fx.Node) -> bool: custom.graph.erase_node(custom.fx_node) -# TODO: I will add a field in IndexMapping to check if it is a preshuffle mapping. -# `kind`: `preshuffle` or `shuffleb` or `linear`. -def _is_preshuffle_mapping(mapping) -> bool: - """Check if a mapping uses the preshuffle e8m0_shuffle pattern.""" - if mapping is None: - return False - # Scale buffers are always 2D: (N or M, K/32). - if len(mapping.output_mapping) != 2 or len(mapping.input_mapping) != 2: - return False - # Output must be identity (logical coords), input must be shuffled. - if not (mapping.is_output_identity() and not mapping.is_input_identity()): - return False - # The preshuffle formula uses floor (m//32, k//4) and Mod (k%4, (m//16)%2) - # to interleave scale bytes within each 256-byte region. - input_atoms = set() - for expr in mapping.input_mapping.values(): - input_atoms.update(type(a) for a in sympy.preorder_traversal(expr)) - return sympy.floor in input_atoms and sympy.Mod in input_atoms - - -def _coalesce_preshuffle_global_reads( - trace: CapturedTrace, constraints: list[Constraint] -): - """Coalesce scattered preshuffle-mapped global byte reads into dword loads. - - After merge_contiguous_reads, preshuffle-mapped global reads may remain as - individual byte loads when the per-wave tile doesn't divide evenly into - groups of 4 contiguous bytes (e.g. BLOCK_N=192 with 4 waves gives 48 - N-rows/wave, whose second 32-row e8m0 block only yields 2 reads instead - of 4). - - Uses pairwise symbolic differences to find reads within the same 4-byte - aligned region, then replaces them with a single 4-byte read plus - ExtractSlice ops for the needed bytes. - """ - from collections import defaultdict - from ...compiler.utils import strides_from_symbolic_shape - from ..._support.indexing import IndexingContext - from ..utils.mapping_utils import transform_index_on_mapping - - idxc = IndexingContext.current() - - # Collect single-byte preshuffle scale reads from global memory, grouped - # by their source buffer. These are the reads that merge_contiguous_reads - # couldn't handle (e.g. because they have bounds, or their preshuffle - # offsets aren't strictly adjacent). - mem_groups: dict[int, list] = defaultdict(list) - for node in trace.walk(lambda n: isinstance(get_custom(n), Read)): - custom = get_custom(node) - # Already widened by merge_contiguous_reads - if custom.elements_per_thread != 1: - continue - # Only target reads with the e8m0 preshuffle mapping. - if not _is_preshuffle_mapping(custom.mapping): - continue - # Shared-memory reads are handled by preshuffle_scale_to_shared. - if subs_idxc(custom.memory_type.address_space) != GLOBAL_ADDRESS_SPACE: - continue - mem_groups[id(custom.memory)].append(node) - - coalesced = 0 - for _, nodes in mem_groups.items(): - # Need at least 2 reads to coalesce. - if len(nodes) < 2: - continue - - sample_custom = get_custom(nodes[0]) - memory_node = sample_custom.memory - memory = get_custom(memory_node) - symbolic_shape = memory.type.symbolic_shape - symbolic_dims = [infer_dim(d) for d in symbolic_shape] - strides = strides_from_symbolic_shape( - idxc, symbolic_shape, allow_mixed_shapes=True - ) - if strides is None: - continue - - # For each read, compute its physical flat offset (in bytes) and store - # it along with the read node, custom op, and physical coordinates. - read_infos = [] - for node in nodes: - custom = get_custom(node) - # Transform the logical index to physical coordinates using the mapping. - physical = transform_index_on_mapping( - custom.mapping, symbolic_shape, custom.index, is_read=True - ) - if not all(dim in physical for dim in symbolic_dims): - continue - # Compute the physical flat offset (in bytes) by multiplying the - # physical coordinates by the corresponding stride and summing the results. - flat = sum( - physical[dim] * stride for dim, stride in zip(symbolic_dims, strides) - ) - read_infos.append((node, custom, subs_idxc(flat), physical)) - - if len(read_infos) < 2: - continue - - # Try all pairs to find contiguous ones (diff == ept). - merged = set() - for i in range(len(read_infos)): - if i in merged: - continue - node_i, custom_i, flat_i, phys_i = read_infos[i] - - group = [(i, node_i, custom_i, 0, phys_i)] - for j in range(len(read_infos)): - if j == i or j in merged: - continue - node_j, custom_j, flat_j, phys_j = read_infos[j] - raw_diff = subs_idxc(flat_j - flat_i) - diff_val = _numeric_eval_constant(raw_diff) - if diff_val is None: - continue - if 0 < diff_val < 4: - group.append((j, node_j, custom_j, diff_val, phys_j)) - - if len(group) < 2: - continue - - group.sort(key=lambda x: x[3]) - max_off = group[-1][3] - if max_off >= 4: - continue - - base_phys = group[0][4] - - earliest_node = group[0][1] - for g in group[1:]: - candidate = g[1] - for n in custom_i.graph.nodes: - if n is candidate: - earliest_node = candidate - break - if n is earliest_node: - break - - with get_custom(earliest_node).graph.inserting_before(earliest_node): - # Create a new wide read with 4-byte elements, using the earliest - # read's physical coordinates as the base. - wide_index = {} - for dim_idx, dim in enumerate(symbolic_dims): - if dim_idx == len(symbolic_dims) - 1: - wide_index[dim] = IndexSequence(base_phys[dim], 4, 1) - else: - wide_index[dim] = IndexSequence(base_phys[dim], 1, 1) - - wide_read = Read( - memory_node, - elements_per_thread=4, - mapping=None, - _write_dependency=custom_i._write_dependency, - flags=custom_i.flags, - ).add_to_graph(custom_i.graph, loc=custom_i.location) - wide_custom = get_custom(wide_read) - wide_custom.index = wide_index - if hasattr(earliest_node, "vector_shapes"): - wide_read.vector_shapes = deepcopy(earliest_node.vector_shapes) - propagate_tag(group[0][1], wide_read) - - # Create ExtractSlice ops for each byte in the contiguous region. - # because the wide read is 4-byte aligned, we need to extract the - # correct byte from the wide read. - for g_idx, g_node, g_custom, byte_pos, _ in group: - with g_custom.graph.inserting_before(g_node): - extract = ExtractSlice( - wide_read, [byte_pos], [1], [1] - ).add_to_graph(g_custom.graph, loc=g_custom.location) - if hasattr(g_node, "vector_shapes"): - get_custom(extract).vector_shapes = deepcopy( - g_node.vector_shapes - ) - propagate_tag(g_node, extract) - - g_custom.replace_all_uses_with(extract) - g_custom.graph.erase_node(g_node) - - merged.update(g[0] for g in group) - coalesced += 1 - - if coalesced > 0: - logger.info( - f"Coalesced {coalesced} preshuffle global read group(s) into dword loads" - ) - - def merge_contiguous_reads( trace: CapturedTrace, constraints: list[Constraint], target: str ): @@ -606,7 +420,6 @@ def merge_contiguous_reads( hw_constraint = get_hardware_constraint(constraints) while _merge_contiguous_reads_once(trace, hw_constraint): pass - # _coalesce_preshuffle_global_reads(trace, constraints) def _get_physical_start( @@ -635,11 +448,22 @@ def _get_physical_start( def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: - """Single merge pass: merge adjacent pairs of same-ept reads. + """Single merge pass: merge reads that access nearby physical memory. + + Two strategies are applied per (memory, ept) group: + + 1. **Pairwise contiguous merge** — pairs whose physical flat offset + starts differ by exactly ``ept`` are merged into a ``2*ept`` read + with two ExtractSlice outputs. + + 2. **Multi-way coalescing** (``ept==1`` only) — unmerged byte reads + whose flat offsets fall within a power-of-2 aligned window (up to + ``max_elems_per_load``) are replaced by a single wide read with + per-byte ExtractSlice outputs. Diffs are evaluated via numeric + probing (``_numeric_eval_constant``), so this handles non-constant + symbolic offsets such as preshuffle mappings. - Groups reads by (memory operand, ept) and merges pairs whose physical - flat offset starts differ by exactly ept. Returns True if any merges - happened. + Returns True if any merges or coalescing happened. """ from collections import defaultdict from ...compiler.utils import strides_from_symbolic_shape @@ -704,6 +528,16 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: off1, phys1, custom1, node1 = read_infos[i] off2, phys2, custom2, node2 = read_infos[j] + # The pairwise merge drops the mapping and produces a + # physical-space read with no mask. That is unsafe for + # bounded reads whose OOB behaviour relies on the mask + # generated from bounds + mapping. Skip them here; + # bounded ept==1 reads (e.g. preshuffle scales backed by + # fat_raw_buffer) are handled by the multi-way coalescing + # pass below, where OOB loads safely return zero. + if custom1.bounds is not None or custom2.bounds is not None: + continue + raw_diff = subs_idxc(off2 - off1) # For reads with non-identity mappings (e.g. preshuffle @@ -739,14 +573,6 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: else: continue - # Only merge bounded reads when both have the exact same - # bounds (same OOB mask). This lets preshuffle reads - # (which all share identical bounds) pass through even if - # they can't merge here — _coalesce_preshuffle_global_reads - # handles them downstream. - if lo_custom.bounds != hi_custom.bounds: - continue - # Find dimension that advances by ept. merge_dim = None for dim in symbolic_dims: @@ -826,6 +652,107 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: merged_any = True break + # Multi-way coalescing for ept==1 reads that survived the pairwise + # pass. Groups reads whose numerically-probed flat offsets fall + # within an aligned window of max_load_width bytes, then emits a + # single wide read + per-byte ExtractSlice ops. This generalises + # the former _coalesce_preshuffle_global_reads to all mapping types. + if ept == 1 and len(read_infos) >= 2: + element_type = get_custom(reads[0]).type.dtype + max_load_width = hw_constraint.max_elems_per_load(element_type) + + unmerged_infos = [ + read_infos[k] for k in range(len(read_infos)) if k not in merged + ] + if len(unmerged_infos) >= 2: + coalesced_set: set[int] = set() + for anchor_idx in range(len(unmerged_infos)): + if anchor_idx in coalesced_set: + continue + off_a, phys_a, custom_a, node_a = unmerged_infos[anchor_idx] + + group = [(anchor_idx, node_a, custom_a, 0, phys_a)] + for probe_idx in range(len(unmerged_infos)): + if probe_idx == anchor_idx or probe_idx in coalesced_set: + continue + off_p, _, custom_p, node_p = unmerged_infos[probe_idx] + raw_diff = subs_idxc(off_p - off_a) + diff_val = _numeric_eval_constant(raw_diff) + if diff_val is None: + continue + if 0 < diff_val < max_load_width: + _, _, _, phys_p = unmerged_infos[probe_idx] + group.append( + (probe_idx, node_p, custom_p, diff_val, phys_p) + ) + + if len(group) < 2: + continue + + group.sort(key=lambda x: x[3]) + max_off = group[-1][3] + wide_ept = 1 + while wide_ept <= max_off: + wide_ept *= 2 + if wide_ept > max_load_width: + continue + + base_phys = group[0][4] + + earliest_node = group[0][1] + for g in group[1:]: + candidate = g[1] + for n in custom_a.graph.nodes: + if n is candidate: + earliest_node = candidate + break + if n is earliest_node: + break + + with get_custom(earliest_node).graph.inserting_before( + earliest_node + ): + wide_index = {} + for dim_idx, dim in enumerate(symbolic_dims): + if dim_idx == len(symbolic_dims) - 1: + wide_index[dim] = IndexSequence( + base_phys[dim], wide_ept, 1 + ) + else: + wide_index[dim] = IndexSequence(base_phys[dim], 1, 1) + + wide_read = Read( + custom_a.memory, + elements_per_thread=wide_ept, + mapping=None, + _write_dependency=custom_a._write_dependency, + flags=custom_a.flags, + ).add_to_graph(custom_a.graph, loc=custom_a.location) + wide_custom = get_custom(wide_read) + wide_custom.index = wide_index + if hasattr(earliest_node, "vector_shapes"): + wide_read.vector_shapes = deepcopy( + earliest_node.vector_shapes + ) + propagate_tag(group[0][1], wide_read) + + for g_idx, g_node, g_custom, byte_pos, _ in group: + with g_custom.graph.inserting_before(g_node): + extract = ExtractSlice( + wide_read, [byte_pos], [1], [1] + ).add_to_graph(g_custom.graph, loc=g_custom.location) + if hasattr(g_node, "vector_shapes"): + get_custom(extract).vector_shapes = deepcopy( + g_node.vector_shapes + ) + propagate_tag(g_node, extract) + + g_custom.replace_all_uses_with(extract) + g_custom.graph.erase_node(g_node) + + coalesced_set.update(g[0] for g in group) + merged_any = True + return merged_any diff --git a/wave_lang/kernel/wave/opsel_scaled_mfma.py b/wave_lang/kernel/wave/opsel_scaled_mfma.py index 2a4910b93..9172c8a30 100644 --- a/wave_lang/kernel/wave/opsel_scaled_mfma.py +++ b/wave_lang/kernel/wave/opsel_scaled_mfma.py @@ -165,17 +165,19 @@ def _find_yield_op(for_view) -> Optional[Operation]: def _find_mergeable_groups( for_view, yield_op: Operation ) -> list[tuple[Value, Value, dict[int, int]]]: - """Find groups of 4 ``vector<1xi8>`` iter_args that can be coalesced. + """Find groups of ``vector<1xi8>`` iter_args that can be coalesced. A group is valid when: - * All 4 init values are ``extract_strided_slice`` at offsets - {0, 1, ..., SCALE_VECTOR_WIDTH-1} from the same ``vector<4xi8>`` - source. - * All 4 yield values follow the same pattern from a (possibly - different) ``vector<4xi8>`` source. + * At least 2 init values are ``extract_strided_slice`` at distinct + offsets from the same ``vector<4xi8>`` source. + * The corresponding yield values follow the same pattern from a + (possibly different) ``vector<4xi8>`` source. * For each member, init_offset == yield_offset (byte identity is preserved across iterations). + Partial groups (e.g. only offsets {0, 2}) are accepted — the + coalesced ``vector<4xi8>`` iter_arg simply carries unused bytes. + Returns a list of ``(init_source, yield_source, {offset: iter_index})``. """ i8 = IntegerType.get_signless(8) @@ -199,9 +201,6 @@ def _find_mergeable_groups( continue eligible.append((i, init_off, init_src, yield_src)) - # Group by init source. Multiple args can share the same source - # (e.g. two MFMAs using the same scale load), so partition by offset - # to form distinct groups of exactly 4. by_init_src = defaultdict(list) for entry in eligible: _, _, init_src, _ = entry @@ -214,12 +213,16 @@ def _find_mergeable_groups( _, off, _, _ = entry by_offset[off].append(entry) - while all(len(by_offset[o]) > 0 for o in range(SCALE_VECTOR_WIDTH)): + # Greedily form groups from available offsets. Accept any group + # with >= 2 distinct offsets (full groups of 4 are the common + # case; partial groups like {0, 2} arise from preshuffle scales). + present_offsets = [o for o in range(SCALE_VECTOR_WIDTH) if by_offset[o]] + while len(present_offsets) >= 2: members = {} init_source = None yield_owners = set() yield_source = None - for o in range(SCALE_VECTOR_WIDTH): + for o in present_offsets: idx, _, isrc, ysrc = by_offset[o].pop(0) members[o] = idx init_source = isrc @@ -227,6 +230,7 @@ def _find_mergeable_groups( yield_owners.add(id(ysrc.owner)) if len(yield_owners) == 1: result.append((init_source, yield_source, members)) + present_offsets = [o for o in range(SCALE_VECTOR_WIDTH) if by_offset[o]] return result @@ -314,15 +318,13 @@ def _rewire_for_results( members = plan.groups[g_idx][2] new_idx = plan.group_new_iter_idx[g_idx] has_users = any( - any(True for _ in old_results[members[o]].uses) - for o in range(SCALE_VECTOR_WIDTH) + any(True for _ in old_results[members[o]].uses) for o in members ) if not has_users: continue with InsertionPoint(for_op): - for o in range(SCALE_VECTOR_WIDTH): - old_i = members[o] + for o, old_i in members.items(): if not any(True for _ in old_results[old_i].uses): continue extract_slice = make_extract_slice(new_results[new_idx], o) @@ -330,12 +332,15 @@ def _rewire_for_results( def _coalesce_vector_iter_args(module: Module) -> None: - """Merge groups of 4 ``vector<1xi8>`` scf.for iter_args into ``vector<4xi8>``. + """Merge groups of ``vector<1xi8>`` scf.for iter_args into ``vector<4xi8>``. - Pipeline double-buffering splits a ``vector<4xi8>`` scale load into 4 + Pipeline double-buffering splits a ``vector<4xi8>`` scale load into individual bytes for loop-carry. This pass merges them back so that ``_trace_scale_chain`` sees the full ``extract_strided_slice`` pattern inside the loop body and the opsel optimisation fires. + + Handles both full groups (all 4 offsets present) and partial groups + (e.g. only offsets {0, 2} from preshuffle scales). """ i8 = IntegerType.get_signless(8) i64 = IntegerType.get_signless(64) @@ -364,7 +369,7 @@ def make_extract_slice(source: Value, offset: int): if not groups: continue - logger.debug(f"Coalescing {len(groups)} group(s) of 4 vector<1xi8> iter_args") + logger.debug(f"Coalescing {len(groups)} group(s) of vector<1xi8> iter_args") plan = _build_coalesce_plan(groups, for_view, yield_op) old_iter_args = list(for_view.inner_iter_args) @@ -396,8 +401,7 @@ def make_extract_slice(source: Value, offset: int): with InsertionPoint(first_op): for g_idx, (_, _, members) in enumerate(groups): merged_arg = new_for.inner_iter_args[plan.group_new_iter_idx[g_idx]] - for offset in range(SCALE_VECTOR_WIDTH): - iter_idx = members[offset] + for offset, iter_idx in members.items(): extract_slice = make_extract_slice(merged_arg, offset) extract_results[iter_idx] = extract_slice.result From 948ee26b19f4039e6625b947701aaf9e3655485a Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 05:45:52 +0000 Subject: [PATCH 04/15] flat exprsn Signed-off-by: xintin --- .../kernel/compiler/wave_codegen/handlers.py | 12 ++++ .../analysis/partition_strided_operators.py | 72 ++++++++++++++++--- wave_lang/kernel/wave/opsel_scaled_mfma.py | 7 ++ wave_lang/kernel/wave/utils/symbol_utils.py | 2 + 4 files changed, 83 insertions(+), 10 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/handlers.py b/wave_lang/kernel/compiler/wave_codegen/handlers.py index b913675e1..b67783eb8 100644 --- a/wave_lang/kernel/compiler/wave_codegen/handlers.py +++ b/wave_lang/kernel/compiler/wave_codegen/handlers.py @@ -1955,6 +1955,18 @@ def handle_extract_slice(emitter: WaveEmitter, node: fx.Node): strides, ) + mask_expr = getattr(node, "precomputed_mask_expr", None) + if mask_expr is not None: + mask = gen_sympy_index(add_emitter_subs(emitter), mask_expr) + ept = sizes[0] if sizes else 1 + mask_vec_type = VectorType.get([ept], IntegerType.get_signless(1)) + if mask.type != mask_vec_type: + mask = vector_d.broadcast(mask_vec_type, mask) + element_type = result_type.element_type + zero = arith_d.constant(element_type, get_constant_attr(0, element_type)) + zero_vec = vector_d.broadcast(result_type, zero) + element = arith_d.select(mask, element, zero_vec) + emitter.bind_node_proxy(node, IRProxyValue(element)) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 1202ee930..9b641d2b9 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -447,6 +447,57 @@ def _get_physical_start( return {dim: custom.index[dim].start for dim in symbolic_dims} +def _flatten_bounds_to_mask_expr( + custom: Read, + symbolic_shape: tuple, +): + """Pre-compute a read's bounds check as a flat sympy boolean. + + Replicates the _build_mask_with_mapping decision logic: when the + mapping's transformed index contains all bounded dims (and has no + dynamic_val_indices), the transformed index is used; otherwise the + original logical index is used. Returns a sympy boolean expression + (e.g. And(idx < bound, ...)) or None if the read has no bounds. + + """ + if not custom.bounds: + return None + + import functools + from ..utils.mapping_utils import transform_index_on_mapping + + index = custom.index + + if custom.mapping is not None and not custom.has_identity_mapping(): + transformed = transform_index_on_mapping( + custom.mapping, symbolic_shape, index, is_read=True + ) + use_transformed = ( + all(dim in transformed for dim in custom.bounds) + and not custom.mapping.dynamic_val_indices + ) + if use_transformed: + index = transformed + + conditions = [] + for dim, bound in custom.bounds.items(): + if dim not in index: + continue + start = ( + index[dim].start if isinstance(index[dim], IndexSequence) else index[dim] + ) + if isinstance(start, int): + start = sympy.Integer(start) + if isinstance(bound, int): + bound = sympy.Integer(bound) + conditions.append(sympy.StrictLessThan(start, bound)) + + if not conditions: + return None + + return functools.reduce(sympy.And, conditions) + + def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: """Single merge pass: merge reads that access nearby physical memory. @@ -528,16 +579,6 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: off1, phys1, custom1, node1 = read_infos[i] off2, phys2, custom2, node2 = read_infos[j] - # The pairwise merge drops the mapping and produces a - # physical-space read with no mask. That is unsafe for - # bounded reads whose OOB behaviour relies on the mask - # generated from bounds + mapping. Skip them here; - # bounded ept==1 reads (e.g. preshuffle scales backed by - # fat_raw_buffer) are handled by the multi-way coalescing - # pass below, where OOB loads safely return zero. - if custom1.bounds is not None or custom2.bounds is not None: - continue - raw_diff = subs_idxc(off2 - off1) # For reads with non-identity mappings (e.g. preshuffle @@ -632,6 +673,9 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: get_custom(extract0).vector_shapes = deepcopy( lo_custom.vector_shapes ) + lo_mask = _flatten_bounds_to_mask_expr(lo_custom, symbolic_shape) + if lo_mask is not None: + extract0.precomputed_mask_expr = lo_mask propagate_tag(lo_node, extract0) extract1 = ExtractSlice( @@ -641,6 +685,9 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: get_custom(extract1).vector_shapes = deepcopy( hi_custom.vector_shapes ) + hi_mask = _flatten_bounds_to_mask_expr(hi_custom, symbolic_shape) + if hi_mask is not None: + extract1.precomputed_mask_expr = hi_mask propagate_tag(hi_node, extract1) lo_custom.replace_all_uses_with(extract0) @@ -745,6 +792,11 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: get_custom(extract).vector_shapes = deepcopy( g_node.vector_shapes ) + g_mask = _flatten_bounds_to_mask_expr( + g_custom, symbolic_shape + ) + if g_mask is not None: + extract.precomputed_mask_expr = g_mask propagate_tag(g_node, extract) g_custom.replace_all_uses_with(extract) diff --git a/wave_lang/kernel/wave/opsel_scaled_mfma.py b/wave_lang/kernel/wave/opsel_scaled_mfma.py index 9172c8a30..3880dfba8 100644 --- a/wave_lang/kernel/wave/opsel_scaled_mfma.py +++ b/wave_lang/kernel/wave/opsel_scaled_mfma.py @@ -98,6 +98,8 @@ def _trace_scale_chain(scale_value): return None slice_op = bitcast_source.owner + if _is_op_named(slice_op, "arith.select"): + slice_op = slice_op.operands[1].owner if not _is_op_named(slice_op, "vector.extract_strided_slice"): return None @@ -139,9 +141,14 @@ def _trace_extract_strided_slice( ) -> Optional[tuple[Value, int]]: """Check if *value* is produced by extract_strided_slice of a vector<4xi8>. + Looks through ``arith.select`` (inserted by flatten-bounds masking) + to find the underlying extract_strided_slice. + Returns ``(source_vec4xi8, byte_offset)`` or ``None``. """ op = value.owner + if _is_op_named(op, "arith.select"): + op = op.operands[1].owner if not _is_op_named(op, "vector.extract_strided_slice"): return None source = op.operands[0] diff --git a/wave_lang/kernel/wave/utils/symbol_utils.py b/wave_lang/kernel/wave/utils/symbol_utils.py index 9c1013378..38b3ef3ff 100644 --- a/wave_lang/kernel/wave/utils/symbol_utils.py +++ b/wave_lang/kernel/wave/utils/symbol_utils.py @@ -409,6 +409,8 @@ def _numeric_eval_constant(expr, num_samples: int = 48): free, evaluator = (), None if not free: + if isinstance(expr, int): + return expr if expr.has(*_BAD_ATOMS): return None if expr.is_integer is not True: From 3170b2fa9a481b45762b48971de439c46b1b1efc Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 06:50:16 +0000 Subject: [PATCH 05/15] updated lit tests Signed-off-by: xintin --- lit_tests/kernel/wave/scaled_gemm.py | 12 ++++++------ lit_tests/kernel/wave/scaled_mma.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/lit_tests/kernel/wave/scaled_gemm.py b/lit_tests/kernel/wave/scaled_gemm.py index 7bf3b8b17..96dbdddfa 100644 --- a/lit_tests/kernel/wave/scaled_gemm.py +++ b/lit_tests/kernel/wave/scaled_gemm.py @@ -361,9 +361,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: } # Epilogue Local Read - # CHECK-COUNT-16: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-16: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> - # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-8: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> # Epilogue MFMA @@ -471,8 +471,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: rocdl.s.waitcnt # CHECK: amdgpu.lds_barrier - # Steady state local loads - # CHECK-COUNT-48: vector.load{{.*}} memref<{{.*}}, #gpu.address_space> + # Steady state local loads (8+4 scale loads as vector<8xi8> + 16+8 data loads as vector<16xi8>) + # CHECK-COUNT-36: vector.load{{.*}} memref<{{.*}}, #gpu.address_space> # Steady State global load to lds # CHECK-COUNT-34: amdgpu.gather_to_lds @@ -637,9 +637,9 @@ def repeat( # CHECK: } # Epilogue Local Read - # CHECK-COUNT-16: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-8: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-16: vector.load {{.*}} : memref<34816xi8, #gpu.address_space>, vector<16xi8> - # CHECK-COUNT-8: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-8: vector.load {{.*}} : memref<34816xi8, #gpu.address_space>, vector<16xi8> # Epilogue MFMA diff --git a/lit_tests/kernel/wave/scaled_mma.py b/lit_tests/kernel/wave/scaled_mma.py index f37dae9b5..a762fb4a2 100644 --- a/lit_tests/kernel/wave/scaled_mma.py +++ b/lit_tests/kernel/wave/scaled_mma.py @@ -352,9 +352,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-COUNT-4: vector.load {{.*}} : memref<16384x8192xi8, strided<[8192, 1]>>, vector<16xi8> # CHECK-COUNT-1: vector.load {{.*}} : memref<16384x512xi8, strided<[512, 1]>>, vector<4xi8> # CHECK: amdgpu.lds_barrier - # CHECK-COUNT-16: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-16: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> - # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-8: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> # CHECK-COUNT-8: vector.bitcast {{.*}} : vector<16xi8> to vector<32xf4E2M1FN> # CHECK-COUNT-8: vector.bitcast {{.*}} : vector<1xi8> to vector<1xf8E8M0FNU> From 3c9f6cedddc3bd16498fe09a424edee54f777257 Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 15:11:23 +0000 Subject: [PATCH 06/15] fix wave asm backend Signed-off-by: xintin --- wave_lang/kernel/wave/asm/handlers_memory.py | 23 +++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/wave_lang/kernel/wave/asm/handlers_memory.py b/wave_lang/kernel/wave/asm/handlers_memory.py index a32394028..e0ee25d3d 100644 --- a/wave_lang/kernel/wave/asm/handlers_memory.py +++ b/wave_lang/kernel/wave/asm/handlers_memory.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from wave_lang.support.ir_imports import ( + VectorType, amdgpu_d, gpu_d, memref_d, @@ -79,14 +80,24 @@ def handle_vector_extract_strided_slice_op( offset_val = int(str(offsets).split("[")[1].split("]")[0]) size_val = int(str(sizes).split("[")[1].split("]")[0]) + # Convert MLIR element offsets to physical register indices. + # Each 32-bit VGPR holds (32 // elem_bits) elements, e.g. 4 for i8. + # An extract at element offset 4 on vector<8xi8> targets register 1, + # not register 4. + source_vec_type = VectorType(operation.operands[0].type) + elem_bits = source_vec_type.element_type.width + elems_per_reg = 32 // elem_bits + + reg_offset = offset_val // elems_per_reg + reg_count = max(1, (size_val * elem_bits + 31) // 32) + # Extract the appropriate subset of registers - if size_val == 1: - # Single scalar extract - return just the one register as a tuple - extracted_reg = source_regs[offset_val] - result_regs = (extracted_reg,) + if reg_count == 1: + # Single register extract - return just the one register as a tuple + result_regs = (source_regs[reg_offset],) else: - # Multi-element extract - return a slice - result_regs = source_regs[offset_val : offset_val + size_val] + # Multi-register extract - return a slice + result_regs = source_regs[reg_offset : reg_offset + reg_count] result_ssa = str(operation.result) self.walker.kernel_ctx.ssa_to_reg[result_ssa] = result_regs From f72bd107afecb777d0a93543c1d2311e7f48adb7 Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 19:47:14 +0000 Subject: [PATCH 07/15] lint Signed-off-by: xintin --- .../wave/analysis/partition_strided_operators.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 9b641d2b9..c7fc0784c 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import functools from collections.abc import Sequence from copy import deepcopy from itertools import groupby @@ -30,6 +31,7 @@ ) from ..assumptions import get_divisibility_subs from ..constraints import Constraint +from ..utils.mapping_utils import transform_index_on_mapping from ..utils.tag_utils import propagate_tag from ..utils.general_utils import ( all_equal, @@ -433,8 +435,6 @@ def _get_physical_start( coordinates. For identity-mapped reads (mapping=None), reads the start offsets directly from the index. """ - from ..utils.mapping_utils import transform_index_on_mapping - if custom.mapping is not None and not custom.has_identity_mapping(): physical = transform_index_on_mapping( custom.mapping, symbolic_shape, custom.index, is_read=True @@ -457,15 +457,12 @@ def _flatten_bounds_to_mask_expr( mapping's transformed index contains all bounded dims (and has no dynamic_val_indices), the transformed index is used; otherwise the original logical index is used. Returns a sympy boolean expression - (e.g. And(idx < bound, ...)) or None if the read has no bounds. + (eg. And(idx < bound, ...)) or None if the read has no bounds. """ if not custom.bounds: return None - import functools - from ..utils.mapping_utils import transform_index_on_mapping - index = custom.index if custom.mapping is not None and not custom.has_identity_mapping(): @@ -503,11 +500,11 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: Two strategies are applied per (memory, ept) group: - 1. **Pairwise contiguous merge** — pairs whose physical flat offset + 1. **Pairwise contiguous merge**: pairs whose physical flat offset starts differ by exactly ``ept`` are merged into a ``2*ept`` read with two ExtractSlice outputs. - 2. **Multi-way coalescing** (``ept==1`` only) — unmerged byte reads + 2. **Multi-way coalescing** (``ept==1`` only): unmerged byte reads whose flat offsets fall within a power-of-2 aligned window (up to ``max_elems_per_load``) are replaced by a single wide read with per-byte ExtractSlice outputs. Diffs are evaluated via numeric @@ -736,6 +733,9 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: if len(group) < 2: continue + # Pick the smallest power-of-2 width that covers all + # byte offsets in the group, and skip if it exceeds + # the hardware's max load width. group.sort(key=lambda x: x[3]) max_off = group[-1][3] wide_ept = 1 From e11256b76e05c581f54a0f0f278c9290be6b00c5 Mon Sep 17 00:00:00 2001 From: xintin Date: Wed, 4 Mar 2026 17:30:41 +0000 Subject: [PATCH 08/15] add comment Signed-off-by: xintin --- wave_lang/kernel/compiler/wave_codegen/handlers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/wave_lang/kernel/compiler/wave_codegen/handlers.py b/wave_lang/kernel/compiler/wave_codegen/handlers.py index b67783eb8..09e68f56b 100644 --- a/wave_lang/kernel/compiler/wave_codegen/handlers.py +++ b/wave_lang/kernel/compiler/wave_codegen/handlers.py @@ -1955,6 +1955,10 @@ def handle_extract_slice(emitter: WaveEmitter, node: fx.Node): strides, ) + # When reads are coalesced into a single wider load, the original + # per-read bounds checks are lost. The partition pass preserves them + # as a precomputed sympy mask on each ExtractSlice so we can zero + # out-of-bounds lanes here. mask_expr = getattr(node, "precomputed_mask_expr", None) if mask_expr is not None: mask = gen_sympy_index(add_emitter_subs(emitter), mask_expr) From 2c71a628aec18fe13d2cbd890ddffea509a53b50 Mon Sep 17 00:00:00 2001 From: xintin Date: Wed, 4 Mar 2026 19:23:55 +0000 Subject: [PATCH 09/15] refactored Signed-off-by: xintin --- .../analysis/partition_strided_operators.py | 539 +++++++++--------- wave_lang/kernel/wave/utils/symbol_utils.py | 2 +- 2 files changed, 279 insertions(+), 262 deletions(-) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index c7fc0784c..74d4d475e 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -495,32 +495,53 @@ def _flatten_bounds_to_mask_expr( return functools.reduce(sympy.And, conditions) -def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: - """Single merge pass: merge reads that access nearby physical memory. - - Two strategies are applied per (memory, ept) group: - - 1. **Pairwise contiguous merge**: pairs whose physical flat offset - starts differ by exactly ``ept`` are merged into a ``2*ept`` read - with two ExtractSlice outputs. - - 2. **Multi-way coalescing** (``ept==1`` only): unmerged byte reads - whose flat offsets fall within a power-of-2 aligned window (up to - ``max_elems_per_load``) are replaced by a single wide read with - per-byte ExtractSlice outputs. Diffs are evaluated via numeric - probing (``_numeric_eval_constant``), so this handles non-constant - symbolic offsets such as preshuffle mappings. - - Returns True if any merges or coalescing happened. +def _emit_wide_read(anchor_custom, wide_index, wide_ept, tag_source): + """Create a merged Read node covering ``wide_ept`` elements.""" + wide_read = Read( + anchor_custom.memory, + elements_per_thread=wide_ept, + mapping=None, + _write_dependency=anchor_custom._write_dependency, + flags=anchor_custom.flags, + ).add_to_graph(anchor_custom.graph, loc=anchor_custom.location) + wide_custom = get_custom(wide_read) + wide_custom.index = wide_index + if hasattr(tag_source, "vector_shapes"): + wide_read.vector_shapes = deepcopy(tag_source.vector_shapes) + propagate_tag(tag_source, wide_read) + return wide_read + + +def _emit_extract_slice( + wide_read, offset, size, orig_custom, orig_node, symbolic_shape +): + """Create an ExtractSlice from a wide read with bounds mask and tag.""" + extract = ExtractSlice(wide_read, [offset], [size], [1]).add_to_graph( + orig_custom.graph, loc=orig_custom.location + ) + extract_custom = get_custom(extract) + extract_custom.index = deepcopy(orig_custom.index) + if hasattr(orig_node, "vector_shapes"): + extract_custom.vector_shapes = deepcopy(orig_node.vector_shapes) + mask = _flatten_bounds_to_mask_expr(orig_custom, symbolic_shape) + if mask is not None: + extract.precomputed_mask_expr = mask + propagate_tag(orig_node, extract) + return extract + + +def _group_reads_by_memory( + trace: CapturedTrace, +) -> dict[tuple, list[fx.Node]]: + """Group reads by (memory, ept, region). + + A new region starts at each subgraph boundary and whenever a + side-effecting op (write, barrier, ...) is encountered, so we never + merge reads across such ops. Reads with dynamic mapping values are + skipped to keep the merge logic simple. """ from collections import defaultdict - from ...compiler.utils import strides_from_symbolic_shape - from ..._support.indexing import IndexingContext - # Group reads by (memory, ept, region). A new region starts at each - # subgraph boundary and whenever a side-effecting op (write, barrier, ...) - # is encountered, so we never merge reads across such ops. Reads with - # dynamic mapping values are skipped to keep the merge logic simple. groups: dict[tuple, list[fx.Node]] = defaultdict(list) region_id = 0 for subgraph in trace.region_graph.subgraphs.values(): @@ -536,11 +557,234 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: continue if custom.mapping_dynamic_vals: continue - # Bounded reads are allowed into the group; the merge loop - # below only merges pairs with identical bounds. key = (custom.memory, custom.elements_per_thread, region_id) groups[key].append(node) + return groups + + +def _resolve_symbolic_diff(raw_diff, has_complex_mapping, expected_vals=None): + """Resolve a raw sympy offset difference to a value, or None. + + Strategy: + 1. If already a plain int / sympy.Integer, return it directly. + 2. If the mapping is complex (non-identity), use numeric probing. + 3. Otherwise try sym_simplify. If ``expected_vals`` is given and + the result isn't among them, fall back to numeric probing; + return None when neither approach succeeds. + """ + if isinstance(raw_diff, (int, sympy.Integer)): + return int(raw_diff) + if has_complex_mapping: + return _numeric_eval_constant(raw_diff) + simplified = sym_simplify(raw_diff) + if expected_vals is None or simplified in expected_vals: + return simplified + nv = _numeric_eval_constant(raw_diff) + return nv + + +def _pairwise_merge(read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint): + """Merge pairs of reads whose flat offsets differ by exactly ``ept``. + + Returns ``(merged_indices, did_merge)`` where *merged_indices* is + the set of read_infos indices that were consumed. + """ + merged = set() + did_merge = False + + for i in range(len(read_infos)): + if i in merged: + continue + for j in range(i + 1, len(read_infos)): + if j in merged: + continue + off1, phys1, custom1, node1 = read_infos[i] + off2, phys2, custom2, node2 = read_infos[j] + + raw_diff = subs_idxc(off2 - off1) + has_complex_mapping = ( + custom1.mapping is not None and not custom1.has_identity_mapping() + ) + diff = _resolve_symbolic_diff( + raw_diff, has_complex_mapping, expected_vals={ept, -ept} + ) + if diff is None: + continue + + if diff == ept: + lo_phys, hi_phys = phys1, phys2 + lo_custom, hi_custom = custom1, custom2 + lo_node, hi_node = node1, node2 + elif diff == -ept: + lo_phys, hi_phys = phys2, phys1 + lo_custom, hi_custom = custom2, custom1 + lo_node, hi_node = node2, node1 + else: + continue + + merge_dim = None + for dim in symbolic_dims: + raw_d = subs_idxc(hi_phys[dim] - lo_phys[dim]) + d = _resolve_symbolic_diff( + raw_d, has_complex_mapping, expected_vals={ept, 0} + ) + if d is None: + merge_dim = None + break + if d == ept: + merge_dim = dim + elif d != 0: + merge_dim = None + break + if merge_dim is None: + continue + + new_ept = 2 * ept + element_type = lo_custom.type.dtype + if new_ept > hw_constraint.max_elems_per_load(element_type): + continue + with lo_custom.graph.inserting_before(lo_node): + new_index = { + dim: IndexSequence( + lo_phys[dim], + new_ept if dim == merge_dim else 1, + 1, + ) + for dim in symbolic_dims + } + merged_read = _emit_wide_read(lo_custom, new_index, new_ept, lo_node) + + # Masks are attached per-slice rather than on the merged + # read because lo and hi may have different bounds + # conditions (e.g. lo is in-bounds while hi crosses the + # tensor boundary). A single mask on the wide read + # cannot express that — each half needs its own. + lo_extract = _emit_extract_slice( + merged_read, 0, ept, lo_custom, lo_node, symbolic_shape + ) + hi_extract = _emit_extract_slice( + merged_read, ept, ept, hi_custom, hi_node, symbolic_shape + ) + + lo_custom.replace_all_uses_with(lo_extract) + hi_custom.replace_all_uses_with(hi_extract) + lo_custom.graph.erase_node(lo_node) + hi_custom.graph.erase_node(hi_node) + merged.update({i, j}) + did_merge = True + break + + return merged, did_merge + + +def _multiway_coalesce( + read_infos, merged, reads, symbolic_dims, symbolic_shape, hw_constraint +): + """Coalesce unmerged ept==1 reads whose flat offsets fall in an aligned window. + + Groups reads whose numerically-probed flat offsets fall within a + power-of-2 aligned window (up to ``max_elems_per_load``), then emits + a single wide read with per-byte ExtractSlice ops. + """ + element_type = get_custom(reads[0]).type.dtype + max_load_width = hw_constraint.max_elems_per_load(element_type) + + unmerged_infos = [read_infos[k] for k in range(len(read_infos)) if k not in merged] + if len(unmerged_infos) < 2: + return False + + coalesced_any = False + coalesced_set: set[int] = set() + for anchor_idx in range(len(unmerged_infos)): + if anchor_idx in coalesced_set: + continue + off_a, phys_a, custom_a, node_a = unmerged_infos[anchor_idx] + + group = [(anchor_idx, node_a, custom_a, 0, phys_a)] + for probe_idx in range(len(unmerged_infos)): + if probe_idx == anchor_idx or probe_idx in coalesced_set: + continue + off_p, _, custom_p, node_p = unmerged_infos[probe_idx] + raw_diff = subs_idxc(off_p - off_a) + diff_val = _numeric_eval_constant(raw_diff) + if diff_val is None: + continue + if 0 < diff_val < max_load_width: + _, _, _, phys_p = unmerged_infos[probe_idx] + group.append((probe_idx, node_p, custom_p, diff_val, phys_p)) + + if len(group) < 2: + continue + + group.sort(key=lambda x: x[3]) + max_off = group[-1][3] + wide_ept = 1 + while wide_ept <= max_off: + wide_ept *= 2 + if wide_ept > max_load_width: + continue + + base_phys = group[0][4] + + earliest_node = group[0][1] + for g in group[1:]: + candidate = g[1] + for n in custom_a.graph.nodes: + if n is candidate: + earliest_node = candidate + break + if n is earliest_node: + break + + with get_custom(earliest_node).graph.inserting_before(earliest_node): + wide_index = {} + for dim_idx, dim in enumerate(symbolic_dims): + if dim_idx == len(symbolic_dims) - 1: + wide_index[dim] = IndexSequence(base_phys[dim], wide_ept, 1) + else: + wide_index[dim] = IndexSequence(base_phys[dim], 1, 1) + + wide_read = _emit_wide_read(custom_a, wide_index, wide_ept, earliest_node) + + extracts = [] + for g_idx, g_node, g_custom, byte_pos, _ in group: + with g_custom.graph.inserting_before(g_node): + ext = _emit_extract_slice( + wide_read, byte_pos, 1, g_custom, g_node, symbolic_shape + ) + extracts.append((ext, g_custom, g_node)) + + for ext, g_custom, g_node in extracts: + g_custom.replace_all_uses_with(ext) + g_custom.graph.erase_node(g_node) + + coalesced_set.update(g[0] for g in group) + coalesced_any = True + + return coalesced_any + + +def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: + """Single merge pass: merge reads that access nearby physical memory. + + Two strategies are applied per (memory, ept) group: + + 1. **Pairwise contiguous merge** (``_pairwise_merge``): pairs whose + physical flat offset starts differ by exactly ``ept`` are merged + into a ``2*ept`` read with two ExtractSlice outputs. + + 2. **Multi-way coalescing** (``_multiway_coalesce``, ``ept==1`` only): + unmerged byte reads whose flat offsets fall within a power-of-2 + aligned window are replaced by a single wide read with per-byte + ExtractSlice outputs. + + Returns True if any merges or coalescing happened. + """ + from ...compiler.utils import strides_from_symbolic_shape + from ..._support.indexing import IndexingContext + + groups = _group_reads_by_memory(trace) idxc = IndexingContext.current() merged_any = False @@ -566,244 +810,17 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: ) read_infos.append((flat_offset, phys_start, custom, node)) - merged = set() - for i in range(len(read_infos)): - if i in merged: - continue - for j in range(i + 1, len(read_infos)): - if j in merged: - continue - off1, phys1, custom1, node1 = read_infos[i] - off2, phys2, custom2, node2 = read_infos[j] - - raw_diff = subs_idxc(off2 - off1) - - # For reads with non-identity mappings (e.g. preshuffle - # scales), the flat-offset diff contains complex floor/Mod - # expressions that sympy.simplify cannot reduce. Use fast - # numeric probing instead. - has_complex_mapping = ( - custom1.mapping is not None and not custom1.has_identity_mapping() - ) + merged, did_merge = _pairwise_merge( + read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint + ) + merged_any |= did_merge - # subs_idxc may fully resolve to a plain int. - if isinstance(raw_diff, (int, sympy.Integer)): - diff = int(raw_diff) - elif has_complex_mapping: - diff = _numeric_eval_constant(raw_diff) - if diff is None: - continue - else: - diff = sym_simplify(raw_diff) - if diff != ept and diff != -ept: - nv = _numeric_eval_constant(raw_diff) - if nv is not None: - diff = nv - - if diff == ept: - lo_phys, hi_phys = phys1, phys2 - lo_custom, hi_custom = custom1, custom2 - lo_node, hi_node = node1, node2 - elif diff == -ept: - lo_phys, hi_phys = phys2, phys1 - lo_custom, hi_custom = custom2, custom1 - lo_node, hi_node = node2, node1 - else: - continue - - # Find dimension that advances by ept. - merge_dim = None - for dim in symbolic_dims: - raw_d = subs_idxc(hi_phys[dim] - lo_phys[dim]) - if isinstance(raw_d, (int, sympy.Integer)): - d = int(raw_d) - elif has_complex_mapping: - d = _numeric_eval_constant(raw_d) - if d is None: - merge_dim = None - break - else: - d = sym_simplify(raw_d) - if d != ept and d != 0: - nv = _numeric_eval_constant(raw_d) - if nv is not None: - d = nv - if d == ept: - merge_dim = dim - elif not (d == 0): - merge_dim = None - break - if merge_dim is None: - continue - - # Respect hardware vector width limit. - new_ept = 2 * ept - element_type = lo_custom.type.dtype - if new_ept > hw_constraint.max_elems_per_load(element_type): - continue - with lo_custom.graph.inserting_before(lo_node): - new_index = { - dim: IndexSequence( - lo_phys[dim], - new_ept if dim == merge_dim else 1, - 1, - ) - for dim in symbolic_dims - } - - merged_read = Read( - lo_custom.memory, - elements_per_thread=new_ept, - mapping=None, - _write_dependency=lo_custom._write_dependency, - flags=lo_custom.flags, - ).add_to_graph(lo_custom.graph, loc=lo_custom.location) - merged_custom = get_custom(merged_read) - merged_custom.index = new_index - merged_custom.vector_shapes = deepcopy(lo_custom.vector_shapes) - propagate_tag(lo_node, merged_read) - - extract0 = ExtractSlice(merged_read, [0], [ept], [1]).add_to_graph( - lo_custom.graph, loc=lo_custom.location - ) - get_custom(extract0).index = deepcopy(lo_custom.index) - get_custom(extract0).vector_shapes = deepcopy( - lo_custom.vector_shapes - ) - lo_mask = _flatten_bounds_to_mask_expr(lo_custom, symbolic_shape) - if lo_mask is not None: - extract0.precomputed_mask_expr = lo_mask - propagate_tag(lo_node, extract0) - - extract1 = ExtractSlice( - merged_read, [ept], [ept], [1] - ).add_to_graph(lo_custom.graph, loc=lo_custom.location) - get_custom(extract1).index = deepcopy(hi_custom.index) - get_custom(extract1).vector_shapes = deepcopy( - hi_custom.vector_shapes - ) - hi_mask = _flatten_bounds_to_mask_expr(hi_custom, symbolic_shape) - if hi_mask is not None: - extract1.precomputed_mask_expr = hi_mask - propagate_tag(hi_node, extract1) - - lo_custom.replace_all_uses_with(extract0) - hi_custom.replace_all_uses_with(extract1) - lo_custom.graph.erase_node(lo_node) - hi_custom.graph.erase_node(hi_node) - - merged.update({i, j}) - merged_any = True - break - - # Multi-way coalescing for ept==1 reads that survived the pairwise - # pass. Groups reads whose numerically-probed flat offsets fall - # within an aligned window of max_load_width bytes, then emits a - # single wide read + per-byte ExtractSlice ops. This generalises - # the former _coalesce_preshuffle_global_reads to all mapping types. + # Only ept==1 (byte) reads need multi-way coalescing; wider reads + # are already handled by the pairwise merge above. if ept == 1 and len(read_infos) >= 2: - element_type = get_custom(reads[0]).type.dtype - max_load_width = hw_constraint.max_elems_per_load(element_type) - - unmerged_infos = [ - read_infos[k] for k in range(len(read_infos)) if k not in merged - ] - if len(unmerged_infos) >= 2: - coalesced_set: set[int] = set() - for anchor_idx in range(len(unmerged_infos)): - if anchor_idx in coalesced_set: - continue - off_a, phys_a, custom_a, node_a = unmerged_infos[anchor_idx] - - group = [(anchor_idx, node_a, custom_a, 0, phys_a)] - for probe_idx in range(len(unmerged_infos)): - if probe_idx == anchor_idx or probe_idx in coalesced_set: - continue - off_p, _, custom_p, node_p = unmerged_infos[probe_idx] - raw_diff = subs_idxc(off_p - off_a) - diff_val = _numeric_eval_constant(raw_diff) - if diff_val is None: - continue - if 0 < diff_val < max_load_width: - _, _, _, phys_p = unmerged_infos[probe_idx] - group.append( - (probe_idx, node_p, custom_p, diff_val, phys_p) - ) - - if len(group) < 2: - continue - - # Pick the smallest power-of-2 width that covers all - # byte offsets in the group, and skip if it exceeds - # the hardware's max load width. - group.sort(key=lambda x: x[3]) - max_off = group[-1][3] - wide_ept = 1 - while wide_ept <= max_off: - wide_ept *= 2 - if wide_ept > max_load_width: - continue - - base_phys = group[0][4] - - earliest_node = group[0][1] - for g in group[1:]: - candidate = g[1] - for n in custom_a.graph.nodes: - if n is candidate: - earliest_node = candidate - break - if n is earliest_node: - break - - with get_custom(earliest_node).graph.inserting_before( - earliest_node - ): - wide_index = {} - for dim_idx, dim in enumerate(symbolic_dims): - if dim_idx == len(symbolic_dims) - 1: - wide_index[dim] = IndexSequence( - base_phys[dim], wide_ept, 1 - ) - else: - wide_index[dim] = IndexSequence(base_phys[dim], 1, 1) - - wide_read = Read( - custom_a.memory, - elements_per_thread=wide_ept, - mapping=None, - _write_dependency=custom_a._write_dependency, - flags=custom_a.flags, - ).add_to_graph(custom_a.graph, loc=custom_a.location) - wide_custom = get_custom(wide_read) - wide_custom.index = wide_index - if hasattr(earliest_node, "vector_shapes"): - wide_read.vector_shapes = deepcopy( - earliest_node.vector_shapes - ) - propagate_tag(group[0][1], wide_read) - - for g_idx, g_node, g_custom, byte_pos, _ in group: - with g_custom.graph.inserting_before(g_node): - extract = ExtractSlice( - wide_read, [byte_pos], [1], [1] - ).add_to_graph(g_custom.graph, loc=g_custom.location) - if hasattr(g_node, "vector_shapes"): - get_custom(extract).vector_shapes = deepcopy( - g_node.vector_shapes - ) - g_mask = _flatten_bounds_to_mask_expr( - g_custom, symbolic_shape - ) - if g_mask is not None: - extract.precomputed_mask_expr = g_mask - propagate_tag(g_node, extract) - - g_custom.replace_all_uses_with(extract) - g_custom.graph.erase_node(g_node) - - coalesced_set.update(g[0] for g in group) - merged_any = True + merged_any |= _multiway_coalesce( + read_infos, merged, reads, symbolic_dims, symbolic_shape, hw_constraint + ) return merged_any diff --git a/wave_lang/kernel/wave/utils/symbol_utils.py b/wave_lang/kernel/wave/utils/symbol_utils.py index 38b3ef3ff..1cb79e1bd 100644 --- a/wave_lang/kernel/wave/utils/symbol_utils.py +++ b/wave_lang/kernel/wave/utils/symbol_utils.py @@ -181,7 +181,7 @@ def transform_mod(expr): return None mult = m if (mult is None) or (m < mult) else mult terms.append(arg) - if c >= mult: + if c is None or mult is None or c >= mult: return None return (sum(terms) % q) + c From 62c3bb3bdc8e1a86aad1a4674c0026f8d94c4fee Mon Sep 17 00:00:00 2001 From: xintin Date: Wed, 4 Mar 2026 20:18:48 +0000 Subject: [PATCH 10/15] one mask for wide read Signed-off-by: xintin --- .../kernel/compiler/wave_codegen/handlers.py | 16 ---- .../compiler/wave_codegen/read_write.py | 13 +++- .../analysis/partition_strided_operators.py | 76 ++++++++++++++++--- 3 files changed, 76 insertions(+), 29 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/handlers.py b/wave_lang/kernel/compiler/wave_codegen/handlers.py index 09e68f56b..b913675e1 100644 --- a/wave_lang/kernel/compiler/wave_codegen/handlers.py +++ b/wave_lang/kernel/compiler/wave_codegen/handlers.py @@ -1955,22 +1955,6 @@ def handle_extract_slice(emitter: WaveEmitter, node: fx.Node): strides, ) - # When reads are coalesced into a single wider load, the original - # per-read bounds checks are lost. The partition pass preserves them - # as a precomputed sympy mask on each ExtractSlice so we can zero - # out-of-bounds lanes here. - mask_expr = getattr(node, "precomputed_mask_expr", None) - if mask_expr is not None: - mask = gen_sympy_index(add_emitter_subs(emitter), mask_expr) - ept = sizes[0] if sizes else 1 - mask_vec_type = VectorType.get([ept], IntegerType.get_signless(1)) - if mask.type != mask_vec_type: - mask = vector_d.broadcast(mask_vec_type, mask) - element_type = result_type.element_type - zero = arith_d.constant(element_type, get_constant_attr(0, element_type)) - zero_vec = vector_d.broadcast(result_type, zero) - element = arith_d.select(mask, element, zero_vec) - emitter.bind_node_proxy(node, IRProxyValue(element)) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 5caaada52..aedf6b20a 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -711,7 +711,18 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): ) dynamic_vals_map_start = _build_dyn_vals_map(mapping, dyn_vals) - if mapping: + is_global_mem = kb_src.type.memory_space is None + buffer_ops_enabled = emitter.options.use_buffer_ops and is_global_mem + + precomputed_mask_expr = getattr(node, "precomputed_mask_expr", None) + if precomputed_mask_expr is not None and not buffer_ops_enabled: + mask = gen_sympy_index(add_emitter_subs(emitter), precomputed_mask_expr) + mask_vec_type = VectorType.get( + [elements_per_thread], IntegerType.get_signless(1) + ) + if mask.type != mask_vec_type: + mask = vector_d.broadcast(mask_vec_type, mask) + elif mapping: transformed_index = transform_index_on_mapping( mapping, input_shape, index, is_read=True ) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 74d4d475e..8d6a9b90e 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -495,7 +495,49 @@ def _flatten_bounds_to_mask_expr( return functools.reduce(sympy.And, conditions) -def _emit_wide_read(anchor_custom, wide_index, wide_ept, tag_source): +def _build_wide_mask_expr(sub_reads, symbolic_shape, wide_ept): + """Build a concatenated sympy mask for a wide read from its sub-reads. + + Each entry in *sub_reads* is ``(offset, size, orig_custom)`` where + *offset* is the lane offset within the wide vector and *size* is the + number of lanes that sub-read occupies. *wide_ept* is the total + number of elements in the wide read (may exceed ``sum(sizes)`` when + there are gaps, e.g. multiway coalesce with non-contiguous offsets). + + Builds ``Or(And(lane_in_range_0, mask_0), And(lane_in_range_1, mask_1), ...)`` + using pure boolean ops so that ``gen_sympy_index`` can lower it without + the nested-Piecewise ``select_stack`` ordering issue. + + Returns a sympy boolean expression or ``None`` when no sub-read has + bounds. + """ + from ..._support.indexing import IndexingContext + + masks = [ + (offset, size, _flatten_bounds_to_mask_expr(custom, symbolic_shape)) + for offset, size, custom in sub_reads + ] + + if not any(m is not None for _, _, m in masks): + return None + + idxc = IndexingContext.current() + iota = idxc.iota(wide_ept) + + terms = [] + for offset, size, mask in masks: + upper = offset + size + lane_cond = sympy.And( + sympy.GreaterThan(iota, offset) if offset > 0 else sympy.true, + sympy.StrictLessThan(iota, upper), + ) + bound_cond = mask if mask is not None else sympy.true + terms.append(sympy.And(lane_cond, bound_cond)) + + return functools.reduce(sympy.Or, terms) + + +def _emit_wide_read(anchor_custom, wide_index, wide_ept, tag_source, mask_expr=None): """Create a merged Read node covering ``wide_ept`` elements.""" wide_read = Read( anchor_custom.memory, @@ -508,6 +550,8 @@ def _emit_wide_read(anchor_custom, wide_index, wide_ept, tag_source): wide_custom.index = wide_index if hasattr(tag_source, "vector_shapes"): wide_read.vector_shapes = deepcopy(tag_source.vector_shapes) + if mask_expr is not None: + wide_read.precomputed_mask_expr = mask_expr propagate_tag(tag_source, wide_read) return wide_read @@ -515,7 +559,7 @@ def _emit_wide_read(anchor_custom, wide_index, wide_ept, tag_source): def _emit_extract_slice( wide_read, offset, size, orig_custom, orig_node, symbolic_shape ): - """Create an ExtractSlice from a wide read with bounds mask and tag.""" + """Create an ExtractSlice from a wide read and propagate metadata.""" extract = ExtractSlice(wide_read, [offset], [size], [1]).add_to_graph( orig_custom.graph, loc=orig_custom.location ) @@ -523,9 +567,6 @@ def _emit_extract_slice( extract_custom.index = deepcopy(orig_custom.index) if hasattr(orig_node, "vector_shapes"): extract_custom.vector_shapes = deepcopy(orig_node.vector_shapes) - mask = _flatten_bounds_to_mask_expr(orig_custom, symbolic_shape) - if mask is not None: - extract.precomputed_mask_expr = mask propagate_tag(orig_node, extract) return extract @@ -557,6 +598,8 @@ def _group_reads_by_memory( continue if custom.mapping_dynamic_vals: continue + if getattr(node, "precomputed_mask_expr", None) is not None: + continue key = (custom.memory, custom.elements_per_thread, region_id) groups[key].append(node) return groups @@ -643,6 +686,11 @@ def _pairwise_merge(read_infos, ept, symbolic_dims, symbolic_shape, hw_constrain element_type = lo_custom.type.dtype if new_ept > hw_constraint.max_elems_per_load(element_type): continue + wide_mask = _build_wide_mask_expr( + [(0, ept, lo_custom), (ept, ept, hi_custom)], + symbolic_shape, + new_ept, + ) with lo_custom.graph.inserting_before(lo_node): new_index = { dim: IndexSequence( @@ -652,13 +700,10 @@ def _pairwise_merge(read_infos, ept, symbolic_dims, symbolic_shape, hw_constrain ) for dim in symbolic_dims } - merged_read = _emit_wide_read(lo_custom, new_index, new_ept, lo_node) + merged_read = _emit_wide_read( + lo_custom, new_index, new_ept, lo_node, mask_expr=wide_mask + ) - # Masks are attached per-slice rather than on the merged - # read because lo and hi may have different bounds - # conditions (e.g. lo is in-bounds while hi crosses the - # tensor boundary). A single mask on the wide read - # cannot express that — each half needs its own. lo_extract = _emit_extract_slice( merged_read, 0, ept, lo_custom, lo_node, symbolic_shape ) @@ -737,6 +782,11 @@ def _multiway_coalesce( if n is earliest_node: break + wide_mask = _build_wide_mask_expr( + [(byte_pos, 1, g_custom) for _, _, g_custom, byte_pos, _ in group], + symbolic_shape, + wide_ept, + ) with get_custom(earliest_node).graph.inserting_before(earliest_node): wide_index = {} for dim_idx, dim in enumerate(symbolic_dims): @@ -745,7 +795,9 @@ def _multiway_coalesce( else: wide_index[dim] = IndexSequence(base_phys[dim], 1, 1) - wide_read = _emit_wide_read(custom_a, wide_index, wide_ept, earliest_node) + wide_read = _emit_wide_read( + custom_a, wide_index, wide_ept, earliest_node, mask_expr=wide_mask + ) extracts = [] for g_idx, g_node, g_custom, byte_pos, _ in group: From e7b061a7a5babcd7701c4005dc981243a09e7bfd Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 4 Mar 2026 23:10:37 +0100 Subject: [PATCH 11/15] Use numeric probing for pairwise read merging to avoid symbolic diff explosion MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The old _pairwise_merge used O(n²) symbolic diff resolution via sympy.lambdify, which hangs on huge preshuffle index expressions (postorder_traversal of the diff tree never completes). Replace with xreplace-based numeric evaluation of each offset independently, dict lookup for O(1) candidate matching, and verification across multiple probe value sets. Fixes dynamic preshuffle MXFP4 GEMM compilation hanging in merge_contiguous_reads (128 reads now merge in ~1s). Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ivan Butygin --- .../analysis/partition_strided_operators.py | 322 +++++++++++++----- 1 file changed, 236 insertions(+), 86 deletions(-) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 8d6a9b90e..75fd99089 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -626,99 +626,199 @@ def _resolve_symbolic_diff(raw_diff, has_complex_mapping, expected_vals=None): return nv +def _do_merge( + lo_i, hi_i, merge_dim, read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint +): + """Emit a wide read merging reads at lo_i and hi_i. Returns True on success.""" + _, lo_phys, lo_custom, lo_node = read_infos[lo_i] + _, _, hi_custom, hi_node = read_infos[hi_i] + new_ept = 2 * ept + element_type = lo_custom.type.dtype + if new_ept > hw_constraint.max_elems_per_load(element_type): + return False + wide_mask = _build_wide_mask_expr( + [(0, ept, lo_custom), (ept, ept, hi_custom)], + symbolic_shape, + new_ept, + ) + with lo_custom.graph.inserting_before(lo_node): + new_index = { + dim: IndexSequence( + lo_phys[dim], + new_ept if dim == merge_dim else 1, + 1, + ) + for dim in symbolic_dims + } + merged_read = _emit_wide_read( + lo_custom, new_index, new_ept, lo_node, mask_expr=wide_mask + ) + lo_extract = _emit_extract_slice( + merged_read, 0, ept, lo_custom, lo_node, symbolic_shape + ) + hi_extract = _emit_extract_slice( + merged_read, ept, ept, hi_custom, hi_node, symbolic_shape + ) + lo_custom.replace_all_uses_with(lo_extract) + hi_custom.replace_all_uses_with(hi_extract) + lo_custom.graph.erase_node(lo_node) + hi_custom.graph.erase_node(hi_node) + return True + + +def _find_merge_dim_from_diffs(dim_diffs, ept, symbolic_dims): + """Return the single dimension whose diff equals ept, or None.""" + merge_dim = None + for dim in symbolic_dims: + d = dim_diffs[dim] + if d == ept: + if merge_dim is not None: + return None + merge_dim = dim + elif d != 0: + return None + return merge_dim + + +# Probe value sets for numeric offset evaluation. Diverse primes avoid +# floor/Mod aliasing; all positive (symbols are nonneg). +_MERGE_PROBES = [ + lambda i: 137 + i * 31, + lambda i: 251 + i * 47, + lambda i: 503 + i * 17, +] + + +def _eval_expr(expr, probe_map): + """Evaluate a sympy expression with concrete symbol values.""" + if isinstance(expr, (int, float)): + return int(expr) + return int(expr.xreplace(probe_map)) + + def _pairwise_merge(read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint): """Merge pairs of reads whose flat offsets differ by exactly ``ept``. Returns ``(merged_indices, did_merge)`` where *merged_indices* is the set of read_infos indices that were consumed. + + Evaluates each offset independently with concrete probe values and + uses dict lookup for O(n) candidate matching, avoiding O(n²) + symbolic diff resolution. """ + n = len(read_infos) + if n < 2: + return set(), False + merged = set() did_merge = False - for i in range(len(read_infos)): - if i in merged: - continue - for j in range(i + 1, len(read_infos)): - if j in merged: - continue - off1, phys1, custom1, node1 = read_infos[i] - off2, phys2, custom2, node2 = read_infos[j] + # Resolve all flat offsets and per-dim physical starts once. + resolved_flat = [subs_idxc(info[0]) for info in read_infos] + resolved_phys = [ + {dim: subs_idxc(info[1][dim]) for dim in symbolic_dims} for info in read_infos + ] - raw_diff = subs_idxc(off2 - off1) - has_complex_mapping = ( - custom1.mapping is not None and not custom1.has_identity_mapping() - ) - diff = _resolve_symbolic_diff( - raw_diff, has_complex_mapping, expected_vals={ept, -ept} - ) - if diff is None: - continue + # Collect free symbols across all expressions. + all_free = set() + for expr in resolved_flat: + if hasattr(expr, "free_symbols"): + all_free.update(expr.free_symbols) + for phys in resolved_phys: + for expr in phys.values(): + if hasattr(expr, "free_symbols"): + all_free.update(expr.free_symbols) + free_list = sorted(all_free, key=str) + + # Build probe maps. + probe_maps = [{s: gen(i) for i, s in enumerate(free_list)} for gen in _MERGE_PROBES] + + # Evaluate flat offsets with first probe for candidate matching. + num_flat_0 = [None] * n + for i in range(n): + try: + num_flat_0[i] = _eval_expr(resolved_flat[i], probe_maps[0]) + except (TypeError, ValueError, ZeroDivisionError): + pass + + # Evaluate per-dim physical starts with first probe. + num_phys_0 = [None] * n + for i in range(n): + if num_flat_0[i] is None: + continue + try: + num_phys_0[i] = { + dim: _eval_expr(resolved_phys[i][dim], probe_maps[0]) + for dim in symbolic_dims + } + except (TypeError, ValueError, ZeroDivisionError): + pass + + # Build dict: numeric_flat_offset -> [indices] for O(1) partner lookup. + from collections import defaultdict - if diff == ept: - lo_phys, hi_phys = phys1, phys2 - lo_custom, hi_custom = custom1, custom2 - lo_node, hi_node = node1, node2 - elif diff == -ept: - lo_phys, hi_phys = phys2, phys1 - lo_custom, hi_custom = custom2, custom1 - lo_node, hi_node = node2, node1 - else: - continue + offset_map = defaultdict(list) + for i in range(n): + if num_flat_0[i] is not None: + offset_map[num_flat_0[i]].append(i) - merge_dim = None - for dim in symbolic_dims: - raw_d = subs_idxc(hi_phys[dim] - lo_phys[dim]) - d = _resolve_symbolic_diff( - raw_d, has_complex_mapping, expected_vals={ept, 0} - ) - if d is None: - merge_dim = None - break - if d == ept: - merge_dim = dim - elif d != 0: - merge_dim = None - break - if merge_dim is None: - continue + def _verify_with_extra_probes(lo_i, hi_i, expected_flat_diff, expected_dim_diffs): + """Confirm diffs are consistent across additional probe sets.""" + for probe in probe_maps[1:]: + try: + flat_lo = _eval_expr(resolved_flat[lo_i], probe) + flat_hi = _eval_expr(resolved_flat[hi_i], probe) + if flat_hi - flat_lo != expected_flat_diff: + return False + for dim in symbolic_dims: + d_lo = _eval_expr(resolved_phys[lo_i][dim], probe) + d_hi = _eval_expr(resolved_phys[hi_i][dim], probe) + if d_hi - d_lo != expected_dim_diffs[dim]: + return False + except (TypeError, ValueError, ZeroDivisionError): + return False + return True - new_ept = 2 * ept - element_type = lo_custom.type.dtype - if new_ept > hw_constraint.max_elems_per_load(element_type): - continue - wide_mask = _build_wide_mask_expr( - [(0, ept, lo_custom), (ept, ept, hi_custom)], - symbolic_shape, - new_ept, - ) - with lo_custom.graph.inserting_before(lo_node): - new_index = { - dim: IndexSequence( - lo_phys[dim], - new_ept if dim == merge_dim else 1, - 1, - ) + for i in range(n): + if i in merged or num_flat_0[i] is None: + continue + vi = num_flat_0[i] + found = False + for target, i_is_lo in ((vi + ept, True), (vi - ept, False)): + for j in offset_map.get(target, []): + if j in merged or j == i: + continue + if num_phys_0[j] is None: + continue + lo_i, hi_i = (i, j) if i_is_lo else (j, i) + # Per-dim check with first probe. + dim_diffs = { + dim: num_phys_0[hi_i][dim] - num_phys_0[lo_i][dim] for dim in symbolic_dims } - merged_read = _emit_wide_read( - lo_custom, new_index, new_ept, lo_node, mask_expr=wide_mask - ) - - lo_extract = _emit_extract_slice( - merged_read, 0, ept, lo_custom, lo_node, symbolic_shape - ) - hi_extract = _emit_extract_slice( - merged_read, ept, ept, hi_custom, hi_node, symbolic_shape - ) - - lo_custom.replace_all_uses_with(lo_extract) - hi_custom.replace_all_uses_with(hi_extract) - lo_custom.graph.erase_node(lo_node) - hi_custom.graph.erase_node(hi_node) - - merged.update({i, j}) - did_merge = True - break + merge_dim = _find_merge_dim_from_diffs(dim_diffs, ept, symbolic_dims) + if merge_dim is None: + continue + # Verify with additional probes. + flat_diff = num_flat_0[hi_i] - num_flat_0[lo_i] + if not _verify_with_extra_probes(lo_i, hi_i, flat_diff, dim_diffs): + continue + if _do_merge( + lo_i, + hi_i, + merge_dim, + read_infos, + ept, + symbolic_dims, + symbolic_shape, + hw_constraint, + ): + merged.update({i, j}) + did_merge = True + found = True + break + if found: + break return merged, did_merge @@ -739,25 +839,63 @@ def _multiway_coalesce( if len(unmerged_infos) < 2: return False + # Pre-evaluate flat offsets with probe values to avoid symbolic diffs. + resolved_offs = [subs_idxc(info[0]) for info in unmerged_infos] + all_free = set() + for expr in resolved_offs: + if hasattr(expr, "free_symbols"): + all_free.update(expr.free_symbols) + free_list = sorted(all_free, key=str) + probe0 = {s: _MERGE_PROBES[0](i) for i, s in enumerate(free_list)} + extra_probes = [ + {s: gen(i) for i, s in enumerate(free_list)} for gen in _MERGE_PROBES[1:] + ] + num_offs = [None] * len(unmerged_infos) + for i, expr in enumerate(resolved_offs): + try: + num_offs[i] = _eval_expr(expr, probe0) + except (TypeError, ValueError, ZeroDivisionError): + pass + coalesced_any = False coalesced_set: set[int] = set() for anchor_idx in range(len(unmerged_infos)): if anchor_idx in coalesced_set: continue - off_a, phys_a, custom_a, node_a = unmerged_infos[anchor_idx] + if num_offs[anchor_idx] is None: + continue + _, phys_a, custom_a, node_a = unmerged_infos[anchor_idx] group = [(anchor_idx, node_a, custom_a, 0, phys_a)] for probe_idx in range(len(unmerged_infos)): if probe_idx == anchor_idx or probe_idx in coalesced_set: continue - off_p, _, custom_p, node_p = unmerged_infos[probe_idx] - raw_diff = subs_idxc(off_p - off_a) - diff_val = _numeric_eval_constant(raw_diff) - if diff_val is None: + if num_offs[probe_idx] is None: + continue + diff_val = num_offs[probe_idx] - num_offs[anchor_idx] + if not (0 < diff_val < max_load_width): continue - if 0 < diff_val < max_load_width: - _, _, _, phys_p = unmerged_infos[probe_idx] - group.append((probe_idx, node_p, custom_p, diff_val, phys_p)) + # Verify diff is constant across extra probes. + consistent = True + for ep in extra_probes: + try: + va = _eval_expr(resolved_offs[anchor_idx], ep) + vp = _eval_expr(resolved_offs[probe_idx], ep) + if vp - va != diff_val: + consistent = False + break + except (TypeError, ValueError, ZeroDivisionError): + consistent = False + break + if not consistent: + continue + _, custom_p, node_p = ( + unmerged_infos[probe_idx][1], + unmerged_infos[probe_idx][2], + unmerged_infos[probe_idx][3], + ) + _, _, _, phys_p = unmerged_infos[probe_idx] + group.append((probe_idx, node_p, custom_p, diff_val, phys_p)) if len(group) < 2: continue @@ -862,17 +1000,29 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: ) read_infos.append((flat_offset, phys_start, custom, node)) + import time as _time + + print(f"[DEBUG merge] ept={ept} n_reads={len(read_infos)}", flush=True) + _t0 = _time.time() merged, did_merge = _pairwise_merge( read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint ) + print( + f"[DEBUG merge] _pairwise_merge {_time.time()-_t0:.3f}s merged={len(merged)} did_merge={did_merge}", + flush=True, + ) merged_any |= did_merge # Only ept==1 (byte) reads need multi-way coalescing; wider reads # are already handled by the pairwise merge above. if ept == 1 and len(read_infos) >= 2: + _t0 = _time.time() merged_any |= _multiway_coalesce( read_infos, merged, reads, symbolic_dims, symbolic_shape, hw_constraint ) + print( + f"[DEBUG merge] _multiway_coalesce {_time.time()-_t0:.3f}s", flush=True + ) return merged_any From 4c2f7f7d1f8a2abf91207f2720ff19bee863a8aa Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Wed, 4 Mar 2026 23:37:27 +0100 Subject: [PATCH 12/15] Allow re-merging reads with precomputed masks across merge levels Previously, _group_reads_by_memory skipped reads with precomputed_mask_expr, preventing merged ept=2 reads from being further merged to ept=4/8/16. Fix by removing the skip and remapping the sub-read's iota symbol ($IOTA{old_size} -> $IOTA{wide_ept} - offset) when composing masks in _build_wide_mask_expr. Result: dynamic preshuffle MXFP4 b-tensor loads go from 332 vector<2xi8> + 84 vector<16xi8> to 8 vector<2xi8> + 120 vector<16xi8>. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ivan Butygin --- .../analysis/partition_strided_operators.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 75fd99089..5732a68d5 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -513,10 +513,17 @@ def _build_wide_mask_expr(sub_reads, symbolic_shape, wide_ept): """ from ..._support.indexing import IndexingContext - masks = [ - (offset, size, _flatten_bounds_to_mask_expr(custom, symbolic_shape)) - for offset, size, custom in sub_reads - ] + from ..._support.indexing import index_symbol + + masks = [] + for offset, size, custom in sub_reads: + existing = getattr(custom.fx_node, "precomputed_mask_expr", None) + if existing is not None: + masks.append((offset, size, existing)) + else: + masks.append( + (offset, size, _flatten_bounds_to_mask_expr(custom, symbolic_shape)) + ) if not any(m is not None for _, _, m in masks): return None @@ -531,7 +538,14 @@ def _build_wide_mask_expr(sub_reads, symbolic_shape, wide_ept): sympy.GreaterThan(iota, offset) if offset > 0 else sympy.true, sympy.StrictLessThan(iota, upper), ) - bound_cond = mask if mask is not None else sympy.true + if mask is not None: + # Remap any iota from a previous merge level to the new wide iota. + old_iota_sym = index_symbol(f"$IOTA{size}") + if mask.has(old_iota_sym): + mask = mask.subs(old_iota_sym, iota - offset) + bound_cond = mask + else: + bound_cond = sympy.true terms.append(sympy.And(lane_cond, bound_cond)) return functools.reduce(sympy.Or, terms) @@ -598,8 +612,6 @@ def _group_reads_by_memory( continue if custom.mapping_dynamic_vals: continue - if getattr(node, "precomputed_mask_expr", None) is not None: - continue key = (custom.memory, custom.elements_per_thread, region_id) groups[key].append(node) return groups @@ -1002,7 +1014,10 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: import time as _time - print(f"[DEBUG merge] ept={ept} n_reads={len(read_infos)}", flush=True) + print( + f"[DEBUG merge] mem={memory.fx_node.name} ept={ept} region={_region} n_reads={len(read_infos)}", + flush=True, + ) _t0 = _time.time() merged, did_merge = _pairwise_merge( read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint From 9d8997731acf53e1ab95461ac2da5f8e7ddea848 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 6 Mar 2026 01:00:13 +0100 Subject: [PATCH 13/15] Apply divisibility substitutions before numeric probing in read coalescing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The pairwise merge uses numeric probing to verify that adjacent reads have consistent per-dim offset diffs across multiple probe points. With symbolic K, the 2D decomposition (row = offset floordiv K/2, col = offset mod K/2) gives inconsistent diffs when probe values don't respect divisibility constraints — e.g. at K=137, K/2=68, adjacent bytes straddle a row boundary that doesn't exist at K=256. Fix by applying divisibility forward subs (K -> 256*K') before probing, so floordiv/Mod evaluate consistently across all probes. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ivan Butygin --- .../analysis/partition_strided_operators.py | 45 ++++++++++++++++--- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 5732a68d5..fe8ef040e 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -420,7 +420,8 @@ def merge_contiguous_reads( physical flat offset starts differ by exactly ept are merged. """ hw_constraint = get_hardware_constraint(constraints) - while _merge_contiguous_reads_once(trace, hw_constraint): + fwd, _ = get_divisibility_subs(constraints) + while _merge_contiguous_reads_once(trace, hw_constraint, divisibility_fwd=fwd): pass @@ -708,7 +709,9 @@ def _eval_expr(expr, probe_map): return int(expr.xreplace(probe_map)) -def _pairwise_merge(read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint): +def _pairwise_merge( + read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint, divisibility_fwd=None +): """Merge pairs of reads whose flat offsets differ by exactly ``ept``. Returns ``(merged_indices, did_merge)`` where *merged_indices* is @@ -731,6 +734,15 @@ def _pairwise_merge(read_infos, ept, symbolic_dims, symbolic_shape, hw_constrain {dim: subs_idxc(info[1][dim]) for dim in symbolic_dims} for info in read_infos ] + # Apply divisibility forward subs so that floordiv/Mod with symbolic + # divisors evaluate consistently across probe points. + if divisibility_fwd: + resolved_flat = [safe_subs(e, divisibility_fwd) for e in resolved_flat] + resolved_phys = [ + {dim: safe_subs(e, divisibility_fwd) for dim, e in phys.items()} + for phys in resolved_phys + ] + # Collect free symbols across all expressions. all_free = set() for expr in resolved_flat: @@ -836,7 +848,13 @@ def _verify_with_extra_probes(lo_i, hi_i, expected_flat_diff, expected_dim_diffs def _multiway_coalesce( - read_infos, merged, reads, symbolic_dims, symbolic_shape, hw_constraint + read_infos, + merged, + reads, + symbolic_dims, + symbolic_shape, + hw_constraint, + divisibility_fwd=None, ): """Coalesce unmerged ept==1 reads whose flat offsets fall in an aligned window. @@ -853,6 +871,8 @@ def _multiway_coalesce( # Pre-evaluate flat offsets with probe values to avoid symbolic diffs. resolved_offs = [subs_idxc(info[0]) for info in unmerged_infos] + if divisibility_fwd: + resolved_offs = [safe_subs(e, divisibility_fwd) for e in resolved_offs] all_free = set() for expr in resolved_offs: if hasattr(expr, "free_symbols"): @@ -967,7 +987,9 @@ def _multiway_coalesce( return coalesced_any -def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: +def _merge_contiguous_reads_once( + trace: CapturedTrace, hw_constraint, divisibility_fwd=None +) -> bool: """Single merge pass: merge reads that access nearby physical memory. Two strategies are applied per (memory, ept) group: @@ -1020,7 +1042,12 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: ) _t0 = _time.time() merged, did_merge = _pairwise_merge( - read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint + read_infos, + ept, + symbolic_dims, + symbolic_shape, + hw_constraint, + divisibility_fwd=divisibility_fwd, ) print( f"[DEBUG merge] _pairwise_merge {_time.time()-_t0:.3f}s merged={len(merged)} did_merge={did_merge}", @@ -1033,7 +1060,13 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: if ept == 1 and len(read_infos) >= 2: _t0 = _time.time() merged_any |= _multiway_coalesce( - read_infos, merged, reads, symbolic_dims, symbolic_shape, hw_constraint + read_infos, + merged, + reads, + symbolic_dims, + symbolic_shape, + hw_constraint, + divisibility_fwd=divisibility_fwd, ) print( f"[DEBUG merge] _multiway_coalesce {_time.time()-_t0:.3f}s", flush=True From abf6de55d99b5af322b6ece41044474ac3025079 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 6 Mar 2026 01:16:42 +0100 Subject: [PATCH 14/15] Add unit tests for numeric probing with MXFP4 preshuffle expressions Tests use actual B-scale preshuffle index expressions (row = floor(offset / (K/2)), col = offset mod (K/2)) to verify that: - Flat offset diffs are always correct regardless of probe values. - Per-dim diffs are inconsistent without divisibility subs (the bug). - Per-dim diffs become consistent after K -> 256*K' substitution. - _find_merge_dim_from_diffs correctly identifies the merge dimension. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ivan Butygin --- tests/unittests/test_numeric_probing.py | 287 ++++++++++++++++++++++++ 1 file changed, 287 insertions(+) create mode 100644 tests/unittests/test_numeric_probing.py diff --git a/tests/unittests/test_numeric_probing.py b/tests/unittests/test_numeric_probing.py new file mode 100644 index 000000000..d30d5c5ee --- /dev/null +++ b/tests/unittests/test_numeric_probing.py @@ -0,0 +1,287 @@ +# Copyright 2025 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Tests for numeric probing in read coalescing. + +The pairwise merge uses numeric probing to verify that adjacent reads +have consistent per-dim offset diffs. With symbolic K, the 2D +decomposition (row = offset floordiv K/2, col = offset mod K/2) +can give inconsistent diffs when probe values don't respect divisibility +constraints. Applying divisibility forward subs (K -> 256*K') fixes +this. + +The expressions here are taken from the MXFP4 preshuffle B-scale +codegen (test_dbuf_4wave_mxfp_dynamic_preshuffle_b_gemm on gfx950). +""" + +import pytest +import sympy + +from wave_lang.kernel.wave.analysis.partition_strided_operators import ( + _MERGE_PROBES, + _eval_expr, + _find_merge_dim_from_diffs, +) +from wave_lang.kernel.wave.utils.symbol_utils import ( + _numeric_eval_constant, + safe_subs, +) + + +# --- Symbols matching the MXFP4 B-scale path --- +# t = thread_id_x (0..63), K = GEMM K dimension (multiple of 256), +# wg_m/wg_n = workgroup tile offsets. +t = sympy.Symbol("t", integer=True, nonneg=True) +K = sympy.Symbol("K", integer=True, positive=True) +wg_m = sympy.Symbol("wg_m", integer=True, nonneg=True) +wg_n = sympy.Symbol("wg_n", integer=True, nonneg=True) +half_K = K / 2 + +# Preshuffle base offset for thread t (from MLIR #map27/#map28 pattern): +# base(t) = t*16 + ((t%64)//16)*256 - (t//16)*256 +# For t in 0..63 this simplifies to t*16, but sympy doesn't know that. +_base = t * 16 + ((t % 64) // 16) * 256 - (t // 16) * 256 + + +def _make_row_col(offset): + """Build (row, col) from a linearized preshuffle offset.""" + row = wg_m * 64 + wg_n * 256 + sympy.floor(offset / half_K) + col = offset % half_K + return row, col + + +# The MXFP4 preshuffle MMA layout scatters 16 scale bytes across +# offsets with large gaps. From the MLIR maps (#map27-#map36), the +# actual offsets within a single thread's 16-byte read are: +# base, base+1024, base+2, base+3, base+11, base+15, +# base+16*N, base+16*N+1024, ... +# Here we use the first 4 groups from the MLIR maps: offsets +# base+0, base+1024, base+2, base+3 — which is enough to +# demonstrate the inconsistency. +_PRESHUFFLE_DELTAS = [ + 0, + 1, + 1024, + 1025, + 2, + 3, + 11, + 15, + 16, + 17, + 1040, + 1041, + 18, + 19, + 27, + 31, +] + +BYTE_READS = [] +for delta in _PRESHUFFLE_DELTAS: + row_i, col_i = _make_row_col(_base + delta) + flat_i = row_i * half_K + col_i + BYTE_READS.append({"row": row_i, "col": col_i, "flat": flat_i, "delta": delta}) + +# Divisibility forward sub: K -> 256 * K'. +K_prime = sympy.Symbol("_K_div_256", integer=True, positive=True) +DIV_FWD = [(K, 256 * K_prime)] + + +def _build_merge_probe_maps(expressions): + """Build _MERGE_PROBES probe maps over free symbols in expressions.""" + all_free = set() + for expr in expressions: + if hasattr(expr, "free_symbols"): + all_free |= expr.free_symbols + free_list = sorted(all_free, key=str) + return [{s: gen(i) for i, s in enumerate(free_list)} for gen in _MERGE_PROBES] + + +def _find_reads_by_delta(d0, d1): + """Return the two BYTE_READS entries with the given deltas.""" + r0 = next(br for br in BYTE_READS if br["delta"] == d0) + r1 = next(br for br in BYTE_READS if br["delta"] == d1) + return r0, r1 + + +class TestMergeProbeConsistency: + """Verify that numeric probing gives consistent per-dim diffs + for adjacent B-scale reads, with and without divisibility subs.""" + + def test_flat_offset_diff_always_correct(self): + """Flat offset diff between bytes equals their delta difference.""" + all_flats = [br["flat"] for br in BYTE_READS] + probes = _build_merge_probe_maps(all_flats) + for probe in probes: + for i in range(len(BYTE_READS) - 1): + flat_i = _eval_expr(BYTE_READS[i]["flat"], probe) + flat_j = _eval_expr(BYTE_READS[i + 1]["flat"], probe) + expected = BYTE_READS[i + 1]["delta"] - BYTE_READS[i]["delta"] + assert flat_j - flat_i == expected, ( + f"flat diff mismatch for deltas " + f"{BYTE_READS[i]['delta']},{BYTE_READS[i+1]['delta']}" + ) + + def test_per_dim_diffs_inconsistent_without_div_subs(self): + """Without divisibility subs, per-dim diffs are inconsistent + across probe sets for reads with large preshuffle gaps. + + Reads at offsets base+0 and base+1024 (from MXFP4 maps #map27 + and #map29) have row_diff and col_diff that depend on K/2. + At K/2=68: row_diff=15, col_diff=-3. + At K/2=125: row_diff=8, col_diff=20. Etc. + """ + r0, r1 = _find_reads_by_delta(0, 1024) + all_exprs = list(r0.values()) + list(r1.values()) + probes = _build_merge_probe_maps(all_exprs) + + diffs_per_probe = [] + for probe in probes: + row_diff = _eval_expr(r1["row"], probe) - _eval_expr(r0["row"], probe) + col_diff = _eval_expr(r1["col"], probe) - _eval_expr(r0["col"], probe) + diffs_per_probe.append((row_diff, col_diff)) + + all_agree = all(d == diffs_per_probe[0] for d in diffs_per_probe[1:]) + assert not all_agree, ( + "Expected inconsistent per-dim diffs across probes, " + f"but all agreed: {diffs_per_probe[0]}" + ) + + def test_adjacent_diffs_consistent_with_div_subs(self): + """With divisibility subs (K -> 256*K'), adjacent byte pairs + have consistent diffs: row_diff=0, col_diff=delta_diff.""" + div_reads = [ + {k: safe_subs(v, DIV_FWD) for k, v in br.items()} for br in BYTE_READS + ] + all_exprs = [] + for br in div_reads: + all_exprs.extend(v for k, v in br.items() if k != "delta") + probes = _build_merge_probe_maps(all_exprs) + + # Test all consecutive pairs sorted by delta. + sorted_reads = sorted(zip(BYTE_READS, div_reads), key=lambda x: x[0]["delta"]) + for probe in probes: + for idx in range(len(sorted_reads) - 1): + orig_a, div_a = sorted_reads[idx] + orig_b, div_b = sorted_reads[idx + 1] + row_diff = _eval_expr(div_b["row"], probe) - _eval_expr( + div_a["row"], probe + ) + col_diff = _eval_expr(div_b["col"], probe) - _eval_expr( + div_a["col"], probe + ) + expected_flat = orig_b["delta"] - orig_a["delta"] + assert row_diff == 0, ( + f"row diff != 0 for deltas " f"{orig_a['delta']},{orig_b['delta']}" + ) + assert col_diff == expected_flat, ( + f"col diff {col_diff} != {expected_flat} for deltas " + f"{orig_a['delta']},{orig_b['delta']}" + ) + + def test_large_gap_consistent_with_div_subs(self): + """Reads 1024 apart have consistent diffs after div subs.""" + r0, r1 = _find_reads_by_delta(0, 1024) + div_r0 = {k: safe_subs(v, DIV_FWD) for k, v in r0.items()} + div_r1 = {k: safe_subs(v, DIV_FWD) for k, v in r1.items()} + + all_exprs = [v for d in [div_r0, div_r1] for k, v in d.items() if k != "delta"] + probes = _build_merge_probe_maps(all_exprs) + + for probe in probes: + row_diff = _eval_expr(div_r1["row"], probe) - _eval_expr( + div_r0["row"], probe + ) + col_diff = _eval_expr(div_r1["col"], probe) - _eval_expr( + div_r0["col"], probe + ) + assert row_diff == 0, f"row diff {row_diff} != 0" + assert col_diff == 1024, f"col diff {col_diff} != 1024" + + def test_merge_dim_found_with_div_subs(self): + """After div subs, _find_merge_dim_from_diffs finds the col dim.""" + r0, r1 = _find_reads_by_delta(0, 1) + div_r0 = {k: safe_subs(v, DIV_FWD) for k, v in r0.items()} + div_r1 = {k: safe_subs(v, DIV_FWD) for k, v in r1.items()} + all_exprs = [v for d in [div_r0, div_r1] for k, v in d.items() if k != "delta"] + probes = _build_merge_probe_maps(all_exprs) + + dim_row = sympy.Symbol("dim_row") + dim_col = sympy.Symbol("dim_col") + dims = [dim_row, dim_col] + for probe in probes: + row_diff = _eval_expr(div_r1["row"], probe) - _eval_expr( + div_r0["row"], probe + ) + col_diff = _eval_expr(div_r1["col"], probe) - _eval_expr( + div_r0["col"], probe + ) + result = _find_merge_dim_from_diffs( + {dim_row: row_diff, dim_col: col_diff}, 1, dims + ) + assert result == dim_col + + +class TestNumericEvalConstant: + """Test _numeric_eval_constant with preshuffle-like expressions. + + Note: _numeric_eval_constant uses _PROBE_POOL which includes 0, + causing ZeroDivisionError on floor/Mod expressions with symbolic + divisors. It returns None conservatively on any error, so these + complex expressions are beyond its reach. This is documented + behavior, not a bug — the merge probing (_MERGE_PROBES) handles + these cases instead. + """ + + def test_simple_constant(self): + """Sanity: detects that a simple constant expression is constant.""" + x = sympy.Symbol("x", integer=True, nonneg=True) + assert _numeric_eval_constant(3 * x - 3 * x) == 0 + assert _numeric_eval_constant(x - x + 7) == 7 + + def test_floor_mod_identity_returns_none(self): + """floor(x/d)*d + x%d - x is always 0, but _numeric_eval_constant + can't prove it because _PROBE_POOL includes 0 → ZeroDivisionError.""" + x = sympy.Symbol("x", integer=True, nonneg=True) + d = sympy.Symbol("d", integer=True, positive=True) + expr = sympy.floor(x / d) * d + sympy.Mod(x, d) - x + # Returns None because probing hits d=0. + assert _numeric_eval_constant(expr) is None + + def test_row_diff_not_constant_without_div_subs(self): + """Row diff for adjacent bytes is not constant when K is symbolic.""" + row_diff = BYTE_READS[1]["row"] - BYTE_READS[0]["row"] + assert _numeric_eval_constant(row_diff) is None + + +class TestFindMergeDim: + """Test _find_merge_dim_from_diffs helper.""" + + def test_single_dim_matches(self): + dim_a = sympy.Symbol("a") + dim_b = sympy.Symbol("b") + dims = [dim_a, dim_b] + assert _find_merge_dim_from_diffs({dim_a: 0, dim_b: 1}, 1, dims) == dim_b + + def test_both_change_returns_none(self): + dim_a = sympy.Symbol("a") + dim_b = sympy.Symbol("b") + dims = [dim_a, dim_b] + assert _find_merge_dim_from_diffs({dim_a: 1, dim_b: 1}, 1, dims) is None + + def test_wrong_diff_returns_none(self): + dim_a = sympy.Symbol("a") + dim_b = sympy.Symbol("b") + dims = [dim_a, dim_b] + assert _find_merge_dim_from_diffs({dim_a: 0, dim_b: 2}, 1, dims) is None + + def test_nonzero_non_ept_returns_none(self): + """If a dim changes by something other than 0 or ept, reject.""" + dim_a = sympy.Symbol("a") + dim_b = sympy.Symbol("b") + dims = [dim_a, dim_b] + assert _find_merge_dim_from_diffs({dim_a: 3, dim_b: 0}, 1, dims) is None From 943570d87526c965598ca561c96b8fbc052ddd49 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Fri, 6 Mar 2026 01:55:46 +0100 Subject: [PATCH 15/15] Add lit test for preshuffle B-scale coalescing with dynamic dims Verifies that divisibility substitutions (K % 256) enable the read coalescer to produce clean vector<16xi8> B-scale and vector<4xi8> A-scale loads from fat_raw_buffer, with no vector.from_elements fragmentation, when M, N, K are dynamic. Co-Authored-By: Claude Opus 4.6 Signed-off-by: Ivan Butygin --- lit_tests/kernel/wave/scaled_gemm.py | 55 ++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/lit_tests/kernel/wave/scaled_gemm.py b/lit_tests/kernel/wave/scaled_gemm.py index 96dbdddfa..7fb39936c 100644 --- a/lit_tests/kernel/wave/scaled_gemm.py +++ b/lit_tests/kernel/wave/scaled_gemm.py @@ -10,6 +10,7 @@ ScaledMMAType, ) from wave_lang.kernel.wave.scheduling.schedule_enums import SchedulingType +from wave_lang.kernel.wave.templates import get_tagged_mxfp4_gemm_preshuffle_b from wave_lang.kernel.wave.templates.test_kernels import ( get_broadcasted_scale_gemm_mxfp4, ) @@ -997,3 +998,57 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # Unmasked vector stores for output. # CHECK: vector.store # CHECK: return + + +@run_test +def test_dynamic_preshuffle_b_scale_coalescing(): + """Verify B-scale reads coalesce into clean vector<16xi8> with dynamic dims. + + Uses the preshuffle-B MXFP4 template with dynamic M, N, K and small + block sizes. The K % 256 divisibility assumption lets the coalescer + apply divisibility substitutions during numeric probing, so the 2D + decomposition (row = offset floordiv K/2, col = offset mod K/2) gives + consistent per-dim diffs across probe sets. Without this fix, probes + like K=137 make K/2=68, causing inconsistent row/col diffs and + fragmenting 16-byte scale reads into {2, 16, 8, 4} loads glued by + vector.from_elements. + """ + shape = (256, 256, 256) + block = (128, 128, 256) + kernel, options = get_tagged_mxfp4_gemm_preshuffle_b( + shape, + block, + wave_shape=(2, 2), + mfma_variant=ScaledMMAType.F32_16x16x128_F8F6F4, + reorder_workgroups=False, + ) + dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] + for sym in dynamic_symbols: + del options.subs[sym] + options.dynamic_symbols = dynamic_symbols + options.schedule = SchedulingType.NONE + options.use_buffer_ops = True + options.compile_to_mlir = True + options.device = "hip" + options.target = "gfx950" + result = wave_compile(options, kernel) + print(result.asm) + + # CHECK-LABEL: test_dynamic_preshuffle_b_scale_coalescing + + # Dynamic index arguments for M, N, K. + # CHECK: func.func @gemm(%arg0: {{.*}}, %arg1: {{.*}}, %arg2: {{.*}}, %arg3: {{.*}}, %arg4: {{.*}}, %arg5: index, %arg6: index, %arg7: index) + + # Buffer ops: fat_raw_buffer_cast for global buffers. + # CHECK: amdgpu.fat_raw_buffer_cast + + # B-scale reads are clean vector<16xi8> from fat_raw_buffer — no + # fragmentation into mixed-width loads glued by from_elements. + # A-scale reads are vector<4xi8>. + # CHECK: scf.for + # CHECK-COUNT-8: vector.load %{{.*}} : memref>, vector<16xi8> + # CHECK-COUNT-2: vector.load %{{.*}} : memref>, vector<4xi8> + # CHECK: amdgpu.scaled_mfma + + # No byte-level reassembly — coalescing succeeded. + # CHECK-NOT: vector.from_elements