Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/sglang/srt/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def apply(
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:

if layer.use_intel_amx_backend:
if layer.use_intel_amx_backend and type(layer.weight) is torch.Tensor:
return sgl_kernel.cpu.weight_packed_linear(x, layer.weight, bias)

return F.linear(x, layer.weight, bias)
Expand Down
17 changes: 16 additions & 1 deletion python/sglang/srt/layers/torchao_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def apply_torchao_config_to_model(
model: torch.nn.Module,
torchao_config: str,
filter_fn: Optional[Callable] = proj_filter,
device: Optional[str] = "cuda",
):
"""Quantize a modelwith torchao quantization specified by torchao_config

Expand All @@ -50,6 +51,7 @@ def apply_torchao_config_to_model(
128
"""
# Lazy import to suppress some warnings
from torchao.dtypes import Int4CPULayout
from torchao.quantization import (
float8_dynamic_activation_float8_weight,
float8_weight_only,
Expand All @@ -74,7 +76,20 @@ def apply_torchao_config_to_model(
128,
256,
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
if device == "cuda":
quantize_(
model, int4_weight_only(group_size=group_size), filter_fn=filter_fn
)
elif device == "cpu":
quantize_(
model,
int4_weight_only(group_size=group_size, layout=Int4CPULayout()),
filter_fn=filter_fn,
)
else:
raise ValueError(
f"TorchAO only supports INT4 weight only on CUDA/CPU device but got: {device}"
)
elif "gemlite" in torchao_config:
# gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
# gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,9 @@ def initialize(self, min_per_gpu_memory: float):
# In layered loading, torchao may have been applied
if not torchao_applied:
apply_torchao_config_to_model(
self.model, global_server_args_dict["torchao_config"]
self.model,
global_server_args_dict["torchao_config"],
device=self.device,
)

# Apply torch TP if the model supports it
Expand Down
4 changes: 4 additions & 0 deletions sgl-kernel/csrc/cpu/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ void fused_experts_fp8_kernel_impl(
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
scalar_t* __restrict__ A_tmp,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
Expand Down Expand Up @@ -116,6 +118,8 @@ void shared_expert_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
Expand Down
164 changes: 79 additions & 85 deletions sgl-kernel/csrc/cpu/gemm_fp8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
#include "vec.h"
#include "gemm.h"

// we use 4x32 for BLOCK_M
#define BLOCK_SIZE_M_SCALE 4

namespace {

template <typename scalar_t>
Expand Down Expand Up @@ -60,33 +63,32 @@ inline void unpack_B(
constexpr int BLOCK_N = block_size_n();
static_assert(BLOCK_N == 32);

#pragma GCC unroll 4
for (int k = 0; k < K2; ++k) {
for (int n = 0; n < N; n += 64) { // BLOCK_N = 32
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + n);
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2);

__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);

__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);

// Apply scale
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
// Apply scale
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));

f0_lo = _mm512_mul_ps(f0_lo, vd);
f0_hi = _mm512_mul_ps(f0_hi, vd);
f1_lo = _mm512_mul_ps(f1_lo, vd);
f1_hi = _mm512_mul_ps(f1_hi, vd);
f0_lo = _mm512_mul_ps(f0_lo, vd);
f0_hi = _mm512_mul_ps(f0_hi, vd);
f1_lo = _mm512_mul_ps(f1_lo, vd);
f1_hi = _mm512_mul_ps(f1_hi, vd);

bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);

_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + n * 2 + 0, (__m512i)bf16_0);
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + n * 2 + 32, (__m512i)bf16_1);
}
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0);
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1);
}
#else
TORCH_CHECK(false, "unpack_B: scalar path not implemented!");
Expand All @@ -112,12 +114,18 @@ struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BL
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;

const int KB = div_up(K, BLOCK_K);

// prefetch distance
constexpr int PREFETCH_SIZE_K = 0;

__m512bh va;
__m512bh vb[COLS];
__m512 vc[ROWS * COLS];
__m512 vsum[ROWS * COLS];

// block quant scale
__m512 vscale;

auto loadc = [&](auto i) {
constexpr int col = i % COLS;
Expand All @@ -129,7 +137,6 @@ struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BL
};
Unroll<ROWS * COLS>{}(loadc);

const int K2 = K >> 1;
const int lda2 = lda >> 1;
const int ldb2 = ldb; // ldb * 2 >> 1;
const float* a_ptr = reinterpret_cast<const float*>(A);
Expand All @@ -139,62 +146,50 @@ struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BL
constexpr int row = i / COLS;
constexpr int col = i % COLS;

int idx = k * 2 / block_size_K;
const __m512 vd = _mm512_set1_ps(scale[idx]);

if constexpr (col == 0) {
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
}
if constexpr (row == 0) {

if constexpr (col % 2 == 0) {
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16);
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
}

__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);

__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);

// Apply scale
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));

f0_lo = _mm512_mul_ps(f0_lo, vd);
f0_hi = _mm512_mul_ps(f0_hi, vd);
f1_lo = _mm512_mul_ps(f1_lo, vd);
f1_hi = _mm512_mul_ps(f1_hi, vd);

vb[col + 0] = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
vb[col + 1] = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0));
vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1));
}
}
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]);
};
for (int k = 0; k < K2; ++k) {
Unroll<ROWS * COLS>{}(compute, k);

constexpr int BLOCK_K2 = BLOCK_K >> 1;
for (int kb = 0; kb < KB; ++kb) {
int kb_start = kb * BLOCK_K2;
int kb_end = std::min(K, kb_start + BLOCK_K2);
// 1. load scale vector
vscale = _mm512_set1_ps(scale[kb]);
// 2. zero vsum for each block
Unroll<ROWS * COLS>{}([&](auto i) {
vsum[i] = _mm512_set1_ps(0.f);
});
// 3. accumulate across each block
for (int k = kb_start; k < kb_end; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
// 4. apply scale
Unroll<ROWS * COLS>{}([&](auto i) {
vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]);
});
}

auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// for COLS = 1, 3 use 256bit store
// for COLS = 2, 4 use 512bit store
if constexpr (COLS % 2 == 0) {
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
}
} else {
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(C + row * ldc + col * 16),
(__m256i)(_mm512_cvtneps_pbh(vc[i])));
// for COLS = 2,4 use 512bit store
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
}
};
Unroll<ROWS * COLS>{}(storec);
Expand Down Expand Up @@ -243,25 +238,22 @@ struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
int lda,
int ldb,
int ldc) {
constexpr int BLOCK_N = block_size_n();

// [BLOCK_K, BLOCK_N] -> [BLOCK_K / 2, BLOCK_N * 2]
const int ldb_tmp = block_size_n();
constexpr int BLOCK_N = block_size_n();

static_assert(BLOCK_K == 128);
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
const int ldb_tmp = BLOCK_N;

// accumulate across K per BLOCK_K
for (int k = 0; k < K; k += BLOCK_K) {
int kb_size = std::min(BLOCK_K, K - k);

int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
unpack_B(Btmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);

const bool add_C = (k != 0);
at::native::cpublas::brgemm(
M, N, kb_size, lda, ldb_tmp, BLOCK_N, add_C, A + k, Btmp, Ctmp);
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
}

at::native::cpublas::brgemm(
M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp);

// copy from Ctmp to C
for (int m = 0; m < M; ++m) {
if constexpr (has_bias) {
Expand Down Expand Up @@ -310,18 +302,10 @@ void tinygemm_kernel(
int64_t nb_size = std::min(BLOCK_N, N - nb_start);

switch(mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break;
// mb_size = 2
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break;
// mb_size = 3
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break;
// mb_size = 4
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break;
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
Expand All @@ -335,15 +319,17 @@ void fp8_scaled_mm_kernel_impl(
const at::Float8_e4m3fn* __restrict__ mat2,
const float* __restrict__ scales2,
const float* __restrict__ bias,
scalar_t* __restrict__ buffer,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideM,
int64_t out_strideM,
int64_t block_size_N,
int64_t block_size_K) {
int64_t block_size_K,
int64_t buffer_size_per_thread) {

constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
Expand All @@ -359,10 +345,9 @@ void fp8_scaled_mm_kernel_impl(
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);

// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
// for brgemm when mat2 is float8_e4m3
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
int tid = at::get_thread_num();
scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread;
float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K));

for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
Expand Down Expand Up @@ -470,6 +455,7 @@ at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& sca
int64_t block_size_N = block_size[0];
int64_t block_size_K = block_size[1];

constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
constexpr int64_t BLOCK_N = block_size_n();
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
Expand Down Expand Up @@ -498,20 +484,28 @@ at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& sca
bias_data = bias.value().data_ptr<float>();
}

// Btmp : [T, BLOCK_N * K]
// Ctmp : [T, BLOCK_M * BLOCK_N]
int num_threads = at::get_num_threads();
int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
auto buffer = at::empty({num_threads, size_per_thread}, mat1.options());

AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
fp8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<at::Float8_e4m3fn>(),
scales2.data_ptr<float>(),
bias_data,
buffer.data_ptr<scalar_t>(),
M,
N,
K,
mat1_strideM,
out_strideM,
block_size_N,
block_size_K);
block_size_K,
size_per_thread);
});

return out;
Expand Down
Loading