diff --git a/.gitmodules b/.gitmodules index 3a14f6297a3a..6745dd375588 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "sgl-kernel/3rdparty/cutlass"] path = sgl-kernel/3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "sgl-kernel/3rdparty/composable_kernel"] + path = sgl-kernel/3rdparty/composable_kernel + url = https://github.com/ROCm/composable_kernel.git diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 96eaf856616f..85bafed8b2fe 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -1,5 +1,7 @@ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py +import os + from abc import abstractmethod from enum import Enum from typing import Callable, List, Optional, Tuple @@ -18,7 +20,11 @@ QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.utils import set_weight_attrs + +from sglang.srt.utils import ( + is_hip, + set_weight_attrs, +) if torch.cuda.is_available(): from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts @@ -97,6 +103,32 @@ def create_weights( layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + def permute_weight(x: torch.Tensor) -> torch.Tensor: + b_ = x.shape[0]; + n_ = x.shape[1]; + k_ = x.shape[2]; + + x_ = x + if envs.VLLM_MOE_SHUFFLE: + x_ = x_.view(b_, n_ / 16, 16, k_ / 32, 4, 8) + x_ = x_.permute(0, 1, 3, 4, 2, 5) + x_ = x_.contiguous() + return x_ + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if is_hip() and bool(int(os.getenv("CK_MOE", "0"))): + self.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + self.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + return + def apply( self, layer: torch.nn.Module, diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index a263cb2362a9..efa2badf2dc1 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -432,6 +432,18 @@ def __init__(self, quant_config): self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None + def permute_weight(x: torch.Tensor) -> torch.Tensor: + b_ = x.shape[0]; + n_ = x.shape[1]; + k_ = x.shape[2]; + + x_ = x + if envs.VLLM_MOE_SHUFFLE: + x_ = x_.view(b_, n_ / 16, 16, k_ / 64, 4, 16) + x_ = x_.permute(0, 1, 3, 4, 2, 5) + x_ = x_.contiguous() + return x_ + def create_weights( self, layer: Module, @@ -616,18 +628,30 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - # If ROCm, apply weight padding (min. Mem channel contention) only if set - if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): - layer.w13_weight = torch.nn.Parameter( - F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() + if is_hip(): + if bool(int(os.getenv("CK_MOE", "0"))): + self.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + self.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + elif bool(int(os.getenv("MOE_PADDING", "0"))): + # If ROCm, apply weight padding (min. Mem channel contention) only if set + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return # If checkpoint is fp8, we need to handle that the @@ -708,18 +732,30 @@ def process_weights_after_loading(self, layer: Module) -> None: max_w13_scales, requires_grad=False ) - # If ROCm, apply weight padding (min. Mem channel contention) only if set - if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): - layer.w13_weight = torch.nn.Parameter( - F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() + if is_hip(): + if bool(int(os.getenv("CK_MOE", "0"))): + self.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + self.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + elif bool(int(os.getenv("MOE_PADDING", "0"))): + # If ROCm, apply weight padding (min. Mem channel contention) only if set + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return def apply( diff --git a/sgl-kernel/3rdparty/composable_kernel b/sgl-kernel/3rdparty/composable_kernel new file mode 160000 index 000000000000..888317e698e9 --- /dev/null +++ b/sgl-kernel/3rdparty/composable_kernel @@ -0,0 +1 @@ +Subproject commit 888317e698e9803c62bd38568abc9e05d7709f33 diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index eb81991e3684..704ff16b5ec4 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -1,10 +1,24 @@ from pathlib import Path +import os +import sys +import shutil +import torch + from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension root = Path(__file__).parent.resolve() +def is_cuda() -> bool: + """Return whether it is CUDA on the NVIDIA CUDA platform.""" + return torch.cuda.is_available() and torch.version.cuda + + +def is_hip() -> bool: + """Return whether it is HIP on the AMD ROCm platform.""" + return torch.cuda.is_available() and torch.version.hip + def get_version(): with open(root / "pyproject.toml") as f: @@ -21,44 +35,147 @@ def update_wheel_platform_tag(): ) old_wheel.rename(new_wheel) +if is_cuda(): + cutlass = root / "3rdparty" / "cutlass" + include_dirs = [ + cutlass.resolve() / "include", + cutlass.resolve() / "tools" / "util" / "include", + ] + nvcc_flags = [ + "-O3", + "-Xcompiler", + "-fPIC", + "-gencode=arch=compute_75,code=sm_75", + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_89,code=sm_89", + "-gencode=arch=compute_90,code=sm_90", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + ] + cxx_flags = ["-O3"] + libraries = ["c10", "torch", "torch_python"] + extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"] + ext_modules = [ + CUDAExtension( + name="sgl_kernel.ops._kernels", + sources=[ + "src/sgl-kernel/csrc/trt_reduce_internal.cu", + "src/sgl-kernel/csrc/trt_reduce_kernel.cu", + "src/sgl-kernel/csrc/moe_align_kernel.cu", + "src/sgl-kernel/csrc/sgl_kernel_ops.cu", + ], + include_dirs=include_dirs, + extra_compile_args={ + "nvcc": nvcc_flags, + "cxx": cxx_flags, + }, + libraries=libraries, + extra_link_args=extra_link_args, + ), + ] +else: + def validate_and_update_archs(archs): + # List of allowed architectures + allowed_archs = ["native", "gfx90a", + "gfx940", "gfx941", "gfx942", "gfx1100"] + + # Validate if each element in archs is in allowed_archs + assert all( + arch in allowed_archs for arch in archs + ), f"One of GPU archs of {archs} is invalid or not supported" + + def rename_cpp_to_cu(els, dst, recurisve=False): + def do_rename_and_mv(name, src, dst, ret): + newName = name + if name.endswith(".cpp") or name.endswith(".cu"): + newName = name.replace(".cpp", ".cu") + ret.append(f'{dst}/{newName}') + shutil.copy(f'{src}/{name}', f'{dst}/{newName}') + ret = [] + for el in els: + if not os.path.exists(el): + continue + if os.path.isdir(el): + for entry in os.listdir(el): + if os.path.isdir(f'{el}/{entry}'): + if recurisve: + ret += rename_cpp_to_cu([f'{el}/{entry}'], + dst, recurisve) + continue + do_rename_and_mv(entry, el, dst, ret) + else: + do_rename_and_mv(os.path.basename(el), + os.path.dirname(el), dst, ret) + return ret + + this_dir = os.path.dirname(os.path.abspath(__file__)) + ck_dir = os.environ.get("CK_DIR", f"{root}/3rdparty/composable_kernel") + bd_dir = f"{this_dir}/build" + + if not os.path.exists(bd_dir): + os.makedirs(bd_dir) + + shutil.copytree(ck_dir, f'{bd_dir}/ck', dirs_exist_ok=True) + + ck_dir = f'{bd_dir}/ck' + + archs = os.getenv("GPU_ARCHS", "gfx942").split(";") + validate_and_update_archs(archs) + + cc_flag = [f"--offload-arch={arch}" for arch in archs] + + cc_flag += [ + "-mllvm", "-enable-post-misched=0", + "-mllvm", "-amdgpu-early-inline-all=true", + "-mllvm", "-amdgpu-function-calls=false", + "-mllvm", "--amdgpu-kernarg-preload-count=16", + "-mllvm", "-amdgpu-coerce-illegal-types=1", + "-Wno-unused-result", + "-Wno-switch-bool", + "-Wno-vla-cxx-extension", + "-Wno-undefined-func-template", + ] + + extra_compile_args = { + "cxx": ["-O3", "-std=c++17"], + "nvcc": + [ + "-O3", "-std=c++17", + "-fPIC", + "-DUSE_PROF_API=1", + "-DENABLE_FP8", + "-D__HIP_PLATFORM_HCC__=1", + "-D__HIP_PLATFORM_AMD__=1", + "-U__HIP_NO_HALF_CONVERSIONS__", + "-U__HIP_NO_HALF_OPERATORS__", + ] + + cc_flag, + } + + include_dirs = [ + f"{this_dir}/build", + f"{ck_dir}/include", + f"{ck_dir}/library/include", + f"{ck_dir}/example/ck_tile/15_fused_moe", + ] + + renamed_ck_srcs = rename_cpp_to_cu( + [ # f'for other kernels' + f"{ck_dir}/example/ck_tile/15_fused_moe/instances", + ], bd_dir) + + build_srcs = ["src/sgl-kernel/csrc/moe_align_kernel.cu"] -cutlass = root / "3rdparty" / "cutlass" -include_dirs = [ - cutlass.resolve() / "include", - cutlass.resolve() / "tools" / "util" / "include", -] -nvcc_flags = [ - "-O3", - "-Xcompiler", - "-fPIC", - "-gencode=arch=compute_75,code=sm_75", - "-gencode=arch=compute_80,code=sm_80", - "-gencode=arch=compute_89,code=sm_89", - "-gencode=arch=compute_90,code=sm_90", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF2_OPERATORS__", -] -cxx_flags = ["-O3"] -libraries = ["c10", "torch", "torch_python"] -extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"] -ext_modules = [ - CUDAExtension( - name="sgl_kernel.ops._kernels", - sources=[ - "src/sgl-kernel/csrc/trt_reduce_internal.cu", - "src/sgl-kernel/csrc/trt_reduce_kernel.cu", - "src/sgl-kernel/csrc/moe_align_kernel.cu", - "src/sgl-kernel/csrc/sgl_kernel_ops.cu", - ], - include_dirs=include_dirs, - extra_compile_args={ - "nvcc": nvcc_flags, - "cxx": cxx_flags, - }, - libraries=libraries, - extra_link_args=extra_link_args, - ), -] + ext_modules = [ + CUDAExtension( + name="sgl_kernel.ops._kernels", + sources=build_srcs+renamed_ck_srcs, + extra_compile_args=extra_compile_args, + libraries=["hiprtc", "amdhip64", "c10", "torch", "torch_python"], + include_dirs=include_dirs, + extra_link_args=["-Wl,-rpath,$ORIGIN/../../torch/lib"], + ), + ] setup( name="sgl-kernel", diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_fused_experts.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_fused_experts.cu new file mode 100644 index 000000000000..68c5ae9f7293 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_fused_experts.cu @@ -0,0 +1,104 @@ +#include +#include +#include + +#include + +#include "utils.hpp" + +#include +#include "fused_moe.hpp" + +#define FOREACH_BUFFER_TORCH_TYPE_MAP(F) \ + F("fp32", torch::kFloat) \ + F("fp16", torch::kHalf) \ + F("bf16", torch::kBFloat16) \ + F("int32", torch::kInt32) \ + F("int8", torch::kInt8) \ + F("fp8", c10::kFloat8_e4m3fnuz) + +inline std::string torchDTypeToStr(caffe2::TypeMeta dtype) +{ +#define TYPE_CASE(type, torch_type) \ + case torch_type: \ + { \ + return type; \ + } + + switch (dtype.toScalarType()) + { + FOREACH_BUFFER_TORCH_TYPE_MAP(TYPE_CASE); + default: + throw std::runtime_error("CKPyInterface: Unsupported data type " + std::to_string((int8_t)(dtype.toScalarType()))); + } + +#undef TYPE_CASE +} + +void moe_fused_experts(torch::Tensor hidden_states, torch::Tensor w1, torch::Tensor w2, + torch::Tensor topk_weights, torch::Tensor topk_ids, + torch::Tensor w1_scale, torch::Tensor w2_scale, + torch::Tensor a1_scale, torch::Tensor a2_scale, + torch::Tensor sorted_ids, torch::Tensor sorted_weights, + torch::Tensor sorted_expert_ids, torch::Tensor num_tokens_post_pad, + torch::Tensor out, int64_t block_m, int fused_qunt, int gate_only) { + + auto prec_i = torchDTypeToStr(hidden_states.dtype()); + auto prec_w = torchDTypeToStr(w1.dtype()); + auto prec_o = torchDTypeToStr(out.dtype()); + auto prec_kw = torchDTypeToStr(topk_weights.dtype()); + + auto prec_st = "fp32" + auto prec_sw = "fp32" + auto prec_sq = "fp32" + + if (fused_qunt != 0) { + prec_st = torchDTypeToStr(a1_scale.dtype()); + prec_sw = torchDTypeToStr(w1_scale.dtype()) + prec_sq = torchDTypeToStr(a2_scale.dtype()) + } + + auto hidden_size = w1.size(2); + auto shared_intermediate_size_0 = w1.size(1); + + auto tokens = hidden_states.size(0); + auto experts = w1.size(0); + + auto topk = topk_ids_host.size(1); + + auto stride = hidden_size; + + fused_moe_traits traits{prec_i, + prec_w, + prec_o, + prec_st, + prec_sw, + prec_sq, + prec_kw, + block_m, + gate_only, + fused_quant}; + + fused_moe_args args{hidden_states.data_ptr(), + fused_quant != 0 ? a1_scale.data_ptr() : nullptr, + w1.data_ptr(), + w2.data_ptr(), + fused_quant != 0 ? w1_scale.data_ptr() : nullptr, + fused_quant != 0 ? w2_scale.data_ptr() : nullptr, + fused_quant == 1 ? a2_scale.data_ptr() : nullptr, + out.data_ptr(), + topk_ids.data_ptr(), + topk_weights.data_ptr(), + sorted_ids.data_ptr(), + sorted_weights.data_ptr(), + sorted_expert_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + block_m, + hidden_size, + shared_intermediate_size_0, + tokens, + experts, + topk, + stride}; + +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu index 69b8f8eebc53..c467f03156ed 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -12,6 +12,14 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer); +void moe_fused_experts(torch::Tensor hidden_states, torch::Tensor w1, torch::Tensor w2, + torch::Tensor topk_weights, torch::Tensor topk_ids, + torch::Tensor w1_scale, torch::Tensor w2_scale, + torch::Tensor a1_scale, torch::Tensor a2_scale, + torch::Tensor sorted_ids, torch::Tensor sorted_weights, + torch::Tensor sorted_expert_ids, torch::Tensor num_tokens_post_pad, + torch::Tensor out, int64_t block_m) + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // trt_reduce m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); @@ -19,4 +27,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)"); // moe_align_block_size m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)"); + m.def("moe_fused_experts", &moe_fused_experts, "MOE implementation by ck") }