From 3343cbbe3f5a94e736865d7ce0a3e61472df0df5 Mon Sep 17 00:00:00 2001 From: xintin Date: Mon, 2 Mar 2026 20:57:46 +0000 Subject: [PATCH 01/10] initial commit Signed-off-by: xintin --- .../analysis/partition_strided_operators.py | 187 ++++++++++++++++++ 1 file changed, 187 insertions(+) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index bda0a39083..a44023b796 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -400,6 +400,192 @@ def has_gpr_offsets(node: fx.Node) -> bool: custom.graph.erase_node(custom.fx_node) +# TODO: I will add a field in IndexMapping to check if it is a preshuffle mapping. +# `kind`: `preshuffle` or `shuffleb` or `linear`. +def _is_preshuffle_mapping(mapping) -> bool: + """Check if a mapping uses the preshuffle e8m0_shuffle pattern.""" + if mapping is None: + return False + # Scale buffers are always 2D: (N or M, K/32). + if len(mapping.output_mapping) != 2 or len(mapping.input_mapping) != 2: + return False + # Output must be identity (logical coords), input must be shuffled. + if not (mapping.is_output_identity() and not mapping.is_input_identity()): + return False + # The preshuffle formula uses floor (m//32, k//4) and Mod (k%4, (m//16)%2) + # to interleave scale bytes within each 256-byte region. + input_atoms = set() + for expr in mapping.input_mapping.values(): + input_atoms.update(type(a) for a in sympy.preorder_traversal(expr)) + return sympy.floor in input_atoms and sympy.Mod in input_atoms + + +def _coalesce_preshuffle_global_reads( + trace: CapturedTrace, constraints: list[Constraint] +): + """Coalesce scattered preshuffle-mapped global byte reads into dword loads. + + After merge_contiguous_reads, preshuffle-mapped global reads may remain as + individual byte loads when the per-wave tile doesn't divide evenly into + groups of 4 contiguous bytes (e.g. BLOCK_N=192 with 4 waves gives 48 + N-rows/wave, whose second 32-row e8m0 block only yields 2 reads instead + of 4). + + Uses pairwise symbolic differences to find reads within the same 4-byte + aligned region, then replaces them with a single 4-byte read plus + ExtractSlice ops for the needed bytes. + """ + from collections import defaultdict + from ...compiler.utils import strides_from_symbolic_shape + from ..._support.indexing import IndexingContext + from ..utils.mapping_utils import transform_index_on_mapping + + idxc = IndexingContext.current() + + # Collect single-byte preshuffle scale reads from global memory, grouped + # by their source buffer. These are the reads that merge_contiguous_reads + # couldn't handle (e.g. because they have bounds, or their preshuffle + # offsets aren't strictly adjacent). + mem_groups: dict[int, list] = defaultdict(list) + for node in trace.walk(lambda n: isinstance(get_custom(n), Read)): + custom = get_custom(node) + # Already widened by merge_contiguous_reads + if custom.elements_per_thread != 1: + continue + # Only target reads with the e8m0 preshuffle mapping. + if not _is_preshuffle_mapping(custom.mapping): + continue + # Shared-memory reads are handled by preshuffle_scale_to_shared. + if subs_idxc(custom.memory_type.address_space) != GLOBAL_ADDRESS_SPACE: + continue + mem_groups[id(custom.memory)].append(node) + + coalesced = 0 + for _, nodes in mem_groups.items(): + # Need at least 2 reads to coalesce. + if len(nodes) < 2: + continue + + sample_custom = get_custom(nodes[0]) + memory_node = sample_custom.memory + memory = get_custom(memory_node) + symbolic_shape = memory.type.symbolic_shape + symbolic_dims = [infer_dim(d) for d in symbolic_shape] + strides = strides_from_symbolic_shape( + idxc, symbolic_shape, allow_mixed_shapes=True + ) + if strides is None: + continue + + # For each read, compute its physical flat offset (in bytes) and store + # it along with the read node, custom op, and physical coordinates. + read_infos = [] + for node in nodes: + custom = get_custom(node) + # Transform the logical index to physical coordinates using the mapping. + physical = transform_index_on_mapping( + custom.mapping, symbolic_shape, custom.index, is_read=True + ) + if not all(dim in physical for dim in symbolic_dims): + continue + # Compute the physical flat offset (in bytes) by multiplying the + # physical coordinates by the corresponding stride and summing the results. + flat = sum( + physical[dim] * stride for dim, stride in zip(symbolic_dims, strides) + ) + read_infos.append((node, custom, subs_idxc(flat), physical)) + + if len(read_infos) < 2: + continue + + # Try all pairs to find contiguous ones (diff == ept). + merged = set() + for i in range(len(read_infos)): + if i in merged: + continue + node_i, custom_i, flat_i, phys_i = read_infos[i] + + group = [(i, node_i, custom_i, 0, phys_i)] + for j in range(len(read_infos)): + if j == i or j in merged: + continue + node_j, custom_j, flat_j, phys_j = read_infos[j] + raw_diff = subs_idxc(flat_j - flat_i) + diff_val = _numeric_eval_constant(raw_diff) + if diff_val is None: + continue + if 0 < diff_val < 4: + group.append((j, node_j, custom_j, diff_val, phys_j)) + + if len(group) < 2: + continue + + group.sort(key=lambda x: x[3]) + max_off = group[-1][3] + if max_off >= 4: + continue + + base_phys = group[0][4] + + earliest_node = group[0][1] + for g in group[1:]: + candidate = g[1] + for n in custom_i.graph.nodes: + if n is candidate: + earliest_node = candidate + break + if n is earliest_node: + break + + with get_custom(earliest_node).graph.inserting_before(earliest_node): + # Create a new wide read with 4-byte elements, using the earliest + # read's physical coordinates as the base. + wide_index = {} + for dim_idx, dim in enumerate(symbolic_dims): + if dim_idx == len(symbolic_dims) - 1: + wide_index[dim] = IndexSequence(base_phys[dim], 4, 1) + else: + wide_index[dim] = IndexSequence(base_phys[dim], 1, 1) + + wide_read = Read( + memory_node, + elements_per_thread=4, + mapping=None, + _write_dependency=custom_i._write_dependency, + flags=custom_i.flags, + ).add_to_graph(custom_i.graph, loc=custom_i.location) + wide_custom = get_custom(wide_read) + wide_custom.index = wide_index + if hasattr(earliest_node, "vector_shapes"): + wide_read.vector_shapes = deepcopy(earliest_node.vector_shapes) + propagate_tag(group[0][1], wide_read) + + # Create ExtractSlice ops for each byte in the contiguous region. + # because the wide read is 4-byte aligned, we need to extract the + # correct byte from the wide read. + for g_idx, g_node, g_custom, byte_pos, _ in group: + with g_custom.graph.inserting_before(g_node): + extract = ExtractSlice( + wide_read, [byte_pos], [1], [1] + ).add_to_graph(g_custom.graph, loc=g_custom.location) + if hasattr(g_node, "vector_shapes"): + get_custom(extract).vector_shapes = deepcopy( + g_node.vector_shapes + ) + propagate_tag(g_node, extract) + + g_custom.replace_all_uses_with(extract) + g_custom.graph.erase_node(g_node) + + merged.update(g[0] for g in group) + coalesced += 1 + + if coalesced > 0: + logger.info( + f"Coalesced {coalesced} preshuffle global read group(s) into dword loads" + ) + + def merge_contiguous_reads( trace: CapturedTrace, constraints: list[Constraint], target: str ): @@ -416,6 +602,7 @@ def merge_contiguous_reads( hw_constraint = get_hardware_constraint(constraints) while _merge_contiguous_reads_once(trace, hw_constraint): pass + _coalesce_preshuffle_global_reads(trace, constraints) def _get_physical_start( From 1ccfd54afcb813e1a866e30572571c931bc93ff7 Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 01:16:17 +0000 Subject: [PATCH 02/10] remove bound check Signed-off-by: xintin --- .../wave/analysis/partition_strided_operators.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index a44023b796..99608094db 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -602,7 +602,7 @@ def merge_contiguous_reads( hw_constraint = get_hardware_constraint(constraints) while _merge_contiguous_reads_once(trace, hw_constraint): pass - _coalesce_preshuffle_global_reads(trace, constraints) + # _coalesce_preshuffle_global_reads(trace, constraints) def _get_physical_start( @@ -660,10 +660,8 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: continue if custom.mapping_dynamic_vals: continue - # Skip reads that have bounds: the merged read would lose the - # mapping and source→target index, making mask generation incorrect. - if custom.bounds is not None: - continue + # Bounded reads are allowed into the group; the merge loop + # below only merges pairs with identical bounds. key = (custom.memory, custom.elements_per_thread, region_id) groups[key].append(node) @@ -737,6 +735,14 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: else: continue + # Only merge bounded reads when both have the exact same + # bounds (same OOB mask). This lets preshuffle reads + # (which all share identical bounds) pass through even if + # they can't merge here — _coalesce_preshuffle_global_reads + # handles them downstream. + if lo_custom.bounds != hi_custom.bounds: + continue + # Find dimension that advances by ept. merge_dim = None for dim in symbolic_dims: From a14e4ea02dd8b796c8ba881b3a8c8d2c3520daa0 Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 05:22:52 +0000 Subject: [PATCH 03/10] updated Signed-off-by: xintin --- .../analysis/partition_strided_operators.py | 325 +++++++----------- wave_lang/kernel/wave/opsel_scaled_mfma.py | 44 +-- 2 files changed, 150 insertions(+), 219 deletions(-) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 99608094db..f02901abe8 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -400,192 +400,6 @@ def has_gpr_offsets(node: fx.Node) -> bool: custom.graph.erase_node(custom.fx_node) -# TODO: I will add a field in IndexMapping to check if it is a preshuffle mapping. -# `kind`: `preshuffle` or `shuffleb` or `linear`. -def _is_preshuffle_mapping(mapping) -> bool: - """Check if a mapping uses the preshuffle e8m0_shuffle pattern.""" - if mapping is None: - return False - # Scale buffers are always 2D: (N or M, K/32). - if len(mapping.output_mapping) != 2 or len(mapping.input_mapping) != 2: - return False - # Output must be identity (logical coords), input must be shuffled. - if not (mapping.is_output_identity() and not mapping.is_input_identity()): - return False - # The preshuffle formula uses floor (m//32, k//4) and Mod (k%4, (m//16)%2) - # to interleave scale bytes within each 256-byte region. - input_atoms = set() - for expr in mapping.input_mapping.values(): - input_atoms.update(type(a) for a in sympy.preorder_traversal(expr)) - return sympy.floor in input_atoms and sympy.Mod in input_atoms - - -def _coalesce_preshuffle_global_reads( - trace: CapturedTrace, constraints: list[Constraint] -): - """Coalesce scattered preshuffle-mapped global byte reads into dword loads. - - After merge_contiguous_reads, preshuffle-mapped global reads may remain as - individual byte loads when the per-wave tile doesn't divide evenly into - groups of 4 contiguous bytes (e.g. BLOCK_N=192 with 4 waves gives 48 - N-rows/wave, whose second 32-row e8m0 block only yields 2 reads instead - of 4). - - Uses pairwise symbolic differences to find reads within the same 4-byte - aligned region, then replaces them with a single 4-byte read plus - ExtractSlice ops for the needed bytes. - """ - from collections import defaultdict - from ...compiler.utils import strides_from_symbolic_shape - from ..._support.indexing import IndexingContext - from ..utils.mapping_utils import transform_index_on_mapping - - idxc = IndexingContext.current() - - # Collect single-byte preshuffle scale reads from global memory, grouped - # by their source buffer. These are the reads that merge_contiguous_reads - # couldn't handle (e.g. because they have bounds, or their preshuffle - # offsets aren't strictly adjacent). - mem_groups: dict[int, list] = defaultdict(list) - for node in trace.walk(lambda n: isinstance(get_custom(n), Read)): - custom = get_custom(node) - # Already widened by merge_contiguous_reads - if custom.elements_per_thread != 1: - continue - # Only target reads with the e8m0 preshuffle mapping. - if not _is_preshuffle_mapping(custom.mapping): - continue - # Shared-memory reads are handled by preshuffle_scale_to_shared. - if subs_idxc(custom.memory_type.address_space) != GLOBAL_ADDRESS_SPACE: - continue - mem_groups[id(custom.memory)].append(node) - - coalesced = 0 - for _, nodes in mem_groups.items(): - # Need at least 2 reads to coalesce. - if len(nodes) < 2: - continue - - sample_custom = get_custom(nodes[0]) - memory_node = sample_custom.memory - memory = get_custom(memory_node) - symbolic_shape = memory.type.symbolic_shape - symbolic_dims = [infer_dim(d) for d in symbolic_shape] - strides = strides_from_symbolic_shape( - idxc, symbolic_shape, allow_mixed_shapes=True - ) - if strides is None: - continue - - # For each read, compute its physical flat offset (in bytes) and store - # it along with the read node, custom op, and physical coordinates. - read_infos = [] - for node in nodes: - custom = get_custom(node) - # Transform the logical index to physical coordinates using the mapping. - physical = transform_index_on_mapping( - custom.mapping, symbolic_shape, custom.index, is_read=True - ) - if not all(dim in physical for dim in symbolic_dims): - continue - # Compute the physical flat offset (in bytes) by multiplying the - # physical coordinates by the corresponding stride and summing the results. - flat = sum( - physical[dim] * stride for dim, stride in zip(symbolic_dims, strides) - ) - read_infos.append((node, custom, subs_idxc(flat), physical)) - - if len(read_infos) < 2: - continue - - # Try all pairs to find contiguous ones (diff == ept). - merged = set() - for i in range(len(read_infos)): - if i in merged: - continue - node_i, custom_i, flat_i, phys_i = read_infos[i] - - group = [(i, node_i, custom_i, 0, phys_i)] - for j in range(len(read_infos)): - if j == i or j in merged: - continue - node_j, custom_j, flat_j, phys_j = read_infos[j] - raw_diff = subs_idxc(flat_j - flat_i) - diff_val = _numeric_eval_constant(raw_diff) - if diff_val is None: - continue - if 0 < diff_val < 4: - group.append((j, node_j, custom_j, diff_val, phys_j)) - - if len(group) < 2: - continue - - group.sort(key=lambda x: x[3]) - max_off = group[-1][3] - if max_off >= 4: - continue - - base_phys = group[0][4] - - earliest_node = group[0][1] - for g in group[1:]: - candidate = g[1] - for n in custom_i.graph.nodes: - if n is candidate: - earliest_node = candidate - break - if n is earliest_node: - break - - with get_custom(earliest_node).graph.inserting_before(earliest_node): - # Create a new wide read with 4-byte elements, using the earliest - # read's physical coordinates as the base. - wide_index = {} - for dim_idx, dim in enumerate(symbolic_dims): - if dim_idx == len(symbolic_dims) - 1: - wide_index[dim] = IndexSequence(base_phys[dim], 4, 1) - else: - wide_index[dim] = IndexSequence(base_phys[dim], 1, 1) - - wide_read = Read( - memory_node, - elements_per_thread=4, - mapping=None, - _write_dependency=custom_i._write_dependency, - flags=custom_i.flags, - ).add_to_graph(custom_i.graph, loc=custom_i.location) - wide_custom = get_custom(wide_read) - wide_custom.index = wide_index - if hasattr(earliest_node, "vector_shapes"): - wide_read.vector_shapes = deepcopy(earliest_node.vector_shapes) - propagate_tag(group[0][1], wide_read) - - # Create ExtractSlice ops for each byte in the contiguous region. - # because the wide read is 4-byte aligned, we need to extract the - # correct byte from the wide read. - for g_idx, g_node, g_custom, byte_pos, _ in group: - with g_custom.graph.inserting_before(g_node): - extract = ExtractSlice( - wide_read, [byte_pos], [1], [1] - ).add_to_graph(g_custom.graph, loc=g_custom.location) - if hasattr(g_node, "vector_shapes"): - get_custom(extract).vector_shapes = deepcopy( - g_node.vector_shapes - ) - propagate_tag(g_node, extract) - - g_custom.replace_all_uses_with(extract) - g_custom.graph.erase_node(g_node) - - merged.update(g[0] for g in group) - coalesced += 1 - - if coalesced > 0: - logger.info( - f"Coalesced {coalesced} preshuffle global read group(s) into dword loads" - ) - - def merge_contiguous_reads( trace: CapturedTrace, constraints: list[Constraint], target: str ): @@ -602,7 +416,6 @@ def merge_contiguous_reads( hw_constraint = get_hardware_constraint(constraints) while _merge_contiguous_reads_once(trace, hw_constraint): pass - # _coalesce_preshuffle_global_reads(trace, constraints) def _get_physical_start( @@ -631,11 +444,22 @@ def _get_physical_start( def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: - """Single merge pass: merge adjacent pairs of same-ept reads. + """Single merge pass: merge reads that access nearby physical memory. + + Two strategies are applied per (memory, ept) group: - Groups reads by (memory operand, ept) and merges pairs whose physical - flat offset starts differ by exactly ept. Returns True if any merges - happened. + 1. **Pairwise contiguous merge** — pairs whose physical flat offset + starts differ by exactly ``ept`` are merged into a ``2*ept`` read + with two ExtractSlice outputs. + + 2. **Multi-way coalescing** (``ept==1`` only) — unmerged byte reads + whose flat offsets fall within a power-of-2 aligned window (up to + ``max_elems_per_load``) are replaced by a single wide read with + per-byte ExtractSlice outputs. Diffs are evaluated via numeric + probing (``_numeric_eval_constant``), so this handles non-constant + symbolic offsets such as preshuffle mappings. + + Returns True if any merges or coalescing happened. """ from collections import defaultdict from ...compiler.utils import strides_from_symbolic_shape @@ -700,6 +524,16 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: off1, phys1, custom1, node1 = read_infos[i] off2, phys2, custom2, node2 = read_infos[j] + # The pairwise merge drops the mapping and produces a + # physical-space read with no mask. That is unsafe for + # bounded reads whose OOB behaviour relies on the mask + # generated from bounds + mapping. Skip them here; + # bounded ept==1 reads (e.g. preshuffle scales backed by + # fat_raw_buffer) are handled by the multi-way coalescing + # pass below, where OOB loads safely return zero. + if custom1.bounds is not None or custom2.bounds is not None: + continue + raw_diff = subs_idxc(off2 - off1) # For reads with non-identity mappings (e.g. preshuffle @@ -735,14 +569,6 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: else: continue - # Only merge bounded reads when both have the exact same - # bounds (same OOB mask). This lets preshuffle reads - # (which all share identical bounds) pass through even if - # they can't merge here — _coalesce_preshuffle_global_reads - # handles them downstream. - if lo_custom.bounds != hi_custom.bounds: - continue - # Find dimension that advances by ept. merge_dim = None for dim in symbolic_dims: @@ -822,6 +648,107 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: merged_any = True break + # Multi-way coalescing for ept==1 reads that survived the pairwise + # pass. Groups reads whose numerically-probed flat offsets fall + # within an aligned window of max_load_width bytes, then emits a + # single wide read + per-byte ExtractSlice ops. This generalises + # the former _coalesce_preshuffle_global_reads to all mapping types. + if ept == 1 and len(read_infos) >= 2: + element_type = get_custom(reads[0]).type.dtype + max_load_width = hw_constraint.max_elems_per_load(element_type) + + unmerged_infos = [ + read_infos[k] for k in range(len(read_infos)) if k not in merged + ] + if len(unmerged_infos) >= 2: + coalesced_set: set[int] = set() + for anchor_idx in range(len(unmerged_infos)): + if anchor_idx in coalesced_set: + continue + off_a, phys_a, custom_a, node_a = unmerged_infos[anchor_idx] + + group = [(anchor_idx, node_a, custom_a, 0, phys_a)] + for probe_idx in range(len(unmerged_infos)): + if probe_idx == anchor_idx or probe_idx in coalesced_set: + continue + off_p, _, custom_p, node_p = unmerged_infos[probe_idx] + raw_diff = subs_idxc(off_p - off_a) + diff_val = _numeric_eval_constant(raw_diff) + if diff_val is None: + continue + if 0 < diff_val < max_load_width: + _, _, _, phys_p = unmerged_infos[probe_idx] + group.append( + (probe_idx, node_p, custom_p, diff_val, phys_p) + ) + + if len(group) < 2: + continue + + group.sort(key=lambda x: x[3]) + max_off = group[-1][3] + wide_ept = 1 + while wide_ept <= max_off: + wide_ept *= 2 + if wide_ept > max_load_width: + continue + + base_phys = group[0][4] + + earliest_node = group[0][1] + for g in group[1:]: + candidate = g[1] + for n in custom_a.graph.nodes: + if n is candidate: + earliest_node = candidate + break + if n is earliest_node: + break + + with get_custom(earliest_node).graph.inserting_before( + earliest_node + ): + wide_index = {} + for dim_idx, dim in enumerate(symbolic_dims): + if dim_idx == len(symbolic_dims) - 1: + wide_index[dim] = IndexSequence( + base_phys[dim], wide_ept, 1 + ) + else: + wide_index[dim] = IndexSequence(base_phys[dim], 1, 1) + + wide_read = Read( + custom_a.memory, + elements_per_thread=wide_ept, + mapping=None, + _write_dependency=custom_a._write_dependency, + flags=custom_a.flags, + ).add_to_graph(custom_a.graph, loc=custom_a.location) + wide_custom = get_custom(wide_read) + wide_custom.index = wide_index + if hasattr(earliest_node, "vector_shapes"): + wide_read.vector_shapes = deepcopy( + earliest_node.vector_shapes + ) + propagate_tag(group[0][1], wide_read) + + for g_idx, g_node, g_custom, byte_pos, _ in group: + with g_custom.graph.inserting_before(g_node): + extract = ExtractSlice( + wide_read, [byte_pos], [1], [1] + ).add_to_graph(g_custom.graph, loc=g_custom.location) + if hasattr(g_node, "vector_shapes"): + get_custom(extract).vector_shapes = deepcopy( + g_node.vector_shapes + ) + propagate_tag(g_node, extract) + + g_custom.replace_all_uses_with(extract) + g_custom.graph.erase_node(g_node) + + coalesced_set.update(g[0] for g in group) + merged_any = True + return merged_any diff --git a/wave_lang/kernel/wave/opsel_scaled_mfma.py b/wave_lang/kernel/wave/opsel_scaled_mfma.py index 2a4910b934..9172c8a30a 100644 --- a/wave_lang/kernel/wave/opsel_scaled_mfma.py +++ b/wave_lang/kernel/wave/opsel_scaled_mfma.py @@ -165,17 +165,19 @@ def _find_yield_op(for_view) -> Optional[Operation]: def _find_mergeable_groups( for_view, yield_op: Operation ) -> list[tuple[Value, Value, dict[int, int]]]: - """Find groups of 4 ``vector<1xi8>`` iter_args that can be coalesced. + """Find groups of ``vector<1xi8>`` iter_args that can be coalesced. A group is valid when: - * All 4 init values are ``extract_strided_slice`` at offsets - {0, 1, ..., SCALE_VECTOR_WIDTH-1} from the same ``vector<4xi8>`` - source. - * All 4 yield values follow the same pattern from a (possibly - different) ``vector<4xi8>`` source. + * At least 2 init values are ``extract_strided_slice`` at distinct + offsets from the same ``vector<4xi8>`` source. + * The corresponding yield values follow the same pattern from a + (possibly different) ``vector<4xi8>`` source. * For each member, init_offset == yield_offset (byte identity is preserved across iterations). + Partial groups (e.g. only offsets {0, 2}) are accepted — the + coalesced ``vector<4xi8>`` iter_arg simply carries unused bytes. + Returns a list of ``(init_source, yield_source, {offset: iter_index})``. """ i8 = IntegerType.get_signless(8) @@ -199,9 +201,6 @@ def _find_mergeable_groups( continue eligible.append((i, init_off, init_src, yield_src)) - # Group by init source. Multiple args can share the same source - # (e.g. two MFMAs using the same scale load), so partition by offset - # to form distinct groups of exactly 4. by_init_src = defaultdict(list) for entry in eligible: _, _, init_src, _ = entry @@ -214,12 +213,16 @@ def _find_mergeable_groups( _, off, _, _ = entry by_offset[off].append(entry) - while all(len(by_offset[o]) > 0 for o in range(SCALE_VECTOR_WIDTH)): + # Greedily form groups from available offsets. Accept any group + # with >= 2 distinct offsets (full groups of 4 are the common + # case; partial groups like {0, 2} arise from preshuffle scales). + present_offsets = [o for o in range(SCALE_VECTOR_WIDTH) if by_offset[o]] + while len(present_offsets) >= 2: members = {} init_source = None yield_owners = set() yield_source = None - for o in range(SCALE_VECTOR_WIDTH): + for o in present_offsets: idx, _, isrc, ysrc = by_offset[o].pop(0) members[o] = idx init_source = isrc @@ -227,6 +230,7 @@ def _find_mergeable_groups( yield_owners.add(id(ysrc.owner)) if len(yield_owners) == 1: result.append((init_source, yield_source, members)) + present_offsets = [o for o in range(SCALE_VECTOR_WIDTH) if by_offset[o]] return result @@ -314,15 +318,13 @@ def _rewire_for_results( members = plan.groups[g_idx][2] new_idx = plan.group_new_iter_idx[g_idx] has_users = any( - any(True for _ in old_results[members[o]].uses) - for o in range(SCALE_VECTOR_WIDTH) + any(True for _ in old_results[members[o]].uses) for o in members ) if not has_users: continue with InsertionPoint(for_op): - for o in range(SCALE_VECTOR_WIDTH): - old_i = members[o] + for o, old_i in members.items(): if not any(True for _ in old_results[old_i].uses): continue extract_slice = make_extract_slice(new_results[new_idx], o) @@ -330,12 +332,15 @@ def _rewire_for_results( def _coalesce_vector_iter_args(module: Module) -> None: - """Merge groups of 4 ``vector<1xi8>`` scf.for iter_args into ``vector<4xi8>``. + """Merge groups of ``vector<1xi8>`` scf.for iter_args into ``vector<4xi8>``. - Pipeline double-buffering splits a ``vector<4xi8>`` scale load into 4 + Pipeline double-buffering splits a ``vector<4xi8>`` scale load into individual bytes for loop-carry. This pass merges them back so that ``_trace_scale_chain`` sees the full ``extract_strided_slice`` pattern inside the loop body and the opsel optimisation fires. + + Handles both full groups (all 4 offsets present) and partial groups + (e.g. only offsets {0, 2} from preshuffle scales). """ i8 = IntegerType.get_signless(8) i64 = IntegerType.get_signless(64) @@ -364,7 +369,7 @@ def make_extract_slice(source: Value, offset: int): if not groups: continue - logger.debug(f"Coalescing {len(groups)} group(s) of 4 vector<1xi8> iter_args") + logger.debug(f"Coalescing {len(groups)} group(s) of vector<1xi8> iter_args") plan = _build_coalesce_plan(groups, for_view, yield_op) old_iter_args = list(for_view.inner_iter_args) @@ -396,8 +401,7 @@ def make_extract_slice(source: Value, offset: int): with InsertionPoint(first_op): for g_idx, (_, _, members) in enumerate(groups): merged_arg = new_for.inner_iter_args[plan.group_new_iter_idx[g_idx]] - for offset in range(SCALE_VECTOR_WIDTH): - iter_idx = members[offset] + for offset, iter_idx in members.items(): extract_slice = make_extract_slice(merged_arg, offset) extract_results[iter_idx] = extract_slice.result From 1a5f07ee4ad6a3dc579aa995ac05e034512bc690 Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 05:45:52 +0000 Subject: [PATCH 04/10] flat exprsn Signed-off-by: xintin --- .../kernel/compiler/wave_codegen/handlers.py | 12 ++++ .../analysis/partition_strided_operators.py | 72 ++++++++++++++++--- wave_lang/kernel/wave/opsel_scaled_mfma.py | 7 ++ wave_lang/kernel/wave/utils/symbol_utils.py | 2 + 4 files changed, 83 insertions(+), 10 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/handlers.py b/wave_lang/kernel/compiler/wave_codegen/handlers.py index b913675e19..b67783eb84 100644 --- a/wave_lang/kernel/compiler/wave_codegen/handlers.py +++ b/wave_lang/kernel/compiler/wave_codegen/handlers.py @@ -1955,6 +1955,18 @@ def handle_extract_slice(emitter: WaveEmitter, node: fx.Node): strides, ) + mask_expr = getattr(node, "precomputed_mask_expr", None) + if mask_expr is not None: + mask = gen_sympy_index(add_emitter_subs(emitter), mask_expr) + ept = sizes[0] if sizes else 1 + mask_vec_type = VectorType.get([ept], IntegerType.get_signless(1)) + if mask.type != mask_vec_type: + mask = vector_d.broadcast(mask_vec_type, mask) + element_type = result_type.element_type + zero = arith_d.constant(element_type, get_constant_attr(0, element_type)) + zero_vec = vector_d.broadcast(result_type, zero) + element = arith_d.select(mask, element, zero_vec) + emitter.bind_node_proxy(node, IRProxyValue(element)) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index f02901abe8..420b21591f 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -443,6 +443,57 @@ def _get_physical_start( return {dim: custom.index[dim].start for dim in symbolic_dims} +def _flatten_bounds_to_mask_expr( + custom: Read, + symbolic_shape: tuple, +): + """Pre-compute a read's bounds check as a flat sympy boolean. + + Replicates the _build_mask_with_mapping decision logic: when the + mapping's transformed index contains all bounded dims (and has no + dynamic_val_indices), the transformed index is used; otherwise the + original logical index is used. Returns a sympy boolean expression + (e.g. And(idx < bound, ...)) or None if the read has no bounds. + + """ + if not custom.bounds: + return None + + import functools + from ..utils.mapping_utils import transform_index_on_mapping + + index = custom.index + + if custom.mapping is not None and not custom.has_identity_mapping(): + transformed = transform_index_on_mapping( + custom.mapping, symbolic_shape, index, is_read=True + ) + use_transformed = ( + all(dim in transformed for dim in custom.bounds) + and not custom.mapping.dynamic_val_indices + ) + if use_transformed: + index = transformed + + conditions = [] + for dim, bound in custom.bounds.items(): + if dim not in index: + continue + start = ( + index[dim].start if isinstance(index[dim], IndexSequence) else index[dim] + ) + if isinstance(start, int): + start = sympy.Integer(start) + if isinstance(bound, int): + bound = sympy.Integer(bound) + conditions.append(sympy.StrictLessThan(start, bound)) + + if not conditions: + return None + + return functools.reduce(sympy.And, conditions) + + def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: """Single merge pass: merge reads that access nearby physical memory. @@ -524,16 +575,6 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: off1, phys1, custom1, node1 = read_infos[i] off2, phys2, custom2, node2 = read_infos[j] - # The pairwise merge drops the mapping and produces a - # physical-space read with no mask. That is unsafe for - # bounded reads whose OOB behaviour relies on the mask - # generated from bounds + mapping. Skip them here; - # bounded ept==1 reads (e.g. preshuffle scales backed by - # fat_raw_buffer) are handled by the multi-way coalescing - # pass below, where OOB loads safely return zero. - if custom1.bounds is not None or custom2.bounds is not None: - continue - raw_diff = subs_idxc(off2 - off1) # For reads with non-identity mappings (e.g. preshuffle @@ -628,6 +669,9 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: get_custom(extract0).vector_shapes = deepcopy( lo_custom.vector_shapes ) + lo_mask = _flatten_bounds_to_mask_expr(lo_custom, symbolic_shape) + if lo_mask is not None: + extract0.precomputed_mask_expr = lo_mask propagate_tag(lo_node, extract0) extract1 = ExtractSlice( @@ -637,6 +681,9 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: get_custom(extract1).vector_shapes = deepcopy( hi_custom.vector_shapes ) + hi_mask = _flatten_bounds_to_mask_expr(hi_custom, symbolic_shape) + if hi_mask is not None: + extract1.precomputed_mask_expr = hi_mask propagate_tag(hi_node, extract1) lo_custom.replace_all_uses_with(extract0) @@ -741,6 +788,11 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: get_custom(extract).vector_shapes = deepcopy( g_node.vector_shapes ) + g_mask = _flatten_bounds_to_mask_expr( + g_custom, symbolic_shape + ) + if g_mask is not None: + extract.precomputed_mask_expr = g_mask propagate_tag(g_node, extract) g_custom.replace_all_uses_with(extract) diff --git a/wave_lang/kernel/wave/opsel_scaled_mfma.py b/wave_lang/kernel/wave/opsel_scaled_mfma.py index 9172c8a30a..3880dfba89 100644 --- a/wave_lang/kernel/wave/opsel_scaled_mfma.py +++ b/wave_lang/kernel/wave/opsel_scaled_mfma.py @@ -98,6 +98,8 @@ def _trace_scale_chain(scale_value): return None slice_op = bitcast_source.owner + if _is_op_named(slice_op, "arith.select"): + slice_op = slice_op.operands[1].owner if not _is_op_named(slice_op, "vector.extract_strided_slice"): return None @@ -139,9 +141,14 @@ def _trace_extract_strided_slice( ) -> Optional[tuple[Value, int]]: """Check if *value* is produced by extract_strided_slice of a vector<4xi8>. + Looks through ``arith.select`` (inserted by flatten-bounds masking) + to find the underlying extract_strided_slice. + Returns ``(source_vec4xi8, byte_offset)`` or ``None``. """ op = value.owner + if _is_op_named(op, "arith.select"): + op = op.operands[1].owner if not _is_op_named(op, "vector.extract_strided_slice"): return None source = op.operands[0] diff --git a/wave_lang/kernel/wave/utils/symbol_utils.py b/wave_lang/kernel/wave/utils/symbol_utils.py index 01e6d777f1..fe2a922f8e 100644 --- a/wave_lang/kernel/wave/utils/symbol_utils.py +++ b/wave_lang/kernel/wave/utils/symbol_utils.py @@ -282,6 +282,8 @@ def _numeric_eval_constant(expr, num_samples: int = 48): free, evaluator = (), None if not free: + if isinstance(expr, int): + return expr if expr.has(*_BAD_ATOMS): return None if expr.is_integer is not True: From 985457eaf6d3460931d0cb401cf1b16429c19ec8 Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 06:50:16 +0000 Subject: [PATCH 05/10] updated lit tests Signed-off-by: xintin --- lit_tests/kernel/wave/scaled_gemm.py | 12 ++++++------ lit_tests/kernel/wave/scaled_mma.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/lit_tests/kernel/wave/scaled_gemm.py b/lit_tests/kernel/wave/scaled_gemm.py index 2abf657d86..940a2c04e7 100644 --- a/lit_tests/kernel/wave/scaled_gemm.py +++ b/lit_tests/kernel/wave/scaled_gemm.py @@ -359,9 +359,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: } # Epilogue Local Read - # CHECK-COUNT-16: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-16: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> - # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-8: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> # Epilogue MFMA @@ -469,8 +469,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: rocdl.s.waitcnt # CHECK: amdgpu.lds_barrier - # Steady state local loads - # CHECK-COUNT-48: vector.load{{.*}} memref<{{.*}}, #gpu.address_space> + # Steady state local loads (8+4 scale loads as vector<8xi8> + 16+8 data loads as vector<16xi8>) + # CHECK-COUNT-36: vector.load{{.*}} memref<{{.*}}, #gpu.address_space> # Steady State global load to lds # CHECK-COUNT-34: amdgpu.gather_to_lds @@ -635,9 +635,9 @@ def repeat( # CHECK: } # Epilogue Local Read - # CHECK-COUNT-16: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-8: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-16: vector.load {{.*}} : memref<34816xi8, #gpu.address_space>, vector<16xi8> - # CHECK-COUNT-8: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-8: vector.load {{.*}} : memref<34816xi8, #gpu.address_space>, vector<16xi8> # Epilogue MFMA diff --git a/lit_tests/kernel/wave/scaled_mma.py b/lit_tests/kernel/wave/scaled_mma.py index f37dae9b56..a762fb4a24 100644 --- a/lit_tests/kernel/wave/scaled_mma.py +++ b/lit_tests/kernel/wave/scaled_mma.py @@ -352,9 +352,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-COUNT-4: vector.load {{.*}} : memref<16384x8192xi8, strided<[8192, 1]>>, vector<16xi8> # CHECK-COUNT-1: vector.load {{.*}} : memref<16384x512xi8, strided<[512, 1]>>, vector<4xi8> # CHECK: amdgpu.lds_barrier - # CHECK-COUNT-16: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-16: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> - # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-8: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> # CHECK-COUNT-8: vector.bitcast {{.*}} : vector<16xi8> to vector<32xf4E2M1FN> # CHECK-COUNT-8: vector.bitcast {{.*}} : vector<1xi8> to vector<1xf8E8M0FNU> From 7e96bae097096a479bf6d41e72d3ed3400180ef1 Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 15:11:23 +0000 Subject: [PATCH 06/10] fix wave asm backend Signed-off-by: xintin --- wave_lang/kernel/wave/asm/handlers_memory.py | 23 +++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/wave_lang/kernel/wave/asm/handlers_memory.py b/wave_lang/kernel/wave/asm/handlers_memory.py index a323940284..e0ee25d3d8 100644 --- a/wave_lang/kernel/wave/asm/handlers_memory.py +++ b/wave_lang/kernel/wave/asm/handlers_memory.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from wave_lang.support.ir_imports import ( + VectorType, amdgpu_d, gpu_d, memref_d, @@ -79,14 +80,24 @@ def handle_vector_extract_strided_slice_op( offset_val = int(str(offsets).split("[")[1].split("]")[0]) size_val = int(str(sizes).split("[")[1].split("]")[0]) + # Convert MLIR element offsets to physical register indices. + # Each 32-bit VGPR holds (32 // elem_bits) elements, e.g. 4 for i8. + # An extract at element offset 4 on vector<8xi8> targets register 1, + # not register 4. + source_vec_type = VectorType(operation.operands[0].type) + elem_bits = source_vec_type.element_type.width + elems_per_reg = 32 // elem_bits + + reg_offset = offset_val // elems_per_reg + reg_count = max(1, (size_val * elem_bits + 31) // 32) + # Extract the appropriate subset of registers - if size_val == 1: - # Single scalar extract - return just the one register as a tuple - extracted_reg = source_regs[offset_val] - result_regs = (extracted_reg,) + if reg_count == 1: + # Single register extract - return just the one register as a tuple + result_regs = (source_regs[reg_offset],) else: - # Multi-element extract - return a slice - result_regs = source_regs[offset_val : offset_val + size_val] + # Multi-register extract - return a slice + result_regs = source_regs[reg_offset : reg_offset + reg_count] result_ssa = str(operation.result) self.walker.kernel_ctx.ssa_to_reg[result_ssa] = result_regs From 72b7cd8144665253995aa74e3382f33fc0f631f1 Mon Sep 17 00:00:00 2001 From: xintin Date: Tue, 3 Mar 2026 19:47:14 +0000 Subject: [PATCH 07/10] lint Signed-off-by: xintin --- .../wave/analysis/partition_strided_operators.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index a3c79ade02..1cc8022f79 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import functools from collections.abc import Sequence from copy import deepcopy from itertools import groupby @@ -30,6 +31,7 @@ ) from ..assumptions import Assumption from ..constraints import Constraint +from ..utils.mapping_utils import transform_index_on_mapping from ..utils.tag_utils import propagate_tag from ..utils.general_utils import ( all_equal, @@ -433,8 +435,6 @@ def _get_physical_start( coordinates. For identity-mapped reads (mapping=None), reads the start offsets directly from the index. """ - from ..utils.mapping_utils import transform_index_on_mapping - if custom.mapping is not None and not custom.has_identity_mapping(): physical = transform_index_on_mapping( custom.mapping, symbolic_shape, custom.index, is_read=True @@ -457,15 +457,12 @@ def _flatten_bounds_to_mask_expr( mapping's transformed index contains all bounded dims (and has no dynamic_val_indices), the transformed index is used; otherwise the original logical index is used. Returns a sympy boolean expression - (e.g. And(idx < bound, ...)) or None if the read has no bounds. + (eg. And(idx < bound, ...)) or None if the read has no bounds. """ if not custom.bounds: return None - import functools - from ..utils.mapping_utils import transform_index_on_mapping - index = custom.index if custom.mapping is not None and not custom.has_identity_mapping(): @@ -503,11 +500,11 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: Two strategies are applied per (memory, ept) group: - 1. **Pairwise contiguous merge** — pairs whose physical flat offset + 1. **Pairwise contiguous merge**: pairs whose physical flat offset starts differ by exactly ``ept`` are merged into a ``2*ept`` read with two ExtractSlice outputs. - 2. **Multi-way coalescing** (``ept==1`` only) — unmerged byte reads + 2. **Multi-way coalescing** (``ept==1`` only): unmerged byte reads whose flat offsets fall within a power-of-2 aligned window (up to ``max_elems_per_load``) are replaced by a single wide read with per-byte ExtractSlice outputs. Diffs are evaluated via numeric @@ -736,6 +733,9 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: if len(group) < 2: continue + # Pick the smallest power-of-2 width that covers all + # byte offsets in the group, and skip if it exceeds + # the hardware's max load width. group.sort(key=lambda x: x[3]) max_off = group[-1][3] wide_ept = 1 From c93ec162d31cb7f5cec3edb82b85f34ac189b309 Mon Sep 17 00:00:00 2001 From: xintin Date: Wed, 4 Mar 2026 17:30:41 +0000 Subject: [PATCH 08/10] add comment Signed-off-by: xintin --- wave_lang/kernel/compiler/wave_codegen/handlers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/wave_lang/kernel/compiler/wave_codegen/handlers.py b/wave_lang/kernel/compiler/wave_codegen/handlers.py index b67783eb84..09e68f56b8 100644 --- a/wave_lang/kernel/compiler/wave_codegen/handlers.py +++ b/wave_lang/kernel/compiler/wave_codegen/handlers.py @@ -1955,6 +1955,10 @@ def handle_extract_slice(emitter: WaveEmitter, node: fx.Node): strides, ) + # When reads are coalesced into a single wider load, the original + # per-read bounds checks are lost. The partition pass preserves them + # as a precomputed sympy mask on each ExtractSlice so we can zero + # out-of-bounds lanes here. mask_expr = getattr(node, "precomputed_mask_expr", None) if mask_expr is not None: mask = gen_sympy_index(add_emitter_subs(emitter), mask_expr) From 6c3054a24e3570f2da41fb8af0033d223a20a25f Mon Sep 17 00:00:00 2001 From: xintin Date: Wed, 4 Mar 2026 19:23:55 +0000 Subject: [PATCH 09/10] refactored Signed-off-by: xintin --- .../analysis/partition_strided_operators.py | 539 +++++++++--------- wave_lang/kernel/wave/utils/symbol_utils.py | 2 +- 2 files changed, 279 insertions(+), 262 deletions(-) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 073bf5404c..c9b9b26ded 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -493,32 +493,53 @@ def _flatten_bounds_to_mask_expr( return functools.reduce(sympy.And, conditions) -def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: - """Single merge pass: merge reads that access nearby physical memory. - - Two strategies are applied per (memory, ept) group: - - 1. **Pairwise contiguous merge**: pairs whose physical flat offset - starts differ by exactly ``ept`` are merged into a ``2*ept`` read - with two ExtractSlice outputs. - - 2. **Multi-way coalescing** (``ept==1`` only): unmerged byte reads - whose flat offsets fall within a power-of-2 aligned window (up to - ``max_elems_per_load``) are replaced by a single wide read with - per-byte ExtractSlice outputs. Diffs are evaluated via numeric - probing (``_numeric_eval_constant``), so this handles non-constant - symbolic offsets such as preshuffle mappings. - - Returns True if any merges or coalescing happened. +def _emit_wide_read(anchor_custom, wide_index, wide_ept, tag_source): + """Create a merged Read node covering ``wide_ept`` elements.""" + wide_read = Read( + anchor_custom.memory, + elements_per_thread=wide_ept, + mapping=None, + _write_dependency=anchor_custom._write_dependency, + flags=anchor_custom.flags, + ).add_to_graph(anchor_custom.graph, loc=anchor_custom.location) + wide_custom = get_custom(wide_read) + wide_custom.index = wide_index + if hasattr(tag_source, "vector_shapes"): + wide_read.vector_shapes = deepcopy(tag_source.vector_shapes) + propagate_tag(tag_source, wide_read) + return wide_read + + +def _emit_extract_slice( + wide_read, offset, size, orig_custom, orig_node, symbolic_shape +): + """Create an ExtractSlice from a wide read with bounds mask and tag.""" + extract = ExtractSlice(wide_read, [offset], [size], [1]).add_to_graph( + orig_custom.graph, loc=orig_custom.location + ) + extract_custom = get_custom(extract) + extract_custom.index = deepcopy(orig_custom.index) + if hasattr(orig_node, "vector_shapes"): + extract_custom.vector_shapes = deepcopy(orig_node.vector_shapes) + mask = _flatten_bounds_to_mask_expr(orig_custom, symbolic_shape) + if mask is not None: + extract.precomputed_mask_expr = mask + propagate_tag(orig_node, extract) + return extract + + +def _group_reads_by_memory( + trace: CapturedTrace, +) -> dict[tuple, list[fx.Node]]: + """Group reads by (memory, ept, region). + + A new region starts at each subgraph boundary and whenever a + side-effecting op (write, barrier, ...) is encountered, so we never + merge reads across such ops. Reads with dynamic mapping values are + skipped to keep the merge logic simple. """ from collections import defaultdict - from ...compiler.utils import strides_from_symbolic_shape - from ..._support.indexing import IndexingContext - # Group reads by (memory, ept, region). A new region starts at each - # subgraph boundary and whenever a side-effecting op (write, barrier, ...) - # is encountered, so we never merge reads across such ops. Reads with - # dynamic mapping values are skipped to keep the merge logic simple. groups: dict[tuple, list[fx.Node]] = defaultdict(list) region_id = 0 for subgraph in trace.region_graph.subgraphs.values(): @@ -534,11 +555,234 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: continue if custom.mapping_dynamic_vals: continue - # Bounded reads are allowed into the group; the merge loop - # below only merges pairs with identical bounds. key = (custom.memory, custom.elements_per_thread, region_id) groups[key].append(node) + return groups + + +def _resolve_symbolic_diff(raw_diff, has_complex_mapping, expected_vals=None): + """Resolve a raw sympy offset difference to a value, or None. + + Strategy: + 1. If already a plain int / sympy.Integer, return it directly. + 2. If the mapping is complex (non-identity), use numeric probing. + 3. Otherwise try sym_simplify. If ``expected_vals`` is given and + the result isn't among them, fall back to numeric probing; + return None when neither approach succeeds. + """ + if isinstance(raw_diff, (int, sympy.Integer)): + return int(raw_diff) + if has_complex_mapping: + return _numeric_eval_constant(raw_diff) + simplified = sym_simplify(raw_diff) + if expected_vals is None or simplified in expected_vals: + return simplified + nv = _numeric_eval_constant(raw_diff) + return nv + + +def _pairwise_merge(read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint): + """Merge pairs of reads whose flat offsets differ by exactly ``ept``. + + Returns ``(merged_indices, did_merge)`` where *merged_indices* is + the set of read_infos indices that were consumed. + """ + merged = set() + did_merge = False + + for i in range(len(read_infos)): + if i in merged: + continue + for j in range(i + 1, len(read_infos)): + if j in merged: + continue + off1, phys1, custom1, node1 = read_infos[i] + off2, phys2, custom2, node2 = read_infos[j] + + raw_diff = subs_idxc(off2 - off1) + has_complex_mapping = ( + custom1.mapping is not None and not custom1.has_identity_mapping() + ) + diff = _resolve_symbolic_diff( + raw_diff, has_complex_mapping, expected_vals={ept, -ept} + ) + if diff is None: + continue + + if diff == ept: + lo_phys, hi_phys = phys1, phys2 + lo_custom, hi_custom = custom1, custom2 + lo_node, hi_node = node1, node2 + elif diff == -ept: + lo_phys, hi_phys = phys2, phys1 + lo_custom, hi_custom = custom2, custom1 + lo_node, hi_node = node2, node1 + else: + continue + + merge_dim = None + for dim in symbolic_dims: + raw_d = subs_idxc(hi_phys[dim] - lo_phys[dim]) + d = _resolve_symbolic_diff( + raw_d, has_complex_mapping, expected_vals={ept, 0} + ) + if d is None: + merge_dim = None + break + if d == ept: + merge_dim = dim + elif d != 0: + merge_dim = None + break + if merge_dim is None: + continue + + new_ept = 2 * ept + element_type = lo_custom.type.dtype + if new_ept > hw_constraint.max_elems_per_load(element_type): + continue + with lo_custom.graph.inserting_before(lo_node): + new_index = { + dim: IndexSequence( + lo_phys[dim], + new_ept if dim == merge_dim else 1, + 1, + ) + for dim in symbolic_dims + } + merged_read = _emit_wide_read(lo_custom, new_index, new_ept, lo_node) + + # Masks are attached per-slice rather than on the merged + # read because lo and hi may have different bounds + # conditions (e.g. lo is in-bounds while hi crosses the + # tensor boundary). A single mask on the wide read + # cannot express that — each half needs its own. + lo_extract = _emit_extract_slice( + merged_read, 0, ept, lo_custom, lo_node, symbolic_shape + ) + hi_extract = _emit_extract_slice( + merged_read, ept, ept, hi_custom, hi_node, symbolic_shape + ) + + lo_custom.replace_all_uses_with(lo_extract) + hi_custom.replace_all_uses_with(hi_extract) + lo_custom.graph.erase_node(lo_node) + hi_custom.graph.erase_node(hi_node) + merged.update({i, j}) + did_merge = True + break + + return merged, did_merge + + +def _multiway_coalesce( + read_infos, merged, reads, symbolic_dims, symbolic_shape, hw_constraint +): + """Coalesce unmerged ept==1 reads whose flat offsets fall in an aligned window. + + Groups reads whose numerically-probed flat offsets fall within a + power-of-2 aligned window (up to ``max_elems_per_load``), then emits + a single wide read with per-byte ExtractSlice ops. + """ + element_type = get_custom(reads[0]).type.dtype + max_load_width = hw_constraint.max_elems_per_load(element_type) + + unmerged_infos = [read_infos[k] for k in range(len(read_infos)) if k not in merged] + if len(unmerged_infos) < 2: + return False + + coalesced_any = False + coalesced_set: set[int] = set() + for anchor_idx in range(len(unmerged_infos)): + if anchor_idx in coalesced_set: + continue + off_a, phys_a, custom_a, node_a = unmerged_infos[anchor_idx] + + group = [(anchor_idx, node_a, custom_a, 0, phys_a)] + for probe_idx in range(len(unmerged_infos)): + if probe_idx == anchor_idx or probe_idx in coalesced_set: + continue + off_p, _, custom_p, node_p = unmerged_infos[probe_idx] + raw_diff = subs_idxc(off_p - off_a) + diff_val = _numeric_eval_constant(raw_diff) + if diff_val is None: + continue + if 0 < diff_val < max_load_width: + _, _, _, phys_p = unmerged_infos[probe_idx] + group.append((probe_idx, node_p, custom_p, diff_val, phys_p)) + + if len(group) < 2: + continue + + group.sort(key=lambda x: x[3]) + max_off = group[-1][3] + wide_ept = 1 + while wide_ept <= max_off: + wide_ept *= 2 + if wide_ept > max_load_width: + continue + + base_phys = group[0][4] + + earliest_node = group[0][1] + for g in group[1:]: + candidate = g[1] + for n in custom_a.graph.nodes: + if n is candidate: + earliest_node = candidate + break + if n is earliest_node: + break + + with get_custom(earliest_node).graph.inserting_before(earliest_node): + wide_index = {} + for dim_idx, dim in enumerate(symbolic_dims): + if dim_idx == len(symbolic_dims) - 1: + wide_index[dim] = IndexSequence(base_phys[dim], wide_ept, 1) + else: + wide_index[dim] = IndexSequence(base_phys[dim], 1, 1) + + wide_read = _emit_wide_read(custom_a, wide_index, wide_ept, earliest_node) + + extracts = [] + for g_idx, g_node, g_custom, byte_pos, _ in group: + with g_custom.graph.inserting_before(g_node): + ext = _emit_extract_slice( + wide_read, byte_pos, 1, g_custom, g_node, symbolic_shape + ) + extracts.append((ext, g_custom, g_node)) + + for ext, g_custom, g_node in extracts: + g_custom.replace_all_uses_with(ext) + g_custom.graph.erase_node(g_node) + + coalesced_set.update(g[0] for g in group) + coalesced_any = True + + return coalesced_any + + +def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: + """Single merge pass: merge reads that access nearby physical memory. + + Two strategies are applied per (memory, ept) group: + + 1. **Pairwise contiguous merge** (``_pairwise_merge``): pairs whose + physical flat offset starts differ by exactly ``ept`` are merged + into a ``2*ept`` read with two ExtractSlice outputs. + + 2. **Multi-way coalescing** (``_multiway_coalesce``, ``ept==1`` only): + unmerged byte reads whose flat offsets fall within a power-of-2 + aligned window are replaced by a single wide read with per-byte + ExtractSlice outputs. + + Returns True if any merges or coalescing happened. + """ + from ...compiler.utils import strides_from_symbolic_shape + from ..._support.indexing import IndexingContext + + groups = _group_reads_by_memory(trace) idxc = IndexingContext.current() merged_any = False @@ -564,244 +808,17 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: ) read_infos.append((flat_offset, phys_start, custom, node)) - merged = set() - for i in range(len(read_infos)): - if i in merged: - continue - for j in range(i + 1, len(read_infos)): - if j in merged: - continue - off1, phys1, custom1, node1 = read_infos[i] - off2, phys2, custom2, node2 = read_infos[j] - - raw_diff = subs_idxc(off2 - off1) - - # For reads with non-identity mappings (e.g. preshuffle - # scales), the flat-offset diff contains complex floor/Mod - # expressions that sympy.simplify cannot reduce. Use fast - # numeric probing instead. - has_complex_mapping = ( - custom1.mapping is not None and not custom1.has_identity_mapping() - ) + merged, did_merge = _pairwise_merge( + read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint + ) + merged_any |= did_merge - # subs_idxc may fully resolve to a plain int. - if isinstance(raw_diff, (int, sympy.Integer)): - diff = int(raw_diff) - elif has_complex_mapping: - diff = _numeric_eval_constant(raw_diff) - if diff is None: - continue - else: - diff = sym_simplify(raw_diff) - if diff != ept and diff != -ept: - nv = _numeric_eval_constant(raw_diff) - if nv is not None: - diff = nv - - if diff == ept: - lo_phys, hi_phys = phys1, phys2 - lo_custom, hi_custom = custom1, custom2 - lo_node, hi_node = node1, node2 - elif diff == -ept: - lo_phys, hi_phys = phys2, phys1 - lo_custom, hi_custom = custom2, custom1 - lo_node, hi_node = node2, node1 - else: - continue - - # Find dimension that advances by ept. - merge_dim = None - for dim in symbolic_dims: - raw_d = subs_idxc(hi_phys[dim] - lo_phys[dim]) - if isinstance(raw_d, (int, sympy.Integer)): - d = int(raw_d) - elif has_complex_mapping: - d = _numeric_eval_constant(raw_d) - if d is None: - merge_dim = None - break - else: - d = sym_simplify(raw_d) - if d != ept and d != 0: - nv = _numeric_eval_constant(raw_d) - if nv is not None: - d = nv - if d == ept: - merge_dim = dim - elif not (d == 0): - merge_dim = None - break - if merge_dim is None: - continue - - # Respect hardware vector width limit. - new_ept = 2 * ept - element_type = lo_custom.type.dtype - if new_ept > hw_constraint.max_elems_per_load(element_type): - continue - with lo_custom.graph.inserting_before(lo_node): - new_index = { - dim: IndexSequence( - lo_phys[dim], - new_ept if dim == merge_dim else 1, - 1, - ) - for dim in symbolic_dims - } - - merged_read = Read( - lo_custom.memory, - elements_per_thread=new_ept, - mapping=None, - _write_dependency=lo_custom._write_dependency, - flags=lo_custom.flags, - ).add_to_graph(lo_custom.graph, loc=lo_custom.location) - merged_custom = get_custom(merged_read) - merged_custom.index = new_index - merged_custom.vector_shapes = deepcopy(lo_custom.vector_shapes) - propagate_tag(lo_node, merged_read) - - extract0 = ExtractSlice(merged_read, [0], [ept], [1]).add_to_graph( - lo_custom.graph, loc=lo_custom.location - ) - get_custom(extract0).index = deepcopy(lo_custom.index) - get_custom(extract0).vector_shapes = deepcopy( - lo_custom.vector_shapes - ) - lo_mask = _flatten_bounds_to_mask_expr(lo_custom, symbolic_shape) - if lo_mask is not None: - extract0.precomputed_mask_expr = lo_mask - propagate_tag(lo_node, extract0) - - extract1 = ExtractSlice( - merged_read, [ept], [ept], [1] - ).add_to_graph(lo_custom.graph, loc=lo_custom.location) - get_custom(extract1).index = deepcopy(hi_custom.index) - get_custom(extract1).vector_shapes = deepcopy( - hi_custom.vector_shapes - ) - hi_mask = _flatten_bounds_to_mask_expr(hi_custom, symbolic_shape) - if hi_mask is not None: - extract1.precomputed_mask_expr = hi_mask - propagate_tag(hi_node, extract1) - - lo_custom.replace_all_uses_with(extract0) - hi_custom.replace_all_uses_with(extract1) - lo_custom.graph.erase_node(lo_node) - hi_custom.graph.erase_node(hi_node) - - merged.update({i, j}) - merged_any = True - break - - # Multi-way coalescing for ept==1 reads that survived the pairwise - # pass. Groups reads whose numerically-probed flat offsets fall - # within an aligned window of max_load_width bytes, then emits a - # single wide read + per-byte ExtractSlice ops. This generalises - # the former _coalesce_preshuffle_global_reads to all mapping types. + # Only ept==1 (byte) reads need multi-way coalescing; wider reads + # are already handled by the pairwise merge above. if ept == 1 and len(read_infos) >= 2: - element_type = get_custom(reads[0]).type.dtype - max_load_width = hw_constraint.max_elems_per_load(element_type) - - unmerged_infos = [ - read_infos[k] for k in range(len(read_infos)) if k not in merged - ] - if len(unmerged_infos) >= 2: - coalesced_set: set[int] = set() - for anchor_idx in range(len(unmerged_infos)): - if anchor_idx in coalesced_set: - continue - off_a, phys_a, custom_a, node_a = unmerged_infos[anchor_idx] - - group = [(anchor_idx, node_a, custom_a, 0, phys_a)] - for probe_idx in range(len(unmerged_infos)): - if probe_idx == anchor_idx or probe_idx in coalesced_set: - continue - off_p, _, custom_p, node_p = unmerged_infos[probe_idx] - raw_diff = subs_idxc(off_p - off_a) - diff_val = _numeric_eval_constant(raw_diff) - if diff_val is None: - continue - if 0 < diff_val < max_load_width: - _, _, _, phys_p = unmerged_infos[probe_idx] - group.append( - (probe_idx, node_p, custom_p, diff_val, phys_p) - ) - - if len(group) < 2: - continue - - # Pick the smallest power-of-2 width that covers all - # byte offsets in the group, and skip if it exceeds - # the hardware's max load width. - group.sort(key=lambda x: x[3]) - max_off = group[-1][3] - wide_ept = 1 - while wide_ept <= max_off: - wide_ept *= 2 - if wide_ept > max_load_width: - continue - - base_phys = group[0][4] - - earliest_node = group[0][1] - for g in group[1:]: - candidate = g[1] - for n in custom_a.graph.nodes: - if n is candidate: - earliest_node = candidate - break - if n is earliest_node: - break - - with get_custom(earliest_node).graph.inserting_before( - earliest_node - ): - wide_index = {} - for dim_idx, dim in enumerate(symbolic_dims): - if dim_idx == len(symbolic_dims) - 1: - wide_index[dim] = IndexSequence( - base_phys[dim], wide_ept, 1 - ) - else: - wide_index[dim] = IndexSequence(base_phys[dim], 1, 1) - - wide_read = Read( - custom_a.memory, - elements_per_thread=wide_ept, - mapping=None, - _write_dependency=custom_a._write_dependency, - flags=custom_a.flags, - ).add_to_graph(custom_a.graph, loc=custom_a.location) - wide_custom = get_custom(wide_read) - wide_custom.index = wide_index - if hasattr(earliest_node, "vector_shapes"): - wide_read.vector_shapes = deepcopy( - earliest_node.vector_shapes - ) - propagate_tag(group[0][1], wide_read) - - for g_idx, g_node, g_custom, byte_pos, _ in group: - with g_custom.graph.inserting_before(g_node): - extract = ExtractSlice( - wide_read, [byte_pos], [1], [1] - ).add_to_graph(g_custom.graph, loc=g_custom.location) - if hasattr(g_node, "vector_shapes"): - get_custom(extract).vector_shapes = deepcopy( - g_node.vector_shapes - ) - g_mask = _flatten_bounds_to_mask_expr( - g_custom, symbolic_shape - ) - if g_mask is not None: - extract.precomputed_mask_expr = g_mask - propagate_tag(g_node, extract) - - g_custom.replace_all_uses_with(extract) - g_custom.graph.erase_node(g_node) - - coalesced_set.update(g[0] for g in group) - merged_any = True + merged_any |= _multiway_coalesce( + read_infos, merged, reads, symbolic_dims, symbolic_shape, hw_constraint + ) return merged_any diff --git a/wave_lang/kernel/wave/utils/symbol_utils.py b/wave_lang/kernel/wave/utils/symbol_utils.py index 38b3ef3ff6..1cb79e1bdd 100644 --- a/wave_lang/kernel/wave/utils/symbol_utils.py +++ b/wave_lang/kernel/wave/utils/symbol_utils.py @@ -181,7 +181,7 @@ def transform_mod(expr): return None mult = m if (mult is None) or (m < mult) else mult terms.append(arg) - if c >= mult: + if c is None or mult is None or c >= mult: return None return (sum(terms) % q) + c From 09ad6b34fa4d2e9f9e426f9bbd38f444644ffcee Mon Sep 17 00:00:00 2001 From: xintin Date: Wed, 4 Mar 2026 20:18:48 +0000 Subject: [PATCH 10/10] one mask for wide read Signed-off-by: xintin --- .../kernel/compiler/wave_codegen/handlers.py | 16 ---- .../compiler/wave_codegen/read_write.py | 13 +++- .../analysis/partition_strided_operators.py | 76 ++++++++++++++++--- 3 files changed, 76 insertions(+), 29 deletions(-) diff --git a/wave_lang/kernel/compiler/wave_codegen/handlers.py b/wave_lang/kernel/compiler/wave_codegen/handlers.py index 09e68f56b8..b913675e19 100644 --- a/wave_lang/kernel/compiler/wave_codegen/handlers.py +++ b/wave_lang/kernel/compiler/wave_codegen/handlers.py @@ -1955,22 +1955,6 @@ def handle_extract_slice(emitter: WaveEmitter, node: fx.Node): strides, ) - # When reads are coalesced into a single wider load, the original - # per-read bounds checks are lost. The partition pass preserves them - # as a precomputed sympy mask on each ExtractSlice so we can zero - # out-of-bounds lanes here. - mask_expr = getattr(node, "precomputed_mask_expr", None) - if mask_expr is not None: - mask = gen_sympy_index(add_emitter_subs(emitter), mask_expr) - ept = sizes[0] if sizes else 1 - mask_vec_type = VectorType.get([ept], IntegerType.get_signless(1)) - if mask.type != mask_vec_type: - mask = vector_d.broadcast(mask_vec_type, mask) - element_type = result_type.element_type - zero = arith_d.constant(element_type, get_constant_attr(0, element_type)) - zero_vec = vector_d.broadcast(result_type, zero) - element = arith_d.select(mask, element, zero_vec) - emitter.bind_node_proxy(node, IRProxyValue(element)) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 5c871e8604..2ef8e1c174 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -699,7 +699,18 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): ) dynamic_vals_map_start = _build_dyn_vals_map(mapping, dyn_vals) - if mapping: + is_global_mem = kb_src.type.memory_space is None + buffer_ops_enabled = emitter.options.use_buffer_ops and is_global_mem + + precomputed_mask_expr = getattr(node, "precomputed_mask_expr", None) + if precomputed_mask_expr is not None and not buffer_ops_enabled: + mask = gen_sympy_index(add_emitter_subs(emitter), precomputed_mask_expr) + mask_vec_type = VectorType.get( + [elements_per_thread], IntegerType.get_signless(1) + ) + if mask.type != mask_vec_type: + mask = vector_d.broadcast(mask_vec_type, mask) + elif mapping: transformed_index = transform_index_on_mapping( mapping, input_shape, index, is_read=True ) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index c9b9b26ded..4b47004a8f 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -493,7 +493,49 @@ def _flatten_bounds_to_mask_expr( return functools.reduce(sympy.And, conditions) -def _emit_wide_read(anchor_custom, wide_index, wide_ept, tag_source): +def _build_wide_mask_expr(sub_reads, symbolic_shape, wide_ept): + """Build a concatenated sympy mask for a wide read from its sub-reads. + + Each entry in *sub_reads* is ``(offset, size, orig_custom)`` where + *offset* is the lane offset within the wide vector and *size* is the + number of lanes that sub-read occupies. *wide_ept* is the total + number of elements in the wide read (may exceed ``sum(sizes)`` when + there are gaps, e.g. multiway coalesce with non-contiguous offsets). + + Builds ``Or(And(lane_in_range_0, mask_0), And(lane_in_range_1, mask_1), ...)`` + using pure boolean ops so that ``gen_sympy_index`` can lower it without + the nested-Piecewise ``select_stack`` ordering issue. + + Returns a sympy boolean expression or ``None`` when no sub-read has + bounds. + """ + from ..._support.indexing import IndexingContext + + masks = [ + (offset, size, _flatten_bounds_to_mask_expr(custom, symbolic_shape)) + for offset, size, custom in sub_reads + ] + + if not any(m is not None for _, _, m in masks): + return None + + idxc = IndexingContext.current() + iota = idxc.iota(wide_ept) + + terms = [] + for offset, size, mask in masks: + upper = offset + size + lane_cond = sympy.And( + sympy.GreaterThan(iota, offset) if offset > 0 else sympy.true, + sympy.StrictLessThan(iota, upper), + ) + bound_cond = mask if mask is not None else sympy.true + terms.append(sympy.And(lane_cond, bound_cond)) + + return functools.reduce(sympy.Or, terms) + + +def _emit_wide_read(anchor_custom, wide_index, wide_ept, tag_source, mask_expr=None): """Create a merged Read node covering ``wide_ept`` elements.""" wide_read = Read( anchor_custom.memory, @@ -506,6 +548,8 @@ def _emit_wide_read(anchor_custom, wide_index, wide_ept, tag_source): wide_custom.index = wide_index if hasattr(tag_source, "vector_shapes"): wide_read.vector_shapes = deepcopy(tag_source.vector_shapes) + if mask_expr is not None: + wide_read.precomputed_mask_expr = mask_expr propagate_tag(tag_source, wide_read) return wide_read @@ -513,7 +557,7 @@ def _emit_wide_read(anchor_custom, wide_index, wide_ept, tag_source): def _emit_extract_slice( wide_read, offset, size, orig_custom, orig_node, symbolic_shape ): - """Create an ExtractSlice from a wide read with bounds mask and tag.""" + """Create an ExtractSlice from a wide read and propagate metadata.""" extract = ExtractSlice(wide_read, [offset], [size], [1]).add_to_graph( orig_custom.graph, loc=orig_custom.location ) @@ -521,9 +565,6 @@ def _emit_extract_slice( extract_custom.index = deepcopy(orig_custom.index) if hasattr(orig_node, "vector_shapes"): extract_custom.vector_shapes = deepcopy(orig_node.vector_shapes) - mask = _flatten_bounds_to_mask_expr(orig_custom, symbolic_shape) - if mask is not None: - extract.precomputed_mask_expr = mask propagate_tag(orig_node, extract) return extract @@ -555,6 +596,8 @@ def _group_reads_by_memory( continue if custom.mapping_dynamic_vals: continue + if getattr(node, "precomputed_mask_expr", None) is not None: + continue key = (custom.memory, custom.elements_per_thread, region_id) groups[key].append(node) return groups @@ -641,6 +684,11 @@ def _pairwise_merge(read_infos, ept, symbolic_dims, symbolic_shape, hw_constrain element_type = lo_custom.type.dtype if new_ept > hw_constraint.max_elems_per_load(element_type): continue + wide_mask = _build_wide_mask_expr( + [(0, ept, lo_custom), (ept, ept, hi_custom)], + symbolic_shape, + new_ept, + ) with lo_custom.graph.inserting_before(lo_node): new_index = { dim: IndexSequence( @@ -650,13 +698,10 @@ def _pairwise_merge(read_infos, ept, symbolic_dims, symbolic_shape, hw_constrain ) for dim in symbolic_dims } - merged_read = _emit_wide_read(lo_custom, new_index, new_ept, lo_node) + merged_read = _emit_wide_read( + lo_custom, new_index, new_ept, lo_node, mask_expr=wide_mask + ) - # Masks are attached per-slice rather than on the merged - # read because lo and hi may have different bounds - # conditions (e.g. lo is in-bounds while hi crosses the - # tensor boundary). A single mask on the wide read - # cannot express that — each half needs its own. lo_extract = _emit_extract_slice( merged_read, 0, ept, lo_custom, lo_node, symbolic_shape ) @@ -735,6 +780,11 @@ def _multiway_coalesce( if n is earliest_node: break + wide_mask = _build_wide_mask_expr( + [(byte_pos, 1, g_custom) for _, _, g_custom, byte_pos, _ in group], + symbolic_shape, + wide_ept, + ) with get_custom(earliest_node).graph.inserting_before(earliest_node): wide_index = {} for dim_idx, dim in enumerate(symbolic_dims): @@ -743,7 +793,9 @@ def _multiway_coalesce( else: wide_index[dim] = IndexSequence(base_phys[dim], 1, 1) - wide_read = _emit_wide_read(custom_a, wide_index, wide_ept, earliest_node) + wide_read = _emit_wide_read( + custom_a, wide_index, wide_ept, earliest_node, mask_expr=wide_mask + ) extracts = [] for g_idx, g_node, g_custom, byte_pos, _ in group: