Skip to content
68 changes: 68 additions & 0 deletions docs/per-workgroup-srd-base-adjustment.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Per-Workgroup SRD Base Adjustment for >4GB Output Buffers

## Problem

For large GEMM shapes (e.g., `M=32768, N=57344, K=16384`), the C output matrix `memref<32768x57344xf32>` is ~7GB. This caused two failures:

1. **Assembly error**: `s_mov_b32 srd[2], 0x1C0000000` — the SRD `num_records` field is 32 bits, but the computed buffer size (7,516,192,768 bytes) exceeds 2^32.
2. **Address overflow**: even if `num_records` were clamped, the store voffset `row * 229376 + col * 4` overflows a 32-bit VGPR for workgroups targeting the upper portion of the output matrix.

## Fix

Split the store byte offset into a **workgroup base** (folded into the SRD base address via 64-bit SALU) and a **thread offset** (small, used as `voffset`). This matches AITER's per-workgroup SRD pattern.

### Layer 1: MLIR codegen (Python)

**File**: `wave_lang/kernel/compiler/wave_codegen/read_write.py`

The existing `_linearize_memref` function already separates workgroup offsets (from `block_id * tile_size`) into the memref base pointer and returns thread-only offsets for indexing. It was previously gated on `buffer_ops_enabled`.

Change: for global writes without `buffer_ops`, also call `_linearize_memref` (skipping `_cast_buffer_and_encode_stride`). This produces a `memref.reinterpret_cast` with a dynamic per-workgroup element offset and 1D thread-only store indices:

```mlir
%wg_offset = arith.addi(arith.muli(%block_id_x_times_128, 57344), %block_id_y_times_256)
%tile_mem = memref.reinterpret_cast(%c_raw) offset: [%wg_offset], sizes: [536870910], strides: [1]
vector.store(%val, %tile_mem, [%thread_offset])
```

Thread offsets stay within ~28MB (the 128x256 tile), fitting comfortably in 32 bits.

### Layer 2: C++ backend

Three changes in the WaveASM C++ backend:

**1. Clamp buffer size** (`TranslateFromMLIR.cpp`): In `emitSRDPrologue`, clamp `pending.bufferSize` to `0xFFFFFFFF` before emitting `s_mov_b32` for `num_records`. This is a safety net — the original full-sized `reinterpret_cast` still exists in the MLIR but is unused by stores after linearization.

**2. Track pending SRD adjustments** (`MemRefHandlers.cpp`): In `handleMemRefReinterpretCast`, detect dynamic offsets (from `_linearize_memref`) and store the element offset Value, source SRD index, and element byte width in a `PendingSRDBaseAdjust` map. The actual SALU ops are deferred to the store handler to survive DCE.

**3. Emit SRD adjustment inline** (`TranslateFromMLIR.cpp`): In `handleVectorStore`, when a pending adjustment exists for the store target, emit:

```asm
s_mov_b64 s[N:N+1], s[src:src+1] ; copy source SRD base
v_readfirstlane_b32 s[N+3], vOffset ; element offset → SGPR
s_mul_hi_u32 s[N+2], s[N+3], 4 ; byte offset high (for >4GB)
s_mul_i32 s[N+3], s[N+3], 4 ; byte offset low
s_add_u32 s[N], s[N], s[N+3] ; base_lo += byteOffLo (sets SCC)
s_addc_u32 s[N+1], s[N+1], s[N+2] ; base_hi += byteOffHi + carry
s_mov_b32 s[N+2], 0x7FFFFFF8 ; num_records (tile-sized)
s_mov_b32 s[N+3], 0x20000 ; stride descriptor
```

The adjustment uses `PSRegType` (precolored physical SGPRs) for all intermediates, with `s[N+2]` and `s[N+3]` serving as temporaries before being overwritten by `num_records` and `stride`. After the first store emits the adjustment, subsequent stores reuse the adjusted SRD via `setSRDIndex`.

### Layer 3: Dialect changes

**File**: `WaveASMOps.td`

- Added `S_ADDC_U32` (carry-dependent add, reads SCC from preceding `s_add_u32`).
- Made `S_ADD_U32` and `S_ADDC_U32` non-`Pure`. These ops set SCC as a side effect; removing `Pure` prevents the canonicalizer from DCE'ing the SRD adjustment chain (whose PSRegType results have no explicit SSA users — they communicate through physical register aliasing with the later `PrecoloredSRegOp`).

## Files modified

| File | Change |
|------|--------|
| `wave_codegen/read_write.py` | Call `_linearize_memref` for global writes without `buffer_ops` |
| `TranslateFromMLIR.cpp` | Clamp `bufferSize` in `emitSRDPrologue`; emit SRD adjustment in `handleVectorStore` |
| `TranslateFromMLIR.h` | Add `PendingSRDBaseAdjust` struct and tracking methods |
| `handlers/MemRefHandlers.cpp` | Detect dynamic offset in `handleMemRefReinterpretCast`, track for deferred emission |
| `WaveASMOps.td` | Add `S_ADDC_U32`; make `S_ADD_U32`/`S_ADDC_U32` non-Pure |
9 changes: 8 additions & 1 deletion examples/python/7.1_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
b_preshuffle,
e8m0_shuffle,
)
import wave_lang.kernel.lang as tkl
from wave_lang.kernel.lang.global_symbols import GLOBAL_ADDRESS_SPACE
from utils import parse_args, list_tests, run_test

Expand Down Expand Up @@ -254,8 +255,14 @@ def test_dbuf_4wave_mxfp_asymmetric_gemm_cpp(
def test_dbuf_4wave_mxfp_preshuffle_b_gemm_cpp(
is_debug=False, shape=(1024, 1024, 8192), block=(128, 256, 256)
):
"""Preshuffle-B MXFP4 GEMM using C++ WaveASM backend."""
"""Preshuffle-B MXFP4 GEMM using C++ WaveASM backend with dynamic M/N/K."""
gemm, options = get_tagged_mxfp4_gemm_preshuffle_b(shape, block, wave_shape=(1, 4))
# Make M, N, K dynamic so the compiler does not specialize on problem size.
dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K]
for sym in dynamic_symbols:
del options.subs[sym]
options.dynamic_symbols = dynamic_symbols
options.use_buffer_ops = True
options.backend = "asm"
options.wave_runtime = True
options.use_wave_asm_backend = True
Expand Down
14 changes: 6 additions & 8 deletions tests/kernel/wave_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
delinearize_index,
divide_shape_into_chunks,
)
from wave_lang.kernel.wave.utils.mapping_utils import (
_simplify_sympy_expr,
)
from wave_lang.kernel.wave.utils.symbol_utils import simplify
from wave_lang.kernel.wave.constraints import MMAType
from wave_lang.kernel.wave.compile import WaveCompileOptions, wave_compile
from wave_lang.kernel.wave.templates.gemm import get_gemm_kernel
Expand Down Expand Up @@ -111,10 +109,10 @@ def test_divide_shape_into_chunks():
def test_custom_sympy_simplifications():
a = sympy.Symbol("a", integer=True, nonnegative=True)
mod_expr = (sympy.floor(a) * 4 + 3) % 16
assert str(_simplify_sympy_expr(mod_expr)) == "4*(Mod(a, 4)) + 3"
assert str(simplify(mod_expr)) == "4*(Mod(a, 4)) + 3"

floor_expr = sympy.floor(sympy.floor(a) / 3 + sympy.sympify(1) / 6)
assert str(_simplify_sympy_expr(floor_expr)) == "floor(a/3)"
assert str(simplify(floor_expr)) == "floor(a/3)"


@pytest.mark.skip("Too slow")
Expand All @@ -139,7 +137,7 @@ def test_fuzz_custom_sympy_simplifications_mod():
expr = expr.subs({a: vals[0], b: vals[1], c: vals[2]})
expr = sympy.simplify(expr)

expr2 = _simplify_sympy_expr(expr)
expr2 = simplify(expr)

if i % 50 == 0 and i > 0:
print(f"{100*i/outer_num_iters}%")
Expand Down Expand Up @@ -453,7 +451,7 @@ def check_specific(*vals):
expr1 = orig_expr.subs({a: vals[0], b: vals[1], c: vals[2], d: vals[3]})
expr1 = sympy.simplify(expr1)

expr2 = _simplify_sympy_expr(expr1)
expr2 = simplify(expr1)
assert expr1.subs({x: vals[4]}) == expr2.subs({x: vals[4]})

check_specific(10, 11, 6, 10, 6)
Expand All @@ -477,7 +475,7 @@ def check_specific(*vals):
expr = orig_expr.subs({a: vals[0], b: vals[1], c: vals[2], d: vals[3]})
expr = sympy.simplify(expr)

expr2 = _simplify_sympy_expr(expr)
expr2 = simplify(expr)
if expr != expr2:
break

Expand Down
31 changes: 31 additions & 0 deletions tests/unittests/symbol_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,37 @@ def test_bounds_unsupported_returns_none():
assert expr_bounds(x**2) is None


def test_bounds_ceiling():
x = _sym("x")
inner = sympy.Mod(x, 16, evaluate=False) / 16
# ceiling([0, 15/16]) = (0, 1).
assert expr_bounds(sympy.ceiling(inner)) == (0, 1)


def test_bounds_piecewise():
x = _sym("x")
pw = sympy.Piecewise(
(sympy.Mod(x, 4, evaluate=False), x > 10),
(sympy.Integer(5), True),
)
# Branch 0: [0, 3], branch 1: [5, 5] → envelope [0, 5].
assert expr_bounds(pw) == (0, 5)


def test_bounds_max():
x = _sym("x")
a = sympy.Mod(x, 4, evaluate=False) # [0, 3]
b = sympy.Mod(x, 8, evaluate=False) # [0, 7]
assert expr_bounds(sympy.Max(a, b)) == (0, 7)


def test_bounds_min():
x = _sym("x")
a = sympy.Mod(x, 4, evaluate=False) # [0, 3]
b = sympy.Mod(x, 8, evaluate=False) # [0, 7]
assert expr_bounds(sympy.Min(a, b)) == (0, 3)


# ---- simplify tests ----


Expand Down
37 changes: 33 additions & 4 deletions wave_lang/kernel/compiler/host_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@
tensor_d,
)

import sympy

from .._support.indexing import IndexSymbol
from ..wave.utils.general_utils import infer_dim
from ...support.location_config import LocationCaptureConfig
from .builder import (
ModuleBuilder,
Expand Down Expand Up @@ -150,14 +153,30 @@ def isolated_test_call(
argument_dims = get_dynamic_dims(host_sig.buffer_bindings, dynamic_symbols)

# Map dynamic symbols to buffer argument indices and dimensions.
# For derived shapes like K/2, also store the inverse expression
# so we can recover K from the buffer dimension at runtime.
arg_dim_mapping: dict[IndexSymbol, tuple[int, int]] = {}
# Maps symbol -> sympy expression to recover it from the dim value.
# For direct matches (M in shape[M, ...]) this is just a dummy d.
# For derived (K/2 in shape[M, K/2]) this is e.g. 2*d.
_dim_val = sympy.Symbol("_dim_val")
arg_dim_inverse: dict[IndexSymbol, sympy.Expr] = {}
for arg_idx, b in enumerate(host_sig.buffer_bindings):
shape = b.kernel_buffer_type.symbolic_shape
for dim_idx, dim_symbol in enumerate(shape):
if dim_symbol in arg_dim_mapping:
for dim_idx, dim_expr in enumerate(shape):
base_sym = infer_dim(dim_expr)
if base_sym in arg_dim_mapping:
continue

arg_dim_mapping[dim_symbol] = (arg_idx, dim_idx)
arg_dim_mapping[base_sym] = (arg_idx, dim_idx)
if dim_expr == base_sym:
arg_dim_inverse[base_sym] = _dim_val
else:
# Solve shape_expr = d for the base symbol.
solutions = sympy.solve(dim_expr - _dim_val, base_sym)
assert (
len(solutions) == 1
), f"Cannot solve {dim_expr} = _dim_val for {base_sym}"
arg_dim_inverse[base_sym] = solutions[0]

if async_dispatch:
fence_type = IrType.parse("!hal.fence")
Expand Down Expand Up @@ -217,6 +236,8 @@ def isolated_test_call(
]

# Get the dynamic symbols values from the buffer dimensions.
# For derived shapes (K/2), apply the inverse expression to
# recover the original symbol value.
dynamic_argument_map: dict[IndexSymbol, Value] = {}
for symbol in dynamic_symbols:
if symbol in arg_dim_mapping:
Expand Down Expand Up @@ -338,6 +359,14 @@ def isolated_test_call(
else:
# If no device constraints, just dispatch the kernel directly
# with the provided host signature arguments.
from .wave_codegen.emitter import gen_sympy_index as _gen

def _resolve_dim(expr):
"""Resolve a shape expression to an IR value."""
if expr in dynamic_argument_map:
return dynamic_argument_map[expr]
return _gen(dynamic_argument_map, expr)

out = flow_d.DispatchOp(
memref_to_tensor(output_types),
[dynamic_argument_map[dim] for dim in dynamic_symbols]
Expand Down
1 change: 0 additions & 1 deletion wave_lang/kernel/compiler/wave_codegen/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,6 @@ def _rem(lhs, rhs):
rem_expr(muli_expr(lhs, rhs.denominator), rhs.numerator),
rhs.denominator,
)

return rem_expr(lhs, rhs)

def _floordiv(lhs, rhs):
Expand Down
Loading
Loading