From d84b272f657e37c2ad276c1871ae3b813329bf41 Mon Sep 17 00:00:00 2001 From: mingfeima Date: Tue, 22 Apr 2025 23:29:46 -0700 Subject: [PATCH 1/2] improve fp8 performance for gemm and fused_moe, shared_moe 1. brgemm impl: move brgemm out of inner loop 2. avx512 impl: move scaling out of inner loop 3. fp8_scaled_mm: change BLOCK_M to 128 to reduce access to B 4. cvt_fp8_bf16: ignore NaN handling ``` Comparing: True max_diff = 0.01562, asum = 10.562, bsum = 10.375 gemm_bf16(native): 89.812 us, gemm_fp8(opt): 124.585 us Comparing: True max_diff = 0.01562, asum = -32.500, bsum = -32.750 gemm_bf16(native): 83.805 us, gemm_fp8(opt): 125.586 us Comparing: True max_diff = 0.01562, asum = -35.750, bsum = -36.500 gemm_bf16(native): 89.579 us, gemm_fp8(opt): 151.284 us Comparing: True max_diff = 0.03125, asum = 4512.000, bsum = 4512.000 gemm_bf16(native): 262.104 us, gemm_fp8(opt): 615.823 us ``` ``` Comparing: True max_diff = 0.01562, asum = 10.562, bsum = 10.375 gemm_bf16(native): 86.403 us, gemm_fp8(opt): 95.792 us Comparing: True max_diff = 0.01562, asum = -32.500, bsum = -32.750 gemm_bf16(native): 84.178 us, gemm_fp8(opt): 100.573 us Comparing: True max_diff = 0.01562, asum = -35.750, bsum = -36.500 gemm_bf16(native): 90.365 us, gemm_fp8(opt): 114.198 us Comparing: True max_diff = 0.03125, asum = 4512.000, bsum = 4512.000 gemm_bf16(native): 267.053 us, gemm_fp8(opt): 404.231 us ``` --- sgl-kernel/csrc/cpu/gemm.h | 4 + sgl-kernel/csrc/cpu/gemm_fp8.cpp | 164 +++++++++++++++---------------- sgl-kernel/csrc/cpu/moe.cpp | 17 +++- sgl-kernel/csrc/cpu/moe_fp8.cpp | 36 +++---- sgl-kernel/csrc/cpu/vec.h | 18 +++- 5 files changed, 132 insertions(+), 107 deletions(-) diff --git a/sgl-kernel/csrc/cpu/gemm.h b/sgl-kernel/csrc/cpu/gemm.h index daab64e05b3a..fc6f045fe3b0 100644 --- a/sgl-kernel/csrc/cpu/gemm.h +++ b/sgl-kernel/csrc/cpu/gemm.h @@ -74,6 +74,8 @@ void fused_experts_fp8_kernel_impl( scalar_t* __restrict__ ic1, scalar_t* __restrict__ ic2, scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, const scalar_t* __restrict__ input, const at::Float8_e4m3fn* __restrict__ packed_w1, const at::Float8_e4m3fn* __restrict__ packed_w2, @@ -116,6 +118,8 @@ void shared_expert_fp8_kernel_impl( scalar_t* __restrict__ output, scalar_t* __restrict__ ic0, scalar_t* __restrict__ ic1, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, const scalar_t* __restrict__ input, const at::Float8_e4m3fn* __restrict__ packed_w1, const at::Float8_e4m3fn* __restrict__ packed_w2, diff --git a/sgl-kernel/csrc/cpu/gemm_fp8.cpp b/sgl-kernel/csrc/cpu/gemm_fp8.cpp index 420a31e67548..4b3822113493 100644 --- a/sgl-kernel/csrc/cpu/gemm_fp8.cpp +++ b/sgl-kernel/csrc/cpu/gemm_fp8.cpp @@ -2,6 +2,9 @@ #include "vec.h" #include "gemm.h" +// we use 4x32 for BLOCK_M +#define BLOCK_SIZE_M_SCALE 4 + namespace { template @@ -60,33 +63,32 @@ inline void unpack_B( constexpr int BLOCK_N = block_size_n(); static_assert(BLOCK_N == 32); +#pragma GCC unroll 4 for (int k = 0; k < K2; ++k) { - for (int n = 0; n < N; n += 64) { // BLOCK_N = 32 - __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + n); + __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2); - __m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); - __m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1); + __m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); + __m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1); - __m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0); - __m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1); + __m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0); + __m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1); - // Apply scale - __m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0)); - __m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1)); - __m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0)); - __m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1)); + // Apply scale + __m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0)); + __m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1)); + __m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0)); + __m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1)); - f0_lo = _mm512_mul_ps(f0_lo, vd); - f0_hi = _mm512_mul_ps(f0_hi, vd); - f1_lo = _mm512_mul_ps(f1_lo, vd); - f1_hi = _mm512_mul_ps(f1_hi, vd); + f0_lo = _mm512_mul_ps(f0_lo, vd); + f0_hi = _mm512_mul_ps(f0_hi, vd); + f1_lo = _mm512_mul_ps(f1_lo, vd); + f1_hi = _mm512_mul_ps(f1_hi, vd); - bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo); - bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo); + bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo); + bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo); - _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + n * 2 + 0, (__m512i)bf16_0); - _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + n * 2 + 32, (__m512i)bf16_1); - } + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0); + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1); } #else TORCH_CHECK(false, "unpack_B: scalar path not implemented!"); @@ -112,12 +114,18 @@ struct tinygemm_kernel_nn{}(loadc); - const int K2 = K >> 1; const int lda2 = lda >> 1; const int ldb2 = ldb; // ldb * 2 >> 1; const float* a_ptr = reinterpret_cast(A); @@ -139,62 +146,50 @@ struct tinygemm_kernel_nn 0) { _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); } - - __m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); - __m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1); - - __m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0); - __m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1); - - // Apply scale - __m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0)); - __m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1)); - __m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0)); - __m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1)); - - f0_lo = _mm512_mul_ps(f0_lo, vd); - f0_hi = _mm512_mul_ps(f0_hi, vd); - f1_lo = _mm512_mul_ps(f1_lo, vd); - f1_hi = _mm512_mul_ps(f1_hi, vd); - - vb[col + 0] = _mm512_cvtne2ps_pbh(f0_hi, f0_lo); - vb[col + 1] = _mm512_cvtne2ps_pbh(f1_hi, f1_lo); + vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0)); + vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1)); } } - vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]); }; - for (int k = 0; k < K2; ++k) { - Unroll{}(compute, k); + + constexpr int BLOCK_K2 = BLOCK_K >> 1; + for (int kb = 0; kb < KB; ++kb) { + int kb_start = kb * BLOCK_K2; + int kb_end = std::min(K, kb_start + BLOCK_K2); + // 1. load scale vector + vscale = _mm512_set1_ps(scale[kb]); + // 2. zero vsum for each block + Unroll{}([&](auto i) { + vsum[i] = _mm512_set1_ps(0.f); + }); + // 3. accumulate across each block + for (int k = kb_start; k < kb_end; ++k) { + Unroll{}(compute, k); + } + // 4. apply scale + Unroll{}([&](auto i) { + vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]); + }); } auto storec = [&](auto i) { constexpr int row = i / COLS; constexpr int col = i % COLS; - // for COLS = 1, 3 use 256bit store - // for COLS = 2, 4 use 512bit store - if constexpr (COLS % 2 == 0) { - if constexpr (col % 2 == 0) { - _mm512_storeu_si512( - reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), - (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); - } - } else { - _mm256_storeu_si256( - reinterpret_cast<__m256i*>(C + row * ldc + col * 16), - (__m256i)(_mm512_cvtneps_pbh(vc[i]))); + // for COLS = 2,4 use 512bit store + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); } }; Unroll{}(storec); @@ -243,25 +238,22 @@ struct brgemm { int lda, int ldb, int ldc) { - constexpr int BLOCK_N = block_size_n(); - // [BLOCK_K, BLOCK_N] -> [BLOCK_K / 2, BLOCK_N * 2] - const int ldb_tmp = block_size_n(); + constexpr int BLOCK_N = block_size_n(); - static_assert(BLOCK_K == 128); + // [K, BLOCK_N] -> [K / 2, BLOCK_N * 2] + const int ldb_tmp = BLOCK_N; - // accumulate across K per BLOCK_K for (int k = 0; k < K; k += BLOCK_K) { int kb_size = std::min(BLOCK_K, K - k); int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128 - unpack_B(Btmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); - - const bool add_C = (k != 0); - at::native::cpublas::brgemm( - M, N, kb_size, lda, ldb_tmp, BLOCK_N, add_C, A + k, Btmp, Ctmp); + unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); } + at::native::cpublas::brgemm( + M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp); + // copy from Ctmp to C for (int m = 0; m < M; ++m) { if constexpr (has_bias) { @@ -310,18 +302,10 @@ void tinygemm_kernel( int64_t nb_size = std::min(BLOCK_N, N - nb_start); switch(mb_size << 4 | nb_size >> 4) { - // mb_size = 1 case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; - case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; - // mb_size = 2 case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; - case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break; - // mb_size = 3 case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; - case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break; - // mb_size = 4 case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; - case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break; default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); } } @@ -335,15 +319,17 @@ void fp8_scaled_mm_kernel_impl( const at::Float8_e4m3fn* __restrict__ mat2, const float* __restrict__ scales2, const float* __restrict__ bias, + scalar_t* __restrict__ buffer, int64_t M, int64_t N, int64_t K, int64_t mat1_strideM, int64_t out_strideM, int64_t block_size_N, - int64_t block_size_K) { + int64_t block_size_K, + int64_t buffer_size_per_thread) { - constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; constexpr int64_t BLOCK_N = block_size_n(); const int64_t MB = div_up(M, BLOCK_M); const int64_t NB = div_up(N, BLOCK_N); @@ -359,10 +345,9 @@ void fp8_scaled_mm_kernel_impl( int64_t mb{0}, nb{0}; data_index_init(begin, mb, MB, nb, NB); - // for brgemm, use float32 for accumulate - alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; - // for brgemm when mat2 is float8_e4m3 - alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K]; + int tid = at::get_thread_num(); + scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread; + float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K)); for (int64_t i = begin; i < end; ++i) { UNUSED(i); @@ -470,6 +455,7 @@ at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& sca int64_t block_size_N = block_size[0]; int64_t block_size_K = block_size[1]; + constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; constexpr int64_t BLOCK_N = block_size_n(); TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"); TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"); @@ -498,6 +484,12 @@ at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& sca bias_data = bias.value().data_ptr(); } + // Btmp : [T, BLOCK_N * K] + // Ctmp : [T, BLOCK_M * BLOCK_N] + int num_threads = at::get_num_threads(); + int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2; + auto buffer = at::empty({num_threads, size_per_thread}, mat1.options()); + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] { fp8_scaled_mm_kernel_impl( out.data_ptr(), @@ -505,13 +497,15 @@ at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& sca packed_w.data_ptr(), scales2.data_ptr(), bias_data, + buffer.data_ptr(), M, N, K, mat1_strideM, out_strideM, block_size_N, - block_size_K); + block_size_K, + size_per_thread); }); return out; diff --git a/sgl-kernel/csrc/cpu/moe.cpp b/sgl-kernel/csrc/cpu/moe.cpp index f596d015129b..3344b24cb96f 100644 --- a/sgl-kernel/csrc/cpu/moe.cpp +++ b/sgl-kernel/csrc/cpu/moe.cpp @@ -1058,7 +1058,8 @@ at::Tensor fused_experts_cpu( // 6. As_tmp : [M * topk] // // for fp8 w8a16: - // 7. intermediate_cache1 : [M * topk, 2N] + // 7. intermediate_cache0 : [M * topk, 2N] + // 8. B_tmp : [T, BLOCK_N, std::max(K, N)] // int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 + num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) + @@ -1068,7 +1069,7 @@ at::Tensor fused_experts_cpu( buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float); } if (use_fp8_w8a16) { - buffer_size_nbytes += M * topk * 2 * N * 2; + buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * BLOCK_N * std::max(K, N) * 2; } auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); @@ -1114,7 +1115,9 @@ at::Tensor fused_experts_cpu( } else if (use_fp8_w8a16) { // here we just ignore C_tmp as it is not used scalar_t* __restrict__ A_tmp = (scalar_t*)((void*)(intermediate_cache2 + M * topk * K)); - scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * topk * 2 * N)); CHECK_MOE_SCALES_FP8(1, 2); fused_experts_fp8_kernel_impl( @@ -1123,6 +1126,8 @@ at::Tensor fused_experts_cpu( intermediate_cache1, intermediate_cache2, A_tmp, + B_tmp, + C_tmp, hidden_states.data_ptr(), packed_w1.data_ptr(), packed_w2.data_ptr(), @@ -1236,6 +1241,7 @@ at::Tensor shared_expert_cpu( // // for fp8 w8a16: // 5. intermediate_cache0 : [M, 2N] + // 6. B_tmp: [T, BLOCK_M, max(K, N)] // int num_threads = at::get_num_threads(); int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); @@ -1244,7 +1250,7 @@ at::Tensor shared_expert_cpu( buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float); } if (use_fp8_w8a16) { - buffer_size_nbytes += M * 2 * N * 2; + buffer_size_nbytes += M * 2 * N * 2 + num_threads * BLOCK_M * std::max(K, N) * 2; } auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); @@ -1279,12 +1285,15 @@ at::Tensor shared_expert_cpu( K); } else if (use_fp8_w8a16) { scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * 2 * N)); CHECK_MOE_SCALES_FP8(0, 1); shared_expert_fp8_kernel_impl( out_hidden_states.data_ptr(), intermediate_cache0, intermediate_cache1, + B_tmp, + C_tmp, hidden_states.data_ptr(), packed_w1.data_ptr(), packed_w2.data_ptr(), diff --git a/sgl-kernel/csrc/cpu/moe_fp8.cpp b/sgl-kernel/csrc/cpu/moe_fp8.cpp index 3d317ab4cea0..cfe37b0dc1ab 100644 --- a/sgl-kernel/csrc/cpu/moe_fp8.cpp +++ b/sgl-kernel/csrc/cpu/moe_fp8.cpp @@ -145,6 +145,8 @@ void fused_experts_fp8_kernel_impl( scalar_t* __restrict__ ic1, scalar_t* __restrict__ ic2, scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, const scalar_t* __restrict__ input, const at::Float8_e4m3fn* __restrict__ packed_w1, const at::Float8_e4m3fn* __restrict__ packed_w2, @@ -182,9 +184,6 @@ void fused_experts_fp8_kernel_impl( int tid = at::get_thread_num(); scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; - alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K]; - alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; - bool is_brgemm_used = false; for (int64_t i = begin; i < end; ++i) { @@ -215,8 +214,8 @@ void fused_experts_fp8_kernel_impl( /* A */ A, /* B */ B, /* C */ ic0 + offset * 2 * N + nb * BLOCK_N, - /* Btmp */ Btmp, - /* Ctmp */ Ctmp, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* scale */ Bs, /* M */ m_size, /* N */ n_size, @@ -257,9 +256,8 @@ void fused_experts_fp8_kernel_impl( // parallel on [MB2, NB2] at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { - alignas(64) scalar_t Btmp[BLOCK_K * BLOCK_N]; + int tid = at::get_thread_num(); alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; - alignas(64) float Ctmp[BLOCK_M * BLOCK_K]; bool is_brgemm_used = false; @@ -287,8 +285,8 @@ void fused_experts_fp8_kernel_impl( /* A */ A, /* B */ B, /* C */ C, - /* Btmp */ Btmp, - /* Ctmp */ Ctmp, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* scale */ Bs, /* M */ m_size, /* N */ n_size, @@ -329,6 +327,8 @@ void fused_experts_fp8_kernel_impl( TYPE* __restrict__ ic1, \ TYPE* __restrict__ ic2, \ TYPE* __restrict__ A_tmp, \ + TYPE* __restrict__ B_tmp, \ + float* __restrict__ C_tmp, \ const TYPE* __restrict__ input, \ const at::Float8_e4m3fn* __restrict__ packed_w1, \ const at::Float8_e4m3fn* __restrict__ packed_w2, \ @@ -355,6 +355,8 @@ void shared_expert_fp8_kernel_impl( scalar_t* __restrict__ output, scalar_t* __restrict__ ic0, scalar_t* __restrict__ ic1, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, const scalar_t* __restrict__ input, const at::Float8_e4m3fn* __restrict__ packed_w1, const at::Float8_e4m3fn* __restrict__ packed_w2, @@ -380,8 +382,7 @@ void shared_expert_fp8_kernel_impl( const bool use_brgemm = can_use_brgemm(M); at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { - alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K]; - alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + int tid = at::get_thread_num(); for (int64_t i = begin; i < end; ++i) { int64_t mb = i / NB; @@ -393,8 +394,8 @@ void shared_expert_fp8_kernel_impl( /* A */ input + mb * BLOCK_M * K, /* B */ packed_w1 + nb * BLOCK_N * K, /* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N, - /* Btmp */ Btmp, - /* Ctmp */ Ctmp, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K, /* M */ m_size, /* N */ n_size, @@ -432,9 +433,8 @@ void shared_expert_fp8_kernel_impl( // parallel on [MB2, NB2] at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { - alignas(64) scalar_t Btmp[BLOCK_K * BLOCK_N]; + int tid = at::get_thread_num(); alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; - alignas(64) float Ctmp[BLOCK_M * BLOCK_K]; for (int64_t i = begin; i < end; ++i) { int64_t mb = i / NB2; @@ -447,8 +447,8 @@ void shared_expert_fp8_kernel_impl( /* A */ ic1 + mb * BLOCK_M * N, /* B */ packed_w2 + nb * BLOCK_N * N, /* C */ C, - /* Btmp */ Btmp, - /* Ctmp */ Ctmp, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, /* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K, /* M */ m_size, /* N */ n_size, @@ -478,6 +478,8 @@ void shared_expert_fp8_kernel_impl( TYPE* __restrict__ output, \ TYPE* __restrict__ ic0, \ TYPE* __restrict__ ic1, \ + TYPE* __restrict__ B_tmp, \ + float* __restrict__ C_tmp, \ const TYPE* __restrict__ input, \ const at::Float8_e4m3fn* __restrict__ packed_w1, \ const at::Float8_e4m3fn* __restrict__ packed_w2, \ diff --git a/sgl-kernel/csrc/cpu/vec.h b/sgl-kernel/csrc/cpu/vec.h index 2b0a499e3fb3..7af877b44615 100644 --- a/sgl-kernel/csrc/cpu/vec.h +++ b/sgl-kernel/csrc/cpu/vec.h @@ -32,6 +32,22 @@ inline Vectorized convert_from_float_ext(const Vecto #define CVT_FP16_TO_FP32(a) \ _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) +// this doesn't hanel NaN. +inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) { + const __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + + const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4); + const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3); + const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7); + const __m512i nonsign = _mm512_or_si512(exp, mant); + + const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8); + const __m512i combined = _mm512_or_si512(nonsign, sign); + + const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512()); + return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined); +} + inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) { // The following conversion is without denorm behavior, that is to say, // Max subnorm : S.0000.111 = 0.875 ∗ 2**(−6) @@ -86,7 +102,7 @@ inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) { inline __m512bh CVT_FP8_TO_BF16(__m256i a) { #ifdef SGLANG_CPU_FP8_CVT_FTZ - return cvt_e4m3_bf16_intrinsic_without_denorm(a); + return cvt_e4m3_bf16_intrinsic_no_nan(a); #else return cvt_e4m3_bf16_intrinsic_with_denorm(a); #endif From 9acd754e28dba8505398877185883ee1a57a198b Mon Sep 17 00:00:00 2001 From: "Meng, Hengyu" Date: Wed, 23 Apr 2025 08:44:26 +0000 Subject: [PATCH 2/2] enable torchao on cpu side --- python/sglang/srt/layers/linear.py | 2 +- python/sglang/srt/layers/torchao_utils.py | 17 ++++++++++++++++- .../sglang/srt/model_executor/model_runner.py | 4 +++- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index fbb7f58d048c..5c31877ac272 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -183,7 +183,7 @@ def apply( bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - if layer.use_intel_amx_backend: + if layer.use_intel_amx_backend and type(layer.weight) is torch.Tensor: return sgl_kernel.cpu.weight_packed_linear(x, layer.weight, bias) return F.linear(x, layer.weight, bias) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index e08abd5ae1d5..ba373bc44b6e 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -40,6 +40,7 @@ def apply_torchao_config_to_model( model: torch.nn.Module, torchao_config: str, filter_fn: Optional[Callable] = proj_filter, + device: Optional[str] = "cuda", ): """Quantize a modelwith torchao quantization specified by torchao_config @@ -50,6 +51,7 @@ def apply_torchao_config_to_model( 128 """ # Lazy import to suppress some warnings + from torchao.dtypes import Int4CPULayout from torchao.quantization import ( float8_dynamic_activation_float8_weight, float8_weight_only, @@ -74,7 +76,20 @@ def apply_torchao_config_to_model( 128, 256, ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" - quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn) + if device == "cuda": + quantize_( + model, int4_weight_only(group_size=group_size), filter_fn=filter_fn + ) + elif device == "cpu": + quantize_( + model, + int4_weight_only(group_size=group_size, layout=Int4CPULayout()), + filter_fn=filter_fn, + ) + else: + raise ValueError( + f"TorchAO only supports INT4 weight only on CUDA/CPU device but got: {device}" + ) elif "gemlite" in torchao_config: # gemlite--- or # gemlite-- (packing_bitwidth defaults to 32) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0ba5f8b419b9..09fa276ab416 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -188,7 +188,9 @@ def initialize(self, min_per_gpu_memory: float): # In layered loading, torchao may have been applied if not torchao_applied: apply_torchao_config_to_model( - self.model, global_server_args_dict["torchao_config"] + self.model, + global_server_args_dict["torchao_config"], + device=self.device, ) # Apply torch TP if the model supports it