diff --git a/lit_tests/kernel/wave/scaled_gemm.py b/lit_tests/kernel/wave/scaled_gemm.py index 7bf3b8b17..96dbdddfa 100644 --- a/lit_tests/kernel/wave/scaled_gemm.py +++ b/lit_tests/kernel/wave/scaled_gemm.py @@ -361,9 +361,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: } # Epilogue Local Read - # CHECK-COUNT-16: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-16: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> - # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-8: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> # Epilogue MFMA @@ -471,8 +471,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: rocdl.s.waitcnt # CHECK: amdgpu.lds_barrier - # Steady state local loads - # CHECK-COUNT-48: vector.load{{.*}} memref<{{.*}}, #gpu.address_space> + # Steady state local loads (8+4 scale loads as vector<8xi8> + 16+8 data loads as vector<16xi8>) + # CHECK-COUNT-36: vector.load{{.*}} memref<{{.*}}, #gpu.address_space> # Steady State global load to lds # CHECK-COUNT-34: amdgpu.gather_to_lds @@ -637,9 +637,9 @@ def repeat( # CHECK: } # Epilogue Local Read - # CHECK-COUNT-16: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-8: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-16: vector.load {{.*}} : memref<34816xi8, #gpu.address_space>, vector<16xi8> - # CHECK-COUNT-8: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<4096xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-8: vector.load {{.*}} : memref<34816xi8, #gpu.address_space>, vector<16xi8> # Epilogue MFMA diff --git a/lit_tests/kernel/wave/scaled_mma.py b/lit_tests/kernel/wave/scaled_mma.py index f37dae9b5..a762fb4a2 100644 --- a/lit_tests/kernel/wave/scaled_mma.py +++ b/lit_tests/kernel/wave/scaled_mma.py @@ -352,9 +352,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-COUNT-4: vector.load {{.*}} : memref<16384x8192xi8, strided<[8192, 1]>>, vector<16xi8> # CHECK-COUNT-1: vector.load {{.*}} : memref<16384x512xi8, strided<[512, 1]>>, vector<4xi8> # CHECK: amdgpu.lds_barrier - # CHECK-COUNT-16: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-16: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> - # CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<8xi8> # CHECK-COUNT-8: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> # CHECK-COUNT-8: vector.bitcast {{.*}} : vector<16xi8> to vector<32xf4E2M1FN> # CHECK-COUNT-8: vector.bitcast {{.*}} : vector<1xi8> to vector<1xf8E8M0FNU> diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 5c871e860..49b5a3ac5 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -466,6 +466,7 @@ def _create_vec_read_write( memory: CustomOp, mask: Optional[Value], node_index: Optional[IndexSequence] = None, + use_wide_load_select: bool = False, ) -> Optional[Value]: is_read = value is None uint32 = IntegerType.get_signless(32) @@ -558,7 +559,8 @@ def extract(vec, ind): indices = [offset_th] if buffer_ops_enabled else start_indices if no_masked_load_store_ops: - # find the index at which memory out of bounds of buffer + scalar_offset_th = offset_th + oob_index_value = _get_out_of_bounds_index(element_type) oob_index = arith_d.constant(IndexType.get(), oob_index_value) @@ -582,7 +584,6 @@ def extract(vec, ind): # based on mask, select between the offsets_vec and out of bounds. In this case all 3 operands can be vectors selected_index = arith_d.select(mask, offsets_vec, oob_index) - elems = list() if splatted_mask: # mask is same for all of them, can just pick the first index @@ -595,27 +596,28 @@ def extract(vec, ind): vector_d.store(value, mem, indices=[selected_index]) return - for i in range(elements_per_thread): - # mask is not same for all elements, need to unroll - this_index = extract(selected_index, i) # this element + if is_read and use_wide_load_select: + result = vector_d.load(vector_type, mem, indices=[scalar_offset_th]) + zero_vec = vector_d.broadcast(vector_type, zero) + return arith_d.select(mask, result, zero_vec) - # Unmasked load, using selected_index - singlenumvec_type = VectorType.get([1], vector_type.element_type) - if is_read: + if is_read: + elems = [] + for i in range(elements_per_thread): + this_index = extract(selected_index, i) + singlenumvec_type = VectorType.get([1], vector_type.element_type) elem = vector_d.load(singlenumvec_type, mem, indices=[this_index]) elem = extract(elem, 0) elems.append(elem) - else: - elem = extract(value, i) - single_num_vector = vector_d.broadcast(singlenumvec_type, elem) - vector_d.store(single_num_vector, mem, indices=[this_index]) - - if is_read: - # now make a vector from all the elements loaded return vector_d.from_elements(vector_type, elems) - else: # it was a store, return - return + for i in range(elements_per_thread): + this_index = extract(selected_index, i) + elem = extract(value, i) + singlenumvec_type = VectorType.get([1], vector_type.element_type) + single_num_vector = vector_d.broadcast(singlenumvec_type, elem) + vector_d.store(single_num_vector, mem, indices=[this_index]) + return else: # normal masked load/store @@ -699,7 +701,18 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): ) dynamic_vals_map_start = _build_dyn_vals_map(mapping, dyn_vals) - if mapping: + is_global_mem = kb_src.type.memory_space is None + buffer_ops_enabled = emitter.options.use_buffer_ops and is_global_mem + + precomputed_mask_expr = getattr(node, "precomputed_mask_expr", None) + if precomputed_mask_expr is not None: + mask = gen_sympy_index(add_emitter_subs(emitter), precomputed_mask_expr) + mask_vec_type = VectorType.get( + [elements_per_thread], IntegerType.get_signless(1) + ) + if mask.type != mask_vec_type: + mask = vector_d.broadcast(mask_vec_type, mask) + elif mapping: transformed_index = transform_index_on_mapping( mapping, input_shape, index, is_read=True ) @@ -744,6 +757,7 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): get_custom(memory), mask, node_index=index, + use_wide_load_select=precomputed_mask_expr is not None, ) emitter.bind_node_proxy(node, IRProxyValue(result)) diff --git a/wave_lang/kernel/wave/analysis/partition_strided_operators.py b/wave_lang/kernel/wave/analysis/partition_strided_operators.py index 18b92c99f..4b47004a8 100644 --- a/wave_lang/kernel/wave/analysis/partition_strided_operators.py +++ b/wave_lang/kernel/wave/analysis/partition_strided_operators.py @@ -4,6 +4,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import functools from collections.abc import Sequence from copy import deepcopy from itertools import groupby @@ -30,6 +31,7 @@ ) from ..assumptions import get_divisibility_subs from ..constraints import Constraint +from ..utils.mapping_utils import transform_index_on_mapping from ..utils.tag_utils import propagate_tag from ..utils.general_utils import ( all_equal, @@ -433,8 +435,6 @@ def _get_physical_start( coordinates. For identity-mapped reads (mapping=None), reads the start offsets directly from the index. """ - from ..utils.mapping_utils import transform_index_on_mapping - if custom.mapping is not None and not custom.has_identity_mapping(): physical = transform_index_on_mapping( custom.mapping, symbolic_shape, custom.index, is_read=True @@ -447,21 +447,140 @@ def _get_physical_start( return {dim: custom.index[dim].start for dim in symbolic_dims} -def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: - """Single merge pass: merge adjacent pairs of same-ept reads. +def _flatten_bounds_to_mask_expr( + custom: Read, + symbolic_shape: tuple, +): + """Pre-compute a read's bounds check as a flat sympy boolean. + + Replicates the _build_mask_with_mapping decision logic: when the + mapping's transformed index contains all bounded dims (and has no + dynamic_val_indices), the transformed index is used; otherwise the + original logical index is used. Returns a sympy boolean expression + (eg. And(idx < bound, ...)) or None if the read has no bounds. - Groups reads by (memory operand, ept) and merges pairs whose physical - flat offset starts differ by exactly ept. Returns True if any merges - happened. """ - from collections import defaultdict - from ...compiler.utils import strides_from_symbolic_shape + if not custom.bounds: + return None + + index = custom.index + + if custom.mapping is not None and not custom.has_identity_mapping(): + transformed = transform_index_on_mapping( + custom.mapping, symbolic_shape, index, is_read=True + ) + use_transformed = ( + all(dim in transformed for dim in custom.bounds) + and not custom.mapping.dynamic_val_indices + ) + if use_transformed: + index = transformed + + conditions = [] + for dim, bound in custom.bounds.items(): + if dim not in index: + continue + start = ( + index[dim].start if isinstance(index[dim], IndexSequence) else index[dim] + ) + start = sympy.sympify(start) + bound = sympy.sympify(bound) + conditions.append(sympy.StrictLessThan(start, bound)) + + if not conditions: + return None + + return functools.reduce(sympy.And, conditions) + + +def _build_wide_mask_expr(sub_reads, symbolic_shape, wide_ept): + """Build a concatenated sympy mask for a wide read from its sub-reads. + + Each entry in *sub_reads* is ``(offset, size, orig_custom)`` where + *offset* is the lane offset within the wide vector and *size* is the + number of lanes that sub-read occupies. *wide_ept* is the total + number of elements in the wide read (may exceed ``sum(sizes)`` when + there are gaps, e.g. multiway coalesce with non-contiguous offsets). + + Builds ``Or(And(lane_in_range_0, mask_0), And(lane_in_range_1, mask_1), ...)`` + using pure boolean ops so that ``gen_sympy_index`` can lower it without + the nested-Piecewise ``select_stack`` ordering issue. + + Returns a sympy boolean expression or ``None`` when no sub-read has + bounds. + """ from ..._support.indexing import IndexingContext - # Group reads by (memory, ept, region). A new region starts at each - # subgraph boundary and whenever a side-effecting op (write, barrier, ...) - # is encountered, so we never merge reads across such ops. Reads with - # dynamic mapping values are skipped to keep the merge logic simple. + masks = [ + (offset, size, _flatten_bounds_to_mask_expr(custom, symbolic_shape)) + for offset, size, custom in sub_reads + ] + + if not any(m is not None for _, _, m in masks): + return None + + idxc = IndexingContext.current() + iota = idxc.iota(wide_ept) + + terms = [] + for offset, size, mask in masks: + upper = offset + size + lane_cond = sympy.And( + sympy.GreaterThan(iota, offset) if offset > 0 else sympy.true, + sympy.StrictLessThan(iota, upper), + ) + bound_cond = mask if mask is not None else sympy.true + terms.append(sympy.And(lane_cond, bound_cond)) + + return functools.reduce(sympy.Or, terms) + + +def _emit_wide_read(anchor_custom, wide_index, wide_ept, tag_source, mask_expr=None): + """Create a merged Read node covering ``wide_ept`` elements.""" + wide_read = Read( + anchor_custom.memory, + elements_per_thread=wide_ept, + mapping=None, + _write_dependency=anchor_custom._write_dependency, + flags=anchor_custom.flags, + ).add_to_graph(anchor_custom.graph, loc=anchor_custom.location) + wide_custom = get_custom(wide_read) + wide_custom.index = wide_index + if hasattr(tag_source, "vector_shapes"): + wide_read.vector_shapes = deepcopy(tag_source.vector_shapes) + if mask_expr is not None: + wide_read.precomputed_mask_expr = mask_expr + propagate_tag(tag_source, wide_read) + return wide_read + + +def _emit_extract_slice( + wide_read, offset, size, orig_custom, orig_node, symbolic_shape +): + """Create an ExtractSlice from a wide read and propagate metadata.""" + extract = ExtractSlice(wide_read, [offset], [size], [1]).add_to_graph( + orig_custom.graph, loc=orig_custom.location + ) + extract_custom = get_custom(extract) + extract_custom.index = deepcopy(orig_custom.index) + if hasattr(orig_node, "vector_shapes"): + extract_custom.vector_shapes = deepcopy(orig_node.vector_shapes) + propagate_tag(orig_node, extract) + return extract + + +def _group_reads_by_memory( + trace: CapturedTrace, +) -> dict[tuple, list[fx.Node]]: + """Group reads by (memory, ept, region). + + A new region starts at each subgraph boundary and whenever a + side-effecting op (write, barrier, ...) is encountered, so we never + merge reads across such ops. Reads with dynamic mapping values are + skipped to keep the merge logic simple. + """ + from collections import defaultdict + groups: dict[tuple, list[fx.Node]] = defaultdict(list) region_id = 0 for subgraph in trace.region_graph.subgraphs.values(): @@ -477,13 +596,245 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: continue if custom.mapping_dynamic_vals: continue - # Skip reads that have bounds: the merged read would lose the - # mapping and source→target index, making mask generation incorrect. - if custom.bounds is not None: + if getattr(node, "precomputed_mask_expr", None) is not None: continue key = (custom.memory, custom.elements_per_thread, region_id) groups[key].append(node) + return groups + + +def _resolve_symbolic_diff(raw_diff, has_complex_mapping, expected_vals=None): + """Resolve a raw sympy offset difference to a value, or None. + + Strategy: + 1. If already a plain int / sympy.Integer, return it directly. + 2. If the mapping is complex (non-identity), use numeric probing. + 3. Otherwise try sym_simplify. If ``expected_vals`` is given and + the result isn't among them, fall back to numeric probing; + return None when neither approach succeeds. + """ + if isinstance(raw_diff, (int, sympy.Integer)): + return int(raw_diff) + if has_complex_mapping: + return _numeric_eval_constant(raw_diff) + simplified = sym_simplify(raw_diff) + if expected_vals is None or simplified in expected_vals: + return simplified + nv = _numeric_eval_constant(raw_diff) + return nv + + +def _pairwise_merge(read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint): + """Merge pairs of reads whose flat offsets differ by exactly ``ept``. + + Returns ``(merged_indices, did_merge)`` where *merged_indices* is + the set of read_infos indices that were consumed. + """ + merged = set() + did_merge = False + + for i in range(len(read_infos)): + if i in merged: + continue + for j in range(i + 1, len(read_infos)): + if j in merged: + continue + off1, phys1, custom1, node1 = read_infos[i] + off2, phys2, custom2, node2 = read_infos[j] + + raw_diff = subs_idxc(off2 - off1) + has_complex_mapping = ( + custom1.mapping is not None and not custom1.has_identity_mapping() + ) + diff = _resolve_symbolic_diff( + raw_diff, has_complex_mapping, expected_vals={ept, -ept} + ) + if diff is None: + continue + + if diff == ept: + lo_phys, hi_phys = phys1, phys2 + lo_custom, hi_custom = custom1, custom2 + lo_node, hi_node = node1, node2 + elif diff == -ept: + lo_phys, hi_phys = phys2, phys1 + lo_custom, hi_custom = custom2, custom1 + lo_node, hi_node = node2, node1 + else: + continue + + merge_dim = None + for dim in symbolic_dims: + raw_d = subs_idxc(hi_phys[dim] - lo_phys[dim]) + d = _resolve_symbolic_diff( + raw_d, has_complex_mapping, expected_vals={ept, 0} + ) + if d is None: + merge_dim = None + break + if d == ept: + merge_dim = dim + elif d != 0: + merge_dim = None + break + if merge_dim is None: + continue + + new_ept = 2 * ept + element_type = lo_custom.type.dtype + if new_ept > hw_constraint.max_elems_per_load(element_type): + continue + wide_mask = _build_wide_mask_expr( + [(0, ept, lo_custom), (ept, ept, hi_custom)], + symbolic_shape, + new_ept, + ) + with lo_custom.graph.inserting_before(lo_node): + new_index = { + dim: IndexSequence( + lo_phys[dim], + new_ept if dim == merge_dim else 1, + 1, + ) + for dim in symbolic_dims + } + merged_read = _emit_wide_read( + lo_custom, new_index, new_ept, lo_node, mask_expr=wide_mask + ) + + lo_extract = _emit_extract_slice( + merged_read, 0, ept, lo_custom, lo_node, symbolic_shape + ) + hi_extract = _emit_extract_slice( + merged_read, ept, ept, hi_custom, hi_node, symbolic_shape + ) + + lo_custom.replace_all_uses_with(lo_extract) + hi_custom.replace_all_uses_with(hi_extract) + lo_custom.graph.erase_node(lo_node) + hi_custom.graph.erase_node(hi_node) + + merged.update({i, j}) + did_merge = True + break + + return merged, did_merge + + +def _multiway_coalesce( + read_infos, merged, reads, symbolic_dims, symbolic_shape, hw_constraint +): + """Coalesce unmerged ept==1 reads whose flat offsets fall in an aligned window. + + Groups reads whose numerically-probed flat offsets fall within a + power-of-2 aligned window (up to ``max_elems_per_load``), then emits + a single wide read with per-byte ExtractSlice ops. + """ + element_type = get_custom(reads[0]).type.dtype + max_load_width = hw_constraint.max_elems_per_load(element_type) + + unmerged_infos = [read_infos[k] for k in range(len(read_infos)) if k not in merged] + if len(unmerged_infos) < 2: + return False + + coalesced_any = False + coalesced_set: set[int] = set() + for anchor_idx in range(len(unmerged_infos)): + if anchor_idx in coalesced_set: + continue + off_a, phys_a, custom_a, node_a = unmerged_infos[anchor_idx] + + group = [(anchor_idx, node_a, custom_a, 0, phys_a)] + for probe_idx in range(len(unmerged_infos)): + if probe_idx == anchor_idx or probe_idx in coalesced_set: + continue + off_p, _, custom_p, node_p = unmerged_infos[probe_idx] + raw_diff = subs_idxc(off_p - off_a) + diff_val = _numeric_eval_constant(raw_diff) + if diff_val is None: + continue + if 0 < diff_val < max_load_width: + _, _, _, phys_p = unmerged_infos[probe_idx] + group.append((probe_idx, node_p, custom_p, diff_val, phys_p)) + + if len(group) < 2: + continue + + group.sort(key=lambda x: x[3]) + max_off = group[-1][3] + wide_ept = 1 + while wide_ept <= max_off: + wide_ept *= 2 + if wide_ept > max_load_width: + continue + + base_phys = group[0][4] + + earliest_node = group[0][1] + for g in group[1:]: + candidate = g[1] + for n in custom_a.graph.nodes: + if n is candidate: + earliest_node = candidate + break + if n is earliest_node: + break + + wide_mask = _build_wide_mask_expr( + [(byte_pos, 1, g_custom) for _, _, g_custom, byte_pos, _ in group], + symbolic_shape, + wide_ept, + ) + with get_custom(earliest_node).graph.inserting_before(earliest_node): + wide_index = {} + for dim_idx, dim in enumerate(symbolic_dims): + if dim_idx == len(symbolic_dims) - 1: + wide_index[dim] = IndexSequence(base_phys[dim], wide_ept, 1) + else: + wide_index[dim] = IndexSequence(base_phys[dim], 1, 1) + + wide_read = _emit_wide_read( + custom_a, wide_index, wide_ept, earliest_node, mask_expr=wide_mask + ) + + extracts = [] + for g_idx, g_node, g_custom, byte_pos, _ in group: + with g_custom.graph.inserting_before(g_node): + ext = _emit_extract_slice( + wide_read, byte_pos, 1, g_custom, g_node, symbolic_shape + ) + extracts.append((ext, g_custom, g_node)) + + for ext, g_custom, g_node in extracts: + g_custom.replace_all_uses_with(ext) + g_custom.graph.erase_node(g_node) + coalesced_set.update(g[0] for g in group) + coalesced_any = True + + return coalesced_any + + +def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: + """Single merge pass: merge reads that access nearby physical memory. + + Two strategies are applied per (memory, ept) group: + + 1. **Pairwise contiguous merge** (``_pairwise_merge``): pairs whose + physical flat offset starts differ by exactly ``ept`` are merged + into a ``2*ept`` read with two ExtractSlice outputs. + + 2. **Multi-way coalescing** (``_multiway_coalesce``, ``ept==1`` only): + unmerged byte reads whose flat offsets fall within a power-of-2 + aligned window are replaced by a single wide read with per-byte + ExtractSlice outputs. + + Returns True if any merges or coalescing happened. + """ + from ...compiler.utils import strides_from_symbolic_shape + from ..._support.indexing import IndexingContext + + groups = _group_reads_by_memory(trace) idxc = IndexingContext.current() merged_any = False @@ -509,129 +860,17 @@ def _merge_contiguous_reads_once(trace: CapturedTrace, hw_constraint) -> bool: ) read_infos.append((flat_offset, phys_start, custom, node)) - merged = set() - for i in range(len(read_infos)): - if i in merged: - continue - for j in range(i + 1, len(read_infos)): - if j in merged: - continue - off1, phys1, custom1, node1 = read_infos[i] - off2, phys2, custom2, node2 = read_infos[j] - - raw_diff = subs_idxc(off2 - off1) - - # For reads with non-identity mappings (e.g. preshuffle - # scales), the flat-offset diff contains complex floor/Mod - # expressions that sympy.simplify cannot reduce. Use fast - # numeric probing instead. - has_complex_mapping = ( - custom1.mapping is not None and not custom1.has_identity_mapping() - ) - - # subs_idxc may fully resolve to a plain int. - if isinstance(raw_diff, (int, sympy.Integer)): - diff = int(raw_diff) - elif has_complex_mapping: - diff = _numeric_eval_constant(raw_diff) - if diff is None: - continue - else: - diff = sym_simplify(raw_diff) - if diff != ept and diff != -ept: - nv = _numeric_eval_constant(raw_diff) - if nv is not None: - diff = nv - - if diff == ept: - lo_phys, hi_phys = phys1, phys2 - lo_custom, hi_custom = custom1, custom2 - lo_node, hi_node = node1, node2 - elif diff == -ept: - lo_phys, hi_phys = phys2, phys1 - lo_custom, hi_custom = custom2, custom1 - lo_node, hi_node = node2, node1 - else: - continue - - # Find dimension that advances by ept. - merge_dim = None - for dim in symbolic_dims: - raw_d = subs_idxc(hi_phys[dim] - lo_phys[dim]) - if isinstance(raw_d, (int, sympy.Integer)): - d = int(raw_d) - elif has_complex_mapping: - d = _numeric_eval_constant(raw_d) - if d is None: - merge_dim = None - break - else: - d = sym_simplify(raw_d) - if d != ept and d != 0: - nv = _numeric_eval_constant(raw_d) - if nv is not None: - d = nv - if d == ept: - merge_dim = dim - elif not (d == 0): - merge_dim = None - break - if merge_dim is None: - continue - - # Respect hardware vector width limit. - new_ept = 2 * ept - element_type = lo_custom.type.dtype - if new_ept > hw_constraint.max_elems_per_load(element_type): - continue - with lo_custom.graph.inserting_before(lo_node): - new_index = { - dim: IndexSequence( - lo_phys[dim], - new_ept if dim == merge_dim else 1, - 1, - ) - for dim in symbolic_dims - } - - merged_read = Read( - lo_custom.memory, - elements_per_thread=new_ept, - mapping=None, - _write_dependency=lo_custom._write_dependency, - flags=lo_custom.flags, - ).add_to_graph(lo_custom.graph, loc=lo_custom.location) - merged_custom = get_custom(merged_read) - merged_custom.index = new_index - merged_custom.vector_shapes = deepcopy(lo_custom.vector_shapes) - propagate_tag(lo_node, merged_read) - - extract0 = ExtractSlice(merged_read, [0], [ept], [1]).add_to_graph( - lo_custom.graph, loc=lo_custom.location - ) - get_custom(extract0).index = deepcopy(lo_custom.index) - get_custom(extract0).vector_shapes = deepcopy( - lo_custom.vector_shapes - ) - propagate_tag(lo_node, extract0) - - extract1 = ExtractSlice( - merged_read, [ept], [ept], [1] - ).add_to_graph(lo_custom.graph, loc=lo_custom.location) - get_custom(extract1).index = deepcopy(hi_custom.index) - get_custom(extract1).vector_shapes = deepcopy( - hi_custom.vector_shapes - ) - propagate_tag(hi_node, extract1) - - lo_custom.replace_all_uses_with(extract0) - hi_custom.replace_all_uses_with(extract1) - lo_custom.graph.erase_node(lo_node) - hi_custom.graph.erase_node(hi_node) + merged, did_merge = _pairwise_merge( + read_infos, ept, symbolic_dims, symbolic_shape, hw_constraint + ) + merged_any |= did_merge - merged.update({i, j}) - merged_any = True - break + # Only ept==1 (byte) reads need multi-way coalescing; wider reads + # are already handled by the pairwise merge above. + if ept == 1 and len(read_infos) >= 2: + merged_any |= _multiway_coalesce( + read_infos, merged, reads, symbolic_dims, symbolic_shape, hw_constraint + ) return merged_any diff --git a/wave_lang/kernel/wave/asm/handlers_memory.py b/wave_lang/kernel/wave/asm/handlers_memory.py index a32394028..e0ee25d3d 100644 --- a/wave_lang/kernel/wave/asm/handlers_memory.py +++ b/wave_lang/kernel/wave/asm/handlers_memory.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from wave_lang.support.ir_imports import ( + VectorType, amdgpu_d, gpu_d, memref_d, @@ -79,14 +80,24 @@ def handle_vector_extract_strided_slice_op( offset_val = int(str(offsets).split("[")[1].split("]")[0]) size_val = int(str(sizes).split("[")[1].split("]")[0]) + # Convert MLIR element offsets to physical register indices. + # Each 32-bit VGPR holds (32 // elem_bits) elements, e.g. 4 for i8. + # An extract at element offset 4 on vector<8xi8> targets register 1, + # not register 4. + source_vec_type = VectorType(operation.operands[0].type) + elem_bits = source_vec_type.element_type.width + elems_per_reg = 32 // elem_bits + + reg_offset = offset_val // elems_per_reg + reg_count = max(1, (size_val * elem_bits + 31) // 32) + # Extract the appropriate subset of registers - if size_val == 1: - # Single scalar extract - return just the one register as a tuple - extracted_reg = source_regs[offset_val] - result_regs = (extracted_reg,) + if reg_count == 1: + # Single register extract - return just the one register as a tuple + result_regs = (source_regs[reg_offset],) else: - # Multi-element extract - return a slice - result_regs = source_regs[offset_val : offset_val + size_val] + # Multi-register extract - return a slice + result_regs = source_regs[reg_offset : reg_offset + reg_count] result_ssa = str(operation.result) self.walker.kernel_ctx.ssa_to_reg[result_ssa] = result_regs diff --git a/wave_lang/kernel/wave/opsel_scaled_mfma.py b/wave_lang/kernel/wave/opsel_scaled_mfma.py index 2a4910b93..3880dfba8 100644 --- a/wave_lang/kernel/wave/opsel_scaled_mfma.py +++ b/wave_lang/kernel/wave/opsel_scaled_mfma.py @@ -98,6 +98,8 @@ def _trace_scale_chain(scale_value): return None slice_op = bitcast_source.owner + if _is_op_named(slice_op, "arith.select"): + slice_op = slice_op.operands[1].owner if not _is_op_named(slice_op, "vector.extract_strided_slice"): return None @@ -139,9 +141,14 @@ def _trace_extract_strided_slice( ) -> Optional[tuple[Value, int]]: """Check if *value* is produced by extract_strided_slice of a vector<4xi8>. + Looks through ``arith.select`` (inserted by flatten-bounds masking) + to find the underlying extract_strided_slice. + Returns ``(source_vec4xi8, byte_offset)`` or ``None``. """ op = value.owner + if _is_op_named(op, "arith.select"): + op = op.operands[1].owner if not _is_op_named(op, "vector.extract_strided_slice"): return None source = op.operands[0] @@ -165,17 +172,19 @@ def _find_yield_op(for_view) -> Optional[Operation]: def _find_mergeable_groups( for_view, yield_op: Operation ) -> list[tuple[Value, Value, dict[int, int]]]: - """Find groups of 4 ``vector<1xi8>`` iter_args that can be coalesced. + """Find groups of ``vector<1xi8>`` iter_args that can be coalesced. A group is valid when: - * All 4 init values are ``extract_strided_slice`` at offsets - {0, 1, ..., SCALE_VECTOR_WIDTH-1} from the same ``vector<4xi8>`` - source. - * All 4 yield values follow the same pattern from a (possibly - different) ``vector<4xi8>`` source. + * At least 2 init values are ``extract_strided_slice`` at distinct + offsets from the same ``vector<4xi8>`` source. + * The corresponding yield values follow the same pattern from a + (possibly different) ``vector<4xi8>`` source. * For each member, init_offset == yield_offset (byte identity is preserved across iterations). + Partial groups (e.g. only offsets {0, 2}) are accepted — the + coalesced ``vector<4xi8>`` iter_arg simply carries unused bytes. + Returns a list of ``(init_source, yield_source, {offset: iter_index})``. """ i8 = IntegerType.get_signless(8) @@ -199,9 +208,6 @@ def _find_mergeable_groups( continue eligible.append((i, init_off, init_src, yield_src)) - # Group by init source. Multiple args can share the same source - # (e.g. two MFMAs using the same scale load), so partition by offset - # to form distinct groups of exactly 4. by_init_src = defaultdict(list) for entry in eligible: _, _, init_src, _ = entry @@ -214,12 +220,16 @@ def _find_mergeable_groups( _, off, _, _ = entry by_offset[off].append(entry) - while all(len(by_offset[o]) > 0 for o in range(SCALE_VECTOR_WIDTH)): + # Greedily form groups from available offsets. Accept any group + # with >= 2 distinct offsets (full groups of 4 are the common + # case; partial groups like {0, 2} arise from preshuffle scales). + present_offsets = [o for o in range(SCALE_VECTOR_WIDTH) if by_offset[o]] + while len(present_offsets) >= 2: members = {} init_source = None yield_owners = set() yield_source = None - for o in range(SCALE_VECTOR_WIDTH): + for o in present_offsets: idx, _, isrc, ysrc = by_offset[o].pop(0) members[o] = idx init_source = isrc @@ -227,6 +237,7 @@ def _find_mergeable_groups( yield_owners.add(id(ysrc.owner)) if len(yield_owners) == 1: result.append((init_source, yield_source, members)) + present_offsets = [o for o in range(SCALE_VECTOR_WIDTH) if by_offset[o]] return result @@ -314,15 +325,13 @@ def _rewire_for_results( members = plan.groups[g_idx][2] new_idx = plan.group_new_iter_idx[g_idx] has_users = any( - any(True for _ in old_results[members[o]].uses) - for o in range(SCALE_VECTOR_WIDTH) + any(True for _ in old_results[members[o]].uses) for o in members ) if not has_users: continue with InsertionPoint(for_op): - for o in range(SCALE_VECTOR_WIDTH): - old_i = members[o] + for o, old_i in members.items(): if not any(True for _ in old_results[old_i].uses): continue extract_slice = make_extract_slice(new_results[new_idx], o) @@ -330,12 +339,15 @@ def _rewire_for_results( def _coalesce_vector_iter_args(module: Module) -> None: - """Merge groups of 4 ``vector<1xi8>`` scf.for iter_args into ``vector<4xi8>``. + """Merge groups of ``vector<1xi8>`` scf.for iter_args into ``vector<4xi8>``. - Pipeline double-buffering splits a ``vector<4xi8>`` scale load into 4 + Pipeline double-buffering splits a ``vector<4xi8>`` scale load into individual bytes for loop-carry. This pass merges them back so that ``_trace_scale_chain`` sees the full ``extract_strided_slice`` pattern inside the loop body and the opsel optimisation fires. + + Handles both full groups (all 4 offsets present) and partial groups + (e.g. only offsets {0, 2} from preshuffle scales). """ i8 = IntegerType.get_signless(8) i64 = IntegerType.get_signless(64) @@ -364,7 +376,7 @@ def make_extract_slice(source: Value, offset: int): if not groups: continue - logger.debug(f"Coalescing {len(groups)} group(s) of 4 vector<1xi8> iter_args") + logger.debug(f"Coalescing {len(groups)} group(s) of vector<1xi8> iter_args") plan = _build_coalesce_plan(groups, for_view, yield_op) old_iter_args = list(for_view.inner_iter_args) @@ -396,8 +408,7 @@ def make_extract_slice(source: Value, offset: int): with InsertionPoint(first_op): for g_idx, (_, _, members) in enumerate(groups): merged_arg = new_for.inner_iter_args[plan.group_new_iter_idx[g_idx]] - for offset in range(SCALE_VECTOR_WIDTH): - iter_idx = members[offset] + for offset, iter_idx in members.items(): extract_slice = make_extract_slice(merged_arg, offset) extract_results[iter_idx] = extract_slice.result diff --git a/wave_lang/kernel/wave/utils/symbol_utils.py b/wave_lang/kernel/wave/utils/symbol_utils.py index 9c1013378..1cb79e1bd 100644 --- a/wave_lang/kernel/wave/utils/symbol_utils.py +++ b/wave_lang/kernel/wave/utils/symbol_utils.py @@ -181,7 +181,7 @@ def transform_mod(expr): return None mult = m if (mult is None) or (m < mult) else mult terms.append(arg) - if c >= mult: + if c is None or mult is None or c >= mult: return None return (sum(terms) % q) + c @@ -409,6 +409,8 @@ def _numeric_eval_constant(expr, num_samples: int = 48): free, evaluator = (), None if not free: + if isinstance(expr, int): + return expr if expr.has(*_BAD_ATOMS): return None if expr.is_integer is not True: