Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 61 additions & 6 deletions lit_tests/kernel/wave/scaled_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ScaledMMAType,
)
from wave_lang.kernel.wave.scheduling.schedule_enums import SchedulingType
from wave_lang.kernel.wave.templates import get_tagged_mxfp4_gemm_preshuffle_b
from wave_lang.kernel.wave.templates.test_kernels import (
get_broadcasted_scale_gemm_mxfp4,
)
Expand Down Expand Up @@ -361,9 +362,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK: }

# Epilogue Local Read
# CHECK-COUNT-16: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<1xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<8xi8>
# CHECK-COUNT-16: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space<workgroup>>, vector<16xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<1xi8>
# CHECK-COUNT-4: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<8xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space<workgroup>>, vector<16xi8>

# Epilogue MFMA
Expand Down Expand Up @@ -471,8 +472,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK: rocdl.s.waitcnt
# CHECK: amdgpu.lds_barrier

# Steady state local loads
# CHECK-COUNT-48: vector.load{{.*}} memref<{{.*}}, #gpu.address_space<workgroup>>
# Steady state local loads (8+4 scale loads as vector<8xi8> + 16+8 data loads as vector<16xi8>)
# CHECK-COUNT-36: vector.load{{.*}} memref<{{.*}}, #gpu.address_space<workgroup>>

# Steady State global load to lds
# CHECK-COUNT-34: amdgpu.gather_to_lds
Expand Down Expand Up @@ -637,9 +638,9 @@ def repeat(
# CHECK: }

# Epilogue Local Read
# CHECK-COUNT-16: vector.load {{.*}} : memref<4096xi8, #gpu.address_space<workgroup>>, vector<1xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<4096xi8, #gpu.address_space<workgroup>>, vector<8xi8>
# CHECK-COUNT-16: vector.load {{.*}} : memref<34816xi8, #gpu.address_space<workgroup>>, vector<16xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<4096xi8, #gpu.address_space<workgroup>>, vector<1xi8>
# CHECK-COUNT-4: vector.load {{.*}} : memref<4096xi8, #gpu.address_space<workgroup>>, vector<8xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<34816xi8, #gpu.address_space<workgroup>>, vector<16xi8>

# Epilogue MFMA
Expand Down Expand Up @@ -997,3 +998,57 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# Unmasked vector stores for output.
# CHECK: vector.store
# CHECK: return


@run_test
def test_dynamic_preshuffle_b_scale_coalescing():
"""Verify B-scale reads coalesce into clean vector<16xi8> with dynamic dims.

Uses the preshuffle-B MXFP4 template with dynamic M, N, K and small
block sizes. The K % 256 divisibility assumption lets the coalescer
apply divisibility substitutions during numeric probing, so the 2D
decomposition (row = offset floordiv K/2, col = offset mod K/2) gives
consistent per-dim diffs across probe sets. Without this fix, probes
like K=137 make K/2=68, causing inconsistent row/col diffs and
fragmenting 16-byte scale reads into {2, 16, 8, 4} loads glued by
vector.from_elements.
"""
shape = (256, 256, 256)
block = (128, 128, 256)
kernel, options = get_tagged_mxfp4_gemm_preshuffle_b(
shape,
block,
wave_shape=(2, 2),
mfma_variant=ScaledMMAType.F32_16x16x128_F8F6F4,
reorder_workgroups=False,
)
dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K]
for sym in dynamic_symbols:
del options.subs[sym]
options.dynamic_symbols = dynamic_symbols
options.schedule = SchedulingType.NONE
options.use_buffer_ops = True
options.compile_to_mlir = True
options.device = "hip"
options.target = "gfx950"
result = wave_compile(options, kernel)
print(result.asm)

# CHECK-LABEL: test_dynamic_preshuffle_b_scale_coalescing

# Dynamic index arguments for M, N, K.
# CHECK: func.func @gemm(%arg0: {{.*}}, %arg1: {{.*}}, %arg2: {{.*}}, %arg3: {{.*}}, %arg4: {{.*}}, %arg5: index, %arg6: index, %arg7: index)

# Buffer ops: fat_raw_buffer_cast for global buffers.
# CHECK: amdgpu.fat_raw_buffer_cast

# B-scale reads are clean vector<16xi8> from fat_raw_buffer — no
# fragmentation into mixed-width loads glued by from_elements.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

# A-scale reads are vector<4xi8>.
# CHECK: scf.for
# CHECK-COUNT-8: vector.load %{{.*}} : memref<?xi8, #amdgpu.address_space<fat_raw_buffer>>, vector<16xi8>
# CHECK-COUNT-2: vector.load %{{.*}} : memref<?xi8, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi8>
# CHECK: amdgpu.scaled_mfma

# No byte-level reassembly — coalescing succeeded.
# CHECK-NOT: vector.from_elements
4 changes: 2 additions & 2 deletions lit_tests/kernel/wave/scaled_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,9 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK-COUNT-4: vector.load {{.*}} : memref<16384x8192xi8, strided<[8192, 1]>>, vector<16xi8>
# CHECK-COUNT-1: vector.load {{.*}} : memref<16384x512xi8, strided<[512, 1]>>, vector<4xi8>
# CHECK: amdgpu.lds_barrier
# CHECK-COUNT-16: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<1xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<8xi8>
# CHECK-COUNT-16: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space<workgroup>>, vector<16xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<1xi8>
# CHECK-COUNT-4: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space<workgroup>>, vector<8xi8>
# CHECK-COUNT-8: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space<workgroup>>, vector<16xi8>
# CHECK-COUNT-8: vector.bitcast {{.*}} : vector<16xi8> to vector<32xf4E2M1FN>
# CHECK-COUNT-8: vector.bitcast {{.*}} : vector<1xi8> to vector<1xf8E8M0FNU>
Expand Down
Loading
Loading