Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 0 additions & 175 deletions lit_tests/kernel/wave/asm.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,181 +395,6 @@ def mma_multi(
# CHECK: s_endpgm


@run_test
def test_gemm_multi_wave_k_loop():
"""
Test multi-wave GEMM with K-loop (BLOCK_K=64).
Uses 4 waves per workgroup (BLOCK_M=32, BLOCK_N=32, WAVE_M=16, WAVE_N=16)
with BLOCK_K=64 to test loop generation with chained MFMA accumulators.
Verifies:
- Loop induction variable initialization
- Multiple MFMA instructions with accumulator chaining
- Loop increment and branch back
"""
constraints: list[tkw.Constraint] = [
tkw.WorkgroupConstraint(M, BLOCK_M, 0),
tkw.WorkgroupConstraint(N, BLOCK_N, 1),
tkw.TilingConstraint(K, BLOCK_K),
tkw.WaveConstraint(M, BLOCK_M // 2), # 2 waves in M dimension
tkw.WaveConstraint(N, BLOCK_N // 2), # 2 waves in N dimension
tkw.HardwareConstraint(
threads_per_wave=64,
mma_type=tkw.MMAType.F32_16x16x16_F16,
),
]

@tkw.wave(constraints)
def gemm_multi_wave(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.iterate(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a)
b_reg = tkw.read(b)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

compile_options = WaveCompileOptions(
subs={
M: 64,
N: 64,
K: 128,
BLOCK_M: 32,
BLOCK_N: 32,
BLOCK_K: 64,
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE,
LOAD_ELEMS_PER_THREAD: 4,
STORE_ELEMS_PER_THREAD: 4,
},
canonicalize=True,
compile_to_mlir=True,
)
compile_options.compile_to_asm = True
gemm_multi_wave = wave_compile(compile_options, gemm_multi_wave)
print(gemm_multi_wave.asm)

# CHECK-LABEL: test_gemm_multi_wave_k_loop
# CHECK: .protected gemm_multi_wave
# CHECK: .amdhsa_kernel gemm_multi_wave
# CHECK: .amdhsa_system_vgpr_workitem_id {{[0-9]+}}
# CHECK: gemm_multi_wave:

# Verify loop structure - header with comparison and conditional branch
# CHECK: loop_0_header:
# CHECK: s_cmp_lt_u32 s{{[0-9]+}}, s{{[0-9]+}}
# CHECK: s_cbranch_scc1 loop_0_body

# Verify loop body has MFMA instructions
# CHECK: loop_0_body:
# CHECK: v_mfma_f32_16x16x16_f16

# Verify loop latch - increment and branch back
# CHECK: loop_0_latch:
# CHECK: s_add_u32 s{{[0-9]+}}, s{{[0-9]+}}, s{{[0-9]+}}
# CHECK: s_branch loop_0_header

# Verify loop exit and result stores
# CHECK: loop_0_exit:
# CHECK: buffer_store_dword
# CHECK: s_endpgm


@run_test
def test_gemm_gather_to_lds():
"""
Test GEMM with gather_to_lds (global_to_shared) enabled.
When use_global_to_shared=True, the compiler generates buffer_load_dword...lds
instructions that load directly from global memory to LDS, bypassing VGPRs.
Verifies:
- buffer_load_dword ... lds instructions are emitted
- M0 register setup for LDS addressing
- Proper barrier synchronization (vmcnt + lgkmcnt + s_barrier)
"""
constraints: list[tkw.Constraint] = [
tkw.WorkgroupConstraint(M, BLOCK_M, 0),
tkw.WorkgroupConstraint(N, BLOCK_N, 1),
tkw.TilingConstraint(K, BLOCK_K),
tkw.WaveConstraint(M, BLOCK_M),
tkw.WaveConstraint(N, BLOCK_N),
tkw.HardwareConstraint(
threads_per_wave=64,
mma_type=tkw.MMAType.F32_16x16x16_F16,
),
]

@tkw.wave(constraints)
def gemm_g2s(
a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16],
b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32],
):
c_reg = tkl.Register[M, N, tkl.f32](0.0)

@tkw.iterate(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
a_reg = tkw.read(a)
b_reg = tkw.read(b)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD)

compile_options = WaveCompileOptions(
subs={
M: 32,
N: 32,
K: 32,
BLOCK_M: 16,
BLOCK_N: 16,
BLOCK_K: 16,
LOAD_ELEMS_PER_THREAD: 4,
STORE_ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: SHARED_ADDRESS_SPACE,
ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE,
},
canonicalize=True,
compile_to_mlir=True,
use_global_to_shared=True, # Enable gather_to_lds
)
compile_options.compile_to_asm = True
gemm_g2s = wave_compile(compile_options, gemm_g2s)
print(gemm_g2s.asm)

# CHECK-LABEL: test_gemm_gather_to_lds
# CHECK: .protected gemm_g2s
# CHECK: .amdhsa_kernel gemm_g2s
# CHECK: gemm_g2s:

# Verify loop structure for K-loop
# CHECK: loop_0_header:
# CHECK: s_cmp_lt_u32 s{{[0-9]+}}, s{{[0-9]+}}
# CHECK: s_cbranch_scc1 loop_0_body

# Verify loop body has MFMA instruction
# CHECK: loop_0_body:
# CHECK: v_mfma_f32_16x16x16_f16

# Verify loop latch
# CHECK: loop_0_latch:
# CHECK: s_branch loop_0_header

# Verify loop exit and stores
# CHECK: loop_0_exit:
# CHECK: buffer_store_dword
# CHECK: s_endpgm


@run_test
def test_cse_intermediate_caching():
"""
Expand Down
10 changes: 5 additions & 5 deletions lit_tests/kernel/wave/gather_to_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,17 +145,17 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
print(gemm.asm)

# CHECK-LABEL: test_gather_to_shared_wave_tile_aligned_coalescing
# CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 16 + s1 * 2 - (s1 floordiv 8) * 16)>
# CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 8) * 16)>
# CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + (s0 floordiv 64) * 8 + (s0 mod 64) floordiv 8 - ((s1 * 16 + (s0 floordiv 64) * 8 + (s0 mod 64) floordiv 8) floordiv 32) * 32)>
# CHECK: func.func @gemm
# CHECK: %[[BLOCK_ID_Y:.+]] = gpu.block_id y
# CHECK: %[[TIDX:.+]] = gpu.thread_id x
# CHECK: %[[TIDY:.+]] = gpu.thread_id y
# CHECK: %[[WAVE_ALIGNED_OFFSET:.+]] = affine.apply #[[MAP2]]()[%[[TIDX]], %[[TIDY]], %[[BLOCK_ID_Y]]]
# CHECK: affine.apply #[[MAP2]]()[%[[TIDX]], %[[TIDY]], %{{.*}}]
# CHECK: %[[TH_OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[TIDX]]]
# CHECK: scf.for %[[IND_VAR:.+]] = %c0
# CHECK: amdgpu.lds_barrier
# CHECK: %[[UPDATE_OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[IND_VAR]], %[[TIDX]]]
# CHECK: %[[LHS:.+]] = arith.addi %{{.*}}, %[[UPDATE_OFFSET]]
# CHECK: %[[K_STRIDE:.+]] = arith.muli %[[IND_VAR]], %{{.*}}
# CHECK: %[[LHS:.+]] = arith.addi %{{.*}}, %[[K_STRIDE]]


@run_test
Expand Down
48 changes: 21 additions & 27 deletions lit_tests/kernel/wave/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 4)>
# CHECK-DAG: #{{.*}} = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)>
# CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)>
# CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 * 16 + ((s1 mod 64) floordiv 16) * 4)>
# CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (s0 * 32)>
# CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)>
# CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)>
Expand Down Expand Up @@ -284,8 +283,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK-DAG: #[[MAP_IDX_N:.+]] = affine_map<()[s0, s1, s2, s3] -> (s1 * 32 + s2 * 64 + s0 floordiv 4 - ((s1 * 32 + s0 floordiv 4) floordiv 64) * 64 + ((s2 + s3 * 8) floordiv 32) * 256 - (s2 floordiv 4) * 256)>
# CHECK-DAG: %[[IDX_M_READ:.+]] = affine.apply #[[MAP_IDX_M]]()[%thread_id_x, %thread_id_y, %block_id_y, %block_id_x]
# CHECK-DAG: %[[IDX_N_READ:.+]] = affine.apply #[[MAP_IDX_N]]()[%thread_id_x, %thread_id_y, %block_id_x, %block_id_y]
# CHECK-DAG: vector.load {{.*}}[%[[IDX_M_READ]], {{.*}}]
# CHECK-DAG: vector.load {{.*}}[%[[IDX_N_READ]], {{.*}}]
# CHECK-DAG: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16>
# CHECK-DAG: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16>
# CHECK: amdgpu.mfma
# CHECK: vector.store {{.*}} : memref<{{.*}}xf32{{.*}}>, vector<1xf32>

Expand Down Expand Up @@ -483,11 +482,11 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
# CHECK-DAG: %[[C768:.+]] = arith.constant 768 : index
# CHECK-DAG: %[[CST:.+]] = arith.constant dense<0> : vector<4xi32>
# CHECK-DAG: %[[GLOBAL_0:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [64, 64], strides: [64, 1] : memref<i8> to memref<64x64xi8, strided<[64, 1]>>
# CHECK-DAG: %[[GLOBAL_1:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<i8> to memref<128x64xi8, strided<[64, 1]>>
# CHECK: %[[BASE_ALLOC:.+]] = memref.alloc() : memref<1536xi8, #gpu.address_space<workgroup>>
# CHECK: %[[ALLOC_0:.+]] = memref.view %[[BASE_ALLOC]][%[[C0]]]
# CHECK: %[[ALLOC_1:.+]] = memref.view %[[BASE_ALLOC]][%[[C768]]]
# CHECK: %[[GLOBAL_0:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref<i8> to memref<{{.*}}>
# CHECK: %[[GLOBAL_1:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref<i8> to memref<{{.*}}>
# CHECK: scf.for %[[IVAR:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[CST]]) -> (vector<4xi32>) {
# CHECK: %[[REG_0:.+]] = vector.load %[[GLOBAL_0]]
# CHECK: vector.store %[[REG_0]], %[[ALLOC_1]]
Expand Down Expand Up @@ -557,11 +556,11 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
# CHECK-DAG: %[[C1280:.+]] = arith.constant 1280 : index
# CHECK-DAG: %[[CST:.+]] = arith.constant dense<0> : vector<4xi32>
# CHECK-DAG: %[[GLOBAL_0:.+]] = memref.reinterpret_cast %0 to offset: [0], sizes: [64, 64], strides: [64, 1] : memref<i8> to memref<64x64xi8, strided<[64, 1]>>
# CHECK-DAG: %[[GLOBAL_1:.+]] = memref.reinterpret_cast %1 to offset: [0], sizes: [128, 64], strides: [64, 1] : memref<i8> to memref<128x64xi8, strided<[64, 1]>>
# CHECK: %[[BASE_ALLOC:.+]] = memref.alloc() : memref<2560xi8, #gpu.address_space<workgroup>>
# CHECK: %[[ALLOC_0:.+]] = memref.view %[[BASE_ALLOC]][%[[C0]]]
# CHECK: %[[ALLOC_1:.+]] = memref.view %[[BASE_ALLOC]][%[[C1280]]]
# CHECK: %[[GLOBAL_0:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref<i8> to memref<{{.*}}>
# CHECK: %[[GLOBAL_1:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref<i8> to memref<{{.*}}>
# CHECK: scf.for %[[IVAR:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[CST]]) -> (vector<4xi32>) {
# CHECK: %[[REG_0:.+]] = vector.load %[[GLOBAL_0]]
# CHECK: vector.store %[[REG_0]], %[[ALLOC_1]]
Expand Down Expand Up @@ -644,7 +643,6 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# %[[IV_K:.+]] = affine.apply #[[MAP_IV_K]]()[%[[IV]], %[[TID_X]]]

# CHECK-LABEL: test_packed_gemm
# CHECK-DAG: #[[MAP_IV_K:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + ((s1 mod 64) floordiv 16) * 2)>
# CHECK: func.func @packed_gemm
# CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
# CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
Expand All @@ -654,11 +652,10 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK: %[[RHS_SHARED:.+]] = memref.view %[[ALLOC]][%c0][] : memref<2560xi8, #gpu.address_space<workgroup>> to memref<32x10xi32, #gpu.address_space<workgroup>>
# CHECK: %[[LHS_SHARED:.+]] = memref.view %[[ALLOC]][%c1280][] : memref<2560xi8, #gpu.address_space<workgroup>> to memref<32x10xi32, #gpu.address_space<workgroup>>
# CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[C4]] step %[[C1]]
# CHECK: %[[IV_K:.+]] = affine.apply #[[MAP_IV_K]]()[%[[IV]], %[[TID_X]]]
# CHECK: %[[LHS_REG:.+]] = vector.load %{{.*}}[%{{.*}}, %[[IV_K]]] : memref<64x32xi32, strided<[32, 1]>>, vector<2xi32>
# CHECK: %[[LHS_REG:.+]] = vector.load {{.*}} : memref<{{.*}}>, vector<2xi32>
# CHECK: amdgpu.lds_barrier
# CHECK: vector.store %[[LHS_REG]], %[[LHS_SHARED]]
# CHECK: %[[RHS_REG:.+]] = vector.load %{{.*}}[%{{.*}}, %[[IV_K]]] : memref<128x32xi32, strided<[32, 1]>>, vector<2xi32>
# CHECK: %[[RHS_REG:.+]] = vector.load {{.*}} : memref<{{.*}}>, vector<2xi32>
# CHECK: vector.store %[[RHS_REG]], %[[RHS_SHARED]]
# CHECK: amdgpu.lds_barrier
# CHECK-COUNT-2: vector.load {{.*}} : {{.*}}, vector<2xi32>
Expand Down Expand Up @@ -728,7 +725,6 @@ def repeat(
# CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 4)>
# CHECK-DAG: #{{.*}} = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)>
# CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)>
# CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 * 16 + ((s1 mod 64) floordiv 16) * 4)>
# CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (s0 * 32)>
# CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)>
# CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)>
Expand Down Expand Up @@ -1261,11 +1257,11 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK: scf.for
# CHECK-COUNT-1: amdgpu.lds_barrier
# Steady State Local Read
# CHECK-COUNT-4: vector.load %[[ALLOC_0]]
# CHECK-COUNT-4: vector.load %[[ALLOC_1]]
# CHECK-COUNT-4: vector.load %[[VIEW_0]]
# CHECK-COUNT-4: vector.load %[[VIEW_1]]

# Steady State Global Read
# CHECK-COUNT-2: vector.load {{.*}} : memref<128x128xf16, strided<[128, 1]>>, vector<8xf16>
# CHECK-COUNT-2: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16>
# CHECK-COUNT-2: rocdl.sched.group.barrier

# Compute
Expand All @@ -1278,8 +1274,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK: scf.yield

# Prologue
# CHECK-COUNT-4: vector.load %[[ALLOC_0]]
# CHECK-COUNT-4: vector.load %[[ALLOC_1]]
# CHECK-COUNT-4: vector.load %[[VIEW_0]]
# CHECK-COUNT-4: vector.load %[[VIEW_1]]
# CHECK-COUNT-8: amdgpu.mfma


Expand Down Expand Up @@ -1656,7 +1652,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK: rocdl.sched.barrier

# 1st Cluster: Global load LHS
# CHECK-COUNT-2: vector.load {{.*}} : memref<4096x4096xf16, strided<[4096, 1]>>, vector<8xf16>
# CHECK-COUNT-2: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16>
# CHECK: rocdl.sched.barrier

# 1st Cluster: Second slice of Local read lhs and rhs
Expand All @@ -1667,7 +1663,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK: rocdl.sched.barrier

# 1st Cluster: Global load RHS
# CHECK-COUNT-4: vector.load {{.*}} : memref<4096x4096xf16, strided<[4096, 1]>>, vector<8xf16>
# CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16>
# CHECK: rocdl.s.barrier
# CHECK: rocdl.sched.barrier

Expand Down Expand Up @@ -2214,12 +2210,12 @@ def repeat(
# CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} {
# CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
# CHECK-DAG: %[[WG_ID2:.*]] = gpu.block_id z
# CHECK: %[[LHS_GLOBAL:.*]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [1, 64, 64], strides: [4096, 64, 1] : memref<f16> to memref<1x64x64xf16, strided<[4096, 64, 1]>>
# CHECK: %[[RHS_GLOBAL:.*]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [6, 128, 64], strides: [8192, 64, 1] : memref<f16> to memref<6x128x64xf16, strided<[8192, 64, 1]>>
# CHECK: %[[HKV_IDX:.+]] = affine.apply #[[MAP]]()[%[[WG_ID2]]]
# CHECK: %[[GLOBAL_LHS:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref<f16> to memref<{{.*}}>
# CHECK: %[[GLOBAL_RHS:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref<f16> to memref<{{.*}}>
# CHECK: scf.for
# CHECK: %[[LHS_READ:.+]] = vector.load %[[LHS_GLOBAL]][%[[HKV_IDX]], %{{.+}}, {{.+}}] : {{.*}}, vector<8xf16>
# CHECK: %[[RHS_READ:.+]] = vector.load %[[RHS_GLOBAL]][%[[WG_ID2]], %{{.+}}, {{.+}}] : {{.*}}, vector<8xf16>
# CHECK: %[[LHS_READ:.+]] = vector.load %[[GLOBAL_LHS]]
# CHECK: %[[RHS_READ:.+]] = vector.load %[[GLOBAL_RHS]]
# CHECK-COUNT-2: vector.extract_strided_slice
# CHECK-COUNT-1: amdgpu.mfma
# CHECK-COUNT-2: vector.extract_strided_slice
Expand Down Expand Up @@ -2382,19 +2378,17 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
# CHECK-LABEL: test_explicit_shared_gemm
# CHECK: func.func @gemm
# CHECK-SAME: (%[[ARG0:.*]]: !stream.binding, %[[ARG1:.*]]: !stream.binding, %[[ARG2:.*]]: !stream.binding)
# CHECK-DAG: %[[GLOBAL_A:.+]] = memref.reinterpret_cast %{{.*}} to offset: [{{.*}}], sizes: [64, 64], strides: [64, 1] : memref<f16> to memref<64x64xf16{{.*}}>
# CHECK-DAG: %[[GLOBAL_B:.+]] = memref.reinterpret_cast %{{.*}} to offset: [{{.*}}], sizes: [128, 64], strides: [64, 1] : memref<f16> to memref<128x64xf16{{.*}}>
# Verify explicit shared memory allocations (two separate allocs)
# CHECK: %[[ALLOC_A:.+]] = memref.alloc() : memref<{{.*}}xf16, #gpu.address_space<workgroup>>
# CHECK: %[[ALLOC_B:.+]] = memref.alloc() : memref<{{.*}}xf16, #gpu.address_space<workgroup>>
# CHECK: scf.for
# Verify load from global memory (A)
# CHECK: vector.load %[[GLOBAL_A]]
# CHECK: vector.load {{.*}} : memref<{{.*}}>, vector<{{.*}}>
# Verify barrier before shared memory writes
# CHECK: amdgpu.lds_barrier
# Verify write to shared memory
# CHECK: vector.store %{{.*}}, %[[ALLOC_A]]
# CHECK: vector.load %[[GLOBAL_B]]
# CHECK: vector.load {{.*}} : memref<{{.*}}>, vector<{{.*}}>
# CHECK: vector.store %{{.*}}, %[[ALLOC_B]]
# Verify barrier before shared memory reads
# CHECK: amdgpu.lds_barrier
Expand Down
Loading