Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "sgl-kernel/3rdparty/cutlass"]
path = sgl-kernel/3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
[submodule "sgl-kernel/3rdparty/composable_kernel"]
path = sgl-kernel/3rdparty/composable_kernel
url = https://github.com/ROCm/composable_kernel.git
34 changes: 33 additions & 1 deletion python/sglang/srt/layers/moe/fused_moe_triton/layer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py

import os

from abc import abstractmethod
from enum import Enum
from typing import Callable, List, Optional, Tuple
Expand All @@ -18,7 +20,11 @@
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.utils import set_weight_attrs

from sglang.srt.utils import (
is_hip,
set_weight_attrs,
)

if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
Expand Down Expand Up @@ -97,6 +103,32 @@ def create_weights(
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)

def permute_weight(x: torch.Tensor) -> torch.Tensor:
b_ = x.shape[0];
n_ = x.shape[1];
k_ = x.shape[2];

x_ = x
if envs.VLLM_MOE_SHUFFLE:
x_ = x_.view(b_, n_ / 16, 16, k_ / 32, 4, 8)
x_ = x_.permute(0, 1, 3, 4, 2, 5)
x_ = x_.contiguous()
return x_

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
self.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
self.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
return

def apply(
self,
layer: torch.nn.Module,
Expand Down
84 changes: 60 additions & 24 deletions python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,18 @@ def __init__(self, quant_config):
self.quant_config = quant_config
self.block_quant = self.quant_config.weight_block_size is not None

def permute_weight(x: torch.Tensor) -> torch.Tensor:
b_ = x.shape[0];
n_ = x.shape[1];
k_ = x.shape[2];

x_ = x
if envs.VLLM_MOE_SHUFFLE:
x_ = x_.view(b_, n_ / 16, 16, k_ / 64, 4, 16)
x_ = x_.permute(0, 1, 3, 4, 2, 5)
x_ = x_.contiguous()
return x_

def create_weights(
self,
layer: Module,
Expand Down Expand Up @@ -616,18 +628,30 @@ def process_weights_after_loading(self, layer: Module) -> None:
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)

# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
if is_hip():
if bool(int(os.getenv("CK_MOE", "0"))):
self.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
self.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
elif bool(int(os.getenv("MOE_PADDING", "0"))):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return

# If checkpoint is fp8, we need to handle that the
Expand Down Expand Up @@ -708,18 +732,30 @@ def process_weights_after_loading(self, layer: Module) -> None:
max_w13_scales, requires_grad=False
)

# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
if is_hip():
if bool(int(os.getenv("CK_MOE", "0"))):
self.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
self.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
elif bool(int(os.getenv("MOE_PADDING", "0"))):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return

def apply(
Expand Down
1 change: 1 addition & 0 deletions sgl-kernel/3rdparty/composable_kernel
Submodule composable_kernel added at 888317
191 changes: 154 additions & 37 deletions sgl-kernel/setup.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,24 @@
from pathlib import Path

import os
import sys
import shutil
import torch

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

root = Path(__file__).parent.resolve()

def is_cuda() -> bool:
"""Return whether it is CUDA on the NVIDIA CUDA platform."""
return torch.cuda.is_available() and torch.version.cuda


def is_hip() -> bool:
"""Return whether it is HIP on the AMD ROCm platform."""
return torch.cuda.is_available() and torch.version.hip


def get_version():
with open(root / "pyproject.toml") as f:
Expand All @@ -21,44 +35,147 @@ def update_wheel_platform_tag():
)
old_wheel.rename(new_wheel)

if is_cuda():
cutlass = root / "3rdparty" / "cutlass"
include_dirs = [
cutlass.resolve() / "include",
cutlass.resolve() / "tools" / "util" / "include",
]
nvcc_flags = [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
]
cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"]
ext_modules = [
CUDAExtension(
name="sgl_kernel.ops._kernels",
sources=[
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
],
include_dirs=include_dirs,
extra_compile_args={
"nvcc": nvcc_flags,
"cxx": cxx_flags,
},
libraries=libraries,
extra_link_args=extra_link_args,
),
]
else:
def validate_and_update_archs(archs):
# List of allowed architectures
allowed_archs = ["native", "gfx90a",
"gfx940", "gfx941", "gfx942", "gfx1100"]

# Validate if each element in archs is in allowed_archs
assert all(
arch in allowed_archs for arch in archs
), f"One of GPU archs of {archs} is invalid or not supported"

def rename_cpp_to_cu(els, dst, recurisve=False):
def do_rename_and_mv(name, src, dst, ret):
newName = name
if name.endswith(".cpp") or name.endswith(".cu"):
newName = name.replace(".cpp", ".cu")
ret.append(f'{dst}/{newName}')
shutil.copy(f'{src}/{name}', f'{dst}/{newName}')
ret = []
for el in els:
if not os.path.exists(el):
continue
if os.path.isdir(el):
for entry in os.listdir(el):
if os.path.isdir(f'{el}/{entry}'):
if recurisve:
ret += rename_cpp_to_cu([f'{el}/{entry}'],
dst, recurisve)
continue
do_rename_and_mv(entry, el, dst, ret)
else:
do_rename_and_mv(os.path.basename(el),
os.path.dirname(el), dst, ret)
return ret

this_dir = os.path.dirname(os.path.abspath(__file__))
ck_dir = os.environ.get("CK_DIR", f"{root}/3rdparty/composable_kernel")
bd_dir = f"{this_dir}/build"

if not os.path.exists(bd_dir):
os.makedirs(bd_dir)

shutil.copytree(ck_dir, f'{bd_dir}/ck', dirs_exist_ok=True)

ck_dir = f'{bd_dir}/ck'

archs = os.getenv("GPU_ARCHS", "gfx942").split(";")
validate_and_update_archs(archs)

cc_flag = [f"--offload-arch={arch}" for arch in archs]

cc_flag += [
"-mllvm", "-enable-post-misched=0",
"-mllvm", "-amdgpu-early-inline-all=true",
"-mllvm", "-amdgpu-function-calls=false",
"-mllvm", "--amdgpu-kernarg-preload-count=16",
"-mllvm", "-amdgpu-coerce-illegal-types=1",
"-Wno-unused-result",
"-Wno-switch-bool",
"-Wno-vla-cxx-extension",
"-Wno-undefined-func-template",
]

extra_compile_args = {
"cxx": ["-O3", "-std=c++17"],
"nvcc":
[
"-O3", "-std=c++17",
"-fPIC",
"-DUSE_PROF_API=1",
"-DENABLE_FP8",
"-D__HIP_PLATFORM_HCC__=1",
"-D__HIP_PLATFORM_AMD__=1",
"-U__HIP_NO_HALF_CONVERSIONS__",
"-U__HIP_NO_HALF_OPERATORS__",
]
+ cc_flag,
}

include_dirs = [
f"{this_dir}/build",
f"{ck_dir}/include",
f"{ck_dir}/library/include",
f"{ck_dir}/example/ck_tile/15_fused_moe",
]

renamed_ck_srcs = rename_cpp_to_cu(
[ # f'for other kernels'
f"{ck_dir}/example/ck_tile/15_fused_moe/instances",
], bd_dir)

build_srcs = ["src/sgl-kernel/csrc/moe_align_kernel.cu"]

cutlass = root / "3rdparty" / "cutlass"
include_dirs = [
cutlass.resolve() / "include",
cutlass.resolve() / "tools" / "util" / "include",
]
nvcc_flags = [
"-O3",
"-Xcompiler",
"-fPIC",
"-gencode=arch=compute_75,code=sm_75",
"-gencode=arch=compute_80,code=sm_80",
"-gencode=arch=compute_89,code=sm_89",
"-gencode=arch=compute_90,code=sm_90",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF2_OPERATORS__",
]
cxx_flags = ["-O3"]
libraries = ["c10", "torch", "torch_python"]
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"]
ext_modules = [
CUDAExtension(
name="sgl_kernel.ops._kernels",
sources=[
"src/sgl-kernel/csrc/trt_reduce_internal.cu",
"src/sgl-kernel/csrc/trt_reduce_kernel.cu",
"src/sgl-kernel/csrc/moe_align_kernel.cu",
"src/sgl-kernel/csrc/sgl_kernel_ops.cu",
],
include_dirs=include_dirs,
extra_compile_args={
"nvcc": nvcc_flags,
"cxx": cxx_flags,
},
libraries=libraries,
extra_link_args=extra_link_args,
),
]
ext_modules = [
CUDAExtension(
name="sgl_kernel.ops._kernels",
sources=build_srcs+renamed_ck_srcs,
extra_compile_args=extra_compile_args,
libraries=["hiprtc", "amdhip64", "c10", "torch", "torch_python"],
include_dirs=include_dirs,
extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"],
),
]

setup(
name="sgl-kernel",
Expand Down
Loading
Loading