diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 96271b727..9fe225383 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -220,7 +220,7 @@ def read_dynamic_buffer(a: tkl.Memory[B, M, N, ADDRESS_SPACE, tkl.f16]): # Gets offset to tensor's base pointer, then set memref_offset = indexing_offset + base_tensor_offset. # CHECK: %{{.*}}, %[[BASE_TENSOR_OFFSET:.+]], %{{.*}}, %{{.*}} = memref.extract_strided_metadata %[[ARG0]] : memref> -> memref, index, index, index, index, index, index, index - # CHECK: %[[MEMREF_OFFSET:.+]] = arith.addi %{{.*}}, %[[BASE_TENSOR_OFFSET]] overflow : index + # CHECK: %[[MEMREF_OFFSET:.+]] = arith.addi %{{.*}}, %[[BASE_TENSOR_OFFSET]] : index # CHECK: %[[MEMREF_CAST:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [%[[MEMREF_OFFSET]]], {{.*}}: memref> to memref> # CHECK: %[[SWIZZLE_CAST:.*]] = arith.index_cast %c16{{.*}} : index to i14 @@ -259,8 +259,8 @@ def read_write( # CHECK: %[[S0:.*]] = memref.reinterpret_cast %[[D0]] to offset: [0], sizes: [16, 16], strides: [16, 1] : memref to memref<16x16xf16, strided<[16, 1]>> # CHECK: %[[I0:.*]] = affine.apply #[[MAP0]]()[%[[thread_id_x]]] # CHECK: %[[V:.*]] = vector.load %[[S0]][%[[I0]], %[[C0]]] : memref<16x16xf16, strided<[16, 1]>>, vector<16xf16> - # CHECK: memref.reinterpret_cast %[[D1]] to offset: [0], sizes: [1073741822], strides: [1] : memref to memref<1073741822xf16, strided<[1]>> - # CHECK: vector.store %[[V]], {{.*}}[{{.*}}] : memref<1073741822xf16, strided<[1]>>, vector<16xf16> + # CHECK: memref.reinterpret_cast %[[D1]] to offset: [0], sizes: [2147483646], strides: [1] : memref to memref<2147483646xf16, strided<[1]>> + # CHECK: vector.store %[[V]], {{.*}}[{{.*}}] : memref<2147483646xf16, strided<[1]>>, vector<16xf16> # CHECK: return diff --git a/lit_tests/kernel/wave/dynamic_strides.py b/lit_tests/kernel/wave/dynamic_strides.py index fe70fdc5a..1b9a4b796 100644 --- a/lit_tests/kernel/wave/dynamic_strides.py +++ b/lit_tests/kernel/wave/dynamic_strides.py @@ -50,5 +50,5 @@ def test_dynamic_strides_gemm(): # Output is linearized using dynamic strides from extract_strided_metadata, then stored to 1D view. # CHECK: memref.extract_strided_metadata %reinterpret_cast_1 : memref<1024x1024xf32, strided<[?, 1]>> - # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [536870910], strides: [1] - # CHECK: vector.store {{.*}} %reinterpret_cast_3{{.*}} : memref<536870910xf32, strided<[1], offset: ?>> + # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [1073741822], strides: [1] + # CHECK: vector.store {{.*}} %reinterpret_cast_3{{.*}} : memref<1073741822xf32, strided<[1], offset: ?>> diff --git a/lit_tests/kernel/wave/scaled_gemm.py b/lit_tests/kernel/wave/scaled_gemm.py index 7bf3b8b17..18dbccfae 100644 --- a/lit_tests/kernel/wave/scaled_gemm.py +++ b/lit_tests/kernel/wave/scaled_gemm.py @@ -583,10 +583,10 @@ def repeat( # CHECK-DAG: %[[C512_I14:.+]] = arith.constant 512 : i14 # Prologue Global Read - # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1], offset: ?>> + # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [4294967294], strides: [1] : memref to memref<4294967294xi8, strided<[1], offset: ?>> # CHECK: amdgpu.fat_raw_buffer_cast %{{.*}} validBytes(%{{.*}}) cacheSwizzleStride(%[[C512_I14]]) resetOffset : memref> to memref> # CHECK-COUNT-4: vector.load {{.*}} : memref>, vector<16xi8> - # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1], offset: ?>> + # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [4294967294], strides: [1] : memref to memref<4294967294xi8, strided<[1], offset: ?>> # CHECK: amdgpu.fat_raw_buffer_cast %{{.*}} validBytes(%{{.*}}) cacheSwizzleStride(%[[C32_I14]]) resetOffset : memref> to memref> # CHECK: vector.load {{.*}} : memref>, vector<4xi8> # CHECK-COUNT-4: vector.load {{.*}} : memref>, vector<16xi8> @@ -846,10 +846,10 @@ def repeat( # CHECK-DAG: #[[MAP17:.*]] = affine_map<()[s0, s1] -> (s1 * 32 + s0 floordiv 8 - ((s1 * 32 + s0 floordiv 8 + 192) floordiv 256) * 256 + 192)> # CHECK: func.func @batched_gemm(%arg0: !stream.binding, %arg1: !stream.binding, %arg2: !stream.binding, %arg3: !stream.binding, %arg4: !stream.binding, %arg5: index, %arg6: index) attributes {translation_info = #translation} { # CHECK-DAG: %[[CST1:.*]] = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32> - # CHECK-DAG: %[[CST2:.*]] = arith.constant dense<2147483647> : vector<16xindex> + # CHECK-DAG: %[[CST2:.*]] = arith.constant dense<4294967295> : vector<16xindex> # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index # CHECK-DAG: %[[C8192:.*]] = arith.constant 8192 : index - # CHECK-DAG: %[[C2147483646_I64:.*]] = arith.constant 2147483646 : i64 + # CHECK-DAG: %[[C2147483646_I64:.*]] = arith.constant 4294967294 : i64 # CHECK-DAG: %[[C_NEG_8192_I14:.*]] = arith.constant -8192 : i14 # CHECK-DAG: %[[BLOCK_ID_X:.*]] = gpu.block_id x # CHECK-DAG: %[[BLOCK_ID_Z:.*]] = gpu.block_id z @@ -857,16 +857,16 @@ def repeat( # CHECK-DAG: %[[THREAD_ID_Y:.*]] = gpu.thread_id y upper_bound 2 # CHECK-DAG: %[[AFFINE_APPLY2:.*]] = affine.apply #[[MAP3]]()[%arg6] # CHECK-DAG: %[[AFFINE_APPLY1:.*]] = affine.apply #[[MAP2]]()[%[[THREAD_ID_X]]] - # CHECK: %[[MUL1:.*]] = arith.muli %[[BLOCK_ID_Z]], %[[AFFINE_APPLY2]] overflow : index - # CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1], offset: ?>> - # CHECK: %[[CAST:.*]] = memref.cast %[[REINTERPRET_CAST]] : memref<2147483646xi8, strided<[1], offset: ?>> to memref> + # CHECK: %[[MUL1:.*]] = arith.muli %[[BLOCK_ID_Z]], %[[AFFINE_APPLY2]] : index + # CHECK: %[[REINTERPRET_CAST:.*]] = memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [4294967294], strides: [1] : memref to memref<4294967294xi8, strided<[1], offset: ?>> + # CHECK: %[[CAST:.*]] = memref.cast %[[REINTERPRET_CAST]] : memref<4294967294xi8, strided<[1], offset: ?>> to memref> # CHECK: %[[BUFF_CAST:.*]] = amdgpu.fat_raw_buffer_cast %[[CAST]] validBytes(%[[C2147483646_I64]]) cacheSwizzleStride(%[[C_NEG_8192_I14]]) resetOffset : memref> to memref> # CHECK: %[[AFFINE_APPLY3:.*]] = affine.apply #[[MAP6]]()[%[[THREAD_ID_X]], %[[THREAD_ID_Y]], %[[BLOCK_ID_X]]] # CHECK: %[[CMP1:.*]] = arith.cmpi slt, %[[AFFINE_APPLY3]], %arg6 : index # CHECK: %[[BROADCAST1:.*]] = vector.broadcast %[[CMP1]] : i1 to vector<16xi1> # CHECK: %[[AFFINE_APPLY4:.*]] = affine.apply #[[MAP17]]()[%[[THREAD_ID_X]], %[[THREAD_ID_Y]]] - # CHECK: %[[MUL2:.*]] = arith.muli %[[AFFINE_APPLY4]], %[[C8192]] overflow : index - # CHECK: %[[ADD1:.*]] = arith.addi %[[MUL2]], %[[AFFINE_APPLY1]] overflow : index + # CHECK: %[[MUL2:.*]] = arith.muli %[[AFFINE_APPLY4]], %[[C8192]] : index + # CHECK: %[[ADD1:.*]] = arith.addi %[[MUL2]], %[[AFFINE_APPLY1]] : index # CHECK: %[[IDX_CAST1:.*]] = arith.index_cast %[[ADD1]] : index to i32 # CHECK: %[[BROADCAST2:.*]] = vector.broadcast %[[IDX_CAST1]] : i32 to vector<16xi32> # CHECK: %[[ADD3:.*]] = arith.addi %[[BROADCAST2]], %[[CST1]] : vector<16xi32> diff --git a/lit_tests/kernel/wave/scaled_mma.py b/lit_tests/kernel/wave/scaled_mma.py index f37dae9b5..39f02f906 100644 --- a/lit_tests/kernel/wave/scaled_mma.py +++ b/lit_tests/kernel/wave/scaled_mma.py @@ -121,17 +121,17 @@ def scaled_mma( # CHECK: %[[SCALED_MFMA:.+]] = amdgpu.scaled_mfma 16x16x128 (%[[VECTOR_LOAD_6]][0] * %[[BITCAST_0]]) * (%[[VECTOR_LOAD_4]][0] * %[[BITCAST_1]]) + %[[CST]] : f8E8M0FNU, vector<32xf4E2M1FN>, f8E8M0FNU, vector<32xf4E2M1FN>, vector<4xf32> # CHECK: %[[EXTRACT_STRIDED_SLICE_0:.+]] = vector.extract_strided_slice %[[SCALED_MFMA]] {offsets = [0], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> # CHECK: %[[AFFINE_APPLY_4:.+]] = affine.apply #[[MAP4]]()[%[[THREAD_ID_X]]] - # CHECK: %[[REINTERPRET_CAST_6:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [536870910], strides: [1] : memref to memref<536870910xf32, strided<[1]>> - # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_0]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<536870910xf32, strided<[1]>>, vector<1xf32> + # CHECK: %[[REINTERPRET_CAST_6:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [1073741822], strides: [1] : memref to memref<1073741822xf32, strided<[1]>> + # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_0]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<1073741822xf32, strided<[1]>>, vector<1xf32> # CHECK: %[[EXTRACT_STRIDED_SLICE_1:.+]] = vector.extract_strided_slice %[[SCALED_MFMA]] {offsets = [1], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> # CHECK: %[[AFFINE_APPLY_5:.+]] = affine.apply #[[MAP5]]()[%[[THREAD_ID_X]]] - # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_1]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<536870910xf32, strided<[1]>>, vector<1xf32> + # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_1]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<1073741822xf32, strided<[1]>>, vector<1xf32> # CHECK: %[[EXTRACT_STRIDED_SLICE_2:.+]] = vector.extract_strided_slice %[[SCALED_MFMA]] {offsets = [2], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> # CHECK: %[[AFFINE_APPLY_6:.+]] = affine.apply #[[MAP6]]()[%[[THREAD_ID_X]]] - # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_2]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<536870910xf32, strided<[1]>>, vector<1xf32> + # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_2]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<1073741822xf32, strided<[1]>>, vector<1xf32> # CHECK: %[[EXTRACT_STRIDED_SLICE_3:.+]] = vector.extract_strided_slice %[[SCALED_MFMA]] {offsets = [3], sizes = [1], strides = [1]} : vector<4xf32> to vector<1xf32> # CHECK: %[[AFFINE_APPLY_7:.+]] = affine.apply #[[MAP7]]()[%[[THREAD_ID_X]]] - # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_3]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<536870910xf32, strided<[1]>>, vector<1xf32> + # CHECK: vector.store %[[EXTRACT_STRIDED_SLICE_3]], %[[REINTERPRET_CAST_6]][%{{.*}}] : memref<1073741822xf32, strided<[1]>>, vector<1xf32> # CHECK: return diff --git a/lit_tests/kernel/wave/topk.py b/lit_tests/kernel/wave/topk.py index 8c6bdac9d..feb679c1e 100644 --- a/lit_tests/kernel/wave/topk.py +++ b/lit_tests/kernel/wave/topk.py @@ -102,7 +102,7 @@ def topk( # CHECK: vector.from_elements{{.*}} : vector<2xi32> # Write operations for both values and indices (linearized 1D stores) - # CHECK: memref.reinterpret_cast {{.*}} to offset: [{{.*}}], sizes: [1073741822], strides: [1] : memref - # CHECK: vector.store {{.*}} : memref<1073741822xf16{{.*}}>, vector<2xf16> - # CHECK: memref.reinterpret_cast {{.*}} to offset: [{{.*}}], sizes: [536870910], strides: [1] : memref - # CHECK: vector.store {{.*}} : memref<536870910xi32{{.*}}>, vector<2xi32> + # CHECK: memref.reinterpret_cast {{.*}} to offset: [{{.*}}], sizes: [2147483646], strides: [1] : memref + # CHECK: vector.store {{.*}} : memref<2147483646xf16{{.*}}>, vector<2xf16> + # CHECK: memref.reinterpret_cast {{.*}} to offset: [{{.*}}], sizes: [1073741822], strides: [1] : memref + # CHECK: vector.store {{.*}} : memref<1073741822xi32{{.*}}>, vector<2xi32> diff --git a/tests/kernel/wave_gemm_mxfp_test.py b/tests/kernel/wave_gemm_mxfp_test.py index 1eadcfe49..96ef65070 100644 --- a/tests/kernel/wave_gemm_mxfp_test.py +++ b/tests/kernel/wave_gemm_mxfp_test.py @@ -366,10 +366,11 @@ def testScaledBatchedGemmMXFP4Codegen(use_water_backend: bool, tmp_path: Path): "s_waitcnt vmcnt(0)", "s_waitcnt vmcnt(0) lgkmcnt(0)", "s_waitcnt vmcnt(0)", + "s_waitcnt lgkmcnt(14)", "s_waitcnt lgkmcnt(7)", - "s_waitcnt lgkmcnt(5)", - "s_waitcnt lgkmcnt(4)", - "s_waitcnt lgkmcnt(3)", + "s_waitcnt lgkmcnt(6)", + "s_waitcnt lgkmcnt(1)", + "s_waitcnt lgkmcnt(2)", "s_waitcnt lgkmcnt(1)", "s_waitcnt lgkmcnt(0)", ] @@ -834,6 +835,7 @@ def testScaledGemmMXFP4PreshuffleB( (64, 128, 256), (64, 128, 128), (64, 64, 128), + (64, 64, 256), (32, 192, 256), (32, 128, 256), (32, 64, 256), diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 5caaada52..31179cf39 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -210,7 +210,7 @@ def _get_max_buffer_size(elem_type: IrType) -> int: Buffer ops offsets are i32, return maximum memref size in elements. """ - return ((1 << 31) - 1) // (elem_type.width // 8) + return ((1 << 32) - 1) // (elem_type.width // 8) def _get_strides_from_memref(mem: Value) -> list[Value]: @@ -241,7 +241,7 @@ def _linearize_memref( memref_type = mem.type offset = None offset_th = None - overflow_flags = arith_d.IntegerOverflowFlags.nsw + overflow_flags = arith_d.IntegerOverflowFlags.none for ind_wg, ind_th, stride in zip(offsets_wg, offsets_th, strides): if isinstance(ind_wg, int): ind_wg = arith_d.constant(IndexType.get(), ind_wg) @@ -342,7 +342,7 @@ def _valid_bytes_buffer(elem_type: IrType) -> int: """ Make valid bytes to be the address of the last byte of the second to last element that can fit in a 32 bit offset to memory address """ - ans = (1 << 31) - 1 - (elem_type.width // 8) + ans = (1 << 32) - 1 - (elem_type.width // 8) assert isinstance(ans, int) return ans @@ -359,7 +359,7 @@ def _get_out_of_bounds_index(element_type: IrType) -> int: assert (oob_index_value * element_width_in_bytes) > _valid_bytes_buffer( element_type ) - assert (oob_index_value * element_width_in_bytes) < (1 << 31) + assert (oob_index_value * element_width_in_bytes) < (1 << 32) return oob_index_value