Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def read_dynamic_buffer(a: tkl.Memory[B, M, N, ADDRESS_SPACE, tkl.f16]):

# Gets offset to tensor's base pointer, then set memref_offset = indexing_offset + base_tensor_offset.
# CHECK: %{{.*}}, %[[BASE_TENSOR_OFFSET:.+]], %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[ARG0]] : memref<?x?x16xf16, strided<[?, 16, 1], offset: ?>> -> memref<f16>, index, index, index, index, index, index, index
# CHECK: %[[MEMREF_OFFSET:.+]] = arith.addi %{{.*}}, %[[BASE_TENSOR_OFFSET]] overflow<nsw> : index
# CHECK: %[[MEMREF_OFFSET:.+]] = arith.addi %{{.*}}, %[[BASE_TENSOR_OFFSET]] : index

# CHECK: %[[MEMREF_CAST:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[MEMREF_OFFSET]]], {{.*}}: memref<?x?x16xf16, strided<[?, 16, 1], offset: ?>> to memref<?xf16, strided<[1], offset: ?>>
# CHECK: %[[SWIZZLE_CAST:.*]] = arith.index_cast %c16{{.*}} : index to i14
Expand Down Expand Up @@ -259,8 +259,8 @@ def read_write(
# CHECK: %[[S0:.*]] = memref.reinterpret_cast %[[D0]] to offset: [0], sizes: [16, 16], strides: [16, 1] : memref<f16> to memref<16x16xf16, strided<[16, 1]>>
# CHECK: %[[I0:.*]] = affine.apply #[[MAP0]]()[%[[thread_id_x]]]
# CHECK: %[[V:.*]] = vector.load %[[S0]][%[[I0]], %[[C0]]] : memref<16x16xf16, strided<[16, 1]>>, vector<16xf16>
# CHECK: memref.reinterpret_cast %[[D1]] to offset: [0], sizes: [1073741822], strides: [1] : memref<f16> to memref<1073741822xf16, strided<[1]>>
# CHECK: vector.store %[[V]], {{.*}}[{{.*}}] : memref<1073741822xf16, strided<[1]>>, vector<16xf16>
# CHECK: memref.reinterpret_cast %[[D1]] to offset: [0], sizes: [2147483646], strides: [1] : memref<f16> to memref<2147483646xf16, strided<[1]>>
# CHECK: vector.store %[[V]], {{.*}}[{{.*}}] : memref<2147483646xf16, strided<[1]>>, vector<16xf16>
Comment on lines +262 to +263
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fly-by: where does this number come from? This is an 8GB buffer, whereas it looks like we have M, N = 16, 16, meaning I'd expect to see 256 here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see the diff, it is updated from 1073741822 (f32) to 2147483646 (f16)
that is ((2^32 - 1) // 2) - 1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but why do we use this number? Memref sizes are meaningful for MLIR optimization, you must not have a wrong size.

# CHECK: return


Expand Down
4 changes: 2 additions & 2 deletions lit_tests/kernel/wave/dynamic_strides.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ def test_dynamic_strides_gemm():

# Output is linearized using dynamic strides from extract_strided_metadata, then stored to 1D view.
# CHECK: memref.extract_strided_metadata %reinterpret_cast_1 : memref<1024x1024xf32, strided<[?, 1]>>
# CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [536870910], strides: [1]
# CHECK: vector.store {{.*}} %reinterpret_cast_3{{.*}} : memref<536870910xf32, strided<[1], offset: ?>>
# CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [1073741822], strides: [1]
# CHECK: vector.store {{.*}} %reinterpret_cast_3{{.*}} : memref<1073741822xf32, strided<[1], offset: ?>>
18 changes: 9 additions & 9 deletions lit_tests/kernel/wave/scaled_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,10 +583,10 @@ def repeat(
# CHECK-DAG: %[[C512_I14:.+]] = arith.constant 512 : i14

# Prologue Global Read
# CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [2147483646], strides: [1] : memref<i8> to memref<2147483646xi8, strided<[1], offset: ?>>
# CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [4294967294], strides: [1] : memref<i8> to memref<4294967294xi8, strided<[1], offset: ?>>
# CHECK: amdgpu.fat_raw_buffer_cast %{{.*}} validBytes(%{{.*}}) cacheSwizzleStride(%[[C512_I14]]) resetOffset : memref<?xi8, strided<[1], offset: ?>> to memref<?xi8, #amdgpu.address_space<fat_raw_buffer>>
# CHECK-COUNT-4: vector.load {{.*}} : memref<?xi8, #amdgpu.address_space<fat_raw_buffer>>, vector<16xi8>
# CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [2147483646], strides: [1] : memref<i8> to memref<2147483646xi8, strided<[1], offset: ?>>
# CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [4294967294], strides: [1] : memref<i8> to memref<4294967294xi8, strided<[1], offset: ?>>
# CHECK: amdgpu.fat_raw_buffer_cast %{{.*}} validBytes(%{{.*}}) cacheSwizzleStride(%[[C32_I14]]) resetOffset : memref<?xi8, strided<[1], offset: ?>> to memref<?xi8, #amdgpu.address_space<fat_raw_buffer>>
# CHECK: vector.load {{.*}} : memref<?xi8, #amdgpu.address_space<fat_raw_buffer>>, vector<4xi8>
# CHECK-COUNT-4: vector.load {{.*}} : memref<?xi8, #amdgpu.address_space<fat_raw_buffer>>, vector<16xi8>
Expand Down Expand Up @@ -846,27 +846,27 @@ def repeat(
# CHECK-DAG: #[[MAP17:.*]] = affine_map<()[s0, s1] -> (s1 * 32 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 192) floordiv 256) * 256 + 192)>
# CHECK: func.func @batched_gemm(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding, %arg3: !stream.binding, %arg4: !stream.binding, %arg5: index, %arg6: index) attributes {translation_info = #translation} {
# CHECK-DAG: %[[CST1:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
# CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2147483647> : vector<16xindex>
# CHECK-DAG: %[[CST2:.*]] = arith.constant dense<4294967295> : vector<16xindex>
# CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
# CHECK-DAG: %[[C8192:.*]] = arith.constant 8192 : index
# CHECK-DAG: %[[C2147483646_I64:.*]] = arith.constant 2147483646 : i64
# CHECK-DAG: %[[C2147483646_I64:.*]] = arith.constant 4294967294 : i64
# CHECK-DAG: %[[C_NEG_8192_I14:.*]] = arith.constant -8192 : i14
# CHECK-DAG: %[[BLOCK_ID_X:.*]] = gpu.block_id x
# CHECK-DAG: %[[BLOCK_ID_Z:.*]] = gpu.block_id z
# CHECK-DAG: %[[THREAD_ID_X:.*]] = gpu.thread_id x upper_bound 256
# CHECK-DAG: %[[THREAD_ID_Y:.*]] = gpu.thread_id y upper_bound 2
# CHECK-DAG: %[[AFFINE_APPLY2:.*]] = affine.apply #[[MAP3]]()[%arg6]
# CHECK-DAG: %[[AFFINE_APPLY1:.*]] = affine.apply #[[MAP2]]()[%[[THREAD_ID_X]]]
# CHECK: %[[MUL1:.*]] = arith.muli %[[BLOCK_ID_Z]], %[[AFFINE_APPLY2]] overflow<nsw> : index
# CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [2147483646], strides: [1] : memref<i8> to memref<2147483646xi8, strided<[1], offset: ?>>
# CHECK: %[[CAST:.*]] = memref.cast %[[REINTERPRET_CAST]] : memref<2147483646xi8, strided<[1], offset: ?>> to memref<?xi8, strided<[1], offset: ?>>
# CHECK: %[[MUL1:.*]] = arith.muli %[[BLOCK_ID_Z]], %[[AFFINE_APPLY2]] : index
# CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [4294967294], strides: [1] : memref<i8> to memref<4294967294xi8, strided<[1], offset: ?>>
# CHECK: %[[CAST:.*]] = memref.cast %[[REINTERPRET_CAST]] : memref<4294967294xi8, strided<[1], offset: ?>> to memref<?xi8, strided<[1], offset: ?>>
# CHECK: %[[BUFF_CAST:.*]] = amdgpu.fat_raw_buffer_cast %[[CAST]] validBytes(%[[C2147483646_I64]]) cacheSwizzleStride(%[[C_NEG_8192_I14]]) resetOffset : memref<?xi8, strided<[1], offset: ?>> to memref<?xi8, #amdgpu.address_space<fat_raw_buffer>>
# CHECK: %[[AFFINE_APPLY3:.*]] = affine.apply #[[MAP6]]()[%[[THREAD_ID_X]], %[[THREAD_ID_Y]], %[[BLOCK_ID_X]]]
# CHECK: %[[CMP1:.*]] = arith.cmpi slt, %[[AFFINE_APPLY3]], %arg6 : index
# CHECK: %[[BROADCAST1:.*]] = vector.broadcast %[[CMP1]] : i1 to vector<16xi1>
# CHECK: %[[AFFINE_APPLY4:.*]] = affine.apply #[[MAP17]]()[%[[THREAD_ID_X]], %[[THREAD_ID_Y]]]
# CHECK: %[[MUL2:.*]] = arith.muli %[[AFFINE_APPLY4]], %[[C8192]] overflow<nsw> : index
# CHECK: %[[ADD1:.*]] = arith.addi %[[MUL2]], %[[AFFINE_APPLY1]] overflow<nsw> : index
# CHECK: %[[MUL2:.*]] = arith.muli %[[AFFINE_APPLY4]], %[[C8192]] : index
# CHECK: %[[ADD1:.*]] = arith.addi %[[MUL2]], %[[AFFINE_APPLY1]] : index
# CHECK: %[[IDX_CAST1:.*]] = arith.index_cast %[[ADD1]] : index to i32
# CHECK: %[[BROADCAST2:.*]] = vector.broadcast %[[IDX_CAST1]] : i32 to vector<16xi32>
# CHECK: %[[ADD3:.*]] = arith.addi %[[BROADCAST2]], %[[CST1]] : vector<16xi32>
Expand Down
10 changes: 5 additions & 5 deletions lit_tests/kernel/wave/scaled_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,17 @@ def scaled_mma(
# CHECK: %[[SCALED_MFMA:.+]] = amdgpu.scaled_mfma 16x16x128 (%[[VECTOR_LOAD_6]][0] * %[[BITCAST_0]]) * (%[[VECTOR_LOAD_4]][0] * %[[BITCAST_1]]) + %[[CST]] : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32>
# CHECK: %[[EXTRACT_STRIDED_SLICE_0:.+]] = vector.extract_strided_slice %[[SCALED_MFMA]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
# CHECK: %[[AFFINE_APPLY_4:.+]] = affine.apply #[[MAP4]]()[%[[THREAD_ID_X]]]
# CHECK: %[[REINTERPRET_CAST_6:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [536870910], strides: [1] : memref<f32> to memref<536870910xf32, strided<[1]>>
# CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_0]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<536870910xf32, strided<[1]>>, vector<1xf32>
# CHECK: %[[REINTERPRET_CAST_6:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [1073741822], strides: [1] : memref<f32> to memref<1073741822xf32, strided<[1]>>
# CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_0]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<1073741822xf32, strided<[1]>>, vector<1xf32>
# CHECK: %[[EXTRACT_STRIDED_SLICE_1:.+]] = vector.extract_strided_slice %[[SCALED_MFMA]] {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
# CHECK: %[[AFFINE_APPLY_5:.+]] = affine.apply #[[MAP5]]()[%[[THREAD_ID_X]]]
# CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_1]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<536870910xf32, strided<[1]>>, vector<1xf32>
# CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_1]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<1073741822xf32, strided<[1]>>, vector<1xf32>
# CHECK: %[[EXTRACT_STRIDED_SLICE_2:.+]] = vector.extract_strided_slice %[[SCALED_MFMA]] {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
# CHECK: %[[AFFINE_APPLY_6:.+]] = affine.apply #[[MAP6]]()[%[[THREAD_ID_X]]]
# CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_2]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<536870910xf32, strided<[1]>>, vector<1xf32>
# CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_2]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<1073741822xf32, strided<[1]>>, vector<1xf32>
# CHECK: %[[EXTRACT_STRIDED_SLICE_3:.+]] = vector.extract_strided_slice %[[SCALED_MFMA]] {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32>
# CHECK: %[[AFFINE_APPLY_7:.+]] = affine.apply #[[MAP7]]()[%[[THREAD_ID_X]]]
# CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_3]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<536870910xf32, strided<[1]>>, vector<1xf32>
# CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_3]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<1073741822xf32, strided<[1]>>, vector<1xf32>
# CHECK: return


Expand Down
8 changes: 4 additions & 4 deletions lit_tests/kernel/wave/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def topk(
# CHECK: vector.from_elements{{.*}} : vector<2xi32>

# Write operations for both values and indices (linearized 1D stores)
# CHECK: memref.reinterpret_cast {{.*}} to offset: [{{.*}}], sizes: [1073741822], strides: [1] : memref<f16>
# CHECK: vector.store {{.*}} : memref<1073741822xf16{{.*}}>, vector<2xf16>
# CHECK: memref.reinterpret_cast {{.*}} to offset: [{{.*}}], sizes: [536870910], strides: [1] : memref<i32>
# CHECK: vector.store {{.*}} : memref<536870910xi32{{.*}}>, vector<2xi32>
# CHECK: memref.reinterpret_cast {{.*}} to offset: [{{.*}}], sizes: [2147483646], strides: [1] : memref<f16>
# CHECK: vector.store {{.*}} : memref<2147483646xf16{{.*}}>, vector<2xf16>
# CHECK: memref.reinterpret_cast {{.*}} to offset: [{{.*}}], sizes: [1073741822], strides: [1] : memref<i32>
# CHECK: vector.store {{.*}} : memref<1073741822xi32{{.*}}>, vector<2xi32>
8 changes: 5 additions & 3 deletions tests/kernel/wave_gemm_mxfp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,11 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path):
"s_waitcnt vmcnt(0)",
"s_waitcnt vmcnt(0) lgkmcnt(0)",
"s_waitcnt vmcnt(0)",
"s_waitcnt lgkmcnt(14)",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is... a lot.

Copy link
Contributor Author

@xintin xintin Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's what error message stated.
If the error message is wrong, then that should be corrected or guarded too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What error message? This is a test looking for the presence of exact strings. It likely told you that a new string is present. But we need to understnad what that means, in particular we are adding a lot waits here, which will decrease performance.

"s_waitcnt lgkmcnt(7)",
"s_waitcnt lgkmcnt(5)",
"s_waitcnt lgkmcnt(4)",
"s_waitcnt lgkmcnt(3)",
"s_waitcnt lgkmcnt(6)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(2)",
"s_waitcnt lgkmcnt(1)",
"s_waitcnt lgkmcnt(0)",
]
Expand Down Expand Up @@ -834,6 +835,7 @@ def testScaledGemmMXFP4PreshuffleB(
(64, 128, 256),
(64, 128, 128),
(64, 64, 128),
(64, 64, 256),
(32, 192, 256),
(32, 128, 256),
(32, 64, 256),
Expand Down
8 changes: 4 additions & 4 deletions wave_lang/kernel/compiler/wave_codegen/read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def _get_max_buffer_size(elem_type: IrType) -> int:
Buffer ops offsets are i32, return maximum memref size in elements.
"""
return ((1 << 31) - 1) // (elem_type.width // 8)
return ((1 << 32) - 1) // (elem_type.width // 8)


def _get_strides_from_memref(mem: Value) -> list[Value]:
Expand Down Expand Up @@ -241,7 +241,7 @@ def _linearize_memref(
memref_type = mem.type
offset = None
offset_th = None
overflow_flags = arith_d.IntegerOverflowFlags.nsw
overflow_flags = arith_d.IntegerOverflowFlags.none
for ind_wg, ind_th, stride in zip(offsets_wg, offsets_th, strides):
if isinstance(ind_wg, int):
ind_wg = arith_d.constant(IndexType.get(), ind_wg)
Expand Down Expand Up @@ -342,7 +342,7 @@ def _valid_bytes_buffer(elem_type: IrType) -> int:
"""
Make valid bytes to be the address of the last byte of the second to last element that can fit in a 32 bit offset to memory address
"""
ans = (1 << 31) - 1 - (elem_type.width // 8)
ans = (1 << 32) - 1 - (elem_type.width // 8)

assert isinstance(ans, int)
return ans
Expand All @@ -359,7 +359,7 @@ def _get_out_of_bounds_index(element_type: IrType) -> int:
assert (oob_index_value * element_width_in_bytes) > _valid_bytes_buffer(
element_type
)
assert (oob_index_value * element_width_in_bytes) < (1 << 31)
assert (oob_index_value * element_width_in_bytes) < (1 << 32)
return oob_index_value


Expand Down
Loading