diff --git a/.dockerignore b/.dockerignore deleted file mode 120000 index 3e4e48b0b5fe..000000000000 --- a/.dockerignore +++ /dev/null @@ -1 +0,0 @@ -.gitignore \ No newline at end of file diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000000..00d145ca5546 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,277 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +**/build/ +**/develop-eggs/ +**/dist/ +**/downloads/ +**/eggs/ +.eggs/ +**/lib/ +**/lib64/ +**/parts/ +**/sdist/ +**/var/ +**/wheels/ +**/share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ + +# Tokenizer cache for tests +.tokenizer_cache/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# MacOS +.DS_Store + +# Vim +*.swp + +# Documentation +docs/_build + +# SGL +benchmark/mmlu/data +benchmark/mmlu/data.tar +benchmark/llava_bench/images +benchmark/llava_bench/mme_pack +*.jsonl +tmp*.txt + +# Torch Compile logs +tl_out/ + +# Plots +*.png +*.pdf + +# personnal +work_dirs/ +*.csv + +!logo.png + +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +compile_commands.json + +*.iml + +# VSCode +.vscode + +1 + +# Autoenv +.env.leave + +# Rust lib +Cargo.lock + +# Generated vision test fixtures (regenerate with: python scripts/generate_vision_golden.py) +sgl-model-gateway/tests/fixtures/golden/ + +lmms-eval + +**/.claude/ +**/.serena/ +ctags/ +outputs/ + +# Eval Cache +.longbench_cache/ + +# CUDA kernel develop, profile and debug +.clangd +*.nsys-rep +*.ncu-rep +*.nvcudmp + +# setuptools-scm generated version file +python/sglang/_version.py + +# Generated protobuf files (regenerate during wheel build or with compile_proto.py) +python/sglang/srt/grpc/*_pb2.py +python/sglang/srt/grpc/*_pb2_grpc.py +python/sglang/srt/grpc/*_pb2.pyi + +# MUSA section +# Generated source files by torchada +sgl-kernel/csrc_musa/ +sgl-kernel/include_musa/ +sgl-kernel/csrc/**/*_musa/ + +# MUSA core dump files +*.mudmp diff --git a/python/pyproject.toml b/python/pyproject.toml index fd8cdb2d9d55..82922ab198ce 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -13,23 +13,20 @@ classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", ] +dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle"] -dependencies = [ +[project.optional-dependencies] +runtime_common = [ "IPython", "aiohttp", - "apache-tvm-ffi>=0.1.5,<0.2", "anthropic>=0.20.0", - "av ; sys_platform == 'linux' and (platform_machine == 'aarch64' or platform_machine == 'arm64' and platform_machine == 'armv7l')", "blobfile==3.0.0", "build", "compressed-tensors", - "cuda-python==12.9", "decord2", "datasets", "einops", "fastapi", - "flashinfer_python==0.6.2", # keep it aligned with jit-cache version in Dockerfile - "flashinfer_cubin==0.6.2", "gguf", "hf_transfer", "huggingface_hub", @@ -39,8 +36,6 @@ dependencies = [ "msgspec", "ninja", "numpy", - "nvidia-cutlass-dsl>=4.3.4", - "nvidia-ml-py", "openai-harmony==0.0.4", "openai==2.6.1", "orjson", @@ -55,52 +50,41 @@ dependencies = [ "pydantic", "python-multipart", "pyzmq>=25.1.2", - "quack-kernels==0.2.4", "requests", "scipy", "sentencepiece", "setproctitle", - "sgl-kernel==0.3.21", "soundfile==0.13.1", "tiktoken", "timm==1.0.16", - "torch_memory_saver==0.0.9", - "torch==2.9.1", "torchao==0.9.0", - "torchaudio==2.9.1", - "torchcodec==0.8.0 ; sys_platform != 'linux' or (sys_platform == 'linux' and platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')", # torchcodec does not exist in those systems. If not provided, transformer will use torchvision instead by default. - "torchvision", "tqdm", "transformers==4.57.1", "uvicorn", "uvloop", "xgrammar==0.1.27", - "grpcio==1.75.1", # keep it align with compile_proto.py "grpcio-tools==1.75.1", # keep it align with compile_proto.py "grpcio-reflection==1.75.1", # required by srt/entrypoints/grpc_server.py - "grpcio-health-checking==1.75.1", # required for Kubernetes gRPC health probes ] -[[tool.uv.index]] -name = "pypi" -url = "https://pypi.org/simple" -default = true - -[[tool.uv.index]] -name = "torch-cu129" -url = "https://download.pytorch.org/whl/cu129" -explicit = true +tracing = [ + "opentelemetry-sdk", + "opentelemetry-api", + "opentelemetry-exporter-otlp", + "opentelemetry-exporter-otlp-proto-grpc", +] -[tool.uv.sources] -torch = [ - { index = "pypi", marker = "platform_machine == 'x86_64'"}, - { index = "torch-cu129", marker = "platform_machine == 'aarch64'"}, +# HIP (Heterogeneous-computing Interface for Portability) for AMD +# => base docker rocm/vllm-dev:20250114, not from public vllm whl +srt_hip = [ + "sglang[runtime_common]", + "torch", + "petit_kernel==0.0.2", + "wave-lang==3.8.2", ] -[project.optional-dependencies] -checkpoint-engine = ["checkpoint-engine==0.1.2"] -diffusion = [ +diffusion_hip = [ "PyYAML==6.0.1", "cloudpickle", "diffusers==0.36.0", @@ -109,44 +93,63 @@ diffusion = [ "moviepy>=2.0.0", "opencv-python-headless==4.10.0.84", "remote-pdb", - "st_attn==0.0.7 ; platform_machine != 'aarch64' and platform_machine != 'arm64'", - "vsa==0.0.4 ; platform_machine != 'aarch64' and platform_machine != 'arm64'", + "st_attn==0.0.7", + "vsa==0.0.4", "runai_model_streamer>=0.15.5", - "cache-dit==1.2.0", + "cache-dit==1.1.8", "addict" ] -tracing = [ - "opentelemetry-api", - "opentelemetry-exporter-otlp", - "opentelemetry-exporter-otlp-proto-grpc", - "opentelemetry-sdk", +# For Intel Gaudi(device : hpu) follow the installation guide +# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html +srt_hpu = ["sglang[runtime_common]"] + +# https://docs.sglang.io/platforms/mthreads_gpu.md +srt_musa = [ + "sglang[runtime_common]", + "torch", + "torch_musa", + "torchada>=0.1.25", + "mthreads-ml-py", + "numpy<2.0", +] + +diffusion_musa = [ + "PyYAML==6.0.1", + "cloudpickle", + "diffusers==0.36.0", + "imageio==2.36.0", + "imageio-ffmpeg==0.5.1", + "moviepy>=2.0.0", + "opencv-python-headless==4.10.0.84", + "remote-pdb", + "st_attn==0.0.7", + "vsa==0.0.4", + "runai_model_streamer>=0.15.5", + "cache-dit==1.1.8", + "addict" ] test = [ "accelerate", - "bitsandbytes", "expecttest", + "gguf", "jsonlines", "matplotlib", "pandas", - "parameterized", "peft", "pytest", "sentence_transformers", "tabulate", ] -dev = ["sglang[test]"] - -all = [ - "sglang[diffusion]", - "sglang[tracing]", -] +all_hip = ["sglang[srt_hip]", "sglang[diffusion_hip]"] +all_hpu = ["sglang[srt_hpu]"] +all_musa = ["sglang[srt_musa]", "sglang[diffusion_musa]"] -[tool.uv.extra-build-dependencies] -st-attn = ["torch", "setuptools"] -vsa = ["torch", "setuptools"] +dev_hip = ["sglang[all_hip]", "sglang[test]"] +dev_hpu = ["sglang[all_hpu]", "sglang[test]"] +dev_musa = ["sglang[all_musa]", "sglang[test]"] [project.urls] "Homepage" = "https://github.com/sgl-project/sglang" @@ -186,6 +189,4 @@ exclude = [ [tool.setuptools_scm] root = ".." version_file = "sglang/_version.py" -git_describe_command = ["bash", "-c", "git tag --list --sort=-version:refname 'v*.*.*' | head -1 | xargs git describe --tags --long"] -# Allow editable installs even when .git metadata is not available. -fallback_version = "0.0.0.dev0" +git_describe_command = ["git", "describe", "--tags", "--long", "--match", "v*"] diff --git a/python/pyproject_other.toml b/python/pyproject_other.toml deleted file mode 100755 index 82922ab198ce..000000000000 --- a/python/pyproject_other.toml +++ /dev/null @@ -1,192 +0,0 @@ -[build-system] -requires = ["setuptools>=61.0", "setuptools-scm>=8.0", "wheel", "grpcio-tools==1.75.1"] -build-backend = "setuptools.build_meta" - -[project] -name = "sglang" -dynamic = ["version"] -description = "SGLang is a fast serving framework for large language models and vision language models." -readme = "README.md" -requires-python = ">=3.10" -license = { file = "LICENSE" } -classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License", -] -dependencies = ["aiohttp", "requests", "tqdm", "numpy", "IPython", "setproctitle"] - -[project.optional-dependencies] -runtime_common = [ - "IPython", - "aiohttp", - "anthropic>=0.20.0", - "blobfile==3.0.0", - "build", - "compressed-tensors", - "decord2", - "datasets", - "einops", - "fastapi", - "gguf", - "hf_transfer", - "huggingface_hub", - "interegular", - "llguidance>=0.7.11,<0.8.0", - "modelscope", - "msgspec", - "ninja", - "numpy", - "openai-harmony==0.0.4", - "openai==2.6.1", - "orjson", - "outlines==0.1.11", - "packaging", - "partial_json_parser", - "pillow", - "prometheus-client>=0.20.0", - "psutil", - "py-spy", - "pybase64", - "pydantic", - "python-multipart", - "pyzmq>=25.1.2", - "requests", - "scipy", - "sentencepiece", - "setproctitle", - "soundfile==0.13.1", - "tiktoken", - "timm==1.0.16", - "torchao==0.9.0", - "tqdm", - "transformers==4.57.1", - "uvicorn", - "uvloop", - "xgrammar==0.1.27", - "grpcio==1.75.1", # keep it align with compile_proto.py - "grpcio-tools==1.75.1", # keep it align with compile_proto.py - "grpcio-reflection==1.75.1", # required by srt/entrypoints/grpc_server.py -] - -tracing = [ - "opentelemetry-sdk", - "opentelemetry-api", - "opentelemetry-exporter-otlp", - "opentelemetry-exporter-otlp-proto-grpc", -] - -# HIP (Heterogeneous-computing Interface for Portability) for AMD -# => base docker rocm/vllm-dev:20250114, not from public vllm whl -srt_hip = [ - "sglang[runtime_common]", - "torch", - "petit_kernel==0.0.2", - "wave-lang==3.8.2", -] - -diffusion_hip = [ - "PyYAML==6.0.1", - "cloudpickle", - "diffusers==0.36.0", - "imageio==2.36.0", - "imageio-ffmpeg==0.5.1", - "moviepy>=2.0.0", - "opencv-python-headless==4.10.0.84", - "remote-pdb", - "st_attn==0.0.7", - "vsa==0.0.4", - "runai_model_streamer>=0.15.5", - "cache-dit==1.1.8", - "addict" -] - -# For Intel Gaudi(device : hpu) follow the installation guide -# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html -srt_hpu = ["sglang[runtime_common]"] - -# https://docs.sglang.io/platforms/mthreads_gpu.md -srt_musa = [ - "sglang[runtime_common]", - "torch", - "torch_musa", - "torchada>=0.1.25", - "mthreads-ml-py", - "numpy<2.0", -] - -diffusion_musa = [ - "PyYAML==6.0.1", - "cloudpickle", - "diffusers==0.36.0", - "imageio==2.36.0", - "imageio-ffmpeg==0.5.1", - "moviepy>=2.0.0", - "opencv-python-headless==4.10.0.84", - "remote-pdb", - "st_attn==0.0.7", - "vsa==0.0.4", - "runai_model_streamer>=0.15.5", - "cache-dit==1.1.8", - "addict" -] - -test = [ - "accelerate", - "expecttest", - "gguf", - "jsonlines", - "matplotlib", - "pandas", - "peft", - "pytest", - "sentence_transformers", - "tabulate", -] - -all_hip = ["sglang[srt_hip]", "sglang[diffusion_hip]"] -all_hpu = ["sglang[srt_hpu]"] -all_musa = ["sglang[srt_musa]", "sglang[diffusion_musa]"] - -dev_hip = ["sglang[all_hip]", "sglang[test]"] -dev_hpu = ["sglang[all_hpu]", "sglang[test]"] -dev_musa = ["sglang[all_musa]", "sglang[test]"] - -[project.urls] -"Homepage" = "https://github.com/sgl-project/sglang" -"Bug Tracker" = "https://github.com/sgl-project/sglang/issues" - -[project.scripts] -sglang = "sglang.cli.main:main" - -[tool.setuptools.package-data] -"sglang" = [ - "srt/**/*", - "jit_kernel/**/*" -] - -[tool.setuptools.packages.find] -exclude = [ - "assets*", - "benchmark*", - "docs*", - "dist*", - "playground*", - "scripts*", - "tests*", -] - -[tool.wheel] -exclude = [ - "assets*", - "benchmark*", - "docs*", - "dist*", - "playground*", - "scripts*", - "tests*", -] - -[tool.setuptools_scm] -root = ".." -version_file = "sglang/_version.py" -git_describe_command = ["git", "describe", "--tags", "--long", "--match", "v*"] diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..e25d0492d493 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=384,N=128,device_name=AMD_Instinct_MI308X,dtype=int4_w4a16.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=384,N=128,device_name=AMD_Instinct_MI308X,dtype=int4_w4a16.json new file mode 100644 index 000000000000..4f8f76774b92 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=384,N=128,device_name=AMD_Instinct_MI308X,dtype=int4_w4a16.json @@ -0,0 +1,90 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X_VF.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X_VF.json new file mode 100644 index 000000000000..4d4b752fa5d6 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X_VF.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X_VF.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X_VF.json new file mode 100644 index 000000000000..a218fc40642c --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X_VF.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X_VF.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X_VF.json new file mode 100644 index 000000000000..3682cc548f35 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X_VF.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..d7f14d6656eb --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8.json @@ -0,0 +1,178 @@ +{ + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16384": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32768": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "65536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "131072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X_VF.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X_VF.json new file mode 100644 index 000000000000..21742854c613 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X_VF.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..d9d2f5eac52f --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8.json @@ -0,0 +1,175 @@ +{ + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 4, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16384": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32768": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "65536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 1, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "131072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 185f1bea3ea5..f99e6957a767 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -693,6 +693,7 @@ def _weight_loader_impl( "CompressedTensorsWNA16MarlinMoEMethod", "CompressedTensorsWNA16MoEMethod", "CompressedTensorsWNA16TritonMoEMethod", + "CompressedTensorsWNA16AiterMoEMethod", ] ) else loaded_weight @@ -909,6 +910,7 @@ def weight_loader_fused( in [ "CompressedTensorsWNA16MoEMethod", "CompressedTensorsWNA16TritonMoEMethod", + "CompressedTensorsWNA16AiterMoEMethod", ] ) else loaded_weight diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 176311eaf284..4f99eacd305d 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -145,6 +145,22 @@ def get_moe_method( ) return CompressedTensorsMxInt4MoEMethod(quant_config) elif _is_hip: + # Try aiter/FlyDSL for W4A16 (per-row scale or group_size=32) + _w4a16_config = quant_config.target_scheme_map["Linear"].get("weights") + _is_w4a16 = _w4a16_config and _w4a16_config.num_bits == 4 + _group_size = _w4a16_config.group_size if _w4a16_config else -1 + _is_flydsl_compatible = _group_size in (-1, 0, 32) + + logger.info(f"[DEBUG] W4A16 check: _use_aiter={_use_aiter}, _is_w4a16={_is_w4a16}, group_size={_group_size}, _is_flydsl_compatible={_is_flydsl_compatible}") + if _use_aiter and _is_w4a16 and _is_flydsl_compatible: + try: + logger.info_once( + "Using CompressedTensorsWNA16AiterMoEMethod (ROCm + FlyDSL)" + ) + return CompressedTensorsWNA16AiterMoEMethod(quant_config) + except ValueError as e: + logger.warning(f"FlyDSL W4A16 not available: {e}, falling back to Triton") + logger.info_once( "Using CompressedTensorsWNA16TritonMoEMethod (ROCm)" ) @@ -1448,6 +1464,148 @@ def apply( return self.runner.run(dispatch_output, quant_info) + +class CompressedTensorsWNA16AiterMoEMethod(CompressedTensorsWNA16MoEMethod): + """ROCm/HIP W4A16 MoE method using aiter/FlyDSL kernels. + + Requirements: + - Only supports W4A16 (4-bit weights, bf16 activations) + - Supports per-row scale (group_size == -1) or group_size == 32 + - Only supports SiLU activation + """ + + def __init__(self, quant_config, num_gpu_experts=-1): + super().__init__(quant_config, num_gpu_experts) + # FlyDSL W4A16 supports per-row scale or group_size=32 + if self.group_size not in (-1, 0, 32): + raise ValueError( + f"FlyDSL W4A16 only supports per-row scale or group_size=32, " + f"got group_size={self.group_size}" + ) + # Ensure this is W4A16 (4-bit weights) + if self.num_bits != 4: + raise ValueError( + f"FlyDSL aiter path only supports W4A16 (4-bit weights), " + f"got num_bits={self.num_bits}" + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if getattr(layer, "is_aiter_converted", False): + return + + import os, sys as _sys + DSL2_ROOT = os.environ.get("DSL2_ROOT", "/opt/FlyDSL") + if DSL2_ROOT not in _sys.path: + _sys.path.insert(0, DSL2_ROOT) + from tests.utils import shuffle_weight as flydsl_shuffle + from tests.kernels.test_moe_gemm import _pack_shuffled_int8_to_packed_int4_no_perm + + def _gptq_int32_to_flydsl_packed(w_int32): + """Convert GPTQ int32 [E, K//8, N] to FlyDSL shuffled packed int4 [E, N, K//2]. + + Steps: + 1. Unpack int32 to individual unsigned int4 values (as int8) + 2. Convert unsigned [0,15] to signed [-8,7] by subtracting 8 + 3. Apply FlyDSL preshuffle (on individual int8 values) + 4. Pack with FlyDSL's interleaved int4 packing + """ + E = w_int32.shape[0] + # [E, K//8, N] -> transpose -> [E, N, K//8] + w = w_int32.transpose(1, 2).contiguous() + N = w.shape[1] + K_div8 = w.shape[2] + K = K_div8 * 8 + + # Unpack int32 -> 8 x uint4 values along K + w_expanded = w.unsqueeze(-1).expand(E, N, K_div8, 8) # [E, N, K//8, 8] + shifts = torch.arange(8, device=w.device) * 4 # [0, 4, 8, ..., 28] + nibbles = ((w_expanded >> shifts) & 0xF).to(torch.int8) # [E, N, K//8, 8] + nibbles = nibbles.reshape(E, N, K) # [E, N, K] unsigned int4 as int8 + + # Convert unsigned [0,15] to signed [-8,7] + signed = nibbles.to(torch.int16) - 8 + signed = signed.to(torch.int8) # [E, N, K] signed int4 as int8 + + # FlyDSL preshuffle (operates on individual values) + shuffled = flydsl_shuffle(signed, layout=(16, 16)) + + # FlyDSL interleaved int4 packing + packed = _pack_shuffled_int8_to_packed_int4_no_perm(shuffled) + return packed.view(E, N, K // 2) + + # Convert w13 weights + w13 = layer.w13_weight_packed.data + w13 = _gptq_int32_to_flydsl_packed(w13) + layer.w13_weight_packed = torch.nn.Parameter(w13, requires_grad=False) + logger.debug(f"[FlyDSL] w13 converted: {w13.shape} {w13.dtype}") + + # Convert w2 weights + w2 = layer.w2_weight_packed.data + w2 = _gptq_int32_to_flydsl_packed(w2) + layer.w2_weight_packed = torch.nn.Parameter(w2, requires_grad=False) + logger.debug(f"[FlyDSL] w2 converted: {w2.shape} {w2.dtype}") + + # Convert scales for FlyDSL: + # per-row: [E, 1, N] -> squeeze -> [E, N] + # groupwise: [E, K//gs, N] -> keep as-is (Opt 0: cache-friendly layout) + w13_scale = layer.w13_weight_scale.data + if self.group_size > 0 and w13_scale.dim() == 3 and w13_scale.shape[1] > 1: + # Groupwise: keep [E, K//gs, N] layout (Opt 0: stride-1 access for adjacent threads) + w13_scale = w13_scale.contiguous() + logger.debug(f"[FlyDSL] w13_scale groupwise [E,K//gs,N]: {w13_scale.shape} (group_size={self.group_size})") + elif w13_scale.dim() == 3 and w13_scale.shape[1] == 1: + # Per-row: squeeze [E, 1, N] -> [E, N] + w13_scale = w13_scale.squeeze(1) + layer.w13_weight_scale = torch.nn.Parameter(w13_scale.contiguous(), requires_grad=False) + + w2_scale = layer.w2_weight_scale.data + if self.group_size > 0 and w2_scale.dim() == 3 and w2_scale.shape[1] > 1: + # Groupwise: keep [E, K//gs, N] layout (Opt 0: stride-1 access for adjacent threads) + w2_scale = w2_scale.contiguous() + logger.debug(f"[FlyDSL] w2_scale groupwise [E,K//gs,N]: {w2_scale.shape} (group_size={self.group_size})") + elif w2_scale.dim() == 3 and w2_scale.shape[1] == 1: + # Per-row: squeeze [E, 1, N] -> [E, N] + w2_scale = w2_scale.squeeze(1) + layer.w2_weight_scale = torch.nn.Parameter(w2_scale.contiguous(), requires_grad=False) + + layer.w13_weight_packed.is_shuffled = True + layer.w2_weight_packed.is_shuffled = True + layer.is_aiter_converted = True + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: "StandardDispatchOutput", + ) -> "CombineInput": + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + assert self.moe_runner_config.activation == "silu", "FlyDSL W4A16 only supports SiLU" + + x = dispatch_output.hidden_states + topk_weights, topk_ids, _ = dispatch_output.topk_output + + output = fused_moe( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + topk_weights, + topk_ids, + quant_type=QuantType.No, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=None, + a2_scale=None, + activation=ActivationType.Silu, + ) + return StandardCombineInput(hidden_states=output) + + + class NPUCompressedTensorsW4A8Int8DynamicMoEMethod(CompressedTensorsMoEMethod): ### TODO: Get rid of code duplication with python/sglang/srt/modelslim/modelslim_moe.py @OrangeRedeng @TamirBaydasov diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py.bak b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py.bak new file mode 100644 index 000000000000..176311eaf284 --- /dev/null +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py.bak @@ -0,0 +1,2171 @@ +# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import enum +import logging +from enum import Enum +from typing import TYPE_CHECKING + +import torch +from compressed_tensors import CompressionFormat +from compressed_tensors.quantization import QuantizationStrategy + +from sglang.srt.distributed import ( + get_moe_expert_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, +) +from sglang.srt.distributed.device_communicators.pynccl_allocator import ( + use_symmetric_memory, +) +from sglang.srt.hardware_backend.npu.quantization.fused_moe_method_npu import ( + NPUW4A8Int8DynamicMoEMethod, + NPUW4A16Int4DynamicMoEMethod, + NPUW8A8Int8DynamicMoEMethod, +) +from sglang.srt.layers.dp_attention import is_allocation_symmetric +from sglang.srt.layers.moe import ( + MoeRunner, + MoeRunnerBackend, + MoeRunnerConfig, + get_moe_runner_backend, +) +from sglang.srt.layers.moe.cutlass_moe_params import CutlassMoEParams, CutlassMoEType +from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo +from sglang.srt.layers.moe.utils import RoutingMethodType +from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase +from sglang.srt.layers.quantization.compressed_tensors.schemes import ( + WNA16_SUPPORTED_BITS, +) +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant +from sglang.srt.layers.quantization.fp8_utils import ( + is_blackwell_supported, + normalize_e4m3fn_to_e4m3fnuz, +) +from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack +from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales +from sglang.srt.layers.quantization.utils import ( + all_close_1d, + per_tensor_dequantize, + prepare_static_weights_for_trtllm_fp4_moe, + reorder_w1w3_to_w3w1, + replace_parameter, + swizzle_blockscale, +) +from sglang.srt.utils import ( + get_bool_env_var, + is_cuda, + is_flashinfer_available, + is_hip, + is_npu, + next_power_of_2, + set_weight_attrs, +) + +if TYPE_CHECKING: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + from sglang.srt.layers.moe.token_dispatcher import ( + CombineInput, + StandardDispatchOutput, + ) + from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( + CompressedTensorsConfig, + ) + +_is_hip = is_hip() +_is_npu = is_npu() +_is_cuda = is_cuda() + +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip + +if _use_aiter: + from aiter import ActivationType, QuantType + from aiter.fused_moe import fused_moe + from aiter.ops.shuffle import shuffle_weight + +if is_flashinfer_available(): + from flashinfer.fp4_quantization import block_scale_interleave + from flashinfer.fused_moe import ( + convert_to_block_layout, + trtllm_mxint4_block_scale_moe, + ) + from flashinfer.fused_moe.core import ( + _maybe_get_cached_w3_w1_permute_indices, + get_w2_permute_indices_with_cache, + ) + + +logger = logging.getLogger(__name__) + + +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +__all__ = [ + "CompressedTensorsMoEMethod", + "CompressedTensorsW4A4Nvfp4MoEMethod", + "NPUCompressedTensorsW4A8Int8DynamicMoEMethod", + "CompressedTensorsW8A8Fp8MoEMethod", + "NPUCompressedTensorsW8A8Int8MoEMethod", + "CompressedTensorsWNA16MoEMethod", + "CompressedTensorsMxInt4MoEMethod", + "NPUCompressedTensorsW4A16Int4DynamicMoEMethod", +] + + +class CompressedTensorsMoEMethod(FusedMoEMethodBase): + def __new__(cls, *args, **kwargs): + if cls is CompressedTensorsMoEMethod: + return super().__new__(cls) + return super().__new__(cls) + + @staticmethod + def get_moe_method( + quant_config: CompressedTensorsConfig, + layer: torch.nn.Module, + prefix: str, + ) -> "CompressedTensorsMoEMethod": + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + + weight_quant = quant_config.target_scheme_map["Linear"].get("weights") + input_quant = quant_config.target_scheme_map["Linear"].get("input_activations") + + if quant_config._is_wNa16_group_channel(weight_quant, input_quant): + if not _is_npu: + if ( + quant_config._is_mxint4a16(weight_quant, input_quant) + and get_moe_runner_backend().is_flashinfer_trtllm() + ): + logger.info_once( + "Using CompressedTensorsMxInt4MoEMethod with flashinfer_trtllm backend" + ) + return CompressedTensorsMxInt4MoEMethod(quant_config) + elif _is_hip: + logger.info_once( + "Using CompressedTensorsWNA16TritonMoEMethod (ROCm)" + ) + return CompressedTensorsWNA16TritonMoEMethod(quant_config) + else: + logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") + return CompressedTensorsWNA16MoEMethod(quant_config) + else: + if ( + quant_config._is_dynamic_token_w4(weight_quant, input_quant) + and input_quant is None + ): + logger.info_once( + "Using NPUCompressedTensorsW4A16Int4DynamicMoEMethod" + ) + return NPUCompressedTensorsW4A16Int4DynamicMoEMethod(quant_config) + elif quant_config._is_fp4a4_nvfp4(weight_quant, input_quant): + logger.info_once("Using CompressedTensorsW4A4Nvfp4MoEMethod") + return CompressedTensorsW4A4Nvfp4MoEMethod(quant_config) + elif quant_config._is_fp8_w8a8(weight_quant, input_quant): + logger.info_once("Using CompressedTensorsW8A8Fp8MoEMethod") + return CompressedTensorsW8A8Fp8MoEMethod(quant_config) + elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant): + if _is_npu: + logger.info_once("Using NPUCompressedTensorsW8A8Int8DynamicMoEMethod") + return NPUCompressedTensorsW8A8Int8DynamicMoEMethod(quant_config) + else: + raise NotImplementedError( + f"The W8A8Int8 Fused MoE scheme is implemented only for NPU for now." + ) + elif quant_config._is_dynamic_token_w4a8(weight_quant, input_quant): + if _is_npu: + logger.info_once("Using NPUCompressedTensorsW4A8Int8DynamicMoEMethod") + return NPUCompressedTensorsW4A8Int8DynamicMoEMethod(quant_config) + else: + raise NotImplementedError( + f"The W4A8Int8 Fused MoE scheme is implemented only for NPU for now." + ) + else: + raise RuntimeError( + f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}" + ) + + +class CompressedTensorsW4A4Nvfp4MoEMethod(CompressedTensorsMoEMethod): + + def __init__(self, quant_config: CompressedTensorsConfig): + if not is_blackwell_supported(): + raise ValueError( + "Current platform does not support NVFP4" + " quantization. Please use Blackwell and" + " above." + ) + self.quant_config = quant_config + self.group_size = 16 + self.use_flashinfer_trtllm = get_moe_runner_backend().is_flashinfer_trtllm() + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + layer.params_dtype = params_dtype + + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // 2, + requires_grad=False, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // 2, + dtype=torch.uint8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # Weight Scales + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + # 2 fp4 items are packed in the input dimension + hidden_size // self.group_size, + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + # 2 fp4 items are packed in the input dimension + intermediate_size_per_partition // self.group_size, + dtype=torch.float8_e4m3fn, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} + ) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # Weight Global Scales + w13_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_weight_global_scale", w13_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w13_weight_scale_2, extra_weight_attrs) + + w2_weight_scale_2 = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_weight_global_scale", w2_weight_scale_2) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w2_weight_scale_2, extra_weight_attrs) + + # Input Global Scales + w13_input_scale = torch.nn.Parameter( + torch.empty(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_input_global_scale", w13_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.empty(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_input_global_scale", w2_input_scale) + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} + ) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # From packed to weight + layer.w13_weight = torch.nn.Parameter( + layer.w13_weight_packed.data, requires_grad=False + ) + delattr(layer, "w13_weight_packed") + + layer.w2_weight = torch.nn.Parameter( + layer.w2_weight_packed.data, requires_grad=False + ) + delattr(layer, "w2_weight_packed") + + if self.use_flashinfer_trtllm: + w, s = reorder_w1w3_to_w3w1( + layer.w13_weight.data, layer.w13_weight_scale.data, dim=-2 + ) + layer.w13_weight = torch.nn.Parameter(w, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(s, requires_grad=False) + + if not torch.allclose( + layer.w13_weight_global_scale[:, 0], layer.w13_weight_global_scale[:, 1] + ): + logger.warning_once( + "w1_weight_global_scale must match w3_weight_global_scale. " + "Accuracy may be affected." + ) + + # Take inverse of global scale saved to disk + layer.w13_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w13_weight_global_scale[:, 0], requires_grad=False + ) + + layer.w2_weight_scale_2 = torch.nn.Parameter( + 1 / layer.w2_weight_global_scale.data, requires_grad=False + ) + + # w13 + if self.use_flashinfer_trtllm: + w13_input_global_scale = ( + layer.w13_input_global_scale.min() + .to(torch.float32) + .expand(layer.num_local_experts) + ) + else: + w13_input_global_scale = layer.w13_input_global_scale.min(dim=1).values.to( + torch.float32 + ) + layer.g1_alphas = torch.nn.Parameter( + ((1 / w13_input_global_scale) * layer.w13_weight_scale_2), + requires_grad=False, + ) + + layer.w13_input_scale_quant = torch.nn.Parameter( + (w13_input_global_scale), requires_grad=False + ) + + # w2 + if self.use_flashinfer_trtllm: + w2_input_global_scale = ( + layer.w2_input_global_scale.min() + .to(torch.float32) + .expand(layer.num_local_experts) + ) + else: + w2_input_global_scale = layer.w2_input_global_scale + + layer.g2_alphas = torch.nn.Parameter( + ((1 / w2_input_global_scale) * layer.w2_weight_scale_2).to(torch.float32), + requires_grad=False, + ) + + layer.w2_input_scale_quant = torch.nn.Parameter( + (w2_input_global_scale), requires_grad=False + ) + + # TensorRT-LLM specific processing + if self.use_flashinfer_trtllm: + # Prepare static weights for TRT-LLM kernel + ( + gemm1_weights_fp4_shuffled, + gemm1_scales_fp4_shuffled, + gemm2_weights_fp4_shuffled, + gemm2_scales_fp4_shuffled, + ) = prepare_static_weights_for_trtllm_fp4_moe( + layer.w13_weight, + layer.w2_weight, + layer.w13_weight_scale, + layer.w2_weight_scale, + layer.w2_weight.size(-2), # hidden_size + layer.w13_weight.size(-2) // 2, # intermediate_size + layer.w13_weight.size(0), # num_experts + ) + logger.debug("Finished shuffling weights for TRT-LLM MOE") + + layer.gemm1_weights_fp4_shuffled = torch.nn.Parameter( + gemm1_weights_fp4_shuffled, requires_grad=False + ) + layer.gemm2_weights_fp4_shuffled = torch.nn.Parameter( + gemm2_weights_fp4_shuffled, requires_grad=False + ) + layer.gemm1_scales_fp4_shuffled = torch.nn.Parameter( + gemm1_scales_fp4_shuffled, requires_grad=False + ) + layer.gemm2_scales_fp4_shuffled = torch.nn.Parameter( + gemm2_scales_fp4_shuffled, requires_grad=False + ) + + # Additional parameter needed for TRT-LLM + layer.g1_scale_c = torch.nn.Parameter( + (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), + requires_grad=False, + ) + + # Clean up weights that won't be used by TRT-LLM + del layer.w2_weight + del layer.w2_weight_scale + del layer.w13_weight + del layer.w13_weight_scale + else: + # swizzle weight scales + layer.w13_weight_scale = torch.nn.Parameter( + swizzle_blockscale(layer.w13_weight_scale), requires_grad=False + ) + + layer.w2_weight_scale = torch.nn.Parameter( + swizzle_blockscale(layer.w2_weight_scale), requires_grad=False + ) + + layer.cutlass_moe_params = CutlassMoEParams( + CutlassMoEType.BlockscaledFP4, + layer.w13_weight.device, + num_experts=layer.num_experts, + intermediate_size_per_partition=layer.w2_weight.shape[2] * 2, + hidden_size=layer.w13_weight.shape[2] * 2, + ) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4 + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + topk_weights, topk_ids = topk_output.topk_weights, topk_output.topk_ids + + output = cutlass_moe_fp4( + a=x, + a1_gscale=layer.w13_input_scale_quant, + w1_fp4=layer.w13_weight, + w1_blockscale=layer.w13_weight_scale, + w1_alphas=layer.g1_alphas, + a2_gscale=layer.w2_input_scale_quant, + w2_fp4=layer.w2_weight, + w2_blockscale=layer.w2_weight_scale, + w2_alphas=layer.g2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + params=layer.cutlass_moe_params, + apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input, + ).to(x.dtype) + + return StandardCombineInput(hidden_states=output) + + def apply_with_router_logits( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> torch.Tensor: + assert self.use_flashinfer_trtllm + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + from flashinfer import fp4_quantize, trtllm_fp4_block_scale_moe + + from sglang.srt.layers.moe.utils import RoutingMethodType + + router_logits = topk_output.router_logits + topk_config = topk_output.topk_config + + # Quantize input hidden states using fp4_quantize + hs_fp4_bytes, hs_sf_bytes = fp4_quantize( + x, + layer.w13_input_scale_quant, + self.group_size, # sf_vec_size + False, # use_ue8m0 + False, # is_sf_swizzled_layout + ) + hs_fp4 = hs_fp4_bytes.reshape(x.shape[0], x.shape[1] // 2) + hs_scale = hs_sf_bytes.view(torch.float8_e4m3fn).reshape(-1) + + correction_bias = ( + None + if topk_config.correction_bias is None + else topk_config.correction_bias.to(x.dtype) + ) + + assert layer.routing_method_type is not None + + # DeepSeekV3 style routing requires float32 router logits + if layer.routing_method_type == RoutingMethodType.DeepSeekV3: + router_logits = router_logits.to(torch.float32) + + routed_scaling_factor = self.moe_runner_config.routed_scaling_factor + routed_scaling_factor = ( + routed_scaling_factor if routed_scaling_factor is not None else 1.0 + ) + + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + num_tokens = hs_fp4.shape[0] + hidden_size = ( + hs_fp4.shape[-1] * 2 + if hs_fp4.dtype == torch.uint8 + else hs_fp4.shape[-1] + ) + symm_output = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=hs_fp4.device + ) + + return trtllm_fp4_block_scale_moe( + routing_logits=router_logits, + routing_bias=correction_bias, + hidden_states=hs_fp4, + hidden_states_scale=hs_scale, + gemm1_weights=layer.gemm1_weights_fp4_shuffled, + gemm1_weights_scale=layer.gemm1_scales_fp4_shuffled.view( + torch.float8_e4m3fn + ), + gemm1_bias=None, + gemm1_alpha=None, + gemm1_beta=None, + gemm1_clamp_limit=None, + gemm2_weights=layer.gemm2_weights_fp4_shuffled, + gemm2_weights_scale=layer.gemm2_scales_fp4_shuffled.view( + torch.float8_e4m3fn + ), + gemm2_bias=None, + output1_scale_scalar=layer.g1_scale_c, + output1_scale_gate_scalar=layer.g1_alphas, + output2_scale_scalar=layer.g2_alphas, + num_experts=layer.num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=layer.intermediate_size_per_partition, + local_expert_offset=layer.moe_ep_rank * layer.num_local_experts, + local_num_experts=layer.num_local_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=layer.routing_method_type, + do_finalize=True, + tune_max_num_tokens=next_power_of_2(hs_fp4.shape[0]), + output=symm_output, + )[0] + + +class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): + + def __init__(self, quant_config: CompressedTensorsConfig): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations" + ) + + per_tensor = ( + self.weight_quant.strategy == QuantizationStrategy.TENSOR + and self.input_quant.strategy == QuantizationStrategy.TENSOR + ) + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN + ) + if not (per_tensor or per_channel): + assert self.weight_quant.strategy == QuantizationStrategy.BLOCK + self.weight_block_size = self.weight_quant.block_structure + assert self.weight_quant.dynamic is not None + else: + self.weight_block_size = None + self.block_quant = self.weight_block_size is not None + + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales and per_channel: + raise ValueError( + "For FP8 Fused MoE layer, we require either per tensor or " + "channelwise, dynamic per token quantization." + ) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + params_dtype = torch.float8_e4m3fn + + if self.block_quant: + assert self.weight_block_size is not None + layer.weight_block_size = self.weight_block_size + tp_size = get_tensor_model_parallel_world_size() + block_n, block_k = ( + self.weight_block_size[0], + self.weight_block_size[1], + ) + # NOTE: To ensure proper alignment of the block-wise quantization + # scales, the output_size of the weights for both the gate and up + # layers must be divisible by block_n. + # Required by column parallel or enabling merged weights + if intermediate_size_per_partition % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1 and intermediate_size_per_partition % block_k != 0: + # Required by row parallel + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # per-tensor quantization + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + weight_quant_method = FusedMoeWeightScaleSupported.TENSOR.value + elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + weight_quant_method = FusedMoeWeightScaleSupported.CHANNEL.value + elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + weight_quant_method = FusedMoeWeightScaleSupported.BLOCK.value + else: + raise ValueError( + f"Unsupported weight quantization strategy: {self.weight_quant.strategy}" + ) + + layer.register_parameter("w13_weight_scale", w13_weight_scale) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update({"quant_method": weight_quant_method}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + assert ( + self.input_quant.strategy == QuantizationStrategy.TENSOR + ), "Only per-tensor quantization is supported for static input scales" + w13_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float32), requires_grad=False + ) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module | FusedMoE) -> None: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.static_input_scales: + if layer.w13_input_scale is None or layer.w2_input_scale is None: + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None." + ) + if not all_close_1d(layer.w13_input_scale) or not all_close_1d( + layer.w2_input_scale + ): + logger.warning( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer." + ) + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False + ) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False + ) + + if is_fp8_fnuz(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale + ) + ) + w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale + ) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter( + w13_weight_scale, requires_grad=False + ) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter( + w13_input_scale, requires_grad=False + ) + layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter( + w2_weight_scale, requires_grad=False + ) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter( + w2_input_scale, requires_grad=False + ) + if self.weight_quant.strategy == QuantizationStrategy.TENSOR: + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.num_local_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start : start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id], + ) + ( + layer.w13_weight[expert_id][start : start + shard_size, :], + _, + ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) + + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter( + max_w13_scales, requires_grad=False + ) + + if self.weight_quant.strategy == QuantizationStrategy.CHANNEL and _use_aiter: + with torch.no_grad(): + # Pre-shuffle weights + layer.w13_weight = torch.nn.Parameter( + shuffle_weight(layer.w13_weight.data, (16, 16)), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + shuffle_weight(layer.w2_weight.data, (16, 16)), + requires_grad=False, + ) + torch.cuda.empty_cache() + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + moe_runner_config = self.moe_runner_config + + if _use_aiter and self.weight_quant.strategy == QuantizationStrategy.CHANNEL: + assert not moe_runner_config.no_combine, "unsupported" + topk_weights, topk_ids, _ = topk_output + if moe_runner_config.apply_router_weight_on_input: + assert ( + topk_weights.dim() == 2 + ), "`topk_weights` should be in shape (num_tokens, topk)" + _, topk = topk_weights.shape + assert ( + topk == 1 + ), "Only support topk=1 when `apply_router_weight_on_input` is True" + x = x * topk_weights.to(x.dtype) + topk_weights = torch.ones_like( + topk_weights, dtype=torch.float32 + ) # topk_weights must be FP32 (float32) + output = fused_moe( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights, + topk_ids, + activation=( + ActivationType.Silu + if moe_runner_config.activation == "silu" + else ActivationType.Gelu + ), + quant_type=QuantType.per_Token, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + return StandardCombineInput(hidden_states=output) + elif self.weight_quant.strategy == QuantizationStrategy.BLOCK: + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + use_fp8_w8a8=True, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a13_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.weight_block_size, + ) + return self.runner.run(dispatch_output, quant_info) + else: + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight, + w2_weight=layer.w2_weight, + use_fp8_w8a8=True, + per_channel_quant=self.weight_quant.strategy + == QuantizationStrategy.CHANNEL, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a13_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + return self.runner.run(dispatch_output, quant_info) + + +class NPUCompressedTensorsW8A8Int8DynamicMoEMethod(CompressedTensorsMoEMethod): + + def __init__(self, quant_config: CompressedTensorsConfig): + self.quant_config = quant_config + self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights") + self.input_quant = self.quant_config.target_scheme_map["Linear"].get( + "input_activations" + ) + self.kernel = NPUW8A8Int8DynamicMoEMethod() + + self.static_input_scales = not self.input_quant.dynamic + per_channel = ( + self.weight_quant.strategy == QuantizationStrategy.CHANNEL + and self.input_quant.strategy == QuantizationStrategy.TOKEN + ) + if not per_channel: + raise ValueError( + "For INT8 Fused MoE layers, we require channelwise, " + "dynamic per token quantization. Found " + f"{self.weight_quant}, {self.input_quant}" + ) + + self.static_input_scales = not self.input_quant.dynamic + if self.static_input_scales: + raise ValueError( + "For INT8 Fused MoE layers, we require channelwise, " + "dynamic per token quantization. Found static input scales." + ) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + params_dtype = torch.int8 + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + w2_weight_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add PER-CHANNEL quantization for FusedMoE.weight_loader. + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + assert not self.static_input_scales + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + return self.kernel.apply(layer, dispatch_output) + + +class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod): + + def __init__(self, quant_config: CompressedTensorsConfig, num_gpu_experts=-1): + self.quant_config = quant_config + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + config = self.quant_config.target_scheme_map["Linear"].get("weights") + self.num_bits = config.num_bits + self.packed_factor = 32 // config.num_bits + self.strategy = config.strategy + self.group_size = config.group_size + self.actorder = config.actorder + assert config.symmetric, "Only symmetric quantization is supported for MoE" + + if not ( + self.quant_config.quant_format == CompressionFormat.pack_quantized.value + and self.num_bits in WNA16_SUPPORTED_BITS + ): + raise ValueError( + "For Fused MoE layers, only ", + f"{CompressionFormat.pack_quantized.value} ", + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}", + ) + self.num_gpu_experts = num_gpu_experts + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Will transpose the loaded weight along the + # intermediate and hidden dim sizes. Will + # shard for TP along the transposed dims + extra_weight_attrs.update( + {"is_transposed": True, "quant_method": self.strategy} + ) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.packed_factor, + 2 * intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition // self.packed_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # In the case where we have actorder/g_idx, + # we do not partition the w2 scales + load_full_w2 = self.actorder and self.group_size != -1 + + if load_full_w2: + w2_scales_size = intermediate_size_per_partition * layer.moe_tp_size + else: + w2_scales_size = intermediate_size_per_partition + + self.is_k_full = (not self.actorder) or layer.moe_tp_size == 1 + + if self.strategy == "channel": + num_groups_w2 = num_groups_w13 = 1 + self.group_size = -1 + else: + num_groups_w2 = w2_scales_size // self.group_size + num_groups_w13 = hidden_size // self.group_size + + w13_scale = torch.nn.Parameter( + torch.ones( + num_experts, + num_groups_w13, + 2 * intermediate_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_scale) + set_weight_attrs(w13_scale, extra_weight_attrs) + + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_scale) + set_weight_attrs(w2_scale, extra_weight_attrs) + set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2}) + + w2_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + layer.a13_scale = None + layer.a2_scale = None + layer.marlin_state = GPTQMarlinState.REPACK + + if not hasattr(layer, "_original_shapes"): + layer._original_shapes = {} + + # Force record: these are the target GPTQ shapes for rollback. + layer._original_shapes["w13_weight_packed"] = tuple(w13_weight.shape) + layer._original_shapes["w2_weight_packed"] = tuple(w2_weight.shape) + + # Also record the shapes of the scales. + layer._original_shapes["w2_weight_scale"] = tuple(w2_scale.shape) + layer._original_shapes["w13_weight_scale"] = tuple(w13_scale.shape) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # Skip if the layer is already converted to Marlin format to prevent double-packing. + if getattr(layer, "is_marlin_converted", False): + return + + if not hasattr(layer, "_original_shapes"): + layer._original_shapes = {} + + def replace_tensor(name, new_t): + target_attr = getattr(layer, name) + + # Only save if the key doesn't exist to prevent overwriting with Marlin shapes. + if name not in layer._original_shapes: + # This is a safety check; `create_weights` usually handles this already. + layer._original_shapes[name] = tuple(target_attr.shape) + + # It is important to use resize_() here since it ensures + # the same buffer is reused + target_attr.resize_(new_t.shape) + target_attr.copy_(new_t) + del new_t + + num_experts = layer.w13_weight_g_idx.shape[0] + device = layer.w13_weight_g_idx.device + + # when running models with grouped act order, + # resort to g_idx values provided in checkpoint + if self.actorder == "group": + w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx) + + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_weight_g_idx[e]).to( + torch.int32 + ) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_weight_g_idx[e]).to( + torch.int32 + ) + w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][ + w13_g_idx_sort_indices[e] + ] + w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][w2_g_idx_sort_indices[e]] + + replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx) + replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx) + replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices) + replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices) + + else: + layer.w13_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_weight_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = gptq_marlin_moe_repack( + layer.w13_weight_packed, + layer.w13_g_idx_sort_indices, + layer.w13_weight_packed.shape[1] * self.packed_factor, + layer.w13_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w13_weight_packed", marlin_w13_qweight) + marlin_w2_qweight = gptq_marlin_moe_repack( + layer.w2_weight_packed, + layer.w2_g_idx_sort_indices, + layer.w2_weight_packed.shape[1] * self.packed_factor, + layer.w2_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w2_weight_packed", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + layer.w13_weight_scale, + layer.w13_weight_packed.shape[2], + layer.w13_weight_scale.shape[2], + self.group_size, + ) + replace_tensor("w13_weight_scale", marlin_w13_scales) + + marlin_w2_scales = marlin_moe_permute_scales( + layer.w2_weight_scale, + layer.w2_weight_scale.shape[1] + * (self.group_size if self.group_size != -1 else self.packed_factor), + layer.w2_weight_scale.shape[2], + self.group_size, + ) + replace_tensor("w2_weight_scale", marlin_w2_scales) + + layer.is_marlin_converted = True + + def restore_weights_before_loading(self, layer: torch.nn.Module): + """Forcibly resize parameters back to their original shapes (e.g., GPTQ format) before loading weights.""" + + if not hasattr(layer, "_original_shapes"): + return + + for name, orig_shape in layer._original_shapes.items(): + param = getattr(layer, name, None) + + if param is not None and param.shape != orig_shape: + param.resize_(orig_shape) + + layer.is_marlin_converted = False + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + from sglang.srt.layers.moe.fused_moe_triton.fused_marlin_moe import ( + fused_marlin_moe, + ) + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + assert ( + self.moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + topk_weights, topk_ids, router_logits = topk_output + + # Get expert_map for EP support + expert_map = None + global_num_experts = -1 + if hasattr(layer, "dispatcher") and hasattr( + layer.dispatcher, "local_expert_mapping" + ): + expert_map = layer.dispatcher.local_expert_mapping + if expert_map is not None: + global_num_experts = self.moe_runner_config.num_experts + + output = fused_marlin_moe( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + layer.w13_weight_scale, + layer.w2_weight_scale, + router_logits, + topk_weights, + topk_ids, + global_num_experts=global_num_experts, + expert_map=expert_map, + g_idx1=layer.w13_weight_g_idx, + g_idx2=layer.w2_weight_g_idx, + sort_indices1=layer.w13_g_idx_sort_indices, + sort_indices2=layer.w2_g_idx_sort_indices, + num_bits=self.num_bits, + is_k_full=self.is_k_full, + routed_scaling_factor=self.moe_runner_config.routed_scaling_factor, + ) + return StandardCombineInput(hidden_states=output) + + +class CompressedTensorsWNA16TritonMoEMethod(CompressedTensorsWNA16MoEMethod): + """ROCm/HIP-compatible W4A16 MoE method using Triton kernels instead of Marlin. + + Inherits weight creation from CompressedTensorsWNA16MoEMethod but converts + weights to the uint8-packed format expected by the Triton fused MoE kernel + instead of the Marlin-specific format. + """ + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if getattr(layer, "is_triton_converted", False): + return + + num_experts = layer.w13_weight_packed.shape[0] + + # Convert w13 weights: [E, K//8, N] int32 -> [E, N, K//2] uint8 + w13 = layer.w13_weight_packed.data + w13 = w13.transpose(1, 2).contiguous().view(torch.uint8) + layer.w13_weight_packed = torch.nn.Parameter(w13, requires_grad=False) + + # Convert w2 weights: [E, K//8, N] int32 -> [E, N, K//2] uint8 + w2 = layer.w2_weight_packed.data + w2 = w2.transpose(1, 2).contiguous().view(torch.uint8) + layer.w2_weight_packed = torch.nn.Parameter(w2, requires_grad=False) + + # Convert w13 scales: [E, K//group_size, N] -> [E, N, K//group_size] + w13_scale = layer.w13_weight_scale.data + w13_scale = w13_scale.transpose(1, 2).contiguous() + layer.w13_weight_scale = torch.nn.Parameter(w13_scale, requires_grad=False) + + # Convert w2 scales: [E, K//group_size, N] -> [E, N, K//group_size] + w2_scale = layer.w2_weight_scale.data + w2_scale = w2_scale.transpose(1, 2).contiguous() + layer.w2_weight_scale = torch.nn.Parameter(w2_scale, requires_grad=False) + + layer.is_triton_converted = True + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config) + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: "StandardDispatchOutput", + ) -> "CombineInput": + from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo + + assert ( + self.moe_runner_config.activation == "silu" + ), "Only SiLU activation is supported." + + quant_info = TritonMoeQuantInfo( + w13_weight=layer.w13_weight_packed, + w2_weight=layer.w2_weight_packed, + use_int4_w4a16=True, + w13_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + block_shape=[0, self.group_size], + ) + return self.runner.run(dispatch_output, quant_info) + + +class NPUCompressedTensorsW4A8Int8DynamicMoEMethod(CompressedTensorsMoEMethod): + + ### TODO: Get rid of code duplication with python/sglang/srt/modelslim/modelslim_moe.py @OrangeRedeng @TamirBaydasov + def __init__(self, quantization_config) -> None: + self.group_size = 0 + self.is_per_channel_weight = self.group_size == 0 + self.tp_size = 1 + self.activation_use_clip = ( + self.quantization_config.get("config_groups", {}) + .get("group_1", {}) + .get("activation_use_clip", False) + ) + self.kernel = NPUW4A8Int8DynamicMoEMethod() + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + self.num_experts = num_experts + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} + ) + + # >> weight + w13_output_size = intermediate_size_per_partition + w2_output_size = hidden_size // 2 + w13_weight = torch.nn.Parameter( + torch.empty(num_experts, w13_output_size, hidden_size, dtype=torch.int8), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + w2_output_size, + intermediate_size_per_partition, + dtype=torch.int8, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # >> scale + weight_scale_dtype = torch.int64 if self.activation_use_clip else torch.float32 + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + 1, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + + w2_weight_scale = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, 1, dtype=weight_scale_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # >> offset + w13_weight_offset = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_offset", w13_weight_offset) + set_weight_attrs(w13_weight_offset, extra_weight_attrs) + + w2_weight_offset = torch.nn.Parameter( + torch.empty(num_experts, hidden_size, 1, dtype=torch.float32), + requires_grad=False, + ) + layer.register_parameter("w2_weight_offset", w2_weight_offset) + set_weight_attrs(w2_weight_offset, extra_weight_attrs) + + # >>> special param for w4a8 + if self.activation_use_clip: + self._init_activation_clip_params( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + extra_weight_attrs, + ) + else: + self._init_extra_scale_params( + layer, + num_experts, + hidden_size, + intermediate_size_per_partition, + extra_weight_attrs, + ) + + def _init_activation_clip_params( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + extra_weight_attrs: dict, + ) -> None: + """ + Initializes bias and alpha parameters for quantization schemes that use activation clipping. + + This helper registers `w13_bias`, `w2_bias`, and `w2_alpha`, which are required to + shift and scale the activations or outputs to compensate for the precision loss + introduced by clamping activations. + """ + w13_bias = torch.nn.Parameter( + torch.ones( + num_experts, 2 * intermediate_size_per_partition, dtype=torch.float + ), + requires_grad=False, + ) + layer.register_parameter("w13_bias", w13_bias) + set_weight_attrs(w13_bias, extra_weight_attrs) + + w2_bias = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, dtype=torch.float), + requires_grad=False, + ) + layer.register_parameter("w2_bias", w2_bias) + set_weight_attrs(w2_bias, extra_weight_attrs) + + w2_alpha = torch.nn.Parameter( + torch.ones(num_experts, dtype=torch.float), requires_grad=False + ) + layer.register_parameter("w2_alpha", w2_alpha) + set_weight_attrs(w2_alpha, extra_weight_attrs) + + def _init_extra_scale_params( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + extra_weight_attrs: dict, + ) -> None: + """ + Initializes additional scaling, offset, and bias parameters for quantization schemes without activation clipping. + + This method registers the following parameters: + 1. Scale Biases: `w13_scale_bias` and `w2_scale_bias`. + 2. Secondary Quantization Params (initialized only for grouped quantization): + `w13_weight_scale_second`, `w13_weight_offset_second`, + `w2_weight_scale_second`, and `w2_weight_offset_second`. + """ + if not self.is_per_channel_weight: + w13_weight_scale_second = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_second", w13_weight_scale_second) + set_weight_attrs(w13_weight_scale_second, extra_weight_attrs) + + w13_weight_offset_second = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter( + "w13_weight_offset_second", w13_weight_offset_second + ) + set_weight_attrs(w13_weight_offset_second, extra_weight_attrs) + + w2_weight_scale_second = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale_second", w2_weight_scale_second) + set_weight_attrs(w2_weight_scale_second, extra_weight_attrs) + + w2_weight_offset_second = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_offset_second", w2_weight_offset_second) + set_weight_attrs(w2_weight_offset_second, extra_weight_attrs) + + w13_scale_bias = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w13_scale_bias", w13_scale_bias) + set_weight_attrs(w13_scale_bias, extra_weight_attrs) + + w2_scale_bias = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, 16 // self.tp_size, dtype=torch.float32 + ), + requires_grad=False, + ) + layer.register_parameter("w2_scale_bias", w2_scale_bias) + set_weight_attrs(w2_scale_bias, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading( + layer, self.is_per_channel_weight, self.activation_use_clip + ) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + return self.kernel.apply(layer, dispatch_output) + + def apply_without_routing_weights( + self, + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ): + return self.kernel.apply_without_routing_weights( + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ) + + +class NPUCompressedTensorsW4A16Int4DynamicMoEMethod(CompressedTensorsMoEMethod): + + def __init__(self, quantization_config) -> None: + self.pack_factor = 8 # weight dtype is int4, but use int32 to create + target = ( + "MoEGMM" if "MoEGMM" in quantization_config.target_scheme_map else "Linear" + ) + if target in quantization_config.target_scheme_map: + self.group_size = quantization_config.target_scheme_map[target][ + "weights" + ].group_size + else: + self.group_size = 128 + + self.kernel = NPUW4A16Int4DynamicMoEMethod() + + # TODO: See if we can merge this method's logic + # with CompressedTensorsWNA16MoEMethod. Need more models and tests. + # @OrangeRedeng @TamirBaydasov + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + self.num_experts = num_experts + if ( + extra_weight_attrs.get( + "moe_intermediate_size", intermediate_size_per_partition + ) + // intermediate_size_per_partition + > 1 + ): + quant_method = FusedMoeWeightScaleSupported.GROUP.value + else: + quant_method = FusedMoeWeightScaleSupported.CHANNEL.value + extra_weight_attrs.update({"quant_method": quant_method}) + # weight + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # scale + weight_scale_dtype = torch.bfloat16 + w13_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + w2_weight_scale = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # offset + w13_weight_offset = torch.nn.Parameter( + torch.zeros( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_offset", w13_weight_offset) + set_weight_attrs(w13_weight_offset, extra_weight_attrs) + + w2_weight_offset = torch.nn.Parameter( + torch.zeros( + num_experts, + hidden_size, + intermediate_size_per_partition // self.group_size, + dtype=weight_scale_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_offset", w2_weight_offset) + set_weight_attrs(w2_weight_offset, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + + return self.kernel.apply(layer, dispatch_output) + + def apply_without_routing_weights( + self, + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ): + return self.kernel.apply_without_routing_weights( + layer, + hidden_states, + hidden_states_scale, + group_list_type, + group_list, + output_dtype, + ) + + +class CompressedTensorsMxInt4MoEMethod(CompressedTensorsMoEMethod): + def __init__(self, quant_config: CompressedTensorsConfig): + self.quant_config = quant_config + config = self.quant_config.target_scheme_map["Linear"].get("weights") + self.num_bits = config.num_bits + self.packed_factor = 32 // config.num_bits + self.strategy = config.strategy + self.group_size = config.group_size + self.actorder = config.actorder + assert ( + config.strategy == "group" + and config.group_size == 32 + and config.num_bits == 4 + ), "MxInt4 only supports group strategy with group size 32" + assert config.symmetric, "Only symmetric quantization is supported for MoE" + assert ( + get_moe_runner_backend().is_flashinfer_trtllm() + ), "MxInt4 only supports flashinfer_trtllm backend" + assert ( + not config.actorder + ), "Actorder is not supported by flashinfer_trtllm backend" + self.moe_ep_rank = get_moe_expert_parallel_rank() + + if self.quant_config.quant_format != CompressionFormat.pack_quantized.value: + raise ValueError( + f"For Fused MoE layers, only {CompressionFormat.pack_quantized.value} " + "is supported for the mxint4" + ) + self._cache_permute_indices = {} + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + extra_weight_attrs.update({"quant_method": self.strategy}) + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size // self.packed_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition // self.packed_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + w2_scales_size = intermediate_size_per_partition + num_groups_w2 = w2_scales_size // self.group_size + num_groups_w13 = hidden_size // self.group_size + + assert params_dtype == torch.bfloat16 + w13_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * intermediate_size_per_partition, + num_groups_w13, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale", w13_scale) + set_weight_attrs(w13_scale, extra_weight_attrs) + + w2_scale = torch.nn.Parameter( + torch.ones(num_experts, hidden_size, num_groups_w2, dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_weight_scale", w2_scale) + set_weight_attrs(w2_scale, extra_weight_attrs) + + w13_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + w2_weight_shape = torch.nn.Parameter( + torch.empty(num_experts, 2), requires_grad=False + ) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + + layer.a13_scale = None + layer.a2_scale = None + + # Adapted from https://github.com/flashinfer-ai/flashinfer/blob/main/tests/moe/test_trtllm_gen_fused_moe.py + def prepare_static_weights_for_kernel( + self, + gemm1_weights, + gemm2_weights, + gemm1_scales, + gemm2_scales, + num_experts, + ): + """Prepare quantized weights for kernel (done offline with weights).""" + + epilogue_tile_m = 128 + gemm1_weights_mxint4_shuffled = [] + gemm1_scales_shuffled = [] + gemm2_weights_mxint4_shuffled = [] + gemm2_scales_shuffled = [] + + def repack(w): + assert w.dim() == 2 and w.dtype == torch.int32 + shifts = torch.arange(0, 32, 4, dtype=torch.int32, device=w.device) + w = (w.unsqueeze(2) >> shifts) & 0x0F + w = (w - 8).to(torch.int8).reshape(w.shape[0], -1, 2) + w = (w[..., 0] & 0x0F) | ((w[..., 1] & 0x0F) << 4) + w = w.to(torch.uint8) + return w + + for i in range(num_experts): + # NOTE(HandH1998): + # the huggingface weight format follows (w/s + 8) to pack, + # however, trtllm requires (w/s) to pack + # we need to convert the weight to trtllm's format first + cur_expert_gemm1_weight = repack(gemm1_weights[i]) + cur_expert_gemm2_weight = repack(gemm2_weights[i]) + + # Calculate the permute indices for the following: + # 1. Reorder rows of W1 and scales for fused gated activation + # 2. Shuffle weights and scaling factors for transposed mma output + # for both w3_w1 and w2 weights and scale factors + permute_indices = _maybe_get_cached_w3_w1_permute_indices( + self._cache_permute_indices, + cur_expert_gemm1_weight, + epilogue_tile_m, + ) + gemm1_weights_shuffled = cur_expert_gemm1_weight[ + permute_indices.to(gemm1_weights.device) + ].contiguous() + permute_sf_indices = _maybe_get_cached_w3_w1_permute_indices( + self._cache_permute_indices, + gemm1_scales[i].to(torch.bfloat16), + epilogue_tile_m, + num_elts_per_sf=32, + ) + gemm1_scales_shuffled.append( + block_scale_interleave( + gemm1_scales[i] + .to(torch.bfloat16)[permute_sf_indices.to(gemm1_scales.device)] + .contiguous() + ) + ) + + permute_indices = get_w2_permute_indices_with_cache( + self._cache_permute_indices, + cur_expert_gemm2_weight, + epilogue_tile_m, + ) + gemm2_weights_shuffled = cur_expert_gemm2_weight[ + permute_indices.to(gemm2_weights.device) + ].contiguous() + + permute_sf_indices = get_w2_permute_indices_with_cache( + self._cache_permute_indices, + gemm2_scales[i].to(torch.bfloat16), + epilogue_tile_m, + num_elts_per_sf=16, + ) + gemm2_scales_shuffled.append( + block_scale_interleave( + gemm2_scales[i] + .to(torch.bfloat16)[permute_sf_indices.to(gemm2_scales.device)] + .contiguous() + ) + ) + + block_k = 128 + gemm1_weights_shuffled = convert_to_block_layout( + gemm1_weights_shuffled.view(torch.uint8), block_k + ) + gemm2_weights_shuffled = convert_to_block_layout( + gemm2_weights_shuffled.view(torch.uint8), block_k + ) + + gemm1_weights_mxint4_shuffled.append(gemm1_weights_shuffled) + gemm2_weights_mxint4_shuffled.append(gemm2_weights_shuffled) + + gemm1_weights_mxint4_shuffled = torch.stack(gemm1_weights_mxint4_shuffled) + gemm2_weights_mxint4_shuffled = torch.stack(gemm2_weights_mxint4_shuffled) + gemm1_scales_shuffled = torch.stack(gemm1_scales_shuffled).view(torch.bfloat16) + gemm2_scales_shuffled = torch.stack(gemm2_scales_shuffled).view(torch.bfloat16) + + return ( + gemm1_weights_mxint4_shuffled, + gemm1_scales_shuffled, + gemm2_weights_mxint4_shuffled, + gemm2_scales_shuffled, + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + num_experts = layer.w13_weight_packed.shape[0] + ( + gemm1_weights_mxint4_shuffled, + gemm1_scales_shuffled, + gemm2_weights_mxint4_shuffled, + gemm2_scales_shuffled, + ) = self.prepare_static_weights_for_kernel( + layer.w13_weight_packed, + layer.w2_weight_packed, + layer.w13_weight_scale, + layer.w2_weight_scale, + num_experts=num_experts, + ) + replace_parameter(layer, "w13_weight_packed", gemm1_weights_mxint4_shuffled) + replace_parameter(layer, "w2_weight_packed", gemm2_weights_mxint4_shuffled) + replace_parameter(layer, "w13_weight_scale", gemm1_scales_shuffled) + replace_parameter(layer, "w2_weight_scale", gemm2_scales_shuffled) + + def create_moe_runner( + self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig + ): + self.moe_runner_config = moe_runner_config + + def apply( + self, + layer: torch.nn.Module, + dispatch_output: StandardDispatchOutput, + ) -> CombineInput: + from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput + + assert ( + self.moe_runner_config.is_gated + ), "Only gated MoEs are supported for flashinfer mxint4" + + x = dispatch_output.hidden_states + topk_output = dispatch_output.topk_output + + router_logits = topk_output.router_logits + topk_config = topk_output.topk_config + correction_bias = ( + None + if topk_config.correction_bias is None + else topk_config.correction_bias.to(x.dtype) + ) + + local_num_experts = self.moe_runner_config.num_local_experts + routing_method_type = layer.routing_method_type + assert routing_method_type is not None + # DeepSeekV3 style routing requires float32 router logits, + # see this PR for details: https://github.com/flashinfer-ai/flashinfer/commit/d84e1d560da0a27961c19ca788d96c19cb9dcfb6 + if routing_method_type == RoutingMethodType.DeepSeekV3: + router_logits = router_logits.to(torch.float32) + routed_scaling_factor = self.moe_runner_config.routed_scaling_factor + routed_scaling_factor = ( + routed_scaling_factor if routed_scaling_factor is not None else 1.0 + ) + + with use_symmetric_memory( + get_tp_group(), disabled=not is_allocation_symmetric() + ): + num_tokens = x.shape[0] + hidden_size = x.shape[-1] + symm_output = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=x.device + ) + + output = trtllm_mxint4_block_scale_moe( + routing_logits=router_logits, # float + routing_bias=correction_bias, + hidden_states=x, + gemm1_weights=layer.w13_weight_packed, + gemm1_weights_scale=layer.w13_weight_scale, + gemm1_alpha=self.moe_runner_config.gemm1_alpha, + gemm1_beta=None, + gemm1_clamp_limit=self.moe_runner_config.gemm1_clamp_limit, + gemm2_weights=layer.w2_weight_packed, + gemm2_weights_scale=layer.w2_weight_scale, + num_experts=self.moe_runner_config.num_experts, + top_k=topk_config.top_k, + n_group=topk_config.num_expert_group, + topk_group=topk_config.topk_group, + intermediate_size=self.moe_runner_config.intermediate_size_per_partition, + local_expert_offset=self.moe_ep_rank * local_num_experts, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling_factor, + routing_method_type=routing_method_type, + tune_max_num_tokens=next_power_of_2(x.shape[0]), + output=symm_output, + ) + + return StandardCombineInput(hidden_states=output) diff --git a/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..c098ef2dbb9a --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..6f5adbb93612 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..6f5adbb93612 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..4225c78eb72c --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..4225c78eb72c --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..5e6789d00e0a --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..49ac14d2a576 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..dcbb0efc53e4 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..dfe5c1e43d68 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..dfe5c1e43d68 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..a87f5de1b183 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json new file mode 100644 index 000000000000..468f9e78da00 --- /dev/null +++ b/python/sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X_VF,dtype=fp8_w8a8,block_shape=[128, 128].json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format b/python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format deleted file mode 120000 index 5a7a8cea7bb0..000000000000 --- a/python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format +++ /dev/null @@ -1 +0,0 @@ -../../../../../sgl-kernel/.clang-format \ No newline at end of file diff --git a/python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format b/python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format new file mode 100644 index 000000000000..afbd654a7903 --- /dev/null +++ b/python/sglang/srt/mem_cache/cpp_radix_tree/.clang-format @@ -0,0 +1,15 @@ +BasedOnStyle: Google +IndentWidth: 2 +ColumnLimit: 120 +AllowShortFunctionsOnASingleLine: Empty +DerivePointerAlignment: false +PointerAlignment: Left +NamespaceIndentation: None +SortIncludes: true +AllowShortLoopsOnASingleLine: false +BinPackParameters: false # Prevents packing parameters in declarations +BinPackArguments: false # Prevents packing arguments in function calls +AlignAfterOpenBracket: AlwaysBreak # Forces a break after the opening parenthesis +AlignOperands: Align # Aligns arguments vertically +PenaltyBreakBeforeFirstCallParameter: 1 # Encourages breaking before the first argument +PenaltyReturnTypeOnItsOwnLine: 100 # Keeps return type with function name diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index b0b2ede6dbde..24abffb63c86 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -362,8 +362,21 @@ def __init__(self, model_runner: ModelRunner): # Capture try: + orig_enable_torch_compile = False + if self.enable_torch_compile: + orig_enable_torch_compile = True + self.enable_torch_compile = False + orig_compile_bs = self.compile_bs + self.compile_bs = [] + with model_capture_mode(): self.capture() + + self.enable_torch_compile = orig_enable_torch_compile + if orig_enable_torch_compile: + self.compile_bs = orig_compile_bs + set_torch_compile_config() + except RuntimeError as e: raise Exception( f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" @@ -678,6 +691,18 @@ def capture_one_batch_size( ) self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) + if self.enable_torch_compile and (bs in self.compile_bs): + if not hasattr(self, "_compiled_forward"): + compiled_forward = torch.compile(forward) + for _ in range(2): + compiled_forward( + input_ids, + forward_batch.positions, + forward_batch, + ) + self._compiled_forward = compiled_forward + forward = self._compiled_forward + if lora_ids is not None: self.model_runner.lora_manager.prepare_lora_batch(forward_batch) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 988ed91e831d..21f7e1cf5d15 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -2629,7 +2629,14 @@ def forward( pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Union[torch.Tensor, PPProxyTensors]: total_num_layers = self.end_layer - self.start_layer - device = input_embeds.device if input_embeds is not None else input_ids.device + if input_embeds is not None: + device = input_embeds.device + elif input_ids is not None: + device = input_ids.device + else: + # For non-first PP rank, get device from pp_proxy_tensors + assert pp_proxy_tensors is not None, "pp_proxy_tensors required for non-first PP rank" + device = pp_proxy_tensors["hidden_states"].device zero_allocator = BumpAllocator( buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1), dtype=torch.float32, diff --git a/python/sglang/srt/models/kimi_k25.py b/python/sglang/srt/models/kimi_k25.py index 309eb539a9a7..3ca1da861017 100644 --- a/python/sglang/srt/models/kimi_k25.py +++ b/python/sglang/srt/models/kimi_k25.py @@ -1,6 +1,6 @@ import logging from copy import deepcopy -from typing import Iterable, List, Optional, Sequence, Tuple +from typing import Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -9,11 +9,13 @@ from transformers import activations from sglang.srt.configs.kimi_k25 import KimiK25Config, KimiK25VisionConfig +from sglang.srt.distributed.parallel_state import get_pp_group from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternMultimodalTokens, general_mm_embed_routine, ) +from sglang.srt.model_executor.forward_batch_info import PPProxyTensors try: from transformers.activations import PytorchGELUTanh @@ -659,6 +661,9 @@ def __init__( self.language_model = DeepseekV3ForCausalLM(config.text_config, quant_config) + # Initialize PP group for pipeline parallelism support + self.pp_group = get_pp_group() + # Ensure that the dtype of the vision_tower and mm_projector matches that of the language_model. # This solves the dtype mismatch issue when using device_map="auto" and torch_dtype. if hasattr(self.language_model, "dtype"): @@ -695,7 +700,20 @@ def forward( positions: torch.Tensor, forward_batch: ForwardBatch, get_embedding: bool = False, - ): + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[torch.Tensor, PPProxyTensors]: + """Forward pass with Pipeline Parallelism support. + + Args: + input_ids: Input token IDs + positions: Position IDs + forward_batch: Forward batch information + get_embedding: Whether to return embeddings + pp_proxy_tensors: Pipeline parallelism proxy tensors for inter-stage communication + + Returns: + Hidden states or logits depending on PP rank + """ hidden_states = general_mm_embed_routine( input_ids=input_ids, forward_batch=forward_batch, @@ -704,6 +722,7 @@ def forward( Modality.IMAGE: self.get_image_feature, }, positions=positions, + pp_proxy_tensors=pp_proxy_tensors, ) return hidden_states @@ -745,3 +764,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): EntryClass = [KimiK25ForConditionalGeneration] + diff --git a/sgl-kernel/csrc/allreduce/deterministic_all_reduce_hip.hip b/sgl-kernel/csrc/allreduce/deterministic_all_reduce_hip.hip new file mode 100644 index 000000000000..1de97463cc3c --- /dev/null +++ b/sgl-kernel/csrc/allreduce/deterministic_all_reduce_hip.hip @@ -0,0 +1,179 @@ +#include "hip/hip_runtime.h" +// Deterministic All-Reduce for ROCm/HIP +// +// This is a wrapper that forces the use of the existing 1-stage all-reduce kernel +// (cross_device_reduce_1stage) which is inherently deterministic due to fixed +// accumulation ordering (no atomics, no race conditions). +// +// How the 1-stage kernel works: +// - Each GPU reads ALL data from ALL other GPUs via direct memory access +// - Each GPU reduces the data locally in a fixed order +// - Result: every GPU has the complete reduced output +// +// This is NOT a reduce-scatter + all-gather (RS+AG) approach. +// The 2-stage kernel (cross_device_reduce_2stage) implements RS+AG but may have +// non-deterministic behavior, so we explicitly avoid it here. + +#include +#include +#include +#include + +#include "custom_all_reduce_hip.cuh" + +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +// Helper function for weak contiguity check +bool _is_weak_contiguous_det(torch::Tensor& t) { + return t.is_contiguous() || + (t.storage().nbytes() - t.storage_offset() * t.element_size() == t.numel() * t.element_size()); +} + +// Deterministic all-reduce for registered buffers (ROCm) +// Uses the 1-stage kernel which is deterministic (fixed ordering) +void deterministic_all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp)); + auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK(_is_weak_contiguous_det(out)); + TORCH_CHECK(_is_weak_contiguous_det(inp)); + + auto fa = reinterpret_cast(_fa); + + // For ROCm, manually call the 1-stage kernel to ensure deterministic ordering + // Get rank data pointer + sglang::RankData* ptrs; + hipStreamCaptureStatus status; + AT_CUDA_CHECK(hipStreamIsCapturing(stream, &status)); + if (status == hipStreamCaptureStatusActive) { + ptrs = fa->d_rank_data_base_ + fa->graph_unreg_buffers_.size(); + fa->graph_unreg_buffers_.push_back(inp.data_ptr()); + } else { + auto it = fa->buffers_.find(inp.data_ptr()); + if (it == fa->buffers_.end()) { + throw std::runtime_error("buffer not registered!"); + } + ptrs = it->second; + } + + int size = out.numel(); + int threads = 512; + + switch (out.scalar_type()) { + case at::ScalarType::Float: { + using T = float; + using P = typename sglang::packed_t::P; + auto d = P::size; + if (size % d != 0) { + throw std::runtime_error("size must be multiple of " + std::to_string(d)); + } + size /= d; + int blocks = std::min(16, (size + threads - 1) / threads); + // Always use 1-stage kernel for determinism + switch (fa->world_size_) { + case 2: + hipLaunchKernelGGL((sglang::cross_device_reduce_1stage), dim3(blocks), dim3(threads), 0, stream, + ptrs, fa->sg_, fa->self_sg_, reinterpret_cast(out.data_ptr()), fa->rank_, size); + break; + case 4: + hipLaunchKernelGGL((sglang::cross_device_reduce_1stage), dim3(blocks), dim3(threads), 0, stream, + ptrs, fa->sg_, fa->self_sg_, reinterpret_cast(out.data_ptr()), fa->rank_, size); + break; + case 6: + hipLaunchKernelGGL((sglang::cross_device_reduce_1stage), dim3(blocks), dim3(threads), 0, stream, + ptrs, fa->sg_, fa->self_sg_, reinterpret_cast(out.data_ptr()), fa->rank_, size); + break; + case 8: + hipLaunchKernelGGL((sglang::cross_device_reduce_1stage), dim3(blocks), dim3(threads), 0, stream, + ptrs, fa->sg_, fa->self_sg_, reinterpret_cast(out.data_ptr()), fa->rank_, size); + break; + default: + throw std::runtime_error("world_size must be in (2,4,6,8)"); + } + break; + } + case at::ScalarType::Half: { + using T = half; + using P = typename sglang::packed_t::P; + auto d = P::size; + if (size % d != 0) { + throw std::runtime_error("size must be multiple of " + std::to_string(d)); + } + size /= d; + int blocks = std::min(16, (size + threads - 1) / threads); + switch (fa->world_size_) { + case 2: + hipLaunchKernelGGL((sglang::cross_device_reduce_1stage), dim3(blocks), dim3(threads), 0, stream, + ptrs, fa->sg_, fa->self_sg_, reinterpret_cast(out.data_ptr()), fa->rank_, size); + break; + case 4: + hipLaunchKernelGGL((sglang::cross_device_reduce_1stage), dim3(blocks), dim3(threads), 0, stream, + ptrs, fa->sg_, fa->self_sg_, reinterpret_cast(out.data_ptr()), fa->rank_, size); + break; + case 6: + hipLaunchKernelGGL((sglang::cross_device_reduce_1stage), dim3(blocks), dim3(threads), 0, stream, + ptrs, fa->sg_, fa->self_sg_, reinterpret_cast(out.data_ptr()), fa->rank_, size); + break; + case 8: + hipLaunchKernelGGL((sglang::cross_device_reduce_1stage), dim3(blocks), dim3(threads), 0, stream, + ptrs, fa->sg_, fa->self_sg_, reinterpret_cast(out.data_ptr()), fa->rank_, size); + break; + default: + throw std::runtime_error("world_size must be in (2,4,6,8)"); + } + break; + } +#if (__HIP_ARCH__ >= 800 || !defined(__HIP_ARCH__)) + case at::ScalarType::BFloat16: { + using T = nv_bfloat16; + using P = typename sglang::packed_t::P; + auto d = P::size; + if (size % d != 0) { + throw std::runtime_error("size must be multiple of " + std::to_string(d)); + } + size /= d; + int blocks = std::min(16, (size + threads - 1) / threads); + switch (fa->world_size_) { + case 2: + hipLaunchKernelGGL((sglang::cross_device_reduce_1stage), dim3(blocks), dim3(threads), 0, stream, + ptrs, fa->sg_, fa->self_sg_, reinterpret_cast(out.data_ptr()), fa->rank_, size); + break; + case 4: + hipLaunchKernelGGL((sglang::cross_device_reduce_1stage), dim3(blocks), dim3(threads), 0, stream, + ptrs, fa->sg_, fa->self_sg_, reinterpret_cast(out.data_ptr()), fa->rank_, size); + break; + case 6: + hipLaunchKernelGGL((sglang::cross_device_reduce_1stage), dim3(blocks), dim3(threads), 0, stream, + ptrs, fa->sg_, fa->self_sg_, reinterpret_cast(out.data_ptr()), fa->rank_, size); + break; + case 8: + hipLaunchKernelGGL((sglang::cross_device_reduce_1stage), dim3(blocks), dim3(threads), 0, stream, + ptrs, fa->sg_, fa->self_sg_, reinterpret_cast(out.data_ptr()), fa->rank_, size); + break; + default: + throw std::runtime_error("world_size must be in (2,4,6,8)"); + } + break; + } +#endif + default: + throw std::runtime_error("deterministic allreduce only supports float32, float16 and bfloat16"); + } +} + +// Deterministic all-reduce for unregistered buffers (ROCm) +void deterministic_all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out) { + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp)); + auto stream = c10::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + + auto input_size = inp.numel() * inp.element_size(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), + "registered buffer is too small to contain the input"); + AT_CUDA_CHECK(hipMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(), + input_size, hipMemcpyDeviceToDevice, stream)); + deterministic_all_reduce_reg(_fa, reg_buffer, out); +} diff --git a/sgl-kernel/csrc/allreduce/quick_all_reduce.hip b/sgl-kernel/csrc/allreduce/quick_all_reduce.hip new file mode 100644 index 000000000000..de07cadce10b --- /dev/null +++ b/sgl-kernel/csrc/allreduce/quick_all_reduce.hip @@ -0,0 +1,112 @@ +// !!! This is a file automatically generated by hipify!!! +#include +#include +#include +#include + +#ifdef USE_ROCM + +#include "quick_all_reduce_hip.h" + +quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size, std::optional qr_max_size) { + if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); + if (world_size == 6) throw std::invalid_argument("world size == 6 is not supported"); + if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now"); + if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in"); + quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms(); + fptr->init(world_size, rank, qr_max_size); + return (quickreduce::fptr_t)fptr; +} + +void qr_destroy(quickreduce::fptr_t _fa) { + if (_fa) { + auto fa = reinterpret_cast(_fa); + fa->destroy(); + delete fa; + } +} + +torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + hipIpcMemHandle_t handle = fa->get_handle(); + auto options = torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto data_handle = torch::empty({static_cast(sizeof(hipIpcMemHandle_t))}, options); + std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t)); + return data_handle; +} + +void qr_open_handles(quickreduce::fptr_t _fa, const std::vector& handles) { + auto fa = reinterpret_cast(_fa); + std::vector ipc_handles; + ipc_handles.reserve(handles.size()); + for (auto& handle : handles) { + // Ensure the tensor is on the same device as the current device. + hipIpcMemHandle_t ipc_handle; + std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t)); + ipc_handles.push_back(ipc_handle); + } + fa->open_ipc_handles(ipc_handles); +} + +void qr_all_reduce( + quickreduce::fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, int64_t quant_level, bool cast_bf2half) { + auto fa = reinterpret_cast(_fa); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(inp)); + auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA(); + + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize); + if (out.scalar_type() == at::ScalarType::Half) { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), + quant_level, + stream); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + if (cast_bf2half) { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), + quant_level, + stream); + } else { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), + quant_level, + stream); + } + } else { + throw std::runtime_error("quick allreduce only supports float16 and bfloat16"); + } +} + +int64_t qr_max_size() { + // The default is 2GB (2,147,483,648 bytes) + return static_cast(std::numeric_limits::max()) + 1; +} + +#define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \ + template struct quickreduce::AllReduceTwoshot, cast_bf2half>; \ + template struct quickreduce::AllReduceTwoshot, cast_bf2half>; \ + template struct quickreduce::AllReduceTwoshot, cast_bf2half>; + +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true) + +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false) + +#endif // USE_ROCM diff --git a/sgl-kernel/csrc/allreduce/quick_all_reduce_hip.h b/sgl-kernel/csrc/allreduce/quick_all_reduce_hip.h new file mode 100644 index 000000000000..20c644047138 --- /dev/null +++ b/sgl-kernel/csrc/allreduce/quick_all_reduce_hip.h @@ -0,0 +1,238 @@ +// !!! This is a file automatically generated by hipify!!! +#pragma once + +#include + +#include + +#include "quick_all_reduce.cuh" + +#define HIP_CHECK(err) \ + do { \ + hipError_t err_ = (err); \ + if (err_ != hipSuccess) { \ + std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, hipGetErrorString(err_)); \ + throw std::runtime_error("HIP error"); \ + } \ + } while (0) + +namespace quickreduce { +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +template +__global__ __quickreduce_launch_bounds_two_shot__ static void allreduce_prototype_twoshot( + T const* A, + T* B, + uint32_t N, + uint32_t num_blocks, + int rank, + uint8_t** dbuffer_list, + uint32_t data_offset, + uint32_t flag_color, + int64_t data_size_per_phase) { + int block = blockIdx.x; + int grid = gridDim.x; + + while (block < num_blocks) { + AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset, flag_color, data_size_per_phase); + block += grid; + flag_color++; + } +} + +#define TWOSHOT_DISPATCH(__codec) \ + if (world_size == 2) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL( \ + (allreduce_prototype_twoshot), \ + dim3(grid), \ + dim3(kBlockTwoShot), \ + 0, \ + stream, \ + A, \ + B, \ + N, \ + num_blocks, \ + rank, \ + dbuffer_list, \ + data_offset, \ + flag_color, \ + this->kMaxProblemSize); \ + } else if (world_size == 4) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL( \ + (allreduce_prototype_twoshot), \ + dim3(grid), \ + dim3(kBlockTwoShot), \ + 0, \ + stream, \ + A, \ + B, \ + N, \ + num_blocks, \ + rank, \ + dbuffer_list, \ + data_offset, \ + flag_color, \ + this->kMaxProblemSize); \ + } else if (world_size == 8) { \ + using LineCodec = __codec; \ + using AllReduceKernel = AllReduceTwoshot; \ + hipLaunchKernelGGL( \ + (allreduce_prototype_twoshot), \ + dim3(grid), \ + dim3(kBlockTwoShot), \ + 0, \ + stream, \ + A, \ + B, \ + N, \ + num_blocks, \ + rank, \ + dbuffer_list, \ + data_offset, \ + flag_color, \ + this->kMaxProblemSize); \ + } + +enum QuickReduceQuantLevel { + F16 = 0, + INT8 = 1, + INT6 = 2, + INT4 = 3, +}; + +struct DeviceComms { + // Max problem size is 2GB (in bytes) or half of uint32_t max value. + int64_t kMaxProblemSize = static_cast(std::numeric_limits::max()) + 1; + + // Max TP-8 + static int constexpr kMaxWorldSize = 8; + + bool initialized = false; + uint32_t flag_color = 1; + int world_size; + int rank; + + uint8_t* dbuffer; + uint8_t** dbuffer_list; + hipIpcMemHandle_t buffer_ipc_handle; + std::vector all_buffer_ipc_handles; + std::vector buffer_list; + uint32_t data_offset; + + DeviceComms() : initialized(false), world_size(1), rank(0) {} + ~DeviceComms() { + destroy(); + } + + void init(int world_size, int rank, std::optional max_problem_size = std::nullopt) { + destroy(); + this->world_size = world_size; + this->rank = rank; + if (max_problem_size.has_value() && max_problem_size.value() > 0) { + this->kMaxProblemSize = max_problem_size.value(); + } + // Allocate buffer size for worst case: F16 2-stage buffer. + uint32_t flags_buffer_size = 2 * world_size * kMaxNumBlocks * sizeof(uint32_t); + static int64_t data_buffer_size = 2 * this->kMaxProblemSize; + int64_t total_buffer_size = flags_buffer_size + data_buffer_size; + data_offset = flags_buffer_size; + HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size, hipDeviceMallocUncached)); + + // Clear the flags buffer. + HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size)); + + // Device-side list of IPC buffers. + buffer_list.resize(world_size); + HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*))); + + // Create IPC handles for rank's communication buffer. + all_buffer_ipc_handles.resize(world_size); + HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer)); + + initialized = true; + } + int get_world_size() { + return world_size; + } + int get_rank() { + return rank; + } + bool status() { + return initialized; + } + hipIpcMemHandle_t const get_handle() { + return buffer_ipc_handle; + } + + void destroy() { + if (initialized) { + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i])); + } + } + + HIP_CHECK(hipFree(dbuffer)); + HIP_CHECK(hipFree(dbuffer_list)); + + initialized = false; + } + } + + void open_ipc_handles(std::vector const& ipc_handles) { + assert(ipc_handles.size() == all_buffer_ipc_handles.size()); + for (int i = 0; i < world_size; i++) { + all_buffer_ipc_handles[i] = ipc_handles[i]; + } + + // Open device memory access to the IPC communication buffers. + // Note: For our own rank, we do not need to open a handle. + for (int i = 0; i < world_size; i++) { + if (i != rank) { + HIP_CHECK( + hipIpcOpenMemHandle((void**)&buffer_list[i], all_buffer_ipc_handles[i], hipIpcMemLazyEnablePeerAccess)); + } else { + buffer_list[i] = dbuffer; + } + } + + HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(), world_size * sizeof(uint8_t*), hipMemcpyHostToDevice)); + } + + template + void allreduce(T const* A, T* B, uint32_t N, int quant_level, hipStream_t stream) { + if (world_size != 2 && world_size != 4 && world_size != 8) { + throw std::runtime_error("All Reduce not supported for world_size = " + std::to_string(world_size)); + } + + // Configuration. + uint32_t msg_size = N * sizeof(T); + uint32_t num_blocks = divceil(msg_size, kTileSize); + uint32_t grid = min(kMaxNumBlocks, num_blocks); + auto quant_level_ = static_cast(quant_level); + switch (quant_level_) { + case QuickReduceQuantLevel::INT8: + TWOSHOT_DISPATCH(CodecQ8) + break; + case QuickReduceQuantLevel::INT6: + TWOSHOT_DISPATCH(CodecQ6) + break; + case QuickReduceQuantLevel::INT4: + TWOSHOT_DISPATCH(CodecQ4) + break; + default: + TWOSHOT_DISPATCH(CodecFP) + break; + } + HIP_CHECK(hipGetLastError()); + // Rotate the flag color. + flag_color += divceil(N, grid); + } +}; + +} // namespace quickreduce diff --git a/sgl-kernel/csrc/elementwise/activation.hip b/sgl-kernel/csrc/elementwise/activation.hip new file mode 100644 index 000000000000..25cc9774e630 --- /dev/null +++ b/sgl-kernel/csrc/elementwise/activation.hip @@ -0,0 +1,172 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#ifndef USE_ROCM + +#include + +#include "utils_hip.h" + +#else +#include "hip/hip_act_and_mul_hip.cuh" +#endif + +// Adapted from flashinfer activation +// https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/csrc/activation.cu#L44 + +namespace detail { + +template +__device__ __forceinline__ float to_f32(const T& x) { +#if USE_ROCM + return castToFloat(x); +#else + return static_cast(x); +#endif +} + +template +__device__ __forceinline__ T from_f32(float f32) { +#if USE_ROCM + return castFromFloat(f32); +#else + return static_cast(f32); +#endif +} + +} // namespace detail + +template +__device__ __forceinline__ T silu(const T& x) { + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val / (1.0f + expf(-f32_val))); +} + +template +__device__ __forceinline__ T gelu(const T& x) { + constexpr float kAlpha = M_SQRT1_2; + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val * (0.5f * (1.0f + erf(f32_val * kAlpha)))); +} + +// gelu_quick(x) = x * torch.sigmoid(1.702 * x) +template +__device__ __forceinline__ T gelu_quick_act(const T& x) { + float f32_val = detail::to_f32(x); + return detail::from_f32(f32_val / (1.0f + expf(-f32_val * 1.702f))); +} + +template +__device__ __forceinline__ T gelu_tanh(const T& x) { + constexpr float kAlpha = 0.044715f; + constexpr float kBeta = 0.7978845608028654f; + float f32_val = detail::to_f32(x); + const float cdf = 0.5f * (1.0f + tanhf((kBeta * (f32_val + kAlpha * f32_val * f32_val * f32_val)))); + return detail::from_f32(f32_val * cdf); +} + +void silu_and_mul(at::Tensor& out, at::Tensor& input) { + int d = input.size(-1) / 2; + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(::min(d / vec_size, 1024U)); +#if USE_ROCM + hipLaunchKernelGGL(( sgl_hip::activation::act_and_mul_kernel) + , dim3(grid), dim3(block), 0, stream, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#else + hipLaunchKernelGGL(( flashinfer::activation::act_and_mul_kernel) + , dim3(grid), dim3(block), 0, stream, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#endif + return true; + }); +} + +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input) { + int d = input.size(-1) / 2; + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(::min(d / vec_size, 1024U)); +#if USE_ROCM + hipLaunchKernelGGL(( sgl_hip::activation::act_and_mul_kernel) + , dim3(grid), dim3(block), 0, stream, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#else + hipLaunchKernelGGL(( flashinfer::activation::act_and_mul_kernel) + , dim3(grid), dim3(block), 0, stream, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#endif + return true; + }); +} + +void gelu_and_mul(at::Tensor& out, at::Tensor& input) { + int d = input.size(-1) / 2; + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(::min(d / vec_size, 1024U)); +#if USE_ROCM + hipLaunchKernelGGL(( sgl_hip::activation::act_and_mul_kernel) + , dim3(grid), dim3(block), 0, stream, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#else + hipLaunchKernelGGL(( flashinfer::activation::act_and_mul_kernel) + , dim3(grid), dim3(block), 0, stream, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); +#endif + + return true; + }); +} + +#if USE_ROCM +void gelu_quick(at::Tensor& out, const at::Tensor& input) { + int d = input.size(-1); + int64_t num_tokens = input.numel() / input.size(-1); + dim3 grid(num_tokens); + + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(input)); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(input.scalar_type(), c_type, [&] { + uint32_t vec_size = 16 / sizeof(c_type); + dim3 block(::min(d / vec_size, 1024U)); + hipLaunchKernelGGL(( sgl_hip::activation::act_only_kernel) + , dim3(grid), dim3(block), 0, stream, static_cast(out.data_ptr()), static_cast(input.data_ptr()), d); + + return true; + }); +} +#endif diff --git a/sgl-kernel/csrc/elementwise/pos_enc.hip b/sgl-kernel/csrc/elementwise/pos_enc.hip new file mode 100644 index 000000000000..ecd6c8994420 --- /dev/null +++ b/sgl-kernel/csrc/elementwise/pos_enc.hip @@ -0,0 +1,210 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Adapted from +// https://github.com/vllm-project/vllm/blob/014ece97c7aa49084a1119dca792af081a18dbc1/csrc/pos_encoding_kernels.cu + +#include +#include +#include + +#include "utils_hip.h" + +template +inline __device__ void apply_token_rotary_embedding( + scalar_t* __restrict__ arr, + const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, + int rot_offset, + int embed_dim) { + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = SGLANG_LDG(cos_ptr + x_index); + sin = SGLANG_LDG(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = SGLANG_LDG(cos_ptr + x_index / 2); + sin = SGLANG_LDG(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* cache_ptr, + const int head_size, + const int num_heads, + const int num_kv_heads, + const int rot_dim, + const int token_idx, + const int64_t query_stride, + const int64_t key_stride, + const int64_t head_stride) { + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_stride; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } + + if (key != nullptr) { + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_stride; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } + } +} + +template +__global__ void rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int rot_dim, + const int64_t query_stride, + const int64_t key_stride, + const int64_t head_stride, + const int num_heads, + const int num_kv_heads, + const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding( + query, + key, + cache_ptr, + head_size, + num_heads, + num_kv_heads, + rot_dim, + token_idx, + query_stride, + key_stride, + head_stride); +} + +void rotary_embedding( + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] + std::optional key, + // null or + // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] + int64_t head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox) { + // num_tokens = batch_size * seq_len + int64_t num_tokens = positions.numel(); + int positions_ndim = positions.dim(); + + // Make sure num_tokens dim is consistent across positions, query, and key + TORCH_CHECK( + positions_ndim == 1 || positions_ndim == 2, "positions must have shape [num_tokens] or [batch_size, seq_len]"); + if (positions_ndim == 1) { + TORCH_CHECK( + query.size(0) == positions.size(0) && (!key.has_value() || key->size(0) == positions.size(0)), + "query, key and positions must have the same number of tokens"); + } + if (positions_ndim == 2) { + TORCH_CHECK( + query.size(0) == positions.size(0) && (!key.has_value() || key->size(0) == positions.size(0)) && + query.size(1) == positions.size(1) && (!key.has_value() || key->size(1) == positions.size(1)), + "query, key and positions must have the same batch_size and seq_len"); + } + + // Make sure head_size is valid for query and key + // hidden_size = num_heads * head_size + int query_hidden_size = query.numel() / num_tokens; + int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; + TORCH_CHECK(query_hidden_size % head_size == 0); + TORCH_CHECK(key_hidden_size % head_size == 0); + + // Make sure query and key have consistent number of heads + int num_heads = query_hidden_size / head_size; + int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; + TORCH_CHECK(num_heads % num_kv_heads == 0); + + int rot_dim = cos_sin_cache.size(1); + int seq_dim_idx = positions_ndim - 1; + int64_t query_stride = query.stride(seq_dim_idx); + int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; + // Determine head stride: for [*, heads, head_size] use stride of last dim; + // for flat [*, heads*head_size], heads blocks are contiguous of size + // head_size + int query_ndim = query.dim(); + int64_t head_stride = (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(query)); + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + DISPATCH_FLOAT_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + hipLaunchKernelGGL(( rotary_embedding_kernel), dim3(grid), dim3(block), 0, stream, + positions.data_ptr(), + query.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + key_stride, + head_stride, + num_heads, + num_kv_heads, + head_size); + } else { + hipLaunchKernelGGL(( rotary_embedding_kernel), dim3(grid), dim3(block), 0, stream, + positions.data_ptr(), + query.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), + rot_dim, + query_stride, + key_stride, + head_stride, + num_heads, + num_kv_heads, + head_size); + } + }); +} diff --git a/sgl-kernel/csrc/elementwise/topk.hip b/sgl-kernel/csrc/elementwise/topk.hip new file mode 100644 index 000000000000..7a423812c038 --- /dev/null +++ b/sgl-kernel/csrc/elementwise/topk.hip @@ -0,0 +1,547 @@ +// !!! This is a file automatically generated by hipify!!! +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a +// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; // bytes +#endif +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } +} + +auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; +} + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + // hipify will turn hipFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::hipFuncSetAttribute( + reinterpret_cast(f), ::hipFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + // CUDA: keep original behavior (no cast needed). + return ::hipFuncSetAttribute(f, ::hipFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == hipSuccess, "set_up_kernel_once failed:", ::hipGetErrorString(result)); +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + hipLaunchKernelGGL(( topk_kernel), dim3(grid), dim3(block), kSmem, stream, params); + const auto result = hipGetLastError(); + TORCH_CHECK(result == hipSuccess, "topk kernel failed:", ::hipGetErrorString(result)); +} + +void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + hipLaunchKernelGGL(( topk_transform_decode_kernel), dim3(grid), dim3(block), kSmem, stream, + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + hipLaunchKernelGGL(( topk_transform_prefill_kernel), dim3(grid), dim3(block), kSmem, stream, + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = hipGetLastError(); + TORCH_CHECK(result == hipSuccess, "topk kernel failed:", ::hipGetErrorString(result)); +} + +void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + hipLaunchKernelGGL(( topk_transform_prefill_ragged_kernel), dim3(grid), dim3(block), kSmem, stream, + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = hipGetLastError(); + TORCH_CHECK(result == hipSuccess, "topk kernel failed:", ::hipGetErrorString(result)); +} diff --git a/sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_hip.hip b/sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_hip.hip new file mode 100644 index 000000000000..477944a1c178 --- /dev/null +++ b/sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_hip.hip @@ -0,0 +1,272 @@ +// !!! This is a file automatically generated by hipify!!! +// Adapted from +// https://github.com/mlc-ai/xgrammar/blob/v0.1.18/python/xgrammar/kernels/apply_token_bitmask_inplace_cuda.cu + +/* + * SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// clang-format off +#include +#include +#include +#include +#include + + +#if !defined(USE_ROCM) && (!defined(TORCH_HIP_VERSION) || TORCH_HIP_VERSION < 12040) +void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional indices = at::nullopt) { + TORCH_CHECK(false, "CUDA version must be >= 12.4 for ApplyTokenBitmaskInplace"); +} +#else + +#ifndef CUDART_INF_FP16 +#ifndef USE_ROCM +#define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U) +#endif +#endif + +#ifndef CUDART_INF_BF16 +#ifndef USE_ROCM +#define CUDART_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U) +#endif +#endif + +constexpr int32_t BITS_PER_BLOCK = 32; +constexpr int32_t THREADS_PER_THREAD_BLOCK = 256; + +template +__device__ T NegativeInfinity() { + return -INFINITY; +} + +template <> +__device__ __half NegativeInfinity<__half>() { +#ifdef USE_ROCM + return __float2half(-INFINITY); +#else + return -CUDART_INF_FP16; +#endif +} + +template <> +__device__ __hip_bfloat16 NegativeInfinity<__hip_bfloat16>() { +#ifdef USE_ROCM + return __hip_bfloat16(-INFINITY); +#else + return -CUDART_INF_BF16; +#endif +} + +template +__device__ PackedT PackedNegativeInfinity() { + constexpr int kAlignment = sizeof(PackedT) / sizeof(T); + T packed[kAlignment]; +#pragma unroll + for (int i = 0; i < kAlignment; i++) { + packed[i] = NegativeInfinity(); + } + return *reinterpret_cast(packed); +} + +template +__global__ void __launch_bounds__(THREADS_PER_THREAD_BLOCK) LogitsBitmaskKernel( + T* __restrict__ logits, + const int32_t* __restrict__ bitmask, + const int32_t* __restrict__ indices, + int32_t vocab_size, + int32_t logits_stride, + int32_t bitmask_stride) { + constexpr int kAlignment = sizeof(PackedT) / sizeof(T); + constexpr uint32_t kPackedMask = (1 << kAlignment) - 1; + + const int batch_idx = (indices == nullptr) ? blockIdx.y : indices[blockIdx.y]; + + const int block_offset = blockIdx.x * THREADS_PER_THREAD_BLOCK * kBitsPerThread; + T* logits_gmem_ptr = logits + batch_idx * logits_stride + block_offset; + const int32_t* bitmask_gmem_ptr = bitmask + batch_idx * bitmask_stride + block_offset / BITS_PER_BLOCK; + const int bitmask_inner_idx = threadIdx.x % (BITS_PER_BLOCK / kAlignment); + T logits_reg[kAlignment]; + +#pragma unroll + for (int offset = threadIdx.x * kAlignment; offset < THREADS_PER_THREAD_BLOCK * kBitsPerThread; + offset += THREADS_PER_THREAD_BLOCK * kAlignment) { + if (block_offset + offset >= vocab_size) { + break; + } + + const uint32_t bitmask_val = + (~bitmask_gmem_ptr[offset / BITS_PER_BLOCK] >> (bitmask_inner_idx * kAlignment)) & kPackedMask; + + if (bitmask_val == 0) { + continue; + } + + if (bitmask_val == kPackedMask) { + *reinterpret_cast(logits_gmem_ptr + offset) = PackedNegativeInfinity(); + continue; + } + + *reinterpret_cast(logits_reg) = *reinterpret_cast(logits_gmem_ptr + offset); +#pragma unroll + for (int i = 0; i < kAlignment; i++) { + if (((bitmask_val >> i) & 1)) { + logits_reg[i] = NegativeInfinity(); + } + } + *reinterpret_cast(logits_gmem_ptr + offset) = *reinterpret_cast(logits_reg); + } +} + +template ::value>> +constexpr auto CeilDiv(T numerator, T denominator) { + return (numerator + denominator - 1) / denominator; +} + +template +void ApplyTokenBitmaskInplaceDispatchToBitsPerThread( + T* __restrict__ logits, + const int32_t* __restrict__ bitmask, + const int32_t* __restrict__ indices, + int32_t vocab_size, + int32_t logits_stride, + int32_t bitmask_stride, + int32_t num_rows) { + constexpr int kAlignment = sizeof(PackedT) / sizeof(T); + const int32_t num_blocks_per_row = CeilDiv(2048 / THREADS_PER_THREAD_BLOCK * 128, num_rows); + const int32_t num_bits_per_thread = CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * num_blocks_per_row); + + const dim3 block(THREADS_PER_THREAD_BLOCK); + hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + + if (num_bits_per_thread <= 4 && kAlignment <= 4) { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 4), num_rows); + hipLaunchKernelGGL(( LogitsBitmaskKernel) + , dim3(grid), dim3(block), 0, stream, logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } else if (num_bits_per_thread <= 8 && kAlignment <= 8) { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 8), num_rows); + hipLaunchKernelGGL(( LogitsBitmaskKernel) + , dim3(grid), dim3(block), 0, stream, logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } else if (num_bits_per_thread <= 16 && kAlignment <= 16) { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 16), num_rows); + hipLaunchKernelGGL(( LogitsBitmaskKernel) + , dim3(grid), dim3(block), 0, stream, logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } else { + const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 32), num_rows); + hipLaunchKernelGGL(( LogitsBitmaskKernel) + , dim3(grid), dim3(block), 0, stream, logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); + } +} + +template +void ApplyTokenBitmaskInplaceDispatchToPackedT( + T* __restrict__ logits, + const int32_t* __restrict__ bitmask, + const int32_t* __restrict__ indices, + int32_t vocab_size, + int32_t logits_stride, + int32_t bitmask_stride, + int32_t num_rows) { + if (logits_stride % (sizeof(float4) / sizeof(T)) == 0) { + ApplyTokenBitmaskInplaceDispatchToBitsPerThread( + logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows); + } else { + ApplyTokenBitmaskInplaceDispatchToBitsPerThread( + logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride, num_rows); + } +} + +void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional indices = at::nullopt) { + TORCH_CHECK(logits.is_cuda(), "logits must be a CUDA tensor."); + TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous."); + TORCH_CHECK(logits.dim() == 1 || logits.dim() == 2, "logits must be a 1D or 2D tensor."); + std::pair logits_shape = + logits.dim() == 2 ? std::make_pair(static_cast(logits.size(0)), static_cast(logits.size(1))) + : std::make_pair(1, static_cast(logits.size(0))); + + TORCH_CHECK(bitmask.is_cuda(), "bitmask must be a CUDA tensor."); + TORCH_CHECK(bitmask.is_contiguous(), "bitmask must be contiguous."); + TORCH_CHECK(bitmask.dim() == 1 || bitmask.dim() == 2, "bitmask must be a 1D or 2D tensor."); + std::pair bitmask_shape = + bitmask.dim() == 2 ? std::make_pair(static_cast(bitmask.size(0)), static_cast(bitmask.size(1))) + : std::make_pair(1, static_cast(bitmask.size(0))); + + TORCH_CHECK(bitmask.dtype() == torch::kInt32, "bitmask must be of type int32."); + + TORCH_CHECK( + (logits_shape.second + BITS_PER_BLOCK - 1) / BITS_PER_BLOCK >= bitmask_shape.second, + "The provided logits's vocab size should be no less than the bitmask's vocab size " + "(converted from bitmask size). But got vocab size ", + logits_shape.second, + " vs bitmask size ", + bitmask_shape.second); + + int vocab_size = ::min(logits_shape.second, bitmask_shape.second * BITS_PER_BLOCK); + + int32_t num_rows = logits_shape.first; + int32_t* indices_ptr = nullptr; + if (indices) { + TORCH_CHECK(indices->is_cuda(), "indices must be a CUDA tensor."); + TORCH_CHECK(indices->is_contiguous(), "indices must be contiguous."); + TORCH_CHECK(indices->dim() == 1, "indices must be a 1D tensor."); + TORCH_CHECK(indices->dtype() == torch::kInt32, "indices must be of type int32."); + num_rows = indices->size(0); + indices_ptr = indices->data_ptr(); + } else { + TORCH_CHECK(logits_shape.first == bitmask_shape.first, "logits and bitmask must have the same batch size."); + } + + switch (logits.scalar_type()) { + case torch::kFloat32: { + ApplyTokenBitmaskInplaceDispatchToPackedT( + logits.data_ptr(), + bitmask.data_ptr(), + indices_ptr, + vocab_size, + logits_shape.second, + bitmask_shape.second, + num_rows); + break; + } + case torch::kFloat16: { + ApplyTokenBitmaskInplaceDispatchToPackedT( + reinterpret_cast<__half*>(logits.data_ptr()), + bitmask.data_ptr(), + indices_ptr, + vocab_size, + logits_shape.second, + bitmask_shape.second, + num_rows); + break; + } + case torch::kBFloat16: { + ApplyTokenBitmaskInplaceDispatchToPackedT( + reinterpret_cast<__hip_bfloat16*>(logits.data_ptr()), + bitmask.data_ptr(), + indices_ptr, + vocab_size, + logits_shape.second, + bitmask_shape.second, + num_rows); + break; + } + default: + TORCH_CHECK(false, "logits dtype must be float, half or bfloat16."); + break; + } +} +#endif +// clang-format on diff --git a/sgl-kernel/csrc/kvcacheio/transfer.hip b/sgl-kernel/csrc/kvcacheio/transfer.hip new file mode 100644 index 000000000000..5546a1370804 --- /dev/null +++ b/sgl-kernel/csrc/kvcacheio/transfer.hip @@ -0,0 +1,809 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +#include +#include +#include + +#include + +#ifndef USE_ROCM +#define WARP_SIZE 32 +#include "pytorch_extension_utils.h" +#else +#include "pytorch_extension_utils_rocm.h" +#include "utils_hip.h" // WARP_SIZE +#endif + +__device__ __forceinline__ void +transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_t item_size_bytes) { + const uint64_t* __restrict__ src = static_cast(src_addr); + uint64_t* __restrict__ dst = static_cast(dst_addr); + const int total_chunks = item_size_bytes / sizeof(uint64_t); + +#pragma unroll + for (int j = lane_id; j < total_chunks; j += WARP_SIZE) { +#ifndef USE_ROCM + uint64_t tmp; + asm volatile("ld.global.nc.b64 %0,[%1];" : "=l"(tmp) : "l"(src + j) : "memory"); + asm volatile("st.global.cg.b64 [%0],%1;" ::"l"(dst + j), "l"(tmp) : "memory"); + +#else + uint64_t tmp = __builtin_nontemporal_load(src + j); + __builtin_nontemporal_store(tmp, dst + j); +#endif + } +} + +template +__device__ __forceinline__ T* get_global_offset_lf( + T* base, + const uintptr_t* __restrict__ /*unused*/, + int64_t layer_id, + int64_t layer_dim, + int64_t page_id, + int64_t item_size_bytes) { + // layer first + return base + layer_id * layer_dim + page_id * item_size_bytes; +} + +template +__device__ __forceinline__ T* get_global_offset_pf( + T* base, + const uintptr_t* __restrict__ /*unused*/, + int64_t layer_id, + int64_t page_dim, + int64_t page_id, + int64_t item_size_bytes) { + // page first + return base + page_id * page_dim + layer_id * item_size_bytes; +} + +// get offset from layer base table when layers are not contiguous +template +__device__ __forceinline__ T* get_global_offset_lf_tbl( + T* /*unused*/, + const uintptr_t* __restrict__ layer_base_tbl, + int64_t layer_id, + int64_t /*unused*/, + int64_t page_id, + int64_t item_size_bytes) { + return reinterpret_cast(layer_base_tbl[layer_id]) + page_id * item_size_bytes; +} + +template +__device__ __forceinline__ T* get_global_offset_per_head_lf( + T* base, + const uintptr_t* __restrict__ /*unused*/, + int64_t layer_id, + int64_t layer_dim, + int64_t page_id, + int64_t item_size_bytes, + int64_t head_id, + int64_t head_num, + int64_t /*unused*/) { + // layer first offset func per head + return base + layer_id * layer_dim + page_id * item_size_bytes + item_size_bytes / head_num * head_id; +} + +template +__device__ __forceinline__ T* get_global_offset_per_head_lf_tbl( + T* /*unused*/, + const uintptr_t* __restrict__ layer_base_tbl, + int64_t layer_id, + int64_t /*unused*/, + int64_t page_id, + int64_t item_size_bytes, + int64_t head_id, + int64_t head_num, + int64_t /*unused*/) { + return reinterpret_cast(layer_base_tbl[layer_id]) + page_id * item_size_bytes + + item_size_bytes / head_num * head_id; +} + +template +__device__ __forceinline__ T* get_global_offset_ph( + T* base, + const uintptr_t* __restrict__ /*unused*/, + int64_t layer_id, + int64_t page_dim, + int64_t page_id, + int64_t item_size_bytes, + int64_t head_id, + int64_t head_num, + int64_t page_size) { + // page head layout: [page_num, head_num, page_size, layer_num, head_dim] + return base + page_id / page_size * page_size * page_dim + // page_num dimension offset + page_dim / head_num * head_id * page_size + // head_num dimension offset + page_id % page_size * page_dim / head_num + // page_size dimension offset + layer_id * item_size_bytes / head_num; // layer_num dimension offset +} + +template +__global__ void transfer_page_head_kernel_impl( + const void* __restrict__ src_k, + void* __restrict__ dst_k, + const void* __restrict__ src_v, + void* __restrict__ dst_v, + const int64_t* __restrict__ src_indices, + const int64_t* __restrict__ dst_indices, + int64_t start_layer_id, + int64_t num_layers_to_process, + int64_t num_items, + int64_t items_per_warp, + int64_t item_size_bytes, + int64_t src_layout_dim, + int64_t dst_layout_dim, + const uintptr_t* __restrict__ src_k_layer_tbl, + const uintptr_t* __restrict__ dst_k_layer_tbl, + const uintptr_t* __restrict__ src_v_layer_tbl, + const uintptr_t* __restrict__ dst_v_layer_tbl, + const int64_t page_size, + const int64_t head_num) { + int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + int32_t lane_id = tid % WARP_SIZE; + int32_t warp_id = tid / WARP_SIZE; + const int64_t head_size_bytes = item_size_bytes / head_num; + + for (int i = 0; i < items_per_warp; ++i) { + int64_t item_id = warp_id * items_per_warp + i; + if (item_id >= num_items) { + break; + } + const int64_t src_page_id = src_indices[item_id]; + const int64_t dst_page_id = dst_indices[item_id]; + + // Loop over layers if necessary + for (int64_t layer_id = start_layer_id; layer_id < start_layer_id + num_layers_to_process; ++layer_id) { + // For page head layout, the cache of each head in the token is discontinuous, need to loop + for (int64_t head_id = 0; head_id < head_num; ++head_id) { + const char* src_k_ptr = SrcOffsetFn( + static_cast(src_k), + src_k_layer_tbl, + layer_id, + src_layout_dim, + src_page_id, + item_size_bytes, + head_id, + head_num, + page_size); + char* dst_k_ptr = DstOffsetFn( + static_cast(dst_k), + dst_k_layer_tbl, + layer_id, + dst_layout_dim, + dst_page_id, + item_size_bytes, + head_id, + head_num, + page_size); + transfer_item_warp(lane_id, src_k_ptr, dst_k_ptr, head_size_bytes); + + const char* src_v_ptr = SrcOffsetFn( + static_cast(src_v), + src_v_layer_tbl, + layer_id, + src_layout_dim, + src_page_id, + item_size_bytes, + head_id, + head_num, + page_size); + char* dst_v_ptr = DstOffsetFn( + static_cast(dst_v), + dst_v_layer_tbl, + layer_id, + dst_layout_dim, + dst_page_id, + item_size_bytes, + head_id, + head_num, + page_size); + transfer_item_warp(lane_id, src_v_ptr, dst_v_ptr, head_size_bytes); + } + } + } +} + +template +__global__ void transfer_kernel_impl( + const void* __restrict__ src_k, + void* __restrict__ dst_k, + const void* __restrict__ src_v, + void* __restrict__ dst_v, + const int64_t* __restrict__ src_indices, + const int64_t* __restrict__ dst_indices, + int64_t start_layer_id, + int64_t num_layers_to_process, + int64_t num_items, + int64_t items_per_warp, + int64_t item_size_bytes, + int64_t src_layout_dim, + int64_t dst_layout_dim, + const uintptr_t* __restrict__ src_k_layer_tbl, + const uintptr_t* __restrict__ dst_k_layer_tbl, + const uintptr_t* __restrict__ src_v_layer_tbl, + const uintptr_t* __restrict__ dst_v_layer_tbl) { + int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; + int32_t lane_id = tid % WARP_SIZE; + int32_t warp_id = tid / WARP_SIZE; + + for (int i = 0; i < items_per_warp; ++i) { + int64_t item_id = warp_id * items_per_warp + i; + if (item_id >= num_items) { + break; + } + const int64_t src_page_id = src_indices[item_id]; + const int64_t dst_page_id = dst_indices[item_id]; + + // Loop over layers if necessary + for (int64_t layer_id = start_layer_id; layer_id < start_layer_id + num_layers_to_process; ++layer_id) { + const char* src_ptr = SrcOffsetFn( + static_cast(src_k), src_k_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes); + char* dst_ptr = DstOffsetFn( + static_cast(dst_k), dst_k_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes); + transfer_item_warp(lane_id, src_ptr, dst_ptr, item_size_bytes); + + if constexpr (!IsMLA) { + const char* src_v_ptr = SrcOffsetFn( + static_cast(src_v), src_v_layer_tbl, layer_id, src_layout_dim, src_page_id, item_size_bytes); + char* dst_v_ptr = DstOffsetFn( + static_cast(dst_v), dst_v_layer_tbl, layer_id, dst_layout_dim, dst_page_id, item_size_bytes); + transfer_item_warp(lane_id, src_v_ptr, dst_v_ptr, item_size_bytes); + } + } + } +} + +template +void transfer_kv_launcher( + const at::Tensor& src_k, + at::Tensor& dst_k, + const at::Tensor& src_v, + at::Tensor& dst_v, + const at::Tensor& src_indices, + const at::Tensor& dst_indices, + int64_t start_layer_id, + int64_t num_layers_to_process, + int64_t item_size, + int64_t src_layout_dim, + int64_t dst_layout_dim, + const at::Tensor& src_k_layers, + const at::Tensor& dst_k_layers, + const at::Tensor& src_v_layers, + const at::Tensor& dst_v_layers, + int64_t block_quota, + int64_t num_warps_per_block, + const int64_t page_size = 16, + const int64_t head_num = 1) { + TORCH_CHECK(src_indices.is_cuda(), "Source indices must be a CUDA tensor"); + TORCH_CHECK(dst_indices.is_cuda(), "Destination indices must be a CUDA tensor"); + TORCH_CHECK(src_indices.scalar_type() == at::kLong, "Source indices must be of type long"); + TORCH_CHECK(dst_indices.scalar_type() == at::kLong, "Destination indices must be of type long"); + TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); + TORCH_CHECK(item_size % 8 == 0, "Item byte size must be divisible by 8"); + + auto div_up = [](int64_t x, int64_t y) { return (x + y - 1) / y; }; + const int64_t num_items = src_indices.numel(); + const int64_t items_per_warp = div_up(num_items, block_quota * num_warps_per_block); + const int32_t num_blocks = div_up(num_items, items_per_warp * num_warps_per_block); + dim3 grid_dim(num_blocks, 1, 1); + const int32_t threads_per_block = num_warps_per_block * WARP_SIZE; + + const void* src_k_ptr = src_k.defined() ? src_k.data_ptr() : nullptr; + void* dst_k_ptr = dst_k.defined() ? dst_k.data_ptr() : nullptr; + const void* src_v_ptr = IsMLA || !src_v.defined() ? nullptr : src_v.data_ptr(); + void* dst_v_ptr = IsMLA || !dst_v.defined() ? nullptr : dst_v.data_ptr(); + const uintptr_t* src_k_tbl_ptr = src_k_layers.defined() ? src_k_layers.data_ptr() : nullptr; + const uintptr_t* dst_k_tbl_ptr = dst_k_layers.defined() ? dst_k_layers.data_ptr() : nullptr; + const uintptr_t* src_v_tbl_ptr = IsMLA || !src_v_layers.defined() ? nullptr : src_v_layers.data_ptr(); + const uintptr_t* dst_v_tbl_ptr = IsMLA || !dst_v_layers.defined() ? nullptr : dst_v_layers.data_ptr(); + + hipStream_t torch_current_stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + if constexpr (PageHeadLayout) { + hipLaunchKernelGGL(( transfer_page_head_kernel_impl), dim3(grid_dim), dim3(threads_per_block), 0, torch_current_stream, + src_k_ptr, + dst_k_ptr, + src_v_ptr, + dst_v_ptr, + src_indices.data_ptr(), + dst_indices.data_ptr(), + start_layer_id, + num_layers_to_process, + num_items, + items_per_warp, + item_size, + src_layout_dim, + dst_layout_dim, + src_k_tbl_ptr, + dst_k_tbl_ptr, + src_v_tbl_ptr, + dst_v_tbl_ptr, + page_size, + head_num); + } else { + hipLaunchKernelGGL(( transfer_kernel_impl), dim3(grid_dim), dim3(threads_per_block), 0, torch_current_stream, + src_k_ptr, + dst_k_ptr, + src_v_ptr, + dst_v_ptr, + src_indices.data_ptr(), + dst_indices.data_ptr(), + start_layer_id, + num_layers_to_process, + num_items, + items_per_warp, + item_size, + src_layout_dim, + dst_layout_dim, + src_k_tbl_ptr, + dst_k_tbl_ptr, + src_v_tbl_ptr, + dst_v_tbl_ptr); + } + C10_HIP_KERNEL_LAUNCH_CHECK(); +} + +void transfer_kv_per_layer( + const at::Tensor src_k, + at::Tensor dst_k, + const at::Tensor src_v, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t block_quota, + int64_t num_warps_per_block) { + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, false>( + src_k, + dst_k, + src_v, + dst_v, + src_indices, + dst_indices, + 0, + 1, + item_size, + 0, + 0, + empty, + empty, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_per_layer_pf_lf( + const at::Tensor src_k, + at::Tensor dst_k, + const at::Tensor src_v, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t layer_id, + int64_t item_size, + int64_t src_layout_dim, + int64_t block_quota, + int64_t num_warps_per_block) { + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, false>( + src_k, + dst_k, + src_v, + dst_v, + src_indices, + dst_indices, + layer_id, + 1, + item_size, + src_layout_dim, + 0, + empty, + empty, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_per_layer_ph_lf( + const at::Tensor src_k, + at::Tensor dst_k, + const at::Tensor src_v, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t layer_id, + int64_t item_size, + int64_t src_layout_dim, + int64_t page_size, + int64_t head_num, + int64_t block_quota, + int64_t num_warps_per_block) { + at::Tensor empty; + transfer_kv_launcher, get_global_offset_per_head_lf, false, true>( + src_k, + dst_k, + src_v, + dst_v, + src_indices, + dst_indices, + layer_id, + 1, + item_size, + src_layout_dim, + 0, + empty, + empty, + empty, + empty, + block_quota, + num_warps_per_block, + page_size, + head_num); +} + +void transfer_kv_all_layer( + const at::Tensor src_k_layers, + const at::Tensor dst_k_layers, + const at::Tensor src_v_layers, + const at::Tensor dst_v_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf_tbl, false>( + empty, + empty, + empty, + empty, + src_indices, + dst_indices, + 0, + num_layers, + item_size, + 0, + 0, + src_k_layers, + dst_k_layers, + src_v_layers, + dst_v_layers, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer_lf_pf( + const at::Tensor src_k_layers, + at::Tensor dst_k, + const at::Tensor src_v_layers, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t dst_layout_dim, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_pf, false>( + empty, + dst_k, + empty, + dst_v, + src_indices, + dst_indices, + 0, + num_layers, + item_size, + 0, + dst_layout_dim, + src_k_layers, + empty, + src_v_layers, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer_lf_ph( + const at::Tensor src_k_layers, + at::Tensor dst_k, + const at::Tensor src_v_layers, + at::Tensor dst_v, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t dst_layout_dim, + int64_t num_layers, + int64_t page_size, + int64_t head_num, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_k_layers.size(0), "Number of layers in source k tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_ph, false, true>( + empty, + dst_k, + empty, + dst_v, + src_indices, + dst_indices, + 0, + num_layers, + item_size, + 0, + dst_layout_dim, + src_k_layers, + empty, + src_v_layers, + empty, + block_quota, + num_warps_per_block, + page_size, + head_num); +} + +void transfer_kv_per_layer_mla( + const at::Tensor src, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t block_quota, + int64_t num_warps_per_block) { + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, true>( + src, + dst, + empty, + empty, + src_indices, + dst_indices, + 0, + 1, + item_size, + 0, + 0, + empty, + empty, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_per_layer_mla_pf_lf( + const at::Tensor src, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t layer_id, + int64_t item_size, + int64_t src_layout_dim, + int64_t block_quota, + int64_t num_warps_per_block) { + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf, true>( + src, + dst, + empty, + empty, + src_indices, + dst_indices, + layer_id, + 1, + item_size, + src_layout_dim, + 0, + empty, + empty, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer_mla( + const at::Tensor src_layers, + const at::Tensor dst_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_lf_tbl, true>( + empty, + empty, + empty, + empty, + src_indices, + dst_indices, + 0, + num_layers, + item_size, + 0, + 0, + src_layers, + dst_layers, + empty, + empty, + block_quota, + num_warps_per_block); +} + +void transfer_kv_all_layer_mla_lf_pf( + const at::Tensor src_layers, + at::Tensor dst, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t item_size, + int64_t dst_layout_dim, + int64_t num_layers, + int64_t block_quota, + int64_t num_warps_per_block) { + TORCH_CHECK(num_layers == src_layers.size(0), "Number of layers in source tensor does not match num_layers"); + at::Tensor empty; + transfer_kv_launcher, get_global_offset_pf, true>( + empty, + dst, + empty, + empty, + src_indices, + dst_indices, + 0, + num_layers, + item_size, + 0, + dst_layout_dim, + src_layers, + empty, + empty, + empty, + block_quota, + num_warps_per_block); +} + +inline void transfer_page_direct( + const at::Tensor src_buffer, + at::Tensor dst_buffer, + int64_t src_page_index, + int64_t dst_page_index, + int64_t page_size) { + dst_buffer.slice(0, dst_page_index, dst_page_index + page_size) + .copy_( + src_buffer.slice(0, src_page_index, src_page_index + page_size), + /* non_blocking= */ true); +} + +void transfer_kv_direct( + const std::vector& src_layers, + std::vector dst_layers, + const at::Tensor src_indices, + const at::Tensor dst_indices, + int64_t page_size) { + TORCH_CHECK( + src_layers.size() == dst_layers.size(), "Source and destination layers must have the same number of layers"); + TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); + TORCH_CHECK(page_size > 0, "Page size must be positive"); + TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size"); + + auto src_indices_cpu = src_indices.cpu(); + auto dst_indices_cpu = dst_indices.cpu(); + + const auto num_indices = src_indices_cpu.numel(); + const int64_t num_layers = src_layers.size(); + int64_t* src_indices_ptr = src_indices_cpu.data_ptr(); + int64_t* dst_indices_ptr = dst_indices_cpu.data_ptr(); + + int64_t start_index = 0; + int64_t end_index = 0; + + for (int64_t i = 0; i < num_indices; ++i) { + if (i < num_indices - 1) { + auto src_diff = src_indices_ptr[i + 1] - src_indices_ptr[i]; + auto dst_diff = dst_indices_ptr[i + 1] - dst_indices_ptr[i]; + + if (src_diff == 1 && dst_diff == 1) { + continue; + } + end_index = i + 1; + } else { // last batch + end_index = num_indices; + } + auto src_index = src_indices_ptr[start_index]; + auto dst_index = dst_indices_ptr[start_index]; + auto num_tokens = end_index - start_index; + + for (int64_t j = 0; j < num_layers; ++j) { + transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, num_tokens); + } + start_index = end_index; + } +} + +template +inline void transfer_kv_page_first_direct_impl( + const std::vector& src_ptrs, + std::vector dst_ptrs, + const at::Tensor& src_indices, + const at::Tensor& dst_indices, + int64_t start_layer_id, + int64_t page_size) { + TORCH_CHECK(src_indices.numel() == dst_indices.numel(), "Source and destination indices must have the same length"); + TORCH_CHECK(page_size > 0, "Page size must be positive"); + TORCH_CHECK(src_indices.numel() % page_size == 0, "Source indices size must be divisible by page size"); + + auto src_indices_cpu = src_indices.cpu(); + auto dst_indices_cpu = dst_indices.cpu(); + const int64_t num_pages = src_indices_cpu.size(0) / page_size; + + if constexpr (IsLf2Pf) { + const bool is_mla = dst_ptrs.size() == 1; + const int64_t num_layers = is_mla ? src_ptrs.size() : src_ptrs.size() / 2; + + for (const auto i : c10::irange(num_pages)) { + auto s_index = src_indices_cpu[i * page_size].item(); + auto d_index = dst_indices_cpu[i * page_size].item() / page_size; + for (int64_t j = 0; j < num_layers; ++j) { + transfer_page_direct( + src_ptrs[j], dst_ptrs[0].select(0, d_index).select(0, start_layer_id + j), s_index, 0, page_size); + if (!is_mla) { + transfer_page_direct( + src_ptrs[j + num_layers], + dst_ptrs[1].select(0, d_index).select(0, start_layer_id + j), + s_index, + 0, + page_size); + } + } + } + } else { + const bool is_mla = src_ptrs.size() == 1; + const int64_t num_layers = is_mla ? dst_ptrs.size() : dst_ptrs.size() / 2; + + for (const auto i : c10::irange(num_pages)) { + auto s_index = src_indices_cpu[i * page_size].item() / page_size; + auto d_index = dst_indices_cpu[i * page_size].item(); + for (int64_t j = 0; j < num_layers; ++j) { + transfer_page_direct( + src_ptrs[0].select(0, s_index).select(0, start_layer_id + j), dst_ptrs[j], 0, d_index, page_size); + if (!is_mla) { + transfer_page_direct( + src_ptrs[1].select(0, s_index).select(0, start_layer_id + j), + dst_ptrs[j + num_layers], + 0, + d_index, + page_size); + } + } + } + } +} + +void transfer_kv_per_layer_direct_pf_lf( + const std::vector& src_ptrs, + std::vector dst_ptrs, + const at::Tensor& src_indices, + const at::Tensor& dst_indices, + int64_t layer_id, + int64_t page_size) { + transfer_kv_page_first_direct_impl(src_ptrs, dst_ptrs, src_indices, dst_indices, layer_id, page_size); +} + +void transfer_kv_all_layer_direct_lf_pf( + const std::vector& src_ptrs, + std::vector dst_ptrs, + const at::Tensor& src_indices, + const at::Tensor& dst_indices, + int64_t page_size) { + transfer_kv_page_first_direct_impl(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size); +} diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.hip b/sgl-kernel/csrc/moe/moe_align_kernel.hip new file mode 100644 index 000000000000..1c60f7f714d4 --- /dev/null +++ b/sgl-kernel/csrc/moe/moe_align_kernel.hip @@ -0,0 +1,385 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include + +#include "utils_hip.h" + +#define VEC_SIZE 4 +using Vec = int4; + +template +__global__ void count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ cumsum_buffer, + size_t numel) { + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = blockDim.x * gridDim.x; + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); + sorted_token_ids[rank_post_pad] = i; + } +} + +#ifdef __CUDA_ARCH__ +__device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffffffffu) { + int original = v; +#pragma unroll + for (int offset = 1; offset < WARP_SIZE; offset <<= 1) { + int n = __shfl_up_sync(mask, v, offset); + if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n; + } + return v - original; +} +#endif + +template +__global__ void moe_align_block_size_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel, + int32_t* __restrict__ cumsum, + bool pad_sorted_token_ids, + const int32_t scan_size, + int32_t max_num_tokens_padded) { + // Use a separate thread block to populate sorted_token_ids + if (blockIdx.x == 1) { + if (pad_sorted_token_ids) { + Vec fill_vec; + fill_vec.x = fill_vec.y = fill_vec.z = fill_vec.w = numel; + int32_t total_vecs = (max_num_tokens_padded + VEC_SIZE - 1) / VEC_SIZE; + Vec* out_ptr = reinterpret_cast(sorted_token_ids); + for (int32_t i = threadIdx.x; i < total_vecs; i += blockDim.x) { + out_ptr[i] = fill_vec; + } + } + return; + } + + extern __shared__ int32_t smem[]; + int32_t* shared_counts = smem; // [num_experts] + int32_t* prefix = shared_counts + num_experts; // [num_experts + 1] + int32_t* scan_buf = prefix + num_experts + 1; // [scan_size] + __shared__ int32_t s_total_tokens_post_pad; + + const size_t tid = threadIdx.x; + const size_t stride = blockDim.x; + + if (tid < num_experts) { + shared_counts[tid] = 0; + } + + __syncthreads(); + + for (size_t i = tid; i < numel; i += stride) { + int expert_id = topk_ids[i] + 1; + atomicAdd(&shared_counts[expert_id], 1); + } + + __syncthreads(); + + int32_t padded_count = 0; + if (tid < num_experts) { + int32_t count = shared_counts[tid]; + padded_count = (count + block_size - 1) / block_size * block_size; + scan_buf[tid] = padded_count; + } + +#ifndef __CUDA_ARCH__ // HIP + + if (tid >= num_experts && tid < scan_size) { + scan_buf[tid] = 0; + } + + __syncthreads(); + + // Blelloch scan + int offset = 1; +#pragma unroll + for (int d = scan_size >> 1; d > 0; d >>= 1) { + if (tid < d) { + int ai = offset * (2 * tid + 1) - 1; + int bi = offset * (2 * tid + 2) - 1; + scan_buf[bi] += scan_buf[ai]; + } + offset <<= 1; + __syncthreads(); + } + + // down-sweep + if (tid == 0) { + prefix[num_experts] = scan_buf[scan_size - 1]; + scan_buf[scan_size - 1] = 0; + } + __syncthreads(); + +#pragma unroll + for (int d = 1; d < scan_size; d <<= 1) { + offset >>= 1; + if (tid < d) { + int ai = offset * (2 * tid + 1) - 1; + int bi = offset * (2 * tid + 2) - 1; + if (bi < scan_size) { + int temp = scan_buf[ai]; + scan_buf[ai] = scan_buf[bi]; + scan_buf[bi] += temp; + } + } + __syncthreads(); + } + + if (tid < num_experts) { + prefix[tid] = scan_buf[tid]; + } + + if (tid == 0) { + s_total_tokens_post_pad = prefix[num_experts]; + *total_tokens_post_pad = s_total_tokens_post_pad; + } + __syncthreads(); + +#else // CUDA + + // Intra warp prefix sum + int32_t* warp_sums = scan_buf + scan_size; // [<= 32] + const int warp_id = tid / WARP_SIZE; + const int lane_id = tid & (WARP_SIZE - 1); + const int num_warps_for_scan = (scan_size + WARP_SIZE - 1) / WARP_SIZE; + const int warp_sum = warp_exclusive_scan(padded_count) + padded_count; + if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = warp_sum; + __syncthreads(); + + // warp0 accumulate all the block's prefix sum + if (tid < WARP_SIZE) { + int val = (tid < num_warps_for_scan) ? warp_sums[tid] : 0; + int incl = warp_exclusive_scan(val) + val; + warp_sums[tid] = incl; + } + __syncthreads(); + + // Every thread obtains the whole block's sum + if (tid == 0) { + prefix[num_experts] = warp_sums[num_warps_for_scan - 1]; + s_total_tokens_post_pad = prefix[num_experts]; + *total_tokens_post_pad = s_total_tokens_post_pad; + } + __syncthreads(); + + // Fill 0 to scan_buf extended area (tid >= num_expert) + if (tid >= num_experts && tid < scan_size) scan_buf[tid] = 0; + __syncthreads(); + + // Perform 2 level exclusive-prefix-sum to scan_buf + int v = (tid < scan_size) ? scan_buf[tid] : 0; + int pre = warp_exclusive_scan(v); + if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = pre + v; + __syncthreads(); + + if (warp_id == 0) { + int val = (lane_id < num_warps_for_scan) ? warp_sums[lane_id] : 0; + warp_sums[lane_id] = warp_exclusive_scan(val); + } + __syncthreads(); + + int offset = warp_sums[warp_id]; + if (tid < scan_size) scan_buf[tid] = pre + offset; + __syncthreads(); + + // Write prefix[0..num_experts - 1] and cumsum + if (tid < num_experts) prefix[tid] = scan_buf[tid]; +#endif + + if (tid <= num_experts) { + cumsum[tid] = prefix[tid]; + } + // fill expert_ids + const int32_t num_blocks = s_total_tokens_post_pad / block_size; + for (int32_t i = tid; i < num_blocks; i += stride) { + int32_t block_start = i * block_size; + int left = 0, right = num_experts; + while (left < right) { + int mid = (left + right) >> 1; + if (prefix[mid] <= block_start) { + left = mid + 1; + } else { + right = mid; + } + } + expert_ids[i] = left - 2; + } +} + +template +__global__ void moe_align_block_size_small_batch_expert_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, + size_t numel, + bool pad_sorted_token_ids, + int32_t max_num_tokens_padded) { + // Adapted from + // https://github.com/vllm-project/vllm/pull/29642/files#diff-5647b1413f4ae9aacba904eca8f8a8aee9079321eadff4c10101a2c6962dcc53R226 + // Use an additional group of threads to fill sorted_token_ids. + // Since the kernel will use sorted_token_ids afterward, + // we fill sorted_token_ids within the same threadblock to make + // synchronization easier. + if (threadIdx.x < fill_threads) { + // Initialize sorted_token_ids with numel + if (pad_sorted_token_ids) { + for (int32_t it = threadIdx.x; it < max_num_tokens_padded; it += fill_threads) { + sorted_token_ids[it] = numel; + } + } + // Three __syncthreads() corresponding to the other threads + __syncthreads(); + __syncthreads(); + __syncthreads(); + return; + } + + const size_t tid = threadIdx.x - fill_threads; + const size_t stride = blockDim.x - fill_threads; + + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; + int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[(tid + 1) * num_experts + i] = 0; + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + ++tokens_cnts[(tid + 1) * num_experts + expert_id]; + } + + __syncthreads(); + + if (tid < num_experts) { + tokens_cnts[tid] = 0; + for (int i = 1; i <= stride; ++i) { + tokens_cnts[i * num_experts + tid] += tokens_cnts[(i - 1) * num_experts + tid]; + } + } + + __syncthreads(); + + if (tid == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[stride * num_experts + i - 1], block_size) * block_size; + } + *total_tokens_post_pad = static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + if (tid < num_experts) { + for (int i = cumsum[tid]; i < cumsum[tid + 1]; i += block_size) { + expert_ids[i / block_size] = tid - 1; + } + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i] + 1; + int32_t rank_post_pad = tokens_cnts[tid * num_experts + expert_id] + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[tid * num_experts + expert_id]; + } +} + +void moe_align_block_size( + torch::Tensor topk_ids, + int64_t num_experts, + int64_t block_size, + torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor num_tokens_post_pad, + torch::Tensor cumsum_buffer, + bool pad_sorted_token_ids) { + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + + int threads = 1024; + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + + int64_t max_num_tokens_padded = sorted_token_ids.size(0); + + DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t threads = max((int32_t)num_experts, WARP_SIZE); + constexpr int32_t fill_threads = 256; + const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); + + auto small_batch_expert_kernel = moe_align_block_size_small_batch_expert_kernel; + hipLaunchKernelGGL(( small_batch_expert_kernel), dim3(1), dim3(fill_threads + threads), shared_mem_size, stream, + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + block_size, + topk_ids.numel(), + pad_sorted_token_ids, + max_num_tokens_padded); + } else { + auto align_kernel = moe_align_block_size_kernel; + + const size_t scan_size = next_pow2(num_experts); + const size_t shared_mem_size = (num_experts + (num_experts + 1) + scan_size + WARP_SIZE) * sizeof(int32_t); + hipLaunchKernelGGL(( align_kernel), dim3(2), dim3(threads), shared_mem_size, stream, + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), + num_experts, + block_size, + topk_ids.numel(), + cumsum_buffer.data_ptr(), + pad_sorted_token_ids, + scan_size, + max_num_tokens_padded); + + const int block_threads = ::min(256, (int)threads); + const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = ::min(num_blocks, max_blocks); + + auto sort_kernel = count_and_sort_expert_tokens_kernel; + hipLaunchKernelGGL(( sort_kernel), dim3(actual_blocks), dim3(block_threads), 0, stream, + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + cumsum_buffer.data_ptr(), + topk_ids.numel()); + } + }); +} diff --git a/sgl-kernel/csrc/moe/moe_topk_sigmoid_kernels.hip b/sgl-kernel/csrc/moe/moe_topk_sigmoid_kernels.hip new file mode 100644 index 000000000000..e4bcd78ed290 --- /dev/null +++ b/sgl-kernel/csrc/moe/moe_topk_sigmoid_kernels.hip @@ -0,0 +1,594 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/moe/topk_softmax_kernels.cu +// which is originally adapted from +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#ifndef USE_ROCM +#include +#include +#include +#else +#include +#include +#endif + +#include "utils_hip.h" + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +// Define reduction operators based on CUDA version +// CUDA 13 (12.9+) deprecated hipcub::Max/Min in favor of cuda::maximum/minimum +#if TORCH_HIP_VERSION >= 12090 +using MaxReduceOp = cuda::maximum<>; +using MinReduceOp = cuda::minimum<>; +#else +using MaxReduceOp = hipcub::Max; +using MinReduceOp = hipcub::Min; +#endif + +/// Aligned array type +template < + typename T, + /// Number of elements in the array + int N, + /// Alignment requirement in bytes + int Alignment = sizeof(T) * N> +class alignas(Alignment) AlignedArray { + T data[N]; +}; + +// ========================== Util functions to convert types ========================== +template +__device__ float convert_to_float(T x) { + if constexpr (std::is_same_v) { + return __half2float(x); + } else if constexpr (std::is_same_v) { + return __bfloat162float(x); + } else if constexpr (std::is_same_v) { + return x; + } else { + return static_cast(x); + } +} + +// ====================== Sigmoid things =============================== +// We have our own implementation of sigmoid here so we can support transposing the output +// in the sigmoid kernel when we extend this module to support expert-choice routing. +template +__launch_bounds__(TPB) __global__ void moeSigmoid( + const T* input, const bool* finished, float* output, const int num_cols, const float* correction_bias) { + const int thread_row_offset = blockIdx.x * num_cols; + + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) { + return; + } + + // First pass: Apply transformation, find max, and write transformed values to output + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + float val = convert_to_float(input[idx]); + + val = 1.0f / (1.0f + expf(-val)); + + // Apply correction bias if provided + if (correction_bias != nullptr) { + val = val + correction_bias[ii]; + } + + output[idx] = val; // Store transformed value + } +} + +template +__launch_bounds__(TPB) __global__ void moeTopK( + const float* inputs_after_sigmoid, + const bool* finished, + float* output, + int* indices, + const int num_experts, + const int k, + const int start_expert, + const int end_expert, + const bool renormalize, + const float* correction_bias) { + using cub_kvp = hipcub::KeyValuePair; + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + hipcub::ArgMax arg_max; + + const int block_row = blockIdx.x; + + const bool row_is_active = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + float row_sum_for_renormalize = 0; + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = -1.f; // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_sigmoid[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + // Ignore experts the node isn't responsible for with expert parallelism + const int expert = result_kvp.key; + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + const int idx = k * block_row + k_idx; + float val = result_kvp.value; + if (correction_bias != nullptr) { + val -= correction_bias[expert]; + } + output[idx] = val; + indices[idx] = should_process_row ? (expert - start_expert) : num_experts; + assert(indices[idx] >= 0); + row_sum_for_renormalize += val; + } + __syncthreads(); + } + + if (renormalize && threadIdx.x == 0) { + float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize; + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int idx = k * block_row + k_idx; + output[idx] = output[idx] * row_sum_for_renormalize_inv; + } + } +} + +// ====================== TopK sigmoid things =============================== + +/* + A Top-K gating sigmoid written to exploit when the number of experts in the MoE layers + are a small power of 2. This allows us to cleanly share the rows among the threads in + a single warp and eliminate communication between warps (so no need to use shared mem). + + It fuses the sigmoid, max and argmax into a single kernel. + + Limitations: + 1) This implementation is intended for when the number of experts is a small power of 2. + 2) This implementation assumes k is small, but will work for any k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSigmoid( + const T* input, + const bool* finished, + float* output, + const int num_rows, + int* indices, + const int k, + const int start_expert, + const int end_expert, + const bool renormalize, + const float* correction_bias) { + // We begin by enforcing compile time assertions and setting up compile time constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. + // This, each block processes a chunk of rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) { + return; + } + const bool row_is_active = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the + // row it will read. + const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, + // this can support all powers of 2 up to 16. + // NOTE(woosuk): The original implementation uses CUTLASS aligned array here. + // We defined our own aligned array and use it here to avoid the dependency on CUTLASS. + using AccessType = AlignedArray; + + // Finally, we pull in the data from global mem + T row_chunk_temp[VPT]; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk_temp); + const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + // Note(Byron): interleaved loads to achieve better memory coalescing + // | thread[0] | thread[1] | thread[2] | thread[3] | thread[0] | thread[1] | thread[2] | thread[3] | ... + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + float row_chunk[VPT]; +#pragma unroll + // Note(Byron): upcast logits to float32 + for (int ii = 0; ii < VPT; ++ii) { + float val = convert_to_float(row_chunk_temp[ii]); + val = 1.0f / (1.0f + expf(-val)); + // Apply correction bias if provided + if (correction_bias != nullptr) { + /* + LDG is interleaved + |thread0 LDG| |thread1 LDG| |thread0 LDG| |thread1 LDG| + |--------- group0 --------| |----------group1 --------| + ^ local2 + */ + const int group_id = ii / ELTS_PER_LDG; + const int local_id = ii % ELTS_PER_LDG; + const int expert_idx = first_elt_read_by_thread + group_id * THREADS_PER_ROW * ELTS_PER_LDG + local_id; + val = val + correction_bias[expert_idx]; + } + + row_chunk[ii] = val; + } + + // Now, row_chunk contains the sigmoid of the row chunk. Now, I want to find the topk elements in each row, along + // with the max index. + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + float row_sum_for_renormalize = 0; + + for (int k_idx = 0; k_idx < k; ++k_idx) { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index are processed first and only + // updated if > (not >=) + if (val > max_val) { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. +// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can +// then blank out their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + float other_max = SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, max_val, mask, THREADS_PER_ROW); + int other_expert = SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) { + // Add a guard to ignore experts not included by this node + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + // The lead thread from each sub-group will write out the final results to global memory. (This will be a + // single) thread per row of the input/output matrices. + const int idx = k * thread_row + k_idx; + if (correction_bias != nullptr) { + max_val -= correction_bias[expert]; + } + output[idx] = max_val; + indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; + row_sum_for_renormalize += max_val; + } + + // Finally, we clear the value in the thread with the current max if there is another iteration to run. + if (k_idx + 1 < k) { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f; + } + } + } + + // Fuse renormalization of topk_weights into this kernel + if (renormalize && thread_group_idx == 0) { + float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize; +#pragma unroll + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int idx = k * thread_row + k_idx; + output[idx] = output[idx] * row_sum_for_renormalize_inv; + } + } +} + +namespace detail { +// Constructs some constants needed to partition the work across threads at compile time. +template +struct TopkConstants { + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); + static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topkGatingSigmoidLauncherHelper( + const T* input, + const bool* finished, + float* output, + int* indices, + const int num_rows, + const int k, + const int start_expert, + const int end_expert, + const bool renormalize, + const float* correction_bias, + hipStream_t stream) { + static constexpr std::size_t MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + hipLaunchKernelGGL(( topkGatingSigmoid), dim3(num_blocks), dim3(block_dim), 0, stream, + input, finished, output, num_rows, indices, k, start_expert, end_expert, renormalize, correction_bias); +} + +#define LAUNCH_SIGMOID(TYPE, NUM_EXPERTS, WARPS_PER_TB) \ + topkGatingSigmoidLauncherHelper( \ + gating_output, \ + nullptr, \ + topk_weights, \ + topk_indices, \ + num_tokens, \ + topk, \ + 0, \ + num_experts, \ + renormalize, \ + correction_bias, \ + stream); + +template +void topkGatingSigmoidKernelLauncher( + const T* gating_output, + float* topk_weights, + int* topk_indices, + float* sigmoid_workspace, + const int num_tokens, + const int num_experts, + const int topk, + const bool renormalize, + const float* correction_bias, + hipStream_t stream) { + static constexpr int WARPS_PER_TB = 4; + switch (num_experts) { + case 1: + LAUNCH_SIGMOID(T, 1, WARPS_PER_TB); + break; + case 2: + LAUNCH_SIGMOID(T, 2, WARPS_PER_TB); + break; + case 4: + LAUNCH_SIGMOID(T, 4, WARPS_PER_TB); + break; + case 8: + LAUNCH_SIGMOID(T, 8, WARPS_PER_TB); + break; + case 16: + LAUNCH_SIGMOID(T, 16, WARPS_PER_TB); + break; + case 32: + LAUNCH_SIGMOID(T, 32, WARPS_PER_TB); + break; + case 64: + LAUNCH_SIGMOID(T, 64, WARPS_PER_TB); + break; + case 128: + LAUNCH_SIGMOID(T, 128, WARPS_PER_TB); + break; + case 256: + LAUNCH_SIGMOID(T, 256, WARPS_PER_TB); + break; + default: { + TORCH_CHECK( + sigmoid_workspace != nullptr, + "sigmoid_workspace must be provided for num_experts that are not a power of 2."); + static constexpr int TPB = 256; + hipLaunchKernelGGL(( moeSigmoid) + , dim3(num_tokens), dim3(TPB), 0, stream, gating_output, nullptr, sigmoid_workspace, num_experts, correction_bias); + hipLaunchKernelGGL(( moeTopK), dim3(num_tokens), dim3(TPB), 0, stream, + sigmoid_workspace, + nullptr, + topk_weights, + topk_indices, + num_experts, + topk, + 0, + num_experts, + renormalize, + correction_bias); + } + } +} + +void topk_sigmoid( + torch::Tensor& topk_weights, // [num_tokens, topk] + torch::Tensor& topk_indices, // [num_tokens, topk] + torch::Tensor& gating_output, // [num_tokens, num_experts] + const bool renormalize, + const c10::optional& correction_bias) { + // Check data type + TORCH_CHECK( + gating_output.scalar_type() == at::ScalarType::Float || gating_output.scalar_type() == at::ScalarType::Half || + gating_output.scalar_type() == at::ScalarType::BFloat16, + "gating_output must be float32, float16, or bfloat16"); + + // Check dimensions + TORCH_CHECK(gating_output.dim() == 2, "gating_output must be 2D tensor [num_tokens, num_experts]"); + TORCH_CHECK(topk_weights.dim() == 2, "topk_weights must be 2D tensor [num_tokens, topk]"); + TORCH_CHECK(topk_indices.dim() == 2, "topk_indices must be 2D tensor [num_tokens, topk]"); + + // Check shapes + TORCH_CHECK( + gating_output.size(0) == topk_weights.size(0), + "First dimension of topk_weights must match num_tokens in gating_output"); + TORCH_CHECK( + gating_output.size(0) == topk_indices.size(0), + "First dimension of topk_indices must match num_tokens in gating_output"); + TORCH_CHECK( + topk_weights.size(-1) == topk_indices.size(-1), + "Second dimension of topk_indices must match topk in topk_weights"); + TORCH_CHECK(topk_weights.size(-1) <= gating_output.size(-1), "topk must be less than or equal to num_experts"); + + const int num_experts = static_cast(gating_output.size(-1)); + const int num_tokens = static_cast(gating_output.size(0)); + const int topk = static_cast(topk_weights.size(-1)); + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + const bool needs_workspace = !is_pow_2 || num_experts > 256; + const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0; + + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(gating_output)); + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + torch::Tensor sigmoid_workspace = + torch::empty({workspace_size}, gating_output.options().dtype(at::ScalarType::Float)); + + const at::ScalarType dtype = gating_output.scalar_type(); + + // Validate correction_bias if provided - must always be float32 + const float* bias_ptr = nullptr; + if (correction_bias.has_value()) { + const torch::Tensor& bias_tensor = correction_bias.value(); + TORCH_CHECK(bias_tensor.dim() == 1, "correction_bias must be 1D tensor [num_experts]"); + TORCH_CHECK(bias_tensor.size(0) == num_experts, "correction_bias size must match num_experts"); + TORCH_CHECK( + bias_tensor.scalar_type() == at::ScalarType::Float, + "correction_bias must be float32, got ", + bias_tensor.scalar_type()); + bias_ptr = bias_tensor.data_ptr(); + } + + if (dtype == at::ScalarType::Float) { + topkGatingSigmoidKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + sigmoid_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + renormalize, + bias_ptr, + stream); + } else if (dtype == at::ScalarType::Half) { + topkGatingSigmoidKernelLauncher<__half>( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + sigmoid_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + renormalize, + bias_ptr, + stream); + } else if (dtype == at::ScalarType::BFloat16) { + topkGatingSigmoidKernelLauncher<__hip_bfloat16>( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + sigmoid_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + renormalize, + bias_ptr, + stream); + } else { + TORCH_CHECK(false, "Unsupported gating_output dtype: ", dtype); + } +} diff --git a/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.hip b/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.hip new file mode 100644 index 000000000000..ca09ce080553 --- /dev/null +++ b/sgl-kernel/csrc/moe/moe_topk_softmax_kernels.hip @@ -0,0 +1,822 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +// Adapt from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/moe/topk_softmax_kernels.cu +// which is originally adapted from +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#ifndef USE_ROCM +#include +#include +#include +#else +#include +#include +#endif + +#include "utils_hip.h" + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +// Define reduction operators based on CUDA version +// CUDA 13 (12.9+) deprecated hipcub::Max/Min in favor of cuda::maximum/minimum +#if TORCH_HIP_VERSION >= 12090 +using MaxReduceOp = cuda::maximum<>; +using MinReduceOp = cuda::minimum<>; +#else +using MaxReduceOp = hipcub::Max; +using MinReduceOp = hipcub::Min; +#endif + +using cub_kvp = hipcub::KeyValuePair; + +/// Aligned array type +template < + typename T, + /// Number of elements in the array + int N, + /// Alignment requirement in bytes + int Alignment = sizeof(T) * N> +class alignas(Alignment) AlignedArray { + T data[N]; +}; + +// ========================== Util functions to convert types ========================== +template +__device__ float convert_to_float(T x) { + if constexpr (std::is_same_v) { + return __half2float(x); + } else if constexpr (std::is_same_v) { + return __bfloat162float(x); + } else if constexpr (std::is_same_v) { + return x; + } else { + return static_cast(x); + } +} + +// ====================== Softmax things =============================== +// We have our own implementation of softmax here so we can support transposing the output +// in the softmax kernel when we extend this module to support expert-choice routing. +template +__launch_bounds__(TPB) __global__ void moeSoftmax( + const T* input, + const bool* finished, + float* output, + const int num_cols, + const float moe_softcapping, + const float* correction_bias) { + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + const int thread_row_offset = blockIdx.x * num_cols; + + float threadData(-FLT_MAX); + + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) { + return; + } + + // First pass: Apply transformation, find max, and write transformed values to output + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + float val = convert_to_float(input[idx]); + + // Apply tanh softcapping if enabled + if (moe_softcapping != 0.0f) { + val = tanhf(val / moe_softcapping) * moe_softcapping; + } + + // Apply correction bias if provided + if (correction_bias != nullptr) { + val = val + correction_bias[ii]; + } + + output[idx] = val; // Store transformed value + threadData = max(val, threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, MaxReduceOp()); + + if (threadIdx.x == 0) { + float_max = maxElem; + } + __syncthreads(); + + // Second pass: Compute sum using transformed values from output + threadData = 0; + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + threadData += exp((output[idx] - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Sum(threadData); + + if (threadIdx.x == 0) { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + // Third pass: Compute final softmax using transformed values from output + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) { + const int idx = thread_row_offset + ii; + const float softmax_val = exp((output[idx] - float_max)) * normalizing_factor; + output[idx] = softmax_val; + } +} + +namespace moe { +struct TopKPair { + static const int PAIR = 2; + static const int MAX_INDEX = 0; + cub_kvp max; + cub_kvp secondMax; + + __device__ TopKPair() {} + __device__ TopKPair(cub_kvp max, cub_kvp secondMax) : max(max), secondMax(secondMax) {} +}; + +struct TopKPairArgMax { + __device__ TopKPairArgMax() {} + __device__ __forceinline__ TopKPair operator()(const TopKPair& candidate1, const TopKPair& candidate2) const { + cub_kvp globalMax, globalSecondMax; + + // Determine the global maximum + if (candidate1.max.value > candidate2.max.value) { + globalMax = candidate1.max; + } else { + globalMax = candidate2.max; + } + + // Determine the global second maximum + if (globalMax.key == candidate1.max.key) { + // If candidate1 contributed the max, compare its secondMax with candidate2's max + globalSecondMax = (candidate1.secondMax.value > candidate2.max.value) ? candidate1.secondMax : candidate2.max; + } else { + // If candidate2 contributed the max, compare its secondMax with candidate1's max + globalSecondMax = (candidate2.secondMax.value > candidate1.max.value) ? candidate2.secondMax : candidate1.max; + } + return TopKPair(globalMax, globalSecondMax); + } +}; +} // namespace moe + +template +__launch_bounds__(TPB) __global__ void moeTopKFast( + float* inputs_after_softmax, + const bool* finished, + float* output, + int* indices, + const int num_experts, + const int k, + const int start_expert, + const int end_expert, + const bool renormalize) { + using namespace moe; + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + TopKPair thread_pair; + + const int block_row = blockIdx.x; + + const bool row_is_active = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + float row_sum_for_renormalize = 0; + // Each loop finds the top 2 elements, + // thus requiring only ⌈k/2⌉ loops (calculated as (k + 1) / 2). + for (int k_idx = 0; k_idx < (k + TopKPair::PAIR - 1) / TopKPair::PAIR; ++k_idx) { + // Initializing the top 2 elements by the minimum value. + thread_pair.max.key = 0; + thread_pair.max.value = -1.f; + thread_pair.secondMax.key = 0; + thread_pair.secondMax.value = -1.f; + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + // updating the thread_pair according to inp_kvp's value + if (inp_kvp.value > thread_pair.max.value) { + thread_pair.secondMax = thread_pair.max; + thread_pair.max = inp_kvp; + } else if (inp_kvp.value > thread_pair.secondMax.value) { + thread_pair.secondMax = inp_kvp; + } + } + + TopKPairArgMax reducer; + const TopKPair result_pair = BlockReduce(tmpStorage).Reduce(thread_pair, reducer); + if (threadIdx.x == 0) { +#pragma unroll + // updating 2 elements to the result. + for (int i = 0; i < TopKPair::PAIR; i++) { + if (k_idx * 2 + i >= k) break; + cub_kvp result = (i == TopKPair::MAX_INDEX) ? result_pair.max : result_pair.secondMax; + int expert = result.key; + bool node_uses_expert = expert >= start_expert && expert < end_expert; + bool should_process_row = row_is_active && node_uses_expert; + // The inputs_after_softmax is modified in-place to avoid unnecessary loops for finding the top k-1 value. + // 1.f represents the minimum value. + inputs_after_softmax[thread_read_offset + expert] = -1.f; + int idx = k * block_row + k_idx * 2 + i; + output[idx] = result.value; + indices[idx] = should_process_row ? (expert - start_expert) : num_experts; + assert(indices[idx] >= 0); + row_sum_for_renormalize += result.value; + } + } + __syncthreads(); + } + + if (renormalize && threadIdx.x == 0) { + float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize; + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int idx = k * block_row + k_idx; + output[idx] = output[idx] * row_sum_for_renormalize_inv; + } + } +} + +template +__launch_bounds__(TPB) __global__ void moeTopK( + float* inputs_after_softmax, + const bool* finished, + float* output, + int* indices, + const int num_experts, + const int k, + const int start_expert, + const int end_expert, + const bool renormalize) { + using cub_kvp = hipcub::KeyValuePair; + using BlockReduce = hipcub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + hipcub::ArgMax arg_max; + + const int block_row = blockIdx.x; + + const bool row_is_active = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + float row_sum_for_renormalize = 0; + for (int k_idx = 0; k_idx < k; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = -1.f; // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + // Ignore experts the node isn't responsible for with expert parallelism + const int expert = result_kvp.key; + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? (expert - start_expert) : num_experts; + assert(indices[idx] >= 0); + row_sum_for_renormalize += result_kvp.value; + // The inputs_after_softmax is modified in-place to avoid unnecessary loops for finding the top k-1 value. + // 1.f represents the minimum value. + inputs_after_softmax[thread_read_offset + expert] = -1.f; + } + __syncthreads(); + } + + if (renormalize && threadIdx.x == 0) { + float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize; + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int idx = k * block_row + k_idx; + output[idx] = output[idx] * row_sum_for_renormalize_inv; + } + } +} + +// ====================== TopK softmax things =============================== + +/* + A Top-K gating softmax written to exploit when the number of experts in the MoE layers + are a small power of 2. This allows us to cleanly share the rows among the threads in + a single warp and eliminate communication between warps (so no need to use shared mem). + + It fuses the softmax, max and argmax into a single kernel. + + Limitations: + 1) This implementation is intended for when the number of experts is a small power of 2. + 2) This implementation assumes k is small, but will work for any k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ void topkGatingSoftmax( + const T* input, + const bool* finished, + float* output, + const int num_rows, + int* indices, + const int k, + const int start_expert, + const int end_expert, + const bool renormalize, + const float moe_softcapping, + const float* correction_bias) { + // We begin by enforcing compile time assertions and setting up compile time constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. + // This, each block processes a chunk of rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) { + return; + } + const bool row_is_active = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the + // row it will read. + const T* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const T* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, + // this can support all powers of 2 up to 16. + // NOTE(woosuk): The original implementation uses CUTLASS aligned array here. + // We defined our own aligned array and use it here to avoid the dependency on CUTLASS. + using AccessType = AlignedArray; + + // Finally, we pull in the data from global mem + T row_chunk_temp[VPT]; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk_temp); + const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + // Note(Byron): interleaved loads to achieve better memory coalescing + // | thread[0] | thread[1] | thread[2] | thread[3] | thread[0] | thread[1] | thread[2] | thread[3] | ... + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + float row_chunk[VPT]; +#pragma unroll + // Note(Byron): upcast logits to float32 + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = convert_to_float(row_chunk_temp[ii]); + } + + // Apply tanh softcapping and correction bias + if (moe_softcapping != 0.0f || correction_bias != nullptr) { +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + float val = row_chunk[ii]; + + // Apply tanh softcapping if enabled + if (moe_softcapping != 0.0f) { + val = tanhf(val / moe_softcapping) * moe_softcapping; + } + + // Apply correction bias if provided + if (correction_bias != nullptr) { + /* + LDG is interleaved + |thread0 LDG| |thread1 LDG| |thread0 LDG| |thread1 LDG| + |--------- group0 --------| |----------group1 --------| + ^ local2 + */ + const int group_id = ii / ELTS_PER_LDG; + const int local_id = ii % ELTS_PER_LDG; + const int expert_idx = first_elt_read_by_thread + group_id * THREADS_PER_ROW * ELTS_PER_LDG + local_id; + val = val + correction_bias[expert_idx]; + } + + row_chunk[ii] = val; + } + } + + // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just + // convert to float afterwards for the exp + sum reduction. + float thread_max = row_chunk[0]; +#pragma unroll + for (int ii = 1; ii < VPT; ++ii) { + thread_max = max(thread_max, row_chunk[ii]); + } + + /*********************************/ + /********* Softmax Begin *********/ + /*********************************/ + +// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. +// lane id: 0-31 within a warp +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + // butterfly reduce with (lane id ^ mask) + thread_max = max(thread_max, SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, thread_max, mask, THREADS_PER_ROW)); + } + + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. + float row_sum = 0; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } + +// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + row_sum += SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, row_sum, mask, THREADS_PER_ROW); + } + + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables + // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. + // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; + +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + /*******************************/ + /********* Softmax End *********/ + /*******************************/ + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along + // with the max index. + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + float row_sum_for_renormalize = 0; + + for (int k_idx = 0; k_idx < k; ++k_idx) { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index are processed first and only + // updated if > (not >=) + if (val > max_val) { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. +// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can +// then blank out their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) { + float other_max = SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, max_val, mask, THREADS_PER_ROW); + int other_expert = SGLANG_SHFL_XOR_SYNC_WIDTH(0xffffffff, expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) { + // Add a guard to ignore experts not included by this node + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + // The lead thread from each sub-group will write out the final results to global memory. (This will be a + // single) thread per row of the input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = max_val; + indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; + row_sum_for_renormalize += max_val; + } + + // Finally, we clear the value in the thread with the current max if there is another iteration to run. + if (k_idx + 1 < k) { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f; + } + } + } + + // Fuse renormalization of topk_weights into this kernel + if (renormalize && thread_group_idx == 0) { + float row_sum_for_renormalize_inv = 1.f / row_sum_for_renormalize; +#pragma unroll + for (int k_idx = 0; k_idx < k; ++k_idx) { + const int idx = k * thread_row + k_idx; + output[idx] = output[idx] * row_sum_for_renormalize_inv; + } + } +} + +namespace detail { +// Constructs some constants needed to partition the work across threads at compile time. +template +struct TopkConstants { + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(T); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); + static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topkGatingSoftmaxLauncherHelper( + const T* input, + const bool* finished, + float* output, + int* indices, + const int num_rows, + const int k, + const int start_expert, + const int end_expert, + const bool renormalize, + const float moe_softcapping, + const float* correction_bias, + hipStream_t stream) { + static constexpr std::size_t MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(T) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + hipLaunchKernelGGL(( topkGatingSoftmax), dim3(num_blocks), dim3(block_dim), 0, stream, + input, + finished, + output, + num_rows, + indices, + k, + start_expert, + end_expert, + renormalize, + moe_softcapping, + correction_bias); +} + +#define LAUNCH_SOFTMAX(TYPE, NUM_EXPERTS, WARPS_PER_TB) \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, \ + nullptr, \ + topk_weights, \ + topk_indices, \ + num_tokens, \ + topk, \ + 0, \ + num_experts, \ + renormalize, \ + moe_softcapping, \ + correction_bias, \ + stream); + +template +void topkGatingSoftmaxKernelLauncher( + const T* gating_output, + float* topk_weights, + int* topk_indices, + float* softmax_workspace, + const int num_tokens, + const int num_experts, + const int topk, + const bool renormalize, + const float moe_softcapping, + const float* correction_bias, + hipStream_t stream) { + static constexpr int WARPS_PER_TB = 4; + switch (num_experts) { + case 1: + LAUNCH_SOFTMAX(T, 1, WARPS_PER_TB); + break; + case 2: + LAUNCH_SOFTMAX(T, 2, WARPS_PER_TB); + break; + case 4: + LAUNCH_SOFTMAX(T, 4, WARPS_PER_TB); + break; + case 8: + LAUNCH_SOFTMAX(T, 8, WARPS_PER_TB); + break; + case 16: + LAUNCH_SOFTMAX(T, 16, WARPS_PER_TB); + break; + case 32: + LAUNCH_SOFTMAX(T, 32, WARPS_PER_TB); + break; + case 64: + LAUNCH_SOFTMAX(T, 64, WARPS_PER_TB); + break; + case 128: + LAUNCH_SOFTMAX(T, 128, WARPS_PER_TB); + break; + case 256: + LAUNCH_SOFTMAX(T, 256, WARPS_PER_TB); + break; + default: { + TORCH_CHECK( + softmax_workspace != nullptr, + "softmax_workspace must be provided for num_experts that are not a power of 2."); + static constexpr int TPB = 256; + hipLaunchKernelGGL(( moeSoftmax), dim3(num_tokens), dim3(TPB), 0, stream, + gating_output, nullptr, softmax_workspace, num_experts, moe_softcapping, correction_bias); + if (topk == 1) { + // Note: As an optimization for better performance, + // the softmax_workspace is overwritten in-place by both moeTopK and moeTopKFast. + hipLaunchKernelGGL(( moeTopK), dim3(num_tokens), dim3(TPB), 0, stream, + softmax_workspace, nullptr, topk_weights, topk_indices, num_experts, topk, 0, num_experts, renormalize); + } else { + hipLaunchKernelGGL(( moeTopKFast), dim3(num_tokens), dim3(TPB), 0, stream, + softmax_workspace, nullptr, topk_weights, topk_indices, num_experts, topk, 0, num_experts, renormalize); + } + } + } +} + +void topk_softmax( + torch::Tensor& topk_weights, // [num_tokens, topk] + torch::Tensor& topk_indices, // [num_tokens, topk] + torch::Tensor& gating_output, // [num_tokens, num_experts] + const bool renormalize, + const double moe_softcapping, + const c10::optional& correction_bias) { + // Check data type + TORCH_CHECK( + gating_output.scalar_type() == at::ScalarType::Float || gating_output.scalar_type() == at::ScalarType::Half || + gating_output.scalar_type() == at::ScalarType::BFloat16, + "gating_output must be float32, float16, or bfloat16"); + + // Check dimensions + TORCH_CHECK(gating_output.dim() == 2, "gating_output must be 2D tensor [num_tokens, num_experts]"); + TORCH_CHECK(topk_weights.dim() == 2, "topk_weights must be 2D tensor [num_tokens, topk]"); + TORCH_CHECK(topk_indices.dim() == 2, "topk_indices must be 2D tensor [num_tokens, topk]"); + + // Check shapes + TORCH_CHECK( + gating_output.size(0) == topk_weights.size(0), + "First dimension of topk_weights must match num_tokens in gating_output"); + TORCH_CHECK( + gating_output.size(0) == topk_indices.size(0), + "First dimension of topk_indices must match num_tokens in gating_output"); + TORCH_CHECK( + topk_weights.size(-1) == topk_indices.size(-1), + "Second dimension of topk_indices must match topk in topk_weights"); + TORCH_CHECK(topk_weights.size(-1) <= gating_output.size(-1), "topk must be less than or equal to num_experts"); + + const int num_experts = static_cast(gating_output.size(-1)); + const int num_tokens = static_cast(gating_output.size(0)); + const int topk = static_cast(topk_weights.size(-1)); + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + const bool needs_workspace = !is_pow_2 || num_experts > 256; + const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0; + + const at::hip::OptionalHIPGuardMasqueradingAsCUDA device_guard(device_of(gating_output)); + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + torch::Tensor softmax_workspace = + torch::empty({workspace_size}, gating_output.options().dtype(at::ScalarType::Float)); + + const at::ScalarType dtype = gating_output.scalar_type(); + + // Validate correction_bias if provided - must always be float32 + const float* bias_ptr = nullptr; + if (correction_bias.has_value()) { + const torch::Tensor& bias_tensor = correction_bias.value(); + TORCH_CHECK(bias_tensor.dim() == 1, "correction_bias must be 1D tensor [num_experts]"); + TORCH_CHECK(bias_tensor.size(0) == num_experts, "correction_bias size must match num_experts"); + TORCH_CHECK( + bias_tensor.scalar_type() == at::ScalarType::Float, + "correction_bias must be float32, got ", + bias_tensor.scalar_type()); + bias_ptr = bias_tensor.data_ptr(); + } + + // Cast moe_softcapping from double to float for CUDA kernels + const float moe_softcapping_f = static_cast(moe_softcapping); + + if (dtype == at::ScalarType::Float) { + topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + renormalize, + moe_softcapping_f, + bias_ptr, + stream); + } else if (dtype == at::ScalarType::Half) { + topkGatingSoftmaxKernelLauncher<__half>( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + renormalize, + moe_softcapping_f, + bias_ptr, + stream); + } else if (dtype == at::ScalarType::BFloat16) { + topkGatingSoftmaxKernelLauncher<__hip_bfloat16>( + reinterpret_cast(gating_output.data_ptr()), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + renormalize, + moe_softcapping_f, + bias_ptr, + stream); + } else { + TORCH_CHECK(false, "Unsupported gating_output dtype: ", dtype); + } +} diff --git a/sgl-kernel/csrc/speculative/eagle_utils.hip b/sgl-kernel/csrc/speculative/eagle_utils.hip new file mode 100644 index 000000000000..a67d5b3a7c04 --- /dev/null +++ b/sgl-kernel/csrc/speculative/eagle_utils.hip @@ -0,0 +1,409 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +/* + * Copyright (c) 2025 by SGLang team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#ifndef USE_ROCM +#include "pytorch_extension_utils.h" +#else +#include "pytorch_extension_utils_rocm.h" +#endif + +typedef enum { FULL_MASK = 0, QLEN_ONLY = 1, QLEN_ONLY_BITPACKING = 2 } TreeMaskMode; + +// parent_list [bs, topk * (depth - 1) + 1)] +// selected_index [bs, draft_token_num - 1] +// verified_seq_len [bs] +// tree_mask [draft_token*(seq_len[0]+draft_token) | draft_token*(seq_len[1]+draft_token) | ..] = +// [sum(verified_seq_len)*draft_token+bs*draft_token*draft_token] positions [bs * draft_token] retrive_index [b, +// draft_token] retrive_next_token [b, draft_token] retrive_next_sibling [b, draft_token] +__global__ void build_tree_efficient( + int64_t* parent_list, + int64_t* selected_index, + int64_t* verified_seq_len, + bool* tree_mask, + int64_t* positions, + int64_t* retrive_index, + int64_t* retrive_next_token, + int64_t* retrive_next_sibling, + int topk, + int depth, + int draft_token_num, + int tree_mask_mode) { + int bid = blockIdx.x; + int tid = threadIdx.x; + + if (tid >= draft_token_num) { + return; + } + int seq_tree_idx = draft_token_num * draft_token_num * bid; + for (int i = 0; i < bid; i++) { + seq_tree_idx += verified_seq_len[i] * draft_token_num; + } + int seq_len = verified_seq_len[bid]; + int token_tree_idx; + if (tree_mask_mode == FULL_MASK) { + token_tree_idx = seq_tree_idx + (seq_len + draft_token_num) * tid + seq_len + 1; + } else { + token_tree_idx = draft_token_num * draft_token_num * bid + draft_token_num * tid + 1; + } + tree_mask[token_tree_idx - 1] = true; + for (int i = 0; i < draft_token_num - 1; i++) { + tree_mask[token_tree_idx + i] = false; + } + + int position = 0; + if (tid == 0) { + positions[bid * draft_token_num] = seq_len; + + int retrive_index_offset = bid * draft_token_num; + for (int i = draft_token_num - 1; i > 0; --i) { + int current_token_idx = retrive_index_offset + i; + retrive_index[bid * draft_token_num + i] = current_token_idx; + int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk; + int parent_position = 0; + if (parent_tb_idx > 0) { + int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx]; + for (; parent_position < draft_token_num; ++parent_position) { + if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) { + ++parent_position; + break; + } + } + } + if (parent_position == draft_token_num) { + printf( + "WARNING: invalid eagle tree!!! Detected a token with no parent token selected. " + "Please check if the logprob has nan. The token will be ignored to keep proceeding.\n"); + continue; + } + + if (retrive_next_token[bid * draft_token_num + parent_position] == -1) { + retrive_next_token[bid * draft_token_num + parent_position] = i; + } else { + int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position]; + retrive_next_token[bid * draft_token_num + parent_position] = i; + retrive_next_sibling[bid * draft_token_num + i] = origin_next_token; + } + } + retrive_index[bid * draft_token_num] = bid * draft_token_num; + } else { + int cur_position = tid - 1; + while (true) { + position += 1; + tree_mask[token_tree_idx + cur_position] = true; + int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk; + if (parent_tb_idx == 0) { + break; + } + + int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx]; + for (cur_position = 0; cur_position < draft_token_num; ++cur_position) { + if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) { + break; + } + } + } + positions[bid * draft_token_num + tid] = position + seq_len; + } +} + +// parent_list [bs, topk * (depth - 1) + 1)] +// selected_index [bs, draft_token_num - 1] +// verified_seq_len [bs] +// tree_mask: [draft_token*num_bytes_per_item | .. ] = [bs*draft_token*num_bytes_per_item] +// positions [bs * draft_token] +// retrive_index [bs, draft_token] +// retrive_next_token [bs, draft_token] +// retrive_next_sibling [bs, draft_token] +__global__ void build_tree_efficient_partial_packed( + int64_t* parent_list, + int64_t* selected_index, + int64_t* verified_seq_len, + uint8_t* tree_mask, + int64_t* positions, + int64_t* retrive_index, + int64_t* retrive_next_token, + int64_t* retrive_next_sibling, + int topk, + int depth, + int draft_token_num, + size_t num_bytes_per_item) { + int bid = blockIdx.x; + int tid = threadIdx.x; + + if (tid >= draft_token_num) { + return; + } + int seq_len = verified_seq_len[bid]; + int token_tree_idx = (bid * draft_token_num + tid) * num_bytes_per_item; + tree_mask[token_tree_idx] = 1; // little endian + + int position = 0; + if (tid == 0) { + positions[bid * draft_token_num] = seq_len; + + int retrive_index_offset = bid * draft_token_num; + for (int i = draft_token_num - 1; i > 0; --i) { + int current_token_idx = retrive_index_offset + i; + retrive_index[bid * draft_token_num + i] = current_token_idx; + int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + i - 1] / topk; + int parent_position = 0; + if (parent_tb_idx > 0) { + int parent_token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx]; + for (; parent_position < draft_token_num; ++parent_position) { + if (selected_index[bid * (draft_token_num - 1) + parent_position] == parent_token_idx) { + ++parent_position; + break; + } + } + } + if (parent_position == draft_token_num) { + printf( + "WARNING: invalid eagle tree!!! Detected a token with no parent token selected. " + "Please check if the logprob has nan. The token will be ignored to keep proceeding.\n"); + continue; + } + + if (retrive_next_token[bid * draft_token_num + parent_position] == -1) { + retrive_next_token[bid * draft_token_num + parent_position] = i; + } else { + int origin_next_token = retrive_next_token[bid * draft_token_num + parent_position]; + retrive_next_token[bid * draft_token_num + parent_position] = i; + retrive_next_sibling[bid * draft_token_num + i] = origin_next_token; + } + } + retrive_index[bid * draft_token_num] = bid * draft_token_num; + } else { + int cur_position = tid - 1; + while (true) { + position += 1; + int byte_idx = (cur_position + 1) / 8; + int bit_idx = (cur_position + 1) % 8; + tree_mask[token_tree_idx + byte_idx] |= (1 << bit_idx); + int parent_tb_idx = selected_index[bid * (draft_token_num - 1) + cur_position] / topk; + if (parent_tb_idx == 0) { + break; + } + + int token_idx = parent_list[bid * (topk * (depth - 1) + 1) + parent_tb_idx]; + for (cur_position = 0; cur_position < draft_token_num; ++cur_position) { + if (selected_index[bid * (draft_token_num - 1) + cur_position] == token_idx) { + break; + } + } + } + positions[bid * draft_token_num + tid] = position + seq_len; + } +} + +void build_tree_kernel_efficient( + at::Tensor parent_list, + at::Tensor selected_index, + at::Tensor verified_seq_len, + at::Tensor tree_mask, + at::Tensor positions, + at::Tensor retrive_index, + at::Tensor retrive_next_token, + at::Tensor retrive_next_sibling, + int64_t topk, + int64_t depth, + int64_t draft_token_num, + int64_t tree_mask_mode) { + // TODO (ying) check shape + // TODO (ying) check type + int bs = parent_list.size(0); + dim3 grid(bs); + dim3 block(draft_token_num); + const hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + + if (tree_mask_mode == QLEN_ONLY_BITPACKING) { + size_t num_bytes_per_item = 1; + if (draft_token_num > 16) { + num_bytes_per_item = 4; + } else if (draft_token_num > 8) { + num_bytes_per_item = 2; + } + hipLaunchKernelGGL(( build_tree_efficient_partial_packed), dim3(grid), dim3(block), 0, stream, + static_cast(parent_list.data_ptr()), + static_cast(selected_index.data_ptr()), + static_cast(verified_seq_len.data_ptr()), + static_cast(tree_mask.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), + int32_t(topk), + int32_t(depth), + int32_t(draft_token_num), + num_bytes_per_item); + } else { + hipLaunchKernelGGL(( build_tree_efficient), dim3(grid), dim3(block), 0, stream, + static_cast(parent_list.data_ptr()), + static_cast(selected_index.data_ptr()), + static_cast(verified_seq_len.data_ptr()), + static_cast(tree_mask.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), + int32_t(topk), + int32_t(depth), + int32_t(draft_token_num), + int32_t(tree_mask_mode)); + } +} + +template +__global__ void VerifyTreeGreedy( + IdType* predicts, + IdType* accept_index, + IdType* accept_token_num, // mutable + IdType2* candidates, + IdType2* retrive_index, + IdType2* retrive_next_token, + IdType2* retrive_next_sibling, + IdType2* target_predict, + uint32_t batch_size, + uint32_t num_speculative_tokens, + uint32_t num_draft_tokens) { + uint32_t bx = blockIdx.x; + + IdType2 last_accepted_retrive_idx = retrive_index[bx * num_draft_tokens]; + accept_index[bx * num_speculative_tokens] = last_accepted_retrive_idx; + uint32_t num_accepted_tokens = 0; + IdType2 cur_index = 0; + + for (uint32_t j = 1; j < num_speculative_tokens; ++j) { + cur_index = retrive_next_token[bx * num_draft_tokens + cur_index]; + while (cur_index != -1) { + IdType2 draft_index = retrive_index[bx * num_draft_tokens + cur_index]; + IdType2 draft_token_id = candidates[bx * num_draft_tokens + cur_index]; + IdType2 target_token_id = target_predict[last_accepted_retrive_idx]; + + if (draft_token_id == target_token_id) { + // accept token + predicts[last_accepted_retrive_idx] = target_token_id; + ++num_accepted_tokens; + accept_index[bx * num_speculative_tokens + num_accepted_tokens] = draft_index; + last_accepted_retrive_idx = draft_index; + break; + } else { + cur_index = retrive_next_sibling[bx * num_draft_tokens + cur_index]; + } + } + if (cur_index == -1) break; + } + accept_token_num[bx] = num_accepted_tokens; + predicts[last_accepted_retrive_idx] = target_predict[last_accepted_retrive_idx]; +} + +// predicts: [tot_num_draft_tokens] +// accept_index: [bs, num_spec_step] +// accept_token_num: [bs] +// candidates: [bs, num_draft_tokens] +// retrive_index: [bs, num_draft_tokens] +// retrive_next_token: [bs, num_draft_tokens] +// retrive_next_sibling: [bs, num_draft_tokens] +// target_predict: [bs, num_draft_tokens] +void verify_tree_greedy( + at::Tensor predicts, + at::Tensor accept_index, + at::Tensor accept_token_num, // mutable + at::Tensor candidates, + at::Tensor retrive_index, + at::Tensor retrive_next_token, + at::Tensor retrive_next_sibling, + at::Tensor target_predict) { + CHECK_INPUT(candidates); + CHECK_INPUT(retrive_index); + CHECK_INPUT(retrive_next_token); + CHECK_INPUT(retrive_next_sibling); + CHECK_INPUT(target_predict); + auto device = target_predict.device(); + CHECK_EQ(candidates.device(), device); + CHECK_EQ(retrive_index.device(), device); + CHECK_EQ(retrive_next_token.device(), device); + CHECK_EQ(retrive_next_sibling.device(), device); + CHECK_EQ(target_predict.device(), device); + CHECK_DIM(1, predicts); + CHECK_DIM(2, accept_index); + CHECK_DIM(1, accept_token_num); + CHECK_DIM(2, candidates); + CHECK_DIM(2, retrive_index); + CHECK_DIM(2, retrive_next_token); + CHECK_DIM(2, retrive_next_sibling); + CHECK_DIM(2, target_predict); + unsigned int batch_size = candidates.size(0); + unsigned int num_spec_step = accept_index.size(1); + unsigned int num_draft_tokens = candidates.size(1); + CHECK_EQ(batch_size, accept_index.size(0)); + CHECK_EQ(batch_size, accept_token_num.size(0)); + CHECK_EQ(batch_size, retrive_index.size(0)); + CHECK_EQ(batch_size, retrive_next_token.size(0)); + CHECK_EQ(batch_size, retrive_next_sibling.size(0)); + CHECK_EQ(batch_size, target_predict.size(0)); + CHECK_EQ(num_draft_tokens, retrive_index.size(1)); + CHECK_EQ(num_draft_tokens, retrive_next_token.size(1)); + CHECK_EQ(num_draft_tokens, retrive_next_sibling.size(1)); + CHECK_EQ(num_draft_tokens, target_predict.size(1)); + CHECK_EQ(batch_size, accept_index.size(0)); + CHECK_EQ(batch_size, accept_token_num.size(0)); + if (predicts.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'predicts' to be of type int (torch.int32)."); + } + if (accept_index.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'accept_index' to be of type int (torch.int32)."); + } + if (accept_token_num.scalar_type() != at::kInt) { + throw std::runtime_error("Expected 'accept_token_num' to be of type int (torch.int32)."); + } + if (candidates.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'candidates' to be of type long (torch.int64)."); + } + if (retrive_index.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'retrive_index' to be of type long (torch.int64)."); + } + if (retrive_next_token.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'retrive_next_token' to be of type long (torch.int64)."); + } + if (retrive_next_sibling.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'retrive_next_sibling' to be of type long (torch.int64)."); + } + if (target_predict.scalar_type() != at::kLong) { + throw std::runtime_error("Expected 'target_predict' to be of type long (torch.int64)."); + } + + hipStream_t stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA(); + dim3 grid(batch_size); + dim3 block(1); + + hipLaunchKernelGGL(( VerifyTreeGreedy), dim3(grid), dim3(block), 0, stream, + static_cast(predicts.data_ptr()), + static_cast(accept_index.data_ptr()), + static_cast(accept_token_num.data_ptr()), + static_cast(candidates.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), + static_cast(target_predict.data_ptr()), + batch_size, + num_spec_step, + num_draft_tokens); +} diff --git a/sgl-kernel/include/hip/hip_act_and_mul_hip.cuh b/sgl-kernel/include/hip/hip_act_and_mul_hip.cuh new file mode 100644 index 000000000000..4b71d9213b20 --- /dev/null +++ b/sgl-kernel/include/hip/hip_act_and_mul_hip.cuh @@ -0,0 +1,89 @@ +// !!! This is a file automatically generated by hipify!!! +#include "hip/hip_runtime.h" +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "utils_hip.h" + +#define kBitsToLoad 128 +#define kBytesToLoad (kBitsToLoad / 8) + +// Adapted from +// [flashinfer::activation::act_and_mul_kernel](https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/include/flashinfer/activation.cuh#L29) + +namespace sgl_hip { +namespace activation { + +template +__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { + constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + const int64_t offset = token_idx * 2 * d; + +#pragma unroll 1 + for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { + sgl_hip::vec_t x_vec, y_vec, out_vec; + x_vec.cast_load(input + offset + idx * vec_size); + y_vec.cast_load(input + offset + d + idx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + out_vec[i] = Activation(x_vec[i]) * y_vec[i]; + } + out_vec.cast_store(out + token_idx * d + idx * vec_size); + } + + const int64_t remaining_offset = d - d % (stride * vec_size); + // process the remaining elements +#pragma unroll 1 + for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { + T x = input[offset + remaining_offset + idx], y = input[offset + remaining_offset + d + idx]; + out[token_idx * d + remaining_offset + idx] = Activation(x) * y; + } +} + +template +__global__ void act_only_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { + constexpr uint32_t vec_size = kBytesToLoad / sizeof(T); + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + const int64_t offset = token_idx * d; + +#pragma unroll 1 + for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { + sgl_hip::vec_t x_vec, y_vec, out_vec; + x_vec.cast_load(input + offset + idx * vec_size); +#pragma unroll + for (uint32_t i = 0; i < vec_size; ++i) { + out_vec[i] = Activation(x_vec[i]); + } + out_vec.cast_store(out + token_idx * d + idx * vec_size); + } + + const int64_t remaining_offset = d - d % (stride * vec_size); + // process the remaining elements +#pragma unroll 1 + for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { + T x = input[offset + remaining_offset + idx]; + out[token_idx * d + remaining_offset + idx] = Activation(x); + } +} + +} // namespace activation +} // namespace sgl_hip diff --git a/sgl-kernel/include/utils_hip.h b/sgl-kernel/include/utils_hip.h new file mode 100644 index 000000000000..1c7f2f254867 --- /dev/null +++ b/sgl-kernel/include/utils_hip.h @@ -0,0 +1,470 @@ +// !!! This is a file automatically generated by hipify!!! +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include +#include +#include + +#ifdef USE_ROCM +#include +#endif + +#ifdef USE_ROCM +// Adapted from flashinfer-rocm [PR#491](https://github.com/flashinfer-ai/flashinfer/pull/491) +#define _DISPATCH_CASE_F16(c_type, ...) \ + case at::ScalarType::Half: { \ + using c_type = __half; \ + return __VA_ARGS__(); \ + } + +#define _DISPATCH_CASE_BF16(c_type, ...) \ + case at::ScalarType::BFloat16: { \ + using c_type = __hip_bfloat16; \ + return __VA_ARGS__(); \ + } +#endif // USE_ROCM + +#ifndef USE_ROCM +// Adapt from FlashInfer +#ifdef FLASHINFER_ENABLE_F16 +#define _DISPATCH_CASE_F16(c_type, ...) \ + case at::ScalarType::Half: { \ + using c_type = nv_half; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_F16(c_type, ...) +#endif // FLASHINFER_ENABLE_F16 + +#ifdef FLASHINFER_ENABLE_BF16 +#define _DISPATCH_CASE_BF16(c_type, ...) \ + case at::ScalarType::BFloat16: { \ + using c_type = nv_bfloat16; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_BF16(c_type, ...) +#endif // FLASHINFER_ENABLE_BF16 + +#ifdef FLASHINFER_ENABLE_FP8_E4M3 +#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \ + case at::ScalarType::Float8_e4m3fn: { \ + using c_type = __nv_fp8_e4m3; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) +#endif // FLASHINFER_ENABLE_FP8_E4M3 + +#ifdef FLASHINFER_ENABLE_FP8_E5M2 +#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \ + case at::ScalarType::Float8_e5m2: { \ + using c_type = __nv_fp8_e5m2; \ + return __VA_ARGS__(); \ + } +#else +#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) +#endif // FLASHINFER_ENABLE_FP8_E5M2 + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_SWITCH(var_name, cond, ...) \ + [&]() -> bool { \ + switch (cond) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \ + [&]() -> bool { \ + switch (pack_u16(cond1, cond2)) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" << int(cond1) << ", " \ + << int(cond2) << ")"; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_CASE(case_expr, case_var, ...) \ + case case_expr: { \ + constexpr auto case_var = case_expr; \ + return __VA_ARGS__(); \ + } + +#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \ + case pack_u16(case_expr1, case_expr2): { \ + constexpr auto case_var1 = case_expr1; \ + constexpr auto case_var2 = case_expr2; \ + return __VA_ARGS__(); \ + } + +#define DISPATCH_BOOL(expr, const_expr, ...) \ + [&]() -> bool { \ + if (expr) { \ + constexpr bool const_expr = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + return __VA_ARGS__(); \ + } \ + }() + +inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) { + TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim()); + for (int i = 0; i < a.dim(); ++i) { + TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")"); + } +} + +inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { + return (uint32_t(a) << 16) | uint32_t(b); +} + +#define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \ + TORCH_CHECK( \ + num_qo_heads % num_kv_heads == 0, \ + "num_qo_heads(", \ + num_qo_heads, \ + ") must be divisible by num_kv_heads(", \ + num_kv_heads, \ + ")") + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension") + +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) +#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_LAST_DIM_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) + +#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) + +inline bool is_float8_tensor(const at::Tensor& tensor) { + return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || tensor.scalar_type() == at::ScalarType::Float8_e5m2; +} +#endif // USE_ROCM + +struct cuda_error : public std::runtime_error { + /** + * @brief Constructs a `cuda_error` object with the given `message`. + * + * @param message The error char array used to construct `cuda_error` + */ + cuda_error(const char* message) : std::runtime_error(message) {} + /** + * @brief Constructs a `cuda_error` object with the given `message` string. + * + * @param message The `std::string` used to construct `cuda_error` + */ + cuda_error(std::string const& message) : cuda_error{message.c_str()} {} +}; + +#define CHECK_CUDA_SUCCESS(cmd) \ + do { \ + hipError_t e = cmd; \ + if (e != hipSuccess) { \ + std::stringstream _message; \ + auto s = hipGetErrorString(e); \ + _message << std::string(s) + "\n" << __FILE__ << ':' << __LINE__; \ + throw cuda_error(_message.str()); \ + } \ + } while (0) + +#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CUDA_INPUT(x) \ + CHECK_IS_CUDA(x); \ + CHECK_IS_CONTIGUOUS(x) + +inline int getSMVersion() { + int device{-1}; + CHECK_CUDA_SUCCESS(hipGetDevice(&device)); + int sm_major = 0; + int sm_minor = 0; + CHECK_CUDA_SUCCESS(hipDeviceGetAttribute(&sm_major, hipDeviceAttributeComputeCapabilityMajor, device)); + CHECK_CUDA_SUCCESS(hipDeviceGetAttribute(&sm_minor, hipDeviceAttributeComputeCapabilityMinor, device)); + return sm_major * 10 + sm_minor; +} + +inline bool isDeviceType(const std::string& device_type) { + int deviceCount; + CHECK_CUDA_SUCCESS(hipGetDeviceCount(&deviceCount)); + + int device_id = -1; + if (deviceCount >= 1) { + CHECK_CUDA_SUCCESS(hipGetDevice(&device_id)); + } else { + return false; + } + + hipDeviceProp_t prop; + CHECK_CUDA_SUCCESS(hipGetDeviceProperties(&prop, device_id)); + if (device_type == std::string(prop.name)) { + return true; + } + return false; +} + +inline bool getBoolEnv(char const* name) { + char const* env = std::getenv(name); + return env && env[0] == '1' && env[1] == '\0'; +} + +inline bool getEnvEnablePDL() { + static std::once_flag flag; + static bool enablePDL = false; + std::call_once(flag, [&]() { + if (getSMVersion() >= 90) { + // PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1` + enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL"); + } + }); + return enablePDL; +} + +// SGLANG_SHFL_XOR_* adapted from https://github.com/vllm-project/vllm/blob/v0.7.3/csrc/cuda_compat.h#L19-L28 +#ifndef USE_ROCM +#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor_sync((mask), (var), (lane_mask)) +#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor_sync((mask), (var), (lane_mask), (width)) +#else +#define SGLANG_SHFL_XOR_SYNC(mask, var, lane_mask) __shfl_xor((var), (lane_mask)) +#define SGLANG_SHFL_XOR_SYNC_WIDTH(mask, var, lane_mask, width) __shfl_xor((var), (lane_mask), (width)) +#endif + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Float: { \ + using c_type = float; \ + return __VA_ARGS__(); \ + } \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define DISPATCH_CASE_FLOAT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define DISPATCH_FLOAT_TYPES(TYPE, NAME, ...) AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_FLOAT_TYPES(__VA_ARGS__)) + +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) + +#ifndef USE_ROCM +#define WARP_SIZE 32 +#else +#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) +#define WARP_SIZE 64 +#else +#define WARP_SIZE 32 +#endif +#endif + +#ifdef USE_ROCM + +#include "hip/hip_math_def.h" +#include "hip/hip_vec_dtypes.h" + +#else + +template +__device__ __forceinline__ float castToFloat(srcDtype val) { + return static_cast(val); +} + +template +__device__ __forceinline__ dstDtype castFromFloat(float val) { + return static_cast(val); +} + +#endif + +// add FP8 support +#ifndef USE_ROCM +#include +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); +#else // USE_ROCM +#if HIP_FP8_TYPE_FNUZ +#include +using FP8_TYPE = c10::Float8_e4m3fnuz; +constexpr auto FP8_E4M3_MAX = 224.0f; +#else +#if HIP_FP8_TYPE_E4M3 +#include +using FP8_TYPE = c10::Float8_e4m3fn; +C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); +#else +#error "fp8 is not supported in this processor (arch < gfx942)." +#endif // HIP_FP8_TYPE_E4M3 +#endif // HIP_FP8_TYPE_FNUZ +#endif // USE_ROCM + +#define FULL_MASK 0xffffffff + +__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) { +#ifndef USE_ROCM + float old; + old = (value >= 0) ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); + return old; +#else + int* addr_as_i = (int*)addr; + int old = *addr_as_i, assumed; + do { + assumed = old; + old = atomicCAS(addr_as_i, assumed, __float_as_int(fmaxf(value, __int_as_float(assumed)))); + } while (assumed != old); + return __int_as_float(old); +#endif +} + +__device__ __forceinline__ float warpReduceMax(float value) { + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 16)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 8)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 4)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 2)); + value = fmaxf(value, __shfl_xor_sync(FULL_MASK, value, 1)); + return value; +} + +__device__ __forceinline__ float blockReduceMax(float value) { + static __shared__ float warpLevelMaxs[WARP_SIZE]; + const int laneId = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + + value = warpReduceMax(value); + + if (laneId == 0) warpLevelMaxs[warpId] = value; + __syncthreads(); + + value = (threadIdx.x < blockDim.x / WARP_SIZE) ? warpLevelMaxs[laneId] : 0; + if (warpId == 0) value = warpReduceMax(value); + + return value; +} + +// Pads to a multiple of `alignment` rows. +inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment = 4, bool is_column_major = false) { + int64_t rows = tensor.size(0); + int64_t cols = tensor.size(1); + int64_t pad_rows = (alignment - (rows % alignment)) % alignment; // Compute padding size + + if (pad_rows == 0) { + return tensor; // Already aligned + } + + torch::Tensor padding = torch::zeros({pad_rows, cols}, tensor.options()); + torch::Tensor tensor_padded = torch::cat({tensor, padding}, 0); // Pad along rows + + // Ensure column-major layout + if (is_column_major) { + return tensor_padded.t().contiguous().t(); + } + return tensor_padded; +} + +// Get the next power of 2 of a number +inline uint32_t next_pow2(uint32_t x) noexcept { + if (x <= 1) return 1; + return 1u << (32 - __builtin_clz(x - 1)); +} + +/* + * LDG Support + */ +#ifndef USE_ROCM +#define SGLANG_LDG(arg) __ldg(arg) +#else +#define SGLANG_LDG(arg) *(arg) +#endif diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 3e918ace89b4..40ca884a7e5f 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -1,17 +1,15 @@ [build-system] requires = [ + "setuptools>=75.0", "scikit-build-core>=0.10", "torch>=2.8.0", "wheel", ] -build-backend = "scikit_build_core.build" +build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" version = "0.3.21" -authors = [ - { name="Yineng Zhang", email="me@zhyncs.com" }, -] description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.10" @@ -32,11 +30,3 @@ exclude = [ "dist*", "tests*", ] - -[tool.scikit-build] -cmake.build-type = "Release" -minimum-version = "build-system.requires" - -wheel.py-api = "cp310" -wheel.license-files = [] -wheel.packages = ["python/sgl_kernel"] diff --git a/sgl-kernel/pyproject_rocm.toml b/sgl-kernel/pyproject_rocm.toml deleted file mode 100644 index 40ca884a7e5f..000000000000 --- a/sgl-kernel/pyproject_rocm.toml +++ /dev/null @@ -1,32 +0,0 @@ -[build-system] -requires = [ - "setuptools>=75.0", - "scikit-build-core>=0.10", - "torch>=2.8.0", - "wheel", -] -build-backend = "setuptools.build_meta" - -[project] -name = "sgl-kernel" -version = "0.3.21" -description = "Kernel Library for SGLang" -readme = "README.md" -requires-python = ">=3.10" -license = { file = "LICENSE" } -classifiers = [ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: Apache Software License", - "Environment :: GPU :: NVIDIA CUDA" -] -dependencies = [] - -[project.urls] -"Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel" -"Bug Tracker" = "https://github.com/sgl-project/sglang/issues" - -[tool.wheel] -exclude = [ - "dist*", - "tests*", -] diff --git a/sgl-model-gateway/LICENSE b/sgl-model-gateway/LICENSE deleted file mode 120000 index ea5b60640b01..000000000000 --- a/sgl-model-gateway/LICENSE +++ /dev/null @@ -1 +0,0 @@ -../LICENSE \ No newline at end of file diff --git a/sgl-model-gateway/LICENSE b/sgl-model-gateway/LICENSE new file mode 100644 index 000000000000..9c422689c8f5 --- /dev/null +++ b/sgl-model-gateway/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023-2024 SGLang Team + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License.