diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile new file mode 100644 index 000000000000..0c061cd1871a --- /dev/null +++ b/.devcontainer/Dockerfile @@ -0,0 +1,35 @@ +From lmsysorg/sglang:dev + +# Create non-root user with specified UID and GID +# NOTE: Replace with your own UID and GID. This is a workaround from https://github.com/microsoft/vscode-remote-release/issues/49#issuecomment-489060908. +ARG HOST_UID=1003 +ARG HOST_GID=1003 +RUN groupadd -g $HOST_GID devuser && \ + useradd -m -u $HOST_UID -g $HOST_GID -s /bin/zsh devuser + +# Give devuser sudo access +RUN apt-get update && apt-get install -y sudo && \ + echo "devuser ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/devuser && \ + rm -rf /var/lib/apt/lists/* && \ + apt-get clean + +# Set up oh-my-zsh for devuser +RUN cp -r /root/.oh-my-zsh /home/devuser/.oh-my-zsh && \ + cp /root/.zshrc /home/devuser/.zshrc && \ + cp /root/.vimrc /home/devuser/.vimrc && \ + cp /root/.tmux.conf /home/devuser/.tmux.conf && \ + sed -i 's|/root/.oh-my-zsh|/home/devuser/.oh-my-zsh|g' /home/devuser/.zshrc && \ + chown -R devuser:devuser /home/devuser/ + +# Set workspace directory and ownership +WORKDIR /sgl-workspace/sglang +RUN chown -R devuser:devuser /sgl-workspace + +# Switch to devuser +USER devuser + +# Install uv +RUN curl -LsSf https://astral.sh/uv/install.sh | sh + +# Install rust +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index aee285898644..5767aa2631a4 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,8 +1,9 @@ { "name": "sglang", "build": { - "dockerfile": "../docker/Dockerfile.dev" + "dockerfile": "Dockerfile" }, + "remoteUser": "devuser", "customizations": { "vscode": { "extensions": [ @@ -15,6 +16,9 @@ ] } }, - "workspaceFolder": "/sgl-workspace/sglang", - "forwardPorts": [] + "forwardPorts": [], + "runArgs": [ + "--gpus", + "all" + ] } diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 8e73727a0936..5493c4201c41 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -10,6 +10,7 @@ ## Checklist -- [ ] Format your code according to the [Contributor Guide](https://github.com/sgl-project/sglang/blob/main/docs/references/contribution_guide.md). -- [ ] Add unit tests as outlined in the [Contributor Guide](https://github.com/sgl-project/sglang/blob/main/docs/references/contribution_guide.md). -- [ ] Update documentation as needed, including docstrings or example tutorials. +- [ ] Format your code according to the [Code Formatting with Pre-Commit](https://docs.sglang.ai/references/contribution_guide.html#code-formatting-with-pre-commit). +- [ ] Add unit tests as outlined in the [Running Unit Tests](https://docs.sglang.ai/references/contribution_guide.html#running-unit-tests-adding-to-ci). +- [ ] Update documentation / docstrings / example tutorials as needed, according to [Writing Documentation](https://docs.sglang.ai/references/contribution_guide.html#writing-documentation-running-docs-ci). +- [ ] Provide throughput / latency benchmark results and accuracy evaluation results as needed, according to [Benchmark and Profiling](https://docs.sglang.ai/references/benchmark_and_profiling.html). diff --git a/.github/workflows/pr-test-rust.yml b/.github/workflows/pr-test-rust.yml index 928d0efa5b34..277ddef774e9 100644 --- a/.github/workflows/pr-test-rust.yml +++ b/.github/workflows/pr-test-rust.yml @@ -40,7 +40,7 @@ jobs: cd sgl-router/ cargo test - e2e-rust: + e2e-python: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' runs-on: 2-gpu-runner steps: @@ -65,7 +65,7 @@ jobs: python3 run_suite.py finish: - needs: [unit-test-rust, e2e-rust] + needs: [unit-test-rust, e2e-python] runs-on: ubuntu-latest steps: - name: Finish diff --git a/.github/workflows/pr-test-sgl-kernel.yml b/.github/workflows/pr-test-sgl-kernel.yml new file mode 100644 index 000000000000..65e452369617 --- /dev/null +++ b/.github/workflows/pr-test-sgl-kernel.yml @@ -0,0 +1,99 @@ +name: PR Test (sgl-kernel) + +on: + push: + branches: [ main ] + paths: + - "sgl-kernel/**" + pull_request: + branches: [ main ] + paths: + - "sgl-kernel/**" + workflow_dispatch: + +concurrency: + group: pr-test-sgl-kernel-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Check clang-format + uses: DoozyX/clang-format-lint-action@v0.18.1 + with: + source: sgl-kernel + extensions: h,c,cpp,hpp,cu,cuh,cc + clangFormatVersion: 16 + style: file + + build-wheels: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9'] + cuda-version: ['12.4'] + + steps: + - uses: actions/checkout@v4 + with: + submodules: 'recursive' + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Build wheels for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }} + run: | + cd sgl-kernel + chmod +x ./build.sh + ./build.sh "${{ matrix.python-version }}" "${{ matrix.cuda-version }}" + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: wheel-python${{ matrix.python-version }}-cuda${{ matrix.cuda-version }} + path: sgl-kernel/dist/* + + unit-test: + if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + needs: build-wheels + runs-on: 1-gpu-runner + steps: + - uses: actions/checkout@v4 + + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + path: sgl-kernel/dist/ + merge-multiple: true + pattern: wheel-* + + - name: Install + run: | + pip3 install torch==2.5.1 && pip3 install pytest && pip3 install vllm==0.6.4.post1 + pip3 uninstall sgl-kernel -y || true + pip3 install sgl-kernel/dist/*whl --force-reinstall --no-deps + pip3 list | grep sgl-kernel + + - name: Run test + timeout-minutes: 30 + run: | + cd sgl-kernel + find tests -name "test_*.py" | xargs -n 1 python3 + + - name: Uninstall dependencies + run: | + pip3 uninstall sgl-kernel -y + + finish: + needs: [unit-test, lint] + runs-on: ubuntu-latest + steps: + - name: Finish + run: echo "This is an empty step to ensure that all jobs are completed." diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index f1c7871debb2..487dfb6612ba 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -29,7 +29,7 @@ concurrency: jobs: unit-test-frontend: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner steps: - name: Checkout code @@ -43,16 +43,18 @@ jobs: - name: Run test timeout-minutes: 10 + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} run: | cd test/lang python3 run_suite.py --suite per-commit unit-test-backend-1-gpu: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner strategy: matrix: - range: [0-6, 6-16, 16-23, 23-30, 30-100] + range: [0-6, 6-15, 15-22, 22-32, 32-40, 40-100] steps: - name: Checkout code uses: actions/checkout@v3 @@ -75,7 +77,7 @@ jobs: unit-test-backend-2-gpu: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 2-gpu-runner steps: - name: Checkout code @@ -87,18 +89,16 @@ jobs: run: | bash scripts/ci_install_dependency.sh - - name: Evaluate data parallelism accuracy (DP=2) + - name: Test data parallelism (DP=2) timeout-minutes: 10 run: | cd test/srt python3 test_data_parallelism.py - - name: Evaluate MLA accuracy (TP=2) + - name: Test data parallelism attention (DP=2) timeout-minutes: 10 run: | cd test/srt - python3 test_mla.py - python3 test_mla_fp8.py python3 test_dp_attention.py - name: Test update weights from distributed @@ -107,14 +107,14 @@ jobs: cd test/srt python3 test_update_weights_from_distributed.py - - name: Evaluate MoE EP accuracy (TP=2) + - name: Test expert parallelism (EP=2) timeout-minutes: 10 run: | cd test/srt python3 test_moe_ep.py performance-test-1-gpu-part-1: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner steps: - name: Checkout code @@ -130,7 +130,7 @@ jobs: timeout-minutes: 10 run: | cd test/srt - python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_default + python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_bs1 - name: Benchmark online latency timeout-minutes: 10 @@ -150,8 +150,15 @@ jobs: cd test/srt python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size + - name: Benchmark online latency (EAGLE) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_bench_serving.TestBenchServing.test_online_latency_eagle + + performance-test-1-gpu-part-2: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner steps: - name: Checkout code @@ -182,7 +189,7 @@ jobs: python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_default_fp8 performance-test-2-gpu: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 2-gpu-runner steps: - name: Checkout code @@ -198,7 +205,13 @@ jobs: timeout-minutes: 10 run: | cd test/srt - python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_default + python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_moe_tp2_bs1 + + - name: Benchmark single latency + torch.compile (TP=2) + timeout-minutes: 10 + run: | + cd test/srt + python3 -m unittest test_bench_one_batch.TestBenchOneBatch.test_torch_compile_tp2_bs1 - name: Benchmark offline throughput (TP=2) timeout-minutes: 10 @@ -212,8 +225,9 @@ jobs: cd test/srt python3 -m unittest test_bench_serving.TestBenchServing.test_moe_offline_throughput_without_radix_cache + accuracy-test-1-gpu: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 1-gpu-runner steps: - name: Checkout code @@ -237,7 +251,7 @@ jobs: accuracy-test-2-gpu: - if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' + if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && github.event.pull_request.draft == false runs-on: 2-gpu-runner steps: - name: Checkout code diff --git a/.github/workflows/release-docker-amd.yml b/.github/workflows/release-docker-amd.yml index 866cc5fa5209..228eecdb9c5c 100644 --- a/.github/workflows/release-docker-amd.yml +++ b/.github/workflows/release-docker-amd.yml @@ -10,19 +10,27 @@ on: jobs: publish: if: github.repository == 'sgl-project/sglang' - runs-on: docker-builder-amd + runs-on: amd-docker environment: 'prod' strategy: matrix: rocm_version: ['6.2.0'] build_type: ['all', 'srt'] steps: - - name: Delete huge unnecessary tools folder - run: rm -rf /opt/hostedtoolcache - - name: Checkout repository uses: actions/checkout@v3 + - name: Free disk space + uses: jlumbroso/free-disk-space@main + with: + tool-cache: false + docker-images: false + android: true + dotnet: true + haskell: true + large-packages: true + swap-storage: false + - name: Login to Docker Hub uses: docker/login-action@v2 with: diff --git a/.github/workflows/release-docs.yml b/.github/workflows/release-docs.yml index ab2129e3721a..37db70c7c4be 100644 --- a/.github/workflows/release-docs.yml +++ b/.github/workflows/release-docs.yml @@ -39,7 +39,7 @@ jobs: - name: Execute notebooks and push to documents env: - GITHUB_TOKEN: ${{ secrets.PAT_TOKEN }} + GITHUB_TOKEN: ${{ secrets.DOCUMENTATION_PAT_TOKEN }} run: | cd docs make clean @@ -49,7 +49,7 @@ jobs: cd _build/html git clone https://$GITHUB_TOKEN@github.com/sgl-project/sgl-project.github.io.git ../sgl-project.github.io --depth 1 - rm -rf ../sgl-project.github.io/* + find ../sgl-project.github.io/ -mindepth 1 -not -path "../sgl-project.github.io/.git*" -not -name CNAME -not -name ".jekyll" -not -name ".nojekyll" -delete cp -r * ../sgl-project.github.io cp ../../README.md ../sgl-project.github.io/README.md cd ../sgl-project.github.io diff --git a/.github/workflows/release-pypi-kernel.yml b/.github/workflows/release-pypi-kernel.yml index f046538a6fad..af34c8423ce7 100644 --- a/.github/workflows/release-pypi-kernel.yml +++ b/.github/workflows/release-pypi-kernel.yml @@ -5,7 +5,7 @@ on: branches: - main paths: - - sgl-kernel/pyproject.toml + - sgl-kernel/version.py workflow_dispatch: concurrency: @@ -14,14 +14,17 @@ concurrency: jobs: build-wheels: + if: github.repository == 'sgl-project/sglang' runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] - cuda-version: ['12.1'] + python-version: ['3.9'] + cuda-version: ['12.4'] steps: - uses: actions/checkout@v4 + with: + submodules: 'recursive' - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 diff --git a/.github/workflows/release-pypi-router.yml b/.github/workflows/release-pypi-router.yml index df20c211cb3e..547522e8aa6c 100644 --- a/.github/workflows/release-pypi-router.yml +++ b/.github/workflows/release-pypi-router.yml @@ -7,7 +7,7 @@ on: branches: - main paths: - - sglang-router/pyproject.toml + - sgl-router/pyproject.toml workflow_dispatch: jobs: @@ -26,9 +26,9 @@ jobs: with: path: sglang-repo - - name: Move sglang-router folder to root and delete sglang-repo + - name: Move sgl-router folder to root and delete sglang-repo run: | - mv sglang-repo/sglang-router/* . + mv sglang-repo/sgl-router/* . rm -rf sglang-repo ls -alt @@ -69,9 +69,9 @@ jobs: with: path: sglang-repo - - name: Move sglang-router folder to root, copy the license file, and delete sglang-repo + - name: Move sgl-router folder to root, copy the license file, and delete sglang-repo run: | - mv sglang-repo/sglang-router/* . + mv sglang-repo/sgl-router/* . mv sglang-repo/LICENSE . rm -rf sglang-repo ls -alt @@ -84,6 +84,7 @@ jobs: - name: Build SDist run: | pip install build + python -m pip install -U packaging python -m build --sdist - uses: actions/upload-artifact@v4 diff --git a/.github/workflows/release-whl-kernel.yml b/.github/workflows/release-whl-kernel.yml new file mode 100644 index 000000000000..70c451778fa4 --- /dev/null +++ b/.github/workflows/release-whl-kernel.yml @@ -0,0 +1,92 @@ +name: Release SGLang Kernel Wheel (cu118) + +on: + workflow_dispatch: + inputs: + tag_name: + type: string + push: + branches: + - main + paths: + - sgl-kernel/version.py + +jobs: + build-wheels: + if: github.repository == 'sgl-project/sglang' + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9'] + cuda-version: ['11.8'] + + steps: + - uses: actions/checkout@v4 + with: + submodules: 'recursive' + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Build wheels for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }} + run: | + cd sgl-kernel + chmod +x ./build.sh + ./build.sh "${{ matrix.python-version }}" "${{ matrix.cuda-version }}" + + - name: Upload artifacts + uses: actions/upload-artifact@v4 + with: + name: wheel-python${{ matrix.python-version }}-cuda${{ matrix.cuda-version }} + path: sgl-kernel/dist/* + + release: + needs: build-wheels + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Download artifacts + uses: actions/download-artifact@v4 + with: + path: sgl-kernel/dist/ + merge-multiple: true + pattern: wheel-* + + - name: Set tag name + id: set_tag_name + run: | + if [ -z "${{ inputs.tag_name }}" ]; then + TAG_NAME="v$(cat sgl-kernel/version.py | cut -d'"' -f2)" + echo "tag_name=$TAG_NAME" >> $GITHUB_OUTPUT + else + echo "tag_name=${{ inputs.tag_name }}" >> $GITHUB_OUTPUT + fi + + - name: Release + uses: softprops/action-gh-release@v2 + with: + tag_name: ${{ steps.set_tag_name.outputs.tag_name }} + repository: sgl-project/whl + token: ${{ secrets.WHL_TOKEN }} + files: | + sgl-kernel/dist/* + + - name: Clone wheel index + run: git clone https://oauth2:${WHL_TOKEN}@github.com/sgl-project/whl.git sgl-whl + env: + WHL_TOKEN: ${{ secrets.WHL_TOKEN }} + + - name: Update wheel index + run: python3 scripts/update_kernel_whl_index.py + + - name: Push wheel index + run: | + cd sgl-whl + git config --local user.name "github-actions[bot]" + git config --local user.email "41898282+github-actions[bot]@users.noreply.github.com" + git add -A + git commit -m "update whl index" + git push diff --git a/.gitignore b/.gitignore index 73fd52992c28..75e29fac373a 100644 --- a/.gitignore +++ b/.gitignore @@ -222,3 +222,8 @@ work_dirs/ compile_commands.json *.iml + +# VSCode +.vscode + +1 diff --git a/.gitmodules b/.gitmodules index 3a14f6297a3a..97f3421449d3 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,12 @@ [submodule "sgl-kernel/3rdparty/cutlass"] path = sgl-kernel/3rdparty/cutlass url = https://github.com/NVIDIA/cutlass.git +[submodule "sgl-kernel/3rdparty/cccl"] + path = sgl-kernel/3rdparty/cccl + url = https://github.com/NVIDIA/cccl.git +[submodule "sgl-kernel/3rdparty/flashinfer"] + path = sgl-kernel/3rdparty/flashinfer + url = https://github.com/flashinfer-ai/flashinfer.git +[submodule "sgl-kernel/3rdparty/turbomind"] + path = sgl-kernel/3rdparty/turbomind + url = https://github.com/InternLM/turbomind diff --git a/3rdparty/amd/profiling/PROFILING.md b/3rdparty/amd/profiling/PROFILING.md index 79bc75b503bc..7e15ec844f2b 100644 --- a/3rdparty/amd/profiling/PROFILING.md +++ b/3rdparty/amd/profiling/PROFILING.md @@ -336,7 +336,7 @@ loadTracer.sh python3 -m sglang.launch_server \ --model-path /sgl-workspace/sglang/dummy_grok1 \ --tokenizer-path Xenova/grok-1-tokenizer \ --load-format dummy \ - --quant fp8 \ + --quantization fp8 \ --tp 8 \ --port 30000 \ --disable-radix-cache 2>&1 | tee "$LOGFILE" diff --git a/3rdparty/amd/profiling/server.sh b/3rdparty/amd/profiling/server.sh index aa574f64c940..f877e6c7acd4 100755 --- a/3rdparty/amd/profiling/server.sh +++ b/3rdparty/amd/profiling/server.sh @@ -14,7 +14,7 @@ loadTracer.sh python3 -m sglang.launch_server \ --model-path /sgl-workspace/sglang/dummy_grok1 \ --tokenizer-path Xenova/grok-1-tokenizer \ --load-format dummy \ - --quant fp8 \ + --quantization fp8 \ --tp 8 \ --port 30000 \ --disable-radix-cache 2>&1 | tee "$LOGFILE" diff --git a/3rdparty/amd/tuning/TUNING.md b/3rdparty/amd/tuning/TUNING.md index a38a16d4f7a5..0638041c9743 100644 --- a/3rdparty/amd/tuning/TUNING.md +++ b/3rdparty/amd/tuning/TUNING.md @@ -104,7 +104,7 @@ To maximize moe kernel efficiency, need to use below scripts to find out the bes ```bash #Tuning -#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quant fp" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run). +#for example, we have one case like this "python3 -m sglang.bench_latency --model dummy_grok1/ --load-format dummy --tokenizer-path Xenova/grok-1-tokenizer --tp 8 --batch-size 32 --input 1024 --output 8 --attention-backend triton --sampling-backend pytorch --quantization fp8" to run, it defined batch-size 32 input lenth 1024 and output length 8, from "--batch" in moe view point, the prefill batch is 32*1024 = 32768, the decode batch is 32*1(only one output token generated in each run). #so we can tune decode moe use below command python benchmark_moe_rocm.py --model grok1 --tp-size 8 --dtype float8 --batch "32" # and use this command to tune prefill moe diff --git a/3rdparty/amd/tuning/benchmark_moe_rocm.py b/3rdparty/amd/tuning/benchmark_moe_rocm.py index a3f26e8e5028..5aff8c0d664e 100644 --- a/3rdparty/amd/tuning/benchmark_moe_rocm.py +++ b/3rdparty/amd/tuning/benchmark_moe_rocm.py @@ -10,7 +10,10 @@ from tqdm import tqdm from transformers import AutoConfig -from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe, get_config_file_name +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( + fused_moe, + get_config_file_name, +) padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 diff --git a/README.md b/README.md index 97ad1e935c68..1165826c559b 100644 --- a/README.md +++ b/README.md @@ -13,9 +13,9 @@ -------------------------------------------------------------------------------- | [**Blog**](https://lmsys.org/blog/2024-07-25-sglang-llama3/) -| [**Documentation**](https://sgl-project.github.io/) -| [**Join Slack**](https://join.slack.com/t/sgl-fru7574/shared_invite/zt-2um0ad92q-LkU19KQTxCGzlCgRiOiQEw) -| [**Join Bi-Weekly Development Meeting**](https://docs.google.com/document/d/1xEow4eIM152xNcRxqZz9VEcOiTQo8-CEuuQ5qTmkt-E/edit?usp=sharing) +| [**Documentation**](https://docs.sglang.ai/) +| [**Join Slack**](https://slack.sglang.ai/) +| [**Join Bi-Weekly Development Meeting**](https://meeting.sglang.ai/) | [**Slides**](https://github.com/sgl-project/sgl-learning-materials?tab=readme-ov-file#slides) | ## News @@ -45,11 +45,11 @@ The core features include: - **Active Community**: SGLang is open-source and backed by an active community with industry adoption. ## Getting Started -- [Install SGLang](https://sgl-project.github.io/start/install.html) -- [Quick Start](https://sgl-project.github.io/start/send_request.html) -- [Backend Tutorial](https://sgl-project.github.io/backend/openai_api_completions.html) -- [Frontend Tutorial](https://sgl-project.github.io/frontend/frontend.html) -- [Contribution Guide](https://sgl-project.github.io/references/contribution_guide.html) +- [Install SGLang](https://docs.sglang.ai/start/install.html) +- [Quick Start](https://docs.sglang.ai/start/send_request.html) +- [Backend Tutorial](https://docs.sglang.ai/backend/openai_api_completions.html) +- [Frontend Tutorial](https://docs.sglang.ai/frontend/frontend.html) +- [Contribution Guide](https://docs.sglang.ai/references/contribution_guide.html) ## Benchmark and Performance Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-sglang-llama3/), [v0.3 blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/), [v0.4 blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/) @@ -58,8 +58,7 @@ Learn more in the release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s [Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487) ## Adoption and Sponsorship -The project is supported by (alphabetically): AMD, Baseten, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS.org, Meituan, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI. +The project is supported by (alphabetically): AMD, Baseten, Cursor, DataCrunch, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, LMSYS.org, Meituan, NVIDIA, RunPod, Stanford, UC Berkeley, UCLA, xAI, 01.AI. ## Acknowledgment and Citation -We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). -Please cite the paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful. +We learned the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql). Please cite the paper, [SGLang: Efficient Execution of Structured Language Model Programs](https://arxiv.org/abs/2312.07104), if you find the project useful. diff --git a/benchmark/blog_v0_2/405b_sglang.sh b/benchmark/blog_v0_2/405b_sglang.sh index 4e3372ae8c70..491853782805 100644 --- a/benchmark/blog_v0_2/405b_sglang.sh +++ b/benchmark/blog_v0_2/405b_sglang.sh @@ -6,7 +6,7 @@ # wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json # Launch sglang -# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87 +# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quantization fp8 --disable-radix --mem-frac 0.87 # offline python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11 diff --git a/benchmark/deepseek_v3/README.md b/benchmark/deepseek_v3/README.md index 15cf0b26a244..ea972831a368 100644 --- a/benchmark/deepseek_v3/README.md +++ b/benchmark/deepseek_v3/README.md @@ -4,6 +4,8 @@ The SGLang and DeepSeek teams collaborated to get DeepSeek V3 FP8 running on NVI Special thanks to Meituan's Search & Recommend Platform Team and Baseten's Model Performance Team for implementing the model, and DataCrunch for providing GPU resources. +For optimizations made on the DeepSeek series models regarding SGLang, please refer to [DeepSeek Model Optimizations in SGLang](https://docs.sglang.ai/references/deepseek.html). + ## Hardware Recommendation - 8 x NVIDIA H200 GPUs @@ -29,7 +31,7 @@ For high QPS scenarios, add the `--enable-dp-attention` argument to boost throug ### Using pip ```bash # Installation -pip install "sglang[all]>=0.4.1.post3" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer +pip install "sglang[all]>=0.4.1.post5" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer # Launch python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code @@ -37,7 +39,7 @@ python3 -m sglang.launch_server --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-r For high QPS scenarios, add the `--enable-dp-attention` argument to boost throughput. -### Example with OpenAI API +### Example: Sending requests with OpenAI API ```python3 import openai @@ -56,8 +58,11 @@ response = client.chat.completions.create( ) print(response) ``` -### Example serving with 2 H20*8 -For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `10.0.0.1`, and the second node's IP is `10.0.0.2`. + +### Example: Serving with two H20*8 nodes +For example, there are two H20 nodes, each with 8 GPUs. The first node's IP is `10.0.0.1`, and the second node's IP is `10.0.0.2`. Please **use the first node's IP** for both commands. + +If the command fails, try setting the `GLOO_SOCKET_IFNAME` parameter. For more information, see [Common Environment Variables](https://pytorch.org/docs/stable/distributed.html#common-environment-variables). ```bash # node 1 @@ -69,7 +74,7 @@ python -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3 --tp 16 --di If you have two H100 nodes, the usage is similar to the aforementioned H20. -### Example serving with Docker two H200*8 nodes +### Example: Serving with two H200*8 nodes and docker There are two H200 nodes, each with 8 GPUs. The first node's IP is `192.168.114.10`, and the second node's IP is `192.168.114.11`. Configure the endpoint to expose it to another Docker container using `--host 0.0.0.0` and `--port 40000`, and set up communications with `--dist-init-addr 192.168.114.10:20000`. A single H200 with 8 devices can run DeepSeek V3, the dual H200 setup is just to demonstrate multi-node usage. diff --git a/benchmark/gsm8k/bench_sglang.py b/benchmark/gsm8k/bench_sglang.py index 9fe9b79baaf8..f01734f0afb0 100644 --- a/benchmark/gsm8k/bench_sglang.py +++ b/benchmark/gsm8k/bench_sglang.py @@ -1,6 +1,7 @@ import argparse import ast import json +import os import re import time @@ -46,9 +47,11 @@ def main(args): set_default_backend(select_sglang_backend(args)) # Read data + data_path = args.data_path url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" - filename = download_and_cache_file(url) - lines = list(read_jsonl(filename)) + if not os.path.isfile(data_path): + data_path = download_and_cache_file(url) + lines = list(read_jsonl(data_path)) # Construct prompts num_questions = args.num_questions diff --git a/benchmark/hellaswag/bench_sglang.py b/benchmark/hellaswag/bench_sglang.py index f09d7256da93..798521f9766d 100644 --- a/benchmark/hellaswag/bench_sglang.py +++ b/benchmark/hellaswag/bench_sglang.py @@ -1,5 +1,6 @@ import argparse import json +import os import time import numpy as np @@ -31,9 +32,11 @@ def main(args): set_default_backend(select_sglang_backend(args)) # Read data + data_path = args.data_path url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" - filename = download_and_cache_file(url) - lines = list(read_jsonl(filename)) + if not os.path.isfile(data_path): + data_path = download_and_cache_file(url) + lines = list(read_jsonl(data_path)) # Construct prompts num_questions = args.num_questions diff --git a/benchmark/hicache/bench_multiturn.py b/benchmark/hicache/bench_multiturn.py new file mode 100644 index 000000000000..ab34c33da44e --- /dev/null +++ b/benchmark/hicache/bench_multiturn.py @@ -0,0 +1,334 @@ +import argparse +import asyncio +import json +import queue +import random +import threading +import time +from typing import Optional + +import aiohttp +import requests +from tqdm.asyncio import tqdm + +from sglang.bench_serving import ( + RequestFuncOutput, + get_tokenizer, + remove_prefix, + sample_random_requests, +) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Script to benchmark concurrent requests to a server." + ) + parser.add_argument( + "--num-clients", + type=int, + default=200, + help="Number of concurrent clients", + ) + parser.add_argument( + "--request-length", + type=int, + default=512, + help="Length of each new request", + ) + parser.add_argument( + "--output-length", + type=int, + default=64, + help="Length of each output", + ) + parser.add_argument( + "--num-rounds", + type=int, + default=5, + help="Number of rounds per client", + ) + parser.add_argument( + "--distribution", + type=str, + default="poisson", + choices=["poisson", "uniform"], + help="Distribution type for request intervals (poisson or uniform)", + ) + parser.add_argument( + "--request-rate", + type=float, + default=1.0, + help="Average number of requests per second", + ) + parser.add_argument( + "--host", + type=str, + default="localhost", + help="Server hostname or IP (default: localhost)", + ) + parser.add_argument( + "--port", + type=int, + default=30000, + help="Server port (default: 30000)", + ) + parser.add_argument( + "--model", + type=str, + default="meta-llama/Llama-3.1-8B-Instruct", + help="model path compatible with Hugging Face Transformers", + ) + return parser.parse_args() + + +async def async_request_sglang_generate( + payload, + url, + pbar: Optional[tqdm] = None, +): + """ + Sends a streaming request to the server. Gathers text token-by-token. + """ + async with aiohttp.ClientSession() as session: + headers = {} + generated_text = "" + ttft = 0.0 + st = time.perf_counter() + most_recent_timestamp = st + output = RequestFuncOutput() + + try: + async with session.post(url=url, json=payload, headers=headers) as response: + if response.status == 200: + async for chunk_bytes in response.content: + chunk_bytes = chunk_bytes.strip() + if not chunk_bytes: + continue + + chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ") + latency = time.perf_counter() - st + if chunk == "[DONE]": + pass + else: + data = json.loads(chunk) + + if data["text"]: + timestamp = time.perf_counter() + # First token + if ttft == 0.0: + ttft = time.perf_counter() - st + output.ttft = ttft + + # Decoding phase + else: + output.itl.append(timestamp - most_recent_timestamp) + + most_recent_timestamp = timestamp + generated_text = data["text"] + + output.generated_text = generated_text + output.success = True + output.latency = latency + else: + output.error = response.reason or "" + output.success = False + except Exception as e: + output.success = False + output.error = str(e) + print(f"Request failed: {e}") + + if pbar: + pbar.update(1) + return output + + +def gen_payload(prompt, output_len): + payload = { + "text": prompt, + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": output_len, + "ignore_eos": True, + }, + "stream": True, + "lora_path": "", + "return_logprob": False, + "logprob_start_len": -1, + } + return payload + + +class ReadyQueue: + """ + Thread-safe queue that can pop requests in different orders based on given policy. + """ + + def __init__(self, init_requests=None, policy="random"): + self.lock = threading.Lock() + self.requests = init_requests or [] + self.policy = policy + + def append(self, item): + with self.lock: + self.requests.append(item) + + def pop(self): + with self.lock: + if not self.requests: + return None + if self.policy == "random": + index = random.randrange(len(self.requests)) + return self.requests.pop(index) + elif self.policy == "fifo": + return self.requests.pop(0) + else: + # todo, varying thinking time of clients + raise ValueError(f"{self.policy} not implemented") + + +class WorkloadGenerator: + def __init__(self, args): + # Construct the base URL for requests + self.url = f"http://{args.host}:{args.port}/generate" + + self.tokenizer = get_tokenizer(args.model) + self.distribution = args.distribution + self.request_rate = args.request_rate + self.start_time = None + self.finished_time = None + + self.candidate_inputs = sample_random_requests( + input_len=args.request_length, + output_len=args.output_length, + num_prompts=args.num_clients * args.num_rounds, + range_ratio=1.0, + tokenizer=self.tokenizer, + dataset_path="", + ) + self.candidate_inputs = [i[0] for i in self.candidate_inputs] + + init_requests = [ + (i, gen_payload(self.candidate_inputs[i], args.output_length)) + for i in range(args.num_clients) + ] + self.client_records = { + i: {"round": 0, "history": init_requests[i][1]["text"]} + for i in range(args.num_clients) + } + self.ready_queue = ReadyQueue(init_requests=init_requests) + self.candidate_inputs = self.candidate_inputs[args.num_clients :] + + self.response_queue = queue.Queue() + self.pbar = tqdm(total=args.num_clients * args.num_rounds) + self.performance_metrics = {"ttft": [], "latency": []} + + async def handle_request(self, item): + try: + client_id, payload = item + response = await async_request_sglang_generate(payload, self.url, self.pbar) + if self.pbar.n == self.pbar.total: + self.finished_time = time.time() + self.response_queue.put((client_id, response)) + except Exception as e: + print(f"Request failed: {e}") + + def request_sender(self): + async def request_loop(): + while True: + # Calculate Poisson-distributed wait time + if self.distribution == "poisson": + sleep_time = random.expovariate(self.request_rate) + elif self.distribution == "uniform": + avg_interval = ( + 1.0 / self.request_rate if self.request_rate > 0 else 1.0 + ) + sleep_time = random.uniform(0, 2 * avg_interval) + else: + raise ValueError("Invalid distribution type") + await asyncio.sleep(sleep_time) # Wait before sending the next request + + new_request = self.ready_queue.pop() + # Submit async request + if new_request: + asyncio.create_task(self.handle_request(new_request)) + else: + if self.pbar.n == self.pbar.total: + break + + # Create and run the event loop for asynchronous requests + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(request_loop()) + loop.close() + + def response_handler(self): + while True: + try: + client_id, response = self.response_queue.get( + timeout=10 + ) # Block until response is available + if not response.success: + raise ValueError(f"Request failed with error: {response.error}") + self.client_records[client_id]["history"] += response.generated_text + self.client_records[client_id]["round"] += 1 + self.performance_metrics["ttft"].append(response.ttft) + self.performance_metrics["latency"].append(response.latency) + + if self.client_records[client_id]["round"] < args.num_rounds: + self.client_records[client_id][ + "history" + ] += self.candidate_inputs.pop() + self.ready_queue.append( + ( + client_id, + gen_payload( + self.client_records[client_id]["history"], + args.output_length, + ), + ) + ) + except queue.Empty: + if self.pbar.n == self.pbar.total: + break + + def run(self): + request_thread = threading.Thread(target=self.request_sender, daemon=True) + response_thread = threading.Thread(target=self.response_handler, daemon=True) + + self.start_time = time.time() + request_thread.start() + response_thread.start() + + request_thread.join() + response_thread.join() + + self.pbar.close() + print("All requests completed.") + print("Performance metrics summary:") + print( + f" Total requests: {len(self.performance_metrics['ttft'])} at {self.request_rate} requests per second" + ) + print( + f" Average TTFT: {sum(self.performance_metrics['ttft']) / len(self.performance_metrics['ttft']):.2f}" + ) + print( + f" Median TTFT: {sorted(self.performance_metrics['ttft'])[len(self.performance_metrics['ttft']) // 2]:.2f}" + ) + print( + f" Average latency: {sum(self.performance_metrics['latency']) / len(self.performance_metrics['latency']):.2f}" + ) + print( + f" Median latency: {sorted(self.performance_metrics['latency'])[len(self.performance_metrics['latency']) // 2]:.2f}" + ) + throughput = self.pbar.total / (self.finished_time - self.start_time) + print(f"Throughput: {throughput:.2f} requests per second") + + +if __name__ == "__main__": + args = parse_args() + flush_cache_url = f"http://{args.host}:{args.port}/flush_cache" + + for request_rate in range(1, 41, 2): + args.request_rate = request_rate + requests.post(flush_cache_url) + WorkloadGenerator(args).run() diff --git a/benchmark/kernels/fused_moe_triton/benchmark_moe_align_blocks.py b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py similarity index 90% rename from benchmark/kernels/fused_moe_triton/benchmark_moe_align_blocks.py rename to benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py index 92547ea95ae2..e2c4d8d35067 100644 --- a/benchmark/kernels/fused_moe_triton/benchmark_moe_align_blocks.py +++ b/benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py @@ -1,12 +1,13 @@ import argparse import itertools -import time import torch import triton import triton.language as tl from sgl_kernel import moe_align_block_size +USE_RANDOM_PERM = False + def ceil_div(a, b): return (a + b - 1) // b @@ -141,8 +142,13 @@ def moe_align_block_size_triton( def calculate_diff(batch_size, seq_len): num_experts = 256 block_size = 128 - topk_ids = torch.randint( - 0, num_experts, (batch_size, seq_len), dtype=torch.int32, device="cuda" + topk = 8 + + topk_ids = torch.stack( + [ + torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] + for _ in range(batch_size * seq_len) + ] ) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) @@ -169,7 +175,7 @@ def calculate_diff(batch_size, seq_len): expert_ids_triton = torch.empty_like(expert_ids_cuda) num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) - # 运行两个实现 + # compare the performance of cuda and triton implementation moe_align_block_size( topk_ids, num_experts, @@ -206,6 +212,15 @@ def calculate_diff(batch_size, seq_len): configs = list(itertools.product(batch_size_range, seq_length_range)) +def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: + topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="cuda") + for i in range(num_tokens): + topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="cuda")[ + :topk + ] + return topk_ids + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size", "seq_len"], @@ -222,9 +237,18 @@ def calculate_diff(batch_size, seq_len): def benchmark(batch_size, seq_len, provider): num_experts = 256 block_size = 128 - topk_ids = torch.randint( - 0, num_experts, (batch_size, seq_len), dtype=torch.int32, device="cuda" - ) + topk = 8 + + if USE_RANDOM_PERM: + topk_ids = get_topk_ids(batch_size * seq_len, num_experts, topk) + else: + topk_ids = torch.randint( + 0, + num_experts, + (batch_size * seq_len, topk), + dtype=torch.int32, + device="cuda", + ) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) sorted_ids = torch.empty( diff --git a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py new file mode 100644 index 000000000000..57fbcfddf2c1 --- /dev/null +++ b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_decode.py @@ -0,0 +1,577 @@ +import itertools +import math +import os +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange +from sgl_kernel import lightning_attention_decode as sgl_lightning_attention_decode + + +@triton.jit +def _decode_kernel( + Q, + K, + V, + KV, + Out, + S, + b: tl.constexpr, + h: tl.constexpr, + n: tl.constexpr, + d: tl.constexpr, + d_original: tl.constexpr, + e: tl.constexpr, + e_original: tl.constexpr, +): + off_bh = tl.program_id(0) + off_h = off_bh % h + + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + kv_offset = off_bh * d * e + + s = tl.load(S + off_h) + ratio = tl.exp(-s) + + d_idx = tl.arange(0, d) + e_idx = tl.arange(0, e) + + # Create masks for original dimensions + d_mask = d_idx < d_original + e_mask = e_idx < e_original + + # Load with masking + q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0) + k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0) + v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0) + + # Load KV with 2D masking + kv = tl.load( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + mask=(d_mask[:, None] & e_mask[None, :]), + other=0.0, + ) + + # Compute outer product using element-wise operations + k_v_prod = k[:, None] * v[None, :] + kv = ratio * kv + k_v_prod + + # Store KV with 2D masking + tl.store( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + kv.to(KV.dtype.element_ty), + mask=(d_mask[:, None] & e_mask[None, :]), + ) + + # Compute matrix-vector multiplication using element-wise operations and reduction + o = tl.sum(q[:, None] * kv, axis=0) + + # Store output with masking + tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask) + + +def lightning_attn_decode(q, k, v, kv, s): + """Triton implementation of Lightning Attention decode operation""" + b, h, n, d = q.shape + e = v.shape[-1] + assert n == 1, "Sequence length must be 1 in decode mode" + + # Get padded dimensions (power of 2) + d_padded = next_power_of_2(d) + e_padded = next_power_of_2(e) + + # Create output tensor (padded) + o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + + # Create padded tensors without actually padding the data + q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device) + k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device) + v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + kv_padded = torch.empty( + b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device + ) + + # Copy data to padded tensors + q_padded[..., :d] = q + k_padded[..., :d] = k + v_padded[..., :e] = v + kv_padded[..., :d, :e] = kv + + # Launch kernel + grid = (b * h, 1) + _decode_kernel[grid]( + q_padded, + k_padded, + v_padded, + kv_padded, + o_padded, + s, + b=b, + h=h, + n=n, + d=d_padded, + d_original=d, + e=e_padded, + e_original=e, + ) + + # Get unpadded outputs + o = o_padded[..., :e] + kv_out = kv_padded[..., :d, :e] + + return o, kv_out + + +def next_power_of_2(n): + return 2 ** (int(math.ceil(math.log(n, 2)))) + + +class MiniMaxText01LightningAttention(nn.Module): + def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs): + super().__init__() + if config is None: + config = type("Config", (), kwargs) + + bias = False + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + + self.out_proj = nn.Linear( + self.head_dim * self.num_heads, self.hidden_size, bias=bias + ) + self.act = get_activation_fn(config.hidden_act) + self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads) + + self.qkv_proj = nn.Linear( + self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias + ) + self.output_gate = nn.Linear( + self.hidden_size, self.head_dim * self.num_heads, bias=bias + ) + + # for inference only + self.offset = 0 + self.layer_idx = layer_idx + + def forward( + self, + hidden_states, + attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m) + output_attentions: bool = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + slope_rate: Optional[torch.Tensor] = None, + **kwargs, + ): + if (not self.training) and (not do_eval): + return self.inference( + hidden_states, + attn_mask, + output_attentions, + past_key_value, + use_cache, + slope_rate, + ) + + def inference( + self, + x, + attn_mask: Optional[torch.Tensor] = None, # (b, n) + output_attentions: bool = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1) + ): + # x: b n d + b, n, d = x.shape + # linear map + qkv = self.act(self.qkv_proj(x)) + new_shape = qkv.size()[:-1] + (self.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3) + q = q.transpose(1, 2) # [b, n, h, d] -> [b, h, n, d] + k = k.transpose(1, 2) # [b, n, h, d] -> [b, h, n, d] + v = v.transpose(1, 2) # [b, n, h, d] -> [b, h, n, e] + + self.offset += 1 + ratio = torch.exp(-slope_rate) # [h, 1, 1] + + # decode mode + kv = past_key_value # [b, h, d, e] + output = [] + for i in range(n): + # kv: [b, h, d, e] + # ratio: [h, 1, 1] + # k: [b, h, n, d] + # v: [b, h, n, e] + # k[:, :, i : i + 1]: [b, h, 1, d] + # v[:, :, i : i + 1]: [b, h, 1, e] + # ratio * kv: [b, h, d, e] + # torch.einsum( + # "... n d, ... n e -> ... d e", + # k[:, :, i : i + 1], + # v[:, :, i : i + 1], + # ) + # [b, h, d, e] + [b, h, d, e] -> [b, h, d, e] + kv = ratio * kv + torch.einsum( + "... n d, ... n e -> ... d e", + k[:, :, i : i + 1], + v[:, :, i : i + 1], + ) + # q[:, :, i : i + 1]: [b, h, 1, d] + # kv.to(q.dtype): [b, h, d, e] + # torch.einsum( + # "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype) + # ) + # [b, h, 1, d] * [b, h, d, e] -> [b, h, 1, e] + qkv = torch.einsum( + "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype) + ) + output.append(qkv) + output = torch.concat(output, dim=-2) + + # reshape + output = rearrange(output, "b h n d -> b n (h d)") + # normalize + output = self.norm(output) + # gate + output = F.sigmoid(self.output_gate(x)) * output + # outproj + output = self.out_proj(output) + + attn_weights = None + + return output, attn_weights, kv + + +def get_activation_fn(activation): + if activation == "gelu": + return F.gelu + elif activation == "relu": + return F.relu + elif activation == "elu": + return F.elu + elif activation == "sigmoid": + return F.sigmoid + elif activation == "exp": + + def f(x): + with torch.no_grad(): + x_max = torch.max(x, dim=-1, keepdims=True).values + y = torch.exp(x - x_max) + return y + + return f + elif activation == "leak": + return F.leaky_relu + elif activation == "1+elu": + + def f(x): + return 1 + F.elu(x) + + return f + elif activation == "2+elu": + + def f(x): + return 2 + F.elu(x) + + return f + elif activation == "silu" or activation == "swish": + return F.silu + elif activation == "sine": + return torch.sin + else: + return lambda x: x + + +class MiniMaxText01RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MiniMaxText01RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +def test_lightning_attention_implementations(model_params): + torch.manual_seed(42) + + batch_size = 64 + seq_len = 1 + dtype = torch.bfloat16 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + hidden_states = torch.randn( + batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device + ) + + attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device) + + slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device) + + model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device) + model_attn.eval() + + d = model_params["head_dim"] + past_kv = torch.randn( + batch_size, + model_params["num_attention_heads"], + d, + d, + device=device, + ) + with torch.no_grad(): + model_output, _, new_kv = model_attn.inference( + hidden_states, + attn_mask=attention_mask, + slope_rate=slope_rate, + past_key_value=past_kv, + ) + + qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) + new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + past_kv = past_kv.contiguous() + slope_rate = slope_rate.contiguous() + + # Test Triton implementation + triton_output, triton_new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate) + triton_output = triton_output.transpose(1, 2).contiguous() + triton_output = triton_output.view(batch_size, seq_len, -1) + triton_output = model_attn.norm(triton_output) + triton_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * triton_output + triton_output = model_attn.out_proj(triton_output) + + # Test SGL implementation + sgl_output = torch.empty_like(v) + sgl_new_kv = torch.empty_like(past_kv) + sgl_lightning_attention_decode(q, k, v, past_kv, slope_rate, sgl_output, sgl_new_kv) + + sgl_output = sgl_output.transpose(1, 2).contiguous() + sgl_output = sgl_output.view(batch_size, seq_len, -1) + sgl_output = model_attn.norm(sgl_output) + sgl_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * sgl_output + sgl_output = model_attn.out_proj(sgl_output) + + # Verify Triton implementation results + torch.testing.assert_close( + model_output, + triton_output, + rtol=1e-3, + atol=1e-2, + msg="Triton lightning attention implementation produces different output results", + ) + torch.testing.assert_close( + new_kv, + triton_new_kv, + rtol=1e-3, + atol=1e-2, + msg="Triton lightning attention implementation produces different kv results", + ) + + # Verify SGL implementation results + torch.testing.assert_close( + model_output, + sgl_output, + rtol=1e-3, + atol=1e-2, + msg="SGL lightning attention implementation produces different output results", + ) + torch.testing.assert_close( + new_kv, + sgl_new_kv, + rtol=1e-3, + atol=1e-2, + msg="SGL lightning attention implementation produces different kv results", + ) + + print("✅ All implementations match") + + +def _build_slope_tensor(n_attention_heads: int): + def get_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2(n) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(n)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + slopes = torch.tensor(get_slopes(n_attention_heads)).reshape( + n_attention_heads, 1, 1 + ) + return slopes + + +def get_benchmark(): + batch_size_range = [i for i in range(1, 33)] # max 32 + seq_length_range = [1] # decode mode sequence length is fixed to 1 + configs = list(itertools.product(batch_size_range, seq_length_range)) + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["Original", "Triton", "SGL"], + line_names=[ + "Original PyTorch Implementation", + "Triton Implementation", + "SGL Implementation", + ], + styles=[("blue", "-"), ("green", "-"), ("red", "-")], + ylabel="us", + plot_name="lightning-attention-decode-performance", + args={}, + ) + ) + def benchmark(batch_size, seq_len, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + + params = { + "hidden_size": 6144, + "num_attention_heads": 64, + "head_dim": 96, + "hidden_act": "gelu", + } + + hidden_states = torch.randn( + batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device + ) + + attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device) + + slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device) + model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device) + model_attn.eval() + + d = params["head_dim"] + past_kv = torch.randn( + batch_size, + params["num_attention_heads"], + d, + d, + device=device, + ) + + quantiles = [0.5, 0.2, 0.8] + if provider == "Original": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: model_attn.inference( + hidden_states, + attn_mask=attention_mask, + slope_rate=slope_rate, + past_key_value=past_kv, + ), + quantiles=quantiles, + ) + elif provider == "Triton": + + def run_triton(): + qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) + new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + output, new_kv = lightning_attn_decode(q, k, v, past_kv, slope_rate) + output = output.transpose(1, 2).contiguous() + output = output.view(batch_size, seq_len, -1) + output = model_attn.norm(output) + output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output + return model_attn.out_proj(output) + + ms, min_ms, max_ms = triton.testing.do_bench( + run_triton, + quantiles=quantiles, + ) + else: # SGL + + def run_sgl(): + qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) + new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1) + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + + output = torch.empty_like(v) + new_kv = torch.empty_like(past_kv) + sgl_lightning_attention_decode( + q, k, v, past_kv, slope_rate, output, new_kv + ) + + output = output.transpose(1, 2).contiguous() + output = output.view(batch_size, seq_len, -1) + output = model_attn.norm(output) + output = torch.sigmoid(model_attn.output_gate(hidden_states)) * output + return model_attn.out_proj(output) + + ms, min_ms, max_ms = triton.testing.do_bench( + run_sgl, + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/lightning_attention_decode/", + help="Path to save lightning attention decode benchmark results", + ) + args = parser.parse_args() + + params = { + "hidden_size": 6144, + "num_attention_heads": 64, + "head_dim": 96, + "hidden_act": "silu", + } + # Run correctness test first + # Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json + test_lightning_attention_implementations(params) + + # Run performance benchmark + benchmark = get_benchmark() + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_prefill.py b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_prefill.py new file mode 100644 index 000000000000..cd298487b590 --- /dev/null +++ b/benchmark/kernels/minmax-text-01-lightning_attention/benchmark_lightning_attention_prefill.py @@ -0,0 +1,603 @@ +import itertools +import math +import os +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +import triton +import triton.language as tl +from einops import rearrange + + +# Adapted from https://github.com/OpenNLPLab/lightning-attention/blob/main/lightning_attn/ops/triton/lightning_attn2.py +@triton.jit +def _fwd_kernel( + Q, + K, + V, + Out, + S, # log lambda + b: tl.constexpr, + h: tl.constexpr, + n: tl.constexpr, + d: tl.constexpr, + e: tl.constexpr, + BLOCK: tl.constexpr, + NUM_BLOCK: tl.constexpr, + BLOCK_MODEL: tl.constexpr, +): + ##### get offset + off_bh = tl.program_id(0) + off_h = off_bh % h + off_e = tl.program_id(1) + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + # channel offset + e_offset = off_e * BLOCK_MODEL + + ##### get block ptr + Q_block_ptr = Q + qk_offset + tl.arange(0, d)[None, :] + K_trans_block_ptr = K + qk_offset + tl.arange(0, d)[:, None] + V_block_ptr = V + v_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :] + O_block_ptr = Out + o_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :] + S_block_ptr = S + off_h + + ##### init diag decay(Lambda); q, k decay; kv + s = tl.load(S_block_ptr) + # q, k decay + off_block = tl.arange( + 0, BLOCK + ) # Not bug, this is a bit different from algorithm 1, but is mathematically equivalent + q_decay = tl.exp(-s.to(tl.float32) * off_block[:, None]) + k_trans_decay = tl.exp(-s.to(tl.float32) * (BLOCK - off_block[None, :])) + block_decay = tl.exp(-s.to(tl.float32) * BLOCK) + # diag decay + index = off_block[:, None] - off_block[None, :] + s_index = s * index + s_index = tl.where(index >= 0, -s_index, float("-inf")) + diag_decay = tl.exp(s_index) + kv = tl.zeros([d, BLOCK_MODEL], dtype=tl.float32) + + ##### compute + for i in range(NUM_BLOCK): + # load + q = tl.load( + Q_block_ptr + off_block[:, None] * d, mask=off_block[:, None] < n, other=0.0 + ).to(tl.float32) + k_trans = tl.load( + K_trans_block_ptr + off_block[None, :] * d, + mask=off_block[None, :] < n, + other=0.0, + ).to(tl.float32) + v = tl.load( + V_block_ptr + off_block[:, None] * e, mask=off_block[:, None] < n, other=0.0 + ).to(tl.float32) + + # compute + qk = tl.dot(q, k_trans) * diag_decay + o_intra = tl.dot(qk, v) + o_inter = tl.dot(q, kv) * q_decay + o = o_intra + o_inter + + # save and update + tl.store( + O_block_ptr + off_block[:, None] * e, + o.to(O_block_ptr.dtype.element_ty), + mask=off_block[:, None] < n, + ) + kv = block_decay * kv + tl.dot(k_trans * k_trans_decay, v) + off_block += BLOCK + + +def lightning_attn2(q, k, v, s): + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + s = s.contiguous() + + b, h, n, d = q.shape + e = v.shape[-1] + + # Pad d to next power of 2 + d_padded = next_power_of_2(d) + if d_padded != d: + q_padded = F.pad(q, (0, d_padded - d)) + k_padded = F.pad(k, (0, d_padded - d)) + else: + q_padded = q + k_padded = k + + # Pad e to next power of 2 + e_padded = next_power_of_2(e) + if e_padded != e: + v_padded = F.pad(v, (0, e_padded - e)) + else: + v_padded = v + + o_padded = torch.empty((b, h, n, e_padded), dtype=q.dtype, device=q.device) + + BLOCK = 64 + NUM_BLOCK = triton.cdiv(q.shape[2], BLOCK) + # parallel over channel + BLOCK_MODEL = min(triton.next_power_of_2(e_padded), 32) + grid = (b * h, triton.cdiv(e_padded, BLOCK_MODEL)) + + _fwd_kernel[grid]( + q_padded, + k_padded, + v_padded, + o_padded, + s, + b, + h, + n, + d_padded, + e_padded, + BLOCK=BLOCK, + NUM_BLOCK=NUM_BLOCK, + BLOCK_MODEL=BLOCK_MODEL, + ) + + # Remove padding from output + if e_padded != e: + o = o_padded[..., :e] + else: + o = o_padded + + return o + + +def is_support(dim): + return 16 % dim + + +def next_power_of_2(n): + return 2 ** (int(math.ceil(math.log(n, 2)))) + + +def lightning_attn_func(q, k, v, s): + b, h, n, d = q.shape + e = v.shape[-1] + assert is_support(d) and is_support(e) + + # pad v's feature dim to power of 2 + e_pad = next_power_of_2(e) + need_pad = e_pad != e + if need_pad: + v = F.pad(v, (0, e_pad - e)) + + if d > 128: + # split over head + if 64 % d: + m = 64 + elif 32 % d: + m = 32 + elif 16 % d: + m = 16 + arr = [m * i for i in range(d // m + 1)] + if arr[-1] != d: + arr.append(d) + n = len(arr) + o = 0 + for i in range(n - 1): + start = arr[i] + end = arr[i + 1] + q1 = q[..., start:end] + k1 = k[..., start:end] + o += lightning_attn2(q1, k1, v, s) + else: + o = lightning_attn2(q, k, v, s) + + if need_pad: + o = o[:, :, :, :e] + + return o + + +debug = eval(os.environ.get("debug", default="False")) + +BLOCK = 256 + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->MiniMaxText01 +class MiniMaxText01RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MiniMaxText01RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py +def get_activation_fn(activation): + if debug: + logger.info(f"activation: {activation}") + if activation == "gelu": + return F.gelu + elif activation == "relu": + return F.relu + elif activation == "elu": + return F.elu + elif activation == "sigmoid": + return F.sigmoid + elif activation == "exp": + + def f(x): + with torch.no_grad(): + x_max = torch.max(x, dim=-1, keepdims=True).values + y = torch.exp(x - x_max) + + return y + + return f + elif activation == "leak": + return F.leaky_relu + elif activation == "1+elu": + + def f(x): + return 1 + F.elu(x) + + return f + elif activation == "2+elu": + + def f(x): + return 2 + F.elu(x) + + return f + elif activation == "silu" or activation == "swish": + return F.silu + elif activation == "sine": + return torch.sin + else: + logger.info(f"activation: does not support {activation}, use Identity!!!") + return lambda x: x + + +# Copied from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/modeling_minimax_text_01.py +class MiniMaxText01LightningAttention(nn.Module): + def __init__(self, config=None, layer_idx: Optional[int] = None, **kwargs): + super().__init__() + if config is None: + config = type("Config", (), kwargs) + + bias = False + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) + + self.out_proj = nn.Linear( + self.head_dim * self.num_heads, self.hidden_size, bias=bias + ) + self.act = get_activation_fn(config.hidden_act) + self.norm = MiniMaxText01RMSNorm(self.head_dim * self.num_heads) + + self.qkv_proj = nn.Linear( + self.hidden_size, 3 * self.head_dim * self.num_heads, bias=bias + ) + self.output_gate = nn.Linear( + self.hidden_size, self.head_dim * self.num_heads, bias=bias + ) + + # for inference only + self.offset = 0 + self.layer_idx = layer_idx + + def forward( + self, + hidden_states, + attn_mask: Optional[torch.Tensor] = None, # (b, h, n, m) + output_attentions: bool = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + slope_rate: Optional[torch.Tensor] = None, + **kwargs, + ): + if (not self.training) and (not do_eval): + return self.inference( + hidden_states, + attn_mask, + output_attentions, + past_key_value, + use_cache, + slope_rate, + ) + + def inference( + self, + x, + attn_mask: Optional[torch.Tensor] = None, # (b, n) + output_attentions: bool = False, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + use_cache: bool = False, + slope_rate: Optional[torch.Tensor] = None, # (h, 1, 1) + ): + # x: b n d + b, n, d = x.shape + # linear map + qkv = self.act(self.qkv_proj(x)) + new_shape = qkv.size()[:-1] + (self.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [self.head_dim] * 3, dim=3) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if past_key_value is None: + self.offset = q.shape[-2] + else: + self.offset += 1 + + # for align with metaseq + ratio = torch.exp(-slope_rate) + + # only use for the first time + if past_key_value is None: + slope_rate = slope_rate.to(torch.float32) + if attn_mask is not None: + v = v.masked_fill( + (1 - attn_mask).unsqueeze(1).unsqueeze(-1).to(torch.bool), 0 + ) + NUM_BLOCK = (n + BLOCK - 1) // BLOCK + b, h, n, d = q.shape + e = v.shape[-1] + # other + array = torch.arange(BLOCK).to(q) + 1 + q_decay = torch.exp(-slope_rate * array.reshape(-1, 1)) + k_decay = torch.exp(-slope_rate * (BLOCK - array.reshape(-1, 1))) + index = array[:, None] - array[None, :] + s_index = ( + slope_rate + * index[ + None, + None, + ] + ) + s_index = torch.where(index >= 0, -s_index, float("-inf")) + diag_decay = torch.exp(s_index) + + kv = torch.zeros(b, h, d, e).to(torch.float32).to(q.device) + output = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device) + for i in range(NUM_BLOCK): + si = i * BLOCK + ei = min(si + BLOCK, n) + m = ei - si + qi = q[:, :, si:ei].contiguous() + ki = k[:, :, si:ei].contiguous() + vi = v[:, :, si:ei].contiguous() + qkv_none_diag = torch.matmul(qi * q_decay[:, :m], kv).to(torch.float32) + + # diag + qk = ( + torch.matmul(qi, ki.transpose(-1, -2)).to(torch.float32) + * diag_decay[:, :, :m, :m] + ) + qkv_diag = torch.matmul(qk, vi.to(torch.float32)) + block_decay = torch.exp(-slope_rate * m) + output[:, :, si:ei] = qkv_none_diag + qkv_diag + kv = block_decay * kv + torch.matmul( + (ki * k_decay[:, -m:]).transpose(-1, -2).to(vi.dtype), vi + ) + + else: + kv = past_key_value + output = [] + for i in range(n): + kv = ratio * kv + torch.einsum( + "... n d, ... n e -> ... d e", + k[:, :, i : i + 1], + v[:, :, i : i + 1], + ) + qkv = torch.einsum( + "... n e, ... e d -> ... n d", q[:, :, i : i + 1], kv.to(q.dtype) + ) + output.append(qkv) + output = torch.concat(output, dim=-2) + # reshape + output = rearrange(output, "b h n d -> b n (h d)") + # normalize + output = self.norm(output) + # gate + output = F.sigmoid(self.output_gate(x)) * output + # outproj + output = self.out_proj(output) + + attn_weights = None + + return output, attn_weights, kv + + +def _build_slope_tensor(n_attention_heads: int): + def get_slopes(n): + def get_slopes_power_of_2(n): + start = 2 ** (-(2 ** -(math.log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if math.log2(n).is_integer(): + return get_slopes_power_of_2( + n + ) # In the paper, we only train models that have 2^a heads for some a. This function has + else: # some good properties that only occur when the input is a power of 2. To maintain that even + closest_power_of_2 = 2 ** math.floor( + math.log2(n) + ) # when the number of heads is not a power of 2, we use this workaround. + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2] + ) + + # h, 1, 1 + slopes = torch.tensor(get_slopes(n_attention_heads)).reshape( + n_attention_heads, 1, 1 + ) + + return slopes + + +def test_lightning_attention_implementations(model_params): + torch.manual_seed(42) + + batch_size = 2 + seq_len = 1024 + dtype = torch.bfloat16 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + hidden_states = torch.randn( + batch_size, seq_len, model_params["hidden_size"], dtype=dtype, device=device + ) + + attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device) + + slope_rate = _build_slope_tensor(model_params["num_attention_heads"]).to(device) + + model_attn = MiniMaxText01LightningAttention(**model_params).to(dtype).to(device) + model_attn.eval() + + with torch.no_grad(): + model_output, _, _ = model_attn.inference( + hidden_states, attn_mask=attention_mask, slope_rate=slope_rate + ) + + qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) + new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + lib_output = lightning_attn_func(q, k, v, slope_rate) + lib_output = lib_output.transpose(1, 2).contiguous() + lib_output = lib_output.view(batch_size, seq_len, -1) + lib_output = model_attn.norm(lib_output) + lib_output = torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output + lib_output = model_attn.out_proj(lib_output) + + torch.testing.assert_close( + model_output, + lib_output, + rtol=1e-3, + atol=1e-2, + msg="Lightning attention implementations produce different results", + ) + + print("✅ Two implementations match") + + +def get_benchmark(): + batch_size_range = [2**i for i in range(0, 7)] # max 64 + seq_length_range = [256, 512, 1024, 2048, 4096] # max 4096 + configs = list(itertools.product(batch_size_range, seq_length_range)) + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "seq_len"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["MiniMax-Text-01", "OpenNLPLab"], + line_names=[ + "MiniMax-Text-01 Model Implementation", + "OpenNLPLab Library Implementation", + ], + styles=[("blue", "-"), ("green", "-")], + ylabel="us", + plot_name="lightning-attention-prefill-performance", + args={}, + ) + ) + def benchmark(batch_size, seq_len, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + + params = { + "hidden_size": 6144, + "num_attention_heads": 64, + "head_dim": 96, + "hidden_act": "gelu", + } + + hidden_states = torch.randn( + batch_size, seq_len, params["hidden_size"], dtype=dtype, device=device + ) + + attention_mask = torch.ones(batch_size, seq_len, dtype=dtype, device=device) + + slope_rate = _build_slope_tensor(params["num_attention_heads"]).to(device) + model_attn = MiniMaxText01LightningAttention(**params).to(dtype).to(device) + model_attn.eval() + + quantiles = [0.5, 0.2, 0.8] + if provider == "MiniMax-Text-01": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: model_attn.inference( + hidden_states, attn_mask=attention_mask, slope_rate=slope_rate + ), + quantiles=quantiles, + ) + else: + + def run_lib(): + qkv = model_attn.act(model_attn.qkv_proj(hidden_states)) + new_shape = qkv.size()[:-1] + (model_attn.num_heads, -1) + qkv = qkv.view(*new_shape) + q, k, v = torch.split(qkv, [model_attn.head_dim] * 3, dim=-1) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + lib_output = lightning_attn_func(q, k, v, slope_rate) + lib_output = lib_output.transpose(1, 2).contiguous() + lib_output = lib_output.view(batch_size, seq_len, -1) + lib_output = model_attn.norm(lib_output) + lib_output = ( + torch.sigmoid(model_attn.output_gate(hidden_states)) * lib_output + ) + return model_attn.out_proj(lib_output) + + ms, min_ms, max_ms = triton.testing.do_bench( + run_lib, + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + return benchmark + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/lightning_attention_prefill/", + help="Path to save lightning attention prefill benchmark results", + ) + args = parser.parse_args() + + # Run correctness test first + # Adapted from https://huggingface.co/MiniMaxAI/MiniMax-Text-01/blob/main/config.json + params = { + "hidden_size": 6144, + "num_attention_heads": 64, + "head_dim": 96, + "hidden_act": "silu", + } + test_lightning_attention_implementations(params) + + # Run performance benchmark + benchmark = get_benchmark() + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/benchmark/kernels/quantization/bench_int8_quant.py b/benchmark/kernels/quantization/bench_int8_quant.py new file mode 100644 index 000000000000..94b795690bfc --- /dev/null +++ b/benchmark/kernels/quantization/bench_int8_quant.py @@ -0,0 +1,94 @@ +import argparse + +import torch +import triton +from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant + +from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 + + +@torch.compile(backend="inductor") +def torch_int8_quant(x): + int8_max = torch.iinfo(torch.int8).max + + abs_max = x.abs().max(dim=-1, keepdim=True).values + scales = abs_max.to(torch.float32) / float(int8_max) + + q_x = (x / scales).round().to(torch.int8) + + return q_x, scales + + +def _test_accuracy_once(M, K, input_dtype, device): + x = torch.randn(M, K, dtype=input_dtype, device=device) * 5000 + out, scales, _ = vllm_scaled_int8_quant(x, symmetric=True) + out1, scales1 = per_token_quant_int8(x) + out2, scales2 = torch_int8_quant(x) + torch.testing.assert_close(out, out2, atol=1, rtol=0) + torch.testing.assert_close(out, out1, atol=1, rtol=0) + torch.testing.assert_close(scales, scales2) + torch.testing.assert_close(scales1, scales2) + print(f"M: {M}, K: {K}, type: {input_dtype} OK") + + +def test_accuracy(): + Ms = [1, 13, 128, 1024, 2048, 4096] + Ks = [512, 1024, 2048, 8192] + input_dtypes = [torch.float16, torch.bfloat16] + for M in Ms: + for K in Ks: + for input_dtype in input_dtypes: + _test_accuracy_once(M, K, input_dtype, "cuda") + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], + x_log=False, + line_arg="provider", + line_vals=["vllm op", "triton", "torch.compile"], + line_names=["vllm op", "triton", "torch.compile"], + styles=[("blue", "-"), ("orange", "-"), ("red", "-")], + ylabel="ms", + plot_name="int8 per token quant", + args={}, + ) +) +def benchmark(batch_size, provider): + M, K = batch_size, 16384 + x = torch.randn(M, K, dtype=torch.float16, device="cuda") * 1000 + + quantiles = [0.5, 0.2, 0.8] + if provider == "vllm op": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: vllm_scaled_int8_quant(x, symmetric=True), + quantiles=quantiles, + ) + if provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: per_token_quant_int8(x), + quantiles=quantiles, + ) + if provider == "torch.compile": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch_int8_quant(x), + quantiles=quantiles, + ) + + return ms, min_ms, max_ms + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./bench_int8_quant_res", + help="Path to save int8 quant benchmark results", + ) + args = parser.parse_args() + + test_accuracy() + + benchmark.run(print_data=True, show_plots=True, save_path=args.save_path) diff --git a/benchmark/tree_of_thought_deep/bench_sglang.py b/benchmark/tree_of_thought_deep/bench_sglang.py index b60f1f00f19c..bfb2a4113de5 100644 --- a/benchmark/tree_of_thought_deep/bench_sglang.py +++ b/benchmark/tree_of_thought_deep/bench_sglang.py @@ -103,6 +103,7 @@ def tree_search(s, question, num_branches): def main(args): lines = read_jsonl(args.data_path) + lines = list(lines) # Construct prompts num_branches = 2 diff --git a/docker/Dockerfile.dev b/docker/Dockerfile.dev index 70860d8ef886..5ff1fa7a51a0 100644 --- a/docker/Dockerfile.dev +++ b/docker/Dockerfile.dev @@ -18,6 +18,9 @@ RUN apt-get update && apt-get install -y \ silversearcher-ag \ cloc \ unzip \ + pkg-config \ + libssl-dev \ + bear \ && rm -rf /var/lib/apt/lists/* \ && apt-get clean @@ -39,7 +42,8 @@ RUN python3 -m pip install --no-cache-dir \ pytest \ black \ isort \ - icdiff + icdiff \ + pre-commit # Install diff-so-fancy RUN curl -LSso /usr/local/bin/diff-so-fancy https://github.com/so-fancy/diff-so-fancy/releases/download/v1.4.4/diff-so-fancy \ diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index 44b3f85b3516..2a55504e6122 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -1,8 +1,8 @@ # Usage (to build SGLang ROCm docker image): -# docker build --build-arg SGL_BRANCH=v0.4.1.post4 -t v0.4.1.post4-rocm620 -f Dockerfile.rocm . +# docker build --build-arg SGL_BRANCH=v0.4.1.post7 -t v0.4.1.post7-rocm620 -f Dockerfile.rocm . # default base image -ARG BASE_IMAGE="rocmshared/vllm-rocm:20241031-tuned" +ARG BASE_IMAGE="rocmshared/vllm-rocm:20250114-tuned-elementwise-layernorm" FROM $BASE_IMAGE AS base USER root @@ -16,6 +16,10 @@ ARG SGL_BRANCH=${SGL_DEFAULT} ARG TRITON_REPO="https://github.com/triton-lang/triton.git" ARG TRITON_COMMIT="845d75a" + +ARG ATER_REPO="https://github.com/HaiShaw/ater" +ARG CK_COMMITS="fa05ae" + RUN git clone ${SGL_REPO} \ && cd sglang \ && if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \ @@ -46,6 +50,11 @@ RUN git clone ${TRITON_REPO} \ && cd python \ && python3 setup.py install +RUN git clone ${ATER_REPO} \ + && cd ater \ + && git submodule update --init --recursive \ + && PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop + # Performance environment variable. ENV HIP_FORCE_DEV_KERNARG=1 diff --git a/docs/backend/native_api.ipynb b/docs/backend/native_api.ipynb index 26758f7f9759..f6c10d745c5e 100644 --- a/docs/backend/native_api.ipynb +++ b/docs/backend/native_api.ipynb @@ -348,6 +348,76 @@ "source": [ "terminate_process(reward_process)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Skip Tokenizer and Detokenizer\n", + "\n", + "SGLang Runtime also supports skip tokenizer and detokenizer. This is useful in cases like integrating with RLHF workflow." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer_free_server_process = execute_shell_command(\n", + " \"\"\"\n", + "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010 --skip-tokenizer-init\n", + "\"\"\"\n", + ")\n", + "\n", + "wait_for_server(\"http://localhost:30010\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"meta-llama/Llama-3.2-1B-Instruct\")\n", + "\n", + "input_text = \"What is the capital of France?\"\n", + "\n", + "input_tokens = tokenizer.encode(input_text)\n", + "print_highlight(f\"Input Text: {input_text}\")\n", + "print_highlight(f\"Tokenized Input: {input_tokens}\")\n", + "\n", + "response = requests.post(\n", + " \"http://localhost:30010/generate\",\n", + " json={\n", + " \"input_ids\": input_tokens,\n", + " \"sampling_params\": {\n", + " \"temperature\": 0,\n", + " \"max_new_tokens\": 256,\n", + " \"stop_token_ids\": [tokenizer.eos_token_id],\n", + " },\n", + " \"stream\": False,\n", + " },\n", + ")\n", + "output = response.json()\n", + "output_tokens = output[\"token_ids\"]\n", + "\n", + "output_text = tokenizer.decode(output_tokens, skip_special_tokens=False)\n", + "print_highlight(f\"Tokenized Output: {output_tokens}\")\n", + "print_highlight(f\"Decoded Output: {output_text}\")\n", + "print_highlight(f\"Output Text: {output['meta_info']['finish_reason']}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(tokenizer_free_server_process)" + ] } ], "metadata": { diff --git a/docs/backend/openai_api_completions.ipynb b/docs/backend/openai_api_completions.ipynb index 8660da2f98fd..58b524108db1 100644 --- a/docs/backend/openai_api_completions.ipynb +++ b/docs/backend/openai_api_completions.ipynb @@ -41,10 +41,10 @@ ")\n", "\n", "server_process = execute_shell_command(\n", - " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0\"\n", + " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30020 --host 0.0.0.0\"\n", ")\n", "\n", - "wait_for_server(\"http://localhost:30000\")" + "wait_for_server(\"http://localhost:30020\")" ] }, { @@ -68,7 +68,7 @@ "source": [ "import openai\n", "\n", - "client = openai.Client(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n", + "client = openai.Client(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "\n", "response = client.chat.completions.create(\n", " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", @@ -214,125 +214,8 @@ "metadata": {}, "source": [ "## Structured Outputs (JSON, Regex, EBNF)\n", - "You can specify a JSON schema, [regular expression](https://en.wikipedia.org/wiki/Regular_expression) or [EBNF](https://en.wikipedia.org/wiki/Extended_Backus%E2%80%93Naur_form) to constrain the model output. The model output will be guaranteed to follow the given constraints. Only one constraint parameter (`json_schema`, `regex`, or `ebnf`) can be specified for a request.\n", "\n", - "SGLang supports two grammar backends:\n", - "\n", - "- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n", - "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints.\n", - " - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md)\n", - "\n", - "Initialize the XGrammar backend using `--grammar-backend xgrammar` flag\n", - "```bash\n", - "python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", - "--port 30000 --host 0.0.0.0 --grammar-backend [xgrammar|outlines] # xgrammar or outlines (default: outlines)\n", - "```\n", - "\n", - "### JSON" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import json\n", - "\n", - "json_schema = json.dumps(\n", - " {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"name\": {\"type\": \"string\", \"pattern\": \"^[\\\\w]+$\"},\n", - " \"population\": {\"type\": \"integer\"},\n", - " },\n", - " \"required\": [\"name\", \"population\"],\n", - " }\n", - ")\n", - "\n", - "response = client.chat.completions.create(\n", - " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", - " messages=[\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": \"Give me the information of the capital of France in the JSON format.\",\n", - " },\n", - " ],\n", - " temperature=0,\n", - " max_tokens=128,\n", - " response_format={\n", - " \"type\": \"json_schema\",\n", - " \"json_schema\": {\"name\": \"foo\", \"schema\": json.loads(json_schema)},\n", - " },\n", - ")\n", - "\n", - "print_highlight(response.choices[0].message.content)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Regular expression (use default \"outlines\" backend)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "response = client.chat.completions.create(\n", - " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", - " messages=[\n", - " {\"role\": \"user\", \"content\": \"What is the capital of France?\"},\n", - " ],\n", - " temperature=0,\n", - " max_tokens=128,\n", - " extra_body={\"regex\": \"(Paris|London)\"},\n", - ")\n", - "\n", - "print_highlight(response.choices[0].message.content)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### EBNF (use \"xgrammar\" backend)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# terminate the existing server(that's using default outlines backend) for this demo\n", - "terminate_process(server_process)\n", - "\n", - "# start new server with xgrammar backend\n", - "server_process = execute_shell_command(\n", - " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0 --grammar-backend xgrammar\"\n", - ")\n", - "wait_for_server(\"http://localhost:30000\")\n", - "\n", - "# EBNF example\n", - "ebnf_grammar = r\"\"\"\n", - " root ::= \"Hello\" | \"Hi\" | \"Hey\"\n", - " \"\"\"\n", - "response = client.chat.completions.create(\n", - " model=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n", - " messages=[\n", - " {\"role\": \"system\", \"content\": \"You are a helpful EBNF test bot.\"},\n", - " {\"role\": \"user\", \"content\": \"Say a greeting.\"},\n", - " ],\n", - " temperature=0,\n", - " max_tokens=32,\n", - " extra_body={\"ebnf\": ebnf_grammar},\n", - ")\n", - "\n", - "print_highlight(response.choices[0].message.content)" + "For OpenAI compatible structed outputs API, refer to [Structured Outputs](https://docs.sglang.ai/backend/structured_outputs.html#OpenAI-Compatible-API) for more details.\n" ] }, { @@ -362,7 +245,7 @@ "import time\n", "from openai import OpenAI\n", "\n", - "client = OpenAI(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n", + "client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "\n", "requests = [\n", " {\n", @@ -465,7 +348,7 @@ "import time\n", "from openai import OpenAI\n", "\n", - "client = OpenAI(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n", + "client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "\n", "requests = []\n", "for i in range(100):\n", @@ -542,7 +425,7 @@ "from openai import OpenAI\n", "import os\n", "\n", - "client = OpenAI(base_url=\"http://127.0.0.1:30000/v1\", api_key=\"None\")\n", + "client = OpenAI(base_url=\"http://127.0.0.1:30020/v1\", api_key=\"None\")\n", "\n", "requests = []\n", "for i in range(500):\n", diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index a4913b8af6b9..7e8f4ca0a544 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -1,13 +1,16 @@ # Server Arguments +## Common launch commands + - To enable multi-GPU tensor parallelism, add `--tp 2`. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command. ``` python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 2 ``` -- To enable multi-GPU data parallelism, add `--dp 2`. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total. +- To enable multi-GPU data parallelism, add `--dp 2`. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total. We recommend [SGLang Router](https://docs.sglang.ai/router/router.html) for data parallelism. ``` -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --dp 2 --tp 2 +python -m sglang_router.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --dp 2 --tp 2 ``` + - If you see out-of-memory errors during serving, try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`. ``` python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --mem-fraction-static 0.7 @@ -26,52 +29,156 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct - To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port, you can use the following commands. If you meet deadlock, please try to add `--disable-cuda-graph` ``` # Node 0 -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 0 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 0 # Node 1 -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 1 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 1 ``` -## Use Models From ModelScope -
-More +Please consult the documentation below to learn more about the parameters you may provide when launching a server. -To use a model from [ModelScope](https://www.modelscope.cn), set the environment variable SGLANG_USE_MODELSCOPE. -``` -export SGLANG_USE_MODELSCOPE=true -``` -Launch [Qwen2-7B-Instruct](https://www.modelscope.cn/models/qwen/qwen2-7b-instruct) Server -``` -SGLANG_USE_MODELSCOPE=true python -m sglang.launch_server --model-path qwen/Qwen2-7B-Instruct --port 30000 -``` -Or start it by docker. -```bash -docker run --gpus all \ - -p 30000:30000 \ - -v ~/.cache/modelscope:/root/.cache/modelscope \ - --env "SGLANG_USE_MODELSCOPE=true" \ - --ipc=host \ - lmsysorg/sglang:latest \ - python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 30000 -``` +## Model and tokenizer -
+* `model_path`: Path to the model that will be served. +* `tokenizer_path`: Defaults to the `model_path`. +* `tokenizer_mode`: By default `auto`, see [here](https://huggingface.co/docs/transformers/en/main_classes/tokenizer) for different mode. +* `load_format`: The format the weights are loaded in. Defaults to `*.safetensors`/`*.bin`. +* `trust_remote_code`: If `True`, will use locally cached config files, other wise use remote configs in HuggingFace. +* `dtype`: Dtype used for the model, defaults to `bfloat16`. +* `kv_cache_dtype`: Dtype of the kv cache, defaults to the `dtype`. +* `context_length`: The number of tokens our model can process *including the input*. Not that extending the default might lead to strange behavior. +* `device`: The device we put the model, defaults to `cuda`. +* `chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.html#Chat-Template). +* `is_embedding`: Set to true to perform [embedding](https://docs.sglang.ai/backend/openai_api_embeddings.html) / [enocode](https://docs.sglang.ai/backend/native_api.html#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api.html#Classify-(reward-model)) tasks. +* `revision`: Adjust if a specific version of the model should be used. +* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. +* `json_model_override_args`: Override model config with the provided JSON. +* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model. -## Example: Run Llama 3.1 405B -
-More +## Serving: HTTP & API -```bash -# Run 405B (fp8) on a single node -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 --tp 8 +### HTTP Server configuration -# Run 405B (fp16) on two nodes -## on the first node, replace the `172.16.4.52:20000` with your own first node ip address and port -python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 +* `port` and `host`: Setup the host for HTTP server. By default `host: str = "127.0.0.1"` and `port: int = 30000` -## on the first node, replace the `172.16.4.52:20000` with your own first node ip address and port -python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 1 -``` +### API configuration + +* `api_key`: Sets an API key for the server and the OpenAI-compatible API. +* `file_storage_pth`: Directory for storing uploaded or generated files from API calls. +* `enable_cache_report`: If set, includes detailed usage of cached tokens in the response usage. + +## Parallelism + +### Tensor parallelism + +* `tp_size`: The number of GPUs the model weights get sharded over. Mainly for saving memory rather than for high throughput, see [this blogpost](https://pytorch.org/tutorials/intermediate/TP_tutorial.html#how-tensor-parallel-works). + +### Data parallelism + +* `dp_size`: Will be deprecated. The number of data-parallel copies of the model. [SGLang router](https://docs.sglang.ai/router/router.html) is recommended instead of the current naive data parallel. +* `load_balance_method`: Will be deprecated. Load balancing strategy for data parallel requests. + +### Expert parallelism + +* `ep_size`: Distribute the experts onto multiple GPUs for MoE models. Remember to shard the model weights with `tp_size=ep_size`, for detailed benchmarking refer to [this PR](https://github.com/sgl-project/sglang/pull/2203). + +## Memory and scheduling + +* `mem_fraction_static`: Fraction of the free GPU memory used for static memory like model weights and KV cache. If building KV cache fails, it should be increased. If CUDA runs out of memory, it should be decreased. +* `max_running_requests`: The maximum number of requests to run concurrently. +* `max_total_tokens`: The maximum number of tokens that can be stored into the KV cache. Use mainly for debugging. +* `chunked_prefill_size`: Perform the prefill in chunks of these size. Larger chunk size speeds up the prefill phase but increases the VRAM consumption. If CUDA runs out of memory, it should be decreased. +* `max_prefill_tokens`: Token budget of how many tokens to accept in one prefill batch. The actual number is the max of this parameter and the `context_length`. +* `schedule_policy`: The scheduling policy to control the processing order of waiting prefill requests in a single engine. +* `schedule_conservativeness`: Can be used to decrease/increase the conservativeness of the server when taking new requests. Highly conservative behavior leads to starvation, but low conservativeness leads to slowed-down performance. +* `cpu_offload_gb`: Reserve this amount of RAM in GB for offloading of model parameters to the CPU. +* `prefill_only_one_req`: When this flag is turned on, the engine prefills only one request at a time. + +## Other runtime options + +* `stream_interval`: Interval (in tokens) for streaming responses. Smaller values lead to smoother streaming, and larger values lead to better throughput. +* `random_seed`: Can be used to enforce more deterministic behavior. +* `watchdog_timeout`: Adjusts the watchdog thread’s timeout before killing the server if batch generation takes too long. +* `download_dir`: Use to override the default Hugging Face cache directory for model weights. +* `base_gpu_id`: Use to adjust first GPU used to distribute the model across available GPUs. +* `allow_auto_truncate`: Automatically truncate requests that exceed the maximum input length. + +## Logging + +* `log_level`: Global log verbosity. +* `log_level_http`: Separate verbosity level for the HTTP server logs (if unset, defaults to `log_level`). +* `log_requests`: Logs the inputs and outputs of all requests for debugging. +* `show_time_cost`: Prints or logs detailed timing info for internal operations (helpful for performance tuning). +* `enable_metrics`: Exports Prometheus-like metrics for request usage and performance. +* `decode_log_interval`: How often (in tokens) to log decode progress. + +## Multi-node distributed serving + +* `dist_init_addr`: The TCP address used for initializing PyTorch’s distributed backend (e.g. `192.168.0.2:25000`). +* `nnodes`: Total number of nodes in the cluster. Refer to how to run the [Llama 405B model](https://docs.sglang.ai/references/llama_405B.html#run-405b-fp16-on-two-nodes). +* `node_rank`: Rank (ID) of this node among the `nnodes` in the distributed setup. + + +## LoRA + +* `lora_paths`: You may provide a list of adapters to your model as a list. Each batch element will get model response with the corresponding lora adapter applied. Currently `cuda_graph` and `radix_attention` are not supportet with this option so you need to disable them manually. We are still working on through these [issues](https://github.com/sgl-project/sglang/issues/2929). +* `max_loras_per_batch`: Maximum number of LoRAs in a running batch including base model. + +## Kernel backend + +* `attention_backend`: The backend for attention computation and KV cache management. +* `sampling_backend`: The backend for sampling. + +## Constrained Decoding + +* `grammar_backend`: The grammar backend for constraint decoding. Detailed usage can be found in this [document](https://docs.sglang.ai/backend/structured_outputs.html). +* `constrained_json_whitespace_pattern`: Use with `Outlines` grammar backend to allow JSON with syntatic newlines, tabs or multiple spaces. Details can be found [here](https://dottxt-ai.github.io/outlines/latest/reference/generation/json/#using-pydantic). + +## Speculative decoding + +* `speculative_draft_model_path`: The draft model path for speculative decoding. +* `speculative_algorithm`: The algorithm for speculative decoding. Currently only [Eagle](https://arxiv.org/html/2406.16858v1) is supported. Note that the radix cache, chunked prefill, and overlap scheduler are disabled when using eagle speculative decoding. +* `speculative_num_steps`: How many draft passes we run before verifying. +* `speculative_num_draft_tokens`: The number of tokens proposed in a draft. +* `speculative_eagle_topk`: The number of top candidates we keep for verification at each step for [Eagle](https://arxiv.org/html/2406.16858v1). + + +## Double Sparsity + +* `enable_double_sparsity`: Enables [double sparsity](https://arxiv.org/html/2408.07092v2) which increases throughput. +* `ds_channel_config_path`: The double sparsity config. For a guide on how to generate the config for your model see [this repo](https://github.com/andy-yang-1/DoubleSparse/tree/main/config). +* `ds_heavy_channel_num`: Number of channel indices to keep for each layer. +* `ds_heavy_token_num`: Number of tokens used for attention during decode. Skip sparse decoding if `min_seq_len` in batch < this number. +* `ds_heavy_channel_type`: The type of heavy channels. Either `q`, `k` or `qk`. +* `ds_sparse_decode_threshold`: Don't apply sparse decoding if `max_seq_len` in batch < this threshold. + +## Debug options + +*Note: We recommend to stay with the defaults and only use these options for debugging for best possible performance.* + +* `disable_radix_cache`: Disable [Radix](https://lmsys.org/blog/2024-01-17-sglang/) backend for prefix caching. +* `disable_jump_forward`: Disable [jump-forward](https://lmsys.org/blog/2024-02-05-compressed-fsm/#our-method-jump-forward-decoding-with-a-compressed-finite-state-machine) for outlines grammar backend. +* `disable_cuda_graph`: Disable [cuda graph](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/) for model forward. +* `disable_cuda_graph_padding`: Disable cuda graph when padding is needed. In other case still use cuda graph. +* `disable_outlines_disk_cache`: Disable disk cache for outlines grammar backend. +* `disable_custom_all_reduce`: Disable usage of custom all reduce kernel. +* `disable_mla`: Disable [Multi-Head Latent Attention](https://arxiv.org/html/2405.04434v5) for Deepseek model. +* `disable_overlap_schedule`: Disable the [Overhead-Scheduler](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#zero-overhead-batch-scheduler). +* `enable_nan_detection`: Turning this on makes the sampler print a warning if the logits contain `NaN`. +* `enable_p2p_check`: Turns off the default of allowing always p2p check when accessing GPU. +* `triton_attention_reduce_in_fp32`: In triton kernels this will cast the intermediate attention result to `float32`. + +## Optimization + +*Note: Some of these options are still in experimental stage.* -
+* `enable_mixed_chunk`: Enables mixing prefill and decode, see [this discussion](https://github.com/sgl-project/sglang/discussions/1163). +* `enable_dp_attention`: Enable [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models) for Deepseek models. Note that you need to choose `dp_size = tp_size` for this. +* `enable_ep_moe`: Enables expert parallelism, see the description of `ep_size`. +* `enable_torch_compile`: Torch compile the model. This is an experimental feature. +* `torch_compile_max_bs`: The maximum batch size when using `torch_compile`. +* `cuda_graph_max_bs`: Adjust the maximum batchsize when using cuda graph. By default this is chosen for you based on GPU specifics. +* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. +* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. +* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. diff --git a/docs/backend/structured_outputs.ipynb b/docs/backend/structured_outputs.ipynb index f017ef863035..e413743ccfde 100644 --- a/docs/backend/structured_outputs.ipynb +++ b/docs/backend/structured_outputs.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Structured Outputs (JSON, Regex, EBNF)" + "# Structured Outputs" ] }, { @@ -16,11 +16,13 @@ "SGLang supports two grammar backends:\n", "\n", "- [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints.\n", - "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints and currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md).\n", + "- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints.\n", "\n", - "We suggest using XGrammar whenever possible for its better performance. For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n", + "We suggest using XGrammar for its better performance and utility. XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md). For more details, see [XGrammar technical overview](https://blog.mlc.ai/2024/11/22/achieving-efficient-flexible-portable-structured-generation-with-xgrammar).\n", "\n", - "To use Xgrammar, simply add `--grammar-backend` xgrammar when launching the server. If no backend is specified, Outlines will be used as the default." + "To use Xgrammar, simply add `--grammar-backend` xgrammar when launching the server. If no backend is specified, Outlines will be used as the default.\n", + "\n", + "For better output quality, **It's advisable to explicitly include instructions in the prompt to guide the model to generate the desired format.** For example, you can specify, 'Please generate the output in the following JSON format: ...'.\n" ] }, { @@ -43,6 +45,10 @@ " print_highlight,\n", ")\n", "import openai\n", + "import os\n", + "\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n", + "\n", "\n", "server_process = execute_shell_command(\n", " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0 --grammar-backend xgrammar\"\n", @@ -88,7 +94,7 @@ " messages=[\n", " {\n", " \"role\": \"user\",\n", - " \"content\": \"Give me the information of the capital of France in the JSON format.\",\n", + " \"content\": \"Please generate the information of the capital of France in the JSON format.\",\n", " },\n", " ],\n", " temperature=0,\n", @@ -192,20 +198,6 @@ "print_highlight(response.choices[0].message.content)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "terminate_process(server_process)\n", - "server_process = execute_shell_command(\n", - " \"python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0\"\n", - ")\n", - "\n", - "wait_for_server(\"http://localhost:30000\")" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -232,15 +224,6 @@ "print_highlight(response.choices[0].message.content)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "terminate_process(server_process)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -248,21 +231,6 @@ "## Native API and SGLang Runtime (SRT)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "server_process = execute_shell_command(\n", - " \"\"\"\n", - "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010 --grammar-backend xgrammar\n", - "\"\"\"\n", - ")\n", - "\n", - "wait_for_server(\"http://localhost:30010\")" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -296,7 +264,7 @@ "\n", "# Make API request\n", "response = requests.post(\n", - " \"http://localhost:30010/generate\",\n", + " \"http://localhost:30000/generate\",\n", " json={\n", " \"text\": \"Here is the information of the capital of France in the JSON format.\\n\",\n", " \"sampling_params\": {\n", @@ -341,7 +309,7 @@ "\n", "# JSON\n", "response = requests.post(\n", - " \"http://localhost:30010/generate\",\n", + " \"http://localhost:30000/generate\",\n", " json={\n", " \"text\": \"Here is the information of the capital of France in the JSON format.\\n\",\n", " \"sampling_params\": {\n", @@ -371,7 +339,7 @@ "import requests\n", "\n", "response = requests.post(\n", - " \"http://localhost:30010/generate\",\n", + " \"http://localhost:30000/generate\",\n", " json={\n", " \"text\": \"Give me the information of the capital of France.\",\n", " \"sampling_params\": {\n", @@ -394,22 +362,6 @@ "print_highlight(response.json())" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "terminate_process(server_process)\n", - "server_process = execute_shell_command(\n", - " \"\"\"\n", - "python3 -m sglang.launch_server --model-path meta-llama/Llama-3.2-1B-Instruct --port=30010\n", - "\"\"\"\n", - ")\n", - "\n", - "wait_for_server(\"http://localhost:30010\")" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -424,7 +376,7 @@ "outputs": [], "source": [ "response = requests.post(\n", - " \"http://localhost:30010/generate\",\n", + " \"http://localhost:30000/generate\",\n", " json={\n", " \"text\": \"Paris is the capital of\",\n", " \"sampling_params\": {\n", @@ -461,7 +413,7 @@ "source": [ "import sglang as sgl\n", "\n", - "llm_xgrammar = sgl.Engine(\n", + "llm = sgl.Engine(\n", " model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\", grammar_backend=\"xgrammar\"\n", ")" ] @@ -509,7 +461,7 @@ " \"json_schema\": json.dumps(CapitalInfo.model_json_schema()),\n", "}\n", "\n", - "outputs = llm_xgrammar.generate(prompts, sampling_params)\n", + "outputs = llm.generate(prompts, sampling_params)\n", "for prompt, output in zip(prompts, outputs):\n", " print_highlight(\"===============================\")\n", " print_highlight(f\"Prompt: {prompt}\") # validate the output by the pydantic model\n", @@ -549,7 +501,7 @@ "\n", "sampling_params = {\"temperature\": 0.1, \"top_p\": 0.95, \"json_schema\": json_schema}\n", "\n", - "outputs = llm_xgrammar.generate(prompts, sampling_params)\n", + "outputs = llm.generate(prompts, sampling_params)\n", "for prompt, output in zip(prompts, outputs):\n", " print_highlight(\"===============================\")\n", " print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")" @@ -586,22 +538,12 @@ " ),\n", "}\n", "\n", - "outputs = llm_xgrammar.generate(prompts, sampling_params)\n", + "outputs = llm.generate(prompts, sampling_params)\n", "for prompt, output in zip(prompts, outputs):\n", " print_highlight(\"===============================\")\n", " print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "llm_xgrammar.shutdown()\n", - "llm_outlines = sgl.Engine(model_path=\"meta-llama/Meta-Llama-3.1-8B-Instruct\")" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -622,7 +564,7 @@ "\n", "sampling_params = {\"temperature\": 0.8, \"top_p\": 0.95, \"regex\": \"(France|England)\"}\n", "\n", - "outputs = llm_outlines.generate(prompts, sampling_params)\n", + "outputs = llm.generate(prompts, sampling_params)\n", "for prompt, output in zip(prompts, outputs):\n", " print_highlight(\"===============================\")\n", " print_highlight(f\"Prompt: {prompt}\\nGenerated text: {output['text']}\")" @@ -634,7 +576,7 @@ "metadata": {}, "outputs": [], "source": [ - "llm_outlines.shutdown()" + "llm.shutdown()" ] } ], diff --git a/docs/developer/setup_github_runner.md b/docs/developer/setup_github_runner.md index 7b510d72305e..e805cfce7dad 100644 --- a/docs/developer/setup_github_runner.md +++ b/docs/developer/setup_github_runner.md @@ -11,9 +11,9 @@ docker pull nvidia/cuda:12.1.1-devel-ubuntu22.04 # Nvidia docker run --shm-size 128g -it -v /tmp/huggingface:/hf_home --gpus all nvidia/cuda:12.1.1-devel-ubuntu22.04 /bin/bash # AMD -docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.1.post4-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.1.post7-rocm620 /bin/bash # AMD just the last 2 GPUs -docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.1.post4-rocm620 /bin/bash +docker run --rm --device=/dev/kfd --device=/dev/dri/renderD176 --device=/dev/dri/renderD184 --group-add video --shm-size 128g -it -v /tmp/huggingface:/hf_home lmsysorg/sglang:v0.4.1.post7-rocm620 /bin/bash ``` ### Step 2: Configure the runner by `config.sh` diff --git a/docs/index.rst b/docs/index.rst index 80a53d1cb3bb..51796d4a1071 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -56,6 +56,9 @@ The core features include: references/hyperparameter_tuning.md references/benchmark_and_profiling.md references/custom_chat_template.md + references/deepseek.md + references/llama_405B.md + references/modelscope.md references/contribution_guide.md references/troubleshooting.md references/faq.md diff --git a/docs/references/benchmark_and_profiling.md b/docs/references/benchmark_and_profiling.md index 87ac5177424d..0600b192b4fb 100644 --- a/docs/references/benchmark_and_profiling.md +++ b/docs/references/benchmark_and_profiling.md @@ -64,16 +64,31 @@ with nvtx.annotate("description", color="color"): ```bash # set trace path export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log + # start server python -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct -python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --profile +# send profiling request from client +python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 10 --sharegpt-output-len 100 --profile ``` - -Traces can be visualized using https://ui.perfetto.dev/. +Please make sure that the `SGLANG_TORCH_PROFILER_DIR` should be set at both server and client side, otherwise the trace file cannot be generated correctly . A secure way will be setting `SGLANG_TORCH_PROFILER_DIR` in the `.*rc` file of shell (e.g. `~/.bashrc` for bash shells). - To profile offline ```bash export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log python -m sglang.bench_offline_throughput --model-path meta-llama/Llama-3.1-8B-Instruct --dataset-name random --num-prompts 10 --profile --mem-frac=0.8 ``` + +- View Traces + +Trace files can be loaded and visualized from: +1. https://ui.perfetto.dev/ (any browser) +2. chrome://tracing (Chrome browser only) + +If browser cannot open trace file due to its large size, +client can generate a small trace file (<100MB) by controlling number of prompts and lengths of prompt outputs. +For example, when profiling a server, +```bash +python -m sglang.bench_serving --backend sglang --model-path meta-llama/Llama-3.1-8B-Instruct --num-prompts 2 --sharegpt-output-len 100 --profile +``` +sets the number of prompts to 2 with `--num-prompts` argument and limits the length of output sequences to 100 with `--sharegpt-output-len` argument, which can generate a small trace file for browser to open smoothly. diff --git a/docs/references/contribution_guide.md b/docs/references/contribution_guide.md index b2211f463fb0..b3b7f826894a 100644 --- a/docs/references/contribution_guide.md +++ b/docs/references/contribution_guide.md @@ -14,7 +14,7 @@ git clone https://github.com//sglang.git ### Install Dependencies & Build -Refer to [Install SGLang from Source](https://sgl-project.github.io/start/install.html#method-2-from-source) documentation for more details on setting up the necessary dependencies. +Refer to [Install SGLang from Source](https://docs.sglang.ai/start/install.html#method-2-from-source) documentation for more details on setting up the necessary dependencies. ## Code Formatting with Pre-Commit diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md new file mode 100644 index 000000000000..2bdceb90478e --- /dev/null +++ b/docs/references/deepseek.md @@ -0,0 +1,56 @@ +# DeepSeek Model Optimizations + +SGLang provides several optimizations specifically designed for the DeepSeek model to boost its inference speed. This document outlines current optimizations for DeepSeek. Additionally, the SGLang team is actively developing enhancements for [DeepSeek-V3](https://github.com/sgl-project/sglang/issues/2591). + + +## Multi-head Latent Attention (MLA) Throughput Optimizations + +**Description**: [MLA](https://arxiv.org/pdf/2405.04434) is an innovative attention mechanism introduced by the DeepSeek team, aimed at improving inference efficiency. SGLang has implemented specific optimizations for this, including: + +- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. +- **Triton Decoding Kernel Optimization**: In the MLA decoding kernel, there is only one KV head. This optimization reduces memory access to the KV cache by processing multiple query heads within one block, accelerating the decoding process. + +- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. + +- **CUDA Graph & Torch.compile**: Both MLA and Mixture of Experts (MoE) are compatible with CUDA Graph and Torch.compile, which reduces latency and accelerates decoding speed for small batch sizes. + +Overall, with these optimizations, we have achieved up to a 7x acceleration in output throughput compared to the previous version. + +

+ Multi-head Latent Attention for DeepSeek Series Models +

+ +**Usage**: MLA optimization is enabled by defalut, to disable, use `--disable-mla`. + +**Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details. + +## Data Parallelism Attention + +**Description**: This optimization involves data parallelism (DP) for the MLA attention mechanism of DeepSeek Series Models, which allows for a significant reduction in the KV cache size, enabling larger batch sizes. Each DP worker independently handles different types of batches (prefill, decode, idle), which are then synchronized before and after processing through the Mixture-of-Experts (MoE) layer. + +

+ Data Parallelism Attention for DeepSeek Series Models +

+ +**Usage**: This optimization is aimed at improving throughput and should be used for scenarios with high QPS (Queries Per Second). Data Parallelism Attention optimization can be enabeld by `--enable-dp-attention` for DeepSeek Series Models. + +

+ Data Parallelism Attention Performance Comparison +

+ +**Reference**: Check [Blog](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models). + +## Multi Node Tensor Parallelism + +**Description**: For users with limited memory on a single node, SGLang supports serving DeepSeek Series Models, including DeepSeek V3, across multiple nodes using tensor parallelism. This approach partitions the model parameters across multiple GPUs or nodes to handle models that are too large for one node's memory. + +**Usage**: Check [here](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#example-serving-with-2-h208) for usage examples. + +## Block-wise FP8 + +**Description**: SGLang implements block-wise FP8 quantization with two key optimizations: + +- **Activation**: E4M3 format using per-token-per-128-channel sub-vector scales with online casting. +- **Weight**: Per-128x128-block quantization for better numerical stability. + +**Usage**: turn on by default for DeepSeek V3 models. diff --git a/docs/references/llama_405B.md b/docs/references/llama_405B.md new file mode 100644 index 000000000000..a63b012fb27f --- /dev/null +++ b/docs/references/llama_405B.md @@ -0,0 +1,19 @@ +# Run Llama 3.1 405B + +## Run 405B (fp8) on a Single Node + +```bash +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 --tp 8 +``` + +## Run 405B (fp16) on Two Nodes + +```bash +# on the first node, replace 172.16.4.52:20000 with your own node ip address and port + +python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --dist-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 + +# on the second node, replace 172.18.45.52:20000 with your own node ip address and port + +python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --dist-init-addr 172.18.45.52:20000 --nnodes 2 --node-rank 1 +``` diff --git a/docs/references/modelscope.md b/docs/references/modelscope.md new file mode 100644 index 000000000000..4740c2770f9e --- /dev/null +++ b/docs/references/modelscope.md @@ -0,0 +1,28 @@ +# Use Models From ModelScope + +To use a model from [ModelScope](https://www.modelscope.cn), set the environment variable `SGLANG_USE_MODELSCOPE`. + +```bash +export SGLANG_USE_MODELSCOPE=true +``` + +We take [Qwen2-7B-Instruct](https://www.modelscope.cn/models/qwen/qwen2-7b-instruct) as an example. + +Launch the Server: +```bash +python -m sglang.launch_server --model-path qwen/Qwen2-7B-Instruct --port 30000 +``` + +Or start it by docker: + +```bash +docker run --gpus all \ + -p 30000:30000 \ + -v ~/.cache/modelscope:/root/.cache/modelscope \ + --env "SGLANG_USE_MODELSCOPE=true" \ + --ipc=host \ + lmsysorg/sglang:latest \ + python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-7B-Instruct --host 0.0.0.0 --port 30000 +``` + +Note that modelscope uses a different cache directory than huggingface. You may need to set it manually to avoid running out of disk space. diff --git a/docs/references/sampling_params.md b/docs/references/sampling_params.md index 5dad3fd12597..77d7c9f82e75 100644 --- a/docs/references/sampling_params.md +++ b/docs/references/sampling_params.md @@ -32,6 +32,20 @@ class GenerateReqInput: return_text_in_logprobs: bool = False # Whether to stream output. stream: bool = False + # Whether to log metrics for this request (e.g. health_generate calls do not log metrics) + log_metrics: bool = True + + # The modalities of the image data [image, multi-images, video] + modalities: Optional[List[str]] = None + # LoRA related + lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + + # Session info for continual prompting + session_params: Optional[Union[List[Dict], Dict]] = None + # Custom logit processor for advanced sampling control. Must be a serialized instance + # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py + # Use the processor's `to_str()` method to generate the serialized string. + custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None ``` The `sampling_params` follows this format @@ -90,6 +104,14 @@ repetition_penalty: float = 1.0, # difficult to infer the correct token ID by given `stop` strings. # Must be 0 <= value < max_new_tokens. Setting to 0 (default) will disable this penalty. min_new_tokens: int = 0, + + +## Custom Parameters for Custom Logit Processor. +# A dictionary of custom parameters for the custom logit processor. +# The custom logit processor takes a list of dictionaries as input, where each +# dictionary is the custom parameters for one token in a batch of the input. +# See also python/sglang/srt/sampling/custom_logit_processor.py +custom_params: Optional[Dict[str, Any]] = None, ``` ## Examples @@ -189,7 +211,7 @@ You can specify a JSON schema, regular expression or [EBNF](https://en.wikipedia SGLang supports two grammar backends: - [Outlines](https://github.com/dottxt-ai/outlines) (default): Supports JSON schema and regular expression constraints. -- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema and EBNF constraints. +- [XGrammar](https://github.com/mlc-ai/xgrammar): Supports JSON schema, regular expression, and EBNF constraints. - XGrammar currently uses the [GGML BNF format](https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md) Initialize the XGrammar backend using `--grammar-backend xgrammar` flag @@ -253,3 +275,49 @@ response = requests.post( ) print(response.json()) ``` +### Custom Logit Processor +Launch a server with `--enable-custom-logit-processor` flag on. +``` +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --port 30000 --enable-custom-logit-processor +``` + +Define a custom logit processor that will always sample a specific token id. +```python +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor + +class DeterministicLogitProcessor(CustomLogitProcessor): + """A dummy logit processor that changes the logits to always + sample the given token id. + """ + + def __call__(self, logits, custom_param_list): + # Check that the number of logits matches the number of custom parameters + assert logits.shape[0] == len(custom_param_list) + key = "token_id" + + for i, param_dict in enumerate(custom_param_list): + # Mask all other tokens + logits[i, :] = -float("inf") + # Assign highest probability to the specified token + logits[i, param_dict[key]] = 0.0 + return logits +``` + +Send a request +```python +import requests + +response = requests.post( + "http://localhost:30000/generate", + json={ + "text": "The capital of France is", + "custom_logit_processor": DeterministicLogitProcessor().to_str(), + "sampling_params": { + "temperature": 0.0, + "max_new_tokens": 32, + "custom_params": {"token_id": 5}, + }, + }, +) +print(response.json()) +``` diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 9dafc3d2a3d7..0a00ad0c8a1a 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -5,7 +5,7 @@ - Mistral / Mixtral / Mistral NeMo - Gemma / Gemma 2 - Qwen / Qwen 2 / Qwen 2 MoE / Qwen 2 VL -- DeepSeek / DeepSeek 2 +- DeepSeek / DeepSeek 2 / [DeepSeek 3](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3) - OLMoE - [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/) - `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-7b-ov --port=30000 --chat-template=chatml-llava` @@ -24,10 +24,11 @@ - InternLM 2 - Exaone 3 - BaiChuan2 -- MiniCPM / MiniCPM 3 +- MiniCPM / MiniCPM 3 / MiniCPMV - XVERSE / XVERSE MoE - SmolLM - GLM-4 +- Phi-3 / Phi-4 - Phi-3-Small - IBM Granite 3 @@ -81,6 +82,7 @@ To port a model from vLLM to SGLang, you can compare these two files [SGLang Lla - Remove `Sample`. - Change `forward()` functions, and add `forward_batch`. - Add `EntryClass` at the end. + - Please ensure the new implementation uses **only SGLang components and does not rely on any vLLM components**. ### Registering an external model implementation @@ -90,7 +92,7 @@ Here is how you can do it: ```python from sglang.srt.models.registry import ModelRegistry -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server # for a single model, you can add it to the registry ModelRegistry.models[model_name] = model_class diff --git a/docs/start/install.md b/docs/start/install.md index 8a81bb177974..81e2345a6738 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -13,7 +13,7 @@ Note: Please check the [FlashInfer installation doc](https://docs.flashinfer.ai/ ## Method 2: From source ``` # Use the last release branch -git clone -b v0.4.1.post4 https://github.com/sgl-project/sglang.git +git clone -b v0.4.1.post7 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -26,7 +26,7 @@ Note: To AMD ROCm system with Instinct/MI GPUs, do following instead: ``` # Use the last release branch -git clone -b v0.4.1.post4 https://github.com/sgl-project/sglang.git +git clone -b v0.4.1.post7 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip @@ -51,7 +51,7 @@ docker run --gpus all \ Note: To AMD ROCm system with Instinct/MI GPUs, it is recommended to use `docker/Dockerfile.rocm` to build images, example and usage as below: ```bash -docker build --build-arg SGL_BRANCH=v0.4.1.post4 -t v0.4.1.post4-rocm620 -f Dockerfile.rocm . +docker build --build-arg SGL_BRANCH=v0.4.1.post7 -t v0.4.1.post7-rocm620 -f Dockerfile.rocm . alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/dri --ipc=host \ --shm-size 16G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \ @@ -60,11 +60,11 @@ alias drun='docker run -it --rm --network=host --device=/dev/kfd --device=/dev/d drun -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=" \ - v0.4.1.post4-rocm620 \ + v0.4.1.post7-rocm620 \ python3 -m sglang.launch_server --model-path meta-llama/Llama-3.1-8B-Instruct --host 0.0.0.0 --port 30000 # Till flashinfer backend available, --attention-backend triton --sampling-backend pytorch are set by default -drun v0.4.1.post4-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 +drun v0.4.1.post7-rocm620 python3 -m sglang.bench_one_batch --batch-size 32 --input 1024 --output 128 --model amd/Meta-Llama-3.1-8B-Instruct-FP8-KV --tp 8 --quantization fp8 ``` ## Method 4: Using docker compose diff --git a/examples/frontend_language/usage/json_decode.py b/examples/frontend_language/usage/json_decode.py index ce8f5ba70627..5dc3522d512a 100644 --- a/examples/frontend_language/usage/json_decode.py +++ b/examples/frontend_language/usage/json_decode.py @@ -9,7 +9,7 @@ from pydantic import BaseModel import sglang as sgl -from sglang.srt.constrained import build_regex_from_object +from sglang.srt.constrained.outlines_backend import build_regex_from_object character_regex = ( r"""\{\n""" diff --git a/examples/frontend_language/usage/triton/models/character_generation/1/model.py b/examples/frontend_language/usage/triton/models/character_generation/1/model.py index 5550e93984b7..4bf86f1b6919 100644 --- a/examples/frontend_language/usage/triton/models/character_generation/1/model.py +++ b/examples/frontend_language/usage/triton/models/character_generation/1/model.py @@ -3,8 +3,8 @@ from pydantic import BaseModel import sglang as sgl -from sglang import function, set_default_backend -from sglang.srt.constrained import build_regex_from_object +from sglang import function +from sglang.srt.constrained.outlines_backend import build_regex_from_object sgl.set_default_backend(sgl.RuntimeEndpoint("http://localhost:30000")) diff --git a/examples/runtime/async_io_api.py b/examples/runtime/async_io_api.py deleted file mode 100644 index 23d3d0b90bf9..000000000000 --- a/examples/runtime/async_io_api.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Usage: - -python3 async_io.py -""" - -import asyncio - -from sglang import Runtime - - -async def generate( - engine, - prompt, - sampling_params, -): - tokenizer = engine.get_tokenizer() - - messages = [ - { - "role": "system", - "content": "You will be given question answer tasks.", - }, - {"role": "user", "content": prompt}, - ] - - prompt = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True - ) - - stream = engine.add_request(prompt, sampling_params) - - async for output in stream: - print(output, end="", flush=True) - print() - - -if __name__ == "__main__": - runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf") - print("--- runtime ready ---\n") - - prompt = "Who is Alan Turing?" - sampling_params = {"max_new_tokens": 128} - asyncio.run(generate(runtime, prompt, sampling_params)) - - runtime.shutdown() diff --git a/examples/runtime/engine/offline_batch_inference.py b/examples/runtime/engine/offline_batch_inference.py index 724051eab538..92e68dcd72ca 100644 --- a/examples/runtime/engine/offline_batch_inference.py +++ b/examples/runtime/engine/offline_batch_inference.py @@ -1,3 +1,8 @@ +""" +Usage: +python3 offline_batch_inference.py --model meta-llama/Llama-3.1-8B-Instruct +""" + import argparse import dataclasses diff --git a/python/pyproject.toml b/python/pyproject.toml index d536f8832e1d..80cc0e9dc60e 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang" -version = "0.4.1.post4" +version = "0.4.1.post7" description = "SGLang is yet another fast serving framework for large language models and vision language models." readme = "README.md" requires-python = ">=3.8" @@ -23,11 +23,11 @@ runtime_common = [ "packaging", "pillow", "prometheus-client>=0.20.0", "psutil", "pydantic", "python-multipart", "pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop", - "xgrammar>=0.1.6" + "xgrammar>=0.1.10" ] srt = [ "sglang[runtime_common]", "cuda-python", - "sgl-kernel>=0.0.2.post11", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", + "sgl-kernel>=0.0.2.post14", "torch", "vllm==0.6.4.post1", "flashinfer==0.1.6" ] @@ -40,10 +40,15 @@ srt_xpu = ["sglang[runtime_common]"] #For Intel Gaudi(device : hpu) follow the installation guide #https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html srt_hpu = ["sglang[runtime_common]"] +# CPU: currently, there are no pre-built vllm wheels for CPU. +# To install vllm for CPU, please follow the instruction here: +# https://docs.vllm.ai/en/latest/getting_started/installation/cpu/index.html +srt_cpu = ["sglang[runtime_common]", "torch"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] litellm = ["litellm>=1.0.0"] +torch_memory_saver = ["torch_memory_saver"] test = [ "jsonlines", "matplotlib", @@ -56,11 +61,13 @@ all = ["sglang[srt]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all_hip = ["sglang[srt_hip]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all_xpu = ["sglang[srt_xpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] all_hpu = ["sglang[srt_hpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] +all_cpu = ["sglang[srt_cpu]", "sglang[openai]", "sglang[anthropic]", "sglang[litellm]"] dev = ["sglang[all]", "sglang[test]"] dev_hip = ["sglang[all_hip]", "sglang[test]"] dev_xpu = ["sglang[all_xpu]", "sglang[test]"] dev_hpu = ["sglang[all_hpu]", "sglang[test]"] +dev_cpu = ["sglang[all_cpu]", "sglang[test]"] [project.urls] "Homepage" = "https://github.com/sgl-project/sglang" diff --git a/python/sglang/__init__.py b/python/sglang/__init__.py index de9134857a61..70d58043d40c 100644 --- a/python/sglang/__init__.py +++ b/python/sglang/__init__.py @@ -1,5 +1,6 @@ -# SGL API Components +# SGLang public APIs +# Frontend Language APIs from sglang.api import ( Engine, Runtime, @@ -23,16 +24,26 @@ user_end, video, ) +from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint from sglang.lang.choices import ( greedy_token_selection, token_length_normalized, unconditional_likelihood_normalized, ) +from sglang.utils import LazyImport + +Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic") +LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM") +OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") +VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI") + +# Other configs +from sglang.global_config import global_config +from sglang.version import __version__ -# SGLang DSL APIs __all__ = [ - "Runtime", "Engine", + "Runtime", "assistant", "assistant_begin", "assistant_end", @@ -52,27 +63,14 @@ "user_begin", "user_end", "video", + "RuntimeEndpoint", "greedy_token_selection", "token_length_normalized", "unconditional_likelihood_normalized", + "Anthropic", + "LiteLLM", + "OpenAI", + "VertexAI", + "global_config", + "__version__", ] - -# Global Configurations -from sglang.global_config import global_config - -__all__ += ["global_config"] - -from sglang.version import __version__ - -__all__ += ["__version__"] - -# SGLang Backends -from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint -from sglang.utils import LazyImport - -Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic") -LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM") -OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI") -VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI") - -__all__ += ["Anthropic", "LiteLLM", "OpenAI", "VertexAI", "RuntimeEndpoint"] diff --git a/python/sglang/api.py b/python/sglang/api.py index 9a30ad492da3..7ef306380a91 100644 --- a/python/sglang/api.py +++ b/python/sglang/api.py @@ -1,6 +1,5 @@ """Public APIs of the language.""" -import os import re from typing import Callable, List, Optional, Union @@ -33,19 +32,15 @@ def decorator(func): def Runtime(*args, **kwargs): - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - # Avoid importing unnecessary dependency - from sglang.srt.server import Runtime + from sglang.lang.backend.runtime_endpoint import Runtime return Runtime(*args, **kwargs) def Engine(*args, **kwargs): - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - # Avoid importing unnecessary dependency - from sglang.srt.server import Engine + from sglang.srt.entrypoints.engine import Engine return Engine(*args, **kwargs) diff --git a/python/sglang/bench_offline_throughput.py b/python/sglang/bench_offline_throughput.py index f32063b41ca9..9d56ff07c8be 100644 --- a/python/sglang/bench_offline_throughput.py +++ b/python/sglang/bench_offline_throughput.py @@ -27,7 +27,8 @@ sample_random_requests, set_ulimit, ) -from sglang.srt.server import Engine, Runtime +from sglang.lang.backend.runtime_endpoint import Runtime +from sglang.srt.entrypoints.engine import Engine from sglang.srt.server_args import ServerArgs @@ -39,20 +40,22 @@ class BenchArgs: dataset_path: str = "" num_prompts: int = 1000 sharegpt_output_len: Optional[int] = None + sharegpt_context_len: Optional[int] = None random_input_len: int = 1024 random_output_len: int = 1024 random_range_ratio: float = 0.0 - gen_num_groups: int = 64 - gen_prompts_per_group: int = 16 - gen_system_prompt_len: int = 2048 - gen_question_len: int = 128 - gen_output_len: int = 256 + gsp_num_groups: int = 64 + gsp_prompts_per_group: int = 16 + gsp_system_prompt_len: int = 2048 + gsp_question_len: int = 128 + gsp_output_len: int = 256 + seed: int = 1 disable_ignore_eos: bool = False extra_request_body: Optional[str] = None - seed: int = 1 + apply_chat_template: bool = False + profile: bool = False skip_warmup: bool = False do_not_exit: bool = False - profile: bool = False @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -82,6 +85,12 @@ def add_cli_args(parser: argparse.ArgumentParser): default=BenchArgs.sharegpt_output_len, help="Output length for each request. Overrides the output length from the ShareGPT dataset.", ) + parser.add_argument( + "--sharegpt-context-len", + type=int, + default=BenchArgs.sharegpt_context_len, + help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.", + ) parser.add_argument( "--random-input-len", type=int, @@ -102,51 +111,62 @@ def add_cli_args(parser: argparse.ArgumentParser): "used only for random dataset.", ) parser.add_argument( - "--gen-num-groups", + "--gsp-num-groups", type=int, - default=BenchArgs.gen_num_groups, + default=BenchArgs.gsp_num_groups, help="Number of groups with shared prefix, used" "only for generate-shared-prefix", ) parser.add_argument( - "--gen-prompts-per-group", + "--gsp-prompts-per-group", type=int, - default=BenchArgs.gen_prompts_per_group, + default=BenchArgs.gsp_prompts_per_group, help="Number of prompts per group of shared prefix, used" "only for generate-shared-prefix", ) parser.add_argument( - "--gen-system-prompt-len", + "--gsp-system-prompt-len", type=int, - default=BenchArgs.gen_system_prompt_len, + default=BenchArgs.gsp_system_prompt_len, help="System prompt length, used" "only for generate-shared-prefix", ) parser.add_argument( - "--gen-question-len", + "--gsp-question-len", type=int, - default=BenchArgs.gen_question_len, + default=BenchArgs.gsp_question_len, help="Question length, used" "only for generate-shared-prefix", ) parser.add_argument( - "--gen-output-len", + "--gsp-output-len", type=int, - default=BenchArgs.gen_output_len, + default=BenchArgs.gsp_output_len, help="Target length in tokens for outputs in generated-shared-prefix dataset", ) + parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--disable-ignore-eos", - type=bool, - default=BenchArgs.disable_ignore_eos, + action="store_true", help="Disable ignore EOS token", ) parser.add_argument( "--extra-request-body", metavar='{"key1": "value1", "key2": "value2"}', type=str, + default=BenchArgs.extra_request_body, help="Append given JSON object to the request payload. You can use this to specify" "additional generate params like sampling params.", ) - parser.add_argument("--seed", type=int, default=1, help="The random seed.") + parser.add_argument( + "--apply-chat-template", + action="store_true", + help="Apply chat template", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + ) parser.add_argument( "--skip-warmup", action="store_true", @@ -157,12 +177,6 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Do not exit the program. This is useful for nsys profile with --duration and --delay.", ) - parser.add_argument( - "--profile", - action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "SGLANG_TORCH_PROFILER_DIR to enable profiler.", - ) @classmethod def from_cli_args(cls, args: argparse.Namespace): diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index 63787addf0ed..bc7a9c7a1a71 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -9,7 +9,8 @@ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --load-format dummy ## sweep through multiple data points and store (append) the results in a jsonl file: python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run - +## run with profiling: +python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile # Usage (correctness test): python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct @@ -56,12 +57,12 @@ import torch.distributed as dist from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.entrypoints.engine import _set_envs_and_config from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams -from sglang.srt.server import _set_envs_and_config from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import configure_logger, kill_process_tree, suppress_other_loggers @@ -77,6 +78,8 @@ class BenchArgs: correctness_test: bool = False # This is only used for correctness test cut_len: int = 4 + profile: bool = False + profile_filename_prefix: str = "profile" @staticmethod def add_cli_args(parser: argparse.ArgumentParser): @@ -95,6 +98,16 @@ def add_cli_args(parser: argparse.ArgumentParser): ) parser.add_argument("--correctness-test", action="store_true") parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len) + parser.add_argument( + "--profile", action="store_true", help="Use Torch Profiler." + ) + parser.add_argument( + "--profile-filename-prefix", + type=str, + default=BenchArgs.profile_filename_prefix, + help="Prefix of the profiling file names. The full profiling result file(s) be " + '"[profile_filename_prefix]_batch[batch_size]_input[input_len]_output[output_len].trace.json.gz"', + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -216,6 +229,7 @@ def extend(reqs, model_runner): model_config=model_runner.model_config, enable_overlap=False, spec_algorithm=SpeculativeAlgorithm.NONE, + enable_custom_logit_processor=False, ) batch.prepare_for_extend() model_worker_batch = batch.get_model_worker_batch() @@ -286,7 +300,16 @@ def synchronize(device): def latency_test_run_once( - run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device + run_name, + model_runner, + rank_print, + reqs, + batch_size, + input_len, + output_len, + device, + profile, + profile_filename_prefix, ): max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len) if batch_size > max_batch_size: @@ -308,6 +331,17 @@ def latency_test_run_once( tot_latency = 0 + profiler = None + if profile: + profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + ) + profiler.start() + # Prefill synchronize(device) tic = time.time() @@ -338,6 +372,14 @@ def latency_test_run_once( f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s" ) + if profile: + profiler.stop() + profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz" + parent_dir = os.path.dirname(os.path.abspath(profile_filename)) + os.makedirs(parent_dir, exist_ok=True) + profiler.export_chrome_trace(profile_filename) + rank_print(f"torch profiler chrome trace saved to {profile_filename}") + # Record decode timing from 2nd output if output_len > 1: med_decode_latency = np.median(decode_latencies) @@ -386,6 +428,8 @@ def latency_test( bench_args.input_len[0], 8, # shorter decoding to speed up the warmup server_args.device, + profile=False, + profile_filename_prefix="", # not used ) rank_print("Benchmark ...") @@ -405,6 +449,8 @@ def latency_test( il, ol, server_args.device, + bench_args.profile if tp_rank == 0 else None, + bench_args.profile_filename_prefix, ) if ret is not None: result_list.append(ret) diff --git a/python/sglang/bench_one_batch_server.py b/python/sglang/bench_one_batch_server.py index 01cc561e1ced..5f0759a7ce1b 100644 --- a/python/sglang/bench_one_batch_server.py +++ b/python/sglang/bench_one_batch_server.py @@ -22,7 +22,7 @@ import numpy as np import requests -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import kill_process_tree diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 4744ad3386ba..10ce965be742 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -452,6 +452,8 @@ def get_dataset(args, tokenizer): num_requests=args.num_prompts, tokenizer=tokenizer, fixed_output_len=args.sharegpt_output_len, + context_len=args.sharegpt_context_len, + apply_chat_template=args.apply_chat_template, ) elif args.dataset_name == "random": input_requests = sample_random_requests( @@ -464,11 +466,11 @@ def get_dataset(args, tokenizer): ) elif args.dataset_name == "generated-shared-prefix": input_requests = sample_generated_shared_prefix_requests( - num_groups=args.gen_num_groups, - prompts_per_group=args.gen_prompts_per_group, - system_prompt_len=args.gen_system_prompt_len, - question_len=args.gen_question_len, - output_len=args.gen_output_len, + num_groups=args.gsp_num_groups, + prompts_per_group=args.gsp_prompts_per_group, + system_prompt_len=args.gsp_system_prompt_len, + question_len=args.gsp_question_len, + output_len=args.gsp_output_len, tokenizer=tokenizer, ) else: @@ -514,6 +516,9 @@ class BenchmarkMetrics: p99_itl_ms: float mean_e2e_latency_ms: float median_e2e_latency_ms: float + std_e2e_latency_ms: float + p99_e2e_latency_ms: float + concurrency: float SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" @@ -558,6 +563,8 @@ def sample_sharegpt_requests( num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, + context_len: Optional[int] = None, + apply_chat_template=False, ) -> List[Tuple[str, int, int]]: if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") @@ -588,6 +595,15 @@ def sample_sharegpt_requests( # Tokenize the prompts and completions. prompt = dataset[i][0] + + if apply_chat_template: + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + prompt = prompt.replace(tokenizer.bos_token, "") + prompt_token_ids = tokenizer.encode(prompt) completion = dataset[i][1] completion_token_ids = tokenizer.encode(completion) @@ -595,14 +611,15 @@ def sample_sharegpt_requests( output_len = ( len(completion_token_ids) if fixed_output_len is None else fixed_output_len ) - if prompt_len < 4 or output_len < 4: + + if prompt_len < 2 or output_len < 2: # Prune too short sequences. continue - if prompt_len > 1024 or ( - prompt_len + output_len > 2048 and fixed_output_len is None - ): + + if context_len and prompt_len + output_len > context_len: # Prune too long sequences. continue + filtered_dataset.append((prompt, prompt_len, output_len)) print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}") @@ -704,8 +721,8 @@ def get_gen_prefix_cache_path(args, tokenizer): # Create a unique cache filename based on the generation parameters cache_key = ( - f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_" - f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_" + f"gen_shared_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_" + f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_" f"{tokenizer.__class__.__name__}.pkl" ) return cache_dir / cache_key @@ -873,6 +890,9 @@ def calculate_metrics( p99_itl_ms=np.percentile(itls or 0, 99) * 1000, mean_e2e_latency_ms=np.mean(e2e_latencies) * 1000, median_e2e_latency_ms=np.median(e2e_latencies) * 1000, + std_e2e_latency_ms=np.std(e2e_latencies) * 1000, + p99_e2e_latency_ms=np.percentile(e2e_latencies, 99) * 1000, + concurrency=np.sum(e2e_latencies) / dur_s, ) return metrics, output_lens @@ -1024,6 +1044,7 @@ async def limited_request_func(request_func_input, pbar): "Total token throughput (tok/s):", metrics.total_throughput ) ) + print("{:<40} {:<10.2f}".format("Concurrency:", metrics.concurrency)) print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-")) print( "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms) @@ -1055,27 +1076,41 @@ async def limited_request_func(request_func_input, pbar): and metrics.output_throughput is not None ): result = { + # Arguments "backend": args.backend, "dataset_name": args.dataset_name, "request_rate": request_rate, "max_concurrency": max_concurrency, + "sharegpt_output_len": args.sharegpt_output_len, + "random_input_len": args.random_input_len, + "random_output_len": args.random_output_len, + "random_range_ratio": args.random_range_ratio, + # Results + "duration": benchmark_duration, + "completed": metrics.completed, "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "total_output_tokens_retokenized": metrics.total_output_retokenized, + "request_throughput": metrics.request_throughput, + "input_throughput": metrics.input_throughput, + "output_throughput": metrics.output_throughput, "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, "median_e2e_latency_ms": metrics.median_e2e_latency_ms, + "std_e2e_latency_ms": metrics.std_e2e_latency_ms, + "p99_e2e_latency_ms": metrics.p99_e2e_latency_ms, "mean_ttft_ms": metrics.mean_ttft_ms, "median_ttft_ms": metrics.median_ttft_ms, + "std_ttft_ms": metrics.std_ttft_ms, + "p99_ttft_ms": metrics.p99_ttft_ms, + "mean_tpot_ms": metrics.mean_tpot_ms, + "median_tpot_ms": metrics.median_tpot_ms, + "std_tpot_ms": metrics.std_tpot_ms, + "p99_tpot_ms": metrics.p99_tpot_ms, "mean_itl_ms": metrics.mean_itl_ms, "median_itl_ms": metrics.median_itl_ms, - "input_throughput": metrics.input_throughput, - "output_throughput": metrics.output_throughput, - "sharegpt_output_len": args.sharegpt_output_len, - "random_input_len": args.random_input_len, - "random_output_len": args.random_output_len, - "random_range_ratio": args.random_range_ratio, - "duration": benchmark_duration, - "completed": metrics.completed, + "std_itl_ms": metrics.std_itl_ms, + "p99_itl_ms": metrics.p99_itl_ms, + "concurrency": metrics.concurrency, } else: print(f"Error running benchmark for request rate: {request_rate}") @@ -1095,36 +1130,16 @@ async def limited_request_func(request_func_input, pbar): with open(output_file_name, "a") as file: file.write(json.dumps(result) + "\n") - result = { - "duration": benchmark_duration, - "completed": metrics.completed, - "total_input_tokens": metrics.total_input, - "total_output_tokens": metrics.total_output, - "total_output_tokens_retokenized": metrics.total_output_retokenized, - "request_throughput": metrics.request_throughput, - "input_throughput": metrics.input_throughput, - "output_throughput": metrics.output_throughput, - "mean_ttft_ms": metrics.mean_ttft_ms, - "median_ttft_ms": metrics.median_ttft_ms, - "std_ttft_ms": metrics.std_ttft_ms, - "p99_ttft_ms": metrics.p99_ttft_ms, - "mean_tpot_ms": metrics.mean_tpot_ms, - "median_tpot_ms": metrics.median_tpot_ms, - "std_tpot_ms": metrics.std_tpot_ms, - "p99_tpot_ms": metrics.p99_tpot_ms, - "mean_itl_ms": metrics.mean_itl_ms, - "median_itl_ms": metrics.median_itl_ms, - "std_itl_ms": metrics.std_itl_ms, - "p99_itl_ms": metrics.p99_itl_ms, - "input_lens": [output.prompt_len for output in outputs], - "output_lens": output_lens, - "ttfts": [output.ttft for output in outputs], - "itls": [output.itl for output in outputs], - "generated_texts": [output.generated_text for output in outputs], - "errors": [output.error for output in outputs], - "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, - "median_e2e_latency_ms": metrics.median_e2e_latency_ms, - } + result.update( + { + "input_lens": [output.prompt_len for output in outputs], + "output_lens": output_lens, + "ttfts": [output.ttft for output in outputs], + "itls": [output.itl for output in outputs], + "generated_texts": [output.generated_text for output in outputs], + "errors": [output.error for output in outputs], + } + ) return result @@ -1360,6 +1375,12 @@ def set_ulimit(target_soft_limit=65535): default=None, help="Output length for each request. Overrides the output length from the ShareGPT dataset.", ) + parser.add_argument( + "--sharegpt-context-len", + type=int, + default=None, + help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.", + ) parser.add_argument( "--random-input-len", type=int, @@ -1399,7 +1420,6 @@ def set_ulimit(target_soft_limit=65535): "actual request rate may be lower than specified with --request-rate, " "if the server is not processing requests fast enough to keep up.", ) - parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( "--multi", action="store_true", @@ -1423,14 +1443,15 @@ def set_ulimit(target_soft_limit=65535): help="Disable streaming mode.", ) parser.add_argument( - "--disable-ignore-eos", + "--return-logprob", action="store_true", - help="Disable ignoring EOS.", + help="Return logprob.", ) + parser.add_argument("--seed", type=int, default=1, help="The random seed.") parser.add_argument( - "--return-logprob", + "--disable-ignore-eos", action="store_true", - help="Return logprob.", + help="Disable ignoring EOS.", ) parser.add_argument( "--extra-request-body", @@ -1439,49 +1460,54 @@ def set_ulimit(target_soft_limit=65535): help="Append given JSON object to the request payload. You can use this to specify" "additional generate params like sampling params.", ) + parser.add_argument( + "--apply-chat-template", + action="store_true", + help="Apply chat template", + ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "SGLANG_TORCH_PROFILER_DIR to enable profiler.", + ) + parser.add_argument( + "--lora-name", + type=str, + default=None, + help="The name of LoRA adapter", + ) group = parser.add_argument_group("generated-shared-prefix dataset arguments") group.add_argument( - "--gen-num-groups", + "--gsp-num-groups", type=int, default=64, help="Number of system prompt groups for generated-shared-prefix dataset", ) group.add_argument( - "--gen-prompts-per-group", + "--gsp-prompts-per-group", type=int, default=16, help="Number of prompts per system prompt group for generated-shared-prefix dataset", ) group.add_argument( - "--gen-system-prompt-len", + "--gsp-system-prompt-len", type=int, default=2048, help="Target length in tokens for system prompts in generated-shared-prefix dataset", ) group.add_argument( - "--gen-question-len", + "--gsp-question-len", type=int, default=128, help="Target length in tokens for questions in generated-shared-prefix dataset", ) group.add_argument( - "--gen-output-len", + "--gsp-output-len", type=int, default=256, help="Target length in tokens for outputs in generated-shared-prefix dataset", ) - parser.add_argument( - "--profile", - action="store_true", - help="Use Torch Profiler. The endpoint must be launched with " - "SGLANG_TORCH_PROFILER_DIR to enable profiler.", - ) - parser.add_argument( - "--lora-name", - type=str, - default=None, - help="The name of LoRA adapter", - ) args = parser.parse_args() run_benchmark(args) diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 1261b6d0c9fe..01f10b9f063b 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -1,6 +1,11 @@ +import atexit import json +import multiprocessing import warnings -from typing import List, Optional +from typing import Dict, List, Optional, Union + +import aiohttp +import requests from sglang.global_config import global_config from sglang.lang.backend.base_backend import BaseBackend @@ -251,11 +256,12 @@ def select( } obj = self._generate_http_request(s, data) - normalized_prompt_logprobs = [ - r["meta_info"]["normalized_prompt_logprob"] for r in obj - ] input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj] output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj] + normalized_prompt_logprobs = [ + compute_normalized_prompt_logprobs(r["meta_info"]["input_token_logprobs"]) + for r in obj + ] # Remove extra token if no token healing occurred for i in range(len(input_token_logprobs)): @@ -319,3 +325,176 @@ def _add_images(self, s: StreamExecutor, data): def _assert_success(self, res): if res.status_code != 200: raise RuntimeError(res.json()) + + +def compute_normalized_prompt_logprobs(input_logprobs): + values = [x[0] for x in input_logprobs if x[0]] + return sum(values) / len(values) + + +class Runtime: + """ + A wrapper for the HTTP server. + This is used for launching the server in a python program without + using the commond line interface. + + It is mainly used for the frontend language. + You should use the Engine class if you want to do normal offline processing without the frontend language. + """ + + def __init__( + self, + log_level: str = "error", + *args, + **kwargs, + ): + """See the arguments in server_args.py::ServerArgs""" + # We delay the import of any `sglang.srt` components in `sglang.lang`, so users can run + # client code without installing SRT server and its dependency if they want. + from sglang.srt.entrypoints.http_server import launch_server + from sglang.srt.server_args import ServerArgs + from sglang.srt.utils import is_port_available + + self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) + + # Pre-allocate ports + for port in range(self.server_args.port, 40000): + if is_port_available(port): + break + self.server_args.port = port + + self.url = self.server_args.url() + self.generate_url = self.url + "/generate" + + # NOTE: We store pid instead of proc to fix some issues during __delete__ + self.pid = None + pipe_reader, pipe_writer = multiprocessing.Pipe(duplex=False) + + proc = multiprocessing.Process( + target=launch_server, + args=(self.server_args, pipe_writer), + ) + proc.start() + pipe_writer.close() + self.pid = proc.pid + + # Before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() + atexit.register(self.shutdown) + + # TODO: remove this pipe_writer mechanism and use `/health_generate` instead. + try: + init_state = pipe_reader.recv() + except EOFError: + init_state = "" + + if init_state != "ready": + self.shutdown() + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + + self.endpoint = RuntimeEndpoint(self.url) + + def shutdown(self): + from sglang.srt.utils import kill_process_tree + + if self.pid is not None: + kill_process_tree(self.pid) + self.pid = None + + def cache_prefix(self, prefix: str): + self.endpoint.cache_prefix(prefix) + + def get_tokenizer(self): + from sglang.srt.hf_transformers_utils import get_tokenizer + + return get_tokenizer( + self.server_args.tokenizer_path, + tokenizer_mode=self.server_args.tokenizer_mode, + trust_remote_code=self.server_args.trust_remote_code, + revision=self.server_args.revision, + ) + + async def async_generate( + self, + prompt: str, + sampling_params: Optional[Dict] = None, + ): + if self.server_args.skip_tokenizer_init: + json_data = { + "input_ids": prompt, + "sampling_params": sampling_params, + "stream": True, + } + else: + json_data = { + "text": prompt, + "sampling_params": sampling_params, + "stream": True, + } + pos = 0 + + timeout = aiohttp.ClientTimeout(total=3 * 3600) + async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: + async with session.post(self.generate_url, json=json_data) as response: + async for chunk, _ in response.content.iter_chunks(): + chunk = chunk.decode("utf-8") + if chunk and chunk.startswith("data:"): + if chunk == "data: [DONE]\n\n": + break + data = json.loads(chunk[5:].strip("\n")) + if "text" in data: + cur = data["text"][pos:] + if cur: + yield cur + pos += len(cur) + else: + yield data + + add_request = async_generate + + def generate( + self, + prompt: Union[str, List[str]], + sampling_params: Optional[Dict] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + ): + json_data = { + "text": prompt, + "sampling_params": sampling_params, + "return_logprob": return_logprob, + "logprob_start_len": logprob_start_len, + "top_logprobs_num": top_logprobs_num, + "lora_path": lora_path, + } + assert not isinstance(lora_path, list) or len(lora_path) == len(prompt) + response = requests.post( + self.url + "/generate", + json=json_data, + ) + return json.dumps(response.json()) + + def encode( + self, + prompt: Union[str, List[str], List[Dict], List[List[Dict]]], + ): + json_data = {"text": prompt} + response = requests.post(self.url + "/encode", json=json_data) + return json.dumps(response.json()) + + async def get_server_info(self): + async with aiohttp.ClientSession() as session: + async with session.get(f"{self.url}/get_server_info") as response: + if response.status == 200: + return await response.json() + else: + error_data = await response.json() + raise RuntimeError( + f"Failed to get server info. {error_data['error']['message']}" + ) + + def __del__(self): + self.shutdown() diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 4a774c4fb6b8..a2c91c561c29 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -88,7 +88,6 @@ def get_chat_template_by_model_path(model_path): ) ) - register_chat_template( ChatTemplate( name="claude", @@ -101,7 +100,6 @@ def get_chat_template_by_model_path(model_path): ) ) - register_chat_template( ChatTemplate( name="chatml", @@ -116,7 +114,6 @@ def get_chat_template_by_model_path(model_path): ) ) - register_chat_template( ChatTemplate( name="chatml-llava", @@ -132,7 +129,6 @@ def get_chat_template_by_model_path(model_path): ) ) - # There is default system prompt for qwen # reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1 # The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" @@ -219,6 +215,21 @@ def get_chat_template_by_model_path(model_path): ) ) +# https://huggingface.co/openbmb/MiniCPM-V-2_6 +register_chat_template( + ChatTemplate( + name="minicpmv", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ("", " "), + "user": ("user:", " "), + "assistant": ("assistant:", ""), + }, + stop_str=("<|im_end|>", "<|endoftext|>"), + image_token="(./)", + ) +) + # The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token. register_chat_template( ChatTemplate( @@ -343,6 +354,37 @@ def get_chat_template_by_model_path(model_path): ) +register_chat_template( + ChatTemplate( + name="deepseek-v3", + default_system_prompt=None, + role_prefix_and_suffix={ + "system": ( + "", + "", + ), + "user": ( + "<|User|>", + "", + ), + "assistant": ( + "<|Assistant|>", + "<|end▁of▁sentence|>", + ), + }, + stop_str=("<|end▁of▁sentence|>",), + ) +) + + +@register_chat_template_matching_function +def match_deepseek(model_path: str): + if ( + "deepseek-v3" in model_path.lower() or "deepseek-r1" in model_path.lower() + ) and "base" not in model_path.lower(): + return get_chat_template("deepseek-v3") + + @register_chat_template_matching_function def match_dbrx(model_path: str): if "dbrx" in model_path.lower() and "instruct" in model_path.lower(): diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 6d1ca71adab1..4c294781c20e 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -347,7 +347,7 @@ def fork( size: int = 1, position_ids_offset: Optional[List[int]] = None, ): - if size > 1: + if size > 1 and str(self.text_): self.submit(SglCommitLazy()) self.sync() diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 6b0c25711c66..caae7b0f6cc7 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -3,7 +3,7 @@ import os import sys -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import prepare_server_args from sglang.srt.utils import kill_process_tree diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py deleted file mode 100644 index 138c2127e16e..000000000000 --- a/python/sglang/launch_server_llavavid.py +++ /dev/null @@ -1,25 +0,0 @@ -"""Launch the inference server for Llava-video model.""" - -import json -import sys - -from sglang.srt.server import launch_server, prepare_server_args - -if __name__ == "__main__": - server_args = prepare_server_args(sys.argv[1:]) - - model_override_args = {} - model_override_args["mm_spatial_pool_stride"] = 2 - model_override_args["architectures"] = ["LlavaVidForCausalLM"] - model_override_args["num_frames"] = 16 - model_override_args["model_type"] = "llavavid" - if model_override_args["num_frames"] == 32: - model_override_args["rope_scaling"] = {"factor": 2.0, "rope_type": "linear"} - model_override_args["max_sequence_length"] = 4096 * 2 - model_override_args["tokenizer_model_max_length"] = 4096 * 2 - model_override_args["model_max_length"] = 4096 * 2 - if "34b" in server_args.model_path.lower(): - model_override_args["image_token_index"] = 64002 - server_args.json_model_override_args = json.dumps(model_override_args) - - launch_server(server_args) diff --git a/python/sglang/srt/_custom_ops.py b/python/sglang/srt/_custom_ops.py index 9eb7caa1bbae..3cb313b9133b 100644 --- a/python/sglang/srt/_custom_ops.py +++ b/python/sglang/srt/_custom_ops.py @@ -1,8 +1,9 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/_custom_ops.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py import contextlib import functools import importlib import logging +import os from typing import TYPE_CHECKING, List, Optional, Tuple, Union import torch @@ -11,12 +12,19 @@ from sglang.srt.utils import is_hpu logger = logging.getLogger(__name__) +use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True) if not is_hpu(): - try: - import custom_ar - except ImportError as e: - logger.warning("Failed to import from custom_ar with %r", e) + if use_vllm_custom_allreduce: + try: + import vllm._C + except ImportError as e: + logger.warning("Failed to import from vllm._C with %r", e) + else: + try: + import sgl_kernel + except ImportError as e: + logger.warning("Failed to import from custom_ar with %r", e) def hint_on_error(fn): @@ -48,48 +56,78 @@ def wrapper(*args, **kwargs): return wrapper -# custom ar -def init_custom_ar( - ipc_tensors: List[torch.Tensor], - rank_data: torch.Tensor, - rank: int, - full_nvlink: bool, -) -> int: - return torch.ops._C_vllm_ar.init_custom_ar( - ipc_tensors, rank_data, rank, full_nvlink - ) - - -def all_reduce( - fa: int, - inp: torch.Tensor, - out: torch.Tensor, - reg_buffer: int, - reg_buffer_sz_bytes: int, -) -> None: - torch.ops._C_vllm_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) - - -def dispose(fa: int) -> None: - torch.ops._C_vllm_ar.dispose(fa) - - -def meta_size() -> int: - return torch.ops._C_vllm_ar.meta_size() - +if use_vllm_custom_allreduce: + # custom ar + def init_custom_ar( + ipc_tensors: List[torch.Tensor], + rank_data: torch.Tensor, + rank: int, + full_nvlink: bool, + ) -> int: + return torch.ops._C_custom_ar.init_custom_ar( + ipc_tensors, rank_data, rank, full_nvlink + ) -def register_buffer(fa: int, ipc_tensors: List[int]) -> None: - return torch.ops._C_vllm_ar.register_buffer(fa, ipc_tensors) + def all_reduce( + fa: int, + inp: torch.Tensor, + out: torch.Tensor, + reg_buffer: int, + reg_buffer_sz_bytes: int, + ) -> None: + torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) + + def dispose(fa: int) -> None: + torch.ops._C_custom_ar.dispose(fa) + + def meta_size() -> int: + return torch.ops._C_custom_ar.meta_size() + + def register_buffer(fa: int, ipc_tensors: List[int]) -> None: + return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors) + + def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: + return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) + + def register_graph_buffers( + fa: int, handles: List[List[int]], offsets: List[List[int]] + ) -> None: + torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) + +else: + # custom ar + def init_custom_ar( + rank_id: int, + world_size: int, + rank_data_base: torch.Tensor, + buffers: List[int], + tmp_result_buffers: List[int], + barrier_in: List[int], + barrier_out: List[int], + ) -> int: + return sgl_kernel.ops.init_custom_reduce( + rank_id, + world_size, + rank_data_base, + buffers, + tmp_result_buffers, + barrier_in, + barrier_out, + ) + def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: + sgl_kernel.ops.custom_reduce(fa, inp, out) -def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: - return torch.ops._C_vllm_ar.get_graph_buffer_ipc_meta(fa) + def dispose(fa: int) -> None: + sgl_kernel.ops.custom_dispose(fa) + def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: + return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa) -def register_graph_buffers( - fa: int, handles: List[List[int]], offsets: List[List[int]] -) -> None: - torch.ops._C_vllm_ar.register_graph_buffers(fa, handles, offsets) + def register_graph_buffers( + fa: int, handles: List[List[int]], offsets: List[List[int]] + ) -> None: + sgl_kernel.ops.register_graph_buffers(fa, handles, offsets) # temporary fix for https://github.com/vllm-project/vllm/issues/5456 diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index 600b58e49377..3d81c5d4fd50 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -1,3 +1,5 @@ +from sglang.srt.configs.chatglm import ChatGLMConfig +from sglang.srt.configs.dbrx import DbrxConfig from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.qwen2vl import Qwen2VLConfig, Qwen2VLVisionConfig @@ -5,4 +7,6 @@ "ExaoneConfig", "Qwen2VLConfig", "Qwen2VLVisionConfig", + "ChatGLMConfig", + "DbrxConfig", ] diff --git a/python/sglang/srt/configs/chatglm.py b/python/sglang/srt/configs/chatglm.py new file mode 100644 index 000000000000..9370c218aab8 --- /dev/null +++ b/python/sglang/srt/configs/chatglm.py @@ -0,0 +1,78 @@ +# Adapted from +# https://github.com/THUDM/ChatGLM2-6B +# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/chatglm.py + +# ChatGLM2 and ChatGLM3 share the same config. +# ChatGLM4 is officially supported by Huggingface +# transformers >= 4.46.0 is required +# https://huggingface.co/docs/transformers/en/model_doc/glm +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + attribute_map = { + "num_hidden_layers": "num_layers", + "n_head_kv": "multi_query_group_num", + } + + def __init__( + self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + interleaved_qkv=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs + ): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + # It is to be compatible with long lora. + self.max_position_embeddings = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = ( + apply_residual_connection_post_layernorm + ) + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + self.interleaved_qkv = interleaved_qkv + super().__init__(**kwargs) diff --git a/python/sglang/srt/configs/dbrx.py b/python/sglang/srt/configs/dbrx.py new file mode 100644 index 000000000000..75ccbde944ea --- /dev/null +++ b/python/sglang/srt/configs/dbrx.py @@ -0,0 +1,279 @@ +# Adapted from +# https://huggingface.co/databricks/dbrx-base/blob/main/configuration_dbrx.py +# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/dbrx.py +"""Dbrx configuration.""" + +from typing import Any, Optional + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DBRX_PRETRAINED_CONFIG_ARCHIVE_MAP = {} # type: ignore + + +class DbrxAttentionConfig(PretrainedConfig): + """Configuration class for Dbrx Attention. + + [`DbrxAttention`] class. It is used to instantiate attention layers + according to the specified arguments, defining the layers architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + attn_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability for the attention layers. + clip_qkv (`float`, *optional*, defaults to None): + If not `None`, clip the queries, keys, and values in the attention layer to this value. + kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. + rope_theta (float): The base frequency for rope. + """ + + def __init__( + self, + attn_pdrop: float = 0, + clip_qkv: Optional[float] = None, + kv_n_heads: int = 1, + rope_theta: float = 10000.0, + **kwargs: Any, + ): + super().__init__(**kwargs) + self.attn_pdrop = attn_pdrop + self.clip_qkv = clip_qkv + self.kv_n_heads = kv_n_heads + self.rope_theta = rope_theta + + for k in ["model_type"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: str, **kwargs: Any + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + if config_dict.get("model_type") == "dbrx": + config_dict = config_dict["attn_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all configurations of " + "models and can yield errors.", + config_dict["model_type"], + cls.model_type, + ) + + return cls.from_dict(config_dict, **kwargs) + + +class DbrxFFNConfig(PretrainedConfig): + """Configuration class for Dbrx FFN. + + [`DbrxFFN`] class. It is used to instantiate feedforward layers according to + the specified arguments, defining the layers architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + ffn_act_fn (dict, optional): A dict specifying activation function for the FFN. + The dict should have a key 'name' with the value being the name of + the activation function along with any additional keyword arguments. + ffn_hidden_size (int, optional): The hidden size of the feedforward network. + moe_num_experts (int, optional): The number of experts in the mixture of experts layer. + moe_top_k (int, optional): The number of experts to use in the mixture of experts layer. + moe_jitter_eps (float, optional): The jitter epsilon for the mixture of experts layer. + moe_loss_weight (float, optional): The loss weight for the mixture of experts layer. + moe_normalize_expert_weights (float, optional): The normalization factor for the expert weights. + uniform_expert_assignment (bool, optional): Whether to use uniform expert assignment. + This should only be used for benchmarking purposes. + """ + + def __init__( + self, + ffn_act_fn: Optional[dict] = None, + ffn_hidden_size: int = 3584, + moe_num_experts: int = 4, + moe_top_k: int = 1, + moe_jitter_eps: Optional[float] = None, + moe_loss_weight: float = 0.01, + moe_normalize_expert_weights: Optional[float] = 1, + uniform_expert_assignment: bool = False, + **kwargs: Any, + ): + super().__init__() + if ffn_act_fn is None: + ffn_act_fn = {"name": "silu"} + self.ffn_act_fn = ffn_act_fn + self.ffn_hidden_size = ffn_hidden_size + self.moe_num_experts = moe_num_experts + self.moe_top_k = moe_top_k + self.moe_jitter_eps = moe_jitter_eps + self.moe_loss_weight = moe_loss_weight + self.moe_normalize_expert_weights = moe_normalize_expert_weights + self.uniform_expert_assignment = uniform_expert_assignment + + for k in ["model_type"]: + if k in kwargs: + kwargs.pop(k) + if len(kwargs) != 0: + raise ValueError(f"Found unknown {kwargs=}") + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path: str, **kwargs: Any + ) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict( + pretrained_model_name_or_path, **kwargs + ) + + if config_dict.get("model_type") == "dbrx": + config_dict = config_dict["ffn_config"] + + if ( + "model_type" in config_dict + and hasattr(cls, "model_type") + and config_dict["model_type"] != cls.model_type + ): + logger.warning( + "You are using a model of type %s to instantiate a model of " + "type %s. This is not supported for all " + "configurations of models and can yield errors.", + config_dict["model_type"], + cls.model_type, + ) + + return cls.from_dict(config_dict, **kwargs) + + +class DbrxConfig(PretrainedConfig): + """Configuration class for Dbrx. + + [`DbrxModel`]. It is used to instantiate a Dbrx model according to the + specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + d_model (`int`, *optional*, defaults to 6144): + Dimensionality of the embeddings and hidden states. + n_heads (`int`, *optional*, defaults to 48): + Number of attention heads for each attention layer in the Transformer encoder. + n_layers (`int`, *optional*, defaults to 40): + Number of hidden layers in the Transformer encoder. + max_seq_len (`int`, *optional*, defaults to 32768): + The maximum sequence length of the model. + vocab_size (`int`, *optional*, defaults to 100352): + Vocabulary size of the Dbrx model. Defines the maximum number of different tokens that can be represented by + the `inputs_ids` passed when calling [`DbrxModel`]. + resid_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability applied to the attention output before combining with residual. + emb_pdrop (`float`, *optional*, defaults to 0.0): + The dropout probability for the embedding layer. + attn_config (`dict`, *optional*): + A dictionary used to configure the model's attention module. + ffn_config (`dict`, *optional*): + A dictionary used to configure the model's FFN module. + use_cache (`bool`, *optional*, defaults to `False`): + Whether or not the model should return the last key/values attentions (not used by all models). + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + output_router_logits (`bool`, *optional*, defaults to `False`): + Whether or not the router logits should be returned by the model. Enabling this will also + allow the model to output the auxiliary loss. See [here]() for more details + router_aux_loss_coef (`float`, *optional*, defaults to 0.001): + The aux loss factor for the total loss. + + + Example: + ```python + >>> from transformers import DbrxConfig, DbrxModel + + >>> # Initializing a Dbrx configuration + >>> configuration = DbrxConfig() + + >>> # Initializing a model (with random weights) from the configuration + >>> model = DbrxModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "dbrx" + attribute_map = { + "num_attention_heads": "n_heads", + "hidden_size": "d_model", + "num_hidden_layers": "n_layers", + "max_position_embeddings": "max_seq_len", + } + + def __init__( + self, + d_model: int = 2048, + n_heads: int = 16, + n_layers: int = 24, + max_seq_len: int = 2048, + vocab_size: int = 32000, + resid_pdrop: float = 0.0, + emb_pdrop: float = 0.0, + attn_config: Optional[DbrxAttentionConfig] = None, + ffn_config: Optional[DbrxFFNConfig] = None, + use_cache: bool = True, + initializer_range: float = 0.02, + output_router_logits: bool = False, + router_aux_loss_coef: float = 0.05, + **kwargs: Any, + ): + if attn_config is None: + self.attn_config = DbrxAttentionConfig() + elif isinstance(attn_config, dict): + self.attn_config = DbrxAttentionConfig(**attn_config) + else: + self.attn_config = attn_config + + if ffn_config is None: + self.ffn_config = DbrxFFNConfig() + elif isinstance(ffn_config, dict): + self.ffn_config = DbrxFFNConfig(**ffn_config) + else: + self.ffn_config = ffn_config + + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.use_cache = use_cache + self.initializer_range = initializer_range + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + + tie_word_embeddings = kwargs.pop("tie_word_embeddings", False) + if tie_word_embeddings: + raise ValueError("tie_word_embeddings is not supported for Dbrx models.") + + super().__init__( + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/python/sglang/srt/configs/device_config.py b/python/sglang/srt/configs/device_config.py index 74deb8919024..d95e848ddae6 100644 --- a/python/sglang/srt/configs/device_config.py +++ b/python/sglang/srt/configs/device_config.py @@ -10,7 +10,7 @@ class DeviceConfig: device: Optional[torch.device] def __init__(self, device: str = "cuda") -> None: - if device in ["cuda", "xpu", "hpu"]: + if device in ["cuda", "xpu", "hpu", "cpu"]: self.device_type = device else: raise RuntimeError(f"Not supported device type: {device}") diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index 2b2b341faeb5..6cb35ab47c68 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -20,6 +20,7 @@ class LoadFormat(str, enum.Enum): GGUF = "gguf" BITSANDBYTES = "bitsandbytes" MISTRAL = "mistral" + LAYERED = "layered" @dataclass diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index a2f9b82844e8..6d144f84433c 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -128,7 +128,7 @@ def __init__( self.num_hidden_layers = self.hf_text_config.num_hidden_layers self.vocab_size = self.hf_text_config.vocab_size - # Veirfy quantization + # Verify quantization self._verify_quantization() # Cache attributes @@ -223,7 +223,11 @@ def _verify_quantization(self) -> None: "compressed_tensors", "compressed-tensors", "experts_int8", + "w8a8_int8", ] + compatible_quantization_methods = { + "w8a8_int8": ["compressed-tensors", "compressed_tensors"] + } if self.quantization is not None: self.quantization = self.quantization.lower() @@ -247,12 +251,17 @@ def _verify_quantization(self) -> None: if self.quantization is None: self.quantization = quant_method elif self.quantization != quant_method: - raise ValueError( - "Quantization method specified in the model config " - f"({quant_method}) does not match the quantization " - f"method specified in the `quantization` argument " - f"({self.quantization})." - ) + if ( + self.quantization not in compatible_quantization_methods + or quant_method + not in compatible_quantization_methods[self.quantization] + ): + raise ValueError( + "Quantization method specified in the model config " + f"({quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization})." + ) if self.quantization is not None: if self.quantization not in supported_quantization: @@ -393,6 +402,7 @@ def is_multimodal_model(model_architectures: List[str]): or "LlavaVidForCausalLM" in model_architectures or "MllamaForConditionalGeneration" in model_architectures or "Qwen2VLForConditionalGeneration" in model_architectures + or "MiniCPMV" in model_architectures ): return True else: diff --git a/python/sglang/srt/constrained/__init__.py b/python/sglang/srt/constrained/__init__.py deleted file mode 100644 index 458d19252413..000000000000 --- a/python/sglang/srt/constrained/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# TODO(lmzheng): make this an optional dependency -from sglang.srt.constrained.outlines_backend import build_regex_from_object diff --git a/python/sglang/srt/constrained/base_grammar_backend.py b/python/sglang/srt/constrained/base_grammar_backend.py index 7c88229cf168..6f304ea171ea 100644 --- a/python/sglang/srt/constrained/base_grammar_backend.py +++ b/python/sglang/srt/constrained/base_grammar_backend.py @@ -18,6 +18,8 @@ from threading import Event, Lock from typing import Any, Optional, Tuple +from sglang.srt.server_args import ServerArgs + @dataclass class CacheEntry: @@ -69,3 +71,22 @@ def get_future_value(self, key: Tuple[str, str]) -> Future: def reset(self): with self.cache_lock: self.cache.clear() + + +def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size): + if server_args.grammar_backend == "outlines": + from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend + + grammar_backend = OutlinesGrammarBackend( + tokenizer, + whitespace_pattern=server_args.constrained_json_whitespace_pattern, + allow_jump_forward=not server_args.disable_jump_forward, + ) + elif server_args.grammar_backend == "xgrammar": + from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend + + grammar_backend = XGrammarGrammarBackend(tokenizer, vocab_size=vocab_size) + else: + raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}") + + return grammar_backend diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index b0b2c31c2ac9..c423a567eda8 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -19,6 +19,7 @@ import torch from xgrammar import ( CompiledGrammar, + Grammar, GrammarCompiler, GrammarMatcher, TokenizerInfo, @@ -133,10 +134,13 @@ def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar: logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}") return None elif key_type == "regex": - logger.warning( - "regex hasn't been supported by xgrammar yet. This is skipped." - ) - return None + try: + ctx = self.grammar_compiler.compile_grammar( + Grammar.from_regex(key_string) + ) + except RuntimeError as e: + logging.warning(f"Skip invalid regex: regex={key_string}, {e=}") + return None else: raise ValueError(f"Invalid key_type: {key_type}") diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 60dba87cb081..3a775aa1e95f 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -452,7 +452,6 @@ def generate_chat_conv( # Add a blank message for the assistant. conv.append_message(conv.roles[1], None) - return conv @@ -555,3 +554,17 @@ def generate_chat_conv( image_token="<|vision_start|><|image_pad|><|vision_end|>", ) ) + +# Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage +register_conv_template( + Conversation( + name="minicpmv", + system_message="You are a helpful assistant", + system_template="<|im_start|>system\n{system_message}.", + roles=("<|im_start|>user", "<|im_start|>assistant"), + sep="<|im_end|>\n", + sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE, + stop_str=("<|im_end|>", "<|endoftext|>"), + image_token="(./)", + ) +) diff --git a/python/sglang/srt/distributed/__init__.py b/python/sglang/srt/distributed/__init__.py index db325cfabf55..12f802055c50 100644 --- a/python/sglang/srt/distributed/__init__.py +++ b/python/sglang/srt/distributed/__init__.py @@ -1,3 +1,3 @@ -from .communication_op import * -from .parallel_state import * -from .utils import * +from sglang.srt.distributed.communication_op import * +from sglang.srt.distributed.parallel_state import * +from sglang.srt.distributed.utils import * diff --git a/python/sglang/srt/distributed/communication_op.py b/python/sglang/srt/distributed/communication_op.py index ddf3b8ef5689..95600edfb410 100644 --- a/python/sglang/srt/distributed/communication_op.py +++ b/python/sglang/srt/distributed/communication_op.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/communication_op.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/communication_op.py + from typing import Any, Dict, Optional, Union import torch diff --git a/python/sglang/srt/distributed/device_communicators/__init__.py b/python/sglang/srt/distributed/device_communicators/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py b/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py index ab4ee33fcfc4..c902f314112e 100644 --- a/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/cuda_wrapper.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/cuda_wrapper.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/cuda_wrapper.py + """This file is a pure Python wrapper for the cudart library. It avoids the need to compile a separate shared library, and is convenient for use when we just need to call a few functions. diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py index b6df234407d8..faeac0bbae9e 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce.py + import ctypes import logging import os @@ -6,7 +7,6 @@ from functools import wraps from typing import Callable, List, Optional, TypeVar, Union -import pynvml import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -20,8 +20,19 @@ from sglang.srt.distributed.parallel_state import in_the_same_node_as from sglang.srt.utils import cuda_device_count_stateless, is_cuda +logger = logging.getLogger(__name__) + +if is_cuda(): + try: + import pynvml + except ImportError as e: + logger.warning("Failed to import pynvml with %r", e) + try: - ops.meta_size() + if ops.use_vllm_custom_allreduce: + ops.meta_size() + else: + import sgl_kernel custom_ar = True except Exception: # For AMD GPUs and CPUs @@ -29,7 +40,6 @@ logger = logging.getLogger(__name__) - _P = ParamSpec("_P") _R = TypeVar("_R") @@ -47,7 +57,7 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: @with_nvml_context -def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: +def is_full_nvlink(physical_device_ids: List[int]) -> bool: """ query if the set of gpus are fully connected by nvlink (1 hop) """ @@ -175,9 +185,12 @@ def __init__( # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - assert is_cuda() + if is_cuda(): + assert is_cuda() - full_nvlink = is_full_nvlink(physical_device_ids) + full_nvlink = is_full_nvlink(physical_device_ids) + else: + full_nvlink = False if world_size > 2 and not full_nvlink: logger.warning( "Custom allreduce is disabled because it's not supported on" @@ -196,32 +209,64 @@ def __init__( ) return - self.disabled = False - # Buffers memory are owned by this Python class and passed to C++. - # Meta data composes of two parts: meta data for synchronization and a - # temporary buffer for storing intermediate allreduce results. - self.meta_ptrs = self.create_shared_buffer( - ops.meta_size() + max_size, group=group - ) - # This is a pre-registered IPC buffer. In eager mode, input tensors - # are first copied into this buffer before allreduce is performed - self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) - # This is a buffer for storing the tuples of pointers pointing to - # IPC buffers from all ranks. Each registered tuple has size of - # 8*world_size bytes where world_size is at most 8. Allocating 8MB - # is enough for 131072 such tuples. The largest model I've seen only - # needs less than 10000 of registered tuples. - self.rank_data = torch.empty( - 8 * 1024 * 1024, dtype=torch.uint8, device=self.device - ) self.max_size = max_size self.rank = rank self.world_size = world_size self.full_nvlink = full_nvlink - self._ptr = ops.init_custom_ar( - self.meta_ptrs, self.rank_data, rank, self.full_nvlink - ) - ops.register_buffer(self._ptr, self.buffer_ptrs) + + if ops.use_vllm_custom_allreduce: + # Buffers memory are owned by this Python class and passed to C++. + # Meta data composes of two parts: meta data for synchronization and a + # temporary buffer for storing intermediate allreduce results. + self.meta_ptrs = self.create_shared_buffer( + ops.meta_size() + max_size, group=group + ) + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) + self._ptr = ops.init_custom_ar( + self.meta_ptrs, self.rank_data, rank, self.full_nvlink + ) + ops.register_buffer(self._ptr, self.buffer_ptrs) + else: + # From TensorRT-LLM getMaxRequiredWorkspaceSize + self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024] + + # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE; + self.barrier_max_size = 8 * (36 + 2) * 8 + + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + self.tmp_result_buffer_ptrs = self.create_shared_buffer( + max_size, group=group + ) + self.rank_data_base = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=self.device + ) + self.barrier_in_ptrs = self.create_shared_buffer( + self.barrier_max_size, group=group + ) + self.barrier_out_ptrs = self.create_shared_buffer( + self.barrier_max_size, group=group + ) + + self._ptr = ops.init_custom_ar( + rank, + world_size, + self.rank_data_base, + self.buffer_ptrs, + self.tmp_result_buffer_ptrs, + self.barrier_in_ptrs, + self.barrier_out_ptrs, + ) + self.disabled = False @staticmethod def create_shared_buffer( @@ -300,12 +345,31 @@ def should_custom_ar(self, inp: torch.Tensor): return False # for 4 or more non NVLink-capable GPUs, custom allreduce provides # little performance improvement over NCCL. - if self.world_size == 2 or self.full_nvlink: - return inp_size < self.max_size + if ops.use_vllm_custom_allreduce: + if self.world_size == 2 or self.full_nvlink: + return inp_size < self.max_size + return False + + if self.world_size == 2: + return ( + inp_size < self.max_size + and inp_size < self.max_required_workspace_size[0] + ) + + if self.full_nvlink: + return ( + inp_size < self.max_size + and inp_size < self.max_required_workspace_size[1] + ) + return False def all_reduce( - self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False + self, + inp: torch.Tensor, + *, + out: torch.Tensor = None, + registered: bool = False, ): """Performs an out-of-place all reduce. @@ -315,12 +379,15 @@ def all_reduce( """ if out is None: out = torch.empty_like(inp) - if registered: - ops.all_reduce(self._ptr, inp, out, 0, 0) + if ops.use_vllm_custom_allreduce: + if registered: + ops.all_reduce(self._ptr, inp, out, 0, 0) + else: + ops.all_reduce( + self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size + ) else: - ops.all_reduce( - self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size - ) + ops.all_reduce(self._ptr, inp, out) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: @@ -336,17 +403,20 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: # allreduce is out-of-place. return torch.empty_like(input) else: - # Note: outside of cuda graph context, custom allreduce incurs a - # cost of cudaMemcpy, which should be small (<=1% of overall - # latency) compared to the performance gain of using custom kernels return self.all_reduce(input, registered=False) def close(self): if not self.disabled and self._ptr: ops.dispose(self._ptr) + if ops.use_vllm_custom_allreduce: + self.free_shared_buffer(self.meta_ptrs) + self.free_shared_buffer(self.buffer_ptrs) + else: + self.free_shared_buffer(self.buffer_ptrs) + self.free_shared_buffer(self.tmp_result_buffer_ptrs) + self.free_shared_buffer(self.barrier_in_ptrs) + self.free_shared_buffer(self.barrier_out_ptrs) self._ptr = 0 - self.free_shared_buffer(self.meta_ptrs) - self.free_shared_buffer(self.buffer_ptrs) def __del__(self): self.close() diff --git a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py index d807dfd5ce59..4073491aa621 100644 --- a/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +++ b/python/sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce_utils.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce_utils.py + import ctypes import json import logging @@ -7,7 +8,6 @@ import subprocess import sys import tempfile -from functools import lru_cache from itertools import product from typing import Dict, List, Optional, Sequence diff --git a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py index 72ef3889e014..722e494cf775 100644 --- a/python/sglang/srt/distributed/device_communicators/hpu_communicator.py +++ b/python/sglang/srt/distributed/device_communicators/hpu_communicator.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/hpu_communicator.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/hpu_communicator.py + import torch import torch.distributed as dist from torch.distributed import ProcessGroup diff --git a/python/sglang/srt/distributed/device_communicators/pynccl.py b/python/sglang/srt/distributed/device_communicators/pynccl.py index baee270da907..9f65939f6d91 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl.py @@ -1,8 +1,10 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py + import logging from contextlib import contextmanager from typing import Optional, Union +# ===================== import region ===================== import torch import torch.distributed as dist from torch.distributed import ProcessGroup, ReduceOp @@ -143,6 +145,57 @@ def all_reduce( cudaStream_t(stream.cuda_stream), ) + def all_gather( + self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}" + ) + if stream is None: + stream = self.stream + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), + input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + + def reduce_scatter( + self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None, + ): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}" + ) + if stream is None: + stream = self.stream + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), + output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), + self.comm, + cudaStream_t(stream.cuda_stream), + ) + def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: return @@ -179,6 +232,32 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): cudaStream_t(stream.cuda_stream), ) + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}" + ) + if stream is None: + stream = self.stream + if src == self.rank: + sendbuff = buffer_type(tensor.data_ptr()) + # NCCL requires the sender also to have a receive buffer + recvbuff = buffer_type(tensor.data_ptr()) + else: + sendbuff = buffer_type() + recvbuff = buffer_type(tensor.data_ptr()) + self.nccl.ncclBroadcast( + sendbuff, + recvbuff, + tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), + src, + self.comm, + cudaStream_t(stream.cuda_stream), + ) + @contextmanager def change_state( self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None diff --git a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py index e72284f51178..afb47733476a 100644 --- a/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py +++ b/python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py # This file is a pure Python wrapper for the NCCL library. # The main purpose is to use NCCL combined with CUDA graph. @@ -57,7 +57,7 @@ def find_nccl_library() -> str: so_file = "librccl.so.1" else: raise ValueError("NCCL only supports CUDA and ROCm backends.") - logger.info("Found nccl from library %s", so_file) + logger.debug("Found nccl from library %s", so_file) return so_file @@ -187,6 +187,43 @@ class NCCLLibrary: cudaStream_t, ], ), + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclAllGather", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclComm_t, + cudaStream_t, + ], + ), + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function( + "ncclReduceScatter", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ncclRedOp_t, + ncclComm_t, + cudaStream_t, + ], + ), # ncclResult_t ncclSend( # const void* sendbuff, size_t count, ncclDataType_t datatype, # int dest, ncclComm_t comm, cudaStream_t stream); @@ -217,6 +254,23 @@ class NCCLLibrary: cudaStream_t, ], ), + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function( + "ncclBroadcast", + ncclResult_t, + [ + buffer_type, + buffer_type, + ctypes.c_size_t, + ncclDataType_t, + ctypes.c_int, + ncclComm_t, + cudaStream_t, + ], + ), # be cautious! this is a collective call, it will block until all # processes in the communicator have called this function. # because Python object destruction can happen in random order, @@ -321,6 +375,46 @@ def ncclAllReduce( ) ) + def ncclReduceScatter( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + op: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclReduceScatter"]( + sendbuff, recvbuff, count, datatype, op, comm, stream + ) + ) + + def ncclAllGather( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK( + self._funcs["ncclAllGather"]( + sendbuff, recvbuff, count, datatype, comm, stream + ) + ) + def ncclSend( self, sendbuff: buffer_type, @@ -347,6 +441,22 @@ def ncclRecv( self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream) ) + def ncclBroadcast( + self, + sendbuff: buffer_type, + recvbuff: buffer_type, + count: int, + datatype: int, + root: int, + comm: ncclComm_t, + stream: cudaStream_t, + ) -> None: + self.NCCL_CHECK( + self._funcs["ncclBroadcast"]( + sendbuff, recvbuff, count, datatype, root, comm, stream + ) + ) + def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) diff --git a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py index 1afe6fca5266..7a3b22e27a81 100644 --- a/python/sglang/srt/distributed/device_communicators/shm_broadcast.py +++ b/python/sglang/srt/distributed/device_communicators/shm_broadcast.py @@ -1,11 +1,9 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/shm_broadcast.py -import ipaddress +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/shm_broadcast.py + import logging import os import pickle -import socket import time -import warnings from contextlib import contextmanager from dataclasses import dataclass, field from multiprocessing import shared_memory @@ -18,6 +16,8 @@ from zmq import IPV6 # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore +from sglang.srt.utils import get_ip, get_open_port, is_valid_ipv6_address + # SGLANG_RINGBUFFER_WARNING_INTERVAL can be set to 60 SGLANG_RINGBUFFER_WARNING_INTERVAL = int( os.environ.get("SGLANG_RINGBUFFER_WARNING_INTERVAL", "60") @@ -26,73 +26,6 @@ logger = logging.getLogger(__name__) -def get_ip() -> str: - # SGLANG_HOST_IP env can be ignore - host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "") - if host_ip: - return host_ip - - # IP is not set, try to get it from the network interface - - # try ipv4 - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable - return s.getsockname()[0] - except Exception: - pass - - # try ipv6 - try: - s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) - # Google's public DNS server, see - # https://developers.google.com/speed/public-dns/docs/using#addresses - s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable - return s.getsockname()[0] - except Exception: - pass - - warnings.warn( - "Failed to get the IP address, using 0.0.0.0 by default." - "The value can be set by the environment variable" - " SGLANG_HOST_IP or HOST_IP.", - stacklevel=2, - ) - return "0.0.0.0" - - -def get_open_port() -> int: - - port = os.getenv("SGLANG_PORT") - if port is not None: - while True: - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", port)) - return port - except OSError: - port += 1 # Increment port number if already in use - logger.info("Port %d is already in use, trying port %d", port - 1, port) - # try ipv4 - try: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - except OSError: - # try ipv6 - with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - - -def is_valid_ipv6_address(address: str) -> bool: - try: - ipaddress.IPv6Address(address) - return True - except ValueError: - return False - - class ShmRingBuffer: def __init__( @@ -313,7 +246,7 @@ def __init__( remote_subscribe_port=remote_subscribe_port, ) - logger.info("vLLM message queue communication handle: %s", self.handle) + logger.debug("Message queue communication handle: %s", self.handle) def export_handle(self) -> Handle: return self.handle diff --git a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py index ff0981b80bc8..532279f70c35 100644 --- a/python/sglang/srt/distributed/device_communicators/xpu_communicator.py +++ b/python/sglang/srt/distributed/device_communicators/xpu_communicator.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/xpu_communicator.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/xpu_communicator.py + import torch import torch.distributed as dist from torch.distributed import ProcessGroup diff --git a/python/sglang/srt/distributed/parallel_state.py b/python/sglang/srt/distributed/parallel_state.py index 26d04b04ce91..c6d1a8307818 100644 --- a/python/sglang/srt/distributed/parallel_state.py +++ b/python/sglang/srt/distributed/parallel_state.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/parallel_state.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/parallel_state.py # Copyright 2023 The vLLM team. # Adapted from diff --git a/python/sglang/srt/distributed/utils.py b/python/sglang/srt/distributed/utils.py index a225fbb91820..e117aa30d073 100644 --- a/python/sglang/srt/distributed/utils.py +++ b/python/sglang/srt/distributed/utils.py @@ -1,4 +1,5 @@ -# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/utils.py +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/utils.py + # Copyright 2023 The vLLM team. # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py new file mode 100644 index 000000000000..310e92c23d95 --- /dev/null +++ b/python/sglang/srt/entrypoints/engine.py @@ -0,0 +1,449 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +The entry point of inference server. (SRT = SGLang Runtime) + +This file implements python APIs for the inference engine. +""" + +import asyncio +import atexit +import dataclasses +import logging +import multiprocessing as mp +import os +import signal +import threading +from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union + +# Fix a bug of Python threading +setattr(threading, "_register_atexit", lambda *args, **kwargs: None) + +import torch +import uvloop + +from sglang.srt.managers.data_parallel_controller import ( + run_data_parallel_controller_process, +) +from sglang.srt.managers.detokenizer_manager import run_detokenizer_process +from sglang.srt.managers.io_struct import ( + EmbeddingReqInput, + GenerateReqInput, + GetWeightsByNameReqInput, + InitWeightsUpdateGroupReqInput, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, + UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromTensorReqInput, +) +from sglang.srt.managers.scheduler import run_scheduler_process +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api +from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter +from sglang.srt.utils import ( + MultiprocessingSerializer, + assert_pkg_version, + configure_logger, + kill_process_tree, + maybe_set_triton_cache_manager, + prepare_model_and_tokenizer, + set_prometheus_multiproc_dir, + set_ulimit, +) +from sglang.version import __version__ + +logger = logging.getLogger(__name__) +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +class Engine: + """ + The entry point to the inference engine. + + - The engine consists of three components: + 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. + 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. + 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. + + Note: + 1. The HTTP server, Engine, and TokenizerManager both run in the main process. + 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. + """ + + def __init__(self, **kwargs): + """ + The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`. + Please refer to `ServerArgs` for the documentation. + """ + if "server_args" in kwargs: + # Directly load server_args + server_args = kwargs["server_args"] + else: + # Construct server_args from kwargs + if "log_level" not in kwargs: + # Do not print logs by default + kwargs["log_level"] = "error" + server_args = ServerArgs(**kwargs) + + # Shutdown the subprocesses automatically when the program exists + atexit.register(self.shutdown) + + # Launch subprocesses + tokenizer_manager, scheduler_info = _launch_subprocesses( + server_args=server_args + ) + self.tokenizer_manager = tokenizer_manager + self.scheduler_info = scheduler_info + + def generate( + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + stream: bool = False, + ) -> Union[Dict, Iterator[Dict]]: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. + Please refer to `GenerateReqInput` for the documentation. + """ + obj = GenerateReqInput( + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + lora_path=lora_path, + custom_logit_processor=custom_logit_processor, + stream=stream, + ) + loop = asyncio.get_event_loop() + generator = self.tokenizer_manager.generate_request(obj, None) + + if stream: + + def generator_wrapper(): + while True: + try: + chunk = loop.run_until_complete(generator.__anext__()) + yield chunk + except StopAsyncIteration: + break + + return generator_wrapper() + else: + ret = loop.run_until_complete(generator.__anext__()) + return ret + + async def async_generate( + self, + # The input prompt. It can be a single prompt or a batch of prompts. + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict], Dict]] = None, + # The token ids for text; one can either specify text or input_ids. + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + return_logprob: Optional[Union[List[bool], bool]] = False, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[List[Optional[str]]] = None, + custom_logit_processor: Optional[Union[List[str], str]] = None, + stream: bool = False, + ) -> Union[Dict, AsyncIterator[Dict]]: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`. + Please refer to `GenerateReqInput` for the documentation. + """ + obj = GenerateReqInput( + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + lora_path=lora_path, + stream=stream, + custom_logit_processor=custom_logit_processor, + ) + generator = self.tokenizer_manager.generate_request(obj, None) + + if stream is True: + return generator + else: + return await generator.__anext__() + + def encode( + self, + prompt: Union[str, List[str], List[Dict], List[List[Dict]]], + ) -> Dict: + """ + The arguments of this function is the same as `sglang/srt/managers/io_struct.py::EmbeddingReqInput`. + Please refer to `EmbeddingReqInput` for the documentation. + """ + + obj = EmbeddingReqInput(text=prompt) + loop = asyncio.get_event_loop() + generator = self.tokenizer_manager.generate_request(obj, None) + ret = loop.run_until_complete(generator.__anext__()) + return ret + + def shutdown(self): + """Shutdown the engine""" + kill_process_tree(os.getpid(), include_parent=False) + + def start_profile(self): + self.tokenizer_manager.start_profile() + + def stop_profile(self): + self.tokenizer_manager.stop_profile() + + def get_server_info(self): + return { + **dataclasses.asdict(self.tokenizer_manager.server_args), # server args + **self.scheduler_info, + "version": __version__, + } + + def init_weights_update_group( + self, + master_address: str, + master_port: int, + rank_offset: int, + world_size: int, + group_name: str, + backend: str = "nccl", + ): + """Initialize parameter update group.""" + obj = InitWeightsUpdateGroupReqInput( + master_address=master_address, + master_port=master_port, + rank_offset=rank_offset, + world_size=world_size, + group_name=group_name, + backend=backend, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.init_weights_update_group(obj, None) + ) + + def update_weights_from_distributed(self, name: str, dtype, shape): + """Update weights from distributed source.""" + obj = UpdateWeightsFromDistributedReqInput( + name=name, + dtype=dtype, + shape=shape, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.update_weights_from_distributed(obj, None) + ) + + def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]): + """Update weights from distributed source.""" + obj = UpdateWeightsFromTensorReqInput( + serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors) + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.update_weights_from_tensor(obj, None) + ) + + def get_weights_by_name(self, name: str, truncate_size: int = 100): + """Get weights by parameter name.""" + obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.get_weights_by_name(obj, None) + ) + + def release_memory_occupation(self): + """Release GPU occupation temporarily.""" + obj = ReleaseMemoryOccupationReqInput() + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.release_memory_occupation(obj, None) + ) + + def resume_memory_occupation(self): + """Resume GPU occupation.""" + obj = ResumeMemoryOccupationReqInput() + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.resume_memory_occupation(obj, None) + ) + + +def _set_envs_and_config(server_args: ServerArgs): + # Set global environments + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + os.environ["NCCL_CUMEM_ENABLE"] = "0" + os.environ["NCCL_NVLS_ENABLE"] = "0" + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" + + # Set prometheus env vars + if server_args.enable_metrics: + set_prometheus_multiproc_dir() + + # Set ulimit + set_ulimit() + + # Fix triton bugs + if server_args.tp_size * server_args.dp_size > 1: + # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. + maybe_set_triton_cache_manager() + + # Check flashinfer version + if server_args.attention_backend == "flashinfer": + assert_pkg_version( + "flashinfer", + "0.1.6", + "Please uninstall the old version and " + "reinstall the latest version by following the instructions " + "at https://docs.flashinfer.ai/installation.html.", + ) + + # Register the signal handler. + # The child processes will send SIGQUIT to this process when any error happens + # This process then clean up the whole process tree + def sigquit_handler(signum, frame): + logger.error( + "Received sigquit from a child proces. It usually means the child failed." + ) + kill_process_tree(os.getpid()) + + signal.signal(signal.SIGQUIT, sigquit_handler) + + # Set mp start method + mp.set_start_method("spawn", force=True) + + +def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]: + """ + Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. + """ + # Configure global environment + configure_logger(server_args) + server_args.check_server_args() + _set_envs_and_config(server_args) + + # Allocate ports for inter-process communications + port_args = PortArgs.init_new(server_args) + logger.info(f"{server_args=}") + + # If using model from www.modelscope.cn, first download the model. + server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( + server_args.model_path, server_args.tokenizer_path + ) + + scheduler_procs = [] + if server_args.dp_size == 1: + # Launch tensor parallel scheduler processes + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.enable_memory_saver + ) + + scheduler_pipe_readers = [] + tp_size_per_node = server_args.tp_size // server_args.nnodes + tp_rank_range = range( + tp_size_per_node * server_args.node_rank, + tp_size_per_node * (server_args.node_rank + 1), + ) + for tp_rank in tp_rank_range: + reader, writer = mp.Pipe(duplex=False) + gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node + proc = mp.Process( + target=run_scheduler_process, + args=(server_args, port_args, gpu_id, tp_rank, None, writer), + ) + with memory_saver_adapter.configure_subprocess(): + proc.start() + scheduler_procs.append(proc) + scheduler_pipe_readers.append(reader) + else: + # Launch the data parallel controller + reader, writer = mp.Pipe(duplex=False) + scheduler_pipe_readers = [reader] + proc = mp.Process( + target=run_data_parallel_controller_process, + args=(server_args, port_args, writer), + ) + proc.start() + scheduler_procs.append(proc) + + if server_args.node_rank >= 1: + # In multi-node cases, non-zero rank nodes do not need to run tokenizer or detokenizer, + # so they can just wait here. + + for reader in scheduler_pipe_readers: + data = reader.recv() + assert data["status"] == "ready" + + if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0": + # When using `Engine` as a Python API, we don't want to block here. + return + + for proc in scheduler_procs: + proc.join() + logger.error( + f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" + ) + return + + # Launch detokenizer process + detoken_proc = mp.Process( + target=run_detokenizer_process, + args=( + server_args, + port_args, + ), + ) + detoken_proc.start() + + # Launch tokenizer process + tokenizer_manager = TokenizerManager(server_args, port_args) + if server_args.chat_template: + load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) + + # Wait for the model to finish loading + scheduler_infos = [] + for i in range(len(scheduler_pipe_readers)): + try: + data = scheduler_pipe_readers[i].recv() + except EOFError: + logger.error( + f"Rank {i} scheduler is dead. Please check if there are relevant logs." + ) + scheduler_procs[i].join() + logger.error(f"Exit code: {scheduler_procs[i].exitcode}") + raise + + if data["status"] != "ready": + raise RuntimeError( + "Initialization failed. Please see the error messages above." + ) + scheduler_infos.append(data) + + # Assume all schedulers have the same scheduler_info + scheduler_info = scheduler_infos[0] + tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"] + return tokenizer_manager, scheduler_info diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py new file mode 100644 index 000000000000..0ebce1a85d55 --- /dev/null +++ b/python/sglang/srt/entrypoints/http_server.py @@ -0,0 +1,579 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +The entry point of inference server. (SRT = SGLang Runtime) + +This file implements HTTP APIs for the inferenc engine via fastapi. +""" + +import asyncio +import dataclasses +import logging +import multiprocessing as multiprocessing +import os +import threading +import time +from http import HTTPStatus +from typing import AsyncIterator, Dict, Optional + +# Fix a bug of Python threading +setattr(threading, "_register_atexit", lambda *args, **kwargs: None) + +import orjson +import requests +import uvicorn +import uvloop +from fastapi import FastAPI, File, Form, Request, UploadFile +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import ORJSONResponse, Response, StreamingResponse + +from sglang.srt.entrypoints.engine import _launch_subprocesses +from sglang.srt.managers.io_struct import ( + CloseSessionReqInput, + ConfigureLoggingReq, + EmbeddingReqInput, + GenerateReqInput, + GetWeightsByNameReqInput, + InitWeightsUpdateGroupReqInput, + OpenSessionReqInput, + ReleaseMemoryOccupationReqInput, + ResumeMemoryOccupationReqInput, + UpdateWeightFromDiskReqInput, + UpdateWeightsFromDistributedReqInput, +) +from sglang.srt.managers.tokenizer_manager import TokenizerManager +from sglang.srt.metrics.func_timer import enable_func_timer +from sglang.srt.openai_api.adapter import ( + v1_batches, + v1_cancel_batch, + v1_chat_completions, + v1_completions, + v1_delete_file, + v1_embeddings, + v1_files_create, + v1_retrieve_batch, + v1_retrieve_file, + v1_retrieve_file_content, +) +from sglang.srt.openai_api.protocol import ModelCard, ModelList +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import ( + add_api_key_middleware, + add_prometheus_middleware, + delete_directory, + kill_process_tree, + set_uvicorn_logging_configs, +) +from sglang.utils import get_exception_traceback +from sglang.version import __version__ + +logger = logging.getLogger(__name__) +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +# Fast API +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# Store global states +@dataclasses.dataclass +class _GlobalState: + tokenizer_manager: TokenizerManager + scheduler_info: Dict + + +_global_state: Optional[_GlobalState] = None + + +def set_global_state(global_state: _GlobalState): + global _global_state + _global_state = global_state + + +##### Native API endpoints ##### + + +@app.get("/health") +async def health() -> Response: + """Check the health of the http server.""" + return Response(status_code=200) + + +@app.get("/health_generate") +async def health_generate(request: Request) -> Response: + """Check the health of the inference server by generating one token.""" + + sampling_params = {"max_new_tokens": 1, "temperature": 0.7} + + if _global_state.tokenizer_manager.is_generation: + gri = GenerateReqInput( + input_ids=[0], sampling_params=sampling_params, log_metrics=False + ) + else: + gri = EmbeddingReqInput( + input_ids=[0], sampling_params=sampling_params, log_metrics=False + ) + + try: + async for _ in _global_state.tokenizer_manager.generate_request(gri, request): + break + return Response(status_code=200) + except Exception as e: + logger.exception(e) + return Response(status_code=503) + + +@app.get("/get_model_info") +async def get_model_info(): + """Get the model information.""" + result = { + "model_path": _global_state.tokenizer_manager.model_path, + "tokenizer_path": _global_state.tokenizer_manager.server_args.tokenizer_path, + "is_generation": _global_state.tokenizer_manager.is_generation, + } + return result + + +@app.get("/get_server_info") +async def get_server_info(): + return { + **dataclasses.asdict(_global_state.tokenizer_manager.server_args), + **_global_state.scheduler_info, + "version": __version__, + } + + +# fastapi implicitly converts json in the request to obj (dataclass) +@app.api_route("/generate", methods=["POST", "PUT"]) +async def generate_request(obj: GenerateReqInput, request: Request): + """Handle a generate request.""" + if obj.stream: + + async def stream_results() -> AsyncIterator[bytes]: + try: + async for out in _global_state.tokenizer_manager.generate_request( + obj, request + ): + yield b"data: " + orjson.dumps( + out, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" + except ValueError as e: + out = {"error": {"message": str(e)}} + yield b"data: " + orjson.dumps( + out, option=orjson.OPT_NON_STR_KEYS + ) + b"\n\n" + yield b"data: [DONE]\n\n" + + return StreamingResponse( + stream_results(), + media_type="text/event-stream", + background=_global_state.tokenizer_manager.create_abort_task(obj), + ) + else: + try: + ret = await _global_state.tokenizer_manager.generate_request( + obj, request + ).__anext__() + return ret + except ValueError as e: + logger.error(f"Error: {e}") + return _create_error_response(e) + + +@app.api_route("/encode", methods=["POST", "PUT"]) +async def encode_request(obj: EmbeddingReqInput, request: Request): + """Handle an embedding request.""" + try: + ret = await _global_state.tokenizer_manager.generate_request( + obj, request + ).__anext__() + return ret + except ValueError as e: + return _create_error_response(e) + + +@app.api_route("/classify", methods=["POST", "PUT"]) +async def classify_request(obj: EmbeddingReqInput, request: Request): + """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" + try: + ret = await _global_state.tokenizer_manager.generate_request( + obj, request + ).__anext__() + return ret + except ValueError as e: + return _create_error_response(e) + + +@app.post("/flush_cache") +async def flush_cache(): + """Flush the radix cache.""" + _global_state.tokenizer_manager.flush_cache() + return Response( + content="Cache flushed.\nPlease check backend logs for more details. " + "(When there are running or waiting requests, the operation will not be performed.)\n", + status_code=200, + ) + + +@app.api_route("/start_profile", methods=["GET", "POST"]) +async def start_profile_async(): + """Start profiling.""" + _global_state.tokenizer_manager.start_profile() + return Response( + content="Start profiling.\n", + status_code=200, + ) + + +@app.api_route("/stop_profile", methods=["GET", "POST"]) +async def stop_profile_async(): + """Stop profiling.""" + _global_state.tokenizer_manager.stop_profile() + return Response( + content="Stop profiling. This will take some time.\n", + status_code=200, + ) + + +@app.post("/update_weights_from_disk") +async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): + """Update the weights from disk in-place without re-launching the server.""" + success, message = await _global_state.tokenizer_manager.update_weights_from_disk( + obj, request + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse( + content, + status_code=HTTPStatus.OK, + ) + else: + return ORJSONResponse( + content, + status_code=HTTPStatus.BAD_REQUEST, + ) + + +@app.post("/init_weights_update_group") +async def init_weights_update_group( + obj: InitWeightsUpdateGroupReqInput, request: Request +): + """Initialize the parameter update group.""" + success, message = await _global_state.tokenizer_manager.init_weights_update_group( + obj, request + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content, status_code=200) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + +@app.post("/update_weights_from_distributed") +async def update_weights_from_distributed( + obj: UpdateWeightsFromDistributedReqInput, request: Request +): + """Update model parameter from distributed online.""" + success, message = ( + await _global_state.tokenizer_manager.update_weights_from_distributed( + obj, request + ) + ) + content = {"success": success, "message": message} + if success: + return ORJSONResponse(content, status_code=200) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + +@app.api_route("/get_weights_by_name", methods=["GET", "POST"]) +async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): + """Get model parameter by name.""" + try: + ret = await _global_state.tokenizer_manager.get_weights_by_name(obj, request) + if ret is None: + return _create_error_response("Get parameter by name failed") + else: + return ORJSONResponse(ret, status_code=200) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/release_memory_occupation", methods=["GET", "POST"]) +async def release_memory_occupation( + obj: ReleaseMemoryOccupationReqInput, request: Request +): + """Release GPU occupation temporarily""" + try: + await _global_state.tokenizer_manager.release_memory_occupation(obj, request) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/resume_memory_occupation", methods=["GET", "POST"]) +async def resume_memory_occupation( + obj: ResumeMemoryOccupationReqInput, request: Request +): + """Resume GPU occupation""" + try: + await _global_state.tokenizer_manager.resume_memory_occupation(obj, request) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/open_session", methods=["GET", "POST"]) +async def open_session(obj: OpenSessionReqInput, request: Request): + """Open a session, and return its unique session id.""" + try: + session_id = await _global_state.tokenizer_manager.open_session(obj, request) + if session_id is None: + raise Exception( + "Failed to open the session. Check if a session with the same id is still open." + ) + return session_id + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/close_session", methods=["GET", "POST"]) +async def close_session(obj: CloseSessionReqInput, request: Request): + """Close the session""" + try: + await _global_state.tokenizer_manager.close_session(obj, request) + return Response(status_code=200) + except Exception as e: + return _create_error_response(e) + + +@app.api_route("/configure_logging", methods=["GET", "POST"]) +async def configure_logging(obj: ConfigureLoggingReq, request: Request): + """Close the session""" + _global_state.tokenizer_manager.configure_logging(obj) + return Response(status_code=200) + + +##### OpenAI-compatible API endpoints ##### + + +@app.post("/v1/completions") +async def openai_v1_completions(raw_request: Request): + return await v1_completions(_global_state.tokenizer_manager, raw_request) + + +@app.post("/v1/chat/completions") +async def openai_v1_chat_completions(raw_request: Request): + return await v1_chat_completions(_global_state.tokenizer_manager, raw_request) + + +@app.post("/v1/embeddings", response_class=ORJSONResponse) +async def openai_v1_embeddings(raw_request: Request): + response = await v1_embeddings(_global_state.tokenizer_manager, raw_request) + return response + + +@app.get("/v1/models", response_class=ORJSONResponse) +def available_models(): + """Show available models.""" + served_model_names = [_global_state.tokenizer_manager.served_model_name] + model_cards = [] + for served_model_name in served_model_names: + model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) + return ModelList(data=model_cards) + + +@app.post("/v1/files") +async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): + return await v1_files_create( + file, purpose, _global_state.tokenizer_manager.server_args.file_storage_pth + ) + + +@app.delete("/v1/files/{file_id}") +async def delete_file(file_id: str): + # https://platform.openai.com/docs/api-reference/files/delete + return await v1_delete_file(file_id) + + +@app.post("/v1/batches") +async def openai_v1_batches(raw_request: Request): + return await v1_batches(_global_state.tokenizer_manager, raw_request) + + +@app.post("/v1/batches/{batch_id}/cancel") +async def cancel_batches(batch_id: str): + # https://platform.openai.com/docs/api-reference/batch/cancel + return await v1_cancel_batch(_global_state.tokenizer_manager, batch_id) + + +@app.get("/v1/batches/{batch_id}") +async def retrieve_batch(batch_id: str): + return await v1_retrieve_batch(batch_id) + + +@app.get("/v1/files/{file_id}") +async def retrieve_file(file_id: str): + # https://platform.openai.com/docs/api-reference/files/retrieve + return await v1_retrieve_file(file_id) + + +@app.get("/v1/files/{file_id}/content") +async def retrieve_file_content(file_id: str): + # https://platform.openai.com/docs/api-reference/files/retrieve-contents + return await v1_retrieve_file_content(file_id) + + +def _create_error_response(e): + return ORJSONResponse( + {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST + ) + + +def launch_server( + server_args: ServerArgs, + pipe_finish_writer: Optional[multiprocessing.connection.Connection] = None, +): + """ + Launch SRT (SGLang Runtime) Server. + + The SRT server consists of an HTTP server and an SRT engine. + + - HTTP server: A FastAPI server that routes requests to the engine. + - The engine consists of three components: + 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. + 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. + 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. + + Note: + 1. The HTTP server, Engine, and TokenizerManager both run in the main process. + 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. + """ + tokenizer_manager, scheduler_info = _launch_subprocesses(server_args=server_args) + set_global_state( + _GlobalState( + tokenizer_manager=tokenizer_manager, + scheduler_info=scheduler_info, + ) + ) + + # Add api key authorization + if server_args.api_key: + add_api_key_middleware(app, server_args.api_key) + + # Add prometheus middleware + if server_args.enable_metrics: + add_prometheus_middleware(app) + enable_func_timer() + + # Send a warmup request + t = threading.Thread( + target=_wait_and_warmup, + args=( + server_args, + pipe_finish_writer, + _global_state.tokenizer_manager.image_token_id, + ), + ) + t.start() + + try: + # Update logging configs + set_uvicorn_logging_configs() + + # Listen for HTTP requests + uvicorn.run( + app, + host=server_args.host, + port=server_args.port, + log_level=server_args.log_level_http or server_args.log_level, + timeout_keep_alive=5, + loop="uvloop", + ) + finally: + t.join() + + +def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text): + headers = {} + url = server_args.url() + if server_args.api_key: + headers["Authorization"] = f"Bearer {server_args.api_key}" + + # Wait until the server is launched + success = False + for _ in range(120): + time.sleep(1) + try: + res = requests.get(url + "/get_model_info", timeout=5, headers=headers) + assert res.status_code == 200, f"{res=}, {res.text=}" + success = True + break + except (AssertionError, requests.exceptions.RequestException): + last_traceback = get_exception_traceback() + pass + + if not success: + if pipe_finish_writer is not None: + pipe_finish_writer.send(last_traceback) + logger.error(f"Initialization failed. warmup error: {last_traceback}") + kill_process_tree(os.getpid()) + return + + model_info = res.json() + + # Send a warmup request + request_name = "/generate" if model_info["is_generation"] else "/encode" + max_new_tokens = 8 if model_info["is_generation"] else 1 + json_data = { + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + } + if server_args.skip_tokenizer_init: + json_data["input_ids"] = [10, 11, 12] + else: + json_data["text"] = "The capital city of France is" + + try: + for _ in range(server_args.dp_size): + res = requests.post( + url + request_name, + json=json_data, + headers=headers, + timeout=600, + ) + assert res.status_code == 200, f"{res}" + except Exception: + last_traceback = get_exception_traceback() + if pipe_finish_writer is not None: + pipe_finish_writer.send(last_traceback) + logger.error(f"Initialization failed. warmup error: {last_traceback}") + kill_process_tree(os.getpid()) + return + + # Debug print + # logger.info(f"{res.json()=}") + + logger.info("The server is fired up and ready to roll!") + if pipe_finish_writer is not None: + pipe_finish_writer.send("ready") + + if server_args.delete_ckpt_after_loading: + delete_directory(server_args.model_path) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 92b01d4524f8..ea39d73f2eea 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -30,20 +30,15 @@ ) from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES -try: - from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig - - from sglang.srt.configs import ExaoneConfig, Qwen2VLConfig - - _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { - ChatGLMConfig.model_type: ChatGLMConfig, - DbrxConfig.model_type: DbrxConfig, - ExaoneConfig.model_type: ExaoneConfig, - Qwen2VLConfig.model_type: Qwen2VLConfig, - } -except ImportError: - # We want this file to run without vllm dependency - _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {} +from sglang.srt.configs import ChatGLMConfig, DbrxConfig, ExaoneConfig, Qwen2VLConfig + +_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { + ChatGLMConfig.model_type: ChatGLMConfig, + DbrxConfig.model_type: DbrxConfig, + ExaoneConfig.model_type: ExaoneConfig, + Qwen2VLConfig.model_type: Qwen2VLConfig, +} + for name, cls in _CONFIG_REGISTRY.items(): with contextlib.suppress(ValueError): diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index c4c54f0b03c4..ebb0652c5d2e 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -25,13 +25,13 @@ if is_flashinfer_available(): from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul -from vllm.distributed import ( +from vllm.model_executor.custom_op import CustomOp + +from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.custom_op import CustomOp - from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.utils import set_weight_attrs diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index 140755ff5e67..745598643028 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -66,7 +66,14 @@ def forward( if forward_batch.forward_mode.is_decode(): return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache) else: - return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache) + return self.forward_extend( + q, + k, + v, + layer, + forward_batch, + save_kv_cache, + ) def forward_decode( self, diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 8b823cc5a5dd..7540515c5fd1 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -18,6 +18,7 @@ from sglang.global_config import global_config from sglang.srt.layers.attention import AttentionBackend +from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available @@ -62,9 +63,9 @@ def __init__(self, model_runner: ModelRunner): self.decode_use_tensor_cores = should_use_tensor_core( kv_cache_dtype=model_runner.kv_cache_dtype, num_attention_heads=model_runner.model_config.num_attention_heads - // model_runner.tp_size, + // get_attention_tp_size(), num_kv_heads=model_runner.model_config.get_num_kv_heads( - model_runner.tp_size + get_attention_tp_size() ), ) self.max_context_len = model_runner.model_config.context_len @@ -84,6 +85,10 @@ def __init__(self, model_runner: ModelRunner): self.num_wrappers = 1 self.dispatch_reason = None + # Qwen2 models require higher flashinfer workspace size + if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures: + global_config.flashinfer_workspace_size = 512 * 1024 * 1024 + # Allocate buffers self.workspace_buffer = torch.empty( global_config.flashinfer_workspace_size, @@ -143,7 +148,7 @@ def __init__(self, model_runner: ModelRunner): self.prefill_cuda_graph_metadata = {} def init_forward_metadata(self, forward_batch: ForwardBatch): - if forward_batch.forward_mode.is_decode(): + if forward_batch.forward_mode.is_decode_or_idle(): self.indices_updater_decode.update( forward_batch.req_pool_indices, forward_batch.seq_lens, @@ -234,7 +239,7 @@ def init_forward_metadata_capture_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInfo], ): - if forward_mode.is_decode(): + if forward_mode.is_decode_or_idle(): decode_wrappers = [] for i in range(self.num_wrappers): decode_wrappers.append( @@ -303,7 +308,7 @@ def init_forward_metadata_replay_cuda_graph( forward_mode: ForwardMode, spec_info: Optional[SpecInfo], ): - if forward_mode.is_decode(): + if forward_mode.is_decode_or_idle(): self.indices_updater_decode.update( req_pool_indices[:bs], seq_lens[:bs], @@ -347,11 +352,15 @@ def forward_extend( else forward_batch.encoder_out_cache_loc ) + logits_soft_cap = layer.logit_cap + if not self.forward_metadata.use_ragged: if k is not None: assert v is not None if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) o = prefill_wrapper_paged.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), @@ -359,7 +368,9 @@ def forward_extend( causal=not layer.is_cross_attention, sm_scale=layer.scaling, window_left=layer.sliding_window_size, - logits_soft_cap=layer.logit_cap, + logits_soft_cap=logits_soft_cap, + k_scale=layer.k_scale, + v_scale=layer.v_scale, ) else: o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( @@ -368,7 +379,7 @@ def forward_extend( v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim), causal=True, sm_scale=layer.scaling, - logits_soft_cap=layer.logit_cap, + logits_soft_cap=logits_soft_cap, ) if self.forward_metadata.extend_no_prefix: @@ -385,7 +396,9 @@ def forward_extend( o, _ = merge_state(o1, s1, o2, s2) if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) @@ -410,13 +423,17 @@ def forward_decode( if k is not None: assert v is not None if save_kv_cache: - forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v) + forward_batch.token_to_kv_pool.set_kv_buffer( + layer, cache_loc, k, v, layer.k_scale, layer.v_scale + ) o = decode_wrapper.forward( q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id), sm_scale=layer.scaling, logits_soft_cap=layer.logit_cap, + k_scale=layer.k_scale, + v_scale=layer.v_scale, ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) @@ -437,10 +454,10 @@ class FlashInferIndicesUpdaterDecode: def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): # Parse Constants self.num_qo_heads = ( - model_runner.model_config.num_attention_heads // model_runner.tp_size + model_runner.model_config.num_attention_heads // get_attention_tp_size() ) self.num_kv_heads = model_runner.model_config.get_num_kv_heads( - model_runner.tp_size + get_attention_tp_size() ) self.head_dim = model_runner.model_config.head_dim self.data_type = model_runner.kv_cache_dtype @@ -609,10 +626,10 @@ class FlashInferIndicesUpdaterPrefill: def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): # Parse Constants self.num_qo_heads = ( - model_runner.model_config.num_attention_heads // model_runner.tp_size + model_runner.model_config.num_attention_heads // get_attention_tp_size() ) self.num_kv_heads = model_runner.model_config.get_num_kv_heads( - model_runner.tp_size + get_attention_tp_size() ) self.head_dim = model_runner.model_config.head_dim self.data_type = model_runner.kv_cache_dtype diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 04327b162b90..fade8ed292dc 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -5,6 +5,7 @@ import torch from sglang.srt.layers.attention import AttentionBackend +from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode if TYPE_CHECKING: @@ -28,12 +29,9 @@ def __init__(self, model_runner: ModelRunner): self.decode_attention_fwd = decode_attention_fwd self.extend_attention_fwd = extend_attention_fwd - if model_runner.server_args.enable_dp_attention: - self.num_head = model_runner.model_config.num_attention_heads - else: - self.num_head = ( - model_runner.model_config.num_attention_heads // model_runner.tp_size - ) + self.num_head = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py new file mode 100644 index 000000000000..4fcfaad56251 --- /dev/null +++ b/python/sglang/srt/layers/attention/vision.py @@ -0,0 +1,204 @@ +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from sglang.srt.distributed import parallel_state +from sglang.srt.distributed import utils as dist_utils +from sglang.srt.layers.attention.triton_ops.prefill_attention import ( + context_attention_fwd, +) +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from sglang.srt.layers.quantization import QuantizationConfig + + +def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange( + torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 + ) + + +def apply_rotary_emb_torch( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False +) -> torch.Tensor: + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat( + cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + sin = repeat( + sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" + ) + return torch.cat( + [ + x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, + x[..., ro_dim:], + ], + dim=-1, + ) + + +def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + t_ = t.float() + cos = freqs.cos() + sin = freqs.sin() + output = apply_rotary_emb_torch(t_, cos, sin).type_as(t) + return output + + +class VisionAttention(nn.Module): + """Multi-headed attention without any cache, mostly used for ViT.""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + use_qkv_parallel: bool, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + world_size = parallel_state.get_tensor_model_parallel_world_size() + + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads + ) + self.num_attention_heads_per_partition = dist_utils.divide( + num_heads, world_size + ) + # self.tp_size = get_tensor_model_parallel_world_size() + # num_heads = self.num_heads_per_partition + self.use_qkv_parallel = use_qkv_parallel + if use_qkv_parallel: + self.head_dim = embed_dim // num_heads + self.qkv_proj = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.head_dim, + total_num_heads=num_heads, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + else: + self.qkv_proj = ColumnParallelLinear( + input_size=embed_dim, + output_size=3 * projection_size, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.proj = RowParallelLinear( + input_size=embed_dim, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: Optional[torch.Tensor] = None, + rotary_pos_emb: torch.Tensor = None, + ) -> torch.Tensor: + """ + Input shape: [b, s, embed_dim] + Output shape: [s, b, num_heads * head_size] + """ + + bsz, s, _ = x.shape + if self.use_qkv_parallel: + # [b, s, embed_dim] --> [b, s, embed_dim] + qkv, _ = self.qkv_proj(x) + q, k, v = qkv.chunk(3, dim=-1) + + # [b, s, embed_dim] --> [b * s, num_heads, head_size] + q, k, v = [ + x.reshape( + bsz * s, self.num_attention_heads_per_partition, -1 + ).contiguous() + for x in (q, k, v) + ] + else: + # [b, s, embed_dim] --> [s, b, embed_dim] + x = rearrange(x, "b s ... -> s b ...") + # [s, b, embed_dim] --> [s, b, head * 3 * head_dim] + qkv, _ = self.qkv_proj(x) + # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] + new_x_shape = qkv.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) + qkv = qkv.view(*new_x_shape) + + # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim] + q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3) + + # [s, b, head, head_dim] --> [b, s, head, head_dim] + q, k, v = [ + rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) + ] + + if rotary_pos_emb is not None: + q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) + k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) + + if self.use_qkv_parallel: + pass + else: + # [b, s, head, head_dim] --> [b * s, head, head_dim] + q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] + + # [b * s, num_heads, head_size] + output = torch.empty_like(q) + + seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda() + max_seqlen = seq_lens.max().item() + + context_attention_fwd( + q, + k, + v, + output, + cu_seqlens.cuda(), + seq_lens, + max_seqlen, + is_causal=False, + ) + + if self.use_qkv_parallel: + + # [b * s, head, head_dim] --> [b, s, head * head_dim] + output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz) + + # [b, s, head, head_dim] --> [b, s, head, head_dim] + output, _ = self.proj(output) + else: + # [b * s, head, head_dim] --> [b, s, head, head_dim] + context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz) + + # [s, b, num_heads * head_size] + context_layer = rearrange( + context_layer, "b s h d -> s b (h d)" + ).contiguous() + + # [s, b, num_heads * head_size] --> [s, b, num_heads * head_size] + output, _ = self.proj(context_layer) + + output = output.view(bsz, s, -1) + + return output diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py new file mode 100644 index 000000000000..65efa0feb84b --- /dev/null +++ b/python/sglang/srt/layers/dp_attention.py @@ -0,0 +1,69 @@ +import torch + +from sglang.srt.distributed import GroupCoordinator, get_tp_group + +_ATTN_TP_GROUP = None +_ATTN_TP_RANK = None +_ATTN_TP_SIZE = None +_DP_RANK = None +_DP_SIZE = None + + +def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): + if not enable_dp_attention: + return tp_rank, tp_size, 0 + + attn_tp_size = tp_size // dp_size + dp_rank = tp_rank // attn_tp_size + attn_tp_rank = tp_rank % attn_tp_size + return attn_tp_rank, attn_tp_size, dp_rank + + +def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size): + global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE + + _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info( + enable_dp_attention, tp_rank, tp_size, dp_size + ) + _DP_SIZE = dp_size + + tp_group = get_tp_group() + _ATTN_TP_GROUP = GroupCoordinator( + [ + list(range(head, head + _ATTN_TP_SIZE)) + for head in range(0, tp_size, _ATTN_TP_SIZE) + ], + tp_rank, + torch.distributed.get_backend(tp_group.device_group), + False, + False, + False, + False, + False, + group_name="attention_tp", + ) + + +def get_attention_tp_group(): + assert _ATTN_TP_GROUP is not None, "dp attention not initialized!" + return _ATTN_TP_GROUP + + +def get_attention_tp_rank(): + assert _ATTN_TP_RANK is not None, "dp attention not initialized!" + return _ATTN_TP_RANK + + +def get_attention_tp_size(): + assert _ATTN_TP_SIZE is not None, "dp attention not initialized!" + return _ATTN_TP_SIZE + + +def get_attention_dp_rank(): + assert _DP_RANK is not None, "dp attention not initialized!" + return _DP_RANK + + +def get_attention_dp_size(): + assert _DP_SIZE is not None, "dp attention not initialized!" + return _DP_SIZE diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index b828c03911e8..bfa5d2b66544 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -1,4 +1,4 @@ -# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/linear.py +"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py""" import logging from abc import abstractmethod @@ -7,7 +7,8 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter -from vllm.distributed import ( + +from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -15,17 +16,13 @@ tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, ) - -# workaround -from vllm.model_executor.layers.linear import LinearBase -from vllm.model_executor.parameter import ( +from sglang.srt.layers.parameter import ( BasevLLMParameter, PackedColumnParameter, PackedvLLMParameter, PerTensorScaleParameter, RowvLLMParameter, ) - from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -42,9 +39,13 @@ "GPTQMarlinLinearMethod", "Fp8LinearMethod", "MarlinLinearMethod", - "GPTQLinearMethod", "QQQLinearMethod", + "GPTQMarlin24LinearMethod", + "TPUInt8LinearMethod", + "GPTQLinearMethod", + "FBGEMMFp8LinearMethod", "ModelOptFp8LinearMethod", + "IPEXAWQLinearMethod", ] @@ -170,6 +171,45 @@ def apply( return F.linear(x, layer.weight, bias) +class LinearBase(torch.nn.Module): + """Base linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + if quant_config is None: + self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod() + else: + self.quant_method = quant_config.get_quant_method(self, prefix=prefix) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + class ReplicatedLinear(LinearBase): """Replicated linear layer. @@ -287,6 +327,8 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None, prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, ): super().__init__( input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix @@ -295,7 +337,11 @@ def __init__( self.gather_output = gather_output # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() + if tp_rank is None: + tp_rank = get_tensor_model_parallel_rank() + if tp_size is None: + tp_size = get_tensor_model_parallel_world_size() + self.tp_rank, self.tp_size = tp_rank, tp_size assert self.quant_method is not None self.output_size_per_partition = divide(self.output_size, tp_size) self.output_partition_sizes = [self.output_size_per_partition] @@ -336,7 +382,6 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() output_dim = getattr(param, "output_dim", None) # Special case for GGUF @@ -356,7 +401,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # no need to narrow here if output_dim is not None and not use_bitsandbytes_4bit: shard_size = param_data.shape[output_dim] - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for loading scales off disk, which often do not @@ -373,7 +418,7 @@ def weight_loader_v2(self, param: Parameter, loaded_weight: torch.Tensor): if len(loaded_weight.shape) == 0: assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - param.load_column_parallel_weight(loaded_weight=loaded_weight) + param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank) def forward(self, input_): bias = self.bias if not self.skip_bias_add else None @@ -393,7 +438,7 @@ def extra_repr(self) -> str: s = f"in_features={self.input_size}" s += f", output_features={self.output_size_per_partition}" s += f", bias={self.bias is not None}" - s += f", tp_size={get_tensor_model_parallel_world_size()}" + s += f", tp_size={self.tp_size}" s += f", gather_output={self.gather_output}" return s @@ -431,10 +476,18 @@ def __init__( params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + use_presharded_weights: bool = False, ): self.output_sizes = output_sizes - tp_size = get_tensor_model_parallel_world_size() + if tp_rank is None: + tp_rank = get_tensor_model_parallel_rank() + if tp_size is None: + tp_size = get_tensor_model_parallel_world_size() + self.tp_rank, self.tp_size = tp_rank, tp_size assert all(output_size % tp_size == 0 for output_size in output_sizes) + self.use_presharded_weights = use_presharded_weights super().__init__( input_size=input_size, output_size=sum(output_sizes), @@ -444,6 +497,8 @@ def __init__( params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, + tp_rank=tp_rank, + tp_size=tp_size, ) def weight_loader( @@ -463,12 +518,9 @@ def weight_loader( return if is_gguf_weight: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - output_dim = getattr(param, "output_dim", None) - shard_size = loaded_weight.size(output_dim) // tp_size - start_idx = tp_rank * shard_size + shard_size = loaded_weight.size(output_dim) // self.tp_size + start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) @@ -522,11 +574,9 @@ def weight_loader( return assert loaded_shard_id < len(self.output_sizes) - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() if output_dim is not None: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size + shard_size = self.output_sizes[loaded_shard_id] // self.tp_size # Special case for quantization. # If quantized, we need to adjust the offset and size to account # for the packing. @@ -545,10 +595,10 @@ def weight_loader( shard_offset = loaded_weight.shape[output_dim] * loaded_shard_id param_data = param_data.narrow(output_dim, shard_offset, shard_size) - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size # bitsandbytes loads the weights of the specific portion # no need to narrow here - if not use_bitsandbytes_4bit: + if not use_bitsandbytes_4bit and not self.use_presharded_weights: loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) # Special case for AQLM codebooks. elif is_metadata: @@ -624,31 +674,33 @@ def weight_loader_v2( elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_merged_column_weight(loaded_weight=loaded_weight) return + # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) return assert loaded_shard_id < len(self.output_sizes) - tp_size = get_tensor_model_parallel_world_size() - if isinstance(param, BlockQuantScaleParameter): weight_block_size = self.quant_method.quant_config.weight_block_size block_n, _ = weight_block_size[0], weight_block_size[1] shard_offset = ( (sum(self.output_sizes[:loaded_shard_id]) + block_n - 1) // block_n - ) // tp_size + ) // self.tp_size shard_size = ( - (self.output_sizes[loaded_shard_id] + block_n - 1) // block_n // tp_size + (self.output_sizes[loaded_shard_id] + block_n - 1) + // block_n + // self.tp_size ) else: - shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size - shard_size = self.output_sizes[loaded_shard_id] // tp_size + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size + shard_size = self.output_sizes[loaded_shard_id] // self.tp_size param.load_merged_column_weight( loaded_weight=loaded_weight, shard_id=loaded_shard_id, shard_offset=shard_offset, shard_size=shard_size, + use_presharded_weights=self.use_presharded_weights, ) @@ -689,6 +741,8 @@ def __init__( params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, ): self.hidden_size = hidden_size self.head_size = head_size @@ -697,7 +751,11 @@ def __init__( total_num_kv_heads = total_num_heads self.total_num_kv_heads = total_num_kv_heads # Divide the weight matrix along the last dimension. - tp_size = get_tensor_model_parallel_world_size() + if tp_rank is None: + tp_rank = get_tensor_model_parallel_rank() + if tp_size is None: + tp_size = get_tensor_model_parallel_world_size() + self.tp_rank, self.tp_size = tp_rank, tp_size self.num_heads = divide(self.total_num_heads, tp_size) if tp_size >= self.total_num_kv_heads: self.num_kv_heads = 1 @@ -724,6 +782,8 @@ def __init__( params_dtype=params_dtype, quant_config=quant_config, prefix=prefix, + tp_rank=tp_rank, + tp_size=tp_size, ) def _get_shard_offset_mapping(self, loaded_shard_id: str): @@ -800,6 +860,7 @@ def weight_loader_v2( elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_qkv_weight(loaded_weight=loaded_weight) return + # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) return @@ -820,6 +881,7 @@ def weight_loader_v2( shard_id=loaded_shard_id, shard_offset=shard_offset, shard_size=shard_size, + tp_rank=self.tp_rank, ) def weight_loader( @@ -840,12 +902,9 @@ def weight_loader( return if is_gguf_weight: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - output_dim = getattr(param, "output_dim", None) - shard_size = loaded_weight.size(output_dim) // tp_size - start_idx = tp_rank * shard_size + shard_size = loaded_weight.size(output_dim) // self.tp_size + start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) @@ -934,7 +993,6 @@ def weight_loader( self.weight_loader(param, loaded_weight_shard, shard_id) return - tp_rank = get_tensor_model_parallel_rank() assert loaded_shard_id in ["q", "k", "v"] # If output dim is defined, use the default loading process. @@ -984,9 +1042,9 @@ def weight_loader( param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": - shard_id = tp_rank + shard_id = self.tp_rank else: - shard_id = tp_rank // self.num_kv_head_replicas + shard_id = self.tp_rank // self.num_kv_head_replicas start_idx = shard_id * shard_size # bitsandbytes loads the weights of the specific portion @@ -1055,6 +1113,9 @@ def __init__( reduce_results: bool = True, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + use_presharded_weights: bool = False, ): super().__init__( input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix @@ -1064,10 +1125,14 @@ def __init__( self.reduce_results = reduce_results # Divide the weight matrix along the last dimension. - self.tp_rank = get_tensor_model_parallel_rank() - self.tp_size = get_tensor_model_parallel_world_size() + if tp_rank is None: + tp_rank = get_tensor_model_parallel_rank() + if tp_size is None: + tp_size = get_tensor_model_parallel_world_size() + self.tp_rank, self.tp_size = tp_rank, tp_size self.input_size_per_partition = divide(input_size, self.tp_size) assert self.quant_method is not None + self.use_presharded_weights = use_presharded_weights self.quant_method.create_weights( layer=self, @@ -1101,8 +1166,6 @@ def __init__( self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() input_dim = getattr(param, "input_dim", None) use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False) @@ -1116,15 +1179,19 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): if is_gguf_weight and isinstance(param, UninitializedParameter): weight_shape = list(loaded_weight.shape) if input_dim: - weight_shape[input_dim] = weight_shape[input_dim] // tp_size + weight_shape[input_dim] = weight_shape[input_dim] // self.tp_size param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype) param_data = param.data # bitsandbytes loads the weights of the specific portion # no need to narrow here - if input_dim is not None and not use_bitsandbytes_4bit: + if ( + input_dim is not None + and not use_bitsandbytes_4bit + and not self.use_presharded_weights + ): shard_size = param_data.shape[input_dim] - start_idx = tp_rank * shard_size + start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(input_dim, start_idx, shard_size) # Special case for loading scales off disk, which often do not @@ -1143,17 +1210,27 @@ def weight_loader_v2(self, param: BasevLLMParameter, loaded_weight: torch.Tensor assert loaded_weight.numel() == 1 loaded_weight = loaded_weight.reshape(1) - param.load_row_parallel_weight(loaded_weight=loaded_weight) + if isinstance(param, BasevLLMParameter): + # This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py, + # It supports additional parameters like tp_rank and use_presharded_weights. + param.load_row_parallel_weight( + loaded_weight, + tp_rank=self.tp_rank, + use_presharded_weights=self.use_presharded_weights, + ) + else: + # `params` is defined in `vllm/model_executor/parameter.py`, + # It does not support additional parameters. + param.load_row_parallel_weight(loaded_weight) def forward(self, input_): if self.input_is_parallel: input_parallel = input_ else: - tp_rank = get_tensor_model_parallel_rank() splitted_input = split_tensor_along_last_dim( input_, num_partitions=self.tp_size ) - input_parallel = splitted_input[tp_rank].contiguous() + input_parallel = splitted_input[self.tp_rank].contiguous() # Matrix multiply. assert self.quant_method is not None diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 7ca1d51a756d..08ee5a3509b9 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -14,17 +14,18 @@ """Logits processing.""" import dataclasses +import logging from typing import List, Optional, Union import torch import triton import triton.language as tl from torch import nn -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) - from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, @@ -32,6 +33,8 @@ ForwardMode, ) +logger = logging.getLogger(__name__) + @dataclasses.dataclass class LogitsProcessorOutput: @@ -50,8 +53,6 @@ class LogitsProcessorOutput: next_token_top_logprobs_idx: Optional[List] = None ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor - # The normlaized logprobs of prompts. shape: [#seq] - normalized_prompt_logprobs: torch.Tensor = None # The logprobs of input tokens. shape: [#token] input_token_logprobs: torch.Tensor = None # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k] @@ -129,59 +130,70 @@ def forward( hidden_states, lm_head: VocabParallelEmbedding, logits_metadata: Union[LogitsMetadata, ForwardBatch], - ): + ) -> LogitsProcessorOutput: if isinstance(logits_metadata, ForwardBatch): logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata) # Get the last hidden states and last logits for the next token prediction if ( - logits_metadata.forward_mode.is_decode() + logits_metadata.forward_mode.is_decode_or_idle() or logits_metadata.forward_mode.is_target_verify() ): - last_index = None - last_hidden = hidden_states - else: + pruned_states = hidden_states + sample_indices = None + elif ( + logits_metadata.forward_mode.is_extend() + and not logits_metadata.extend_return_logprob + ): + # Prefill without input logprobs. last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1 - last_hidden = hidden_states[last_index] + pruned_states = hidden_states[last_index] + sample_indices = None + else: + # Slice the requested tokens to compute logprob + sample_index_pt = -1 + sample_indices = [] + pt, pruned_states, pruned_input_ids = 0, [], [] + for start_len, extend_len in zip( + logits_metadata.extend_logprob_start_lens_cpu, + logits_metadata.extend_seq_lens_cpu, + ): + pruned_states.append(hidden_states[pt + start_len : pt + extend_len]) + sample_index_pt += extend_len - start_len + sample_indices.append(sample_index_pt) + pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) + pt += extend_len + + pruned_states = torch.cat(pruned_states) + + # Compute logits for both input and sampled tokens. + logits = self._get_logits(pruned_states, lm_head, logits_metadata) + sampled_logits = ( + logits[sample_indices] if sample_indices is not None else logits + ) - # Compute logits - last_logits = self._get_logits(last_hidden, lm_head) if ( not logits_metadata.extend_return_logprob or logits_metadata.capture_hidden_mode.need_capture() ): # Decode mode or extend mode without return_logprob. return LogitsProcessorOutput( - next_token_logits=last_logits, + next_token_logits=sampled_logits, hidden_states=( hidden_states if logits_metadata.capture_hidden_mode.is_full() else ( - last_hidden + pruned_states if logits_metadata.capture_hidden_mode.is_last() else None ) ), ) else: - # Slice the requested tokens to compute logprob - pt, pruned_states, pruned_input_ids = 0, [], [] - for start_len, extend_len in zip( - logits_metadata.extend_logprob_start_lens_cpu, - logits_metadata.extend_seq_lens_cpu, - ): - pruned_states.append(hidden_states[pt + start_len : pt + extend_len]) - pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len]) - pt += extend_len - - # Compute the logits of all required tokens - pruned_states = torch.cat(pruned_states) - del hidden_states - input_token_logits = self._get_logits(pruned_states, lm_head) - del pruned_states + input_logprobs = logits + del hidden_states, logits # Normalize the logprob w/o temperature, top-p - input_logprobs = input_token_logits input_logprobs = self.compute_temp_top_p_normalized_logprobs( input_logprobs, logits_metadata ) @@ -195,25 +207,18 @@ def forward( else: input_top_logprobs_val = input_top_logprobs_idx = None - # Compute the normalized logprobs for the requested tokens. - # Note that we pad a zero at the end for easy batching. input_token_logprobs = input_logprobs[ - torch.arange(input_logprobs.shape[0], device="cuda"), + torch.arange(input_logprobs.shape[0], device=input_logprobs.device), torch.cat( [ torch.cat(pruned_input_ids)[1:], - torch.tensor([0], device="cuda"), + torch.tensor([0], device=input_logprobs.device), ] ), ] - normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( - input_token_logprobs, - logits_metadata, - ) return LogitsProcessorOutput( - next_token_logits=last_logits, - normalized_prompt_logprobs=normalized_prompt_logprobs, + next_token_logits=sampled_logits, input_token_logprobs=input_token_logprobs, input_top_logprobs_val=input_top_logprobs_val, input_top_logprobs_idx=input_top_logprobs_idx, @@ -223,8 +228,11 @@ def _get_logits( self, hidden_states: torch.Tensor, lm_head: VocabParallelEmbedding, + logits_metadata: LogitsMetadata, embedding_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: + """Get logits from hidden_states.""" + if hasattr(lm_head, "weight"): logits = torch.matmul(hidden_states, lm_head.weight.T) else: @@ -237,8 +245,6 @@ def _get_logits( if self.do_tensor_parallel_all_gather: logits = tensor_model_parallel_all_gather(logits) - # Compute the normalized logprobs for the requested tokens. - # Note that we pad a zero at the end for easy batching. logits = logits[:, : self.config.vocab_size].float() if self.final_logit_softcapping: @@ -246,27 +252,6 @@ def _get_logits( return logits - @staticmethod - def _get_normalized_prompt_logprobs( - input_token_logprobs: torch.Tensor, - logits_metadata: LogitsMetadata, - ): - logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32) - pruned_lens = torch.tensor( - logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda" - ) - - start = torch.zeros_like(pruned_lens) - start[1:] = torch.cumsum(pruned_lens[:-1], dim=0) - end = torch.clamp( - start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1 - ) - sum_logp = ( - logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start] - ) - normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1) - return normalized_prompt_logprobs - @staticmethod def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata): max_k = max(logits_metadata.top_logprobs_nums) @@ -311,7 +296,7 @@ def fused_softcap_kernel( n_elements, BLOCK_SIZE: tl.constexpr, ): - pid = tl.program_id(0) + pid = tl.program_id(0).to(tl.int64) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements diff --git a/python/sglang/srt/layers/moe/ep_moe/layer.py b/python/sglang/srt/layers/moe/ep_moe/layer.py index 96e02e312781..8f5a71dff8c3 100644 --- a/python/sglang/srt/layers/moe/ep_moe/layer.py +++ b/python/sglang/srt/layers/moe/ep_moe/layer.py @@ -4,13 +4,12 @@ import torch from torch.nn import Module from vllm import _custom_ops as ops -from vllm.distributed import ( +from vllm.model_executor.custom_op import CustomOp + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.custom_op import CustomOp -from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod - from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.moe.ep_moe.kernels import ( grouped_gemm_triton, @@ -25,6 +24,7 @@ QuantizationConfig, QuantizeMethodBase, ) +from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from sglang.srt.utils import is_hip, set_weight_attrs logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/moe/fused_moe_native.py b/python/sglang/srt/layers/moe/fused_moe_native.py index 638173b647d5..0703e840ca64 100644 --- a/python/sglang/srt/layers/moe/fused_moe_native.py +++ b/python/sglang/srt/layers/moe/fused_moe_native.py @@ -8,6 +8,7 @@ import torch from torch.nn import functional as F +from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.moe.topk import select_experts @@ -44,3 +45,71 @@ def fused_moe_forward_native( x3 = torch.einsum("ti, taoi -> tao", x, w3_weights) expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights) return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype)) + + +def moe_forward_native( + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + torch_native=True, + ) + + # Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589 + len_experts = layer.num_experts + + cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts)) + cnts.scatter_(1, topk_ids.to(torch.int64), 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + + sorted_tokens = x[idxs // topk_ids.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + + layer_w13_weight = layer.w13_weight[i] + layer_w2_weight = layer.w2_weight[i] + + gate_up = F.linear(tokens_for_this_expert, layer_w13_weight) + gate_up = SiluAndMul()(gate_up) + expert_out = F.linear(gate_up, layer_w2_weight) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + new_x = torch.empty_like(outs) + + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weights.dtype) + .mul_(topk_weights.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index ed132555bd20..c0d558085587 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -15,15 +15,18 @@ from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 -from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip +from sglang.srt.utils import ( + direct_register_custom_op, + get_device_name, + is_cuda_available, + is_hip, +) -is_hip_flag = False -if not is_hip(): +is_cuda = is_cuda_available() +is_hip_flag = is_hip() +if is_cuda: from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size - is_hip_flag = False -else: - is_hip_flag = True logger = logging.getLogger(__name__) padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index 96eaf856616f..75d4c5ead650 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -5,20 +5,21 @@ from typing import Callable, List, Optional, Tuple import torch -from vllm.distributed import ( +from vllm.model_executor.custom_op import CustomOp + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.custom_op import CustomOp - from sglang.srt.layers.custom_op_util import register_custom_op +from sglang.srt.layers.moe.fused_moe_native import moe_forward_native from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.utils import set_weight_attrs +from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs if torch.cuda.is_available(): from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts @@ -27,6 +28,8 @@ import logging +is_hip_ = is_hip() + logger = logging.getLogger(__name__) @@ -97,6 +100,20 @@ def create_weights( layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + if is_hip_ and get_bool_env_var("CK_MOE"): + layer.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + return + def apply( self, layer: torch.nn.Module, @@ -148,17 +165,52 @@ def forward_cuda( correction_bias=correction_bias, ) - return fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - ) + if is_hip_ and get_bool_env_var("CK_MOE"): + import ater + from ater.fused_moe import fused_experts_ck + + return fused_experts_ck( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + ) + else: + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + ) - def forward_cpu(self, *args, **kwargs): - raise NotImplementedError("The CPU backend currently does not support MoE.") + def forward_cpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return moe_forward_native( + layer, + x, + use_grouped_topk, + top_k, + router_logits, + renormalize, + topk_group, + num_expert_group, + custom_routing_function, + correction_bias, + ) def forward_tpu(self, *args, **kwargs) -> torch.Tensor: raise NotImplementedError("The TPU backend currently does not support MoE.") @@ -204,6 +256,7 @@ def __init__( prefix: str = "", custom_routing_function: Optional[Callable] = None, correction_bias: Optional[torch.Tensor] = None, + use_presharded_weights: bool = False, ): super().__init__() @@ -243,6 +296,7 @@ def __init__( params_dtype=params_dtype, weight_loader=self.weight_loader, ) + self.use_presharded_weights = use_presharded_weights def _load_per_tensor_weight_scale( self, @@ -395,10 +449,7 @@ def weight_loader( weight_name: str, shard_id: str, expert_id: int, - use_presharded_weights: bool = False, ) -> None: - self.use_presharded_weights = use_presharded_weights - # compressed-tensors checkpoints with packed weights are stored flipped # TODO (mgoin): check self.quant_method.quant_config.quant_format # against known CompressionFormat enum values that have this quality diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 8190321988dc..527a7d499b6a 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -24,7 +24,9 @@ def fused_topk_native( topk: int, renormalize: bool, ): - assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch" + assert ( + hidden_states.shape[0] == gating_output.shape[0] + ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}" M, _ = hidden_states.shape topk_weights = torch.empty( M, topk, dtype=torch.float32, device=hidden_states.device @@ -180,7 +182,7 @@ def select_experts( num_expert_group=num_expert_group, topk_group=topk_group, ) - elif torch_native: + elif torch_native and custom_routing_function is None: topk_weights, topk_ids = fused_topk_native( hidden_states=hidden_states, gating_output=router_logits, diff --git a/python/sglang/srt/layers/parameter.py b/python/sglang/srt/layers/parameter.py new file mode 100644 index 000000000000..d99b2efe85ff --- /dev/null +++ b/python/sglang/srt/layers/parameter.py @@ -0,0 +1,440 @@ +"""Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/parameter.py""" + +import logging +from fractions import Fraction +from typing import Callable, Optional, Union + +import torch +from torch.nn import Parameter + +from sglang.srt.distributed import get_tensor_model_parallel_rank + +__all__ = [ + "BasevLLMParameter", + "PackedvLLMParameter", + "PerTensorScaleParameter", + "ModelWeightParameter", + "ChannelQuantScaleParameter", + "GroupQuantScaleParameter", + "PackedColumnParameter", + "RowvLLMParameter", +] + +logger = logging.getLogger(__name__) + + +class BasevLLMParameter(Parameter): + """ + Base parameter for vLLM linear layers. Extends the torch.nn.parameter + by taking in a linear weight loader. Will copy the loaded weight + into the parameter when the provided weight loader is called. + """ + + def __new__(cls, data: torch.Tensor, **kwargs): + + return super().__new__(cls, data=data, requires_grad=False) + + def __init__(self, data: torch.Tensor, weight_loader: Callable): + """ + Initialize the BasevLLMParameter + + :param data: torch tensor with the parameter data + :param weight_loader: weight loader callable + + :returns: a torch.nn.parameter + """ + + self._weight_loader = weight_loader + + @property + def weight_loader(self): + return self._weight_loader + + def _assert_and_load(self, loaded_weight: torch.Tensor): + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + def load_column_parallel_weight(self, loaded_weight: torch.Tensor): + self._assert_and_load(loaded_weight) + + def load_row_parallel_weight(self, loaded_weight: torch.Tensor): + self._assert_and_load(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + self._assert_and_load(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs): + self._assert_and_load(loaded_weight) + + +class _ColumnvLLMParameter(BasevLLMParameter): + """ + Private class defining weight loading functionality + (load_merged_column_weight, load_qkv_weight) + for parameters being loaded into linear layers with column + parallelism. This includes QKV and MLP layers which are + not already fused on disk. Requires an output dimension + to be defined. Called within the weight loader of + each of the column parallel linear layers. + """ + + def __init__(self, output_dim: int, **kwargs): + self._output_dim = output_dim + super().__init__(**kwargs) + + @property + def output_dim(self): + return self._output_dim + + def load_column_parallel_weight( + self, + loaded_weight: torch.Tensor, + tp_rank: int, + use_presharded_weights: bool = False, + ): + if not use_presharded_weights: + shard_size = self.data.shape[self.output_dim] + loaded_weight = loaded_weight.narrow( + self.output_dim, tp_rank * shard_size, shard_size + ) + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + use_presharded_weights = kwargs.get("use_presharded_weights") + if ( + isinstance(self, (PackedColumnParameter, PackedvLLMParameter)) + and self.packed_dim == self.output_dim + ): + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size + ) + + param_data = self.data + + tp_rank = get_tensor_model_parallel_rank() + param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) + if not use_presharded_weights: + loaded_weight = loaded_weight.narrow( + self.output_dim, tp_rank * shard_size, shard_size + ) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs): + + shard_offset = kwargs.get("shard_offset") + shard_size = kwargs.get("shard_size") + shard_id = kwargs.get("shard_id") + num_heads = kwargs.get("num_heads") + + if ( + isinstance(self, (PackedColumnParameter, PackedvLLMParameter)) + and self.output_dim == self.packed_dim + ): + shard_size, shard_offset = self.adjust_shard_indexes_for_packing( + shard_offset=shard_offset, shard_size=shard_size + ) + + param_data = self.data + shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads + param_data = param_data.narrow(self.output_dim, shard_offset, shard_size) + loaded_weight = loaded_weight.narrow( + self.output_dim, shard_id * shard_size, shard_size + ) + + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowvLLMParameter(BasevLLMParameter): + """ + Parameter class defining weight_loading functionality + (load_row_parallel_weight) for parameters being loaded + into linear layers with row parallel functionality. + Requires an input_dim to be defined. + """ + + def __init__(self, input_dim: int, **kwargs): + self._input_dim = input_dim + super().__init__(**kwargs) + + @property + def input_dim(self): + return self._input_dim + + def load_row_parallel_weight( + self, + loaded_weight: torch.Tensor, + tp_rank: int, + use_presharded_weights: bool = False, + ): + if not use_presharded_weights: + shard_size = self.data.shape[self.input_dim] + loaded_weight = loaded_weight.narrow( + self.input_dim, tp_rank * shard_size, shard_size + ) + + if len(loaded_weight.shape) == 0: + loaded_weight = loaded_weight.reshape(1) + + assert self.data.shape == loaded_weight.shape + self.data.copy_(loaded_weight) + + +class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for linear layer weights. Uses both column and + row parallelism. + """ + + pass + + +class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + grouped quantization. Uses both column and row parallelism. + """ + + pass + + +class ChannelQuantScaleParameter(_ColumnvLLMParameter): + """ + Parameter class for weight scales loaded for weights with + channel-wise quantization. Equivalent to _ColumnvLLMParameter. + """ + + pass + + +class PerTensorScaleParameter(BasevLLMParameter): + """ + Parameter class for scales where the number of scales is + equivalent to the number of logical matrices in fused linear + layers (e.g. for QKV, there are 3 scales loaded from disk). + This is relevant to weights with per-tensor quantization. + Adds functionality to map the scalers to a shard during + weight loading. + + Note: additional parameter manipulation may be handled + for each quantization config specifically, within + process_weights_after_loading + """ + + def __init__(self, **kwargs): + self.qkv_idxs = {"q": 0, "k": 1, "v": 2} + super().__init__(**kwargs) + + def _shard_id_as_int(self, shard_id: Union[str, int]) -> int: + if isinstance(shard_id, int): + return shard_id + + # if not int, assume shard_id for qkv + # map to int and return + assert isinstance(shard_id, str) + assert shard_id in self.qkv_idxs + return self.qkv_idxs[shard_id] + + # For row parallel layers, no sharding needed + # load weight into parameter as is + def load_row_parallel_weight(self, *args, **kwargs): + kwargs.pop("tp_rank", None) + kwargs.pop("use_presharded_weights", None) + super().load_row_parallel_weight(*args, **kwargs) + + def load_merged_column_weight(self, *args, **kwargs): + self._load_into_shard_id(*args, **kwargs) + + def load_qkv_weight(self, *args, **kwargs): + self._load_into_shard_id(*args, **kwargs) + + def load_column_parallel_weight(self, *args, **kwargs): + kwargs.pop("tp_rank", None) + kwargs.pop("use_presharded_weights", None) + super().load_row_parallel_weight(*args, **kwargs) + + def _load_into_shard_id( + self, loaded_weight: torch.Tensor, shard_id: Union[str, int], **kwargs + ): + """ + Slice the parameter data based on the shard id for + loading. + """ + + param_data = self.data + shard_id = self._shard_id_as_int(shard_id) + + # AutoFP8 scales do not have a shape + # compressed-tensors scales do have a shape + if len(loaded_weight.shape) != 0: + assert loaded_weight.shape[0] == 1 + loaded_weight = loaded_weight[0] + + param_data = param_data[shard_id] + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class PackedColumnParameter(_ColumnvLLMParameter): + """ + Parameter for model parameters which are packed on disk + and support column parallelism only. See PackedvLLMParameter + for more details on the packed properties. + """ + + def __init__( + self, + packed_factor: Union[int, Fraction], + packed_dim: int, + marlin_tile_size: Optional[int] = None, + **kwargs + ): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + self._marlin_tile_size = marlin_tile_size + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + @property + def marlin_tile_size(self): + return self._marlin_tile_size + + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + marlin_tile_size=self.marlin_tile_size, + ) + + +class PackedvLLMParameter(ModelWeightParameter): + """ + Parameter for model weights which are packed on disk. + Example: GPTQ Marlin weights are int4 or int8, packed into int32. + Extends the ModelWeightParameter to take in the + packed factor, the packed dimension, and optionally, marlin + tile size for marlin kernels. Adjusts the shard_size and + shard_offset for fused linear layers model weight loading + by accounting for packing and optionally, marlin tile size. + """ + + def __init__( + self, + packed_factor: Union[int, Fraction], + packed_dim: int, + marlin_tile_size: Optional[int] = None, + **kwargs + ): + self._packed_factor = packed_factor + self._packed_dim = packed_dim + self._marlin_tile_size = marlin_tile_size + super().__init__(**kwargs) + + @property + def packed_dim(self): + return self._packed_dim + + @property + def packed_factor(self): + return self._packed_factor + + @property + def marlin_tile_size(self): + return self._marlin_tile_size + + def adjust_shard_indexes_for_packing(self, shard_size, shard_offset): + return _adjust_shard_indexes_for_packing( + shard_size=shard_size, + shard_offset=shard_offset, + packed_factor=self.packed_factor, + marlin_tile_size=self.marlin_tile_size, + ) + + +def permute_param_layout_( + param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs +) -> BasevLLMParameter: + """ + Permute a parameter's layout to the specified input and output dimensions, + useful for forcing the parameter into a known layout, for example, if I need + a packed (quantized) weight matrix to be in the layout + {input_dim = 0, output_dim = 1, packed_dim = 0} + then I can call: + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + to ensure x is in the correct layout (permuting it to the correct layout if + required, asserting if it cannot get it to the correct layout) + """ + + curr_input_dim = getattr(param, "input_dim", None) + curr_output_dim = getattr(param, "output_dim", None) + + if curr_input_dim is None or curr_output_dim is None: + assert param.data.dim() == 2, ( + "permute_param_layout_ only supports 2D parameters when either " + "input_dim or output_dim is not set" + ) + + # if one of the dimensions is not set, set it to the opposite of the other + # we can only do this since we asserted the parameter is 2D above + if curr_input_dim is None: + assert curr_output_dim is not None, "either input or output dim must be set" + curr_input_dim = (curr_output_dim + 1) % 2 + if curr_output_dim is None: + assert curr_input_dim is not None, "either input or output dim must be set" + curr_output_dim = (curr_input_dim + 1) % 2 + + # create permutation from the current layout to the layout with + # self.input_dim at input_dim and self.output_dim at output_dim preserving + # other dimensions + perm = [ + i for i in range(param.data.dim()) if i not in [curr_input_dim, curr_output_dim] + ] + perm.insert(input_dim, curr_input_dim) + perm.insert(output_dim, curr_output_dim) + + if "packed_dim" in kwargs: + assert ( + hasattr(param, "packed_dim") + and param.packed_dim == perm[kwargs["packed_dim"]] + ), "permute_param_layout_ currently doesn't support repacking" + + param.data = param.data.permute(*perm) + if hasattr(param, "_input_dim"): + param._input_dim = input_dim + if hasattr(param, "_output_dim"): + param._output_dim = output_dim + if "packed_dim" in kwargs and hasattr(param, "_packed_dim"): + param._packed_dim = kwargs["packed_dim"] + + return param + + +def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size): + return shard_size * marlin_tile_size, shard_offset * marlin_tile_size + + +def _adjust_shard_indexes_for_packing( + shard_size, shard_offset, packed_factor, marlin_tile_size +): + shard_size = shard_size // packed_factor + shard_offset = shard_offset // packed_factor + if marlin_tile_size is not None: + return _adjust_shard_indexes_for_marlin( + shard_size=shard_size, + shard_offset=shard_offset, + marlin_tile_size=marlin_tile_size, + ) + return shard_size, shard_offset diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index df20a7a4ba47..1c0092c1a40d 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -17,12 +17,13 @@ from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config from vllm.model_executor.layers.quantization.marlin import MarlinConfig -from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8 import Fp8Config +from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config +from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, @@ -42,6 +43,7 @@ "bitsandbytes": BitsAndBytesConfig, "qqq": QQQConfig, "experts_int8": ExpertsInt8Config, + "w8a8_int8": W8A8Int8Config, } @@ -54,33 +56,13 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: return QUANTIZATION_METHODS[quantization] -def fp8_get_quant_method(self, layer, prefix): - """Enhanced get_quant_method for FP8 config.""" - from vllm.model_executor.layers.linear import LinearBase - from vllm.model_executor.layers.quantization.utils.quant_utils import ( - is_layer_skipped, - ) - - from sglang.srt.layers.linear import UnquantizedLinearMethod - from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE - from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod - - if isinstance(layer, LinearBase): - if is_layer_skipped(prefix, self.ignored_layers): - return UnquantizedLinearMethod() - return Fp8LinearMethod(self) - elif isinstance(layer, FusedMoE): - return Fp8MoEMethod(self) - return None - - def gptq_get_quant_method(self, layer, prefix): - from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinLinearMethod, GPTQMarlinMoEMethod, ) + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE if isinstance(layer, LinearBase): @@ -91,12 +73,12 @@ def gptq_get_quant_method(self, layer, prefix): def awq_get_quant_method(self, layer, prefix): - from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.awq_marlin import ( AWQMarlinLinearMethod, AWQMoEMethod, ) + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE if isinstance(layer, LinearBase): @@ -106,13 +88,30 @@ def awq_get_quant_method(self, layer, prefix): return None +def patch_vllm_linear_base_isinstance(): + import builtins + + from vllm.model_executor.layers.linear import LinearBase + + from sglang.srt.layers.linear import LinearBase as PatchedLinearBase + + original_isinstance = builtins.isinstance + + def patched_isinstance(obj, classinfo): + if classinfo is LinearBase: + return original_isinstance(obj, PatchedLinearBase) + return original_isinstance(obj, classinfo) + + builtins.isinstance = patched_isinstance + + def apply_monkey_patches(): """Apply all monkey patches in one place.""" - setattr(Fp8Config, "get_quant_method", fp8_get_quant_method) setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) setattr(AWQMarlinConfig, "get_quant_method", awq_get_quant_method) +patch_vllm_linear_base_isinstance() # Apply patches when module is imported apply_monkey_patches() diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index a263cb2362a9..bd59352a7969 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -1,7 +1,6 @@ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py import logging -import os from typing import Any, Callable, Dict, List, Optional import torch @@ -9,8 +8,6 @@ from torch.nn import Module from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, @@ -25,9 +22,14 @@ per_tensor_dequantize, requantize_with_max_scale, ) -from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter -from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -40,12 +42,15 @@ from sglang.srt.utils import ( get_bool_env_var, is_hip, + permute_weight, print_warning_once, set_weight_attrs, ) ACTIVATION_SCHEMES = ["static", "dynamic"] +is_hip_ = is_hip() + logger = logging.getLogger(__name__) @@ -161,7 +166,7 @@ def __init__(self, quant_config: Fp8Config): # kernel for fast weight-only FP8 quantization self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") # Disable marlin for ROCm - if is_hip(): + if is_hip_: self.use_marlin = False self.block_quant = self.quant_config.weight_block_size is not None @@ -273,7 +278,7 @@ def process_weights_after_loading(self, layer: Module) -> None: # Block quant doesn't need to process weights after loading if self.block_quant: # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip(): + if is_hip_: # activation_scheme: dynamic weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=layer.weight, @@ -330,7 +335,7 @@ def process_weights_after_loading(self, layer: Module) -> None: weight_scale = layer.weight_scale # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip(): + if is_hip_: weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight=weight, weight_scale=weight_scale, @@ -567,7 +572,7 @@ def process_weights_after_loading(self, layer: Module) -> None: # Block quant doesn't need to process weights after loading if self.block_quant: # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip(): + if is_hip_: # activation_scheme: dynamic w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight=layer.w13_weight, @@ -594,7 +599,7 @@ def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint is fp16 or bfloat16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: # If ROCm, use float8_e4m3fnuz instead (MI300x HW) - fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn + fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) @@ -616,18 +621,30 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) - # If ROCm, apply weight padding (min. Mem channel contention) only if set - if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): - layer.w13_weight = torch.nn.Parameter( - F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() + if is_hip_: + if get_bool_env_var("CK_MOE"): + layer.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + elif get_bool_env_var("MOE_PADDING"): + # If ROCm, apply weight padding (min. Mem channel contention) only if set + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return # If checkpoint is fp8, we need to handle that the @@ -658,7 +675,7 @@ def process_weights_after_loading(self, layer: Module) -> None: ) # If ROCm, normalize the weights and scales to e4m3fnuz - if is_hip(): + if is_hip_: # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( @@ -708,18 +725,30 @@ def process_weights_after_loading(self, layer: Module) -> None: max_w13_scales, requires_grad=False ) - # If ROCm, apply weight padding (min. Mem channel contention) only if set - if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): - layer.w13_weight = torch.nn.Parameter( - F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() - layer.w2_weight = torch.nn.Parameter( - F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), - requires_grad=False, - ) - torch.cuda.empty_cache() + if is_hip_: + if get_bool_env_var("CK_MOE"): + layer.w13_weight = torch.nn.Parameter( + permute_weight(layer.w13_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + permute_weight(layer.w2_weight.data), + requires_grad=False, + ) + torch.cuda.empty_cache() + elif get_bool_env_var("MOE_PADDING"): + # If ROCm, apply weight padding (min. Mem channel contention) only if set + layer.w13_weight = torch.nn.Parameter( + F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter( + F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), + requires_grad=False, + ) + torch.cuda.empty_cache() return def apply( @@ -752,27 +781,55 @@ def apply( correction_bias=correction_bias, ) - # Expert fusion with FP8 quantization - return fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=True, - w1_scale=( - layer.w13_weight_scale_inv - if self.block_quant - else layer.w13_weight_scale - ), - w2_scale=( - layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale - ), - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale, - block_shape=self.quant_config.weight_block_size, - ) + if is_hip_ and get_bool_env_var("CK_MOE"): + import ater + from ater.fused_moe import fused_experts_ck + + return fused_experts_ck( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + use_fp8_w8a8=True, + w1_scale=( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ), + w2_scale=( + layer.w2_weight_scale_inv + if self.block_quant + else layer.w2_weight_scale + ), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + ) + + else: + # Expert fusion with FP8 quantization + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + w1_scale=( + layer.w13_weight_scale_inv + if self.block_quant + else layer.w13_weight_scale + ), + w2_scale=( + layer.w2_weight_scale_inv + if self.block_quant + else layer.w2_weight_scale + ), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size, + ) class Fp8KVCacheMethod(BaseKVCacheMethod): diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 140e70dd9d20..d6ff12ee1635 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -1,8 +1,8 @@ from typing import List, Optional, Tuple import torch -from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter +from sglang.srt.layers.parameter import RowvLLMParameter, _ColumnvLLMParameter from sglang.srt.layers.quantization.fp8_kernel import ( per_token_group_quant_fp8, w8a8_block_fp8_matmul, diff --git a/python/sglang/srt/layers/quantization/int8_kernel.py b/python/sglang/srt/layers/quantization/int8_kernel.py new file mode 100644 index 000000000000..91b56f9e0e9c --- /dev/null +++ b/python/sglang/srt/layers/quantization/int8_kernel.py @@ -0,0 +1,54 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _per_token_quant_int8( + x_ptr, + xq_ptr, + scale_ptr, + stride_x, + stride_xq, + N, + BLOCK: tl.constexpr, +): + # Adapted from https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282 + row_id = tl.program_id(0) + + cols = tl.arange(0, BLOCK) + mask = cols < N + + x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32) + absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10) + scale_x = absmax / 127 + x_q = x * (127 / absmax) + x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8) + + tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask) + tl.store(scale_ptr + row_id, scale_x) + + +def per_token_quant_int8(x): + M = x.numel() // x.shape[-1] + N = x.shape[-1] + x_q = torch.empty_like(x, device=x.device, dtype=torch.int8) + scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32) + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + + assert x.is_contiguous() + _per_token_quant_int8[(M,)]( + x, + x_q, + scales, + stride_x=x.stride(-2), + stride_xq=x_q.stride(-2), + N=N, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + + return x_q, scales diff --git a/python/sglang/srt/layers/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py similarity index 96% rename from python/sglang/srt/layers/modelopt_quant.py rename to python/sglang/srt/layers/quantization/modelopt_quant.py index 2c0887df2391..3e5f996ed10d 100644 --- a/python/sglang/srt/layers/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -5,15 +5,14 @@ import torch from torch.nn.parameter import Parameter -from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( apply_fp8_linear, cutlass_fp8_supported, requantize_with_max_scale, ) -from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter -from sglang.srt.layers.linear import LinearMethodBase +from sglang.srt.layers.linear import LinearBase, LinearMethodBase +from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -142,6 +141,7 @@ def create_weights( data=torch.full( (len(output_partition_sizes),), torch.finfo(torch.float32).min, + dtype=torch.float32, ), weight_loader=weight_loader, ), diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py new file mode 100644 index 000000000000..87ba4cfc5593 --- /dev/null +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -0,0 +1,117 @@ +from typing import Any, Dict, List, Optional + +import torch + +from sglang.srt.utils import is_cuda_available + +is_cuda = is_cuda_available() +if is_cuda: + from sgl_kernel import int8_scaled_mm + +from torch.nn.parameter import Parameter + +from sglang.srt.layers.linear import LinearMethodBase +from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 + + +class W8A8Int8Config(QuantizationConfig): + """Config class for W8A8 Int8 Quantization. + + - Weight: static, per-channel, symmetric + - Activation: dynamic, per-token, symmetric + """ + + def __init__(self): + pass + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def get_name(self) -> str: + return "w8a8_int8" + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config": + return cls() + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str, + ) -> Optional["QuantizeMethodBase"]: + from sglang.srt.layers.linear import LinearBase + + if isinstance(layer, LinearBase): + return W8A8Int8LinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class W8A8Int8LinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: W8A8Int8Config): + self.quantization_config = quantization_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.weight = Parameter(layer.weight.t(), requires_grad=False) + layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs + ): + + weight_loader = extra_weight_attrs.get("weight_loader") + self.logical_widths = output_partition_sizes + + weight = ModelWeightParameter( + data=torch.empty( + sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8 + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale", weight_scale) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ): + x_q, x_scale = per_token_quant_int8(x) + + return int8_scaled_mm( + x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias + ) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 4b762c00ba55..0d46e7bba9aa 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -47,6 +47,8 @@ def __init__( self.logit_cap = logit_cap self.sliding_window_size = sliding_window_size or -1 self.is_cross_attention = is_cross_attention + self.k_scale = None + self.v_scale = None def forward( self, diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 80158573bd63..ad265830f8f7 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -1,54 +1,917 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""MRotaryEmbedding""" +# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py + +"""Rotary Positional Embeddings.""" +import math from typing import Any, Dict, List, Optional, Tuple, Union import torch +import torch.nn as nn +from vllm.model_executor.custom_op import CustomOp + +from sglang.srt.layers.custom_op_util import register_custom_op + + +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +@register_custom_op("sglang_rotary_embedding") +class RotaryEmbedding(CustomOp): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm import _custom_ops as ops + + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + return query, key + + def forward_xpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm._ipex_ops import ipex_ops as ops + + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, dtype=query.dtype) + ops.rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + return query, key + + def forward_hpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from habana_frameworks.torch.hpex.kernels import ( + RotaryPosEmbeddingMode, + apply_rotary_pos_emb, + ) + + positions = positions.flatten() + if offsets is not None: + positions = positions + offsets + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions).view(num_tokens, 1, -1) + cos, sin = cos_sin.chunk(2, dim=-1) + # HPU RoPE kernel requires hidden dimension for cos and sin to be equal + # to query hidden dimension, so the original tensors need to be + # expanded + # GPT-NeoX kernel requires position_ids = None, offset, mode = BLOCKWISE + # and expansion of cos/sin tensors via concatenation + # GPT-J kernel requires position_ids = None, offset = 0, mode = PAIRWISE + # and expansion of cos/sin tensors via repeat_interleave + rope_mode: RotaryPosEmbeddingMode + if self.is_neox_style: + rope_mode = RotaryPosEmbeddingMode.BLOCKWISE + cos = torch.cat((cos, cos), dim=-1) + sin = torch.cat((sin, sin), dim=-1) + else: + rope_mode = RotaryPosEmbeddingMode.PAIRWISE + sin = torch.repeat_interleave(sin, 2, dim=-1, output_size=cos_sin.shape[-1]) + cos = torch.repeat_interleave(cos, 2, dim=-1, output_size=cos_sin.shape[-1]) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = apply_rotary_pos_emb(query_rot, cos, sin, None, 0, rope_mode) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = apply_rotary_pos_emb(key_rot, cos, sin, None, 0, rope_mode) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. + + It supports multiple scaling factors. Since multiple LoRA adapters may have + different scaling factors, we need multiple cos/sin caches. In this way, + instead of running rotary embedding kernel per lora, we can run multiple + lora in a batched way. + + In addition to that, we also keep the cos/sin cache for the scaling factor + of 1 (default) at all times. + + Exemplary for two scaling factors x=1, y and z with embeddings + [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and + [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and + [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]], + + we construct the cos/sin cache as follows: + [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p], + ... + [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]] + + We then use offsets to index into the cos/sin cache for + the respective scaling factors. + + The offset to cache can be accessed via `scaling_factor_to_offset` API. + + Credits to the Reddit user /u/kaiokendev + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factors: Union[List[float], float], + dtype: torch.dtype, + ) -> None: + if isinstance(scaling_factors, float): + scaling_factors = [scaling_factors] + self.scaling_factors: List[float] = scaling_factors # noqa + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + # Lazy initialized. + self._scaling_factor_to_offset: Dict[float, int] + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + cache_list: List[torch.Tensor] = [] + # offsets to the next cache in a tensor. + # Each offset corresponds to the same index in scaling_factors. + offsets: List[int] = [] + for scaling_factor in self.scaling_factors: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * scaling_factor + t = torch.arange(max_len, dtype=torch.float) + t = t / scaling_factor + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + if not cache_list: + offset = 0 + else: + last_offset = offsets[-1] + next_max_len = cache_list[-1].shape[0] + offset = last_offset + next_max_len + offsets.append(offset) + cache_list.append(cache) + self._scaling_factor_to_offset = { + float(scaling_factor): offsets[i] + for i, scaling_factor in enumerate(self.scaling_factors) + } + assert len(self.scaling_factors) == len(offsets) + return torch.cat(cache_list, dim=0) + + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self._scaling_factor_to_offset + + +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + ) -> None: + self.scaling_factor = scaling_factor + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * self.scaling_factor + base = self.base * ( + (self.scaling_factor * max_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.rotary_dim / (self.rotary_dim - 2)) + inv_freq = self._compute_inv_freq(base) + t = torch.arange(max_len, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> Tuple[int, int]: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask( + low: float, high: float, dim: int, dtype: torch.dtype +) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def _yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +class YaRNScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, dtype=torch.float32 + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + +class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): + """Phi3 family of models scaled rotary embedding. + + Based on the original RotaryEmbedding implementation. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + original_max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + short_factor: List[float], + long_factor: List[float], + short_mscale: Optional[float] = None, + long_mscale: Optional[float] = None, + ): + super().__init__() + + if rotary_dim != head_size: + raise ValueError( + f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \ + rotary_dim != head_size ({rotary_dim}!={head_size})." + ) + if is_neox_style is False: + raise ValueError( + "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style." + ) + + self.head_size = head_size + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.base = base + self.short_factor = short_factor + self.long_factor = long_factor + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt( + 1 + math.log(scale) / math.log(self.original_max_position_embeddings) + ) + if short_mscale is None: + short_mscale = scaling_factor + if long_mscale is None: + long_mscale = scaling_factor + + self.short_mscale = short_mscale + self.long_mscale = long_mscale + + short_cache = self._compute_cos_sin_cache( + original_max_position_embeddings, short_factor, short_mscale + ) + short_cache = short_cache.to(dtype) + self.register_buffer("short_cos_sin_cache", short_cache, persistent=False) + + long_cache = self._compute_cos_sin_cache( + max_position_embeddings, long_factor, long_mscale + ) + long_cache = long_cache.to(dtype) + self.register_buffer("long_cos_sin_cache", long_cache, persistent=False) + + long_short_cache = torch.cat( + [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0 + ) + self.register_buffer( + "long_short_cos_sin_cache", long_short_cache, persistent=False + ) + + def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor: + rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) + inv_freq = 1.0 / ( + rescale_factors + * ( + self.base + ** ( + torch.arange(0, self.head_size, 2, dtype=torch.float) + / self.head_size + ) + ) + ) + return inv_freq + + def _compute_cos_sin_cache( + self, + max_position_embeddings: int, + rescale_factors: List[float], + mscale: float, + ) -> torch.Tensor: + inv_freq = self._compute_inv_freq(rescale_factors) + t = torch.arange(max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * mscale + sin = freqs.sin() * mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + + k = self.original_max_position_embeddings + long_prompt_offset = ( + torch.any(positions > k).float() * torch.full_like(positions, k) + ).long() + idx = ( + torch.add(positions, long_prompt_offset) + if long_prompt_offset is not None + else positions + ) + self.long_short_cos_sin_cache: torch.Tensor = self.long_short_cos_sin_cache.to( + idx.device + ) + idx = torch.add(idx, offsets) if offsets is not None else idx + cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) + + cos, sin = cos_sin.chunk(2, dim=-1) + cos = cos.repeat(1, 2).unsqueeze(-2) + sin = sin.repeat(1, 2).unsqueeze(-2) + + query = query * cos + _rotate_neox(query) * sin + key = key * cos + _rotate_neox(key) * sin + + return query.flatten(-2), key.flatten(-2) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + Credits to Peng et al. github.com/jquesnelle/yarn + """ -class MRotaryEmbedding: + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + device: Optional[str] = "cuda", + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) + / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) + * attn_factor + ) + self.device = device + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) + / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, + device=self.device, + dtype=torch.float32, + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + print("Cache shape", cache.shape) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim :] + key_pass = key[..., self.rotary_dim :] + + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) + cos_sin = self.cos_sin_cache[ + torch.add(positions, offsets) if offsets is not None else positions + ] + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key + + +class Llama3RotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + scaling_factor: float, + low_freq_factor: float, + high_freq_factor: float, + orig_max_position: int, + ) -> None: + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + low_freq_wavelen = self.orig_max_position / self.low_freq_factor + high_freq_wavelen = self.orig_max_position / self.high_freq_factor + + wave_len = 2 * math.pi / inv_freqs + if self.low_freq_factor != self.high_freq_factor: + smooth = (self.orig_max_position / wave_len - self.low_freq_factor) / ( + self.high_freq_factor - self.low_freq_factor + ) + else: + smooth = 0 + new_freqs = torch.where( + wave_len < high_freq_wavelen, + inv_freqs, + torch.where( + wave_len > low_freq_wavelen, + inv_freqs / self.scaling_factor, + (1 - smooth) * inv_freqs / self.scaling_factor + smooth * inv_freqs, + ), + ) + return new_freqs + + +class MRotaryEmbedding(RotaryEmbedding): """Rotary Embedding with Multimodal Sections.""" + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + mrope_section: Optional[List[int]] = None, + ) -> None: + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + self.mrope_section = mrope_section + if self.mrope_section: + assert sum(self.mrope_section) == rotary_dim // 2 + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward(). + + Args: + positions: + [num_tokens,] (text only) or + [3, num_tokens] (T/H/W positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 1 or positions.ndim == 2 + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + cos = torch.cat( + [m[i] for i, m in enumerate(cos.split(self.mrope_section, dim=-1))], + dim=-1, + ) + sin = torch.cat( + [m[i] for i, m in enumerate(sin.split(self.mrope_section, dim=-1))], + dim=-1, + ) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + @staticmethod def get_input_positions( - input_tokens: torch.Tensor, + input_tokens: List[int], image_grid_thw: Union[List[List[int]], torch.Tensor], + video_grid_thw: Union[List[List[int]], torch.Tensor], + image_token_id: int, + video_token_id: int, vision_start_token_id: int, + vision_end_token_id: int, spatial_merge_size: int, context_len: int = 0, + seq_len: Optional[int] = None, ) -> Tuple[List[List[int]], int]: """Get mrope input positions and delta value.""" if isinstance(image_grid_thw, torch.Tensor): image_grid_thw = image_grid_thw.tolist() + if isinstance(video_grid_thw, torch.Tensor): + video_grid_thw = video_grid_thw.tolist() + input_tokens_tensor = torch.tensor(input_tokens) vision_start_indices = torch.argwhere( - input_tokens == vision_start_token_id + input_tokens_tensor == vision_start_token_id ).squeeze(1) - image_indices = vision_start_indices + 1 - image_nums = image_indices.shape[0] + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() llm_pos_ids_list: list = [] st = 0 - input_tokens_len = input_tokens.shape[0] - for image_index in range(image_nums): - ed = image_indices[image_index].item() - t, h, w = ( - image_grid_thw[image_index][0], - image_grid_thw[image_index][1], - image_grid_thw[image_index][2], - ) + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video llm_grid_t, llm_grid_h, llm_grid_w = ( t, h // spatial_merge_size, @@ -84,16 +947,17 @@ def get_input_positions( ) st = ed + llm_grid_t * llm_grid_h * llm_grid_w - if st < input_tokens_len: + if st < len(input_tokens): st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 - text_len = input_tokens_len - st + text_len = len(input_tokens) - st llm_pos_ids_list.append( torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx ) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - llm_positions = llm_positions[:, context_len:] - mrope_position_delta = (llm_positions.max() + 1 - input_tokens_len).item() + mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() + llm_positions = llm_positions[:, context_len:seq_len] + return llm_positions.tolist(), mrope_position_delta @staticmethod @@ -110,3 +974,292 @@ def get_next_input_positions( ) for _ in range(3) ] + + +_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} + + +def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, +) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + key = ( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling_args, + dtype, + ) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + if rope_scaling is None: + rotary_emb = RotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype + ) + else: + if "rope_type" in rope_scaling: + scaling_type = rope_scaling["rope_type"] + elif "type" in rope_scaling: + scaling_type = rope_scaling["type"] + else: + raise ValueError("Unknown RoPE scaling type") + + if scaling_type == "llama3": + scaling_factor = rope_scaling["factor"] + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + scaling_factor, + low_freq_factor, + high_freq_factor, + original_max_position, + ) + elif scaling_type == "default": + if "mrope_section" in rope_scaling: + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + ) + else: + rotary_emb = RotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + ) + elif scaling_type == "linear": + scaling_factor = rope_scaling["factor"] + rotary_emb = LinearScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) + elif scaling_type == "dynamic": + scaling_factor = rope_scaling["factor"] + rotary_emb = DynamicNTKScalingRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + scaling_factor, + dtype, + ) + elif scaling_type == "yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k + in ("extrapolation_factor", "attn_factor", "beta_fast", "beta_slow") + } + rotary_emb = YaRNScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) + elif scaling_type == "deepseek_yarn": + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + # assert max_position == original_max_position * scaling_factor + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ) + } + rotary_emb = DeepseekScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) + elif scaling_type == "longrope": + short_factor = rope_scaling["short_factor"] + long_factor = rope_scaling["long_factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("short_mscale", "long_mscale") + } + rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( + head_size, + rotary_dim, + max_position, + original_max_position, + base, + is_neox_style, + dtype, + short_factor, + long_factor, + **extra_kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + _ROPE_DICT[key] = rotary_emb + return rotary_emb + + +def get_rope_cpu( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, + device: Optional[str] = None, +) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + key = ( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling_args, + dtype, + ) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + assert rope_scaling is not None + scaling_type = rope_scaling["rope_type"] + assert ( + scaling_type == "deepseek_yarn" + ), "Only deepseek_yarn is supported for CPU for now" + + scaling_factor = rope_scaling["factor"] + original_max_position = rope_scaling["original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k + in ( + "extrapolation_factor", + "attn_factor", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ) + } + extra_kwargs["device"] = device + rotary_emb = DeepseekScalingRotaryEmbedding( + head_size, + rotary_dim, + original_max_position, + base, + is_neox_style, + scaling_factor, + dtype, + **extra_kwargs, + ) + + _ROPE_DICT[key] = rotary_emb + return rotary_emb + + +def get_rope_wrapper( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, + device: Optional[str] = None, +): + if device != "cpu": + return get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling, + dtype, + partial_rotary_factor, + ) + + return get_rope_cpu( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + rope_scaling, + dtype, + partial_rotary_factor, + device, + ) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 23037650a31c..3173d533d16e 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -2,12 +2,19 @@ from typing import List import torch +import torch.distributed as dist from torch import nn +from sglang.srt.distributed import get_tensor_model_parallel_group +from sglang.srt.layers.dp_attention import get_attention_tp_group from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.utils import crash_on_warnings, is_flashinfer_available +from sglang.srt.utils import ( + crash_on_warnings, + get_bool_env_var, + is_flashinfer_available, +) if is_flashinfer_available(): from flashinfer.sampling import ( @@ -20,11 +27,17 @@ logger = logging.getLogger(__name__) +SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP") + class Sampler(nn.Module): def __init__(self): super().__init__() self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"] + self.tp_sync_group = get_tensor_model_parallel_group().device_group + + if global_server_args_dict["enable_dp_attention"]: + self.tp_sync_group = get_attention_tp_group().device_group def forward( self, @@ -35,6 +48,10 @@ def forward( ): logits = logits_output.next_token_logits + # Apply the custom logit processors if registered in the sampling info. + if sampling_info.has_custom_logit_processor: + self._apply_custom_logit_processor(logits, sampling_info) + if self.use_nan_detectioin and torch.any(torch.isnan(logits)): logger.warning("Detected errors during sampling! NaN in the logits.") logits = torch.where( @@ -104,8 +121,6 @@ def forward( f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}" ) - batch_next_token_ids = batch_next_token_ids.to(torch.int32) - # Attach logprobs to logits_output (in-place modification) if return_logprob: if any(x > 0 for x in top_logprobs_nums): @@ -119,7 +134,54 @@ def forward( batch_next_token_ids, ] - return batch_next_token_ids + if SYNC_TOKEN_IDS_ACROSS_TP or sampling_info.grammars: + # For performance reasons, SGLang does not sync the final token IDs across TP ranks by default. + # This saves one all-reduce, but the correctness of this approach depends on the determinism of several operators: + # the last all-reduce, the last lm_head matmul, and all sampling kernels. + # These kernels are deterministic in most cases, but there are some rare instances where they are not deterministic. + # In such cases, enable this env variable to prevent hanging due to TP ranks becoming desynchronized. + # When using xgrammar, this becomes more likely so we also do the sync when grammar is used. + + torch.distributed.all_reduce( + batch_next_token_ids, + op=dist.ReduceOp.MIN, + group=self.tp_sync_group, + ) + + return batch_next_token_ids.to(torch.int32) + + def _apply_custom_logit_processor( + self, logits: torch.Tensor, sampling_batch_info: SamplingBatchInfo + ): + """Apply custom logit processors to the logits. + This function will modify the logits in-place.""" + + assert logits.shape[0] == len(sampling_batch_info), ( + f"The batch size of logits ({logits.shape[0]}) does not match the batch size of " + f"sampling_batch_info ({len(sampling_batch_info)})" + ) + + for _, ( + processor, + batch_mask, + ) in sampling_batch_info.custom_logit_processor.items(): + # Get the batch indices that need to be processed + batch_indices = batch_mask.nonzero(as_tuple=True)[0] + + assert batch_mask.shape[0] == len(sampling_batch_info), ( + f"The number of batch mask ({batch_mask.shape[0]}) does not match the number of " + f"sampling_batch_info ({len(sampling_batch_info)})" + ) + + # Apply the processor to the logits + logits[batch_mask] = processor( + logits[batch_mask], + [sampling_batch_info.custom_params[i] for i in batch_indices], + ) + + logger.debug( + f"Custom logit processor {processor.__class__.__name__} is applied." + ) def top_k_top_p_min_p_sampling_from_probs_torch( diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index c5bca25df373..e08abd5ae1d5 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -5,6 +5,7 @@ import logging import os import pwd +from typing import Callable, Optional import torch @@ -27,8 +28,18 @@ def save_gemlite_cache(print_error: bool = False) -> bool: return True +def proj_filter( + module: torch.nn.Module, + fqn: str, +): + """Filter function for quantizing projection layers.""" + return "proj" in fqn + + def apply_torchao_config_to_model( - model: torch.nn.Module, torchao_config: str, filter_fn=None + model: torch.nn.Module, + torchao_config: str, + filter_fn: Optional[Callable] = proj_filter, ): """Quantize a modelwith torchao quantization specified by torchao_config @@ -49,11 +60,6 @@ def apply_torchao_config_to_model( ) from torchao.quantization.observer import PerRow, PerTensor - if filter_fn is None: - - def filter_fn(module, fqn): - return "proj" in fqn - if torchao_config == "" or torchao_config is None: return model elif "int8wo" in torchao_config: diff --git a/python/sglang/srt/layers/vocab_parallel_embedding.py b/python/sglang/srt/layers/vocab_parallel_embedding.py index effea1c6c950..ed9d67ef9706 100644 --- a/python/sglang/srt/layers/vocab_parallel_embedding.py +++ b/python/sglang/srt/layers/vocab_parallel_embedding.py @@ -6,14 +6,14 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter -from vllm.distributed import ( + +from sglang.srt.distributed import ( divide, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.parameter import BasevLLMParameter - +from sglang.srt.layers.parameter import BasevLLMParameter from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -220,6 +220,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", enable_tp: bool = True, + use_presharded_weights: bool = False, ): super().__init__() self.quant_config = quant_config @@ -236,6 +237,12 @@ def __init__( self.padding_size = padding_size self.org_vocab_size = org_num_embeddings or num_embeddings num_added_embeddings = num_embeddings - self.org_vocab_size + self.use_presharded_weights = use_presharded_weights + if use_presharded_weights: + assert ( + num_added_embeddings == 0 + ), "Lora is not supported with presharded weights." + self.org_vocab_size_padded = pad_vocab_size( self.org_vocab_size, self.padding_size ) @@ -447,10 +454,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): start_idx = start_idx // packed_factor shard_size = shard_size // packed_factor else: - assert loaded_weight.shape[output_dim] == self.org_vocab_size + assert loaded_weight.shape[output_dim] == ( + self.org_vocab_size + // (self.tp_size if self.use_presharded_weights else 1) + ) # Copy the data. - loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) + if not self.use_presharded_weights: + loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) param[: loaded_weight.shape[0]].data.copy_(loaded_weight) param[loaded_weight.shape[0] :].data.fill_(0) @@ -514,6 +525,7 @@ def __init__( padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + use_presharded_weights: bool = False, ): super().__init__( num_embeddings, @@ -523,6 +535,7 @@ def __init__( padding_size, quant_config, prefix, + use_presharded_weights=use_presharded_weights, ) self.quant_config = quant_config if bias: diff --git a/python/sglang/srt/lora/lora.py b/python/sglang/srt/lora/lora.py index 839d10222e2b..c8cbe36602b2 100644 --- a/python/sglang/srt/lora/lora.py +++ b/python/sglang/srt/lora/lora.py @@ -19,18 +19,11 @@ # https://github.com/vllm-project/vllm/blob/4abf6336ec65c270343eb895e7b18786e9274176/vllm/lora/layers.py -import json -import os import re -from typing import Any, Dict, List, Optional, Tuple -import safetensors.torch import torch from torch import nn -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, - VocabParallelEmbedding, -) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -38,7 +31,6 @@ QKVParallelLinear, RowParallelLinear, ) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_loader.loader import DefaultModelLoader diff --git a/python/sglang/srt/managers/cache_controller.py b/python/sglang/srt/managers/cache_controller.py new file mode 100644 index 000000000000..4560a270870f --- /dev/null +++ b/python/sglang/srt/managers/cache_controller.py @@ -0,0 +1,307 @@ +from __future__ import annotations + +""" +Copyright 2023-2025 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import logging +import threading +from queue import PriorityQueue, Queue +from typing import Optional + +import torch + +from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost + +logger = logging.getLogger(__name__) + + +class CacheOperation: + + counter = 0 + + def __init__( + self, + host_indices: torch.Tensor, + device_indices: torch.Tensor, + node_id: int, + priority: Optional[int] = None, + ): + self.host_indices = host_indices + self.device_indices = device_indices + self.node_ids = [node_id] + self.data = None + + self.id = CacheOperation.counter + CacheOperation.counter += 1 + # default priority is the order of creation + self.priority = priority if priority is not None else self.id + + def merge(self, other: "CacheOperation") -> None: + # multiple operations can be merged into a single operation for batch processing + self.host_indices = torch.cat([self.host_indices, other.host_indices]) + self.device_indices = torch.cat([self.device_indices, other.device_indices]) + self.priority = min(self.priority, other.priority) + self.node_ids.extend(other.node_ids) + + def __lt__(self, other: "CacheOperation"): + return self.priority < other.priority + + +class TransferBuffer: + """ + Overlapping buffer preparation and transfer operations to improve throughput. + """ + + def __init__(self, buffer_count: int = 3, max_buffer_size: int = 1000) -> None: + self.buffers = Queue(maxsize=buffer_count) + # todo: adjust the buffer size based on throughput profile of the system + self.max_buffer_size = max_buffer_size + + def full(self) -> bool: + return self.buffers.full() + + def empty(self) -> bool: + return self.buffers.empty() + + def put(self, item, block=True) -> None: + self.buffers.put(item, block=block) + + def get(self, block=True) -> Optional[CacheOperation]: + try: + return self.buffers.get(block=block) + except Exception as e: + logger.error(e) + + +class HiCacheController: + + def __init__( + self, + mem_pool_device: MHATokenToKVPool, + mem_pool_host: MLATokenToKVPoolHost, + write_policy: str = "write_through_selective", + ): + + self.mem_pool_device = mem_pool_device + self.mem_pool_host = mem_pool_host + self.write_policy = write_policy + + if write_policy not in [ + "write_through", + "write_through_selective", + "write_back", + ]: + raise ValueError(f"Invalid write policy: {write_policy}") + + self.write_queue = PriorityQueue() + self.load_queue = PriorityQueue() + + self.ack_write_queue = Queue() + self.ack_load_queue = Queue() + + self.write_buffer = TransferBuffer() + self.load_buffer = TransferBuffer() + + self.write_stream = torch.cuda.Stream() + self.load_stream = torch.cuda.Stream() + + self.write_thread = threading.Thread( + target=self.write_thread_func_buffer, daemon=True + ) + self.load_thread = threading.Thread( + target=self.load_thread_func_buffer, daemon=True + ) + self.write_thread.start() + self.load_thread.start() + + def write( + self, + device_indices: torch.Tensor, + priority: Optional[int] = None, + node_id: int = 0, + ) -> Optional[torch.Tensor]: + """ + Back up KV caches from device memory to host memory. + """ + host_indices = self.mem_pool_host.alloc(len(device_indices)) + if host_indices is None: + return None + self.write_queue.put( + CacheOperation(host_indices, device_indices, node_id, priority) + ) + self.mem_pool_host.protect_write(host_indices) + return host_indices + + def load( + self, + host_indices: torch.Tensor, + priority: Optional[int] = None, + node_id: int = 0, + ) -> Optional[torch.Tensor]: + """ + Load KV caches from host memory to device memory. + """ + device_indices = self.mem_pool_device.alloc(len(host_indices)) + if device_indices is None: + return None + self.load_queue.put( + CacheOperation(host_indices, device_indices, node_id, priority) + ) + self.mem_pool_host.protect_load(host_indices) + return device_indices + + def write_thread_func_direct(self): + """ + Directly write through KV caches to host memory without buffering. + """ + with torch.cuda.stream(self.write_stream): + while True: + try: + operation = self.write_queue.get(block=True) + operation.data = self.mem_pool_device.get_flat_data( + operation.device_indices + ) + self.mem_pool_host.transfer(operation.host_indices, operation.data) + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + self.ack_write_queue.put(node_id) + except Exception as e: + logger.error(e) + + def load_thread_func_direct(self): + """ + Directly load KV caches from host memory to device memory without buffering. + """ + with torch.cuda.stream(self.load_stream): + while True: + try: + operation = self.load_queue.get(block=True) + operation.data = self.mem_pool_host.get_flat_data( + operation.host_indices + ) + self.mem_pool_device.transfer( + operation.device_indices, operation.data + ) + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + self.ack_load_queue.put(node_id) + except Exception as e: + logger.error(e) + + def write_aux_func(self, no_wait=False): + """ + Auxiliary function to prepare the buffer for write operations. + """ + buffer = None + while True: + try: + operation = self.write_queue.get(block=True) + if buffer is None: + buffer = operation + else: + buffer.merge(operation) + if ( + no_wait + or len(buffer.host_indices) >= self.write_buffer.max_buffer_size + or self.write_queue.empty() + or self.write_buffer.empty() + ): + assert ( + buffer.device_indices.is_cuda + ), "Device indices should be on GPU" + buffer.data = self.mem_pool_device.get_flat_data( + buffer.device_indices + ).contiguous() + self.write_buffer.put(buffer, block=True) + buffer = None + except Exception as e: + logger.error(e) + + def load_aux_func(self): + """ + Auxiliary function to prepare the buffer for load operations. + """ + buffer = None + while True: + try: + operation = self.load_queue.get(block=True) + if buffer is None: + buffer = operation + else: + buffer.merge(operation) + if ( + len(buffer.host_indices) >= self.load_buffer.max_buffer_size + or self.load_queue.empty() + or self.load_buffer.empty() + ): + buffer.data = ( + self.mem_pool_host.get_flat_data(buffer.host_indices) + .contiguous() + .pin_memory() + ) + self.load_buffer.put(buffer, block=True) + buffer = None + except Exception as e: + logger.error(e) + + def write_thread_func_buffer(self): + aux_thread = threading.Thread(target=self.write_aux_func, daemon=True) + aux_thread.start() + with torch.cuda.stream(self.write_stream): + while True: + operation = self.write_buffer.get() + if operation is None: + continue + self.mem_pool_host.transfer(operation.host_indices, operation.data) + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + self.ack_write_queue.put(node_id) + + def load_thread_func_buffer(self): + aux_thread = threading.Thread(target=self.load_aux_func, daemon=True) + aux_thread.start() + with torch.cuda.stream(self.load_stream): + while True: + operation = self.load_buffer.get() + if operation is None: + continue + self.mem_pool_device.transfer(operation.device_indices, operation.data) + self.mem_pool_host.complete_io(operation.host_indices) + for node_id in operation.node_ids: + self.ack_load_queue.put(node_id) + + def evict_device( + self, device_indices: torch.Tensor, host_indices: torch.Tensor + ) -> int: + if self.mem_pool_host.is_synced(host_indices): + self.mem_pool_device.free(device_indices) + self.mem_pool_host.update_backup(host_indices) + return len(device_indices) + else: + raise ValueError( + f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}" + ) + + def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int: + if not backup_only: + raise ValueError("Other eviction policies are not supported yet.") + + if self.mem_pool_host.is_backup(host_indices): + self.mem_pool_host.free(host_indices) + return len(host_indices) + else: + raise ValueError( + f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}" + ) diff --git a/python/sglang/srt/managers/configure_logging.py b/python/sglang/srt/managers/configure_logging.py new file mode 100644 index 000000000000..187af4d9c088 --- /dev/null +++ b/python/sglang/srt/managers/configure_logging.py @@ -0,0 +1,46 @@ +""" +Copyright 2023-2025 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" +Configure the logging settings of a server. + +Usage: +python3 -m sglang.srt.managers.configure_logging --url http://localhost:30000 +""" + +import argparse + +import requests + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--url", type=str, default="http://localhost:30000") + parser.add_argument("--log-requests", action="store_true") + parser.add_argument( + "--dump-requests-folder", type=str, default="/tmp/sglang_request_dump" + ) + parser.add_argument("--dump-requests-threshold", type=int, default=1000) + args = parser.parse_args() + + response = requests.post( + args.url + "/configure_logging", + json={ + "log_requests": args.log_requests, + "log_requests_level": 1, # Log full requests + "dump_requests_folder": args.dump_requests_folder, + "dump_requests_threshold": args.dump_requests_threshold, + }, + ) + assert response.status_code == 200 diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 7ae6689ee694..3b959b1ba768 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -23,6 +23,7 @@ import setproctitle import zmq +from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.managers.io_struct import ( TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -55,6 +56,7 @@ class DataParallelController: def __init__(self, server_args, port_args) -> None: # Parse args + self.max_total_num_tokens = None self.server_args = server_args self.port_args = port_args self.load_balance_method = LoadBalanceMethod.from_str( @@ -63,9 +65,10 @@ def __init__(self, server_args, port_args) -> None: # Init inter-process communication self.context = zmq.Context(1 + server_args.dp_size) - self.recv_from_tokenizer = get_zmq_socket( - self.context, zmq.PULL, port_args.scheduler_input_ipc_name - ) + if server_args.node_rank == 0: + self.recv_from_tokenizer = get_zmq_socket( + self.context, zmq.PULL, port_args.scheduler_input_ipc_name, False + ) # Dispatch method self.round_robin_counter = 0 @@ -75,33 +78,50 @@ def __init__(self, server_args, port_args) -> None: } self.dispatching = dispatch_lookup[self.load_balance_method] - # Start data parallel workers - base_gpu_id = 0 + # Launch data parallel workers + self.scheduler_procs = [] self.workers = [None] * server_args.dp_size + if not server_args.enable_dp_attention: + dp_port_args = self.launch_dp_schedulers(server_args, port_args) + else: + dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args) + + # Only node rank 0 runs the real data parallel controller that dispatches the requests. + if server_args.node_rank == 0: + for dp_rank in range(server_args.dp_size): + self.workers[dp_rank] = get_zmq_socket( + self.context, + zmq.PUSH, + dp_port_args[dp_rank].scheduler_input_ipc_name, + True, + ) + + self.max_req_input_len = None + + def launch_dp_schedulers(self, server_args, port_args): + base_gpu_id = 0 + threads = [] sockets = [] + dp_port_args = [] for dp_rank in range(server_args.dp_size): tmp_port_args = PortArgs.init_new(server_args) tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name + dp_port_args.append(tmp_port_args) - if server_args.enable_dp_attention: - # Data parallelism resues the tensor parallelism group, - # so all dp ranks should use the same nccl port. - tmp_port_args.nccl_port = port_args.nccl_port - else: - # This port is checked free in PortArgs.init_new. - # We hold it first so that the next dp worker gets a different port - sockets.append(bind_port(tmp_port_args.nccl_port)) + # This port is checked free in PortArgs.init_new. + # We hold it first so that the next dp worker gets a different port + sockets.append(bind_port(tmp_port_args.nccl_port)) # Create a thread for each worker thread = threading.Thread( - target=self.launch_worker_func, + target=self.launch_tensor_parallel_group, args=(server_args, tmp_port_args, base_gpu_id, dp_rank), ) threads.append(thread) - base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size + base_gpu_id += server_args.tp_size # Free all sockets before starting the threads to launch TP workers for sock in sockets: @@ -113,26 +133,14 @@ def __init__(self, server_args, port_args) -> None: for thread in threads: thread.join() - def launch_worker_func( - self, - server_args: ServerArgs, - port_args: PortArgs, - base_gpu_id: int, - dp_rank: int, - ): - logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.") + return dp_port_args - launch_func_ = ( - self.launch_tensor_parallel_process - if server_args.enable_dp_attention - else self.launch_tensor_parallel_group - ) - self.workers[dp_rank] = launch_func_( - server_args, - port_args, - base_gpu_id, - dp_rank, - ) + def launch_dp_attention_schedulers(self, server_args, port_args): + self.launch_tensor_parallel_group(server_args, port_args, 0, None) + dp_port_args = [] + for dp_rank in range(server_args.dp_size): + dp_port_args.append(PortArgs.init_new(server_args, dp_rank)) + return dp_port_args def launch_tensor_parallel_group( self, @@ -141,8 +149,10 @@ def launch_tensor_parallel_group( base_gpu_id: int, dp_rank: int, ): + if not server_args.enable_dp_attention: + logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.") + # Launch tensor parallel scheduler processes - scheduler_procs = [] scheduler_pipe_readers = [] tp_size_per_node = server_args.tp_size // server_args.nnodes tp_rank_range = range( @@ -150,52 +160,39 @@ def launch_tensor_parallel_group( tp_size_per_node * (server_args.node_rank + 1), ) for tp_rank in tp_rank_range: + rank_port_args = port_args + + if server_args.enable_dp_attention: + # dp attention has different sharding logic + _, _, dp_rank = compute_dp_attention_world_info( + server_args.enable_dp_attention, + tp_rank, + server_args.tp_size, + server_args.dp_size, + ) + # compute zmq ports for this dp rank + rank_port_args = PortArgs.init_new(server_args, dp_rank) + # Data parallelism resues the tensor parallelism group, + # so all dp ranks should use the same nccl port. + rank_port_args.nccl_port = port_args.nccl_port + reader, writer = mp.Pipe(duplex=False) gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node proc = mp.Process( target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer), + args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer), ) proc.start() - scheduler_procs.append(proc) + self.scheduler_procs.append(proc) scheduler_pipe_readers.append(reader) - send_to = get_zmq_socket( - self.context, zmq.PUSH, port_args.scheduler_input_ipc_name - ) - - # Wait for model to finish loading and get max token nums + # Wait for model to finish loading scheduler_info = [] for i in range(len(scheduler_pipe_readers)): scheduler_info.append(scheduler_pipe_readers[i].recv()) self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] - - return send_to - - def launch_tensor_parallel_process( - self, - server_args: ServerArgs, - port_args: PortArgs, - base_gpu_id: int, - dp_rank: int, - ): - reader, writer = mp.Pipe(duplex=False) - gpu_id = base_gpu_id - tp_rank = dp_rank - proc = mp.Process( - target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer), - ) - proc.start() - send_to = get_zmq_socket( - self.context, zmq.PUSH, port_args.scheduler_input_ipc_name - ) - - scheduler_info = reader.recv() - self.max_total_num_tokens = scheduler_info["max_total_num_tokens"] - - return send_to + self.max_req_input_len = scheduler_info[0]["max_req_input_len"] def round_robin_scheduler(self, req): self.workers[self.round_robin_counter].send_pyobj(req) @@ -221,8 +218,8 @@ def event_loop(self): ): self.dispatching(recv_req) else: - # Send other control messages to all workers - for worker in self.workers: + # Send other control messages to first worker of tp group + for worker in self.workers[:: self.server_args.tp_size]: worker.send_pyobj(recv_req) @@ -238,9 +235,19 @@ def run_data_parallel_controller_process( try: controller = DataParallelController(server_args, port_args) pipe_writer.send( - {"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens} + { + "status": "ready", + "max_total_num_tokens": controller.max_total_num_tokens, + "max_req_input_len": controller.max_req_input_len, + } ) - controller.event_loop() + if server_args.node_rank == 0: + controller.event_loop() + for proc in controller.scheduler_procs: + proc.join() + logger.error( + f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" + ) except Exception: traceback = get_exception_traceback() logger.error(f"DataParallelController hit an exception: {traceback}") diff --git a/python/sglang/srt/managers/detokenizer_manager.py b/python/sglang/srt/managers/detokenizer_manager.py index fd77d338edce..972f9595b2c8 100644 --- a/python/sglang/srt/managers/detokenizer_manager.py +++ b/python/sglang/srt/managers/detokenizer_manager.py @@ -15,6 +15,7 @@ import dataclasses import logging +import os import signal from collections import OrderedDict from typing import Dict, List, Union @@ -35,6 +36,12 @@ logger = logging.getLogger(__name__) +# Maximum number of request states that detokenizer can hold. When exceeded, +# oldest request states will be evicted. Default: 65536 (1<<16). +# For more details, see: https://github.com/sgl-project/sglang/issues/2812 +# Use power of 2 values for better memory allocation. +DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 << 16)) + @dataclasses.dataclass class DecodeStatus: @@ -58,10 +65,10 @@ def __init__( # Init inter-process communication context = zmq.Context(2) self.recv_from_scheduler = get_zmq_socket( - context, zmq.PULL, port_args.detokenizer_ipc_name + context, zmq.PULL, port_args.detokenizer_ipc_name, True ) self.send_to_tokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.tokenizer_ipc_name + context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) if server_args.skip_tokenizer_init: @@ -71,9 +78,10 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) - self.decode_status = LimitedCapacityDict() + self.decode_status = LimitedCapacityDict(capacity=DETOKENIZER_MAX_STATES) def trim_matched_stop( self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool @@ -155,7 +163,17 @@ def event_loop(self): # Incremental decoding output_strs = [] for i in range(bs): - s = self.decode_status[recv_obj.rids[i]] + try: + s = self.decode_status[recv_obj.rids[i]] + except KeyError: + raise RuntimeError( + f"Decode status not found for request {recv_obj.rids[i]}. " + "It may be due to the request being evicted from the decode status due to memory pressure. " + "Please increase the maximum number of requests by setting " + "the SGLANG_DETOKENIZER_MAX_STATES environment variable to a bigger value than the default value. " + f"The current value is {DETOKENIZER_MAX_STATES}. " + "For more details, see: https://github.com/sgl-project/sglang/issues/2812" + ) new_text = read_texts[i][len(surr_texts[i]) :] if recv_obj.finished_reasons[i] is None: # Streaming chunk: update the decode status @@ -181,8 +199,6 @@ def event_loop(self): finished_reasons=recv_obj.finished_reasons, output_strs=output_strs, prompt_tokens=recv_obj.prompt_tokens, - origin_input_ids=recv_obj.origin_input_ids, - output_ids=recv_obj.output_ids, completion_tokens=recv_obj.completion_tokens, cached_tokens=recv_obj.cached_tokens, input_token_logprobs_val=recv_obj.input_token_logprobs_val, @@ -193,13 +209,12 @@ def event_loop(self): input_top_logprobs_idx=recv_obj.input_top_logprobs_idx, output_top_logprobs_val=recv_obj.output_top_logprobs_val, output_top_logprobs_idx=recv_obj.output_top_logprobs_idx, - normalized_prompt_logprob=recv_obj.normalized_prompt_logprob, ) ) class LimitedCapacityDict(OrderedDict): - def __init__(self, capacity=1 << 15, *args, **kwargs): + def __init__(self, capacity: int, *args, **kwargs): super().__init__(*args, **kwargs) self.capacity = capacity diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py index 7120fa48d525..c8ebbed783ae 100644 --- a/python/sglang/srt/managers/image_processor.py +++ b/python/sglang/srt/managers/image_processor.py @@ -9,6 +9,8 @@ import numpy as np import transformers +from decord import VideoReader, cpu +from PIL import Image from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.mm_utils import expand2square, process_anyres_image @@ -36,6 +38,7 @@ class BaseImageProcessor(ABC): def __init__(self, hf_config, server_args, _processor): self.hf_config = hf_config self._processor = _processor + self.server_args = server_args self.executor = concurrent.futures.ProcessPoolExecutor( initializer=init_global_processor, @@ -126,7 +129,12 @@ async def _process_single_image( ) async def process_images_async( - self, image_data: List[Union[str, bytes]], input_text, request_obj + self, + image_data: List[Union[str, bytes]], + input_text, + request_obj, + *args, + **kwargs, ): if not image_data: return None @@ -229,6 +237,147 @@ async def process_images_async( return image_inputs +class MiniCPMVImageProcessor(BaseImageProcessor): + def __init__(self, hf_config, server_args, _processor): + super().__init__(hf_config, server_args, _processor) + + @staticmethod + def _process_images_task(images, input_text): + result = global_processor.__call__( + text=input_text, images=images, return_tensors="pt" + ) + return { + "input_ids": result["input_ids"], + "pixel_values": result["pixel_values"], + "tgt_sizes": result["tgt_sizes"], + } + + async def _process_images(self, images, input_text): + if self.executor is not None: + loop = asyncio.get_event_loop() + image_inputs = await loop.run_in_executor( + self.executor, + MiniCPMVImageProcessor._process_images_task, + images, + input_text, + ) + else: + image_inputs = self._processor( + images=images, text=input_text, return_tensors="pt" + ) + + return image_inputs + + async def process_images_async( + self, + image_data: List[Union[str, bytes]], + input_text, + request_obj, + max_req_input_len, + ): + if not image_data: + return None + + if not isinstance(image_data, list): + image_data = [image_data] + + image_hashes, image_sizes = [], [] + raw_images = [] + IMAGE_TOKEN = "(./)" + + # roughly calculate the max number of frames + # TODO: the process should be applied to all the visual inputs + def calculate_max_num_frames() -> int: + # Model-specific + NUM_TOKEN_PER_FRAME = 330 + + ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME + return min(ret, 100) + + # if cuda OOM set a smaller number + MAX_NUM_FRAMES = calculate_max_num_frames() + print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}") + + def encode_video(video_path): + if not os.path.exists(video_path): + logger.error(f"Video {video_path} does not exist") + return [] + + if MAX_NUM_FRAMES == 0: + return [] + + def uniform_sample(l, n): + gap = len(l) / n + idxs = [int(i * gap + gap / 2) for i in range(n)] + return [l[i] for i in idxs] + + vr = VideoReader(video_path, ctx=cpu(0)) + sample_fps = round(vr.get_avg_fps() / 1) # FPS + frame_idx = [i for i in range(0, len(vr), sample_fps)] + if len(frame_idx) > MAX_NUM_FRAMES: + frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES) + frames = vr.get_batch(frame_idx).asnumpy() + frames = [Image.fromarray(v.astype("uint8")) for v in frames] + return frames + + if isinstance(input_text, list): + assert len(input_text) and isinstance(input_text[0], int) + input_text = self._processor.tokenizer.decode(input_text) + + # MiniCPMV requires each frame of video as a single image token + text_parts = input_text.split(IMAGE_TOKEN) + new_text_parts = [] + + for image_index, image in enumerate(image_data): + try: + if isinstance(image, str) and image.startswith("video:"): + path = image[len("video:") :] + frames = encode_video(path) + else: + raw_image, size = load_image(image) + frames = [raw_image] + if len(frames) == 0: + continue + except FileNotFoundError as e: + print(e) + return None + + image_sizes += frames[0].size * len(frames) + image_hashes += [hash(image)] * len(frames) + raw_images += frames + new_text_parts.append(text_parts[image_index]) + new_text_parts.append(IMAGE_TOKEN * len(frames)) + + new_text_parts.append(text_parts[-1]) + input_text = "".join(new_text_parts) + if len(raw_images) == 0: + return None + res = await self._process_images(images=raw_images, input_text=input_text) + pixel_values = res["pixel_values"] + tgt_sizes = res["tgt_sizes"] + input_ids = res["input_ids"] + + # Collect special token ids + tokenizer = self._processor.tokenizer + im_start_id = [tokenizer.im_start_id] + im_end_id = [tokenizer.im_end_id] + if tokenizer.slice_start_id: + slice_start_id = [tokenizer.slice_start_id] + slice_end_id = [tokenizer.slice_end_id] + + return { + "input_ids": input_ids.flatten().tolist(), + "pixel_values": pixel_values, + "tgt_sizes": tgt_sizes, + "image_hashes": image_hashes, + "modalities": request_obj.modalities or ["image"], + "im_start_id": im_start_id, + "im_end_id": im_end_id, + "slice_start_id": slice_start_id, + "slice_end_id": slice_end_id, + } + + class Qwen2VLImageProcessor(BaseImageProcessor): def __init__(self, hf_config, server_args, _image_processor): self.hf_config = hf_config @@ -289,7 +438,12 @@ async def _process_single_image(self, image_data: Union[bytes, str]): return self._process_single_image_task(image_data) async def process_images_async( - self, image_data: List[Union[str, bytes]], input_text, request_obj + self, + image_data: List[Union[str, bytes]], + input_text, + request_obj, + *args, + **kwargs, ): if not image_data: return None @@ -350,6 +504,8 @@ def get_image_processor( return MllamaImageProcessor(hf_config, server_args, processor) elif "Qwen2VLForConditionalGeneration" in hf_config.architectures: return Qwen2VLImageProcessor(hf_config, server_args, processor.image_processor) + elif "MiniCPMV" in hf_config.architectures: + return MiniCPMVImageProcessor(hf_config, server_args, processor) else: return LlavaImageProcessor(hf_config, server_args, processor.image_processor) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 1aae28b00b76..eee9b6722d4f 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -19,9 +19,7 @@ import uuid from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Optional, Tuple, Union - -import torch +from typing import Dict, List, Optional, Union from sglang.srt.managers.schedule_batch import BaseFinishReason from sglang.srt.sampling.sampling_params import SamplingParams @@ -61,6 +59,9 @@ class GenerateReqInput: return_text_in_logprobs: bool = False # Whether to stream output. stream: bool = False + # Whether to log metrics for this request (e.g. health_generate calls do not log metrics) + log_metrics: bool = True + # The modalities of the image data [image, multi-images, video] modalities: Optional[List[str]] = None # LoRA related @@ -68,6 +69,10 @@ class GenerateReqInput: # Session info for continual prompting session_params: Optional[Union[List[Dict], Dict]] = None + # Custom logit processor for advanced sampling control. Must be a serialized instance + # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py + # Use the processor's `to_str()` method to generate the serialized string. + custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None def normalize_batch_and_arguments(self): if ( @@ -182,6 +187,13 @@ def normalize_batch_and_arguments(self): else: assert self.parallel_sample_num == 1 + if self.custom_logit_processor is None: + self.custom_logit_processor = [None] * num + elif not isinstance(self.custom_logit_processor, list): + self.custom_logit_processor = [self.custom_logit_processor] * num + else: + assert self.parallel_sample_num == 1 + def regenerate_rid(self): self.rid = uuid.uuid4().hex return self.rid @@ -198,8 +210,14 @@ def __getitem__(self, i): top_logprobs_num=self.top_logprobs_num[i], return_text_in_logprobs=self.return_text_in_logprobs, stream=self.stream, + log_metrics=self.log_metrics, modalities=self.modalities[i] if self.modalities else None, lora_path=self.lora_path[i] if self.lora_path is not None else None, + custom_logit_processor=( + self.custom_logit_processor[i] + if self.custom_logit_processor is not None + else None + ), ) @@ -232,6 +250,11 @@ class TokenizedGenerateReqInput: # Session info for continual prompting session_params: Optional[SessionParams] = None + # Custom logit processor for advanced sampling control. Must be a serialized instance + # of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py + # Use the processor's `to_str()` method to generate the serialized string. + custom_logit_processor: Optional[str] = None + @dataclass class EmbeddingReqInput: @@ -245,6 +268,8 @@ class EmbeddingReqInput: sampling_params: Union[List[Dict], Dict] = None # Dummy input embeds for compatibility input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None + # Whether to log metrics for this request (e.g. health_generate calls do not log metrics) + log_metrics: bool = True def normalize_batch_and_arguments(self): if (self.text is None and self.input_ids is None) or ( @@ -323,9 +348,7 @@ class BatchTokenIDOut: decoded_texts: List[str] decode_ids: List[int] read_offsets: List[int] - # Only used when --return-token-ids` is set - origin_input_ids: Optional[List[int]] - # Only used when `--skip-tokenizer-init` or `--return-token-ids` is set + # Only used when `--skip-tokenizer-init` is on output_ids: Optional[List[int]] # Detokenization configs skip_special_tokens: List[bool] @@ -344,7 +367,6 @@ class BatchTokenIDOut: input_top_logprobs_idx: List[List] output_top_logprobs_val: List[List] output_top_logprobs_idx: List[List] - normalized_prompt_logprob: List[float] @dataclass @@ -356,14 +378,7 @@ class BatchStrOut: # The output decoded strings output_strs: List[str] - # The token ids - origin_input_ids: Optional[List[int]] - output_ids: Optional[List[int]] - # Token counts - # real input and output tokens can be get from - # origin_input_ids and output_ids by enabling --return_token_ids - # TODO (Shuai): Rename this to clarify the meaning. prompt_tokens: List[int] completion_tokens: List[int] cached_tokens: List[int] @@ -377,7 +392,6 @@ class BatchStrOut: input_top_logprobs_idx: List[List] output_top_logprobs_val: List[List] output_top_logprobs_idx: List[List] - normalized_prompt_logprob: List[float] @dataclass @@ -468,6 +482,26 @@ class GetWeightsByNameReqOutput: parameter: list +@dataclass +class ReleaseMemoryOccupationReqInput: + pass + + +@dataclass +class ReleaseMemoryOccupationReqOutput: + pass + + +@dataclass +class ResumeMemoryOccupationReqInput: + pass + + +@dataclass +class ResumeMemoryOccupationReqOutput: + pass + + @dataclass class AbortReq: # The request id @@ -479,6 +513,14 @@ class ProfileReq(Enum): STOP_PROFILE = 2 +@dataclass +class ConfigureLoggingReq: + log_requests: Optional[bool] = None + log_requests_level: Optional[int] = None + dump_requests_folder: Optional[str] = None + dump_requests_threshold: Optional[int] = None + + @dataclass class OpenSessionReqInput: capacity_of_str_len: int diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 3b056cc5d492..6c44b17ffd86 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -52,7 +52,6 @@ if TYPE_CHECKING: from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm - INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 # Put some global args for easy access @@ -65,9 +64,9 @@ "enable_nan_detection": ServerArgs.enable_nan_detection, "enable_dp_attention": ServerArgs.enable_dp_attention, "enable_ep_moe": ServerArgs.enable_ep_moe, + "device": ServerArgs.device, } - logger = logging.getLogger(__name__) @@ -116,14 +115,18 @@ def to_json(self): class FINISH_ABORT(BaseFinishReason): - def __init__(self, message="Unknown error"): + def __init__(self, message="Unknown error", status_code=None, err_type=None): super().__init__(is_error=True) self.message = message + self.status_code = status_code + self.err_type = err_type def to_json(self): return { "type": "abort", "message": self.message, + "status_code": self.status_code, + "err_type": self.err_type, } @@ -148,6 +151,15 @@ class ImageInputs: image_grid_thws: List[Tuple[int, int, int]] = None mrope_position_delta: Optional[torch.Tensor] = None + # MiniCPMV related + # All the images in the batch should share the same special image + # bound token ids. + im_start_id: Optional[torch.Tensor] = None + im_end_id: Optional[torch.Tensor] = None + slice_start_id: Optional[torch.Tensor] = None + slice_end_id: Optional[torch.Tensor] = None + tgt_sizes: Optional[list] = None + @staticmethod def from_dict(obj: dict): ret = ImageInputs( @@ -167,6 +179,11 @@ def from_dict(obj: dict): "aspect_ratio_ids", "aspect_ratio_mask", "image_grid_thws", + "im_start_id", + "im_end_id", + "slice_start_id", + "slice_end_id", + "tgt_sizes", ] for arg in optional_args: if arg in obj: @@ -215,6 +232,7 @@ def __init__( lora_path: Optional[str] = None, input_embeds: Optional[List[List[float]]] = None, session_id: Optional[str] = None, + custom_logit_processor: Optional[str] = None, eos_token_ids: Optional[Set[int]] = None, ): # Input and output info @@ -226,14 +244,16 @@ def __init__( else origin_input_ids # Before image padding ) self.origin_input_ids = origin_input_ids - self.output_ids = [] # Each decode stage's output ids - self.fill_ids = None # fill_ids = origin_input_ids + output_ids + # Each decode stage's output ids + self.output_ids = [] + # fill_ids = origin_input_ids + output_ids. Updated if chunked. self.session_id = session_id self.input_embeds = input_embeds # Sampling info self.sampling_params = sampling_params self.lora_path = lora_path + self.custom_logit_processor = custom_logit_processor # Memory pool info self.req_pool_idx = None @@ -265,6 +285,7 @@ def __init__( # Prefix info self.prefix_indices = [] # Tokens to run prefill. input_tokens - shared_prefix_tokens. + # Updated if chunked. self.extend_input_len = 0 self.last_node = None @@ -280,11 +301,10 @@ def __init__( self.top_logprobs_num = top_logprobs_num # Logprobs (return value) - self.normalized_prompt_logprob = None - self.input_token_logprobs_val = None - self.input_token_logprobs_idx = None - self.input_top_logprobs_val = None - self.input_top_logprobs_idx = None + self.input_token_logprobs_val: Optional[List[float]] = None + self.input_token_logprobs_idx: Optional[List[int]] = None + self.input_top_logprobs_val: Optional[List[float]] = None + self.input_top_logprobs_idx: Optional[List[int]] = None if return_logprob: self.output_token_logprobs_val = [] @@ -344,9 +364,6 @@ def adjust_max_prefix_ids(self): max_prefix_len = min(max_prefix_len, input_len - 1) if self.return_logprob: - if self.normalized_prompt_logprob is None: - # Need at least two tokens to compute normalized logprob - max_prefix_len = min(max_prefix_len, input_len - 2) max_prefix_len = min(max_prefix_len, self.logprob_start_len) max_prefix_len = max(max_prefix_len, 0) @@ -533,13 +550,13 @@ class ScheduleBatch: next_batch_sampling_info: SamplingBatchInfo = None # Batched arguments to model runner - input_ids: torch.Tensor = None - input_embeds: torch.Tensor = None - req_pool_indices: torch.Tensor = None - seq_lens: torch.Tensor = None + input_ids: torch.Tensor = None # shape: [b], int32 + input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32 + req_pool_indices: torch.Tensor = None # shape: [b], int32 + seq_lens: torch.Tensor = None # shape: [b], int64 # The output locations of the KV cache - out_cache_loc: torch.Tensor = None - output_ids: torch.Tensor = None + out_cache_loc: torch.Tensor = None # shape: [b], int32 + output_ids: torch.Tensor = None # shape: [b], int32 # The sum of all sequence lengths seq_lens_sum: int = None @@ -578,6 +595,9 @@ class ScheduleBatch: spec_algorithm: SpeculativeAlgorithm = None spec_info: Optional[SpecInfo] = None + # Enable custom logit processor + enable_custom_logit_processor: bool = False + @classmethod def init_new( cls, @@ -588,6 +608,7 @@ def init_new( model_config: ModelConfig, enable_overlap: bool, spec_algorithm: SpeculativeAlgorithm, + enable_custom_logit_processor: bool, ): return cls( reqs=reqs, @@ -601,6 +622,7 @@ def init_new( has_grammar=any(req.grammar for req in reqs), device=req_to_token_pool.device, spec_algorithm=spec_algorithm, + enable_custom_logit_processor=enable_custom_logit_processor, ) def batch_size(self): @@ -656,7 +678,7 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) or len(req.prefix_indices) >= im.num_image_tokens ) - self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int32).to( + self.encoder_lens = torch.tensor(self.encoder_lens_cpu, dtype=torch.int64).to( self.device, non_blocking=True ) @@ -690,7 +712,7 @@ def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]) self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( self.device, non_blocking=True ) - self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to( + self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to( self.device, non_blocking=True ) @@ -766,10 +788,10 @@ def prepare_for_extend(self): self.input_ids = torch.tensor(sum(input_ids, []), dtype=torch.int32).to( self.device, non_blocking=True ) - self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int32).to( + self.req_pool_indices = torch.tensor(req_pool_indices, dtype=torch.int64).to( self.device, non_blocking=True ) - self.seq_lens = torch.tensor(seq_lens, dtype=torch.int32).to( + self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to( self.device, non_blocking=True ) self.input_embeds = ( @@ -1002,11 +1024,16 @@ def prepare_encoder_info_decode(self): def prepare_for_idle(self): self.forward_mode = ForwardMode.IDLE self.input_ids = torch.empty(0, dtype=torch.int32, device=self.device) - self.seq_lens = torch.empty(0, dtype=torch.int32, device=self.device) + self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device) self.out_cache_loc = torch.empty(0, dtype=torch.int32, device=self.device) self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.seq_lens_sum = 0 self.extend_num_tokens = 0 + self.sampling_info = SamplingBatchInfo.from_schedule_batch( + self, + self.model_config.vocab_size, + enable_overlap_schedule=self.enable_overlap, + ) def prepare_for_decode(self): self.forward_mode = ForwardMode.DECODE @@ -1067,7 +1094,7 @@ def filter_batch( self.encoder_lens_cpu = [self.encoder_lens_cpu[i] for i in keep_indices] self.reqs = [self.reqs[i] for i in keep_indices] - new_indices = torch.tensor(keep_indices, dtype=torch.int32).to( + new_indices = torch.tensor(keep_indices, dtype=torch.int64).to( self.device, non_blocking=True ) self.req_pool_indices = self.req_pool_indices[new_indices] @@ -1085,6 +1112,8 @@ def filter_batch( self.has_grammar = any(req.grammar for req in self.reqs) self.sampling_info.filter_batch(keep_indices, new_indices) + if self.spec_info: + self.spec_info.filter_batch(new_indices) def merge_batch(self, other: "ScheduleBatch"): # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because @@ -1121,7 +1150,7 @@ def merge_batch(self, other: "ScheduleBatch"): self.spec_info.merge_batch(other.spec_info) def get_model_worker_batch(self): - if self.forward_mode.is_decode() or self.forward_mode.is_idle(): + if self.forward_mode.is_decode_or_idle(): extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None else: extend_seq_lens = self.extend_lens @@ -1136,7 +1165,6 @@ def get_model_worker_batch(self): global bid bid += 1 - return ModelWorkerBatch( bid=bid, forward_mode=self.forward_mode, @@ -1180,6 +1208,7 @@ def copy(self): return_logprob=self.return_logprob, decoding_reqs=self.decoding_reqs, spec_algorithm=self.spec_algorithm, + enable_custom_logit_processor=self.enable_custom_logit_processor, ) def __str__(self): diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index d2083d092bcd..a3a099b83de2 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -24,6 +24,7 @@ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache +from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large. @@ -250,23 +251,24 @@ class PrefillAdder: def __init__( self, tree_cache: BasePrefixCache, + token_to_kv_pool: BaseTokenToKVPool, running_batch: ScheduleBatch, new_token_ratio: float, - rem_total_tokens: int, rem_input_tokens: int, rem_chunk_tokens: Optional[int], mixed_with_decode_tokens: int = 0, ): self.tree_cache = tree_cache + self.token_to_kv_pool = token_to_kv_pool self.running_batch = running_batch self.new_token_ratio = new_token_ratio - self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_chunk_tokens = rem_chunk_tokens if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= mixed_with_decode_tokens - self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens + self.rem_total_token_offset = mixed_with_decode_tokens + self.cur_rem_token_offset = mixed_with_decode_tokens self.req_states = None self.can_run_list = [] @@ -275,8 +277,7 @@ def __init__( self.log_input_tokens = 0 if running_batch is not None: - # Pre-remove the tokens which will be occupied by the running requests - self.rem_total_tokens -= sum( + self.rem_total_token_offset += sum( [ min( (r.sampling_params.max_new_tokens - len(r.output_ids)), @@ -287,6 +288,22 @@ def __init__( ] ) + @property + def rem_total_tokens(self): + return ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + - self.rem_total_token_offset + ) + + @property + def cur_rem_tokens(self): + return ( + self.token_to_kv_pool.available_size() + + self.tree_cache.evictable_size() + - self.cur_rem_token_offset + ) + def budget_state(self): if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0: return AddReqResult.NO_TOKEN @@ -301,8 +318,8 @@ def budget_state(self): def _prefill_one_req( self, prefix_len: int, extend_input_len: int, max_new_tokens: int ): - self.rem_total_tokens -= extend_input_len + max_new_tokens - self.cur_rem_tokens -= extend_input_len + self.rem_total_token_offset += extend_input_len + max_new_tokens + self.cur_rem_token_offset += extend_input_len self.rem_input_tokens -= extend_input_len if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= extend_input_len @@ -332,12 +349,10 @@ def add_being_chunked_req(self, req: Req): @contextmanager def _lock_node(self, last_node: TreeNode): try: - delta = self.tree_cache.inc_lock_ref(last_node) - self.rem_total_tokens += delta + self.tree_cache.inc_lock_ref(last_node) yield None finally: - delta = self.tree_cache.dec_lock_ref(last_node) - self.rem_total_tokens += delta + self.tree_cache.dec_lock_ref(last_node) def add_one_req_ignore_eos(self, req: Req): def add_req_state(r, insert_sort=False): @@ -433,7 +448,6 @@ def add_one_req(self, req: Req): or input_tokens <= self.rem_chunk_tokens or ( req.return_logprob - and req.normalized_prompt_logprob is None and req.logprob_start_len != len(req.origin_input_ids) - 1 ) ): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 6022a2567343..85bd1c2a4adf 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -13,6 +13,7 @@ # ============================================================================== """A scheduler that manages a tensor parallel GPU worker.""" +import faulthandler import logging import os import signal @@ -21,8 +22,10 @@ import warnings from collections import deque from concurrent import futures +from dataclasses import dataclass +from http import HTTPStatus from types import SimpleNamespace -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union import psutil import setproctitle @@ -31,7 +34,9 @@ from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.constrained.base_grammar_backend import create_grammar_backend from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer +from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( AbortReq, @@ -46,6 +51,10 @@ OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, + ReleaseMemoryOccupationReqInput, + ReleaseMemoryOccupationReqOutput, + ResumeMemoryOccupationReqInput, + ResumeMemoryOccupationReqOutput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, UpdateWeightFromDiskReqInput, @@ -71,12 +80,14 @@ from sglang.srt.managers.session_controller import Session from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient +from sglang.srt.managers.utils import validate_input_length from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( broadcast_pyobj, configure_logger, @@ -87,7 +98,7 @@ set_random_seed, suppress_other_loggers, ) -from sglang.utils import get_exception_traceback +from sglang.utils import TypeBasedDispatcher, get_exception_traceback logger = logging.getLogger(__name__) @@ -95,6 +106,19 @@ test_retract = get_bool_env_var("SGLANG_TEST_RETRACT") +@dataclass +class GenerationBatchResult: + logits_output: LogitsProcessorOutput + next_token_ids: List[int] + bid: int + + +@dataclass +class EmbeddingBatchResult: + embeddings: torch.Tensor + bid: int + + class Scheduler: """A scheduler that manages a tensor parallel GPU worker.""" @@ -126,26 +150,36 @@ def __init__( else 1 ) + # Distributed rank info + self.dp_size = server_args.dp_size + self.attn_tp_rank, self.attn_tp_size, self.dp_rank = ( + compute_dp_attention_world_info( + server_args.enable_dp_attention, + self.tp_rank, + self.tp_size, + self.dp_size, + ) + ) + # Init inter-process communication context = zmq.Context(2) - - if self.tp_rank == 0 or self.server_args.enable_dp_attention: + if self.attn_tp_rank == 0: self.recv_from_tokenizer = get_zmq_socket( - context, zmq.PULL, port_args.scheduler_input_ipc_name + context, zmq.PULL, port_args.scheduler_input_ipc_name, False ) self.send_to_tokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.tokenizer_ipc_name + context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) if server_args.skip_tokenizer_init: # Directly send to the TokenizerManager self.send_to_detokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.tokenizer_ipc_name + context, zmq.PUSH, port_args.tokenizer_ipc_name, False ) else: # Send to the DetokenizerManager self.send_to_detokenizer = get_zmq_socket( - context, zmq.PUSH, port_args.detokenizer_ipc_name + context, zmq.PUSH, port_args.detokenizer_ipc_name, False ) else: self.recv_from_tokenizer = None @@ -173,6 +207,7 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) self.tokenizer = self.processor.tokenizer else: @@ -180,6 +215,7 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) # Check whether overlap can be enabled @@ -208,7 +244,7 @@ def __init__( nccl_port=port_args.nccl_port, ) - # Launch worker for speculative decoding if need + # Launch a worker for speculative decoding if needed if self.spec_algorithm.is_eagle(): from sglang.srt.speculative.eagle_worker import EAGLEWorker @@ -238,10 +274,10 @@ def __init__( _, ) = self.tp_worker.get_worker_info() self.tp_cpu_group = self.tp_worker.get_tp_cpu_group() + self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group() self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func() global_server_args_dict.update(worker_global_server_args_dict) set_random_seed(self.random_seed) - # Print debug info logger.info( f"max_total_num_tokens={self.max_total_num_tokens}, " @@ -281,9 +317,13 @@ def __init__( self.forward_ct = 0 self.forward_ct_decode = 0 self.num_generated_tokens = 0 + self.spec_num_total_accepted_tokens = 0 + self.spec_num_total_forward_ct = 0 self.last_decode_stats_tic = time.time() self.stream_interval = server_args.stream_interval self.current_stream = torch.get_device_module(self.device).current_stream() + if self.device == "cpu": + self.current_stream.synchronize = lambda: None # No-op for CPU # Session info self.sessions: Dict[str, Session] = {} @@ -300,28 +340,9 @@ def __init__( # Init the grammar backend for constrained generation self.grammar_queue: List[Req] = [] if not server_args.skip_tokenizer_init: - if server_args.grammar_backend == "outlines": - from sglang.srt.constrained.outlines_backend import ( - OutlinesGrammarBackend, - ) - - self.grammar_backend = OutlinesGrammarBackend( - self.tokenizer, - whitespace_pattern=server_args.constrained_json_whitespace_pattern, - allow_jump_forward=not server_args.disable_jump_forward, - ) - elif server_args.grammar_backend == "xgrammar": - from sglang.srt.constrained.xgrammar_backend import ( - XGrammarGrammarBackend, - ) - - self.grammar_backend = XGrammarGrammarBackend( - self.tokenizer, vocab_size=self.model_config.vocab_size - ) - else: - raise ValueError( - f"Invalid grammar backend: {server_args.grammar_backend}" - ) + self.grammar_backend = create_grammar_backend( + server_args, self.tokenizer, self.model_config.vocab_size + ) else: self.grammar_backend = None @@ -356,6 +377,10 @@ def __init__( t.start() self.parent_process = psutil.Process().parent() + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=server_args.enable_memory_saver + ) + # Init profiler if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "": self.profiler = None @@ -383,22 +408,53 @@ def __init__( }, ) + # Init request dispatcher + self._request_dispatcher = TypeBasedDispatcher( + [ + (TokenizedGenerateReqInput, self.handle_generate_request), + (TokenizedEmbeddingReqInput, self.handle_embedding_request), + (FlushCacheReq, self.flush_cache_wrapped), + (AbortReq, self.abort_request), + (UpdateWeightFromDiskReqInput, self.update_weights_from_disk), + (InitWeightsUpdateGroupReqInput, self.init_weights_update_group), + ( + UpdateWeightsFromDistributedReqInput, + self.update_weights_from_distributed, + ), + (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), + (GetWeightsByNameReqInput, self.get_weights_by_name), + (ProfileReq, self.profile), + (OpenSessionReqInput, self.open_session), + (CloseSessionReqInput, self.close_session), + ( + ReleaseMemoryOccupationReqInput, + lambda _: self.release_memory_occupation(), + ), + ( + ResumeMemoryOccupationReqInput, + lambda _: self.resume_memory_occupation(), + ), + ] + ) + def watchdog_thread(self): """A watch dog thread that will try to kill the server itself if one batch takes too long.""" self.watchdog_last_forward_ct = 0 self.watchdog_last_time = time.time() while True: + current = time.time() if self.cur_batch is not None: if self.watchdog_last_forward_ct == self.forward_ct: - if time.time() > self.watchdog_last_time + self.watchdog_timeout: + if current > self.watchdog_last_time + self.watchdog_timeout: logger.error(f"Watchdog timeout ({self.watchdog_timeout=})") break else: self.watchdog_last_forward_ct = self.forward_ct - self.watchdog_last_time = time.time() - time.sleep(self.watchdog_timeout / 2) - + self.watchdog_last_time = current + time.sleep(self.watchdog_timeout // 2) + # Wait sometimes so that the parent process can print the error. + time.sleep(5) self.parent_process.send_signal(signal.SIGQUIT) @torch.no_grad() @@ -409,10 +465,6 @@ def event_loop_normal(self): self.process_input_requests(recv_reqs) batch = self.get_next_batch_to_run() - - if self.server_args.enable_dp_attention: # TODO: simplify this - batch = self.prepare_dp_attn_batch(batch) - self.cur_batch = batch if batch: @@ -442,7 +494,7 @@ def event_loop_overlap(self): result_queue.append((batch.copy(), result)) if self.last_batch is None: - # Create a dummy first batch to start the pipeline for overlap scheduler. + # Create a dummy first batch to start the pipeline for overlap schedule. # It is now used for triggering the sampling_info_done event. tmp_batch = ScheduleBatch( reqs=None, @@ -467,7 +519,7 @@ def event_loop_overlap(self): def recv_requests(self) -> List[Req]: """Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" - if self.tp_rank == 0 or self.server_args.enable_dp_attention: + if self.attn_tp_rank == 0: recv_reqs = [] while True: @@ -479,57 +531,48 @@ def recv_requests(self) -> List[Req]: else: recv_reqs = None - if self.tp_size != 1 and not self.server_args.enable_dp_attention: + if self.server_args.enable_dp_attention: + if self.attn_tp_rank == 0: + work_reqs = [ + req + for req in recv_reqs + if isinstance( + req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) + ) + ] + control_reqs = [ + req + for req in recv_reqs + if not isinstance( + req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) + ) + ] + else: + work_reqs = None + control_reqs = None + + if self.attn_tp_size != 1: + attn_tp_rank_0 = self.dp_rank * self.attn_tp_size + work_reqs = broadcast_pyobj( + work_reqs, + self.attn_tp_rank, + self.attn_tp_cpu_group, + src=attn_tp_rank_0, + ) + if self.tp_size != 1: + control_reqs = broadcast_pyobj( + control_reqs, self.tp_rank, self.tp_cpu_group + ) + recv_reqs = work_reqs + control_reqs + elif self.tp_size != 1: recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) return recv_reqs def process_input_requests(self, recv_reqs: List): for recv_req in recv_reqs: - if isinstance(recv_req, TokenizedGenerateReqInput): - self.handle_generate_request(recv_req) - elif isinstance(recv_req, TokenizedEmbeddingReqInput): - self.handle_embedding_request(recv_req) - elif isinstance(recv_req, FlushCacheReq): - self.flush_cache() - elif isinstance(recv_req, AbortReq): - self.abort_request(recv_req) - elif isinstance(recv_req, UpdateWeightFromDiskReqInput): - success, message = self.update_weights_from_disk(recv_req) - self.send_to_tokenizer.send_pyobj( - UpdateWeightFromDiskReqOutput(success, message) - ) - elif isinstance(recv_req, InitWeightsUpdateGroupReqInput): - success, message = self.init_weights_update_group(recv_req) - self.send_to_tokenizer.send_pyobj( - InitWeightsUpdateGroupReqOutput(success, message) - ) - elif isinstance(recv_req, UpdateWeightsFromDistributedReqInput): - success, message = self.update_weights_from_distributed(recv_req) - self.send_to_tokenizer.send_pyobj( - UpdateWeightsFromDistributedReqOutput(success, message) - ) - elif isinstance(recv_req, UpdateWeightsFromTensorReqInput): - success, message = self.update_weights_from_tensor(recv_req) - self.send_to_tokenizer.send_pyobj( - UpdateWeightsFromTensorReqOutput(success, message) - ) - elif isinstance(recv_req, GetWeightsByNameReqInput): - parameter = self.get_weights_by_name(recv_req) - self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter)) - elif isinstance(recv_req, ProfileReq): - if recv_req == ProfileReq.START_PROFILE: - self.start_profile() - else: - self.stop_profile() - elif isinstance(recv_req, OpenSessionReqInput): - session_id, success = self.open_session(recv_req) - self.send_to_tokenizer.send_pyobj( - OpenSessionReqOutput(session_id=session_id, success=success) - ) - elif isinstance(recv_req, CloseSessionReqInput): - self.close_session(recv_req) - else: - raise ValueError(f"Invalid request: {recv_req}") + output = self._request_dispatcher(recv_req) + if output is not None: + self.send_to_tokenizer.send_pyobj(output) def handle_generate_request( self, @@ -548,6 +591,19 @@ def handle_generate_request( fake_input_ids = [1] * seq_length recv_req.input_ids = fake_input_ids + # Handle custom logit processor passed to the request + custom_logit_processor = recv_req.custom_logit_processor + if ( + not self.server_args.enable_custom_logit_processor + and custom_logit_processor is not None + ): + logger.warning( + "The SGLang server is not configured to enable custom logit processor." + "The custom logit processor passed in will be ignored." + "Please set --enable-custom-logits-processor to enable this feature." + ) + custom_logit_processor = None + req = Req( recv_req.rid, recv_req.input_text, @@ -558,6 +614,7 @@ def handle_generate_request( stream=recv_req.stream, lora_path=recv_req.lora_path, input_embeds=recv_req.input_embeds, + custom_logit_processor=custom_logit_processor, eos_token_ids=self.model_config.hf_eos_token_id, ) req.tokenizer = self.tokenizer @@ -589,15 +646,16 @@ def handle_generate_request( req.extend_image_inputs(image_inputs) if len(req.origin_input_ids) >= self.max_req_input_len: - logger.error( + error_msg = ( "Multimodal prompt is too long after expanding multimodal tokens. " - f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. " + f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}." ) + logger.error(error_msg) req.origin_input_ids = [0] req.image_inputs = None req.sampling_params.max_new_tokens = 0 req.finished_reason = FINISH_ABORT( - "Multimodal prompt is too long. Check server logs for details." + error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" ) self.waiting_queue.append(req) return @@ -609,13 +667,16 @@ def handle_generate_request( # By default, only return the logprobs for output tokens req.logprob_start_len = len(req.origin_input_ids) - 1 - # Truncate prompts that are too long - if len(req.origin_input_ids) > self.max_req_input_len: - logger.warning( - "Request length is longer than the KV cache pool size or " - "the max context length. Truncated!!!" - ) - req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] + # Validate prompts length + error_msg = validate_input_length( + req, + self.max_req_input_len, + self.server_args.allow_auto_truncate, + ) + + if error_msg: + self.waiting_queue.append(req) + return req.sampling_params.max_new_tokens = min( ( @@ -663,13 +724,12 @@ def handle_embedding_request( ) req.tokenizer = self.tokenizer - # Truncate prompts that are too long - if len(req.origin_input_ids) >= self.max_req_input_len: - logger.warning( - "Request length is longer than the KV cache pool size or " - "the max context length. Truncated!!!" - ) - req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] + # Validate prompts length + validate_input_length( + req, + self.max_req_input_len, + self.server_args.allow_auto_truncate, + ) self.waiting_queue.append(req) @@ -715,21 +775,40 @@ def log_decode_stats(self): self.num_generated_tokens = 0 self.last_decode_stats_tic = time.time() num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 - logger.info( - f"Decode batch. " - f"#running-req: {num_running_reqs}, " - f"#token: {num_used}, " - f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"gen throughput (token/s): {gen_throughput:.2f}, " - f"#queue-req: {len(self.waiting_queue)}" - ) + if self.spec_algorithm.is_none(): + msg = ( + f"Decode batch. " + f"#running-req: {num_running_reqs}, " + f"#token: {num_used}, " + f"token usage: {num_used / self.max_total_num_tokens:.2f}, " + f"gen throughput (token/s): {gen_throughput:.2f}, " + f"#queue-req: {len(self.waiting_queue)}" + ) + spec_accept_length = 0 + else: + spec_accept_length = ( + self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct + ) + self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0 + msg = ( + f"Decode batch. " + f"#running-req: {num_running_reqs}, " + f"#token: {num_used}, " + f"token usage: {num_used / self.max_total_num_tokens:.2f}, " + f"accept len: {spec_accept_length:.2f}, " + f"gen throughput (token/s): {gen_throughput:.2f}, " + f"#queue-req: {len(self.waiting_queue)}" + ) + + logger.info(msg) if self.enable_metrics: self.stats.num_running_reqs = num_running_reqs self.stats.num_used_tokens = num_used self.stats.token_usage = num_used / self.max_total_num_tokens self.stats.gen_throughput = gen_throughput self.stats.num_queue_reqs = len(self.waiting_queue) + self.stats.spec_accept_length = spec_accept_length self.metrics_collector.log_stats(self.stats) def check_memory(self): @@ -772,16 +851,23 @@ def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: else: self.running_batch.merge_batch(self.last_batch) - # Run prefill first if possible new_batch = self.get_new_batch_prefill() if new_batch is not None: - return new_batch + # Run prefill first if possible + ret = new_batch + else: + # Run decode + if self.running_batch is None: + ret = None + else: + self.running_batch = self.update_running_batch(self.running_batch) + ret = self.running_batch - # Run decode - if self.running_batch is None: - return None - self.running_batch = self.update_running_batch(self.running_batch) - return self.running_batch + # Handle DP attention + if self.server_args.enable_dp_attention: + ret = self.prepare_dp_attn_batch(ret) + + return ret def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # Check if the grammar is ready in the grammar queue @@ -805,9 +891,9 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # Prefill policy adder = PrefillAdder( self.tree_cache, + self.token_to_kv_pool, self.running_batch, self.new_token_ratio, - self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(), self.max_prefill_tokens, self.chunked_prefill_size, running_bs if self.is_mixed_chunk else 0, @@ -868,7 +954,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.being_chunked_req.is_being_chunked += 1 # Print stats - if self.tp_rank == 0: + if self.attn_tp_rank == 0: self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked) # Create a new batch @@ -880,6 +966,7 @@ def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: self.model_config, self.enable_overlap, self.spec_algorithm, + self.server_args.enable_custom_logit_processor, ) new_batch.prepare_for_extend() @@ -936,7 +1023,7 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: ) # Check for jump-forward - if not self.disable_jump_forward: + if not self.disable_jump_forward and batch.has_grammar: jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func) self.waiting_queue.extend(jump_forward_reqs) if batch.is_empty(): @@ -950,12 +1037,14 @@ def update_running_batch(self, batch: ScheduleBatch) -> Optional[ScheduleBatch]: batch.prepare_for_decode() return batch - def run_batch(self, batch: ScheduleBatch): + def run_batch( + self, batch: ScheduleBatch + ) -> Union[GenerationBatchResult, EmbeddingBatchResult]: """Run a batch.""" self.forward_ct += 1 if self.is_generation: - if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: + if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0: if self.spec_algorithm.is_none(): model_worker_batch = batch.get_model_worker_batch() logits_output, next_token_ids = ( @@ -968,45 +1057,65 @@ def run_batch(self, batch: ScheduleBatch): model_worker_batch, num_accepted_tokens, ) = self.draft_worker.forward_batch_speculative_generation(batch) + self.spec_num_total_accepted_tokens += ( + num_accepted_tokens + batch.batch_size() + ) + self.spec_num_total_forward_ct += batch.batch_size() self.num_generated_tokens += num_accepted_tokens - elif batch.forward_mode.is_idle(): - model_worker_batch = batch.get_model_worker_batch() - self.tp_worker.forward_batch_idle(model_worker_batch) - return else: - logits_output = None - if self.skip_tokenizer_init: - next_token_ids = torch.full( - (batch.batch_size(),), self.tokenizer.eos_token_id - ) - else: - next_token_ids = torch.full((batch.batch_size(),), 0) + assert False, "batch.extend_num_tokens == 0, this is unexpected!" batch.output_ids = next_token_ids - ret = logits_output, next_token_ids, model_worker_batch.bid + + ret = GenerationBatchResult( + logits_output=logits_output, + next_token_ids=next_token_ids, + bid=model_worker_batch.bid, + ) else: # embedding or reward model assert batch.extend_num_tokens != 0 model_worker_batch = batch.get_model_worker_batch() embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) - ret = embeddings, model_worker_batch.bid + ret = EmbeddingBatchResult( + embeddings=embeddings, bid=model_worker_batch.bid + ) return ret - def process_batch_result(self, batch: ScheduleBatch, result): + def process_batch_result( + self, + batch: ScheduleBatch, + result: Union[GenerationBatchResult, EmbeddingBatchResult], + ): if batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result) if batch.is_empty(): self.running_batch = None elif batch.forward_mode.is_extend(): self.process_batch_result_prefill(batch, result) + elif batch.forward_mode.is_idle(): + if self.enable_overlap: + self.tp_worker.resolve_batch_result(result.bid) elif batch.forward_mode.is_dummy_first(): batch.next_batch_sampling_info.update_regex_vocab_mask() self.current_stream.synchronize() batch.next_batch_sampling_info.sampling_info_done.set() - def process_batch_result_prefill(self, batch: ScheduleBatch, result): + def process_batch_result_prefill( + self, + batch: ScheduleBatch, + result: Union[GenerationBatchResult, EmbeddingBatchResult], + ): skip_stream_req = None if self.is_generation: - logits_output, next_token_ids, bid = result + ( + logits_output, + next_token_ids, + bid, + ) = ( + result.logits_output, + result.next_token_ids, + result.bid, + ) if self.enable_overlap: logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid) @@ -1020,9 +1129,6 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): logits_output.input_token_logprobs = ( logits_output.input_token_logprobs.tolist() ) - logits_output.normalized_prompt_logprobs = ( - logits_output.normalized_prompt_logprobs.tolist() - ) # Check finish conditions logprob_pt = 0 @@ -1067,7 +1173,7 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): batch.next_batch_sampling_info.sampling_info_done.set() else: # embedding or reward model - embeddings, bid = result + embeddings, bid = result.embeddings, result.bid embeddings = embeddings.tolist() # Check finish conditions @@ -1091,8 +1197,16 @@ def process_batch_result_prefill(self, batch: ScheduleBatch, result): self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req) - def process_batch_result_decode(self, batch: ScheduleBatch, result): - logits_output, next_token_ids, bid = result + def process_batch_result_decode( + self, + batch: ScheduleBatch, + result: GenerationBatchResult, + ): + logits_output, next_token_ids, bid = ( + result.logits_output, + result.next_token_ids, + result.bid, + ) self.num_generated_tokens += len(batch.reqs) if self.enable_overlap: @@ -1150,7 +1264,7 @@ def process_batch_result_decode(self, batch: ScheduleBatch, result): self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) if ( - self.tp_rank == 0 + self.attn_tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0 ): self.log_decode_stats() @@ -1170,9 +1284,6 @@ def add_logprob_return_values( # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len - if req.normalized_prompt_logprob is None: - req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] - if req.input_token_logprobs_val is None: input_token_logprobs_val = output.input_token_logprobs[ pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens @@ -1253,7 +1364,6 @@ def stream_output( decode_ids_list = [] read_offsets = [] output_ids = [] - origin_input_ids = [] skip_special_tokens = [] spaces_between_special_tokens = [] @@ -1271,15 +1381,12 @@ def stream_output( input_top_logprobs_idx = [] output_top_logprobs_val = [] output_top_logprobs_idx = [] - normalized_prompt_logprob = [] else: input_token_logprobs_val = input_token_logprobs_idx = ( output_token_logprobs_val ) = output_token_logprobs_idx = input_top_logprobs_val = ( input_top_logprobs_idx - ) = output_top_logprobs_val = output_top_logprobs_idx = ( - normalized_prompt_logprob - ) = None + ) = output_top_logprobs_val = output_top_logprobs_idx = None for req in reqs: if req is skip_req: @@ -1305,14 +1412,8 @@ def stream_output( decode_ids, read_offset = req.init_incremental_detokenize() decode_ids_list.append(decode_ids) read_offsets.append(read_offset) - if self.skip_tokenizer_init or self.server_args.return_token_ids: + if self.skip_tokenizer_init: output_ids.append(req.output_ids) - else: - output_ids = None - if self.server_args.return_token_ids: - origin_input_ids.append(req.origin_input_ids) - else: - origin_input_ids = None skip_special_tokens.append(req.sampling_params.skip_special_tokens) spaces_between_special_tokens.append( req.sampling_params.spaces_between_special_tokens @@ -1332,7 +1433,6 @@ def stream_output( input_top_logprobs_idx.append(req.input_top_logprobs_idx) output_top_logprobs_val.append(req.output_top_logprobs_val) output_top_logprobs_idx.append(req.output_top_logprobs_idx) - normalized_prompt_logprob.append(req.normalized_prompt_logprob) # Send to detokenizer if rids: @@ -1344,7 +1444,6 @@ def stream_output( decoded_texts, decode_ids_list, read_offsets, - origin_input_ids, output_ids, skip_special_tokens, spaces_between_special_tokens, @@ -1360,7 +1459,6 @@ def stream_output( input_top_logprobs_idx, output_top_logprobs_val, output_top_logprobs_idx, - normalized_prompt_logprob, ) ) else: # embedding or reward model @@ -1402,12 +1500,7 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): # Check forward mode for cuda graph if not self.server_args.disable_cuda_graph: forward_mode_state = torch.tensor( - ( - 1 - if local_batch.forward_mode.is_decode() - or local_batch.forward_mode.is_idle() - else 0 - ), + (1 if local_batch.forward_mode.is_decode_or_idle() else 0), dtype=torch.int32, ) torch.distributed.all_reduce( @@ -1428,6 +1521,7 @@ def get_idle_batch(self): self.model_config, self.enable_overlap, self.spec_algorithm, + self.server_args.enable_custom_logit_processor, ) idle_batch.prepare_for_idle() return idle_batch @@ -1456,6 +1550,9 @@ def move_ready_grammar_requests(self): self.waiting_queue.extend(self.grammar_queue[:num_ready_reqs]) self.grammar_queue = self.grammar_queue[num_ready_reqs:] + def flush_cache_wrapped(self, recv_req: FlushCacheReq): + self.flush_cache() + def flush_cache(self): """Flush the memory pool and cache.""" if len(self.waiting_queue) == 0 and ( @@ -1467,6 +1564,15 @@ def flush_cache(self): self.grammar_backend.reset() self.req_to_token_pool.clear() self.token_to_kv_pool.clear() + + if not self.spec_algorithm.is_none(): + self.draft_worker.model_runner.req_to_token_pool.clear() + self.draft_worker.model_runner.token_to_kv_pool.clear() + + self.num_generated_tokens = 0 + self.forward_ct_decode = 0 + self.spec_num_total_accepted_tokens = 0 + self.spec_num_total_forward_ct = 0 torch.cuda.empty_cache() logger.info("Cache flushed successfully!") if_success = True @@ -1508,12 +1614,12 @@ def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): assert flash_cache_success, "Cache flush failed after updating weights" else: logger.error(message) - return success, message + return UpdateWeightFromDiskReqOutput(success, message) def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput): """Initialize the online model parameter update group.""" success, message = self.tp_worker.init_weights_update_group(recv_req) - return success, message + return InitWeightsUpdateGroupReqOutput(success, message) def update_weights_from_distributed( self, @@ -1526,7 +1632,7 @@ def update_weights_from_distributed( assert flash_cache_success, "Cache flush failed after updating weights" else: logger.error(message) - return success, message + return UpdateWeightsFromDistributedReqOutput(success, message) def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): """Update the online model parameter from tensors.""" @@ -1537,11 +1643,33 @@ def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput): assert flash_cache_success, "Cache flush failed after updating weights" else: logger.error(message) - return success, message + return UpdateWeightsFromTensorReqOutput(success, message) def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.tp_worker.get_weights_by_name(recv_req) - return parameter + return GetWeightsByNameReqOutput(parameter) + + def release_memory_occupation(self): + self.stashed_model_static_state = _export_static_state( + self.tp_worker.worker.model_runner.model + ) + self.memory_saver_adapter.pause() + self.flush_cache() + return ReleaseMemoryOccupationReqOutput() + + def resume_memory_occupation(self): + self.memory_saver_adapter.resume() + _import_static_state( + self.tp_worker.worker.model_runner.model, self.stashed_model_static_state + ) + del self.stashed_model_static_state + return ResumeMemoryOccupationReqOutput() + + def profile(self, recv_req: ProfileReq): + if recv_req == ProfileReq.START_PROFILE: + self.start_profile() + else: + self.stop_profile() def start_profile(self) -> None: if self.profiler is None: @@ -1557,20 +1685,20 @@ def stop_profile(self) -> None: ) logger.info("Profiler is done") - def open_session(self, recv_req: OpenSessionReqInput) -> Tuple[Optional[str], bool]: + def open_session(self, recv_req: OpenSessionReqInput): # handle error session_id = recv_req.session_id if session_id in self.sessions: logger.warning(f"session id {session_id} already exist, cannot open.") - return session_id, False + return OpenSessionReqOutput(session_id, False) elif session_id is None: logger.warning(f"session id is None, cannot open.") - return session_id, False + return OpenSessionReqOutput(session_id, False) else: self.sessions[session_id] = Session( recv_req.capacity_of_str_len, session_id ) - return session_id, True + return OpenSessionReqOutput(session_id, True) def close_session(self, recv_req: CloseSessionReqInput): # handle error @@ -1581,6 +1709,20 @@ def close_session(self, recv_req: CloseSessionReqInput): del self.sessions[session_id] +def _export_static_state(model): + return dict( + buffers=[ + (name, buffer.detach().clone()) for name, buffer in model.named_buffers() + ] + ) + + +def _import_static_state(model, static_params): + self_named_buffers = dict(model.named_buffers()) + for name, tensor in static_params["buffers"]: + self_named_buffers[name][...] = tensor + + def run_scheduler_process( server_args: ServerArgs, port_args: PortArgs, @@ -1590,6 +1732,7 @@ def run_scheduler_process( pipe_writer, ): setproctitle.setproctitle("sglang::scheduler") + faulthandler.enable() # [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var if dp_rank is None and "SGLANG_DP_RANK" in os.environ: @@ -1612,7 +1755,11 @@ def run_scheduler_process( try: scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) pipe_writer.send( - {"status": "ready", "max_total_num_tokens": scheduler.max_total_num_tokens} + { + "status": "ready", + "max_total_num_tokens": scheduler.max_total_num_tokens, + "max_req_input_len": scheduler.max_req_input_len, + } ) if scheduler.enable_overlap: scheduler.event_loop_overlap() diff --git a/python/sglang/srt/managers/session_controller.py b/python/sglang/srt/managers/session_controller.py index e3e94ce6b655..4f4af6367573 100644 --- a/python/sglang/srt/managers/session_controller.py +++ b/python/sglang/srt/managers/session_controller.py @@ -99,7 +99,7 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer): if last_req is not None: # trim bos token if it is an append - if req.input_ids[0] == tokenizer.bos_token_id: + if tokenizer is not None and req.input_ids[0] == tokenizer.bos_token_id: req.input_ids = req.input_ids[1:] input_ids = ( @@ -131,6 +131,7 @@ def create_req(self, req: TokenizedGenerateReqInput, tokenizer): sampling_params=req.sampling_params, lora_path=req.lora_path, session_id=self.session_id, + custom_logit_processor=req.custom_logit_processor, ) if last_req is not None: new_req.image_inputs = last_req.image_inputs diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 08dbd02c5ba3..2be2e532d078 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -18,10 +18,14 @@ import dataclasses import logging import os +import pickle import signal import sys +import threading import time import uuid +from datetime import datetime +from http import HTTPStatus from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union import fastapi @@ -43,6 +47,7 @@ BatchStrOut, BatchTokenIDOut, CloseSessionReqInput, + ConfigureLoggingReq, EmbeddingReqInput, FlushCacheReq, GenerateReqInput, @@ -53,6 +58,10 @@ OpenSessionReqInput, OpenSessionReqOutput, ProfileReq, + ReleaseMemoryOccupationReqInput, + ReleaseMemoryOccupationReqOutput, + ResumeMemoryOccupationReqInput, + ResumeMemoryOccupationReqOutput, SessionParams, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -71,6 +80,7 @@ get_zmq_socket, kill_process_tree, ) +from sglang.utils import TypeBasedDispatcher, get_exception_traceback asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) @@ -103,16 +113,19 @@ def __init__( port_args: PortArgs, ): # Parse args + self.server_args = server_args self.enable_metrics = server_args.enable_metrics + self.log_requests = server_args.log_requests + self.log_requests_level = 0 # Init inter-process communication context = zmq.asyncio.Context(2) self.recv_from_detokenizer = get_zmq_socket( - context, zmq.PULL, port_args.tokenizer_ipc_name + context, zmq.PULL, port_args.tokenizer_ipc_name, True ) self.send_to_scheduler = get_zmq_socket( - context, zmq.PUSH, port_args.scheduler_input_ipc_name + context, zmq.PUSH, port_args.scheduler_input_ipc_name, True ) # Read model args @@ -145,6 +158,7 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) self.tokenizer = self.processor.tokenizer os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -158,11 +172,15 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) # Store states - self.to_create_loop = True + self.no_create_loop = False self.rid_to_state: Dict[str, ReqState] = {} + self.dump_requests_folder = "" # By default do not dump + self.dump_requests_threshold = 1000 + self.dump_request_list: List[Tuple] = [] # The event to notify the weight sync is finished. self.model_update_lock = RWLock() @@ -188,6 +206,14 @@ def __init__( self.get_weights_by_name_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.release_memory_occupation_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + self.resume_memory_occupation_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) + # Set after scheduler is initialized + self.max_req_input_len = None # Metrics if self.enable_metrics: @@ -198,6 +224,44 @@ def __init__( }, ) + self._result_dispatcher = TypeBasedDispatcher( + [ + ( + (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut), + self._handle_batch_output, + ), + (OpenSessionReqOutput, self._handle_open_session_req_output), + ( + UpdateWeightFromDiskReqOutput, + self._handle_update_weights_from_disk_req_output, + ), + ( + InitWeightsUpdateGroupReqOutput, + self.init_weights_update_group_communicator.handle_recv, + ), + ( + UpdateWeightsFromDistributedReqOutput, + self.update_weights_from_distributed_communicator.handle_recv, + ), + ( + UpdateWeightsFromTensorReqOutput, + self.update_weights_from_tensor_communicator.handle_recv, + ), + ( + GetWeightsByNameReqOutput, + self.get_weights_by_name_communicator.handle_recv, + ), + ( + ReleaseMemoryOccupationReqOutput, + self.release_memory_occupation_communicator.handle_recv, + ), + ( + ResumeMemoryOccupationReqOutput, + self.resume_memory_occupation_communicator.handle_recv, + ), + ] + ) + async def generate_request( self, obj: Union[GenerateReqInput, EmbeddingReqInput], @@ -215,8 +279,11 @@ async def generate_request( obj.normalize_batch_and_arguments() - if self.server_args.log_requests: - logger.info(f"Receive: obj={dataclass_to_string_truncated(obj)}") + if self.log_requests: + max_length = 2048 if self.log_requests_level == 0 else 1 << 30 + logger.info( + f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}" + ) async with self.model_update_lock.reader_lock: is_single = obj.is_single @@ -248,15 +315,21 @@ async def _tokenize_one_request( ) input_embeds = obj.input_embeds input_ids = obj.input_ids - elif obj.input_ids is None: - input_ids = self.tokenizer.encode(input_text) - else: + elif obj.input_ids is not None: input_ids = obj.input_ids + else: + if self.tokenizer is None: + raise ValueError( + "The engine initialized with skip_tokenizer_init=True cannot " + "accept text prompts. Please provide input_ids or re-initialize " + "the engine with skip_tokenizer_init=False." + ) + input_ids = self.tokenizer.encode(input_text) if self.is_generation: # TODO: also support getting embeddings for multimodal models image_inputs: Dict = await self.image_processor.process_images_async( - obj.image_data, input_text or input_ids, obj + obj.image_data, input_text or input_ids, obj, self.max_req_input_len ) if image_inputs and "input_ids" in image_inputs: input_ids = image_inputs["input_ids"] @@ -267,12 +340,28 @@ async def _tokenize_one_request( SessionParams(**obj.session_params) if obj.session_params else None ) - if obj.input_ids is not None and len(input_ids) >= self.context_len: + input_token_num = len(input_ids) if input_ids is not None else 0 + if input_token_num >= self.context_len: raise ValueError( - f"The input ({len(input_ids)} tokens) is longer than the " + f"The input ({input_token_num} tokens) is longer than the " f"model's context length ({self.context_len} tokens)." ) + if ( + obj.sampling_params.get("max_new_tokens") is not None + and obj.sampling_params.get("max_new_tokens") + input_token_num + >= self.context_len + ): + raise ValueError( + f"Requested token count exceeds the model's maximum context length " + f"of {self.context_len} tokens. You requested a total of " + f"{obj.sampling_params.get('max_new_tokens') + input_token_num} " + f"tokens: {input_token_num} tokens from the input messages and " + f"{obj.sampling_params.get('max_new_tokens')} tokens for the " + f"completion. Please reduce the number of tokens in the input " + f"messages or the completion to fit within the limit." + ) + # Parse sampling parameters sampling_params = SamplingParams(**obj.sampling_params) sampling_params.normalize(self.tokenizer) @@ -293,6 +382,7 @@ async def _tokenize_one_request( lora_path=obj.lora_path, input_embeds=input_embeds, session_params=session_params, + custom_logit_processor=obj.custom_logit_processor, ) elif isinstance(obj, EmbeddingReqInput): tokenized_obj = TokenizedEmbeddingReqInput( @@ -336,10 +426,21 @@ async def _wait_one_response( state.out_list = [] if state.finished: - if self.server_args.log_requests: - msg = f"Finish: obj={dataclass_to_string_truncated(obj)}, out={dataclass_to_string_truncated(out)}" + if self.log_requests: + max_length = 2048 if self.log_requests_level == 0 else 1 << 30 + msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}" logger.info(msg) del self.rid_to_state[obj.rid] + + # Check if this was an abort/error created by scheduler + if isinstance(out["meta_info"].get("finish_reason"), dict): + finish_reason = out["meta_info"]["finish_reason"] + if ( + finish_reason.get("type") == "abort" + and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST + ): + raise ValueError(finish_reason["message"]) + yield out break @@ -548,6 +649,22 @@ async def get_weights_by_name( else: return all_parameters + async def release_memory_occupation( + self, + obj: ReleaseMemoryOccupationReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.release_memory_occupation_communicator(obj) + + async def resume_memory_occupation( + self, + obj: ResumeMemoryOccupationReqInput, + request: Optional[fastapi.Request] = None, + ): + self.auto_create_handle_loop() + await self.resume_memory_occupation_communicator(obj) + async def open_session( self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None ): @@ -568,9 +685,19 @@ async def open_session( async def close_session( self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None ): - assert not self.to_create_loop, "close session should not be the first request" await self.send_to_scheduler.send_pyobj(obj) + def configure_logging(self, obj: ConfigureLoggingReq): + if obj.log_requests is not None: + self.log_requests = obj.log_requests + if obj.log_requests_level is not None: + self.log_requests_level = obj.log_requests_level + if obj.dump_requests_folder is not None: + self.dump_requests_folder = obj.dump_requests_folder + if obj.dump_requests_threshold is not None: + self.dump_requests_threshold = obj.dump_requests_threshold + logging.info(f"Config logging: {obj=}") + def create_abort_task(self, obj: GenerateReqInput): # Abort the request if the client is disconnected. async def abort_request(): @@ -586,22 +713,35 @@ async def abort_request(): return background_tasks def auto_create_handle_loop(self): - if not self.to_create_loop: + if self.no_create_loop: return - self.to_create_loop = False + self.no_create_loop = True loop = asyncio.get_event_loop() - self.asyncio_tasks.add(loop.create_task(self.handle_loop())) + self.asyncio_tasks.add( + loop.create_task(print_exception_wrapper(self.handle_loop)) + ) - signal_handler = SignalHandler(self) - loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler) - self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog())) + # We cannot add signal handler when the tokenizer manager is not in + # the main thread due to the CPython limitation. + if threading.current_thread() is threading.main_thread(): + signal_handler = SignalHandler(self) + loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler) + else: + logger.warning( + "Signal handler is not added because the tokenizer manager is " + "not in the main thread. This disables graceful shutdown of the " + "tokenizer manager when SIGTERM is received." + ) + self.asyncio_tasks.add( + loop.create_task(print_exception_wrapper(self.sigterm_watchdog)) + ) async def sigterm_watchdog(self): while not self.gracefully_exit: await asyncio.sleep(5) - # drain requests + # Drain requests while True: remain_num_req = len(self.rid_to_state) logger.info( @@ -619,139 +759,64 @@ async def handle_loop(self): """The event loop that handles requests""" while True: - recv_obj: Union[ - BatchStrOut, - BatchEmbeddingOut, - BatchTokenIDOut, - UpdateWeightFromDiskReqOutput, - UpdateWeightsFromDistributedReqOutput, - GetWeightsByNameReqOutput, - InitWeightsUpdateGroupReqOutput, - ] = await self.recv_from_detokenizer.recv_pyobj() - - if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)): - for i, rid in enumerate(recv_obj.rids): - state = self.rid_to_state.get(rid, None) - if state is None: - continue - - meta_info = { - "id": rid, - "finish_reason": recv_obj.finished_reasons[i], - "prompt_tokens": recv_obj.prompt_tokens[i], - } + recv_obj = await self.recv_from_detokenizer.recv_pyobj() + self._result_dispatcher(recv_obj) - if getattr(state.obj, "return_logprob", False): - self.convert_logprob_style( - meta_info, - state.obj.top_logprobs_num, - state.obj.return_text_in_logprobs, - recv_obj, - i, - ) - - if not isinstance(recv_obj, BatchEmbeddingOut): - meta_info.update( - { - "completion_tokens": recv_obj.completion_tokens[i], - "cached_tokens": recv_obj.cached_tokens[i], - } - ) - - if isinstance(recv_obj, BatchStrOut): - out_dict = { - "text": recv_obj.output_strs[i], - "meta_info": meta_info, - } - if self.server_args.return_token_ids: - out_dict.update( - { - "input_ids": recv_obj.origin_input_ids[i], - "output_ids": recv_obj.output_ids[i], - } - ) - elif isinstance(recv_obj, BatchTokenIDOut): - out_dict = { - "token_ids": recv_obj.output_ids[i], - "meta_info": meta_info, - } - else: - assert isinstance(recv_obj, BatchEmbeddingOut) - out_dict = { - "embedding": recv_obj.embeddings[i], - "meta_info": meta_info, - } - state.out_list.append(out_dict) - state.finished = recv_obj.finished_reasons[i] is not None - state.event.set() - - if self.enable_metrics: - completion_tokens = ( - recv_obj.completion_tokens[i] - if recv_obj.completion_tokens - else 0 - ) - - if state.first_token_time is None: - state.first_token_time = time.time() - self.metrics_collector.observe_time_to_first_token( - state.first_token_time - state.created_time - ) - else: - if completion_tokens >= 2: - # Compute time_per_output_token for the streaming case - self.metrics_collector.observe_time_per_output_token( - (time.time() - state.first_token_time) - / (completion_tokens - 1) - ) - - if state.finished: - self.metrics_collector.inc_prompt_tokens( - recv_obj.prompt_tokens[i] - ) - self.metrics_collector.inc_generation_tokens( - completion_tokens - ) - self.metrics_collector.observe_e2e_request_latency( - time.time() - state.created_time - ) - # Compute time_per_output_token for the non-streaming case - if not state.obj.stream and completion_tokens >= 1: - self.metrics_collector.observe_time_per_output_token( - (time.time() - state.created_time) - / completion_tokens - ) - elif isinstance(recv_obj, OpenSessionReqOutput): - self.session_futures[recv_obj.session_id].set_result( - recv_obj.session_id if recv_obj.success else None + def _handle_batch_output( + self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut] + ): + for i, rid in enumerate(recv_obj.rids): + state = self.rid_to_state.get(rid, None) + if state is None: + continue + + meta_info = { + "id": rid, + "finish_reason": recv_obj.finished_reasons[i], + "prompt_tokens": recv_obj.prompt_tokens[i], + } + + if getattr(state.obj, "return_logprob", False): + self.convert_logprob_style( + meta_info, + state.obj.top_logprobs_num, + state.obj.return_text_in_logprobs, + recv_obj, + i, + ) + + if not isinstance(recv_obj, BatchEmbeddingOut): + meta_info.update( + { + "completion_tokens": recv_obj.completion_tokens[i], + "cached_tokens": recv_obj.cached_tokens[i], + } ) - elif isinstance(recv_obj, UpdateWeightFromDiskReqOutput): - if self.server_args.dp_size == 1: - self.model_update_result.set_result(recv_obj) - else: # self.server_args.dp_size > 1 - self.model_update_tmp.append(recv_obj) - # set future if the all results are recevied - if len(self.model_update_tmp) == self.server_args.dp_size: - self.model_update_result.set_result(self.model_update_tmp) - elif isinstance(recv_obj, InitWeightsUpdateGroupReqOutput): - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for init parameter update group" - self.init_weights_update_group_communicator.handle_recv(recv_obj) - elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput): - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for update weights from distributed" - self.update_weights_from_distributed_communicator.handle_recv(recv_obj) - elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput): - assert ( - self.server_args.dp_size == 1 - ), "dp_size must be 1 for update weights from distributed" - self.update_weights_from_tensor_communicator.handle_recv(recv_obj) - elif isinstance(recv_obj, GetWeightsByNameReqOutput): - self.get_weights_by_name_communicator.handle_recv(recv_obj) + + if isinstance(recv_obj, BatchStrOut): + out_dict = { + "text": recv_obj.output_strs[i], + "meta_info": meta_info, + } + elif isinstance(recv_obj, BatchTokenIDOut): + out_dict = { + "token_ids": recv_obj.output_ids[i], + "meta_info": meta_info, + } else: - raise ValueError(f"Invalid object: {recv_obj=}") + assert isinstance(recv_obj, BatchEmbeddingOut) + out_dict = { + "embedding": recv_obj.embeddings[i], + "meta_info": meta_info, + } + state.out_list.append(out_dict) + state.finished = recv_obj.finished_reasons[i] is not None + state.event.set() + + if self.enable_metrics and state.obj.log_metrics: + self.collect_metrics(state, recv_obj, i) + if self.dump_requests_folder and state.finished and state.obj.log_metrics: + self.dump_requests(state, out_dict) def convert_logprob_style( self, @@ -771,9 +836,6 @@ def convert_logprob_style( recv_obj.output_token_logprobs_idx[recv_obj_index], return_text_in_logprobs, ) - meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[ - recv_obj_index - ] if top_logprobs_num > 0: meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens( @@ -823,6 +885,93 @@ def detokenize_top_logprobs_tokens( ret.append(None) return ret + def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int): + completion_tokens = ( + recv_obj.completion_tokens[i] + if getattr(recv_obj, "completion_tokens", None) + else 0 + ) + + if state.first_token_time is None: + state.first_token_time = time.time() + self.metrics_collector.observe_time_to_first_token( + state.first_token_time - state.created_time + ) + else: + if completion_tokens >= 2: + # Compute time_per_output_token for the streaming case + self.metrics_collector.observe_time_per_output_token( + (time.time() - state.first_token_time) / (completion_tokens - 1) + ) + + if state.finished: + self.metrics_collector.observe_one_finished_request( + recv_obj.prompt_tokens[i], completion_tokens + ) + self.metrics_collector.observe_e2e_request_latency( + time.time() - state.created_time + ) + # Compute time_per_output_token for the non-streaming case + if ( + hasattr(state.obj, "stream") + and not state.obj.stream + and completion_tokens >= 1 + ): + self.metrics_collector.observe_time_per_output_token( + (time.time() - state.created_time) / completion_tokens + ) + + def dump_requests(self, state: ReqState, out_dict: dict): + self.dump_request_list.append( + (state.obj, out_dict, state.created_time, time.time()) + ) + + if len(self.dump_request_list) >= self.dump_requests_threshold: + filename = os.path.join( + self.dump_requests_folder, + datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl", + ) + logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}") + + to_dump = self.dump_request_list + self.dump_request_list = [] + + def background_task(): + os.makedirs(self.dump_requests_folder, exist_ok=True) + with open(filename, "wb") as f: + pickle.dump(to_dump, f) + + # Schedule the task to run in the background without awaiting it + asyncio.create_task(asyncio.to_thread(background_task)) + + def _handle_open_session_req_output(self, recv_obj): + self.session_futures[recv_obj.session_id].set_result( + recv_obj.session_id if recv_obj.success else None + ) + + def _handle_update_weights_from_disk_req_output(self, recv_obj): + if self.server_args.dp_size == 1: + self.model_update_result.set_result(recv_obj) + else: # self.server_args.dp_size > 1 + self.model_update_tmp.append(recv_obj) + # set future if the all results are recevied + if len(self.model_update_tmp) == self.server_args.dp_size: + self.model_update_result.set_result(self.model_update_tmp) + + +async def print_exception_wrapper(func): + """ + Sometimes an asyncio function does not print exception. + We do another wrapper to handle the exception. + """ + try: + await func() + except Exception: + traceback = get_exception_traceback() + logger.error(f"TokenizerManager hit an exception: {traceback}") + kill_process_tree(os.getpid(), include_parent=True) + sys.exit(1) + class SignalHandler: def __init__(self, tokenizer_manager): diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 25a1c85f2c69..fd4dbae9900d 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -83,6 +83,7 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) self.tokenizer = self.processor.tokenizer else: @@ -90,6 +91,7 @@ def __init__( server_args.tokenizer_path, tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, + revision=server_args.revision, ) self.device = self.model_runner.device @@ -101,6 +103,7 @@ def __init__( self.max_total_num_tokens // 2 if server_args.max_running_requests is None else server_args.max_running_requests + // (server_args.dp_size if server_args.enable_dp_attention else 1) ), self.model_runner.req_to_token_pool.size, ) @@ -142,16 +145,15 @@ def get_pad_input_ids_func(self): def get_tp_cpu_group(self): return self.model_runner.tp_group.cpu_group + def get_attention_tp_cpu_group(self): + return self.model_runner.attention_tp_group.cpu_group + def get_memory_pool(self): return ( self.model_runner.req_to_token_pool, self.model_runner.token_to_kv_pool, ) - def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch): - forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - self.model_runner.forward(forward_batch) - def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 4c98c6be2e4c..961b0bbdc119 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -82,6 +82,8 @@ def __init__( self.forward_thread.start() self.parent_process = psutil.Process().parent() self.scheduler_stream = torch.get_device_module(self.device).current_stream() + if self.device == "cpu": + self.scheduler_stream.synchronize = lambda: None # No-op for CPU def get_worker_info(self): return self.worker.get_worker_info() @@ -92,6 +94,9 @@ def get_pad_input_ids_func(self): def get_tp_cpu_group(self): return self.worker.get_tp_cpu_group() + def get_attention_tp_cpu_group(self): + return self.worker.get_attention_tp_cpu_group() + def get_memory_pool(self): return ( self.worker.model_runner.req_to_token_pool, @@ -151,11 +156,6 @@ def forward_thread_func_(self): logits_output.input_token_logprobs = ( logits_output.input_token_logprobs.to("cpu", non_blocking=True) ) - logits_output.normalized_prompt_logprobs = ( - logits_output.normalized_prompt_logprobs.to( - "cpu", non_blocking=True - ) - ) next_token_ids = next_token_ids.to("cpu", non_blocking=True) copy_done.record() @@ -174,9 +174,6 @@ def resolve_batch_result(self, bid: int): logits_output.input_token_logprobs = ( logits_output.input_token_logprobs.tolist() ) - logits_output.normalized_prompt_logprobs = ( - logits_output.normalized_prompt_logprobs.tolist() - ) next_token_ids = next_token_ids.tolist() return logits_output, next_token_ids diff --git a/python/sglang/srt/managers/utils.py b/python/sglang/srt/managers/utils.py new file mode 100644 index 000000000000..10a1209631eb --- /dev/null +++ b/python/sglang/srt/managers/utils.py @@ -0,0 +1,44 @@ +import logging +from http import HTTPStatus +from typing import Optional + +from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req + +logger = logging.getLogger(__name__) + + +def validate_input_length( + req: Req, max_req_input_len: int, allow_auto_truncate: bool +) -> Optional[str]: + """Validate and potentially truncate input length. + + Args: + req: The request containing input_ids to validate + max_req_input_len: Maximum allowed input length + allow_auto_truncate: Whether to truncate long inputs + + Returns: + Error message if validation fails, None if successful + """ + if len(req.origin_input_ids) >= max_req_input_len: + if allow_auto_truncate: + logger.warning( + "Request length is longer than the KV cache pool size or " + "the max context length. Truncated. " + f"{len(req.origin_input_ids)=}, {max_req_input_len=}." + ) + req.origin_input_ids = req.origin_input_ids[:max_req_input_len] + return None + else: + error_msg = ( + f"Input length ({len(req.origin_input_ids)} tokens) exceeds " + f"the maximum allowed length ({max_req_input_len} tokens). " + f"Use a shorter input or enable --allow-auto-truncate." + ) + logger.error(error_msg) + req.finished_reason = FINISH_ABORT( + error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError" + ) + return error_msg + + return None diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index b67f085b204b..7b9b35611d8d 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -13,6 +13,8 @@ limitations under the License. """ +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter + """ Memory pool. @@ -25,8 +27,9 @@ import threading from enum import IntEnum from functools import wraps -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union +import numpy as np import psutil import torch @@ -35,29 +38,34 @@ logger = logging.getLogger(__name__) +GB = 1024 * 1024 * 1024 + class ReqToTokenPool: """A memory pool that maps a request to its token locations.""" - def __init__(self, size: int, max_context_len: int, device: str, use_records: bool): + def __init__( + self, + size: int, + max_context_len: int, + device: str, + enable_memory_saver: bool, + ): + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) + self.size = size self.max_context_len = max_context_len self.device = device - self.req_to_token = torch.zeros( - (size, max_context_len), dtype=torch.int32, device=device - ) + with memory_saver_adapter.region(): + self.req_to_token = torch.zeros( + (size, max_context_len), dtype=torch.int32, device=device + ) self.free_slots = list(range(size)) - self.write_records = [] - self.use_records = use_records - - if self.use_records: - self.write = self.write_with_records - else: - self.write = self.write_without_records def write(self, indices, values): - # Keep the signature for type checking. It will be assigned during runtime. - raise NotImplementedError() + self.req_to_token[indices] = values def available_size(self): return len(self.free_slots) @@ -79,23 +87,6 @@ def free(self, free_index: Union[int, List[int]]): def clear(self): self.free_slots = list(range(self.size)) - self.write_records = [] - - def write_without_records(self, indices, values): - self.req_to_token[indices] = values - - def write_with_records(self, indices, values): - self.req_to_token[indices] = values - self.write_records.append((indices, values)) - - def get_write_records(self): - ret = self.write_records - self.write_records = [] - return ret - - def apply_write_records(self, write_records: List[Tuple]): - for indices, values in write_records: - self.req_to_token[indices] = values class BaseTokenToKVPool: @@ -109,8 +100,8 @@ def __init__( ): self.size = size self.dtype = dtype - if dtype == torch.float8_e5m2: - # NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2 + if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): + # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 self.store_dtype = torch.uint8 else: self.store_dtype = dtype @@ -186,37 +177,60 @@ def __init__( head_dim: int, layer_num: int, device: str, + enable_memory_saver: bool, ): super().__init__(size, dtype, device) + + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) + self.head_num = head_num self.head_dim = head_dim self.layer_num = layer_num self._create_buffers() + k_size, v_size = self.get_kv_size_bytes() + logger.info( + f"KV Cache is allocated. K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB." + ) + def _create_buffers(self): - # [size, head_num, head_dim] for each layer - # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.k_buffer = [ - torch.empty( - (self.size + 1, self.head_num, self.head_dim), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] - self.v_buffer = [ - torch.empty( - (self.size + 1, self.head_num, self.head_dim), - dtype=self.store_dtype, - device=self.device, - ) - for _ in range(self.layer_num) - ] + with self.memory_saver_adapter.region(): + # [size, head_num, head_dim] for each layer + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.k_buffer = [ + torch.empty( + (self.size + 1, self.head_num, self.head_dim), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] + self.v_buffer = [ + torch.empty( + (self.size + 1, self.head_num, self.head_dim), + dtype=self.store_dtype, + device=self.device, + ) + for _ in range(self.layer_num) + ] def _clear_buffers(self): del self.k_buffer del self.v_buffer + def get_kv_size_bytes(self): + assert hasattr(self, "k_buffer") + assert hasattr(self, "v_buffer") + k_size_bytes = 0 + for k_cache in self.k_buffer: + k_size_bytes += np.prod(k_cache.shape) * k_cache.dtype.itemsize + v_size_bytes = 0 + for v_cache in self.v_buffer: + v_size_bytes += np.prod(v_cache.shape) * v_cache.dtype.itemsize + return k_size_bytes, v_size_bytes + # Todo: different memory layout def get_flat_data(self, indices): # prepare a large chunk of contiguous data for efficient transfer @@ -256,9 +270,15 @@ def set_kv_buffer( loc: torch.Tensor, cache_k: torch.Tensor, cache_v: torch.Tensor, + k_scale: Optional[float] = None, + v_scale: Optional[float] = None, ): layer_id = layer.layer_id if cache_k.dtype != self.dtype: + if k_scale is not None: + cache_k.div_(k_scale) + if v_scale is not None: + cache_v.div_(v_scale) cache_k = cache_k.to(self.dtype) cache_v = cache_v.to(self.dtype) if self.store_dtype != self.dtype: @@ -286,19 +306,26 @@ def __init__( qk_rope_head_dim: int, layer_num: int, device: str, + enable_memory_saver: bool, ): super().__init__(size, dtype, device) self.kv_lora_rank = kv_lora_rank - # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.kv_buffer = [ - torch.empty( - (size + 1, 1, kv_lora_rank + qk_rope_head_dim), - dtype=self.store_dtype, - device=device, - ) - for _ in range(layer_num) - ] + + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) + + with memory_saver_adapter.region(): + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.kv_buffer = [ + torch.empty( + (size + 1, 1, kv_lora_rank + qk_rope_head_dim), + dtype=self.store_dtype, + device=device, + ) + for _ in range(layer_num) + ] def get_key_buffer(self, layer_id: int): if self.store_dtype != self.dtype: @@ -339,26 +366,32 @@ def __init__( layer_num: int, device: str, heavy_channel_num: int, + enable_memory_saver: bool, ): super().__init__(size, dtype, device) - # [size, head_num, head_dim] for each layer - self.k_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) - for _ in range(layer_num) - ] - self.v_buffer = [ - torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) - for _ in range(layer_num) - ] - - # [size, head_num, heavy_channel_num] for each layer - self.label_buffer = [ - torch.empty( - (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device - ) - for _ in range(layer_num) - ] + memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=enable_memory_saver + ) + + with memory_saver_adapter.region(): + # [size, head_num, head_dim] for each layer + self.k_buffer = [ + torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) + for _ in range(layer_num) + ] + self.v_buffer = [ + torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) + for _ in range(layer_num) + ] + + # [size, head_num, heavy_channel_num] for each layer + self.label_buffer = [ + torch.empty( + (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device + ) + for _ in range(layer_num) + ] def get_key_buffer(self, layer_id: int): return self.k_buffer[layer_id] diff --git a/python/sglang/srt/metrics/collector.py b/python/sglang/srt/metrics/collector.py index 9505f012f067..26eb2fc27d22 100644 --- a/python/sglang/srt/metrics/collector.py +++ b/python/sglang/srt/metrics/collector.py @@ -25,6 +25,7 @@ class SchedulerStats: gen_throughput: float = 0.0 num_queue_reqs: int = 0 cache_hit_rate: float = 0.0 + spec_accept_length: float = 0.0 class SchedulerMetricsCollector: @@ -37,42 +38,49 @@ def __init__(self, labels: Dict[str, str]) -> None: self.num_running_reqs = Gauge( name="sglang:num_running_reqs", - documentation="The number of running requests", + documentation="The number of running requests.", labelnames=labels.keys(), multiprocess_mode="sum", ) self.num_used_tokens = Gauge( name="sglang:num_used_tokens", - documentation="The number of used tokens", + documentation="The number of used tokens.", labelnames=labels.keys(), multiprocess_mode="sum", ) self.token_usage = Gauge( name="sglang:token_usage", - documentation="The token usage", + documentation="The token usage.", labelnames=labels.keys(), multiprocess_mode="mostrecent", ) self.gen_throughput = Gauge( name="sglang:gen_throughput", - documentation="The generate throughput (token/s)", + documentation="The generation throughput (token/s).", labelnames=labels.keys(), multiprocess_mode="sum", ) self.num_queue_reqs = Gauge( name="sglang:num_queue_reqs", - documentation="The number of requests in the waiting queue", + documentation="The number of requests in the waiting queue.", labelnames=labels.keys(), multiprocess_mode="sum", ) self.cache_hit_rate = Gauge( name="sglang:cache_hit_rate", - documentation="The cache hit rate", + documentation="The prefix cache hit rate.", + labelnames=labels.keys(), + multiprocess_mode="mostrecent", + ) + + self.spec_accept_length = Gauge( + name="sglang:spec_accept_length", + documentation="The average acceptance length of speculative decoding.", labelnames=labels.keys(), multiprocess_mode="mostrecent", ) @@ -88,6 +96,7 @@ def log_stats(self, stats: SchedulerStats) -> None: self._log_gauge(self.gen_throughput, stats.gen_throughput) self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs) self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate) + self._log_gauge(self.spec_accept_length, stats.spec_accept_length) class TokenizerMetricsCollector: @@ -109,6 +118,12 @@ def __init__(self, labels: Dict[str, str]) -> None: labelnames=labels.keys(), ) + self.num_requests_total = Counter( + name="sglang:num_requests_total", + documentation="Number of requests processed.", + labelnames=labels.keys(), + ) + self.histogram_time_to_first_token = Histogram( name="sglang:time_to_first_token_seconds", documentation="Histogram of time to first token in seconds.", @@ -185,11 +200,10 @@ def _log_counter(self, counter, data: Union[int, float]) -> None: # Convenience function for logging to counter. counter.labels(**self.labels).inc(data) - def inc_prompt_tokens(self, value: int): - self._log_counter(self.prompt_tokens_total, value) - - def inc_generation_tokens(self, value: int): - self._log_counter(self.generation_tokens_total, value) + def observe_one_finished_request(self, prompt_tokens: int, generation_tokens: int): + self.prompt_tokens_total.labels(**self.labels).inc(prompt_tokens) + self.generation_tokens_total.labels(**self.labels).inc(generation_tokens) + self.num_requests_total.labels(**self.labels).inc(1) def observe_time_to_first_token(self, value: Union[float, int]): self._log_histogram(self.histogram_time_to_first_token, value) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index deaea33129d1..169b64343681 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -21,10 +21,10 @@ import torch import tqdm -from vllm.distributed import get_tensor_model_parallel_rank -from vllm.distributed.parallel_state import graph_capture from vllm.model_executor.custom_op import CustomOp +from sglang.srt.distributed import get_tensor_model_parallel_rank +from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native from sglang.srt.layers.torchao_utils import save_gemlite_cache @@ -33,7 +33,6 @@ ForwardBatch, ForwardMode, ) -from sglang.srt.utils import monkey_patch_vllm_all_gather if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner @@ -64,7 +63,7 @@ def patch_model( model: torch.nn.Module, enable_compile: bool, batch_size: int, - tp_group: "GroupCoordinator", + tp_group: GroupCoordinator, ): """Patch the model to make it compatible with with torch.compile""" backup_ca_comm = None @@ -72,7 +71,6 @@ def patch_model( try: if enable_compile: _to_torch(model, reverse=False, batch_size=batch_size) - monkey_patch_vllm_all_gather() backup_ca_comm = tp_group.ca_comm # Use custom-allreduce here. # We found the custom allreduce is much faster than the built-in allreduce in torch, @@ -88,7 +86,6 @@ def patch_model( finally: if enable_compile: _to_torch(model, reverse=True, batch_size=batch_size) - monkey_patch_vllm_all_gather(reverse=True) tp_group.ca_comm = backup_ca_comm @@ -122,6 +119,7 @@ def __init__(self, model_runner: "ModelRunner"): self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention self.tp_size = self.model_runner.tp_size + self.dp_size = self.model_runner.server_args.dp_size # Batch sizes to capture self.capture_bs = self.model_runner.server_args.cuda_graph_bs @@ -131,11 +129,6 @@ def __init__(self, model_runner: "ModelRunner"): else: self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] - if model_runner.server_args.disable_cuda_graph_padding: - self.capture_bs = list(range(1, 33)) + [64, 128] - else: - self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] - if max(self.capture_bs) > model_runner.req_to_token_pool.size: # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests # is very samll. We add more values here to make sure we capture the maximum bs. @@ -156,9 +149,18 @@ def __init__(self, model_runner: "ModelRunner"): and bs <= model_runner.server_args.cuda_graph_max_bs ] + self.compile_bs = ( + [ + bs + for bs in self.capture_bs + if bs <= self.model_runner.server_args.torch_compile_max_bs + ] + if self.use_torch_compile + else [] + ) + self.capture_forward_mode = ForwardMode.DECODE self.num_tokens_per_bs = 1 - if model_runner.spec_algorithm.is_eagle(): if self.model_runner.is_draft_worker: self.num_tokens_per_bs = ( @@ -170,16 +172,6 @@ def __init__(self, model_runner: "ModelRunner"): self.model_runner.server_args.speculative_num_draft_tokens ) - self.compile_bs = ( - [ - bs - for bs in self.capture_bs - if bs <= self.model_runner.server_args.torch_compile_max_bs - ] - if self.use_torch_compile - else [] - ) - # Attention backend self.max_bs = max(self.capture_bs) self.max_num_token = self.max_bs * self.num_tokens_per_bs @@ -187,7 +179,6 @@ def __init__(self, model_runner: "ModelRunner"): self.seq_len_fill_value = ( self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) - # FIXME(lsyin): leave it here for now, I don't know whether it is necessary self.encoder_len_fill_value = 0 @@ -196,14 +187,14 @@ def __init__(self, model_runner: "ModelRunner"): # Common inputs with torch.device("cuda"): - self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32) + self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.seq_lens = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) - self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32) + self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64) self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) - self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32) + self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) # Speculative_inference if model_runner.spec_algorithm.is_eagle(): @@ -223,7 +214,7 @@ def __init__(self, model_runner: "ModelRunner"): if self.enable_dp_attention: self.gathered_buffer = torch.zeros( ( - self.max_bs * self.tp_size, + self.max_bs * self.dp_size, self.model_runner.model_config.hidden_size, ), dtype=self.model_runner.dtype, diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index fab8b15a3316..8bd1052754c9 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -38,7 +38,7 @@ import triton.language as tl from sglang.srt.layers.rotary_embedding import MRotaryEmbedding -from sglang.srt.utils import maybe_torch_compile +from sglang.srt.utils import get_compiler_backend if TYPE_CHECKING: from sglang.srt.layers.attention import AttentionBackend @@ -106,6 +106,9 @@ def is_cuda_graph(self): def is_dummy_first(self): return self == ForwardMode.DUMMY_FIRST + def is_decode_or_idle(self): + return self == ForwardMode.DECODE or self == ForwardMode.IDLE + class CaptureHiddenMode(IntEnum): NULL = auto() @@ -279,6 +282,9 @@ def init_new( can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph, lora_paths=batch.lora_paths, sampling_info=batch.sampling_info, + req_to_token_pool=model_runner.req_to_token_pool, + token_to_kv_pool=model_runner.token_to_kv_pool, + attn_backend=model_runner.attn_backend, spec_algorithm=batch.spec_algorithm, spec_info=batch.spec_info, capture_hidden_mode=batch.capture_hidden_mode, @@ -333,11 +339,6 @@ def init_new( if model_runner.model_is_mrope: ret.compute_mrope_positions(model_runner, batch) - # Init attention information - ret.req_to_token_pool = model_runner.req_to_token_pool - ret.token_to_kv_pool = model_runner.token_to_kv_pool - ret.attn_backend = model_runner.attn_backend - # Init lora information if model_runner.server_args.lora_paths is not None: model_runner.lora_manager.prepare_lora_batch(ret) @@ -414,6 +415,6 @@ def compute_position_torch( return positions.to(torch.int64), extend_start_loc -@maybe_torch_compile(dynamic=True) +@torch.compile(dynamic=True, backend=get_compiler_backend()) def clamp_position(seq_lens): return torch.clamp((seq_lens - 1), min=0).to(torch.int64) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 7cd9e759a3dc..e7dc6bd66c53 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -21,20 +21,26 @@ import torch import torch.distributed as dist -from vllm.distributed import ( + +from sglang.srt.configs.device_config import DeviceConfig +from sglang.srt.configs.load_config import LoadConfig +from sglang.srt.configs.model_config import AttentionArch, ModelConfig +from sglang.srt.distributed import ( get_tp_group, init_distributed_environment, initialize_model_parallel, set_custom_all_reduce, ) - -from sglang.srt.configs.device_config import DeviceConfig -from sglang.srt.configs.load_config import LoadConfig -from sglang.srt.configs.model_config import AttentionArch, ModelConfig +from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend +from sglang.srt.layers.dp_attention import ( + get_attention_tp_group, + get_attention_tp_size, + initialize_dp_attention, +) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model @@ -50,13 +56,15 @@ from sglang.srt.model_loader import get_model from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( enable_show_time_cost, get_available_gpu_memory, init_custom_process_group, + is_cuda, is_hip, + monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, - monkey_patch_vllm_p2p_access_check, set_cpu_offload_max_bytes, ) @@ -99,8 +107,10 @@ def __init__( self.model_config.attention_arch == AttentionArch.MLA and not self.server_args.disable_mla ): - logger.info("MLA optimization is turned on. Use triton backend.") - self.server_args.attention_backend = "triton" + # TODO: add MLA optimization on CPU + if self.server_args.device != "cpu": + logger.info("MLA optimization is turned on. Use triton backend.") + self.server_args.attention_backend = "triton" if self.server_args.enable_double_sparsity: logger.info( @@ -157,6 +167,7 @@ def __init__( "enable_nan_detection": server_args.enable_nan_detection, "enable_dp_attention": server_args.enable_dp_attention, "enable_ep_moe": server_args.enable_ep_moe, + "device": server_args.device, } ) @@ -165,14 +176,21 @@ def __init__( # Get memory before model loading min_per_gpu_memory = self.init_torch_distributed() + self.memory_saver_adapter = TorchMemorySaverAdapter.create( + enable=self.server_args.enable_memory_saver + ) + # Load the model self.sampler = Sampler() self.load_model() # Apply torchao quantization - apply_torchao_config_to_model( - self.model, global_server_args_dict["torchao_config"] - ) + torchao_applied = getattr(self.model, "torchao_applied", False) + # In layered loading, torchao may have been applied + if not torchao_applied: + apply_torchao_config_to_model( + self.model, global_server_args_dict["torchao_config"] + ) # Apply torch TP if the model supports it supports_torch_tp = getattr(self.model, "supports_torch_tp", False) @@ -205,14 +223,17 @@ def init_torch_distributed(self): if self.device == "cuda": backend = "nccl" elif self.device == "xpu": - # TODO(liangan1):Just use gloo to bypass the initilization fail + # TODO(liangan1): Just use gloo to bypass the initilization fail # Need to use xccl for xpu backend in the future backend = "gloo" elif self.device == "hpu": backend = "hccl" + elif self.device == "cpu": + backend = "gloo" if not self.server_args.enable_p2p_check: - monkey_patch_vllm_p2p_access_check(self.gpu_id) + monkey_patch_p2p_access_check() + if self.server_args.dist_init_addr: dist_init_method = f"tcp://{self.server_args.dist_init_addr}" else: @@ -220,7 +241,7 @@ def init_torch_distributed(self): set_custom_all_reduce(not self.server_args.disable_custom_all_reduce) if not self.is_draft_worker: - # Only initilzie the distributed environment on the target model worker. + # Only initialize the distributed environment on the target model worker. init_distributed_environment( backend=backend, world_size=self.tp_size, @@ -229,11 +250,18 @@ def init_torch_distributed(self): distributed_init_method=dist_init_method, ) initialize_model_parallel(tensor_model_parallel_size=self.tp_size) + initialize_dp_attention( + enable_dp_attention=self.server_args.enable_dp_attention, + tp_rank=self.tp_rank, + tp_size=self.tp_size, + dp_size=self.server_args.dp_size, + ) min_per_gpu_memory = get_available_gpu_memory( self.device, self.gpu_id, distributed=self.tp_size > 1 ) self.tp_group = get_tp_group() + self.attention_tp_group = get_attention_tp_group() # Check memory for tensor parallelism if self.tp_size > 1: @@ -251,7 +279,8 @@ def load_model(self): ) # This can reduce thread conflicts and speed up weight loading. - torch.set_num_threads(1) + if self.device != "cpu": + torch.set_num_threads(1) if self.device == "cuda": if torch.cuda.get_device_capability()[0] < 8: logger.info( @@ -271,11 +300,38 @@ def load_model(self): monkey_patch_vllm_gguf_config() # Load the model - self.model = get_model( - model_config=self.model_config, - load_config=self.load_config, - device_config=DeviceConfig(self.device), - ) + # Remove monkey_patch when linear.py quant remove dependencies with vllm + monkey_patch_vllm_parallel_state() + with self.memory_saver_adapter.region(): + self.model = get_model( + model_config=self.model_config, + load_config=self.load_config, + device_config=DeviceConfig(self.device), + ) + monkey_patch_vllm_parallel_state(reverse=True) + + if self.server_args.kv_cache_dtype == "fp8_e4m3": + if self.server_args.quantization_param_path is not None: + if callable(getattr(self.model, "load_kv_cache_scales", None)): + self.model.load_kv_cache_scales( + self.server_args.quantization_param_path + ) + logger.info( + "Loaded KV cache scaling factors from %s", + self.server_args.quantization_param_path, + ) + else: + raise RuntimeError( + "Using FP8 KV cache and scaling factors provided but " + "model %s does not support loading scaling factors.", + self.model.__class__, + ) + else: + logger.warning( + "Using FP8 KV cache but no scaling factors " + "provided. Defaulting to scaling factors of 1.0. " + "This may lead to less accurate results!" + ) # Parse other args self.sliding_window_size = ( @@ -393,7 +449,7 @@ def init_weights_update_group( logger.info( f"init custom process group: master_address={master_address}, master_port={master_port}, " - f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}" + f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}" ) try: @@ -491,7 +547,7 @@ def profile_max_num_token(self, total_gpu_memory: int): ) else: cell_size = ( - self.model_config.get_num_kv_heads(self.tp_size) + self.model_config.get_num_kv_heads(get_attention_tp_size()) * self.model_config.head_dim * self.model_config.num_hidden_layers * 2 @@ -516,6 +572,9 @@ def init_memory_pool( self.kv_cache_dtype = torch.float8_e5m2fnuz else: self.kv_cache_dtype = torch.float8_e5m2 + elif self.server_args.kv_cache_dtype == "fp8_e4m3": + if is_cuda(): + self.kv_cache_dtype = torch.float8_e4m3fn else: raise ValueError( f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}." @@ -562,7 +621,7 @@ def init_memory_pool( size=max_num_reqs + 1, max_context_len=self.model_config.context_len + 4, device=self.device, - use_records=False, + enable_memory_saver=self.server_args.enable_memory_saver, ) if ( self.model_config.attention_arch == AttentionArch.MLA @@ -575,25 +634,28 @@ def init_memory_pool( qk_rope_head_dim=self.model_config.qk_rope_head_dim, layer_num=self.model_config.num_hidden_layers, device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, ) elif self.server_args.enable_double_sparsity: self.token_to_kv_pool = DoubleSparseTokenToKVPool( self.max_total_num_tokens, dtype=self.kv_cache_dtype, - head_num=self.model_config.get_num_kv_heads(self.tp_size), + head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, device=self.device, heavy_channel_num=self.server_args.ds_heavy_channel_num, + enable_memory_saver=self.server_args.enable_memory_saver, ) else: self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, dtype=self.kv_cache_dtype, - head_num=self.model_config.get_num_kv_heads(self.tp_size), + head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, device=self.device, + enable_memory_saver=self.server_args.enable_memory_saver, ) logger.info( f"Memory pool end. " @@ -634,7 +696,6 @@ def init_attention_backend(self): ) def init_double_sparsity_channel_config(self, selected_channel): - selected_channel = "." + selected_channel + "_proj" self.sorted_channels = [] # load channel config @@ -725,7 +786,7 @@ def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput: elif forward_batch.forward_mode.is_idle(): return self.forward_idle(forward_batch) else: - raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}") + raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}") def sample( self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 776b69aafa7c..9e6b09488e61 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -21,14 +21,14 @@ from torch import nn from transformers import AutoModelForCausalLM, PretrainedConfig from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm.distributed import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.model_loader.utils import ( get_model_architecture, @@ -374,6 +374,78 @@ def load_model( return model.eval() +class LayeredModelLoader(DefaultModelLoader): + """Model loader that loads weights layer by layer so that one can quantize a + layer before loading another to make the peak memory envelope smaller.""" + + def __init__(self, load_config: LoadConfig): + # Back to the default load format + load_config.load_format = LoadFormat.AUTO + super().__init__(load_config) + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model + from sglang.srt.managers.schedule_batch import global_server_args_dict + + torchao_config = global_server_args_dict.get("torchao_config") + target_device = torch.device(device_config.device) + + with set_default_torch_dtype(model_config.dtype): + # Create model on meta device + with torch.device("meta"): + model = _initialize_model( + model_config, + self.load_config, + ) + + # Check model's layered load support + if not hasattr(model, "load_weights_to_module"): + raise ValueError( + "LayeredModelLoader requires the model to have a " + "`load_weights_to_module` method. " + f"{model_config.model_path} does not support it." + ) + + # Get all weights from disk + weights = self._get_all_weights(model_config, model) + + # Helper function to recursively fill the weights of a module + def fill_module(module, fqn: List[str], weights): + """ + fqn: list of strings representing the fully qualified name of `module`. + """ + # Layer by layer + for name, submod in module.named_children(): + fill_module(submod, fqn + [name], weights) + + # First materialize on target device + module.to_empty(device=target_device, recurse=False) + fqn_path = ".".join(fqn) + # Fill weights + model.load_weights_to_module( + fqn_path, + weights, + ) + # Quantize weights if applicable + if torchao_config and "proj" in fqn_path: + # Note: `None` here is needed to indicate no filter, see + # `apply_torchao_config_to_model` for details. + apply_torchao_config_to_model(module, torchao_config, None) + + # Start calling on root module + fill_module(model, [], weights) + + if torchao_config: + model.torchao_applied = True + + return model.eval() + + class DummyModelLoader(BaseModelLoader): """Model loader that will set model weights to random values.""" @@ -496,7 +568,8 @@ def load_model( device_config: DeviceConfig, ) -> nn.Module: from safetensors.torch import safe_open - from vllm.distributed import get_tensor_model_parallel_rank + + from sglang.srt.distributed import get_tensor_model_parallel_rank local_model_path = self._prepare_weights( model_config.model_path, model_config.revision @@ -556,7 +629,8 @@ def save_model( max_size: Optional[int] = None, ) -> None: from safetensors.torch import save_file - from vllm.distributed import get_tensor_model_parallel_rank + + from sglang.srt.distributed import get_tensor_model_parallel_rank if pattern is None: pattern = ShardedStateLoader.DEFAULT_PATTERN @@ -1147,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: if load_config.load_format == LoadFormat.GGUF: return GGUFModelLoader(load_config) + if load_config.load_format == LoadFormat.LAYERED: + return LayeredModelLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index 13b323b5d329..f2f67ecab1d4 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -9,7 +9,17 @@ import os import tempfile from collections import defaultdict -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, + Union, +) import filelock import gguf @@ -17,12 +27,13 @@ import numpy as np import torch from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download +from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm -from vllm.distributed import get_tensor_model_parallel_rank from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config from sglang.srt.utils import print_warning_once @@ -638,3 +649,121 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]: # If there were no matches, return the untouched param name return name + + +# Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py +class KVCacheQuantSchema(BaseModel): + dtype: str + # Each key is a TP rank. Each value is a dictionary mapping a TP rank's + # layer indices to their per-tensor KV cache scaling factor. + # TODO: Consider pulling this and its validation methods out into its + # own schema class (tricky as its members are variable) + scaling_factor: Dict[int, Dict[int, float]] + + @model_validator(mode="after") + def check_is_fp8(self) -> "KVCacheQuantSchema": + assert self.dtype == "float8_e4m3fn", ( + "Loaded scaling factors intended for KV cache dtype = " + f"{self.dtype} rather than float8_e4m3fn!" + ) + return self + + @model_validator(mode="after") + def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_size = context["tp_size"] + num_hidden_layers = context["num_hidden_layers"] + assert len(self.scaling_factor) == tp_size, ( + f"Loaded dictionary has TP size {len(self.scaling_factor)} " + f"but LLM engine is currently running with TP size {tp_size}." + ) + for tp_rank, layer_maps in self.scaling_factor.items(): + assert len(layer_maps) == num_hidden_layers, ( + f"KV cache scales map for TP rank {tp_rank} is malformed. " + f"Expected {num_hidden_layers} layers, got " + f"{len(layer_maps)}." + ) + for i in range(tp_size): + assert ( + i in self.scaling_factor + ), f"KV cache scales map for TP rank {i} not found." + return self + + @model_validator(mode="after") + def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema": + context = info.context + if context: + tp_rank = context["tp_rank"] + num_hidden_layers = context["num_hidden_layers"] + layer_scales_map = self.scaling_factor[tp_rank] + for i in range(num_hidden_layers): + assert i in layer_scales_map, ( + f"Could not find KV cache scales for layer {i} in " + f"TP rank {tp_rank}." + ) + return self + + +class QuantParamSchema(BaseModel): + # TODO: Generalize and extend with more fields + # (e.g. weights/activations params) once functionality is enabled + model_config = ConfigDict(protected_namespaces=()) + model_type: Optional[str] + kv_cache: KVCacheQuantSchema + + @model_validator(mode="after") + def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema": + context = info.context + if context: + model_type = context.get("model_type", None) + if model_type is not None: + assert model_type == self.model_type, ( + f"Model type is {model_type} but loaded " + f"scaling factors belonging to different " + f"model type {self.model_type}!" + ) + return self + + +def kv_cache_scales_loader( + filename: str, + tp_rank: int, + tp_size: int, + num_hidden_layers: int, + model_type: Optional[str], +) -> Iterable[Tuple[int, float]]: + """ + A simple utility to read in KV cache scaling factors that have been + previously serialized to disk. Used by the model to populate the appropriate + KV cache scaling factors. The serialization should represent a dictionary + whose keys are the TP ranks and values are another dictionary mapping layers + to their KV cache scaling factors. + """ + try: + with open(filename) as f: + context = { + "model_type": model_type, + "num_hidden_layers": num_hidden_layers, + "tp_rank": tp_rank, + "tp_size": tp_size, + } + schema_dct = json.load(f) + schema = QuantParamSchema.model_validate(schema_dct, context=context) + layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] + return layer_scales_map.items() + except FileNotFoundError: + logger.error("File or directory '%s' not found.", filename) + except json.JSONDecodeError: + logger.error("Error decoding JSON in file '%s'.", filename) + except Exception: + logger.error("An error occurred while reading '%s'.", filename) + # This section is reached if and only if any of the excepts are hit + # Return an empty iterable (list) => no KV cache scales are loaded + # which ultimately defaults to 1.0 scales + logger.warning( + "Defaulting to KV cache scaling factors = 1.0 for all " + "layers in TP rank %d as an error occurred during loading.", + tp_rank, + ) + return [] diff --git a/python/sglang/srt/models/baichuan.py b/python/sglang/srt/models/baichuan.py index 3bd60c25d3e4..066157f05ce1 100644 --- a/python/sglang/srt/models/baichuan.py +++ b/python/sglang/srt/models/baichuan.py @@ -24,22 +24,22 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.layers.linear import ( +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - -from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index 9c3bc2ee9e0a..222cc3e2d805 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -21,10 +21,9 @@ import torch from torch import nn from torch.nn import LayerNorm -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.transformers_utils.configs import ChatGLMConfig +from sglang.srt.configs import ChatGLMConfig +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -35,6 +34,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index 83ac3d8671b8..e4b291b66cb2 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -44,12 +44,11 @@ from torch import nn from torch.nn.parameter import Parameter from transformers import PretrainedConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -59,9 +58,13 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from sglang.srt.utils import get_compiler_backend, set_weight_attrs @@ -372,10 +375,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded_params.add(name) -EntryClass = CohereForCausalLM +class Cohere2ForCausalLM(CohereForCausalLM): + pass + + +EntryClass = [CohereForCausalLM, Cohere2ForCausalLM] diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 852f58a710d6..92fc679391fd 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -19,14 +19,13 @@ import torch import torch.nn as nn -from vllm.distributed import ( + +from sglang.srt.configs import DbrxConfig +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.transformers_utils.configs.dbrx import DbrxConfig - from sglang.srt.layers.linear import ( QKVParallelLinear, ReplicatedLinear, @@ -36,13 +35,17 @@ from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from sglang.srt.utils import set_weight_attrs @@ -411,6 +414,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, weight_name) break else: + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index d840cb866bd2..7d2c0700fe45 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -21,13 +21,12 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -40,6 +39,7 @@ from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index a9c0b59cea37..17d7fcf8924c 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -23,14 +23,13 @@ from torch import nn from transformers import PretrainedConfig from vllm import _custom_ops as ops -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, get_tp_group, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -49,6 +48,7 @@ normalize_e4m3fn_to_e4m3fnuz, ) from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -271,13 +271,14 @@ def __init__( quant_config=quant_config, ) rope_scaling["rope_type"] = "deepseek_yarn" - self.rotary_emb = get_rope( + self.rotary_emb = get_rope_wrapper( qk_rope_head_dim, rotary_dim=qk_rope_head_dim, max_position=max_position_embeddings, base=rope_theta, rope_scaling=rope_scaling, is_neox_style=False, + device=global_server_args_dict["device"], ) if rope_scaling: @@ -855,10 +856,9 @@ def forward( forward_batch: ForwardBatch, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch) - if not forward_batch.forward_mode.is_idle(): - return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch - ) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/exaone.py b/python/sglang/srt/models/exaone.py index 536c253c33ae..10be1e74d617 100644 --- a/python/sglang/srt/models/exaone.py +++ b/python/sglang/srt/models/exaone.py @@ -20,9 +20,8 @@ import torch from torch import nn -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -33,6 +32,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index 10949a2f5727..9940c569e257 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -21,9 +21,8 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -34,6 +33,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 58d9ce02f20a..06a7b030260a 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -15,13 +15,13 @@ # Adapted from: # https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py -from typing import Iterable, Optional, Set, Tuple, Union +from typing import Iterable, Optional, Set, Tuple import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.linear import ( @@ -32,9 +32,13 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) from sglang.srt.utils import make_layers @@ -44,23 +48,6 @@ def get_attention_sliding_window_size(config): return config.sliding_window - 1 -# FIXME: temporary solution, remove after next vllm release -from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding - - -class GemmaRotaryEmbedding(RotaryEmbedding): - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: - # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 - inv_freq = 1.0 / ( - base - ** ( - torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() - / self.rotary_dim - ) - ) - return inv_freq - - class Gemma2MLP(nn.Module): def __init__( self, @@ -143,14 +130,12 @@ def __init__( bias=config.attention_bias, quant_config=quant_config, ) - # from vLLM: TODO(woosuk): Use the `get_rope` interface. - self.rotary_emb = GemmaRotaryEmbedding( - self.head_dim, + self.rotary_emb = get_rope( self.head_dim, - max_position_embeddings, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, base=self.rope_theta, is_neox_style=True, - dtype=torch.get_default_dtype(), ) use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window") @@ -442,6 +427,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/gpt2.py b/python/sglang/srt/models/gpt2.py index 144ad8bbf728..04c3005ce2f3 100644 --- a/python/sglang/srt/models/gpt2.py +++ b/python/sglang/srt/models/gpt2.py @@ -17,16 +17,14 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch from torch import nn from transformers import GPT2Config -from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding -# from sglang.srt.layers.activation import get_act_fn +from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index f2f5ebd5204d..0d705fb41b60 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -21,8 +21,8 @@ import torch from torch import nn from transformers import GPTBigCodeConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.linear import ( ColumnParallelLinear, diff --git a/python/sglang/srt/models/granite.py b/python/sglang/srt/models/granite.py index d207ff61b26d..255f23227ff5 100644 --- a/python/sglang/srt/models/granite.py +++ b/python/sglang/srt/models/granite.py @@ -22,9 +22,8 @@ import torch from torch import nn from transformers import GraniteConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -36,6 +35,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index 0485b80fc3a2..c13d3e253688 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -22,12 +22,11 @@ import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -40,6 +39,7 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -57,6 +57,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, prefix: str = "", reduce_results=True, + use_presharded_weights: bool = False, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -65,6 +66,7 @@ def __init__( bias=False, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", + use_presharded_weights=use_presharded_weights, ) self.down_proj = RowParallelLinear( intermediate_size, @@ -73,6 +75,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.down_proj", reduce_results=reduce_results, + use_presharded_weights=use_presharded_weights, ) self.act_fn = GeluAndMul(approximate="tanh") @@ -103,6 +106,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, reduce_results=True, + use_presharded_weights: bool = False, ): super().__init__() self.hidden_size = hidden_size @@ -129,6 +133,7 @@ def __init__( renormalize=False, quant_config=quant_config, tp_size=tp_size, + use_presharded_weights=use_presharded_weights, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -156,6 +161,7 @@ def __init__( max_position: int = 4096 * 32, rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, ) -> None: super().__init__() self.config = config @@ -194,6 +200,7 @@ def __init__( hidden_size, bias=False, quant_config=quant_config, + reduce_results=reduce_results, ) self.rotary_emb = get_rope( self.head_dim, @@ -234,10 +241,12 @@ def __init__( config: PretrainedConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + use_presharded_weights: bool = False, ) -> None: super().__init__() self.num_experts = config.num_local_experts self.hidden_size = config.hidden_size + self.layer_id = layer_id rope_theta = getattr(config, "rope_theta", 10000) self.self_attn = Grok1Attention( @@ -262,6 +271,7 @@ def __init__( ), quant_config=quant_config, reduce_results=True, + use_presharded_weights=use_presharded_weights, ) self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -299,6 +309,7 @@ def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + use_presharded_weights: bool = False, ) -> None: super().__init__() self.config = config @@ -311,7 +322,12 @@ def __init__( ) self.layers = nn.ModuleList( [ - Grok1DecoderLayer(config, i, quant_config=quant_config) + Grok1DecoderLayer( + config, + i, + quant_config=quant_config, + use_presharded_weights=use_presharded_weights, + ) for i in range(config.num_hidden_layers) ] ) @@ -347,11 +363,7 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.model = Grok1Model(config, quant_config=quant_config) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.logits_processor = LogitsProcessor(config) - # Monkey patch _prepare_weights to load pre-sharded weights if ( self.config.num_local_experts > 0 and get_tensor_model_parallel_world_size() > 1 @@ -361,6 +373,14 @@ def __init__( else: self.use_presharded_weights = False + self.model = Grok1Model( + config, + quant_config=quant_config, + use_presharded_weights=self.use_presharded_weights, + ) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.logits_processor = LogitsProcessor(config) + def forward( self, input_ids: torch.Tensor, @@ -376,10 +396,7 @@ def forward( def load_weights( self, weights: Iterable[Tuple[str, torch.Tensor]], - use_presharded_weights: Optional[bool] = None, ): - if use_presharded_weights is None: - use_presharded_weights = self.use_presharded_weights num_experts = self.config.num_local_experts stacked_params_mapping = [ @@ -435,20 +452,12 @@ def load_weight_wrapper(name, loaded_weight, *args, **kwargs): continue name = name.replace(weight_name, param_name) - if use_presharded_weights: - extra_kwargs = { - "use_presharded_weights": use_presharded_weights - } - else: - extra_kwargs = {} - load_weight_wrapper( name, loaded_weight, name, shard_id=shard_id, expert_id=expert_id, - **extra_kwargs, ) break else: diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index 0a737c1388b8..ce8f9a3cf651 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -19,9 +19,8 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -32,6 +31,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index e1688df01a8c..4ea77eede9be 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -22,9 +22,11 @@ import torch from torch import nn from transformers import LlamaConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -36,12 +38,16 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + kv_cache_scales_loader, +) from sglang.srt.utils import make_layers from sglang.utils import get_exception_traceback @@ -299,6 +305,30 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.layers[layer_idx], nn.Identity): + layer_self_attn = self.layers[layer_idx].self_attn + + if hasattr(layer_self_attn.attn, "k_scale"): + layer_self_attn.attn.k_scale = scaling_factor + layer_self_attn.attn.v_scale = scaling_factor + else: + raise RuntimeError( + "Self attention has no KV cache scaling " "factor attribute!" + ) + class LlamaForCausalLM(nn.Module): @@ -534,9 +564,16 @@ def set_embed_and_head(self, embed, head): torch.cuda.empty_cache() torch.cuda.synchronize() + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) + class Phi3ForCausalLM(LlamaForCausalLM): pass -EntryClass = [LlamaForCausalLM, Phi3ForCausalLM] +class InternLM3ForCausalLM(LlamaForCausalLM): + pass + + +EntryClass = [LlamaForCausalLM, Phi3ForCausalLM, InternLM3ForCausalLM] diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 3482a8281323..f5e69411acc0 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -18,9 +18,8 @@ import torch from torch import nn -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -31,6 +30,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index b0c93274e2b4..118be8ff6c81 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -19,20 +19,20 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.linear import ( + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, ReplicatedLinear, RowParallelLinear, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - -from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py new file mode 100644 index 000000000000..23147529a647 --- /dev/null +++ b/python/sglang/srt/models/minicpmv.py @@ -0,0 +1,1238 @@ +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only MiniCPM-V model compatible with HuggingFace weights.""" +from functools import cached_property, partial +from typing import ( + Any, + Callable, + Iterable, + List, + Literal, + Optional, + Tuple, + TypedDict, + Union, +) + +import torch +import torch.types +from PIL import Image +from torch import nn +from torch.nn.init import trunc_normal_ +from transformers import PretrainedConfig +from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.models.module_mapping import MultiModelKeys +from vllm.model_executor.sampling_metadata import SamplingMetadata + +from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import get_act_fn +from sglang.srt.layers.attention.vision import VisionAttention +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.schedule_batch import ImageInputs +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.utils import set_default_torch_dtype +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.qwen2 import Qwen2Config, Qwen2ForCausalLM + +RawImageType = Union[Image.Image, torch.Tensor] + + +class Idefics2VisionMLP(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.activation_fn = get_act_fn(config.hidden_act) + self.fc1 = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.fc2 = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class Idefics2EncoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.embed_dim = config.hidden_size + + self.num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + num_heads_per_partition = divide(self.num_heads, tp_size) + self.self_attn = VisionAttention( + embed_dim=config.hidden_size, + num_heads=num_heads_per_partition, + projection_size=config.intermediate_size, + use_qkv_parallel=True, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = Idefics2VisionMLP(config, quant_config=quant_config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + + """ + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + hidden_states = self.self_attn( + hidden_states, + cu_seqlens=cu_seqlens, + # , forward_batch=forward_batch + ) + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class Idefics2Encoder(nn.Module): + """ + Transformer encoder consisting of `config.num_hidden_layers` self attention + layers. Each layer is a + [`Idefics2EncoderLayer`]. + + Args: + config: Idefics2Config + """ + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.layers = nn.ModuleList( + [ + Idefics2EncoderLayer( + config, + quant_config=quant_config, + ) + for _ in range(config.num_hidden_layers) + ] + ) + + def forward( + self, + inputs_embeds: torch.Tensor, + cu_seqlens: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + r""" + Args: + inputs_embeds (torch.Tensor): + Optionally, instead of passing `input_ids` you can choose to + directly pass an embedded representation. + This is useful if you want more control over how to convert + `input_ids` indices into associated vectorsthan the model's + internal embedding lookup matrix. + """ + hidden_states = inputs_embeds + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch + ) + hidden_states = layer_outputs + return hidden_states + + +class Idefics2VisionEmbeddings(nn.Module): + """ + This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings + ` to enable images of variable + resolution. + + The modifications are adapted from [Patch n' Pack: NaViT, a Vision + Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304) + which allows treating images in their native aspect ratio and without the + need to resize them to the same fixed size. In particular, we start from the + original pre-trained SigLIP model(which uses images of fixed-size square + images) and adapt it by training on images of variable resolutions. + """ + + def __init__(self, config: PretrainedConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def forward( + self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None, + ) -> torch.Tensor: + batch_size, _, max_im_h, max_im_w = pixel_values.shape + target_dtype = self.patch_embedding.weight.dtype + pixel_values = pixel_values.to( + device=self.patch_embedding.weight.device, dtype=target_dtype + ) + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + max_nb_patches_h, max_nb_patches_w = ( + max_im_h // self.patch_size, + max_im_w // self.patch_size, + ) + boundaries = torch.arange( + 1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side + ) + position_ids = torch.full( + size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0 + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + + if tgt_sizes is not None: + nb_patches_h = tgt_sizes[batch_idx][0] + nb_patches_w = tgt_sizes[batch_idx][1] + else: + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + bucket_coords_h = torch.bucketize( + fractional_coords_h, boundaries, right=True + ) + bucket_coords_w = torch.bucketize( + fractional_coords_w, boundaries, right=True + ) + pos_ids = ( + bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w + ).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + position_ids = position_ids.to(self.position_embedding.weight.device) + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +class Idefics2VisionTransformer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + embed_dim = config.hidden_size + self.config = config + self.embeddings = Idefics2VisionEmbeddings(config) + self.encoder = Idefics2Encoder(config=config, quant_config=quant_config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + def get_input_embeddings(self): + return self.embeddings + + def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor: + patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,) + + # 做 prefix sum 来得到 cu_seqlens,注意在最前面插一个 0 作为 offset + cu_seqlens = torch.cat( + [ + torch.tensor([0], device=patch_len.device, dtype=torch.int32), + torch.cumsum(patch_len, dim=0, dtype=torch.int32), + ], + dim=0, + ).to(tgt_sizes.device) + return cu_seqlens + + def forward( + self, + pixel_values, + forward_batch: ForwardBatch, + patch_attention_mask: Optional[torch.BoolTensor] = None, + tgt_sizes: Optional[torch.IntTensor] = None, + ) -> torch.Tensor: + hidden_states = self.embeddings( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + # forward_batch=forward_batch, + tgt_sizes=tgt_sizes, + ) + cu_seqlens = self.compute_cu_seqlens(tgt_sizes) + encoder_outputs = self.encoder( + hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch + ) + last_hidden_state = self.post_layernorm(encoder_outputs) + return last_hidden_state + + +class MiniCPMVImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: List[torch.Tensor] + """ + Shape: `(batch_size * num_images, num_channels, height, width)` + + Note that the image size may vary, so we pass it as a list + instead of a batched tensor. + """ + + image_bounds: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(start, stop)` format. + """ + + tgt_sizes: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(height, width)` format. + """ + + +class MiniCPMVImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """ + Shape: `(batch_size * num_images, image_feature_size, hidden_size)` + + `hidden_size` must match the hidden size of language model backbone. + instead of a batched tensor. + """ + + image_bounds: torch.Tensor + """ + Shape: `(batch_size * num_images, 2)` + + This should be in `(start, stop)` format. + """ + + +MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs, MiniCPMVImageEmbeddingInputs] + +DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) + + +class BaseResampler(nn.Module): + """ + A 2D perceiver-resampler network with one cross attention layers by + (grid_size**2) learnable queries and 2d sincos pos_emb. + Outputs: + A tensor with the shape of (grid_size**2, embed_dim) + """ + + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + do_post_projection: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + self.num_queries = num_queries + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) + trunc_normal_(self.query, std=0.02) + if kv_dim is not None and kv_dim != embed_dim: + self.kv_proj = ReplicatedLinear( + kv_dim, + embed_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_proj", + ) + else: + # Maintain the same return value with ReplicatedLinear.forward + self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa + nn.Identity()(*args, **kwargs), + None, + ) + self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.ln_q = norm_layer(embed_dim) + self.ln_kv = norm_layer(embed_dim) + self.do_post_projection = do_post_projection + self.ln_post = norm_layer(embed_dim) if do_post_projection else None + self.proj = ( + nn.Parameter((embed_dim**-0.5) * torch.randn(embed_dim, embed_dim)) + if do_post_projection + else None + ) + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class Resampler2_5(BaseResampler): + + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + max_size: Tuple[int, int] = (70, 70), + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + num_queries, + embed_dim, + num_heads, + kv_dim, + norm_layer, + quant_config=quant_config, + prefix=prefix, + ) + + self.max_size = max_size + self._set_2d_pos_cache(self.max_size) + + self.apply(self._init_weights) + + def _set_2d_pos_cache( + self, max_size: Tuple[int, int], device: torch.types.Device = "cpu" + ) -> None: + pos_embed_arr = get_2d_sincos_pos_embed( + self.embed_dim, max_size, version=(2, 5) + ) + pos_embed = torch.from_numpy(pos_embed_arr).float().to(device) + self.register_buffer("pos_embed", pos_embed, persistent=False) + + def _adjust_pos_cache( + self, tgt_sizes: torch.Tensor, device: torch.types.Device + ) -> None: + max_h = tgt_sizes[:, 0].max().item() + max_w = tgt_sizes[:, 1].max().item() + assert isinstance(max_h, int) and isinstance(max_w, int) + + if max_h > self.max_size[0] or max_w > self.max_size[1]: + self.max_size = ( + max(max_h, self.max_size[0]), + max(max_w, self.max_size[1]), + ) + self._set_2d_pos_cache(self.max_size, device) + + def forward(self, x: torch.Tensor, tgt_sizes: torch.Tensor) -> torch.Tensor: + assert x.shape[0] == tgt_sizes.shape[0] + bs = x.shape[0] + + device = x.device + dtype = x.dtype + + patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] + + self._adjust_pos_cache(tgt_sizes, device=device) + + max_patch_len = patch_len.max().item() + assert isinstance(max_patch_len, int) + + key_padding_mask = torch.zeros( + (bs, max_patch_len), dtype=torch.bool, device=device + ) + + pos_embed = [] + for i in range(bs): + tgt_h, tgt_w = tgt_sizes[i].tolist() + pos_embed.append( + self.pos_embed[:tgt_h, :tgt_w, :].reshape((tgt_h * tgt_w, -1)).to(dtype) + ) # patches * D + key_padding_mask[i, patch_len[i] :] = True + pos_embed = torch.nn.utils.rnn.pad_sequence( + pos_embed, batch_first=True, padding_value=0.0 + ).permute( + 1, 0, 2 + ) # BLD => L * B * D + x, _ = self.kv_proj(x) # B * L * D + x = self.ln_kv(x).permute(1, 0, 2) # L * B * D + + q = self.ln_q(self.query) # Q * D + + out = self.attn( + self._repeat(q, bs), # Q * B * D + x + pos_embed, # L * B * D + L * B * D + x, + key_padding_mask=key_padding_mask, + )[0] + # out: Q * B * D + x = out.permute(1, 0, 2) # B * Q * D + + x = self.ln_post(x) + x = x @ self.proj + return x + + +def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]: + version_float = getattr(config, "version", None) + + # The old configs do not include version number + # TODO: Remove this after the HF repos are updated + if version_float is None: + if config.hidden_size == 2304 and config.query_num == 64: + return 2, 0 + return 2, 5 + + version_str = str(version_float) + return tuple(int(x) for x in version_str.split(".")) + + +class MiniCPMVBaseModel(nn.Module): + """ + The abstract class of MiniCPMV can only be inherited, but cannot be + instantiated. + """ + + def __init__( + self, + *, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + # multimodal_config = config.model_config.multimodal_config + super().__init__() + # All MiniCPM-V models disable `tie_word_embeddings` but + # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot + # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model + # and config class + self.config = config + # self.multimodal_config = multimodal_config + + self.version = get_version_by_config(self.config) + self.llm = self.init_llm(config=config, quant_config=quant_config) + self.vpm = self.init_vision_module(config, quant_config) + self.vision_dim = ( + self.vpm.embed_dim + if self.version == (2, 0) + else self.vpm.embeddings.embed_dim + ) + self.embed_dim = self.config.hidden_size + + self.resampler = self.init_resampler( + self.embed_dim, self.vision_dim, quant_config=quant_config + ) + + self.logits_processor = LogitsProcessor(config) + + @cached_property + def sampler(self): + if hasattr(self.llm, "sampler"): + return self.llm.sampler + + return get_sampler() + + def _get_image_bounds( + self, + input_ids: torch.Tensor, + pad_values: List[int], + im_start_id: torch.Tensor, + im_end_id: torch.Tensor, + slice_start_id: Optional[torch.Tensor] = None, + slice_end_id: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Returns a tensor indicating the bounds (start and end token ids) of the images + """ + # All the images in the batch should share the same special image + # bound token ids. + start_cond = input_ids == im_start_id[0] + end_cond = input_ids == im_end_id[0] + if slice_start_id is not None: + start_cond |= input_ids == slice_start_id[0] + end_cond |= input_ids == slice_end_id[0] + + (image_start_tokens,) = torch.where(start_cond) + image_start_tokens += 1 + (image_end_tokens,) = torch.where(end_cond) + + # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the images + if len(image_start_tokens) != len(image_end_tokens): + if ( + len(image_start_tokens) + 1 == len(image_end_tokens) + and input_ids[0] in pad_values + and image_end_tokens[0] < image_start_tokens[0] + ): + image_start_tokens = torch.cat( + [ + torch.tensor([0], device=image_start_tokens.device), + image_start_tokens, + ] + ) + valid_image_nums = min(len(image_start_tokens), len(image_end_tokens)) + + if valid_image_nums == 0: + return torch.zeros((0, 2), device=input_ids.device) + + # Filter out pairs where start_token >= end_token + valid_pairs = [] + for i in range(valid_image_nums): + start_token = image_start_tokens[i] + end_token = image_end_tokens[i] + if start_token < end_token: + valid_pairs.append((start_token, end_token)) + + if not valid_pairs: + return torch.zeros((0, 2), device=input_ids.device) + + # Convert valid pairs to tensor + valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device) + return valid_pairs_tensor + + def get_embedding( + self, + input_ids: torch.Tensor, + image_inputs: Optional[MiniCPMVImageInputs], + forward_batch: ForwardBatch, + ) -> Tuple[torch.Tensor, torch.Tensor]: + vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids) + + if image_inputs is None: # No image + vision_hidden_states = torch.tensor([], device=input_ids.device) + else: + if image_inputs["type"] == "image_embeds": + vision_hidden_states = ( + image_inputs["data"] + .type(vlm_embedding.dtype) + .to(vlm_embedding.device) + ) + else: + vision_hidden_states = self.get_vision_hidden_states( + forward_batch, image_inputs + ) + + # See NOTE in _parse_and_validate_inputs + image_bounds = image_inputs["image_bounds"] + if len(image_bounds) > 0: + image_indices = torch.stack( + [ + torch.arange(start, end, dtype=torch.long) + for start, end in image_bounds.tolist() + ] + ).to(vlm_embedding.device) + vlm_embedding.scatter_( + 0, + image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]), + vision_hidden_states.view(-1, vision_hidden_states.shape[-1]), + ) + + return vlm_embedding, vision_hidden_states + + def _parse_and_validate_inputs( + self, + input_ids: torch.Tensor, + **kwargs: object, + ) -> Optional[MiniCPMVImageInputs]: + pixel_values = kwargs.pop("pixel_values", []) + tgt_sizes = kwargs.pop("tgt_sizes", []) + im_start_id = kwargs.pop("im_start_id", None) + im_end_id = kwargs.pop("im_end_id", None) + slice_start_id = kwargs.pop("slice_start_id", None) + slice_end_id = kwargs.pop("slice_end_id", None) + image_embeds = kwargs.pop("image_embeds", None) + pad_values = kwargs.pop("pad_values", None) + + if image_embeds is not None: + image_bounds = self._get_image_bounds( + input_ids=input_ids, + pad_values=pad_values, + im_start_id=im_start_id, + im_end_id=im_end_id, + slice_start_id=slice_start_id, + slice_end_id=slice_end_id, + ) + if not isinstance(image_embeds, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of image embeds. " + f"Got type: {type(image_embeds)}" + ) + + if isinstance(image_embeds, list): + image_embeds = torch.concat(image_embeds) + + return MiniCPMVImageEmbeddingInputs( + image_bounds=image_bounds, + data=image_embeds, + type="image_embeds", + ) + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of pixel values. " f"Got type: {type(pixel_values)}" + ) + + if not isinstance(tgt_sizes, (torch.Tensor, list)): + raise ValueError( + "Incorrect type of target sizes. " f"Got type: {type(tgt_sizes)}" + ) + + if len(pixel_values) != len(tgt_sizes): + raise ValueError( + "Inconsistent batch lengths, found: " + f"{len(pixel_values)} vs. {len(tgt_sizes)}" + ) + + pixel_values_flat: List[torch.Tensor] = [] + tgt_sizes_flat: List[torch.Tensor] = [] + for pixel_b, tgt_b in zip(pixel_values, tgt_sizes): + if len(pixel_b) != len(tgt_b): + raise ValueError( + "Inconsistent N lengths, found: " f"{len(pixel_b)} vs {len(tgt_b)}" + ) + + for pixel_n, tgt_n in zip(pixel_b, tgt_b): + pixel_values_flat += pixel_n + tgt_sizes_flat += tgt_n + + # NOTE: Input IDs does not contain image tokens during memory profiling, + # so we allow it to be empty + if len(pixel_values_flat) != len(tgt_sizes_flat): + raise ValueError( + "Inconsistent flattened lengths, found: " + f"{len(pixel_values_flat)} vs. " + f"{len(tgt_sizes_flat)}" + ) + + if len(pixel_values_flat) == 0: + return None + + image_bounds = self._get_image_bounds( + input_ids=input_ids, + pad_values=pad_values, + im_start_id=im_start_id, + im_end_id=im_end_id, + slice_start_id=slice_start_id, + slice_end_id=slice_end_id, + ) + return MiniCPMVImagePixelInputs( + image_bounds=image_bounds.to(device=input_ids.device), + data=pixel_values_flat, + tgt_sizes=torch.stack(tgt_sizes_flat), + type="pixel_values", + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + **kwargs: Any, + ) -> torch.Tensor: + if forward_batch.image_inputs is not None and forward_batch.image_inputs != [ + None + ]: + kwargs.update( + { + "pixel_values": ( + None + if forward_batch.image_inputs is None + else [ + i.pixel_values + for i in forward_batch.image_inputs + if i is not None + ] + ), + "tgt_sizes": ( + None + if forward_batch.image_inputs is None + else [ + i.tgt_sizes + for i in forward_batch.image_inputs + if i is not None + ] + ), + "im_start_id": forward_batch.image_inputs[0].im_start_id, + "im_end_id": forward_batch.image_inputs[0].im_end_id, + "slice_start_id": forward_batch.image_inputs[0].slice_start_id, + "slice_end_id": forward_batch.image_inputs[0].slice_end_id, + "pad_values": forward_batch.image_inputs[0].pad_values, + } + ) + + image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) + + # Clamp input ids. This is because the input_ids for the image tokens are + # filled with the hash values of the image for the prefix matching in the radix attention. + # There values are useless because their embeddings will be replaced by vision embeddings anyway. + input_ids.clamp_(min=0, max=self.config.vocab_size - 1) + + vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs, forward_batch) + + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None + + hidden_states = self.llm.model( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + input_embeds=vlm_embeddings, + ) + + return self.logits_processor( + input_ids, hidden_states, self.llm.lm_head, forward_batch + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + return self.llm.compute_logits(hidden_states, sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def get_mm_mapping(self) -> MultiModelKeys: + """ + Get the module prefix in multimodal models + """ + return MultiModelKeys.from_string_field( + language_model="llm", connector="resampler", tower_model="vpm" + ) + + def init_llm( + self, + config: Qwen2Config, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + raise NotImplementedError + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + ) -> nn.Module: + raise NotImplementedError + + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + raise NotImplementedError + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + raise NotImplementedError + + def get_vision_hidden_states( + self, forward_batch: ForwardBatch, data: MiniCPMVImageInputs + ) -> torch.Tensor: + raise NotImplementedError + + +class MiniCPMV2_6(MiniCPMVBaseModel): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + # LoRA specific attributes + supported_lora_modules = [ + # vision encoder + "fc1", + "fc2", + "out_proj", + # language model + "qkv_proj", # same name with vision encoder + "o_proj", + "gate_up_proj", + "down_proj", + # resampler + "kv_proj", + ] + + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(config=config, quant_config=quant_config) + assert self.version == (2, 6) + + def init_llm( + self, + config: Qwen2Config, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + return Qwen2ForCausalLM(config=config, quant_config=quant_config) + + def init_vision_module( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig], + ) -> nn.Module: + model = Idefics2VisionTransformer( + config=config.vision_config, quant_config=quant_config + ) + if self.config.drop_vision_last_layer: + model.encoder.layers = model.encoder.layers[:-1] + + setattr(model, "embed_dim", model.embeddings.embed_dim) + setattr(model, "patch_size", model.embeddings.patch_size) + return model + + def init_resampler( + self, + embed_dim: int, + vision_dim: int, + quant_config: Optional[QuantizationConfig] = None, + ) -> nn.Module: + with set_default_torch_dtype(torch.float16): + # The resampler in 2.6 remains consistent with the one in 2.5. + resampler = Resampler2_5( + num_queries=self.config.query_num, + embed_dim=embed_dim, + num_heads=embed_dim // 128, + kv_dim=vision_dim, + quant_config=quant_config, + ) + + return resampler.to(device="cuda", dtype=torch.get_default_dtype()) + + def get_vision_embedding( + self, + pixel_values: List[torch.Tensor], + patch_attn_mask: Optional[torch.Tensor] = None, + tgt_sizes: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + vision_embedding = self.vpm( + pixel_values, + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes, + ) + return vision_embedding + + def get_vision_hidden_states( + self, + forward_batch: ForwardBatch, + data: MiniCPMVImageInputs, + ) -> torch.Tensor: + pixel_values = data["data"] + tgt_sizes = data["tgt_sizes"] + + device = self.vpm.embeddings.position_embedding.weight.device + dtype = self.vpm.embeddings.position_embedding.weight.dtype + all_pixel_values_lst = [ + i.flatten(end_dim=1).permute(1, 0) for i in pixel_values + ] + + max_patches = (tgt_sizes[:, 0] * tgt_sizes[:, 1]).max().item() + assert isinstance(max_patches, int) + + all_pixel_values = torch.nn.utils.rnn.pad_sequence( + all_pixel_values_lst, batch_first=True, padding_value=0.0 + ) + B, L, _ = all_pixel_values.shape + all_pixel_values = all_pixel_values.permute(0, 2, 1).reshape(B, 3, -1, L) + patch_attn_mask = torch.zeros( + (B, 1, max_patches), dtype=torch.bool, device=device + ) + for i in range(B): + patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True + vision_embedding = self.vpm( + all_pixel_values.type(dtype), + forward_batch=forward_batch, + patch_attention_mask=patch_attn_mask, + tgt_sizes=tgt_sizes, + ) + + return self.resampler(vision_embedding, tgt_sizes) + + def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): + if not isinstance(image_inputs.im_start_id, list) or not isinstance( + image_inputs.im_end_id, list + ): + return input_ids + + new_input_ids = [] + last_idx = 0 + image_idx = -1 + image_inputs.image_offsets = [] + + # Get all special token IDs + im_start_id = ( + image_inputs.im_start_id[0].item() + if isinstance(image_inputs.im_start_id[0], torch.Tensor) + else image_inputs.im_start_id[0] + ) + im_end_id = ( + image_inputs.im_end_id[0].item() + if isinstance(image_inputs.im_end_id[0], torch.Tensor) + else image_inputs.im_end_id[0] + ) + slice_start_id = ( + image_inputs.slice_start_id[0].item() + if isinstance(image_inputs.slice_start_id[0], torch.Tensor) + else image_inputs.slice_start_id[0] + ) + slice_end_id = ( + image_inputs.slice_end_id[0].item() + if isinstance(image_inputs.slice_end_id[0], torch.Tensor) + else image_inputs.slice_end_id[0] + ) + + # Find all start and end positions for both types + start_indices = [ + i + for i, x in enumerate(input_ids) + if x == im_start_id or x == slice_start_id + ] + end_indices = [ + i for i, x in enumerate(input_ids) if x == im_end_id or x == slice_end_id + ] + + if len(start_indices) != len(end_indices): + return input_ids + # Process each region (both image and slice) + for start_idx, end_idx in zip(start_indices, end_indices): + # Add non-image tokens before this region + new_input_ids.extend( + input_ids[last_idx : start_idx + 1] + ) # include start token + + is_image_start = input_ids[start_idx] == im_start_id + + if is_image_start: + image_inputs.image_offsets += [start_idx] + image_idx += 1 + + num_tokens = end_idx - start_idx - 1 # exclude start and end tokens + + # Generate pad_ids + pad_values = [image_inputs.pad_values[image_idx]] + + pad_ids = pad_values * ((num_tokens + len(pad_values)) // len(pad_values)) + pad_ids = pad_ids[:num_tokens] + + # Add pad_ids + new_input_ids.extend(pad_ids) + + # Update last_idx to after end token + last_idx = end_idx + + # Add remaining tokens after last region + new_input_ids.extend(input_ids[last_idx:]) + assert len(input_ids) == len(new_input_ids) + return new_input_ids + + +_SUPPORT_VERSION = {(2, 6): MiniCPMV2_6} + + +class MiniCPMV: + """ + Different versions of MiniCPMV use different visual encoders and LLMs, + which is not conducive to the current integration logic of LoRA and + bitsandbytes in vLLM. Therefore, it is necessary to separate them. + """ + + # Ensure that the LoRA support check passes when the class is not + # initialized, but set all these attributes to empty. + packed_modules_mapping = {} + supported_lora_modules = [] + embedding_modules = {} + embedding_padding_modules = [] + + minicpmv: nn.Module + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + if not hasattr(config, "version"): + version = (2, 6) + else: + version = str(config.version).split(".") + version = tuple([int(x) for x in version]) + # Dispatch class based on version + instance_class = _SUPPORT_VERSION.get(version) + if instance_class is None: + raise ValueError("Currently, MiniCPMV only supports versions 2.6") + + try: + minicpmv = instance_class(config=config, quant_config=quant_config) + self.minicpmv = minicpmv + except Exception as e: + print(f"Failed to instantiate MiniCPMV: {e}") + raise e + self.config = config + + def __getattr__(self, name): + if name == "minicpmv": + return None + return getattr(self.minicpmv, name) + + def __call__(self, *args, **kwargs): + return self.minicpmv(*args, **kwargs) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.minicpmv.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq~" in name or "projector" in name: + continue + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if name.startswith("model.vision_tower") and name not in params_dict: + continue + + # adapt to VisionAttention + name = name.replace(r"self_attn.out_proj", r"self_attn.proj") + + if "sampler" in name: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + # replace the name and load with customized loader + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + +EntryClass = MiniCPMV diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index 9dbdb46ff979..4ea734836afc 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -21,12 +21,11 @@ import torch from torch import nn from transformers import MixtralConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( QKVParallelLinear, @@ -38,6 +37,7 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index e5f49f5662fe..244dc7df2d06 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -23,13 +23,12 @@ import torch.nn.functional as F from torch import nn from transformers import MixtralConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( QKVParallelLinear, @@ -39,6 +38,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 019d21c20861..43f6793e4ef2 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -8,14 +8,14 @@ import torch.nn.functional as F import torch.utils.checkpoint import transformers.models.mllama.configuration_mllama as config_mllama -import vllm.distributed.parallel_state as ps from torch import nn from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast from transformers.models.mllama.modeling_mllama import ( _prepare_aspect_ratio_attention_mask, ) -from vllm.distributed import get_tensor_model_parallel_world_size +import sglang.srt.distributed.parallel_state as ps +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( diff --git a/python/sglang/srt/models/olmo.py b/python/sglang/srt/models/olmo.py index 1cfa27309fe2..4d8a79900f4c 100644 --- a/python/sglang/srt/models/olmo.py +++ b/python/sglang/srt/models/olmo.py @@ -15,14 +15,13 @@ # Adapted from # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/olmo.py#L1 """Inference-only OLMo model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, Optional, Tuple import torch from torch import nn from transformers import OlmoConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -32,6 +31,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/olmo2.py b/python/sglang/srt/models/olmo2.py old mode 100755 new mode 100644 index 0944b5720925..f3e1979f8492 --- a/python/sglang/srt/models/olmo2.py +++ b/python/sglang/srt/models/olmo2.py @@ -21,15 +21,13 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, split_tensor_along_last_dim, tensor_model_parallel_all_gather, ) -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.model_loader.weight_utils import default_weight_loader - from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -40,11 +38,13 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.utils import make_layers diff --git a/python/sglang/srt/models/olmoe.py b/python/sglang/srt/models/olmoe.py index df96be3bc94f..10b781d72ffb 100644 --- a/python/sglang/srt/models/olmoe.py +++ b/python/sglang/srt/models/olmoe.py @@ -17,30 +17,24 @@ """Inference-only OLMoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, Optional, Tuple import torch -import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig -from vllm.distributed import ( - get_tensor_model_parallel_world_size, - tensor_model_parallel_all_reduce, -) -from vllm.model_executor.layers.linear import ( - MergedColumnParallelLinear, + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - -from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput +from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/phi3_small.py b/python/sglang/srt/models/phi3_small.py index 1e70c7d7874c..b7195dbaa28b 100644 --- a/python/sglang/srt/models/phi3_small.py +++ b/python/sglang/srt/models/phi3_small.py @@ -5,9 +5,8 @@ from torch import nn from transformers import Phi3Config from transformers.configuration_utils import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, @@ -17,6 +16,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 5492a3e12214..2c99da926b60 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -20,9 +20,8 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -33,6 +32,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 2a20d6c50de1..0c01ab9e5b4b 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -20,9 +20,11 @@ import torch from torch import nn -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -34,12 +36,16 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.model_loader.weight_utils import ( + default_weight_loader, + kv_cache_scales_loader, +) from sglang.srt.utils import make_layers Qwen2Config = None @@ -242,6 +248,9 @@ def __init__( ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -265,9 +274,31 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.layers[layer_idx], nn.Identity): + layer_self_attn = self.layers[layer_idx].self_attn + if hasattr(layer_self_attn.attn, "k_scale"): + layer_self_attn.attn.k_scale = scaling_factor + layer_self_attn.attn.v_scale = scaling_factor + else: + raise RuntimeError( + "Self attention has no KV cache scaling " "factor attribute!" + ) -class Qwen2ForCausalLM(nn.Module): +class Qwen2ForCausalLM(nn.Module): # BitandBytes specific attributes default_bitsandbytes_target_modules = [ ".gate_proj.", @@ -305,6 +336,9 @@ def __init__( self.logits_processor = LogitsProcessor(config) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + @torch.no_grad() def forward( self, @@ -362,5 +396,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + self.model.load_kv_cache_scales(quantization_param_path) + EntryClass = Qwen2ForCausalLM diff --git a/python/sglang/srt/models/qwen2_eagle.py b/python/sglang/srt/models/qwen2_eagle.py new file mode 100644 index 000000000000..01069ef482cd --- /dev/null +++ b/python/sglang/srt/models/qwen2_eagle.py @@ -0,0 +1,131 @@ +""" +Copyright 2023-2024 SGLang Team +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# Adapted from +# https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py +"""Inference-only LLaMA-EAGLE model compatible with HuggingFace weights.""" + +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn + +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM + +Qwen2Config = None + + +class Qwen2DecoderLayer(Qwen2DecoderLayer): + def __init__( + self, + config: Qwen2Config, + layer_id: int = 0, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, layer_id, quant_config) + + # Skip the input_layernorm + # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 + if layer_id == 0: + del self.input_layernorm + setattr(self, "input_layernorm", lambda x: x) + + +class Qwen2Model(nn.Module): + def __init__( + self, + config: Qwen2Config, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList( + [ + Qwen2DecoderLayer( + config, i, quant_config=quant_config, prefix=f"model.layers.{i}" + ) + for i in range(config.num_hidden_layers) + ] + ) + self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + hidden_states = self.fc( + torch.cat((hidden_states, forward_batch.spec_info.hidden_states), dim=-1) + ) + + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, + ) + return hidden_states + residual + + +class Qwen2ForCausalLMEagle(Qwen2ForCausalLM): + def __init__( + self, + config: Qwen2Config, + quant_config: Optional[QuantizationConfig] = None, + cache_config=None, + ) -> None: + nn.Module.__init__(self) + self.config = config + self.quant_config = quant_config + self.model = Qwen2Model(config, quant_config=quant_config) + if self.config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, config.hidden_size, quant_config=quant_config + ) + self.logits_processor = LogitsProcessor(config) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + super().load_weights([(name, loaded_weight)]) + + +EntryClass = [Qwen2ForCausalLMEagle] diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 9db2d538234d..6183f30daf43 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -22,12 +22,11 @@ import torch.nn.functional as F from torch import nn from transformers import PretrainedConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( @@ -40,6 +39,7 @@ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 2e9ec9d8f507..0fb85679f7af 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -22,6 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2-VL model compatible with HuggingFace weights.""" +import logging from functools import lru_cache, partial from typing import Iterable, List, Optional, Tuple, Type, TypedDict @@ -30,16 +31,13 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange, repeat -from vllm.distributed import parallel_state -from vllm.distributed import utils as dist_utils -from vllm.logger import init_logger from vllm.model_executor.layers.activation import QuickGELU from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig +from sglang.srt.distributed import parallel_state +from sglang.srt.distributed import utils as dist_utils from sglang.srt.hf_transformers_utils import get_processor -from sglang.srt.layers.attention.triton_ops.prefill_attention import ( - context_attention_fwd, -) +from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType @@ -50,7 +48,8 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.qwen2 import Qwen2Model -logger = init_logger(__name__) +logger = logging.getLogger(__name__) + # === Vision Inputs === # @@ -110,118 +109,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor: - if not interleaved: - x1, x2 = x.chunk(2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - else: - x1, x2 = x[..., ::2], x[..., 1::2] - return rearrange( - torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2 - ) - - -def apply_rotary_emb_torch( - x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False -) -> torch.Tensor: - """ - x: (batch_size, seqlen, nheads, headdim) - cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) - """ - ro_dim = cos.shape[-1] * 2 - assert ro_dim <= x.shape[-1] - cos = repeat( - cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - sin = repeat( - sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)" - ) - return torch.cat( - [ - x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, - x[..., ro_dim:], - ], - dim=-1, - ) - - -def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: - t_ = t.float() - cos = freqs.cos() - sin = freqs.sin() - output = apply_rotary_emb_torch(t_, cos, sin).type_as(t) - return output - - -class Qwen2VisionAttention(nn.Module): - - def __init__( - self, - embed_dim: Optional[int] = None, - num_heads: Optional[int] = None, - projection_size: Optional[int] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - # Per attention head and per partition values. - world_size = parallel_state.get_tensor_model_parallel_world_size() - self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads - ) - self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, world_size - ) - - self.qkv = ColumnParallelLinear( - input_size=embed_dim, - output_size=3 * projection_size, - quant_config=quant_config, - ) - self.proj = RowParallelLinear( - input_size=projection_size, output_size=embed_dim, quant_config=quant_config - ) - - def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - rotary_pos_emb: torch.Tensor = None, - ) -> torch.Tensor: - # [s, b, c] --> [s, b, head * 3 * head_dim] - x, _ = self.qkv(x) - - # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head, - ) - x = x.view(*new_x_shape) - - # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim] - q, k, v = dist_utils.split_tensor_along_last_dim(x, 3) - batch_size = q.shape[1] - - q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)] - if rotary_pos_emb is not None: - q = apply_rotary_pos_emb_vision(q, rotary_pos_emb) - k = apply_rotary_pos_emb_vision(k, rotary_pos_emb) - - seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] - max_seqlen = (seq_lens).max().item() - q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] - - output = torch.empty_like(q) - context_attention_fwd( - q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False - ) - - context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() - - output, _ = self.proj(context_layer) - return output - - class Qwen2VisionBlock(nn.Module): def __init__( @@ -240,10 +127,11 @@ def __init__( self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.attn = Qwen2VisionAttention( + self.attn = VisionAttention( embed_dim=dim, num_heads=num_heads, projection_size=dim, + use_qkv_parallel=False, quant_config=quant_config, ) self.mlp = Qwen2VisionMLP( @@ -253,9 +141,13 @@ def __init__( def forward( self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor ) -> torch.Tensor: - x = x + self.attn( - self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb + hidden_states = self.norm1(x) + hidden_states = rearrange(hidden_states, "s b ... -> b s ...") + attn = self.attn( + hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb ) + attn = rearrange(attn, "b s ... -> s b ...") + x = x + attn x = x + self.mlp(self.norm2(x)) return x @@ -684,10 +576,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue @@ -696,6 +590,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: + if "visual" in name and "qkv.weight" in name: visual_num_heads = self.config.vision_config.num_heads visual_embed_dim = self.config.vision_config.embed_dim @@ -712,6 +607,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_weight = loaded_weight.view(3, visual_num_heads, head_size) loaded_weight = loaded_weight.transpose(0, 1) loaded_weight = loaded_weight.reshape(-1) + + if "visual" in name: + # adapt to VisionAttention + name = name.replace(r"attn.qkv.", r"attn.qkv_proj.") + try: # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 079d54e3c83d..c169dd6fba42 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -24,9 +24,8 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.rotary_embedding import get_rope +from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/torch_native_llama.py b/python/sglang/srt/models/torch_native_llama.py index 7a55d50457a4..7b3e5bc5ddd5 100644 --- a/python/sglang/srt/models/torch_native_llama.py +++ b/python/sglang/srt/models/torch_native_llama.py @@ -47,17 +47,17 @@ from torch import nn from torch.nn.parameter import Parameter from transformers import LlamaConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -460,7 +460,12 @@ def get_num_params(self): params_dict = dict(self.named_parameters()) return len(params_dict) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + def load_weights_to_module( + self, + fqn: str, + weights: Iterable[Tuple[str, torch.Tensor]], + ): + """Load weights onto submodule pointed by path `fqn`.""" stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), @@ -469,7 +474,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] - params_dict = dict(self.named_parameters()) + module = self.get_submodule(fqn) + params_dict = dict(module.named_parameters(prefix=fqn, recurse=False)) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name or "projector" in name: @@ -486,7 +492,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name.endswith(".bias") or name not in params_dict: continue param = params_dict[name] weight_loader = param.weight_loader @@ -494,12 +500,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if name.endswith(".bias") or name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + def load_weights( + self, + weights: Iterable[Tuple[str, torch.Tensor]], + ): + """Load weights onto the full model.""" + self.load_weights_to_module("", weights) + class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM): pass diff --git a/python/sglang/srt/models/xverse.py b/python/sglang/srt/models/xverse.py index e65514215190..7fd241823749 100644 --- a/python/sglang/srt/models/xverse.py +++ b/python/sglang/srt/models/xverse.py @@ -21,19 +21,19 @@ import torch from torch import nn from transformers import LlamaConfig -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/models/xverse_moe.py b/python/sglang/srt/models/xverse_moe.py index 9b4b27f07d26..218b96f9cb46 100644 --- a/python/sglang/srt/models/xverse_moe.py +++ b/python/sglang/srt/models/xverse_moe.py @@ -18,25 +18,25 @@ import torch from torch import nn from transformers import PretrainedConfig -from vllm.distributed import ( + +from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce, ) -from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import ( +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) -from vllm.model_executor.layers.rotary_embedding import get_rope - from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton import fused_moe from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 4fbe20846568..2ed9006c0ea2 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -180,6 +180,7 @@ class CompletionRequest(BaseModel): ignore_eos: bool = False skip_special_tokens: bool = True lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + session_params: Optional[Dict] = None class CompletionResponseChoice(BaseModel): @@ -322,6 +323,7 @@ class ChatCompletionRequest(BaseModel): ignore_eos: bool = False skip_special_tokens: bool = True lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None + session_params: Optional[Dict] = None class FunctionResponse(BaseModel): diff --git a/python/sglang/srt/sampling/custom_logit_processor.py b/python/sglang/srt/sampling/custom_logit_processor.py new file mode 100644 index 000000000000..a64b2498f239 --- /dev/null +++ b/python/sglang/srt/sampling/custom_logit_processor.py @@ -0,0 +1,38 @@ +import json +from abc import ABC, abstractmethod +from functools import lru_cache +from typing import Any, Dict, List, Optional + +import dill +import torch + + +@lru_cache(maxsize=None) +def _cache_from_str(json_str: str): + """Deserialize a json string to a Callable object. + This function is cached to avoid redundant deserialization. + """ + data = json.loads(json_str) + return dill.loads(bytes.fromhex(data["callable"])) + + +class CustomLogitProcessor(ABC): + """Abstract base class for callable functions.""" + + @abstractmethod + def __call__( + self, + logits: torch.Tensor, + custom_param_list: Optional[List[Dict[str, Any]]] = None, + ) -> torch.Tensor: + """Define the callable behavior.""" + raise NotImplementedError + + def to_str(self) -> str: + """Serialize the callable function to a JSON-compatible string.""" + return json.dumps({"callable": dill.dumps(self).hex()}) + + @classmethod + def from_str(cls, json_str: str): + """Deserialize a callable function from a JSON string.""" + return _cache_from_str(json_str) diff --git a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py index 4c293b89520d..fe687c569d4c 100644 --- a/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +++ b/python/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py @@ -3,6 +3,16 @@ import torch from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs +from sglang.srt.utils import get_compiler_backend + + +@torch.compile(dynamic=True, backend=get_compiler_backend()) +def apply_scaling_penalties(logits, scaling_penalties): + logits[:] = torch.where( + logits > 0, + logits / scaling_penalties, + logits * scaling_penalties, + ) class BatchedRepetitionPenalizer(_BatchedPenalizer): @@ -56,11 +66,8 @@ def _cumulate_output_tokens(self, output_ids: _TokenIDs): self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask] def _apply(self, logits: torch.Tensor) -> torch.Tensor: - return torch.where( - logits > 0, - logits / self.cumulated_repetition_penalties, - logits * self.cumulated_repetition_penalties, - ) + apply_scaling_penalties(logits, self.cumulated_repetition_penalties) + return logits def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor): self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep] diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 9497e53d3092..9521a34f4f6f 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -3,11 +3,15 @@ import dataclasses import logging import threading -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple import torch import sglang.srt.sampling.penaltylib as penaltylib +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor +from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import ( + apply_scaling_penalties, +) logger = logging.getLogger(__name__) @@ -30,6 +34,9 @@ class SamplingBatchInfo: # Dispatch in CUDA graph need_min_p_sampling: bool + # Whether any request has custom logit processor + has_custom_logit_processor: bool + # Bias Tensors vocab_size: int grammars: Optional[List] = None @@ -46,6 +53,14 @@ class SamplingBatchInfo: # Device device: str = "cuda" + # Custom Parameters + custom_params: Optional[List[Optional[Dict[str, Any]]]] = None + + # Custom Logit Processor + custom_logit_processor: Optional[ + Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]] + ] = None + @classmethod def from_schedule_batch( cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool @@ -70,6 +85,39 @@ def from_schedule_batch( [r.sampling_params.min_p for r in reqs], dtype=torch.float ).to(device, non_blocking=True) + # Check if any request has custom logit processor + has_custom_logit_processor = ( + batch.enable_custom_logit_processor # check the flag first. + and any(r.custom_logit_processor for r in reqs) # then check the requests. + ) + + if has_custom_logit_processor: + # Merge the same type of custom logit processors together + processor_dict = {} + for i, r in enumerate(reqs): + if r.custom_logit_processor is None: + continue + processor_str = r.custom_logit_processor + if processor_str not in processor_dict: + processor_dict[processor_str] = [] + processor_dict[processor_str].append(i) + + merged_custom_logit_processor = { + hash(processor_str): ( + # The deserialized custom logit processor object + CustomLogitProcessor.from_str(processor_str), + # The mask tensor for the requests that use this custom logit processor + torch.zeros(len(reqs), dtype=torch.bool) + .scatter_(0, torch.tensor(true_indices), True) + .to(device, non_blocking=True), + ) + for processor_str, true_indices in processor_dict.items() + } + custom_params = [r.sampling_params.custom_params for r in reqs] + else: + merged_custom_logit_processor = None + custom_params = None + ret = cls( temperatures=temperatures, top_ps=top_ps, @@ -77,8 +125,11 @@ def from_schedule_batch( min_ps=min_ps, need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs), is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs), + has_custom_logit_processor=has_custom_logit_processor, vocab_size=vocab_size, device=device, + custom_params=custom_params, + custom_logit_processor=merged_custom_logit_processor, ) # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. @@ -178,6 +229,8 @@ def update_regex_vocab_mask(self): def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor): self.penalizer_orchestrator.filter(unfinished_indices, new_indices) + if self.has_custom_logit_processor: + self._filter_batch_custom_logit_processor(unfinished_indices, new_indices) for item in [ "temperatures", @@ -190,6 +243,27 @@ def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor) if value is not None: # logit_bias can be None setattr(self, item, value[new_indices]) + def _filter_batch_custom_logit_processor( + self, unfinished_indices: List[int], new_indices: torch.Tensor + ): + """Filter the custom logit processor and custom params""" + + self.custom_logit_processor = { + k: (p, mask[new_indices]) + for k, (p, mask) in self.custom_logit_processor.items() + if any( + mask[new_indices] + ) # ignore the custom logit processor whose mask is all False + } + self.custom_params = [self.custom_params[i] for i in unfinished_indices] + + # If the custom logit processor is an empty dict, set the flag to False, + # and set the custom logit processor and custom params to None. + if len(self.custom_logit_processor) == 0: + self.custom_logit_processor = None + self.custom_params = None + self.has_custom_logit_processor = False + @staticmethod def merge_bias_tensor( lhs: torch.Tensor, @@ -215,9 +289,76 @@ def merge_bias_tensor( return None + @staticmethod + def merge_custom_logit_processor( + lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]], + rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]], + bs1: int, + bs2: int, + device: str, + ): + if lhs is None and rhs is None: + return None + lhs, rhs = lhs or {}, rhs or {} + + keys = set(lhs.keys()).union(set(rhs.keys())) + merged_dict = {} + + for k in keys: + # Get the logit processor object + processor = lhs[k][0] if k in lhs else rhs[k][0] + # Get and merge the mask tensors from the two dicts + left_mask = ( + lhs[k][1] + if k in lhs + else torch.zeros(bs1, dtype=torch.bool, device=device) + ) + right_mask = ( + rhs[k][1] + if k in rhs + else torch.zeros(bs2, dtype=torch.bool, device=device) + ) + merged_dict[k] = (processor, torch.cat([left_mask, right_mask])) + + assert merged_dict[k][1].shape[0] == bs1 + bs2, ( + f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match " + f"the sum of the batch sizes of the two masks ({bs1 + bs2})" + f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}" + f"\n{lhs=}\n{rhs=}" + ) + + return merged_dict + def merge_batch(self, other: "SamplingBatchInfo"): self.penalizer_orchestrator.merge(other.penalizer_orchestrator) + # Merge the logit bias tensor + self.logit_bias = SamplingBatchInfo.merge_bias_tensor( + self.logit_bias, other.logit_bias, len(self), len(other), self.device + ) + # Merge the custom logit processors and custom params lists + if self.has_custom_logit_processor or other.has_custom_logit_processor: + # Merge the custom logit processors + self.custom_logit_processor = ( + SamplingBatchInfo.merge_custom_logit_processor( + self.custom_logit_processor, + other.custom_logit_processor, + len(self), + len(other), + self.device, + ) + ) + # Merge the custom params lists + self.custom_params = self.custom_params or [None] * len(self) + other.custom_params = other.custom_params or [None] * len(other) + self.custom_params.extend(other.custom_params) + + # Set the flag to True if any of the two has custom logit processor + self.has_custom_logit_processor = True + + # Note: becasue the __len()__ operator is defined on the temperatures tensor, + # please make sure any merge operation with len(self) or len(other) is done before + # the merge operation of the temperatures tensor below. for item in [ "temperatures", "top_ps", @@ -229,9 +370,6 @@ def merge_batch(self, other: "SamplingBatchInfo"): setattr(self, item, torch.concat([self_val, other_val])) self.is_all_greedy = self.is_all_greedy and other.is_all_greedy - self.logit_bias = SamplingBatchInfo.merge_bias_tensor( - self.logit_bias, other.logit_bias, len(self), len(other), self.device - ) self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling def apply_logits_bias(self, logits: torch.Tensor): @@ -245,11 +383,7 @@ def apply_logits_bias(self, logits: torch.Tensor): # repetition if self.scaling_penalties is not None: - logits[:] = torch.where( - logits > 0, - logits / self.scaling_penalties, - logits * self.scaling_penalties, - ) + apply_scaling_penalties(logits, self.scaling_penalties) # Apply regex vocab_mask if self.vocab_mask is not None: diff --git a/python/sglang/srt/sampling/sampling_params.py b/python/sglang/srt/sampling/sampling_params.py index 2c3817e1b795..2224fb0919a1 100644 --- a/python/sglang/srt/sampling/sampling_params.py +++ b/python/sglang/srt/sampling/sampling_params.py @@ -13,7 +13,7 @@ # ============================================================================== """Sampling parameters for text generation.""" -from typing import List, Optional, Union +from typing import Any, Dict, List, Optional, Union _SAMPLING_EPS = 1e-6 @@ -23,7 +23,7 @@ class SamplingParams: The sampling parameters. See docs/references/sampling_params.md or - https://sgl-project.github.io/references/sampling_params.html + https://docs.sglang.ai/references/sampling_params.html for the documentation. """ @@ -48,6 +48,7 @@ def __init__( no_stop_trim: bool = False, ignore_eos: bool = False, skip_special_tokens: bool = True, + custom_params: Optional[Dict[str, Any]] = None, ) -> None: self.temperature = temperature self.top_p = top_p @@ -71,6 +72,7 @@ def __init__( self.json_schema = json_schema self.ebnf = ebnf self.no_stop_trim = no_stop_trim + self.custom_params = custom_params # Process some special cases if self.temperature < _SAMPLING_EPS: diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index f60af5d73153..869a984d0cf9 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -11,1036 +11,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -""" -The entry point of inference server. -SRT = SGLang Runtime. -""" -import asyncio -import atexit -import dataclasses -import json -import logging -import multiprocessing as mp -import os -import signal -import threading -import time -from http import HTTPStatus -from typing import AsyncIterator, Dict, List, Optional, Tuple, Union - -import torch - -# Fix a bug of Python threading -setattr(threading, "_register_atexit", lambda *args, **kwargs: None) - -import aiohttp -import orjson -import requests -import uvicorn -import uvloop -from fastapi import FastAPI, File, Form, Request, UploadFile -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import ORJSONResponse, Response, StreamingResponse -from uvicorn.config import LOGGING_CONFIG - -from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint -from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.managers.data_parallel_controller import ( - run_data_parallel_controller_process, -) -from sglang.srt.managers.detokenizer_manager import run_detokenizer_process -from sglang.srt.managers.io_struct import ( - CloseSessionReqInput, - EmbeddingReqInput, - GenerateReqInput, - GetWeightsByNameReqInput, - InitWeightsUpdateGroupReqInput, - OpenSessionReqInput, - UpdateWeightFromDiskReqInput, - UpdateWeightsFromDistributedReqInput, - UpdateWeightsFromTensorReqInput, -) -from sglang.srt.managers.scheduler import run_scheduler_process -from sglang.srt.managers.tokenizer_manager import TokenizerManager -from sglang.srt.metrics.func_timer import enable_func_timer, time_func_latency -from sglang.srt.openai_api.adapter import ( - load_chat_template_for_openai_api, - v1_batches, - v1_cancel_batch, - v1_chat_completions, - v1_completions, - v1_delete_file, - v1_embeddings, - v1_files_create, - v1_retrieve_batch, - v1_retrieve_file, - v1_retrieve_file_content, -) -from sglang.srt.openai_api.protocol import ModelCard, ModelList -from sglang.srt.server_args import PortArgs, ServerArgs -from sglang.srt.utils import ( - MultiprocessingSerializer, - add_api_key_middleware, - add_prometheus_middleware, - assert_pkg_version, - configure_logger, - delete_directory, - is_port_available, - kill_process_tree, - maybe_set_triton_cache_manager, - prepare_model_and_tokenizer, - set_prometheus_multiproc_dir, - set_ulimit, -) -from sglang.utils import get_exception_traceback -from sglang.version import __version__ - -logger = logging.getLogger(__name__) - -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) - -# Fast API -app = FastAPI() -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) - -tokenizer_manager: TokenizerManager = None -scheduler_info: Dict = None - - -##### Native API endpoints ##### - - -@app.get("/health") -async def health() -> Response: - """Check the health of the http server.""" - return Response(status_code=200) - - -@app.get("/health_generate") -async def health_generate(request: Request) -> Response: - """Check the health of the inference server by generating one token.""" - - sampling_params = {"max_new_tokens": 1, "temperature": 0.7} - - if tokenizer_manager.is_generation: - gri = GenerateReqInput(input_ids=[0], sampling_params=sampling_params) - else: - gri = EmbeddingReqInput(input_ids=[0], sampling_params=sampling_params) - - try: - async for _ in tokenizer_manager.generate_request(gri, request): - break - return Response(status_code=200) - except Exception as e: - logger.exception(e) - return Response(status_code=503) - - -@app.get("/get_model_info") -async def get_model_info(): - """Get the model information.""" - result = { - "model_path": tokenizer_manager.model_path, - "tokenizer_path": tokenizer_manager.server_args.tokenizer_path, - "is_generation": tokenizer_manager.is_generation, - } - return result - - -@app.get("/get_server_info") -async def get_server_info(): - return { - **dataclasses.asdict(tokenizer_manager.server_args), # server args - **scheduler_info, - "version": __version__, - } - - -@app.post("/flush_cache") -async def flush_cache(): - """Flush the radix cache.""" - tokenizer_manager.flush_cache() - return Response( - content="Cache flushed.\nPlease check backend logs for more details. " - "(When there are running or waiting requests, the operation will not be performed.)\n", - status_code=200, - ) - - -@app.get("/start_profile") -@app.post("/start_profile") -async def start_profile_async(): - """Start profiling.""" - tokenizer_manager.start_profile() - return Response( - content="Start profiling.\n", - status_code=200, - ) - - -@app.get("/stop_profile") -@app.post("/stop_profile") -async def stop_profile_async(): - """Stop profiling.""" - tokenizer_manager.stop_profile() - return Response( - content="Stop profiling. This will take some time.\n", - status_code=200, - ) - - -@app.post("/update_weights_from_disk") -@time_func_latency -async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request): - """Update the weights from disk in-place without re-launching the server.""" - success, message = await tokenizer_manager.update_weights_from_disk(obj, request) - content = {"success": success, "message": message} - if success: - return ORJSONResponse( - content, - status_code=HTTPStatus.OK, - ) - else: - return ORJSONResponse( - content, - status_code=HTTPStatus.BAD_REQUEST, - ) - - -@app.post("/init_weights_update_group") -async def init_weights_update_group( - obj: InitWeightsUpdateGroupReqInput, request: Request -): - """Initialize the parameter update group.""" - success, message = await tokenizer_manager.init_weights_update_group(obj, request) - content = {"success": success, "message": message} - if success: - return ORJSONResponse(content, status_code=200) - else: - return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) - - -@app.post("/update_weights_from_distributed") -async def update_weights_from_distributed( - obj: UpdateWeightsFromDistributedReqInput, request: Request -): - """Update model parameter from distributed online.""" - success, message = await tokenizer_manager.update_weights_from_distributed( - obj, request - ) - content = {"success": success, "message": message} - if success: - return ORJSONResponse(content, status_code=200) - else: - return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) - - -@app.api_route("/get_weights_by_name", methods=["GET", "POST"]) -async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): - """Get model parameter by name.""" - try: - ret = await tokenizer_manager.get_weights_by_name(obj, request) - if ret is None: - return _create_error_response("Get parameter by name failed") - else: - return ORJSONResponse(ret, status_code=200) - except Exception as e: - return _create_error_response(e) - - -@app.api_route("/open_session", methods=["GET", "POST"]) -async def open_session(obj: OpenSessionReqInput, request: Request): - """Open a session, and return its unique session id.""" - try: - session_id = await tokenizer_manager.open_session(obj, request) - if session_id is None: - raise Exception( - "Failed to open the session. Check if a session with the same id is still open." - ) - return session_id - except Exception as e: - return _create_error_response(e) - - -@app.api_route("/close_session", methods=["GET", "POST"]) -async def close_session(obj: CloseSessionReqInput, request: Request): - """Close the session""" - try: - await tokenizer_manager.close_session(obj, request) - return Response(status_code=200) - except Exception as e: - return _create_error_response(e) - - -# fastapi implicitly converts json in the request to obj (dataclass) -@app.api_route("/generate", methods=["POST", "PUT"]) -@time_func_latency -async def generate_request(obj: GenerateReqInput, request: Request): - """Handle a generate request.""" - if obj.stream: - - async def stream_results() -> AsyncIterator[bytes]: - try: - async for out in tokenizer_manager.generate_request(obj, request): - yield b"data: " + orjson.dumps( - out, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" - except ValueError as e: - out = {"error": {"message": str(e)}} - yield b"data: " + orjson.dumps( - out, option=orjson.OPT_NON_STR_KEYS - ) + b"\n\n" - yield b"data: [DONE]\n\n" - - return StreamingResponse( - stream_results(), - media_type="text/event-stream", - background=tokenizer_manager.create_abort_task(obj), - ) - else: - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - logger.error(f"Error: {e}") - return _create_error_response(e) - - -@app.api_route("/encode", methods=["POST", "PUT"]) -@time_func_latency -async def encode_request(obj: EmbeddingReqInput, request: Request): - """Handle an embedding request.""" - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - return _create_error_response(e) - - -@app.api_route("/classify", methods=["POST", "PUT"]) -@time_func_latency -async def classify_request(obj: EmbeddingReqInput, request: Request): - """Handle a reward model request. Now the arguments and return values are the same as embedding models.""" - try: - ret = await tokenizer_manager.generate_request(obj, request).__anext__() - return ret - except ValueError as e: - return _create_error_response(e) - - -##### OpenAI-compatible API endpoints ##### - - -@app.post("/v1/completions") -@time_func_latency -async def openai_v1_completions(raw_request: Request): - return await v1_completions(tokenizer_manager, raw_request) - - -@app.post("/v1/chat/completions") -@time_func_latency -async def openai_v1_chat_completions(raw_request: Request): - return await v1_chat_completions(tokenizer_manager, raw_request) - - -@app.post("/v1/embeddings", response_class=ORJSONResponse) -@time_func_latency -async def openai_v1_embeddings(raw_request: Request): - response = await v1_embeddings(tokenizer_manager, raw_request) - return response - - -@app.get("/v1/models", response_class=ORJSONResponse) -def available_models(): - """Show available models.""" - served_model_names = [tokenizer_manager.served_model_name] - model_cards = [] - for served_model_name in served_model_names: - model_cards.append(ModelCard(id=served_model_name, root=served_model_name)) - return ModelList(data=model_cards) - - -@app.post("/v1/files") -async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")): - return await v1_files_create( - file, purpose, tokenizer_manager.server_args.file_storage_pth - ) - - -@app.delete("/v1/files/{file_id}") -async def delete_file(file_id: str): - # https://platform.openai.com/docs/api-reference/files/delete - return await v1_delete_file(file_id) - - -@app.post("/v1/batches") -async def openai_v1_batches(raw_request: Request): - return await v1_batches(tokenizer_manager, raw_request) - - -@app.post("/v1/batches/{batch_id}/cancel") -async def cancel_batches(batch_id: str): - # https://platform.openai.com/docs/api-reference/batch/cancel - return await v1_cancel_batch(tokenizer_manager, batch_id) - - -@app.get("/v1/batches/{batch_id}") -async def retrieve_batch(batch_id: str): - return await v1_retrieve_batch(batch_id) - - -@app.get("/v1/files/{file_id}") -async def retrieve_file(file_id: str): - # https://platform.openai.com/docs/api-reference/files/retrieve - return await v1_retrieve_file(file_id) - - -@app.get("/v1/files/{file_id}/content") -async def retrieve_file_content(file_id: str): - # https://platform.openai.com/docs/api-reference/files/retrieve-contents - return await v1_retrieve_file_content(file_id) - - -def _create_error_response(e): - return ORJSONResponse( - {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST - ) - - -def launch_engine( - server_args: ServerArgs, -): - """ - Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess. - """ - - global tokenizer_manager - global scheduler_info - - # Configure global environment - configure_logger(server_args) - server_args.check_server_args() - _set_envs_and_config(server_args) - - # Allocate ports for inter-process communications - port_args = PortArgs.init_new(server_args) - logger.info(f"{server_args=}") - - # If using model from www.modelscope.cn, first download the model. - server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer( - server_args.model_path, server_args.tokenizer_path - ) - - if server_args.dp_size == 1: - # Launch tensor parallel scheduler processes - scheduler_procs = [] - scheduler_pipe_readers = [] - tp_size_per_node = server_args.tp_size // server_args.nnodes - tp_rank_range = range( - tp_size_per_node * server_args.node_rank, - tp_size_per_node * (server_args.node_rank + 1), - ) - for tp_rank in tp_rank_range: - reader, writer = mp.Pipe(duplex=False) - gpu_id = server_args.base_gpu_id + tp_rank % tp_size_per_node - proc = mp.Process( - target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, None, writer), - ) - proc.start() - scheduler_procs.append(proc) - scheduler_pipe_readers.append(reader) - - if server_args.node_rank >= 1: - # For other nodes, they do not need to run tokenizer or detokenizer, - # so they can just wait here. - for proc in scheduler_procs: - proc.join() - else: - # Launch the data parallel controller - reader, writer = mp.Pipe(duplex=False) - scheduler_pipe_readers = [reader] - proc = mp.Process( - target=run_data_parallel_controller_process, - args=(server_args, port_args, writer), - ) - proc.start() - - # Launch detokenizer process - detoken_proc = mp.Process( - target=run_detokenizer_process, - args=( - server_args, - port_args, - ), - ) - detoken_proc.start() - - # Launch tokenizer process - tokenizer_manager = TokenizerManager(server_args, port_args) - if server_args.chat_template: - load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) - - # Wait for model to finish loading - scheduler_infos = [] - for i in range(len(scheduler_pipe_readers)): - try: - data = scheduler_pipe_readers[i].recv() - except EOFError as e: - logger.exception(e) - logger.error( - f"Rank {i} scheduler is dead. Please check if there are relevant logs." - ) - scheduler_procs[i].join() - logger.error(f"Exit code: {scheduler_procs[i].exitcode}") - raise - - if data["status"] != "ready": - raise RuntimeError( - "Initialization failed. Please see the error messages above." - ) - scheduler_infos.append(data) - - # Assume all schedulers have same scheduler_info - scheduler_info = scheduler_infos[0] - - -def launch_server( - server_args: ServerArgs, - pipe_finish_writer: Optional[mp.connection.Connection] = None, -): - """ - Launch SRT (SGLang Runtime) Server - - The SRT server consists of an HTTP server and the SRT engine. - - 1. HTTP server: A FastAPI server that routes requests to the engine. - 2. SRT engine: - 1. TokenizerManager: Tokenizes the requests and sends them to the scheduler. - 2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager. - 3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager. - - Note: - 1. The HTTP server and TokenizerManager both run in the main process. - 2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library. - """ - launch_engine(server_args=server_args) - - # Add api key authorization - if server_args.api_key: - add_api_key_middleware(app, server_args.api_key) - - # Add prometheus middleware - if server_args.enable_metrics: - add_prometheus_middleware(app) - enable_func_timer() - - # Send a warmup request - t = threading.Thread( - target=_wait_and_warmup, args=(server_args, pipe_finish_writer) - ) - t.start() - - try: - # Update logging configs - LOGGING_CONFIG["formatters"]["default"][ - "fmt" - ] = "[%(asctime)s] %(levelprefix)s %(message)s" - LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S" - LOGGING_CONFIG["formatters"]["access"][ - "fmt" - ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' - LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" - - # Listen for HTTP requests - uvicorn.run( - app, - host=server_args.host, - port=server_args.port, - log_level=server_args.log_level_http or server_args.log_level, - timeout_keep_alive=5, - loop="uvloop", - ) - finally: - t.join() - - -def _set_envs_and_config(server_args: ServerArgs): - # Set global environments - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" - os.environ["NCCL_CUMEM_ENABLE"] = "0" - os.environ["NCCL_NVLS_ENABLE"] = "0" - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" - - # Set prometheus env vars - if server_args.enable_metrics: - set_prometheus_multiproc_dir() - - # Set ulimit - set_ulimit() - - # Fix triton bugs - if server_args.tp_size * server_args.dp_size > 1: - # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency. - maybe_set_triton_cache_manager() - - # Check flashinfer version - if server_args.attention_backend == "flashinfer": - assert_pkg_version( - "flashinfer", - "0.1.6", - "Please uninstall the old version and " - "reinstall the latest version by following the instructions " - "at https://docs.flashinfer.ai/installation.html.", - ) - - # Register the signal handler. - # The child processes will send SIGQUIT to this process when any error happens - # This process then clean up the whole process tree - def sigquit_handler(signum, frame): - kill_process_tree(os.getpid()) - - signal.signal(signal.SIGQUIT, sigquit_handler) - - # Set mp start method - mp.set_start_method("spawn", force=True) - - -def _wait_and_warmup(server_args, pipe_finish_writer): - headers = {} - url = server_args.url() - if server_args.api_key: - headers["Authorization"] = f"Bearer {server_args.api_key}" - - # Wait until the server is launched - success = False - for _ in range(120): - time.sleep(1) - try: - res = requests.get(url + "/get_model_info", timeout=5, headers=headers) - assert res.status_code == 200, f"{res=}, {res.text=}" - success = True - break - except (AssertionError, requests.exceptions.RequestException): - last_traceback = get_exception_traceback() - pass - - if not success: - if pipe_finish_writer is not None: - pipe_finish_writer.send(last_traceback) - logger.error(f"Initialization failed. warmup error: {last_traceback}") - kill_process_tree(os.getpid()) - return - - model_info = res.json() - - # Send a warmup request - request_name = "/generate" if model_info["is_generation"] else "/encode" - max_new_tokens = 8 if model_info["is_generation"] else 1 - json_data = { - "sampling_params": { - "temperature": 0, - "max_new_tokens": max_new_tokens, - }, - } - if server_args.skip_tokenizer_init: - json_data["input_ids"] = [10, 11, 12] - else: - json_data["text"] = "The capital city of France is" - - try: - for _ in range(server_args.dp_size): - res = requests.post( - url + request_name, - json=json_data, - headers=headers, - timeout=600, - ) - assert res.status_code == 200, f"{res}" - except Exception: - last_traceback = get_exception_traceback() - if pipe_finish_writer is not None: - pipe_finish_writer.send(last_traceback) - logger.error(f"Initialization failed. warmup error: {last_traceback}") - kill_process_tree(os.getpid()) - return - - # Debug print - # logger.info(f"{res.json()=}") - - logger.info("The server is fired up and ready to roll!") - if pipe_finish_writer is not None: - pipe_finish_writer.send("ready") - - if server_args.delete_ckpt_after_loading: - delete_directory(server_args.model_path) - - -STREAM_END_SYMBOL = b"data: [DONE]" -STREAM_CHUNK_START_SYMBOL = b"data:" - - -class Engine: - """ - SRT Engine without an HTTP server layer. - - This class provides a direct inference engine without the need for an HTTP server. It is designed for use cases where - launching the HTTP server adds unnecessary complexity or overhead, - """ - - def __init__(self, log_level: str = "error", *args, **kwargs): - """See the arguments in server_args.py::ServerArgs""" - - # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() - atexit.register(self.shutdown) - - server_args = ServerArgs(*args, log_level=log_level, **kwargs) - launch_engine(server_args=server_args) - - def generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Union[List[Dict], Dict]] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - lora_path: Optional[List[Optional[str]]] = None, - stream: bool = False, - ): - obj = GenerateReqInput( - text=prompt, - input_ids=input_ids, - sampling_params=sampling_params, - return_logprob=return_logprob, - logprob_start_len=logprob_start_len, - top_logprobs_num=top_logprobs_num, - lora_path=lora_path, - stream=stream, - ) - - # get the current event loop - loop = asyncio.get_event_loop() - ret = loop.run_until_complete(generate_request(obj, None)) - - if stream is True: - - def generator_wrapper(): - offset = 0 - loop = asyncio.get_event_loop() - generator = ret.body_iterator - while True: - chunk = loop.run_until_complete(generator.__anext__()) - - if chunk.startswith(STREAM_END_SYMBOL): - break - else: - data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :]) - data["text"] = data["text"][offset:] - offset += len(data["text"]) - yield data - - # we cannot yield in the scope of generate() because python does not allow yield + return in the same function - # however, it allows to wrap the generator as a subfunction and return - return generator_wrapper() - else: - return ret - - async def async_generate( - self, - # The input prompt. It can be a single prompt or a batch of prompts. - prompt: Optional[Union[List[str], str]] = None, - sampling_params: Optional[Dict] = None, - # The token ids for text; one can either specify text or input_ids. - input_ids: Optional[Union[List[List[int]], List[int]]] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - lora_path: Optional[List[Optional[str]]] = None, - stream: bool = False, - ): - obj = GenerateReqInput( - text=prompt, - input_ids=input_ids, - sampling_params=sampling_params, - return_logprob=return_logprob, - logprob_start_len=logprob_start_len, - top_logprobs_num=top_logprobs_num, - lora_path=lora_path, - stream=stream, - ) - - ret = await generate_request(obj, None) - - if stream is True: - generator = ret.body_iterator - - async def generator_wrapper(): - - offset = 0 - - while True: - chunk = await generator.__anext__() - - if chunk.startswith(STREAM_END_SYMBOL): - break - else: - data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :]) - data["text"] = data["text"][offset:] - offset += len(data["text"]) - yield data - - return generator_wrapper() - else: - return ret - - def shutdown(self): - kill_process_tree(os.getpid(), include_parent=False) - - def get_tokenizer(self): - global tokenizer_manager - - if tokenizer_manager is None: - raise ReferenceError("Tokenizer Manager is not initialized.") - else: - return tokenizer_manager.tokenizer - - def encode( - self, - prompt: Union[str, List[str], List[Dict], List[List[Dict]]], - ): - obj = EmbeddingReqInput(text=prompt) - - # get the current event loop - loop = asyncio.get_event_loop() - return loop.run_until_complete(encode_request(obj, None)) - - def start_profile(self): - tokenizer_manager.start_profile() - - def stop_profile(self): - tokenizer_manager.stop_profile() - - def get_server_info(self): - return { - **dataclasses.asdict(tokenizer_manager.server_args), # server args - **scheduler_info, - "version": __version__, - } - - def init_weights_update_group( - self, - master_address: str, - master_port: int, - rank_offset: int, - world_size: int, - group_name: str, - backend: str = "nccl", - ): - """Initialize parameter update group.""" - obj = InitWeightsUpdateGroupReqInput( - master_address=master_address, - master_port=master_port, - rank_offset=rank_offset, - world_size=world_size, - group_name=group_name, - backend=backend, - ) - loop = asyncio.get_event_loop() - return loop.run_until_complete( - tokenizer_manager.init_weights_update_group(obj, None) - ) - - def update_weights_from_distributed(self, name, dtype, shape): - """Update weights from distributed source.""" - obj = UpdateWeightsFromDistributedReqInput( - name=name, - dtype=dtype, - shape=shape, - ) - loop = asyncio.get_event_loop() - return loop.run_until_complete( - tokenizer_manager.update_weights_from_distributed(obj, None) - ) - - def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]): - """Update weights from distributed source.""" - obj = UpdateWeightsFromTensorReqInput( - serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors) - ) - loop = asyncio.get_event_loop() - return loop.run_until_complete( - tokenizer_manager.update_weights_from_tensor(obj, None) - ) - - def get_weights_by_name(self, name, truncate_size=100): - """Get weights by parameter name.""" - obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) - loop = asyncio.get_event_loop() - return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None)) - - -class Runtime: - """ - A wrapper for the HTTP server. - This is used for launching the server in a python program without - using the commond line interface. - - It is mainly used for the frontend language. - You should use the Engine class above if you want to do normal offline processing. - """ - - def __init__( - self, - log_level: str = "error", - *args, - **kwargs, - ): - """See the arguments in server_args.py::ServerArgs""" - self.server_args = ServerArgs(*args, log_level=log_level, **kwargs) - - # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown() - atexit.register(self.shutdown) - - # Pre-allocate ports - for port in range(self.server_args.port, 40000): - if is_port_available(port): - break - self.server_args.port = port - - self.url = self.server_args.url() - self.generate_url = self.url + "/generate" - - # NOTE: We store pid instead of proc to fix some issues during __delete__ - self.pid = None - pipe_reader, pipe_writer = mp.Pipe(duplex=False) - - proc = mp.Process( - target=launch_server, - args=(self.server_args, pipe_writer), - ) - proc.start() - pipe_writer.close() - self.pid = proc.pid - - try: - init_state = pipe_reader.recv() - except EOFError: - init_state = "" - - if init_state != "ready": - self.shutdown() - raise RuntimeError( - "Initialization failed. Please see the error messages above." - ) - - self.endpoint = RuntimeEndpoint(self.url) - - def shutdown(self): - if self.pid is not None: - kill_process_tree(self.pid) - self.pid = None - - def cache_prefix(self, prefix: str): - self.endpoint.cache_prefix(prefix) - - def get_tokenizer(self): - return get_tokenizer( - self.server_args.tokenizer_path, - tokenizer_mode=self.server_args.tokenizer_mode, - trust_remote_code=self.server_args.trust_remote_code, - ) - - async def async_generate( - self, - prompt: str, - sampling_params: Optional[Dict] = None, - ): - if self.server_args.skip_tokenizer_init: - json_data = { - "input_ids": prompt, - "sampling_params": sampling_params, - "stream": True, - } - else: - json_data = { - "text": prompt, - "sampling_params": sampling_params, - "stream": True, - } - pos = 0 - - timeout = aiohttp.ClientTimeout(total=3 * 3600) - async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session: - async with session.post(self.generate_url, json=json_data) as response: - async for chunk, _ in response.content.iter_chunks(): - chunk = chunk.decode("utf-8") - if chunk and chunk.startswith("data:"): - if chunk == "data: [DONE]\n\n": - break - data = json.loads(chunk[5:].strip("\n")) - if "text" in data: - cur = data["text"][pos:] - if cur: - yield cur - pos += len(cur) - else: - yield data - - add_request = async_generate - - def generate( - self, - prompt: Union[str, List[str]], - sampling_params: Optional[Dict] = None, - return_logprob: Optional[Union[List[bool], bool]] = False, - logprob_start_len: Optional[Union[List[int], int]] = None, - top_logprobs_num: Optional[Union[List[int], int]] = None, - lora_path: Optional[List[Optional[str]]] = None, - ): - json_data = { - "text": prompt, - "sampling_params": sampling_params, - "return_logprob": return_logprob, - "logprob_start_len": logprob_start_len, - "top_logprobs_num": top_logprobs_num, - "lora_path": lora_path, - } - assert not isinstance(lora_path, list) or len(lora_path) == len(prompt) - response = requests.post( - self.url + "/generate", - json=json_data, - ) - return json.dumps(response.json()) - - def encode( - self, - prompt: Union[str, List[str], List[Dict], List[List[Dict]]], - ): - json_data = {"text": prompt} - response = requests.post(self.url + "/encode", json=json_data) - return json.dumps(response.json()) - - async def get_server_info(self): - async with aiohttp.ClientSession() as session: - async with session.get(f"{self.url}/get_server_info") as response: - if response.status == 200: - return await response.json() - else: - error_data = await response.json() - raise RuntimeError( - f"Failed to get server info. {error_data['error']['message']}" - ) - - def __del__(self): - self.shutdown() +# Some shortcuts for backward compatibility. +# They will be removed in new versions. +from sglang.srt.entrypoints.engine import Engine +from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ef4df60a5763..330c38132885 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -23,15 +23,15 @@ import torch from sglang.srt.hf_transformers_utils import check_gguf_file -from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( get_amdgpu_memory_capacity, get_hpu_memory_capacity, get_nvgpu_memory_capacity, is_flashinfer_available, is_hip, - is_ipv6, is_port_available, + is_valid_ipv6_address, + nullable_str, ) logger = logging.getLogger(__name__) @@ -47,6 +47,7 @@ class ServerArgs: trust_remote_code: bool = True dtype: str = "auto" kv_cache_dtype: str = "auto" + quantization_param_path: nullable_str = None quantization: Optional[str] = None context_length: Optional[int] = None device: str = "cuda" @@ -55,7 +56,6 @@ class ServerArgs: is_embedding: bool = False revision: Optional[str] = None skip_tokenizer_init: bool = False - return_token_ids: bool = False # Port for the HTTP server host: str = "127.0.0.1" @@ -91,7 +91,7 @@ class ServerArgs: # API related api_key: Optional[str] = None - file_storage_pth: str = "SGLang_storage" + file_storage_pth: str = "sglang_storage" enable_cache_report: bool = False # Data parallelism @@ -156,6 +156,11 @@ class ServerArgs: triton_attention_num_kv_splits: int = 8 num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False + enable_memory_saver: bool = False + allow_auto_truncate: bool = False + + # Custom logit processor + enable_custom_logit_processor: bool = False def __post_init__(self): # Set missing default values @@ -239,14 +244,13 @@ def __post_init__(self): # Others if self.enable_dp_attention: self.dp_size = self.tp_size + assert self.tp_size % self.dp_size == 0 self.chunked_prefill_size = self.chunked_prefill_size // 2 self.schedule_conservativeness = self.schedule_conservativeness * 0.3 - self.disable_overlap_schedule = True logger.warning( f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. " "Data parallel size is adjusted to be the same as tensor parallel size. " - "Overlap scheduler is disabled." ) # Speculative Decoding @@ -296,6 +300,11 @@ def add_cli_args(parser: argparse.ArgumentParser): "tokenizer if available, and 'slow' will " "always use the slow tokenizer.", ) + parser.add_argument( + "--skip-tokenizer-init", + action="store_true", + help="If set, skip init tokenizer and pass input_ids in generate request", + ) parser.add_argument( "--load-format", type=str, @@ -308,6 +317,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "dummy", "gguf", "bitsandbytes", + "layered", ], help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' @@ -321,7 +331,10 @@ def add_cli_args(parser: argparse.ArgumentParser): "which is mainly for profiling." '"gguf" will load the weights in the gguf format. ' '"bitsandbytes" will load the weights using bitsandbytes ' - "quantization.", + "quantization." + '"layered" loads weights layer by layer so that one can quantize a ' + "layer before loading another to make the peak memory envelope " + "smaller.", ) parser.add_argument( "--trust-remote-code", @@ -346,8 +359,17 @@ def add_cli_args(parser: argparse.ArgumentParser): "--kv-cache-dtype", type=str, default=ServerArgs.kv_cache_dtype, - choices=["auto", "fp8_e5m2"], - help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.', + choices=["auto", "fp8_e5m2", "fp8_e4m3"], + help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.', + ) + parser.add_argument( + "--quantization-param-path", + type=nullable_str, + default=None, + help="Path to the JSON file containing the KV cache " + "scaling factors. This should generally be supplied, when " + "KV cache dtype is FP8. Otherwise, KV cache scaling factors " + "default to 1.0, which may cause accuracy issues. ", ) parser.add_argument( "--quantization", @@ -363,6 +385,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "bitsandbytes", "gguf", "modelopt", + "w8a8_int8", ], help="The quantization method.", ) @@ -376,7 +399,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--device", type=str, default="cuda", - choices=["cuda", "xpu", "hpu"], + choices=["cuda", "xpu", "hpu", "cpu"], help="The device type.", ) parser.add_argument( @@ -404,18 +427,6 @@ def add_cli_args(parser: argparse.ArgumentParser): "name, a tag name, or a commit id. If unspecified, will use " "the default version.", ) - parser.add_argument( - "--skip-tokenizer-init", - action="store_true", - help="If set, skip init tokenizer and pass input_ids in generate request", - ) - parser.add_argument( - "--return-token-ids", - action="store_true", - default=ServerArgs.return_token_ids, - help="Whether to return token IDs in the output, this may introduce additional overhead.", - ) - # Memory and scheduling parser.add_argument( "--mem-fraction-static", @@ -551,7 +562,7 @@ def add_cli_args(parser: argparse.ArgumentParser): "--decode-log-interval", type=int, default=ServerArgs.decode_log_interval, - help="The log interval of decode batch", + help="The log interval of decode batch.", ) # API related @@ -851,6 +862,21 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Delete the model checkpoint after loading the model.", ) + parser.add_argument( + "--enable-memory-saver", + action="store_true", + help="Allow saving memory using release_memory_occupation and resume_memory_occupation", + ) + parser.add_argument( + "--allow-auto-truncate", + action="store_true", + help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.", + ) + parser.add_argument( + "--enable-custom-logit-processor", + action="store_true", + help="Enable users to pass custom logit processors to the server (disabled by default for security)", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace): @@ -861,7 +887,7 @@ def from_cli_args(cls, args: argparse.Namespace): return cls(**{attr: getattr(args, attr) for attr in attrs}) def url(self): - if is_ipv6(self.host): + if is_valid_ipv6_address(self.host): return f"http://[{self.host}]:{self.port}" else: return f"http://{self.host}:{self.port}" @@ -871,8 +897,8 @@ def check_server_args(self): self.tp_size % self.nnodes == 0 ), "tp_size must be divisible by number of nodes" assert not ( - self.dp_size > 1 and self.nnodes != 1 - ), "multi-node data parallel is not supported" + self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention + ), "multi-node data parallel is not supported unless dp attention!" assert ( self.max_loras_per_batch > 0 # FIXME @@ -910,6 +936,9 @@ def prepare_server_args(argv: List[str]) -> ServerArgs: return server_args +ZMQ_TCP_PORT_DELTA = 233 + + @dataclasses.dataclass class PortArgs: # The ipc filename for tokenizer to receive inputs from detokenizer (zmq) @@ -923,19 +952,49 @@ class PortArgs: nccl_port: int @staticmethod - def init_new(server_args) -> "PortArgs": + def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": port = server_args.port + random.randint(100, 1000) while True: if is_port_available(port): break - port += 42 + if port < 60000: + port += 42 + else: + port -= 43 + + if not server_args.enable_dp_attention: + # Normal case, use IPC within a single node + return PortArgs( + tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + nccl_port=port, + ) + else: + # DP attention. Use TCP + port to handle both single-node and multi-node. + if server_args.nnodes == 1 and server_args.dist_init_addr is None: + dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA) + else: + dist_init_addr = server_args.dist_init_addr.split(":") + assert ( + len(dist_init_addr) == 2 + ), "please provide --dist-init-addr as host:port of head node" + + dist_init_host, dist_init_port = dist_init_addr + port_base = int(dist_init_port) + 1 + if dp_rank is None: + scheduler_input_port = ( + port_base + 2 + ) # TokenizerManager to DataParallelController + else: + scheduler_input_port = port_base + 2 + 1 + dp_rank - return PortArgs( - tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, - scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, - detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, - nccl_port=port, - ) + return PortArgs( + tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}", + scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}", + detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}", + nccl_port=port, + ) class LoRAPathAction(argparse.Action): diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index a6fcf2e570df..049ba22750aa 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -14,7 +14,7 @@ from sglang.srt.speculative.spec_info import SpecInfo if TYPE_CHECKING: - from python.sglang.srt.managers.schedule_batch import ScheduleBatch + from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.server_args import ServerArgs @@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices( class EAGLEDraftInput(SpecInfo): def __init__(self): self.prev_mode = ForwardMode.DECODE - self.sample_output = None self.scores: torch.Tensor = None self.score_list: List[torch.Tensor] = [] @@ -190,12 +189,16 @@ def __init__(self): self.cache_list: List[torch.Tenor] = [] self.iter = 0 + # shape: (b, hidden_size) self.hidden_states: torch.Tensor = None + # shape: (b,) self.verified_id: torch.Tensor = None + # shape: (b, vocab_size) + self.sample_output: torch.Tensor = None + self.positions: torch.Tensor = None self.accept_length: torch.Tensor = None - self.has_finished: bool = False - self.unfinished_index: List[int] = None + self.accept_length_cpu: List[int] = None def load_server_args(self, server_args: ServerArgs): self.topk: int = server_args.speculative_eagle_topk @@ -218,7 +221,7 @@ def prepare_for_extend(self, batch: ScheduleBatch): :pre_len ] = req.prefix_indices - batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = ( + batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = ( out_cache_loc[pt : pt + req.extend_input_len] ) @@ -228,6 +231,14 @@ def prepare_for_extend(self, batch: ScheduleBatch): assert len(batch.extend_lens) == 1 batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id)) + def filter_batch( + self, + new_indices: torch.Tensor, + ): + self.sample_output = self.sample_output[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.verified_id = self.verified_id[: len(new_indices)] + def prepare_for_decode(self, batch: ScheduleBatch): prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab) top = torch.topk(prob, self.topk, dim=-1) @@ -245,9 +256,10 @@ def prepare_for_decode(self, batch: ScheduleBatch): ) # (b, topk) topk_cs_index, topk_cs_p = topk_cs.indices, topk_cs.values - selected_input_index = ( - topk_cs_index.flatten() // self.topk - ) # shape: (b * topk) + selected_input_index = topk_cs_index.flatten() // self.topk + torch.arange( + 0, batch.batch_size() * self.topk, step=self.topk, device="cuda" + ).repeat_interleave(self.topk) + batch.spec_info.hidden_states = batch.spec_info.hidden_states[ selected_input_index, : ] @@ -286,7 +298,9 @@ def prepare_for_decode(self, batch: ScheduleBatch): self.cache_list.append(batch.out_cache_loc) self.positions = ( batch.seq_lens[:, None] - + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter + + torch.full( + [1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long + ) ).flatten() bs = len(batch.seq_lens) @@ -303,24 +317,25 @@ def prepare_for_decode(self, batch: ScheduleBatch): def prepare_extend_after_decode(self, batch: ScheduleBatch): batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel()) - batch.extend_lens = (self.accept_length + 1).tolist() + accept_length_cpu = batch.spec_info.accept_length_cpu + batch.extend_lens = [x + 1 for x in accept_length_cpu] + batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend + seq_lens_cpu = batch.seq_lens.tolist() pt = 0 - seq_lens = batch.seq_lens.tolist() - i = 0 - for req in batch.reqs: if req.finished(): continue # assert seq_len - pre_len == req.extend_input_len - input_len = self.accept_length[i] + 1 - seq_len = seq_lens[i] + input_len = batch.extend_lens[i] + seq_len = seq_lens_cpu[i] batch.req_to_token_pool.req_to_token[req.req_pool_idx][ seq_len - input_len : seq_len ] = batch.out_cache_loc[pt : pt + input_len] pt += input_len i += 1 + assert pt == batch.out_cache_loc.shape[0] self.positions = torch.empty_like(self.verified_id) new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long) @@ -336,6 +351,7 @@ def prepare_extend_after_decode(self, batch: ScheduleBatch): triton.next_power_of_2(self.spec_steps + 1), ) + batch.seq_lens_sum = sum(seq_lens_cpu) batch.input_ids = self.verified_id self.verified_id = new_verified_id @@ -439,7 +455,14 @@ def generate_attn_arg_prefill( return kv_indices, cum_kv_seq_len, qo_indptr, None def merge_batch(self, spec_info: EAGLEDraftInput): - + if self.hidden_states is None: + self.hidden_states = spec_info.hidden_states + self.verified_id = spec_info.verified_id + self.sample_output = spec_info.sample_output + self.prev_mode = spec_info.prev_mode + return + if spec_info.hidden_states is None: + return self.hidden_states = torch.cat( [self.hidden_states, spec_info.hidden_states], axis=0 ) @@ -550,11 +573,41 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten triton.next_power_of_2(max_draft_len), ) - accept_index = accept_index[accept_index != -1] + draft_input = EAGLEDraftInput() + new_accept_index = [] + unfinished_index = [] + finished_extend_len = {} # {rid:accept_length + 1} + accept_index_cpu = accept_index.tolist() + predict_cpu = predict.tolist() + has_finished = False + + # iterate every accepted token and check if req has finished after append the token + # should be checked BEFORE free kv cache slots + for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)): + new_accept_index_ = [] + for j, idx in enumerate(accept_index_row): + if idx == -1: + break + id = predict_cpu[idx] + # if not found_finished: + req.output_ids.append(id) + finished_extend_len[req.rid] = j + 1 + req.check_finished() + if req.finished(): + has_finished = True + # set all tokens after finished token to -1 and break + accept_index[i, j + 1 :] = -1 + break + else: + new_accept_index_.append(idx) + if not req.finished(): + new_accept_index.extend(new_accept_index_) + unfinished_index.append(i) + accept_length = (accept_index != -1).sum(dim=1) - 1 + accept_index = accept_index[accept_index != -1] accept_length_cpu = accept_length.tolist() verified_id = predict[accept_index] - verified_id_cpu = verified_id.tolist() evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) evict_mask[accept_index] = False @@ -570,30 +623,19 @@ def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Ten triton.next_power_of_2(bs), ) batch.seq_lens.add_(accept_length + 1) - new_accept_index = [] - unfinished_index = [] - finished_extend_len = {} # {rid:accept_length + 1} - # retracted_reqs, new_token_ratio = batch.retract_decode() - - low = 0 - draft_input = EAGLEDraftInput() - for i, (req, verified_len) in enumerate(zip(batch.reqs, accept_length_cpu)): - req.output_ids.extend(verified_id_cpu[low : low + verified_len + 1]) - req.check_finished() - if req.finished(): - draft_input.has_finished = True - else: - new_accept_index.append(accept_index[low : low + verified_len + 1]) - unfinished_index.append(i) - low += verified_len + 1 - finished_extend_len[req.rid] = verified_len + 1 if len(new_accept_index) > 0: - new_accept_index = torch.cat(new_accept_index, dim=0) + new_accept_index = torch.tensor(new_accept_index, device="cuda") draft_input.verified_id = predict[new_accept_index] draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index] draft_input.accept_length = accept_length[unfinished_index] - draft_input.unfinished_index = unfinished_index + draft_input.accept_length_cpu = [ + accept_length_cpu[i] for i in unfinished_index + ] + if has_finished: + draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index] + else: + draft_input.seq_lens_for_draft_extend = batch.seq_lens logits_output.next_token_logits = logits_output.next_token_logits[accept_index] return ( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 16d54c43bafb..06a4372fce2e 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -13,6 +13,7 @@ from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.eagle_utils import EAGLEDraftInput +from sglang.srt.utils import rank0_print class EAGLEWorker(TpModelWorker): @@ -40,6 +41,7 @@ def __init__( ) self.target_worker = target_worker self.server_args = server_args + self.finish_extend_len = [] # Share the embedding and lm_head embed, head = self.target_worker.model_runner.model.get_embed_and_head() @@ -49,18 +51,18 @@ def __init__( def forward_draft_decode(self, batch: ScheduleBatch): batch.spec_info.prepare_for_decode(batch) + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST logits_output = self.model_runner.forward(forward_batch) self.capture_for_decode(logits_output, forward_batch) def forward_draft_extend(self, batch: ScheduleBatch): self._set_mem_pool(batch, self.model_runner) batch.spec_info.prepare_for_extend(batch) + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST logits_output = self.model_runner.forward(forward_batch) self.capture_for_decode(logits_output, forward_batch) self._set_mem_pool(batch, self.target_worker.model_runner) @@ -133,26 +135,23 @@ def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner): batch.req_to_token_pool = runner.req_to_token_pool def forward_draft_extend_after_decode(self, batch: ScheduleBatch): + seq_lens_backup = batch.seq_lens + self._set_mem_pool(batch, self.model_runner) batch.forward_mode = ForwardMode.DRAFT_EXTEND - if batch.spec_info.has_finished: - index = batch.spec_info.unfinished_index - seq_lens = batch.seq_lens - batch.seq_lens = batch.seq_lens[index] - batch.spec_info.prepare_extend_after_decode(batch) + batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST model_worker_batch = batch.get_model_worker_batch() forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST logits_output = self.model_runner.forward(forward_batch) - - batch.spec_info.hidden_states = logits_output.hidden_states self.capture_for_decode(logits_output, forward_batch) - batch.forward_mode = ForwardMode.DECODE - if batch.spec_info.has_finished: - batch.seq_lens = seq_lens self._set_mem_pool(batch, self.target_worker.model_runner) + # Restore backup. + # This is because `seq_lens` can be modified in `prepare_extend_after_decode` + batch.forward_mode = ForwardMode.DECODE + batch.seq_lens = seq_lens_backup + def capture_for_decode( self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch ): @@ -169,6 +168,8 @@ def finish_request(self, reqs: Union[Req, List[Req]]): if not isinstance(reqs, List): reqs = [reqs] for req in reqs: + if req.rid not in self.finish_extend_len: + continue req_len = ( len(req.origin_input_ids) + len(req.output_ids) diff --git a/python/sglang/srt/torch_memory_saver_adapter.py b/python/sglang/srt/torch_memory_saver_adapter.py new file mode 100644 index 000000000000..31f8ebf2f077 --- /dev/null +++ b/python/sglang/srt/torch_memory_saver_adapter.py @@ -0,0 +1,59 @@ +from abc import ABC +from contextlib import contextmanager + +try: + import torch_memory_saver + + _primary_memory_saver = torch_memory_saver.TorchMemorySaver() +except ImportError: + pass + + +class TorchMemorySaverAdapter(ABC): + @staticmethod + def create(enable: bool): + return ( + _TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop() + ) + + def configure_subprocess(self): + raise NotImplementedError + + def region(self): + raise NotImplementedError + + def pause(self): + raise NotImplementedError + + def resume(self): + raise NotImplementedError + + +class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter): + def configure_subprocess(self): + return torch_memory_saver.configure_subprocess() + + def region(self): + return _primary_memory_saver.region() + + def pause(self): + return _primary_memory_saver.pause() + + def resume(self): + return _primary_memory_saver.resume() + + +class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter): + @contextmanager + def configure_subprocess(self): + yield + + @contextmanager + def region(self): + yield + + def pause(self): + pass + + def resume(self): + pass diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 44a5e41a41bd..f1d57e9062a7 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -59,6 +59,7 @@ default_dump_dir, default_override_dir, ) +from uvicorn.config import LOGGING_CONFIG logger = logging.getLogger(__name__) @@ -72,7 +73,7 @@ def is_hip() -> bool: def is_cuda(): - return hasattr(torch, "cuda") and torch.cuda.is_available() + return hasattr(torch, "cuda") and torch.version.cuda is not None def is_cuda_alike(): @@ -97,12 +98,8 @@ def is_flashinfer_available(): return torch.cuda.is_available() and torch.version.cuda -def is_ipv6(address): - try: - ipaddress.IPv6Address(address) - return True - except ipaddress.AddressValueError: - return False +def is_cuda_available(): + return torch.cuda.is_available() and torch.version.cuda def enable_show_time_cost(): @@ -218,6 +215,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info() + elif device == "cpu": + # TODO: rename the variables in the current function to be not GPU specific + free_gpu_memory = psutil.virtual_memory().available + if distributed: tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( torch.device(device, gpu_id) @@ -335,6 +336,8 @@ def is_port_available(port): return True except socket.error: return False + except OverflowError: + return False def decode_video_base64(video_base64): @@ -440,6 +443,8 @@ def load_image(image_file: Union[str, bytes]): else: raise ValueError(f"Invalid image: {image}") + # if image_size is None: + # image_size = image.size return image, image_size @@ -505,76 +510,32 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N pass -def monkey_patch_vllm_p2p_access_check(gpu_id: int): +def monkey_patch_p2p_access_check(): """ - Monkey patch the slow p2p access check in vllm. + Monkey patch the slow p2p access check. NOTE: We assume the p2p access is always allowed, which can be wrong for some setups. """ - import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt + import sglang.srt.distributed.device_communicators.custom_all_reduce_utils as tgt setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) # Suppress the warnings from this delete function when using sglang.bench_one_batch - from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce + from sglang.srt.distributed.device_communicators.custom_all_reduce import ( + CustomAllreduce, + ) setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None) -vllm_all_gather_backup = None - - -def monkey_patch_vllm_all_gather(reverse: bool = False): - """Monkey patch all-gather to remove in-place operations.""" - from torch.distributed import _functional_collectives as funcol - from vllm.distributed.parallel_state import GroupCoordinator - - global vllm_all_gather_backup - if vllm_all_gather_backup is None: - vllm_all_gather_backup = GroupCoordinator.all_gather - - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: - world_size = self.world_size - # Bypass the function if we are using only 1 GPU. - if world_size == 1: - return input_ - assert ( - -input_.dim() <= dim < input_.dim() - ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" - if dim < 0: - # Convert negative dim to positive. - dim += input_.dim() - input_size = input_.size() - # Allocate output tensor. - output_tensor = torch.empty( - (world_size,) + input_size, dtype=input_.dtype, device=input_.device - ) - - output_tensor = funcol.all_gather_tensor( - input_, gather_dim=0, group=self.device_group - ).view((world_size,) + input_size) - - # Reshape - output_tensor = output_tensor.movedim(0, dim) - output_tensor = output_tensor.reshape( - input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] - ) - return output_tensor - - if reverse: - setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup) - else: - setattr(GroupCoordinator, "all_gather", all_gather) - - def monkey_patch_vllm_gguf_config(): - from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.gguf import ( GGUFConfig, GGUFEmbeddingMethod, GGUFLinearMethod, ) + from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding def get_quant_method_with_embedding_replaced( @@ -782,7 +743,9 @@ def first_rank_print(*args, **kwargs): pass -def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str): +def get_zmq_socket( + context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool +): mem = psutil.virtual_memory() total_mem = mem.total / 1024**3 available_mem = mem.available / 1024**3 @@ -795,14 +758,17 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: if socket_type == zmq.PUSH: socket.setsockopt(zmq.SNDHWM, 0) socket.setsockopt(zmq.SNDBUF, buf_size) - socket.connect(f"ipc://{endpoint}") elif socket_type == zmq.PULL: socket.setsockopt(zmq.RCVHWM, 0) socket.setsockopt(zmq.RCVBUF, buf_size) - socket.bind(f"ipc://{endpoint}") else: raise ValueError(f"Unsupported socket type: {socket_type}") + if bind: + socket.bind(endpoint) + else: + socket.connect(endpoint) + return socket @@ -1244,9 +1210,9 @@ def dataclass_to_string_truncated(data, max_length=2048): if isinstance(data, str): if len(data) > max_length: half_length = max_length // 2 - return f'"{data[:half_length]} ... {data[-half_length:]}"' + return f"{repr(data[:half_length])} ... {repr(data[-half_length:])}" else: - return f'"{data}"' + return f"{repr(data)}" elif isinstance(data, (list, tuple)): if len(data) > max_length: half_length = max_length // 2 @@ -1257,7 +1223,7 @@ def dataclass_to_string_truncated(data, max_length=2048): return ( "{" + ", ".join( - f"{k}: {dataclass_to_string_truncated(v, max_length)}" + f"'{k}': {dataclass_to_string_truncated(v, max_length)}" for k, v in data.items() ) + "}" @@ -1338,6 +1304,25 @@ def parse_tool_response(text, tools, **kwargs): return text, call_info_list +def permute_weight(x: torch.Tensor) -> torch.Tensor: + b_ = x.shape[0] + n_ = x.shape[1] + k_ = x.shape[2] + + x_ = x + if x.dtype == torch.bfloat16 or x.dtype == torch.float16: + x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8) + elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8: + x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16) + else: + return x_ + + x_ = x_.permute(0, 1, 3, 4, 2, 5) + x_ = x_.contiguous() + x_ = x_.view(*x.shape) + return x_ + + class MultiprocessingSerializer: @staticmethod def serialize(obj): @@ -1373,3 +1358,94 @@ def wrapper(*args, **kwargs): return func(*args, **kwargs) return wrapper + + +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + +def set_uvicorn_logging_configs(): + LOGGING_CONFIG["formatters"]["default"][ + "fmt" + ] = "[%(asctime)s] %(levelprefix)s %(message)s" + LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S" + LOGGING_CONFIG["formatters"]["access"][ + "fmt" + ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s' + LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S" + + +def get_ip() -> str: + # SGLANG_HOST_IP env can be ignore + host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "") + if host_ip: + return host_ip + + # IP is not set, try to get it from the network interface + + # try ipv4 + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + # try ipv6 + try: + s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + # Google's public DNS server, see + # https://developers.google.com/speed/public-dns/docs/using#addresses + s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + warnings.warn( + "Failed to get the IP address, using 0.0.0.0 by default." + "The value can be set by the environment variable" + " SGLANG_HOST_IP or HOST_IP.", + stacklevel=2, + ) + return "0.0.0.0" + + +def get_open_port() -> int: + + port = os.getenv("SGLANG_PORT") + if port is not None: + while True: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + return port + except OSError: + port += 1 # Increment port number if already in use + logger.info("Port %d is already in use, trying port %d", port - 1, port) + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + +def rank0_print(msg: str): + from sglang.srt.distributed import get_tensor_model_parallel_rank + + if get_tensor_model_parallel_rank() == 0: + print(msg, flush=True) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index f22f9cafaf39..bae0fcf2a494 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -12,7 +12,6 @@ # limitations under the License. # ============================================================================== -import json import multiprocessing as mp import os from dataclasses import dataclass @@ -22,8 +21,8 @@ import torch.nn.functional as F from transformers import AutoModelForCausalLM +from sglang.srt.entrypoints.engine import Engine from sglang.srt.hf_transformers_utils import get_tokenizer -from sglang.srt.server import Runtime from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER DEFAULT_PROMPTS = [ @@ -278,7 +277,7 @@ def __init__( ): self.model_type = model_type self.is_generation = model_type == "generation" - self.runtime = Runtime( + self.engine = Engine( model_path=model_path, tp_size=tp_size, dtype=get_dtype_str(torch_dtype), @@ -306,7 +305,7 @@ def forward( top_output_logprobs = [] sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} for i, prompt in enumerate(prompts): - response = self.runtime.generate( + response = self.engine.generate( prompt, lora_path=lora_paths[i] if lora_paths else None, sampling_params=sampling_params, @@ -314,7 +313,6 @@ def forward( logprob_start_len=0, top_logprobs_num=NUM_TOP_LOGPROBS, ) - response = json.loads(response) output_strs.append(response["text"]) top_input_logprobs.append( [ @@ -343,8 +341,7 @@ def forward( top_output_logprobs=top_output_logprobs, ) else: - response = self.runtime.encode(prompts) - response = json.loads(response) + response = self.engine.encode(prompts) if self.model_type == "embedding": logits = [x["embedding"] for x in response] return ModelOutput(embed_logits=logits) @@ -366,20 +363,18 @@ def batch_forward( # the return value contains logprobs from prefill output_strs = [] sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} - response = self.runtime.generate( + response = self.engine.generate( prompts, lora_path=lora_paths if lora_paths else None, sampling_params=sampling_params, ) - response = json.loads(response) output_strs = [r["text"] for r in response] return ModelOutput( output_strs=output_strs, ) else: - response = self.runtime.encode(prompts) - response = json.loads(response) + response = self.engine.encode(prompts) if self.model_type == "embedding": logits = [x["embedding"] for x in response] return ModelOutput(embed_logits=logits) @@ -391,8 +386,8 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): - self.runtime.shutdown() - del self.runtime + self.engine.shutdown() + del self.engine def monkey_patch_gemma2_sdpa(): diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index 411a20b9267c..088cb0d0af91 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -535,7 +535,8 @@ def few_shot_hellaswag(s, question, choices): # Compute accuracy accuracy_gen = np.mean(np.array(preds_gen) == np.array(labels)) - assert np.abs(accuracy_gen - accuracy) < 0.01 + print(f"{accuracy=}, {accuracy_gen=}") + assert np.abs(accuracy_gen - accuracy) < 0.05 assert np.abs(latency_gen - latency) < 1 return accuracy, latency diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index cd21c896a044..ee5ae278d139 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -36,10 +36,14 @@ DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600 DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it" -DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" +DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP2 = "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8,neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8,neuralmagic/Qwen2-72B-Instruct-FP8,neuralmagic/Qwen2-57B-A14B-Instruct-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8" DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4,hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4" +DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN = "Qwen/Qwen2.5-1.5B-Instruct" + +DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf" +DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmzheng/sglang-EAGLE-llama2-chat-7B" def is_in_ci(): @@ -405,7 +409,7 @@ def popen_launch_server( base_url: str, timeout: float, api_key: Optional[str] = None, - other_args: tuple = (), + other_args: list[str] = (), env: Optional[dict] = None, return_stdout_stderr: Optional[tuple] = None, ): @@ -537,6 +541,7 @@ def run_bench_serving( random_input_len=4096, random_output_len=2048, disable_stream=False, + disable_ignore_eos=False, need_warmup=False, ): # Launch the server @@ -560,20 +565,22 @@ def run_bench_serving( tokenizer=tokenizer, num_prompts=num_prompts, sharegpt_output_len=None, + sharegpt_context_len=None, random_input_len=random_input_len, random_output_len=random_output_len, random_range_ratio=0.0, request_rate=request_rate, multi=None, - seed=0, output_file=None, disable_tqdm=False, disable_stream=disable_stream, - disable_ignore_eos=False, return_logprob=False, - lora_name=None, + seed=0, + disable_ignore_eos=disable_ignore_eos, extra_request_body=None, + apply_chat_template=False, profile=None, + lora_name=None, ) try: diff --git a/python/sglang/utils.py b/python/sglang/utils.py index 98e0f3f4f8db..742eebc3bc9b 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -1,7 +1,6 @@ """Common utilities""" import base64 -import gc import importlib import json import logging @@ -15,7 +14,7 @@ from concurrent.futures import ThreadPoolExecutor from io import BytesIO from json import dumps -from typing import Optional, Union +from typing import Any, Callable, List, Optional, Tuple, Type, Union import numpy as np import requests @@ -363,3 +362,14 @@ def terminate_process(process): def print_highlight(html_content: str): html_content = str(html_content).replace("\n", "
") display(HTML(f"{html_content}")) + + +class TypeBasedDispatcher: + def __init__(self, mapping: List[Tuple[Type, Callable]]): + self._mapping = mapping + + def __call__(self, obj: Any): + for ty, fn in self._mapping: + if isinstance(obj, ty): + return fn(obj) + raise ValueError(f"Invalid object: {obj}") diff --git a/python/sglang/version.py b/python/sglang/version.py index 24e54e5c95d5..18ca924974b2 100644 --- a/python/sglang/version.py +++ b/python/sglang/version.py @@ -1 +1 @@ -__version__ = "0.4.1.post4" +__version__ = "0.4.1.post7" diff --git a/scripts/ci_install_dependency.sh b/scripts/ci_install_dependency.sh index 26c34879e9ba..1a059d5ff683 100755 --- a/scripts/ci_install_dependency.sh +++ b/scripts/ci_install_dependency.sh @@ -12,8 +12,9 @@ bash "${SCRIPT_DIR}/killall_sglang.sh" pip install --upgrade pip pip install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.4/flashinfer/ -# Force reinstall flashinfer +# Force reinstall flashinfer and torch_memory_saver pip install flashinfer==0.1.6 --find-links ${FLASHINFER_REPO} --force-reinstall --no-deps +pip install torch_memory_saver --force-reinstall pip install transformers==4.45.2 sentence_transformers accelerate peft @@ -22,3 +23,6 @@ pip install cutex # For compling xgrammar kernels pip install cuda-python nvidia-cuda-nvrtc-cu12 + +# reinstall sgl-kernel +pip install sgl-kernel --force-reinstall --no-deps diff --git a/scripts/ci_install_rust.sh b/scripts/ci_install_rust.sh index 724207fd7825..519155dfbe85 100755 --- a/scripts/ci_install_rust.sh +++ b/scripts/ci_install_rust.sh @@ -1,9 +1,14 @@ #!/bin/bash set -euxo pipefail -# these are required for actix -apt-get update -apt-get install -y libssl-dev pkg-config +# Check if sudo is available +if command -v sudo >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y libssl-dev pkg-config +else + apt-get update + apt-get install -y libssl-dev pkg-config +fi # Install rustup (Rust installer and version manager) curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y diff --git a/scripts/deprecated/test_httpserver_classify.py b/scripts/deprecated/test_httpserver_classify.py deleted file mode 100644 index cb88802999a7..000000000000 --- a/scripts/deprecated/test_httpserver_classify.py +++ /dev/null @@ -1,85 +0,0 @@ -""" -Usage: -python3 -m sglang.launch_server --model-path /model/llama-classification --is-embedding --disable-radix-cache - -python3 test_httpserver_classify.py -""" - -import argparse - -import numpy as np -import requests - - -def get_logits_deprecated(url: str, prompt: str): - response = requests.post( - url + "/generate", - json={ - "text": prompt, - "sampling_params": { - "max_new_tokens": 0, - }, - "return_logprob": True, - }, - ) - return response.json()["meta_info"]["normalized_prompt_logprob"] - - -def get_logits_batch_deprecated(url: str, prompts: list[str]): - response = requests.post( - url + "/generate", - json={ - "text": prompts, - "sampling_params": { - "max_new_tokens": 0, - }, - "return_logprob": True, - }, - ) - ret = response.json() - logits = np.array( - list( - ret[i]["meta_info"]["normalized_prompt_logprob"] - for i in range(len(prompts)) - ) - ) - return logits - - -def get_logits(url: str, prompt: str): - response = requests.post( - url + "/classify", - json={"text": prompt}, - ) - return response.json()["embedding"] - - -def get_logits_batch(url: str, prompts: list[str]): - response = requests.post( - url + "/classify", - json={"text": prompts}, - ) - return np.array([x["embedding"] for x in response.json()]) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="http://127.0.0.1") - parser.add_argument("--port", type=int, default=30000) - args = parser.parse_args() - - url = f"{args.host}:{args.port}" - - # A single request - prompt = "This is a test prompt.<|eot_id|>" - logits = get_logits(url, prompt) - print(f"{logits=}") - - # A batch of requests - prompts = [ - "This is a test prompt.<|eot_id|>", - "This is another test prompt.<|eot_id|>", - "This is a long long long long test prompt.<|eot_id|>", - ] - logits = get_logits_batch(url, prompts) - print(f"{logits=}") diff --git a/scripts/deprecated/test_httpserver_decode_stream.py b/scripts/deprecated/test_httpserver_decode_stream.py index 955c368d1549..616eaf6c4b1e 100644 --- a/scripts/deprecated/test_httpserver_decode_stream.py +++ b/scripts/deprecated/test_httpserver_decode_stream.py @@ -42,7 +42,6 @@ def test_decode_stream(url, return_logprob, top_logprobs_num): if return_logprob: assert data["meta_info"]["input_token_logprobs"] is not None assert data["meta_info"]["output_token_logprobs"] is not None - assert data["meta_info"]["normalized_prompt_logprob"] is not None for logprob, token_id, token_text in data["meta_info"][ "output_token_logprobs" ][prev:]: diff --git a/scripts/deprecated/test_jump_forward.py b/scripts/deprecated/test_jump_forward.py index 60074a040054..315a50b5ba71 100644 --- a/scripts/deprecated/test_jump_forward.py +++ b/scripts/deprecated/test_jump_forward.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, constr import sglang as sgl -from sglang.srt.constrained import build_regex_from_object +from sglang.srt.constrained.outlines_backend import build_regex_from_object from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, diff --git a/scripts/killall_sglang.sh b/scripts/killall_sglang.sh index 4057d2be2fb4..163a60f184b7 100755 --- a/scripts/killall_sglang.sh +++ b/scripts/killall_sglang.sh @@ -1,5 +1,14 @@ #!/bin/bash +# Check if sudo is available +if command -v sudo >/dev/null 2>&1; then + sudo apt-get update + sudo apt-get install -y lsof +else + apt-get update + apt-get install -y lsof +fi + # Show current GPU status nvidia-smi @@ -7,6 +16,7 @@ nvidia-smi kill -9 $(ps aux | grep 'sglang::' | grep -v 'grep' | awk '{print $2}') 2>/dev/null kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') 2>/dev/null kill -9 $(ps aux | grep 'sglang.bench' | grep -v 'grep' | awk '{print $2}') 2>/dev/null +kill -9 $(ps aux | grep 'sglang.data_parallel' | grep -v 'grep' | awk '{print $2}') 2>/dev/null # Clean all GPU processes if any argument is provided if [ $# -gt 0 ]; then diff --git a/scripts/update_kernel_whl_index.py b/scripts/update_kernel_whl_index.py new file mode 100644 index 000000000000..a42969641f57 --- /dev/null +++ b/scripts/update_kernel_whl_index.py @@ -0,0 +1,16 @@ +# Reference: https://github.com/flashinfer-ai/flashinfer/blob/v0.2.0/scripts/update_whl_index.py + +import hashlib +import pathlib +import re + +for path in sorted(pathlib.Path("sgl-kernel/dist").glob("*.whl")): + with open(path, "rb") as f: + sha256 = hashlib.sha256(f.read()).hexdigest() + ver = re.findall(r"sgl_kernel-([0-9.]+(?:\.post[0-9]+)?)-", path.name)[0] + index_dir = pathlib.Path(f"sgl-whl/cu118/sgl-kernel") + index_dir.mkdir(exist_ok=True) + base_url = "https://github.com/sgl-project/whl/releases/download" + full_url = f"{base_url}/v{ver}/{path.name}#sha256={sha256}" + with (index_dir / "index.html").open("a") as f: + f.write(f'{path.name}
\n') diff --git a/sgl-kernel/3rdparty/cccl b/sgl-kernel/3rdparty/cccl new file mode 160000 index 000000000000..b5fe509fd11a --- /dev/null +++ b/sgl-kernel/3rdparty/cccl @@ -0,0 +1 @@ +Subproject commit b5fe509fd11a925f90d6495176707cc1184eed9d diff --git a/sgl-kernel/3rdparty/cutlass b/sgl-kernel/3rdparty/cutlass index bf9da7b76c76..b78588d1630a 160000 --- a/sgl-kernel/3rdparty/cutlass +++ b/sgl-kernel/3rdparty/cutlass @@ -1 +1 @@ -Subproject commit bf9da7b76c766d7ee7d536afc77880a4ef1f1156 +Subproject commit b78588d1630aa6643bf021613717bafb705df4ef diff --git a/sgl-kernel/3rdparty/flashinfer b/sgl-kernel/3rdparty/flashinfer new file mode 160000 index 000000000000..6e6f38d35349 --- /dev/null +++ b/sgl-kernel/3rdparty/flashinfer @@ -0,0 +1 @@ +Subproject commit 6e6f38d3534994c34b2c6b09b5b45c8a7b92ffd2 diff --git a/sgl-kernel/3rdparty/turbomind b/sgl-kernel/3rdparty/turbomind new file mode 160000 index 000000000000..0c9d0c724a99 --- /dev/null +++ b/sgl-kernel/3rdparty/turbomind @@ -0,0 +1 @@ +Subproject commit 0c9d0c724a99974ca3af0c12b24ef8a0444c4fd9 diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt deleted file mode 100644 index 3c267a4de504..000000000000 --- a/sgl-kernel/CMakeLists.txt +++ /dev/null @@ -1,62 +0,0 @@ -cmake_minimum_required(VERSION 3.18) -project(sgl-kernel LANGUAGES CXX CUDA) - -# Basic settings -set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) -set(CMAKE_CUDA_STANDARD 17) -set(CMAKE_CUDA_STANDARD_REQUIRED ON) - -set(CUTLASS_DIR "3rdparty/cutlass") - -# Set CUDA architectures -set(CMAKE_CUDA_ARCHITECTURES "75;80;86;89;90") -message(STATUS "Building for CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") - -find_package(Python3 COMPONENTS Interpreter Development REQUIRED) - -# Find PyTorch -execute_process( - COMMAND ${Python3_EXECUTABLE} -c "import torch; print(torch.utils.cmake_prefix_path)" - OUTPUT_VARIABLE TORCH_CMAKE_PATH - OUTPUT_STRIP_TRAILING_WHITESPACE -) -list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PATH}") - -find_package(Torch REQUIRED) - -# Warp Reduce library -add_library(_kernels SHARED - src/sgl-kernel/csrc/trt_reduce_internal.cu - src/sgl-kernel/csrc/trt_reduce_kernel.cu - src/sgl-kernel/csrc/moe_align_kernel.cu - src/sgl-kernel/csrc/int8_gemm_kernel.cu - src/sgl-kernel/csrc/sgl_kernel_ops.cu -) - -target_include_directories(_kernels - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/src/sgl-kernel/csrc - ${CUDA_INCLUDE_DIRS} - ${TORCH_INCLUDE_DIRS} - ${CUTLASS_DIR}/include - ${CUTLASS_DIR}/tools/util/include -) - -target_link_libraries(_kernels - PRIVATE - ${TORCH_LIBRARIES} - Python3::Python -) - -# Set common properties for both libraries -foreach(target _kernels) - set_target_properties(${target} PROPERTIES - CUDA_SEPARABLE_COMPILATION ON - POSITION_INDEPENDENT_CODE ON - CUDA_RESOLVE_DEVICE_SYMBOLS ON - PREFIX "" - SUFFIX ".so" - ) -endforeach() diff --git a/sgl-kernel/Makefile b/sgl-kernel/Makefile index 7a041b1ed408..1384f1bcd81d 100644 --- a/sgl-kernel/Makefile +++ b/sgl-kernel/Makefile @@ -1,22 +1,28 @@ -.PHONY: tree ln install build clean test format +.PHONY: tree ln submodule install build clean rebuild test format tree: - @tree --prune -I "__pycache__|*.egg-info|*.so|build" + @tree --prune -I "__pycache__|*.egg-info|*.so|build|3rdparty|dist" -ln: - @rm -rf build && cmake . -DCMAKE_EXPORT_COMPILE_COMMANDS=1 -DCMAKE_CUDA_COMPILER=nvcc -B build && rm -rf compile_commands.json && ln -s build/compile_commands.json compile_commands.json +submodule: + @git submodule update --init --recursive -install: +ln: submodule + @rm -rf build && bear python3 setup.py build + +install: submodule @pip install -e . -build: - @export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel +build: submodule + @rm -rf dist/* || true && export MAX_JOBS=$(nproc) && python3 setup.py bdist_wheel && pip3 install dist/*whl --force-reinstall --no-deps clean: @rm -rf build dist *.egg-info +rebuild: clean submodule build + @echo "Succeed to rebuild" + test: - @pytest tests/ + @find tests -name "test_*.py" | xargs -n 1 python3 format: @find src tests -name '*.cc' -o -name '*.cu' -o -name '*.cuh' -o -name '*.h' -o -name '*.hpp' | xargs clang-format -i && find src tests -name '*.py' | xargs isort && find src tests -name '*.py' | xargs black diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index 857cae366d83..0572f9758ab3 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -1,5 +1,19 @@ # SGL Kernel -Kernel Library for SGLang +[Kernel Library](https://github.com/sgl-project/sglang/tree/main/sgl-kernel) for SGLang [![PyPI](https://img.shields.io/pypi/v/sgl-kernel)](https://pypi.org/project/sgl-kernel) + +## Installation + +For CUDA 11.8: + +```bash +pip3 install sgl-kernel -i https://docs.sglang.ai/whl/cu118 +``` + +For CUDA 12.1 or CUDA 12.4: + +```bash +pip3 install sgl-kernel +``` diff --git a/sgl-kernel/THIRDPARTYNOTICES.txt b/sgl-kernel/THIRDPARTYNOTICES.txt new file mode 100644 index 000000000000..c930aa5dd3d8 --- /dev/null +++ b/sgl-kernel/THIRDPARTYNOTICES.txt @@ -0,0 +1,225 @@ +Notice for flashinfer-ai/flashinfer +------------------------------- + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------------------- +Some of the code in this project are adapted from other open-source projects with different +licenses. This product also bundles some third-party components under other open source licenses. +This section summarizes those components and their licenses. +See licenses/ for text of these licenses. + +BSD 3-Clause License +-------------------- + +include/flashinfer/attention/hopper/epilogue.cuh +include/flashinfer/attention/hopper/mainloop.cuh +include/flashinfer/attention/hopper/kernel_traits.cuh +include/flashinfer/attention/hopper/named_barrier.cuh +include/flashinfer/attention/hopper/tile_scheduler.cuh +include/flashinfer/attention/hopper/utils.cuh + +BSD 3-Clause "New" License +-------------------------- + +3rdparty/cutlass +include/flashinfer/attention/hopper/block_sparse_gather.cuh diff --git a/sgl-kernel/benchmark/bench_fp8_gemm.py b/sgl-kernel/benchmark/bench_fp8_gemm.py new file mode 100644 index 000000000000..c3f804753568 --- /dev/null +++ b/sgl-kernel/benchmark/bench_fp8_gemm.py @@ -0,0 +1,164 @@ +import argparse +import copy +import itertools + +import torch +import triton +from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm +from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant + +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "meta-llama/Llama-3.1-8B-Instruct": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-3.3-70B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "mistralai/Mistral-Large-Instruct-2407": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 57344], 1), + ([28672, 12288], 0), + ], + "Qwen/Qwen2.5-7B-Instruct": [ + ([3584, 4608], 1), + ([3584, 3584], 0), + ([3584, 37888], 1), + ([18944, 3584], 0), + ], + "Qwen/Qwen2.5-32B-Instruct": [ + ([5120, 7168], 1), + ([5120, 5120], 0), + ([5120, 55296], 1), + ([27648, 5120], 0), + ], + "Qwen/Qwen2.5-72B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 59136], 1), + ([29568, 8192], 0), + ], + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ + ([2048, 3072], 1), + ([2048, 4096], 1), + ([2048, 2048], 0), + ([2048, 576], 0), + ([2048, 21888], 1), + ([10944, 2048], 0), + ([2048, 2816], 1), + ([1408, 2048], 0), + ], +} + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048], + x_log=False, + line_arg="provider", + line_vals=[ + "vllm-fp8-fp16", + "vllm-fp8-bf16", + "sglang-fp8-fp16", + "sglang-fp8-bf16", + ], + line_names=[ + "vllm-fp8-fp16", + "vllm-fp8-bf16", + "sglang-fp8-fp16", + "sglang-fp8-bf16", + ], + styles=[("green", "-"), ("green", "--"), ("blue", "-"), ("blue", "--")], + ylabel="GB/s", + plot_name="fp8 scaled matmul", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + # M, N, K = batch_size, 4096, 8192 + M = batch_size + a = torch.ones((M, K), device="cuda") * 5.0 + b = torch.ones((N, K), device="cuda") * 5.0 + scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) + scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + b_fp8 = b_fp8.t() + quantiles = [0.5, 0.2, 0.8] + + dtype = torch.float16 if "fp16" in provider else torch.bfloat16 + + if "vllm-fp8" in provider: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype), + quantiles=quantiles, + ) + elif "sglang-fp8" in provider: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: sgl_scaled_mm( + a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None + ), + quantiles=quantiles, + ) + + gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + +def prepare_shapes(args): + KN_model_names = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + assert model in WEIGHT_SHAPES + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KN.append(model) + KN_model_names.append(KN) + return KN_model_names + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + help="List of models to benchmark", + ) + parser.add_argument( + "--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes", + ) + args = parser.parse_args() + + KN_model_names = prepare_shapes(args) + for K, N, model_name in KN_model_names: + print(f"{model_name} N={N} K={K}: ") + benchmark.run( + print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K + ) + + print("Benchmark finished!") diff --git a/sgl-kernel/benchmark/bench_int8_gemm.py b/sgl-kernel/benchmark/bench_int8_gemm.py index 2657c616cf34..c5a709393c11 100644 --- a/sgl-kernel/benchmark/bench_int8_gemm.py +++ b/sgl-kernel/benchmark/bench_int8_gemm.py @@ -1,3 +1,7 @@ +import argparse +import copy +import itertools + import torch import triton from sgl_kernel import int8_scaled_mm @@ -8,6 +12,56 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor: return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) +WEIGHT_SHAPES = { + "meta-llama/Llama-3.1-8B-Instruct": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-3.3-70B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], + "mistralai/Mistral-Large-Instruct-2407": [ + ([12288, 14336], 1), + ([12288, 12288], 0), + ([12288, 57344], 1), + ([28672, 12288], 0), + ], + "Qwen/Qwen2.5-7B-Instruct": [ + ([3584, 4608], 1), + ([3584, 3584], 0), + ([3584, 37888], 1), + ([18944, 3584], 0), + ], + "Qwen/Qwen2.5-32B-Instruct": [ + ([5120, 7168], 1), + ([5120, 5120], 0), + ([5120, 55296], 1), + ([27648, 5120], 0), + ], + "Qwen/Qwen2.5-72B-Instruct": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 59136], 1), + ([29568, 8192], 0), + ], + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": [ + ([2048, 3072], 1), + ([2048, 4096], 1), + ([2048, 2048], 0), + ([2048, 576], 0), + ([2048, 21888], 1), + ([10944, 2048], 0), + ([2048, 2816], 1), + ([1408, 2048], 0), + ], +} + + @triton.testing.perf_report( triton.testing.Benchmark( x_names=["batch_size"], @@ -22,8 +76,8 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor: args={}, ) ) -def benchmark(batch_size, provider): - M, N, K = batch_size, 4096, 8192 +def benchmark(batch_size, provider, N, K): + M = batch_size a = to_int8(torch.randn((M, K), device="cuda") * 5) b = to_int8(torch.randn((N, K), device="cuda").t() * 5) scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) @@ -52,4 +106,41 @@ def benchmark(batch_size, provider): return gbps(ms), gbps(max_ms), gbps(min_ms) -benchmark.run(print_data=True, show_plots=True, save_path="bench_int8_res") +def prepare_shapes(args): + KN_model_names = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + assert model in WEIGHT_SHAPES + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KN.append(model) + KN_model_names.append(KN) + return KN_model_names + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + help="List of models to benchmark", + ) + parser.add_argument( + "--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes", + ) + args = parser.parse_args() + + KN_model_names = prepare_shapes(args) + for K, N, model_name in KN_model_names: + print(f"{model_name} N={N} K={K}: ") + benchmark.run( + print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K + ) + + print("Benchmark finished!") diff --git a/sgl-kernel/benchmark/bench_lightning_attention_decode.py b/sgl-kernel/benchmark/bench_lightning_attention_decode.py new file mode 100644 index 000000000000..24872e61a4d4 --- /dev/null +++ b/sgl-kernel/benchmark/bench_lightning_attention_decode.py @@ -0,0 +1,299 @@ +import itertools +import math + +import torch +import triton +import triton.language as tl +from sgl_kernel import lightning_attention_decode + + +def next_power_of_2(n): + return 2 ** (int(math.ceil(math.log(n, 2)))) + + +@triton.jit +def _decode_kernel( + Q, + K, + V, + KV, + Out, + S, + b: tl.constexpr, + h: tl.constexpr, + n: tl.constexpr, + d: tl.constexpr, + d_original: tl.constexpr, + e: tl.constexpr, + e_original: tl.constexpr, +): + off_bh = tl.program_id(0) + off_h = off_bh % h + + qk_offset = off_bh * n * d + v_offset = off_bh * n * e + o_offset = off_bh * n * e + kv_offset = off_bh * d * e + + s = tl.load(S + off_h) + ratio = tl.exp(-s) + + d_idx = tl.arange(0, d) + e_idx = tl.arange(0, e) + + # Create masks for original dimensions + d_mask = d_idx < d_original + e_mask = e_idx < e_original + + # Load with masking + q = tl.load(Q + qk_offset + d_idx, mask=d_mask, other=0.0) + k = tl.load(K + qk_offset + d_idx, mask=d_mask, other=0.0) + v = tl.load(V + v_offset + e_idx, mask=e_mask, other=0.0) + + # Load KV with 2D masking + kv = tl.load( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + mask=(d_mask[:, None] & e_mask[None, :]), + other=0.0, + ) + + # Compute outer product using element-wise operations + k_v_prod = k[:, None] * v[None, :] + kv = ratio * kv + k_v_prod + + # Store KV with 2D masking + tl.store( + KV + kv_offset + d_idx[:, None] * e + e_idx[None, :], + kv.to(KV.dtype.element_ty), + mask=(d_mask[:, None] & e_mask[None, :]), + ) + + # Compute matrix-vector multiplication using element-wise operations and reduction + o = tl.sum(q[:, None] * kv, axis=0) + + # Store output with masking + tl.store(Out + o_offset + e_idx, o.to(Out.dtype.element_ty), mask=e_mask) + + +def triton_lightning_attn_decode(q, k, v, kv, s): + """Triton implementation of Lightning Attention decode operation""" + b, h, n, d = q.shape + e = v.shape[-1] + assert n == 1, "Sequence length must be 1 in decode mode" + + # Get padded dimensions (power of 2) + d_padded = next_power_of_2(d) + e_padded = next_power_of_2(e) + + # Create output tensor (padded) + o_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + + # Create padded tensors without actually padding the data + q_padded = torch.empty(b, h, n, d_padded, dtype=q.dtype, device=q.device) + k_padded = torch.empty(b, h, n, d_padded, dtype=k.dtype, device=k.device) + v_padded = torch.empty(b, h, n, e_padded, dtype=v.dtype, device=v.device) + kv_padded = torch.empty( + b, h, d_padded, e_padded, dtype=torch.float32, device=kv.device + ) + + # Copy data to padded tensors + q_padded[..., :d] = q + k_padded[..., :d] = k + v_padded[..., :e] = v + kv_padded[..., :d, :e] = kv + + # Launch kernel + grid = (b * h, 1) + _decode_kernel[grid]( + q_padded, + k_padded, + v_padded, + kv_padded, + o_padded, + s, + b=b, + h=h, + n=n, + d=d_padded, + d_original=d, + e=e_padded, + e_original=e, + ) + + # Get unpadded outputs + o = o_padded[..., :e] + kv_out = kv_padded[..., :d, :e] + + return o, kv_out + + +def lightning_attention_decode_naive(q, k, v, past_kv, slope): + """Naive implementation of lightning attention decode""" + original_dtype = q.dtype + ratio = torch.exp(-slope) # [h, 1, 1] + + kv = past_kv + b, h, n, d = q.shape + + output = [] + for i in range(n): + kv = ratio * kv.to(torch.float32) + torch.einsum( + "... n d, ... n e -> ... d e", + k[:, :, i : i + 1], + v[:, :, i : i + 1], + ) + qkv = torch.einsum( + "... n e, ... e d -> ... n d", + q[:, :, i : i + 1].to(torch.float32), + kv.to(torch.float32), + ) + output.append(qkv) + output = torch.concat(output, dim=-2) + + return output.to(original_dtype), kv + + +def lightning_attention_decode_kernel(q, k, v, past_kv, slope, output, new_kv): + return lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv) + + +def calculate_diff(batch_size): + dtype = torch.bfloat16 + device = torch.device("cuda") + num_heads = 64 + head_dim = 96 + seq_len = 1 + + q = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + k = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device) + slope = torch.randn(num_heads, 1, 1, device=device) + + output_naive, new_kv_naive = lightning_attention_decode_naive( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ) + + output_kernel = torch.empty_like(output_naive) + new_kv_kernel = torch.empty_like(new_kv_naive) + lightning_attention_decode_kernel( + q.clone(), + k.clone(), + v.clone(), + past_kv.clone(), + slope.clone(), + output_kernel, + new_kv_kernel, + ) + + output_triton, new_kv_triton = triton_lightning_attn_decode( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ) + + if ( + torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2) + and torch.allclose(output_naive, output_triton, atol=1e-2, rtol=1e-2) + and torch.allclose(new_kv_naive, new_kv_kernel, atol=1e-2, rtol=1e-2) + and torch.allclose(new_kv_naive, new_kv_triton, atol=1e-2, rtol=1e-2) + ): + print("✅ All implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [i for i in range(1, 65)] # 1 to 128 +configs = [(bs,) for bs in batch_size_range] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["naive", "kernel", "triton"], + line_names=["PyTorch Naive", "SGL Kernel", "Triton"], + styles=[("blue", "-"), ("red", "-"), ("green", "-")], + ylabel="us", + plot_name="lightning-attention-decode-performance", + args={}, + ) +) +def benchmark(batch_size, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + num_heads = 64 + head_dim = 96 + seq_len = 1 + + q = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + k = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + v = torch.randn( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + past_kv = torch.randn(batch_size, num_heads, head_dim, head_dim, device=device) + slope = torch.randn(num_heads, 1, 1, device=device) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "naive": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: lightning_attention_decode_naive( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ), + quantiles=quantiles, + ) + elif provider == "kernel": + output = torch.empty( + batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype + ) + new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device) + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: lightning_attention_decode_kernel( + q.clone(), + k.clone(), + v.clone(), + past_kv.clone(), + slope.clone(), + output, + new_kv, + ), + quantiles=quantiles, + ) + elif provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: triton_lightning_attn_decode( + q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/lightning_attention_decode_sgl/", + help="Path to save lightning attention decode benchmark results", + ) + args = parser.parse_args() + + # Run correctness test + calculate_diff(batch_size=4) + + # Run performance benchmark + benchmark.run(print_data=True) diff --git a/sgl-kernel/benchmark/bench_sampling_scaling_penalties.py b/sgl-kernel/benchmark/bench_sampling_scaling_penalties.py new file mode 100644 index 000000000000..000dab0d8e9a --- /dev/null +++ b/sgl-kernel/benchmark/bench_sampling_scaling_penalties.py @@ -0,0 +1,159 @@ +import itertools + +import torch +import triton +from sgl_kernel import sampling_scaling_penalties + + +def sampling_scaling_penalties_naive(logits, scaling_penalties): + return torch.where( + logits > 0, logits / scaling_penalties, logits * scaling_penalties + ) + + +def sampling_scaling_penalties_kernel(logits, scaling_penalties): + return sampling_scaling_penalties(logits, scaling_penalties) + + +def test_memory(func, _iter): + total_mem = [] + + for _ in range(_iter): + torch.cuda.memory.reset_peak_memory_stats() + func() + mem = torch.cuda.max_memory_allocated() / (2**20) + total_mem.append(mem) + + return sum(total_mem) / len(total_mem) + + +def calculate_diff(batch_size, vocab_size): + dtype = torch.bfloat16 + device = torch.device("cuda") + + logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype) + scaling_penalties = ( + torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5 + ) + + output_naive = sampling_scaling_penalties_naive( + logits.clone(), scaling_penalties.clone() + ) + output_kernel = sampling_scaling_penalties_kernel( + logits.clone(), scaling_penalties.clone() + ) + + print(f"Naive output={output_naive}") + print(f"Kernel output={output_kernel}") + + if torch.allclose(output_naive, output_kernel, atol=1e-2, rtol=1e-2): + print("✅ Both implementations match") + else: + print("❌ Implementations differ") + + +batch_size_range = [2**i for i in range(0, 12)] +vocab_size_range = [2**i for i in range(10, 17)] +configs = list(itertools.product(batch_size_range, vocab_size_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "vocab_size"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["naive", "kernel"], + line_names=["PyTorch Naive", "SGL Kernel"], + styles=[("blue", "-"), ("red", "-")], + ylabel="us", + plot_name="sampling-scaling-penalties-performance", + args={}, + ) +) +def benchmark(batch_size, vocab_size, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + + logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype) + scaling_penalties = ( + torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5 + ) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "naive": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: sampling_scaling_penalties_naive( + logits.clone(), + scaling_penalties.clone(), + ), + quantiles=quantiles, + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: sampling_scaling_penalties_kernel( + logits.clone(), + scaling_penalties.clone(), + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size", "vocab_size"], + x_vals=[list(_) for _ in configs], + line_arg="provider", + line_vals=["naive", "kernel"], + line_names=["PyTorch Naive", "SGL Kernel"], + styles=[("blue", "-"), ("red", "-")], + ylabel="GPU memory usage (MB)", + plot_name="sampling-scaling-penalties-memory", + args={}, + ) +) +def benchmark_memory(batch_size, vocab_size, provider): + dtype = torch.bfloat16 + device = torch.device("cuda") + + print( + f"Running memory benchmark with batch_size={batch_size}, vocab_size={vocab_size}, provider={provider}" + ) + + def run_kernel(): + logits = torch.randn(batch_size, vocab_size, device=device, dtype=dtype) + scaling_penalties = ( + torch.rand(batch_size, vocab_size, device=device, dtype=dtype) + 0.5 + ) + + if provider == "naive": + return sampling_scaling_penalties_naive(logits, scaling_penalties) + else: + return sampling_scaling_penalties_kernel(logits, scaling_penalties) + + mem = test_memory(run_kernel, _iter=10) + return mem + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--save_path", + type=str, + default="./configs/benchmark_ops/sampling_scaling_penalties/", + help="Path to save sampling_scaling_penalties benchmark results", + ) + args = parser.parse_args() + + # Run correctness test + calculate_diff(batch_size=4, vocab_size=4096) + + # Run performance benchmark + benchmark.run(print_data=True, save_path=args.save_path) + + # Run memory benchmark + benchmark_memory.run(print_data=True, save_path=args.save_path) diff --git a/sgl-kernel/build.sh b/sgl-kernel/build.sh index 799b724dfe6e..1caa892bc845 100755 --- a/sgl-kernel/build.sh +++ b/sgl-kernel/build.sh @@ -4,13 +4,24 @@ PYTHON_VERSION=$1 CUDA_VERSION=$2 PYTHON_ROOT_PATH=/opt/python/cp${PYTHON_VERSION//.}-cp${PYTHON_VERSION//.} +if (( ${CUDA_VERSION%.*} < 12 )); then + ENABLE_SM90A=0 +else + ENABLE_SM90A=1 +fi + docker run --rm \ -v "$(pwd)":/sgl-kernel \ pytorch/manylinux-builder:cuda${CUDA_VERSION} \ bash -c " - ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.4.0 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \ + ${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \ export CUDA_VERSION=${CUDA_VERSION} && \ + export SGL_KERNEL_ENABLE_BF16=1 && \ + export SGL_KERNEL_ENABLE_FP8=1 && \ + export SGL_KERNEL_ENABLE_SM90A=${ENABLE_SM90A} && \ + mkdir -p /usr/lib/x86_64-linux-gnu/ && \ + ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \ cd /sgl-kernel && \ ${PYTHON_ROOT_PATH}/bin/python setup.py bdist_wheel " diff --git a/sgl-kernel/developer_guide.md b/sgl-kernel/developer_guide.md new file mode 100644 index 000000000000..26426d90d8a3 --- /dev/null +++ b/sgl-kernel/developer_guide.md @@ -0,0 +1,55 @@ +# Developer Guide for sgl-kernel + +## Development Environment Setup + +Use Docker to set up the development environment. See [Docker setup guide](https://github.com/sgl-project/sglang/blob/main/docs/developer/development_guide_using_docker.md#setup-docker-container). + +Create and enter development container: +```bash +docker run -itd --shm-size 32g --gpus all -v $HOME/.cache:/root/.cache --ipc=host --name sglang_zhyncs lmsysorg/sglang:dev /bin/zsh +docker exec -it sglang_zhyncs /bin/zsh +``` + +## Project Structure + +### Dependencies + +Third-party libraries: + +- [CCCL](https://github.com/NVIDIA/cccl) +- [CUTLASS](https://github.com/NVIDIA/cutlass) +- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) +- [TurboMind](https://github.com/InternLM/turbomind) + +### Kernel Development + +Steps to add a new kernel: + +1. Implement in [src/sgl-kernel/csrc/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/src/sgl-kernel/csrc) +2. Expose interface in [src/sgl-kernel/include/sgl_kernel_ops.h](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/include/sgl_kernel_ops.h) +3. Create torch extension in [src/sgl-kernel/torch_extension.cc](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/torch_extension.cc) +4. Create Python wrapper in [src/sgl-kernel/ops/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/ops/__init__.py) +5. Expose Python interface in [src/sgl-kernel/\_\_init\_\_.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/src/sgl-kernel/__init__.py) +6. Update [setup.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/setup.py) to include new CUDA source + +### Build & Install + +Development build: + +```bash +make build +``` + +Note: + +The `sgl-kernel` is rapidly evolving. If you experience a compilation failure, try using `make rebuild`. + +### Testing & Benchmarking + +1. Add pytest tests in [tests/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/tests) +2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark) +3. Run test suite + +### Release new version + +Update version in [pyproject.toml](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/pyproject.toml) and [version.py](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/version.py) diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index 359ffafd70d2..b23c302b564c 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -4,22 +4,20 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.0.2.post11" +version = "0.0.2.post17" description = "Kernel Library for SGLang" readme = "README.md" -requires-python = ">=3.8" +requires-python = ">=3.9" license = { file = "LICENSE" } classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", "Environment :: GPU :: NVIDIA CUDA" ] -dependencies = [ - "torch", -] +dependencies = [] [project.urls] -"Homepage" = "https://github.com/sgl-project/sglang" +"Homepage" = "https://github.com/sgl-project/sglang/tree/main/sgl-kernel" "Bug Tracker" = "https://github.com/sgl-project/sglang/issues" [tool.setuptools] diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index c93e87f6bad3..c8469dc1c0e2 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -1,34 +1,65 @@ +import os from pathlib import Path -from setuptools import setup +import torch +from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension root = Path(__file__).parent.resolve() -def get_version(): +def _update_wheel_platform_tag(): + wheel_dir = Path("dist") + if wheel_dir.exists() and wheel_dir.is_dir(): + old_wheel = next(wheel_dir.glob("*.whl")) + new_wheel = wheel_dir / old_wheel.name.replace( + "linux_x86_64", "manylinux2014_x86_64" + ) + old_wheel.rename(new_wheel) + + +def _get_cuda_version(): + if torch.version.cuda: + return tuple(map(int, torch.version.cuda.split("."))) + return (0, 0) + + +def _get_device_sm(): + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + return major * 10 + minor + return 0 + + +def _get_version(): with open(root / "pyproject.toml") as f: for line in f: if line.startswith("version"): return line.split("=")[1].strip().strip('"') -def update_wheel_platform_tag(): - wheel_dir = Path("dist") - old_wheel = next(wheel_dir.glob("*.whl")) - new_wheel = wheel_dir / old_wheel.name.replace( - "linux_x86_64", "manylinux2014_x86_64" - ) - old_wheel.rename(new_wheel) - - -cutlass = root / "3rdparty" / "cutlass" +operator_namespace = "sgl_kernels" +cutlass_default = root / "3rdparty" / "cutlass" +cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default)) +flashinfer = root / "3rdparty" / "flashinfer" +turbomind = root / "3rdparty" / "turbomind" include_dirs = [ cutlass.resolve() / "include", cutlass.resolve() / "tools" / "util" / "include", + root / "src" / "sgl-kernel" / "include", root / "src" / "sgl-kernel" / "csrc", + flashinfer.resolve() / "include", + flashinfer.resolve() / "include" / "gemm", + flashinfer.resolve() / "csrc", + "cublas", + "cublasLt", + turbomind.resolve(), + turbomind.resolve() / "src", ] + nvcc_flags = [ + "-DNDEBUG", + f"-DOPERATOR_NAMESPACE={operator_namespace}", "-O3", "-Xcompiler", "-fPIC", @@ -36,22 +67,76 @@ def update_wheel_platform_tag(): "-gencode=arch=compute_80,code=sm_80", "-gencode=arch=compute_89,code=sm_89", "-gencode=arch=compute_90,code=sm_90", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF2_OPERATORS__", + "-std=c++17", + "-use_fast_math", + "-DFLASHINFER_ENABLE_F16", +] +nvcc_flags_fp8 = [ + "-DFLASHINFER_ENABLE_FP8", + "-DFLASHINFER_ENABLE_FP8_E4M3", + "-DFLASHINFER_ENABLE_FP8_E5M2", ] + +sources = [ + "src/sgl-kernel/torch_extension.cc", + "src/sgl-kernel/csrc/trt_reduce_internal.cu", + "src/sgl-kernel/csrc/trt_reduce_kernel.cu", + "src/sgl-kernel/csrc/moe_align_kernel.cu", + "src/sgl-kernel/csrc/int8_gemm_kernel.cu", + "src/sgl-kernel/csrc/fp8_gemm_kernel.cu", + "src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu", + "src/sgl-kernel/csrc/rotary_embedding.cu", + "3rdparty/flashinfer/csrc/activation.cu", + "3rdparty/flashinfer/csrc/bmm_fp8.cu", + "3rdparty/flashinfer/csrc/group_gemm.cu", + "3rdparty/flashinfer/csrc/norm.cu", + "3rdparty/flashinfer/csrc/sampling.cu", + "3rdparty/flashinfer/csrc/renorm.cu", +] + +enable_bf16 = os.getenv("SGL_KERNEL_ENABLE_BF16", "0") == "1" +enable_fp8 = os.getenv("SGL_KERNEL_ENABLE_FP8", "0") == "1" +enable_sm90a = os.getenv("SGL_KERNEL_ENABLE_SM90A", "0") == "1" +cuda_version = _get_cuda_version() +sm_version = _get_device_sm() + +if torch.cuda.is_available(): + if cuda_version >= (12, 0) and sm_version >= 90: + nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") + sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu") + if sm_version >= 90: + nvcc_flags.extend(nvcc_flags_fp8) + if sm_version >= 80: + nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") +else: + # compilation environment without GPU + if enable_sm90a: + nvcc_flags.append("-gencode=arch=compute_90a,code=sm_90a") + sources.append("3rdparty/flashinfer/csrc/group_gemm_sm90.cu") + if enable_fp8: + nvcc_flags.extend(nvcc_flags_fp8) + if enable_bf16: + nvcc_flags.append("-DFLASHINFER_ENABLE_BF16") + +for flag in [ + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", +]: + try: + torch.utils.cpp_extension.COMMON_NVCC_FLAGS.remove(flag) + except ValueError: + pass + cxx_flags = ["-O3"] -libraries = ["c10", "torch", "torch_python"] -extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib"] +libraries = ["c10", "torch", "torch_python", "cuda", "cublas", "cublasLt"] +extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] + ext_modules = [ CUDAExtension( name="sgl_kernel.ops._kernels", - sources=[ - "src/sgl-kernel/csrc/trt_reduce_internal.cu", - "src/sgl-kernel/csrc/trt_reduce_kernel.cu", - "src/sgl-kernel/csrc/moe_align_kernel.cu", - "src/sgl-kernel/csrc/int8_gemm_kernel.cu", - "src/sgl-kernel/csrc/sgl_kernel_ops.cu", - ], + sources=sources, include_dirs=include_dirs, extra_compile_args={ "nvcc": nvcc_flags, @@ -59,17 +144,18 @@ def update_wheel_platform_tag(): }, libraries=libraries, extra_link_args=extra_link_args, + py_limited_api=True, ), ] setup( name="sgl-kernel", - version=get_version(), - packages=["sgl_kernel"], + version=_get_version(), + packages=find_packages(), package_dir={"": "src"}, ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, - install_requires=["torch"], + options={"bdist_wheel": {"py_limited_api": "cp39"}}, ) -update_wheel_platform_tag() +_update_wheel_platform_tag() diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 892808f1ee15..df141dee1d07 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -1,15 +1,51 @@ from sgl_kernel.ops import ( + bmm_fp8, custom_dispose, custom_reduce, + fp8_scaled_mm, + fused_add_rmsnorm, + gelu_and_mul, + gelu_tanh_and_mul, + gemma_fused_add_rmsnorm, + gemma_rmsnorm, + get_graph_buffer_ipc_meta, init_custom_reduce, int8_scaled_mm, + lightning_attention_decode, + min_p_sampling_from_probs, moe_align_block_size, + register_graph_buffers, + rmsnorm, + rotary_embedding, + sampling_scaling_penalties, + silu_and_mul, + top_k_renorm_prob, + top_k_top_p_sampling_from_probs, + top_p_renorm_prob, ) __all__ = [ - "moe_align_block_size", - "init_custom_reduce", + "bmm_fp8", "custom_dispose", "custom_reduce", + "fp8_scaled_mm", + "fused_add_rmsnorm", + "gelu_and_mul", + "gelu_tanh_and_mul", + "gemma_fused_add_rmsnorm", + "gemma_rmsnorm", + "get_graph_buffer_ipc_meta", + "init_custom_reduce", "int8_scaled_mm", + "lightning_attention_decode", + "min_p_sampling_from_probs", + "moe_align_block_size", + "register_graph_buffers", + "rmsnorm", + "rotary_embedding", + "sampling_scaling_penalties", + "silu_and_mul", + "top_k_renorm_prob", + "top_k_top_p_sampling_from_probs", + "top_p_renorm_prob", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h index a9deeb9a7da7..c83cf49ad830 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h @@ -3,11 +3,8 @@ #pragma once -#include "cutlass/arch/memory.h" -#include "cutlass/arch/memory_sm75.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/numeric_conversion.h" +#include +#include namespace cutlass { namespace epilogue { diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h index 10be552a8ec2..33e82decc2b2 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_universal_base_compat.h @@ -2,16 +2,9 @@ // https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h #pragma once -#include "cutlass/arch/arch.h" -#include "cutlass/cutlass.h" -#include "cutlass/device_kernel.h" -#include "cutlass/gemm/device/default_gemm_configuration.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/gemm/kernel/default_gemm_universal.h" -#include "cutlass/gemm/kernel/gemm_universal.h" -#include "cutlass/gemm/threadblock/threadblock_swizzle.h" -#include "cutlass/numeric_types.h" -#include "cutlass/trace.h" +#include +#include +#include //////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h index cf0b9cfa3e97..674e191a077f 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h +++ b/sgl-kernel/src/sgl-kernel/csrc/cutlass_extensions/gemm/gemm_with_epilogue_visitor.h @@ -3,14 +3,11 @@ #pragma once -#include "cutlass/complex.h" -#include "cutlass/cutlass.h" -#include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" -#include "cutlass/matrix_coord.h" -#include "cutlass/semaphore.h" -#include "cutlass/trace.h" -#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h" +#include +#include +#include +#include +#include ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu new file mode 100644 index 000000000000..3e33e143c0ce --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu @@ -0,0 +1,624 @@ +// Adapted from +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm90.h + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" + +using namespace cute; + +#if defined CUDA_VERSION && CUDA_VERSION >= 12040 +template typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT, + typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>> +struct DeviceGemmFp8RowwiseSm89 { + static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); + + using ElementA = ElementType; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementType; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementC = OutElementType; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementOutput = OutElementType; + using LayoutOutput = cutlass::layout::RowMajor; + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = AccumElementType; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm89; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + // Number of epilogue stages in EVT + static constexpr int EVTEpilogueStages = 1; + + using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout; + + // Definition of EVT + using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch; + + using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementComputeEpilogue, ElementComputeEpilogue, cutlass::FloatRoundStyle::round_to_nearest>; + using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast>; + using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeAScale = + cutlass::epilogue::threadblock::VisitorCompute; + using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast>; + using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT; + + // With bias + using biasSrc = + cutlass::epilogue::threadblock::VisitorRowBroadcast>; + using ComputeAScaleWithBias = + cutlass::epilogue::threadblock::VisitorCompute; + using EpilogueAScaleWithBias = + cutlass::epilogue::threadblock::Sm80EVT; + + using dTar = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementC, cutlass::FloatRoundStyle::round_to_nearest, Stride>; + using EpilogueStore = + typename cutlass::platform::conditional, + cutlass::epilogue::threadblock::Sm80EVT>::type; + + using EpilogueOp = EpilogueStore; + + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, ElementB, LayoutB, + cutlass::ComplexTransform::kNone, AlignmentB, ElementC, LayoutC, AlignmentC, ElementAccumulator, + ElementComputeEpilogue, OperatorClass, ArchTag, CtaShape, WarpShape, InstructionShape, EpilogueOp, + ThreadblockSwizzle, Stages, FP8MathOperator, EVTEpilogueStages>::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementT = typename Gemm::ElementA; + using ElementOutput = typename Gemm::ElementD; + using ElementComputeEpilogue = float; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + ElementT const* ptr_a = reinterpret_cast(a.data_ptr()); + ElementT const* ptr_b = reinterpret_cast(b.data_ptr()); + ElementOutput const* ptr_bias = nullptr; + if constexpr (WithBias) { + TORCH_CHECK(bias.has_value()) + ptr_bias = reinterpret_cast(bias.value().data_ptr()); + } + ElementOutput* ptr_d = reinterpret_cast(out.data_ptr()); + ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr()); + ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr()); + + typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, // Mode + {m, n, k}, // Problem size + 1, // Split-k factor + {}, // Epilogue args + ptr_a, // a pointer + ptr_b, // b pointer + nullptr, // c pointer (unused) + nullptr, // d pointer (unused) + m * k, // batch stride a (unused) + n * k, // batch stride b (unused) + m * n, // batch stride c (unused) + m * n, // batch stride d (unused) + lda, // stride a + ldb, // stride b + ldc, // stride c (unused) + ldc); // stride d (unused) + if constexpr (WithBias) { + args.epilogue = {{ + { + {}, // Accumulator + {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, + {ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_d, {n, _1{}, _0{}}}}; + } else { + args.epilogue = {{ + { + {}, // Accumulator + {ptr_scales_b, ElementComputeEpilogue(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, + {} // Multiplies + }, + {ptr_d, {n, _1{}, _0{}}}}; + } + + return args; +} + +template +void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + auto args = prepare_sm89_fp8_args(out, a, b, scales_a, scales_b, bias); + Gemm gemm_op; + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess) + + auto status = gemm_op(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess) +} + +template +void sm89_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementInput = cutlass::float_e4m3_t; + using ElementOutput = OutType; + using AccumElementType = float; + if (bias) { + using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else { + using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } +} + +template +void sm89_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + uint32_t const m = a.size(0); + uint32_t const n = out.size(1); + + if (m == 1) { + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 16) { + // M in (1, 16] + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } else if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 64) { + // M in (16, 64] + if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 128) { + // M in (64, 128] + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } else if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 256) { + // M in (128, 256] + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 7>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 128>, 4>(out, a, b, scales_a, scales_b, bias); + } + } else if (m <= 512) { + // M in (256, 512) + if (n <= 16384) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } + } else { + // M in (512, inf) + if (n <= 8192) { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 3>(out, a, b, scales_a, scales_b, bias); + } else { + return sm89_fp8_dispatch_bias, + cutlass::gemm::GemmShape<64, 32, 64>, 2>(out, a, b, scales_a, scales_b, bias); + } + } +} +#endif + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 +template +struct DeviceGemmFp8RowwiseSm90 { + static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); + + // A matrix configuration + using ElementA = ElementType; // Element type for A matrix operand + using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using ElementB = ElementType; // Element type for B matrix operand + using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using ElementC = void; // Element type for C matrix operands + using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrices in + // units of elements (up to 16 bytes) + + // Output matrix configuration + using ElementOutput = OutElementType; // Element type for output matrix operands + using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + // // Auxiliary matrix configuration and other fusion types + // using ElementBias = float; + + // Multiply-accumulate blocking/pipelining details + using ElementAccumulator = AccumElementType; // Element type for internal accumulation + using ElementCompute = float; // Element type for compute + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = CTAShape; // Threadblock-level tile size + + static constexpr bool PONG = false; + static constexpr bool FAST_ACCUM = true; + static constexpr bool USE_BIAS = false; + + using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized + // based on the tile size + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default + // setting in the Collective Builder + // Implement rowwise scaling epilogue. + using XScale = + cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, + cute::Stride, cute::Int<0>, cute::Int<0>>>; + + using WScale = + cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT; + + // With bias + using ComputeWithBias = + cutlass::epilogue::fusion::Sm90Compute; + using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = typename cutlass::platform::conditional::type; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementComputeEpilogue, ElementC, LayoutC, + AlignmentC, ElementOutput, LayoutOutput, AlignmentOutput, cutlass::epilogue::TmaWarpSpecialized, + EpilogueEVT>::CollectiveOp; + + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; + using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using FastDefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + + using SlowAccum = DefaultSchedule; + using FastAccum = FastPongSchedule; // Default apply Pingpong + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA, AlignmentA, ElementB, LayoutB, AlignmentB, ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduleType>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementT = typename Gemm::ElementA; + using ElementOutput = typename Gemm::ElementD; + using ElementComputeEpilogue = float; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + ElementT const* ptr_a = reinterpret_cast(a.data_ptr()); + ElementT const* ptr_b = reinterpret_cast(b.data_ptr()); + ElementOutput const* ptr_bias = nullptr; + if constexpr (WithBias) { + TORCH_CHECK(bias.has_value()) + ptr_bias = reinterpret_cast(bias.value().data_ptr()); + } + ElementOutput* ptr_d = reinterpret_cast(out.data_ptr()); + ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr()); + ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr()); + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1)); + StrideC stride_c; + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1)); + typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {ptr_a, stride_a, ptr_b, stride_b}, + {{}, // epilogue.thread + nullptr, + stride_c, + ptr_d, + stride_d}}; + if constexpr (WithBias) { + args.epilogue.thread = { + {ptr_scales_a}, + { + {ptr_scales_b}, + {}, // Accumulator + {} // Multiplies + }, + {ptr_bias}, + {}, // Multiplies + }; + } else { + args.epilogue.thread = { + {ptr_scales_a}, + { + {ptr_scales_b}, + {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }; + } + + return args; +} + +template +void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + auto args = prepare_sm90_fp8_args(out, a, b, scales_a, scales_b, bias); + Gemm gemm_op; + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess) + + auto status = gemm_op.run(args, workspace.data_ptr(), stream); + + TORCH_CHECK(status == cutlass::Status::kSuccess) +} + +template +void sm90_fp8_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias, bool fast_accum = true, + bool use_persistent = false) { + using ElementInput = cutlass::float_e4m3_t; + using ElementOutput = OutType; + using AccumElementType = float; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; + + if (bias) { + using Gemm = + typename DeviceGemmFp8RowwiseSm90::Gemm; + return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else { + using Gemm = + typename DeviceGemmFp8RowwiseSm90::Gemm; + return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } +} + +template +void sm90_fp8_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + uint32_t const m = a.size(0); + using FastPingpongScheduler = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + using FastBasicScheduler = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using PersistentTileScheduler = cutlass::gemm::PersistentScheduler; + using BasicTileScheduler = void; + if (m <= 1) { + return sm90_fp8_dispatch_bias, Shape<_1, _8, _1>, FastBasicScheduler, + BasicTileScheduler>(out, a, b, scales_a, scales_b, bias); + } + if (m <= 64) { + // m in [1, 64] + return sm90_fp8_dispatch_bias, Shape<_1, _4, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } else if (m <= 256) { + // m in (64, 256] + return sm90_fp8_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } else if (m <= 1024) { + // m in (256, 1024] + return sm90_fp8_dispatch_bias, Shape<_1, _1, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } else { + // m in (1024, inf) + return sm90_fp8_dispatch_bias, Shape<_2, _1, _1>, FastPingpongScheduler, + PersistentTileScheduler>(out, a, b, scales_a, scales_b, bias); + } +} +#endif + +torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias) { + TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); + TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); + TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); + TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor"); + TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); + TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); + TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); + + TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0, + "mat_a must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0, + "mat_b must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); + TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); + TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); + + TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched"); + TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched"); + TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous"); + TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous"); + TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32"); + TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32"); + + if (bias) { + TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched"); + TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous"); + TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype"); + } + + torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype)); + TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment"); + + auto sm_version = getSMVersion(); + +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + if (sm_version >= 90) { + if (out_dtype == torch::kBFloat16) { + sm90_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm90_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } + return out; + } +#endif + +#if defined CUDA_VERSION && CUDA_VERSION >= 12040 + if (sm_version == 89) { + if (out_dtype == torch::kBFloat16) { + sm89_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm89_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } + return out; + } +#endif + + TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented fp8_scaled_mm for current compute capability: ", sm_version); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu index cce32c2d894a..c77851c32b61 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/int8_gemm_kernel.cu @@ -3,12 +3,22 @@ #include #include #include +#include #include +#include +#include +#include +#include +#include +#include + #include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h" #include "cutlass_extensions/gemm/gemm_universal_base_compat.h" #include "cutlass_extensions/gemm/gemm_with_epilogue_visitor.h" -#include "utils.hpp" +#include "utils.h" + +using namespace cute; template @@ -166,6 +176,186 @@ void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const t } } +template +void cutlass_int8_scaled_mm_sm90(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + using ArchTag = cutlass::arch::Sm90; + + using ElementAccumulator = int32_t; + using ElementCompute = float; + using ElementInputA = int8_t; + using ElementInputB = int8_t; + + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; + using TileSchedulerType = cutlass::gemm::PersistentScheduler; + + using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementCompute, ElementCompute, + Stride, Int<0>, Int<0>>>; + + using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementCompute, ElementCompute, + Stride, Int<1>, Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, + Stride, Int<1>, Int<0>>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + // Scale + using Compute0 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT; + + // With bias + using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute; + using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = typename cutlass::platform::conditional::type; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, ElementOutput, cutlass::layout::RowMajor, AlignmentC, ElementOutput, + cutlass::layout::RowMajor, AlignmentOutput, EpilogueScheduleType, EpilogueEVT>::CollectiveOp; + + using Stages = cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementInputA, cutlass::layout::RowMajor, AlignmentA, ElementInputB, + cutlass::layout::ColumnMajor, AlignmentB, ElementAccumulator, TileShape, ClusterShape, Stages, + MainloopScheduleType>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileSchedulerType>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + Gemm gemm_op; + + int m = mat_a.size(0); + int k = mat_a.size(1); + int n = mat_b.size(1); + + auto a_ptr = static_cast(mat_a.data_ptr()); + auto b_ptr = static_cast(mat_b.data_ptr()); + auto o_ptr = static_cast(out.data_ptr()); + + auto a_s_ptr = static_cast(scales_a.data_ptr()); + auto b_s_ptr = static_cast(scales_b.data_ptr()); + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1)); + StrideC stride_c; + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1)); + + typename Gemm::Arguments args = {cutlass::gemm::GemmUniversalMode::kGemm, + {m, n, k, 1}, + {a_ptr, stride_a, b_ptr, stride_b}, + {{}, // epilogue.thread + nullptr, + stride_c, + o_ptr, + stride_d}}; + + if constexpr (WithBias) { + ElementOutput* bias_ptr = static_cast(bias->data_ptr()); + args.epilogue.thread = { + {a_s_ptr}, + {{b_s_ptr}, {}, {}}, + {bias_ptr}, + {}, + }; + } else { + args.epilogue.thread = { + {a_s_ptr}, + {{b_s_ptr}, {}, {}}, + {}, + }; + } + + auto workspace = torch::empty(gemm_op.get_workspace_size(args), + torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device())); + + auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess, + "gemm cannot implement, error: ", cutlassGetStatusString(can_implement)); + + auto status = gemm_op(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status)); +} + +template +void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + if (bias) { + cutlass_int8_scaled_mm_sm90( + out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + cutlass_int8_scaled_mm_sm90( + out, mat_a, mat_b, scales_a, scales_b, bias); + } +} + +template +void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias) { + int m = mat_a.size(0); + int n = mat_b.size(1); + if (m <= 32) { + if (n < 8192) { + return sm90_dispatch_bias, Shape<_1, _8, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + return sm90_dispatch_bias, Shape<_1, _8, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (m <= 64) { + if (n < 8192) { + return sm90_dispatch_bias, Shape<_1, _4, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + return sm90_dispatch_bias, Shape<_1, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (m <= 128) { + if (n <= 4096) { + return sm90_dispatch_bias, Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + return sm90_dispatch_bias, Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecialized>(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else { + return sm90_dispatch_bias, Shape<_2, _1, _1>, + cutlass::gemm::KernelTmaWarpSpecializedPingpong>(out, mat_a, mat_b, scales_a, scales_b, + bias); + } +} + torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const c10::optional& bias) { @@ -204,7 +394,24 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma TORCH_CHECK(out_dtype == torch::kHalf, "out_dtype must be Half for SM75"); sm75_dispatch_shape>( out, mat_a, mat_b, scales_a, scales_b, bias); - } else if (sm_version >= 80 && sm_version <= 90) { + } else if (sm_version >= 80 && sm_version < 90) { + if (out_dtype == torch::kBFloat16) { + sm80_dispatch_shape>( + out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm80_dispatch_shape>( + out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (sm_version == 90) { +#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + // cutlass 3.x + if (out_dtype == torch::kBFloat16) { + sm90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } +#else + // fallback to cutlass 2.x if (out_dtype == torch::kBFloat16) { sm80_dispatch_shape>( out, mat_a, mat_b, scales_a, scales_b, bias); @@ -212,6 +419,7 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma sm80_dispatch_shape>( out, mat_a, mat_b, scales_a, scales_b, bias); } +#endif } else { TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented int8_scaled_mm for current compute capability."); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu new file mode 100644 index 000000000000..e62a154cb183 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/lightning_attention_decode_kernel.cu @@ -0,0 +1,118 @@ +#include +#include +#include +#include +#include +#include + +#define THREADS_PER_BLOCK 128 + +template +__global__ void lightning_attention_decode_kernel(const T* __restrict__ q, // [b, h, 1, d] + const T* __restrict__ k, // [b, h, 1, d] + const T* __restrict__ v, // [b, h, 1, e] + const float* __restrict__ past_kv, // [b, h, d, e] + const float* __restrict__ slope, // [h, 1, 1] + T* __restrict__ output, // [b, h, 1, e] + float* __restrict__ new_kv, // [b, h, d, e] + const int batch_size, const int num_heads, const int qk_dim, + const int v_dim) { + extern __shared__ char smem[]; + T* q_shared = reinterpret_cast(smem); + T* k_shared = reinterpret_cast(smem + qk_dim * sizeof(T)); + T* v_shared = reinterpret_cast(smem + 2 * qk_dim * sizeof(T)); + float* new_kv_shared = reinterpret_cast(smem + (2 * qk_dim + v_dim) * sizeof(T)); + T* output_shared = + reinterpret_cast(smem + (2 * qk_dim + v_dim) * sizeof(T) + qk_dim * (v_dim + 1) * sizeof(float)); + + const int32_t tid = threadIdx.x; + const int32_t current_head = blockIdx.x; + const int32_t b = current_head / num_heads; + const int32_t h = current_head % num_heads; + + if (b >= batch_size) return; + + const int32_t qk_offset = b * num_heads * qk_dim + h * qk_dim; + const int32_t v_offset = b * num_heads * v_dim + h * v_dim; + const int32_t kv_offset = b * num_heads * qk_dim * v_dim + h * qk_dim * v_dim; + + for (int d = tid; d < qk_dim; d += blockDim.x) { + q_shared[d] = q[qk_offset + d]; + k_shared[d] = k[qk_offset + d]; + } + for (int e = tid; e < v_dim; e += blockDim.x) { + v_shared[e] = v[v_offset + e]; + } + + __syncthreads(); + + const float ratio = expf(-1.0f * slope[h]); + + for (int d = tid; d < qk_dim; d += blockDim.x) { + T k_val = k_shared[d]; + for (int e = 0; e < v_dim; ++e) { + int past_kv_idx = kv_offset + d * v_dim + e; + T v_val = v_shared[e]; + float new_val = ratio * past_kv[past_kv_idx] + k_val * v_val; + int shared_idx = d * (v_dim + 1) + e; + new_kv_shared[shared_idx] = new_val; + } + } + + __syncthreads(); + + for (int idx = tid; idx < qk_dim * v_dim; idx += blockDim.x) { + int d = idx / v_dim; + int e = idx % v_dim; + int shared_idx = d * (v_dim + 1) + e; + int global_idx = kv_offset + idx; + new_kv[global_idx] = new_kv_shared[shared_idx]; + } + + __syncthreads(); + + for (int e = tid; e < v_dim; e += blockDim.x) { + float sum = 0.0f; + for (int d = 0; d < qk_dim; ++d) { + int shared_idx = d * (v_dim + 1) + e; + sum += q_shared[d] * new_kv_shared[shared_idx]; + } + output_shared[e] = static_cast(sum); + } + + __syncthreads(); + + if (tid == 0) { + for (int e = 0; e < v_dim; ++e) { + output[v_offset + e] = output_shared[e]; + } + } +} + +void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, + const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, + torch::Tensor new_kv) { + TORCH_CHECK(q.is_contiguous(), "q must be contiguous"); + TORCH_CHECK(k.is_contiguous(), "k must be contiguous"); + TORCH_CHECK(v.is_contiguous(), "v must be contiguous"); + TORCH_CHECK(past_kv.is_contiguous(), "past_kv must be contiguous"); + + auto batch_size = q.size(0); + auto num_heads = q.size(1); + auto qk_dim = q.size(3); + auto v_dim = v.size(3); + + dim3 block(THREADS_PER_BLOCK); + dim3 grid(batch_size * num_heads); + + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, q.scalar_type(), "lightning_attention_decode_kernel", ([&] { + size_t smem_size = (2 * qk_dim + 2 * v_dim) * sizeof(scalar_t) + qk_dim * (v_dim + 1) * sizeof(float); + lightning_attention_decode_kernel<<>>( + q.data_ptr(), k.data_ptr(), v.data_ptr(), past_kv.data_ptr(), + slope.data_ptr(), output.data_ptr(), new_kv.data_ptr(), batch_size, num_heads, + qk_dim, v_dim); + })); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu index c7faf9d37758..19e9850b51a9 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu @@ -3,28 +3,14 @@ #include #include #include +#include #include -#include "utils.hpp" - -#ifdef USE_ROCM -#include -#endif - -#ifndef USE_ROCM #define WARP_SIZE 32 -#else -#define WARP_SIZE warpSize -#endif -#ifndef USE_ROCM #define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) -#else -#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ - hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) -#endif #define CEILDIV(x, y) (((x) + (y)-1) / (y)) @@ -39,7 +25,6 @@ AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) __device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, int32_t col) { - // don't worry about overflow because num_experts is relatively small return row * total_col + col; } diff --git a/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu b/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu new file mode 100644 index 000000000000..1dd4c4c52440 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/rotary_embedding.cu @@ -0,0 +1,119 @@ +// Reference: https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu + +#include +#include +#include + +template +inline __device__ void apply_token_rotary_embedding(scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, int rot_offset, + int embed_dim) { + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = __ldg(cos_ptr + x_index); + sin = __ldg(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = __ldg(cos_ptr + x_index / 2); + sin = __ldg(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +inline __device__ void apply_rotary_embedding(scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* cache_ptr, const int head_size, const int num_heads, + const int num_kv_heads, const int rot_dim, const int token_idx, + const int64_t query_stride, const int64_t key_stride) { + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } + + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } +} + +template +__global__ void rotary_embedding_kernel(const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding(query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); +} + +void rotary_embedding(torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int64_t head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox) { + int64_t num_tokens = query.numel() / query.size(-1); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t query_stride = query.stride(-2); + int64_t key_stride = key.stride(-2); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, at::ScalarType::Half, query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + rotary_embedding_kernel + <<>>(positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, + query_stride, key_stride, num_heads, num_kv_heads, head_size); + } else { + rotary_embedding_kernel + <<>>(positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, + query_stride, key_stride, num_heads, num_kv_heads, head_size); + } + }); +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu deleted file mode 100644 index 6ed543e6c542..000000000000 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ /dev/null @@ -1,29 +0,0 @@ -#include "utils.hpp" - -// trt_reduce -using fptr_t = int64_t; -fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector& buffers, - const std::vector& barrier_in, const std::vector& barrier_out); -void dispose(fptr_t _fa); -void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); - -// moe_align_block_size -void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, - torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, - torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer); - -// int8_scaled_mm -torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, - const torch::Tensor& scales_b, const torch::Dtype& out_dtype, - const c10::optional& bias); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - // trt_reduce - m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); - m.def("dispose", &dispose, "dispose custom allreduce meta"); - m.def("all_reduce", &all_reduce, "custom all reduce (CUDA)"); - // moe_align_block_size - m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)"); - // int8_scaled_mm - m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)"); -} diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu index a6f2d5216a18..2ee0c98c91e1 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cu @@ -26,6 +26,7 @@ #include #include "trt_reduce_internal.cuh" +#include "utils.h" //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -126,10 +127,10 @@ __inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const __syncthreads(); } +template __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank, - size_t const world_size, int const tidx, int const bidx, int const grid_size, - bool start = true, bool need_fence = false) { - if (!start) { + size_t const world_size, int const tidx, int const bidx, int const grid_size) { + if constexpr (!start) { __syncthreads(); } // After this function, the block of id == bidx of each GPU has reached the barrier @@ -141,22 +142,16 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag // Block broadcast its flag (local_rank on emitting dimension) to all receivers uint32_t flag_block_offset = world_size + bidx * world_size; - if (flag % 2 == 1) { - flag_block_offset += (grid_size + 1) * world_size; - } + flag_block_offset += (grid_size + 1) * world_size * (flag % 2); - if (need_fence) { - st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank); - } else { - st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank); - } - // Blocks check that corresponding blocks on other GPUs have also set the flag uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx; - - if (need_fence) { + // Blocks check that corresponding blocks on other GPUs have also set the flag + if constexpr (need_fence) { + st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank); while (ld_flag_acquire(peer_barrier_d) != flag) { } } else { + st_flag_volatile(flag, signals[tidx] + flag_block_offset + local_rank); while (ld_flag_volatile(peer_barrier_d) != flag) { } } @@ -165,8 +160,8 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag __syncthreads(); } -template /* COPY_INPUT = false, PUSH_MODE = false */ -static __global__ void oneShotAllReduceKernel(AllReduceParams params) { +template +static __global__ void __launch_bounds__(512, 1) oneShotAllReduceKernel(AllReduceParams params) { // Suppose that two GPUs participate in the AR exchange, and we start four blocks. // The message is partitioned into chunks as detailed below: // message @@ -193,6 +188,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { int const bidx = blockIdx.x; int const tidx = threadIdx.x; + int const grid_size = gridDim.x; // The number of elements packed into one for comms static constexpr int NUM_ELTS = 16 / sizeof(T); @@ -201,18 +197,23 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { using PackedStruct = typename PackedOn16Bytes::Type; // The source pointers. Distributed round-robin for the different warps. - T const* buffers[RANKS_PER_NODE]; - + auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs; + T* local_shared_buffer = reinterpret_cast(peer_comm_buffer_ptrs[params.local_rank]); // Start and end offsets of the thread size_t chunk_start = bidx * params.elts_per_block + tidx * NUM_ELTS; size_t chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank); -#pragma unroll - for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - int rank = (params.local_rank + ii) % RANKS_PER_NODE; - buffers[ii] = reinterpret_cast(params.peer_comm_buffer_ptrs[rank]); - } - multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx); + if constexpr (COPY_INPUT) { + T const* local_input_buffer = reinterpret_cast(params.local_input_buffer_ptr); + // Copy from local buffer to shareable buffer + for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) { + *reinterpret_cast(&local_shared_buffer[iter_offset]) = + *reinterpret_cast(&local_input_buffer[iter_offset]); + } + } + // wait for equivalent blocks of other GPUs to have copied data to their shareable buffer + block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, + grid_size); // Each block accumulates the values from the different GPUs on the same node. for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * NUM_ELTS) { @@ -220,7 +221,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { PackedStruct vals[RANKS_PER_NODE]; #pragma unroll for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - vals[ii].packed = *reinterpret_cast(&buffers[ii][iter_offset]); + vals[ii].packed = *reinterpret_cast(&((T*)peer_comm_buffer_ptrs[ii])[iter_offset]); } // Sum the values from the different ranks. @@ -229,8 +230,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { #pragma unroll for (int rank = 0; rank < RANKS_PER_NODE; ++rank) { // Always reduce from rank 0 to ensure stable reduce order. - int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE; - sums.packed = add128b(sums, vals[ii]); + sums.packed = add128b(sums, vals[rank]); } // Store to the destination buffer. @@ -238,7 +238,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params) { } } -template +template static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduceParams params) { // Suppose that two GPUs participate in the AR exchange, and we start two blocks. // The message is partitioned into chunks as detailed below: @@ -286,20 +286,24 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc static constexpr int PACKED_ELTS = 16 / sizeof(T); using PackedType = typename PackedOn16Bytes::Type; - T* local_shared_buffer = reinterpret_cast(params.peer_comm_buffer_ptrs[params.local_rank]); + T const* local_input_buffer = reinterpret_cast(params.local_input_buffer_ptr); + auto peer_comm_buffer_ptrs = params.peer_comm_buffer_ptrs->ptrs; + T* local_shared_buffer = reinterpret_cast(peer_comm_buffer_ptrs[params.local_rank]); T* local_output_buffer = reinterpret_cast(params.local_output_buffer_ptr); size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS; size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank); T* buffers[RANKS_PER_NODE]; + T* buffers_unorder[RANKS_PER_NODE]; int ranks[RANKS_PER_NODE]; #pragma unroll for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { // A mapping of the ranks to scatter reads as much as possible int rank = (params.local_rank + ii) % RANKS_PER_NODE; ranks[ii] = rank; - buffers[ii] = reinterpret_cast(params.peer_comm_buffer_ptrs[rank]); + buffers[ii] = reinterpret_cast(peer_comm_buffer_ptrs[rank]); + buffers_unorder[ii] = reinterpret_cast(peer_comm_buffer_ptrs[ii]); } #if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12)) @@ -308,8 +312,22 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc #endif #endif - block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, - grid_size); + if constexpr (COPY_INPUT) { + // Copy all blocks from local buffer to shareable buffer + for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { +#pragma unroll + for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { + size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset; + if (offset_rank >= params.elts_total) { + continue; + } + *reinterpret_cast(&local_shared_buffer[offset_rank]) = + *reinterpret_cast(&local_input_buffer[offset_rank]); + } + } + } + block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, + grid_size); // Each block accumulates the values from the different GPUs on the same node. for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { @@ -319,7 +337,7 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc PackedType vals[RANKS_PER_NODE]; #pragma unroll for (int ii = 0; ii < RANKS_PER_NODE; ++ii) { - vals[ii].packed = *reinterpret_cast(&buffers[ii][responsible_block_offset]); + vals[ii].packed = *reinterpret_cast(&buffers_unorder[ii][responsible_block_offset]); } // Sum the values from the different ranks. @@ -328,16 +346,19 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc #pragma unroll for (int rank = 0; rank < RANKS_PER_NODE; ++rank) { // Always reduce from rank 0 to ensure stable reduce order. - int ii = (rank + RANKS_PER_NODE - params.local_rank) % RANKS_PER_NODE; - sums.packed = add128b(sums, vals[ii]); + sums.packed = add128b(sums, vals[rank]); } - // Store to the local buffer. - *reinterpret_cast(&local_shared_buffer[responsible_block_offset]) = sums.packed; + // Store to the local buffer or tmp buffer + if constexpr (COPY_INPUT) { + *reinterpret_cast(&local_shared_buffer[responsible_block_offset]) = sums.packed; + } else { + *reinterpret_cast(¶ms.tmp_result_buffers[params.local_rank][responsible_block_offset]) = sums.packed; + } } - block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx, - grid_size, false, true); + block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, + bidx, grid_size); // Gather all needed elts from other intra-node ranks for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS) { @@ -348,8 +369,13 @@ static __global__ void __launch_bounds__(512, 1) twoShotAllReduceKernel(AllReduc if (offset_rank >= params.elts_total) { continue; } - - *reinterpret_cast(&local_output_buffer[offset_rank]) = *reinterpret_cast(&buffers[ii][offset_rank]); + if constexpr (COPY_INPUT) { + *reinterpret_cast(&local_output_buffer[offset_rank]) = + *reinterpret_cast(&buffers[ii][offset_rank]); + } else { + *reinterpret_cast(&local_output_buffer[offset_rank]) = + *reinterpret_cast(¶ms.tmp_result_buffers[ranks[ii]][offset_rank]); + } } } #if (defined(__CUDACC_VER_MAJOR__) && (__CUDACC_VER_MAJOR__ >= 12)) @@ -417,48 +443,50 @@ std::tuple kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template void dispatchARKernels(AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block, cudaStream_t stream) { switch (algo) { case AllReduceStrategyType::ONESHOT: { - oneShotAllReduceKernel<<>>(param); + oneShotAllReduceKernel<<>>(param); break; } case AllReduceStrategyType::TWOSHOT: { - twoShotAllReduceKernel<<>>(param); + twoShotAllReduceKernel<<>>(param); break; } } } -template -void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) { - void* buffer = reinterpret_cast(param.peer_comm_buffer_ptrs[param.rank]); - void* local_inp_buffer = param.local_input_buffer_ptr; - CHECK_CUDA_SUCCESS( - cudaMemcpyAsync(buffer, local_inp_buffer, param.elts_total * param.elts_size, cudaMemcpyDeviceToDevice, stream)); - - CHECK_CUDA_SUCCESS(cudaGetLastError()); - +template +void dispatchARKernelsCopyInput(AllReduceStrategyType strat, AllReduceParams& param, cudaStream_t stream) { size_t elts_per_thread = 16 / sizeof(T); auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread); switch (param.ranks_per_node) { case 2: - dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); break; case 4: - dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); break; case 6: - dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); break; case 8: - dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); + dispatchARKernels(strat, param, blocks_per_grid, threads_per_block, stream); break; default: break; } +} + +template +void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream) { + if (param.is_capturing) { + dispatchARKernelsCopyInput(strat, param, stream); + } else { + dispatchARKernelsCopyInput(strat, param, stream); + } CHECK_CUDA_SUCCESS(cudaGetLastError()); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu index 59b548c77e9e..fd0483e39eed 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_kernel.cu @@ -3,34 +3,53 @@ #include #include -#include -#include -#include #include "trt_reduce_internal.cuh" +#include "utils.h" using namespace trt_llm; using fptr_t = int64_t; +using IPC_KEY = std::array; class AllReduceMeta { public: - AllReduceMeta(int64_t rank_id, int64_t world_size, const std::vector& buffers, - const std::vector& barrier_in, const std::vector& barrier_out) { + AllReduceMeta(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, + const std::vector& tmp_result_buffers, const std::vector& barrier_in, + const std::vector& barrier_out) { this->rank_id = (int)rank_id; this->world_size = (int)world_size; - this->buffers = buffers; this->barrier_in = barrier_in; this->barrier_out = barrier_out; + this->tmp_result_buffers = tmp_result_buffers; + + this->rank_data_base = reinterpret_cast(rank_data.data_ptr()); + RankData data; + for (int i = 0; i < world_size; i++) { + data.ptrs[i] = (void*)buffers[i]; + } + auto d_data = this->rank_data_base++; + CHECK_CUDA_SUCCESS(cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); + this->buffers = d_data; + } + + ~AllReduceMeta() { + for (auto [_, ptr] : ipc_handles_) { + CHECK_CUDA_SUCCESS(cudaIpcCloseMemHandle(ptr)); + } } public: int world_size; int rank_id; - std::vector buffers; std::vector barrier_in; std::vector barrier_out; + std::vector tmp_result_buffers; int barrier_flag = 1; + RankData* buffers; + RankData* rank_data_base; + std::vector graph_unreg_buffers; + std::map ipc_handles_; }; // Get the number of bits for a given data type. @@ -52,9 +71,10 @@ inline bool CanApplyCustomAllReduce(int64_t num_elements, at::ScalarType dtype) return num_elements % (16 / ((get_bits(dtype) + 7) / 8)) == 0; } -fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, const std::vector& buffers, - const std::vector& barrier_in, const std::vector& barrier_out) { - auto m = new AllReduceMeta(rank_id, world_size, buffers, barrier_in, barrier_out); +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, + const std::vector& tmp_result_buffers, const std::vector& barrier_in, + const std::vector& barrier_out) { + auto m = new AllReduceMeta(rank_id, world_size, rank_data, buffers, tmp_result_buffers, barrier_in, barrier_out); return (fptr_t)m; } @@ -63,6 +83,75 @@ void dispose(fptr_t _fa) { delete fa; } +std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa) { + AllReduceMeta* m = reinterpret_cast(_fa); + auto num_buffers = m->graph_unreg_buffers.size(); + auto handle_sz = sizeof(cudaIpcMemHandle_t); + std::string handles(handle_sz * num_buffers, static_cast(0)); + std::vector offsets(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto ptr = m->graph_unreg_buffers[i]; + void* base_ptr; + // note: must share the base address of each allocation, or we get wrong + // address + if (cuPointerGetAttribute(&base_ptr, CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, (CUdeviceptr)ptr) != CUDA_SUCCESS) { + assert(false && "failed to get pointer attr"); + } + + CHECK_CUDA_SUCCESS(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char*)ptr) - ((char*)base_ptr); + } + std::vector bytes(handles.begin(), handles.end()); + return std::make_pair(bytes, offsets); +} + +char* open_ipc_handle(AllReduceMeta* meta, const void* ipc_handle) { + auto [it, new_handle] = meta->ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); + if (new_handle) { + char* ipc_ptr; + CHECK_CUDA_SUCCESS(cudaIpcOpenMemHandle((void**)&ipc_ptr, *((const cudaIpcMemHandle_t*)ipc_handle), + cudaIpcMemLazyEnablePeerAccess)); + it->second = ipc_ptr; + } + return it->second; +} + +// Note: when registering graph buffers, we intentionally choose to not +// deduplicate the addresses. That means if the allocator reuses some +// addresses, they will be registered again. This is to account for the remote +// possibility of different allocation patterns between ranks. For example, +// rank 1 may get the same input address for the second allreduce, but rank 2 +// got a different address. IPC handles have internal reference counting +// mechanism so overhead should be small. +void register_graph_buffers(fptr_t _fa, const std::vector>& handles, + const std::vector>& offsets) { + AllReduceMeta* m = reinterpret_cast(_fa); + std::vector handle_bytes; + handle_bytes.reserve(handles.size()); + for (int i = 0; i < handles.size(); i++) { + handle_bytes.emplace_back(handles[i].begin(), handles[i].end()); + } + auto num_buffers = m->graph_unreg_buffers.size(); + std::vector rank_data(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto self_ptr = m->graph_unreg_buffers[i]; + auto& rd = rank_data[i]; + for (int j = 0; j < m->world_size; j++) { + if (j != m->rank_id) { + char* handle = open_ipc_handle(m, &handle_bytes[j][i * sizeof(cudaIpcMemHandle_t)]); + handle += offsets[j][i]; + rd.ptrs[j] = handle; + } else { + rd.ptrs[j] = self_ptr; + } + } + } + CHECK_CUDA_SUCCESS( + cudaMemcpy(m->rank_data_base, rank_data.data(), sizeof(RankData) * num_buffers, cudaMemcpyHostToDevice)); + m->rank_data_base += num_buffers; + m->graph_unreg_buffers.clear(); +} + void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { AllReduceMeta* m = reinterpret_cast(_fa); auto stream = c10::cuda::getCurrentCUDAStream().stream(); @@ -87,8 +176,18 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { params.elts_size = inp.element_size(); params.barrier_flag = ++(m->barrier_flag); + cudaStreamCaptureStatus status; + CHECK_CUDA_SUCCESS(cudaStreamIsCapturing(stream, &status)); + params.is_capturing = (status == cudaStreamCaptureStatusActive); + if (params.is_capturing) { + params.peer_comm_buffer_ptrs = m->rank_data_base + m->graph_unreg_buffers.size(); + m->graph_unreg_buffers.push_back(params.local_input_buffer_ptr); + } else { + params.peer_comm_buffer_ptrs = m->buffers; + } + for (int i = 0; i < world_size; ++i) { - params.peer_comm_buffer_ptrs[i] = reinterpret_cast(m->buffers[i]); + params.tmp_result_buffers[i] = reinterpret_cast(m->tmp_result_buffers[i]); } for (int i = 0; i < world_size; ++i) { params.peer_barrier_ptrs_in[i] = reinterpret_cast(m->barrier_in[i]); diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h new file mode 100644 index 000000000000..93c53c1e9e4d --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -0,0 +1,114 @@ +#pragma once + +#include +#include + +#include + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } + +// trt_reduce +using fptr_t = int64_t; +fptr_t init_custom_ar(int64_t rank_id, int64_t world_size, torch::Tensor& rank_data, const std::vector& buffers, + const std::vector& tmp_result_buffers, const std::vector& barrier_in, + const std::vector& barrier_out); +void dispose(fptr_t _fa); +void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); +std::tuple, std::vector> get_graph_buffer_ipc_meta(fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector>& handles, + const std::vector>& offsets); + +// moe_align_block_size +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, + torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, + torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer); + +// int8_scaled_mm +torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias); + +// fp8_scaled_mm +torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias); + +// lightning_attention_decode +void lightning_attention_decode(const torch::Tensor& q, const torch::Tensor& k, const torch::Tensor& v, + const torch::Tensor& past_kv, const torch::Tensor& slope, torch::Tensor output, + torch::Tensor new_kv); + +// rotary embedding +void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, torch::Tensor& key, int64_t head_size, + torch::Tensor& cos_sin_cache, bool is_neox); + +// rms norm +void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); + +// fused rms norm +void fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream); + +// gemma rms norm +void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream); + +// fused gemma rms norm +void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, + int64_t cuda_stream); + +// silu and mul +void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +// gelu tanh and mul +void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +// gelu and mul +void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream); + +// bmm fp8 +void bmm_fp8(at::Tensor A, at::Tensor B, at::Tensor D, at::Tensor A_scale, at::Tensor B_scale, + at::Tensor workspace_buffer, int64_t cublas_handle, int64_t cuda_stream); + +// min p sampling from probs +void min_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + std::optional maybe_min_p_arr, double min_p_val, bool deterministic, + int64_t cuda_stream); + +// top k renorm probs +// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. +void top_k_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_k_arr, + unsigned int top_k_val, int64_t cuda_stream); + +// patch here, cause flashinfer use unsigned int. but torch must use int64_t for extension. +// wrapper for binding +inline void top_k_renorm_probs_wrapper(at::Tensor probs, at::Tensor renorm_probs, + std::optional maybe_top_k_arr, int64_t top_k_val, + int64_t cuda_stream) { + top_k_renorm_probs(probs, renorm_probs, maybe_top_k_arr, static_cast(top_k_val), cuda_stream); +} + +// top p renorm probs +void top_p_renorm_probs(at::Tensor probs, at::Tensor renorm_probs, std::optional maybe_top_p_arr, + double top_p_val, int64_t cuda_stream); + +// top k top p sampling from probs +void top_k_top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, + at::Tensor success, std::optional maybe_top_k_arr, double top_k_val, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic, + int64_t cuda_stream); + +// top p sampling from probs +void top_p_sampling_from_probs(at::Tensor probs, at::Tensor uniform_samples, at::Tensor samples, at::Tensor success, + std::optional maybe_top_p_arr, double top_p_val, bool deterministic, + int64_t cuda_stream); diff --git a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh similarity index 93% rename from sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh rename to sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh index 1c7c714dc4a8..46522348aafa 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/trt_reduce_internal.cuh +++ b/sgl-kernel/src/sgl-kernel/include/trt_reduce_internal.cuh @@ -17,12 +17,11 @@ */ #pragma once + #include #include #include -#include "utils.hpp" - namespace trt_llm { constexpr size_t WARP_SIZE = 32; constexpr size_t MAX_ALL_REDUCE_BLOCKS = 36; @@ -36,6 +35,10 @@ enum class AllReduceStrategyType : int8_t { AUTO = 3, }; +struct RankData { + void* ptrs[MAX_RANKS_PER_NODE]; +}; + struct AllReduceParams { size_t elts_size; size_t elts_total; @@ -46,9 +49,11 @@ struct AllReduceParams { uint32_t barrier_flag; uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE]; uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE]; - void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE]; + uint32_t* tmp_result_buffers[MAX_RANKS_PER_NODE]; + RankData* peer_comm_buffer_ptrs; void* local_input_buffer_ptr; void* local_output_buffer_ptr; + bool is_capturing; }; inline size_t GetMaxRequiredWorkspaceSize(int world_size) { diff --git a/sgl-kernel/src/sgl-kernel/csrc/utils.hpp b/sgl-kernel/src/sgl-kernel/include/utils.h similarity index 56% rename from sgl-kernel/src/sgl-kernel/csrc/utils.hpp rename to sgl-kernel/src/sgl-kernel/include/utils.h index 2fed2d60c039..55594f7b2733 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/utils.hpp +++ b/sgl-kernel/src/sgl-kernel/include/utils.h @@ -1,4 +1,7 @@ #pragma once + +#include +#include #include #include @@ -44,3 +47,20 @@ inline int getSMVersion() { CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); return sm_major * 10 + sm_minor; } + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + case at::ScalarType::Float: { \ + using c_type = float; \ + return __VA_ARGS__(); \ + } \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index e388ae35653b..ced0dafa9d79 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -1,20 +1,37 @@ -from sgl_kernel.ops._kernels import all_reduce as _all_reduce -from sgl_kernel.ops._kernels import dispose as _dispose -from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar -from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm -from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size +import os +from typing import Optional, Tuple, Union +import sgl_kernel.ops._kernels +import torch +from sgl_kernel.ops.utils import ( + _get_cache_buf, + _get_cuda_stream, + _to_tensor_scalar_tuple, +) -def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out): - return _init_custom_ar(rank_id, num_devices, buffers, barrier_in, barrier_out) + +def init_custom_reduce( + rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out +): + return torch.ops.sgl_kernels.init_custom_ar( + rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out + ) def custom_dispose(fa): - _dispose(fa) + torch.ops.sgl_kernels.dispose(fa) def custom_reduce(fa, inp, out): - _all_reduce(fa, inp, out) + torch.ops.sgl_kernels.all_reduce(fa, inp, out) + + +def get_graph_buffer_ipc_meta(fa): + return torch.ops.sgl_kernels.get_graph_buffer_ipc_meta(fa) + + +def register_graph_buffers(fa, handles, offsets): + torch.ops.sgl_kernels.register_graph_buffers(fa, handles, offsets) def moe_align_block_size( @@ -27,7 +44,7 @@ def moe_align_block_size( token_cnts_buffer, cumsum_buffer, ): - _moe_align_block_size( + torch.ops.sgl_kernels.moe_align_block_size( topk_ids, num_experts, block_size, @@ -39,8 +56,23 @@ def moe_align_block_size( ) +def sampling_scaling_penalties(logits, scaling_penalties): + return torch.ops.sgl_kernels.sampling_scaling_penalties(logits, scaling_penalties) + + def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): - return _int8_scaled_mm( + return torch.ops.sgl_kernels.int8_scaled_mm( + mat_a, + mat_b, + scales_a, + scales_b, + out_dtype, + bias, + ) + + +def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): + return torch.ops.sgl_kernels.fp8_scaled_mm( mat_a, mat_b, scales_a, @@ -48,3 +80,372 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): out_dtype, bias, ) + + +def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv): + torch.ops.sgl_kernels.lightning_attention_decode( + q, k, v, past_kv, slope, output, new_kv + ) + + +def rotary_embedding(positions, query, key, head_size, cos_sin_cache, is_neox): + return torch.ops.sgl_kernels.rotary_embedding( + positions, query, key, head_size, cos_sin_cache, is_neox + ) + + +# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer +# Kudos to @yzh119 +def rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + with input.device as device: + if out is None: + out = torch.empty_like(input) + torch.ops.sgl_kernels.rmsnorm(out, input, weight, eps, _get_cuda_stream(device)) + return out + + +def fused_add_rmsnorm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> None: + with input.device as device: + torch.ops.sgl_kernels.fused_add_rmsnorm( + input, residual, weight, eps, _get_cuda_stream(device) + ) + + +def gemma_rmsnorm( + input: torch.Tensor, + weight: torch.Tensor, + eps: float = 1e-6, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + with input.device as device: + if out is None: + out = torch.empty_like(input) + torch.ops.sgl_kernels.gemma_rmsnorm( + out, input, weight, eps, _get_cuda_stream(device) + ) + return out + + +def gemma_fused_add_rmsnorm( + input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6 +) -> None: + with input.device as device: + torch.ops.sgl_kernels.gemma_fused_add_rmsnorm( + input, residual, weight, eps, _get_cuda_stream(device) + ) + + +def _check_shape(input: torch.Tensor, output: torch.Tensor) -> None: + assert input.ndim == output.ndim, f"{input.ndim} != {output.ndim}" + assert ( + input.shape[:-1] == output.shape[:-1] + ), f"{input.shape[:-1]} != {output.shape[:-1]}" + assert ( + input.shape[-1] == 2 * output.shape[-1] + ), f"{input.shape[-1]} != {2 * output.shape[-1]}" + + +def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + with input.device as device: + torch.ops.sgl_kernels.silu_and_mul(out, input, _get_cuda_stream(device)) + return out + + +def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + with input.device as device: + torch.ops.sgl_kernels.gelu_tanh_and_mul(out, input, _get_cuda_stream(device)) + return out + + +def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor: + if input.shape[-1] * input.dtype.itemsize % 16 != 0: + raise ValueError("The pointers must be multiple of 16 bytes.") + if out is not None: + _check_shape(input, out) + else: + out = torch.empty( + input.shape[:-1] + (input.shape[-1] // 2,), + device=input.device, + dtype=input.dtype, + ) + with input.device as device: + torch.ops.sgl_kernels.gelu_and_mul(out, input, _get_cuda_stream(device)) + return out + + +def _bmm_fp8_internal( + workspace_buffer: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + D: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, +) -> None: + with A.device as device: + cublas_handle = torch.cuda.current_blas_handle() + torch.ops.sgl_kernels.bmm_fp8( + A, + B, + D, + A_scale, + B_scale, + workspace_buffer, + cublas_handle, + _get_cuda_stream(device), + ) + + +def bmm_fp8( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + dtype: torch.dtype, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + if out is None: + out = torch.empty( + (A.shape[0], A.shape[1], B.shape[2]), + device=A.device, + dtype=dtype, + ) + workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device) + _bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale) + return out + + +def _top_k_renorm_probs_internal( + probs: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, +) -> torch.Tensor: + with probs.device as device: + probs = probs.float() + maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + renorm_probs = torch.empty_like(probs) + torch.ops.sgl_kernels.top_k_renorm_probs_wrapper( + probs, + renorm_probs, + maybe_top_k_arr, + top_k_val, + _get_cuda_stream(device), + ) + return renorm_probs + + +def top_k_renorm_probs( + probs: torch.Tensor, + top_k: Union[torch.Tensor, int], +) -> torch.Tensor: + return _top_k_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_k)) + + +top_k_renorm_prob = top_k_renorm_probs + + +def _top_p_renorm_probs_internal( + probs: torch.Tensor, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, +) -> torch.Tensor: + with probs.device as device: + probs = probs.float() + maybe_top_p_arr = ( + maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + ) + renorm_probs = torch.empty_like(probs) + torch.ops.sgl_kernels.top_p_renorm_probs( + probs, + renorm_probs, + maybe_top_p_arr, + top_p_val, + _get_cuda_stream(device), + ) + return renorm_probs + + +def top_p_renorm_probs( + probs: torch.Tensor, + top_p: Union[torch.Tensor, float], +) -> torch.Tensor: + return _top_p_renorm_probs_internal(probs, *_to_tensor_scalar_tuple(top_p)) + + +top_p_renorm_prob = top_p_renorm_probs + + +def _top_p_sampling_from_probs_internal( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_top_p_arr = ( + maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + success = torch.empty(probs.size(0), dtype=torch.bool, device=device) + torch.ops.sgl_kernels.top_p_sampling_from_probs( + probs, + uniform_samples, + samples, + success, + maybe_top_p_arr, + top_p_val, + deterministic, + _get_cuda_stream(device), + ) + return samples, success + + +def top_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + top_p: Union[torch.Tensor, float], + deterministic: bool = True, + check_nan: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") + return _top_p_sampling_from_probs_internal( + probs, uniform_samples, *_to_tensor_scalar_tuple(top_p), deterministic + ) + + +def _top_k_top_p_sampling_from_probs_internal( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_top_k_arr: Optional[torch.Tensor], + top_k_val: int, + maybe_top_p_arr: Optional[torch.Tensor], + top_p_val: float, + deterministic: bool, +) -> Tuple[torch.Tensor, torch.Tensor]: + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None + maybe_top_p_arr = ( + maybe_top_p_arr.float() if maybe_top_p_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + success = torch.empty(probs.size(0), dtype=torch.bool, device=device) + torch.ops.sgl_kernels.top_k_top_p_sampling_from_probs( + probs, + uniform_samples, + samples, + success, + maybe_top_k_arr, + top_k_val, + maybe_top_p_arr, + top_p_val, + deterministic, + _get_cuda_stream(device), + ) + return samples, success + + +def top_k_top_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + top_k: Union[torch.Tensor, int], + top_p: Union[torch.Tensor, float], + filter_apply_order: str = "top_k_first", + deterministic: bool = True, + check_nan: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + if filter_apply_order == "top_k_first": + renorm_probs = top_k_renorm_probs(probs, top_k) + return top_p_sampling_from_probs( + renorm_probs, uniform_samples, top_p, deterministic, check_nan=check_nan + ) + elif filter_apply_order == "joint": + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") + return _top_k_top_p_sampling_from_probs_internal( + probs, + uniform_samples, + *_to_tensor_scalar_tuple(top_k), + *_to_tensor_scalar_tuple(top_p), + deterministic, + ) + else: + raise ValueError(f"Invalid filter_apply_order: {filter_apply_order}") + + +def _min_p_sampling_from_probs_internal( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + maybe_min_p_arr: Optional[torch.Tensor], + min_p_val: float, + deterministic: bool, +) -> torch.Tensor: + with probs.device as device: + probs = probs.float() + uniform_samples = uniform_samples.float() + maybe_min_p_arr = ( + maybe_min_p_arr.float() if maybe_min_p_arr is not None else None + ) + samples = torch.empty(probs.size(0), dtype=torch.int32, device=device) + torch.ops.sgl_kernels.min_p_sampling_from_probs( + probs, + uniform_samples, + samples, + maybe_min_p_arr, + min_p_val, + deterministic, + _get_cuda_stream(device), + ) + return samples + + +def min_p_sampling_from_probs( + probs: torch.Tensor, + uniform_samples: torch.Tensor, + min_p: Union[torch.Tensor, float], + deterministic: bool = True, + check_nan: bool = False, +) -> torch.Tensor: + if uniform_samples.dim() == 2: + # Take the first row (round) of uniform_samples + uniform_samples = uniform_samples[0] + + if check_nan: + if torch.any(torch.isnan(probs)): + raise ValueError("Input probs contains NaN.") + return _min_p_sampling_from_probs_internal( + probs, uniform_samples, *_to_tensor_scalar_tuple(min_p), deterministic + ) diff --git a/sgl-kernel/src/sgl-kernel/ops/utils.py b/sgl-kernel/src/sgl-kernel/ops/utils.py new file mode 100644 index 000000000000..31a6bbf9919d --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/ops/utils.py @@ -0,0 +1,26 @@ +from typing import Dict, Tuple + +import torch + + +def _get_cuda_stream(device: torch.device) -> int: + return torch.cuda.current_stream(device).cuda_stream + + +_cache_buf: Dict[Tuple[str, torch.device], torch.Tensor] = {} + + +def _get_cache_buf(name: str, bytes: int, device: torch.device) -> torch.Tensor: + key = (name, device) + buf = _cache_buf.get(key) + if buf is None: + buf = torch.empty(bytes, dtype=torch.uint8, device=device) + _cache_buf[key] = buf + return buf + + +def _to_tensor_scalar_tuple(x): + if isinstance(x, torch.Tensor): + return (x, 0) + else: + return (None, x) diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc new file mode 100644 index 000000000000..caf4f1269b6b --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -0,0 +1,121 @@ + +#include +#include + +#include "sgl_kernels_ops.h" + +TORCH_LIBRARY_EXPAND(sgl_kernels, m) { + // trt_reduce + m.def( + "init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] " + "barrier_in, int[] barrier_out) -> int"); + m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); + + m.def("dispose", &dispose); + + m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()"); + m.impl("all_reduce", torch::kCUDA, &all_reduce); + + m.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])"); + m.impl("get_graph_buffer_ipc_meta", torch::kCUDA, &get_graph_buffer_ipc_meta); + + m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()"); + m.impl("register_graph_buffers", torch::kCUDA, ®ister_graph_buffers); + + // moe_align_block_size + m.def( + "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " + "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); + m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + + // int8_scaled_mm + m.def( + "int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " + "bias) -> Tensor"); + m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm); + + // fp8_scaled_mm + m.def( + "fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? " + "bias) -> Tensor"); + m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm); + + // lightning_attention_decode + m.def( + "lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! " + "new_kv) -> ()"); + m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode); + + // rotary embedding + m.def( + "rotary_embedding(Tensor positions, Tensor! query, Tensor! key, int head_size, Tensor cos_sin_cache, bool " + "is_neox) -> ()"); + m.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); + + // rms norm + m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("rmsnorm", torch::kCUDA, &rmsnorm); + + // fused rms norm + m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("fused_add_rmsnorm", torch::kCUDA, &fused_add_rmsnorm); + + // gemma rms norm + m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm); + + // fused gemma rms norm + m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()"); + m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm); + + // silu and mul + m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + + // gelu tanh and mul + m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul); + + // gelu and mul + m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()"); + m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul); + + // bmm fp8 + m.def( + "bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int " + "cublas_handle, int cuda_stream) -> ()"); + m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8); + + // min p sampling from probs + m.def( + "min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float " + "min_p_val, bool deterministic, int cuda_stream) -> ()"); + m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs); + + // top k renorm probs + m.def( + "top_k_renorm_probs_wrapper(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int " + "cuda_stream) -> ()"); + m.impl("top_k_renorm_probs_wrapper", torch::kCUDA, &top_k_renorm_probs_wrapper); + + // top p renorm probs + m.def( + "top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int " + "cuda_stream) -> ()"); + m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs); + + // top k top p sampling from probs + m.def( + "top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " + "maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int " + "cuda_stream) -> ()"); + m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs); + + // top p sampling from probs + m.def( + "top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? " + "maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()"); + m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs); +} + +REGISTER_EXTENSION(_kernels) diff --git a/sgl-kernel/tests/test_activation.py b/sgl-kernel/tests/test_activation.py new file mode 100644 index 000000000000..43593441e3b6 --- /dev/null +++ b/sgl-kernel/tests/test_activation.py @@ -0,0 +1,39 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_activation.py + +import pytest +import sgl_kernel +import torch + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_silu_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.silu(x[..., :dim]) + y = sgl_kernel.silu_and_mul(x) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_gelu_tanh_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="tanh") + y = sgl_kernel.gelu_tanh_and_mul(x) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("dim", [128, 256, 512, 2048, 4096, 11008, 16384]) +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) +@pytest.mark.parametrize("seq_len", [1, 2, 4, 8, 16, 32, 64, 128, 512]) +def test_fused_gelu_mul(dim, batch_size, seq_len): + x = torch.randn(batch_size, seq_len, 2 * dim).to(0).to(torch.float16) + y_ref = x[..., dim:] * torch.nn.functional.gelu(x[..., :dim], approximate="none") + y = sgl_kernel.gelu_and_mul(x) + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_bmm_fp8.py b/sgl-kernel/tests/test_bmm_fp8.py new file mode 100644 index 000000000000..e0be92896f61 --- /dev/null +++ b/sgl-kernel/tests/test_bmm_fp8.py @@ -0,0 +1,43 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_bmm_fp8.py + +import pytest +import torch +import torch.nn.functional as F +from sgl_kernel import bmm_fp8 + + +def to_float8(x, dtype=torch.float8_e4m3fn): + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype), scale.float().reciprocal() + + +@pytest.mark.parametrize("input_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("mat2_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("res_dtype", [torch.bfloat16, torch.float16]) +def test_bmm_fp8(input_dtype, mat2_dtype, res_dtype): + if input_dtype == torch.float8_e5m2 and mat2_dtype == torch.float8_e5m2: + pytest.skip("Invalid combination: both input and mat2 are e5m2") + + input = torch.randn([16, 48, 64], device="cuda", dtype=torch.bfloat16) + input_fp8, input_inv_s = to_float8(input, dtype=input_dtype) + + # mat2 row major -> column major + mat2 = torch.randn([16, 80, 64], device="cuda", dtype=torch.bfloat16).transpose( + -2, -1 + ) + mat2_fp8, mat2_inv_s = to_float8(mat2, dtype=mat2_dtype) + + res = torch.empty([16, 48, 80], device="cuda", dtype=res_dtype) + bmm_fp8(input_fp8, mat2_fp8, input_inv_s, mat2_inv_s, res_dtype, res) + + reference = torch.bmm(input, mat2) + cos_sim = F.cosine_similarity(reference.reshape(-1), res.reshape(-1), dim=0) + assert cos_sim > 0.99 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_fp8_gemm.py b/sgl-kernel/tests/test_fp8_gemm.py new file mode 100644 index 000000000000..1a7318659444 --- /dev/null +++ b/sgl-kernel/tests/test_fp8_gemm.py @@ -0,0 +1,67 @@ +import unittest + +import torch +from sgl_kernel import fp8_scaled_mm + + +def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): + o = torch.matmul(a.to(torch.float32), b.to(torch.float32)) + + o = o.to(torch.float32) + temp1 = o * scale_a.view(-1, 1) + temp2 = temp1 * scale_b.view(1, -1) + final = temp2.to(out_dtype) + if bias is not None: + final = final + bias.view(1, -1) + + return final + + +class TestFp8Gemm(unittest.TestCase): + def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + a_fp32 = ( + (torch.rand(M, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + ) + a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + b_fp32 = ( + (torch.rand(N, K, dtype=torch.float32, device=device) - 0.5) * 2 * fp8_max + ) + b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) + + scale_a = torch.randn((M,), device=device, dtype=torch.float32) * 0.001 + scale_b = torch.randn((N,), device=device, dtype=torch.float32) * 0.001 + if with_bias: + bias = torch.randn((N,), device=device, dtype=out_dtype) + else: + bias = None + o1 = torch.empty((M, N), device=device, dtype=torch.bfloat16) + b_fp8 = b_fp8.t() + o = torch_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) + o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, out_dtype, bias) + rtol = 0.02 + atol = 1 + torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) + print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") + + def test_accuracy(self): + Ms = [1, 128, 512, 1024, 4096] + Ns = [16, 128, 512, 1024, 4096] + Ks = [512, 1024, 4096, 8192, 16384] + bias_opts = [True, False] + out_dtypes = [torch.bfloat16, torch.float16] + for M in Ms: + for N in Ns: + for K in Ks: + for with_bias in bias_opts: + for out_dtype in out_dtypes: + self._test_accuracy_once( + M, N, K, with_bias, out_dtype, "cuda" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sgl-kernel/tests/test_int8_gemm.py b/sgl-kernel/tests/test_int8_gemm.py index 34d17d1c76ac..c33a3effcafd 100644 --- a/sgl-kernel/tests/test_int8_gemm.py +++ b/sgl-kernel/tests/test_int8_gemm.py @@ -25,7 +25,7 @@ def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) if with_bias: - bias = torch.ones((N,), device="cuda", dtype=out_dtype) * 10 + bias = torch.randn((N,), device="cuda", dtype=out_dtype) * 10 else: bias = None diff --git a/sgl-kernel/tests/test_lightning_attention_decode.py b/sgl-kernel/tests/test_lightning_attention_decode.py new file mode 100644 index 000000000000..f2cace00157a --- /dev/null +++ b/sgl-kernel/tests/test_lightning_attention_decode.py @@ -0,0 +1,88 @@ +import pytest +import torch +from sgl_kernel import lightning_attention_decode + + +def naive_lightning_attention_decode(q, k, v, past_kv, slope): + """Naive implementation of lightning attention decode""" + original_dtype = q.dtype + ratio = torch.exp(-slope) # [h, 1, 1] + + kv = past_kv + b, h, n, d = q.shape + + output = [] + for i in range(n): + kv = ratio * kv.to(torch.float32) + torch.einsum( + "... n d, ... n e -> ... d e", + k[:, :, i : i + 1], + v[:, :, i : i + 1], + ) + qkv = torch.einsum( + "... n e, ... e d -> ... n d", + q[:, :, i : i + 1].to(torch.float32), + kv.to(torch.float32), + ) + output.append(qkv) + output = torch.concat(output, dim=-2) + + return output.to(original_dtype), kv + + +configs = [ + # (batch_size, num_heads, dim, embed_dim) + (1, 8, 64, 64), + (2, 8, 64, 64), + (1, 32, 32, 64), + (2, 32, 32, 64), + (4, 32, 64, 64), + (4, 32, 64, 64), + (16, 64, 96, 96), + (64, 64, 96, 96), +] + +dtypes = [torch.float32, torch.float16, torch.bfloat16] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("batch_size,num_heads,dim,embed_dim", configs) +def test_lightning_attention_decode(dtype, batch_size, num_heads, dim, embed_dim): + device = torch.device("cuda") + + q = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype) + k = torch.randn(batch_size, num_heads, 1, dim, device=device, dtype=dtype) + v = torch.randn(batch_size, num_heads, 1, embed_dim, device=device, dtype=dtype) + past_kv = torch.randn(batch_size, num_heads, dim, embed_dim, device=device) + slope = torch.randn(num_heads, 1, 1, device=device) + + ref_output, ref_new_kv = naive_lightning_attention_decode(q, k, v, past_kv, slope) + + output = torch.empty_like(ref_output) + new_kv = torch.empty_like(ref_new_kv) + lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv) + + rtol = 1e-2 + atol = 1e-2 + + torch.testing.assert_close( + output, + ref_output, + rtol=rtol, + atol=atol, + msg=f"Output mismatch for batch_size={batch_size}, num_heads={num_heads}, " + f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}", + ) + + torch.testing.assert_close( + new_kv, + ref_new_kv, + rtol=rtol, + atol=atol, + msg=f"New KV mismatch for batch_size={batch_size}, num_heads={num_heads}, " + f"dim={dim}, embed_dim={embed_dim}, dtype={dtype}", + ) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_moe_align.py b/sgl-kernel/tests/test_moe_align.py index 92596a47e5db..2fca90b2f561 100644 --- a/sgl-kernel/tests/test_moe_align.py +++ b/sgl-kernel/tests/test_moe_align.py @@ -3,38 +3,65 @@ def test_moe_align_block_size(): + # For DeepSeek V3, we have 256 experts num_experts = 256 - block_size = 128 - topk_ids = torch.randint(0, num_experts, (3, 4), dtype=torch.int32, device="cuda") - - max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) - sorted_ids = torch.empty( - (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device - ) - sorted_ids.fill_(topk_ids.numel()) - max_num_m_blocks = max_num_tokens_padded // block_size - expert_ids = torch.empty( - (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device - ) - num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) - - token_cnts_buffer = torch.empty( - (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device - ) - cumsum_buffer = torch.empty( - num_experts + 1, dtype=torch.int32, device=topk_ids.device - ) - - moe_align_block_size( - topk_ids, - num_experts, - block_size, - sorted_ids, - expert_ids, - num_tokens_post_pad, - token_cnts_buffer, - cumsum_buffer, - ) - - -test_moe_align_block_size() + + # Test different combinations of block_size, num_tokens and topk + for block_size in [32, 64, 128, 256]: + print(f"\nTesting block_size={block_size}") + for num_tokens in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]: + for topk in [1, 2, 4, 8, 16, 32, 64]: + print( + f"Testing block_size={block_size}, num_tokens={num_tokens}, topk={topk}" + ) + + # Create random topk_ids with shape [num_tokens, topk] + topk_ids = torch.randint( + 0, num_experts, (num_tokens, topk), dtype=torch.int32, device="cuda" + ) + + max_num_tokens_padded = topk_ids.numel() + num_experts * ( + block_size - 1 + ) + sorted_ids = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device + ) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty( + (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device + ) + num_tokens_post_pad = torch.empty( + (1), dtype=torch.int32, device=topk_ids.device + ) + + token_cnts_buffer = torch.empty( + (num_experts + 1) * num_experts, + dtype=torch.int32, + device=topk_ids.device, + ) + cumsum_buffer = torch.empty( + num_experts + 1, dtype=torch.int32, device=topk_ids.device + ) + + try: + moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids, + expert_ids, + num_tokens_post_pad, + token_cnts_buffer, + cumsum_buffer, + ) + except Exception as e: + print( + f"Error occurred with block_size={block_size}, num_tokens={num_tokens}, topk={topk}" + ) + print(f"Error message: {str(e)}") + raise e + + +if __name__ == "__main__": + test_moe_align_block_size() diff --git a/sgl-kernel/tests/test_norm.py b/sgl-kernel/tests/test_norm.py new file mode 100644 index 000000000000..7b38dba72bfb --- /dev/null +++ b/sgl-kernel/tests/test_norm.py @@ -0,0 +1,133 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/4e8eb1879f9c3ba6d75511e5893183bf8f289a62/tests/test_norm.py + +import pytest +import sgl_kernel +import torch + + +def llama_rms_norm(x, w, eps=1e-6): + orig_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * w.float() + x = x.to(orig_dtype) + return x + + +def gemma_rms_norm(x, w, eps=1e-6): + orig_dtype = x.dtype + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * (1.0 + w.float()) + x = x.to(orig_dtype) + return x + + +def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6): + orig_dtype = x.dtype + x = x + residual + residual = x + x = x.float() + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = x * (1.0 + w.float()) + x = x.to(orig_dtype) + return x, residual + + +def fused_add_rms_norm(x, residual, weight, eps): + orig_dtype = x.dtype + x = x.to(torch.float32) + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + eps) + x = (x * weight.float()).to(orig_dtype) + return x, residual + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("specify_out", [True, False]) +def test_norm(batch_size, hidden_size, dtype, specify_out): + x = torch.randn(batch_size, hidden_size).to(0).to(dtype) + w = torch.randn(hidden_size).to(0).to(dtype) + + y_ref = llama_rms_norm(x, w) + if specify_out: + y = torch.empty_like(x) + sgl_kernel.rmsnorm(x, w, out=y) + else: + y = sgl_kernel.rmsnorm(x, w) + + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_fused_add_rmsnorm(batch_size, hidden_size, dtype): + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) + weight = torch.randn(hidden_size, dtype=dtype, device="cuda") + + x_native, residual_native = fused_add_rms_norm( + x.clone(), residual.clone(), weight, eps + ) + + x_fused = x.clone() + residual_fused = residual.clone() + sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps) + + torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("specify_out", [True, False]) +def test_gemma_norm(batch_size, hidden_size, dtype, specify_out): + x = torch.randn(batch_size, hidden_size).to(0).to(dtype) + w = torch.randn(hidden_size).to(0).to(dtype) + + y_ref = gemma_rms_norm(x, w) + if specify_out: + y = torch.empty_like(x) + sgl_kernel.gemma_rmsnorm(x, w, out=y) + else: + y = sgl_kernel.gemma_rmsnorm(x, w) + + torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype): + eps = 1e-6 + + x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda") + residual = torch.randn_like(x) + weight = torch.randn(hidden_size, dtype=dtype, device="cuda") + + x_native, residual_native = gemma_fused_add_rms_norm( + x.clone(), residual.clone(), weight, eps + ) + + x_fused = x.clone() + residual_fused = residual.clone() + sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps) + + torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_rotary_embedding.py b/sgl-kernel/tests/test_rotary_embedding.py new file mode 100644 index 000000000000..1bbe8f1bfebb --- /dev/null +++ b/sgl-kernel/tests/test_rotary_embedding.py @@ -0,0 +1,118 @@ +from typing import Optional, Tuple + +import torch +from vllm.model_executor.layers.rotary_embedding import ( + RotaryEmbedding as VLLMRotaryEmbedding, +) + + +class SGLRotaryEmbedding(VLLMRotaryEmbedding): + + def forward_cuda( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from sgl_kernel import rotary_embedding + + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) + + rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + self.is_neox_style, + ) + return query, key + + +# Compare the output of SGLRotaryEmbedding's forward_cuda with VLLMRotaryEmbedding's forward_native + + +def test_rotary_embedding(): + # Test case 1: FP32 + def run_test( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + batch_size, + seq_len, + num_heads, + test_name, + ): + print(f"\nRunning {test_name}...") + # Initialize both implementations + sgl_rope = SGLRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype + ).to("cuda") + vllm_rope = VLLMRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, dtype + ).to("cuda") + + # Regular forward pass + positions = torch.arange(seq_len, device="cuda").repeat(batch_size) + query = torch.randn( + batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype + ) + key = torch.randn( + batch_size * seq_len, num_heads * head_size, device="cuda", dtype=dtype + ) + + # Make copies for both implementations + query_sgl = query.clone() + key_sgl = key.clone() + query_vllm = query.clone() + key_vllm = key.clone() + + # Run both implementations + query_sgl_out, key_sgl_out = sgl_rope.forward_cuda( + positions, query_sgl, key_sgl + ) + query_vllm_out, key_vllm_out = vllm_rope.forward_native( + positions, query_vllm, key_vllm + ) + + # Compare outputs + torch.testing.assert_close(query_sgl_out, query_vllm_out, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(key_sgl_out, key_vllm_out, rtol=1e-3, atol=1e-3) + + print(f"{test_name} passed!") + + # Test Case 1: FP32 with larger dimensions + run_test( + head_size=128, + rotary_dim=64, + max_position=4096, + base=10000, + is_neox_style=True, + dtype=torch.float32, + batch_size=4, + seq_len=32, + num_heads=8, + test_name="FP32 Test", + ) + + # Test Case 2: BF16 with smaller dimensions + run_test( + head_size=64, + rotary_dim=32, + max_position=2048, + base=8000, + is_neox_style=True, + dtype=torch.bfloat16, + batch_size=2, + seq_len=16, + num_heads=4, + test_name="BF16 Test", + ) + + +if __name__ == "__main__": + test_rotary_embedding() diff --git a/sgl-kernel/tests/test_sampling.py b/sgl-kernel/tests/test_sampling.py new file mode 100644 index 000000000000..7d3bc5059eea --- /dev/null +++ b/sgl-kernel/tests/test_sampling.py @@ -0,0 +1,141 @@ +# Adapted from https://github.com/flashinfer-ai/flashinfer/blob/93e1a2634e22355b0856246b032b285ad1d1da6b/tests/test_sampling.py + +import pytest +import sgl_kernel +import torch + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5]) +def test_top_k_top_p_joint_sampling_from_probs(batch_size, vocab_size, p): + torch.manual_seed(42) + if p == 0.1: + k = int(vocab_size * 0.5) + elif p == 0.5: + k = int(vocab_size * 0.1) + else: + raise ValueError("p not recognized") + max_top_k_trails = 32 + eps = 1e-4 + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + # top-p mask + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask_top_p = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask_top_p.scatter_add_(1, indices, (cdf > (1 - p) - eps).int()) + # top-k mask + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask_top_k = (normalized_prob >= pivot.unsqueeze(-1)).int() + # overall mask + mask = torch.minimum(mask_top_p, mask_top_k) + uniform_samples = torch.empty(max_top_k_trails, batch_size, dtype=torch.float32).to( + 0 + ) + top_p_tensor = torch.full((batch_size,), p).to(0) + top_k_tensor = torch.full((batch_size,), k).to(0) + + num_trails = 1000 + for _ in range(num_trails): + uniform_samples.uniform_() + samples, success = sgl_kernel.top_k_top_p_sampling_from_probs( + normalized_prob, + uniform_samples, + top_k_tensor, + top_p_tensor, + filter_apply_order="joint", + ) + assert torch.all(success) + assert torch.all(samples < vocab_size) and torch.all(samples >= 0) + assert torch.all(mask[torch.arange(batch_size), samples] == 1), normalized_prob[ + torch.arange(batch_size), samples + ] + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.1, 0.5, 0.9]) +def test_top_p_renorm_probs(batch_size, vocab_size, p): + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + cdf = torch.cumsum(sorted_prob, dim=-1) + mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask.scatter_add_(1, indices, (cdf >= (1 - p)).int()) + renorm_prob_ground_truth = normalized_prob + renorm_prob_ground_truth[mask == 0] = 0 + renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( + dim=-1, keepdim=True + ) + + renorm_prob = sgl_kernel.top_p_renorm_prob(normalized_prob, p) + torch.testing.assert_close( + renorm_prob_ground_truth, + renorm_prob, + rtol=1e-3, + atol=1e-3, + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("k", [10, 100, 500]) +def test_top_k_renorm_probs(batch_size, vocab_size, k): + if k > vocab_size: + pytest.skip("k should be less than vocab_size") + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, _ = torch.sort(normalized_prob, descending=True) + pivot = sorted_prob[:, k - 1] + mask = (normalized_prob >= pivot.unsqueeze(-1)).int() + renorm_prob_ground_truth = normalized_prob + renorm_prob_ground_truth[mask == 0] = 0 + renorm_prob_ground_truth = renorm_prob_ground_truth / renorm_prob_ground_truth.sum( + dim=-1, keepdim=True + ) + + renorm_prob = sgl_kernel.top_k_renorm_prob(normalized_prob, k) + torch.testing.assert_close( + renorm_prob_ground_truth, + renorm_prob, + rtol=1e-3, + atol=1e-3, + ) + + +@pytest.mark.parametrize("batch_size", [1, 19, 99, 989]) +@pytest.mark.parametrize("vocab_size", [111, 500, 32000, 128256]) +@pytest.mark.parametrize("p", [0.05, 0.1, 0.2, 0.7, 1]) +def test_min_p_sampling(batch_size, vocab_size, p): + torch.manual_seed(42) + pre_norm_prob = torch.rand(batch_size, vocab_size).to(0) + normalized_prob = pre_norm_prob / pre_norm_prob.sum(dim=-1, keepdim=True) + sorted_prob, indices = torch.sort(normalized_prob, descending=False) + # scale min-p + top_probs = sorted_prob[:, -1].unsqueeze(-1) + scaled_p = p * top_probs + # min-p mask + mask = torch.zeros(batch_size, vocab_size, dtype=torch.int32).to(0) + mask.scatter_add_(1, indices, (sorted_prob >= scaled_p).int()) + uniform_samples = torch.empty(batch_size, dtype=torch.float32).to(0) + min_p_tensor = torch.full((batch_size,), p).to(0) + + num_trails = 1000 + for _ in range(num_trails): + uniform_samples.uniform_() + samples = sgl_kernel.min_p_sampling_from_probs( + normalized_prob, + uniform_samples, + min_p_tensor, + ) + + assert torch.all(mask[torch.arange(batch_size), samples] == 1), samples[ + torch.nonzero(mask[torch.arange(batch_size), samples] == 0) + ] + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/sgl-kernel/tests/test_trt_reduce.py b/sgl-kernel/tests/test_trt_reduce.py index a5ce1b41db14..b79580070c0c 100644 --- a/sgl-kernel/tests/test_trt_reduce.py +++ b/sgl-kernel/tests/test_trt_reduce.py @@ -10,6 +10,7 @@ import ray import torch import torch.distributed as dist +from sgl_kernel import ops as custom_ops from torch.distributed import ProcessGroup from vllm import _custom_ops as vllm_ops @@ -104,35 +105,38 @@ def test_performance(self): multi_process_parallel(world_size, self, self.performance) def init_custom_allreduce(self, rank, world_size, group): - import sgl_kernel - buffer_max_size = 8 * 1024 * 1024 barrier_max_size = 8 * (24 + 2) * 8 self.buffer_ptrs = self.create_shared_buffer(buffer_max_size, group=group) + self.tmp_result_buffer_ptrs = self.create_shared_buffer( + buffer_max_size, group=group + ) self.barrier_in_ptrs = self.create_shared_buffer(barrier_max_size, group=group) self.barrier_out_ptrs = self.create_shared_buffer(barrier_max_size, group=group) + self.rank_data = torch.empty( + 8 * 1024 * 1024, dtype=torch.uint8, device=torch.device(f"cuda:{rank}") + ) - self.custom_ptr = sgl_kernel.ops.init_custom_reduce( + self.custom_ptr = custom_ops.init_custom_reduce( rank, world_size, + self.rank_data, self.buffer_ptrs, + self.tmp_result_buffer_ptrs, self.barrier_in_ptrs, self.barrier_out_ptrs, ) def custom_allreduce(self, inp, out): - import sgl_kernel - - sgl_kernel.ops.custom_reduce(self.custom_ptr, inp, out) + custom_ops.custom_reduce(self.custom_ptr, inp, out) def free_custom_allreduce(self, group): - import sgl_kernel - self.free_shared_buffer(self.buffer_ptrs, group) + self.free_shared_buffer(self.tmp_result_buffer_ptrs, group) self.free_shared_buffer(self.barrier_in_ptrs, group) self.free_shared_buffer(self.barrier_out_ptrs, group) - sgl_kernel.ops.custom_dispose(self.custom_ptr) + custom_ops.custom_dispose(self.custom_ptr) def init_vllm_allreduce(self, rank, group): self.vllm_rank = rank diff --git a/sgl-kernel/version.py b/sgl-kernel/version.py new file mode 100644 index 000000000000..ad3ff8af9444 --- /dev/null +++ b/sgl-kernel/version.py @@ -0,0 +1 @@ +__version__ = "0.0.2.post17" diff --git a/sgl-router/README.md b/sgl-router/README.md index 617bca5405fe..61c9e692c923 100644 --- a/sgl-router/README.md +++ b/sgl-router/README.md @@ -4,7 +4,7 @@ SGLang router is a standalone module implemented in Rust to achieve data paralle ## User docs -Please check https://sgl-project.github.io/router/router.html +Please check https://docs.sglang.ai/router/router.html ## Developer docs @@ -67,6 +67,16 @@ $ pip install -e . **Note:** When modifying Rust code, you must rebuild the wheel for changes to take effect. +### Troubleshooting + +1. If rust analyzer is not working in VSCode, set `rust-analyzer.linkedProjects` to the absolute path of `Cargo.toml` in your repo. For example: + +```json +{ + "rust-analyzer.linkedProjects": ["/workspaces/sglang/sgl-router/Cargo.toml"] +} +``` + ### CI/CD Setup The continuous integration pipeline consists of three main steps: diff --git a/sgl-router/py_src/sglang_router/__init__.py b/sgl-router/py_src/sglang_router/__init__.py index 285ee173ba92..081740479ca6 100644 --- a/sgl-router/py_src/sglang_router/__init__.py +++ b/sgl-router/py_src/sglang_router/__init__.py @@ -1,11 +1,7 @@ # a lightweihgt wrapper on router with argument type and comments -from sglang_router_rs import PolicyType - # no wrapper on policy type => direct export -from .router import Router - -__all__ = ["Router", "PolicyType"] - +from sglang_router.router import Router from sglang_router.version import __version__ +from sglang_router_rs import PolicyType -__all__ += ["__version__"] +__all__ = ["Router", "PolicyType", "__version__"] diff --git a/sgl-router/py_src/sglang_router/launch_router.py b/sgl-router/py_src/sglang_router/launch_router.py index e4f26a8d4bce..38f1fbba2dce 100644 --- a/sgl-router/py_src/sglang_router/launch_router.py +++ b/sgl-router/py_src/sglang_router/launch_router.py @@ -27,12 +27,14 @@ def setup_logger(): @dataclasses.dataclass class RouterArgs: # Worker configuration - worker_urls: List[str] + worker_urls: List[str] = dataclasses.field(default_factory=list) host: str = "127.0.0.1" port: int = 30000 # Routing policy policy: str = "cache_aware" + worker_startup_timeout_secs: int = 300 + worker_startup_check_interval: int = 10 cache_threshold: float = 0.5 balance_abs_threshold: int = 32 balance_rel_threshold: float = 1.0001 @@ -87,6 +89,18 @@ def add_cli_args( choices=["random", "round_robin", "cache_aware"], help="Load balancing policy to use", ) + parser.add_argument( + f"--{prefix}worker-startup-timeout-secs", + type=int, + default=RouterArgs.worker_startup_timeout_secs, + help="Timeout in seconds for worker startup", + ) + parser.add_argument( + f"--{prefix}worker-startup-check-interval", + type=int, + default=RouterArgs.worker_startup_check_interval, + help="Interval in seconds between checks for worker startup", + ) parser.add_argument( f"--{prefix}cache-threshold", type=float, @@ -141,11 +155,18 @@ def from_cli_args( use_router_prefix: If True, look for arguments with 'router-' prefix """ prefix = "router_" if use_router_prefix else "" + worker_urls = args.worker_urls if args.worker_urls is not None else [] return cls( - worker_urls=args.worker_urls, + worker_urls=worker_urls, host=args.host, port=args.port, policy=getattr(args, f"{prefix}policy"), + worker_startup_timeout_secs=getattr( + args, f"{prefix}worker_startup_timeout_secs" + ), + worker_startup_check_interval=getattr( + args, f"{prefix}worker_startup_check_interval" + ), cache_threshold=getattr(args, f"{prefix}cache_threshold"), balance_abs_threshold=getattr(args, f"{prefix}balance_abs_threshold"), balance_rel_threshold=getattr(args, f"{prefix}balance_rel_threshold"), @@ -187,9 +208,11 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: router = Router( worker_urls=router_args.worker_urls, - policy=policy_from_str(router_args.policy), host=router_args.host, port=router_args.port, + policy=policy_from_str(router_args.policy), + worker_startup_timeout_secs=router_args.worker_startup_timeout_secs, + worker_startup_check_interval=router_args.worker_startup_check_interval, cache_threshold=router_args.cache_threshold, balance_abs_threshold=router_args.balance_abs_threshold, balance_rel_threshold=router_args.balance_rel_threshold, @@ -204,7 +227,7 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]: except Exception as e: logger.error(f"Error starting router: {e}") - return None + raise e class CustomHelpFormatter( @@ -237,12 +260,8 @@ def parse_router_args(args: List[str]) -> RouterArgs: def main() -> None: - logger = setup_logger() router_args = parse_router_args(sys.argv[1:]) - router = launch_router(router_args) - - if router is None: - sys.exit(1) + launch_router(router_args) if __name__ == "__main__": diff --git a/sgl-router/py_src/sglang_router/launch_server.py b/sgl-router/py_src/sglang_router/launch_server.py index 6ee192415429..74353c21edbb 100644 --- a/sgl-router/py_src/sglang_router/launch_server.py +++ b/sgl-router/py_src/sglang_router/launch_server.py @@ -13,7 +13,7 @@ from setproctitle import setproctitle from sglang_router.launch_router import RouterArgs, launch_router -from sglang.srt.server import launch_server +from sglang.srt.entrypoints.http_server import launch_server from sglang.srt.server_args import ServerArgs from sglang.srt.utils import is_port_available @@ -23,7 +23,7 @@ def setup_logger(): logger.setLevel(logging.INFO) formatter = logging.Formatter( - "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s", + "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d", datefmt="%Y-%m-%d %H:%M:%S", ) @@ -68,7 +68,7 @@ def run_server(server_args, dp_rank): # create new process group os.setpgrp() - setproctitle(f"sglang::server") + setproctitle("sglang::server") # Set SGLANG_DP_RANK environment variable os.environ["SGLANG_DP_RANK"] = str(dp_rank) @@ -120,9 +120,26 @@ def find_available_ports(base_port: int, count: int) -> List[int]: def cleanup_processes(processes: List[mp.Process]): for process in processes: - logger.info(f"Terminating process {process.pid}") - process.terminate() - logger.info("All processes terminated") + logger.info(f"Terminating process group {process.pid}") + try: + os.killpg(process.pid, signal.SIGTERM) + except ProcessLookupError: + # Process group may already be terminated + pass + + # Wait for processes to terminate + for process in processes: + process.join(timeout=5) + if process.is_alive(): + logger.warning( + f"Process {process.pid} did not terminate gracefully, forcing kill" + ) + try: + os.killpg(process.pid, signal.SIGKILL) + except ProcessLookupError: + pass + + logger.info("All process groups terminated") def main(): @@ -173,7 +190,12 @@ def main(): ] # Start the router - router = launch_router(router_args) + try: + launch_router(router_args) + except Exception as e: + logger.error(f"Failed to start router: {e}") + cleanup_processes(server_processes) + sys.exit(1) if __name__ == "__main__": diff --git a/sgl-router/py_src/sglang_router/router.py b/sgl-router/py_src/sglang_router/router.py index 5ce21c3d78ea..b8757168b242 100644 --- a/sgl-router/py_src/sglang_router/router.py +++ b/sgl-router/py_src/sglang_router/router.py @@ -17,6 +17,8 @@ class Router: - PolicyType.CacheAware: Distribute requests based on cache state and load balance host: Host address to bind the router server. Default: '127.0.0.1' port: Port number to bind the router server. Default: 3001 + worker_startup_timeout_secs: Timeout in seconds for worker startup. Default: 300 + worker_startup_check_interval: Interval in seconds between checks for worker initialization. Default: 10 cache_threshold: Cache threshold (0.0-1.0) for cache-aware routing. Routes to cached worker if the match rate exceeds threshold, otherwise routes to the worker with the smallest tree. Default: 0.5 @@ -37,6 +39,8 @@ def __init__( policy: PolicyType = PolicyType.RoundRobin, host: str = "127.0.0.1", port: int = 3001, + worker_startup_timeout_secs: int = 300, + worker_startup_check_interval: int = 10, cache_threshold: float = 0.50, balance_abs_threshold: int = 32, balance_rel_threshold: float = 1.0001, @@ -50,6 +54,8 @@ def __init__( policy=policy, host=host, port=port, + worker_startup_timeout_secs=worker_startup_timeout_secs, + worker_startup_check_interval=worker_startup_check_interval, cache_threshold=cache_threshold, balance_abs_threshold=balance_abs_threshold, balance_rel_threshold=balance_rel_threshold, diff --git a/sgl-router/py_src/sglang_router/version.py b/sgl-router/py_src/sglang_router/version.py index 485f44ac21b2..bbab0242f6aa 100644 --- a/sgl-router/py_src/sglang_router/version.py +++ b/sgl-router/py_src/sglang_router/version.py @@ -1 +1 @@ -__version__ = "0.1.1" +__version__ = "0.1.4" diff --git a/sgl-router/py_test/test_launch_router.py b/sgl-router/py_test/test_launch_router.py index 1c3700d423ba..27ed64d6e668 100644 --- a/sgl-router/py_test/test_launch_router.py +++ b/sgl-router/py_test/test_launch_router.py @@ -22,14 +22,14 @@ def terminate_process(process: multiprocessing.Process, timeout: float = 1.0) -> class TestLaunchRouter(unittest.TestCase): - def test_launch_router_no_exception(self): - - # Create SimpleNamespace with default arguments - args = SimpleNamespace( - worker_urls=["http://localhost:8000"], + def setUp(self): + """Set up default arguments for router tests.""" + self.default_args = SimpleNamespace( host="127.0.0.1", port=30000, policy="cache_aware", + worker_startup_timeout_secs=600, + worker_startup_check_interval=10, cache_threshold=0.5, balance_abs_threshold=32, balance_rel_threshold=1.0001, @@ -39,6 +39,15 @@ def test_launch_router_no_exception(self): verbose=False, ) + def create_router_args(self, **kwargs): + """Create router arguments by updating default args with provided kwargs.""" + args_dict = vars(self.default_args).copy() + args_dict.update(kwargs) + return SimpleNamespace(**args_dict) + + def run_router_process(self, args): + """Run router in a separate process and verify it starts successfully.""" + def run_router(): try: from sglang_router.launch_router import launch_router @@ -51,7 +60,6 @@ def run_router(): print(e) return 1 - # Start router in separate process process = multiprocessing.Process(target=run_router) try: process.start() @@ -62,6 +70,14 @@ def run_router(): finally: terminate_process(process) + def test_launch_router_common(self): + args = self.create_router_args(worker_urls=["http://localhost:8000"]) + self.run_router_process(args) + + def test_launch_router_with_empty_worker_urls(self): + args = self.create_router_args(worker_urls=[]) + self.run_router_process(args) + if __name__ == "__main__": unittest.main() diff --git a/sgl-router/py_test/test_launch_server.py b/sgl-router/py_test/test_launch_server.py index e11602933a63..80659fc4f3e0 100644 --- a/sgl-router/py_test/test_launch_server.py +++ b/sgl-router/py_test/test_launch_server.py @@ -22,6 +22,7 @@ def popen_launch_router( timeout: float, policy: str = "cache_aware", max_payload_size: int = None, + api_key: str = None, ): """ Launch the router server process. @@ -33,6 +34,7 @@ def popen_launch_router( timeout: Server launch timeout policy: Router policy, one of "cache_aware", "round_robin", "random" max_payload_size: Maximum payload size in bytes + api_key: API key for the router """ _, host, port = base_url.split(":") host = host[2:] @@ -55,6 +57,9 @@ def popen_launch_router( policy, ] + if api_key is not None: + command.extend(["--api-key", api_key]) + if max_payload_size is not None: command.extend(["--router-max-payload-size", str(max_payload_size)]) @@ -333,6 +338,57 @@ def test_4_payload_size(self): f"1.2MB payload should fail with 413 but got status {response.status_code}", ) + def test_5_api_key(self): + print("Running test_5_api_key...") + + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", + api_key="correct_api_key", + ) + + # # Test case 1: request without api key should fail + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json={"text": "Kanye west is, ", "temperature": 0}, + ) + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual( + response.status_code, + 401, + "Request without api key should fail with 401", + ) + + # Test case 2: request with invalid api key should fail + with requests.Session() as session: + response = requests.post( + f"{self.base_url}/generate", + json={"text": "Kanye west is, ", "temperature": 0}, + headers={"Authorization": "Bearer 123"}, + ) + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual( + response.status_code, + 401, + "Request with invalid api key should fail with 401", + ) + + # Test case 3: request with correct api key should succeed + with requests.Session() as session: + response = session.post( + f"{self.base_url}/generate", + json={"text": "Kanye west is ", "temperature": 0}, + headers={"Authorization": "Bearer correct_api_key"}, + ) + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual( + response.status_code, 200, "Request with correct api key should succeed" + ) + if __name__ == "__main__": unittest.main() diff --git a/sgl-router/pyproject.toml b/sgl-router/pyproject.toml index 20096b6b4912..da5c44a1196d 100644 --- a/sgl-router/pyproject.toml +++ b/sgl-router/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sglang-router" -version = "0.1.1" +version = "0.1.4" description = "SGLang router is a standalone module implemented in Rust to achieve data parallelism across SGLang instances." authors = [{name = "Byron Hsu", email = "byronhsu1230@gmail.com"}] requires-python = ">=3.8" @@ -20,6 +20,10 @@ classifiers = [ [tool.setuptools.packages] find = { where = ["py_src"] } +# workaround for https://github.com/pypa/twine/issues/1216 +[tool.setuptools] +license-files = [] + [[tool.setuptools-rust.ext-modules]] target = "sglang_router_rs" path = "Cargo.toml" diff --git a/sgl-router/src/lib.rs b/sgl-router/src/lib.rs index 2d8cf4c0c8d6..ba9aeac1fef2 100644 --- a/sgl-router/src/lib.rs +++ b/sgl-router/src/lib.rs @@ -17,6 +17,8 @@ struct Router { port: u16, worker_urls: Vec, policy: PolicyType, + worker_startup_timeout_secs: u64, + worker_startup_check_interval: u64, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, @@ -34,6 +36,8 @@ impl Router { policy = PolicyType::RoundRobin, host = String::from("127.0.0.1"), port = 3001, + worker_startup_timeout_secs = 300, + worker_startup_check_interval = 10, cache_threshold = 0.50, balance_abs_threshold = 32, balance_rel_threshold = 1.0001, @@ -47,6 +51,8 @@ impl Router { policy: PolicyType, host: String, port: u16, + worker_startup_timeout_secs: u64, + worker_startup_check_interval: u64, cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, @@ -60,6 +66,8 @@ impl Router { port, worker_urls, policy, + worker_startup_timeout_secs, + worker_startup_check_interval, cache_threshold, balance_abs_threshold, balance_rel_threshold, @@ -72,9 +80,17 @@ impl Router { fn start(&self) -> PyResult<()> { let policy_config = match &self.policy { - PolicyType::Random => router::PolicyConfig::RandomConfig, - PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig, + PolicyType::Random => router::PolicyConfig::RandomConfig { + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval, + }, + PolicyType::RoundRobin => router::PolicyConfig::RoundRobinConfig { + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval, + }, PolicyType::CacheAware => router::PolicyConfig::CacheAwareConfig { + timeout_secs: self.worker_startup_timeout_secs, + interval_secs: self.worker_startup_check_interval, cache_threshold: self.cache_threshold, balance_abs_threshold: self.balance_abs_threshold, balance_rel_threshold: self.balance_rel_threshold, @@ -93,10 +109,9 @@ impl Router { max_payload_size: self.max_payload_size, }) .await - .unwrap(); - }); - - Ok(()) + .map_err(|e| pyo3::exceptions::PyRuntimeError::new_err(e.to_string()))?; + Ok(()) + }) } } diff --git a/sgl-router/src/router.rs b/sgl-router/src/router.rs index 08f6cdefa759..5ee34c59869d 100644 --- a/sgl-router/src/router.rs +++ b/sgl-router/src/router.rs @@ -3,7 +3,7 @@ use actix_web::http::header::{HeaderValue, CONTENT_TYPE}; use actix_web::{HttpRequest, HttpResponse}; use bytes::Bytes; use futures_util::{StreamExt, TryStreamExt}; -use log::{debug, info, warn}; +use log::{debug, error, info, warn}; use std::collections::HashMap; use std::fmt::Debug; use std::sync::atomic::AtomicUsize; @@ -12,14 +12,30 @@ use std::thread; use std::time::Duration; use tokio; +fn copy_request_headers(req: &HttpRequest) -> Vec<(String, String)> { + req.headers() + .iter() + .filter_map(|(name, value)| { + value + .to_str() + .ok() + .map(|v| (name.to_string(), v.to_string())) + }) + .collect() +} + #[derive(Debug)] pub enum Router { RoundRobin { worker_urls: Arc>>, current_index: AtomicUsize, + timeout_secs: u64, + interval_secs: u64, }, Random { worker_urls: Arc>>, + timeout_secs: u64, + interval_secs: u64, }, CacheAware { /* @@ -89,36 +105,73 @@ pub enum Router { cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, + timeout_secs: u64, + interval_secs: u64, _eviction_thread: Option>, }, } #[derive(Debug, Clone)] pub enum PolicyConfig { - RandomConfig, - RoundRobinConfig, + RandomConfig { + timeout_secs: u64, + interval_secs: u64, + }, + RoundRobinConfig { + timeout_secs: u64, + interval_secs: u64, + }, CacheAwareConfig { cache_threshold: f32, balance_abs_threshold: usize, balance_rel_threshold: f32, eviction_interval_secs: u64, max_tree_size: usize, + timeout_secs: u64, + interval_secs: u64, }, } impl Router { pub fn new(worker_urls: Vec, policy_config: PolicyConfig) -> Result { + // Get timeout and interval from policy config + let (timeout_secs, interval_secs) = match &policy_config { + PolicyConfig::RandomConfig { + timeout_secs, + interval_secs, + } => (*timeout_secs, *interval_secs), + PolicyConfig::RoundRobinConfig { + timeout_secs, + interval_secs, + } => (*timeout_secs, *interval_secs), + PolicyConfig::CacheAwareConfig { + timeout_secs, + interval_secs, + .. + } => (*timeout_secs, *interval_secs), + }; + // Wait until all workers are healthy - Self::wait_for_healthy_workers(&worker_urls, 300, 10)?; + Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?; // Create router based on policy... Ok(match policy_config { - PolicyConfig::RandomConfig => Router::Random { + PolicyConfig::RandomConfig { + timeout_secs, + interval_secs, + } => Router::Random { worker_urls: Arc::new(RwLock::new(worker_urls)), + timeout_secs, + interval_secs, }, - PolicyConfig::RoundRobinConfig => Router::RoundRobin { + PolicyConfig::RoundRobinConfig { + timeout_secs, + interval_secs, + } => Router::RoundRobin { worker_urls: Arc::new(RwLock::new(worker_urls)), current_index: std::sync::atomic::AtomicUsize::new(0), + timeout_secs, + interval_secs, }, PolicyConfig::CacheAwareConfig { cache_threshold, @@ -126,6 +179,8 @@ impl Router { balance_rel_threshold, eviction_interval_secs, max_tree_size, + timeout_secs, + interval_secs, } => { let mut running_queue = HashMap::new(); for url in &worker_urls { @@ -176,6 +231,8 @@ impl Router { cache_threshold, balance_abs_threshold, balance_rel_threshold, + timeout_secs, + interval_secs, _eviction_thread: Some(eviction_thread), } } @@ -192,9 +249,13 @@ impl Router { loop { if start_time.elapsed() > Duration::from_secs(timeout_secs) { + error!( + "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_urls + ); return Err(format!( - "Timeout {}s waiting for workers to become healthy", - timeout_secs + "Timeout {}s waiting for workers {:?} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_urls )); } @@ -238,7 +299,7 @@ impl Router { fn select_first_worker(&self) -> Result { match self { Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls } + | Router::Random { worker_urls, .. } | Router::CacheAware { worker_urls, .. } => { if worker_urls.read().unwrap().is_empty() { Err("No workers are available".to_string()) @@ -254,8 +315,18 @@ impl Router { client: &reqwest::Client, worker_url: &str, route: &str, + req: &HttpRequest, ) -> HttpResponse { - match client.get(format!("{}{}", worker_url, route)).send().await { + let mut request_builder = client.get(format!("{}{}", worker_url, route)); + + // Copy all headers from original request except for /health because it does not need authorization + if route != "/health" { + for (name, value) in copy_request_headers(req) { + request_builder = request_builder.header(name, value); + } + } + + match request_builder.send().await { Ok(res) => { let status = actix_web::http::StatusCode::from_u16(res.status().as_u16()) .unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR); @@ -273,7 +344,12 @@ impl Router { } } - pub async fn route_to_first(&self, client: &reqwest::Client, route: &str) -> HttpResponse { + pub async fn route_to_first( + &self, + client: &reqwest::Client, + route: &str, + req: &HttpRequest, + ) -> HttpResponse { const MAX_REQUEST_RETRIES: u32 = 3; const MAX_TOTAL_RETRIES: u32 = 6; let mut total_retries = 0; @@ -289,10 +365,17 @@ impl Router { info!("Retrying request after {} failed attempts", total_retries); } - let response = self.send_request(client, &worker_url, route).await; + let response = self.send_request(client, &worker_url, route, req).await; if response.status().is_success() { return response; + } else { + // if the worker is healthy, it means the request is bad, so return the error response + let health_response = + self.send_request(client, &worker_url, "/health", req).await; + if health_response.status().is_success() { + return response; + } } warn!( @@ -349,6 +432,7 @@ impl Router { Router::RoundRobin { worker_urls, current_index, + .. } => { let idx = current_index .fetch_update( @@ -360,7 +444,7 @@ impl Router { worker_urls.read().unwrap()[idx].clone() } - Router::Random { worker_urls } => worker_urls.read().unwrap() + Router::Random { worker_urls, .. } => worker_urls.read().unwrap() [rand::random::() % worker_urls.read().unwrap().len()] .clone(), @@ -446,19 +530,16 @@ impl Router { .map(|v| v.get("stream").and_then(|s| s.as_bool()).unwrap_or(false)) .unwrap_or(false); - let res = match client + let mut request_builder = client .post(format!("{}{}", worker_url, route)) - .header( - "Content-Type", - req.headers() - .get("Content-Type") - .and_then(|h| h.to_str().ok()) - .unwrap_or("application/json"), - ) - .body(body.to_vec()) - .send() - .await - { + .body(body.to_vec()); + + // Copy all headers from original request + for (name, value) in copy_request_headers(req) { + request_builder = request_builder.header(name, value); + } + + let res = match request_builder.send().await { Ok(res) => res, Err(_) => return HttpResponse::InternalServerError().finish(), }; @@ -546,6 +627,13 @@ impl Router { if response.status().is_success() { return response; + } else { + // if the worker is healthy, it means the request is bad, so return the error response + let health_response = + self.send_request(client, &worker_url, "/health", req).await; + if health_response.status().is_success() { + return response; + } } warn!( @@ -570,16 +658,35 @@ impl Router { } pub async fn add_worker(&self, worker_url: &str) -> Result { - let interval_secs = 10; // check every 10 seconds - let timeout_secs = 300; // 5 minutes + let (timeout_secs, interval_secs) = match self { + Router::Random { + timeout_secs, + interval_secs, + .. + } => (*timeout_secs, *interval_secs), + Router::RoundRobin { + timeout_secs, + interval_secs, + .. + } => (*timeout_secs, *interval_secs), + Router::CacheAware { + timeout_secs, + interval_secs, + .. + } => (*timeout_secs, *interval_secs), + }; let start_time = std::time::Instant::now(); let client = reqwest::Client::new(); loop { if start_time.elapsed() > Duration::from_secs(timeout_secs) { + error!( + "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", + timeout_secs, worker_url + ); return Err(format!( - "Timeout {}s waiting for worker {} to become healthy", + "Timeout {}s waiting for worker {} to become healthy. Please set --router-worker-startup-timeout-secs (sglang_router.launch_server) or --worker-startup-timeout-secs (sglang_worker.router) to a larger value", timeout_secs, worker_url )); } @@ -589,7 +696,7 @@ impl Router { if res.status().is_success() { match self { Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls } + | Router::Random { worker_urls, .. } | Router::CacheAware { worker_urls, .. } => { info!("Worker {} health check passed", worker_url); let mut urls = worker_urls.write().unwrap(); @@ -663,7 +770,7 @@ impl Router { pub fn remove_worker(&self, worker_url: &str) { match self { Router::RoundRobin { worker_urls, .. } - | Router::Random { worker_urls } + | Router::Random { worker_urls, .. } | Router::CacheAware { worker_urls, .. } => { let mut urls = worker_urls.write().unwrap(); if let Some(index) = urls.iter().position(|url| url == &worker_url) { diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index 09878f07f8ec..0706c57c06cc 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -18,45 +18,45 @@ impl AppState { worker_urls: Vec, client: reqwest::Client, policy_config: PolicyConfig, - ) -> Self { + ) -> Result { // Create router based on policy - let router = match Router::new(worker_urls, policy_config) { - Ok(router) => router, - Err(error) => panic!("Failed to create router: {}", error), - }; - - Self { router, client } + let router = Router::new(worker_urls, policy_config)?; + Ok(Self { router, client }) } } #[get("/health")] -async fn health(data: web::Data) -> impl Responder { - data.router.route_to_first(&data.client, "/health").await +async fn health(req: HttpRequest, data: web::Data) -> impl Responder { + data.router + .route_to_first(&data.client, "/health", &req) + .await } #[get("/health_generate")] -async fn health_generate(data: web::Data) -> impl Responder { +async fn health_generate(req: HttpRequest, data: web::Data) -> impl Responder { data.router - .route_to_first(&data.client, "/health_generate") + .route_to_first(&data.client, "/health_generate", &req) .await } #[get("/get_server_info")] -async fn get_server_info(data: web::Data) -> impl Responder { +async fn get_server_info(req: HttpRequest, data: web::Data) -> impl Responder { data.router - .route_to_first(&data.client, "/get_server_info") + .route_to_first(&data.client, "/get_server_info", &req) .await } #[get("/v1/models")] -async fn v1_models(data: web::Data) -> impl Responder { - data.router.route_to_first(&data.client, "/v1/models").await +async fn v1_models(req: HttpRequest, data: web::Data) -> impl Responder { + data.router + .route_to_first(&data.client, "/v1/models", &req) + .await } #[get("/get_model_info")] -async fn get_model_info(data: web::Data) -> impl Responder { +async fn get_model_info(req: HttpRequest, data: web::Data) -> impl Responder { data.router - .route_to_first(&data.client, "/get_model_info") + .route_to_first(&data.client, "/get_model_info", &req) .await } @@ -131,6 +131,7 @@ pub struct ServerConfig { } pub async fn startup(config: ServerConfig) -> std::io::Result<()> { + // Initialize logger Builder::new() .format(|buf, record| { use chrono::Local; @@ -152,24 +153,30 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> { ) .init(); + info!("🚧 Initializing router on {}:{}", config.host, config.port); + info!("🚧 Initializing workers on {:?}", config.worker_urls); + info!("🚧 Policy Config: {:?}", config.policy_config); + info!( + "🚧 Max payload size: {} MB", + config.max_payload_size / (1024 * 1024) + ); + let client = reqwest::Client::builder() .build() .expect("Failed to create HTTP client"); - let app_state = web::Data::new(AppState::new( - config.worker_urls.clone(), - client, - config.policy_config.clone(), - )); - - info!("✅ Starting router on {}:{}", config.host, config.port); - info!("✅ Serving Worker URLs: {:?}", config.worker_urls); - info!("✅ Policy Config: {:?}", config.policy_config); - info!( - "✅ Max payload size: {} MB", - config.max_payload_size / (1024 * 1024) + let app_state = web::Data::new( + AppState::new( + config.worker_urls.clone(), + client, + config.policy_config.clone(), + ) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?, ); + info!("✅ Serving router on {}:{}", config.host, config.port); + info!("✅ Serving workers on {:?}", config.worker_urls); + HttpServer::new(move || { App::new() .app_data(app_state.clone()) diff --git a/sgl-router/v0.1.0.md b/sgl-router/v0.1.0.md index 9a1ee152f113..747731a71c2d 100644 --- a/sgl-router/v0.1.0.md +++ b/sgl-router/v0.1.0.md @@ -54,7 +54,7 @@ Note: ## Closing remarks: -1. Please read the full usage at https://sgl-project.github.io/router/router.html +1. Please read the full usage at https://docs.sglang.ai/router/router.html 2. The feature is still under active improvement, so please don't hesitate to raise issues or submit PRs if you have any suggestions or feedback. diff --git a/test/README.md b/test/README.md index 3d739cc04967..868061bbc4a5 100644 --- a/test/README.md +++ b/test/README.md @@ -25,7 +25,7 @@ export OPENAI_API_KEY=sk-***** python3 test_openai_backend.py # Run a single test -python3 -m unittest test_openai_backend.TestOpenAIBackend.test_few_shot_qa +python3 -m unittest test_openai_backend.TestOpenAIServer.test_few_shot_qa # Run a suite with multiple files python3 run_suite.py --suite per-commit diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py index b99606fc1cb8..a4b1b88a23d3 100644 --- a/test/lang/test_srt_backend.py +++ b/test/lang/test_srt_backend.py @@ -1,6 +1,7 @@ """ Usage: python3 -m unittest test_srt_backend.TestSRTBackend.test_gen_min_new_tokens +python3 -m unittest test_srt_backend.TestSRTBackend.test_hellaswag_select """ import unittest @@ -73,7 +74,7 @@ def test_hellaswag_select(self): # Run twice to capture more bugs for _ in range(2): accuracy, latency = test_hellaswag_select() - self.assertGreater(accuracy, 0.71) + self.assertGreater(accuracy, 0.70) def test_gen_min_new_tokens(self): test_gen_min_new_tokens() diff --git a/test/srt/kv_cache_scales_llama3_1_8b.json b/test/srt/kv_cache_scales_llama3_1_8b.json new file mode 100644 index 000000000000..3e890e50e4af --- /dev/null +++ b/test/srt/kv_cache_scales_llama3_1_8b.json @@ -0,0 +1,42 @@ +{ + "model_type": "llama", + "kv_cache": { + "dtype": "float8_e4m3fn", + "scaling_factor": { + "0": { + "0": 1, + "1": 1, + "2": 1, + "3": 1, + "4": 1, + "5": 1, + "6": 1, + "7": 1, + "8": 1, + "9": 1, + "10": 1, + "11": 1, + "12": 1, + "13": 1, + "14": 1, + "15": 1, + "16": 1, + "17": 1, + "18": 1, + "19": 1, + "20": 1, + "21": 1, + "22": 1, + "23": 1, + "24": 1, + "25": 1, + "26": 1, + "27": 1, + "28": 1, + "29": 1, + "30": 1, + "31": 1 + } + } + } +} diff --git a/test/srt/kv_cache_scales_llama3_8b.json b/test/srt/kv_cache_scales_llama3_8b.json new file mode 100644 index 000000000000..466b0d01a74c --- /dev/null +++ b/test/srt/kv_cache_scales_llama3_8b.json @@ -0,0 +1,42 @@ +{ + "model_type": "llama", + "kv_cache": { + "dtype": "float8_e4m3fn", + "scaling_factor": { + "0": { + "0": 0.0408, + "1": 0.0503, + "2": 0.0667, + "3": 0.0909, + "4": 0.1135, + "5": 0.127, + "6": 0.1768, + "7": 0.1488, + "8": 0.1135, + "9": 0.1203, + "10": 0.1013, + "11": 0.0842, + "12": 0.1231, + "13": 0.1096, + "14": 0.1221, + "15": 0.1013, + "16": 0.1067, + "17": 0.0952, + "18": 0.0899, + "19": 0.097, + "20": 0.087, + "21": 0.0994, + "22": 0.0904, + "23": 0.1013, + "24": 0.1019, + "25": 0.1053, + "26": 0.1, + "27": 0.0894, + "28": 0.1013, + "29": 0.1488, + "30": 0.0766, + "31": 0.0821 + } + } + } +} diff --git a/test/srt/kv_cache_scales_qwen2_1_5b.json b/test/srt/kv_cache_scales_qwen2_1_5b.json new file mode 100644 index 000000000000..984747509f70 --- /dev/null +++ b/test/srt/kv_cache_scales_qwen2_1_5b.json @@ -0,0 +1,38 @@ +{ + "model_type": "qwen", + "kv_cache": { + "dtype": "float8_e4m3fn", + "scaling_factor": { + "0": { + "0": 0.9846, + "1": 0.0645, + "2": 0.0731, + "3": 0.0800, + "4": 0.0748, + "5": 0.0780, + "6": 0.0702, + "7": 0.0894, + "8": 0.0410, + "9": 0.0758, + "10": 0.0556, + "11": 0.0731, + "12": 0.0899, + "13": 0.0780, + "14": 0.1441, + "15": 0.0914, + "16": 0.5614, + "17": 0.1067, + "18": 0.0537, + "19": 0.0658, + "20": 0.0523, + "21": 0.0533, + "22": 0.0699, + "23": 0.0635, + "24": 0.0588, + "25": 0.0884, + "26": 0.0947, + "27": 0.1032 + } + } + } +} diff --git a/test/srt/models/test_qwen_models.py b/test/srt/models/test_qwen_models.py new file mode 100644 index 000000000000..c7788fa8e500 --- /dev/null +++ b/test/srt/models/test_qwen_models.py @@ -0,0 +1,76 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestQwen2(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "Qwen/Qwen2-7B-Instruct" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.81) + + +class TestQwen2FP8(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "neuralmagic/Qwen2-7B-Instruct-FP8" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + self.assertGreater(metrics["accuracy"], 0.79) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/models/test_reward_models.py b/test/srt/models/test_reward_models.py index 0d80a4d0cde8..69ad563671b5 100644 --- a/test/srt/models/test_reward_models.py +++ b/test/srt/models/test_reward_models.py @@ -20,8 +20,8 @@ from sglang.test.runners import HFRunner, SRTRunner MODELS = [ - ("LxzGordon/URM-LLaMa-3.1-8B", 1, 3e-2), - ("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", 1, 3e-2), + ("LxzGordon/URM-LLaMa-3.1-8B", 1, 4e-2), + ("Skywork/Skywork-Reward-Llama-3.1-8B-v0.2", 1, 4e-2), ] TORCH_DTYPES = [torch.float16] diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 83d2e90a43a9..69a5470bee40 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -8,10 +8,12 @@ "models/test_embedding_models.py", "models/test_generation_models.py", "models/test_lora.py", + "models/test_qwen_models.py", "models/test_reward_models.py", "sampling/penaltylib", "test_abort.py", "test_chunked_prefill.py", + "test_custom_allreduce.py", "test_double_sparsity.py", "test_eagle_infer.py", "test_embedding_openai_server.py", @@ -22,11 +24,16 @@ "test_json_constrained.py", "test_large_max_new_tokens.py", "test_metrics.py", + "test_mla.py", + "test_mla_fp8.py", "test_no_chunked_prefill.py", "test_no_overlap_scheduler.py", "test_openai_server.py", "test_pytorch_sampling_backend.py", "test_radix_attention.py", + "test_regex_constrained.py", + "test_release_memory_occupation.py", + "test_request_length_validation.py", "test_retract_decode.py", "test_server_args.py", "test_session_control.py", @@ -35,8 +42,7 @@ "test_srt_endpoint.py", "test_torch_compile.py", "test_torch_compile_moe.py", - # Temporarily disable this because it requires PyTorch >= 2.5 - # "test_torch_native_attention_backend.py", + "test_torch_native_attention_backend.py", "test_torchao.py", "test_triton_attention_kernels.py", "test_triton_attention_backend.py", @@ -44,13 +50,13 @@ "test_update_weights_from_tensor.py", "test_vision_chunked_prefill.py", "test_vision_openai_server.py", + "test_w8a8_quantization.py", "test_session_control.py", - "test_engine_token_ids.py", + "test_fp8_kvcache.py", ], "nightly": [ "test_nightly_gsm8k_eval.py", - "test_nightly_human_eval.py", - # Disable temporarly + # Disable temporarily # "test_nightly_math_eval.py", ], "sampling/penaltylib": glob.glob( diff --git a/test/srt/test_bench_one_batch.py b/test/srt/test_bench_one_batch.py index c1bc98e8e042..c6562170d610 100644 --- a/test/srt/test_bench_one_batch.py +++ b/test/srt/test_bench_one_batch.py @@ -5,24 +5,46 @@ DEFAULT_MOE_MODEL_NAME_FOR_TEST, is_in_ci, run_bench_one_batch, + write_github_step_summary, ) class TestBenchOneBatch(unittest.TestCase): - def test_default(self): + def test_bs1(self): output_throughput = run_bench_one_batch(DEFAULT_MODEL_NAME_FOR_TEST, []) if is_in_ci(): + write_github_step_summary( + f"### test_bs1\n" + f"output_throughput : {output_throughput:.2f} token/s\n" + ) self.assertGreater(output_throughput, 135) - def test_moe_default(self): + def test_moe_tp2_bs1(self): output_throughput = run_bench_one_batch( DEFAULT_MOE_MODEL_NAME_FOR_TEST, ["--tp", "2"] ) if is_in_ci(): + write_github_step_summary( + f"### test_moe_tp2_bs1\n" + f"output_throughput : {output_throughput:.2f} token/s\n" + ) self.assertGreater(output_throughput, 125) + def test_torch_compile_tp2_bs1(self): + output_throughput = run_bench_one_batch( + DEFAULT_MODEL_NAME_FOR_TEST, + ["--tp", "2", "--enable-torch-compile", "--cuda-graph-max-bs", "2"], + ) + + if is_in_ci(): + write_github_step_summary( + f"### test_torch_compile_tp2_bs1\n" + f"output_throughput : {output_throughput:.2f} token/s\n" + ) + self.assertGreater(output_throughput, 240) + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py index b882f12f9df5..8233438fcaf2 100644 --- a/test/srt/test_bench_serving.py +++ b/test/srt/test_bench_serving.py @@ -1,6 +1,8 @@ import unittest from sglang.test.test_utils import ( + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, DEFAULT_FP8_MODEL_NAME_FOR_TEST, DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_MOE_MODEL_NAME_FOR_TEST, @@ -47,7 +49,7 @@ def test_offline_throughput_non_stream_small_batch_size(self): ) # There is a regression with torch 2.5 # This number was 950 for torch 2.4 - self.assertGreater(res["output_throughput"], 800) + self.assertGreater(res["output_throughput"], 1000) def test_offline_throughput_without_radix_cache(self): res = run_bench_serving( @@ -112,7 +114,7 @@ def test_offline_throughput_default_fp8(self): f"### test_offline_throughput_default_fp8\n" f'Output throughput: {res["output_throughput"]:.2f} token/s\n' ) - self.assertGreater(res["output_throughput"], 3850) + self.assertGreater(res["output_throughput"], 3900) def test_online_latency_default(self): res = run_bench_serving( @@ -127,10 +129,40 @@ def test_online_latency_default(self): f"### test_online_latency_default\n" f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n' ) - self.assertLess(res["median_e2e_latency_ms"], 12000) + self.assertLess(res["median_e2e_latency_ms"], 11000) self.assertLess(res["median_ttft_ms"], 86) self.assertLess(res["median_itl_ms"], 10) + def test_online_latency_eagle(self): + res = run_bench_serving( + model=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + num_prompts=50, + request_rate=1, + disable_ignore_eos=True, + dataset_name="sharegpt", + other_server_args=[ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + "5", + "--speculative-eagle-topk", + "8", + "--speculative-num-draft-tokens", + "64", + "--mem-fraction-static", + "0.7", + ], + ) + + if is_in_ci(): + write_github_step_summary( + f"### test_online_latency_eagle\n" + f'median_e2e_latency_ms : {res["median_e2e_latency_ms"]:.2f} ms\n' + ) + self.assertLess(res["median_e2e_latency_ms"], 450) + def test_moe_offline_throughput_default(self): res = run_bench_serving( model=DEFAULT_MOE_MODEL_NAME_FOR_TEST, @@ -144,7 +176,7 @@ def test_moe_offline_throughput_default(self): f"### test_moe_offline_throughput_default\n" f'Output throughput: {res["output_throughput"]:.2f} token/s\n' ) - self.assertGreater(res["output_throughput"], 2150) + self.assertGreater(res["output_throughput"], 2200) def test_moe_offline_throughput_without_radix_cache(self): res = run_bench_serving( @@ -159,7 +191,7 @@ def test_moe_offline_throughput_without_radix_cache(self): f"### test_moe_offline_throughput_without_radix_cache\n" f'Output throughput: {res["output_throughput"]:.2f} token/s\n' ) - self.assertGreater(res["output_throughput"], 2150) + self.assertGreater(res["output_throughput"], 2200) if __name__ == "__main__": diff --git a/test/srt/test_custom_allreduce.py b/test/srt/test_custom_allreduce.py new file mode 100644 index 000000000000..5f6f5d9b4918 --- /dev/null +++ b/test/srt/test_custom_allreduce.py @@ -0,0 +1,164 @@ +import os +import random +import socket +import unittest +from typing import Any + +import ray +import torch +import torch.distributed as dist + +from sglang.srt.distributed import init_distributed_environment +from sglang.srt.distributed.communication_op import ( # noqa + tensor_model_parallel_all_reduce, +) +from sglang.srt.distributed.parallel_state import ( + get_tensor_model_parallel_group, + graph_capture, + initialize_model_parallel, +) + + +def get_open_port() -> int: + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def multi_process_parallel( + world_size: int, + cls: Any, + test_target: Any, +) -> None: + + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + # NOTE: We need to set working_dir for distributed tests, + # otherwise we may get import errors on ray workers + ray.init(log_to_driver=False) + + distributed_init_port = get_open_port() + refs = [] + for rank in range(world_size): + refs.append(test_target.remote(cls, world_size, rank, distributed_init_port)) + ray.get(refs) + + ray.shutdown() + + +class TestCustomAllReduce(unittest.TestCase): + @classmethod + def setUpClass(cls): + random.seed(42) + # 512B to 32MB + cls.test_sizes = [512, 4096, 32768, 262144, 2097152, 16777216, 33554432] + cls.world_sizes = [2, 4, 6, 8] + cls.test_loop = 10 + + def test_graph_allreduce(self): + for world_size in self.world_sizes: + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.graph_allreduce) + + def test_eager_allreduce(self): + for world_size in self.world_sizes: + if world_size > torch.cuda.device_count(): + continue + multi_process_parallel(world_size, self, self.eager_allreduce) + + @ray.remote(num_gpus=1, max_calls=1) + def graph_allreduce(self, world_size, rank, distributed_init_port): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + init_distributed_environment( + world_size=world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=rank, + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + group = get_tensor_model_parallel_group().device_group + + # A small all_reduce for warmup. + # this is needed because device communicators might be created lazily + # (e.g. NCCL). This will ensure that the communicator is initialized + # before any communication happens, so that this group can be used for + # graph capture immediately. + data = torch.zeros(1) + data = data.to(device=device) + torch.distributed.all_reduce(data, group=group) + torch.cuda.synchronize() + del data + + for sz in self.test_sizes: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + for _ in range(self.test_loop): + with graph_capture() as graph_capture_context: + # use integers so result matches NCCL exactly + inp1 = torch.randint( + 1, + 16, + (sz,), + dtype=dtype, + device=torch.cuda.current_device(), + ) + inp2 = torch.randint( + 1, + 16, + (sz,), + dtype=dtype, + device=torch.cuda.current_device(), + ) + torch.cuda.synchronize() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph( + graph, stream=graph_capture_context.stream + ): + out1 = tensor_model_parallel_all_reduce(inp1) + # the input buffer is immediately modified to test + # synchronization + dist.all_reduce(inp1, group=group) + out2 = tensor_model_parallel_all_reduce(inp2) + dist.all_reduce(inp2, group=group) + graph.replay() + torch.testing.assert_close(out1, inp1) + torch.testing.assert_close(out2, inp2) + + @ray.remote(num_gpus=1, max_calls=1) + def eager_allreduce(self, world_size, rank, distributed_init_port): + del os.environ["CUDA_VISIBLE_DEVICES"] + device = torch.device(f"cuda:{rank}") + torch.cuda.set_device(device) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + init_distributed_environment( + world_size=world_size, + rank=rank, + distributed_init_method=distributed_init_method, + local_rank=rank, + ) + initialize_model_parallel(tensor_model_parallel_size=world_size) + group = get_tensor_model_parallel_group().device_group + + for sz in self.test_sizes: + for dtype in [torch.float32, torch.float16, torch.bfloat16]: + for _ in range(self.test_loop): + inp1 = torch.randint( + 1, 16, (sz,), dtype=dtype, device=torch.cuda.current_device() + ) + out1 = tensor_model_parallel_all_reduce(inp1) + dist.all_reduce(inp1, group=group) + torch.testing.assert_close(out1, inp1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py index 609d4411d77d..b01c260496a8 100644 --- a/test/srt/test_eagle_infer.py +++ b/test/srt/test_eagle_infer.py @@ -1,38 +1,179 @@ +import random +import threading +import time import unittest +from types import SimpleNamespace + +import requests import sglang as sgl +from sglang.srt.hf_transformers_utils import get_tokenizer +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) class TestEAGLEEngine(unittest.TestCase): def test_eagle_accuracy(self): prompt = "Today is a sunny day and I like" - target_model_path = "meta-llama/Llama-2-7b-chat-hf" - speculative_draft_model_path = "lmzheng/sglang-EAGLE-llama2-chat-7B" - sampling_params = {"temperature": 0, "max_new_tokens": 8} + # Get the reference output + ref_engine = sgl.Engine(model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) + ref_output = ref_engine.generate(prompt, sampling_params)["text"] + ref_engine.shutdown() + + # Launch EAGLE engine engine = sgl.Engine( - model_path=target_model_path, - speculative_draft_model_path=speculative_draft_model_path, + model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, speculative_algorithm="EAGLE", - speculative_num_steps=3, - speculative_eagle_topk=4, - speculative_num_draft_tokens=16, + speculative_num_steps=5, + speculative_eagle_topk=8, + speculative_num_draft_tokens=64, + mem_fraction_static=0.7, ) + + # Case 1: Test the output of EAGLE engine is the same as normal engine out1 = engine.generate(prompt, sampling_params)["text"] - engine.shutdown() + print(f"{out1=}, {ref_output=}") + self.assertEqual(out1, ref_output) - engine = sgl.Engine(model_path=target_model_path) + # Case 2: Test the output of EAGLE engine does not contain unexpected EOS + prompt = "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like [/INST]" + sampling_params = { + "temperature": 0, + "max_new_tokens": 1024, + "skip_special_tokens": False, + } + + tokenizer = get_tokenizer(DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST) out2 = engine.generate(prompt, sampling_params)["text"] + print(f"{out2=}") + tokens = tokenizer.encode(out2, truncation=False) + assert tokenizer.eos_token_id not in tokens + + # Case 3: Batched prompts + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + sampling_params = {"temperature": 0, "max_new_tokens": 30} + outputs = engine.generate(prompts, sampling_params) + for prompt, output in zip(prompts, outputs): + print("===============================") + print(f"Prompt: {prompt}\nGenerated text: {output['text']}") + + # Shutdown the engine engine.shutdown() - print("==== Answer 1 ====") - print(out1) - print("==== Answer 2 ====") - print(out2) - self.assertEqual(out1, out2) +prompts = [ + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like[/INST]" + '[INST] <>\\nYou are a helpful assistant.\\n<>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]', + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nSummarize Russell Brunson's Perfect Webinar Script...[/INST]", + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwho are you?[/INST]", + "[INST] <>\\nYou are a helpful assistant.\\n<>\\nwhere are you from?[/INST]", +] + + +class TestEAGLEServer(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--speculative-algorithm", + "EAGLE", + "--speculative-draft-model-path", + DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, + "--speculative-num-steps", + "5", + "--speculative-eagle-topk", + "8", + "--speculative-num-draft-tokens", + "64", + "--mem-fraction-static", + "0.7", + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def send_request(self): + time.sleep(random.uniform(0, 2)) + for prompt in prompts: + url = self.base_url + "/generate" + data = { + "text": prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 1024, + }, + } + response = requests.post(url, json=data) + assert response.status_code == 200 + + def send_requests_abort(self): + for prompt in prompts: + try: + time.sleep(random.uniform(0, 2)) + url = self.base_url + "/generate" + data = { + "model": "base", + "text": prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 1024, + }, + } + # set timeout = 1s,mock disconnected + requests.post(url, json=data, timeout=1) + except Exception as e: + print(e) + pass + + def test_request_abort(self): + concurrency = 4 + threads = [ + threading.Thread(target=self.send_request) for _ in range(concurrency) + ] + [ + threading.Thread(target=self.send_requests_abort) + for _ in range(concurrency) + ] + for worker in threads: + worker.start() + for p in threads: + p.join() + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + + self.assertGreater(metrics["accuracy"], 0.20) if __name__ == "__main__": diff --git a/test/srt/test_engine_token_ids.py b/test/srt/test_engine_token_ids.py deleted file mode 100644 index 4dee24edc9de..000000000000 --- a/test/srt/test_engine_token_ids.py +++ /dev/null @@ -1,45 +0,0 @@ -import unittest - -from transformers import AutoTokenizer - -import sglang as sgl -from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST - - -class TestEngineTokenIds(unittest.TestCase): - def test_token_ids_in_generate(self): - llm = sgl.Engine( - model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, return_token_ids=True - ) - tokenizer = AutoTokenizer.from_pretrained(DEFAULT_SMALL_MODEL_NAME_FOR_TEST) - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - sampling_params = {"temperature": 0, "top_p": 0.95} - outputs = llm.generate(prompts, sampling_params) - - for prompt, output in zip(prompts, outputs): - deocode_input = tokenizer.decode( - output["input_ids"], skip_special_tokens=True - ) - assert (deocode_input in prompt) or ( - prompt in deocode_input - ), f"Decode input: {deocode_input} mismatch for: {prompt}" - - deocode_output = tokenizer.decode( - output["output_ids"], skip_special_tokens=True - ) - assert (deocode_output in output["text"]) or ( - output["text"] in deocode_output - ), f"Decode output: {deocode_output} mismatch for: {output['text']}" - - llm.shutdown() - - -if __name__ == "__main__": - unittest.main() diff --git a/test/srt/test_fp8_kvcache.py b/test/srt/test_fp8_kvcache.py new file mode 100644 index 000000000000..4a8a2434699b --- /dev/null +++ b/test/srt/test_fp8_kvcache.py @@ -0,0 +1,113 @@ +import os +import unittest +from types import SimpleNamespace + +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestFp8KvcacheBase(unittest.TestCase): + model_config = None + + @classmethod + def setUpClass(cls): + if cls.model_config is None: + raise NotImplementedError("model_config must be specified in subclass") + + cls.model = cls.model_config["model_name"] + cls.base_url = DEFAULT_URL_FOR_TEST + dirpath = os.path.dirname(__file__) + config_file = os.path.join(dirpath, cls.model_config["config_filename"]) + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--kv-cache-dtype", + "fp8_e4m3", + "--quantization-param-path", + config_file, + ], + ) + + +class TestFp8KvcacheLlama(TestFp8KvcacheBase): + model_config = { + "model_name": DEFAULT_MODEL_NAME_FOR_TEST, + "config_filename": "kv_cache_scales_llama3_8b.json", + } + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.80) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.65) + + +class TestFp8KvcacheQwen(TestFp8KvcacheBase): + model_config = { + "model_name": DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN, + "config_filename": "kv_cache_scales_qwen2_1_5b.json", + } + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mgsm_en(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mgsm_en", + num_examples=None, + num_threads=1024, + ) + + metrics = run_eval(args) + self.assertGreater(metrics["score"], 0.01) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.3) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_metrics.py b/test/srt/test_metrics.py index ccaea5be800e..2837107a1e6b 100644 --- a/test/srt/test_metrics.py +++ b/test/srt/test_metrics.py @@ -56,9 +56,9 @@ def test_metrics_enabled(self): "sglang:gen_throughput", "sglang:num_queue_reqs", "sglang:cache_hit_rate", - "sglang:func_latency_seconds", "sglang:prompt_tokens_total", "sglang:generation_tokens_total", + "sglang:num_requests_total", "sglang:time_to_first_token_seconds", "sglang:time_per_output_token_seconds", "sglang:e2e_request_latency_seconds", diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py index b8105a84af1a..34bc4b446452 100644 --- a/test/srt/test_mla.py +++ b/test/srt/test_mla.py @@ -2,6 +2,7 @@ from types import SimpleNamespace from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -20,7 +21,7 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--tp", "2", "--trust-remote-code"], + other_args=["--trust-remote-code"], ) @classmethod @@ -52,5 +53,37 @@ def test_mgsm_en(self): self.assertGreater(metrics["score"], 0.8) +class TestDeepseekV3(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "lmzheng/sglang-ci-dsv3-test" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--trust-remote-code"], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.62) + + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_mla_fp8.py b/test/srt/test_mla_fp8.py index 769bdf34da87..4fe18b526b1e 100644 --- a/test/srt/test_mla_fp8.py +++ b/test/srt/test_mla_fp8.py @@ -21,8 +21,6 @@ def setUpClass(cls): cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ - "--tp", - "2", "--trust-remote-code", "--kv-cache-dtype", "fp8_e5m2", diff --git a/test/srt/test_moe_ep.py b/test/srt/test_moe_ep.py index 4d9fd435edb5..9f87eb24d719 100644 --- a/test/srt/test_moe_ep.py +++ b/test/srt/test_moe_ep.py @@ -44,7 +44,7 @@ def test_mmlu(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.5 + self.assertGreater(metrics["score"], 0.5) def test_mgsm_en(self): args = SimpleNamespace( @@ -56,7 +56,7 @@ def test_mgsm_en(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.8 + self.assertGreater(metrics["score"], 0.8) class TestEpMoEFP8(unittest.TestCase): diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py index 6f3affbba4d7..dc420f00dfaf 100644 --- a/test/srt/test_moe_eval_accuracy_large.py +++ b/test/srt/test_moe_eval_accuracy_large.py @@ -71,7 +71,7 @@ def test_mgsm_en(self): ) metrics = run_eval(args) - self.assertGreater(metrics["score"], 0.62) + self.assertGreater(metrics["score"], 0.61) if __name__ == "__main__": diff --git a/test/srt/test_nightly_gsm8k_eval.py b/test/srt/test_nightly_gsm8k_eval.py index 7e23b721e433..06c83048f39b 100644 --- a/test/srt/test_nightly_gsm8k_eval.py +++ b/test/srt/test_nightly_gsm8k_eval.py @@ -1,6 +1,5 @@ import json import os -import subprocess import unittest import warnings from datetime import datetime @@ -16,24 +15,26 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, + is_in_ci, popen_launch_server, + write_github_step_summary, ) MODEL_SCORE_THRESHOLDS = { - "meta-llama/Llama-3.1-8B-Instruct": 0.83, + "meta-llama/Llama-3.1-8B-Instruct": 0.82, "mistralai/Mistral-7B-Instruct-v0.3": 0.58, - "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": 0.84, + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct": 0.85, "google/gemma-2-27b-it": 0.92, - "meta-llama/Llama-3.1-70B-Instruct": 0.96, + "meta-llama/Llama-3.1-70B-Instruct": 0.95, "mistralai/Mixtral-8x7B-Instruct-v0.1": 0.63, "Qwen/Qwen2-57B-A14B-Instruct": 0.87, - "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8": 0.84, + "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8": 0.83, "neuralmagic/Mistral-7B-Instruct-v0.3-FP8": 0.54, - "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8": 0.83, + "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8": 0.84, "neuralmagic/gemma-2-2b-it-FP8": 0.60, - "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8": 0.95, - "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.61, - "neuralmagic/Qwen2-72B-Instruct-FP8": 0.95, + "neuralmagic/Meta-Llama-3.1-70B-Instruct-FP8": 0.94, + "neuralmagic/Mixtral-8x7B-Instruct-v0.1-FP8": 0.62, + "neuralmagic/Qwen2-72B-Instruct-FP8": 0.94, "neuralmagic/Qwen2-57B-A14B-Instruct-FP8": 0.82, "hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": 0.84, "hugging-quants/Meta-Llama-3.1-8B-Instruct-GPTQ-INT4": 0.83, @@ -44,7 +45,7 @@ def parse_models(model_string): return [model.strip() for model in model_string.split(",") if model.strip()] -def launch_server(base_url, model, is_fp8, is_tp2): +def popen_launch_server_wrapper(base_url, model, is_fp8, is_tp2): other_args = ["--log-level-http", "warning", "--trust-remote-code"] if is_fp8: if "Llama-3" in model or "gemma-2" in model: @@ -67,7 +68,6 @@ def launch_server(base_url, model, is_fp8, is_tp2): base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=other_args, - return_stdout_stderr=(subprocess.DEVNULL, subprocess.DEVNULL), ) return process @@ -99,6 +99,9 @@ def write_results_to_json(model, metrics, mode="a"): def check_model_scores(results): failed_models = [] + summary = " | model | score | threshold |\n" + summary += "| ----- | ----- | --------- |\n" + for model, score in results: threshold = MODEL_SCORE_THRESHOLDS.get(model) if threshold is None: @@ -111,11 +114,19 @@ def check_model_scores(results): f"Model {model} score ({score:.4f}) is below threshold ({threshold:.4f})" ) + line = f"| {model} | {score} | {threshold} |\n" + summary += line + + print(summary) + + if is_in_ci(): + write_github_step_summary(f"### TestNightlyGsm8KEval\n{summary}") + if failed_models: raise AssertionError("\n".join(failed_models)) -class TestEvalAccuracyLarge(unittest.TestCase): +class TestNightlyGsm8KEval(unittest.TestCase): @classmethod def setUpClass(cls): cls.model_groups = [ @@ -127,13 +138,6 @@ def setUpClass(cls): ] cls.base_url = DEFAULT_URL_FOR_TEST - def setUp(self): - self.process = None - - def tearDown(self): - if self.process: - kill_process_tree(self.process.pid) - def test_mgsm_en_all_models(self): warnings.filterwarnings( "ignore", category=ResourceWarning, message="unclosed.*socket" @@ -144,7 +148,9 @@ def test_mgsm_en_all_models(self): for model_group, is_fp8, is_tp2 in self.model_groups: for model in model_group: with self.subTest(model=model): - self.process = launch_server(self.base_url, model, is_fp8, is_tp2) + process = popen_launch_server_wrapper( + self.base_url, model, is_fp8, is_tp2 + ) args = SimpleNamespace( base_url=self.base_url, @@ -163,8 +169,7 @@ def test_mgsm_en_all_models(self): is_first = False all_results.append((model, metrics["score"])) - - self.tearDown() + kill_process_tree(process.pid) try: with open("results.json", "r") as f: diff --git a/test/srt/test_nightly_human_eval.py b/test/srt/test_nightly_human_eval.py index bffe214b5deb..6558b9effb9b 100644 --- a/test/srt/test_nightly_human_eval.py +++ b/test/srt/test_nightly_human_eval.py @@ -4,7 +4,7 @@ import subprocess import unittest -from test_nightly_gsm8k_eval import launch_server, parse_models +from test_nightly_gsm8k_eval import parse_models, popen_launch_server_wrapper from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( @@ -18,7 +18,7 @@ ) -class TestEvalAccuracyLarge(unittest.TestCase): +class TestNightlyHumanEval(unittest.TestCase): @classmethod def setUpClass(cls): if is_in_ci(): @@ -93,7 +93,7 @@ def test_human_eval_all_models(self): # NOTE: only Llama for now if "Llama" in model: with self.subTest(model=model): - self.process = launch_server( + self.process = popen_launch_server_wrapper( self.base_url, model, is_fp8, is_tp2 ) self.run_evalplus(model) diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 379e57f356e9..4bedf7439663 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -14,6 +14,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( + DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -675,5 +676,45 @@ def test_function_calling_format(self): ), "Function name should be add for the above response" +class TestOpenAIEmbedding(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + + # Configure embedding-specific args + other_args = ["--is-embedding", "--enable-metrics"] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=other_args, + ) + cls.base_url += "/v1" + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_embedding_single(self): + """Test single embedding request""" + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.embeddings.create(model=self.model, input="Hello world") + self.assertEqual(len(response.data), 1) + self.assertTrue(len(response.data[0].embedding) > 0) + + def test_embedding_batch(self): + """Test batch embedding request""" + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + response = client.embeddings.create( + model=self.model, input=["Hello world", "Test text"] + ) + self.assertEqual(len(response.data), 2) + self.assertTrue(len(response.data[0].embedding) > 0) + self.assertTrue(len(response.data[1].embedding) > 0) + + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_regex_constrained.py b/test/srt/test_regex_constrained.py new file mode 100644 index 000000000000..6d5acec15e23 --- /dev/null +++ b/test/srt/test_regex_constrained.py @@ -0,0 +1,186 @@ +""" +python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_email +python3 -m unittest test_regex_constrained.TestRegexConstrained.test_regex_generate_greeting +""" + +import json +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +def setup_class(cls, disable_overlap: bool): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + + other_args = [ + "--max-running-requests", + "10", + "--grammar-backend", + "xgrammar", + ] + + if disable_overlap: + other_args += ["--disable-overlap-schedule"] + + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + +class TestRegexConstrained(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_class(cls, disable_overlap=False) + cls.check_jump_forward = False + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def run_decode( + self, + regex, + prompt, + return_logprob=False, + top_logprobs_num=0, + n=1, + ): + response = requests.post( + self.base_url + "/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": 0 if n == 1 else 0.5, + "max_new_tokens": 128, + "n": n, + "regex": regex, + }, + "stream": False, + "return_logprob": return_logprob, + "top_logprobs_num": top_logprobs_num, + "logprob_start_len": 0, + }, + ) + + ret = response.json() + print(json.dumps(ret, indent=2)) + print("=" * 100) + + if not isinstance(ret, list): + self.fail(f"Expected response to be a list, but got {type(ret)}") + + for item in ret: + text = item.get("text", "").strip() + if not text: + self.fail("Generated text is empty.") + + if not self.regex_match(text, regex): + self.fail(f"Text '{text}' does not match regex pattern.") + + def regex_match(self, text, pattern): + import re + + return re.match(pattern, text) is not None + + def test_regex_generate_email(self): + pattern = r"^user@example\.com$" + prompt = "Generate an email address:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_greeting(self): + pattern = r"^(Hello|Hi|Hey)$" + prompt = "Generate a greeting:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_number(self): + pattern = r"^\d{3}$" + prompt = "Generate a three-digit number:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_phone(self): + pattern = r"^\(\d{3}\) \d{3}-\d{4}$" + prompt = "Generate a phone number:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_date(self): + pattern = r"^2024-(0[1-9]|1[0-2])-(0[1-9]|[12]\d|3[01])$" + prompt = "Generate a date in YYYY-MM-DD format:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_hex_color(self): + pattern = r"^#[0-9A-F]{6}$" + prompt = "Generate a hex color code:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_complex_json(self): + pattern = r'^\{\s*"name"\s*:\s*"[a-zA-Z0-9 ]+"\s*,\s*"age"\s*:\s*[1-9][0-9]*\s*,\s*"city"\s*:\s*"[a-zA-Z0-9 ]+"\s*\}$' + prompt = "Generate a simple JSON with name, age, and city:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + def test_regex_generate_custom_log_format(self): + pattern = r"^\[2024-01-01T12:00:00Z\] INFO: System\.process - Operation [a-z]+ successfully$" + prompt = "Generate a log entry:" + + self.run_decode( + regex=pattern, + prompt=prompt, + n=3, + ) + + +class TestJumpForward(TestRegexConstrained): + @classmethod + def setUpClass(cls): + setup_class(cls, disable_overlap=True) + cls.check_jump_forward = True + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_release_memory_occupation.py b/test/srt/test_release_memory_occupation.py new file mode 100644 index 000000000000..c84b64e77dfe --- /dev/null +++ b/test/srt/test_release_memory_occupation.py @@ -0,0 +1,98 @@ +import time +import unittest + +import torch +from transformers import AutoModelForCausalLM + +import sglang as sgl +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST + +# (temporarily) set to true to observe memory usage in nvidia-smi more clearly +_DEBUG_EXTRA = True + + +class TestReleaseMemoryOccupation(unittest.TestCase): + def test_release_and_resume_occupation(self): + prompt = "Today is a sunny day and I like" + sampling_params = {"temperature": 0, "max_new_tokens": 8} + model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + expect_output = " to spend it outdoors. I decided to" + + engine = sgl.Engine( + model_path=model_name, + random_seed=42, + enable_memory_saver=True, + # disable_cuda_graph=True, # for debugging only + ) + hf_model_new = AutoModelForCausalLM.from_pretrained( + model_name, torch_dtype="bfloat16" + ) + + print("generate (#1)") + outputs = engine.generate(prompt, sampling_params)["text"] + self.assertEqual(outputs, expect_output) + + if _DEBUG_EXTRA: + time.sleep(3) + + self.assertEqual( + _try_allocate_big_tensor(), + False, + "Should not be able to allocate big tensors before releasing", + ) + + print("release_memory_occupation start") + t = time.time() + engine.release_memory_occupation() + if _DEBUG_EXTRA: + print("release_memory_occupation", time.time() - t) + + if _DEBUG_EXTRA: + time.sleep(5) + + self.assertEqual( + _try_allocate_big_tensor(), + True, + "Should be able to allocate big tensors aftre releasing", + ) + + if _DEBUG_EXTRA: + time.sleep(5) + + print("resume_memory_occupation start") + t = time.time() + engine.resume_memory_occupation() + if _DEBUG_EXTRA: + print("resume_memory_occupation", time.time() - t) + + self.assertEqual( + _try_allocate_big_tensor(), + False, + "Should not be able to allocate big tensors after resuming", + ) + + print("update_weights_from_tensor") + # As if: PPO has updated hf model's weights, and now we sync it to SGLang + engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) + + print("generate (#2)") + outputs = engine.generate(prompt, sampling_params)["text"] + self.assertEqual(outputs, expect_output) + + if _DEBUG_EXTRA: + time.sleep(4) + + engine.shutdown() + + +def _try_allocate_big_tensor(size: int = 20_000_000_000): + try: + torch.empty((size,), dtype=torch.uint8, device="cuda") + torch.cuda.empty_cache() + return True + except torch.cuda.OutOfMemoryError: + return False + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_request_length_validation.py b/test/srt/test_request_length_validation.py new file mode 100644 index 000000000000..713e3e21e56b --- /dev/null +++ b/test/srt/test_request_length_validation.py @@ -0,0 +1,71 @@ +import unittest + +import openai + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestRequestLengthValidation(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + + # Start server with auto truncate disabled + cls.process = popen_launch_server( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + api_key=cls.api_key, + other_args=("--max-total-tokens", "1000", "--context-length", "100"), + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_input_length_validation(self): + client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1") + + long_text = "hello " * 100 # Will tokenize to more than context length + + with self.assertRaises(openai.BadRequestError) as cm: + client.chat.completions.create( + model=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + messages=[ + {"role": "user", "content": long_text}, + ], + temperature=0, + ) + + self.assertIn("is longer than the model's context length", str(cm.exception)) + + def test_max_tokens_validation(self): + client = openai.Client(api_key=self.api_key, base_url=f"{self.base_url}/v1") + + long_text = "hello " + + with self.assertRaises(openai.BadRequestError) as cm: + client.chat.completions.create( + model=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + messages=[ + {"role": "user", "content": long_text}, + ], + temperature=0, + max_tokens=500, + ) + + self.assertIn( + "Requested token count exceeds the model's maximum context", + str(cm.exception), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_skip_tokenizer_init.py b/test/srt/test_skip_tokenizer_init.py index bc99b23ad581..db70944091f2 100644 --- a/test/srt/test_skip_tokenizer_init.py +++ b/test/srt/test_skip_tokenizer_init.py @@ -1,11 +1,8 @@ -""" -python3 -m unittest test_skip_tokenizer_init.TestSkipTokenizerInit.test_parallel_sample -""" - import json import unittest import requests +from transformers import AutoTokenizer from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( @@ -15,35 +12,63 @@ popen_launch_server, ) +_server_process = None +_base_url = None +_tokenizer = None + + +def setUpModule(): + """ + Launch the server once before all tests and initialize the tokenizer. + """ + global _server_process, _base_url, _tokenizer + _server_process = popen_launch_server( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_URL_FOR_TEST, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--skip-tokenizer-init"], + ) + _base_url = DEFAULT_URL_FOR_TEST + + _tokenizer = AutoTokenizer.from_pretrained( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, use_fast=False + ) + print(">>> setUpModule: Server launched, tokenizer ready") + + +def tearDownModule(): + """ + Terminate the server once after all tests have completed. + """ + global _server_process + if _server_process is not None: + kill_process_tree(_server_process.pid) + _server_process = None + print(">>> tearDownModule: Server terminated") -class TestSkipTokenizerInit(unittest.TestCase): - @classmethod - def setUpClass(cls): - cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - cls.base_url = DEFAULT_URL_FOR_TEST - cls.process = popen_launch_server( - cls.model, - cls.base_url, - timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--skip-tokenizer-init"], - ) - @classmethod - def tearDownClass(cls): - kill_process_tree(cls.process.pid) +class TestSkipTokenizerInit(unittest.TestCase): + def run_decode( + self, + prompt_text="The capital of France is", + max_new_tokens=32, + return_logprob=False, + top_logprobs_num=0, + n=1, + ): + input_ids = _tokenizer(prompt_text, return_tensors="pt")["input_ids"][ + 0 + ].tolist() - def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): - max_new_tokens = 32 - input_ids = [128000, 791, 6864, 315, 9822, 374] # The capital of France is response = requests.post( - self.base_url + "/generate", + _base_url + "/generate", json={ "input_ids": input_ids, "sampling_params": { "temperature": 0 if n == 1 else 0.5, "max_new_tokens": max_new_tokens, "n": n, - "stop_token_ids": [119690], + "stop_token_ids": [_tokenizer.eos_token_id], }, "stream": False, "return_logprob": return_logprob, @@ -52,23 +77,37 @@ def run_decode(self, return_logprob=False, top_logprobs_num=0, n=1): }, ) ret = response.json() - print(json.dumps(ret)) + print(json.dumps(ret, indent=2)) def assert_one_item(item): - assert len(item["token_ids"]) == item["meta_info"]["completion_tokens"] - assert len(item["token_ids"]) == max_new_tokens - assert item["meta_info"]["prompt_tokens"] == len(input_ids) - - if return_logprob: - assert len(item["meta_info"]["input_token_logprobs"]) == len( - input_ids - ), f'{len(item["meta_info"]["input_token_logprobs"])} vs. f{len(input_ids)}' - assert len(item["meta_info"]["output_token_logprobs"]) == max_new_tokens - + if item["meta_info"]["finish_reason"]["type"] == "stop": + self.assertEqual( + item["meta_info"]["finish_reason"]["matched"], + _tokenizer.eos_token_id, + ) + elif item["meta_info"]["finish_reason"]["type"] == "length": + self.assertEqual( + len(item["token_ids"]), item["meta_info"]["completion_tokens"] + ) + self.assertEqual(len(item["token_ids"]), max_new_tokens) + self.assertEqual(item["meta_info"]["prompt_tokens"], len(input_ids)) + + if return_logprob: + self.assertEqual( + len(item["meta_info"]["input_token_logprobs"]), + len(input_ids), + f'{len(item["meta_info"]["input_token_logprobs"])} mismatch with {len(input_ids)}', + ) + self.assertEqual( + len(item["meta_info"]["output_token_logprobs"]), + max_new_tokens, + ) + + # Determine whether to assert a single item or multiple items based on n if n == 1: assert_one_item(ret) else: - assert len(ret) == n + self.assertEqual(len(ret), n) for i in range(n): assert_one_item(ret[i]) @@ -82,10 +121,10 @@ def test_parallel_sample(self): def test_logprob(self): for top_logprobs_num in [0, 3]: - self.run_decode( - return_logprob=True, - top_logprobs_num=top_logprobs_num, - ) + self.run_decode(return_logprob=True, top_logprobs_num=top_logprobs_num) + + def test_eos_behavior(self): + self.run_decode(max_new_tokens=256) if __name__ == "__main__": diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index 0fd71efcb0b2..7c57c13e251b 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -4,11 +4,15 @@ """ import json +import random import unittest +from concurrent.futures import ThreadPoolExecutor +from typing import Optional import numpy as np import requests +from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, @@ -24,7 +28,10 @@ def setUpClass(cls): cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.base_url = DEFAULT_URL_FOR_TEST cls.process = popen_launch_server( - cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=("--enable-custom-logit-processor",), ) @classmethod @@ -248,6 +255,81 @@ def test_logprob_grammar(self): self.assertTrue(all(x is not None for x in logprobs)) + def run_custom_logit_processor(self, target_token_id: Optional[int] = None): + """Test custom logit processor with custom params. + + If target_token_id is None, the custom logit processor won't be passed in. + """ + + custom_params = {"token_id": target_token_id} + + class DeterministicLogitProcessor(CustomLogitProcessor): + """A dummy logit processor that changes the logits to always + sample the given token id. + """ + + def __call__(self, logits, custom_param_list): + assert logits.shape[0] == len(custom_param_list) + key = "token_id" + + for i, param_dict in enumerate(custom_param_list): + # Mask all other tokens + logits[i, :] = -float("inf") + # Assign highest probability to the specified token + logits[i, param_dict[key]] = 0.0 + return logits + + prompts = "Question: Is Paris the Capital of France? Answer:" + + # Base case json data to be posted to the server. + base_json = { + "text": prompts, + "sampling_params": {"temperature": 0.0}, + "return_logprob": True, + } + + # Custom json data with custom logit processor and params. + custom_json = base_json.copy() + # Only set the custom logit processor if target_token_id is not None. + if target_token_id is not None: + custom_json["custom_logit_processor"] = ( + DeterministicLogitProcessor().to_str() + ) + custom_json["sampling_params"]["custom_params"] = custom_params + + custom_response = requests.post( + self.base_url + "/generate", + json=custom_json, + ).json() + + output_token_logprobs = custom_response["meta_info"]["output_token_logprobs"] + sampled_tokens = [x[1] for x in output_token_logprobs] + + # The logit processor should always sample the given token as the logits is deterministic. + if target_token_id is not None: + self.assertTrue( + all(x == custom_params["token_id"] for x in sampled_tokens), + # Print the detailed test case info if the test fails. + f"{target_token_id=}\n{sampled_tokens=}\n{custom_response=}", + ) + + def test_custom_logit_processor(self): + """Test custom logit processor with a single request.""" + self.run_custom_logit_processor(target_token_id=5) + + def test_custom_logit_processor_batch(self): + """Test custom logit processor with a batch of requests.""" + target_token_ids = list(range(32)) + with ThreadPoolExecutor(len(target_token_ids)) as executor: + list(executor.map(self.run_custom_logit_processor, target_token_ids)) + + def test_custom_logit_processor_batch_mixed(self): + """Test a batch of requests mixed of requests with and without custom logit processor.""" + target_token_ids = list(range(32)) + [None] * 16 + random.shuffle(target_token_ids) + with ThreadPoolExecutor(len(target_token_ids)) as executor: + list(executor.map(self.run_custom_logit_processor, target_token_ids)) + def test_get_server_info(self): response = requests.get(self.base_url + "/get_server_info") response_json = response.json() diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py index 7479b6468376..c535d5c06867 100644 --- a/test/srt/test_srt_engine.py +++ b/test/srt/test_srt_engine.py @@ -1,6 +1,6 @@ """ Usage: -python3 -m unittest test_srt_engine.TestSRTEngine.test_3_sync_streaming_combination +python3 -m unittest test_srt_engine.TestSRTEngine.test_4_sync_async_stream_combination """ import asyncio @@ -44,64 +44,97 @@ def test_1_engine_runtime_consistency(self): print(out2) self.assertEqual(out1, out2) - def test_2_engine_multiple_generate(self): + def test_2_engine_runtime_encode_consistency(self): + prompt = "Today is a sunny day and I like" + model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST + + engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42) + out1 = torch.tensor(engine.encode(prompt)["embedding"]) + engine.shutdown() + + runtime = sgl.Runtime(model_path=model_path, is_embedding=True, random_seed=42) + out2 = torch.tensor(json.loads(runtime.encode(prompt))["embedding"]) + runtime.shutdown() + + self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3)) + + def test_3_engine_token_ids_consistency(self): # just to ensure there is no issue running multiple generate calls prompt = "Today is a sunny day and I like" model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - sampling_params = {"temperature": 0, "max_new_tokens": 8} - engine = sgl.Engine(model_path=model_path, random_seed=42) - engine.generate(prompt, sampling_params) - engine.generate(prompt, sampling_params) - engine.shutdown() + engine = sgl.Engine( + model_path=model_path, random_seed=42, disable_radix_cache=True + ) + out1 = engine.generate(prompt, sampling_params)["text"] - def test_3_sync_streaming_combination(self): + tokenizer = get_tokenizer(model_path) + token_ids = tokenizer.encode(prompt) + out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)[ + "text" + ] - prompt = "AI safety is..." - sampling_params = {"temperature": 0.8, "top_p": 0.95} + engine.shutdown() - async def async_streaming(engine): + print("==== Answer 1 ====") + print(out1) - generator = await engine.async_generate( - prompt, sampling_params, stream=True - ) + print("==== Answer 2 ====") + print(out2) + self.assertEqual(out1, out2) - async for output in generator: - print(output["text"], end="", flush=True) - print() + def test_4_sync_async_stream_combination(self): + prompt = "AI safety is" + sampling_params = {"temperature": 0.8, "top_p": 0.95} # Create an LLM. llm = sgl.Engine( model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ) - # 1. sync + non streaming - print("\n\n==== 1. sync + non streaming ====") - output = llm.generate(prompt, sampling_params) + if True: + # 1. sync + non streaming + print("\n\n==== 1. sync + non streaming ====") + output = llm.generate(prompt, sampling_params) + print(output["text"]) + + # 2. sync + streaming + print("\n\n==== 2. sync + streaming ====") + output_generator = llm.generate(prompt, sampling_params, stream=True) + offset = 0 + for output in output_generator: + print(output["text"][offset:], end="", flush=True) + offset = len(output["text"]) + print() - print(output["text"]) + if True: + loop = asyncio.get_event_loop() + # 3. async + non_streaming + print("\n\n==== 3. async + non streaming ====") + output = loop.run_until_complete( + llm.async_generate(prompt, sampling_params) + ) + print(output["text"]) - # 2. sync + streaming - print("\n\n==== 2. sync + streaming ====") - output_generator = llm.generate(prompt, sampling_params, stream=True) - for output in output_generator: - print(output["text"], end="", flush=True) - print() + # 4. async + streaming + async def async_streaming(engine): + generator = await engine.async_generate( + prompt, sampling_params, stream=True + ) - loop = asyncio.get_event_loop() - # 3. async + non_streaming - print("\n\n==== 3. async + non streaming ====") - output = loop.run_until_complete(llm.async_generate(prompt, sampling_params)) - print(output["text"]) + offset = 0 + async for output in generator: + print(output["text"][offset:], end="", flush=True) + offset = len(output["text"]) + print() - # 4. async + streaming - print("\n\n==== 4. async + streaming ====") - loop.run_until_complete(async_streaming(llm)) + print("\n\n==== 4. async + streaming ====") + loop.run_until_complete(async_streaming(llm)) llm.shutdown() - def test_4_gsm8k(self): + def test_5_gsm8k(self): args = SimpleNamespace( model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, @@ -113,46 +146,7 @@ def test_4_gsm8k(self): metrics = run_eval(args) self.assertGreater(metrics["accuracy"], 0.3) - def test_5_prompt_input_ids_consistency(self): - prompt = "The capital of UK is" - - model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - engine = sgl.Engine( - model_path=model_path, random_seed=42, disable_radix_cache=True - ) - sampling_params = {"temperature": 0, "max_new_tokens": 8} - out1 = engine.generate(prompt, sampling_params)["text"] - - tokenizer = get_tokenizer(model_path) - token_ids = tokenizer.encode(prompt) - out2 = engine.generate(input_ids=token_ids, sampling_params=sampling_params)[ - "text" - ] - - engine.shutdown() - - print("==== Answer 1 ====") - print(out1) - - print("==== Answer 2 ====") - print(out2) - self.assertEqual(out1, out2) - - def test_6_engine_runtime_encode_consistency(self): - prompt = "Today is a sunny day and I like" - model_path = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST - - engine = sgl.Engine(model_path=model_path, is_embedding=True, random_seed=42) - out1 = torch.tensor(engine.encode(prompt)["embedding"]) - engine.shutdown() - - runtime = sgl.Runtime(model_path=model_path, is_embedding=True, random_seed=42) - out2 = torch.tensor(json.loads(runtime.encode(prompt))["embedding"]) - runtime.shutdown() - - self.assertTrue(torch.allclose(out1, out2, atol=1e-5, rtol=1e-3)) - - def test_7_engine_cpu_offload(self): + def test_6_engine_cpu_offload(self): prompt = "Today is a sunny day and I like" model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST @@ -182,7 +176,7 @@ def test_7_engine_cpu_offload(self): print(out2) self.assertEqual(out1, out2) - def test_8_engine_offline_throughput(self): + def test_7_engine_offline_throughput(self): server_args = ServerArgs( model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST, ) diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py index 6f3b344b3cce..e71de3391177 100644 --- a/test/srt/test_torch_compile.py +++ b/test/srt/test_torch_compile.py @@ -23,7 +23,7 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-torch-compile"], + other_args=["--enable-torch-compile", "--cuda-graph-max-bs", "4"], ) @classmethod diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index e19e6b01d513..5be911ab84a4 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -171,7 +171,7 @@ def test_multi_images_chat_completion(self): text = response.choices[0].message.content assert isinstance(text, str) print(text) - assert "man" in text or "cab" in text, text + assert "man" in text or "cab" in text or "SUV" in text or "taxi" in text, text assert "logo" in text or '"S"' in text or "SG" in text, text assert response.id assert response.created @@ -392,34 +392,33 @@ def tearDownClass(cls): def test_chat_completion(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url) - response = client.chat.completions.create( - model="default", - messages=[ - { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": { - "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" + with self.assertRaises(openai.BadRequestError) as cm: + client.chat.completions.create( + model="default", + messages=[ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" + }, }, - }, - { - "type": "text", - "text": "Give a lengthy description of this picture", - }, - ], - }, - ], - temperature=0, - ) + { + "type": "text", + "text": "Give a lengthy description of this picture", + }, + ], + }, + ], + temperature=0, + ) - assert response.choices[0].finish_reason == "abort" - assert response.id - assert response.created - assert response.usage.prompt_tokens > 0 - assert response.usage.completion_tokens > 0 - assert response.usage.total_tokens > 0 + self.assertIn( + "Multimodal prompt is too long after expanding multimodal tokens.", + str(cm.exception), + ) class TestMllamaServer(TestOpenAIVisionServer): @@ -444,5 +443,24 @@ def test_video_chat_completion(self): pass +class TestMinicpmvServer(TestOpenAIVisionServer): + @classmethod + def setUpClass(cls): + cls.model = "openbmb/MiniCPM-V-2_6" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.api_key = "sk-123456" + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--trust-remote-code", + "--chat-template", + "minicpmv", + ], + ) + cls.base_url += "/v1" + + if __name__ == "__main__": unittest.main() diff --git a/test/srt/test_w8a8_quantization.py b/test/srt/test_w8a8_quantization.py new file mode 100644 index 000000000000..78579d5e2dea --- /dev/null +++ b/test/srt/test_w8a8_quantization.py @@ -0,0 +1,74 @@ +import time +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.test_utils import ( + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestW8A8(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = "neuralmagic/Meta-Llama-3-8B-Instruct-quantized.w8a8" + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--quantization", "w8a8_int8"], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.7) + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + "ignore_eos": True, + }, + ) + return response.json() + + def test_throughput(self): + max_tokens = 256 + + tic = time.time() + res = self.run_decode(max_tokens) + tok = time.time() + print(res["text"]) + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + assert throughput >= 140 + + +if __name__ == "__main__": + unittest.main()