From 888fa1aa032c330ebfea5714471ee540c6e0870c Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Wed, 23 Apr 2025 10:39:47 -0700 Subject: [PATCH 1/7] init PR --- sgl-kernel/csrc/cpu/gemm.h | 3 +- sgl-kernel/csrc/cpu/moe.cpp | 8 +- sgl-kernel/csrc/cpu/moe_int8.cpp | 246 ++++++++++++++++++++++++++----- 3 files changed, 220 insertions(+), 37 deletions(-) diff --git a/sgl-kernel/csrc/cpu/gemm.h b/sgl-kernel/csrc/cpu/gemm.h index daab64e05b3a..584a91bf2a44 100644 --- a/sgl-kernel/csrc/cpu/gemm.h +++ b/sgl-kernel/csrc/cpu/gemm.h @@ -10,7 +10,7 @@ // block size for AMX gemm constexpr int block_size_m() { return 2 * TILE_M; } constexpr int block_size_n() { return 2 * TILE_N; } - +constexpr int split_k_num() { return 4; } // define threshold using brgemm (intel AMX) template inline bool can_use_brgemm(int M); template <> inline bool can_use_brgemm(int M) { return M > 4; } @@ -98,6 +98,7 @@ void shared_expert_int8_kernel_impl( scalar_t* __restrict__ output, scalar_t* __restrict__ ic1, float* __restrict__ C_tmp, + float* __restrict__ C_splitk_tmp, uint8_t* __restrict__ Aq_tmp, float* __restrict__ As_tmp, const scalar_t* __restrict__ input, diff --git a/sgl-kernel/csrc/cpu/moe.cpp b/sgl-kernel/csrc/cpu/moe.cpp index f596d015129b..b3b407947f15 100644 --- a/sgl-kernel/csrc/cpu/moe.cpp +++ b/sgl-kernel/csrc/cpu/moe.cpp @@ -1237,7 +1237,7 @@ at::Tensor shared_expert_cpu( // for fp8 w8a16: // 5. intermediate_cache0 : [M, 2N] // - int num_threads = at::get_num_threads(); + int64_t num_threads = at::get_num_threads(); int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); if (use_int8_w8a8) { @@ -1246,11 +1246,14 @@ at::Tensor shared_expert_cpu( if (use_fp8_w8a16) { buffer_size_nbytes += M * 2 * N * 2; } - + // int spiltk_num = 4; + int spiltk = split_k_num(); auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); + auto buffer_splitk = at::empty({spiltk * 2 * M * N + num_threads * 2 * BLOCK_M * BLOCK_N}, hidden_states.options().dtype(at::kFloat)); AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "share_experts_kernel_impl", [&] { scalar_t* __restrict__ intermediate_cache1 = (scalar_t*)((void*)(buffer.data_ptr())); float* __restrict__ C_tmp = (float*)((void*)(intermediate_cache1 + M * N)); + float* __restrict__ C_spiltk_tmp = (float*)buffer_splitk.data_ptr(); if (use_int8_w8a8) { uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); @@ -1265,6 +1268,7 @@ at::Tensor shared_expert_cpu( out_hidden_states.data_ptr(), intermediate_cache1, C_tmp, + C_spiltk_tmp, Aq_tmp, As_tmp, hidden_states.data_ptr(), diff --git a/sgl-kernel/csrc/cpu/moe_int8.cpp b/sgl-kernel/csrc/cpu/moe_int8.cpp index 4ebbead1d2ab..5bcc5d679001 100644 --- a/sgl-kernel/csrc/cpu/moe_int8.cpp +++ b/sgl-kernel/csrc/cpu/moe_int8.cpp @@ -8,11 +8,29 @@ template inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { using Vec = at::vec::Vectorized; // no remainder + int64_t d; #pragma GCC unroll 4 - for (int64_t d = 0; d < size; d += Vec::size()) { + for (d = 0; d < size; d += Vec::size()) { Vec data = Vec::loadu(input + d); data.store(out + d); } + for (; d < size; ++d) { + out[d] = static_cast(input[d]); + } +} + +template +inline void _add_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d) + Vec::loadu(out + d); + data.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d]+out[d]); + } } template <> @@ -278,7 +296,7 @@ struct tinygemm_kernel_vnni2 { static inline void apply( const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, - int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool need_comp) { TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); } }; @@ -289,7 +307,7 @@ struct tinygemm_kernel_vnni2 { static inline void apply( const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, - int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool need_comp) { constexpr int ROWS = BLOCK_M; constexpr int COLS = BLOCK_N / 16; @@ -343,11 +361,14 @@ struct tinygemm_kernel_vnni2 { if constexpr (col % 2 == 0) { vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16); vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); - vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); - vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); + if(need_comp){ + vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); + vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); + } } } - __m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col])); + auto x_ = need_comp? _mm512_sub_epi32(vc[i], vcomp[col]): vc[i]; + __m512 x = _mm512_cvtepi32_ps(x_); x = _mm512_mul_ps(_mm512_mul_ps(x, vas), vbs[col]); _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x); }; @@ -360,7 +381,7 @@ struct tinygemm_kernel_vnni2 { tinygemm_kernel_vnni2::apply( \ A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \ As + mb_start, Bs + nb_start, Bcomp + nb_start, \ - K, lda, ldb, ldc); + K, lda, ldb, ldc, need_comp); template void tinygemm_kernel( @@ -375,7 +396,7 @@ void tinygemm_kernel( int64_t lda, int64_t ldb, int64_t ldc) { - + bool need_comp = true; // B compensation const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * K); @@ -402,6 +423,48 @@ void tinygemm_kernel( } } +template +void tinygemm_kernel_splitk( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + const int8_t* __restrict__ B_ori, + bool need_comp, + float* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + // B compensation + const int32_t* Bcomp = reinterpret_cast(B_ori + block_size_n() * K * split_k_num()); + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int64_t mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + } // anonymous namespace template @@ -602,11 +665,63 @@ void fused_experts_int8_kernel_impl( INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16); INSTANTIATE_MOE_INT8_TEMPLATE(at::Half); +// silu : shape leading dimension +// input0 [m_size, BLOCK_N] BLOCK_N +// input1 [m_size, BLOCK_N] BLOCK_N +// output [M , N] N +template +inline void silu_and_mul( + scalar_t* __restrict__ output, + const float* __restrict__ input0, // x: x0, x1 + const float* __restrict__ input1, // y: y0, y1 + int64_t m_size, + int64_t N, + int64_t M, + int64_t splitk_size) { + + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + const fVec one = fVec(1.f); + + // no remainder + for (int64_t m = 0; m < m_size; ++m) { + scalar_t* __restrict__ out = output + m * N; + const float* __restrict__ x = input0 + m * N; + const float* __restrict__ y = input1 + m * N; + + for (int64_t d = 0; d < BLOCK_N; d += bVec::size()) { + fVec x0 = fVec::loadu(x + d); + fVec x1 = fVec::loadu(x + d + fVec::size()); + fVec y0 = fVec::loadu(y + d); + fVec y1 = fVec::loadu(y + d + fVec::size()); + for(int64_t id=1; id< splitk_size; id++){ + x0 += fVec::loadu(x + id*2*M*N+ d); + x1 += fVec::loadu(x + id*2*M*N+ d + fVec::size()); + y0 += fVec::loadu(y + id*2*M*N+ d); + y1 += fVec::loadu(y + id*2*M*N+ d + fVec::size()); + } + + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + // convert + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + } +} + + template void shared_expert_int8_kernel_impl( scalar_t* __restrict__ output, scalar_t* __restrict__ ic1, float* __restrict__ C_tmp, + float* __restrict__ C_splitk_tmp, uint8_t* __restrict__ Aq_tmp, float* __restrict__ As_tmp, const scalar_t* __restrict__ input, @@ -619,7 +734,6 @@ void shared_expert_int8_kernel_impl( int64_t M, int64_t N, int64_t K) { - // handle 2 tiles per block constexpr int64_t BLOCK_M = block_size_m(); constexpr int64_t BLOCK_N = block_size_n(); @@ -635,10 +749,10 @@ void shared_expert_int8_kernel_impl( } }); - // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) const int64_t MB = div_up(M, BLOCK_M); const int64_t NB = div_up(N, BLOCK_N); - + const int64_t KB = K/split_k_num(); TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); // K and N are packed for int8 @@ -647,45 +761,109 @@ void shared_expert_int8_kernel_impl( const int64_t stride_n = packed_K; // here we only parallel on half of 2N to fuse silu_and_mul with gemm - at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + at::parallel_for(0, split_k_num() * MB * NB, 0, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + float* __restrict__ C0 = C_splitk_tmp + split_k_num() * 2 * M * N + tid * 2 * BLOCK_M * BLOCK_N; + float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; for (int64_t i = begin; i < end; ++i) { - int64_t mb = i / NB; - int64_t nb = i % NB; - + int64_t kb = i / MB / NB; + int64_t mb = (i - kb * MB * NB) / NB; + int64_t nb = (i - kb * MB * NB) % NB; + float* __restrict__ C0_res = C_splitk_tmp + kb * 2 * M * N; + float* __restrict__ C1_res = C0_res + M * N; // nb0 from top half and nb1 from bottom half int64_t nb0 = nb, nb1 = nb + NB; int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); - // A shape [m_size, K] + // A shape [m_size, K] -> [m_size, KB] const uint8_t* A = Aq_tmp + mb * BLOCK_M * K; + // A scale is [m_size, 1] const float* As = As_tmp + mb * BLOCK_M; - // B shape [K, n_size] in vnni format + // // B shape [K, n_size] in vnni format -> [KB, n_size] + // B shape [n_size, K] in vnni format -> [n_size, KB] const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; + // B scale is [n_size, 1] const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N; const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N; - // fused 1.b: silu_and_mul(A @ B0, A @ B1) - tinygemm_kernel( - /* A */ A, - /* B0 */ B0, - /* B1 */ B1, - /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, - /* As */ As, - /* Bs0 */ Bs0, - /* Bs1 */ Bs1, - /* M */ m_size, - /* N */ n_size, - /* K */ K, - /* lda */ K, - /* ldb */ n_size, - /* ldc */ N); + // // fused 1.b: silu_and_mul(A @ B0, A @ B1) + // tinygemm_kernel( + // /* A */ A, + // /* B0 */ B0, + // /* B1 */ B1, + // /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + // /* As */ As, + // /* Bs0 */ Bs0, + // /* Bs1 */ Bs1, + // /* M */ m_size, + // /* N */ n_size, + // /* K */ K, + // /* lda */ K, + // /* ldb */ n_size, + // /* ldc */ N); + tinygemm_kernel_splitk( + /* A */ A + kb * KB , + /* B */ B0 + kb * KB * BLOCK_N, + /* Bcomp */ B0, + kb == 0, + /* C */ C0, + /* As */ As, + /* Bs */ Bs0, + /* M */ m_size, + /* N */ n_size, + /* K */ KB, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + + tinygemm_kernel_splitk( + /* A */ A + kb * KB, + /* B */ B1 + kb * KB * BLOCK_N, + /* Bcomp */ B1, + kb == 0, + /* C */ C1, + /* As */ As, + /* Bs */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ KB, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + + for (int mid = 0; mid< m_size; mid++){ + copy_stub(C0_res + mb*BLOCK_M*N + nb*BLOCK_N + mid * N, C0 + mid * BLOCK_N, n_size); + copy_stub(C1_res + mb*BLOCK_M*N + nb*BLOCK_N + mid * N, C1 + mid * BLOCK_N, n_size); + } } }); - // stage 1.5: quantize ic1 to uint8, [M * topk, N] + + // for(int mid = 0; mid < M; mid++){ + // for(int k_id=1; k_id < split_k_num(); k_id++){ + // _add_stub(C_splitk_tmp+mid*N, C_splitk_tmp+k_id*2*M*N+mid*N, N); + // _add_stub(C_splitk_tmp+M*N+ mid*N, C_splitk_tmp+k_id*2*M*N+M*N+mid*N, N); + // } + // } + + // std::cout<(ic1 + mb*BLOCK_M*N + nb*BLOCK_N, C0, C1, m_size, N, M, split_k_num()); + } + }); + + // stage 1.5: quantize ic1 to uint8, [M, N] at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { for (int64_t m = begin; m < end; ++m) { quantize_row_int8( @@ -753,7 +931,7 @@ void shared_expert_int8_kernel_impl( #define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE) \ template void shared_expert_int8_kernel_impl ( \ TYPE* __restrict__ output, TYPE* __restrict__ ic1, \ - float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \ + float* __restrict__ C_tmp, float* __restrict__ C_splitk_tmp, uint8_t* __restrict__ Aq_tmp, \ float* __restrict__ As_tmp, const TYPE* __restrict__ input, \ const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \ const float* __restrict__ w1s, const float* __restrict__ w2s, \ From 368441d3600b24fdcbf23b0324d03c10d1989261 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Wed, 23 Apr 2025 18:37:47 -0700 Subject: [PATCH 2/7] refactor code --- sgl-kernel/csrc/cpu/gemm.h | 2 +- sgl-kernel/csrc/cpu/moe.cpp | 12 +-- sgl-kernel/csrc/cpu/moe_int8.cpp | 122 ++++++++++++++----------------- 3 files changed, 61 insertions(+), 75 deletions(-) diff --git a/sgl-kernel/csrc/cpu/gemm.h b/sgl-kernel/csrc/cpu/gemm.h index 584a91bf2a44..fcb5a78c710e 100644 --- a/sgl-kernel/csrc/cpu/gemm.h +++ b/sgl-kernel/csrc/cpu/gemm.h @@ -10,7 +10,7 @@ // block size for AMX gemm constexpr int block_size_m() { return 2 * TILE_M; } constexpr int block_size_n() { return 2 * TILE_N; } -constexpr int split_k_num() { return 4; } +constexpr int get_splitk_num() { return 4; } // define threshold using brgemm (intel AMX) template inline bool can_use_brgemm(int M); template <> inline bool can_use_brgemm(int M) { return M > 4; } diff --git a/sgl-kernel/csrc/cpu/moe.cpp b/sgl-kernel/csrc/cpu/moe.cpp index b3b407947f15..f528ed4d9f77 100644 --- a/sgl-kernel/csrc/cpu/moe.cpp +++ b/sgl-kernel/csrc/cpu/moe.cpp @@ -1233,9 +1233,10 @@ at::Tensor shared_expert_cpu( // for int8 w8a8: // 3. Aq_tmp : [M, K] or [M, N] // 4. As_tmp : [M] + // 5. C_spiltk_tmp: [spiltk_num * 2 * M * N] // // for fp8 w8a16: - // 5. intermediate_cache0 : [M, 2N] + // 6. intermediate_cache0 : [M, 2N] // int64_t num_threads = at::get_num_threads(); int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); @@ -1246,19 +1247,20 @@ at::Tensor shared_expert_cpu( if (use_fp8_w8a16) { buffer_size_nbytes += M * 2 * N * 2; } - // int spiltk_num = 4; - int spiltk = split_k_num(); + auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); - auto buffer_splitk = at::empty({spiltk * 2 * M * N + num_threads * 2 * BLOCK_M * BLOCK_N}, hidden_states.options().dtype(at::kFloat)); AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "share_experts_kernel_impl", [&] { scalar_t* __restrict__ intermediate_cache1 = (scalar_t*)((void*)(buffer.data_ptr())); float* __restrict__ C_tmp = (float*)((void*)(intermediate_cache1 + M * N)); - float* __restrict__ C_spiltk_tmp = (float*)buffer_splitk.data_ptr(); if (use_int8_w8a8) { uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); float* __restrict__ As_tmp = (float*)((void*)(Aq_tmp + std::max(M * K, M * N))); + constexpr int64_t SPLITK_NUM = get_splitk_num(); + auto buffer_splitk = at::empty({SPLITK_NUM * 2 * M * N}, hidden_states.options().dtype(at::kFloat)); + float* __restrict__ C_spiltk_tmp = (float*)buffer_splitk.data_ptr(); + auto w1s = w1_scale.value(); auto w2s = w2_scale.value(); TORCH_CHECK(w1s.numel() == 2 * N); diff --git a/sgl-kernel/csrc/cpu/moe_int8.cpp b/sgl-kernel/csrc/cpu/moe_int8.cpp index 5bcc5d679001..d648814927eb 100644 --- a/sgl-kernel/csrc/cpu/moe_int8.cpp +++ b/sgl-kernel/csrc/cpu/moe_int8.cpp @@ -296,7 +296,7 @@ struct tinygemm_kernel_vnni2 { static inline void apply( const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, - int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool need_comp) { + int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool need_Bcomp) { TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); } }; @@ -307,7 +307,7 @@ struct tinygemm_kernel_vnni2 { static inline void apply( const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, - int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool need_comp) { + int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool need_Bcomp) { constexpr int ROWS = BLOCK_M; constexpr int COLS = BLOCK_N / 16; @@ -361,13 +361,13 @@ struct tinygemm_kernel_vnni2 { if constexpr (col % 2 == 0) { vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16); vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); - if(need_comp){ + if(need_Bcomp){ vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); } } } - auto x_ = need_comp? _mm512_sub_epi32(vc[i], vcomp[col]): vc[i]; + auto x_ = need_Bcomp ? _mm512_sub_epi32(vc[i], vcomp[col]): vc[i]; __m512 x = _mm512_cvtepi32_ps(x_); x = _mm512_mul_ps(_mm512_mul_ps(x, vas), vbs[col]); _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x); @@ -381,7 +381,7 @@ struct tinygemm_kernel_vnni2 { tinygemm_kernel_vnni2::apply( \ A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \ As + mb_start, Bs + nb_start, Bcomp + nb_start, \ - K, lda, ldb, ldc, need_comp); + K, lda, ldb, ldc, need_Bcomp); template void tinygemm_kernel( @@ -396,8 +396,8 @@ void tinygemm_kernel( int64_t lda, int64_t ldb, int64_t ldc) { - bool need_comp = true; // B compensation + bool need_Bcomp = true; const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * K); // pattern: 1-4-16 @@ -424,11 +424,11 @@ void tinygemm_kernel( } template -void tinygemm_kernel_splitk( +void tinygemm_kernel( const uint8_t* __restrict__ A, const int8_t* __restrict__ B, - const int8_t* __restrict__ B_ori, - bool need_comp, + const int8_t* __restrict__ B_comp, + bool need_Bcomp, float* __restrict__ C, const float* __restrict__ As, const float* __restrict__ Bs, @@ -440,7 +440,7 @@ void tinygemm_kernel_splitk( int64_t ldc) { // B compensation - const int32_t* Bcomp = reinterpret_cast(B_ori + block_size_n() * K * split_k_num()); + const int32_t* Bcomp = reinterpret_cast(B_comp); // pattern: 1-4-16 constexpr int64_t BLOCK_M = 4; @@ -752,7 +752,8 @@ void shared_expert_int8_kernel_impl( // stage 1: intermediate_cache1 = silu(hidden_states @ w1) const int64_t MB = div_up(M, BLOCK_M); const int64_t NB = div_up(N, BLOCK_N); - const int64_t KB = K/split_k_num(); + const int64_t SPLITK_NUM = get_splitk_num(); + const int64_t K_SPILT_SIZE = div_up(K, SPLITK_NUM); TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); // K and N are packed for int8 @@ -761,31 +762,26 @@ void shared_expert_int8_kernel_impl( const int64_t stride_n = packed_K; // here we only parallel on half of 2N to fuse silu_and_mul with gemm - at::parallel_for(0, split_k_num() * MB * NB, 0, [&](int64_t begin, int64_t end) { - int tid = at::get_thread_num(); - float* __restrict__ C0 = C_splitk_tmp + split_k_num() * 2 * M * N + tid * 2 * BLOCK_M * BLOCK_N; - float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; + at::parallel_for(0, SPLITK_NUM * MB * NB, 0, [&](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; ++i) { int64_t kb = i / MB / NB; int64_t mb = (i - kb * MB * NB) / NB; int64_t nb = (i - kb * MB * NB) % NB; - float* __restrict__ C0_res = C_splitk_tmp + kb * 2 * M * N; - float* __restrict__ C1_res = C0_res + M * N; + float* __restrict__ C0 = C_splitk_tmp + kb * 2 * M * N; + float* __restrict__ C1 = C0 + M * N; // nb0 from top half and nb1 from bottom half int64_t nb0 = nb, nb1 = nb + NB; int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t k_size = std::min(K - kb * K_SPILT_SIZE, K_SPILT_SIZE); - // A shape [m_size, K] -> [m_size, KB] + // A shape [m_size, K] -> [m_size, K_SPILT_SIZE] when splitk const uint8_t* A = Aq_tmp + mb * BLOCK_M * K; - // A scale is [m_size, 1] const float* As = As_tmp + mb * BLOCK_M; - // // B shape [K, n_size] in vnni format -> [KB, n_size] - // B shape [n_size, K] in vnni format -> [n_size, KB] + // B shape [K, n_size] in vnni format -> [K_SPILT_SIZE, n_size] when splitk const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; - // B scale is [n_size, 1] const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N; const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N; @@ -804,66 +800,54 @@ void shared_expert_int8_kernel_impl( // /* lda */ K, // /* ldb */ n_size, // /* ldc */ N); - tinygemm_kernel_splitk( - /* A */ A + kb * KB , - /* B */ B0 + kb * KB * BLOCK_N, - /* Bcomp */ B0, - kb == 0, - /* C */ C0, - /* As */ As, - /* Bs */ Bs0, - /* M */ m_size, - /* N */ n_size, - /* K */ KB, - /* lda */ K, - /* ldb */ n_size, - /* ldc */ BLOCK_N); - - tinygemm_kernel_splitk( - /* A */ A + kb * KB, - /* B */ B1 + kb * KB * BLOCK_N, - /* Bcomp */ B1, - kb == 0, - /* C */ C1, - /* As */ As, - /* Bs */ Bs1, - /* M */ m_size, - /* N */ n_size, - /* K */ KB, - /* lda */ K, - /* ldb */ n_size, - /* ldc */ BLOCK_N); - - for (int mid = 0; mid< m_size; mid++){ - copy_stub(C0_res + mb*BLOCK_M*N + nb*BLOCK_N + mid * N, C0 + mid * BLOCK_N, n_size); - copy_stub(C1_res + mb*BLOCK_M*N + nb*BLOCK_N + mid * N, C1 + mid * BLOCK_N, n_size); - } + + // stage 1.a: GEMMs with splitk + tinygemm_kernel( + /* A */ A + kb * K_SPILT_SIZE , + /* B */ B0 + kb * K_SPILT_SIZE * BLOCK_N, + /* Bcomp_start */ B0 + BLOCK_N * K, + /* need_Bcomp */ kb == 0, + /* C */ C0 + mb*BLOCK_M*N + nb*BLOCK_N, + /* As */ As, + /* Bs */ Bs0, + /* M */ m_size, + /* N */ n_size, + /* K */ k_size, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + + tinygemm_kernel( + /* A */ A + kb * K_SPILT_SIZE, + /* B */ B1 + kb * K_SPILT_SIZE * BLOCK_N, + /* Bcomp_start */ B1 + BLOCK_N * K, + /* need_Bcomp */ kb == 0, + /* C */ C1 + mb*BLOCK_M*N + nb*BLOCK_N, + /* As */ As, + /* Bs */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ k_size, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } }); - - // for(int mid = 0; mid < M; mid++){ - // for(int k_id=1; k_id < split_k_num(); k_id++){ - // _add_stub(C_splitk_tmp+mid*N, C_splitk_tmp+k_id*2*M*N+mid*N, N); - // _add_stub(C_splitk_tmp+M*N+ mid*N, C_splitk_tmp+k_id*2*M*N+M*N+mid*N, N); - // } - // } - - // std::cout<(ic1 + mb*BLOCK_M*N + nb*BLOCK_N, C0, C1, m_size, N, M, split_k_num()); + silu_and_mul(ic1 + mb*BLOCK_M*N + nb*BLOCK_N, C0, C1, m_size, N, M, SPLITK_NUM); } }); - // stage 1.5: quantize ic1 to uint8, [M, N] + // stage 1.c: quantize ic1 to uint8, [M, N] at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { for (int64_t m = begin; m < end; ++m) { quantize_row_int8( From 718bb7720bf2b6c6b79845545d83e8ff07968dc0 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Wed, 23 Apr 2025 18:48:35 -0700 Subject: [PATCH 3/7] clear up --- sgl-kernel/csrc/cpu/moe.cpp | 2 +- sgl-kernel/csrc/cpu/moe_int8.cpp | 41 ++++++++++---------------------- 2 files changed, 13 insertions(+), 30 deletions(-) diff --git a/sgl-kernel/csrc/cpu/moe.cpp b/sgl-kernel/csrc/cpu/moe.cpp index f528ed4d9f77..dc3dbead900d 100644 --- a/sgl-kernel/csrc/cpu/moe.cpp +++ b/sgl-kernel/csrc/cpu/moe.cpp @@ -1238,7 +1238,7 @@ at::Tensor shared_expert_cpu( // for fp8 w8a16: // 6. intermediate_cache0 : [M, 2N] // - int64_t num_threads = at::get_num_threads(); + int num_threads = at::get_num_threads(); int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); if (use_int8_w8a8) { diff --git a/sgl-kernel/csrc/cpu/moe_int8.cpp b/sgl-kernel/csrc/cpu/moe_int8.cpp index d648814927eb..909e72f45f18 100644 --- a/sgl-kernel/csrc/cpu/moe_int8.cpp +++ b/sgl-kernel/csrc/cpu/moe_int8.cpp @@ -8,29 +8,11 @@ template inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { using Vec = at::vec::Vectorized; // no remainder - int64_t d; #pragma GCC unroll 4 - for (d = 0; d < size; d += Vec::size()) { + for (int64_t d = 0; d < size; d += Vec::size()) { Vec data = Vec::loadu(input + d); data.store(out + d); } - for (; d < size; ++d) { - out[d] = static_cast(input[d]); - } -} - -template -inline void _add_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { - using Vec = at::vec::Vectorized; - int64_t d; - #pragma GCC unroll 4 - for (d = 0; d < size; d += Vec::size()) { - Vec data = Vec::loadu(input + d) + Vec::loadu(out + d); - data.store(out + d); - } - for (; d < size; ++d) { - out[d] = static_cast(input[d]+out[d]); - } } template <> @@ -666,8 +648,8 @@ INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16); INSTANTIATE_MOE_INT8_TEMPLATE(at::Half); // silu : shape leading dimension -// input0 [m_size, BLOCK_N] BLOCK_N -// input1 [m_size, BLOCK_N] BLOCK_N +// input0 [m_size, BLOCK_N] N +// input1 [m_size, BLOCK_N] N // output [M , N] N template inline void silu_and_mul( @@ -676,8 +658,8 @@ inline void silu_and_mul( const float* __restrict__ input1, // y: y0, y1 int64_t m_size, int64_t N, - int64_t M, - int64_t splitk_size) { + int64_t offset_reduce, + int64_t splitk_num) { using bVec = at::vec::Vectorized; using fVec = at::vec::Vectorized; @@ -695,11 +677,12 @@ inline void silu_and_mul( fVec x1 = fVec::loadu(x + d + fVec::size()); fVec y0 = fVec::loadu(y + d); fVec y1 = fVec::loadu(y + d + fVec::size()); - for(int64_t id=1; id< splitk_size; id++){ - x0 += fVec::loadu(x + id*2*M*N+ d); - x1 += fVec::loadu(x + id*2*M*N+ d + fVec::size()); - y0 += fVec::loadu(y + id*2*M*N+ d); - y1 += fVec::loadu(y + id*2*M*N+ d + fVec::size()); + // reduce sum if splitk_num > 1 + for (int64_t id = 1; id < splitk_num; id++) { + x0 += fVec::loadu(x + id * offset_reduce+ d); + x1 += fVec::loadu(x + id * offset_reduce+ d + fVec::size()); + y0 += fVec::loadu(y + id * offset_reduce+ d); + y1 += fVec::loadu(y + id * offset_reduce+ d + fVec::size()); } // silu @@ -843,7 +826,7 @@ void shared_expert_int8_kernel_impl( int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); float* __restrict__ C0 = C_splitk_tmp + mb*BLOCK_M*N + nb*BLOCK_N; float* __restrict__ C1 = C_splitk_tmp + M*N + mb*BLOCK_M*N + nb*BLOCK_N; - silu_and_mul(ic1 + mb*BLOCK_M*N + nb*BLOCK_N, C0, C1, m_size, N, M, SPLITK_NUM); + silu_and_mul(ic1 + mb*BLOCK_M*N + nb*BLOCK_N, C0, C1, m_size, N, 2*M*N, SPLITK_NUM); } }); From 2b97e92d08f56376c63c0f79284cf52fa71c3f5c Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Wed, 23 Apr 2025 20:50:54 -0700 Subject: [PATCH 4/7] clear up --- sgl-kernel/csrc/cpu/moe_int8.cpp | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/sgl-kernel/csrc/cpu/moe_int8.cpp b/sgl-kernel/csrc/cpu/moe_int8.cpp index 909e72f45f18..438b076611c5 100644 --- a/sgl-kernel/csrc/cpu/moe_int8.cpp +++ b/sgl-kernel/csrc/cpu/moe_int8.cpp @@ -768,22 +768,6 @@ void shared_expert_int8_kernel_impl( const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N; const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N; - // // fused 1.b: silu_and_mul(A @ B0, A @ B1) - // tinygemm_kernel( - // /* A */ A, - // /* B0 */ B0, - // /* B1 */ B1, - // /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, - // /* As */ As, - // /* Bs0 */ Bs0, - // /* Bs1 */ Bs1, - // /* M */ m_size, - // /* N */ n_size, - // /* K */ K, - // /* lda */ K, - // /* ldb */ n_size, - // /* ldc */ N); - // stage 1.a: GEMMs with splitk tinygemm_kernel( /* A */ A + kb * K_SPILT_SIZE , From 6bf698e6c1e157803ab9b33129a84b2db6fbd8f8 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Wed, 23 Apr 2025 22:41:16 -0700 Subject: [PATCH 5/7] refine 2GEMMs --- sgl-kernel/csrc/cpu/moe_int8.cpp | 225 +++++++++++++++++++++++++++---- 1 file changed, 198 insertions(+), 27 deletions(-) diff --git a/sgl-kernel/csrc/cpu/moe_int8.cpp b/sgl-kernel/csrc/cpu/moe_int8.cpp index 438b076611c5..b338efb0d87f 100644 --- a/sgl-kernel/csrc/cpu/moe_int8.cpp +++ b/sgl-kernel/csrc/cpu/moe_int8.cpp @@ -272,6 +272,159 @@ void tinygemm_kernel( } } +/// gemm for w13, no silu and mul fusion (due to splitk) +template +struct tinygemm_kernel_vnni_v2 { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, float* __restrict__ C0, float* __restrict__ C1, + const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, + const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool need_Bcomp) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_vnni_v2 { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, float* __restrict__ C0, float* __restrict__ C1, + const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, + const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool need_Bcomp) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + __m512i va; + __m512i vb0[COLS]; + __m512i vb1[COLS]; + __m512i vc0[ROWS * COLS]; + __m512i vc1[ROWS * COLS]; + __m512i vcomp0[COLS]; + __m512i vcomp1[COLS]; + __m512 vas; + __m512 vbs0[COLS]; + __m512 vbs1[COLS]; + + auto loadc = [&](auto i) { + vc0[i] = _mm512_set1_epi32(0); + vc1[i] = _mm512_set1_epi32(0); + }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b0_ptr = reinterpret_cast(B0); + const int32_t* b1_ptr = reinterpret_cast(B1); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16); + vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16); + } + vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]); + vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto scalec_and_store = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr(col == 0) { + vas = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp + if constexpr (row == 0) { + vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16); + vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16); + if (need_Bcomp) { + vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16); + vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16); + } + } + auto c0_ = need_Bcomp ? _mm512_sub_epi32(vc0[i], vcomp0[col]): vc0[i]; + auto c1_ = need_Bcomp ? _mm512_sub_epi32(vc1[i], vcomp1[col]): vc1[i]; + __m512 c0 = _mm512_cvtepi32_ps(c0_); + __m512 c1 = _mm512_cvtepi32_ps(c1_); + c0 = _mm512_mul_ps(_mm512_mul_ps(c0, vas), vbs0[col]); + _mm512_storeu_ps(reinterpret_cast<__m512*>(C0 + row * ldc + col * 16), c0); + c1 = _mm512_mul_ps(_mm512_mul_ps(c1, vas), vbs1[col]); + _mm512_storeu_ps(reinterpret_cast<__m512*>(C1 + row * ldc + col * 16), c1); + }; + Unroll{}(scalec_and_store); + + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_VNNI_V2(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_vnni_v2::apply( \ + A + mb_start * lda, B0 + nb_start * 4, B1 + nb_start * 4, \ + C0 + mb_start * ldc + nb_start, C1 + mb_start * ldc + nb_start, As + mb_start, \ + Bs0 + nb_start, Bs1 + nb_start, Bcomp0 + nb_start, Bcomp1 + nb_start,\ + K, lda, ldb, ldc, need_Bcomp); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B0, + const int8_t* __restrict__ B1, + const int8_t* __restrict__ B0_comp, + const int8_t* __restrict__ B1_comp, + float* __restrict__ C0, + float* __restrict__ C1, + bool need_Bcomp, + const float* __restrict__ As, + const float* __restrict__ Bs0, + const float* __restrict__ Bs1, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + const int32_t* Bcomp0 = reinterpret_cast(B0_comp); + const int32_t* Bcomp1 = reinterpret_cast(B1_comp); + + // pattern: 1-(2+2)-(8+8) + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI_V2(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI_V2(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI_V2(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI_V2(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + + /// gemm for w2 template struct tinygemm_kernel_vnni2 { @@ -770,34 +923,52 @@ void shared_expert_int8_kernel_impl( // stage 1.a: GEMMs with splitk tinygemm_kernel( - /* A */ A + kb * K_SPILT_SIZE , - /* B */ B0 + kb * K_SPILT_SIZE * BLOCK_N, - /* Bcomp_start */ B0 + BLOCK_N * K, - /* need_Bcomp */ kb == 0, - /* C */ C0 + mb*BLOCK_M*N + nb*BLOCK_N, - /* As */ As, - /* Bs */ Bs0, - /* M */ m_size, - /* N */ n_size, - /* K */ k_size, - /* lda */ K, - /* ldb */ n_size, - /* ldc */ N); - - tinygemm_kernel( - /* A */ A + kb * K_SPILT_SIZE, - /* B */ B1 + kb * K_SPILT_SIZE * BLOCK_N, - /* Bcomp_start */ B1 + BLOCK_N * K, + /* A */ A + kb * K_SPILT_SIZE, + /* B0 */ B0 + kb * K_SPILT_SIZE * BLOCK_N, + /* B1 */ B1 + kb * K_SPILT_SIZE * BLOCK_N, + /* B0 */ B0 + BLOCK_N * K, + /* B1 */ B1 + BLOCK_N * K, + /* C */ C0 + mb*BLOCK_M*N + nb*BLOCK_N, + /* C */ C1 + mb*BLOCK_M*N + nb*BLOCK_N, /* need_Bcomp */ kb == 0, - /* C */ C1 + mb*BLOCK_M*N + nb*BLOCK_N, - /* As */ As, - /* Bs */ Bs1, - /* M */ m_size, - /* N */ n_size, - /* K */ k_size, - /* lda */ K, - /* ldb */ n_size, - /* ldc */ N); + /* As */ As, + /* Bs0 */ Bs0, + /* Bs1 */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ k_size, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + // tinygemm_kernel( + // /* A */ A + kb * K_SPILT_SIZE , + // /* B */ B0 + kb * K_SPILT_SIZE * BLOCK_N, + // /* Bcomp_start */ B0 + BLOCK_N * K, + // /* need_Bcomp */ kb == 0, + // /* C */ C0 + mb*BLOCK_M*N + nb*BLOCK_N, + // /* As */ As, + // /* Bs */ Bs0, + // /* M */ m_size, + // /* N */ n_size, + // /* K */ k_size, + // /* lda */ K, + // /* ldb */ n_size, + // /* ldc */ N); + + // tinygemm_kernel( + // /* A */ A + kb * K_SPILT_SIZE, + // /* B */ B1 + kb * K_SPILT_SIZE * BLOCK_N, + // /* Bcomp_start */ B1 + BLOCK_N * K, + // /* need_Bcomp */ kb == 0, + // /* C */ C1 + mb*BLOCK_M*N + nb*BLOCK_N, + // /* As */ As, + // /* Bs */ Bs1, + // /* M */ m_size, + // /* N */ n_size, + // /* K */ k_size, + // /* lda */ K, + // /* ldb */ n_size, + // /* ldc */ N); } }); From cb603c8d345fcbd68e9584793aefcda072b41e70 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Thu, 24 Apr 2025 19:41:22 -0700 Subject: [PATCH 6/7] refine code --- sgl-kernel/csrc/cpu/moe_int8.cpp | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/sgl-kernel/csrc/cpu/moe_int8.cpp b/sgl-kernel/csrc/cpu/moe_int8.cpp index b338efb0d87f..fc627a809ba9 100644 --- a/sgl-kernel/csrc/cpu/moe_int8.cpp +++ b/sgl-kernel/csrc/cpu/moe_int8.cpp @@ -940,35 +940,6 @@ void shared_expert_int8_kernel_impl( /* lda */ K, /* ldb */ n_size, /* ldc */ N); - // tinygemm_kernel( - // /* A */ A + kb * K_SPILT_SIZE , - // /* B */ B0 + kb * K_SPILT_SIZE * BLOCK_N, - // /* Bcomp_start */ B0 + BLOCK_N * K, - // /* need_Bcomp */ kb == 0, - // /* C */ C0 + mb*BLOCK_M*N + nb*BLOCK_N, - // /* As */ As, - // /* Bs */ Bs0, - // /* M */ m_size, - // /* N */ n_size, - // /* K */ k_size, - // /* lda */ K, - // /* ldb */ n_size, - // /* ldc */ N); - - // tinygemm_kernel( - // /* A */ A + kb * K_SPILT_SIZE, - // /* B */ B1 + kb * K_SPILT_SIZE * BLOCK_N, - // /* Bcomp_start */ B1 + BLOCK_N * K, - // /* need_Bcomp */ kb == 0, - // /* C */ C1 + mb*BLOCK_M*N + nb*BLOCK_N, - // /* As */ As, - // /* Bs */ Bs1, - // /* M */ m_size, - // /* N */ n_size, - // /* K */ k_size, - // /* lda */ K, - // /* ldb */ n_size, - // /* ldc */ N); } }); From 41a719ad383f58f57d0fe65a2967bd4fa0f777d5 Mon Sep 17 00:00:00 2001 From: jianan-gu Date: Fri, 25 Apr 2025 10:44:51 +0800 Subject: [PATCH 7/7] Update moe_int8.cpp --- sgl-kernel/csrc/cpu/moe_int8.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sgl-kernel/csrc/cpu/moe_int8.cpp b/sgl-kernel/csrc/cpu/moe_int8.cpp index fc627a809ba9..8136b01aee2e 100644 --- a/sgl-kernel/csrc/cpu/moe_int8.cpp +++ b/sgl-kernel/csrc/cpu/moe_int8.cpp @@ -926,8 +926,8 @@ void shared_expert_int8_kernel_impl( /* A */ A + kb * K_SPILT_SIZE, /* B0 */ B0 + kb * K_SPILT_SIZE * BLOCK_N, /* B1 */ B1 + kb * K_SPILT_SIZE * BLOCK_N, - /* B0 */ B0 + BLOCK_N * K, - /* B1 */ B1 + BLOCK_N * K, + /* B0comp*/ B0 + BLOCK_N * K, + /* B1comp*/ B1 + BLOCK_N * K, /* C */ C0 + mb*BLOCK_M*N + nb*BLOCK_N, /* C */ C1 + mb*BLOCK_M*N + nb*BLOCK_N, /* need_Bcomp */ kb == 0,