diff --git a/sgl-kernel/csrc/cpu/bmm.cpp b/sgl-kernel/csrc/cpu/bmm.cpp index 337d6d4c67a2..c834fd35e599 100644 --- a/sgl-kernel/csrc/cpu/bmm.cpp +++ b/sgl-kernel/csrc/cpu/bmm.cpp @@ -75,8 +75,8 @@ void bmm_kernel_impl( // out : [B, M, N] // scale: [] 0-dim tensor for per tensor quant // -void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, - std::optional& scale) { +void bmm_cpu( + at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional& scale) { RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector({out, mat1, mat2})); auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); diff --git a/sgl-kernel/csrc/cpu/gemm.cpp b/sgl-kernel/csrc/cpu/gemm.cpp index 5bee42ec0bfd..5a49a831babf 100644 --- a/sgl-kernel/csrc/cpu/gemm.cpp +++ b/sgl-kernel/csrc/cpu/gemm.cpp @@ -412,8 +412,8 @@ at::Tensor convert_weight_packed(at::Tensor& weight) { // bias : [N] // out : [M, N] // -at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, - std::optional& bias, bool is_vnni) { +at::Tensor +weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional& bias, bool is_vnni) { RECORD_FUNCTION( "sgl-kernel::weight_packed_linear", std::vector({mat1, mat2, bias})); diff --git a/sgl-kernel/csrc/cpu/gemm_fp8.cpp b/sgl-kernel/csrc/cpu/gemm_fp8.cpp index 420a31e67548..69609cb218ea 100644 --- a/sgl-kernel/csrc/cpu/gemm_fp8.cpp +++ b/sgl-kernel/csrc/cpu/gemm_fp8.cpp @@ -443,9 +443,14 @@ void tinygemm_kernel( INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); -at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, - std::vector block_size, std::optional& bias, - at::ScalarType out_dtype, bool is_vnni) { +at::Tensor fp8_scaled_mm_cpu( + at::Tensor& mat1, + at::Tensor& mat2, + at::Tensor& scales2, + std::vector block_size, + const std::optional& bias, + at::ScalarType out_dtype, + bool is_vnni) { RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector({mat1, mat2, scales2, block_size, bias})); auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); diff --git a/sgl-kernel/csrc/cpu/gemm_int8.cpp b/sgl-kernel/csrc/cpu/gemm_int8.cpp index c8ff0b251198..e0167daf8e54 100644 --- a/sgl-kernel/csrc/cpu/gemm_int8.cpp +++ b/sgl-kernel/csrc/cpu/gemm_int8.cpp @@ -310,9 +310,14 @@ std::tuple per_token_quant_int8_cpu(at::Tensor& A) { // bias : [N] // out : [M, N] // -at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, - at::Tensor& scales1, at::Tensor& scales2, - std::optional& bias, at::ScalarType out_dtype, bool is_vnni) { +at::Tensor int8_scaled_mm_cpu( + at::Tensor& mat1, + at::Tensor& mat2, + at::Tensor& scales1, + at::Tensor& scales2, + const std::optional& bias, + at::ScalarType out_dtype, + bool is_vnni) { RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector({mat1, mat2, scales1, scales2, bias})); auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); @@ -363,8 +368,13 @@ at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, } // fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu` -at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, - std::optional& bias, at::ScalarType out_dtype, bool is_vnni) { +at::Tensor int8_scaled_mm_with_quant( + at::Tensor& mat1, + at::Tensor& mat2, + const at::Tensor& scales2, + const std::optional& bias, + at::ScalarType out_dtype, + bool is_vnni) { RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector({mat1, mat2, scales2, bias})); auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); diff --git a/sgl-kernel/csrc/cpu/moe.cpp b/sgl-kernel/csrc/cpu/moe.cpp index 1746928df993..24e78e750dbd 100644 --- a/sgl-kernel/csrc/cpu/moe.cpp +++ b/sgl-kernel/csrc/cpu/moe.cpp @@ -925,10 +925,10 @@ at::Tensor fused_experts_cpu( at::Tensor& topk_ids, bool inplace, bool use_int8_w8a8, - std::optional& w1_scale, - std::optional& w2_scale, - std::optional& a1_scale, - std::optional& a2_scale, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni) { RECORD_FUNCTION( "sgl-kernel::fused_experts_cpu", std::vector({hidden_states, w1, w2, topk_weights, topk_ids})); @@ -1117,11 +1117,11 @@ at::Tensor shared_expert_cpu( bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, - std::optional& w1_scale, - std::optional& w2_scale, - std::optional> block_size, - std::optional& a1_scale, - std::optional& a2_scale, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni) { RECORD_FUNCTION("sgl-kernel::shared_expert_cpu", std::vector({hidden_states, w1, w2})); diff --git a/sgl-kernel/csrc/cpu/qkv_proj.cpp b/sgl-kernel/csrc/cpu/qkv_proj.cpp index ee15dec2332a..f78cfc5e1d9f 100644 --- a/sgl-kernel/csrc/cpu/qkv_proj.cpp +++ b/sgl-kernel/csrc/cpu/qkv_proj.cpp @@ -320,15 +320,19 @@ void rotary_emb_kernel_impl( } // anonymous namespace -extern at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, - std::optional& bias, bool is_vnni); +extern at::Tensor +weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional& bias, bool is_vnni); -extern at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, - std::optional& bias, at::ScalarType out_dtype, bool is_vnni); - -extern void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, - std::optional& scale); +extern at::Tensor int8_scaled_mm_with_quant( + at::Tensor& mat1, + at::Tensor& mat2, + const at::Tensor& scales2, + const std::optional& bias, + at::ScalarType out_dtype, + bool is_vnni); +extern void +bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional& scale); // NB: shapes in DeepDeek R1 // @@ -352,11 +356,10 @@ std::tuple qkv_proj_with_rope( at::Tensor& cos_sin_cache, double eps, bool use_int8_w8a8, - std::optional& q_a_proj_scale, - std::optional& q_b_proj_scale, - std::optional& kv_a_proj_scale, + const std::optional& q_a_proj_scale, + const std::optional& q_b_proj_scale, + const std::optional& kv_a_proj_scale, bool is_vnni) { - RECORD_FUNCTION("sgl-kernel::qkv_proj_with_rope", std::vector({ hidden_states, q_a_proj_weight, q_b_proj_weight, kv_a_proj_weight, w_kc})); diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index a7f5b901c6e3..c9e454ee01d7 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -13,12 +13,17 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include +#include #include "shm.h" +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +#define PREPARE_MODULE_DEF(NAME) \ + static struct PyModuleDef _##NAME##_module_def = {PyModuleDef_HEAD_INIT, #NAME, nullptr, 0, nullptr}; + // silu_and_mul at::Tensor silu_and_mul_cpu(at::Tensor& input); @@ -54,32 +59,54 @@ at::Tensor convert_weight_packed(at::Tensor& weight); std::tuple per_token_quant_int8_cpu(at::Tensor& A); // gemm -at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, - std::optional& bias, bool is_vnni); +at::Tensor +weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional& bias, bool is_vnni); // igemm -at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, - at::Tensor& scales1, at::Tensor& scales2, - std::optional& bias, at::ScalarType out_dtype, bool is_vnni); +at::Tensor int8_scaled_mm_cpu( + at::Tensor& mat1, + at::Tensor& mat2, + at::Tensor& scales1, + at::Tensor& scales2, + const std::optional& bias, + at::ScalarType out_dtype, + bool is_vnni); // fp8 gemm -at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, - at::Tensor& scales2, std::vector block_size, - std::optional& bias, at::ScalarType out_dtype, bool is_vnni); +at::Tensor fp8_scaled_mm_cpu( + at::Tensor& mat1, + at::Tensor& mat2, + at::Tensor& scales2, + std::vector block_size, + const std::optional& bias, + at::ScalarType out_dtype, + bool is_vnni); // quant + igemm -at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, - std::optional& bias, at::ScalarType out_dtype, bool is_vnni); +at::Tensor int8_scaled_mm_with_quant( + at::Tensor& mat1, + at::Tensor& mat2, + const at::Tensor& scales2, + const std::optional& bias, + at::ScalarType out_dtype, + bool is_vnni); // bmm -void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, - std::optional& scale); +void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional& scale); // fused moe -at::Tensor fused_experts_cpu(at::Tensor& hidden_states, at::Tensor& w1, at::Tensor& w2, - at::Tensor& topk_weights, at::Tensor& topk_ids, bool inplace, bool use_int8_w8a8, - std::optional& w1_scale, std::optional& w2_scale, - std::optional& a1_scale, std::optional& a2_scale, +at::Tensor fused_experts_cpu( + at::Tensor& hidden_states, + at::Tensor& w1, + at::Tensor& w2, + at::Tensor& topk_weights, + at::Tensor& topk_ids, + bool inplace, + bool use_int8_w8a8, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni); at::Tensor shared_expert_cpu( @@ -91,20 +118,30 @@ at::Tensor shared_expert_cpu( bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, - std::optional& w1_scale, - std::optional& w2_scale, - std::optional> block_size, - std::optional& a1_scale, - std::optional& a2_scale, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni); // weight absorption -std::tuple qkv_proj_with_rope( at::Tensor& hidden_states, - at::Tensor& q_a_proj_weight, at::Tensor& q_b_proj_weight, at::Tensor& kv_a_proj_weight, - at::Tensor& w_kc, at::Tensor& q_a_layernorm_weight, at::Tensor& kv_a_layernorm_weight, - at::Tensor& positions, at::Tensor& cos_sin_cache, double eps, bool use_int8_w8a8, - std::optional& q_a_proj_scale, std::optional& q_b_proj_scale, - std::optional& kv_a_proj_scale, bool is_vnni); +std::tuple qkv_proj_with_rope( + at::Tensor& hidden_states, + at::Tensor& q_a_proj_weight, + at::Tensor& q_b_proj_weight, + at::Tensor& kv_a_proj_weight, + at::Tensor& w_kc, + at::Tensor& q_a_layernorm_weight, + at::Tensor& kv_a_layernorm_weight, + at::Tensor& positions, + at::Tensor& cos_sin_cache, + double eps, + bool use_int8_w8a8, + const std::optional& q_a_proj_scale, + const std::optional& q_b_proj_scale, + const std::optional& kv_a_proj_scale, + bool is_vnni); // shared memory init void initialize(int size, int rank); @@ -119,61 +156,110 @@ at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos); -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +TORCH_LIBRARY_EXPAND(sgl_kernel, m) { // activation - m.def("silu_and_mul_cpu", &silu_and_mul_cpu, "SiLU and mul for CPU"); + m.def("silu_and_mul_cpu(Tensor input) -> Tensor"); + m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu); // norm - m.def("rmsnorm_cpu", &rmsnorm_cpu, "Root mean square normalization for CPU"); - m.def("fused_add_rmsnorm_cpu", &fused_add_rmsnorm_cpu, "Fused add root mean square normalization for CPU"); + m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor"); + m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu); + + m.def("fused_add_rmsnorm_cpu(Tensor input, Tensor residual, Tensor weight, float eps) -> ()"); + m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu); // topk - m.def("grouped_topk_cpu", &grouped_topk_cpu, "Grouped TopK for CPU"); + m.def( + "grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, " + "int topk_group) -> (Tensor, Tensor)"); + m.impl("grouped_topk_cpu", torch::kCPU, &grouped_topk_cpu); // biased group topk - m.def("biased_grouped_topk_cpu", &biased_grouped_topk_cpu, "Biased Grouped TopK for CPU"); + m.def( + "biased_grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, Tensor correction_bias, int topk, bool " + "renormalize, int num_expert_group, int topk_group) -> (Tensor, Tensor)"); + m.impl("biased_grouped_topk_cpu", torch::kCPU, &biased_grouped_topk_cpu); // decode - m.def("decode_attention_cpu", &decode_attention_cpu, "Attention decoding for CPU"); + m.def( + "decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cache, Tensor output, Tensor key, Tensor value, " + "Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, " + "float logit_cap) -> ()"); + m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu); // extend - m.def("extend_attention_cpu", &extend_attention_cpu, "Attention extend for CPU"); + m.def( + "extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor o_extend, Tensor k_buffer, " + "Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor " + "extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()"); + m.impl("extend_attention_cpu", torch::kCPU, &extend_attention_cpu); // weight prepack - m.def("convert_weight_packed", &convert_weight_packed, "prepack weight to vnni format for intel AMX"); + m.def("convert_weight_packed(Tensor weight) -> Tensor"); + m.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed); // quant - m.def("per_token_quant_int8_cpu", &per_token_quant_int8_cpu, "dynamic quantization for CPU"); + m.def("per_token_quant_int8_cpu(Tensor A) -> (Tensor, Tensor)"); + m.impl("per_token_quant_int8_cpu", torch::kCPU, &per_token_quant_int8_cpu); // gemm - m.def("weight_packed_linear", &weight_packed_linear, "weight packed linear for intel AMX"); + m.def("weight_packed_linear(Tensor mat1, Tensor mat2, Tensor? bias, bool is_vnni) -> Tensor"); + m.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear); // igemm - m.def("int8_scaled_mm_cpu", &int8_scaled_mm_cpu, "int8 weight packed linear for intel AMX"); + m.def( + "int8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales1, Tensor scales2, Tensor? bias, ScalarType " + "out_dtype, bool is_vnni) -> Tensor"); + m.impl("int8_scaled_mm_cpu", torch::kCPU, &int8_scaled_mm_cpu); // fp8 gemm - m.def("fp8_scaled_mm_cpu", &fp8_scaled_mm_cpu, "fp8 weight packed linear for intel AMX"); + m.def( + "fp8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales2, int[] block_size, Tensor? bias, ScalarType " + "out_dtype, bool is_vnni) -> Tensor"); + m.impl("fp8_scaled_mm_cpu", torch::kCPU, &fp8_scaled_mm_cpu); // quant + igemm - m.def("int8_scaled_mm_with_quant", &int8_scaled_mm_with_quant, "fused per row quant and int8 scaled mm for intel AMX"); + m.def( + "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, Tensor? bias, ScalarType out_dtype, bool " + "is_vnni) -> Tensor"); + m.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant); // bmm - m.def("bmm_cpu", &bmm_cpu, "bmm kernel for intel AMX"); + m.def("bmm_cpu(Tensor out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()"); + m.impl("bmm_cpu", torch::kCPU, &bmm_cpu); // moe - m.def("fused_experts_cpu", &fused_experts_cpu, "fused moe kernel for CPU"); + m.def( + "fused_experts_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool " + "inplace, bool use_int8_w8a8, Tensor? w1_scale, Tensor? w2_scale, Tensor? a1_scale, Tensor? a2_scale, bool " + "is_vnni) -> Tensor"); + m.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu); // weight absorption - m.def("qkv_proj_with_rope", &qkv_proj_with_rope, "fused qkv projection kernel with weight absorption for intel AMX"); + m.def( + "qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor " + "kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, " + "Tensor cos_sin_cache, float eps, bool use_int8_w8a8, Tensor? q_a_proj_scale, Tensor? q_b_proj_scale, Tensor? " + "kv_a_proj_scale, bool is_vnni) -> (Tensor, Tensor, Tensor)"); + m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope); // shared expert - m.def("shared_expert_cpu", &shared_expert_cpu, "shared expert kernel for CPU"); + m.def( + "shared_expert_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor fused_experts_out, float " + "routed_scaling_factor, bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? " + "w2_scale, int[]? block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> Tensor"); + m.impl("shared_expert_cpu", torch::kCPU, &shared_expert_cpu); + // rope + m.def("rotary_position_embedding_cpu(Tensor t_pos, Tensor q_pe, Tensor k_pe, Tensor t_emb_pos) -> (Tensor, Tensor)"); + m.impl("rotary_position_embedding_cpu", torch::kCPU, &rotary_position_embedding_cpu); +} + +PREPARE_MODULE_DEF(common_ops) + +PYBIND11_MODULE(common_ops, m) { // all reduce m.def("initialize", &initialize, "shared memory initialization for CPU"); m.def("shm_allreduce", &shm_allreduce, "low latency all_reduce implementation for CPU"); m.def("shm_allgather", &shm_allgather, "low latency all_gather implementation for CPU"); - - // rope - m.def("rotary_position_embedding_cpu", &rotary_position_embedding_cpu, "rotary position embedding for CPU"); } diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index da2959269332..ce67c0dffcd3 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -47,3 +47,5 @@ tree_speculative_sampling_target_only, ) from sgl_kernel.version import __version__ + +from sgl_kernel.cpu import * \ No newline at end of file diff --git a/sgl-kernel/python/sgl_kernel/cpu.py b/sgl-kernel/python/sgl_kernel/cpu.py index 4f6a6c783ad3..9f26d82fa30e 100644 --- a/sgl-kernel/python/sgl_kernel/cpu.py +++ b/sgl-kernel/python/sgl_kernel/cpu.py @@ -16,7 +16,7 @@ def fused_experts( a2_scale=None, is_vnni=True, ): - return sgl_kernel.common_ops.fused_experts_cpu( + return torch.ops.sgl_kernel.fused_experts_cpu( x, w13_weight, w2_weight, @@ -48,7 +48,7 @@ def shared_expert( a2_scale=None, is_vnni=True, ): - return sgl_kernel.common_ops.shared_expert_cpu( + return torch.ops.sgl_kernel.shared_expert_cpu( hidden_states, w1, w2, @@ -67,7 +67,7 @@ def shared_expert( def convert_weight_packed(weight): - return sgl_kernel.common_ops.convert_weight_packed(weight) + return torch.ops.sgl_kernel.convert_weight_packed(weight) def qkv_proj_with_rope( @@ -87,7 +87,7 @@ def qkv_proj_with_rope( kv_a_proj_scale=None, is_vnni=True, ): - return sgl_kernel.common_ops.qkv_proj_with_rope( + return torch.ops.sgl_kernel.qkv_proj_with_rope( hidden_states, q_a_proj_weight, q_b_proj_weight, @@ -121,7 +121,7 @@ def decode_attention( sm_scale, logit_cap=0.0, ): - sgl_kernel.common_ops.decode_attention_cpu( + torch.ops.sgl_kernel.decode_attention_cpu( q, k_buffer, v_buffer, @@ -154,7 +154,7 @@ def extend_attention( sm_scale, logit_cap=0.0, ): - sgl_kernel.common_ops.extend_attention_cpu( + torch.ops.sgl_kernel.extend_attention_cpu( q_extend, k_extend, v_extend, @@ -178,7 +178,7 @@ def weight_packed_linear( bias, is_vnni=True, ): - return sgl_kernel.common_ops.weight_packed_linear( + return torch.ops.sgl_kernel.weight_packed_linear( x, weight, bias, @@ -194,7 +194,7 @@ def grouped_topk( num_expert_group, topk_group, ): - return sgl_kernel.common_ops.grouped_topk_cpu( + return torch.ops.sgl_kernel.grouped_topk_cpu( hidden_states, router_logits, top_k, @@ -213,7 +213,7 @@ def biased_grouped_topk( num_expert_group, topk_group, ): - return sgl_kernel.common_ops.biased_grouped_topk_cpu( + return torch.ops.sgl_kernel.biased_grouped_topk_cpu( hidden_states, router_logits, bias, @@ -230,7 +230,7 @@ def fused_add_rmsnorm( weight, eps, ): - sgl_kernel.common_ops.fused_add_rmsnorm_cpu( + torch.ops.sgl_kernel.fused_add_rmsnorm_cpu( input, residual, weight, @@ -243,7 +243,7 @@ def rmsnorm( weight, eps, ): - return sgl_kernel.common_ops.rmsnorm_cpu( + return torch.ops.sgl_kernel.rmsnorm_cpu( input, weight, eps, @@ -259,7 +259,7 @@ def int8_scaled_mm( out_dtype, is_vnni=True, ): - return sgl_kernel.common_ops.int8_scaled_mm_cpu( + return torch.ops.sgl_kernel.int8_scaled_mm_cpu( mat1, mat2, scales1, scales2, bias, out_dtype, is_vnni ) @@ -272,13 +272,13 @@ def int8_scaled_mm_with_quant( out_dtype, is_vnni=True, ): - return sgl_kernel.common_ops.int8_scaled_mm_with_quant( + return torch.ops.sgl_kernel.int8_scaled_mm_with_quant( mat1, mat2, scales2, bias, out_dtype, is_vnni ) def per_token_quant_int8(x): - return sgl_kernel.common_ops.per_token_quant_int8_cpu(x) + return torch.ops.sgl_kernel.per_token_quant_int8_cpu(x) def fp8_scaled_mm( @@ -290,7 +290,7 @@ def fp8_scaled_mm( out_dtype, is_vnni=True, ): - return sgl_kernel.common_ops.fp8_scaled_mm_cpu( + return torch.ops.sgl_kernel.fp8_scaled_mm_cpu( mat1, mat2, scales2, block_size, bias, out_dtype, is_vnni ) @@ -301,7 +301,7 @@ def rotary_position_embedding( k_pe, t_emb_pos, ): - return sgl_kernel.common_ops.rotary_position_embedding_cpu( + return torch.ops.sgl_kernel.rotary_position_embedding_cpu( t_pos, q_pe, k_pe, @@ -312,8 +312,8 @@ def rotary_position_embedding( def silu_and_mul( input, ): - return sgl_kernel.common_ops.silu_and_mul_cpu(input) + return torch.ops.sgl_kernel.silu_and_mul_cpu(input) def bmm(out, mat1, mat2, is_vnni=True, scale=None): - return sgl_kernel.common_ops.bmm_cpu(out, mat1, mat2, is_vnni, scale) + return torch.ops.sgl_kernel.bmm_cpu(out, mat1, mat2, is_vnni, scale) diff --git a/test/srt/test_rope.py b/test/srt/test_rope.py index 4858e4a4b8cf..b3134adfcaee 100644 --- a/test/srt/test_rope.py +++ b/test/srt/test_rope.py @@ -1,7 +1,7 @@ import unittest import expecttest -import sgl_kernel.cpu +import sgl_kernel import torch