diff --git a/python/sglang/srt/cpu_utils.py b/python/sglang/srt/cpu_utils.py index c8f783460f2a..95d943abb12d 100644 --- a/python/sglang/srt/cpu_utils.py +++ b/python/sglang/srt/cpu_utils.py @@ -94,6 +94,14 @@ def _process_weight_after_loading(module, weight_names) -> None: ) +def is_cpu(): + from sglang.srt.managers.schedule_batch import global_server_args_dict + return global_server_args_dict["device"] == "cpu" + + +def is_cpu_amx(): + return is_cpu() and cpu_has_amx_support() + class PackWeightMethod: def __init__(self, weight_names): self.weight_names = weight_names diff --git a/python/sglang/srt/custom_op.py b/python/sglang/srt/custom_op.py index caa1da43ada2..7fb616aec40d 100644 --- a/python/sglang/srt/custom_op.py +++ b/python/sglang/srt/custom_op.py @@ -4,7 +4,7 @@ from torch import nn from sglang.srt.utils import is_cuda, is_hip -from sglang.srt.cpu_utils import cpu_has_amx_support +from sglang.srt.cpu_utils import cpu_has_amx_support, is_cpu _is_cuda = is_cuda() _is_hip = is_hip() @@ -42,7 +42,7 @@ def dispatch_forward(self): return self.forward_cuda elif _is_hip: return self.forward_hip - elif global_server_args_dict["device"] == "cpu": + elif is_cpu(): return self.forward_cpu else: return self.forward_native diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 4d30ed9b7ebd..fba5732f7bbe 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -26,7 +26,7 @@ if is_cuda_available(): from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul -from sglang.srt.cpu_utils import cpu_has_amx_support +from sglang.srt.cpu_utils import cpu_has_amx_support, is_cpu_amx if cpu_has_amx_support(): import sgl_kernel.cpu @@ -178,7 +178,7 @@ def get_act_fn( return act_fn -if not (is_cuda_available() or (not is_cuda_available() and cpu_has_amx_support())): +if not (is_cuda_available() or is_cpu_amx()): logger.info( "sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries." ) diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index fca3c5e1cd19..147e5b77266f 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -29,7 +29,7 @@ rmsnorm, ) -from sglang.srt.cpu_utils import cpu_has_amx_support +from sglang.srt.cpu_utils import cpu_has_amx_support, is_cpu_amx if cpu_has_amx_support(): import sgl_kernel.cpu @@ -136,7 +136,7 @@ def forward_cuda( return out -if not (is_cuda_available() or (not is_cuda_available() and cpu_has_amx_support())): +if not (is_cuda_available() or is_cpu_amx()): logger.info( "sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries." ) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index a5c92daa43a3..e52ae20d2e32 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from sglang.srt.utils import get_compiler_backend -from sglang.srt.cpu_utils import cpu_has_amx_support +from sglang.srt.cpu_utils import cpu_has_amx_support, is_cpu_amx if cpu_has_amx_support(): import sgl_kernel.cpu @@ -181,7 +181,7 @@ def select_experts( assert num_expert_group is not None if correction_bias is None: device = hidden_states.device - if device == torch.device("cpu") and cpu_has_amx_support(): + if is_cpu_amx(): M = hidden_states.size(0) topk_weights = torch.empty( M, top_k, dtype=torch.float32, device=device @@ -208,7 +208,7 @@ def select_experts( ) else: device = hidden_states.device - if device == torch.device("cpu") and cpu_has_amx_support(): + if is_cpu_amx(): M = hidden_states.size(0) topk_weights = torch.empty( M, top_k, dtype=torch.float32, device=device diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 77087cbb208a..5ac256d404ba 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -15,7 +15,7 @@ if _is_cuda_available: from sgl_kernel import apply_rope_with_cos_sin_cache_inplace -from sglang.srt.cpu_utils import cpu_has_amx_support +from sglang.srt.cpu_utils import cpu_has_amx_support, is_cpu_amx if cpu_has_amx_support(): import sgl_kernel.cpu @@ -725,7 +725,7 @@ def forward( positions = torch.add(positions, offsets) if offsets is not None else positions # TODO: Add scenario of self.rotary_dim < self.head_size - if positions.device == torch.device("cpu") and cpu_has_amx_support(): + if is_cpu_amx(): return sgl_kernel.cpu.rotary_position_embedding( positions, query, key, self.cos_sin_cache) else: