diff --git a/sgl-kernel/csrc/cpu/gemm.h b/sgl-kernel/csrc/cpu/gemm.h index fc6f045fe3b0..c8b87d60113b 100644 --- a/sgl-kernel/csrc/cpu/gemm.h +++ b/sgl-kernel/csrc/cpu/gemm.h @@ -10,7 +10,7 @@ // block size for AMX gemm constexpr int block_size_m() { return 2 * TILE_M; } constexpr int block_size_n() { return 2 * TILE_N; } - +constexpr int get_splitk_num() { return 4; } // define threshold using brgemm (intel AMX) template inline bool can_use_brgemm(int M); template <> inline bool can_use_brgemm(int M) { return M > 4; } @@ -100,6 +100,7 @@ void shared_expert_int8_kernel_impl( scalar_t* __restrict__ output, scalar_t* __restrict__ ic1, float* __restrict__ C_tmp, + float* __restrict__ C_splitk_tmp, uint8_t* __restrict__ Aq_tmp, float* __restrict__ As_tmp, const scalar_t* __restrict__ input, diff --git a/sgl-kernel/csrc/cpu/moe.cpp b/sgl-kernel/csrc/cpu/moe.cpp index 3344b24cb96f..ecfcc6df0ac6 100644 --- a/sgl-kernel/csrc/cpu/moe.cpp +++ b/sgl-kernel/csrc/cpu/moe.cpp @@ -1238,10 +1238,11 @@ at::Tensor shared_expert_cpu( // for int8 w8a8: // 3. Aq_tmp : [M, K] or [M, N] // 4. As_tmp : [M] + // 5. C_spiltk_tmp: [spiltk_num * 2 * M * N] // // for fp8 w8a16: - // 5. intermediate_cache0 : [M, 2N] - // 6. B_tmp: [T, BLOCK_M, max(K, N)] + // 6. intermediate_cache0 : [M, 2N] + // 7. B_tmp: [T, BLOCK_M, max(K, N)] // int num_threads = at::get_num_threads(); int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); @@ -1262,6 +1263,10 @@ at::Tensor shared_expert_cpu( uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); float* __restrict__ As_tmp = (float*)((void*)(Aq_tmp + std::max(M * K, M * N))); + constexpr int64_t SPLITK_NUM = get_splitk_num(); + auto buffer_splitk = at::empty({SPLITK_NUM * 2 * M * N}, hidden_states.options().dtype(at::kFloat)); + float* __restrict__ C_spiltk_tmp = (float*)buffer_splitk.data_ptr(); + auto w1s = w1_scale.value(); auto w2s = w2_scale.value(); TORCH_CHECK(w1s.numel() == 2 * N); @@ -1271,6 +1276,7 @@ at::Tensor shared_expert_cpu( out_hidden_states.data_ptr(), intermediate_cache1, C_tmp, + C_spiltk_tmp, Aq_tmp, As_tmp, hidden_states.data_ptr(), diff --git a/sgl-kernel/csrc/cpu/moe_int8.cpp b/sgl-kernel/csrc/cpu/moe_int8.cpp index 4ebbead1d2ab..8136b01aee2e 100644 --- a/sgl-kernel/csrc/cpu/moe_int8.cpp +++ b/sgl-kernel/csrc/cpu/moe_int8.cpp @@ -272,13 +272,166 @@ void tinygemm_kernel( } } +/// gemm for w13, no silu and mul fusion (due to splitk) +template +struct tinygemm_kernel_vnni_v2 { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, float* __restrict__ C0, float* __restrict__ C1, + const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, + const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool need_Bcomp) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_vnni_v2 { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, float* __restrict__ C0, float* __restrict__ C1, + const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, + const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool need_Bcomp) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + __m512i va; + __m512i vb0[COLS]; + __m512i vb1[COLS]; + __m512i vc0[ROWS * COLS]; + __m512i vc1[ROWS * COLS]; + __m512i vcomp0[COLS]; + __m512i vcomp1[COLS]; + __m512 vas; + __m512 vbs0[COLS]; + __m512 vbs1[COLS]; + + auto loadc = [&](auto i) { + vc0[i] = _mm512_set1_epi32(0); + vc1[i] = _mm512_set1_epi32(0); + }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b0_ptr = reinterpret_cast(B0); + const int32_t* b1_ptr = reinterpret_cast(B1); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16); + vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16); + } + vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]); + vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto scalec_and_store = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr(col == 0) { + vas = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp + if constexpr (row == 0) { + vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16); + vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16); + if (need_Bcomp) { + vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16); + vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16); + } + } + auto c0_ = need_Bcomp ? _mm512_sub_epi32(vc0[i], vcomp0[col]): vc0[i]; + auto c1_ = need_Bcomp ? _mm512_sub_epi32(vc1[i], vcomp1[col]): vc1[i]; + __m512 c0 = _mm512_cvtepi32_ps(c0_); + __m512 c1 = _mm512_cvtepi32_ps(c1_); + c0 = _mm512_mul_ps(_mm512_mul_ps(c0, vas), vbs0[col]); + _mm512_storeu_ps(reinterpret_cast<__m512*>(C0 + row * ldc + col * 16), c0); + c1 = _mm512_mul_ps(_mm512_mul_ps(c1, vas), vbs1[col]); + _mm512_storeu_ps(reinterpret_cast<__m512*>(C1 + row * ldc + col * 16), c1); + }; + Unroll{}(scalec_and_store); + + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_VNNI_V2(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_vnni_v2::apply( \ + A + mb_start * lda, B0 + nb_start * 4, B1 + nb_start * 4, \ + C0 + mb_start * ldc + nb_start, C1 + mb_start * ldc + nb_start, As + mb_start, \ + Bs0 + nb_start, Bs1 + nb_start, Bcomp0 + nb_start, Bcomp1 + nb_start,\ + K, lda, ldb, ldc, need_Bcomp); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B0, + const int8_t* __restrict__ B1, + const int8_t* __restrict__ B0_comp, + const int8_t* __restrict__ B1_comp, + float* __restrict__ C0, + float* __restrict__ C1, + bool need_Bcomp, + const float* __restrict__ As, + const float* __restrict__ Bs0, + const float* __restrict__ Bs1, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + const int32_t* Bcomp0 = reinterpret_cast(B0_comp); + const int32_t* Bcomp1 = reinterpret_cast(B1_comp); + + // pattern: 1-(2+2)-(8+8) + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI_V2(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI_V2(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI_V2(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI_V2(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + + /// gemm for w2 template struct tinygemm_kernel_vnni2 { static inline void apply( const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, - int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool need_Bcomp) { TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); } }; @@ -289,7 +442,7 @@ struct tinygemm_kernel_vnni2 { static inline void apply( const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, - int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool need_Bcomp) { constexpr int ROWS = BLOCK_M; constexpr int COLS = BLOCK_N / 16; @@ -343,11 +496,14 @@ struct tinygemm_kernel_vnni2 { if constexpr (col % 2 == 0) { vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16); vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); - vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); - vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); + if(need_Bcomp){ + vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); + vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); + } } } - __m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col])); + auto x_ = need_Bcomp ? _mm512_sub_epi32(vc[i], vcomp[col]): vc[i]; + __m512 x = _mm512_cvtepi32_ps(x_); x = _mm512_mul_ps(_mm512_mul_ps(x, vas), vbs[col]); _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x); }; @@ -360,7 +516,7 @@ struct tinygemm_kernel_vnni2 { tinygemm_kernel_vnni2::apply( \ A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \ As + mb_start, Bs + nb_start, Bcomp + nb_start, \ - K, lda, ldb, ldc); + K, lda, ldb, ldc, need_Bcomp); template void tinygemm_kernel( @@ -375,8 +531,8 @@ void tinygemm_kernel( int64_t lda, int64_t ldb, int64_t ldc) { - // B compensation + bool need_Bcomp = true; const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * K); // pattern: 1-4-16 @@ -402,6 +558,48 @@ void tinygemm_kernel( } } +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + const int8_t* __restrict__ B_comp, + bool need_Bcomp, + float* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + // B compensation + const int32_t* Bcomp = reinterpret_cast(B_comp); + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int64_t mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + } // anonymous namespace template @@ -602,11 +800,64 @@ void fused_experts_int8_kernel_impl( INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16); INSTANTIATE_MOE_INT8_TEMPLATE(at::Half); +// silu : shape leading dimension +// input0 [m_size, BLOCK_N] N +// input1 [m_size, BLOCK_N] N +// output [M , N] N +template +inline void silu_and_mul( + scalar_t* __restrict__ output, + const float* __restrict__ input0, // x: x0, x1 + const float* __restrict__ input1, // y: y0, y1 + int64_t m_size, + int64_t N, + int64_t offset_reduce, + int64_t splitk_num) { + + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + const fVec one = fVec(1.f); + + // no remainder + for (int64_t m = 0; m < m_size; ++m) { + scalar_t* __restrict__ out = output + m * N; + const float* __restrict__ x = input0 + m * N; + const float* __restrict__ y = input1 + m * N; + + for (int64_t d = 0; d < BLOCK_N; d += bVec::size()) { + fVec x0 = fVec::loadu(x + d); + fVec x1 = fVec::loadu(x + d + fVec::size()); + fVec y0 = fVec::loadu(y + d); + fVec y1 = fVec::loadu(y + d + fVec::size()); + // reduce sum if splitk_num > 1 + for (int64_t id = 1; id < splitk_num; id++) { + x0 += fVec::loadu(x + id * offset_reduce+ d); + x1 += fVec::loadu(x + id * offset_reduce+ d + fVec::size()); + y0 += fVec::loadu(y + id * offset_reduce+ d); + y1 += fVec::loadu(y + id * offset_reduce+ d + fVec::size()); + } + + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + // convert + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + } +} + + template void shared_expert_int8_kernel_impl( scalar_t* __restrict__ output, scalar_t* __restrict__ ic1, float* __restrict__ C_tmp, + float* __restrict__ C_splitk_tmp, uint8_t* __restrict__ Aq_tmp, float* __restrict__ As_tmp, const scalar_t* __restrict__ input, @@ -619,7 +870,6 @@ void shared_expert_int8_kernel_impl( int64_t M, int64_t N, int64_t K) { - // handle 2 tiles per block constexpr int64_t BLOCK_M = block_size_m(); constexpr int64_t BLOCK_N = block_size_n(); @@ -635,10 +885,11 @@ void shared_expert_int8_kernel_impl( } }); - // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) const int64_t MB = div_up(M, BLOCK_M); const int64_t NB = div_up(N, BLOCK_N); - + const int64_t SPLITK_NUM = get_splitk_num(); + const int64_t K_SPILT_SIZE = div_up(K, SPLITK_NUM); TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); // K and N are packed for int8 @@ -647,45 +898,65 @@ void shared_expert_int8_kernel_impl( const int64_t stride_n = packed_K; // here we only parallel on half of 2N to fuse silu_and_mul with gemm - at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + at::parallel_for(0, SPLITK_NUM * MB * NB, 0, [&](int64_t begin, int64_t end) { for (int64_t i = begin; i < end; ++i) { - int64_t mb = i / NB; - int64_t nb = i % NB; - + int64_t kb = i / MB / NB; + int64_t mb = (i - kb * MB * NB) / NB; + int64_t nb = (i - kb * MB * NB) % NB; + float* __restrict__ C0 = C_splitk_tmp + kb * 2 * M * N; + float* __restrict__ C1 = C0 + M * N; // nb0 from top half and nb1 from bottom half int64_t nb0 = nb, nb1 = nb + NB; int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t k_size = std::min(K - kb * K_SPILT_SIZE, K_SPILT_SIZE); - // A shape [m_size, K] + // A shape [m_size, K] -> [m_size, K_SPILT_SIZE] when splitk const uint8_t* A = Aq_tmp + mb * BLOCK_M * K; const float* As = As_tmp + mb * BLOCK_M; - // B shape [K, n_size] in vnni format + // B shape [K, n_size] in vnni format -> [K_SPILT_SIZE, n_size] when splitk const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N; const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N; - // fused 1.b: silu_and_mul(A @ B0, A @ B1) - tinygemm_kernel( - /* A */ A, - /* B0 */ B0, - /* B1 */ B1, - /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, - /* As */ As, - /* Bs0 */ Bs0, - /* Bs1 */ Bs1, - /* M */ m_size, - /* N */ n_size, - /* K */ K, - /* lda */ K, - /* ldb */ n_size, - /* ldc */ N); + // stage 1.a: GEMMs with splitk + tinygemm_kernel( + /* A */ A + kb * K_SPILT_SIZE, + /* B0 */ B0 + kb * K_SPILT_SIZE * BLOCK_N, + /* B1 */ B1 + kb * K_SPILT_SIZE * BLOCK_N, + /* B0comp*/ B0 + BLOCK_N * K, + /* B1comp*/ B1 + BLOCK_N * K, + /* C */ C0 + mb*BLOCK_M*N + nb*BLOCK_N, + /* C */ C1 + mb*BLOCK_M*N + nb*BLOCK_N, + /* need_Bcomp */ kb == 0, + /* As */ As, + /* Bs0 */ Bs0, + /* Bs1 */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ k_size, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } }); - // stage 1.5: quantize ic1 to uint8, [M * topk, N] + // stage 1.b: reduce splitk in [M, splitk, N] * 2 , and then silu_and_mul + at::parallel_for(0, MB*NB, 0, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i/ NB; + int64_t nb = i% NB; + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + float* __restrict__ C0 = C_splitk_tmp + mb*BLOCK_M*N + nb*BLOCK_N; + float* __restrict__ C1 = C_splitk_tmp + M*N + mb*BLOCK_M*N + nb*BLOCK_N; + silu_and_mul(ic1 + mb*BLOCK_M*N + nb*BLOCK_N, C0, C1, m_size, N, 2*M*N, SPLITK_NUM); + } + }); + + // stage 1.c: quantize ic1 to uint8, [M, N] at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { for (int64_t m = begin; m < end; ++m) { quantize_row_int8( @@ -753,7 +1024,7 @@ void shared_expert_int8_kernel_impl( #define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE) \ template void shared_expert_int8_kernel_impl ( \ TYPE* __restrict__ output, TYPE* __restrict__ ic1, \ - float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \ + float* __restrict__ C_tmp, float* __restrict__ C_splitk_tmp, uint8_t* __restrict__ Aq_tmp, \ float* __restrict__ As_tmp, const TYPE* __restrict__ input, \ const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \ const float* __restrict__ w1s, const float* __restrict__ w2s, \