Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/sglang/srt/cpu_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ def _process_weight_after_loading(module, weight_names) -> None:
)


def is_cpu():
from sglang.srt.managers.schedule_batch import global_server_args_dict
return global_server_args_dict["device"] == "cpu"


def is_cpu_amx():
return is_cpu() and cpu_has_amx_support()

class PackWeightMethod:
def __init__(self, weight_names):
self.weight_names = weight_names
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn

from sglang.srt.utils import is_cuda, is_hip
from sglang.srt.cpu_utils import cpu_has_amx_support
from sglang.srt.cpu_utils import cpu_has_amx_support, is_cpu

_is_cuda = is_cuda()
_is_hip = is_hip()
Expand Down Expand Up @@ -42,7 +42,7 @@ def dispatch_forward(self):
return self.forward_cuda
elif _is_hip:
return self.forward_hip
elif global_server_args_dict["device"] == "cpu":
elif is_cpu():
return self.forward_cpu
else:
return self.forward_native
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
if is_cuda_available():
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul

from sglang.srt.cpu_utils import cpu_has_amx_support
from sglang.srt.cpu_utils import cpu_has_amx_support, is_cpu_amx
if cpu_has_amx_support():
import sgl_kernel.cpu

Expand Down Expand Up @@ -178,7 +178,7 @@ def get_act_fn(
return act_fn


if not (is_cuda_available() or (not is_cuda_available() and cpu_has_amx_support())):
if not (is_cuda_available() or is_cpu_amx()):
logger.info(
"sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
)
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
rmsnorm,
)

from sglang.srt.cpu_utils import cpu_has_amx_support
from sglang.srt.cpu_utils import cpu_has_amx_support, is_cpu_amx
if cpu_has_amx_support():
import sgl_kernel.cpu

Expand Down Expand Up @@ -136,7 +136,7 @@ def forward_cuda(
return out


if not (is_cuda_available() or (not is_cuda_available() and cpu_has_amx_support())):
if not (is_cuda_available() or is_cpu_amx()):
logger.info(
"sgl-kernel is not available on Non-NV platforms or Non-AMX CPUs. Fallback to other kernel libraries."
)
Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.nn.functional as F

from sglang.srt.utils import get_compiler_backend
from sglang.srt.cpu_utils import cpu_has_amx_support
from sglang.srt.cpu_utils import cpu_has_amx_support, is_cpu_amx
if cpu_has_amx_support():
import sgl_kernel.cpu

Expand Down Expand Up @@ -181,7 +181,7 @@ def select_experts(
assert num_expert_group is not None
if correction_bias is None:
device = hidden_states.device
if device == torch.device("cpu") and cpu_has_amx_support():
if is_cpu_amx():
M = hidden_states.size(0)
topk_weights = torch.empty(
M, top_k, dtype=torch.float32, device=device
Expand All @@ -208,7 +208,7 @@ def select_experts(
)
else:
device = hidden_states.device
if device == torch.device("cpu") and cpu_has_amx_support():
if is_cpu_amx():
M = hidden_states.size(0)
topk_weights = torch.empty(
M, top_k, dtype=torch.float32, device=device
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
if _is_cuda_available:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace

from sglang.srt.cpu_utils import cpu_has_amx_support
from sglang.srt.cpu_utils import cpu_has_amx_support, is_cpu_amx
if cpu_has_amx_support():
import sgl_kernel.cpu

Expand Down Expand Up @@ -725,7 +725,7 @@ def forward(
positions = torch.add(positions, offsets) if offsets is not None else positions

# TODO: Add scenario of self.rotary_dim < self.head_size
if positions.device == torch.device("cpu") and cpu_has_amx_support():
if is_cpu_amx():
return sgl_kernel.cpu.rotary_position_embedding(
positions, query, key, self.cos_sin_cache)
else:
Expand Down
Loading