Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sgl-kernel/csrc/cpu/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
// block size for AMX gemm
constexpr int block_size_m() { return 2 * TILE_M; }
constexpr int block_size_n() { return 2 * TILE_N; }

constexpr int get_splitk_num() { return 4; }
// define threshold using brgemm (intel AMX)
template <typename T> inline bool can_use_brgemm(int M);
template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; }
Expand Down Expand Up @@ -100,6 +100,7 @@ void shared_expert_int8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
float* __restrict__ C_tmp,
float* __restrict__ C_splitk_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
const scalar_t* __restrict__ input,
Expand Down
10 changes: 8 additions & 2 deletions sgl-kernel/csrc/cpu/moe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1238,10 +1238,11 @@ at::Tensor shared_expert_cpu(
// for int8 w8a8:
// 3. Aq_tmp : [M, K] or [M, N]
// 4. As_tmp : [M]
// 5. C_spiltk_tmp: [spiltk_num * 2 * M * N]
//
// for fp8 w8a16:
// 5. intermediate_cache0 : [M, 2N]
// 6. B_tmp: [T, BLOCK_M, max(K, N)]
// 6. intermediate_cache0 : [M, 2N]
// 7. B_tmp: [T, BLOCK_M, max(K, N)]
//
int num_threads = at::get_num_threads();
int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
Expand All @@ -1262,6 +1263,10 @@ at::Tensor shared_expert_cpu(
uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
float* __restrict__ As_tmp = (float*)((void*)(Aq_tmp + std::max(M * K, M * N)));

constexpr int64_t SPLITK_NUM = get_splitk_num();
auto buffer_splitk = at::empty({SPLITK_NUM * 2 * M * N}, hidden_states.options().dtype(at::kFloat));
float* __restrict__ C_spiltk_tmp = (float*)buffer_splitk.data_ptr<float>();

auto w1s = w1_scale.value();
auto w2s = w2_scale.value();
TORCH_CHECK(w1s.numel() == 2 * N);
Expand All @@ -1271,6 +1276,7 @@ at::Tensor shared_expert_cpu(
out_hidden_states.data_ptr<scalar_t>(),
intermediate_cache1,
C_tmp,
C_spiltk_tmp,
Aq_tmp,
As_tmp,
hidden_states.data_ptr<scalar_t>(),
Expand Down
Loading