diff --git a/lit_tests/kernel/wave/asm.py b/lit_tests/kernel/wave/asm.py index f2ff99598..6ebecc97d 100644 --- a/lit_tests/kernel/wave/asm.py +++ b/lit_tests/kernel/wave/asm.py @@ -395,181 +395,6 @@ def mma_multi( # CHECK: s_endpgm -@run_test -def test_gemm_multi_wave_k_loop(): - """ - Test multi-wave GEMM with K-loop (BLOCK_K=64). - - Uses 4 waves per workgroup (BLOCK_M=32, BLOCK_N=32, WAVE_M=16, WAVE_N=16) - with BLOCK_K=64 to test loop generation with chained MFMA accumulators. - - Verifies: - - Loop induction variable initialization - - Multiple MFMA instructions with accumulator chaining - - Loop increment and branch back - """ - constraints: list[tkw.Constraint] = [ - tkw.WorkgroupConstraint(M, BLOCK_M, 0), - tkw.WorkgroupConstraint(N, BLOCK_N, 1), - tkw.TilingConstraint(K, BLOCK_K), - tkw.WaveConstraint(M, BLOCK_M // 2), # 2 waves in M dimension - tkw.WaveConstraint(N, BLOCK_N // 2), # 2 waves in N dimension - tkw.HardwareConstraint( - threads_per_wave=64, - mma_type=tkw.MMAType.F32_16x16x16_F16, - ), - ] - - @tkw.wave(constraints) - def gemm_multi_wave( - a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], - b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], - ): - c_reg = tkl.Register[M, N, tkl.f32](0.0) - - @tkw.iterate(K, init_args=[c_reg]) - def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: - a_reg = tkw.read(a) - b_reg = tkw.read(b) - acc = tkw.mma(a_reg, b_reg, acc) - return acc - - tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) - - compile_options = WaveCompileOptions( - subs={ - M: 64, - N: 64, - K: 128, - BLOCK_M: 32, - BLOCK_N: 32, - BLOCK_K: 64, - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, - LOAD_ELEMS_PER_THREAD: 4, - STORE_ELEMS_PER_THREAD: 4, - }, - canonicalize=True, - compile_to_mlir=True, - ) - compile_options.compile_to_asm = True - gemm_multi_wave = wave_compile(compile_options, gemm_multi_wave) - print(gemm_multi_wave.asm) - - # CHECK-LABEL: test_gemm_multi_wave_k_loop - # CHECK: .protected gemm_multi_wave - # CHECK: .amdhsa_kernel gemm_multi_wave - # CHECK: .amdhsa_system_vgpr_workitem_id {{[0-9]+}} - # CHECK: gemm_multi_wave: - - # Verify loop structure - header with comparison and conditional branch - # CHECK: loop_0_header: - # CHECK: s_cmp_lt_u32 s{{[0-9]+}}, s{{[0-9]+}} - # CHECK: s_cbranch_scc1 loop_0_body - - # Verify loop body has MFMA instructions - # CHECK: loop_0_body: - # CHECK: v_mfma_f32_16x16x16_f16 - - # Verify loop latch - increment and branch back - # CHECK: loop_0_latch: - # CHECK: s_add_u32 s{{[0-9]+}}, s{{[0-9]+}}, s{{[0-9]+}} - # CHECK: s_branch loop_0_header - - # Verify loop exit and result stores - # CHECK: loop_0_exit: - # CHECK: buffer_store_dword - # CHECK: s_endpgm - - -@run_test -def test_gemm_gather_to_lds(): - """ - Test GEMM with gather_to_lds (global_to_shared) enabled. - - When use_global_to_shared=True, the compiler generates buffer_load_dword...lds - instructions that load directly from global memory to LDS, bypassing VGPRs. - - Verifies: - - buffer_load_dword ... lds instructions are emitted - - M0 register setup for LDS addressing - - Proper barrier synchronization (vmcnt + lgkmcnt + s_barrier) - """ - constraints: list[tkw.Constraint] = [ - tkw.WorkgroupConstraint(M, BLOCK_M, 0), - tkw.WorkgroupConstraint(N, BLOCK_N, 1), - tkw.TilingConstraint(K, BLOCK_K), - tkw.WaveConstraint(M, BLOCK_M), - tkw.WaveConstraint(N, BLOCK_N), - tkw.HardwareConstraint( - threads_per_wave=64, - mma_type=tkw.MMAType.F32_16x16x16_F16, - ), - ] - - @tkw.wave(constraints) - def gemm_g2s( - a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], - b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], - c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], - ): - c_reg = tkl.Register[M, N, tkl.f32](0.0) - - @tkw.iterate(K, init_args=[c_reg]) - def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: - a_reg = tkw.read(a) - b_reg = tkw.read(b) - acc = tkw.mma(a_reg, b_reg, acc) - return acc - - tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) - - compile_options = WaveCompileOptions( - subs={ - M: 32, - N: 32, - K: 32, - BLOCK_M: 16, - BLOCK_N: 16, - BLOCK_K: 16, - LOAD_ELEMS_PER_THREAD: 4, - STORE_ELEMS_PER_THREAD: 4, - ADDRESS_SPACE: SHARED_ADDRESS_SPACE, - ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, - }, - canonicalize=True, - compile_to_mlir=True, - use_global_to_shared=True, # Enable gather_to_lds - ) - compile_options.compile_to_asm = True - gemm_g2s = wave_compile(compile_options, gemm_g2s) - print(gemm_g2s.asm) - - # CHECK-LABEL: test_gemm_gather_to_lds - # CHECK: .protected gemm_g2s - # CHECK: .amdhsa_kernel gemm_g2s - # CHECK: gemm_g2s: - - # Verify loop structure for K-loop - # CHECK: loop_0_header: - # CHECK: s_cmp_lt_u32 s{{[0-9]+}}, s{{[0-9]+}} - # CHECK: s_cbranch_scc1 loop_0_body - - # Verify loop body has MFMA instruction - # CHECK: loop_0_body: - # CHECK: v_mfma_f32_16x16x16_f16 - - # Verify loop latch - # CHECK: loop_0_latch: - # CHECK: s_branch loop_0_header - - # Verify loop exit and stores - # CHECK: loop_0_exit: - # CHECK: buffer_store_dword - # CHECK: s_endpgm - - @run_test def test_cse_intermediate_caching(): """ diff --git a/lit_tests/kernel/wave/gather_to_shared.py b/lit_tests/kernel/wave/gather_to_shared.py index b5c69447c..d29260717 100644 --- a/lit_tests/kernel/wave/gather_to_shared.py +++ b/lit_tests/kernel/wave/gather_to_shared.py @@ -145,17 +145,17 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: print(gemm.asm) # CHECK-LABEL: test_gather_to_shared_wave_tile_aligned_coalescing - # CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 16 + s1 * 2 - (s1 floordiv 8) * 16)> + # CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 2 - (s0 floordiv 8) * 16)> # CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> (s1 * 16 + s2 * 32 + (s0 floordiv 64) * 8 + (s0 mod 64) floordiv 8 - ((s1 * 16 + (s0 floordiv 64) * 8 + (s0 mod 64) floordiv 8) floordiv 32) * 32)> # CHECK: func.func @gemm - # CHECK: %[[BLOCK_ID_Y:.+]] = gpu.block_id y # CHECK: %[[TIDX:.+]] = gpu.thread_id x # CHECK: %[[TIDY:.+]] = gpu.thread_id y - # CHECK: %[[WAVE_ALIGNED_OFFSET:.+]] = affine.apply #[[MAP2]]()[%[[TIDX]], %[[TIDY]], %[[BLOCK_ID_Y]]] + # CHECK: affine.apply #[[MAP2]]()[%[[TIDX]], %[[TIDY]], %{{.*}}] + # CHECK: %[[TH_OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[TIDX]]] # CHECK: scf.for %[[IND_VAR:.+]] = %c0 # CHECK: amdgpu.lds_barrier - # CHECK: %[[UPDATE_OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[IND_VAR]], %[[TIDX]]] - # CHECK: %[[LHS:.+]] = arith.addi %{{.*}}, %[[UPDATE_OFFSET]] + # CHECK: %[[K_STRIDE:.+]] = arith.muli %[[IND_VAR]], %{{.*}} + # CHECK: %[[LHS:.+]] = arith.addi %{{.*}}, %[[K_STRIDE]] @run_test diff --git a/lit_tests/kernel/wave/gemm.py b/lit_tests/kernel/wave/gemm.py index 7115cc091..7169c5947 100644 --- a/lit_tests/kernel/wave/gemm.py +++ b/lit_tests/kernel/wave/gemm.py @@ -111,7 +111,6 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 4)> # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)> # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)> - # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 * 16 + ((s1 mod 64) floordiv 16) * 4)> # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (s0 * 32)> # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)> # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)> @@ -284,8 +283,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-DAG: #[[MAP_IDX_N:.+]] = affine_map<()[s0, s1, s2, s3] -> (s1 * 32 + s2 * 64 + s0 floordiv 4 - ((s1 * 32 + s0 floordiv 4) floordiv 64) * 64 + ((s2 + s3 * 8) floordiv 32) * 256 - (s2 floordiv 4) * 256)> # CHECK-DAG: %[[IDX_M_READ:.+]] = affine.apply #[[MAP_IDX_M]]()[%thread_id_x, %thread_id_y, %block_id_y, %block_id_x] # CHECK-DAG: %[[IDX_N_READ:.+]] = affine.apply #[[MAP_IDX_N]]()[%thread_id_x, %thread_id_y, %block_id_x, %block_id_y] - # CHECK-DAG: vector.load {{.*}}[%[[IDX_M_READ]], {{.*}}] - # CHECK-DAG: vector.load {{.*}}[%[[IDX_N_READ]], {{.*}}] + # CHECK-DAG: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> + # CHECK-DAG: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> # CHECK: amdgpu.mfma # CHECK: vector.store {{.*}} : memref<{{.*}}xf32{{.*}}>, vector<1xf32> @@ -565,11 +564,11 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index # CHECK-DAG: %[[C768:.+]] = arith.constant 768 : index # CHECK-DAG: %[[CST:.+]] = arith.constant dense<0> : vector<4xi32> - # CHECK-DAG: %[[GLOBAL_0:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [64, 64], strides: [64, 1] : memref to memref<64x64xi8, strided<[64, 1]>> - # CHECK-DAG: %[[GLOBAL_1:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [128, 64], strides: [64, 1] : memref to memref<128x64xi8, strided<[64, 1]>> # CHECK: %[[BASE_ALLOC:.+]] = memref.alloc() : memref<1536xi8, #gpu.address_space> # CHECK: %[[ALLOC_0:.+]] = memref.view %[[BASE_ALLOC]][%[[C0]]] # CHECK: %[[ALLOC_1:.+]] = memref.view %[[BASE_ALLOC]][%[[C768]]] + # CHECK: %[[GLOBAL_0:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref to memref<{{.*}}> + # CHECK: %[[GLOBAL_1:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref to memref<{{.*}}> # CHECK: scf.for %[[IVAR:.+]] = %[[C0]] to %[[C4]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[CST]]) -> (vector<4xi32>) { # CHECK: %[[REG_0:.+]] = vector.load %[[GLOBAL_0]] # CHECK: vector.store %[[REG_0]], %[[ALLOC_1]] @@ -639,11 +638,11 @@ def repeat(acc: tkl.Register[M, N, tkl.i32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index # CHECK-DAG: %[[C1280:.+]] = arith.constant 1280 : index # CHECK-DAG: %[[CST:.+]] = arith.constant dense<0> : vector<4xi32> - # CHECK-DAG: %[[GLOBAL_0:.+]] = memref.reinterpret_cast %0 to offset: [0], sizes: [64, 64], strides: [64, 1] : memref to memref<64x64xi8, strided<[64, 1]>> - # CHECK-DAG: %[[GLOBAL_1:.+]] = memref.reinterpret_cast %1 to offset: [0], sizes: [128, 64], strides: [64, 1] : memref to memref<128x64xi8, strided<[64, 1]>> # CHECK: %[[BASE_ALLOC:.+]] = memref.alloc() : memref<2560xi8, #gpu.address_space> # CHECK: %[[ALLOC_0:.+]] = memref.view %[[BASE_ALLOC]][%[[C0]]] # CHECK: %[[ALLOC_1:.+]] = memref.view %[[BASE_ALLOC]][%[[C1280]]] + # CHECK: %[[GLOBAL_0:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref to memref<{{.*}}> + # CHECK: %[[GLOBAL_1:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref to memref<{{.*}}> # CHECK: scf.for %[[IVAR:.+]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ACC:.+]] = %[[CST]]) -> (vector<4xi32>) { # CHECK: %[[REG_0:.+]] = vector.load %[[GLOBAL_0]] # CHECK: vector.store %[[REG_0]], %[[ALLOC_1]] @@ -726,7 +725,6 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # %[[IV_K:.+]] = affine.apply #[[MAP_IV_K]]()[%[[IV]], %[[TID_X]]] # CHECK-LABEL: test_packed_gemm - # CHECK-DAG: #[[MAP_IV_K:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + ((s1 mod 64) floordiv 16) * 2)> # CHECK: func.func @packed_gemm # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index # CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index @@ -736,11 +734,10 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: %[[RHS_SHARED:.+]] = memref.view %[[ALLOC]][%c0][] : memref<2560xi8, #gpu.address_space> to memref<32x10xi32, #gpu.address_space> # CHECK: %[[LHS_SHARED:.+]] = memref.view %[[ALLOC]][%c1280][] : memref<2560xi8, #gpu.address_space> to memref<32x10xi32, #gpu.address_space> # CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[C4]] step %[[C1]] - # CHECK: %[[IV_K:.+]] = affine.apply #[[MAP_IV_K]]()[%[[IV]], %[[TID_X]]] - # CHECK: %[[LHS_REG:.+]] = vector.load %{{.*}}[%{{.*}}, %[[IV_K]]] : memref<64x32xi32, strided<[32, 1]>>, vector<2xi32> + # CHECK: %[[LHS_REG:.+]] = vector.load {{.*}} : memref<{{.*}}>, vector<2xi32> # CHECK: amdgpu.lds_barrier # CHECK: vector.store %[[LHS_REG]], %[[LHS_SHARED]] - # CHECK: %[[RHS_REG:.+]] = vector.load %{{.*}}[%{{.*}}, %[[IV_K]]] : memref<128x32xi32, strided<[32, 1]>>, vector<2xi32> + # CHECK: %[[RHS_REG:.+]] = vector.load {{.*}} : memref<{{.*}}>, vector<2xi32> # CHECK: vector.store %[[RHS_REG]], %[[RHS_SHARED]] # CHECK: amdgpu.lds_barrier # CHECK-COUNT-2: vector.load {{.*}} : {{.*}}, vector<2xi32> @@ -810,7 +807,6 @@ def repeat( # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 4)> # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)> # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)> - # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 * 16 + ((s1 mod 64) floordiv 16) * 4)> # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (s0 * 32)> # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)> # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)> @@ -1343,11 +1339,11 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: scf.for # CHECK-COUNT-1: amdgpu.lds_barrier # Steady State Local Read - # CHECK-COUNT-4: vector.load %[[ALLOC_0]] - # CHECK-COUNT-4: vector.load %[[ALLOC_1]] + # CHECK-COUNT-4: vector.load %[[VIEW_0]] + # CHECK-COUNT-4: vector.load %[[VIEW_1]] # Steady State Global Read - # CHECK-COUNT-2: vector.load {{.*}} : memref<128x128xf16, strided<[128, 1]>>, vector<8xf16> + # CHECK-COUNT-2: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> # CHECK-COUNT-2: rocdl.sched.group.barrier # Compute @@ -1360,8 +1356,8 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: scf.yield # Prologue - # CHECK-COUNT-4: vector.load %[[ALLOC_0]] - # CHECK-COUNT-4: vector.load %[[ALLOC_1]] + # CHECK-COUNT-4: vector.load %[[VIEW_0]] + # CHECK-COUNT-4: vector.load %[[VIEW_1]] # CHECK-COUNT-8: amdgpu.mfma @@ -1738,7 +1734,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: rocdl.sched.barrier # 1st Cluster: Global load LHS - # CHECK-COUNT-2: vector.load {{.*}} : memref<4096x4096xf16, strided<[4096, 1]>>, vector<8xf16> + # CHECK-COUNT-2: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> # CHECK: rocdl.sched.barrier # 1st Cluster: Second slice of Local read lhs and rhs @@ -1749,7 +1745,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: rocdl.sched.barrier # 1st Cluster: Global load RHS - # CHECK-COUNT-4: vector.load {{.*}} : memref<4096x4096xf16, strided<[4096, 1]>>, vector<8xf16> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> # CHECK: rocdl.s.barrier # CHECK: rocdl.sched.barrier @@ -2296,12 +2292,12 @@ def repeat( # CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = #[[TRANSLATION:.+]]} { # CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index # CHECK-DAG: %[[WG_ID2:.*]] = gpu.block_id z - # CHECK: %[[LHS_GLOBAL:.*]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [1, 64, 64], strides: [4096, 64, 1] : memref to memref<1x64x64xf16, strided<[4096, 64, 1]>> - # CHECK: %[[RHS_GLOBAL:.*]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [6, 128, 64], strides: [8192, 64, 1] : memref to memref<6x128x64xf16, strided<[8192, 64, 1]>> # CHECK: %[[HKV_IDX:.+]] = affine.apply #[[MAP]]()[%[[WG_ID2]]] + # CHECK: %[[GLOBAL_LHS:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref to memref<{{.*}}> + # CHECK: %[[GLOBAL_RHS:.+]] = memref.reinterpret_cast %{{.*}} to offset: [0], sizes: [{{.*}}], strides: [{{.*}}] : memref to memref<{{.*}}> # CHECK: scf.for - # CHECK: %[[LHS_READ:.+]] = vector.load %[[LHS_GLOBAL]][%[[HKV_IDX]], %{{.+}}, {{.+}}] : {{.*}}, vector<8xf16> - # CHECK: %[[RHS_READ:.+]] = vector.load %[[RHS_GLOBAL]][%[[WG_ID2]], %{{.+}}, {{.+}}] : {{.*}}, vector<8xf16> + # CHECK: %[[LHS_READ:.+]] = vector.load %[[GLOBAL_LHS]] + # CHECK: %[[RHS_READ:.+]] = vector.load %[[GLOBAL_RHS]] # CHECK-COUNT-2: vector.extract_strided_slice # CHECK-COUNT-1: amdgpu.mfma # CHECK-COUNT-2: vector.extract_strided_slice @@ -2464,19 +2460,17 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-LABEL: test_explicit_shared_gemm # CHECK: func.func @gemm # CHECK-SAME: (%[[ARG0:.*]]: !stream.binding, %[[ARG1:.*]]: !stream.binding, %[[ARG2:.*]]: !stream.binding) - # CHECK-DAG: %[[GLOBAL_A:.+]] = memref.reinterpret_cast %{{.*}} to offset: [{{.*}}], sizes: [64, 64], strides: [64, 1] : memref to memref<64x64xf16{{.*}}> - # CHECK-DAG: %[[GLOBAL_B:.+]] = memref.reinterpret_cast %{{.*}} to offset: [{{.*}}], sizes: [128, 64], strides: [64, 1] : memref to memref<128x64xf16{{.*}}> # Verify explicit shared memory allocations (two separate allocs) # CHECK: %[[ALLOC_A:.+]] = memref.alloc() : memref<{{.*}}xf16, #gpu.address_space> # CHECK: %[[ALLOC_B:.+]] = memref.alloc() : memref<{{.*}}xf16, #gpu.address_space> # CHECK: scf.for # Verify load from global memory (A) - # CHECK: vector.load %[[GLOBAL_A]] + # CHECK: vector.load {{.*}} : memref<{{.*}}>, vector<{{.*}}> # Verify barrier before shared memory writes # CHECK: amdgpu.lds_barrier # Verify write to shared memory # CHECK: vector.store %{{.*}}, %[[ALLOC_A]] - # CHECK: vector.load %[[GLOBAL_B]] + # CHECK: vector.load {{.*}} : memref<{{.*}}>, vector<{{.*}}> # CHECK: vector.store %{{.*}}, %[[ALLOC_B]] # Verify barrier before shared memory reads # CHECK: amdgpu.lds_barrier diff --git a/lit_tests/kernel/wave/merge_scale_reads.py b/lit_tests/kernel/wave/merge_scale_reads.py index 5d56e5b73..3b41fdfc3 100644 --- a/lit_tests/kernel/wave/merge_scale_reads.py +++ b/lit_tests/kernel/wave/merge_scale_reads.py @@ -180,13 +180,13 @@ def test_preshuffle_scale_merge_block_k_256(): # CHECK-LABEL: test_preshuffle_scale_merge_block_k_256 # Each scale tensor produces 2 merged vector<4xi8> loads from global. - # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8> - # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8> - # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8> - # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[1]>>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[1]>>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[1]>>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[1]>>, vector<4xi8> # No unmerged scalar scale loads from global should remain. - # CHECK-NOT: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[{{.*}}, 1]>>, vector<1xi8> + # CHECK-NOT: vector.load %{{.*}} : memref<{{.*}}xi8, strided<[1]>>, vector<1xi8> # Check that amdgpu.scaled_mfma uses opsel (indexed access into scale values) # The key indicator is the [N] indexing syntax on f8E8M0FNU scale operands. Check %REG[1] as a simple check that we are doing a non-zero index diff --git a/lit_tests/kernel/wave/scaled_gemm.py b/lit_tests/kernel/wave/scaled_gemm.py index 7bf3b8b17..1fbd47bdd 100644 --- a/lit_tests/kernel/wave/scaled_gemm.py +++ b/lit_tests/kernel/wave/scaled_gemm.py @@ -95,19 +95,17 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: print(scaled_gemm.asm) # CHECK-LABEL: test_scaled_gemm_mxfp4 - # CHECK-DAG: #map = affine_map<()[s0, s1] -> (s0 + s1 * 32 - (s0 floordiv 16) * 16 + (s0 floordiv 64) * 16)> - # CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 16)> - # CHECK-DAG: #map2 = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16)> - # CHECK-DAG: #map3 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> - # CHECK-DAG: #map4 = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)> - # CHECK-DAG: #map5 = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)> - # CHECK-DAG: #map6 = affine_map<()[s0, s1] -> (s0 * 64 + ((s1 mod 64) floordiv 16) * 16)> - # CHECK-DAG: #map7 = affine_map<()[s0, s1] -> (s1 * 4 + (s0 mod 64) floordiv 16)> - # CHECK-DAG: #map8 = affine_map<()[s0] -> (s0 * 32)> - # CHECK-DAG: #map9 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)> - # CHECK-DAG: #map10 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)> - # CHECK-DAG: #map11 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 2)> - # CHECK-DAG: #map12 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 3)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 + s1 * 32 - (s0 floordiv 16) * 16 + (s0 floordiv 64) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (s0 * 32)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 2)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 3)> # CHECK: func.func @scaled_gemm # CHECK-COUNT-1: memref.alloc() # CHECK: scf.for @@ -195,21 +193,18 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: print(scaled_gemm.asm) # CHECK-LABEL: test_scaled_gemm_mxfp8 - # CHECK-DAG: #map = affine_map<()[s0, s1] -> (s0 + s1 * 32 - (s0 floordiv 16) * 16 + (s0 floordiv 64) * 16)> - # CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 16)> - # CHECK-DAG: #map2 = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16)> - # CHECK-DAG: #map3 = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16 + 64)> - # CHECK-DAG: #map4 = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> - # CHECK-DAG: #map5 = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)> - # CHECK-DAG: #map6 = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)> - # CHECK-DAG: #map7 = affine_map<()[s0, s1] -> (s0 * 128 + ((s1 mod 64) floordiv 16) * 16)> - # CHECK-DAG: #map8 = affine_map<()[s0, s1] -> (s0 * 128 + ((s1 mod 64) floordiv 16) * 16 + 64)> - # CHECK-DAG: #map9 = affine_map<()[s0, s1] -> (s1 * 4 + (s0 mod 64) floordiv 16)> - # CHECK-DAG: #map10 = affine_map<()[s0] -> (s0 * 32)> - # CHECK-DAG: #map11 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)> - # CHECK-DAG: #map12 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)> - # CHECK-DAG: #map13 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 2)> - # CHECK-DAG: #map14 = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 3)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 + s1 * 32 - (s0 floordiv 16) * 16 + (s0 floordiv 64) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (((s0 mod 64) floordiv 16) * 16 + 64)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 16 - (s0 floordiv 16) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (s0 mod 16 + (s0 floordiv 64) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 mod 64) floordiv 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0, s1] -> (s0 + s1 * 16 - (s0 floordiv 16) * 16)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> (s0 * 32)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 1)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 2)> + # CHECK-DAG: #{{.*}} = affine_map<()[s0] -> ((s0 floordiv 64) * 16 + ((s0 mod 64) floordiv 16) * 4 + 3)> # CHECK: func.func @scaled_gemm # CHECK-COUNT-1: memref.alloc() # CHECK: scf.for @@ -317,10 +312,10 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-LABEL: gemm_mxfp4_prefetch # Prologue Global Read - # CHECK-COUNT-4: vector.load {{.*}} : memref<1024x512xi8, strided<[512, 1]>>, vector<16xi8> - # CHECK: vector.load {{.*}} : memref<1024x32xi8, strided<[32, 1]>>, vector<4xi8> - # CHECK-COUNT-4: vector.load {{.*}} : memref<1024x512xi8, strided<[512, 1]>>, vector<16xi8> - # CHECK: vector.load {{.*}} : memref<1024x32xi8, strided<[32, 1]>>, vector<4xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<16xi8> + # CHECK: vector.load {{.*}} : memref<{{.*}}>, vector<4xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<16xi8> + # CHECK: vector.load {{.*}} : memref<{{.*}}>, vector<4xi8> # Prologue Local Write # CHECK-COUNT-4: vector.store {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> @@ -332,22 +327,22 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: scf.for # Steady State global_load_rhs_scale - # CHECK: vector.load %{{.*}} : memref<1024x32xi8, strided<[32, 1]>>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}>, vector<4xi8> # Steady State local_load_rhs_scale # CHECK=COUNT-16: vector.load %{{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> # Steady State global_load_lhs_scale - # CHECK: vector.load %{{.*}} : memref<1024x32xi8, strided<[32, 1]>>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}>, vector<4xi8> # Steady State local_load_lhs_scale # CHECK=COUNT-16: vector.load %{{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> # Steady State global_load_rhs - # CHECK-COUNT-4: vector.load %{{.*}} : memref<1024x512xi8, strided<[512, 1]>>, vector<16xi8> + # CHECK-COUNT-4: vector.load %{{.*}} : memref<{{.*}}>, vector<16xi8> # Steady State local_load_rhs # CHECK=COUNT-16: vector.load %{{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> # Steady State global_load_lhs - # CHECK-COUNT-4: vector.load %{{.*}} : memref<1024x512xi8, strided<[512, 1]>>, vector<16xi8> + # CHECK-COUNT-4: vector.load %{{.*}} : memref<{{.*}}>, vector<16xi8> # Steady State local_load_lhs # CHECK=COUNT-16: vector.load %{{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> @@ -585,12 +580,12 @@ def repeat( # Prologue Global Read # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1], offset: ?>> # CHECK: amdgpu.fat_raw_buffer_cast %{{.*}} validBytes(%{{.*}}) cacheSwizzleStride(%[[C512_I14]]) resetOffset : memref> to memref> - # CHECK-COUNT-4: vector.load {{.*}} : memref>, vector<16xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<16xi8> # CHECK: memref.reinterpret_cast %{{.*}} to offset: [%{{.*}}], sizes: [2147483646], strides: [1] : memref to memref<2147483646xi8, strided<[1], offset: ?>> # CHECK: amdgpu.fat_raw_buffer_cast %{{.*}} validBytes(%{{.*}}) cacheSwizzleStride(%[[C32_I14]]) resetOffset : memref> to memref> - # CHECK: vector.load {{.*}} : memref>, vector<4xi8> - # CHECK-COUNT-4: vector.load {{.*}} : memref>, vector<16xi8> - # CHECK: vector.load {{.*}} : memref>, vector<4xi8> + # CHECK: vector.load {{.*}} : memref<{{.*}}>, vector<4xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<16xi8> + # CHECK: vector.load {{.*}} : memref<{{.*}}>, vector<4xi8> # Prologue Linearize shared memory + Local Write # CHECK: memref.reinterpret_cast {{.*}} to offset: [0], sizes: [34816], strides: [1] : memref<1x256x136xi8, #gpu.address_space> to memref<34816xi8, #gpu.address_space> @@ -606,22 +601,22 @@ def repeat( # CHECK: scf.for # Steady State global_load_rhs_scale - # CHECK: vector.load %{{.*}} : memref>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}>, vector<4xi8> # Steady State local_load_rhs_scale # CHECK=COUNT-16: vector.load %{{.*}} : memref<4096xi8, #gpu.address_space>, vector<1xi8> # Steady State global_load_lhs_scale - # CHECK: vector.load %{{.*}} : memref>, vector<4xi8> + # CHECK: vector.load %{{.*}} : memref<{{.*}}>, vector<4xi8> # Steady State local_load_lhs_scale # CHECK=COUNT-16: vector.load %{{.*}} : memref<4096xi8, #gpu.address_space>, vector<1xi8> # Steady State global_load_rhs - # CHECK-COUNT-4: vector.load %{{.*}} : memref>, vector<16xi8> + # CHECK-COUNT-4: vector.load %{{.*}} : memref<{{.*}}>, vector<16xi8> # Steady State local_load_rhs # CHECK=COUNT-16: vector.load %{{.*}} : memref<34816xi8, #gpu.address_space>, vector<16xi8> # Steady State global_load_lhs - # CHECK-COUNT-4: vector.load %{{.*}} : memref>, vector<16xi8> + # CHECK-COUNT-4: vector.load %{{.*}} : memref<{{.*}}>, vector<16xi8> # Steady State local_load_lhs # CHECK=COUNT-16: vector.load %{{.*}} : memref<34816xi8, #gpu.address_space>, vector<16xi8> @@ -736,12 +731,12 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-DAG: %[[SCALED_LOGITS_BOUND:.+]] = arith.constant dense<96> : vector<16xindex> # CHECK: scf.for %{{.*}} = %[[C0]] to %[[C2]] step %[[C1]] # CHECK: %[[SCALED_LOGITS_MASK:.+]] = arith.cmpi slt, %{{.*}}, %[[SCALED_LOGITS_BOUND]] : vector<16xindex> - # CHECK: vector.maskedload {{.*}}, %[[SCALED_LOGITS_MASK]], %[[CST_0]] : memref<1024x96xi8, strided<[96, 1]>>, vector<16xi1>, vector<16xi8> into vector<16xi8> + # CHECK: vector.maskedload {{.*}}, %[[SCALED_LOGITS_MASK]], %[[CST_0]] : memref<{{.*}}>, vector<16xi1>, vector<16xi8> into vector<16xi8> # CHECK: %[[SCALED_SCALES_MASK_VAL:.+]] = arith.cmpi slt, %{{.*}}, %[[SCALED_SCALES_BOUND]] : index # CHECK: %[[SCALED_SCALES_MASK:.+]] = vector.broadcast %[[SCALED_SCALES_MASK_VAL]] : i1 to vector<1xi1> - # CHECK: vector.maskedload {{.*}}, %[[SCALED_SCALES_MASK]], %[[CST]] : memref<1024x6xi8, strided<[6, 1]>>, vector<1xi1>, vector<1xi8> into vector<1xi8> - # CHECK: vector.maskedload {{.*}}, %[[SCALED_LOGITS_MASK]], %[[CST_0]] : memref<1024x96xi8, strided<[96, 1]>>, vector<16xi1>, vector<16xi8> into vector<16xi8> - # CHECK: vector.maskedload {{.*}}, %[[SCALED_SCALES_MASK]], %[[CST]] : memref<1024x6xi8, strided<[6, 1]>>, vector<1xi1>, vector<1xi8> into vector<1xi8> + # CHECK: vector.maskedload {{.*}}, %[[SCALED_SCALES_MASK]], %[[CST]] : memref<{{.*}}>, vector<1xi1>, vector<1xi8> into vector<1xi8> + # CHECK: vector.maskedload {{.*}}, %[[SCALED_LOGITS_MASK]], %[[CST_0]] : memref<{{.*}}>, vector<16xi1>, vector<16xi8> into vector<16xi8> + # CHECK: vector.maskedload {{.*}}, %[[SCALED_SCALES_MASK]], %[[CST]] : memref<{{.*}}>, vector<1xi1>, vector<1xi8> into vector<1xi8> # CHECK: amdgpu.scaled_mfma # CHECK: scf.yield # CHECK: } diff --git a/lit_tests/kernel/wave/scaled_mma.py b/lit_tests/kernel/wave/scaled_mma.py index f37dae9b5..d681534f6 100644 --- a/lit_tests/kernel/wave/scaled_mma.py +++ b/lit_tests/kernel/wave/scaled_mma.py @@ -347,10 +347,10 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index # CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index # CHECK: scf.for %{{.*}} = %[[C0]] to %[[C64]] step %[[C1]] - # CHECK-COUNT-4: vector.load {{.*}} : memref<16384x8192xi8, strided<[8192, 1]>>, vector<16xi8> - # CHECK-COUNT-1: vector.load {{.*}} : memref<16384x512xi8, strided<[512, 1]>>, vector<4xi8> - # CHECK-COUNT-4: vector.load {{.*}} : memref<16384x8192xi8, strided<[8192, 1]>>, vector<16xi8> - # CHECK-COUNT-1: vector.load {{.*}} : memref<16384x512xi8, strided<[512, 1]>>, vector<4xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<16xi8> + # CHECK-COUNT-1: vector.load {{.*}} : memref<{{.*}}>, vector<4xi8> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<16xi8> + # CHECK-COUNT-1: vector.load {{.*}} : memref<{{.*}}>, vector<4xi8> # CHECK: amdgpu.lds_barrier # CHECK-COUNT-16: vector.load {{.*}} : memref<256x16xi8, #gpu.address_space>, vector<1xi8> # CHECK-COUNT-16: vector.load {{.*}} : memref<256x136xi8, #gpu.address_space>, vector<16xi8> diff --git a/lit_tests/kernel/wave/wave_schedule.py b/lit_tests/kernel/wave/wave_schedule.py index ee4c21a3c..40e02ee12 100644 --- a/lit_tests/kernel/wave/wave_schedule.py +++ b/lit_tests/kernel/wave/wave_schedule.py @@ -46,7 +46,7 @@ def test_gemm_with_wave_schedule(): # CHECK-COUNT-4: vector.load %[[VIEW_1]] # Steady State Global Read - # CHECK-COUNT-2: vector.load {{.*}} : memref<128x128xf16, strided<[128, 1]>>, vector<8xf16> + # CHECK-COUNT-2: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> # CHECK-COUNT-2: rocdl.sched.group.barrier # Compute @@ -111,7 +111,7 @@ def test_gemm_prefetch_reorder_stagger(): # CHECK: rocdl.sched.barrier # 1st Cluster: Global load LHS - # CHECK-COUNT-2: vector.load {{.*}} : memref<4096x4096xf16, strided<[4096, 1]>>, vector<8xf16> + # CHECK-COUNT-2: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> # CHECK: rocdl.sched.barrier # 1st Cluster: Second slice of Local read lhs and rhs @@ -122,7 +122,7 @@ def test_gemm_prefetch_reorder_stagger(): # CHECK: rocdl.sched.barrier # 1st Cluster: Global load RHS - # CHECK-COUNT-4: vector.load {{.*}} : memref<4096x4096xf16, strided<[4096, 1]>>, vector<8xf16> + # CHECK-COUNT-4: vector.load {{.*}} : memref<{{.*}}>, vector<8xf16> # CHECK: rocdl.s.barrier # CHECK: rocdl.sched.barrier diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 147b27704..899d8de21 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -75,6 +75,7 @@ get_type_or_element_type, handle_op, ) +from ...wave.constraints import TilingConstraint def _get_start_index(i: IndexSequence | IndexExpr) -> IndexExpr: @@ -647,6 +648,159 @@ def extract(vec, ind): return +def _get_or_create_flat_memref( + emitter: WaveEmitter, + mem: Value, +) -> Value: + """Return a rank-1 view of *mem* with offset 0 (pure shape change). + + All reads from the same source buffer share one reinterpret_cast, + so the backend maps them all to a single SRD — no per-read SRD copies. + """ + if not hasattr(emitter, "_flat_memref_cache"): + emitter._flat_memref_cache = {} + key = id(mem) + if key in emitter._flat_memref_cache: + return emitter._flat_memref_cache[key] + + kb_type = MemRefType(mem.type) + max_buf = _get_max_buffer_size(kb_type.element_type) - 1 + result_type = MemRefType.get( + [max_buf], + kb_type.element_type, + layout=Attribute.parse("strided<[1], offset: 0>"), + memory_space=kb_type.memory_space, + ) + flat = memref_d.reinterpret_cast( + result_type, + mem, + offsets=[], + sizes=[], + strides=[], + static_offsets=[0], + static_sizes=[max_buf], + static_strides=[1], + ) + emitter._flat_memref_cache[key] = flat + return flat + + +def _try_iv_split_offset( + emitter: WaveEmitter, + index: dict[IndexExpr, IndexSequence | IndexExpr], + strides: list[int], + dynamic_vals: dict[IndexExpr, Any], + use_subs_idxc: bool = False, +) -> Optional[Value]: + """Compute a hoisted IV-split linearized offset for a loop-carried read. + + Returns the MLIR Value ``hoisted_voffset + IV * k_stride`` if the index + expressions are provably affine in the loop IV, or ``None`` to fall back + to the default address path. + + The caller is responsible for emitting the actual load/gather using the + returned offset. + + Parameters + ---------- + strides : per-dimension integer strides for linearisation. + use_subs_idxc : if True, apply ``subs_idxc`` before simplification + (needed when expressions contain residual shape symbols). + """ + ip = InsertionPoint.current + owner = ip.block.owner + if isinstance(owner, func_d.FuncOp): + return None + if owner.name != "scf.for": + return None + + # Find the IV symbol for this scf.for directly from its block argument. + current_iv = owner.induction_variable + + # do a reverse lookup of the dimension/symbol that the current IV is associated with + dim = next((d for d, v in emitter.induction_vars.items() if v == current_iv), None) + if dim is None: + return None + iv_sym = next( + ( + c.induction_var + for c in emitter.constraints + if isinstance(c, TilingConstraint) and c.dim == dim + ), + None, + ) + if iv_sym is None: + return None + + step_int = _get_constant_value(owner.operands[2]) + if step_int is None or step_int <= 0: + return None + + start_exprs = _get_start_indices(index) + if len(start_exprs) != len(strides): + return None + + # Phase 1: Symbolic linearity proof w.r.t. the current loop's IV only. + # substitute IV = step*_j and check + # that the linearized index is c*_j + remainder (no _j in remainder). + _j = sympy.Symbol("_j", integer=True, nonnegative=True) + iv_as_j = step_int * _j + lin_sym = sympy.Integer(0) + for expr, ps in zip(start_exprs, strides): + e = safe_subs(expr, {iv_sym: iv_as_j}) + if use_subs_idxc: + e = subs_idxc(e) + e = sympy.simplify(e) + lin_sym += e * ps + lin_sym = sympy.simplify(lin_sym) + + coeff = lin_sym.coeff(_j) + remainder = sympy.simplify(lin_sym - coeff * _j) + if not coeff.is_Integer or coeff == 0 or _j in remainder.free_symbols: + return None + k_stride_per_iv, rem = divmod(int(coeff), step_int) + if rem != 0: + return None + + # Phase 2: Substitute IV=0 to get the loop-invariant base offset. + iv_zero_subs = {iv_sym: 0} + index_no_iv = {} + for dim, seq in index.items(): + start = _get_start_index(seq) + new_start = safe_subs(start, iv_zero_subs) + if isinstance(seq, IndexSequence): + index_no_iv[dim] = IndexSequence(new_start, seq.size) + else: + index_no_iv[dim] = new_start + + # Emit the hoisted linearized offset BEFORE the scf.for. + hoist_ip = InsertionPoint(owner) + subs_map = add_emitter_subs(emitter, dynamic_vals) + overflow_flags = arith_d.IntegerOverflowFlags.nsw + + with hoist_ip: + iv0_exprs = _get_start_indices(index_no_iv) + lin_offset = None + for expr, ps in zip(iv0_exprs, strides): + val = gen_sympy_index(subs_map, expr) + stride_c = arith_d.constant(IndexType.get(), ps) + term = arith_d.muli(val, stride_c, overflow_flags=overflow_flags) + lin_offset = ( + term + if lin_offset is None + else arith_d.addi(lin_offset, term, overflow_flags=overflow_flags) + ) + + # Back inside the loop: total = hoisted_base + IV * k_stride. + iv_mlir = subs_map.get(iv_sym) + if iv_mlir is None: + return None + + k_stride_val = arith_d.constant(IndexType.get(), k_stride_per_iv) + iv_offset = arith_d.muli(iv_mlir, k_stride_val, overflow_flags=overflow_flags) + return arith_d.addi(lin_offset, iv_offset, overflow_flags=overflow_flags) + + def _build_mask_with_mapping( emitter: WaveEmitter, mapping: IndexMapping, @@ -736,11 +890,43 @@ def handle_read(emitter: WaveEmitter, node: fx.Node): else: mask = _build_mask(emitter, index, elements_per_thread, bounds) + is_global = get_custom(memory).type.address_space != SHARED_ADDRESS_SPACE + use_llvm_load = flags != MemoryAccessFlags.NONE + + # IV-split fast path for global reads: hoist address before the loop. + if ( + is_global + and mask is None + and not use_llvm_load + and not read_meets_hw_transpose_requirements( + get_custom(node), emitter.constraints, emitter.options.target + ) + ): + kb_type = MemRefType(kb_src.type) + phys_strides, _ = kb_type.get_strides_and_offset() + dyn_sentinel = ShapedType.get_dynamic_stride_or_offset() + if not any(s == dyn_sentinel for s in phys_strides): + total_offset = _try_iv_split_offset( + emitter, + index, + list(phys_strides), + dynamic_vals_map_start, + ) + if total_offset is not None: + # Load from a shared flat rank-1 view (one SRD per buffer). + ip = InsertionPoint.current + owner = ip.block.owner + hoist_ip = InsertionPoint(owner) + with hoist_ip: + flat_mem = _get_or_create_flat_memref(emitter, kb_src) + result = vector_d.load(vector_type, flat_mem, [total_offset]) + emitter.bind_node_proxy(node, IRProxyValue(result)) + return + start_indices, start_indices_wg, start_indices_th = _build_start_indices( emitter, index, dynamic_vals_map_start ) - use_llvm_load = flags != MemoryAccessFlags.NONE if use_llvm_load: result = _create_llvm_read_write( kb_src, kb_ir_type, start_indices, vector_type, flags @@ -1095,15 +1281,10 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): store_type = VectorType.get((elements_per_thread,), element_type) - src_index, src_index_wg, src_index_th = _build_start_indices( - emitter, new_src_idx, src_dynamic_vals_map_start - ) - ip = InsertionPoint.current induction_vars = set(emitter.get_induction_vars_and_syms()[1]) - # Hoist to the function level, if not using induction variables. if not any( induction_vars.intersection(set(index.start.free_symbols)) for index in dst_idx.values() @@ -1115,23 +1296,47 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): dst_index, _, _ = _build_start_indices( emitter, dst_idx, dst_dynamic_vals_map_start ) - # We are indexing shared mem so i32 is enough. i32 = IntegerType.get_signless(32) dst_index = [assume_index_subgroup_uniform(idx, i32) for idx in dst_index] - strides = strides_from_symbolic_shape( + # Symbolic strides shared by iv-split and fallback linearization. + sym_stride_vals = strides_from_symbolic_shape( IndexingContext.current(), src_symbolic_shape, allow_mixed_shapes=True ) - strides = [ - gen_sympy_index(add_emitter_subs(emitter, src_dynamic_vals_map_start), s) - for s in strides - ] + subs_map = add_emitter_subs(emitter, src_dynamic_vals_map_start) + strides = [gen_sympy_index(subs_map, s) for s in sym_stride_vals] - src, offset_th = _linearize_memref(src, src_index_wg, src_index_th, strides) - src = _cast_buffer_and_encode_stride(src, strides, element_type, emitter) + # IV-split: try hoisting the src offset before the loop. + try: + sym_strides_int = [int(subs_idxc(s)) for s in sym_stride_vals] + except (TypeError, ValueError): + sym_strides_int = [] + + src_offset = None + if sym_strides_int: + src_offset = _try_iv_split_offset( + emitter, + new_src_idx, + sym_strides_int, + src_dynamic_vals_map_start, + use_subs_idxc=True, + ) + + if src_offset is not None: + # IV-split path: offset=0 reinterpret_cast, full address in src_offset. + zero_indices = [arith_d.constant(IndexType.get(), 0)] * len(strides) + lin_src, _ = _linearize_memref(src, zero_indices, zero_indices, strides) + else: + # Fallback: wg offset baked into memref base, th offset as voffset. + src_index, src_index_wg, src_index_th = _build_start_indices( + emitter, new_src_idx, src_dynamic_vals_map_start + ) + lin_src, src_offset = _linearize_memref( + src, src_index_wg, src_index_th, strides + ) + + lin_src = _cast_buffer_and_encode_stride(lin_src, strides, element_type, emitter) - # We previously checked mask is same for all elements, so we can use - # elements_per_thread=1 to build the mask. mask = _build_mask( emitter, src_idx, @@ -1143,13 +1348,11 @@ def handle_gather_to_lds(emitter: WaveEmitter, node: fx.Node): mask = vector_d.extract(mask, static_position=[0], dynamic_position=[]) oob_index_value = _get_out_of_bounds_index(element_type) oob_index = arith_d.constant(IndexType.get(), oob_index_value) - offset_th = arith_d.select(mask, offset_th, oob_index) - - src_index = [offset_th] + src_offset = arith_d.select(mask, src_offset, oob_index) amdgpu_d.gather_to_lds( - src=src, - src_indices=src_index, + src=lin_src, + src_indices=[src_offset], dst=dst, dst_indices=dst_index, transfer_type=store_type, diff --git a/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py b/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py index 002572543..eea7e950d 100644 --- a/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_triple_buffer.py @@ -150,6 +150,8 @@ def async_two_cluster_three_stage_schedule(): unroll_factor = 2 tkw.unroll(pipeline_loop.KERNEL, unroll_factor) + tkw.insert_after(pipeline_loop.KERNEL, tkw.MemoryCounterWaitBarrier(load=0)) + return async_two_cluster_three_stage_schedule