diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 3515ccd65667..5e79984c9f7b 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -7,7 +7,7 @@ set -o pipefail echo "--- Confirming Clean Initial State" while true; do sleep 3 - if grep -q clean /opt/amdgpu/etc/gpu_state; then + if grep -q clean ${BUILDKITE_AGENT_META_DATA_RESET_TARGET}; then echo "GPUs state is \"clean\"" break fi @@ -46,11 +46,11 @@ cleanup_docker echo "--- Resetting GPUs" -echo "reset" > /opt/amdgpu/etc/gpu_state +echo "reset" > ${BUILDKITE_AGENT_META_DATA_RESET_TARGET} while true; do sleep 3 - if grep -q clean /opt/amdgpu/etc/gpu_state; then + if grep -q clean ${BUILDKITE_AGENT_META_DATA_RESET_TARGET}; then echo "GPUs state is \"clean\"" break fi @@ -141,8 +141,9 @@ if [[ $commands == *"--shard-id="* ]]; then fi done else + echo "Render devices: $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES" docker run \ - --device /dev/kfd --device /dev/dri \ + --device /dev/kfd $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES \ --network host \ --shm-size=16gb \ --rm \ diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index a847a68a6ef7..a3cb93be8d47 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -92,7 +92,9 @@ steps: - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - label: Core Test # 10min + working_dir: "/vllm-workspace/tests" mirror_hardwares: [amd] + amd_gpus: 4 # Just for the sake of queue testing fast_check: true source_file_dependencies: - vllm/core @@ -105,6 +107,7 @@ steps: working_dir: "/vllm-workspace/tests" fast_check: true mirror_hardwares: [amd] + amd_gpus: 1 # Just for the sake of queue testing source_file_dependencies: - vllm/ commands: @@ -158,6 +161,7 @@ steps: - label: Regression Test # 5min mirror_hardwares: [amd] + amd_gpus: 1 source_file_dependencies: - vllm/ - tests/test_regression @@ -168,6 +172,7 @@ steps: - label: Engine Test # 10min mirror_hardwares: [amd] + amd_gpus: 1 source_file_dependencies: - vllm/ - tests/engine @@ -176,6 +181,7 @@ steps: - pytest -v -s engine test_sequence.py test_config.py test_logger.py # OOM in the CI unless we run this separately - pytest -v -s tokenization + working_dir: "/vllm-workspace/tests" # optional - label: V1 Test #mirror_hardwares: [amd] @@ -217,7 +223,9 @@ steps: - python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 - label: Prefix Caching Test # 9min + working_dir: "/vllm-workspace/tests" mirror_hardwares: [amd] + amd_gpus: 1 source_file_dependencies: - vllm/ - tests/prefix_caching @@ -235,7 +243,9 @@ steps: - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers - label: LogitsProcessor Test # 5min + working_dir: "/vllm-workspace/tests" mirror_hardwares: [amd] + amd_gpus: 1 source_file_dependencies: - vllm/model_executor/layers - vllm/model_executor/guided_decoding @@ -256,7 +266,9 @@ steps: - pytest -v -s spec_decode/e2e/test_eagle_correctness.py - label: LoRA Test %N # 15min each + working_dir: "/vllm-workspace/tests" mirror_hardwares: [amd] + amd_gpus: 8 source_file_dependencies: - vllm/lora - tests/lora @@ -282,7 +294,9 @@ steps: - pytest -v -s compile/test_full_graph.py - label: Kernels Test %N # 1h each + working_dir: "/vllm-workspace/tests" mirror_hardwares: [amd] + amd_gpus: 8 source_file_dependencies: - csrc/ - vllm/attention @@ -292,8 +306,10 @@ steps: parallelism: 4 - label: Tensorizer Test # 11min + working_dir: "/vllm-workspace/tests" mirror_hardwares: [amd] soft_fail: true + amd_gpus: 1 source_file_dependencies: - vllm/model_executor/model_loader - tests/tensorizer_loader @@ -305,6 +321,7 @@ steps: - label: Benchmarks # 9min working_dir: "/vllm-workspace/.buildkite" mirror_hardwares: [amd] + amd_gpus: 1 source_file_dependencies: - benchmarks/ commands: @@ -334,8 +351,10 @@ steps: - pytest -v -s encoder_decoder - label: OpenAI-Compatible Tool Use # 20 min + working_dir: "/vllm-workspace/tests" fast_check: false mirror_hardwares: [ amd ] + amd_gpus: 1 source_file_dependencies: - vllm/ - tests/tool_use diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 new file mode 100644 index 000000000000..7106395910d3 --- /dev/null +++ b/.buildkite/test-template.j2 @@ -0,0 +1,46 @@ +{% set docker_image = "public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT" %} +{% set docker_image_amd = "rocm/vllm-ci:$BUILDKITE_COMMIT" %} +{% set default_working_dir = "vllm/tests" %} +{% set hf_home = "/root/.cache/huggingface" %} + +steps: + - label: ":docker: build image" + depends_on: ~ + commands: + - "docker build --build-arg max_jobs=16 --tag {{ docker_image_amd }} -f Dockerfile.rocm --target test --progress plain ." + - "docker push {{ docker_image_amd }}" + key: "amd-build" + env: + DOCKER_BUILDKIT: "1" + retry: + automatic: + - exit_status: -1 # Agent was lost + limit: 5 + - exit_status: -10 # Agent was lost + limit: 5 + agents: + queue: amd-cpu + +{% for step in steps %} +{% if step.mirror_hardwares and "amd" in step.mirror_hardwares %} + - label: "AMD: {{ step.label }}" + depends_on: + - "amd-build" + agents: +{% if step.amd_gpus and step.amd_gpus==8%} + queue: amd_gpu_8 +{% elif step.amd_gpus and step.amd_gpus==4%} + queue: amd_gpu_4 +{% elif step.amd_gpus and step.amd_gpus==2%} + queue: amd_gpu_4 +{% else%} + queue: amd_gpu_1 +{% endif%} + commands: + - bash .buildkite/run-amd-test.sh "cd {{ (step.working_dir or default_working_dir) | safe }} ; {{ step.command or (step.commands | join(" && ")) | safe }}" + env: + DOCKER_BUILDKIT: "1" + priority: 100 + soft_fail: true +{% endif %} +{% endfor %} diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml deleted file mode 100644 index 556b60d2fca1..000000000000 --- a/.github/workflows/lint-and-deploy.yaml +++ /dev/null @@ -1,82 +0,0 @@ -name: Lint and Deploy Charts - -on: pull_request - -jobs: - lint-and-deploy: - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - fetch-depth: 0 - - - name: Set up Helm - uses: azure/setup-helm@fe7b79cd5ee1e45176fcad797de68ecaf3ca4814 # v4.2.0 - with: - version: v3.14.4 - - #Python is required because ct lint runs Yamale and yamllint which require Python. - - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: '3.13' - - - name: Set up chart-testing - uses: helm/chart-testing-action@e6669bcd63d7cb57cb4380c33043eebe5d111992 # v2.6.1 - with: - version: v3.10.1 - - - name: Run chart-testing (lint) - run: ct lint --target-branch ${{ github.event.repository.default_branch }} --chart-dirs examples/online_serving/chart-helm --charts examples/online_serving/chart-helm - - - name: Setup minio - run: | - docker network create vllm-net - docker run -d -p 9000:9000 --name minio --net vllm-net \ - -e "MINIO_ACCESS_KEY=minioadmin" \ - -e "MINIO_SECRET_KEY=minioadmin" \ - -v /tmp/data:/data \ - -v /tmp/config:/root/.minio \ - minio/minio server /data - export AWS_ACCESS_KEY_ID=minioadmin - export AWS_SECRET_ACCESS_KEY=minioadmin - export AWS_EC2_METADATA_DISABLED=true - mkdir opt-125m - cd opt-125m && curl -O -Ls "https://huggingface.co/facebook/opt-125m/resolve/main/{pytorch_model.bin,config.json,generation_config.json,merges.txt,special_tokens_map.json,tokenizer_config.json,vocab.json}" && cd .. - aws --endpoint-url http://127.0.0.1:9000/ s3 mb s3://testbucket - aws --endpoint-url http://127.0.0.1:9000/ s3 cp opt-125m/ s3://testbucket/opt-125m --recursive - - - name: Create kind cluster - uses: helm/kind-action@0025e74a8c7512023d06dc019c617aa3cf561fde # v1.10.0 - - - name: Build the Docker image vllm cpu - run: docker buildx build -f Dockerfile.cpu -t vllm-cpu-env . - - - name: Configuration of docker images, network and namespace for the kind cluster - run: | - docker pull amazon/aws-cli:2.6.4 - kind load docker-image amazon/aws-cli:2.6.4 --name chart-testing - kind load docker-image vllm-cpu-env:latest --name chart-testing - docker network connect vllm-net "$(docker ps -aqf "name=chart-testing-control-plane")" - kubectl create ns ns-vllm - - - name: Run chart-testing (install) - run: | - export AWS_ACCESS_KEY_ID=minioadmin - export AWS_SECRET_ACCESS_KEY=minioadmin - sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & - helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" - - - name: curl test - run: | - kubectl -n ns-vllm port-forward service/test-vllm-service 8001:80 & - sleep 10 - CODE="$(curl -v -f --location http://localhost:8001/v1/completions \ - --header "Content-Type: application/json" \ - --data '{ - "model": "opt-125m", - "prompt": "San Francisco is a", - "max_tokens": 7, - "temperature": 0 - }'):$CODE" - echo "$CODE" \ No newline at end of file diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index e40ceaaa8b03..f3dda4c25c79 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -16,7 +16,9 @@ jobs: release: # Retrieve tag and create release name: Create Release - runs-on: ubuntu-latest + runs-on: self-hosted + container: + image: rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0 outputs: upload_url: ${{ steps.create_release.outputs.upload_url }} steps: @@ -39,73 +41,42 @@ jobs: const script = require('.github/workflows/scripts/create_release.js') await script(github, context, core) - # NOTE(simon): No longer build wheel using Github Actions. See buildkite's release workflow. - # wheel: - # name: Build Wheel - # runs-on: ${{ matrix.os }} - # needs: release + wheel: + name: Build Wheel + runs-on: self-hosted + container: + image: rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0 + needs: release - # strategy: - # fail-fast: false - # matrix: - # os: ['ubuntu-20.04'] - # python-version: ['3.9', '3.10', '3.11', '3.12'] - # pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt. - # cuda-version: ['11.8', '12.1'] + strategy: + fail-fast: false - # steps: - # - name: Checkout - # uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - # - name: Setup ccache - # uses: hendrikmuhs/ccache-action@ed74d11c0b343532753ecead8a951bb09bb34bc9 # v1.2.14 - # with: - # create-symlink: true - # key: ${{ github.job }}-${{ matrix.python-version }}-${{ matrix.cuda-version }} - - # - name: Set up Linux Env - # if: ${{ runner.os == 'Linux' }} - # run: | - # bash -x .github/workflows/scripts/env.sh - - # - name: Set up Python - # uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - # with: - # python-version: ${{ matrix.python-version }} - - # - name: Install CUDA ${{ matrix.cuda-version }} - # run: | - # bash -x .github/workflows/scripts/cuda-install.sh ${{ matrix.cuda-version }} ${{ matrix.os }} - - # - name: Install PyTorch ${{ matrix.pytorch-version }} with CUDA ${{ matrix.cuda-version }} - # run: | - # bash -x .github/workflows/scripts/pytorch-install.sh ${{ matrix.python-version }} ${{ matrix.pytorch-version }} ${{ matrix.cuda-version }} - - # - name: Build wheel - # shell: bash - # env: - # CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size - # run: | - # bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} - # wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename) - # asset_name=${wheel_name//"linux"/"manylinux1"} - # echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV" - # echo "asset_name=${asset_name}" >> "$GITHUB_ENV" + steps: + - name: Prepare + run: | + pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2 + pip3 install -U triton - # - name: Upload Release Asset - # uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2 - # env: - # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - # with: - # upload_url: ${{ needs.release.outputs.upload_url }} - # asset_path: ./dist/${{ env.wheel_name }} - # asset_name: ${{ env.asset_name }} - # asset_content_type: application/* + - name: Checkout + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - # (Danielkinz): This last step will publish the .whl to pypi. Warning: untested - # - name: Publish package - # uses: pypa/gh-action-pypi-publish@release/v1.8 - # with: - # repository-url: https://test.pypi.org/legacy/ - # password: ${{ secrets.PYPI_API_TOKEN }} - # skip-existing: true + - name: Build wheel + shell: bash + env: + CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size + run: | + bash -x .github/workflows/scripts/build.sh + wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename) + asset_name=${wheel_name//"linux"/"manylinux1"} + echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV" + echo "asset_name=${asset_name}" >> "$GITHUB_ENV" + + - name: Upload vllm Release Asset + uses: actions/upload-release-asset@e8f9f06c4b078e705bd2ea027f0926603fc9b4d5 # v1.0.2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + upload_url: ${{ needs.release.outputs.upload_url }} + asset_path: ./dist/${{ env.wheel_name }} + asset_name: ${{ env.asset_name }} + asset_content_type: application/* diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml deleted file mode 100644 index df62539c0b3d..000000000000 --- a/.github/workflows/reminder_comment.yml +++ /dev/null @@ -1,21 +0,0 @@ -name: PR Reminder Comment Bot -on: - pull_request_target: - types: [opened] - -jobs: - pr_reminder: - runs-on: ubuntu-latest - steps: - - name: Remind to run full CI on PR - uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 - with: - script: | - github.rest.issues.createComment({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - body: 'šŸ‘‹ Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org. \n\nOnce the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n To run CI, PR reviewers can do one of these:\n- Add `ready` label to the PR\n- Enable auto-merge.\n\nšŸš€' - }) - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/scripts/build.sh b/.github/workflows/scripts/build.sh index 122e4e101e20..f0a4e4baf1ae 100644 --- a/.github/workflows/scripts/build.sh +++ b/.github/workflows/scripts/build.sh @@ -1,23 +1,20 @@ #!/bin/bash set -eux -python_executable=python$1 -cuda_home=/usr/local/cuda-$2 +python_executable=python3 # Update paths -PATH=${cuda_home}/bin:$PATH -LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH - # Install requirements -$python_executable -m pip install -r requirements-build.txt -r requirements-cuda.txt +$python_executable -m pip install -r requirements-rocm.txt # Limit the number of parallel jobs to avoid OOM export MAX_JOBS=1 # Make sure release wheels are built for the following architectures -export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" -export VLLM_FA_CMAKE_GPU_ARCHES="80-real;90-real" +export PYTORCH_ROCM_ARCH="gfx90a;gfx942" + +rm -f "$(which sccache)" -bash tools/check_repo.sh +export MAX_JOBS=32 # Build $python_executable setup.py bdist_wheel --dist-dir=dist diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4568efcbba21..14006fcab5d7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: rev: v2.4.0 hooks: - id: codespell - exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*' + exclude: 'benchmarks/sonnet.txt|(build|tests/(lora/data|models/fixtures|prompts))/.*|csrc/rocm/.*|csrc/gradlib/.*' - repo: https://github.com/PyCQA/isort rev: 5.13.2 hooks: diff --git a/CMakeLists.txt b/CMakeLists.txt index c823c9ff895c..d77d810f9aed 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,7 +34,7 @@ set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0") # Supported AMD GPU architectures. -set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101") +set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1200;gfx1201") # # Supported/expected torch versions for CUDA/ROCm. @@ -149,6 +149,19 @@ else() "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}") endif() +# +# Setting up debug flags for pleasant debug experience. +# +set(CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG "${CMAKE_${VLLM_GPU_LANG}_FLAGS_DEBUG} -O0 -ggdb3") +set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -ggdb3") + + +# +# Suppressing the noisy warning +# +set(CMAKE_${VLLM_GPU_LANG}_FLAGS "${CMAKE_${VLLM_GPU_LANG}_FLAGS} -Wno-unused-result") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result") + # # Query torch for additional GPU compilation flags for the given # `VLLM_GPU_LANG`. @@ -156,6 +169,20 @@ endif() # get_torch_gpu_compiler_flags(VLLM_GPU_FLAGS ${VLLM_GPU_LANG}) +# +# Get supported FP8 format based on GPU arches +# +get_supported_fp8_format(FP8_FORMAT ${VLLM_GPU_LANG} "${VLLM_GPU_ARCHES}") +if(${FP8_FORMAT} STREQUAL "E4M3FN") + message(STATUS "FP8 format: E4M3FN") + list(APPEND VLLM_GPU_FLAGS "-DUSE_CUDA_FP8_FORMAT") +elseif(${FP8_FORMAT} STREQUAL "E4M3FNUZ") + message(STATUS "FP8 format: E4M3FNUZ") + list(APPEND VLLM_GPU_FLAGS "-DUSE_HIP_FP8_FORMAT") +elseif(${FP8_FORMAT} STREQUAL "CONFLICT") + message(FATAL_ERROR "Target architectures support different types of FP8 formats!") +endif() + # # Set nvcc parallelism. # @@ -175,7 +202,14 @@ file(MAKE_DIRECTORY ${FETCHCONTENT_BASE_DIR}) # Ensure the directory exists message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") # -# Define other extension targets +# Set rocm version dev int. +# +if(VLLM_GPU_LANG STREQUAL "HIP") + list(APPEND VLLM_GPU_FLAGS "-DROCM_VERSION=${ROCM_VERSION_DEV_INT}") +endif() + +# +# Define extension targets # # @@ -450,6 +484,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if CUDA endif endif() +if(VLLM_GPU_LANG STREQUAL "HIP") + list(APPEND VLLM_EXT_SRC + "csrc/custom_all_reduce.cu") +endif() + message(STATUS "Enabling C extension.") define_gpu_extension_target( _C @@ -523,7 +562,10 @@ if(VLLM_GPU_LANG STREQUAL "HIP") # set(VLLM_ROCM_EXT_SRC "csrc/rocm/torch_bindings.cpp" - "csrc/rocm/attention.cu") + "csrc/rocm/attention.cu" + "csrc/rocm/custom_kernels.cu" + "csrc/rocm/fused_kernels.cu" + "csrc/rocm/custom.cu") define_gpu_extension_target( _rocm_C @@ -534,6 +576,24 @@ if(VLLM_GPU_LANG STREQUAL "HIP") ARCHITECTURES ${VLLM_GPU_ARCHES} USE_SABI 3 WITH_SOABI) + + # + # _gradlib_C extension + # + set(VLLM_GRADLIB_EXT_SRC + "csrc/gradlib/torch_bindings.cpp" + "csrc/gradlib/hipbsolgemm.cu" + "csrc/gradlib/rocsolgemm.cu") + + define_gpu_extension_target( + _gradlib_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_GRADLIB_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) endif() # vllm-flash-attn currently only supported on CUDA diff --git a/Dockerfile.base b/Dockerfile.base new file mode 100644 index 000000000000..e33e73b30309 --- /dev/null +++ b/Dockerfile.base @@ -0,0 +1,158 @@ +ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete +ARG HIPBLASLT_BRANCH="4d40e36" +ARG HIPBLAS_COMMON_BRANCH="7c1566b" +ARG LEGACY_HIPBLASLT_OPTION= +ARG RCCL_BRANCH="648a58d" +ARG RCCL_REPO="https://github.com/ROCm/rccl" +ARG TRITON_BRANCH="e5be006" +ARG TRITON_REPO="https://github.com/triton-lang/triton.git" +ARG PYTORCH_BRANCH="3a585126" +ARG PYTORCH_VISION_BRANCH="v0.19.1" +ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" +ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" +ARG FA_BRANCH="b7d29fb" +ARG FA_REPO="https://github.com/ROCm/flash-attention.git" + +FROM ${BASE_IMAGE} AS base + +ENV PATH=/opt/rocm/llvm/bin:$PATH +ENV ROCM_PATH=/opt/rocm +ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib: +ARG PYTORCH_ROCM_ARCH=gfx90a;gfx942 +ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} + +ARG PYTHON_VERSION=3.12 + +RUN mkdir -p /app +WORKDIR /app +ENV DEBIAN_FRONTEND=noninteractive + +# Install Python and other dependencies +RUN apt-get update -y \ + && apt-get install -y software-properties-common git curl sudo vim less \ + && add-apt-repository ppa:deadsnakes/ppa \ + && apt-get update -y \ + && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ + python${PYTHON_VERSION}-lib2to3 python-is-python3 \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ + && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ + && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ + && python3 --version && python3 -m pip --version + +RUN pip install -U packaging cmake ninja wheel setuptools pybind11 Cython + +FROM base AS build_hipblaslt +ARG HIPBLASLT_BRANCH +ARG HIPBLAS_COMMON_BRANCH +# Set to "--legacy_hipblas_direct" for ROCm<=6.2 +ARG LEGACY_HIPBLASLT_OPTION +RUN git clone https://github.com/ROCm/hipBLAS-common.git +RUN cd hipBLAS-common \ + && git checkout ${HIPBLAS_COMMON_BRANCH} \ + && mkdir build \ + && cd build \ + && cmake .. \ + && make package \ + && dpkg -i ./*.deb +RUN git clone https://github.com/ROCm/hipBLASLt +RUN cd hipBLASLt \ + && git checkout ${HIPBLASLT_BRANCH} \ + && ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \ + && cd build/release \ + && make package +RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install + +FROM base AS build_rccl +ARG RCCL_BRANCH +ARG RCCL_REPO +RUN git clone ${RCCL_REPO} +RUN cd rccl \ + && git checkout ${RCCL_BRANCH} \ + && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH} +RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install + +FROM base AS build_triton +ARG TRITON_BRANCH +ARG TRITON_REPO +RUN git clone ${TRITON_REPO} +RUN cd triton \ + && git checkout ${TRITON_BRANCH} \ + && cd python \ + && python3 setup.py bdist_wheel --dist-dir=dist +RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install + +FROM base AS build_amdsmi +RUN cd /opt/rocm/share/amd_smi \ + && pip wheel . --wheel-dir=dist +RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install + +FROM base AS build_pytorch +ARG PYTORCH_BRANCH +ARG PYTORCH_VISION_BRANCH +ARG PYTORCH_REPO +ARG PYTORCH_VISION_REPO +ARG FA_BRANCH +ARG FA_REPO +RUN git clone ${PYTORCH_REPO} pytorch +RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \ + pip install -r requirements.txt && git submodule update --init --recursive \ + && python3 tools/amd_build/build_amd.py \ + && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \ + && pip install dist/*.whl +RUN git clone ${PYTORCH_VISION_REPO} vision +RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \ + && python3 setup.py bdist_wheel --dist-dir=dist \ + && pip install dist/*.whl +RUN git clone ${FA_REPO} +RUN cd flash-attention \ + && git checkout ${FA_BRANCH} \ + && git submodule update --init \ + && MAX_JOBS=64 GPU_ARCHS=${PYTORCH_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist +RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \ + && cp /app/vision/dist/*.whl /app/install \ + && cp /app/flash-attention/dist/*.whl /app/install + +FROM base AS final +RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \ + dpkg -i /install/*deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status +RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \ + dpkg -i /install/*deb \ + && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \ + && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status +RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \ + pip install /install/*.whl +RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ + pip install /install/*.whl +RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ + pip install /install/*.whl + +ARG BASE_IMAGE +ARG HIPBLASLT_BRANCH +ARG LEGACY_HIPBLASLT_OPTION +ARG RCCL_BRANCH +ARG RCCL_REPO +ARG TRITON_BRANCH +ARG TRITON_REPO +ARG PYTORCH_BRANCH +ARG PYTORCH_VISION_BRANCH +ARG PYTORCH_REPO +ARG PYTORCH_VISION_REPO +ARG FA_BRANCH +ARG FA_REPO +RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ + && echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \ + && echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \ + && echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \ + && echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \ + && echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \ + && echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \ + && echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \ + && echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \ + && echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \ + && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \ + && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ + && echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \ + && echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt diff --git a/Dockerfile.base_navi b/Dockerfile.base_navi new file mode 100644 index 000000000000..389933840cd3 --- /dev/null +++ b/Dockerfile.base_navi @@ -0,0 +1,143 @@ +ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:6.3.1-complete +ARG HIPBLASLT_BRANCH="4d40e36" +ARG HIPBLAS_COMMON_BRANCH="7c1566b" +ARG LEGACY_HIPBLASLT_OPTION= +ARG RCCL_BRANCH="648a58d" +ARG RCCL_REPO="https://github.com/ROCm/rccl" +ARG TRITON_BRANCH="e5be006" +ARG TRITON_REPO="https://github.com/triton-lang/triton.git" +ARG PYTORCH_BRANCH="8d4926e" +ARG PYTORCH_VISION_BRANCH="v0.19.1" +ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" +ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" + +FROM ${BASE_IMAGE} AS base + +ENV PATH=/opt/rocm/llvm/bin:$PATH +ENV ROCM_PATH=/opt/rocm +ENV LD_LIBRARY_PATH=/opt/rocm/lib:/usr/local/lib: +ARG PYTORCH_ROCM_ARCH=gfx1100;gfx1101;gfx1200;gfx1201 +ENV PYTORCH_ROCM_ARCH=${PYTORCH_ROCM_ARCH} + +ARG PYTHON_VERSION=3.12 + +RUN mkdir -p /app +WORKDIR /app +ENV DEBIAN_FRONTEND=noninteractive + +# Install Python and other dependencies +RUN apt-get update -y \ + && apt-get install -y software-properties-common git curl sudo vim less \ + && add-apt-repository ppa:deadsnakes/ppa \ + && apt-get update -y \ + && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ + python${PYTHON_VERSION}-lib2to3 python-is-python3 \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ + && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ + && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ + && python3 --version && python3 -m pip --version + +RUN pip install -U packaging cmake ninja wheel setuptools Cython pybind11 + +FROM base AS build_hipblaslt +ARG HIPBLASLT_BRANCH +# Set to "--legacy_hipblas_direct" for ROCm<=6.2 +ARG LEGACY_HIPBLASLT_OPTION +RUN git clone https://github.com/ROCm/hipBLAS-common.git +RUN cd hipBLAS-common \ + && git checkout ${HIPBLAS_COMMON_BRANCH} \ + && mkdir build \ + && cd build \ + && cmake .. \ + && make package \ + && dpkg -i ./*.deb +RUN git clone https://github.com/ROCm/hipBLASLt +RUN cd hipBLASLt \ + && git checkout ${HIPBLASLT_BRANCH} \ + && ./install.sh -d --architecture ${PYTORCH_ROCM_ARCH} ${LEGACY_HIPBLASLT_OPTION} \ + && cd build/release \ + && make package +RUN mkdir -p /app/install && cp /app/hipBLASLt/build/release/*.deb /app/hipBLAS-common/build/*.deb /app/install + +FROM base AS build_rccl +ARG RCCL_BRANCH +ARG RCCL_REPO +RUN git clone ${RCCL_REPO} +RUN cd rccl \ + && git checkout ${RCCL_BRANCH} \ + && ./install.sh -p --amdgpu_targets ${PYTORCH_ROCM_ARCH} +RUN mkdir -p /app/install && cp /app/rccl/build/release/*.deb /app/install + +FROM base AS build_triton +ARG TRITON_BRANCH +ARG TRITON_REPO +RUN git clone ${TRITON_REPO} +RUN cd triton \ + && git checkout ${TRITON_BRANCH} \ + && cd python \ + && python3 setup.py bdist_wheel --dist-dir=dist +RUN mkdir -p /app/install && cp /app/triton/python/dist/*.whl /app/install + +FROM base AS build_amdsmi +RUN cd /opt/rocm/share/amd_smi \ + && pip wheel . --wheel-dir=dist +RUN mkdir -p /app/install && cp /opt/rocm/share/amd_smi/dist/*.whl /app/install + +FROM base AS build_pytorch +ARG PYTORCH_BRANCH +ARG PYTORCH_VISION_BRANCH +ARG PYTORCH_REPO +ARG PYTORCH_VISION_REPO +RUN git clone ${PYTORCH_REPO} pytorch +RUN cd pytorch && git checkout ${PYTORCH_BRANCH} && \ + pip install -r requirements.txt && git submodule update --init --recursive \ + && python3 tools/amd_build/build_amd.py \ + && CMAKE_PREFIX_PATH=$(python3 -c 'import sys; print(sys.prefix)') python3 setup.py bdist_wheel --dist-dir=dist \ + && pip install dist/*.whl +RUN git clone ${PYTORCH_VISION_REPO} vision +RUN cd vision && git checkout ${PYTORCH_VISION_BRANCH} \ + && python3 setup.py bdist_wheel --dist-dir=dist \ + && pip install dist/*.whl +RUN mkdir -p /app/install && cp /app/pytorch/dist/*.whl /app/install \ + && cp /app/vision/dist/*.whl /app/install + +FROM base AS final +RUN --mount=type=bind,from=build_hipblaslt,src=/app/install/,target=/install \ + dpkg -i /install/*deb \ + && sed -i 's/, hipblaslt-dev \(.*\), hipcub-dev/, hipcub-dev/g' /var/lib/dpkg/status \ + && sed -i 's/, hipblaslt \(.*\), hipfft/, hipfft/g' /var/lib/dpkg/status +RUN --mount=type=bind,from=build_rccl,src=/app/install/,target=/install \ + dpkg -i /install/*deb \ + && sed -i 's/, rccl-dev \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status \ + && sed -i 's/, rccl \(.*\), rocalution/, rocalution/g' /var/lib/dpkg/status +RUN --mount=type=bind,from=build_triton,src=/app/install/,target=/install \ + pip install /install/*.whl +RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \ + pip install /install/*.whl +RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \ + pip install /install/*.whl + +ARG BASE_IMAGE +ARG HIPBLASLT_BRANCH +ARG LEGACY_HIPBLASLT_OPTION +ARG RCCL_BRANCH +ARG RCCL_REPO +ARG TRITON_BRANCH +ARG TRITON_REPO +ARG PYTORCH_BRANCH +ARG PYTORCH_VISION_BRANCH +ARG PYTORCH_REPO +ARG PYTORCH_VISION_REPO +RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \ + && echo "HIPBLAS_COMMON_BRANCH: ${HIPBLAS_COMMON_BRANCH}" >> /app/versions.txt \ + && echo "HIPBLASLT_BRANCH: ${HIPBLASLT_BRANCH}" >> /app/versions.txt \ + && echo "LEGACY_HIPBLASLT_OPTION: ${LEGACY_HIPBLASLT_OPTION}" >> /app/versions.txt \ + && echo "RCCL_BRANCH: ${RCCL_BRANCH}" >> /app/versions.txt \ + && echo "RCCL_REPO: ${RCCL_REPO}" >> /app/versions.txt \ + && echo "TRITON_BRANCH: ${TRITON_BRANCH}" >> /app/versions.txt \ + && echo "TRITON_REPO: ${TRITON_REPO}" >> /app/versions.txt \ + && echo "PYTORCH_BRANCH: ${PYTORCH_BRANCH}" >> /app/versions.txt \ + && echo "PYTORCH_VISION_BRANCH: ${PYTORCH_VISION_BRANCH}" >> /app/versions.txt \ + && echo "PYTORCH_REPO: ${PYTORCH_REPO}" >> /app/versions.txt \ + && echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \ No newline at end of file diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 14c522afd7f9..8c86c618103e 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -25,7 +25,7 @@ WORKDIR ${COMMON_WORKDIR} FROM base AS fetch_vllm_0 ONBUILD COPY ./ vllm/ FROM base AS fetch_vllm_1 -ARG VLLM_REPO="https://github.com/vllm-project/vllm.git" +ARG VLLM_REPO="https://github.com/ROCm/vllm.git" ARG VLLM_BRANCH="main" ONBUILD RUN git clone ${VLLM_REPO} \ && cd vllm \ @@ -108,6 +108,7 @@ ARG COMMON_WORKDIR # Copy over the benchmark scripts as well COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples +# "Dummy alternation" ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 ENV TOKENIZERS_PARALLELISM=false diff --git a/ROCm_performance.md b/ROCm_performance.md new file mode 100644 index 000000000000..df8b586dc35f --- /dev/null +++ b/ROCm_performance.md @@ -0,0 +1,20 @@ +# Overview of the optional performance features uinque to https://github.com/ROCm/vllm + +## Triton attention +The default attention function on ROCm is using triton attention kernel. To fallback to the https://github.com/ROCm/flash-attention implementation set up the following environment symbol: +`VLLM_USE_TRITON_FLASH_ATTN=0` + +## Tunable ops +Pytorch tunable ops are supported. +Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order to enable both the runtime tuning and the subsequent use of tuned results. To only use the tuned results without tuning any newly encountered shapes, set `PYTORCH_TUNABLEOP_TUNING=0` + +## Custom PagedAttention + +On ROCm, to have better performance, a custom paged attention is available by switching on the env variable: `VLLM_USE_ROCM_CUSTOM_PAGED_ATTN=1`. +Currently, this env variable is enabled by default. To fallback to PagedAttention v2 kernel assign the env variable to 0. +The custom PagedAttention kernel is enabled for dtype: bf16, fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel. + +## NCCL Performance environment variable + +For MI300x, setting environment variable NCCL_MIN_NCHANNELS=112 is expected to improve performance. + diff --git a/benchmarks/P3L.py b/benchmarks/P3L.py new file mode 100755 index 000000000000..1fe0405212e7 --- /dev/null +++ b/benchmarks/P3L.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +""" +Patch-Perplexity (P3L) + +This is a script that produces a realistic PPL measurement +for the quantized KV cache system by processing a sequence of +non-overlapping patches of the reference text. Generation of the +consecutive symbols in each patch is governed (forced) +by the reference text. + +The initial context size for the system is set by the parameter +"--context-size". + +The number of output symbols to generate starting from a given +context is set by the parameter "--sample-size". This variable also +defines the size of the individual patch. + +For the N-token reference text that is split into M patches with the +system's context size C it takes M*preload + (N-C)*generation time. + +Quick correctness validation tips: + +Running llama-2-7b model +( + ./vllm/examples/P3L.py + --model=meta-llama/Llama-2-7b-chat-hf + --context-size=1024 + --sample-size=512 +) +should result in PPL ~ 6.524227946419175 + +Running llama-2-7b model +( + ./vllm/examples/P3L.py + --model=meta-llama/Llama-2-7b-chat-hf + --context-size=1024 + --sample-size=512 + --patch-size=1 +) +should result in PPL ~ PPL=3.8968611189957523 + +Running the script with multiple batches is possible +by specifying the --batch-size parameter. + +""" + +import argparse +import dataclasses +import datetime +import json +import math +import os + +from huggingface_hub import hf_hub_download + +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def get_wikitext2_text(tokenizer): + hf_hub_download(repo_id='alexei-v-ivanov-amd/wiki', + repo_type="dataset", + filename='wiki.test.raw', + local_dir='./') + with open('./wiki.test.raw') as f: + test_text = "\n".join(line.strip() for line in f) + test_enc = tokenizer(test_text) + + os.remove('./wiki.test.raw') + + return test_enc, test_text + + +def vllm_init(args): + engine_args = EngineArgs.from_cli_args(args) + llm = LLM(**dataclasses.asdict(engine_args)) + + sampling_params = SamplingParams(n=1, + temperature=0.0, + top_p=1, + ignore_eos=True, + ppl_measurement=True, + future_context=[], + prompt_logprobs=1, + logprobs=1, + presence_penalty=0.0) + + return llm, sampling_params + + +def vllm_predict(CONT, llm, sampl_par): + result = llm.generate(prompt_token_ids=CONT, sampling_params=sampl_par) + return result + + +def main(args: argparse.Namespace): + + MESSAGE = f"Initialising @ {datetime.datetime.now()}" + logger.info(MESSAGE) + print(MESSAGE) + my_ppl = 0.0 + + logger.info("Initializing the engine.") + my_llm, my_sampl_par = vllm_init(args) + my_tokenizer = my_llm.llm_engine.tokenizer.tokenizer + logger.info(my_sampl_par) + logger.info("Initialized the engine.") + + my_n_samples = args.sample_size + + if (args.context_size+my_n_samples) > \ + my_llm.llm_engine.model_config.max_model_len: + MESSAGE = ("" \ + "Error! The total number of tokens:\n" \ + f" prefix ({args.context_size}) + " \ + f"to be generated ({my_n_samples})" \ + f" can't be bigger than the model limit " \ + f"({my_llm.llm_engine.model_config.max_model_len}).") + logger.info(MESSAGE) + print(MESSAGE) + return + + my_test_enc, my_test_text = get_wikitext2_text(my_tokenizer) + logger.info("Loaded the test data.") + + my_n_patches = math.ceil( + (len(my_test_enc['input_ids']) - args.context_size - 1) / my_n_samples) + if args.patch_size is not None: + my_n_patches = args.patch_size + + num_tokens_generated = 0 + starting_time = datetime.datetime.now() + MESSAGE = (f"Starting generation @ {starting_time}\n" \ + " Have the test sample of " + f"{len(my_test_enc['input_ids'])} tokens" \ + f" will try to process {my_n_patches} patche(s)," \ + f" generating {my_n_samples} tokens in each patch" \ + f" from the initial context of {args.context_size} tokens.") + + logger.info(MESSAGE) + print(MESSAGE) + + my_batchsize = args.batch_size + + for c in range(0, my_n_patches, my_batchsize): + + CONTEXT = [] + my_sampl_par.future_context = [] + my_sampl_par.cntr = [] + + for b in range(my_batchsize): + if (c + b) < my_n_patches: + upper_boundary = min( + (c + b + 1) * my_n_samples + args.context_size, + len(my_test_enc['input_ids'])) + CONTEXT.append( + my_test_enc['input_ids'][(c + b) * my_n_samples:(c + b) * + my_n_samples + args.context_size]) + + my_sampl_par.future_context.append( + my_test_enc['input_ids'][(c + b) * my_n_samples + + args.context_size:upper_boundary]) + + my_sampl_par.cntr.append(c + b) + + my_sampl_par.max_tokens = max( + len(my_sampl_par.future_context[b]) for b in range(len(CONTEXT))) + + LOGPROBS = vllm_predict(CONTEXT, my_llm, my_sampl_par) + for b in range(len(CONTEXT)): + num_tokens_generated += len(LOGPROBS[b].outputs[0].token_ids) + my_ppl -= LOGPROBS[b].outputs[0].cumulative_logprob + + if (num_tokens_generated < my_n_samples * len(CONTEXT)): + MESSAGE = (f"Warning: The number of generated tokens is" \ + f"less than requested ({num_tokens_generated}" \ + f" < {my_n_samples*len(CONTEXT)}).") + logger.info(MESSAGE) + print(MESSAGE) + + MESSAGE = (f"Iterations {c+1} through {c+len(CONTEXT)}" \ + " of {my_n_patches} Intermediate" \ + "Estimates:\n" \ + f"\tCross-entropy_intermediate={my_ppl/num_tokens_generated}\n" \ + f"\tPerplexity_intermediate=" \ + f"{math.exp(my_ppl/num_tokens_generated)}") + + logger.info(MESSAGE) + print(MESSAGE) + + ending_time = datetime.datetime.now() + MESSAGE = (f"Done @ {ending_time} after processing for" \ + f" {ending_time-starting_time}" \ + f" generated {num_tokens_generated} tokens.") + + logger.info(MESSAGE) + print(MESSAGE) + + MESSAGE = (f"\tIntegral Cross-Entropy={my_ppl}\n\tAverage Cross-Entropy=" \ + f"{my_ppl/num_tokens_generated}" \ + f"\n\tPPL={math.exp(my_ppl/num_tokens_generated)}") + + if args.output_json: + results = { + "integral_cross_entropy": my_ppl, + "average_cross_entropy": my_ppl / num_tokens_generated, + "ppl": math.exp(my_ppl / num_tokens_generated), + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + logger.info(MESSAGE) + print(MESSAGE) + return + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Measure the PPPL (P3L) score of a given model.') + parser.add_argument('--context-size', type=int, default=4096) + parser.add_argument('--sample-size', type=int, default=512) + parser.add_argument('--batch-size', type=int, default=1) + parser.add_argument('--patch-size', type=int, default=None) + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the latency results in JSON format.') + + parser = EngineArgs.add_cli_args(parser) + args = parser.parse_args() + + main(args) diff --git a/benchmarks/P3L_mling.py b/benchmarks/P3L_mling.py new file mode 100755 index 000000000000..740f08681638 --- /dev/null +++ b/benchmarks/P3L_mling.py @@ -0,0 +1,273 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +""" +*MULTILINGUAL* Patch-Perplexity (P3L) + +This is a script that produces a realistic PPL measurement +for the quantized KV cache system by processing a sequence of +non-overlapping patches of the reference text. Generation of the +consecutive symbols in each patch is governed (forced) +by the reference text. + +The initial context size for the system is set by the parameter +"--context-size". + +The number of output symbols to generate starting from a given +context is set by the parameter "--sample-size". This variable also +defines the size of the individual patch. + +For the N-token reference text that is split into M patches with the +system's context size C it takes M*preload + (N-C)*generation time. + +Quick correctness validation tips: + +Running DeepSeek-V2 model +( + ./vllm/examples/P3L_mling.py + --model=meta-llama/Llama-2-7b-chat-hf + --context-size=1024 + --sample-size=512 +) + +should result in PPL ~ 8.42927 + +Running DeepSeek-V2 model +( + ./vllm/examples/P3L_mling.py + --model=meta-llama/Llama-2-7b-chat-hf + --context-size=1024 + --sample-size=512 + --patch-size=1 + --lang-script="cmn_Hant" +) +should result in PPL ~ 2.67962 + +The multi-linguality is implemented through the additional +key "--lang-script", which defaults to English in Latin +scripture ("eng_Latn"). + +Please refer to + +https://confluence.amd.com/display/MLSE/Multi-Lingual+P3L+Test + +for the complete set of possible language-scripture choices. + +Running the script with multiple batches is possible +by specifying the --batch-size parameter. + +""" + +import argparse +import dataclasses +import datetime +import json +import math +import os + +import pandas +from huggingface_hub import hf_hub_download + +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def get_wikitext2_text(tokenizer): + hf_hub_download(repo_id='alexei-v-ivanov-amd/wiki', + repo_type="dataset", + filename='wiki.test.raw', + local_dir='./') + with open('./wiki.test.raw') as f: + test_text = "\n".join(line.strip() for line in f) + test_enc = tokenizer(test_text) + + os.remove('./wiki.test.raw') + + return test_enc, test_text + + +def get_flores_plus_text(tokenizer, lng_scrpt): + hf_hub_download(repo_id='alexei-v-ivanov-amd/flores_plus', + repo_type="dataset", + filename=lng_scrpt + '.parquet', + local_dir='./') + + df = pandas.read_parquet('./' + lng_scrpt + '.parquet') + test_text = "\n\n".join(line.strip() for line in df['text']) + test_enc = tokenizer(test_text) + + os.remove('./' + lng_scrpt + '.parquet') + + return test_enc, test_text + + +def vllm_init(args): + engine_args = EngineArgs.from_cli_args(args) + llm = LLM(**dataclasses.asdict(engine_args)) + + sampling_params = SamplingParams(n=1, + temperature=0.0, + top_p=1, + ignore_eos=True, + ppl_measurement=True, + future_context=[], + prompt_logprobs=1, + logprobs=1, + presence_penalty=0.0) + + return llm, sampling_params + + +def vllm_predict(CONT, llm, sampl_par): + result = llm.generate(prompt_token_ids=CONT, sampling_params=sampl_par) + return result + + +def main(args: argparse.Namespace): + + MESSAGE = f"Initialising @ {datetime.datetime.now()}" + logger.info(MESSAGE) + print(MESSAGE) + my_ppl = 0.0 + + logger.info("Initializing the engine.") + my_llm, my_sampl_par = vllm_init(args) + my_tokenizer = my_llm.llm_engine.tokenizer.tokenizer + logger.info(my_sampl_par) + logger.info("Initialized the engine.") + + my_n_samples = args.sample_size + my_lang_script = args.lang_script + + if (args.context_size+my_n_samples) > \ + my_llm.llm_engine.model_config.max_model_len: + MESSAGE = ("" \ + "Error! The total number of tokens:\n" \ + f" prefix ({args.context_size}) + " \ + f"to be generated ({my_n_samples})" \ + f" can't be bigger than the model limit " \ + f"({my_llm.llm_engine.model_config.max_model_len}).") + logger.info(MESSAGE) + print(MESSAGE) + return + + my_test_enc, my_test_text = get_flores_plus_text(my_tokenizer, + my_lang_script) + + logger.info("Loaded the test data.") + + my_n_patches = math.ceil( + (len(my_test_enc['input_ids']) - args.context_size - 1) / my_n_samples) + if args.patch_size is not None: + my_n_patches = args.patch_size + + num_tokens_generated = 0 + starting_time = datetime.datetime.now() + MESSAGE = (f"Starting generation @ {starting_time}\n" \ + " Have the test sample of " + f"{len(my_test_enc['input_ids'])} tokens" \ + f" will try to process {my_n_patches} patche(s)," \ + f" generating {my_n_samples} tokens in each patch" \ + f" from the initial context of {args.context_size} tokens.") + + logger.info(MESSAGE) + print(MESSAGE) + + my_batchsize = args.batch_size + + for c in range(0, my_n_patches, my_batchsize): + + CONTEXT = [] + my_sampl_par.future_context = [] + my_sampl_par.cntr = [] + + for b in range(my_batchsize): + if (c + b) < my_n_patches: + upper_boundary = min( + (c + b + 1) * my_n_samples + args.context_size, + len(my_test_enc['input_ids'])) + CONTEXT.append( + my_test_enc['input_ids'][(c + b) * my_n_samples:(c + b) * + my_n_samples + args.context_size]) + + my_sampl_par.future_context.append( + my_test_enc['input_ids'][(c + b) * my_n_samples + + args.context_size:upper_boundary]) + + my_sampl_par.cntr.append(c + b) + + my_sampl_par.max_tokens = max( + len(my_sampl_par.future_context[b]) for b in range(len(CONTEXT))) + + LOGPROBS = vllm_predict(CONTEXT, my_llm, my_sampl_par) + for b in range(len(CONTEXT)): + num_tokens_generated += len(LOGPROBS[b].outputs[0].token_ids) + my_ppl -= LOGPROBS[b].outputs[0].cumulative_logprob + + if (num_tokens_generated < my_n_samples * len(CONTEXT)): + MESSAGE = (f"Warning: The number of generated tokens is" \ + f"less than requested ({num_tokens_generated}" \ + f" < {my_n_samples*len(CONTEXT)}).") + logger.info(MESSAGE) + print(MESSAGE) + + MESSAGE = (f"Iterations {c+1} through {c+len(CONTEXT)}" \ + " of {my_n_patches} Intermediate" \ + "Estimates:\n" \ + f"\tCross-entropy_intermediate={my_ppl/num_tokens_generated}\n" \ + f"\tPerplexity_intermediate=" \ + f"{math.exp(my_ppl/num_tokens_generated)}") + + logger.info(MESSAGE) + print(MESSAGE) + + ending_time = datetime.datetime.now() + MESSAGE = (f"Done @ {ending_time} after processing for" \ + f" {ending_time-starting_time}" \ + f" generated {num_tokens_generated} tokens.") + + logger.info(MESSAGE) + print(MESSAGE) + + MESSAGE = (f"\tIntegral Cross-Entropy={my_ppl}\n\tAverage Cross-Entropy=" \ + f"{my_ppl/num_tokens_generated}" \ + f"\n\tPPL={math.exp(my_ppl/num_tokens_generated)}") + + if args.output_json: + results = { + "integral_cross_entropy": my_ppl, + "average_cross_entropy": my_ppl / num_tokens_generated, + "ppl": math.exp(my_ppl / num_tokens_generated), + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + logger.info(MESSAGE) + print(MESSAGE) + return + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description='Measure the PPPL (P3L) score of a given model.') + parser.add_argument( + '--data', + type=str, + default='./wikitext/wikitext-2-v1/test-00000-of-00001.parquet') + parser.add_argument('--context-size', type=int, default=4096) + parser.add_argument('--sample-size', type=int, default=512) + parser.add_argument('--batch-size', type=int, default=1) + parser.add_argument('--patch-size', type=int, default=None) + parser.add_argument('--lang-script', type=str, default="eng_Latn") + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the latency results in JSON format.') + + parser = EngineArgs.add_cli_args(parser) + args = parser.parse_args() + + main(args) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index daedaadb1a77..913bb0cf19a5 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -11,8 +11,9 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, create_kv_caches_with_random) -NUM_BLOCKS = 1024 +NUM_BLOCKS = 128 * 1024 PARTITION_SIZE = 512 +PARTITION_SIZE_ROCM = 256 @torch.inference_mode() @@ -80,6 +81,12 @@ def main( # Prepare for the paged attention kernel. output = torch.empty_like(query) if version == "v2": + if current_platform.is_rocm(): + global PARTITION_SIZE + if not args.custom_paged_attn: + PARTITION_SIZE = 1024 + else: + PARTITION_SIZE = PARTITION_SIZE_ROCM num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), @@ -123,25 +130,48 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: v_scale, ) elif version == "v2": - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - ) + if not args.custom_paged_attn: + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + else: + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + None, + PARTITION_SIZE, + ) else: raise ValueError(f"Invalid version: {version}") torch.cuda.synchronize() @@ -154,13 +184,13 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: # Warmup. print("Warming up...") run_benchmark = run_cuda_benchmark - run_benchmark(num_iters=3, profile=False) + run_benchmark(num_iters=500, profile=False) # Benchmark. if do_profile: latency = run_benchmark(num_iters=1, profile=True) else: - latency = run_benchmark(num_iters=100, profile=False) + latency = run_benchmark(num_iters=10000, profile=False) print(f"Kernel running time: {latency * 1000000:.3f} us") @@ -195,6 +225,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: help="Data type for kv cache storage. If 'auto', will use model " "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") + parser.add_argument("--custom-paged-attn", + action="store_true", + help="Use custom paged attention") args = parser.parse_args() print(args) diff --git a/benchmarks/kernels/moe_tune_script.sh b/benchmarks/kernels/moe_tune_script.sh new file mode 100755 index 000000000000..2ee1748b5535 --- /dev/null +++ b/benchmarks/kernels/moe_tune_script.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + + +## ---- Mixtral fp8 tuning example ---- ## +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-Instruct-v0.1-FP8/ --tp-size 1 --tune --dtype fp8_w8a8 +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-Instruct-v0.1-FP8/ --tp-size 2 --tune --dtype fp8_w8a8 +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-Instruct-v0.1-FP8/ --tp-size 4 --tune --dtype fp8_w8a8 +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-Instruct-v0.1-FP8/ --tp-size 8 --tune --dtype fp8_w8a8 + + +## ---- Mixtral fp16 tuning example ---- ## +# we don't need --dtype fp16; it has been set as default for rocm in the script. + +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-v0.1/ --tp-size 1 --tune +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-v0.1/ --tp-size 2 --tune +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-v0.1/ --tp-size 4 --tune +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-v0.1/ --tp-size 8 --tune + + + +## ---- After the tuning is finished ---- ## +# The tuning script saves the configurations in a json file at the same directory from where you launch the script. +# The name of the json file will look something like this: E=8,N=14336,device_name=AMD_Instinct_MI300X.json +# +# [IMPORTANT] -> Once the tuning is complete, move the tuned config file(s) to the following path: +# vllm/vllm/model_executor/layers/fused_moe/configs/ + + +## ---- Notes ---- ## +# 1. The tuned file is specific for a TP size. This means a tuned file obtained for --tp-size 8 can only be used when running the model under TP=8 setting. +# 2. The script uses Ray for multi-gpu tuning. Export HIP_VISIBLE_DEVICES accordingly to expose the required no. of GPUs and use multiple gpus for tuning. \ No newline at end of file diff --git a/benchmarks/profiling/README.md b/benchmarks/profiling/README.md new file mode 100644 index 000000000000..8e029d8b9c1b --- /dev/null +++ b/benchmarks/profiling/README.md @@ -0,0 +1,59 @@ +# VLLM Benchmark Profiling + +This profiling directory provides a method to profile VLLM throughput and latency benchmarks using ROCm profiling utilities. + +## 1. Dependencies + +Before using the profiling feature, you need to install the required dependencies: + +### Install ROCm Profile Data + +```bash +git clone -b nvtx_enabled https://github.com/ROCm/rocmProfileData.git +cd rocmProfileData && make && sudo make install +``` + +### Install hipMarker + +```bash +cd rocmProfileData/hipMarker && python3 setup.py install +``` + +## 2. Profiling Benchmarks + +Profiling can be used to monitor the performance of the VLLM benchmarks with ROCm. The key flags used for profiling are: + +- `--profile-rpd`: Profiles the generation process of a single batch. +- `--profile-dir PROFILE_DIR`: Specifies the path to save the profiler output, which can later be visualized using tools like [ui.perfetto.dev](https://ui.perfetto.dev/) or [chrome.tracing](chrome://tracing/). + +### Profiling Using Default Directory + +By default, profiling results are saved in either `vllm_benchmark_latency_result` or `vllm_benchmark_throughput_result`. To run a benchmark and profile it using the default directory, execute: + +```bash +python3 benchmark_throughput.py --input-len {len} --output-len {len} --model {model} --profile-rpd +``` + +### Profiling With a Custom Directory + +You can specify a custom directory for saving profiler outputs by using the `--profile-dir` flag: + +```bash +python3 benchmark_throughput.py --input-len {len} --output-len {len} --model {model} --profile-rpd --profile-dir {/path/to/custom/dir} +``` + +After profiling is complete, an `.rpd` file containing the trace data will be saved to the specified directory. + +## 3. Convert Trace Data to JSON Format + +To view the trace data, it needs to be converted into a format that is compatible with tools like Chrome tracing or Perfetto. + +You can use the `rpd2tracing.py` script in rocmProfileData to convert the `.rpd` file into a JSON file: + +```bash +python3 rocmProfileData/tools/rpd2tracing.py trace.rpd trace.json +``` + +Once the trace is converted, open the `.json` file in [Chrome](chrome://tracing/) or [Perfetto](https://ui.perfetto.dev/) for visualization. + + diff --git a/benchmarks/profiling/benchmark_latency.py b/benchmarks/profiling/benchmark_latency.py new file mode 100644 index 000000000000..34b157eb6ab6 --- /dev/null +++ b/benchmarks/profiling/benchmark_latency.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Benchmark the latency of processing a single batch of requests.""" +import argparse +import dataclasses +import json +import os +import time +from contextlib import contextmanager, nullcontext +from pathlib import Path +from typing import List, Optional + +import numpy as np +import torch +from tqdm import tqdm + +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.inputs import PromptType +from vllm.utils import FlexibleArgumentParser + + +def main(args: argparse.Namespace): + print(args) + + @contextmanager + def rpd_profiler_context(): + from rpdTracerControl import rpdTracerControl as rpd + llm.start_profile() + yield + llm.stop_profile() + rpd.top_totals() + + @contextmanager + def torch_profiler_context(profile_result_dir: Optional[str] = None): + p = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + on_trace_ready=torch.profiler.tensorboard_trace_handler( + str(profile_result_dir))) + p.start() + try: + with torch.no_grad(): + yield p + finally: + p.stop() + print(p.key_averages().table(sort_by="self_cuda_time_total", + row_limit=-1)) + + def get_profiling_context(profile_result_dir: Optional[str] = None): + if args.profile_torch: + return torch_profiler_context(profile_result_dir) + elif args.profile_rpd: + return rpd_profiler_context() + else: + return nullcontext() + + if args.profile_torch or args.profile_rpd: + profile_result_dir = Path(args.profile_result_dir + or "./vllm_benchmark_latency_result") + profile_result_dir.mkdir(parents=True, exist_ok=True) + name = os.path.basename(os.path.normpath(args.model)) + model_trace_name = ( + f"{name}_in_{args.input_len}_out_{args.output_len}_" + f"batch_{args.batch_size}_tp_{args.tensor_parallel_size}") + print( + f"Profiling (results will be saved to '{profile_result_dir}')...") + if args.profile_rpd: + profile_result_dir /= f"{model_trace_name}.rpd" + os.environ["VLLM_RPD_PROFILER_DIR"] = str(profile_result_dir) + + engine_args = EngineArgs.from_cli_args(args) + + # NOTE(woosuk): If the request cannot be processed in a single batch, + # the engine will automatically process the request in multiple batches. + llm = LLM(**dataclasses.asdict(engine_args)) + + sampling_params = SamplingParams( + n=args.n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=args.output_len, + ) + print(sampling_params) + dummy_prompt_token_ids = np.random.randint(10000, + size=(args.batch_size, + args.input_len)) + dummy_inputs: List[PromptType] = [{ + "prompt_token_ids": batch + } for batch in dummy_prompt_token_ids.tolist()] + + def run_to_completion(profile_result_dir: Optional[str] = None): + if profile_result_dir: + with get_profiling_context(profile_result_dir): + llm.generate(dummy_inputs, + sampling_params=sampling_params, + use_tqdm=False) + else: + start_time = time.perf_counter() + llm.generate(dummy_inputs, + sampling_params=sampling_params, + use_tqdm=False) + end_time = time.perf_counter() + latency = end_time - start_time + return latency + + print("Warming up...") + for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): + run_to_completion(profile_result_dir=None) + + if args.profile_torch or args.profile_rpd: + run_to_completion(profile_result_dir=profile_result_dir) + return + + # Benchmark. + latencies = [] + for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): + latencies.append(run_to_completion(profile_result_dir=None)) + latencies = np.array(latencies) + percentages = [10, 25, 50, 75, 90, 99] + percentiles = np.percentile(latencies, percentages) + print(f'Avg latency: {np.mean(latencies)} seconds') + for percentage, percentile in zip(percentages, percentiles): + print(f'{percentage}% percentile latency: {percentile} seconds') + + # Output JSON results if specified + if args.output_json: + results = { + "avg_latency": np.mean(latencies), + "latencies": latencies.tolist(), + "percentiles": dict(zip(percentages, percentiles.tolist())), + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == '__main__': + parser = FlexibleArgumentParser( + description='Benchmark the latency of processing a single batch of ' + 'requests till completion.') + parser.add_argument('--input-len', type=int, default=32) + parser.add_argument('--output-len', type=int, default=128) + parser.add_argument('--batch-size', type=int, default=8) + parser.add_argument('--n', + type=int, + default=1, + help='Number of generated sequences per prompt.') + parser.add_argument('--use-beam-search', action='store_true') + parser.add_argument('--num-iters-warmup', + type=int, + default=10, + help='Number of iterations to run for warmup.') + parser.add_argument('--num-iters', + type=int, + default=30, + help='Number of iterations to run.') + parser.add_argument( + '--profile-torch', + action='store_true', + help='profile the generation process of a single batch') + parser.add_argument( + '--profile-rpd', + action='store_true', + help='profile the generation process of a single batch') + parser.add_argument( + '--profile-result-dir', + type=str, + default=os.getenv('VLLM_RPD_PROFILER_DIR', default=None), + help=('path to save the profiler output. Can be visualized ' + 'with ui.perfetto.dev or Tensorboard.')) + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the latency results in JSON format.') + + parser = EngineArgs.add_cli_args(parser) + args = parser.parse_args() + main(args) diff --git a/benchmarks/profiling/benchmark_throughput.py b/benchmarks/profiling/benchmark_throughput.py new file mode 100644 index 000000000000..dbf689de9525 --- /dev/null +++ b/benchmarks/profiling/benchmark_throughput.py @@ -0,0 +1,419 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Benchmark offline inference throughput.""" +import argparse +import dataclasses +import json +import os +import random +import time +from contextlib import contextmanager, nullcontext +from pathlib import Path +from typing import List, Optional, Tuple + +import torch +import uvloop +from tqdm import tqdm +from transformers import (AutoModelForCausalLM, AutoTokenizer, + PreTrainedTokenizerBase) + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) +from vllm.utils import FlexibleArgumentParser, merge_async_iterators + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int], +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data["conversations"][0]["value"], + data["conversations"][1]["value"]) for data in dataset] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + return filtered_dataset + + +def run_vllm( + requests: List[Tuple[str, int, int]], + n: int, + engine_args: EngineArgs, +) -> float: + from vllm import LLM, SamplingParams + + @contextmanager + def rpd_profiler_context(): + from rpdTracerControl import rpdTracerControl as rpd + llm.start_profile() + yield + llm.stop_profile() + rpd.top_totals() + + @contextmanager + def torch_profiler_context(profile_dir: Optional[str] = None): + p = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + on_trace_ready=torch.profiler.tensorboard_trace_handler( + str(profile_dir))) + p.start() + try: + with torch.no_grad(): + yield p + finally: + p.stop() + print(p.key_averages().table(sort_by="self_cuda_time_total", + row_limit=-1)) + + def get_profiling_context(profile_dir: Optional[str] = None): + if args.profile_torch: + return torch_profiler_context(profile_dir) + elif args.profile_rpd: + return rpd_profiler_context() + else: + return nullcontext() + + if args.profile_torch or args.profile_rpd: + profile_dir = Path(args.profile_dir + or "./vllm_benchmark_throughput_result") + profile_dir.mkdir(parents=True, exist_ok=True) + name = os.path.basename(os.path.normpath(args.model)) + model_trace_name = ( + f"{name}_in_{args.input_len}_out_{args.output_len}_" + f"tp_{args.tensor_parallel_size}") + print(f"Profiling (results will be saved to '{profile_dir}')...") + if args.profile_rpd: + profile_dir /= f"{model_trace_name}.rpd" + os.environ["VLLM_RPD_PROFILER_DIR"] = str(profile_dir) + + llm = LLM(**dataclasses.asdict(engine_args)) + + # Add the requests to the engine. + prompts: List[str] = [] + sampling_params: List[SamplingParams] = [] + for prompt, _, output_len in requests: + prompts.append(prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=output_len, + )) + + if args.profile_torch or args.profile_rpd: + with get_profiling_context(profile_dir): + llm.generate(prompts, sampling_params, use_tqdm=True) + return + else: + start = time.perf_counter() + llm.generate(prompts, sampling_params, use_tqdm=True) + end = time.perf_counter() + return end - start + + +async def run_vllm_async( + requests: List[Tuple[str, int, int]], + n: int, + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, +) -> float: + from vllm import SamplingParams + + async with build_async_engine_client_from_engine_args( + engine_args, disable_frontend_multiprocessing) as llm: + + # Add the requests to the engine. + prompts: List[str] = [] + sampling_params: List[SamplingParams] = [] + for prompt, _, output_len in requests: + prompts.append(prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=output_len, + )) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + generator = llm.generate(prompt, sp, request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + end = time.perf_counter() + return end - start + + +def run_hf( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: PreTrainedTokenizerBase, + n: int, + use_beam_search: bool, + max_batch_size: int, + trust_remote_code: bool, +) -> float: + assert not use_beam_search + llm = AutoModelForCausalLM.from_pretrained( + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + if llm.config.model_type == "llama": + # To enable padding in the HF backend. + tokenizer.pad_token = tokenizer.eos_token + llm = llm.cuda() + + pbar = tqdm(total=len(requests)) + start = time.perf_counter() + batch: List[str] = [] + max_prompt_len = 0 + max_output_len = 0 + for i in range(len(requests)): + prompt, prompt_len, output_len = requests[i] + # Add the prompt to the batch. + batch.append(prompt) + max_prompt_len = max(max_prompt_len, prompt_len) + max_output_len = max(max_output_len, output_len) + if len(batch) < max_batch_size and i != len(requests) - 1: + # Check if we can add more requests to the batch. + _, next_prompt_len, next_output_len = requests[i + 1] + if (max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len)) <= 2048: + # We can add more requests to the batch. + continue + + # Generate the sequences. + input_ids = tokenizer(batch, return_tensors="pt", + padding=True).input_ids + llm_outputs = llm.generate( + input_ids=input_ids.cuda(), + do_sample=not use_beam_search, + num_return_sequences=n, + temperature=1.0, + top_p=1.0, + use_cache=True, + max_new_tokens=max_output_len, + ) + # Include the decoding time. + tokenizer.batch_decode(llm_outputs, skip_special_tokens=True) + pbar.update(len(batch)) + + # Clear the batch. + batch = [] + max_prompt_len = 0 + max_output_len = 0 + end = time.perf_counter() + return end - start + + +def run_mii( + requests: List[Tuple[str, int, int]], + model: str, + tensor_parallel_size: int, + output_len: int, +) -> float: + from mii import client, serve + llm = serve(model, tensor_parallel=tensor_parallel_size) + prompts = [prompt for prompt, _, _ in requests] + + start = time.perf_counter() + llm.generate(prompts, max_new_tokens=output_len) + end = time.perf_counter() + client = client(model) + client.terminate_server() + return end - start + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + if args.dataset is None: + # Synthesize a prompt with the given input length. + prompt = { "prompt_token_ids" : [42] * (args.input_len - 1) } \ + if args.skip_tokenizer_init else "hi" * (args.input_len - 1) + requests = [(prompt, args.input_len, args.output_len) + for _ in range(args.num_prompts)] + else: + requests = sample_requests(args.dataset, args.num_prompts, tokenizer, + args.output_len) + + if args.backend == "vllm": + if args.async_engine: + elapsed_time = uvloop.run( + run_vllm_async( + requests, + args.n, + AsyncEngineArgs.from_cli_args(args), + args.disable_frontend_multiprocessing, + )) + else: + elapsed_time = run_vllm(requests, args.n, + EngineArgs.from_cli_args(args)) + elif args.backend == "hf": + assert args.tensor_parallel_size == 1 + elapsed_time = run_hf(requests, args.model, tokenizer, args.n, + args.use_beam_search, args.hf_max_batch_size, + args.trust_remote_code) + elif args.backend == "mii": + elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, + args.output_len) + else: + raise ValueError(f"Unknown backend: {args.backend}") + total_num_tokens = sum(prompt_len + output_len + for _, prompt_len, output_len in requests) + + if args.profile_torch or args.profile_rpd: + # Profiling complete + pass + else: + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the throughput.") + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii"], + default="vllm") + parser.add_argument("--dataset", + type=str, + default=None, + help="Path to the dataset.") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--num-prompts", + type=int, + default=1000, + help="Number of prompts to process.") + parser.add_argument("--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.") + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + parser.add_argument("--async-engine", + action='store_true', + default=False, + help="Use vLLM async engine rather than LLM class.") + parser.add_argument("--disable-frontend-multiprocessing", + action='store_true', + default=False, + help="Disable decoupled async engine frontend.") + parser.add_argument( + '--profile-torch', + action='store_true', + help='profile the generation process of a single batch') + parser.add_argument( + '--profile-rpd', + action='store_true', + help='profile the generation process of a single batch') + parser.add_argument( + '--profile-dir', + type=str, + default=None, + help=('path to save the profiler output. Can be visualized ' + 'with ui.perfetto.dev or Tensorboard.')) + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + if args.dataset is None: + assert args.input_len is not None + assert args.output_len is not None + else: + assert args.input_len is None + + if args.backend == "vllm": + if args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + elif args.backend == "hf": + if args.hf_max_batch_size is None: + raise ValueError("HF max batch size is required for HF backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") + elif args.backend == "mii": + if args.dtype != "auto": + raise ValueError("dtype must be auto for MII backend.") + if args.n != 1: + raise ValueError("n must be 1 for MII backend.") + if args.use_beam_search: + raise ValueError("Beam search is not supported for MII backend.") + if args.quantization is not None: + raise ValueError("Quantization is only for vLLM backend.") + if args.hf_max_batch_size is not None: + raise ValueError("HF max batch size is only for HF backend.") + if args.tokenizer != args.model: + raise ValueError("Tokenizer must be the same as the model for MII " + "backend.") + main(args) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 1c1c539819d0..825fac8cd368 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -444,3 +444,33 @@ function (define_gpu_extension_target GPU_MOD_NAME) install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME}) endfunction() + + +# gfx12xx should not be compiled together with gfx94x (MI300) because they support different types of FP8 format. +# FP8_FORMAT will be returned (E4M3FN / E4M3FNUZ / NONE / CONFLICT) +macro (get_supported_fp8_format FP8_FORMAT GPU_LANG GPU_ARCHES) + set(_USING_CUDA_FP8_FORMAT "FALSE") + set(_USING_HIP_FP8_FORMAT "FALSE") + + if (NOT (${GPU_LANG} STREQUAL "HIP")) + set(_USING_CUDA_FP8_FORMAT "TRUE") + else() + foreach (_ARCH ${GPU_ARCHES}) + if (_ARCH MATCHES "gfx94.") + set(_USING_HIP_FP8_FORMAT "TRUE") + elseif(_ARCH MATCHES "gfx12..") + set(_USING_CUDA_FP8_FORMAT "TRUE") + endif() + endforeach() + endif() + + if ((${_USING_CUDA_FP8_FORMAT} STREQUAL "FALSE") AND (${_USING_HIP_FP8_FORMAT} STREQUAL "FALSE")) + set(FP8_FORMAT "NONE") + elseif((${_USING_CUDA_FP8_FORMAT} STREQUAL "FALSE") AND (${_USING_HIP_FP8_FORMAT} STREQUAL "TRUE")) + set(FP8_FORMAT "E4M3FNUZ") + elseif((${_USING_CUDA_FP8_FORMAT} STREQUAL "TRUE") AND (${_USING_HIP_FP8_FORMAT} STREQUAL "FALSE")) + set(FP8_FORMAT "E4M3FN") + else() + set(FP8_FORMAT "CONFLICT") + endif() +endmacro() diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 88275dbdd83a..41176c801fb0 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -7,6 +7,10 @@ #include "cuda_compat.h" #include "dispatch_utils.h" +#ifdef USE_ROCM + #include "quantization/fp8/amd/hip_float8.h" +#endif + namespace vllm { template +__global__ void scaled_act_and_mul_kernel( + c10::Float8_e4m3fnuz* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d, const float scale) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); + const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); + float r = ACT_FN(x) * y * scale; + out[token_idx * d + idx] = c10::Float8_e4m3fnuz( + hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits()); + } +} +#endif + template __device__ __forceinline__ T silu_kernel(const T& x) { // x * sigmoid(x) @@ -79,6 +101,25 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { input.data_ptr(), d); \ }); +// Launch activation and gating kernel. +#ifdef USE_ROCM + #define LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "scaled_act_and_mul_kernel", [&] { \ + vllm::scaled_act_and_mul_kernel> \ + <<>>( \ + out.data_ptr(), \ + input.data_ptr(), d, \ + 1.0 / (*scale.data_ptr())); \ + }); +#endif + void silu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { @@ -93,6 +134,14 @@ void mul_and_silu(torch::Tensor& out, // [..., d] LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false); } +void scaled_silu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& scale) { +#ifdef USE_ROCM + LAUNCH_SCALED_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); +#endif +} + void gelu_and_mul(torch::Tensor& out, // [..., d] torch::Tensor& input) // [..., 2 * d] { diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 9b3a5c4b1014..87bea0f3b279 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -48,7 +48,7 @@ // TODO(woosuk): Tune NUM_THREADS. template + int NUM_THREADS> void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, int num_kv_heads, float scale, @@ -133,19 +133,38 @@ void paged_attention_v1_launcher( } } -#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE, \ + NUM_THREADS) \ paged_attention_v1_launcher( \ + IS_BLOCK_SPARSE, NUM_THREADS>( \ out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); -#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - if (is_block_sparse) { \ - CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ - } else { \ - CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ +#define CALL_V1_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, \ + IS_FP8_KV_CACHE, IS_BLOCK_SPARSE) \ + switch (num_threads) { \ + case 128: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + IS_BLOCK_SPARSE, 128); \ + break; \ + case 1024: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + IS_BLOCK_SPARSE, 1024); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported num threads: ", num_threads); \ + break; \ + } + +#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + if (is_block_sparse) { \ + CALL_V1_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + true); \ + } else { \ + CALL_V1_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + false); \ } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes @@ -183,7 +202,7 @@ void paged_attention_v1( torch::Tensor& v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { + const int64_t blocksparse_head_sliding_step, const int64_t num_threads) { const bool is_block_sparse = (blocksparse_vert_stride > 1); DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 9935359e02fb..fc84a6774b8e 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -19,6 +19,14 @@ #include "attention_kernels.cuh" +#ifdef USE_ROCM + #include + #include "../quantization/fp8/amd/quant_utils.cuh" +typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/fp8/nvidia/quant_utils.cuh" +#endif + #ifndef USE_ROCM #define WARP_SIZE 32 #else @@ -48,7 +56,7 @@ template + int NUM_THREADS, int PARTITION_SIZE = 512> void paged_attention_v2_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -139,20 +147,39 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE, \ + NUM_THREADS, PARTITION_SIZE) \ paged_attention_v2_launcher( \ + IS_BLOCK_SPARSE, NUM_THREADS, PARTITION_SIZE>( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ blocksparse_vert_stride, blocksparse_block_size, \ blocksparse_head_sliding_step); -#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ - if (is_block_sparse) { \ - CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ - } else { \ - CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ +#define CALL_V2_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, \ + IS_FP8_KV_CACHE, IS_BLOCK_SPARSE) \ + switch (num_threads) { \ + case 128: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + IS_BLOCK_SPARSE, 128, 512); \ + break; \ + case 1024: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + IS_BLOCK_SPARSE, 1024, 1024); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported num threads: ", num_threads); \ + break; \ + } + +#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + if (is_block_sparse) { \ + CALL_V2_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + true); \ + } else { \ + CALL_V2_LAUNCHER_W_NUM_THREADS(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, \ + false); \ } // NOTE(woosuk): To reduce the compilation time, we omitted block sizes @@ -194,7 +221,7 @@ void paged_attention_v2( torch::Tensor& v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step) { + const int64_t blocksparse_head_sliding_step, const int64_t num_threads) { const bool is_block_sparse = (blocksparse_vert_stride > 1); DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, CALL_V2_LAUNCHER_BLOCK_SIZE) diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 123278bfed71..8152cb3d348f 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -5,29 +5,32 @@ #include "custom_all_reduce.cuh" -// Fake pointer type, must match fptr_t type in ops.h. -// We use this type alias to indicate when pointers are passed in as int64_t. +// fake pointer type, must match fptr_t type in ops.h using fptr_t = int64_t; static_assert(sizeof(void*) == sizeof(fptr_t)); -fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, - torch::Tensor& rank_data, int64_t rank, +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int64_t rank, bool full_nvlink) { - int world_size = fake_ipc_ptrs.size(); + int world_size = offsets.size(); if (world_size > 8) throw std::invalid_argument("world size > 8 is not supported"); if (world_size % 2 != 0) throw std::invalid_argument("Odd num gpus is not supported for now"); + if (world_size != handles.size()) + throw std::invalid_argument( + "handles length should equal to offsets length"); if (rank < 0 || rank >= world_size) throw std::invalid_argument("invalid rank passed in"); - vllm::Signal* ipc_ptrs[8]; + cudaIpcMemHandle_t ipc_handles[8]; for (int i = 0; i < world_size; i++) { - ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); + std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); } - return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(), - rank_data.numel(), rank, world_size, - full_nvlink); + return (fptr_t) new vllm::CustomAllreduce( + reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), + rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); } /** @@ -52,48 +55,26 @@ bool _is_weak_contiguous(torch::Tensor& t) { t.numel() * t.element_size()); } -/** - * Performs an out-of-place allreduce and stores result in out. - * - * If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered. - * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first - * copied into _reg_buffer. - */ -void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, - fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes) { +void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + cudaStream_t stream) { auto fa = reinterpret_cast(_fa); - const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); - auto stream = c10::cuda::getCurrentCUDAStream().stream(); - - TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); - TORCH_CHECK_EQ(inp.numel(), out.numel()); TORCH_CHECK(_is_weak_contiguous(out)); - TORCH_CHECK(_is_weak_contiguous(inp)); - auto input_size = inp.numel() * inp.element_size(); - auto reg_buffer = reinterpret_cast(_reg_buffer); - if (reg_buffer) { - TORCH_CHECK_LE(input_size, reg_buffer_sz_bytes); - AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, - cudaMemcpyDeviceToDevice, stream)); - } else { - reg_buffer = inp.data_ptr(); - } switch (out.scalar_type()) { case at::ScalarType::Float: { - fa->allreduce(stream, reinterpret_cast(reg_buffer), + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), reinterpret_cast(out.data_ptr()), out.numel()); break; } case at::ScalarType::Half: { - fa->allreduce(stream, reinterpret_cast(reg_buffer), + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), reinterpret_cast(out.data_ptr()), out.numel()); break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case at::ScalarType::BFloat16: { fa->allreduce( - stream, reinterpret_cast(reg_buffer), + stream, reinterpret_cast(inp.data_ptr()), reinterpret_cast(out.data_ptr()), out.numel()); break; } @@ -104,41 +85,91 @@ void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, } } +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + _all_reduce(_fa, inp, out, stream); +} + +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + + auto input_size = inp.numel() * inp.element_size(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), + "registered buffer is too small to contain the input"); + AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(), + input_size, cudaMemcpyDeviceToDevice, stream)); + _all_reduce(_fa, reg_buffer, out, stream); +} + void dispose(fptr_t _fa) { - delete reinterpret_cast(_fa); + auto fa = reinterpret_cast(_fa); + delete fa; } int64_t meta_size() { return sizeof(vllm::Signal); } -void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs) { +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets) { auto fa = reinterpret_cast(_fa); - TORCH_CHECK(fake_ipc_ptrs.size() == fa->world_size_); - void* ipc_ptrs[8]; - for (int i = 0; i < fake_ipc_ptrs.size(); i++) { - ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); - } - fa->register_buffer(ipc_ptrs); + fa->register_buffer(handles, offsets, t.data_ptr()); } -// Use vector to represent byte data for python binding compatibility. -std::tuple, std::vector> -get_graph_buffer_ipc_meta(fptr_t _fa) { +std::tuple> get_graph_buffer_ipc_meta( + fptr_t _fa) { auto fa = reinterpret_cast(_fa); - auto [handle, offsets] = fa->get_graph_buffer_ipc_meta(); - std::vector bytes(handle.begin(), handle.end()); - return std::make_tuple(bytes, offsets); + auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto handles = + torch::empty({static_cast(handle_bytes.size())}, options); + std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size()); + return {handles, std::move(offsets)}; } -// Use vector to represent byte data for python binding compatibility. -void register_graph_buffers(fptr_t _fa, - const std::vector>& handles, +void register_graph_buffers(fptr_t _fa, const std::vector& handles, const std::vector>& offsets) { auto fa = reinterpret_cast(_fa); - std::vector bytes; - bytes.reserve(handles.size()); - for (int i = 0; i < handles.size(); i++) { - bytes.emplace_back(handles[i].begin(), handles[i].end()); - } - bytes.reserve(handles.size()); - fa->register_graph_buffers(bytes, offsets); + fa->register_graph_buffers(handles, offsets); } + +#ifdef USE_ROCM + +void free_meta_buffer(void* buffer) { CUDACHECK(cudaFree(buffer)); } + +torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) { + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto data_handle = + torch::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, options); + CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)data_handle.data_ptr(), + inp.data_ptr())); + return data_handle; +} + +torch::Tensor allocate_meta_buffer(int64_t size) { + auto device_index = c10::cuda::current_device(); + at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index)); + void* buffer; + cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); + AT_CUDA_CHECK( + hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached)); + AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream)); + AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); + auto options = torch::TensorOptions() + .dtype(torch::kI8) + .device(torch::kCUDA, device_index); + return torch::from_blob(buffer, {size}, free_meta_buffer, options); +} + +#endif diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index b9df4ed160b0..838605e5f04d 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -1,12 +1,16 @@ #pragma once #include -#include +#ifdef USE_ROCM + #include +typedef __hip_bfloat16 nv_bfloat16; +#else + #include +#endif #include #include #include -#include #include #include #include @@ -24,26 +28,30 @@ namespace vllm { -constexpr int kMaxBlocks = 36; -// Counter may overflow, but it's fine since unsigned int overflow is -// well-defined behavior. -using FlagType = uint32_t; +constexpr int kMaxBlocks = 64; +// note: we don't want to use atomics for signals because peer atomics are no +// supported on PCIe links struct Signal { - alignas(128) FlagType self_counter[kMaxBlocks][8]; - // Two sets of peer counters are needed for two syncs. The reason is that - // it's possible for peer GPU block to arrive at the second sync point while - // the current GPU block haven't passed the first sync point. Thus, peer GPU - // may write counter+1 while current GPU is busy waiting for counter. We use - // alternating counter array to avoid this possibility. - alignas(128) FlagType peer_counter[2][kMaxBlocks][8]; + alignas(128) uint32_t start[kMaxBlocks][8]; + alignas(128) uint32_t end[kMaxBlocks][8]; + alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank }; +#ifdef USE_ROCM +struct __align__(16) RankData { + const void* ptrs[8]; +}; +#else struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; +#endif struct __align__(16) RankSignals { - Signal* signals[8]; +#ifndef USE_ROCM + volatile +#endif + Signal* signals[8]; }; // like std::array, but aligned @@ -134,71 +142,97 @@ DINLINE O downcast(array_t val) { } } -static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 - asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), - "l"(flag_addr)); -#else - asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), - "l"(flag_addr)); +// This function is meant to be used as the first synchronization in the all +// reduce kernel. Thus, it doesn't need to make any visibility guarantees for +// prior memory accesses. Note: volatile writes will not be reordered against +// other volatile writes. +template +DINLINE void start_sync(const RankSignals& sg, +#ifndef USE_ROCM + volatile #endif -} - -static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) { - FlagType flag; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 - asm volatile("ld.acquire.sys.global.u32 %0, [%1];" - : "=r"(flag) - : "l"(flag_addr)); + Signal* self_sg, + int rank) { +#ifdef USE_ROCM + uint32_t flag = self_sg->_flag[blockIdx.x] + 1; + if (threadIdx.x < ngpus) { + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + __scoped_atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], + flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM); + // wait until we got true from all ranks + while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], + __ATOMIC_RELAXED, + __MEMORY_SCOPE_DEVICE) < flag); + } + __syncthreads(); + // use one thread to update flag + if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; #else - asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" - : "=r"(flag) - : "l"(flag_addr)); + if (threadIdx.x < ngpus) { + // reset flag for next time + self_sg->end[blockIdx.x][threadIdx.x] = 0; + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; + // wait until we got true from all ranks + while (!self_sg->start[blockIdx.x][threadIdx.x]); + } + __syncthreads(); #endif - return flag; -} - -static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) { - asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); } -static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) { - FlagType flag; - asm volatile("ld.volatile.global.u32 %0, [%1];" - : "=r"(flag) - : "l"(flag_addr)); - return flag; -} - -// is_start: whether this is the very first synchronization barrier. -// need_fence: whether a memory fence is needed. If true, a release-acquire -// semantic is used to enforce memory access order before and after this -// barrier. -template -DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, - int rank) { - if constexpr (!is_start) __syncthreads(); - static_assert( - !(is_start && need_fence)); // Start barrier shouldn't need fence. +// This function is meant to be used as the second or the final synchronization +// barrier in the all reduce kernel. If it's the final synchronization barrier, +// we don't need to make any visibility guarantees for prior memory accesses. +template +DINLINE void end_sync(const RankSignals& sg, +#ifndef USE_ROCM + volatile +#endif + Signal* self_sg, + int rank) { +#ifdef USE_ROCM + __syncthreads(); + // eliminate the case that prior writes are not visible after signals become + // visible. Note that I did not managed to make this happen through a lot of + // testing. Might be the case that hardware provides stronger guarantee than + // the memory model. + uint32_t flag = self_sg->_flag[blockIdx.x] + 1; if (threadIdx.x < ngpus) { - // Increment the counter. Technically we only need one counter, but we use - // multiple per block to eliminate the need to share the counter via smem. - auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1; - // Write the expected counter value to peer and wait for correct value from - // peer. - auto peer_counter_ptr = - &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank]; - auto self_counter_ptr = - &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; - if constexpr (need_fence) { - st_flag_release(peer_counter_ptr, val); - while (ld_flag_acquire(self_counter_ptr) != val); - } else { - st_flag_volatile(peer_counter_ptr, val); - while (ld_flag_volatile(self_counter_ptr) != val); - } + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + __scoped_atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], + flag, + final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE, + __MEMORY_SCOPE_SYSTEM); + // wait until we got true from all ranks + while ( + __scoped_atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], + final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE, + __MEMORY_SCOPE_DEVICE) < flag); + } + __syncthreads(); + // use one thread to update flag + if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; +#else + __syncthreads(); + // eliminate the case that prior writes are not visible after signals become + // visible. Note that I did not managed to make this happen through a lot of + // testing. Might be the case that hardware provides stronger guarantee than + // the memory model. + if constexpr (!final_sync) __threadfence_system(); + if (threadIdx.x < ngpus) { + // reset flag for next time + self_sg->start[blockIdx.x][threadIdx.x] = 0; + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; + // wait until we got true from all ranks + while (!self_sg->end[blockIdx.x][threadIdx.x]); } - if constexpr (is_start || need_fence) __syncthreads(); + if constexpr (!final_sync) __syncthreads(); +#endif } template @@ -213,30 +247,42 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg, + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, +#ifndef USE_ROCM + volatile +#endif + Signal* self_sg, T* __restrict__ result, int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; // note: we don't reorder the address so the accumulation order is the same // for all ranks, ensuring bitwise identical results auto dp = *_dp; - multi_gpu_barrier(sg, self_sg, rank); + start_sync(sg, self_sg, rank); // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); } - multi_gpu_barrier(sg, self_sg, rank); + end_sync(sg, self_sg, rank); } template +#ifdef USE_ROCM DINLINE P* get_tmp_buf(Signal* sg) { +#else +DINLINE P* get_tmp_buf(volatile Signal* sg) { +#endif return (P*)(((Signal*)sg) + 1); } template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg, + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, +#ifndef USE_ROCM + volatile +#endif + Signal* self_sg, T* __restrict__ result, int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; @@ -255,12 +301,12 @@ __global__ void __launch_bounds__(512, 1) tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; - multi_gpu_barrier(sg, self_sg, rank); + start_sync(sg, self_sg, rank); // stage 1: reduce scatter for (int idx = start + tid; idx < end; idx += stride) { tmp_out[idx - start] = packed_reduce(ptrs, idx); } - multi_gpu_barrier(sg, self_sg, rank); + end_sync(sg, self_sg, rank); // stage 2: allgather. Note: it's important to match the tid between // the two stages, because visibility across devices is only guaranteed @@ -289,52 +335,46 @@ class CustomAllreduce { int world_size_; bool full_nvlink_; + // below are device pointers RankSignals sg_; - // Stores an map from a pointer to its peer pointters from all ranks. std::unordered_map buffers_; Signal* self_sg_; - // Stores rank data from all ranks. This is mainly for cuda graph purposes. - // For cuda graph to work, all kernel arguments must be fixed during graph - // capture time. However, the peer pointers are not known during graph capture - // time. Therefore, during capture, we increment the rank data pointer and use - // that as the argument to the kernel. The kernel arguments are stored in - // graph_unreg_buffers_. The actual peer pointers will be filled in at the - // memory pointed to by the pointers in graph_unreg_buffers_ when - // the IPC handles are exchanged between ranks. - // - // The overall process looks like this: - // 1. Graph capture. - // 2. Each rank obtains the IPC handles for each addresses used during cuda - // graph capture using get_graph_buffer_ipc_meta. - // 3. (In Python) all gather the IPC handles. - // 4. Obtain the peer pointers by opening the IPC handles, and store them in - // the rank data array at corresponding positions. + // stores the registered device pointers from all ranks RankData *d_rank_data_base_, *d_rank_data_end_; std::vector graph_unreg_buffers_; // a map from IPC handles to opened IPC pointers std::map ipc_handles_; /** - * Signals are an array of ipc-enabled buffers from all ranks. - * For each of the buffer, the layout is as follows: - * | -- sizeof(Signal) -- | ------ a few MB ----- | - * The first section is for allreduce synchronization, and the second section - * is for storing the intermediate results required by some allreduce algos. + * meta is a pointer to device metadata and temporary buffer for allreduce. * - * Note: this class does not own any device memory. Any required buffers - * are passed in from the constructor. + * There's a total of sizeof(Signal) of prefix before the actual data, + * so meta + 1 points to actual temporary buffer. + * + * note: this class does not own any device memory. Any required buffers + * are passed in from the constructor */ - CustomAllreduce(Signal** signals, void* rank_data, size_t rank_data_sz, - int rank, int world_size, bool full_nvlink = true) + CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, + const cudaIpcMemHandle_t* handles, + const std::vector& offsets, int rank, + bool full_nvlink = true) : rank_(rank), - world_size_(world_size), + world_size_(offsets.size()), full_nvlink_(full_nvlink), - self_sg_(signals[rank]), + self_sg_(meta), d_rank_data_base_(reinterpret_cast(rank_data)), d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { for (int i = 0; i < world_size_; i++) { - sg_.signals[i] = signals[i]; + Signal* rank_sg; + if (i != rank_) { + char* handle = open_ipc_handle(&handles[i]); + handle += offsets[i]; + rank_sg = (Signal*)handle; + } else { + rank_sg = self_sg_; + } + sg_.signals[i] = rank_sg; } } @@ -351,10 +391,11 @@ class CustomAllreduce { return it->second; } - std::pair> get_graph_buffer_ipc_meta() { + std::pair, std::vector> + get_graph_buffer_ipc_meta() { auto num_buffers = graph_unreg_buffers_.size(); auto handle_sz = sizeof(cudaIpcMemHandle_t); - std::string handles(handle_sz * num_buffers, static_cast(0)); + std::vector handles(handle_sz * num_buffers, 0); std::vector offsets(num_buffers); for (int i = 0; i < num_buffers; i++) { auto ptr = graph_unreg_buffers_[i]; @@ -362,7 +403,11 @@ class CustomAllreduce { // note: must share the base address of each allocation, or we get wrong // address if (cuPointerGetAttribute(&base_ptr, +#ifdef USE_ROCM + HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR, +#else CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, +#endif (CUdeviceptr)ptr) != CUDA_SUCCESS) throw std::runtime_error("failed to get pointer attr"); CUDACHECK(cudaIpcGetMemHandle( @@ -379,22 +424,26 @@ class CustomAllreduce { std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); } - /** - * Register already-shared IPC pointers. - */ - void register_buffer(void** ptrs) { + void register_buffer(const std::vector& handles, + const std::vector& offsets, void* self) { check_rank_data_capacity(); RankData data; for (int i = 0; i < world_size_; i++) { - data.ptrs[i] = ptrs[i]; + if (i != rank_) { + char* handle = open_ipc_handle(handles[i].data()); + handle += offsets[i]; + data.ptrs[i] = handle; + } else { + data.ptrs[i] = self; + } } auto d_data = d_rank_data_base_++; CUDACHECK( cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); - buffers_[ptrs[rank_]] = d_data; + buffers_[self] = d_data; } - // Note: when registering graph buffers, we intentionally choose to not + // note: when registering graph buffers, we intentionally choose to not // deduplicate the addresses. That means if the allocator reuses some // addresses, they will be registered again. This is to account for the remote // possibility of different allocation patterns between ranks. For example, @@ -429,52 +478,52 @@ class CustomAllreduce { } /** - * Performs allreduce, assuming input has already been registered. - * - * Block and grid default configs are results after careful grid search. Using - * 36 blocks give the best or close to the best runtime on the devices I - * tried: A100, A10, A30, T4, V100. You'll notice that NCCL kernels also only - * take a small amount of SMs. Not quite sure the underlying reason, but my - * guess is that too many SMs will cause contention on NVLink bus. + * This is the result after careful grid search. Using 36 blocks give the best + * or close to the best runtime on the devices I tried: A100, A10, A30, T4, + * V100. You'll notice that NCCL kernels also only take a small amount of SMs. + * Not quite sure the underlying reason, but my guess is that too many SMs + * will cause contention on NVLink bus. */ template void allreduce(cudaStream_t stream, T* input, T* output, int size, - int threads = 512, int block_limit = 36) { - auto d = packed_t::P::size; - if (size % d != 0) +#ifndef USE_ROCM + int threads = 512, int block_limit = 36){ +#else + int threads = 512, int block_limit = 16) { +#endif + auto d = packed_t::P::size; + if (size % d != 0) + throw std::runtime_error( + "custom allreduce currently requires input length to be multiple " + "of " + + std::to_string(d)); + if (block_limit > kMaxBlocks) + throw std::runtime_error("max supported block limit is " + + std::to_string(kMaxBlocks) + ". Got " + + std::to_string(block_limit)); + + RankData* ptrs; + cudaStreamCaptureStatus status; + CUDACHECK(cudaStreamIsCapturing(stream, &status)); + if (status == cudaStreamCaptureStatusActive) { + ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); + graph_unreg_buffers_.push_back(input); + } else { + auto it = buffers_.find(input); + if (it == buffers_.end()) throw std::runtime_error( - "custom allreduce currently requires input length to be multiple " - "of " + - std::to_string(d)); - if (block_limit > kMaxBlocks) - throw std::runtime_error("max supported block limit is " + - std::to_string(kMaxBlocks) + ". Got " + - std::to_string(block_limit)); - - RankData* ptrs; - cudaStreamCaptureStatus status; - CUDACHECK(cudaStreamIsCapturing(stream, &status)); - if (status == cudaStreamCaptureStatusActive) { - ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); - graph_unreg_buffers_.push_back(input); - } else { - auto it = buffers_.find(input); - if (it == buffers_.end()) - throw std::runtime_error( - "buffer address " + - std::to_string(reinterpret_cast(input)) + - " is not registered!"); - ptrs = it->second; - } + "buffer address " + + std::to_string(reinterpret_cast(input)) + + " is not registered!"); + ptrs = it->second; + } - size /= d; - auto bytes = size * sizeof(typename packed_t::P); - int blocks = std::min(block_limit, (size + threads - 1) / threads); + size /= d; + auto bytes = size * sizeof(typename packed_t::P); + int blocks = std::min(block_limit, (size + threads - 1) / threads); #define KL(ngpus, name) \ name<<>>(ptrs, sg_, self_sg_, output, \ rank_, size); - // TODO(hanzhi713): Threshold is different for A100 and H100. - // Add per device threshold. #define REDUCE_CASE(ngpus) \ case ngpus: { \ if (world_size_ == 2) { \ @@ -490,27 +539,27 @@ class CustomAllreduce { break; \ } - switch (world_size_) { - REDUCE_CASE(2) - REDUCE_CASE(4) - REDUCE_CASE(6) - REDUCE_CASE(8) - default: - throw std::runtime_error( - "custom allreduce only supports num gpus in (2,4,6,8). Actual num " - "gpus = " + - std::to_string(world_size_)); - } + switch (world_size_) { + REDUCE_CASE(2) + REDUCE_CASE(4) + REDUCE_CASE(6) + REDUCE_CASE(8) + default: + throw std::runtime_error( + "custom allreduce only supports num gpus in (2,4,6,8). Actual num " + "gpus = " + + std::to_string(world_size_)); + } #undef REDUCE_CASE #undef KL - } +} - ~CustomAllreduce() { - for (auto [_, ptr] : ipc_handles_) { - CUDACHECK(cudaIpcCloseMemHandle(ptr)); - } +~CustomAllreduce() { + for (auto [_, ptr] : ipc_handles_) { + CUDACHECK(cudaIpcCloseMemHandle(ptr)); } -}; +} +}; // namespace vllm /** * To inspect PTX/SASS, copy paste this header file to compiler explorer and add a template instantiation: diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index b59ea40d980f..4435c433e656 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -1,15 +1,15 @@ /** * This is a standalone test for custom allreduce. * To compile, make sure you have MPI and NCCL installed in your system. - * export MPI_HOME=xxx + * export MPI_HOME=XXX * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o - * custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi + * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi * * Warning: this C++ test is not designed to be very readable and was used * during the rapid prototyping process. * * To run: - * mpirun --allow-run-as-root -np 8 ./custom_all_reduce_test + * mpirun -np 8 ./custom_all_reduce_test */ #include #include @@ -20,9 +20,16 @@ #include #include "cuda_profiler_api.h" -#include "custom_all_reduce.cuh" #include "mpi.h" -#include "nccl.h" +#ifdef USE_ROCM + #include +typedef __hip_bfloat16 nv_bfloat16; + #include "rccl/rccl.h" + #include "custom_all_reduce_hip.cuh" +#else + #include "nccl.h" + #include "custom_all_reduce.cuh" +#endif #define MPICHECK(cmd) \ do { \ @@ -44,13 +51,16 @@ } while (0) __global__ void dummy_kernel() { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 - for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms -#else +#ifdef USE_ROCM for (int i = 0; i < 100; i++) { - long long int start = clock64(); - while (clock64() - start < 150000000); // approximately 98.4ms on P40 + uint64_t start = wall_clock64(); + uint64_t cycles_elapsed; + do { + cycles_elapsed = wall_clock64() - start; + } while (cycles_elapsed < 100); } +#else + for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms #endif } @@ -121,8 +131,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, * registration, they are allocated and registered together in the test for * convenience. */ +#ifdef USE_ROCM + CUDACHECK(hipExtMallocWithFlags( + (void**)&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal), + hipDeviceMallocUncached)); +#else CUDACHECK( cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal))); +#endif CUDACHECK( cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal))); CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T))); @@ -135,26 +151,24 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, void* rank_data; size_t rank_data_sz = 16 * 1024 * 1024; CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); - vllm::Signal* ipc_ptrs[8]; - for (int i = 0; i < nRanks; i++) { - if (i == myRank) - ipc_ptrs[i] = buffer; - else - CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptrs[i], data_handles[i], - cudaIpcMemLazyEnablePeerAccess)); - } - vllm::CustomAllreduce fa(ipc_ptrs, rank_data, rank_data_sz, myRank, nRanks); + std::vector offsets(nRanks, 0); + vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, + offsets, myRank); auto* self_data = reinterpret_cast(reinterpret_cast(buffer) + sizeof(vllm::Signal) + data_size * sizeof(T)); // hack buffer registration { - void* data[8]; + std::vector handles; + handles.reserve(nRanks); for (int i = 0; i < nRanks; i++) { - data[i] = - ((char*)ipc_ptrs[i]) + sizeof(vllm::Signal) + data_size * sizeof(T); + char* begin = (char*)&data_handles[i]; + char* end = (char*)&data_handles[i + 1]; + handles.emplace_back(begin, end); } - fa.register_buffer(data); + std::vector offsets(nRanks, + sizeof(vllm::Signal) + data_size * sizeof(T)); + fa.register_buffer(handles, offsets, self_data); } double* ground_truth; @@ -311,17 +325,20 @@ int main(int argc, char** argv) { bool performance_test = true; cudaProfilerStart(); - // Uncomment to scan through different block size configs. - // for (int threads : {256, 512, 1024}) { + // for (int threads : {256, 512}) { // for (int block_limit = 16; block_limit < 112; block_limit += 4) { - // run(myRank, nRanks, comm, threads, block_limit, 1024 * 1024, - // performance_test); + // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); // } // } - // Scan through different sizes to test performance. +#ifdef USE_ROCM + for (int sz = 512; sz <= (8 << 20); sz *= 2) { + run(myRank, nRanks, comm, 512, 16, sz + 8 * 47, performance_test); + } +#else for (int sz = 512; sz <= (8 << 20); sz *= 2) { run(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test); } +#endif cudaProfilerStop(); MPICHECK(MPI_Finalize()); diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 03414b7e1ae9..3b477ba08199 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -15,7 +15,7 @@ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) // TODO(luka/varun): use FP8_TYPE macro after refactoring -#ifndef USE_ROCM +#ifdef USE_CUDA_FP8_FORMAT #define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \ AT_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) diff --git a/csrc/gradlib/hipbsolgemm.cu b/csrc/gradlib/hipbsolgemm.cu new file mode 100644 index 000000000000..81512d473a4d --- /dev/null +++ b/csrc/gradlib/hipbsolgemm.cu @@ -0,0 +1,499 @@ +// #ifdef __gfx908__ +// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below +// just for gfx908 and not for others +// // below lines enable hip float to half conversion which are disabled by +// default in hip_fp16.h #undef __HIP_NO_HALF_OPERATORS__ #undef +// __HIP_NO_HALF_CONVERSIONS__ #endif + +#include +#include +#include +#include +#include +#include +#include +// #include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include "nvToolsExt.h" + +// #include + +// #ifdef USE_ROCM +// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + +// ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL +// (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #endif + +// #ifdef __HIP_PLATFORM_HCC__ +// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + +// ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL +// (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #if USE_GEMM_FLAGS_FP16_ALT_IMPL +// #ifdef ROCM_BACKWARD_PASS_GUARD +// flag = at::BackwardPassGuard::is_backward_pass() ? +// rocblas_gemm_flags_fp16_alt_impl : 0; #endif #endif #endif + +#ifndef CHECK_HIP_ERROR + #define CHECK_HIP_ERROR(error) \ + if (error != hipSuccess) { \ + fprintf(stderr, "Hip error: '%s'(%d) at %s:%d\n", \ + hipGetErrorString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ + } +#endif + +#ifndef CHECK_HIPBLAS_ERROR + #define CHECK_HIPBLAS_ERROR(error) \ + if (error != HIPBLAS_STATUS_SUCCESS) { \ + fprintf(stderr, "hipBLAS error: '%s'(%d) at %s:%d\n", \ + hipblasStatusToString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ + } +#endif + +namespace { +/*thread_local*/ cudaStream_t weight_stream; +// BUG: DLM has event and stream on different devices error +// In multi-GPU scenerio, do names defined in this namespace exist on all +// devices? C++ keyword: thread_local <- maybe this can help? +/*thread_local*/ cudaEvent_t event; + +// hipBLASLt +hipblasLtHandle_t hipblaslt_handle; +hipblasLtMatmulPreference_t preference; +size_t workspace_size = 2 * 128 * 1024 * 1024; +// uint64_t workspace_size = 0; +void* d_workspace; +int request_solutions = 1; +int returnedAlgoCount = 0; + +struct MatMulConfig { + hipblasOperation_t op_A; + hipblasOperation_t op_B; + int M; + int N; + int K; + hipDataType dtype; + + friend auto operator<(const MatMulConfig& left, const MatMulConfig& right) + -> bool { + return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < + std::tie(right.op_A, right.op_B, right.M, right.N, right.K, + right.dtype); + } +}; + +// std::map, +// std::vector> heuristic_map; +std::map heuristic_map; + +hipEvent_t start, stop; +int bench_iters{1}; +int warmup_iters{1}; + +bool cout_print = false; + +torch::Tensor dTensor; + +std::map dtype_map{ + {at::kHalf, HIP_R_16F}, + {at::kBFloat16, HIP_R_16BF}, + {at::kFloat, HIP_R_32F}, + {at::kFloat8_e4m3fnuz, HIP_R_8F_E4M3_FNUZ}}; + +// std::vector heuristicResult; +} // namespace + +// find all hipblaslt solutions for given gemm problem +std::vector hipblasLtMatmul_findallsols_wrapper( + hipblasLtHandle_t handle, hipblasOperation_t op_A, hipblasOperation_t op_B, + int m, int n, int k, const void* alpha, const void* a, int lda, + const void* b, int ldb, const void* beta, void* c, int ldc, + const void* bias, hipDataType intype, hipDataType outtype, + hipStream_t& stream) { + int flag{0}; + hipblasLtMatrixLayout_t matA, matB, matC; + hipblasLtMatmulDesc_t matmul; + if (op_A == HIPBLAS_OP_N) { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, m, k, lda)); + } else { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, k, m, lda)); + } + if (op_B == HIPBLAS_OP_N) { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, k, n, ldb)); + } else { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb)); + } + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc)); + CHECK_HIPBLAS_ERROR( + hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); + + if (bias) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(void*))); + auto epilogue = HIPBLASLT_EPILOGUE_BIAS; + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(int32_t))); + } + + // std::vector heuristicResult(10); + // CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( + // handle, matmul, matA, matB, matC, matC, + // preference, 10, heuristicResult.data(), &returnedAlgoCount)); + std::vector heuristicResult; + CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAllAlgos( + handle, hipblaslt_ext::GemmType::HIPBLASLT_GEMM, op_A, op_B, intype, + intype, outtype, outtype, HIPBLAS_COMPUTE_32F, heuristicResult)); + + std::vector algoIndex; + int returned_algo_count = heuristicResult.size(); + // for (int i = 0; i < returnedAlgoCount; i++) { + for (int i = 0; i < returned_algo_count; i++) { + auto algo = heuristicResult[i].algo; + size_t ret_workspace_size = 0; + auto status = hipblaslt_ext::matmulIsAlgoSupported( + handle, matmul, alpha, matA, matB, beta, matC, matC, algo, + ret_workspace_size); + if (status == HIPBLAS_STATUS_SUCCESS) { + if (ret_workspace_size < workspace_size) { + algoIndex.push_back(hipblaslt_ext::getIndexFromAlgo(algo)); + } + } + } + + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescDestroy(matmul)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matA)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matB)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matC)); + return algoIndex; +} +///////////////////////////////////////////////////////////////////////////////////////////////////////// +/** + * hipBLASLt GEMM call + */ +hipblasStatus_t hipblasLtMatmul_sol_wrapper( + hipblasLtHandle_t handle, hipblasOperation_t op_A, hipblasOperation_t op_B, + int m, int n, int k, const void* alpha, const void* a, int lda, + const void* scaleA, const void* b, int ldb, const void* scaleB, + const void* beta, void* c, int ldc, const void* scaleC, const void* bias, + hipDataType intype, hipDataType outtype, hipStream_t& stream, + int solution_index = -1) { + // TODO: flag is not supported for hipblasLt yet + int flag{0}; + // if (dtype == HIPBLAS_R_16F) { + // use fp16 alt impl for MI200 + // https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + // flag = rocblas_gemm_flags_fp16_alt_impl; + //} + + // nvtxRangePushA("hipBLASLt variables creation"); + hipblasLtMatrixLayout_t matA, matB, matC; + hipblasLtMatmulDesc_t matmul; + if (op_A == HIPBLAS_OP_N) { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, m, k, lda)); + } else { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, k, m, lda)); + } + if (op_B == HIPBLAS_OP_N) { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, k, n, ldb)); + } else { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb)); + } + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc)); + CHECK_HIPBLAS_ERROR( + hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); + if (scaleA != nullptr) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scaleA, + sizeof(scaleA))); + } + if (scaleB != nullptr) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scaleB, + sizeof(scaleB))); + } + if (scaleC != nullptr) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scaleC, + sizeof(scaleC))); + } + if (bias) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(void*))); + auto epilogue = HIPBLASLT_EPILOGUE_BIAS; + static_assert(sizeof(epilogue) == sizeof(int32_t)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(int32_t))); + } + // nvtxRangePop(); + // if heuristic does not exist in the map, do search and push into the map + // auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } }; + // if (heuristic_map.count(gemm_key) <= 0) { + std::vector heuristicResult(1); + if (solution_index < 0) { + // nvtxRangePushA("hipblasLtMatmulAlgoGetHeuristic"); + std::cout + << "Warning! HipbSolId Gemm Fallback Path used for solution index <0" + << std::endl; + if (cout_print) { + std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") + << (op_B == HIPBLAS_OP_N ? "N" : "T") << " (" << m << ", " << n + << ", " << k << "), dtype: " << intype << ", (lda, ldb, ldc): (" + << lda << ", " << ldb << ", " << ldc << "), " << std::endl; + } + CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( + handle, matmul, matA, matB, matC, matC, preference, request_solutions, + heuristicResult.data(), &returnedAlgoCount)); + if ((returnedAlgoCount != request_solutions) && cout_print) { + std::cout << "less solution found! request: " << request_solutions + << ", found: " << returnedAlgoCount << std::endl; + } + } else { + std::vector algoIndex(1); + algoIndex[0] = solution_index; + CHECK_HIPBLAS_ERROR( + hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, heuristicResult)); + } + + hipblasStatus_t status = hipblasLtMatmul( + handle, matmul, alpha, a, matA, b, matB, beta, c, matC, c, matC, + &heuristicResult[0].algo, d_workspace, workspace_size, stream); + + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescDestroy(matmul)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matA)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matB)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matC)); + + return status; +} +///////////////////////////////////////////////////////////////////////////////////////////////////////// +torch::Tensor hipb_mm(const torch::Tensor& mat1, const torch::Tensor& mat2, + const int64_t solution_index, + at::optional bias, + at::optional out_dtype, + at::optional scale1, + at::optional scale2, + at::optional scaleOut) { + auto mat1_strides{mat1.strides()}; + auto mat2_strides{mat2.strides()}; + auto mat1_sizes{mat1.sizes()}; + auto mat2_sizes{mat2.sizes()}; + + TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); + TORCH_CHECK(mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", + mat1.dtype(), " != ", mat2.dtype()); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], + "mat1 dim 1 must match mat2 dim 0"); + + auto inDtype{mat1.options().dtype().toScalarType()}; + auto outDtype{out_dtype.has_value() ? out_dtype.value() : inDtype}; + auto options{at::TensorOptions().dtype(outDtype).device(at::kCUDA)}; + auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; + + bool transpose_result = true; + bool transpose_mat1; + bool transpose_mat2; + if ((mat2_strides[0] == 1) && + (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { + transpose_mat2 = false; + } else if ((mat2_strides[1] == 1) && + (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { + transpose_mat2 = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + if ((mat1_strides[0] == 1) && + (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { + transpose_mat1 = false; + } else if ((mat1_strides[1] == 1) && + (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { + transpose_mat1 = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + + if (transpose_result) { + bool tmp = transpose_mat1; + transpose_mat1 = !transpose_mat2; + transpose_mat2 = !tmp; + mat1_strides = mat2.strides(); + mat2_strides = mat1.strides(); + mat1_sizes = mat2.sizes(); + mat2_sizes = mat1.sizes(); + } + + float one{1.0f}; + float zero{0.0f}; + int64_t m = mat1_sizes[transpose_result ? 1 : 0]; + int64_t k = mat1_sizes[transpose_result ? 0 : 1]; + int64_t n = mat2_sizes[transpose_result ? 0 : 1]; + int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; + int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; + int64_t result_ld = result.stride(transpose_result ? 0 : 1); + + void *d_scale1 = nullptr, *d_scale2 = nullptr, *d_scaleOut = nullptr; + if (scale1.has_value()) { + d_scale1 = static_cast(scale1.value().data_ptr()); + } + if (scale2.has_value()) { + d_scale2 = static_cast(scale2.value().data_ptr()); + } + if (scaleOut.has_value()) { + d_scaleOut = static_cast(scaleOut.value().data_ptr()); + } + + auto hipblasInType = dtype_map.at(inDtype); + auto hipblasOutType = dtype_map.at(outDtype); + + void* ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; + void* ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; + void* ptrC{static_cast(result.data_ptr())}; + if (transpose_result) std::swap(d_scale1, d_scale2); + auto current_stream{torch::hip::getCurrentHIPStream().stream()}; + void* bias_ptr = + bias.has_value() ? static_cast(bias.value().data_ptr()) : nullptr; + + CHECK_HIPBLAS_ERROR(hipblasLtMatmul_sol_wrapper( + hipblaslt_handle, transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, &one, ptrA, + mat1_ld, d_scale1, ptrB, mat2_ld, d_scale2, &zero, ptrC, result_ld, + d_scaleOut, bias_ptr, hipblasInType, hipblasOutType, current_stream, + solution_index)); + + return result; +} + +// find all hipblas solutions and return them to python land +std::vector hipb_findallsols( + const torch::Tensor& mat1, const torch::Tensor& mat2, + at::optional bias = at::nullopt, + at::optional out_dtype = at::nullopt) { + auto mat1_strides{mat1.strides()}; + auto mat2_strides{mat2.strides()}; + auto mat1_sizes{mat1.sizes()}; + auto mat2_sizes{mat2.sizes()}; + TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); + TORCH_CHECK(mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", + mat1.dtype(), " != ", mat2.dtype()); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], + "mat1 dim 1 must match mat2 dim 0"); + + auto inType{mat1.options().dtype().toScalarType()}; + auto outType{out_dtype.has_value() ? out_dtype.value() : inType}; + + auto options{at::TensorOptions().dtype(outType).device(at::kCUDA)}; + auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; + bool transpose_result = true; + bool transpose_mat1; + bool transpose_mat2; + if ((mat2_strides[0] == 1) && + (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { + transpose_mat2 = false; + } else if ((mat2_strides[1] == 1) && + (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { + transpose_mat2 = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + if ((mat1_strides[0] == 1) && + (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { + transpose_mat1 = false; + } else if ((mat1_strides[1] == 1) && + (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { + transpose_mat1 = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + if (transpose_result) { + bool tmp = transpose_mat1; + transpose_mat1 = !transpose_mat2; + transpose_mat2 = !tmp; + mat1_strides = mat2.strides(); + mat2_strides = mat1.strides(); + mat1_sizes = mat2.sizes(); + mat2_sizes = mat1.sizes(); + } + float one{1.0f}; + float zero{0.0f}; + int64_t m = mat1_sizes[transpose_result ? 1 : 0]; + int64_t k = mat1_sizes[transpose_result ? 0 : 1]; + int64_t n = mat2_sizes[transpose_result ? 0 : 1]; + int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; + int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; + int64_t result_ld = result.stride(transpose_result ? 0 : 1); + hipDataType hipblasInType = dtype_map.at(inType); + hipDataType hipblasOutType = dtype_map.at(outType); + + void* ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; + void* ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; + void* ptrC{static_cast(result.data_ptr())}; + auto current_stream{torch::hip::getCurrentHIPStream().stream()}; + + auto bias_ptr = + bias.has_value() ? static_cast(bias.value().data_ptr()) : nullptr; + + return hipblasLtMatmul_findallsols_wrapper( + hipblaslt_handle, transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, &one, ptrA, + mat1_ld, ptrB, mat2_ld, &zero, ptrC, result_ld, bias_ptr, hipblasInType, + hipblasOutType, current_stream); +} +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +void hipb_create_extension() { + // CHECK_HIP_ERROR(hipStreamCreate(&weight_stream)); + // CHECK_HIP_ERROR(hipEventCreateWithFlags(&event, cudaEventDisableTiming)); + + // hipBLASLt + CHECK_HIPBLAS_ERROR(hipblasLtCreate(&hipblaslt_handle)); + CHECK_HIP_ERROR(hipMalloc(&d_workspace, workspace_size)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceCreate(&preference)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceSetAttribute( + preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, + sizeof(workspace_size))); + + // CHECK_HIP_ERROR(hipEventCreate(&start)); + // CHECK_HIP_ERROR(hipEventCreate(&stop)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +void hipb_destroy_extension() { + // CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); + // CHECK_HIP_ERROR(hipEventDestroy(event)); + + // hipBLASLt + CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); + CHECK_HIP_ERROR(hipFree(d_workspace)); + + // CHECK_HIP_ERROR(hipEventDestroy(start)); + // CHECK_HIP_ERROR(hipEventDestroy(stop)); +} diff --git a/csrc/gradlib/ops.h b/csrc/gradlib/ops.h new file mode 100644 index 000000000000..43107ec3c098 --- /dev/null +++ b/csrc/gradlib/ops.h @@ -0,0 +1,27 @@ +#pragma once + +#include + +void hipb_create_extension(); +void hipb_destroy_extension(); +torch::Tensor hipb_mm(const torch::Tensor& mat1, const torch::Tensor& mat2, + const int64_t solution_index, + at::optional bias = at::nullopt, + at::optional out_dtype = at::nullopt, + at::optional scale1 = at::nullopt, + at::optional scale2 = at::nullopt, + at::optional scaleOut = at::nullopt); + +std::vector hipb_findallsols(const torch::Tensor& mat1, + const torch::Tensor& mat2, + at::optional bias, + at::optional out_dtype); + +void rocb_create_extension(); +void rocb_destroy_extension(); +torch::Tensor RocSolIdxBlas(const torch::Tensor& mat1, + const torch::Tensor& mat2, + const int64_t solution_index); + +std::vector RocFindAllSolIdxBlas(const torch::Tensor& mat1, + const torch::Tensor& mat2); \ No newline at end of file diff --git a/csrc/gradlib/rocsolgemm.cu b/csrc/gradlib/rocsolgemm.cu new file mode 100644 index 000000000000..9d5347e0a7dc --- /dev/null +++ b/csrc/gradlib/rocsolgemm.cu @@ -0,0 +1,567 @@ +// #ifdef __gfx908__ +// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below +// just for gfx908 and not for others +// // below lines enable hip float to half conversion which are disabled by +// default in hip_fp16.h #undef __HIP_NO_HALF_OPERATORS__ #undef +// __HIP_NO_HALF_CONVERSIONS__ #endif + +#define ROCBLAS_NO_DEPRECATED_WARNINGS +#define ROCBLAS_BETA_FEATURES_API + +#include +#include +#include +#include +#include +#include +#include +// #include +#include +#include +#include +#include + +#include +// #include +#include + +#include +#include +#include +#include +#include +#include +#include "nvToolsExt.h" + +#include + +// #ifdef USE_ROCM +// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + +// ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL +// (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #endif + +// #ifdef __HIP_PLATFORM_HCC__ +// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + +// ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL +// (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #if USE_GEMM_FLAGS_FP16_ALT_IMPL +// #ifdef ROCM_BACKWARD_PASS_GUARD +// flag = at::BackwardPassGuard::is_backward_pass() ? +// rocblas_gemm_flags_fp16_alt_impl : 0; #endif #endif #endif + +#ifndef CHECK_HIP_ERROR + #define CHECK_HIP_ERROR(error) \ + if (error != hipSuccess) { \ + fprintf(stderr, "Hip error: '%s'(%d) at %s:%d\n", \ + hipGetErrorString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ + } +#endif + +#ifndef CHECK_HIPBLAS_ERROR + #define CHECK_HIPBLAS_ERROR(error) \ + if (error != HIPBLAS_STATUS_SUCCESS) { \ + fprintf(stderr, "hipBLAS error: '%s'(%d) at %s:%d\n", \ + hipblasStatusToString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ + } +#endif + +namespace { +rocblas_handle r_handle; + +/*thread_local*/ cudaStream_t weight_stream; +// BUG: DLM has event and stream on different devices error +// In multi-GPU scenerio, do names defined in this namespace exist on all +// devices? C++ keyword: thread_local <- maybe this can help? +/*thread_local*/ cudaEvent_t event; + +// hipBLASLt +hipblasLtHandle_t hipblaslt_handle; +hipblasLtMatmulPreference_t preference; +uint64_t workspace_size = 32 * 1024 * 1024; +// uint64_t workspace_size = 0; +void* d_workspace; +int request_solutions = 1; +int returnedAlgoCount = 0; + +struct MatMulConfig { + hipblasOperation_t op_A; + hipblasOperation_t op_B; + int M; + int N; + int K; + hipblasDatatype_t dtype; + + friend auto operator<(const MatMulConfig& left, const MatMulConfig& right) + -> bool { + return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < + std::tie(right.op_A, right.op_B, right.M, right.N, right.K, + right.dtype); + } +}; + +// std::map, +// std::vector> heuristic_map; +std::map heuristic_map; + +hipEvent_t start, stop; +int bench_iters{1}; +int warmup_iters{1}; + +bool cout_print = true; +} // namespace + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +/** + * hipBLASLt GEMM call + */ +/* +hipblasStatus_t hipblasLtMatmul_wrapper( + hipblasLtHandle_t handle, + hipblasOperation_t op_A, + hipblasOperation_t op_B, + int m, int n, int k, + const void *alpha, + const void *a, + int lda, + const void *b, + int ldb, + const void *beta, + void *c, + int ldc, + hipblasDatatype_t dtype, + hipStream_t &stream) +{ + // TODO: flag is not supported for hipblasLt yet + int flag { 0 }; + if (dtype == HIPBLAS_R_16F) { + // use fp16 alt impl for MI200 + // +https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + flag = rocblas_gemm_flags_fp16_alt_impl; + } + + nvtxRangePushA("hipBLASLt variables creation"); + hipblasLtMatrixLayout_t matA, matB, matC; + hipblasLtMatmulDesc_t matmul; + if (op_A == HIPBLAS_OP_N) { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, m, k, lda)); + } else { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, k, m, lda)); + } + if (op_B == HIPBLAS_OP_N) { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, k, n, ldb)); + } else { + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, n, k, ldb)); + } + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, dtype, m, n, ldc)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, +HIPBLAS_R_32F)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, +HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); + nvtxRangePop(); + + // if heuristic does not exist in the map, do search and push into the map + auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } }; + if (heuristic_map.count(gemm_key) <= 0) { + nvtxRangePushA("hipblasLtMatmulAlgoGetHeuristic"); + if (cout_print) { + std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") << (op_B == HIPBLAS_OP_N ? +"N" : "T") + << " (" << m << ", " << n << ", " << k << "), dtype: " << dtype + << ", (lda, ldb, ldc): (" << lda << ", " << ldb << ", " << ldc +<< "), " << std::endl; + } + std::vector +heuristicResult(request_solutions); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( + handle, matmul, matA, matB, matC, matC, + preference, request_solutions, heuristicResult.data(), +&returnedAlgoCount)); if((returnedAlgoCount != request_solutions) && cout_print) +{ std::cout << "less solution found! request: " << request_solutions + << ", found: " << returnedAlgoCount << std::endl; + } + + if (returnedAlgoCount == 1) { + heuristic_map[gemm_key] = heuristicResult[0]; + } else { + // benchmark requested solutions and pick best one + int bestIndex { -1 }; + double bestMs { std::numeric_limits::max() }; + for (int sol { 0 }; sol < returnedAlgoCount; ++sol) { + // warm up + for (int iter { 0 }; iter < warmup_iters; ++iter) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, + alpha, + a, matA, + b, matB, + beta, + c, matC, + c, matC, // In case beta != 0, these runs can overwrite the values +in c + // since c and d are the same + // TODO: allocates separate d memory for these runs + &heuristicResult[sol].algo, + d_workspace, workspace_size, + stream)); + } + // performance measuring + double eventMs; + CHECK_HIP_ERROR(hipEventRecord(start, stream)); + for (int iter { 0 }; iter < bench_iters; ++iter) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, + alpha, + a, matA, + b, matB, + beta, + c, matC, + c, matC, // In case beta != 0, these runs can overwrite the values +in c + // since c and d are the same + // TODO: allocates separate d memory for these runs + &heuristicResult[sol].algo, + d_workspace, workspace_size, + stream)); + } + CHECK_HIP_ERROR(hipEventRecord(stop, stream)); + CHECK_HIP_ERROR(hipEventSynchronize(stop)); + float temp; + CHECK_HIP_ERROR(hipEventElapsedTime(&temp, start, stop)); + eventMs = double(temp); + eventMs /= bench_iters; + + if (cout_print) { + std::cout << " Sol " << sol << ": average time per iter " << +std::to_string(eventMs) << " ms"; + } + if (bestMs > eventMs) { + bestMs = eventMs; + bestIndex = sol; + if (cout_print) { + std::cout << " *" << std::endl; + } + } else { + if (cout_print) { + std::cout << std::endl; + } + } + } + heuristic_map[gemm_key] = heuristicResult[bestIndex]; + } + nvtxRangePop(); + } + + hipblasStatus_t status = hipblasLtMatmul(handle, matmul, + alpha, + a, matA, + b, matB, + beta, + c, matC, + c, matC, + &heuristic_map[gemm_key].algo, + d_workspace, workspace_size, + stream); + + nvtxRangePushA("hipBLASLt variables deletion"); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescDestroy(matmul)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matA)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matB)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matC)); + nvtxRangePop(); + + return status; +} +*/ +///////////////////////////////////////////////////////////////////////////////////////////////////////// +std::vector RocFindAllSolIdxBlas(const torch::Tensor& mat1, + const torch::Tensor& mat2) { + auto mat1_strides{mat1.strides()}; + auto mat2_strides{mat2.strides()}; + auto mat1_sizes{mat1.sizes()}; + auto mat2_sizes{mat2.sizes()}; + + TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); + TORCH_CHECK(mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", + mat1.dtype(), " != ", mat2.dtype()); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], + "mat1 dim 1 must match mat2 dim 0"); + + auto abcType{mat1.options().dtype()}; + auto options{at::TensorOptions().dtype(abcType).device(at::kCUDA)}; + auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; + + bool transpose_result = true; + bool transpose_mat1; + bool transpose_mat2; + if ((mat2_strides[0] == 1) && + (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { + transpose_mat2 = false; + } else if ((mat2_strides[1] == 1) && + (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { + transpose_mat2 = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + if ((mat1_strides[0] == 1) && + (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { + transpose_mat1 = false; + } else if ((mat1_strides[1] == 1) && + (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { + transpose_mat1 = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + if (transpose_result) { + bool tmp = transpose_mat1; + transpose_mat1 = !transpose_mat2; + transpose_mat2 = !tmp; + mat1_strides = mat2.strides(); + mat2_strides = mat1.strides(); + mat1_sizes = mat2.sizes(); + mat2_sizes = mat1.sizes(); + } + float one{1.0f}; + float zero{0.0f}; + int64_t m = mat1_sizes[transpose_result ? 1 : 0]; + int64_t k = mat1_sizes[transpose_result ? 0 : 1]; + int64_t n = mat2_sizes[transpose_result ? 0 : 1]; + int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; + int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; + int64_t result_ld = result.stride(transpose_result ? 0 : 1); + + void* ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; + void* ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; + void* ptrC{static_cast(result.data_ptr())}; + auto current_stream{torch::hip::getCurrentHIPStream().stream()}; + + rocblas_set_stream(r_handle, current_stream); + uint32_t flags{0}; + rocblas_datatype abcRtype; + if (abcType == at::kHalf) { + abcRtype = rocblas_datatype_f16_r; + } else if (abcType == at::kBFloat16) { + abcRtype = rocblas_datatype_bf16_r; + } else if (abcType == at::kFloat) { + abcRtype = rocblas_datatype_f32_r; + } else { + assert(false && "Wrong datatype!"); + } + +#define GEMM_EX_ARGS \ + r_handle, \ + transpose_mat1 ? rocblas_operation_transpose : rocblas_operation_none, \ + transpose_mat2 ? rocblas_operation_transpose : rocblas_operation_none, \ + m, n, k, &one, ptrA, abcRtype, mat1_ld, ptrB, abcRtype, mat2_ld, &zero, \ + ptrC, abcRtype, result_ld, ptrC, abcRtype, result_ld, \ + rocblas_datatype_f32_r, rocblas_gemm_algo_solution_index + + rocblas_int sizeSolve; + // CHECK_ROCBLAS_ERROR( + rocblas_gemm_ex_get_solutions(GEMM_EX_ARGS, rocblas_gemm_flags_none, NULL, + &sizeSolve); + + // Fill array with list of solutions that match type + // Note: some of these may be invalid + std::vector solutionsSolve(sizeSolve); + // CHECK_ROCBLAS_ERROR( + rocblas_gemm_ex_get_solutions(GEMM_EX_ARGS, rocblas_gemm_flags_none, + solutionsSolve.data(), &sizeSolve); + + std::vector validSolutions; + for (auto sol : solutionsSolve) { + auto status = rocblas_gemm_ex( + r_handle, + transpose_mat1 ? rocblas_operation_transpose : rocblas_operation_none, + transpose_mat2 ? rocblas_operation_transpose : rocblas_operation_none, + m, n, k, &one, ptrA, abcRtype, mat1_ld, ptrB, abcRtype, mat2_ld, &zero, + ptrC, abcRtype, result_ld, ptrC, abcRtype, result_ld, + rocblas_datatype_f32_r, rocblas_gemm_algo_solution_index, sol, + rocblas_gemm_flags_none); + if (status == rocblas_status_success) { + validSolutions.push_back(sol); + } + } + + return validSolutions; +} +///////////////////////////////////////////////////////////////////////////////////////////////////////// +torch::Tensor RocSolIdxBlas(const torch::Tensor& mat1, + const torch::Tensor& mat2, + const int64_t solution_index = 0) { + auto mat1_strides{mat1.strides()}; + auto mat2_strides{mat2.strides()}; + auto mat1_sizes{mat1.sizes()}; + auto mat2_sizes{mat2.sizes()}; + // std::cout << " | mat1 info: size: " << mat1_sizes << " stride: " << + // mat1_strides << std::endl + // << " | mat2 info: size: " << mat2_sizes << " stride: " << + // mat2_strides << std::endl; + + TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); + TORCH_CHECK(mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", + mat1.dtype(), " != ", mat2.dtype()); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], + "mat1 dim 1 must match mat2 dim 0"); + + auto abcType{mat1.options().dtype()}; + auto options{at::TensorOptions().dtype(abcType).device(at::kCUDA)}; + auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; + // std::cout << " | result info: size: " << result.sizes() << " stride: " << + // result.strides() << std::endl; + + bool transpose_result = true; + bool transpose_mat1; + bool transpose_mat2; + if ((mat2_strides[0] == 1) && + (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { + transpose_mat2 = false; + } else if ((mat2_strides[1] == 1) && + (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { + transpose_mat2 = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + if ((mat1_strides[0] == 1) && + (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { + transpose_mat1 = false; + } else if ((mat1_strides[1] == 1) && + (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { + transpose_mat1 = true; + } else { + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); + } + + if (transpose_result) { + bool tmp = transpose_mat1; + transpose_mat1 = !transpose_mat2; + transpose_mat2 = !tmp; + mat1_strides = mat2.strides(); + mat2_strides = mat1.strides(); + mat1_sizes = mat2.sizes(); + mat2_sizes = mat1.sizes(); + } + // std::cout << " | transpose_result: " << (transpose_result ? "true" : + // "false") << std::endl + // << " | transpose_A: " << (transpose_mat1 ? "true" : "false") << + // std::endl + // << " | transpose_B: " << (transpose_mat2 ? "true" : "false") << + // std::endl; + // std::cout << " | A matrix: size: " << mat1_sizes << " stride: " << + // mat1_strides << std::endl + // << " | B matrix: size: " << mat2_sizes << " stride: " << + // mat2_strides << std::endl; + + float one{1.0f}; + float zero{0.0f}; + int64_t m = mat1_sizes[transpose_result ? 1 : 0]; + int64_t k = mat1_sizes[transpose_result ? 0 : 1]; + int64_t n = mat2_sizes[transpose_result ? 0 : 1]; + int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; + int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; + int64_t result_ld = result.stride(transpose_result ? 0 : 1); + // std::cout << " | (m, n, k): " << m << ", " << n << ", " << k << std::endl + // << " | (lda, ldb, ldc): " << mat1_ld << ", " << mat2_ld << ", " + // << result_ld << std::endl; + + /* + int flag { 0 }; + hipblasDatatype_t hipblasType; + if (abcType == at::kHalf) { + hipblasType = HIPBLAS_R_16F; + } else if (abcType == at::kBFloat16) { + hipblasType = HIPBLAS_R_16B; + } else if (abcType == at::kFloat) { + hipblasType = HIPBLAS_R_32F; + } else { + assert(false && "Wrong datatype!"); + } + */ + void* ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; + void* ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; + void* ptrC{static_cast(result.data_ptr())}; + auto current_stream{torch::hip::getCurrentHIPStream().stream()}; + /* + + CHECK_HIPBLAS_ERROR(hipblasLtMatmul_wrapper( + hipblaslt_handle, + transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + m, n, k, + &one, + ptrA, mat1_ld, + ptrB, mat2_ld, + &zero, + ptrC, result_ld, + hipblasType, + current_stream)); + */ + rocblas_set_stream(r_handle, current_stream); + uint32_t flags{0}; + // int32_t solution_index {0}; + rocblas_datatype abcRtype; + if (abcType == at::kHalf) { + abcRtype = rocblas_datatype_f16_r; + } else if (abcType == at::kBFloat16) { + abcRtype = rocblas_datatype_bf16_r; + } else if (abcType == at::kFloat) { + abcRtype = rocblas_datatype_f32_r; + } else { + assert(false && "Wrong datatype!"); + } + + // CHECK_ROCBLAS_ERROR( + rocblas_gemm_ex( + r_handle, + transpose_mat1 ? rocblas_operation_transpose : rocblas_operation_none, + transpose_mat2 ? rocblas_operation_transpose : rocblas_operation_none, m, + n, k, &one, ptrA, abcRtype, mat1_ld, ptrB, abcRtype, mat2_ld, &zero, ptrC, + abcRtype, result_ld, ptrC, abcRtype, result_ld, rocblas_datatype_f32_r, + rocblas_gemm_algo_solution_index, solution_index, flags); + //); + + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +void rocb_create_extension() { + /* + CHECK_HIP_ERROR(hipStreamCreate(&weight_stream)); + CHECK_HIP_ERROR(hipEventCreateWithFlags(&event, cudaEventDisableTiming)); + + // hipBLASLt + CHECK_HIPBLAS_ERROR(hipblasLtCreate(&hipblaslt_handle)); + CHECK_HIP_ERROR(hipMalloc(&d_workspace, workspace_size)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceCreate(&preference)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceSetAttribute( + preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, + sizeof(workspace_size))); + + CHECK_HIP_ERROR(hipEventCreate(&start)); + CHECK_HIP_ERROR(hipEventCreate(&stop)); */ + rocblas_create_handle(&r_handle); +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +void rocb_destroy_extension() { + /* + CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); + CHECK_HIP_ERROR(hipEventDestroy(event)); + + // hipBLASLt + CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); + CHECK_HIP_ERROR(hipFree(d_workspace)); + + CHECK_HIP_ERROR(hipEventDestroy(start)); + CHECK_HIP_ERROR(hipEventDestroy(stop)); */ + rocblas_destroy_handle(r_handle); +} diff --git a/csrc/gradlib/torch_bindings.cpp b/csrc/gradlib/torch_bindings.cpp new file mode 100644 index 000000000000..1818df584c5d --- /dev/null +++ b/csrc/gradlib/torch_bindings.cpp @@ -0,0 +1,18 @@ +#include "core/registration.h" +#include "gradlib/ops.h" + +TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, gradlib_ops) { + // Gradlib custom ops + + gradlib_ops.def("hipb_create_extension", &hipb_create_extension); + gradlib_ops.def("hipb_destroy_extension", &hipb_destroy_extension); + gradlib_ops.def("hipb_mm", &hipb_mm); + gradlib_ops.def("hipb_findallsols", &hipb_findallsols); + + gradlib_ops.def("rocb_create_extension", &rocb_create_extension); + gradlib_ops.def("rocb_destroy_extension", &rocb_destroy_extension); + gradlib_ops.def("rocb_mm", &RocSolIdxBlas); + gradlib_ops.def("rocb_findallsols", &RocFindAllSolIdxBlas); +} + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index fb6882f3e7c3..e14ad972a0a4 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -10,19 +10,76 @@ #include #endif +#ifdef USE_ROCM + #include "quantization/fp8/amd/quant_utils.cuh" +#else + #include "quantization/fp8/nvidia/quant_utils.cuh" +#endif + namespace vllm { -// TODO(woosuk): Further optimize this kernel. -template -__global__ void rms_norm_kernel( - scalar_t* __restrict__ out, // [..., hidden_size] - const scalar_t* __restrict__ input, // [..., hidden_size] - const scalar_t* __restrict__ weight, // [hidden_size] - const float epsilon, const int num_tokens, const int hidden_size) { +// This kernel uses the _f16Vec to represent vectorized data. +// A conversion to/from float should exist +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, + const size_t hidden_size, const size_t vec_hidden_size) { + __shared__ float s_variance; + float v8_variance_sum = 0.0f; + + const int64_t tx = threadIdx.x; + const int64_t bx = blockIdx.x; + const int64_t num_threads = blockDim.x; + + auto* __restrict__ out_v = reinterpret_cast<_f16Vec*>(out); + auto* __restrict__ input_v = + reinterpret_cast*>( + input + bx * static_cast(hidden_size)); + auto* __restrict__ weight_v = + reinterpret_cast*>(weight); + + // Compute variance. Be careful, hidden_size should multiple of 4. + for (size_t idx = tx; idx < vec_hidden_size; idx += num_threads) { + _f16Vec temp = input_v[idx]; + v8_variance_sum += temp.sum_squares(); + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + + float variance = + BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, num_threads); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + variance = s_variance; + + for (size_t idx = tx; idx < vec_hidden_size; idx += num_threads) { + _f16Vec temp = input_v[idx]; + temp *= variance; + temp *= weight_v[idx]; + out_v[bx * static_cast(vec_hidden_size) + idx] = temp; + } +} + +// Non vectorized kernel for unusual shapes/types without conversion +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, + const size_t hidden_size, const size_t) { __shared__ float s_variance; float variance = 0.0f; - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + for (size_t idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } @@ -36,7 +93,7 @@ __global__ void rms_norm_kernel( } __syncthreads(); - for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + for (size_t idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { float x = (float)input[blockIdx.x * hidden_size + idx]; out[blockIdx.x * hidden_size + idx] = ((scalar_t)(x * s_variance)) * weight[idx]; @@ -134,24 +191,56 @@ fused_add_rms_norm_kernel( } } +/* Function specialization in the case of FP16/BF16 tensors. + Additional optimizations we can make in this case are + packed and vectorized operations, which help with the + memory latency bottleneck. */ + +template <> +struct Vec { + using Type = uint2; +}; + +template <> +struct Vec { + using Type = uint4; +}; + +template <> +struct Vec { + using Type = bf16_8_t; +}; + } // namespace vllm +#define LAUNCH_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { \ + vllm::rms_norm_kernel<<>>( \ + out.data_ptr(), input.data_ptr(), \ + weight.data_ptr(), epsilon, num_tokens, hidden_size, \ + vec_hidden_size); \ + }); + void rms_norm(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; + int vec_size = 16 / input.element_size(); + int vec_hidden_size = hidden_size / vec_size; + bool can_run_vectorize = (hidden_size % vec_size) == 0; dim3 grid(num_tokens); - dim3 block(std::min(hidden_size, 1024)); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { - vllm::rms_norm_kernel<<>>( - out.data_ptr(), input.data_ptr(), - weight.data_ptr(), epsilon, num_tokens, hidden_size); - }); + if (vec_size % 8 == 0 && can_run_vectorize) { + dim3 block(std::min(vec_hidden_size, 1024)); + LAUNCH_RMS_NORM(8); + } else { + dim3 block(std::min(hidden_size, 1024)); + LAUNCH_RMS_NORM(0); + } } #define LAUNCH_FUSED_ADD_RMS_NORM(width) \ diff --git a/csrc/ops.h b/csrc/ops.h index e39d4ef3188a..f9f0f49faa29 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -38,7 +38,7 @@ void paged_attention_v1( torch::Tensor& v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step); + const int64_t blocksparse_head_sliding_step, const int64_t num_threads); void paged_attention_v2( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, @@ -50,7 +50,7 @@ void paged_attention_v2( torch::Tensor& v_scale, const int64_t tp_rank, const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, - const int64_t blocksparse_head_sliding_step); + const int64_t blocksparse_head_sliding_step, const int64_t num_threads); void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); @@ -90,6 +90,9 @@ void silu_and_mul(torch::Tensor& out, torch::Tensor& input); void mul_and_silu(torch::Tensor& out, torch::Tensor& input); +void scaled_silu_and_mul(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scale); + void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); void gelu_tanh_and_mul(torch::Tensor& out, torch::Tensor& input); @@ -233,18 +236,24 @@ void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, const std::optional& has_initial_state, bool silu_activation, int64_t pad_slot_id); -#ifndef USE_ROCM using fptr_t = int64_t; -fptr_t init_custom_ar(const std::vector& fake_ipc_ptrs, - torch::Tensor& rank_data, int64_t rank, bool full_nvlink); -void all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, - fptr_t reg_buffer, int64_t reg_buffer_sz_bytes); +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int64_t rank, + bool full_nvlink); +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out); void dispose(fptr_t _fa); int64_t meta_size(); -void register_buffer(fptr_t _fa, const std::vector& fake_ipc_ptrs); -std::tuple, std::vector> -get_graph_buffer_ipc_meta(fptr_t _fa); -void register_graph_buffers(fptr_t _fa, - const std::vector>& handles, +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets); +std::tuple> get_graph_buffer_ipc_meta( + fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector& handles, const std::vector>& offsets); +#ifdef USE_ROCM +torch::Tensor allocate_meta_buffer(int64_t size); +torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp); #endif diff --git a/csrc/quantization/fp8/amd/quant_utils.cuh b/csrc/quantization/fp8/amd/quant_utils.cuh index eb66834222f3..4b77817f2df8 100644 --- a/csrc/quantization/fp8/amd/quant_utils.cuh +++ b/csrc/quantization/fp8/amd/quant_utils.cuh @@ -5,9 +5,7 @@ #include #include -#include "../../../attention/dtype_fp8.cuh" -#include "../../../attention/dtype_float32.cuh" -#include "../../../attention/dtype_bfloat16.cuh" +#include "../../../attention/attention_dtypes.h" namespace vllm { #ifdef USE_ROCM @@ -40,8 +38,7 @@ vec_conversion(const uint8_t& a) { template <> __inline__ __device__ uint32_t vec_conversion(const uint16_t& a) { - #if defined(__HIP__MI300__) && \ - defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + #if defined(__HIP__MI300__) const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); union { __half2_raw h2r; @@ -144,8 +141,7 @@ __inline__ __device__ float vec_conversion(const uint8_t& a) { template <> __inline__ __device__ float2 vec_conversion(const uint16_t& a) { - #if defined(__HIP__MI300__) && \ - defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) + #if defined(__HIP__MI300__) float2 res; const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); res.x = f2[0]; @@ -310,7 +306,7 @@ vec_conversion(const Float8_& a) { // fp8 -> half template <> __inline__ __device__ uint16_t -scaled_vec_conversion(const uint8_t& a, const float scale) { +scaled_vec_conversion(const uint8_t& a, float scale) { hip_fp8 f8{a, hip_fp8::from_bits()}; __half_raw res; res.data = static_cast(f8) * scale; @@ -319,10 +315,9 @@ scaled_vec_conversion(const uint8_t& a, const float scale) { // fp8x2 -> half2 template <> -__inline__ __device__ uint32_t scaled_vec_conversion( - const uint16_t& a, const float scale) { - #if defined(__HIP__MI300__) && \ - defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) +__inline__ __device__ uint32_t +scaled_vec_conversion(const uint16_t& a, float scale) { + #if defined(__HIP__MI300__) const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); union { __half2_raw h2r; @@ -348,7 +343,7 @@ __inline__ __device__ uint32_t scaled_vec_conversion( // fp8x4 -> half2x2 template <> __inline__ __device__ uint2 -scaled_vec_conversion(const uint32_t& a, const float scale) { +scaled_vec_conversion(const uint32_t& a, float scale) { union { uint2 u32x2; uint32_t u32[2]; @@ -361,8 +356,8 @@ scaled_vec_conversion(const uint32_t& a, const float scale) { // fp8x8 -> half2x4 template <> -__inline__ __device__ uint4 -scaled_vec_conversion(const uint2& a, const float scale) { +__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, + float scale) { union { uint4 u64x2; uint2 u64[2]; @@ -377,20 +372,17 @@ using __nv_bfloat16 = __hip_bfloat16; // fp8 -> __nv_bfloat16 template <> __inline__ __device__ __nv_bfloat16 -scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, - const float scale) { +scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { hip_fp8 f8{a, hip_fp8::from_bits()}; float f{f8}; return __float2bfloat16(f * scale); } -using __nv_bfloat162 = __hip_bfloat162; - // fp8x2 -> __nv_bfloat162 template <> __inline__ __device__ __nv_bfloat162 scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, - const float scale) { + float scale) { __nv_bfloat162 res; res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); res.y = @@ -400,8 +392,8 @@ scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, // fp8x4 -> bf16_4_t template <> -__inline__ __device__ bf16_4_t scaled_vec_conversion( - const uint32_t& a, const float scale) { +__inline__ __device__ bf16_4_t +scaled_vec_conversion(const uint32_t& a, float scale) { bf16_4_t res; res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), @@ -412,7 +404,7 @@ __inline__ __device__ bf16_4_t scaled_vec_conversion( // fp8x8 -> bf16_8_t template <> __inline__ __device__ bf16_8_t -scaled_vec_conversion(const uint2& a, const float scale) { +scaled_vec_conversion(const uint2& a, float scale) { bf16_4_t tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale); tmp2 = scaled_vec_conversion(a.y, scale); @@ -427,7 +419,7 @@ scaled_vec_conversion(const uint2& a, const float scale) { // fp8 -> float template <> __inline__ __device__ float scaled_vec_conversion( - const uint8_t& a, const float scale) { + const uint8_t& a, float scale) { hip_fp8 fp8{a, hip_fp8::from_bits()}; return static_cast(fp8) * scale; } @@ -435,9 +427,8 @@ __inline__ __device__ float scaled_vec_conversion( // fp8x2 -> float2 template <> __inline__ __device__ float2 -scaled_vec_conversion(const uint16_t& a, const float scale) { - #if defined(__HIP__MI300__) && \ - defined(__HIP_FP8_EXPERIMENTAL_BULK_CONVERT__) +scaled_vec_conversion(const uint16_t& a, float scale) { + #if defined(__HIP__MI300__) float2 res; const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); res.x = f2[0] * scale; @@ -462,10 +453,18 @@ scaled_vec_conversion(const uint32_t& a, const float scale) { return res; } +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 +scaled_vec_conversion(const uint32_t& a, float scale) { + Float4_ res = scaled_vec_conversion(a, scale); + return {res.x.x, res.x.y, res.y.x, res.y.y}; +} + // fp8x8 -> float8 template <> __inline__ __device__ Float8_ -scaled_vec_conversion(const uint2& a, const float scale) { +scaled_vec_conversion(const uint2& a, float scale) { Float4_ tmp1, tmp2; tmp1 = scaled_vec_conversion(a.x, scale); tmp2 = scaled_vec_conversion(a.y, scale); @@ -477,44 +476,178 @@ scaled_vec_conversion(const uint2& a, const float scale) { return res; } -/* Quantize(HP / scale) => FP8 */ - -// TODO(Hai): vectorized to add - // half -> fp8 template <> __inline__ __device__ uint8_t -scaled_vec_conversion(const uint16_t& a, const float scale) { +scaled_vec_conversion(const uint16_t& a, float scale) { __half_raw tmp; tmp.x = a; - hip_fp8 f8{static_cast(tmp.data) / scale}; + hip_fp8 f8{static_cast(tmp.data / scale)}; return f8.data; } +// halfx2 -> fp8x2 +template <> +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint32_t& a, float scale) { + #ifdef __HIP__MI300__ + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + + union { + uint32_t ui32; + float f; + } f1, f2; + f1.f = tmp.h2r.x.data / scale; + f2.f = tmp.h2r.y.data / scale; + if ((f1.ui32 & 0x7F800000) != 0x7F800000) { + f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); + } + if ((f2.ui32 & 0x7F800000) != 0x7F800000) { + f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); + } + return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); + #else + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + + union { + uint8_t ui8[2]; + uint16_t ui16; + } res; + res.ui8[0] = scaled_vec_conversion(tmp.h2r.x.x, scale); + res.ui8[1] = scaled_vec_conversion(tmp.h2r.y.x, scale); + return res.ui16; + #endif +} + +// half2x2 -> fp8x4 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const uint2& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; +} + +// half2x4 -> fp8x8 +template <> +__inline__ __device__ uint2 scaled_vec_conversion(const uint4& a, + float scale) { + union { + uint2 ui2[2]; + uint4 ui4; + } tmp; + tmp.ui4 = a; + uint2 res; + res.x = scaled_vec_conversion(tmp.ui2[0], scale); + res.y = scaled_vec_conversion(tmp.ui2[1], scale); + return res; +} + // bf16 -> fp8 template <> __inline__ __device__ uint8_t scaled_vec_conversion( - const __nv_bfloat16& a, const float scale) { + const __nv_bfloat16& a, float scale) { hip_fp8 res{__bfloat162float(a) / scale}; return res.data; } +// bf16x2 -> fp8x2 +template <> +__inline__ __device__ uint16_t scaled_vec_conversion( + const __nv_bfloat162& a, float scale) { + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; +} + +// bf16x4 -> fp8x4 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const bf16_4_t& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; +} + +// bf16x8 -> fp8x8 +template <> +__inline__ __device__ uint2 +scaled_vec_conversion(const bf16_8_t& a, float scale) { + uint2 res; + res.x = scaled_vec_conversion({a.x, a.y}, scale); + res.y = scaled_vec_conversion({a.z, a.w}, scale); + return res; +} + // float -> fp8 template <> __inline__ __device__ uint8_t -scaled_vec_conversion(const float& a, const float scale) { - hip_fp8 f8(a / scale); +scaled_vec_conversion(const float& a, float scale) { + hip_fp8 f8(a); return f8.data; } -// fp8x4 -> float4 +// floatx2 -> fp8x2 template <> -__inline__ __device__ float4 -scaled_vec_conversion(const uint32_t& a, const float scale) { - Float4_ tmp = scaled_vec_conversion(a, scale); - float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); - return res; +__inline__ __device__ uint16_t +scaled_vec_conversion(const float2& a, float scale) { + #ifdef __HIP__MI300__ + union { + uint32_t ui32; + float f; + } f1, f2; + f1.f = a.x / scale; + f2.f = a.y / scale; + if ((f1.ui32 & 0x7F800000) != 0x7F800000) { + f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); + } + if ((f2.ui32 & 0x7F800000) != 0x7F800000) { + f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); + } + return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); + #else + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; + #endif +} + +// floatx4 -> fp8x4 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const float4& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion({a.x, a.y}, scale); + tmp.ui16[1] = scaled_vec_conversion({a.z, a.w}, scale); + return tmp.ui32; } #endif // ENABLE_FP8 diff --git a/csrc/quantization/fp8/common.cuh b/csrc/quantization/fp8/common.cuh index 15bd5b6ed156..bdfa43a80e7a 100644 --- a/csrc/quantization/fp8/common.cuh +++ b/csrc/quantization/fp8/common.cuh @@ -5,7 +5,7 @@ #include #include -#ifndef USE_ROCM +#ifdef USE_CUDA_FP8_FORMAT #include using FP8_TYPE = c10::Float8_e4m3fn; C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = @@ -43,7 +43,7 @@ __device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val, } float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX)); -#ifndef USE_ROCM +#ifdef USE_CUDA_FP8_FORMAT return static_cast(r); #else // Use hardware cvt instruction for fp8 on rocm diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index ffa9d44610a7..01b29428131a 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -51,6 +51,9 @@ using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; using float16x4 = __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; typedef float16x4 _Half4; +using float16x2 = + __attribute__((__vector_size__(2 * sizeof(_Float16)))) _Float16; +typedef float16x2 _Half2; typedef struct _Half8 { _Half4 xy[2]; } _Half8; @@ -63,23 +66,34 @@ typedef struct _B16x8 { } _B16x8; using _B8x8 = uint2; +using _B8x4 = int32_t; // used in builtins +using bit8_t = uint8_t; -////// Non temporal load stores /////// +typedef struct _B8x16 { + _B8x8 xy[2]; +} _B8x16; +////// Non temporal loads /////// template -__device__ __forceinline__ T load(T* addr) { - return addr[0]; +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); } -template -__device__ __forceinline__ void store(T value, T* addr) { - addr[0] = value; +__device__ __forceinline__ _B16x8 load_ntmprl_16Byte(const _B16x8* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + auto res = make_float4(dat0, dat1, dat2, dat3); + return *reinterpret_cast<_B16x8*>(&res); } +/////////////////////////////////// template -__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, - const _B16x4& inpB, - const floatx4& inpC) { +__device__ __forceinline__ floatx4 gcn_mfma4x4x4_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { if constexpr (std::is_same::value) { return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid, blgp); @@ -91,6 +105,21 @@ __device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, } } +template +__device__ __forceinline__ floatx4 gcn_mfma16x16x16_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x16f16(inpA, inpB, inpC, absz, cbid, + blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(inpA, inpB, inpC, absz, + cbid, blgp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + template __device__ __forceinline__ float to_float(const T& inp) { if constexpr (std::is_same::value) { @@ -102,6 +131,23 @@ __device__ __forceinline__ float to_float(const T& inp) { } } +template +__device__ __forceinline__ float to_float_b16(const bit16_t& inp) { + union tmpcvt { + bit16_t u; + _Float16 f; + __hip_bfloat16 b; + } t16; + t16.u = inp; + if constexpr (std::is_same::value) { + return (float)t16.f; + } else if constexpr (std::is_same::value) { + return __bfloat162float(t16.b); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + template __device__ __forceinline__ T from_float(const float& inp) { if constexpr (std::is_same::value) { @@ -122,17 +168,22 @@ __device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { } t16; _B16x4 ret; if constexpr (std::is_same::value) { - #pragma unroll - for (int i = 0; i < 4; i++) { - t16.f = (_Float16)inp[i]; - ret[i] = t16.u; - } - return ret; + union h2cvt { + __half2 h2[2]; + _B16x4 b16x4; + } u; + u.h2[0] = __float22half2_rn(make_float2(inp[0], inp[1])); + u.h2[1] = __float22half2_rn(make_float2(inp[2], inp[3])); + return u.b16x4; } else if constexpr (std::is_same::value) { - #pragma unroll for (int i = 0; i < 4; i++) { - t16.b = __float2bfloat16(inp[i]); - ret[i] = t16.u; + union fcvt { + uint32_t u32; + float f32; + } u; + u.f32 = inp[i]; + u.u32 += 0x7fff + ((u.u32 >> 16) & 1); // BF16 RNE with no nan/inf check + ret[i] = uint16_t(u.u32 >> 16); } return ret; } else { @@ -150,21 +201,25 @@ __device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, } t1, t2, res; _B16x4 ret; if constexpr (std::is_same::value) { - #pragma unroll - for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.f = t1.f + t2.f; - ret[i] = res.u; - } - return ret; + union h2cvt { + _B16x4 b16x4; + __half2 h2[2]; + } u1, u2, s; + u1.b16x4 = inp1; + u2.b16x4 = inp2; + s.h2[0] = u1.h2[0] + u2.h2[0]; + s.h2[1] = u1.h2[1] + u2.h2[1]; + return s.b16x4; } else if constexpr (std::is_same::value) { - #pragma unroll for (int i = 0; i < 4; i++) { - t1.u = inp1[i]; - t2.u = inp2[i]; - res.b = t1.b + t2.b; - ret[i] = res.u; + union fcvt { + float f32; + uint32_t i32; + } u1, u2, s; + u1.i32 = uint32_t(inp1[i]) << 16; + u2.i32 = uint32_t(inp2[i]) << 16; + s.f32 = u1.f32 + u2.f32; + ret[i] = uint16_t(s.i32 >> 16); } return ret; } else { @@ -192,15 +247,600 @@ __device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, } } +template +__device__ __forceinline__ _B16x8 +scaled_convert_b8x8_custom(const _B8x8 input, const float scale) { + union { + floatx4 f32x4[2]; + vllm::Float8_ f32x8; + } tmpf8; + tmpf8.f32x8 = vllm::fp8::vec_conversion( + *reinterpret_cast(&input)); + + tmpf8.f32x4[0] *= scale; + tmpf8.f32x4[1] *= scale; + + _B16x8 ret; + ret.xy[0] = from_floatx4(tmpf8.f32x4[0]); + ret.xy[1] = from_floatx4(tmpf8.f32x4[1]); + return ret; +} + +__device__ __forceinline__ floatx4 to_float_fp8x4(const _B8x4& inp) { + #if defined(__gfx90a__) + float4 f32x4 = vllm::fp8::vec_conversion( + *reinterpret_cast(&inp)); + return *reinterpret_cast(&f32x4); + #else // MI3xx+ optimized builtins + const auto f0 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, false); + const auto f1 = __builtin_amdgcn_cvt_pk_f32_fp8(inp, true); + floatx4 ret; + ret[0] = f0[0]; + ret[1] = f0[1]; + ret[2] = f1[0]; + ret[3] = f1[1]; + return ret; + #endif +} + +template +__device__ __forceinline__ _B16x4 from_floatx4_rtz(const floatx4& inp) { + _B16x4 ret; + if constexpr (std::is_same::value) { + union h2cvt { + _Half2 h2[2]; + _B16x4 b16x4; + } u; + u.h2[0] = __builtin_amdgcn_cvt_pkrtz(inp[0], inp[1]); + u.h2[1] = __builtin_amdgcn_cvt_pkrtz(inp[2], inp[3]); + return u.b16x4; + } else if constexpr (std::is_same::value) { + for (int i = 0; i < 4; i++) { + union fcvt { + uint32_t i32; + float f32; + } u; + u.f32 = inp[i]; + ret[i] = uint16_t(u.i32 >> 16); + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) { + union { + _B8x8 b8x8; + _B8x4 b8x4[2]; + } tmp; + tmp.b8x8 = input; + _B16x8 ret; + for (int i = 0; i < 2; i++) { + ret.xy[i] = from_floatx4_rtz(to_float_fp8x4(tmp.b8x4[i])); + } + return ret; +} + /////////////////////////////////////// +// grid (num_seqs, num_partitions,num_kv_heads) +// block (256) +template +__global__ +__launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr, + const float* __restrict__ fp8_out_scale_ptr) { + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane4id = laneid % 4; + const int lane16id = laneid % 16; + const int rowid = laneid / 16; + + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + + constexpr int T_PAR_SIZE = 256; // token partition size set to 256 + + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; + + const int partition_start_token_idx = + partition_idx * T_PAR_SIZE; // partition_size; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + + constexpr int GQA_RATIO4 = DIVIDE_ROUND_UP(GQA_RATIO, 4); + + __shared__ float shared_qk_max[NWARPS][16 + 1]; + __shared__ float shared_exp_sum[NWARPS][16 + 1]; + // shared_logits is used for multiple purposes + __shared__ _B16x4 shared_logits[NWARPS][4][16][4]; + + // for QK mfma16x16, layout is QHead/Tokenx16 across every 16 lanes, 16 Bytes + // HeadElements in each lane, 4x16B HeadElements across 4 rows of warp + constexpr int ROWS_PER_WARP = + WARP_SIZE / 16; // rows refers to 16 lanes; refer dpp terminology + constexpr int CONTIGUOUS_KV_ELEMS_16B_LOAD = + 16 / sizeof(cache_t); // 8 for 16 bit cache type, 16 for 8 bit types + constexpr int QKHE_PER_FETCH = + CONTIGUOUS_KV_ELEMS_16B_LOAD * + ROWS_PER_WARP; // each fetch across a warp fetches these many elements + constexpr int QK_SIZE_RATIO = + sizeof(scalar_t) / + sizeof(cache_t); // 1 for 16bit types, 2 for 8bit types + constexpr int QKHELOOP = HEAD_SIZE / QKHE_PER_FETCH; // 4xQKHE_16B across + // warp + + _B16x8 Qlocal[QKHELOOP] + [QK_SIZE_RATIO]; // note that 16 contiguous elements of Q should + // be fetched per lane for 8 bit cache types : + // QK_SIZE_RATIO changes for this + + constexpr int CONTIGUOUS_SCALAR_ELEMS_16B = 16 / sizeof(scalar_t); + + constexpr int TOKENS_PER_WARP = + T_PAR_SIZE / + NWARPS; // sub partition of tokens per warp for qk calculation + constexpr int TLOOP = + TOKENS_PER_WARP / + 16; // each mfma16x16x16 instruction processes 16 tokens + + _B16x8 Klocal[TLOOP][QKHELOOP]; // can be interpreted as B8x16 for 8 bit + // types + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + const int total_num_heads = gridDim.z * GQA_RATIO; + + // for QK mfma, tokens in multiples of TOKENS_PER_WARP are spread across warps + // each mfma takes QH16xT16x16HE across warp + // repeat mfmas across QKHELOOP dimension + // output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens + // across 4 rows x 4 tokens per lane + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq; + + int kphysical_block_number[TLOOP]; + + // fetch k physical block numbers + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kblock_idx = (kglobal_token_idx < context_len) + ? kglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + kphysical_block_number[token_depth] = block_table_seq[kblock_idx]; + } + + // fetch Q in shared across warps and then write to registers + const int local_qhead_idx = 4 * warpid + rowid; + const int global_qhead_idx = wg_start_head_idx + local_qhead_idx; + const int64_t seq_idx64 = static_cast(seq_idx); + const scalar_t* q_ptr = + q + seq_idx64 * q_stride + global_qhead_idx * HEAD_SIZE; + + const int qhead_element = lane16id * CONTIGUOUS_SCALAR_ELEMS_16B; + if ((local_qhead_idx < GQA_RATIO) && (qhead_element < HEAD_SIZE)) { + const scalar_t* q_fetch_ptr = q_ptr + qhead_element; + const _B16x8* q_fetch_ptr_16B = + reinterpret_cast(q_fetch_ptr); + _B16x8 tmp = *q_fetch_ptr_16B; + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const int offset1 = + lane16id / + 4; // 16 contiguous chunks of head elems are spread across 4x4lanes + shared_logits[offset1][lane4id][local_qhead_idx][0] = tmp.xy[0]; + shared_logits[offset1][lane4id][local_qhead_idx][1] = tmp.xy[1]; + } else { + for (int i = 0; i < 2; i++) { + const int head_elem = lane16id * 2 + i; // element id in _B16x4 terms + const int offset3 = head_elem % 4; + const int offset2 = (head_elem / 4) % 4; + const int offset1 = head_elem / 4 / 4; + shared_logits[offset1][offset2][local_qhead_idx][offset3] = tmp.xy[i]; + } + } + } + __syncthreads(); + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for (int i = 0; i < 2; i++) { + Qlocal[qkhe_depth][qkratio].xy[i] = + shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO] + [2 * qkratio + i]; + } + } + } + + // set to true to enable non temporal kv loads: has some benefit in very high + // batch size cases + constexpr bool NT_KV_LOAD = false; + + constexpr int KX = + 16 / sizeof(cache_t); // vLLM defines x as 16 Bytes of kv cache elements + const cache_t* k_ptr = k_cache + wg_start_kv_head_idx * kv_head_stride; + + const int row_head_elem = rowid * CONTIGUOUS_KV_ELEMS_16B_LOAD; + // fetch K values + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int64_t kblock_number = + static_cast(kphysical_block_number[token_depth]); + const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride; + const int klocal_token_idx = + TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id; + const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx; + const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE; + const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX; + + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + const int head_elem = row_head_elem + qkhe_depth * QKHE_PER_FETCH; + const int offset1 = head_elem / KX; + const int offset2 = head_elem % KX; + const cache_t* k_fetch_ptr = k_ptr3 + offset1 * BLOCK_SIZE * KX + offset2; + const _B16x8* k_fetch_ptr_16B = + reinterpret_cast(k_fetch_ptr); + if constexpr (NT_KV_LOAD) { + Klocal[token_depth][qkhe_depth] = load_ntmprl_16Byte(k_fetch_ptr_16B); + } else { + Klocal[token_depth][qkhe_depth] = *k_fetch_ptr_16B; + } + } + } + + float alibi_slope; + if constexpr (ALIBI_ENABLED) { + const int alibi_head_idx = wg_start_head_idx + lane16id; + alibi_slope = (lane16id < GQA_RATIO) ? alibi_slopes[alibi_head_idx] : 0.f; + } + + constexpr int VTOKENS_PER_LANE = + TOKENS_PER_WARP / ROWS_PER_WARP; // 64/4 = 16 contiguous vtokens per lane + constexpr int VBLOCKS_PER_LANE = + 1; // assumes block size >=16, each lane can correspond to 1 block only + constexpr int VTLOOP = NWARPS; // corresponds to tokens across warps + constexpr int VTLANELOOP = DIVIDE_ROUND_UP( + VTOKENS_PER_LANE, + CONTIGUOUS_KV_ELEMS_16B_LOAD); // optimized for 16B fetches; assumes + // minimum block size is 16 + constexpr int VHELOOP = + HEAD_SIZE / 16 / NWARPS; // head_size distributed across warps; each mfma + // instr works on 16 head elements + + int vphysical_block_number[VTLOOP][VBLOCKS_PER_LANE]; + + // fetch v physical block numbers + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vblock_depth = 0; vblock_depth < VBLOCKS_PER_LANE; + vblock_depth++) { + const int vlocal_token_idx = + vtoken_depth * VTOKENS_PER_LANE * ROWS_PER_WARP + + rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE; + const int vglobal_token_idx = + partition_start_token_idx + vlocal_token_idx; + const int vblock_idx = (vglobal_token_idx < context_len) + ? vglobal_token_idx / BLOCK_SIZE + : last_ctx_block; + vphysical_block_number[vtoken_depth][vblock_depth] = + block_table_seq[vblock_idx]; + } + } + + _B16x8 Vlocal[VTLOOP][VHELOOP] + [VTLANELOOP]; // this can be interpreted as B8x16 too + + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride + + ((rowid * VTOKENS_PER_LANE) % BLOCK_SIZE); + + // v fetches are 16head elems across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + const int vhead_elem = vhe_depth * NWARPS * 16 + warpid * 16 + lane16id; + const cache_t* v_ptr2 = v_ptr + vhead_elem * BLOCK_SIZE; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + const int vblock_depth = 0; + const int64_t vblock_number = static_cast( + vphysical_block_number[vtoken_depth][vblock_depth]); + const cache_t* v_ptr3 = v_ptr2 + (vblock_number * kv_block_stride); + + const cache_t* v_fetch_ptr = + v_ptr3 + vfetch_depth * CONTIGUOUS_KV_ELEMS_16B_LOAD; + const _B16x8* v_fetch_ptr_16B = + reinterpret_cast(v_fetch_ptr); + if constexpr (NT_KV_LOAD) { + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = + load_ntmprl_16Byte(v_fetch_ptr_16B); + } else { + Vlocal[vtoken_depth][vhe_depth][vfetch_depth] = *v_fetch_ptr_16B; + } + } + } + } + + // calculate post qk mfma scale + float scale2 = scale; + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + // multiply by k_scale if fp8 kv cache + scale2 *= *k_scale_ptr; + } + + floatx4 dout[TLOOP]; + // qk mfma + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] = {0}; + for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) { + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + for (int i = 0; i < 2; i++) { + dout[token_depth] = gcn_mfma16x16x16_instr( + Klocal[token_depth][qkhe_depth].xy[i], + Qlocal[qkhe_depth][qkratio].xy[i], dout[token_depth]); + } + } + } else { // kv cache dtype fp8 + auto Ktmp = Klocal[token_depth][qkhe_depth]; + _B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp); + for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) { + _B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio]; + _B16x8 Klocaltmp = convert_b8x8_custom(Ktmp8x8); + for (int i = 0; i < 2; i++) { + dout[token_depth] = gcn_mfma16x16x16_instr( + Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i], + dout[token_depth]); + } + } + } + } + dout[token_depth] *= scale2; + } + + const int qkout_token_idx = + partition_start_token_idx + TOKENS_PER_WARP * warpid + rowid * 4; + + // apply alibi + if constexpr (ALIBI_ENABLED) { + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + const int alibi_offset = local_token_idx - context_len + 1; + for (int i = 0; i < 4; i++) { + dout[token_depth][i] += alibi_slope * (alibi_offset + i); + } + } + } + + // calculate qk_max and exp_sum per warp and write to shared memory + float qk_max = -FLT_MAX; + float exp_sum = 0.0f; + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 4; i++) { + const float tmp = + (local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX; + qk_max = fmaxf(qk_max, tmp); + } + } + + for (int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); + } + + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + const int local_token_idx = qkout_token_idx + token_depth * 16; + for (int i = 0; i < 4; i++) { + const float tmp = (local_token_idx + i < context_len) + ? __expf(dout[token_depth][i] - qk_max) + : 0.0f; + dout[token_depth][i] = tmp; + exp_sum += tmp; + } + } + + for (int mask = WARP_SIZE / 2; mask >= 16; mask /= 2) { + exp_sum += __shfl_xor(exp_sum, mask); + } + + __syncthreads(); // sync before writing to shared mem + + float* shared_mem = reinterpret_cast(shared_logits); + if (laneid < 16) { + const int qk_max_offset = warpid * 16 + lane16id; + shared_mem[qk_max_offset] = qk_max; + const int exp_sum_offset = NWARPS * 16 + qk_max_offset; + shared_mem[exp_sum_offset] = exp_sum; + } + + __syncthreads(); + + // calculate partition qk_max and exp_sum + float partition_qk_max = -FLT_MAX; + float warp_qk_max_exp[NWARPS]; + float partition_exp_sum = 0.0f; + + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = shared_mem[w * 16 + lane16id]; + partition_qk_max = fmaxf(partition_qk_max, warp_qk_max_exp[w]); + } + + for (int w = 0; w < NWARPS; w++) { + warp_qk_max_exp[w] = __expf(warp_qk_max_exp[w] - partition_qk_max); + partition_exp_sum += + shared_mem[NWARPS * 16 + w * 16 + lane16id] * warp_qk_max_exp[w]; + } + + const float inv_sum_scale = + __fdividef(1.f, partition_exp_sum + 1e-6f) * warp_qk_max_exp[warpid]; + + __syncthreads(); + + // disable rtz conversion due to its impact on accuracy. + constexpr bool LOGITS_RTZ_CONVERSION = false; + + // write logits to shared mem + for (int token_depth = 0; token_depth < TLOOP; token_depth++) { + dout[token_depth] *= inv_sum_scale; + if constexpr (LOGITS_RTZ_CONVERSION) { + // use rtz conversion for better performance, with negligible impact on + // accuracy. + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx4_rtz(dout[token_depth]); + } else { + shared_logits[warpid][token_depth][lane16id][rowid] = + from_floatx4(dout[token_depth]); + } + } + // write out partition max_logits and exp_sum + if (threadIdx.x < GQA_RATIO) { + const int qhead_idx = lane16id; + const int offset = seq_idx * total_num_heads * max_num_partitions + + (wg_start_head_idx + qhead_idx) * max_num_partitions + + partition_idx; + max_logits[offset] = partition_qk_max; + exp_sums[offset] = partition_exp_sum; + } + + __syncthreads(); + + constexpr int ELEMS8_ELEMS4_RATIO = 8 / 4; + constexpr int ELEMS16_ELEMS8_RATIO = 16 / 8; + + _B16x4 outelems[VHELOOP]; + // Softmax V mfma + // v layout: 16he across lanes x 16 tokens per lane + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + floatx4 tmp_out = {0}; + + for (int vtoken_depth = 0; vtoken_depth < VTLOOP; vtoken_depth++) { + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { + const int offset = rowid * VTLANELOOP * ELEMS8_ELEMS4_RATIO + + vfetch_depth * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems spread + // across 4 rows + tmp_out = gcn_mfma16x16x16_instr( + Vlocal[vtoken_depth][vhe_depth][vfetch_depth].xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + } + // KV cache fp8 + } else { + for (int vfetch_depth = 0; vfetch_depth < VTLANELOOP; vfetch_depth++) { + _B16x8 Vtmp = Vlocal[vtoken_depth][vhe_depth][vfetch_depth]; + // reinterpret V format as 16 elements of 8bits + _B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp); + for (int j = 0; j < ELEMS16_ELEMS8_RATIO; j++) { + _B8x8 Vtmp8x8 = Vtmp8x16.xy[j]; + _B16x8 Vlocaltmp = convert_b8x8_custom(Vtmp8x8); + for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) { + const int offset = + rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO + + j * ELEMS8_ELEMS4_RATIO + i; + const int offset1 = offset % ROWS_PER_WARP; + const int offset2 = offset / ROWS_PER_WARP; + // output format is 16 qheads across 16 lanes, 16 head elems + // spread across 4 rows + tmp_out = gcn_mfma16x16x16_instr( + Vlocaltmp.xy[i], + shared_logits[vtoken_depth][offset2][lane16id][offset1], + tmp_out); + } + } + } + } + } + // apply post Softmax V mfma v_scale + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + tmp_out *= *v_scale_ptr; + } + outelems[vhe_depth] = from_floatx4(tmp_out); + } + + __syncthreads(); + + // store Softmax-V mfma output to shared mem + for (int vhe_depth = 0; vhe_depth < VHELOOP; vhe_depth++) { + // lane16 id head dimension; rowid head element dimension + shared_logits[warpid][vhe_depth][lane16id][rowid] = outelems[vhe_depth]; + } + + __syncthreads(); + + // write to tmp_out with coalesced writes after reading from shared mem + if (warpid == 0) { + _B16x8 vout[GQA_RATIO4]; + // each lane writes out 16Bytes of tmp_out along head elem dimension + const int head_elem_idx = lane16id * 8; + if (head_elem_idx < HEAD_SIZE) { + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + const int offset1 = (head_elem_idx / 16) % 4; + const int offset2 = head_elem_idx / 16 / NWARPS; + const int offset3 = (head_elem_idx / 4) % 4; + for (int i = 0; i < 2; i++) { + vout[h].xy[i] = + shared_logits[offset1][offset2][local_head_idx][offset3 + i]; + } + } -// grid (num_seqs, num_partitions,num_heads/gqa_ratio) -// block (partition size) + const int hsz_maxp_mult = HEAD_SIZE * max_num_partitions; + scalar_t* out_ptr = out + seq_idx * total_num_heads * hsz_maxp_mult + + partition_idx * HEAD_SIZE; + for (int h = 0; h < GQA_RATIO4; h++) { + const int local_head_idx = 4 * h + rowid; + if (local_head_idx < GQA_RATIO) { + const int out_head_idx = wg_start_head_idx + local_head_idx; + scalar_t* out_ptr2 = out_ptr + out_head_idx * hsz_maxp_mult; + scalar_t* out_ptr3 = out_ptr2 + head_elem_idx; + _B16x8* out_ptr_B16x8 = reinterpret_cast<_B16x8*>(out_ptr3); + *out_ptr_B16x8 = vout[h]; + } + } + } + } +} + +///////////////////////////////////////////////////////////// +// grid (num_seqs, num_partitions, num_kv_heads) +// block (256 : partition size) +// each WG handles 1 partition per sequence template -__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] @@ -215,10 +855,11 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, // max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, - // head_size] - scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr) { + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr, + const float* __restrict__ fp8_out_scale_ptr) { constexpr int NWARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -235,27 +876,35 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( if (partition_start_token_idx >= context_len) { return; } - constexpr int QHLOOP = - DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads, - // total qheads =8, so qhloop is 2 + // every 4 lanes fetch 4 different qheads + // qhloop = num loops over qhead dimension + constexpr int QHLOOP = DIVIDE_ROUND_UP(GQA_RATIO, 4); constexpr int GQA_RATIO4 = 4 * QHLOOP; __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; _B16x8 Qlocal[QHLOOP]; constexpr int x = 16 / sizeof(scalar_t); + // kheloop = num loops over head_size for 16Bytes of Q/dequantized K elements constexpr int KHELOOP = HEAD_SIZE / x; _B16x8 Klocal[KHELOOP]; _B8x8 Klocalb8[KHELOOP]; - constexpr int VHELOOP = - HEAD_SIZE / - WARP_SIZE; // v head_size dimension is distributed across lanes - constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 - // 8xtokens + // for SoftMax-V Gemm, V head_size dimension is distributed across warp + // vheloop = num loops to cover v head size dimension + constexpr int VHELOOP = HEAD_SIZE / WARP_SIZE; + // softmax out has warp_size tokens across warp + // vtloop = num loops to cover warp_size(64) tokens with 16Bytes of + // dequantized V elements + constexpr int VTLOOP = WARP_SIZE / 8; + // num vblocks to cover warp_size(64) v elements + constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; + int vphysical_blocks[VBLOCKS]; _B16x8 Vlocal[VHELOOP][VTLOOP]; _B8x8 Vlocalb8[VHELOOP][VTLOOP]; floatx4 dout[QHLOOP]; float qk_max[QHLOOP]; - #pragma unroll + + __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; + for (int h = 0; h < QHLOOP; h++) { dout[h] = {0}; qk_max[h] = -FLT_MAX; @@ -267,37 +916,37 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const int warp_start_token_idx = partition_start_token_idx + warpid * WARP_SIZE; - if (warp_start_token_idx >= context_len) { // warp out of context + // entire warp out of context + if (warp_start_token_idx >= context_len) { #pragma unroll for (int h = 0; h < GQA_RATIO4; h++) { shared_qk_max[warpid][h] = -FLT_MAX; shared_exp_sum[warpid][h] = 0.0f; } - } else { // warp within context - + // warp within context + } else { const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); const int last_ctx_block = num_context_blocks - 1; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - + // token id within partition const int local_token_idx = threadIdx.x; + // token id within sequence const int global_token_idx = partition_start_token_idx + local_token_idx; + // fetch block number for k const int block_idx = (global_token_idx < context_len) ? global_token_idx / BLOCK_SIZE : last_ctx_block; - // fetch block number for q and k - // int32 physical_block_number leads to overflow when multiplied with - // kv_block_stride + + // fetch k physical block number + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride const int64_t physical_block_number = static_cast(block_table[block_idx]); // fetch vphysical block numbers up front - constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; - int vphysical_blocks[VBLOCKS]; - const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; - #pragma unroll for (int b = 0; b < VBLOCKS; b++) { const int vblock_idx = warp_start_block_idx + b; const int vblock_idx_ctx = @@ -305,12 +954,13 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( vphysical_blocks[b] = block_table[vblock_idx_ctx]; } - // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems + // fetch q elements + // every 4 lanes fetch 8 elems, so warp fetches 8*16 = 128 elems const scalar_t* q_ptr = q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; const _B16x8* q_ptrh8 = reinterpret_cast(q_ptr); const int qhead_elemh8 = laneid / 4; - #pragma unroll + for (int h = 0; h < QHLOOP - 1; h++) { const int qhead_idx = h * 4 + lane4id; Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; @@ -324,22 +974,24 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( Qlocal[QHLOOP - 1].xy[1] = {0}; } + // fetch k elements const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + wg_start_kv_head_idx * kv_head_stride; - const int physical_block_offset = - local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset - // is already cast as _H8 + // physical_block_offset is already cast in terms of _B16x8 + const int physical_block_offset = local_token_idx % BLOCK_SIZE; + + // each K fetch is for 8 elements of cache_t which are later dequantized to + // scalar_t for fp8 if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); - #pragma unroll for (int d = 0; d < KHELOOP; d++) { Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; } } else { + // vllm defines X as 16 Bytes of elements of cache_t constexpr int X = 16 / sizeof(cache_t); const cache_t* k_ptr2 = k_ptr + physical_block_offset * X; - #pragma unroll for (int d = 0; d < KHELOOP; d++) { const int head_elem = d * 8; const int offset1 = head_elem / X; @@ -349,9 +1001,9 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } + // optional alibi fetch float alibi_slope[QHLOOP]; - if (alibi_slopes != nullptr) { - #pragma unroll + if constexpr (ALIBI_ENABLED) { for (int h = 0; h < QHLOOP; h++) { const int qhead_idx = h * 4 + lane4id; alibi_slope[h] = (qhead_idx < GQA_RATIO) @@ -361,10 +1013,10 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + // fetch vcache in kv cache auto case if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block - #pragma unroll for (int b = 0; b < VBLOCKS; b++) { // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride @@ -373,21 +1025,20 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const _B16x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; // iterate over each head elem (within head_size) - #pragma unroll for (int h = 0; h < VHELOOP; h++) { const int head_size_elem = h * WARP_SIZE + laneid; const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; // iterate over all velems within block - #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; } } } - } else { + } // if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) + // fetch vcache in fp8 case + else { // if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); // iterate over each v block - #pragma unroll for (int b = 0; b < VBLOCKS; b++) { // int32 physical_block_number leads to overflow when multiplied with // kv_block_stride @@ -396,164 +1047,153 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const _B8x8* v_ptrh8b = v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; // iterate over each head elem (within head_size) - #pragma unroll for (int h = 0; h < VHELOOP; h++) { const int head_size_elem = h * WARP_SIZE + laneid; const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; // iterate over all velems within block - #pragma unroll for (int d = 0; d < BLOCK_SIZE / 8; d++) { - // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; - const _B8x8 Vlocalb8 = v_ptrh8be[d]; - Vlocal[h][b * BLOCK_SIZE / 8 + d] = - scaled_convert_b8x8(Vlocalb8, *v_scale_ptr); + Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; } } } } + #define QK_mfma(x) \ + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { \ + Klocal[x] = convert_b8x8_custom(Klocalb8[x]); \ + } \ + for (int h = 0; h < QHLOOP; h++) { \ + dout[h] = gcn_mfma4x4x4_instr( \ + Qlocal[h].xy[0], Klocal[x].xy[0], dout[h]); \ + dout[h] = gcn_mfma4x4x4_instr( \ + Qlocal[h].xy[1], Klocal[x].xy[1], dout[h]); \ + } + // QK mfma with Q mfma block broadcast + // Q values across head_size dimension stored across lanes + // K values across head_size dimension are stored depthwise within lane + // Q broadcast with absz, cbid of mfma instruction + QK_mfma(0); + QK_mfma(1); + QK_mfma(2); + QK_mfma(3); + QK_mfma(4); + QK_mfma(5); + QK_mfma(6); + QK_mfma(7); + // below only needed for head size 128 + if constexpr (KHELOOP > 8) { + QK_mfma(8); + QK_mfma(9); + QK_mfma(10); + QK_mfma(11); + QK_mfma(12); + QK_mfma(13); + QK_mfma(14); + QK_mfma(15); + } + #undef QK_mfma + + float scale2 = scale; if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { - #pragma unroll - for (int d = 0; d < KHELOOP; d++) { - Klocal[d] = - scaled_convert_b8x8(Klocalb8[d], *k_scale_ptr); - } + // post mfma scaling for fp8 + scale2 *= *k_scale_ptr; } - #pragma unroll for (int h = 0; h < QHLOOP; h++) { - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[0].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[0].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[1].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[1].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[2].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[2].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[3].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[3].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[4].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[4].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[5].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[5].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[6].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[6].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[7].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[7].xy[1], dout[h]); - if constexpr (KHELOOP > 8) { - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[8].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[8].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[9].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[9].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[10].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[10].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[11].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[11].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[12].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[12].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[13].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[13].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[14].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[14].xy[1], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], - Klocal[15].xy[0], dout[h]); - dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], - Klocal[15].xy[1], dout[h]); - } // KHELOOP>8 - dout[h] *= scale; + dout[h] *= scale2; } - // transpose dout so that 4 token ids are in each lane, and 4 heads are across - // 4 lanes - #pragma unroll + + // transpose dout so that 4 token ids are in each lane, and 4 heads are + // across 4 lanes for (int h = 0; h < QHLOOP; h++) { floatx4 tmp = {0}; - #pragma unroll for (int i = 0; i < 4; i++) { const float B = (lane4id == i) ? 1.0f : 0.0f; - // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f; tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0); - // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0); } dout[h] = tmp; } const int lane4_token_idx = 4 * (global_token_idx >> 2); - const int alibi_offset = lane4_token_idx - context_len + 1; - if (alibi_slopes != nullptr) { - #pragma unroll + + if constexpr (ALIBI_ENABLED) { + const int alibi_offset = lane4_token_idx - context_len + 1; for (int h = 0; h < QHLOOP; h++) { - #pragma unroll for (int i = 0; i < 4; i++) { dout[h][i] += alibi_slope[h] * (alibi_offset + i); } } } - #pragma unroll + const int bpermute_mask = 4 * (16 * ((laneid >> 2) % 4) + lane4id); + for (int h = 0; h < QHLOOP; h++) { qk_max[h] = -FLT_MAX; - #pragma unroll for (int i = 0; i < 4; i++) { qk_max[h] = (lane4_token_idx + i < context_len) ? fmaxf(qk_max[h], dout[h][i]) : qk_max[h]; } - #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { - qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); - } + + // for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + // qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); + // } + // faster version of above code with dpp + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); + + auto tmp = __builtin_amdgcn_ds_bpermute( + bpermute_mask, *reinterpret_cast(&qk_max[h])); + qk_max[h] = *reinterpret_cast(&tmp); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); + asm("v_nop\n v_nop\n v_max_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(qk_max[h]) + : "v"(qk_max[h]), "v"(qk_max[h])); } float exp_sum[QHLOOP]; - #pragma unroll for (int h = 0; h < QHLOOP; h++) { exp_sum[h] = 0.0f; - #pragma unroll for (int i = 0; i < 4; i++) { dout[h][i] = (lane4_token_idx + i < context_len) ? __expf(dout[h][i] - qk_max[h]) : 0.0f; exp_sum[h] += dout[h][i]; } - #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { - exp_sum[h] += __shfl_xor(exp_sum[h], mask); - } + // for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + // exp_sum[h] += __shfl_xor(exp_sum[h], mask); + // } + // faster version of above code with dpp + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); + + auto tmp = __builtin_amdgcn_ds_bpermute( + bpermute_mask, *reinterpret_cast(&exp_sum[h])); + exp_sum[h] = *reinterpret_cast(&tmp); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:4" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); + asm("v_nop\n v_nop\n v_add_f32_dpp %0, %1, %2 row_ror:8" + : "=v"(exp_sum[h]) + : "v"(exp_sum[h]), "v"(exp_sum[h])); } - #pragma unroll - for (int h = 0; h < QHLOOP; h++) { - const int head_idx = 4 * h + lane4id; - shared_qk_max[warpid][head_idx] = qk_max[h]; - shared_exp_sum[warpid][head_idx] = exp_sum[h]; + if (laneid < 4) { + for (int h = 0; h < QHLOOP; h++) { + const int head_idx = 4 * h + lane4id; + shared_qk_max[warpid][head_idx] = qk_max[h]; + shared_exp_sum[warpid][head_idx] = exp_sum[h]; + } } } // warp within context @@ -564,18 +1204,16 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx; - #pragma unroll + // calculate qk_max and exp_sums for partition for (int h = 0; h < QHLOOP; h++) { float global_qk_max = -FLT_MAX; float warp_qk_max[NWARPS]; const int head_idx = 4 * h + lane4id; - #pragma unroll for (int w = 0; w < NWARPS; w++) { warp_qk_max[w] = shared_qk_max[w][head_idx]; global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]); } float global_exp_sum = 0.0f; - #pragma unroll for (int w = 0; w < NWARPS; w++) { global_exp_sum += shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max); @@ -590,99 +1228,94 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( __expf(qk_max[h] - global_qk_max); dout[h] *= global_inv_sum_scale; } + constexpr bool LOGITS_RTZ_CONVERSION = false; // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there // are 4x16 tokens across warp _B16x4 logits[QHLOOP]; - #pragma unroll for (int h = 0; h < QHLOOP; h++) { - logits[h] = from_floatx4(dout[h]); + if constexpr (LOGITS_RTZ_CONVERSION) { + // use rtz for faster performance with no perceivable accuracy loss + logits[h] = from_floatx4_rtz(dout[h]); + } else { + logits[h] = from_floatx4(dout[h]); + } } - __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; - if (warp_start_token_idx >= context_len) { // warp out of context - #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { - #pragma unroll for (int vh = 0; vh < VHELOOP; vh++) { vout_shared[qh][vh][laneid][warpid] = {0}; } } } else { // warp in context - // iterate across heads - #pragma unroll - for (int qh = 0; qh < QHLOOP; qh++) { - // iterate over each v head elem (within head_size) - #pragma unroll - for (int vh = 0; vh < VHELOOP; vh++) { - floatx4 acc = {0}; - // iterate over tokens - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[0], - acc); - acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[1], - acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][5].xy[0], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][5].xy[1], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][6].xy[0], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][6].xy[1], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][7].xy[0], acc); - acc = gcn_mfma_instr(logits[qh], - Vlocal[vh][7].xy[1], acc); - vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc); + #define SV_mfma(x) \ + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { \ + Vlocal[vh][x] = convert_b8x8_custom(Vlocalb8[vh][x]); \ + } \ + for (int qh = 0; qh < QHLOOP; qh++) { \ + acc[qh] = gcn_mfma4x4x4_instr( \ + logits[qh], Vlocal[vh][x].xy[0], acc[qh]); \ + acc[qh] = gcn_mfma4x4x4_instr( \ + logits[qh], Vlocal[vh][x].xy[1], acc[qh]); \ + } + + for (int vh = 0; vh < VHELOOP; vh++) { + floatx4 acc[QHLOOP]; + for (int qh = 0; qh < QHLOOP; qh++) { + acc[qh] = {0}; + } + // SoftMax-V calculation + // logits -> token dimension is distributed across lanes + // Vlocal -> token dimension is depthwise within lane + // uses mfma instruction block broadcast for logits + SV_mfma(0); + SV_mfma(1); + SV_mfma(2); + SV_mfma(3); + SV_mfma(4); + SV_mfma(5); + SV_mfma(6); + SV_mfma(7); + + for (int qh = 0; qh < QHLOOP; qh++) { + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + // post mfma v scale for fp8 + acc[qh] *= *v_scale_ptr; + } + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc[qh]); } } + + #undef SV_mfma } // warp in context __syncthreads(); + // final write to tmp_out after vout accumulation if (warpid == 0) { + const float out_scale = + (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; _B16x4 vout[QHLOOP][VHELOOP]; // iterate across heads - scalar_t* out_ptr; - int out_num_partitions; - if (context_len > partition_size) { - out_num_partitions = max_num_partitions; - out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - partition_idx * HEAD_SIZE; - } else { - out_num_partitions = 1; - out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; - } - #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { - // iterate over each v head elem (within head_size) - #pragma unroll + // iterate over each v head elem (within head_size) for (int vh = 0; vh < VHELOOP; vh++) { vout[qh][vh] = {0}; - #pragma unroll for (int w = 0; w < NWARPS; w++) { vout[qh][vh] = addx4(vout[qh][vh], vout_shared[qh][vh][laneid][w]); } + } + } + + scalar_t* out_ptr = out + + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + const int out_num_partitions = max_num_partitions; + bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); + for (int qh = 0; qh < QHLOOP; qh++) { + for (int vh = 0; vh < VHELOOP; vh++) { const int head_size_elem = vh * WARP_SIZE + laneid; - bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); - #pragma unroll for (int i = 0; i < 4; i++) { const int head_idx = 4 * qh + i; if (head_idx < GQA_RATIO) { @@ -693,15 +1326,15 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } } - } + } // warpid == 0 } // Grid: (num_heads, num_seqs). -template +template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] const float* __restrict__ exp_sums, // [num_seqs, num_heads, // max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, @@ -709,24 +1342,19 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] const int* __restrict__ context_lens, // [num_seqs] - const int max_num_partitions) { + const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; const int context_len = context_lens[seq_idx]; const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); - if (num_partitions == 1) { - // if num_partitions==1, main kernel will write to out directly, no work in - // reduction kernel - return; - } - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; __shared__ float shared_global_exp_sum; - __shared__ float shared_exp_sums[2 * WARP_SIZE]; + // max num partitions supported is warp_size * NPAR_LOOPS + __shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE]; if (warpid == 0) { const float* max_logits_ptr = max_logits + @@ -735,14 +1363,25 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // valid partition is the last valid partition in case threadid > num // partitions - const int valid_partition = - (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1; - const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions) - ? WARP_SIZE + threadIdx.x - : num_partitions - 1; - float reg_max_logit = max_logits_ptr[valid_partition]; - float reg_max_logit2 = max_logits_ptr[valid_partition2]; - float max_logit = fmaxf(reg_max_logit, reg_max_logit2); + int valid_partition[NPAR_LOOPS]; + float reg_max_logit[NPAR_LOOPS]; + const int last_valid_partition = num_partitions - 1; + + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + valid_partition[i] = + (partition_no < num_partitions) ? partition_no : last_valid_partition; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + reg_max_logit[i] = max_logits_ptr[valid_partition[i]]; + } + float max_logit = reg_max_logit[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + max_logit = fmaxf(max_logit, reg_max_logit[i]); + } #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { @@ -753,17 +1392,28 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; - float global_exp_sum = 0.0f; - float rescaled_exp_sum = exp_sums_ptr[valid_partition]; - float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2]; - rescaled_exp_sum *= - (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f; - rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions) - ? expf(reg_max_logit2 - max_logit) - : 0.0f; - global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2; - shared_exp_sums[threadIdx.x] = rescaled_exp_sum; - shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2; + float rescaled_exp_sum[NPAR_LOOPS]; + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + rescaled_exp_sum[i] *= (partition_no < num_partitions) + ? expf(reg_max_logit[i] - max_logit) + : 0.0f; + } + float global_exp_sum = rescaled_exp_sum[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + global_exp_sum += rescaled_exp_sum[i]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + shared_exp_sums[partition_no] = rescaled_exp_sum[i]; + } #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { @@ -840,39 +1490,76 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( } } - if (num_partitions > MAX_NPAR) { - idx = 0; + for (int p = 1; p < NPAR_LOOPS; p++) { + if (num_partitions > p * MAX_NPAR) { + idx = 0; #pragma unroll - for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE; - j += HEAD_SIZE) { - // lastj is last valid partition - const int lastj_offset = - (j < num_partition_offset) ? j : last_partition_offset; - tmps[idx] = tmp_out_ptr[lastj_offset]; - idx++; - } + for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } #pragma unroll - for (int j = 0; j < MAX_NPAR; j++) { - acc += to_float(tmps[j]) * shared_exp_sums[j + MAX_NPAR]; + for (int j = 0; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR]; + } } } const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + const float out_scale = + (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; acc *= inv_global_exp_sum; - scalar_t* out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - out_ptr[threadIdx.x] = from_float(acc); + acc *= out_scale; + OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + if constexpr (std::is_same::value) { + out_ptr[threadIdx.x] = hip_fp8(acc).data; + } else { + out_ptr[threadIdx.x] = from_float(acc); + } } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr, + const float* __restrict__ fp8_out_scale_ptr) { + UNREACHABLE_CODE +} + +template -__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, // head_size/x, block_size, x] @@ -887,19 +1574,20 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, // max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, - // head_size] - scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, const float* k_scale, const float* v_scale) { + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, const float* k_scale, const float* v_scale, + const float* __restrict__ fp8_out_scale_ptr) { UNREACHABLE_CODE } // Grid: (num_heads, num_seqs). -template +template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] const float* __restrict__ exp_sums, // [num_seqs, num_heads, // max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, @@ -907,31 +1595,52 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] const int* __restrict__ context_lens, // [num_seqs] - const int max_num_partitions) { + const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { UNREACHABLE_CODE } #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support -#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ - paged_attention_ll4mi_QKV_kernel \ +#define LAUNCH_CUSTOM_ATTENTION_MFMA16(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma16_kernel \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ - k_scale_ptr, v_scale_ptr); + k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr); + +#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \ + paged_attention_ll4mi_QKV_mfma4_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ + k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr); + +#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ + paged_attention_ll4mi_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ + context_lens_ptr, max_num_partitions, fp8_out_scale_ptr); template + int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD, + bool ALIBI_ENABLED> void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, const int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, int max_context_len, const std::optional& alibi_slopes, - torch::Tensor& k_scale, torch::Tensor& v_scale) { + torch::Tensor& k_scale, torch::Tensor& v_scale, + const c10::optional& fp8_out_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -946,7 +1655,6 @@ void paged_attention_custom_launcher( ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); @@ -955,107 +1663,177 @@ void paged_attention_custom_launcher( KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + // NOTE: fp8_out_scale is optional. + const float* fp8_out_scale_ptr = + fp8_out_scale + ? reinterpret_cast(fp8_out_scale.value().data_ptr()) + : nullptr; + OUTT* out_ptr = reinterpret_cast(out.data_ptr()); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + + // partition size is fixed at 256 since both mfma4 and mfma16 kernels support + // it mfma4 kernel also supports partition size 512 + constexpr int PARTITION_SIZE = 256; const int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); const int gqa_ratio = num_heads / num_kv_heads; assert(num_heads % num_kv_heads == 0); assert(head_size == HEAD_SIZE); - assert(max_num_partitions <= 128); - constexpr int NTHR = PARTITION_SIZE; + constexpr int NTHR = 256; dim3 grid(num_seqs, max_num_partitions, num_kv_heads); dim3 block(NTHR); const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + // mfma4 kernel is faster than mfma16 for gqa_ratio <= 4 switch (gqa_ratio) { case 1: - LAUNCH_CUSTOM_ATTENTION(1); + LAUNCH_CUSTOM_ATTENTION_MFMA4(1); break; case 2: - LAUNCH_CUSTOM_ATTENTION(2); + LAUNCH_CUSTOM_ATTENTION_MFMA4(2); break; case 3: - LAUNCH_CUSTOM_ATTENTION(3); + LAUNCH_CUSTOM_ATTENTION_MFMA4(3); break; case 4: - LAUNCH_CUSTOM_ATTENTION(4); + LAUNCH_CUSTOM_ATTENTION_MFMA4(4); break; case 5: - LAUNCH_CUSTOM_ATTENTION(5); + LAUNCH_CUSTOM_ATTENTION_MFMA16(5); break; case 6: - LAUNCH_CUSTOM_ATTENTION(6); + LAUNCH_CUSTOM_ATTENTION_MFMA16(6); break; case 7: - LAUNCH_CUSTOM_ATTENTION(7); + LAUNCH_CUSTOM_ATTENTION_MFMA16(7); break; case 8: - LAUNCH_CUSTOM_ATTENTION(8); + LAUNCH_CUSTOM_ATTENTION_MFMA16(8); break; case 9: - LAUNCH_CUSTOM_ATTENTION(9); + LAUNCH_CUSTOM_ATTENTION_MFMA16(9); break; case 10: - LAUNCH_CUSTOM_ATTENTION(10); + LAUNCH_CUSTOM_ATTENTION_MFMA16(10); break; case 11: - LAUNCH_CUSTOM_ATTENTION(11); + LAUNCH_CUSTOM_ATTENTION_MFMA16(11); break; case 12: - LAUNCH_CUSTOM_ATTENTION(12); + LAUNCH_CUSTOM_ATTENTION_MFMA16(12); break; case 13: - LAUNCH_CUSTOM_ATTENTION(13); + LAUNCH_CUSTOM_ATTENTION_MFMA16(13); break; case 14: - LAUNCH_CUSTOM_ATTENTION(14); + LAUNCH_CUSTOM_ATTENTION_MFMA16(14); break; case 15: - LAUNCH_CUSTOM_ATTENTION(15); + LAUNCH_CUSTOM_ATTENTION_MFMA16(15); break; case 16: - LAUNCH_CUSTOM_ATTENTION(16); + LAUNCH_CUSTOM_ATTENTION_MFMA16(16); break; default: TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); break; } - // dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG); - // dim3 block2(1024); - // LAUNCH_CUSTOM_ATTENTION2; - - // reduction kernel is only required if max_context_len > partition size, - // otherwise main kernel writes directly to final output - // note there are cases with graphing where max_context_len is the max - // supported by graphing, not the actual max among all the sequences: in that - // case reduction kernel will still run but return immediately - if (max_context_len > PARTITION_SIZE) { - dim3 reduce_grid(num_heads, num_seqs); - dim3 reduce_block(head_size); - paged_attention_ll4mi_reduce_kernel - <<>>( - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, - context_lens_ptr, max_num_partitions); + + dim3 reduce_grid(num_heads, num_seqs); + dim3 reduce_block(head_size); + const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, WARP_SIZE); + // reduction kernel supports upto 8 NPAR_loops * 64 (warp_size) * 256 + // (partition size) = 128K context length + switch (npar_loops) { + case 1: + LAUNCH_CUSTOM_REDUCTION(1); + break; + case 2: + LAUNCH_CUSTOM_REDUCTION(2); + break; + case 3: + LAUNCH_CUSTOM_REDUCTION(3); + break; + case 4: + LAUNCH_CUSTOM_REDUCTION(4); + break; + case 5: + LAUNCH_CUSTOM_REDUCTION(5); + break; + case 6: + LAUNCH_CUSTOM_REDUCTION(6); + break; + case 7: + LAUNCH_CUSTOM_REDUCTION(7); + break; + case 8: + LAUNCH_CUSTOM_REDUCTION(8); + break; + default: + TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); + break; } } -#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ - paged_attention_custom_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, max_context_len, \ - alibi_slopes, k_scale, v_scale); +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ + PSIZE, ALIBI_ENABLED) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ + alibi_slopes, k_scale, v_scale, fp8_out_scale); + +#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + OUTT, PSIZE) \ + if (alibi_slopes) { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ + true); \ + } else { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ + false); \ + } + +#define CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + OUTT) \ + switch (partition_size) { \ + case 256: \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ + 256); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported partition size: ", partition_size); \ + break; \ + } +#if defined(__HIPCC__) && defined(__gfx90a__) + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + if (fp8_out_scale) { \ + TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ + } else { \ + CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \ + } +#else + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + if (fp8_out_scale) { \ + CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + uint8_t); \ + } else { \ + CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \ + } +#endif #define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ switch (block_size) { \ case 16: \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ break; \ case 32: \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ @@ -1074,7 +1852,6 @@ void paged_attention_custom_launcher( TORCH_CHECK(false, "Unsupported head size: ", head_size); \ break; \ } - void paged_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] @@ -1092,7 +1869,8 @@ void paged_attention( int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale) { + torch::Tensor& v_scale, const c10::optional& fp8_out_scale, + int64_t partition_size) { const int head_size = query.size(2); if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Half) { @@ -1122,4 +1900,4 @@ void paged_attention( #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP \ No newline at end of file +#undef DIVIDE_ROUND_UP diff --git a/csrc/rocm/custom.cu b/csrc/rocm/custom.cu new file mode 100644 index 000000000000..fae1b4fbfbe3 --- /dev/null +++ b/csrc/rocm/custom.cu @@ -0,0 +1,78 @@ +#include +#include +#include + +// declare templates for front (cpp) and back (cuda) sides of function: +// template + +void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block); +void LLMM_Silu(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block) { + auto M = in_a.size(0); + auto K = in_a.size(1); + LLGemm_Silu(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, + at::cuda::getCurrentCUDAStream(), rows_per_block); +} + +void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block); + +// template +void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block) { + auto M = in_a.size(0); + auto K = in_a.size(1); + // if (N != in_b.numel()) + // throw std::invalid_argument("Size mismatch A.numel(): " + + // std::to_string(in_a.numel()) + // + ", B.numel(): " + + // std::to_string(in_b.numel())); + + // out_c.resize_({N}); + + // call the kernel function... + LLGemm1(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, + at::cuda::getCurrentCUDAStream(), rows_per_block); +} + +void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K, + const int N, cudaStream_t stream, const int CuCount); + +void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t N_in, const int64_t CuCount) { + auto M = in_a.size(0); + auto K = in_a.size(1); + int N = N_in; + wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, + at::cuda::getCurrentCUDAStream(), CuCount); +} + +void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int solidx); + +void LLZZ(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, + const int64_t solidx = 0) { + auto M = in_a.size(0); + auto K = in_a.size(1); + + LLGemmZZ(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, + at::cuda::getCurrentCUDAStream(), solidx); +} +// instantiate the CPP template for T=float: +// template void AddGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor +// out_c); + +void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows, + int numAColumns, int numBRows, int numBColumns, int numCRows, + int numCColumns, cudaStream_t stream); + +void MMCustomGPU(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c) { + auto matA_sizes{in_a.sizes()}; + auto matB_sizes{in_b.sizes()}; + auto matO_sizes{out_c.sizes()}; + MMGPUKernel(in_a.data_ptr(), in_b.data_ptr(), + out_c.data_ptr(), matA_sizes[0], matA_sizes[1], + matB_sizes[0], matB_sizes[1], matO_sizes[0], matO_sizes[1], + at::cuda::getCurrentCUDAStream()); +} diff --git a/csrc/rocm/custom_kernels.cu b/csrc/rocm/custom_kernels.cu new file mode 100644 index 000000000000..ba90b3f75a07 --- /dev/null +++ b/csrc/rocm/custom_kernels.cu @@ -0,0 +1,1309 @@ +#include +#include +#include +#include +#include "cuda_compat.h" + +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ + defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300_MI250__ +#endif + +#if defined(NDEBUG) + #undef NDEBUG + #include + #define UNREACHABLE_CODE assert(false); + #define NDEBUG +#else + #define UNREACHABLE_CODE assert(false); +#endif + +template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); +} + +__device__ __forceinline__ float4 load_ntmprl(const float4* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + // auto dat0 = *(addr_alias); + // auto dat1 = *(addr_alias+1); + // auto dat2 = *(addr_alias+2); + // auto dat3 = *(addr_alias+3); + return make_float4(dat0, dat1, dat2, dat3); +} + +// TBlock fetches entire rows of A, and entire col of B (K dimension); assume +// N=1 for time being grid is M/A_NUM_ROWS blocks +template +__global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c, + const int K) { + __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; + const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8; + const int threadid = threadIdx.x; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int num_warps = blockDim.x / WARP_SIZE; + const int qwarpid = threadid / 16; + const int qthreadid = threadid % 16; + float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; + __half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; + float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; + float acc[NUM_A_ROWS_PER_BLOCK] = {0.0}; + __half2 acch2; + __half2 oval; + + // As we later use warp shuffle operations, we may have more threads in the + // block than the actual available data, hence the if guard here. + if (threadid * 8 < K) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + // rowA_elem4[i] holds 8 * half numbers seen as a single float4. + rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]); + } + } + + colB_elem4x = bf4[threadid * 4 + 0]; + colB_elem4y = bf4[threadid * 4 + 1]; + colB_elem4z = bf4[threadid * 4 + 2]; + colB_elem4w = bf4[threadid * 4 + 3]; + + __half2 Af2; + __half2 Bf2; + float2 S; + + auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4); + __half2* ah2lptr; + +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + // Multiply-add on 8 half. + ah2lptr = Ah2ptr + i * 4; + Af2 = *(ah2lptr); + acch2 = __hmul2(Af2, colB_elem4x); + Af2 = *(ah2lptr + 1); + acch2 = __hfma2(Af2, colB_elem4y, acch2); + Af2 = *(ah2lptr + 2); + acch2 = __hfma2(Af2, colB_elem4z, acch2); + Af2 = *(ah2lptr + 3); + acch2 = __hfma2(Af2, colB_elem4w, acch2); + S = __half22float2(acch2); + + // See comment above concerning the if guard. + if (threadid * 8 < K) { + acc[i] = S.x + S.y; // accumulation on float + } + } + +// all reduce across warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + acc[i] += __shfl_xor(acc[i], mask); + } + } + + // Warp leaders store the data to shared memory. + if (lane < NUM_A_ROWS_PER_BLOCK) { + red_smem[lane][warp] = acc[lane]; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + if (qwarpid < NUM_A_ROWS_PER_BLOCK) { + acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f; +#pragma unroll + for (int mask = 16 / 2; mask >= 1; mask /= 2) { + acc[qwarpid] += __shfl_xor(acc[qwarpid], mask); + } + float oval2 = __shfl_xor(acc[qwarpid], 16); + + if (threadid % WARP_SIZE == 0 or threadid % WARP_SIZE == 32) { + oval = __float22half2_rn(make_float2(acc[qwarpid], oval2)); + c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval; + } + } +} + +// define the kernel calling code: +// template +void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block = 4) { + float4* af4 = reinterpret_cast(in_a); + auto* bf4 = reinterpret_cast<__half2*>(in_b); + auto* c = reinterpret_cast<__half2*>(out_c); + + // NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle + // operations. + const int NUM_THREADS = + K * 2 / 16 % WARP_SIZE == 0 + ? K * 2 / 16 + : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE); + + int NUM_BLOCKS = M / rows_per_block; + + if (rows_per_block == 2) { + LLGemm1_kernel<2><<>>(af4, bf4, c, K); + } else if (rows_per_block == 4) { + LLGemm1_kernel<4><<>>(af4, bf4, c, K); + } else if (rows_per_block == 8) { + LLGemm1_kernel<8><<>>(af4, bf4, c, K); + } else if (rows_per_block == 16) { + LLGemm1_kernel<16><<>>(af4, bf4, c, K); + } else { + NUM_BLOCKS = M / 4; + LLGemm1_kernel<4><<>>(af4, bf4, c, K); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} + +// instantiate the kernel template for T=float: +// template void AddGPUKernel(float *in_a, float *in_b, float *out_c, +// const int M, const int K, cudaStream_t stream); + +const unsigned int TILE_WIDTH = 32; + +// Compute C = A * B +__global__ void matrixMultiplyShared(float* A, float* B, float* C, int numARows, + int numAColumns, int numBRows, + int numBColumns, int numCRows, + int numCColumns) { + __shared__ float sA[TILE_WIDTH][TILE_WIDTH]; // Tile size of 32x32 + __shared__ float sB[TILE_WIDTH][TILE_WIDTH]; + + int Row = blockDim.y * blockIdx.y + threadIdx.y; + int Col = blockDim.x * blockIdx.x + threadIdx.x; + float Cvalue = 0.0; + sA[threadIdx.y][threadIdx.x] = 0.0; + sB[threadIdx.y][threadIdx.x] = 0.0; + + for (int ph = 0; ph < (((numAColumns - 1) / TILE_WIDTH) + 1); ph++) { + if ((Row < numARows) && (threadIdx.x + (ph * TILE_WIDTH)) < numAColumns) { + sA[threadIdx.y][threadIdx.x] = + A[(Row * numAColumns) + threadIdx.x + (ph * TILE_WIDTH)]; + } else { + sA[threadIdx.y][threadIdx.x] = 0.0; + } + if (Col < numBColumns && (threadIdx.y + ph * TILE_WIDTH) < numBRows) { + sB[threadIdx.y][threadIdx.x] = + B[(threadIdx.y + ph * TILE_WIDTH) * numBColumns + Col]; + } else { + sB[threadIdx.y][threadIdx.x] = 0.0; + } + __syncthreads(); + for (int j = 0; j < TILE_WIDTH; ++j) { + Cvalue += sA[threadIdx.y][j] * sB[j][threadIdx.x]; + } + } + if (Row < numCRows && Col < numCColumns) { + C[Row * numCColumns + Col] = Cvalue; + } +} + +void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows, + int numAColumns, int numBRows, int numBColumns, int numCRows, + int numCColumns, cudaStream_t stream) { + // Initialize the grid and block dimensions + dim3 dimBlock(TILE_WIDTH, TILE_WIDTH, 1); + dim3 dimGrid((numCColumns / TILE_WIDTH) + 1, (numCRows / TILE_WIDTH) + 1, 1); + //@@ Launch the GPU Kernel here + matrixMultiplyShared<<>>( + in_a, in_b, out_c, numARows, numAColumns, numBRows, numBColumns, numCRows, + numCColumns); + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} + +template +__global__ __launch_bounds__(512) void HGEMV_WFPerRow( + int m, int n, const _Float16* A, int lda, const _Float16* x, _Float16* y) { + int num_row_per_block = CTA / nThreads_per_row; + int row_id = (blockIdx.x * num_row_per_block + threadIdx.y) * MT0; + int inc = (gridDim.x * num_row_per_block) * MT0; + + while (row_id < m) { + float2 sum2[MT0]; + +#pragma unroll + for (int i = 0; i < MT0; ++i) { + sum2[i] = {0.0, 0.0}; + } + + for (int j = threadIdx.x; j < n; j += (nThreads_per_row * MT1)) { + bool is_active = j < n; + if (is_active) { + float2 x2[MT1 >> 1]; +#pragma unroll + for (int offset = 0; offset < MT1; offset += 2) { + x2[offset >> 1] = {x[j + nThreads_per_row * offset], + x[j + nThreads_per_row * (offset + 1)]}; + } + float2 a2[MT0][MT1 >> 1]; +#pragma unroll + for (int i = 0; i < MT0; i++) { +#pragma unroll + for (int offset = 0; offset < MT1; offset += 2) { + a2[i][offset >> 1] = { + A[(row_id + i) * n + j + nThreads_per_row * offset], + A[(row_id + i) * n + j + nThreads_per_row * (offset + 1)]}; + } + } + +#pragma unroll + for (int i = 0; i < MT0; i++) { +#pragma unroll + for (int offset = 0; offset < (MT1 >> 1); offset++) { + sum2[i] += a2[i][offset] * x2[offset]; + } + } + } + } + float sum[MT0]; +#pragma unroll + for (int i = 0; i < MT0; i++) { + sum[i] = sum2[i].x + sum2[i].y; + } + +#pragma unroll + for (int i = 0; i < MT0; i++) { +#pragma unroll + for (int offset = nThreads_per_row >> 1; offset >= 1; + offset = offset >> 1) { + sum[i] += __shfl_down(sum[i], offset, nThreads_per_row); + } + } + if (threadIdx.x == 0) { +#pragma unroll + for (int i = 0; i < MT0; i++) { + y[row_id + i] = sum[i]; + } + } + row_id += inc; + } +} + +void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int solidx = 0) { + // m -> M, n-> K + dim3 grid(1024); + dim3 block(64, 8); + if (solidx == 0) { + HGEMV_WFPerRow<64, 512, 4, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } else if (solidx == 1) { + HGEMV_WFPerRow<64, 512, 2, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } else if (solidx == 2) { + HGEMV_WFPerRow<64, 512, 1, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } else { + HGEMV_WFPerRow<64, 512, 4, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} + +///////////////////////////////////////////// + +#define DTYPE half + +__device__ __forceinline__ int mindiv(int N, int div1, int div2) { + int nPrRnd = div1 * div2; + int rnds0 = N / nPrRnd; + nPrRnd -= div1 * 3; + int rnds3 = N / nPrRnd; + nPrRnd -= div1; + int rnds4 = N / nPrRnd; + nPrRnd -= div1; + int rnds5 = N / nPrRnd; + nPrRnd -= div1; + int rnds6 = N / nPrRnd; + nPrRnd -= div1; + int rnds7 = N / nPrRnd; + nPrRnd -= div1; + int rnds8 = N / nPrRnd; + nPrRnd -= div1; + int rnds9 = N / nPrRnd; + nPrRnd -= div1; + int rtn = div2; + if (rnds0 == rnds3) rtn = div2 - 3; + if (rnds0 == rnds4) rtn = div2 - 4; + if (rnds0 == rnds5) rtn = div2 - 5; + if (rnds0 == rnds6) rtn = div2 - 6; + if (rnds0 == rnds7) rtn = div2 - 7; + if (rnds0 == rnds8) rtn = div2 - 8; + if (rnds0 == rnds9) rtn = div2 - 9; + return rtn; +} + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets cases where A[] fits LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + // uint32_t commitColumn[YTILE]; + // for (uint32_t i = 0; i < YTILE; i++) { + // commitColumn[i] = 1; + //} + + // It's worth trying to load-balance... + int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + // if (n < N && (n + YTILE) >= N) { + // uint32_t startColumn = N - YTILE; + // for (uint32_t i = 0; i < (n - startColumn); i++) { + // commitColumn[i] = 0; + // } + // n = startColumn; + //} + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (n < N) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + // if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + // else + // bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t m = 0; m < M; m++) { + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + // if (n < N && (n + YTILE) >= N) { + // uint32_t startColumn = N - YTILE; + // for (uint32_t i = 0; i < (n - startColumn); i++) { + // commitColumn[i] = 0; + // } + // n = startColumn; + //} + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets cases where A[] marginally exceeds LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + // It's worth trying to load-balance... + int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (n < N) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + bigType bigB8[UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t m = 0; m < M; m++) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} + +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets big A[] cases, where it is much larger than LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + // It's worth trying to load-balance... + int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + if (threadIdx.y >= _WvPrGrp) return; + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + #define PCML + #ifndef PCML + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + #endif + + #define TUC (THRDS * UNRL * A_CHUNK) + uint32_t kBase = 0; + // find biggest k size that fits in LDS + uint32_t kFit = (32 * 1024) / M; + // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple + // of TUC + kFit = (kFit % TUC == 0) + ? kFit + : (kFit - kFit % TUC); // round up to multiple of TUC + // if (kFit == 0) kFit = TUC; + kFit = min(kFit, K); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + #ifdef PCML + int YW = (YTILE * _WvPrGrp); + uint32_t Nrndp = (N % YW == 0) ? N : (N - N % YW + YW); + while (n < Nrndp) { + #else + while (n < N) { + #endif + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + bigType bigB8[UNRL]; + bigType bigB9[UNRL]; + bigType bigB10[UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + #ifdef PCML + if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS + if (k1 != 0) kBase += kFit; + __syncthreads(); + for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) { + uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + if (kBase + kOff >= K) break; + if (kOff >= kFit) break; + for (uint32_t m = 0; m < M; m++) { + uint32_t k_in = kBase + m * K + kOff; + uint32_t k_ot = m * kFit + kOff; + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + } + } + __syncthreads(); + } + if (n >= N) continue; + #endif + + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + #ifdef PCML + bigA[m][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * m]))); + #else + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + #endif + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + #pragma unroll + for (uint32_t m = 0; m < M; m++) { + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + } + } + } + } + + #ifdef PCML + if (n >= N) { + n += CuCount * _WvPrGrp * YTILE; + kBase = 0; + continue; + } + #endif + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + kBase = 0; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, + const int K_in, const int N_in, cudaStream_t stream, + const int CuCount = 0) { + dim3 grid(CuCount); + half* af4 = reinterpret_cast(in_a); + const half* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); + +#define WVSPLTK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + /*wvSpltK_hf:*/ \ + if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + wvSpltK_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ + <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ + wvSpltK_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + } else { \ + wvSpltK_hf_big_<64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ + <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + } \ + } + + switch (N_in) { + case 1: + WVSPLTK(16, 2, 2, 2, 2, 2, 2, 1) // MI308 + break; + case 2: + WVSPLTK(16, 2, 2, 2, 2, 2, 2, 2) // MI308 + break; + case 3: + WVSPLTK(16, 4, 7, 7, 1, 1, 1, 3) // MI308 + break; + case 4: + WVSPLTK(16, 4, 7, 7, 1, 1, 1, 4) // MI308 + break; + default: + throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + + "," + std::to_string(K_in) + "," + + std::to_string(N_in)); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) { + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); + } +} \ No newline at end of file diff --git a/csrc/rocm/fused_kernels.cu b/csrc/rocm/fused_kernels.cu new file mode 100644 index 000000000000..4f3eea456294 --- /dev/null +++ b/csrc/rocm/fused_kernels.cu @@ -0,0 +1,195 @@ +#include +#include +#include +#include + +constexpr int WARP_SIZE = 64; + +template +__device__ __forceinline__ T silu(const T& x) { + // x * sigmoid(x) + return (T)(((float)x) / (1.0f + expf((float)-x))); +} + +template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); +} + +__device__ __forceinline__ float4 load_ntmprl(const float4* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + // auto dat0 = *(addr_alias); + // auto dat1 = *(addr_alias+1); + // auto dat2 = *(addr_alias+2); + // auto dat3 = *(addr_alias+3); + return make_float4(dat0, dat1, dat2, dat3); +} + +// TBlock fetches entire rows of A, and entire col of B (K dimension); assume +// N=1 for time being grid is M/A_NUM_ROWS blocks +template +__global__ void LLGemm_Silu_kernel(float4* af4, __half2* bf4, _Float16* c, + const int d) { + __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; + const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 * blockDim.x; + const int row_addr_d = row_addr + d * blockDim.x; + // int row_addr_1 = row_addr + CUDA_NUM_THREADS; + // int row_addr_2 = row_addr_1 + CUDA_NUM_THREADS; + // int row_addr_3 = row_addr_2 + CUDA_NUM_THREADS; + const int threadid = threadIdx.x; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int num_warps = blockDim.x / WARP_SIZE; + const int qwarpid = threadid / 16; + const int qthreadid = threadid % 16; + float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; + // float4 colB_elem4; + __half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; + float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; + float acc[NUM_A_ROWS_PER_BLOCK]; //= 0.0; + __half2 acch2; + __half2 oval; + + // rowA_elem4 = af4[row_addr + threadid]; + //__syncthreads(); + // rowA_elem4_1 = af4[row_addr_1 + threadid]; + // rowA_elem4_2 = af4[row_addr_2 + threadid]; + // rowA_elem4_3 = af4[row_addr_3 + threadid]; +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK / 2; i++) { + rowA_elem4[2 * i] = load_ntmprl(&af4[row_addr + i * blockDim.x + threadid]); + rowA_elem4[2 * i + 1] = + load_ntmprl(&af4[row_addr_d + i * blockDim.x + threadid]); + // rowA_elem4[i] = af4[row_addr + i*blockDim.x + threadid]; + //__syncthreads(); + } + colB_elem4x = bf4[threadid * 4 + 0]; + colB_elem4y = bf4[threadid * 4 + 1]; + colB_elem4z = bf4[threadid * 4 + 2]; + colB_elem4w = bf4[threadid * 4 + 3]; + + // __syncthreads(); + __half2 Af2; + __half2 Bf2; + float2 S; + // auto Bh2ptr = reinterpret_cast<__half2 *>(&colB_elem4); + // auto Bf2x = *Bh2ptr; + // auto Bf2y = *(Bh2ptr+1); + // auto Bf2z = *(Bh2ptr+2); + // auto Bf2w = *(Bh2ptr+3); + auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4); + __half2* ah2lptr; +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + ah2lptr = Ah2ptr + i * 4; + Af2 = *(ah2lptr); + acch2 = __hmul2(Af2, colB_elem4x); + Af2 = *(ah2lptr + 1); + acch2 = __hfma2(Af2, colB_elem4y, acch2); + Af2 = *(ah2lptr + 2); + acch2 = __hfma2(Af2, colB_elem4z, acch2); + Af2 = *(ah2lptr + 3); + acch2 = __hfma2(Af2, colB_elem4w, acch2); + S = __half22float2(acch2); + acc[i] = S.x + S.y; + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + acc[i] += __shfl_xor(acc[i], mask); + } + } + + // Warp leaders store the data to shared memory. + // if (lane == 0) { + // #pragma unroll + // for (int i=0; i= 1; mask /= 2) { + // #pragma unroll + // for (int i=0; i +void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block = 4) { + float4* af4 = reinterpret_cast(in_a); + auto* bf4 = reinterpret_cast<__half2*>(in_b); + auto* c = reinterpret_cast<_Float16*>(out_c); + const int d = M / 2; + const int NUM_THREADS = K * 2 / 16; + int NUM_BLOCKS = M / rows_per_block; + if (rows_per_block == 2) { + LLGemm_Silu_kernel<2> + <<>>(af4, bf4, c, d); + } else if (rows_per_block == 4) { + LLGemm_Silu_kernel<4> + <<>>(af4, bf4, c, d); + } else if (rows_per_block == 8) { + LLGemm_Silu_kernel<8> + <<>>(af4, bf4, c, d); + } else if (rows_per_block == 16) { + LLGemm_Silu_kernel<16> + <<>>(af4, bf4, c, d); + } else { + NUM_BLOCKS = M / 4; + LLGemm_Silu_kernel<4> + <<>>(af4, bf4, c, d); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index ba161951772a..59bd28e3bc12 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -2,6 +2,15 @@ #include +void LLMM_Silu(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block); + +void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block); + +void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t N_in, const int64_t CuCount); + void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -11,4 +20,6 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, int64_t max_context_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale); + torch::Tensor& v_scale, + const c10::optional& fp8_out_scale, + int64_t partition_size); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index a5d2e2f97a3e..50640a96725e 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -13,6 +13,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { // vLLM custom ops for rocm + rocm_ops.def( + "LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> " + "()"); + rocm_ops.impl("LLMM1", torch::kCUDA, &LLMM1); + rocm_ops.def( + "LLMM_Silu(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) " + "-> ()"); + rocm_ops.impl("LLMM_Silu", torch::kCUDA, &LLMM_Silu); // Custom attention op // Compute the attention between an input query and the cached @@ -27,8 +35,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " int max_context_len," " Tensor? alibi_slopes," " str kv_cache_dtype," - " Tensor k_scale, Tensor v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale," + " Tensor? fp8_out_scale," + " int partition_size) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); + rocm_ops.def( + "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," + " int CuCount) -> ()"); + rocm_ops.impl("wvSpltK", torch::kCUDA, &wvSpltK); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 186e9c0e81b7..235373240ac3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -33,7 +33,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," - " int blocksparse_head_sliding_step) -> ()"); + " int blocksparse_head_sliding_step," + " int num_threads) -> ()"); ops.impl("paged_attention_v1", torch::kCUDA, &paged_attention_v1); // PagedAttention V2. @@ -47,7 +48,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," - " int blocksparse_head_sliding_step) -> ()"); + " int blocksparse_head_sliding_step," + " int num_threads) -> ()"); ops.impl("paged_attention_v2", torch::kCUDA, &paged_attention_v2); // Activation ops @@ -55,6 +57,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + // Activation function used in SwiGLU. + ops.def("scaled_silu_and_mul(Tensor! out, Tensor input, Tensor scale) -> ()"); + ops.impl("scaled_silu_and_mul", torch::kCUDA, &scaled_silu_and_mul); + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); @@ -500,25 +506,39 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { &get_max_shared_memory_per_block_device_attribute); } -#ifndef USE_ROCM TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) { // Custom all-reduce kernels custom_ar.def( - "init_custom_ar(int[] ipc_tensors, Tensor rank_data, " - "int rank, bool full_nvlink) -> int"); + "init_custom_ar(Tensor meta, Tensor rank_data, " + "str[] handles, int[] offsets, int rank, " + "bool full_nvlink) -> int"); custom_ar.impl("init_custom_ar", torch::kCUDA, &init_custom_ar); + + custom_ar.def("all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); + custom_ar.impl("all_reduce_reg", torch::kCUDA, &all_reduce_reg); + custom_ar.def( - "all_reduce(int fa, Tensor inp, Tensor! out, int reg_buffer, " - "int reg_buffer_sz_bytes) -> ()"); - custom_ar.impl("all_reduce", torch::kCUDA, &all_reduce); + "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> " + "()"); + custom_ar.impl("all_reduce_unreg", torch::kCUDA, &all_reduce_unreg); custom_ar.def("dispose", &dispose); custom_ar.def("meta_size", &meta_size); - custom_ar.def("register_buffer", ®ister_buffer); + custom_ar.def( + "register_buffer(int fa, Tensor t, str[] handles, " + "int[] offsets) -> ()"); + custom_ar.impl("register_buffer", torch::kCUDA, ®ister_buffer); + custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); custom_ar.def("register_graph_buffers", ®ister_graph_buffers); -} +#ifdef USE_ROCM + custom_ar.def("allocate_meta_buffer", &allocate_meta_buffer); + custom_ar.impl("allocate_meta_buffer", torch::kCUDA, &allocate_meta_buffer); + custom_ar.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle); + custom_ar.impl("get_meta_buffer_ipc_handle", torch::kCPU, + &get_meta_buffer_ipc_handle); #endif +} REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/type_convert.cuh b/csrc/type_convert.cuh index 21b9d0ae515d..47b3a767d355 100644 --- a/csrc/type_convert.cuh +++ b/csrc/type_convert.cuh @@ -49,7 +49,7 @@ struct _typeConvert { } }; - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) // CUDA_ARCH < 800 does not have BF16 support // TODO: Add in ROCm support once public headers handle bf16 maturely template <> @@ -162,4 +162,4 @@ struct alignas(16) _f16Vec { return result; } }; -} // namespace vllm \ No newline at end of file +} // namespace vllm diff --git a/docs/dev-docker/README.md b/docs/dev-docker/README.md new file mode 100644 index 000000000000..1ce6da2da95d --- /dev/null +++ b/docs/dev-docker/README.md @@ -0,0 +1,491 @@ +# vllm FP8 Latency and Throughput benchmarks with vLLM on the AMD Instinctā„¢ MI300X accelerator + +Documentation for Inferencing with vLLM on AMD Instinctā„¢ MI300X platforms. + +## Overview + +vLLM is a toolkit and library for large language model (LLM) inference and serving. It deploys the PagedAttention algorithm, which reduces memory consumption and increases throughput by leveraging dynamic key and value allocation in GPU memory. vLLM also incorporates many recent LLM acceleration and quantization algorithms, such as fp8 GeMM, fp8 KV cache, continuous batching, flash attention, hip graph, tensor parallel, GPTQ, AWQ, and token speculation. In addition, AMD implements high-performance custom kernels and modules in vLLM to enhance performance further. + +This documentation includes information for running the popular Llama 3.1 series models from Meta using a pre-built AMD vLLM docker image optimized for an AMD Instinctā„¢ MI300X or MI325X accelerator. The container is publicly available at [AMD Infinity Hub](https://www.amd.com/en/developer/resources/infinity-hub.html) + +The pre-built image includes: + +- ROCmā„¢ 6.3.1 +- vLLM 0.6.6 +- PyTorch 2.6dev (nightly) + +## Pull latest Docker Image + +Pull the most recent validated docker image with `docker pull rocm/vllm-dev:main` + +## What is New + +20250124: +- Fix accuracy issue with 405B FP8 Triton FA +- Fixed accuracy issue with TP8 +20250117: +- [Experimental DeepSeek-V3 and DeepSeek-R1 support](#running-deepseek-v3-and-deepseek-r1) + +## Performance Results + +The data in the following tables is a reference point to help users validate observed performance. It should not be considered as the peak performance that can be delivered by AMD Instinctā„¢ MI300X accelerator with vLLM. See the MLPerf section in this document for information about MLPerf 4.1 inference results. The performance numbers above were collected using the steps below. + +### Throughput Measurements + +The table below shows performance data where a local inference client is fed requests at an infinite rate and shows the throughput client-server scenario under maximum load. + +| Model | Precision | TP Size | Input | Output | Num Prompts | Max Num Seqs | Throughput (tokens/s) | +|-------|-----------|---------|-------|--------|-------------|--------------|-----------------------| +| Llama 3.1 70B (amd/Llama-3.1-70B-Instruct-FP8-KV) | FP8 | 8 | 128 | 2048 | 3200 | 3200 | 15105 | +| | | | 128 | 4096 | 1500 | 1500 | 10505 | +| | | | 500 | 2000 | 2000 | 2000 | 12664 | +| | | | 2048 | 2048 | 1500 | 1500 | 8239 | +| Llama 3.1 405B (amd/Llama-3.1-405B-Instruct-FP8-KV) | FP8 | 8 | 128 | 2048 | 1500 | 1500 | 4065 | +| | | | 128 | 4096 | 1500 | 1500 | 3171 | +| | | | 500 | 2000 | 2000 | 2000 | 2985 | +| | | | 2048 | 2048 | 500 | 500 | 1999 | + +*TP stands for Tensor Parallelism.* + +## Latency Measurements + +The table below shows latency measurement, which typically involves assessing the time from when the system receives an input to when the model produces a result. + +| Model | Precision | TP Size | Batch Size | Input | Output | MI300X Latency (ms) | +|-------|-----------|----------|------------|--------|---------|-------------------| +| Llama 3.1 70B (amd/Llama-3.1-70B-Instruct-FP8-KV) | FP8 | 8 | 1 | 128 | 2048 | 19088.59 | +| | | | 2 | 128 | 2048 | 19610.46 | +| | | | 4 | 128 | 2048 | 19911.30 | +| | | | 8 | 128 | 2048 | 21858.80 | +| | | | 16 | 128 | 2048 | 23537.59 | +| | | | 32 | 128 | 2048 | 25342.94 | +| | | | 64 | 128 | 2048 | 32548.19 | +| | | | 128 | 128 | 2048 | 45216.37 | +| | | | 1 | 2048 | 2048 | 19154.43 | +| | | | 2 | 2048 | 2048 | 19670.60 | +| | | | 4 | 2048 | 2048 | 19976.32 | +| | | | 8 | 2048 | 2048 | 22485.63 | +| | | | 16 | 2048 | 2048 | 25246.27 | +| | | | 32 | 2048 | 2048 | 28967.08 | +| | | | 64 | 2048 | 2048 | 39920.41 | +| | | | 128 | 2048 | 2048 | 59514.25 | +| Llama 3.1 405B (amd/Llama-3.1-70B-Instruct-FP8-KV) | FP8 | 8 | 1 | 128 | 2048 | 51739.70 | +| | | | 2 | 128 | 2048 | 52769.15 | +| | | | 4 | 128 | 2048 | 54557.07 | +| | | | 8 | 128 | 2048 | 56901.86 | +| | | | 16 | 128 | 2048 | 60432.12 | +| | | | 32 | 128 | 2048 | 67353.01 | +| | | | 64 | 128 | 2048 | 81085.33 | +| | | | 128 | 128 | 2048 | 116138.51 | +| | | | 1 | 2048 | 2048 | 52217.76 | +| | | | 2 | 2048 | 2048 | 53227.47 | +| | | | 4 | 2048 | 2048 | 55512.44 | +| | | | 8 | 2048 | 2048 | 59931.41 | +| | | | 16 | 2048 | 2048 | 66890.14 | +| | | | 32 | 2048 | 2048 | 80687.64 | +| | | | 64 | 2048 | 2048 | 108503.12 | +| | | | 128 | 2048 | 2048 | 168845.50 | + +*TP stands for Tensor Parallelism.* + +## Reproducing Benchmarked Results + +### Preparation - Obtaining access to models + +The vllm-dev docker image should work with any model supported by vLLM. When running with FP8, AMD has quantized models available for a variety of popular models, or you can quantize models yourself using Quark. If needed, the vLLM benchmark scripts will automatically download models and then store them in a Hugging Face cache directory for reuse in future tests. Alternatively, you can choose to download the model to the cache (or to another directory on the system) in advance. + +Many HuggingFace models, including Llama-3.1, have gated access. You will need to set up an account at (https://huggingface.co), search for the model of interest, and request access if necessary. You will also need to create a token for accessing these models from vLLM: open your user profile (https://huggingface.co/settings/profile), select "Access Tokens", press "+ Create New Token", and create a new Read token. + +### System optimization + +Before running performance tests you should ensure the system is optimized according to the [ROCm Documentation](https://rocm.docs.amd.com/en/latest/how-to/system-optimization/mi300x.html). In particular, it is important to ensure that NUMA auto-balancing is disabled. + +*Note: Check that NUMA balancing is properly set by inspecting the output of the command below, which should have a value of 0, with, `cat /proc/sys/kernel/numa_balancing`* + +### Launch AMD vLLM Docker + +Download and launch the docker. The HF_TOKEN is required to be set (either here or after launching the container) if you want to allow vLLM to download gated models automatically; use your HuggingFace token in place of `` in the command below: + +```bash +docker run -it --rm --ipc=host --network=host --group-add render \ + --privileged --security-opt seccomp=unconfined \ + --cap-add=CAP_SYS_ADMIN --cap-add=SYS_PTRACE \ + --device=/dev/kfd --device=/dev/dri --device=/dev/mem \ + -e HF_HOME=/data \ + -e HF_TOKEN= \ + -v /data:/data \ + rocm/vllm-dev:main +``` + +Note: The instructions in this document use `/data` to store the models. If you choose a different directory, you will also need to make that change to the host volume mount when launching the docker container. For example, `-v /home/username/models:/data` in place of `-v /data:/data` would store the models in /home/username/models on the host. Some models can be quite large; please ensure that you have sufficient disk space prior to downloading the model. Since the model download may take a long time, you can use `tmux` or `screen` to avoid getting disconnected. + +### Downloading models with huggingface-cli + +If you would like want to download models directly (instead of allowing vLLM to download them automatically), you can use the huggingface-cli inside the running docker container. (remove an extra white space) Login using the token that you created earlier. (Note, it is not necessary to save it as a git credential.) + +```bash +huggingface-cli login +``` + +You can download a model to the huggingface-cache directory using a command similar to the following (substituting the name of the model you wish to download): + +```bash +sudo mkdir -p /data/huggingface-cache +sudo chmod -R a+w /data/huggingface-cache +HF_HOME=/data/huggingface-cache huggingface-cli download meta-llama/Llama-3.1-405B-Instruct --exclude "original/*" +``` + +Alternatively, you may wish to download the model to a specific directory, e.g. so you can quantize the model with Quark: + +```bash +sudo mkdir -p /data/llama-3.1 +sudo chmod -R a+w /data/llama-3.1 +huggingface-cli download meta-llama/Llama-3.1-405B-Instruct --exclude "original/*" --local-dir /data/llama-3.1/Llama-3.1-405B-Instruct +``` + +In the benchmark commands provided later in this document, replace the model name (e.g. `amd/Llama-3.1-405B-Instruct-FP8-KV`) with the path to the model (e.g. `/data/llama-3.1/Llama-3.1-405B-Instruct`) + +### Use pre-quantized models + +AMD has provided [FP8-quantized versions](https://huggingface.co/collections/amd/quark-quantized-ocp-fp8-models-66db7936d18fcbaf95d4405c) of several models in order to make them easier to run on MI300X / MI325X, including: + +- +- +- + +Some models may be private to those who are members of . + +These FP8 quantized checkpoints were generated with AMD’s Quark Quantizer. For more information about Quark, please refer to + +### Quantize your own models + +This is an optional step if you would like to quantize your own model instead of using AMD's pre-quantized models. These instructions use Llama-3.1-405B as an example, but the commands are similar for other models. + +First download the model from to the /data/llama-3.1 directory as described above. + +[Download and install Quark](https://quark.docs.amd.com/latest/install.html) + +Run the quantization script in the example folder using the following command line: + +```bash +# path to quark quantization script +export QUARK_DIR=/data/quark-0.6.0+dba9ca364/examples/torch/language_modeling/llm_ptq/quantize_quark.py +# path to Model +export MODEL_DIR=/data/llama-3.1/Llama-3.1-405B-Instruct +python3 $QUARK_DIR \ +--model_dir $MODEL_DIR \ +--output_dir Llama-3.1-405B-Instruct-FP8-KV \ +--kv_cache_dtype fp8 \ +--quant_scheme w_fp8_a_fp8 \ +--num_calib_data 128 \ +--model_export quark_safetensors \ +--no_weight_matrix_merge \ +--multi_gpu +``` + +Note: the `--multi_gpu` parameter can be omitted for small models that fit on a single GPU. + +## Performance testing with AMD vLLM Docker + +### Performance environment variables + +Some environment variables enhance the performance of the vLLM kernels on the MI300X / MI325X accelerator. See the AMD Instinct MI300X workload optimization guide for more information. + +```bash +export VLLM_USE_TRITON_FLASH_ATTN=0 +``` + +### vLLM engine performance settings + +vLLM provides a number of engine options which can be changed to improve performance. Refer to the [vLLM Engine Args](https://docs.vllm.ai/en/stable/usage/engine_args.html) documentation for the complete list of vLLM engine options. + +Below is a list of a few of the key vLLM engine arguments for performance; these can be passed to the vLLM benchmark scripts: +- **--max-model-len** : Maximum context length supported by the model instance. Can be set to a lower value than model configuration value to improve performance and gpu memory utilization. +- **--max-num-batched-tokens** : The maximum prefill size, i.e., how many prompt tokens can be packed together in a single prefill. Set to a higher value to improve prefill performance at the cost of higher gpu memory utilization. 65536 works well for LLama models. +- **--max-num-seqs** : The maximum decode batch size (default 256). Using larger values will allow more prompts to be processed concurrently, resulting in increased throughput (possibly at the expense of higher latency). If the value is too large, there may not be enough GPU memory for the KV cache, resulting in requests getting preempted. The optimal value will depend on the GPU memory, model size, and maximum context length. +- **--max-seq-len-to-capture** : Maximum sequence length for which Hip-graphs are captured and utilized. It's recommended to use Hip-graphs for the best decode performance. The default value of this parameter is 8K, which is lower than the large context lengths supported by recent models such as LLama. Set this parameter to max-model-len or maximum context length supported by the model for best performance. +- **--gpu-memory-utilization** : The ratio of GPU memory reserved by a vLLM instance. Default value is 0.9. Increasing the value (potentially as high as 0.99) will increase the amount of memory available for KV cache. When running in graph mode (i.e. not using `--enforce-eager`), it may be necessary to use a slightly smaller value of 0.92 - 0.95 to ensure adequate memory is available for the HIP graph. + +### Latency Benchmark + +vLLM's benchmark_latency.py script measures end-to-end latency for a specified model, input/output length, and batch size. + +You can run latency tests for FP8 models with: + +```bash +export VLLM_USE_TRITON_FLASH_ATTN=0 +MODEL=amd/Llama-3.1-405B-Instruct-FP8-KV +BS=1 +IN=128 +OUT=2048 +TP=8 + +python3 /app/vllm/benchmarks/benchmark_latency.py \ + --distributed-executor-backend mp \ + --quantization fp8 \ + --kv-cache-dtype fp8 \ + --dtype float16 \ + --gpu-memory-utilization 0.9 \ + --trust-remote-code \ + --model $MODEL \ + --batch-size $BS \ + --input-len $IN \ + --output-len $OUT \ + --tensor-parallel-size $TP \ + --num-iters-warmup 3 \ + --num-iters 5 \ + --output-json output.json +``` + +For FP16 models, remove `--quantization fp8 --kv-cache-dtype fp8`. + +When measuring models with long context lengths, performance may improve by setting `--max-model-len` to a smaller value. It is important, however, to ensure that the `--max-model-len` is at least as large as the IN + OUT token counts. + +To estimate Time To First Token (TTFT) with the benchmark_latency.py script, set the OUT to 1 token. It is also recommended to use `--enforce-eager` to get a more accurate measurement of the time that it actually takes to generate the first token. (For a more comprehensive measurement of TTFT, use the Online Serving Benchmark.) + +For additional information about the available parameters run: + +```bash +/app/vllm/benchmarks/benchmark_latency.py -h +``` + +### Throughput Benchmark + +vLLM's benchmark_throughput.py script measures offline throughput. It can either use an input dataset or random prompts with fixed input/output lengths. + +You can run latency tests for FP8 models with: + +```bash +export VLLM_USE_TRITON_FLASH_ATTN=0 +MODEL=amd/Llama-3.1-405B-Instruct-FP8-KV +IN=128 +OUT=2048 +TP=8 +PROMPTS=1500 +MAX_NUM_SEQS=1500 + +python3 /app/vllm/benchmarks/benchmark_throughput.py \ + --distributed-executor-backend mp \ + --quantization fp8 \ + --kv-cache-dtype fp8 \ + --dtype float16 \ + --gpu-memory-utilization 0.9 \ + --trust-remote-code \ + --num-scheduler-steps 10 \ + --enable-chunked-prefill False \ + --model $MODEL \ + --max-model-len 8192 \ + --max-num-batched-tokens 131072 \ + --max-seq-len-to-capture 131072 \ + --input-len $IN \ + --output-len $OUT \ + --tensor-parallel-size $TP \ + --num-prompts $PROMPTS \ + --max-num-seqs $MAX_NUM_SEQS \ + --output-json output.json +``` + +For FP16 models, remove `--quantization fp8 --kv-cache-dtype fp8`. + +When measuring models with long context lengths, performance may improve by setting `--max-model-len` to a smaller value (8192 in this example). It is important, however, to ensure that the `--max-model-len` is at least as large as the IN + OUT token counts. + +It is important to tune vLLM’s --max-num-seqs value to an appropriate value depending on the model and input/output lengths. Larger values will allow vLLM to leverage more of the GPU memory for KV Cache and process more prompts concurrently. But if the value is too large, the KV cache will reach its capacity and vLLM will have to cancel and re-process some prompts. Suggested values for various models and configurations are listed below. + +For models that fit on a single GPU, it is usually best to run with `--tensor-parallel-size 1`. Requests can be distributed across multiple copies of vLLM running on different GPUs. This will be more efficient than running a single copy of the model with `--tensor-parallel-size 8`. (Note: the benchmark_throughput.py script does not include direct support for using multiple copies of vLLM) + +For optimal performance, the PROMPTS value should be a multiple of the MAX_NUM_SEQS value -- for example, if MAX_NUM_SEQS=1500 then the PROMPTS value could be 1500, 3000, etc. If PROMPTS is smaller than MAX_NUM_SEQS then there won’t be enough prompts for vLLM to maximize concurrency. + +For additional information about the available parameters run: + +```bash +python3 /app/vllm/benchmarks/benchmark_throughput.py -h +``` + +### Online Serving Benchmark + +Benchmark Llama-3.1-70B with input 4096 tokens, output 512 tokens and tensor parallelism 8 as an example, + +```bash +export VLLM_USE_TRITON_FLASH_ATTN=0 +vllm serve amd/Llama-3.1-70B-Instruct-FP8-KV \ + --swap-space 16 \ + --disable-log-requests \ + --quantization fp8 \ + --kv-cache-dtype fp8 \ + --dtype float16 \ + --max-model-len 8192 \ + --tensor-parallel-size 8 \ + --max-num-batched-tokens 65536 \ + --gpu-memory-utilization 0.99 \ + --num_scheduler-steps 10 +``` + +Change port (for example --port 8005) if port=8000 is currently being used by other processes. + +Run client in a separate terminal. Use port_id from previous step else port-id=8000. + +```bash +python /app/vllm/benchmarks/benchmark_serving.py \ + --port 8000 \ + --model amd/Llama-3.1-70B-Instruct-FP8-KV \ + --dataset-name random \ + --random-input-len 4096 \ + --random-output-len 512 \ + --request-rate 1 \ + --ignore-eos \ + --num-prompts 500 \ + --percentile-metrics ttft,tpot,itl,e2el +``` + +Once all prompts are processed, terminate the server gracefully (ctrl+c). + +### Running DeepSeek-V3 and DeepSeek-R1 + +We have experimental support for running both DeepSeek-V3 and DeepSeek-R1 models. +*Note there are currently limitations and `--max-model-len` cannot be greater than 32768* + +```bash +docker run -it --rm --ipc=host --network=host --group-add render \ + --privileged --security-opt seccomp=unconfined \ + --cap-add=CAP_SYS_ADMIN --cap-add=SYS_PTRACE \ + --device=/dev/kfd --device=/dev/dri --device=/dev/mem \ + -e VLLM_USE_TRITON_FLASH_ATTN=0 \ + -e VLLM_FP8_PADDING=0 \ + rocm/vllm-dev:main +# Online serving +vllm serve deepseek-ai/DeepSeek-V3 \ + --disable-log-requests \ + --tensor-parallel-size 8 \ + --trust-remote-code \ + --max-model-len 32768 + +python3 /app/vllm/benchmarks/benchmark_serving.py \ + --backend vllm \ + --model deepseek-ai/DeepSeek-V3 \ + --max-concurrency 256\ + --dataset-name random \ + --random-input-len 128 \ + --random-output-len 128 \ + --num-prompts 1000 + +# Offline throughput +python3 /app/vllm/benchmarks/benchmark_throughput.py --model deepseek-ai/DeepSeek-V3 \ + --input-len <> --output-len <> --tensor-parallel-size 8 \ + --quantization fp8 --kv-cache-dtype fp8 --dtype float16 \ + --max-model-len 32768 --trust-remote-code +# Offline Latency +python benchmarks/benchmark_latency.py --model deepseek-ai/DeepSeek-V3 \ +--tensor-parallel-size 8 --trust-remote-code --max-model-len 32768 \ +--batch-size <> --input-len <> --output-len <> +``` + +### CPX mode + +Currently only CPX-NPS1 mode is supported. So ONLY tp=1 is supported in CPX mode. +But multiple instances can be started simultaneously (if needed) in CPX-NPS1 mode. + +Set GPUs in CPX mode with: + +```bash +rocm-smi --setcomputepartition cpx +``` + +Example of running Llama3.1-8B on 1 CPX-NPS1 GPU with input 4096 and output 512. As mentioned above, tp=1. + +```bash +HIP_VISIBLE_DEVICES=0 \ +python3 /app/vllm/benchmarks/benchmark_throughput.py \ + --max-model-len 4608 \ + --num-scheduler-steps 10 \ + --num-prompts 100 \ + --model amd/Llama-3.1-8B-Instruct-FP8-KV \ + --input-len 4096 \ + --output-len 512 \ + --dtype float16 \ + --tensor-parallel-size 1 \ + --output-json \ + --quantization fp8 \ + --gpu-memory-utilization 0.99 +``` + +Set GPU to SPX mode. + +```bash +rocm-smi --setcomputepartition spx +``` + +### Speculative Decoding + +Speculative decoding is one of the key features in vLLM. It has been supported on MI300. Here below is an example of the performance benchmark w/wo speculative decoding for Llama 3.1 405B with Llama 3.1 8B as the draft model. + +Without Speculative Decoding - + +```bash +export VLLM_USE_TRITON_FLASH_ATTN=0 +python /app/vllm/benchmarks/benchmark_latency.py --model amd/Llama-3.1-405B-Instruct-FP8-KV --max-model-len 26720 -tp 8 --batch-size 1 --input-len 1024 --output-len 128 +``` + +With Speculative Decoding - + +```bash +export VLLM_USE_TRITON_FLASH_ATTN=0 +python /app/vllm/benchmarks/benchmark_latency.py --model amd/Llama-3.1-405B-Instruct-FP8-KV --max-model-len 26720 -tp 8 --batch-size 1 --input-len 1024 --output-len 128 --speculative-model amd/Llama-3.1-8B-Instruct-FP8-KV --num-speculative-tokens 5 +``` + +You should see some performance improvement about the e2e latency. + +### AITER + +To get [AITER](https://github.com/ROCm/aiter) kernels support, follow the [Docker build steps](#Docker-manifest) using the [aiter_intergration_final](https://github.com/ROCm/vllm/tree/aiter_intergration_final) branch +There is a published release candidate image at `rocm/vllm-dev:nightly_aiter_intergration_final_20250130` + +To enable the feature make sure the following environment is set: `VLLM_USE_AITER=1`. +The default value is `0` in vLLM, but is set to `1` in the aiter docker. + +## MMLU_PRO_Biology Accuracy Evaluation + +### FP16 + +vllm (pretrained=models--meta-llama--Llama-3.1-405B-Instruct/snapshots/069992c75aed59df00ec06c17177e76c63296a26,dtype=float16,tensor_parallel_size=8), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 64 + +| Tasks |Version| Filter |n-shot| Metric | |Value | |Stderr| +|-------|------:|--------------|-----:|-----------|---|-----:|---|-----:| +|biology| 0|custom-extract| 5|exact_match|↑ |0.8466|± |0.0135| + +### FP8 + +vllm (pretrained=models--meta-llama--Llama-3.1-405B-Instruct/snapshots/069992c75aed59df00ec06c17177e76c63296a26,dtype=float16,quantization=fp8,quantized_weights_path=/llama.safetensors,tensor_parallel_size=8), gen_kwargs: (None), limit: None, num_fewshot: None, batch_size: 32 + +| Tasks |Version| Filter |n-shot| Metric | |Value| |Stderr| +|-------|------:|--------------|-----:|-----------|---|----:|---|-----:| +|biology| 0|custom-extract| 5|exact_match|↑ |0.848|± |0.0134| + +## Performance + +### MLPerf Performance Results + +#### LLama-2-70B + +Please refer to the [Benchmarking Machine Learning using ROCm and AMD GPUs: Reproducing Our MLPerf Inference Submission — ROCm Blogs](https://rocm.blogs.amd.com/artificial-intelligence/mlperf-inf-4-1/README.html) for information on reproducing MLPerf 4.1 Inference results. Note that due to changes in vLLM, it is not possible to use these instructions with the current rocm/vllm-dev docker image. Due to recent changes in vLLM, the instructions for MLPerf 4.1 submission do not apply to the current rocm/vllm-dev docker image. + +## Docker Manifest + +To reproduce the release docker: + +```bash + git clone https://github.com/ROCm/vllm.git + cd vllm + git checkout 8e87b08c2a284c1a20eb3d8e0fbdc84918bf27dc + docker build -f Dockerfile.rocm -t --build-arg BUILD_HIPBLASLT=1 --build-arg USE_CYTHON=1 . +``` + +### AITER + +Use Aiter release candidate branch instead: + +```bash + git clone https://github.com/ROCm/vllm.git + cd vllm + git checkout aiter_intergration_final + docker build -f Dockerfile.rocm -t --build-arg BUILD_HIPBLASLT=1 --build-arg USE_CYTHON=1 . +``` diff --git a/gradlib/GemmTuner.py b/gradlib/GemmTuner.py new file mode 100644 index 000000000000..daa3c4e87e33 --- /dev/null +++ b/gradlib/GemmTuner.py @@ -0,0 +1,336 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +import random +from pathlib import Path + +import pandas as pd +import torch +import torch.nn.functional as F + +import vllm._gradlib_C # noqa: F401 + +rtol = 1e-5 +atol = 1 + +CACHE_INVALIDATE_BUFFERS = int(os.getenv("CACHE_INVALIDATE_BUFFERS", "37")) +ONE = torch.ones(1, dtype=torch.float32, device='cuda') + + +class Gemm: + + def __init__(self, m, n, k, bias, indtype, outdtype, rocblas_decode=False): + self.m = m + self.k = k + self.n = n + self.bias = torch.randn(m, device='cuda').to(indtype) if bias else None + self.indtype = indtype + self.outdtype = outdtype + self.use_rocblas = (indtype == outdtype + and indtype is not torch.float8_e4m3fnuz) + self.nb = CACHE_INVALIDATE_BUFFERS + self.inp = torch.randn((self.n, self.k), + device='cuda').to(self.indtype) + self.weights = torch.randn((self.m, self.k), + device='cuda').to(self.indtype) + # weights2 is used in measurement/warm iters to ensure + # HBM fetch for weight tensors + self.weights2 = torch.randn((self.nb, self.m, self.k), + device='cuda').to(self.indtype) + self.blob = torch.ones(128 * 1024 * 1024, + dtype=torch.float32, + device='cuda') + self.topn = 20 #number of top solutions from each source + self.hipb_sols = [] + self.rocb_sols = [] + self.rtol = 1e-5 + self.atol = 1 + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + # prefer hipblaslt unless rocblas time is less than this + # ratio of hipblaslt time + self.hipb_prefer_ratio = 0.995 + self.rocblas_decode = rocblas_decode + + def find_hipblas_sols(self): + sols = torch.ops._gradlib_C.hipb_findallsols(self.inp, + self.weights.t(), + self.bias, self.outdtype) + print('M N K bias dtype', + self.m, + self.n, + self.k, + self.bias is not None, + self.indtype, + '>>> Total hipb solutions', + len(sols), + flush=True) + #print(sols) + self.hipb_sols = sols + + def check_gemm_ref(self, libtype, solidx): + if self.indtype == torch.float8_e4m3fnuz: + ref = torch._scaled_mm(self.inp, + self.weights.t(), + scale_a=ONE, + scale_b=ONE, + out_dtype=self.outdtype) + if type(ref) is tuple and len(ref) == 2: + ref = ref[0] + else: + ref = F.linear(self.inp, self.weights, self.bias) + if libtype == 'hipblaslt': + c = torch.ops._gradlib_C.hipb_mm(self.inp, self.weights.t(), + solidx, self.bias, self.outdtype, + None, None, None) + elif libtype == 'rocblas': + c = torch.ops._gradlib_C.rocb_mm(self.inp, self.weights.t(), + solidx) + if self.bias is not None: + c += self.bias + if torch.allclose(c.to(self.outdtype), + ref.to(self.outdtype), + atol=self.atol, + rtol=self.rtol): + return True + + print('>>>', + libtype, + 'Solidx', + solidx, + 'FAILED reference test', + flush=True) + #print(ref, flush=True) + #print(c, flush=True) + return False + + def hipb_time_sol(self, solidx, cold_iters=2, warm_iters=10): + #print('>>>hipbtime',solidx) + for i in range(cold_iters): + torch.ops._gradlib_C.hipb_mm(self.inp, self.weights.t(), solidx, + None, self.outdtype, None, None, None) + self.start.record() + for i in range(warm_iters): + torch.ops._gradlib_C.hipb_mm( + self.inp, self.weights2[random.randint(0, self.nb - 1)].t(), + solidx, None, self.outdtype, None, None, None) + self.end.record() + torch.cuda.synchronize() + gtime = self.start.elapsed_time(self.end) / warm_iters + #print('>>> Solidx GTime',solidx,gtime,'ms') + return gtime + + def hipb_time_all_sols(self, fast_mode=0, top_sols=0): + coldi = 20 + warmi = 20 + if fast_mode: + coldi = 2 + warmi = 2 + solutions = self.hipb_sols + if top_sols: + solutions = self.hipb_top_sols + gtimes = {} + for solidx in solutions: + gtimes[solidx] = self.hipb_time_sol(solidx, + cold_iters=coldi, + warm_iters=warmi) + self.hipb_gtimedf = pd.DataFrame.from_dict( + gtimes, orient='index', + columns=['gtimems']).sort_values(by='gtimems') + self.hipb_gtimedf.to_csv('/tmp/hipb_gtimedf.csv') + print('>>> HipBlasLt top solutions, Fast Mode', fast_mode) + print(self.hipb_gtimedf.head(self.topn)) + + def rocb_time_sol(self, solidx, cold_iters=2, warm_iters=10): + + def rocb_mm_bias(inp, w, solidx, bias): + return torch.ops._gradlib_C.rocb_mm(inp, w, solidx) + bias + + def rocb_mm_nobias(inp, w, solidx, _): + return torch.ops._gradlib_C.rocb_mm(inp, w, solidx) + + rocb_fun = rocb_mm_bias if self.bias is not None else rocb_mm_nobias + for _ in range(cold_iters): + rocb_fun(self.inp, self.weights.t(), solidx, self.bias) + + self.start.record() + for _ in range(warm_iters): + rocb_fun(self.inp, self.weights2[random.randint(0, + self.nb - 1)].t(), + solidx, self.bias) + + self.end.record() + torch.cuda.synchronize() + gtime = self.start.elapsed_time(self.end) / warm_iters + #print('>>> RocSolidx GTime',solidx,gtime,'ms') + return gtime + + def find_rocblas_sols(self): + sols = torch.ops._gradlib_C.rocb_findallsols(self.inp, + self.weights.t()) + print('M N K dtype', + self.m, + self.n, + self.k, + self.indtype, + '>>> Total rocb solutions', + len(sols), + flush=True) + #print(sols) + self.rocb_sols = sols + + def rocb_time_all_sols(self, fast_mode=0, top_sols=0): + coldi = 20 + warmi = 20 + if fast_mode: + coldi = 2 + warmi = 2 + solutions = self.rocb_sols + if top_sols: + solutions = self.rocb_top_sols + gtimes = {} + for solidx in solutions: + gtimes[solidx] = self.rocb_time_sol(solidx, coldi, warmi) + self.rocb_gtimedf = pd.DataFrame.from_dict( + gtimes, orient='index', + columns=['gtimems']).sort_values(by='gtimems') + self.rocb_gtimedf.to_csv('/tmp/rocb_gtimedf.csv') + print('>>> Rocblas top solutions, Fast Mode', fast_mode, flush=True) + print(self.rocb_gtimedf.head(self.topn), flush=True) + + def warmup(self, warmi=500): + for i in range(warmi): + self.blob = self.blob + 0.00001 + + def functional_check_topn_fastest(self): + rocb_topn = [] + for solidx in self.rocb_gtimedf.index[:self.topn]: + if self.check_gemm_ref(libtype='rocblas', solidx=solidx): + rocb_topn.append(solidx) + self.rocb_top_sols = rocb_topn + hipb_topn = [] + for solidx in self.hipb_gtimedf.index[:self.topn]: + if self.check_gemm_ref(libtype='hipblaslt', solidx=solidx): + hipb_topn.append(solidx) + self.hipb_top_sols = hipb_topn + + def find_fastest_solution(self): + if self.use_rocblas: + self.find_rocblas_sols() + if not (self.rocblas_decode and self.n == 1): + self.find_hipblas_sols() + self.warmup() + self.rocb_time_all_sols(fast_mode=1) + self.warmup() + self.hipb_time_all_sols(fast_mode=1) + self.functional_check_topn_fastest() + self.warmup() + self.rocb_time_all_sols(fast_mode=0, top_sols=1) + self.warmup() + self.hipb_time_all_sols(fast_mode=0, top_sols=1) + if len(self.rocb_gtimedf) > 0 and len(self.hipb_gtimedf) > 0: + best_rocb_time = self.rocb_gtimedf.gtimems.iloc[0] + best_hipb_time = self.hipb_gtimedf.gtimems.iloc[0] + if best_rocb_time < best_hipb_time * self.hipb_prefer_ratio: + self.best_libtype = 'rocblas' + self.best_solidx = self.rocb_gtimedf.index[0] + self.best_soltime = best_rocb_time + else: + self.best_libtype = 'hipblaslt' + self.best_solidx = self.hipb_gtimedf.index[0] + self.best_soltime = best_hipb_time + #self.check_gemm_ref(self.best_libtype,self.best_solidx) + elif len(self.hipb_gtimedf) > 0: + print('>>> Only hipblas solutions found!', flush=True) + best_hipb_time = self.hipb_gtimedf.gtimems.iloc[0] + self.best_libtype = 'hipblaslt' + self.best_solidx = self.hipb_gtimedf.index[0] + self.best_soltime = best_hipb_time + elif len(self.rocb_gtimedf) > 0: + print('>>> Only rocblas solutions found!', flush=True) + best_rocb_time = self.rocb_gtimedf.gtimems.iloc[0] + self.best_libtype = 'rocblas' + self.best_solidx = self.rocb_gtimedf.index[0] + self.best_soltime = best_rocb_time + else: + print('>>> No rocblas or hipblas solutions found!', flush=True) + self.best_libtype = 'rocblas' + self.best_solidx = 0 + self.best_soltime = 0 + print('>>> Fastest Solution is', + self.best_libtype, + self.best_solidx, + self.best_soltime, + flush=True) + + +class GemmTuner: + + def __init__(self, + indtype, + outdtype, + tuned_file=None, + rocblas_decode=False): + self.gemm_problems = pd.DataFrame(columns=['M', 'N', 'K', 'bias']) + self.indtype = indtype + self.outdtype = outdtype + self.rocblas_decode = rocblas_decode + self.tuned_file = tuned_file + if Path(tuned_file).is_file(): + self.tuned_shapes = pd.read_csv(tuned_file) + else: + self.tuned_shapes = None + + def add_gemm(self, m, n, k, indtype=None, bias=False): + indtype = self.indtype if self.indtype is not None else indtype + assert indtype is not None + outdtype = self.outdtype if self.outdtype is not None else indtype + assert outdtype is not None + if (self.tuned_shapes is None or (self.tuned_shapes[ + (self.tuned_shapes['M'] == m) & (self.tuned_shapes['N'] == n) & + (self.tuned_shapes['K'] == k) & + (self.tuned_shapes['bias'] == bias) & + (self.tuned_shapes['dtype'] == str(indtype)) & + (self.tuned_shapes['outdtype'] == str(outdtype))].empty)): + entry = { + 'M': [m], + 'N': [n], + 'K': [k], + 'bias': [bias], + 'dtype': [indtype], + 'outdtype': [outdtype] + } + df = pd.DataFrame(entry) + self.gemm_problems = pd.concat([self.gemm_problems, df], + ignore_index=True) + else: + print(f">>>Info: Found Duplicate shape(M:{m}," + " N:{n}, K:{k} bias:{bias}), skipping") + + def find_best_sols(self): + df = self.gemm_problems + soldf = pd.DataFrame(columns=['libtype', 'solidx', 'soltimes']) + for i in range(len(df)): + ds = df.loc[i, :] + indtype = self.indtype or ds['dtype'] + outdtype = self.outdtype or indtype + gemmobj = Gemm(ds['M'], + ds['N'], + ds['K'], + ds['bias'], + indtype=indtype, + outdtype=outdtype, + rocblas_decode=self.rocblas_decode) + gemmobj.find_fastest_solution() + soldf.loc[i, 'libtype'] = gemmobj.best_libtype + soldf.loc[i, 'solidx'] = gemmobj.best_solidx + soldf.loc[i, 'soltimes'] = gemmobj.best_soltime + + del gemmobj + torch.cuda.empty_cache() + + finaldf = pd.concat([self.gemm_problems, soldf], axis=1) + if self.tuned_shapes is not None: + finaldf = pd.concat([finaldf, self.tuned_shapes]) + finaldf['solidx'] = finaldf['solidx'].convert_dtypes('int64') + finaldf.to_csv(self.tuned_file, index=False) + print(finaldf) diff --git a/gradlib/gemm_runner.py b/gradlib/gemm_runner.py new file mode 100644 index 000000000000..47d952ab2761 --- /dev/null +++ b/gradlib/gemm_runner.py @@ -0,0 +1,79 @@ +# SPDX-License-Identifier: Apache-2.0 +import sys + +import pandas as pd +import torch +import torch.nn.functional as F + +import vllm._gradlib_C # noqa: F401 + +torch.ops._gradlib_C.rocb_create_extension() +torch.ops._gradlib_C.hipb_create_extension() + + +class TunedGemm: + + def __init__(self, tuned_csv_file): + self.bestsols = pd.read_csv(tuned_csv_file, index_col=[0]) + self.create_ds() + + def create_ds(self): + df = self.bestsols + solds = {} + for i in range(len(df)): + ds = df.iloc[i] + key = (ds['M'], ds['N'], ds['K']) + if ds['libtype'] == 'hipblaslt': + soltype = 1 + elif ds['libtype'] == 'rocblas': + soltype = 2 + solds[key] = (soltype, int(ds['solidx'])) + #print(solds) + self.solids = solds + + def query_sol(self, m, n, k): + return self.solids.get((m, n, k), (0, 0)) + + def mm(self, inp, weights): + soltype, solidx = self.query_sol(m=weights.shape[0], + n=inp.shape[0], + k=inp.shape[1]) + if soltype == 1: + out = torch.ops._gradlib_C.hipb_mm(inp, weights.t(), solidx, None, + None, None, None, None) + elif soltype == 2: + out = torch.ops._gradlib_C.rocb_mm(inp, weights.t(), solidx) + else: + out = F.linear(inp, weights) + return out + + def run_all_tuned_sols(self): + for i in range(len(self.bestsols)): + ds = self.bestsols.iloc[i] + print('>>> Running tuned solution') + print(ds) + inp = torch.randn((ds['N'], ds['K']), + dtype=get_dtype(ds['dtype']), + device='cuda') + weights = torch.randn((ds['M'], ds['K']), + dtype=get_dtype(ds['dtype']), + device='cuda') + self.mm(inp, weights) + + +def get_dtype(dtype_csv): + if dtype_csv == 'torch.float16': + dtype = torch.float16 + elif dtype_csv == 'torch.bfloat16': + dtype = torch.bfloat16 + elif dtype_csv == 'torch.float32': + dtype = torch.float32 + elif dtype_csv == 'torch.float8_e4m3fnuz': + dtype = torch.float8_e4m3fnuz + return dtype + + +if __name__ == '__main__': + tgemm = TunedGemm(sys.argv[1]) #csv file with tuned sols goes in argv[1] + print(tgemm.bestsols) + tgemm.run_all_tuned_sols() diff --git a/gradlib/gemm_tuner.py b/gradlib/gemm_tuner.py new file mode 100644 index 000000000000..30a3c5f9ad4d --- /dev/null +++ b/gradlib/gemm_tuner.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 +import argparse +import json +import os +from pathlib import Path + +import torch # isort: split +import pandas as pd +from GemmTuner import GemmTuner + +import vllm._gradlib_C # noqa: F401 + +torch.ops._gradlib_C.rocb_create_extension() +torch.ops._gradlib_C.hipb_create_extension() + + +def generate_mk_sets(model_dir, tp=1): + with open(f'{model_dir}/config.json') as f: + data = json.load(f) + hidden_size = data['hidden_size'] + intermediate_size = data['intermediate_size'] + total_num_heads = data['num_attention_heads'] + total_num_kv_heads = data['num_key_value_heads'] + dtype = get_dtype(data['torch_dtype']) + head_dim = hidden_size // total_num_heads + return [((total_num_heads + (2 * total_num_kv_heads)) * head_dim // tp, + hidden_size), (hidden_size, hidden_size // tp), + (intermediate_size * 2 // tp, hidden_size), + (hidden_size, intermediate_size // tp)], hidden_size, dtype + + +dtypes = { + 'f32': torch.float32, + 'float32': torch.float32, + 'f16': torch.float16, + 'float16': torch.float16, + 'bf16': torch.bfloat16, + 'bfloat16': torch.bfloat16, + 'fp8': torch.float8_e4m3fnuz, +} + + +def get_dtype(dtype_str): + if dtype_str is None: + return None + if dtype_str.startswith('torch'): + return getattr(torch, dtype_str.split('.')[1]) + if dtype_str in dtypes: + return dtypes[dtype_str] + else: + print('>>> Warning! Invalid dtype', dtype_str, + 'using default dtype f16') + return None + + +def list_of_ints(arg): + return list(map(int, arg.split(','))) + + +def load_input_gemms(input_file): + if Path(input_file).is_file(): + return + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--model_dir", + type=str, + default=os.getenv('GTUNE_MODEL', ""), + help="Enter the location of your model directory") + parser.add_argument("--tuned_file", + type=str, + default=os.getenv('GTUNE_TUNED', "tuned.csv"), + help="output file for tuned gemm solutions") + parser.add_argument( + "--input_file", + type=str, + default=os.getenv('GTUNE_INPUT', None), + help="list of gemms to tune for, mutually exclusive with model_dir") + parser.add_argument("--tp", + type=int, + default=os.getenv('GTUNE_TP', 1), + help="Tensor parallelism to be used.") + parser.add_argument( + "--indtype", + type=str, + default=None, + choices=["f32", "f16", "bf16", "fp8"], + help="dtype: f32 f16 bf16 fp8. Use this to override the" + " input_file or if no input_file provided") + parser.add_argument( + "--outdtype", + type=str, + choices=["f32", "f16", "bf16", "fp8"], + help="dtype: f32 f16 bf16 fp8. Use to override the default value," + " which is the same as indtype for each shape (see --indtype.)") + parser.add_argument("--rocblas-decode", + action="store_true", + default=False, + help="forces rocblas solution on decode N=1") + parser.add_argument("--batch_size", + type=int, + default=os.getenv('GTUNE_BATCH_SIZE', 1), + help="Batch size to tune for") + parser.add_argument("--nsets", + type=list_of_ints, + default=[1, 512, 1024, 2048, 3072, 4096, 8192, 16384], + help="N sizes to tune for: 1,128,2048") + parser.add_argument("--all_bias", + action="store_true", + help="Tune for both bias and non bias cases," + " regardless of what was used" + " to collect the shapes") + args = parser.parse_args() + + if args.outdtype is None: + args.outdtype = args.indtype + indtype = get_dtype(args.indtype) + outdtype = get_dtype(args.outdtype) + + gtuner = GemmTuner(indtype, outdtype, args.tuned_file, args.rocblas_decode) + nsets = [i * args.batch_size for i in args.nsets] + if args.input_file: + print(f">>> Loading {args.input_file}") + if not Path(args.input_file).is_file(): + print(f">>> ERROR: {args.input_file} does not exist. Exiting") + exit(1) + shapes = pd.read_csv(args.input_file) + for i in range(len(shapes)): + ds = shapes.iloc[i] + for bias in [True, False] if args.all_bias else [ds['bias']]: + gtuner.add_gemm(ds['M'], + ds['N'], + ds['K'], + indtype=get_dtype(ds['dtype']), + bias=bias) + else: + if not args.model_dir: + print(">>> Warning! NO MODEL SPECIFIED. Tuning for LL2 13B TP1") + #LL2 13B sizes + mksets = [(15360, 5120), (5120, 5120), (27648, 5120), + (5120, 13824)] + gtuner.add_gemm(m=32000, n=1, k=5120) # logits gemm + dtype = torch.float16 + else: + mksets, hidden_size, dtype = generate_mk_sets( + args.model_dir, args.tp) + gtuner.add_gemm( + m=32000 // args.tp, + n=1 * args.batch_size, + k=hidden_size, + indtype=dtype, + ) #TODO: Handle cases where vocab_size is not divisible by tp + + for n in sorted(nsets): + for m, k in mksets: + gtuner.add_gemm(m, n, k, indtype=dtype) + + gtuner.find_best_sols() diff --git a/pyproject.toml b/pyproject.toml index 9892967b82d7..3aeba1c2921b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,7 +92,7 @@ exclude = [ [tool.codespell] ignore-words-list = "dout, te, indicies, subtile, ElementE" -skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build" +skip = "./tests/models/fixtures,./tests/prompts,./benchmarks/sonnet.txt,./tests/lora/data,./build,./csrc/gradlib,./csrc/rocm" [tool.isort] use_parentheses = true diff --git a/requirements-rocm.txt b/requirements-rocm.txt index ccc906234177..07def140ea6c 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -2,6 +2,7 @@ -r requirements-common.txt # Dependencies for AMD GPUs +numpy==1.26.4 awscli boto3 botocore @@ -10,3 +11,4 @@ ray >= 2.10.0 peft pytest-asyncio tensorizer>=2.9.0 +setuptools-scm>=8 diff --git a/rocm_patch/libamdhip64.so.6 b/rocm_patch/libamdhip64.so.6 new file mode 100644 index 000000000000..b551a2c9d890 Binary files /dev/null and b/rocm_patch/libamdhip64.so.6 differ diff --git a/setup.py b/setup.py index 50265d46e7d6..1b128f3f04eb 100755 --- a/setup.py +++ b/setup.py @@ -597,6 +597,7 @@ def _read_requirements(filename: str) -> List[str]: if _is_hip(): ext_modules.append(CMakeExtension(name="vllm._rocm_C")) + ext_modules.append(CMakeExtension(name="vllm._gradlib_C")) if _is_cuda(): ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) diff --git a/setup_cython.py b/setup_cython.py new file mode 100644 index 000000000000..51f5d07c82bc --- /dev/null +++ b/setup_cython.py @@ -0,0 +1,38 @@ +# SPDX-License-Identifier: Apache-2.0 +import Cython.Compiler.Options +from Cython.Build import cythonize +from setuptools import setup + +Cython.Compiler.Options.annotate = True + +infiles = [] + +infiles += [ + "vllm/engine/llm_engine.py", + "vllm/transformers_utils/detokenizer.py", + "vllm/engine/output_processor/single_step.py", + "vllm/outputs.py", + "vllm/engine/output_processor/stop_checker.py", +] + +infiles += [ + "vllm/core/scheduler.py", + "vllm/sequence.py", + "vllm/core/block_manager.py", +] + +infiles += [ + "vllm/model_executor/layers/sampler.py", + "vllm/sampling_params.py", + "vllm/utils.py", +] + +setup(ext_modules=cythonize(infiles, + annotate=False, + force=True, + compiler_directives={ + 'language_level': "3", + 'infer_types': True + })) + +# example usage: python3 setup_cython.py build_ext --inplace diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index e9b537ed5150..a12953d5fbe1 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -6,6 +6,7 @@ from vllm import SamplingParams +from ....test_utils import xfail_if_rocm62 from .conftest import get_token_ids_from_llm_generator @@ -181,6 +182,7 @@ def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator, assert baseline_token_ids == test_token_ids +@xfail_if_rocm62 @pytest.mark.parametrize( "common_llm_kwargs", [ @@ -259,6 +261,7 @@ def test_chunked_prefill_block_manager(baseline_llm_generator, assert baseline_token_ids == test_token_ids +@xfail_if_rocm62 @pytest.mark.parametrize( "common_llm_kwargs", [{ @@ -340,6 +343,7 @@ def test_block_manager_prefix_caching_enabled_with_preemption( assert baseline_token_ids == test_token_ids +@xfail_if_rocm62 @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py index c874608e40a2..1a8873b00999 100644 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -8,6 +8,7 @@ from tests.kernels.utils import override_backend_env_variable from vllm import LLM, SamplingParams +from ....test_utils import xfail_if_rocm62 from .conftest import get_text_from_llm_generator # relatively small model with 4k sliding window @@ -76,6 +77,7 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, assert sum(cmp) > 0.7 * len(cmp) +@xfail_if_rocm62 @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/distributed/test_custom_all_reduce.py b/tests/distributed/test_custom_all_reduce.py index 46887bca42a9..fc682f93f3cb 100644 --- a/tests/distributed/test_custom_all_reduce.py +++ b/tests/distributed/test_custom_all_reduce.py @@ -97,13 +97,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port): inp = torch.ones(sz, dtype=torch.float32, device=device) out = inp for _ in range(num_communication): - out = fa.all_reduce(out, registered=False) + out = fa.all_reduce_unreg(out) torch.testing.assert_close(out, inp * (tp_size**num_communication)) inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) out = inp for _ in range(num_communication): - out = fa.all_reduce(out, registered=False) + out = fa.all_reduce_unreg(out) torch.testing.assert_close(out, inp * (tp_size**num_communication)) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 4c42a0ed8112..9ba219d4c59a 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -61,7 +61,8 @@ def worker_fn(): device=get_world_group().device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) - tensor = pynccl_comm.all_reduce(tensor) + with pynccl_comm.change_state(enable=True): + tensor = pynccl_comm.all_reduce(tensor) torch.cuda.synchronize() assert torch.all(tensor == pynccl_comm.world_size).cpu().item() @@ -82,16 +83,17 @@ def multiple_allreduce_worker_fn(): group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] pynccl_comm = PyNcclCommunicator(group=group, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) - # two groups can communicate independently - if torch.distributed.get_rank() in [0, 1]: - tensor = pynccl_comm.all_reduce(tensor) - tensor = pynccl_comm.all_reduce(tensor) - torch.cuda.synchronize() - assert torch.all(tensor == 4).cpu().item() - else: - tensor = pynccl_comm.all_reduce(tensor) - torch.cuda.synchronize() - assert torch.all(tensor == 2).cpu().item() + with pynccl_comm.change_state(enable=True): + # two groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + tensor = pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) + torch.cuda.synchronize() + assert torch.all(tensor == 4).cpu().item() + else: + tensor = pynccl_comm.all_reduce(tensor) + torch.cuda.synchronize() + assert torch.all(tensor == 2).cpu().item() @pytest.mark.skipif(torch.cuda.device_count() < 4, @@ -137,7 +139,9 @@ def worker_fn_with_cudagraph(): # run something in the default stream to initialize torch engine a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}') torch.cuda.synchronize() - with torch.cuda.graph(graph): + with torch.cuda.graph( + graph, stream=pynccl_comm.stream), pynccl_comm.change_state( + enable=True): a_out = pynccl_comm.all_reduce(a) torch.cuda.synchronize() graph.replay() @@ -166,7 +170,8 @@ def all_gather_worker_fn(): for r in range(world_size) ]).to(device) - pynccl_comm.all_gather(result, tensor) + with pynccl_comm.change_state(enable=True): + pynccl_comm.all_gather(result, tensor) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) @@ -203,7 +208,8 @@ def reduce_scatter_worker_fn(): expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size] for tensor in all_tensors).to(device) - pynccl_comm.reduce_scatter(result, tensor) + with pynccl_comm.change_state(enable=True): + pynccl_comm.reduce_scatter(result, tensor) torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) @@ -230,13 +236,15 @@ def send_recv_worker_fn(): else: tensor = torch.empty(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) - - if pynccl_comm.rank == 0: - pynccl_comm.send(tensor, - dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) - else: - pynccl_comm.recv(tensor, - src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + with pynccl_comm.change_state(enable=True): + if pynccl_comm.rank == 0: + pynccl_comm.send(tensor, + dst=(pynccl_comm.rank + 1) % + pynccl_comm.world_size) + else: + pynccl_comm.recv(tensor, + src=(pynccl_comm.rank - 1) % + pynccl_comm.world_size) torch.cuda.synchronize() assert torch.all(tensor == 1).cpu().item() @@ -267,12 +275,15 @@ def multiple_send_recv_worker_fn(): 1024, dtype=torch.float32, device=device) - if torch.distributed.get_rank() in [0, 1]: - pynccl_comm.send(tensor, - dst=(pynccl_comm.rank + 1) % pynccl_comm.world_size) - else: - pynccl_comm.recv(tensor, - src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + with pynccl_comm.change_state(enable=True): + if torch.distributed.get_rank() in [0, 1]: + pynccl_comm.send(tensor, + dst=(pynccl_comm.rank + 1) % + pynccl_comm.world_size) + else: + pynccl_comm.recv(tensor, + src=(pynccl_comm.rank - 1) % + pynccl_comm.world_size) torch.cuda.synchronize() if torch.distributed.get_rank() in [0, 2]: assert torch.all(tensor == 1).cpu().item() diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index b667d8d9e030..28260ad6fe9a 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -9,7 +9,7 @@ from tests.kernels.utils import opcheck from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.utils import get_max_shared_memory_bytes +from vllm.utils import get_max_shared_memory_bytes, is_navi from .allclose_default import get_default_atol, get_default_rtol @@ -35,7 +35,7 @@ # This should be sync with get_supported_head_sizes() in # vllm.attention.ops.paged_attn.PagedAttention -HEAD_SIZES = [32, 64, 80, 96, 112, 120, 128, 192, 256] +HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256] BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] @@ -186,6 +186,10 @@ def test_paged_attention( # Using default kv_scale k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + # additional argument for v1/v2 pa kernel + num_threads = 1024 if current_platform.is_rocm() \ + and not is_navi() else 128 + # Call the paged attention kernel. output = torch.empty_like(query) if version == "v1": @@ -206,14 +210,16 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v1, - (output, query, key_cache, value_cache, num_kv_heads, scale, - block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v1, + (output, query, key_cache, value_cache, num_kv_heads, scale, + block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0, num_threads), + cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) elif version in ("v2", "rocm"): + if current_platform.is_rocm(): + PARTITION_SIZE = 1024 if version == "v2" else 256 num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) assert PARTITION_SIZE % block_size == 0 num_seqs, num_heads, head_size = output.shape @@ -226,6 +232,7 @@ def test_paged_attention( dtype=torch.float32, ) max_logits = torch.empty_like(exp_sums) + if version == "v2": ops.paged_attention_v2( output, @@ -247,13 +254,14 @@ def test_paged_attention( v_scale, ) - opcheck(torch.ops._C.paged_attention_v2, - (output, exp_sums, max_logits, tmp_output, query, - key_cache, value_cache, num_kv_heads, scale, block_tables, - seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), - cond=(head_size == HEAD_SIZES[0] - and block_size == BLOCK_SIZES[0])) + opcheck( + torch.ops._C.paged_attention_v2, + (output, exp_sums, max_logits, tmp_output, query, key_cache, + value_cache, num_kv_heads, scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, kv_cache_dtype, + k_scale, v_scale, 0, 0, 0, 64, 0, num_threads), + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) else: ops.paged_attention_rocm( @@ -274,13 +282,15 @@ def test_paged_attention( kv_cache_dtype, k_scale, v_scale, + None, + PARTITION_SIZE, ) opcheck(torch.ops._rocm_C.paged_attention, (output, exp_sums, max_logits, tmp_output, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale), + kv_cache_dtype, k_scale, v_scale, None, PARTITION_SIZE), cond=(head_size == HEAD_SIZES[0] and block_size == BLOCK_SIZES[0])) @@ -432,4 +442,4 @@ def test_multi_query_kv_attention( ) atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 - torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) \ No newline at end of file diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index e653d34d00ee..17703a21b108 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -23,7 +23,7 @@ # There may not be enough gpu memory due to large NUM_BLOCKS. # Reduce NUM_BLOCKS when it happens. NUM_BLOCKS = 4321 # Arbitrary values for testing -PARTITION_SIZE = 512 +PARTITION_SIZE = 512 if not current_platform.is_rocm() else 1024 DTYPES = [torch.half, torch.bfloat16] NUM_GEN_SEQS = [3] # Arbitrary values for testing NUM_PREFILL_SEQS = [3] # Arbitrary values for testing diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 0d11e8652ce6..3a7683bcd119 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -23,7 +23,10 @@ from vllm.platforms import current_platform # List of support backends for encoder/decoder models -LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] +LIST_ENC_DEC_SUPPORTED_BACKENDS = ([ + _Backend.XFORMERS, _Backend.FLASH_ATTN +] if not current_platform.is_rocm() else [_Backend.ROCM_FLASH]) + HEAD_SIZES = [64, 256] NUM_HEADS = [1, 16] @@ -873,8 +876,6 @@ def test_encoder_only( attn_backend.name) -@pytest.mark.skipif(current_platform.is_rocm(), - reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP) @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("attn_backend", LIST_ENC_DEC_SUPPORTED_BACKENDS) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index 0f13fbc96503..d3679bbeab50 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -5,9 +5,12 @@ """ import pytest import torch +from torch.nn import Parameter +from torch.nn import functional as F from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock +import vllm.envs as envs import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import (compute_max_diff, opcheck, stack_and_dev, torch_moe, torch_moe_single) @@ -48,13 +51,21 @@ def test_fused_moe( w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 score = torch.randn((m, e), device="cuda", dtype=dtype) + torch_output = torch_moe(a, w1, w2, score, topk) + + # Pad the input if use padding + if envs.VLLM_MOE_PADDING: + w1 = F.pad(w1, (0, 128), "constant", 0) + torch.cuda.empty_cache() + w2 = F.pad(w2, (0, 128), "constant", 0) + torch.cuda.empty_cache() triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) iterative_output = iterative_moe(a, w1, w2, score, topk, renormalize=False) torch.testing.assert_close(iterative_output, torch_output, - atol=2e-2, + atol=1e-2, rtol=0) @@ -179,6 +190,17 @@ def test_mixtral_moe(dtype: torch.dtype): # vLLM uses 1D query [num_tokens, hidden_dim] vllm_inputs = hf_inputs.flatten(0, 1) + # pad the weight if using padding + if envs.VLLM_MOE_PADDING: + vllm_moe.experts.w13_weight = Parameter(F.pad( + vllm_moe.experts.w13_weight, (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() + vllm_moe.experts.w2_weight = Parameter(F.pad( + vllm_moe.experts.w2_weight, (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() + # Run forward passes for both MoE blocks hf_states, _ = hf_moe.forward(hf_inputs) vllm_states = vllm_moe.forward(vllm_inputs) @@ -195,6 +217,8 @@ def test_mixtral_moe(dtype: torch.dtype): atol=mixtral_moe_tol[dtype]) +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Make this test work with MoE padding on HIP") @pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("n", [128, 2048]) @pytest.mark.parametrize("k", [128, 1024]) diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 5be111d71308..f97b95a81da5 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -16,7 +16,8 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.platforms.interface import _Backend from vllm.utils import (STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL, - STR_XFORMERS_ATTN_VAL, make_tensor_with_pad) + STR_ROCM_FLASH_ATTN_VAL, STR_XFORMERS_ATTN_VAL, + make_tensor_with_pad) # For now, disable "test_aot_dispatch_dynamic" since there are some # bugs related to this test in PyTorch 2.4. @@ -528,6 +529,12 @@ def make_backend(backend_name: str) -> AttentionBackend: # NOTE: xFormers backend cannot be imported for CPU and AMD GPUs. from vllm.attention.backends.xformers import XFormersBackend return XFormersBackend() + + elif backend_name == STR_ROCM_FLASH_ATTN_VAL: + from vllm.attention.backends.rocm_flash_attn import ( # noqa: F401 + ROCmFlashAttentionBackend) + return ROCmFlashAttentionBackend + elif backend_name == STR_FLASH_ATTN_VAL: from vllm.attention.backends.flash_attn import FlashAttentionBackend return FlashAttentionBackend() diff --git a/tests/test_utils.py b/tests/test_utils.py index 5b69ffd18bb2..d459fbcfc650 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -112,6 +112,23 @@ def dummy(*, old_arg: object = None, new_arg: object = None): dummy(old_arg=1) +def is_rocm62(): + import torch + return isinstance(torch.version.hip, + str) and torch.version.hip.startswith("6.2") + + +def xfail_if_rocm62(function=None, + reason: str = "Tests are not yet ready for ROCm 6.2", + strict: bool = False): + if function: + return pytest.mark.xfail(is_rocm62(), reason=reason, + strict=strict)(function) + else: + assert callable(function) + return pytest.mark.xfail(is_rocm62(), reason=reason, strict=strict) + + def test_get_open_port(): os.environ["VLLM_PORT"] = "5678" # make sure we can get multiple ports, even if the env var is set diff --git a/tests/vllm_test_utils/vllm_test_utils/__init__.py b/tests/vllm_test_utils/vllm_test_utils/__init__.py index 1d1219fbeffa..88c91f550763 100644 --- a/tests/vllm_test_utils/vllm_test_utils/__init__.py +++ b/tests/vllm_test_utils/vllm_test_utils/__init__.py @@ -5,6 +5,5 @@ """ from .blame import BlameResult, blame -from .monitor import MonitoredValues, monitor -__all__ = ["blame", "BlameResult", "monitor", "MonitoredValues"] +__all__ = ["blame", "BlameResult"] diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index bdc9a6a33df0..844824a54172 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -11,6 +11,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.scalar_type import ScalarType +from vllm.utils import is_navi logger = init_logger(__name__) @@ -63,7 +64,9 @@ def paged_attention_v1( seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step) + blocksparse_head_sliding_step, + num_threads = 1024 if current_platform.is_rocm() \ + and not is_navi() else 128) def paged_attention_v2( @@ -95,7 +98,9 @@ def paged_attention_v2( num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, - blocksparse_block_size, blocksparse_head_sliding_step) + blocksparse_block_size, blocksparse_head_sliding_step, + num_threads = 1024 if current_platform.is_rocm() \ + and not is_navi() else 128) def paged_attention_rocm( @@ -116,12 +121,15 @@ def paged_attention_rocm( kv_cache_dtype: str, k_scale: torch.Tensor, v_scale: torch.Tensor, + fp8_out_scale: Optional[torch.Tensor], + partition_size: int, ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale) + kv_cache_dtype, k_scale, v_scale, + fp8_out_scale, partition_size) # pos encoding ops @@ -158,6 +166,19 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) +def scaled_rms_norm(out: torch.Tensor, input: torch.Tensor, + weight: torch.Tensor, scale: torch.Tensor, + epsilon: float) -> None: + torch.ops._C.rms_norm_static_fp8_quant(out, input, weight, scale, epsilon) + + +def scaled_fused_add_rms_norm(out: torch.Tensor, input: torch.Tensor, + residual: torch.Tensor, weight: torch.Tensor, + scale: torch.Tensor, epsilon: float) -> None: + torch.ops._C.fused_add_rms_norm_static_fp8_quant(out, input, residual, + weight, scale, epsilon) + + def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, @@ -801,7 +822,8 @@ def scaled_fp8_quant( shape: Union[Tuple[int, int], torch.Size] = input.shape # For rocm, the output fp8 dtype is torch.float_e3m3fnuz out_dtype: torch.dtype = torch.float8_e4m3fnuz \ - if current_platform.is_rocm() else torch.float8_e4m3fn + if current_platform.is_rocm() and not \ + is_navi() else torch.float8_e4m3fn if num_token_padding: shape = (max(num_token_padding, input.shape[0]), shape[1]) output = torch.empty(shape, device=input.device, dtype=out_dtype) @@ -1060,16 +1082,20 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int: # custom ar -def init_custom_ar(ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor, - rank: int, full_nvlink: bool) -> int: - return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank, - full_nvlink) +def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor, + handles: List[str], offsets: List[int], rank: int, + full_nvlink: bool) -> int: + return torch.ops._C_custom_ar.init_custom_ar(meta, rank_data, handles, + offsets, rank, full_nvlink) -def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int, - reg_buffer_sz_bytes: int) -> None: - torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, - reg_buffer_sz_bytes) +def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None: + torch.ops._C_custom_ar.all_reduce_reg(fa, inp, out) + + +def all_reduce_unreg(fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, + out: torch.Tensor) -> None: + torch.ops._C_custom_ar.all_reduce_unreg(fa, inp, reg_buffer, out) def dispose(fa: int) -> None: @@ -1080,14 +1106,39 @@ def meta_size() -> int: return torch.ops._C_custom_ar.meta_size() -def register_buffer(fa: int, ipc_tensors: List[int]) -> None: - return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors) +def register_buffer(fa: int, t: torch.Tensor, handles: List[str], + offsets: List[int]) -> None: + return torch.ops._C_custom_ar.register_buffer(fa, t, handles, offsets) -def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: +def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]: return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa) -def register_graph_buffers(fa: int, handles: List[List[int]], +def register_graph_buffers(fa: int, handles: List[str], offsets: List[List[int]]) -> None: torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets) + + +def allocate_meta_buffer(size: int) -> torch.Tensor: + return torch.ops._C_custom_ar.allocate_meta_buffer(size) + + +def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor: + return torch.ops._C_custom_ar.get_meta_buffer_ipc_handle(inp) + + +# ROCm custom +def LLMM1(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, + rows_per_block: int) -> None: + torch.ops._rocm_C.LLMM1(a, b, out, rows_per_block) + + +def LLMM_Silu(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, + rows_per_block: int) -> None: + torch.ops._rocm_C.LLMM_Silu(a, b, out, rows_per_block) + + +def wvSpltK(a: torch.Tensor, b: torch.Tensor, out: torch.Tensor, N: int, + cu_count: int) -> None: + torch.ops._rocm_C.wvSpltK(a, b, out, N, cu_count) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 5f0a54013540..290f323cf73f 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -234,8 +234,10 @@ class AttentionLayer(Protocol): _k_scale: torch.Tensor _v_scale: torch.Tensor - _k_scale_float: float - _v_scale_float: float + _k_scale_float: torch.Tensor + _v_scale_float: torch.Tensor + _q_scale: torch.Tensor + _prob_scale: torch.Tensor def forward( self, @@ -275,6 +277,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, + fp8_out_scale: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError @@ -291,6 +294,7 @@ def forward( k_pe: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, + fp8_out_scale: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 9765e7881ad9..63c3f0120eb8 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -369,6 +369,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: BlocksparseFlashAttentionMetadata, + fp8_out_scale: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 6a82127acdf7..c0a5b0ab24de 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -671,6 +671,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashAttentionMetadata, + fp8_out_scale: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 715ed6748b84..5f0448ed7cec 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -937,6 +937,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: FlashInferMetadata, + fp8_out_scale: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index 1518e518e91b..6d24848e69a4 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -161,6 +161,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: HPUAttentionMetadata, + fp8_out_scale: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index b4879af4cf20..a332e1c7135f 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -179,6 +179,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: IpexAttnMetadata, # type: ignore + fp8_out_scale: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. diff --git a/vllm/attention/backends/mla/utils.py b/vllm/attention/backends/mla/utils.py index 9b63192ed0f6..56aaec05545f 100644 --- a/vllm/attention/backends/mla/utils.py +++ b/vllm/attention/backends/mla/utils.py @@ -27,7 +27,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( scaled_dequantize, scaled_quantize) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding -from vllm.vllm_flash_attn import flash_attn_varlen_func + +try: + from vllm.vllm_flash_attn import flash_attn_varlen_func +except ImportError: + from flash_attn import flash_attn_varlen_func @dataclass @@ -424,6 +428,7 @@ def forward( k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: T, + fp8_out_scale: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: if output is not None: diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index b61dfe63ddca..f9622da42540 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -157,6 +157,7 @@ def forward( value: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor], attn_metadata: PallasMetadata, + fp8_out_scale: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 02bff57a62b7..26059605c51f 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -22,7 +22,7 @@ logger = init_logger(__name__) -_PARTITION_SIZE_ROCM = 512 +_PARTITION_SIZE_ROCM = 256 _GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName _ON_NAVI = "gfx1" in _GPU_ARCH _ON_MI250_MI300 = any(arch in _GPU_ARCH @@ -491,7 +491,7 @@ def __init__( f"Head size {head_size} is not supported by PagedAttention. " f"Supported head sizes are: {supported_head_sizes}.") - self.use_naive_attn = False + self.use_naive_attn = envs.VLLM_USE_SDPA_ATTENTION # Default False # NOTE: Allow for switching between Triton and CK. Defaulting to triton. self.use_triton_flash_attn = envs.VLLM_USE_TRITON_FLASH_ATTN if self.use_triton_flash_attn: @@ -551,6 +551,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: ROCmFlashAttentionMetadata, + fp8_out_scale: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -680,6 +681,13 @@ def forward( query.dtype, seq_lens, make_attn_mask=False) # type: ignore + full_scales = ( + layer._q_scale.item(), layer._k_scale.item(), + layer._v_scale.item(), layer._prob_scale.item(), + fp8_out_scale.item()) if ( + fp8_out_scale and layer._q_scale + and layer._prob_scale + and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) else None out, _ = self.attn_func( query, key, @@ -693,6 +701,7 @@ def forward( self.scale, attn_masks[0][None] if attn_masks is not None else None, + full_scales, ) elif self.use_naive_attn: if self.num_kv_heads != self.num_heads: @@ -792,10 +801,16 @@ def forward( device=output.device, ) max_logits = torch.empty_like(exp_sums) + cpa_fp8_out = False if num_prefill_tokens > 0: out = output[num_prefill_tokens:] else: - out = output + if fp8_out_scale is not None: + out = torch.empty_like(output, + dtype=torch.float8_e4m3fnuz) + cpa_fp8_out = True + else: + out = output ops.paged_attention_rocm( out, exp_sums, @@ -818,7 +833,11 @@ def forward( self.kv_cache_dtype, layer._k_scale, layer._v_scale, + fp8_out_scale if cpa_fp8_out else None, + _PARTITION_SIZE_ROCM, ) + if cpa_fp8_out: + return out.view(num_seqs, num_heads * head_size) else: output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, @@ -854,6 +873,7 @@ def _sdpa_attention( num_heads: int, head_size: int, scale: float, + is_causal: bool, attn_masks: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: start = 0 @@ -871,7 +891,7 @@ def _sdpa_attention( key[:, start:end, :], value[:, start:end, :], dropout_p=0.0, - is_causal=attn_masks is None, + is_causal=is_causal, attn_mask=attn_masks[i] if attn_masks else None, scale=scale).movedim(query.dim() - 2, 0) output[start:end, :, :] = sub_out @@ -888,4 +908,5 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) + and (gqa_ratio >= 1 and gqa_ratio <= 16) + and max_seq_len <= 128 * 1024) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 25fe6ed95c5d..dbbaa415dd12 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -441,6 +441,7 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: TorchSDPAMetadata, # type: ignore + fp8_out_scale: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index ad53e4e70b0f..0adfab4ed085 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -230,9 +230,18 @@ def build(self, seq_lens: List[int], query_lens: List[int], # The shape of graph_block_tables is # [max batch size, max context len // block size]. input_block_tables = self.runner.graph_block_tables[:batch_size] + max_blocks = input_block_tables.shape[1] for i, block_table in enumerate(self.block_tables): if block_table: - input_block_tables[i, :len(block_table)] = block_table + num_blocks = len(block_table) + if num_blocks <= max_blocks: + input_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + input_block_tables[ + i, :max_blocks] = block_table[:max_blocks] block_tables = torch.from_numpy(input_block_tables).to( device, non_blocking=True) else: diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 723a4558d0b3..93aa94e35a85 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -425,6 +425,7 @@ def forward( value: Optional[torch.Tensor], kv_cache: torch.Tensor, attn_metadata: "XFormersMetadata", + fp8_out_scale: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 19ee89630ffa..29254af64a10 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -79,6 +79,8 @@ def __init__( self.calculate_kv_scales = calculate_kv_scales self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32) + self._q_scale = torch.tensor(1.0, dtype=torch.float32) + self._prob_scale = torch.tensor(1.0, dtype=torch.float32) # We also keep the float32 versions of k/v_scale for attention # backends that don't support tensors (Flashinfer) @@ -145,6 +147,7 @@ def __init__( ).parallel_config.pipeline_parallel_size) ] + self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) @@ -155,10 +158,11 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, + fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: if self.calculate_kv_scales and \ attn_metadata.enable_kv_scales_calculation: - self.calc_kv_scales(key, value) + self.calc_kv_scales(query, key, value) if self.use_output: output = torch.empty_like(query) hidden_size = query.size(-1) @@ -173,19 +177,21 @@ def forward( value = value.view(-1, self.num_kv_heads, self.head_size) if self.use_direct_call: unified_attention_with_output(query, key, value, output, - self.layer_name) + self.layer_name, fp8_out_scale) else: torch.ops.vllm.unified_attention_with_output( - query, key, value, output, self.layer_name) + query, key, value, output, self.layer_name, fp8_out_scale) return output.view(-1, hidden_size) else: if self.use_direct_call: - return unified_attention(query, key, value, self.layer_name) + return unified_attention(query, key, value, self.layer_name, + fp8_out_scale) else: return torch.ops.vllm.unified_attention( - query, key, value, self.layer_name) + query, key, value, self.layer_name, fp8_out_scale) - def calc_kv_scales(self, key, value): + def calc_kv_scales(self, query, key, value): + self._q_scale.copy_(torch.abs(query).max() / self.q_range) self._k_scale.copy_(torch.abs(key).max() / self.k_range) self._v_scale.copy_(torch.abs(value).max() / self.v_range) self._k_scale_float = self._k_scale.item() @@ -283,12 +289,14 @@ def unified_attention( key: torch.Tensor, value: torch.Tensor, layer_name: str, + fp8_out_scale: Optional[torch.Tensor], ) -> torch.Tensor: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata self = forward_context.attn_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] - return self.impl.forward(self, query, key, value, kv_cache, attn_metadata) + return self.impl.forward(self, query, key, value, kv_cache, attn_metadata, + fp8_out_scale) def unified_attention_fake( @@ -296,6 +304,7 @@ def unified_attention_fake( key: torch.Tensor, value: torch.Tensor, layer_name: str, + fp8_out_scale: Optional[torch.Tensor], ) -> torch.Tensor: return torch.empty_like(query).contiguous() @@ -315,6 +324,7 @@ def unified_attention_with_output( value: torch.Tensor, output: torch.Tensor, layer_name: str, + fp8_out_scale: Optional[torch.Tensor], ) -> None: forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata @@ -326,6 +336,7 @@ def unified_attention_with_output( value, kv_cache, attn_metadata, + fp8_out_scale, output=output) @@ -335,6 +346,7 @@ def unified_attention_with_output_fake( value: torch.Tensor, output: torch.Tensor, layer_name: str, + fp8_out_scale: Optional[torch.Tensor], ) -> None: return diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 2c60bd0c38d6..ab63efeba499 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -6,13 +6,15 @@ import torch from vllm import _custom_ops as ops +from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON +from vllm.utils import is_navi if HAS_TRITON: from vllm.attention.ops.prefix_prefill import context_attention_fwd # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. -_PARTITION_SIZE = 512 +_PARTITION_SIZE = 512 if not current_platform.is_rocm() or is_navi() else 1024 @dataclass diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index fbb6757ee304..68bf368e30a7 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -739,7 +739,9 @@ def context_attention_fwd(q, assert (v_cache.dtype == torch.uint8) if kv_cache_dtype in ("fp8", "fp8_e4m3"): - target_dtype = torch.float8_e4m3fn + target_dtype = (torch.float8_e4m3fn + if not current_platform.is_rocm() else + torch.float8_e4m3fnuz) elif kv_cache_dtype == "fp8_e5m2": target_dtype = torch.float8_e5m2 else: diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index 745818eb6cff..12641996188f 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -25,6 +25,8 @@ import triton import triton.language as tl +from vllm.utils import is_navi + torch_dtype: tl.constexpr = torch.float16 @@ -105,6 +107,9 @@ def _attn_fwd_inner( ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, + USE_FP8: tl.constexpr, + qk_scale, + p_descale, ): # loop over k, v, and update accumulator for start_n in range(block_min, block_max, BLOCK_N): @@ -146,6 +151,8 @@ def _attn_fwd_inner( qk = tl.where(causal_mask, qk, float("-inf")) # -- compute qk ---- qk += tl.dot(q, k) + if USE_FP8: + qk *= qk_scale if bias_ptr is not None: bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero") @@ -197,7 +204,12 @@ def _attn_fwd_inner( l_i = l_i * alpha + l_ij # update m_i and l_i m_i = m_ij + + if USE_FP8: + p *= p_descale + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) if bias_ptr is not None: @@ -208,103 +220,183 @@ def _attn_fwd_inner( return acc, l_i, m_i -@triton.autotune( - configs=[ +def get_cdna_autotune_configs(): + return [ triton.Config( { - "BLOCK_M": 256, - "BLOCK_N": 64, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 256, + 'BLOCK_N': 64, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 256, - "BLOCK_N": 128, - "waves_per_eu": 2, - "PRE_LOAD_V": False, + 'BLOCK_M': 256, + 'BLOCK_N': 128, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 1, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 1, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": True, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'PRE_LOAD_V': True }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 128, - "BLOCK_N": 64, - "waves_per_eu": 3, - "PRE_LOAD_V": False, + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'waves_per_eu': 3, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), + num_warps=4), triton.Config( { - "BLOCK_M": 64, - "BLOCK_N": 64, - "waves_per_eu": 4, - "PRE_LOAD_V": False, + 'BLOCK_M': 64, + 'BLOCK_N': 64, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), triton.Config( { - "BLOCK_M": 32, - "BLOCK_N": 32, - "waves_per_eu": 4, - "PRE_LOAD_V": False, + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=8, - ), + num_warps=8), # TODO: This config fails with head_size not pow2 with data mismatches. # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, # 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + + # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: + # triton.Config( + # { + # "BLOCK_M": 16, + # "BLOCK_N": 16, + # "waves_per_eu": 1, + # "PRE_LOAD_V": False, + # }, + # num_stages=1, + # num_warps=4, + # ), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] + + +def get_rdna_autotune_configs(): + return [ + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 32, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), triton.Config( { - "BLOCK_M": 16, - "BLOCK_N": 16, - "waves_per_eu": 1, - "PRE_LOAD_V": False, + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 4, + 'PRE_LOAD_V': False }, num_stages=1, - num_warps=4, - ), - ], - key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], + num_warps=2), + triton.Config( + { + 'BLOCK_M': 32, + 'BLOCK_N': 16, + 'waves_per_eu': 2, + 'PRE_LOAD_V': False + }, + num_stages=1, + num_warps=2), + # Fails in AccelerateAMDMatmul (Triton) assert when using FP8: + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 4, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 2, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + # # Fall-back config. + # triton.Config( + # { + # 'BLOCK_M': 16, + # 'BLOCK_N': 16, + # 'waves_per_eu': 1, + # 'PRE_LOAD_V': False + # }, + # num_stages=1, + # num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL', 'USE_FP8'] + + +def get_autotune_configs(): + if is_navi(): + return get_rdna_autotune_configs() + else: + return get_cdna_autotune_configs() + + +autotune_configs, autotune_keys = get_autotune_configs() + +float8_info = torch.finfo(torch.float8_e4m3fnuz) + + +@triton.autotune( + configs=autotune_configs, + key=autotune_keys, ) @triton.jit def attn_fwd( @@ -313,28 +405,34 @@ def attn_fwd( V, bias, sm_scale, + q_scale, + k_scale, + v_scale, + p_scale, + p_descale, + o_descale, L, Out, - stride_qz, - stride_qh, - stride_qm, - stride_qk, - stride_kz, - stride_kh, - stride_kn, - stride_kk, - stride_vz, - stride_vh, - stride_vk, - stride_vn, - stride_oz, - stride_oh, - stride_om, - stride_on, - stride_bz, - stride_bh, - stride_bm, - stride_bn, + stride_qz: tl.int64, + stride_qh: tl.int64, + stride_qm: tl.int64, + stride_qk: tl.int64, + stride_kz: tl.int64, + stride_kh: tl.int64, + stride_kn: tl.int64, + stride_kk: tl.int64, + stride_vz: tl.int64, + stride_vh: tl.int64, + stride_vk: tl.int64, + stride_vn: tl.int64, + stride_oz: tl.int64, + stride_oh: tl.int64, + stride_om: tl.int64, + stride_on: tl.int64, + stride_bz: tl.int64, + stride_bh: tl.int64, + stride_bm: tl.int64, + stride_bn: tl.int64, cu_seqlens_q, cu_seqlens_k, dropout_p, @@ -350,11 +448,14 @@ def attn_fwd( IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + USE_FP8: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, BIAS_TYPE: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, + FP8_MIN: tl.constexpr = float8_info.min, + FP8_MAX: tl.constexpr = float8_info.max, ): start_m = tl.program_id(0) off_h_q = tl.program_id(1) @@ -508,7 +609,12 @@ def attn_fwd( qk_scale = sm_scale * 1.44269504089 # Q is loaded once at the beginning and shared by all N blocks. q = load_fn(Q_block_ptr, True, padded_head, "zero") - q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + if not USE_FP8: + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + acc_scale = 1.0 + else: + qk_scale *= q_scale * k_scale + acc_scale = p_scale * v_scale # Here we compute how many full and masked blocks we have. padded_block_k = n_extra_tokens != 0 @@ -563,6 +669,9 @@ def attn_fwd( ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head, + USE_FP8, + qk_scale, + p_descale, ) block_min = block_max block_max = n_blocks * BLOCK_N @@ -609,8 +718,14 @@ def attn_fwd( ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head, + USE_FP8, + qk_scale, + p_descale, ) # epilogue + + if USE_FP8: + acc *= acc_scale acc = acc / l_i[:, None] if ENABLE_DROPOUT: acc = acc / (1 - dropout_p) @@ -621,6 +736,9 @@ def attn_fwd( end_m_idx = (start_m + 1) * BLOCK_M start_m_idx = start_m * BLOCK_M causal_start_idx = seqlen_q - seqlen_k + if USE_FP8: + acc *= o_descale + acc = tl.clamp(acc, FP8_MIN, FP8_MAX) acc = acc.to(Out.type.element_ty) if IS_CAUSAL: # noqa: SIM102 if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: @@ -630,7 +748,7 @@ def attn_fwd( mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) out_ptrs_mask = (mask_m_offsets[:, None] >= out_mask_boundary[None, :]) - z = 0.0 + z = tl.zeros((1, ), tl.float32) acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m @@ -711,7 +829,29 @@ def forward( causal=False, sm_scale=1.0, bias=None, + fp8_scales=None, ): + if fp8_scales is not None: + use_fp8 = True + (q_scale, k_scale, v_scale, p_scale, o_scale) = fp8_scales + float8 = torch.float8_e4m3fnuz + + def check_and_convert(t, scale): + if t.dtype != float8: + descale = 1.0 / scale + ts = (t * descale).clamp(min=float8_info.min, + max=float8_info.max) + return ts.to(float8) + else: + return t + + q = check_and_convert(q, q_scale) + k = check_and_convert(k, k_scale) + v = check_and_convert(v, v_scale) + else: + use_fp8 = False + q_scale = k_scale = v_scale = p_scale = o_scale = 1.0 + if o is None: o = torch.empty_like(q, dtype=v.dtype) @@ -774,12 +914,24 @@ def forward( else: bias_strides = (0, 0, 0, 0) + p_descale = 1.0 / p_scale + o_descale = 1.0 / o_scale + + arg_max_seqlens_q = 0 if is_navi() else max_seqlens_q + arg_max_seqlens_k = 0 if is_navi() else max_seqlens_k + attn_fwd[grid]( q, k, v, bias, sm_scale, + q_scale, + k_scale, + v_scale, + p_scale, + p_descale, + o_descale, None, o, *q_strides, @@ -796,14 +948,15 @@ def forward( HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, - MAX_SEQLENS_Q=max_seqlens_q, - MAX_SEQLENS_K=max_seqlens_k, + MAX_SEQLENS_Q=arg_max_seqlens_q, + MAX_SEQLENS_K=arg_max_seqlens_k, IS_CAUSAL=causal, VARLEN=True, BLOCK_DMODEL=padded_d_model, BIAS_TYPE=0 if bias is None else 1, ENABLE_DROPOUT=False, RETURN_ENCODED_SOFTMAX=False, + USE_FP8=use_fp8, ) ctx.grid = grid diff --git a/vllm/config.py b/vllm/config.py index d70a637956ed..c3a9a1c6d7ed 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -34,7 +34,8 @@ from vllm.transformers_utils.s3_utils import S3Model from vllm.transformers_utils.utils import is_s3 from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless, - get_cpu_memory, random_uuid, resolve_obj_by_qualname) + get_cpu_memory, is_mi250, is_navi, random_uuid, + resolve_obj_by_qualname) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -592,10 +593,12 @@ def _verify_quantization(self) -> None: # Detect which checkpoint is it for name in QUANTIZATION_METHODS: + from vllm.platforms import current_platform method = get_quantization_config(name) quantization_override = method.override_quantization_method( quant_cfg, self.quantization) - if quantization_override: + if (quantization_override and quantization_override + in current_platform.supported_quantization): quant_method = quantization_override self.quantization = quantization_override break @@ -1401,6 +1404,30 @@ def __post_init__(self) -> None: self._verify_args() + if is_mi250() and self.tensor_parallel_size > 1: + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "working correctly on multi AMD MI250.") + + if is_navi() and self.tensor_parallel_size <= 2: + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "working correctly when using two AMD Navi GPUs.") + + if is_mi250() and self.tensor_parallel_size > 1: + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "working correctly on multi AMD MI250.") + + if is_navi() and self.tensor_parallel_size <= 2: + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "working correctly when using two AMD Navi GPUs.") + @property def use_ray(self) -> bool: return self.distributed_executor_backend == "ray" or ( @@ -1410,7 +1437,6 @@ def use_ray(self) -> bool: def _verify_args(self) -> None: # Lazy import to avoid circular import from vllm.executor.executor_base import ExecutorBase - from vllm.platforms import current_platform if self.distributed_executor_backend not in ( "ray", "mp", "uni", "external_launcher", None) and not (isinstance( @@ -1424,11 +1450,12 @@ def _verify_args(self) -> None: if self.use_ray: from vllm.executor import ray_utils ray_utils.assert_ray_available() - if current_platform.is_rocm(): + if (not self.disable_custom_all_reduce and self.world_size > 1 + and self.pipeline_parallel_size > 1): self.disable_custom_all_reduce = True logger.info( "Disabled the custom all-reduce kernel because it is not " - "supported on AMD GPUs.") + "supported with pipeline parallelism.") if self.ray_workers_use_nsight and not self.use_ray: raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index a2614ed5d0bd..37007e48e041 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -1,8 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 -import ctypes from contextlib import contextmanager -from typing import List, Optional, Union +from typing import Any, List, Optional, Union import torch import torch.distributed as dist @@ -10,7 +9,6 @@ import vllm.envs as envs from vllm import _custom_ops as ops -from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary from vllm.distributed.device_communicators.custom_all_reduce_utils import ( gpu_p2p_access_check) from vllm.distributed.parallel_state import in_the_same_node_as @@ -22,7 +20,7 @@ ops.meta_size() custom_ar = True except Exception: - # For AMD GPUs and CPUs + # For CPUs custom_ar = False logger = init_logger(__name__) @@ -52,10 +50,15 @@ class CustomAllreduce: _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] # max_size: max supported allreduce size + _MAX_CAR_SIZE = 8192 * 1024 + if current_platform.is_rocm(): + # crossover is at 16MB buffer size for ROCm + _MAX_CAR_SIZE = 2 * 8192 * 1024 + def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], - max_size=8192 * 1024) -> None: + max_size=_MAX_CAR_SIZE) -> None: """ Args: group: the process group to work on. If None, it will use the @@ -128,10 +131,8 @@ def __init__(self, # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - assert current_platform.is_cuda() - from vllm.platforms.cuda import CudaPlatform - cuda_platform: CudaPlatform = current_platform - full_nvlink = cuda_platform.is_full_nvlink(physical_device_ids) + assert current_platform.is_cuda() or current_platform.is_rocm() + full_nvlink = current_platform.is_full_nvlink(physical_device_ids) if world_size > 2 and not full_nvlink: logger.warning( "Custom allreduce is disabled because it's not supported on" @@ -141,7 +142,8 @@ def __init__(self, # test P2P capability, this checks software/cudaruntime support # this is expensive to compute at the first time # then we cache the result - if not _can_p2p(rank, world_size): + # On AMD GPU, p2p is always enabled between XGMI connected GPUs + if not current_platform.is_rocm() and not _can_p2p(rank, world_size): logger.warning( "Custom allreduce is disabled because your platform lacks " "GPU P2P capability or P2P test failed. To silence this " @@ -149,14 +151,22 @@ def __init__(self, return self.disabled = False - # Buffers memory are owned by this Python class and passed to C++. - # Meta data composes of two parts: meta data for synchronization and a - # temporary buffer for storing intermediate allreduce results. - self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size, - group=group) + # buffers memory are owned by this Python class and passed to C++ + # meta data composes of two parts: meta data for synchronization + # (256 bytes) and a temporary buffer for storing intermediate + # allreduce results. + if current_platform.is_rocm(): + # meta data buffers need to be "uncached" for signal on MI200 + self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size) + else: + self.meta = torch.zeros(ops.meta_size() + max_size, + dtype=torch.uint8, + device=self.device) # This is a pre-registered IPC buffer. In eager mode, input tensors # are first copied into this buffer before allreduce is performed - self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + self.buffer = torch.empty(max_size, + dtype=torch.uint8, + device=self.device) # This is a buffer for storing the tuples of pointers pointing to # IPC buffers from all ranks. Each registered tuple has size of # 8*world_size bytes where world_size is at most 8. Allocating 8MB @@ -168,43 +178,21 @@ def __init__(self, self.max_size = max_size self.rank = rank self.world_size = world_size + if current_platform.is_rocm(): + # _share_cuda_() doesn't accept meta buffer not allocated from + # PyTorch cache allocator, use direct HIP call to get IPC handle + handle = ops.get_meta_buffer_ipc_handle(self.meta) + shard_data = ( + bytes(handle), # ipc handle to base ptr + 0, # offset of base ptr + ) + handles, offsets = self._gather_ipc_meta(shard_data) + else: + handles, offsets = self._get_ipc_meta(self.meta) self.full_nvlink = full_nvlink - self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank, - self.full_nvlink) - ops.register_buffer(self._ptr, self.buffer_ptrs) - - @staticmethod - def create_shared_buffer( - size_in_bytes: int, - group: Optional[ProcessGroup] = None) -> List[int]: - """ - Creates a shared buffer and returns a list of pointers - representing the buffer on all processes in the group. - """ - lib = CudaRTLibrary() - pointer = lib.cudaMalloc(size_in_bytes) - handle = lib.cudaIpcGetMemHandle(pointer) - world_size = dist.get_world_size(group=group) - rank = dist.get_rank(group=group) - handles = [None] * world_size - dist.all_gather_object(handles, handle, group=group) - - pointers: List[int] = [] - for i, h in enumerate(handles): - if i == rank: - pointers.append(pointer.value) # type: ignore - else: - pointers.append( - lib.cudaIpcOpenMemHandle(h).value) # type: ignore - - return pointers - - @staticmethod - def free_shared_buffer(pointers: List[int], - group: Optional[ProcessGroup] = None) -> None: - rank = dist.get_rank(group=group) - lib = CudaRTLibrary() - lib.cudaFree(ctypes.c_void_p(pointers[rank])) + self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles, + offsets, rank, self.full_nvlink) + self.register_buffer(self.buffer) @contextmanager def capture(self): @@ -221,24 +209,69 @@ def capture(self): if not self.disabled: self.register_graph_buffers() - def register_graph_buffers(self): - handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) - logger.info("Registering %d cuda graph addresses", len(offset)) - # We cannot directly use `dist.all_gather_object` here - # because it is incompatible with `gloo` backend under inference mode. - # see https://github.com/pytorch/pytorch/issues/126032 for details. - all_data = [[None, None] - for _ in range(dist.get_world_size(group=self.group))] - all_data[self.rank] = [handle, offset] - ranks = sorted(dist.get_process_group_ranks(group=self.group)) + def _get_ipc_meta(self, inp: torch.Tensor): + if current_platform.is_rocm(): + # _share_cuda_() doesn't accept meta buffer not allocated from + # PyTorch cache allocator, use direct HIP call to get IPC handle + handle = ops.get_meta_buffer_ipc_handle(inp) + shard_data = ( + bytes(handle), # ipc handle to base ptr + 0, # offset of base ptr + ) + else: + data = inp.untyped_storage()._share_cuda_() + handle = data[1] + # https://github.com/pytorch/pytorch/pull/130890 changes + # the binary format of the ipc handle + # it starts from pytorch 2.5 + if len(handle) > 64: + assert len(handle) == 66 + # only support SHAREABLE_HANDLE_VERSION = 1 + assert int(handle[0]) == 1 + # only support SHAREABLE_CUDA_MALLOC = 'c' + assert handle[1] == ord("c") + handle = handle[2:] + # TODO: support expandable segment + shard_data = ( + handle, # ipc handle to base ptr + data[3], # offset of base ptr + ) + return self._gather_ipc_meta(shard_data) + + def _gather_ipc_meta(self, shard_data): + # Note: don't use `[[None]] * self.world_size` here + # because it will create a list of the same reference + all_data: List[Optional[Any]] = [[None] + for i in range(self.world_size)] + all_data[self.rank][0] = shard_data + + ranks = dist.get_process_group_ranks(group=self.group) + ranks.sort() for i, rank in enumerate(ranks): dist.broadcast_object_list(all_data[i], src=rank, group=self.group, device="cpu") - # Unpack list of tuples to tuple of lists. - handles = [d[0] for d in all_data] # type: ignore - offsets = [d[1] for d in all_data] # type: ignore + + # we cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. + + handles = [] + offsets = [] + for i in range(len(all_data)): + handles.append(all_data[i][0][0]) # type: ignore + offsets.append(all_data[i][0][1]) # type: ignore + return handles, offsets + + def register_buffer(self, inp: torch.Tensor): + handles, offsets = self._get_ipc_meta(inp) + ops.register_buffer(self._ptr, inp, handles, offsets) + + def register_graph_buffers(self): + handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) + handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) + logger.info("Registering %d cuda graph addresses", len(offset)) ops.register_graph_buffers(self._ptr, handles, offsets) def should_custom_ar(self, inp: torch.Tensor): @@ -256,50 +289,45 @@ def should_custom_ar(self, inp: torch.Tensor): return inp_size < self.max_size return False - def all_reduce(self, - inp: torch.Tensor, - *, - out: torch.Tensor = None, - registered: bool = False): - """Performs an out-of-place all reduce. - - If registered is True, this assumes inp's pointer is already - IPC-registered. Otherwise, inp is first copied into a pre-registered - buffer. - """ + # all reduce, assuming inp tensor is IPC registered with register_buffer, + # or, in the context of cuda graphs, register_graph_buffers + def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): if out is None: out = torch.empty_like(inp) - if registered: - ops.all_reduce(self._ptr, inp, out, 0, 0) - else: - ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank], - self.max_size) + ops.all_reduce_reg(self._ptr, inp, out) + return out + + # all reduce, assuming inp tensor is NOT IPC registered + def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): + if out is None: + out = torch.empty_like(inp) + ops.all_reduce_unreg(self._ptr, inp, self.buffer, out) return out def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: - """The main allreduce API that provides support for cuda graph.""" - # When custom allreduce is disabled, this will be None. + # when custom allreduce is disabled, this will be None if self.disabled or not self.should_custom_ar(input): return None if self._IS_CAPTURING: if torch.cuda.is_current_stream_capturing(): - return self.all_reduce(input, registered=True) + return self.all_reduce_reg(input) else: - # If warm up, mimic the allocation pattern since custom - # allreduce is out-of-place. + # if warm up, mimic the allocation pattern + # since custom allreduce is out-of-place return torch.empty_like(input) else: - # Note: outside of cuda graph context, custom allreduce incurs a - # cost of cudaMemcpy, which should be small (<=1% of overall - # latency) compared to the performance gain of using custom kernels - return self.all_reduce(input, registered=False) + # note: outside of cuda graph context, + # custom allreduce incurs a cost of cudaMemcpy, which should + # be small(<=1% of overall latency) compared to the performance + # gains of using custom kernels + return self.all_reduce_unreg(input) + + return None def close(self): if not self.disabled and self._ptr: ops.dispose(self._ptr) self._ptr = 0 - self.free_shared_buffer(self.meta_ptrs) - self.free_shared_buffer(self.buffer_ptrs) def __del__(self): self.close() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 0ccd423121cb..4d2fd67d47d6 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 +from contextlib import contextmanager from typing import Optional, Union # ===================== import region ===================== @@ -12,7 +13,6 @@ ncclRedOpTypeEnum, ncclUniqueId) from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import current_stream logger = init_logger(__name__) @@ -53,6 +53,7 @@ def __init__( if self.world_size == 1: self.available = False self.disabled = True + self.stream = None return try: self.nccl = NCCLLibrary(library_path) @@ -61,6 +62,7 @@ def __init__( # e.g. in a non-GPU environment self.available = False self.disabled = True + self.stream = None return self.available = True @@ -98,12 +100,12 @@ def __init__( with torch.cuda.device(device): self.comm: ncclComm_t = self.nccl.ncclCommInitRank( self.world_size, self.unique_id, self.rank) + self.stream = torch.cuda.Stream() - stream = current_stream() # A small all_reduce for warmup. data = torch.zeros(1, device=device) self.all_reduce(data) - stream.synchronize() + self.stream.synchronize() del data def all_reduce(self, @@ -122,7 +124,7 @@ def all_reduce(self, out_tensor = torch.empty_like(in_tensor) if stream is None: - stream = current_stream() + stream = self.stream self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), buffer_type(out_tensor.data_ptr()), in_tensor.numel(), @@ -144,7 +146,7 @@ def all_gather(self, f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}") if stream is None: - stream = current_stream() + stream = self.stream self.nccl.ncclAllGather( buffer_type(input_tensor.data_ptr()), buffer_type(output_tensor.data_ptr()), input_tensor.numel(), @@ -165,7 +167,7 @@ def reduce_scatter(self, f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {input_tensor.device}") if stream is None: - stream = current_stream() + stream = self.stream self.nccl.ncclReduceScatter( buffer_type(input_tensor.data_ptr()), buffer_type(output_tensor.data_ptr()), output_tensor.numel(), @@ -180,7 +182,7 @@ def send(self, tensor: torch.Tensor, dst: int, stream=None): f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: - stream = current_stream() + stream = self.stream self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), dst, self.comm, cudaStream_t(stream.cuda_stream)) @@ -192,7 +194,7 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: - stream = current_stream() + stream = self.stream self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) @@ -204,7 +206,7 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): f"this nccl communicator is created to work on {self.device}, " f"but the input tensor is on {tensor.device}") if stream is None: - stream = current_stream() + stream = self.stream if src == self.rank: sendbuff = buffer_type(tensor.data_ptr()) # NCCL requires the sender also to have a receive buffer @@ -215,3 +217,27 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None): self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) + + @contextmanager + def change_state(self, + enable: Optional[bool] = None, + stream: Optional[torch.cuda.Stream] = None): + """ + A context manager to change the state of the communicator. + """ + if enable is None: + # guess a default value when not specified + enable = self.available + + if stream is None: + stream = self.stream + + old_disable = self.disabled + old_stream = self.stream + + self.stream = stream + self.disabled = not enable + yield + + self.disabled = old_disable + self.stream = old_stream diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index c5c5dfbbab76..b714bb452e5e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -41,7 +41,8 @@ import vllm.envs as envs from vllm.distributed.utils import StatelessProcessGroup from vllm.logger import init_logger -from vllm.utils import direct_register_custom_op, supports_custom_op +from vllm.utils import (current_stream, direct_register_custom_op, + supports_custom_op) if TYPE_CHECKING: from vllm.config import VllmConfig @@ -307,7 +308,15 @@ def graph_capture( stream.wait_stream(curr_stream) with torch.cuda.stream(stream), maybe_ca_context: - yield graph_capture_context + pynccl_comm = self.pynccl_comm + maybe_pynccl_context: Any + if not pynccl_comm: + maybe_pynccl_context = nullcontext() + else: + maybe_pynccl_context = pynccl_comm.change_state( + stream=torch.cuda.current_stream()) + with maybe_pynccl_context: + yield graph_capture_context def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: """ @@ -359,7 +368,7 @@ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: return out pynccl_comm = self.pynccl_comm assert pynccl_comm is not None - out = pynccl_comm.all_reduce(input_) + out = pynccl_comm.all_reduce(input_, stream=current_stream()) if out is None: # fall back to the default all-reduce using PyTorch. # this usually happens during testing. diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d82d9ad9df32..7471476273dc 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1211,7 +1211,7 @@ def _process_model_outputs(self, return None def _advance_to_next_step( - self, output: List[SamplerOutput], + self, output: SamplerOutput, seq_group_metadata_list: List[SequenceGroupMetadata], scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: """Given model output from a single run, append the tokens to the diff --git a/vllm/envs.py b/vllm/envs.py index 5018f6deb7f4..7543ef76dfb2 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -12,7 +12,15 @@ VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60 VLLM_NCCL_SO_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None - VLLM_USE_TRITON_FLASH_ATTN: bool = False + VLLM_ROCM_PREFER_TORCH: bool = False + VLLM_ROCM_PREFER_TRITON: bool = True + VLLM_USE_SDPA_ATTENTION: bool = False + VLLM_USE_TRITON_FLASH_ATTN: bool = True + VLLM_USE_ROCM_SKINNY_GEMM: bool = True + VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True + VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = True + VLLM_USE_ROCM_FP8_FLASH_ATTN: bool = False + RANK: int = 0 VLLM_FLASH_ATTN_VERSION: Optional[int] = None LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -67,16 +75,20 @@ VLLM_RPC_TIMEOUT: int = 10000 # ms VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None + VLLM_RPD_PROFILER_DIR: Optional[str] = None VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_SKIP_P2P_CHECK: bool = False VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False + VLLM_MOE_PADDING: bool = False + VLLM_FP8_PADDING: bool = True VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False - K_SCALE_CONSTANT: int = 200 - V_SCALE_CONSTANT: int = 100 + Q_SCALE_CONSTANT: int = 20 + K_SCALE_CONSTANT: int = 20 + V_SCALE_CONSTANT: int = 10 VLLM_SERVER_DEV_MODE: bool = False VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_MLA_DISABLE: bool = False @@ -213,6 +225,21 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "LD_LIBRARY_PATH": lambda: os.environ.get("LD_LIBRARY_PATH", None), + # flag to tell vllm to prefer torch on ROCm + "VLLM_ROCM_PREFER_TORCH": + lambda: (os.environ.get("VLLM_ROCM_PREFER_TORCH", "False").lower() in + ("true", "1")), + + # flag to tell vllm to prefer triton on ROCm + "VLLM_ROCM_PREFER_TRITON": + lambda: (os.environ.get("VLLM_ROCM_PREFER_TRITON", "True").lower() in + ("true", "1")), + + # flag to control if vllm should use naive scaled dot-product attention + "VLLM_USE_SDPA_ATTENTION": + lambda: (os.environ.get("VLLM_USE_SDPA_ATTENTION", "False").lower() in + ("true", "1")), + # flag to control if vllm should use triton flash attention "VLLM_USE_TRITON_FLASH_ATTN": lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in @@ -228,6 +255,32 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: bool( os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"), + # small gemms custom implementation for MI3* cards + "VLLM_USE_ROCM_SKINNY_GEMM": + lambda: (os.getenv("VLLM_USE_ROCM_SKINNY_GEMM", "True").lower() in + ("true", "1")), + + # custom paged attention implemented for MI3* cards + "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN": + lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in + ("true", "1")), + + # have custom paged attention implemented for MI3* cards write out fp8 + "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT": + lambda: + (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "True").lower() in + ("true", "1")), + + # use quantized q,k,v,softmax(qk^T), attn output during prefill + "VLLM_USE_ROCM_FP8_FLASH_ATTN": + lambda: (os.getenv("VLLM_USE_ROCM_FP8_FLASH_ATTN", "False").lower() in + ("true", "1")), + + # rank of the process in the distributed setting, used to determine + # the driver worker + "RANK": + lambda: int(os.environ.get("RANK", "0")), + # local rank of the process in the distributed setting, used to determine # the GPU device id "LOCAL_RANK": @@ -406,7 +459,7 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: os.path.join(get_default_cache_root(), "vllm", "xla_cache"), )), "VLLM_FUSED_MOE_CHUNK_SIZE": - lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), + lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")), # If set, vllm will skip the deprecation warnings. "VLLM_NO_DEPRECATION_WARNING": @@ -453,6 +506,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os .path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), + # Enables rpd profiler if set. Path to the directory where torch profiler + # traces are saved. Note that it must be an absolute path. + "VLLM_RPD_PROFILER_DIR": + lambda: (None if os.getenv("VLLM_RPD_PROFILER_DIR", None) is None else os. + path.expanduser(os.getenv("VLLM_RPD_PROFILER_DIR", "."))), + # If set, vLLM will use Triton implementations of AWQ. "VLLM_USE_TRITON_AWQ": lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), @@ -482,13 +541,27 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), - # Divisor for dynamic key scale factor calculation for FP8 KV Cache + # Pad the weight for moe kernel or not + "VLLM_MOE_PADDING": + lambda: bool(int(os.getenv("VLLM_MOE_PADDING", "0"))), + + # Pad the weight for moe kernel or not + "VLLM_FP8_PADDING": + lambda: bool(int(os.getenv("VLLM_FP8_PADDING", "1"))), + + # Divisor for dynamic query scale factor calculation for FP8 attention + "Q_SCALE_CONSTANT": + lambda: int(os.getenv("Q_SCALE_CONSTANT", "20")), + + # Divisor for dynamic key scale factor calculation + # for FP8 KV Cache and attention "K_SCALE_CONSTANT": - lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), + lambda: int(os.getenv("K_SCALE_CONSTANT", "20")), - # Divisor for dynamic value scale factor calculation for FP8 KV Cache + # Divisor for dynamic value scale factor calculation + # for FP8 KV Cache and attention "V_SCALE_CONSTANT": - lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), + lambda: int(os.getenv("V_SCALE_CONSTANT", "10")), # If set, enable multiprocessing in LLM for the V1 code path. "VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index f782920d06a0..5197200f2aaa 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -75,11 +75,21 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] - def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: + def forward_cuda(self, + x: torch.Tensor, + scale: Optional[torch.Tensor] = None) -> torch.Tensor: + d = x.shape[-1] // 2 output_shape = (x.shape[:-1] + (d, )) - out = torch.empty(output_shape, dtype=x.dtype, device=x.device) - self.op(out, x) + if scale is None: + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + self.op(out, x) + else: + # for scaled fp8 output + out = torch.empty(output_shape, + dtype=torch.float8_e4m3fnuz, + device=x.device) + torch.ops._C.scaled_silu_and_mul(out, x, scale) return out def forward_xpu(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py old mode 100644 new mode 100755 index 6f933c3fa3c9..a71628df00c4 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -37,7 +37,7 @@ def get_config() -> Optional[Dict[str, Any]]: import vllm.model_executor.layers.fused_moe.fused_moe # noqa from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts, fused_moe, fused_topk, get_config_file_name, - grouped_topk) + grouped_topk, invoke_fused_moe_kernel, moe_align_block_size) __all__ += [ "fused_moe", @@ -45,4 +45,6 @@ def get_config() -> Optional[Dict[str, Any]]: "fused_experts", "get_config_file_name", "grouped_topk", + "invoke_fused_moe_kernel", + "moe_align_block_size", ] diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json index 66f9106bd1be..022d5ece7f87 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json @@ -45,8 +45,8 @@ }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 2, @@ -96,7 +96,7 @@ "num_stages": 2, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, - "kpack": 2 + "kpack": 1 }, "96": { "BLOCK_SIZE_M": 32, @@ -123,7 +123,7 @@ "256": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 8, "num_stages": 2, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..66aa2600226d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X_OAM.json new file mode 100644 index 000000000000..83be69c7e61f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..f245285bd821 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325_OAM.json new file mode 100644 index 000000000000..3918c93b160a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json new file mode 100644 index 000000000000..5154a3ec25ad --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json @@ -0,0 +1,233 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..58f9e38f5221 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X_OAM.json new file mode 100644 index 000000000000..3ee1a5c267dc --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..16e0a91baf31 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325_OAM.json new file mode 100644 index 000000000000..d766fc062ddc --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json index 1b46cb571651..de2320e4b28c 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json @@ -8,7 +8,7 @@ "num_stages": 2, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, - "kpack": 2 + "kpack": 1 }, "2": { "BLOCK_SIZE_M": 16, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..90d0e6f6ba3f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X_OAM.json new file mode 100644 index 000000000000..193782905a72 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..6d5b1ae5b15f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325_OAM.json new file mode 100644 index 000000000000..ffc1b23ea90d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json new file mode 100644 index 000000000000..5448f4bc3273 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..9d1b36acd64d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X_OAM.json new file mode 100644 index 000000000000..2daaea099d09 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..2758e48fc406 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325_OAM.json new file mode 100644 index 000000000000..fc31215cbae8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Radeon_Graphics.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Radeon_Graphics.json new file mode 100644 index 000000000000..51752fd8a1d2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Radeon_Graphics.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json index ed5b655d8993..5a3f415d5414 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json @@ -45,8 +45,8 @@ }, "16": { "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 2, diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..9e28dade2cee --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X_OAM.json new file mode 100644 index 000000000000..f885cd13a4ad --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..6cb80f48329f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325_OAM.json new file mode 100644 index 000000000000..de9d0aba75a7 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json new file mode 100644 index 000000000000..c5fa16efdedd --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json @@ -0,0 +1,233 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..a971953062cf --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM.json new file mode 100644 index 000000000000..4edf28f3e7c2 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..2c49f359c22a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325_OAM.json new file mode 100644 index 000000000000..c7db6c0cbd3f --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics.json new file mode 100644 index 000000000000..d5d93c53ae24 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics.json @@ -0,0 +1,233 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json index 822f04e33e87..8dec5e3afaba 100644 --- a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json @@ -8,7 +8,7 @@ "num_stages": 2, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, - "kpack": 2 + "kpack": 1 }, "2": { "BLOCK_SIZE_M": 16, @@ -63,7 +63,7 @@ "num_stages": 2, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, - "kpack": 1 + "kpack": 2 }, "32": { "BLOCK_SIZE_M": 16, @@ -128,7 +128,7 @@ "num_warps": 8, "num_stages": 2, "waves_per_eu": 0, - "matrix_instr_nonkdim": 32, + "matrix_instr_nonkdim": 16, "kpack": 2 }, "512": { diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..1bad9550f060 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X_OAM.json new file mode 100644 index 000000000000..f6d70ae78eab --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..7a07bbf41419 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325_OAM.json new file mode 100644 index 000000000000..3a3268cc17a8 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json new file mode 100644 index 000000000000..be63f0128f5a --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json @@ -0,0 +1,233 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..91260051a533 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X_OAM.json new file mode 100644 index 000000000000..d6220f55015d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..c27ca0a36594 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8.json @@ -0,0 +1,164 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0 + }, + "64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "96": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325_OAM.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325_OAM.json new file mode 100644 index 000000000000..da477b1fb15e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325_OAM.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics.json b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics.json new file mode 100644 index 000000000000..f246469d3b55 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics.json @@ -0,0 +1,233 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1bed35525e9d..81501b343138 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -18,6 +18,7 @@ from vllm.utils import direct_register_custom_op logger = init_logger(__name__) +padding_size = 128 if envs.VLLM_MOE_PADDING else 0 @triton.jit @@ -726,7 +727,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, expert_ids, num_tokens_post_padded, B.shape[1], - A.shape[1], + A.shape[1] - padding_size, EM, topk_ids.numel(), A.stride(0), @@ -1137,7 +1138,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, assert hidden_states.shape[1] // 2 == w1.shape[ 2], "Hidden size mismatch" else: - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert hidden_states.shape[ + 1] == w1.shape[2] - padding_size, "Hidden size mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" @@ -1161,7 +1163,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, get_config_func = functools.partial( try_get_optimal_moe_config, w1.shape, - w2.shape, + (w2.shape[0], w2.shape[1], w2.shape[2] - padding_size), topk_ids.shape[1], config_dtype, block_shape=block_shape, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3c7ef5e0080f..05442d415865 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -5,7 +5,9 @@ from typing import Callable, List, Optional, Tuple import torch +import torch.nn.functional as F +import vllm.envs as envs from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) @@ -92,6 +94,17 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) + if envs.VLLM_MOE_PADDING: + layer.w13_weight = torch.nn.Parameter(F.pad( + layer.w13_weight.data, (0, 128), "constant", 0), + requires_grad=False) + torch.cuda.empty_cache() + layer.w2_weight = torch.nn.Parameter(F.pad(layer.w2_weight.data, + (0, 128), "constant", + 0), + requires_grad=False) + torch.cuda.empty_cache() + if current_platform.is_cpu(): if current_platform.get_cpu_architecture() == CpuArchEnum.X86: import intel_extension_for_pytorch as ipex diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index b476fb0dbc7e..93ea30938ff0 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -77,12 +77,24 @@ def forward_cuda( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if self.variance_size_override is not None: return self.forward_native(x, residual) from vllm import _custom_ops as ops + if scale is not None: + out = torch.empty_like(x, dtype=torch.float8_e4m3fnuz) + if residual is not None: + ops.scaled_fused_add_rms_norm(out, x, residual, + self.weight.data, scale, + self.variance_epsilon) + return out, residual + ops.scaled_rms_norm(out, x, self.weight.data, scale, + self.variance_epsilon) + return out + if residual is not None: ops.fused_add_rms_norm( x, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 08f1e103e53b..4d10c7651e85 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -5,7 +5,6 @@ from typing import Dict, List, Optional, Tuple import torch -import torch.nn.functional as F from torch.nn.parameter import Parameter, UninitializedParameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -16,6 +15,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.tuned_gemm import tgemm # yapf: disable from vllm.model_executor.parameter import (BasevLLMParameter, BlockQuantScaleParameter, @@ -138,8 +138,7 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - - return F.linear(x, layer.weight, bias) + return tgemm.mm(x, layer.weight, bias) class LinearBase(torch.nn.Module): diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index ff77af44d770..8cce8c3c91bd 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -4,6 +4,7 @@ import torch +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) @@ -172,8 +173,15 @@ def apply(self, # num_tokens >= threshold FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256 - if FP16_MATMUL_HEURISTIC_CONDITION: - out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) + prefer_torch = envs.VLLM_ROCM_PREFER_TORCH + prefer_triton = envs.VLLM_ROCM_PREFER_TRITON + + if (FP16_MATMUL_HEURISTIC_CONDITION + or (prefer_torch and not prefer_triton)): + if prefer_triton: + out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) + else: + out = torch_awq_dequantize(qweight, scales, qzeros) out = torch.matmul(reshaped_x, out) else: out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, @@ -181,3 +189,44 @@ def apply(self, if bias is not None: out.add_(bias) return out.reshape(out_shape) + + +def torch_awq_dequantize(qweights: torch.Tensor, scales: torch.Tensor, + qzeros: torch.Tensor) -> torch.Tensor: + reverse_awq_func_desc = torch.tensor([0, 16, 4, 20, 8, 24, 12, 28], + dtype=torch.int32, + device=qweights.device) + if qzeros is None: + qzeros = torch.zeros_like(qweights) + + while qweights.dim() < 2: + qweights = torch.unsqueeze(qweights, 0) + while qzeros.dim() < 2: + qzeros = torch.unsqueeze(qzeros, 0) + while scales.dim() < 2: + scales = torch.unsqueeze(scales, 0) + + rows = qweights.size(-2) + group_size_zeros = rows // qzeros.size(-2) + group_size_scales = rows // scales.size(-2) + + qweights_shape = list(qweights.shape) + qweights_shape[-1] *= 8 + qzeros_shape = list(qzeros.shape) + qzeros_shape[-1] *= 8 + + qweights = torch.unsqueeze(qweights, -1) + qzeros = torch.unsqueeze(qzeros, -1) + + unpacked_weights = torch.bitwise_right_shift(qweights, + reverse_awq_func_desc) + unpacked_weights = torch.bitwise_and(unpacked_weights, 0xf) + unpacked_weights = unpacked_weights.to(torch.int8).view(qweights_shape) + + unpacked_zeros = torch.bitwise_right_shift(qzeros, reverse_awq_func_desc) + unpacked_zeros = torch.bitwise_and(unpacked_zeros, 0xf) + unpacked_zeros = unpacked_zeros.to(torch.int8).view(qzeros_shape) + unpacked_zeros = unpacked_zeros.repeat_interleave(group_size_zeros, dim=-2) + + functional_scales = scales.repeat_interleave(group_size_scales, dim=-2) + return (unpacked_weights - unpacked_zeros) * functional_scales diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 24f7542e1238..5b07569164e8 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -451,6 +451,10 @@ def get_cache_scale(self, name: str) -> Optional[str]: return name.replace(".k_proj.output_scale", ".attn.k_scale") if name.endswith(".output_scale") and ".v_proj" in name: return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") # If no matches, return None return None diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py index 5dcc41a9e5da..b841d24e7a19 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py @@ -23,6 +23,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): def __init__(self, strategy: str, is_static_input_scheme: bool): self.strategy = strategy + self.out_dtype = torch.get_default_dtype() self.is_static_input_scheme = is_static_input_scheme self.cutlass_fp8_supported = cutlass_fp8_supported() @@ -142,6 +143,7 @@ def apply_weights(self, input=x, weight=layer.weight, weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, input_scale=layer.input_scale, bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index da5ef36c5105..864b0fa63561 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -72,6 +72,7 @@ class FBGEMMFp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: FBGEMMFp8Config): self.quant_config = quant_config self.cutlass_fp8_supported = cutlass_fp8_supported() + self.out_dtype = torch.get_default_dtype() def create_weights( self, @@ -160,6 +161,7 @@ def apply(self, input=x, weight=layer.weight, weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, input_scale=None, input_scale_ub=layer.input_scale_ub, bias=bias, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 86e025310f4e..3ea954bf9144 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional import torch +import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter @@ -31,6 +32,7 @@ PerTensorScaleParameter) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform +from vllm.utils import is_navi ACTIVATION_SCHEMES = ["static", "dynamic"] @@ -114,6 +116,26 @@ def get_quant_method(self, layer: torch.nn.Module, return Fp8KVCacheMethod(self) return None + def get_cache_scale(self, name: str) -> Optional[str]: + """ + Check whether the param name matches the format for k/v cache scales + in compressed-tensors. If this is the case, return its equivalent + param name expected by vLLM + + :param name: param name + :return: matching param name for KV cache scale in vLLM + """ + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") + # If no matches, return None + return None + class Fp8LinearMethod(LinearMethodBase): """Linear method for FP8. @@ -140,6 +162,7 @@ def __init__(self, quant_config: Fp8Config): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization + self.out_dtype = torch.get_default_dtype() self.use_marlin = (not current_platform.has_device_capability(89) or envs.VLLM_TEST_FORCE_FP8_MARLIN) # Disable marlin for rocm @@ -165,6 +188,8 @@ def create_weights( weight_loader = extra_weight_attrs.get("weight_loader") if self.block_quant: + assert not envs.VLLM_FP8_PADDING, ( + "FP8 weight padding is not supported in block quantization.") tp_size = get_tensor_model_parallel_world_size() assert self.quant_config.weight_block_size is not None block_n, block_k = ( @@ -252,7 +277,7 @@ def process_weights_after_loading(self, layer: Module) -> None: # TODO(rob): refactor block quant into separate class. if self.block_quant: assert self.quant_config.activation_scheme == "dynamic" - if current_platform.is_rocm(): + if current_platform.is_rocm() and not is_navi(): weight, weight_scale_inv, _ = \ normalize_e4m3fn_to_e4m3fnuz( weight=layer.weight, @@ -288,9 +313,13 @@ def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint is fp8, handle that there are N scales for N # shards in a fused module else: + layer.weight_scale.data[layer.weight_scale.data == torch.finfo( + torch.float32).min] = 1 layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data, requires_grad=False) if self.quant_config.activation_scheme == "static": + layer.input_scale.data[layer.input_scale.data == torch.finfo( + torch.float32).min] = 1 layer.input_scale = torch.nn.Parameter(layer.input_scale.data, requires_grad=False) # If using marlin (w8a16), kernel uses channelwise weights, @@ -307,8 +336,8 @@ def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight weight_scale = layer.weight_scale - # If rocm, use float8_e4m3fnuz. - if current_platform.is_rocm(): + # If rocm (except Navi4x), use float8_e4m3fnuz. + if current_platform.is_rocm() and not is_navi(): weight, weight_scale, input_scale = \ normalize_e4m3fn_to_e4m3fnuz( weight=weight, @@ -324,6 +353,14 @@ def process_weights_after_loading(self, layer: Module) -> None: logical_widths=layer.logical_widths, ) + # Pad the weight + if envs.VLLM_FP8_PADDING and weight.stride(-1) == 1 \ + and (weight.stride(-2) * weight.element_size()) % 512 == 0: + num_pad = 256 // weight.element_size() + weight = F.pad(weight, (0, num_pad), "constant", + 0)[..., :-num_pad] + torch.cuda.empty_cache() + # Update layer with new values. layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) @@ -370,6 +407,7 @@ def apply(self, input=x, weight=layer.weight, weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, input_scale=layer.input_scale, bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, @@ -519,7 +557,7 @@ def process_weights_after_loading(self, layer: Module) -> None: # TODO (rob): refactor block quant into separate class. if self.block_quant: assert self.quant_config.activation_scheme == "dynamic" - if current_platform.is_rocm(): + if current_platform.is_rocm() and not is_navi(): w13_weight, w13_weight_scale_inv, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( layer.w13_weight, layer.w13_weight_scale_inv, @@ -545,9 +583,9 @@ def process_weights_after_loading(self, layer: Module) -> None: # If checkpoint is fp16, quantize in place. if not self.quant_config.is_checkpoint_fp8_serialized: - # If rocm, use float8_e4m3fnuz as dtype - fp8_dtype = torch.float8_e4m3fnuz \ - if current_platform.is_rocm() else torch.float8_e4m3fn + # If rocm (except Navi4x), use float8_e4m3fnuz as dtype + fp8_dtype = (torch.float8_e4m3fnuz if current_platform.is_rocm() + and not is_navi() else torch.float8_e4m3fn) w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) @@ -594,8 +632,9 @@ def process_weights_after_loading(self, layer: Module) -> None: layer.w13_input_scale.max(), requires_grad=False) layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False) - # If rocm, normalize the weights and scales to e4m3fnuz - if current_platform.is_rocm(): + # If rocm (except Navi4x, which uses e4m3fn), + # normalize the weights and scales to e4m3fnuz + if current_platform.is_rocm() and not is_navi(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 388a4f16699c..a63682d3e115 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -2,10 +2,12 @@ import torch +import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.platforms import current_platform +from vllm.utils import is_navi logger = init_logger(__name__) @@ -35,57 +37,91 @@ def create_weights(self, layer: torch.nn.Module): requires_grad=False) layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + # Initialize Q and P = softmax(QK^T) scales + layer.q_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) + layer.prob_scale = torch.nn.Parameter(torch.tensor(-1.0), + requires_grad=False) def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError( f"{self.__class__.__name__}.apply should not be called.") def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 - # regardless whether the kv-scale is available in the checkpoint. - # No need to process kv scales after loading if we are going to - # calculate them on the fly. - if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: - if layer.k_scale > 0.0 and layer.v_scale > 0.0: - # We prefer to use separate k_scale and v_scale if present - k_scale = layer.k_scale.to("cpu").tolist() - v_scale = layer.v_scale.to("cpu").tolist() - if current_platform.is_rocm(): - k_scale *= 2 - v_scale *= 2 - elif layer.k_scale < 0.0 and layer.v_scale < 0.0: - # If no scales were loaded (both scales are invalid negative - # values), use the default value of 1.0 - k_scale = 1.0 - v_scale = 1.0 - else: - # If we find a single kv_scale in the checkpoint, we remap - # kv_scale to k_scale during weight loading, and duplicate - # k_scale to v_scale here - assert layer.k_scale > 0.0 - scale_to_duplicate = max(layer.k_scale, layer.v_scale) - k_scale = scale_to_duplicate.to("cpu").tolist() - v_scale = scale_to_duplicate.to("cpu").tolist() - if current_platform.is_rocm(): - k_scale *= 2 - v_scale *= 2 - - if not isinstance(k_scale, float) or not isinstance( - v_scale, float): - raise ValueError("Only support per-tensor scaling factor " - "for fp8 KV cache") - - # These are used in the final Attention.forward() - layer._k_scale.copy_(k_scale) - layer._v_scale.copy_(v_scale) - layer._k_scale_float = k_scale - layer._v_scale_float = v_scale - if (k_scale == 1.0 and v_scale == 1.0 - and "e5m2" not in layer.kv_cache_dtype): - logger.warning_once( - "Using KV cache scaling factor 1.0 for fp8_e4m3. This " - "may cause accuracy issues. Please make sure k/v_scale " - "scaling factors are available in the fp8 checkpoint.") + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + k_scale *= 2 + v_scale *= 2 + layer.calculate_kv_scales = False + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = 1.0 + v_scale = 1.0 + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + k_scale *= 2 + v_scale *= 2 + layer.calculate_kv_scales = False + + if not isinstance(k_scale, float) or not isinstance(v_scale, float): + raise ValueError("Only support per-tensor scaling factor " + "for fp8 KV cache") + + # These are used in the final Attention.forward() + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + layer._k_scale_float = k_scale + layer._v_scale_float = v_scale + if (k_scale == 1.0 and v_scale == 1.0 + and (layer.kv_cache_dtype != "auto" + or envs.VLLM_USE_ROCM_FP8_FLASH_ATTN) + and "e5m2" not in layer.kv_cache_dtype): + logger.warning_once( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure k/v_scale " + "scaling factors are available in the fp8 checkpoint.") + + if layer.q_scale > 0.0: + q_scale = layer.q_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + q_scale *= 2 + layer.calculate_kv_scales = False + else: + q_scale = 1.0 + if layer.prob_scale > 0.0: + prob_scale = layer.prob_scale.to("cpu").tolist() + if current_platform.is_rocm() and not is_navi(): + prob_scale *= 2 + else: + prob_scale = 1.0 + + if not isinstance(q_scale, float) or not isinstance(prob_scale, float): + raise ValueError("Only support per-tensor scaling factor" + "for fp8-quantized Q/prob") + + # These are used in the final Attention.forward() + layer._q_scale.copy_(q_scale) + layer._prob_scale.copy_(prob_scale) + if (q_scale == 1.0 + or prob_scale == 1.0) and envs.VLLM_USE_ROCM_FP8_FLASH_ATTN: + logger.warning_once( + f"Using Q scale {q_scale} and prob scale {prob_scale} " + "with fp8 attention. This may cause accuracy issues. " + "Please make sure Q/prob scaling factors are " + "available in the fp8 checkpoint.") del layer.k_scale del layer.v_scale + del layer.q_scale + del layer.prob_scale diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 0451cf82b997..debfd00392a9 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 import fnmatch -import re from typing import Any, Dict, List, Optional, cast import torch @@ -124,6 +123,13 @@ def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig": for q_config in q_configs: q_config["output_tensors"] = None + # In case q_proj output is also quantized, remove the configuration + # to keep qkv consistency. + q_proj_q_config = cast(Dict[str, Any], + layer_quant_config.get("*q_proj")) + if q_proj_q_config is not None: + q_proj_q_config["output_tensors"] = None + return cls(quant_config=config, kv_cache_group=kv_cache_group, kv_cache_config=kv_cache_config, @@ -150,6 +156,19 @@ def _check_scheme_supported(self, else: return False + def is_fp8_w8a8(self) -> bool: + # Returns True if all quantized layers in model are fp8 w8a8 + global_quant_config = cast( + Dict[str, Any], self.quant_config.get("global_quant_config")) + layer_quant_configs = cast(Dict[str, Any], + self.quant_config.get("layer_quant_config")) + for config in (global_quant_config, *layer_quant_configs.values()): + weight_config = cast(Dict[str, Any], config.get("weight")) + input_config = cast(Dict[str, Any], config.get("input_tensors")) + if not self._is_fp8_w8a8(weight_config, input_config): + return False + return True + def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]], input_quant: Optional[Dict[str, Any]]) -> bool: # Confirm weights and input quantized. @@ -288,25 +307,14 @@ def get_cache_scale(self, name: str) -> Optional[str]: :param name: param name :return: matching param name for KV cache scale in vLLM """ - if self.kv_cache_group is None or len(self.kv_cache_group) == 0: - return None - - kv_proj_names = [ - re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group - ] - if name.endswith(".output_scale"): - if len(kv_proj_names) == 1 and kv_proj_names[0] in name: - kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale" - return name.replace(kv_output_scale_name, ".attn.k_scale") - - elif len(kv_proj_names) == 2: - for kv_proj_name in kv_proj_names: - if kv_proj_name in name and kv_proj_name == "k_proj": - return name.replace(".k_proj.output_scale", - ".attn.k_scale") - elif kv_proj_name in name and kv_proj_name == "v_proj": - return name.replace(".v_proj.output_scale", - ".attn.v_scale") + if name.endswith(".output_scale") and ".k_proj" in name: + return name.replace(".k_proj.output_scale", ".attn.k_scale") + if name.endswith(".output_scale") and ".v_proj" in name: + return name.replace(".v_proj.output_scale", ".attn.v_scale") + if name.endswith(".output_scale") and ".q_proj" in name: + return name.replace(".q_proj.output_scale", ".attn.q_scale") + if name.endswith("self_attn.prob_output_scale"): + return name.replace(".prob_output_scale", ".attn.prob_scale") # If no matches, return None return None diff --git a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py index c885e98a4d66..fcdef26c5ea8 100644 --- a/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +++ b/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -23,6 +23,7 @@ def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]): self.qscheme = qscheme self.is_static_input_scheme = is_static_input_scheme self.cutlass_fp8_supported = cutlass_fp8_supported() + self.out_dtype = torch.get_default_dtype() @classmethod def get_min_capability(cls) -> int: @@ -136,6 +137,7 @@ def apply_weights(self, input=x, weight=layer.weight, weight_scale=layer.weight_scale, + out_dtype=self.out_dtype, input_scale=layer.input_scale, bias=bias, cutlass_fp8_supported=self.cutlass_fp8_supported, diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index 29c7268ad9e0..e7e526c05eff 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -17,6 +17,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( apply_fp8_linear) from vllm.platforms import current_platform +from vllm.utils import is_navi logger = init_logger(__name__) @@ -47,6 +48,16 @@ def apply_w8a8_block_fp8_linear( shape_supported_by_cutlass = (weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) + if current_platform.is_rocm(): + scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) + + input_2d.shape[:-1])[::-1] + scale_b_shape = (weight_scale.view(-1, 1) + if weight_scale.dim() <= 1 else weight_scale.T).shape + ar, ac = scale_a_shape + br, bc = scale_b_shape + if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0]) + or br not in (1, weight.shape[0])): + shape_supported_by_cutlass = False if cutlass_block_fp8_supported and shape_supported_by_cutlass: q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1], @@ -114,8 +125,8 @@ def input_to_float8( """This function quantizes input values to float8 values " "with tensor-wise quantization.""" if dtype is None: - dtype = (torch.float8_e4m3fnuz - if current_platform.is_rocm() else torch.float8_e4m3fn) + dtype = (torch.float8_e4m3fnuz if current_platform.is_rocm() + and not is_navi() else torch.float8_e4m3fn) finfo = torch.finfo(dtype) min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) @@ -246,8 +257,8 @@ def per_token_group_quant_fp8( scaling factor for quantization. """ if dtype is None: - dtype = (torch.float8_e4m3fnuz - if current_platform.is_rocm() else torch.float8_e4m3fn) + dtype = (torch.float8_e4m3fnuz if current_platform.is_rocm() + and not is_navi() else torch.float8_e4m3fn) assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible " f"by `group_size` {group_size}") diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 3fd88e8754a5..99575a3922cc 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -106,6 +106,7 @@ def apply_fp8_linear( input: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor, + out_dtype: Optional[torch.dtype] = None, input_scale: Optional[torch.Tensor] = None, input_scale_ub: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, @@ -120,6 +121,9 @@ def apply_fp8_linear( input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[1]] + if out_dtype is None: + out_dtype = input.dtype + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A if cutlass_fp8_supported: qinput, x_scale = ops.scaled_fp8_quant( @@ -143,11 +147,14 @@ def apply_fp8_linear( # Note: we pad the input because torch._scaled_mm is more performant # for matrices with batch dimension > 16. # This could change in the future. - qinput, x_scale = ops.scaled_fp8_quant( - input_2d, - input_scale, - num_token_padding=17, - use_per_token_if_dynamic=use_per_token_if_dynamic) + if input.dtype != torch.float8_e4m3fnuz: + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + num_token_padding=17, + use_per_token_if_dynamic=use_per_token_if_dynamic) + else: + qinput, x_scale = input_2d, input_scale per_tensor_weights = (weight_scale.numel() == 1) per_tensor_activations = (x_scale.numel() == 1) @@ -156,7 +163,7 @@ def apply_fp8_linear( # Fused GEMM_DQ output = torch._scaled_mm(qinput, weight, - out_dtype=input.dtype, + out_dtype=out_dtype, scale_a=x_scale, scale_b=weight_scale, bias=bias) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6af734be5e98..4155b5a1e8ba 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -21,6 +21,7 @@ CompletionSequenceGroupOutput, Logprob, PromptLogprobs, SampleLogprobs, SequenceOutput) from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics +from vllm.utils import rpd_mark if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): import flashinfer.sampling @@ -67,6 +68,7 @@ class SampleResultArgsType: multinomial_samples: MultinomialSamplesType sample_results_dict: SampleResultsDictType sampling_metadata: SamplingMetadata + forced_samples: Optional[torch.Tensor] greedy_samples: Optional[torch.Tensor] beam_search_logprobs: Optional[torch.Tensor] @@ -214,6 +216,7 @@ def _init_sampling_tensors( self._do_top_p_top_k = do_top_p_top_k self._do_min_p = do_min_p + @rpd_mark(name="Sampler Forward") def forward( self, logits: torch.Tensor, @@ -466,6 +469,39 @@ def _greedy_sample( return results +def _forced_sample( + selected_seq_groups: List[SequenceGroupToSample], + samples: torch.Tensor, +) -> List[Tuple[List[int], List[int]]]: + """Run forced sampling on a given samples. + Args: + selected_seq_groups: A list of sequence groups batched. + samples: (num_selected_samples,) A tensor of samples. The length of + samples could be smaller than selected_seq_groups if + seq_group.do_sample is False. + Returns: + Tuple of (next_token_ids, parent_ids). The length of returned list is + same as the length of selected_seq_groups. If the corresponding + seq_group has do_sample=False, tuple contains ([], []) + + The next_token_ids is guided (forced) by the id containing in the + sampling_parameters.future_context property. + """ + samples = samples.tolist() + sample_idx = 0 + results = [] + for seq_group in selected_seq_groups: + seq_ids = seq_group.seq_ids + num_parent_seqs = len(seq_ids) + assert num_parent_seqs == 1, ( + "Deterministic sampling should have only one seq.") + parent_ids = list(range(num_parent_seqs)) + next_token_ids = [samples[sample_idx]] + results.append((next_token_ids, parent_ids)) + sample_idx += num_parent_seqs + return results + + def _random_sample( selected_seq_groups: List[SequenceGroupToSample], random_samples: torch.Tensor, @@ -664,6 +700,7 @@ def get_pythonized_sample_results( ( sample_metadata, sampling_metadata, + forced_samples, greedy_samples, multinomial_samples, beam_search_logprobs, @@ -671,6 +708,7 @@ def get_pythonized_sample_results( ) = ( sample_result_args.sample_metadata, sample_result_args.sampling_metadata, + sample_result_args.forced_samples, sample_result_args.greedy_samples, sample_result_args.multinomial_samples, sample_result_args.beam_search_logprobs, @@ -686,6 +724,8 @@ def get_pythonized_sample_results( elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): sample_results = _random_sample(seq_groups, multinomial_samples[sampling_type]) + elif sampling_type == SamplingType.FORCED: + sample_results = _forced_sample(seq_groups, forced_samples) elif sampling_type == SamplingType.BEAM: sample_results = _beam_search_sample(seq_groups, beam_search_logprobs) @@ -730,6 +770,7 @@ def _sample_with_torch( sample_results_dict: SampleResultsDictType = {} sample_metadata: SampleMetadataType = {} multinomial_samples: MultinomialSamplesType = {} + forced_samples: Optional[torch.Tensor] = None greedy_samples: Optional[torch.Tensor] = None beam_search_logprobs: Optional[torch.Tensor] = None @@ -754,10 +795,10 @@ def _sample_with_torch( seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id] sample_metadata[sampling_type] = (seq_group_id, seq_groups) long_sample_indices = sample_indices.long() + if sampling_type == SamplingType.GREEDY: greedy_samples = torch.argmax(logprobs[long_sample_indices], dim=-1) - if sampled_token_ids_tensor is not None: # Store sampled tokens in output tensor. sampled_token_ids_tensor[ @@ -799,6 +840,23 @@ def _sample_with_torch( # Store sampled tokens in output tensor. sampled_token_ids_tensor[long_sample_indices] = \ multinomial_samples[sampling_type].to(torch.long) + elif sampling_type == SamplingType.FORCED: + forced_samples = torch.tensor([], dtype=torch.int32) + for sgidx in range(len(seq_groups)): + if (seq_groups[sgidx].sampling_params.future_context + is not None): + forced_sample = torch.tensor([ + seq_groups[sgidx].sampling_params.future_context[sgidx] + [min( + len(sampling_metadata.seq_groups[sgidx].seq_data[ + sampling_params.cntr[sgidx]].output_token_ids), + len(seq_groups[sgidx].sampling_params. + future_context[sgidx]) - 1)] + ]) + else: + forced_sample = torch.argmax(logprobs[long_sample_indices], + dim=-1) + forced_samples = torch.cat([forced_samples, forced_sample]) elif sampling_type == SamplingType.BEAM: beam_search_logprobs = logprobs[sample_indices] @@ -810,6 +868,7 @@ def _sample_with_torch( maybe_deferred_args = SampleResultArgsType( sampling_metadata=sampling_metadata, sample_metadata=sample_metadata, + forced_samples=forced_samples, multinomial_samples=multinomial_samples, greedy_samples=greedy_samples, beam_search_logprobs=beam_search_logprobs, @@ -830,6 +889,7 @@ def _sample_with_torch( ) +@rpd_mark() def _sample( probs: torch.Tensor, logprobs: torch.Tensor, @@ -1259,7 +1319,8 @@ def _build_sampler_output( deferred_sample_results_args=deferred_sample_results_args) -def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: +def _get_next_prompt_tokens( + seq_group: SequenceGroupToSample) -> Tuple[int, ...]: """Get a list of next prompt tokens to compute logprob from a given sequence group. diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py new file mode 100644 index 000000000000..8fb44cdc96c2 --- /dev/null +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from pathlib import Path + +import pandas as pd +import torch +import torch.nn.functional as F + +from vllm import _custom_ops as ops +from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM +from vllm.platforms import current_platform +from vllm.utils import is_navi + +support_tuned_gemms = False +if current_platform.is_rocm(): + import vllm._gradlib_C # noqa: F401 + support_tuned_gemms = True + + +def hipb_mm(inp, weights, solidx, bias=None): + return torch.ops._gradlib_C.hipb_mm(inp, weights, solidx, bias, None, None, + None, None) + + +def rocb_mm(inp, weights, solidx): + return torch.ops._gradlib_C.rocb_mm(inp, weights, solidx) + + +class TunedGemm: + + def __init__(self): + self.extensions_created = False + self.save_gemm = int(os.environ.get('VLLM_TUNE_GEMM', 0)) + self.untune_path = os.environ.get('VLLM_UNTUNE_FILE', + "/tmp/vllm_untuned.csv") + self.tune_path = os.environ.get('VLLM_TUNE_FILE', "tuned.csv") + self.bestsols = {} + self.load_best_sols() + self.create_ds() + self.cu_count = torch.cuda.get_device_properties( + device='cuda').multi_processor_count + + self.use_skinny = (current_platform.is_rocm() + and VLLM_USE_ROCM_SKINNY_GEMM and not is_navi()) + + if (self.save_gemm == 1): + self.tuned_df = pd.DataFrame( + columns=['M', 'N', 'K', 'bias', 'dtype']) + else: + self.tuned_df = None + + def load_best_sols(self): + if self.tune_path is not None and Path(self.tune_path).is_file(): + self.bestsols = pd.read_csv(self.tune_path) + + def create_ds(self): + df: pd.DataFrame = self.bestsols + solds = {} + for i in range(len(df)): + ds = df.iloc[i] + key = (ds['M'], ds['N'], ds['K'], ds['bias'], ds['dtype']) + if ds['libtype'] == 'hipblaslt': + soltype = 1 + elif ds['libtype'] == 'rocblas': + soltype = 2 + solds[key] = (soltype, int(ds['solidx'])) + self.solids = solds + + def query_sol(self, m, n, k, bias, dtype): + return self.solids.get((m, n, k, bias, str(dtype)), (0, 0)) + + def apply_skinny(self, m, n, k, inp_view, weights): + if not self.use_skinny: + return None + if inp_view.dtype != torch.float16 or k % 8 != 0: + return None + if m > 8 and 0 < n <= 4: + out = torch.empty(inp_view.shape[0], + weights.shape[0], + dtype=inp_view.dtype, + device='cuda') + ops.wvSpltK(weights, inp_view, out, n, self.cu_count) + return out + elif m % 4 == 0 and n == 1 and k <= 8192: + out = torch.empty(inp_view.shape[0], + weights.shape[0], + dtype=inp_view.dtype, + device='cuda') + ops.LLMM1(weights, inp_view, out, 4) + return out + else: + return None + + def mm(self, inp, weights, bias=None): + if not support_tuned_gemms: + return F.linear(inp, weights, bias) + # F.Linear can take a 3 dimensional input. vllm + # uses this for linear units. However, sampler + # will use torch.matmul with 2 dimensions only + if inp.dim() == 3: + try: + inp_view = inp.view(-1, inp.size(-1)) + batched = True + except RuntimeError: + return F.linear(inp, weights, bias) + else: + inp_view = inp + batched = False + if self.extensions_created is False: + torch.ops._gradlib_C.rocb_create_extension() + torch.ops._gradlib_C.hipb_create_extension() + self.extensions_created = True + m = weights.shape[0] + n = inp_view.shape[0] + k = inp_view.shape[1] + use_bias = bias is not None + soltype, solidx = self.query_sol(m=m, + n=n, + k=k, + bias=use_bias, + dtype=inp.dtype) + out = self.apply_skinny(m, n, k, inp_view, weights) + if out is not None: + if batched: + out = out.view(inp.shape[0], inp.shape[1], weights.shape[0]) + if bias is not None: + return out + bias + return out + elif soltype == 1: + out = hipb_mm(inp_view, weights.t(), solidx, bias) + elif soltype == 2: + out = rocb_mm(inp_view, weights.t(), solidx) + if bias is not None: + out = out + bias + else: + if (self.save_gemm == 1): + self.tuned_df = pd.concat([ + self.tuned_df, + pd.DataFrame({ + 'M': [m], + 'N': [n], + 'K': [k], + 'bias': [bias is not None], + 'dtype': [inp.dtype], + }) + ]).drop_duplicates() + self.tuned_df.to_csv(self.untune_path, index=False) + return F.linear(inp, weights, bias) + if batched: + out = out.view(inp.shape[0], inp.shape[1], weights.shape[0]) + return out + + +tgemm = TunedGemm() diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index e409094dd535..bedceb12deae 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -12,6 +12,7 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) +from vllm.model_executor.layers.tuned_gemm import tgemm from vllm.model_executor.parameter import BasevLLMParameter from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -40,7 +41,7 @@ def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return F.linear(x, layer.weight, bias) + return tgemm.mm(x, layer.weight, bias) def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor: @@ -135,7 +136,7 @@ def __post_init__(self): assert self.num_added_elements <= self.num_added_elements_padded -@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) +#@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def get_masked_input_and_mask( input_: torch.Tensor, org_vocab_start_index: int, org_vocab_end_index: int, num_org_vocab_padding: int, diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index e73627da05d4..d4ca3ac4db8e 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -47,7 +47,6 @@ row_parallel_weight_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -56,7 +55,6 @@ maybe_prefix) -@torch.compile(backend=current_platform.simple_compile_backend) def layer_norm_func(hidden_states, weight, variance_epsilon): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index f5fede4d8226..e894ca5416c8 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -150,11 +150,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: shared_output = self.shared_experts(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - final_hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits) * self.routed_scaling_factor + final_hidden_states = self.experts(hidden_states=hidden_states, + router_logits=router_logits) if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output + final_hidden_states = final_hidden_states + shared_output \ + * (1. / self.routed_scaling_factor) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) @@ -532,6 +532,7 @@ def __init__( eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor def forward( self, @@ -556,9 +557,14 @@ def forward( ) # Fully Connected + if isinstance(self.mlp, DeepseekV2MoE): + hidden_states *= 1. / self.mlp.routed_scaling_factor hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) hidden_states = self.mlp(hidden_states) + if isinstance(self.mlp, DeepseekV2MLP): + hidden_states *= 1. / self.routed_scaling_factor + residual *= 1. / self.routed_scaling_factor return hidden_states, residual diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py index 4449eb8e8b14..258df7040744 100644 --- a/vllm/model_executor/models/glm4_vision_encoder.py +++ b/vllm/model_executor/models/glm4_vision_encoder.py @@ -20,6 +20,7 @@ RowParallelLinear) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.utils import is_navi3 class PatchEmbedding(nn.Module): @@ -86,6 +87,30 @@ def __init__( self.output_dropout = torch.nn.Dropout(config.dropout_prob) def forward(self, x: torch.Tensor) -> torch.Tensor: + if is_navi3(): + try: + # git clone -b howiejay/navi_support https://github.com/ROCm/flash-attention.git + from flash_attn import flash_attn_func + B, L, _ = x.shape + qkv, _ = self.query_key_value(x) # B, L, 3 * H * D + q, k, v = qkv.chunk(3, dim=-1) + + q = q.reshape(B, L, self.num_heads_per_rank, + self.head_dim) # B, L, H, D + k = k.reshape(B, L, self.num_heads_per_rank, + self.head_dim) # B, L, H, D + v = v.reshape(B, L, self.num_heads_per_rank, + self.head_dim) # B, L, H, D + + out = flash_attn_func(q, k, v) + + output, _ = self.dense(out.view(B, L, -1)) + output = self.output_dropout(output) + + return output + except ModuleNotFoundError: + pass + qkv, _ = self.query_key_value(x) # B, L, 3 * H * D q, k, v = qkv.chunk(3, dim=-1) diff --git a/vllm/model_executor/models/grok1.py b/vllm/model_executor/models/grok1.py new file mode 100644 index 000000000000..e85655ddb595 --- /dev/null +++ b/vllm/model_executor/models/grok1.py @@ -0,0 +1,510 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Grok1 model.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs import Grok1Config + +from .interfaces import SupportsLoRA +from .utils import is_pp_missing_parameter, make_layers + +attn_output_multiplier = 0.08838834764831845 +output_multiplier_scale = 0.5773502691896257 +max_attn_val = 30.0 + + +class Grok1MoE(nn.Module): + """A tensor-parallel MoE implementation for Grok1 that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class Grok1Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Grok1DecoderLayer(nn.Module): + + def __init__( + self, + config: Grok1Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.use_fp8 = isinstance( + quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig) + and quant_config.is_fp8_w8a8()) + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.attn = Grok1Attention(hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + self.moe_block = Grok1MoE(num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.moe_block") + + self.pre_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attn_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.pre_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_moe_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + scale = None if not self.use_fp8 else \ + self.attn.qkv_proj.input_scale + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.pre_attn_norm(hidden_states, None, scale) + else: + hidden_states, residual = self.pre_attn_norm( + hidden_states, residual, scale) + + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states = self.post_attn_norm(hidden_states) + + ### fused_moe performance bad + hidden_states, residual = self.pre_moe_norm(hidden_states, residual) + + hidden_states = self.moe_block(hidden_states) + + hidden_states = self.post_moe_norm(hidden_states) + return hidden_states, residual + + +class Grok1Model(nn.Module): + + def __init__( + self, + config: Grok1Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embedding_multiplier_scale = config.embedding_multiplier_scale + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Grok1DecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + hidden_states = hidden_states * self.embedding_multiplier_scale + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Grok1ForCausalLM(nn.Module, SupportsLoRA): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: Grok1Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = Grok1Model(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if self.config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + output_multiplier_scale) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="linear", + ckpt_down_proj_name="linear_1", + ckpt_up_proj_name="linear_v", + num_experts=self.config.num_experts) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + + if "norm.scale" in name: + name = name.replace("scale", "weight") + + if "lm_head" in name and self.config.tie_word_embeddings: + continue + + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d91c8782a121..c7e4b5cd8ec1 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -28,6 +28,8 @@ from torch import nn from transformers import LlamaConfig +import vllm.envs as envs +from vllm import _custom_ops as ops from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig @@ -39,6 +41,8 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.fp8 import Fp8Config +from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -46,7 +50,9 @@ from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors +from vllm.utils import is_navi from .interfaces import SupportsLoRA, SupportsPP from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index, @@ -81,14 +87,29 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.down_proj", ) + self.use_fp8 = (isinstance(quant_config, Fp8Config) or + (isinstance(quant_config, QuarkConfig) + and quant_config.is_fp8_w8a8()) + if current_platform.is_rocm() and not is_navi() else + False) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") self.act_fn = SiluAndMul() def forward(self, x): - x, _ = self.gate_up_proj(x) - x = self.act_fn(x) + if current_platform.is_rocm() and x.shape[0] == 1 and x.shape[1] == 1: + out = torch.empty(x.shape[0], + self.gate_up_proj.weight.shape[0] // 2, + dtype=x.dtype, + device=x.device) + ops.LLMM_Silu(self.gate_up_proj.weight, x.view(-1, x.size(-1)), + out, 8) + x = out.view(x.shape[0], x.shape[1], out.shape[1]) + else: + x, _ = self.gate_up_proj(x) + x = self.act_fn( + x, self.down_proj.input_scale if self.use_fp8 else None) x, _ = self.down_proj(x) return x @@ -179,6 +200,15 @@ def __init__(self, else: sliding_window = None + # For CUDA devices and Navi4x, attn_fp8 will be set to false. + use_fp8 = isinstance( + quant_config, Fp8Config) or (isinstance(quant_config, QuarkConfig) + and quant_config.is_fp8_w8a8()) + self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \ + and current_platform.is_rocm() \ + and not is_navi() \ + and use_fp8 + self.attn = Attention( self.num_heads, self.head_dim, @@ -200,7 +230,9 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn( + q, k, v, kv_cache, attn_metadata, + self.o_proj.input_scale if self.attn_fp8_out else None) output, _ = self.o_proj(attn_output) return output @@ -216,6 +248,11 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size + self.use_fp8 = (isinstance(quant_config, Fp8Config) or + (isinstance(quant_config, QuarkConfig) + and quant_config.is_fp8_w8a8()) + if current_platform.is_rocm() and not is_navi() else + False) rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) if rope_scaling is not None and getattr( @@ -270,20 +307,23 @@ def forward( residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention + scale = None if not self.use_fp8 else \ + self.self_attn.qkv_proj.input_scale if residual is None: residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.input_layernorm(hidden_states, None, scale) else: hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual, scale) hidden_states = self.self_attn(positions=positions, hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata) # Fully Connected + scale = None if not self.use_fp8 else self.mlp.gate_up_proj.input_scale hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual, scale) hidden_states = self.mlp(hidden_states) return hidden_states, residual diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 962f95f10fc5..a4cfd2eedd46 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -75,6 +75,7 @@ "MistralForCausalLM": ("llama", "LlamaForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + "Grok1ModelForCausalLM": ("grok1", "Grok1ForCausalLM"), # transformers's mpt class has lower case "MptForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"), diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index cd851c0d87a7..576f61077697 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 import os -from functools import lru_cache +from functools import lru_cache, wraps from typing import TYPE_CHECKING, Dict, List, Optional import torch +from amdsmi import (AmdSmiException, amdsmi_get_gpu_board_info, + amdsmi_get_processor_handles, amdsmi_init, + amdsmi_shut_down, amdsmi_topo_get_link_type) import vllm.envs as envs from vllm.logger import init_logger @@ -60,6 +63,41 @@ "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") } +# Prevent use of clashing `{CUDA/HIP}_VISIBLE_DEVICES`` +if "HIP_VISIBLE_DEVICES" in os.environ: + val = os.environ["HIP_VISIBLE_DEVICES"] + if cuda_val := os.environ.get("CUDA_VISIBLE_DEVICES", None): + assert val == cuda_val + else: + os.environ["CUDA_VISIBLE_DEVICES"] = val + +# AMDSMI utils +# Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`, +# all the related functions work on real physical device ids. +# the major benefit of using AMDSMI is that it will not initialize CUDA + + +def with_amdsmi_context(fn): + + @wraps(fn) + def wrapper(*args, **kwargs): + amdsmi_init() + try: + return fn(*args, **kwargs) + finally: + amdsmi_shut_down() + + return wrapper + + +def device_id_to_physical_device_id(device_id: int) -> int: + if "CUDA_VISIBLE_DEVICES" in os.environ: + device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + physical_device_id = device_ids[device_id] + return int(physical_device_id) + else: + return device_id + class RocmPlatform(Platform): _enum = PlatformEnum.ROCM @@ -79,6 +117,9 @@ class RocmPlatform(Platform): def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla) -> str: + if use_mla: + logger.info("Using Triton MLA backend.") + return "vllm.attention.backends.triton_mla.TritonMLABackend" selected_backend = (_Backend.ROCM_FLASH if selected_backend == _Backend.FLASH_ATTN else selected_backend) if selected_backend == _Backend.ROCM_FLASH: @@ -96,10 +137,39 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: major, minor = torch.cuda.get_device_capability(device_id) return DeviceCapability(major=major, minor=minor) + @staticmethod + @with_amdsmi_context + def is_full_nvlink(physical_device_ids: List[int]) -> bool: + """ + Query if the set of gpus are fully connected by xgmi (1 hop) + """ + handles = [ + amdsmi_get_processor_handles()[i] for i in physical_device_ids + ] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + link_type = amdsmi_topo_get_link_type( + handle, peer_handle) + # type is 2 for XGMI + if link_type["hops"] != 1 or link_type["type"] != 2: + return False + except AmdSmiException as error: + logger.error("AMD 1 hop XGMI detection failed.", + exc_info=error) + return False + return True + @classmethod + @with_amdsmi_context @lru_cache(maxsize=8) def get_device_name(cls, device_id: int = 0) -> str: - return torch.cuda.get_device_name(device_id) + physical_device_id = device_id_to_physical_device_id(device_id) + handle = amdsmi_get_processor_handles()[physical_device_id] + # Note: this may not be exactly the same as the torch device name + # E.g. `AMD Instinct MI300X OAM` vs `AMD Instinct MI300X` + return amdsmi_get_gpu_board_info(handle)["product_name"] @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: @@ -166,4 +236,4 @@ def get_current_memory_usage(cls, device: Optional[torch.types.Device] = None ) -> float: torch.cuda.reset_peak_memory_stats(device) - return torch.cuda.max_memory_allocated(device) + return torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0] \ No newline at end of file diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 97f9e2129573..506e34072203 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -23,6 +23,7 @@ class SamplingType(IntEnum): GREEDY = 0 RANDOM = 1 RANDOM_SEED = 2 + FORCED = 3 # maybe make msgspec? @@ -124,6 +125,8 @@ class SamplingParams( min_p: Float that represents the minimum probability for a token to be considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this. + ppl_measurement: Measure perplexity towards the deterministic string + instead of probabilistic regressing. seed: Random seed to use for the generation. stop: List of strings that stop the generation when they are generated. The returned output will not contain the stop strings. @@ -178,6 +181,9 @@ class SamplingParams( top_p: float = 1.0 top_k: int = -1 min_p: float = 0.0 + ppl_measurement: bool = False + future_context: Optional[List[int]] = None + cntr: Optional[List[int]] = None seed: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stop_token_ids: Optional[List[int]] = None @@ -221,6 +227,9 @@ def from_optional( top_p: Optional[float] = 1.0, top_k: int = -1, min_p: float = 0.0, + ppl_measurement: bool = False, + future_context: Optional[List[int]] = None, + cntr: Optional[int] = None, seed: Optional[int] = None, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, @@ -261,6 +270,9 @@ def from_optional( top_p=1.0 if top_p is None else top_p, top_k=top_k, min_p=min_p, + ppl_measurement=ppl_measurement, + future_context=future_context, + cntr=cntr, seed=seed, stop=stop, stop_token_ids=stop_token_ids, @@ -440,6 +452,8 @@ def update_from_generation_config( @cached_property def sampling_type(self) -> SamplingType: + if self.ppl_measurement: + return SamplingType.FORCED if self.temperature < _SAMPLING_EPS: return SamplingType.GREEDY if self.seed is not None: @@ -475,6 +489,7 @@ def __repr__(self) -> str: f"top_p={self.top_p}, " f"top_k={self.top_k}, " f"min_p={self.min_p}, " + f"ppl_measurement={self.ppl_measurement}, " f"seed={self.seed}, " f"stop={self.stop}, " f"stop_token_ids={self.stop_token_ids}, " diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 1c0f20a6e045..57c4da5e5516 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -28,7 +28,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, DbrxConfig, DeepseekVLV2Config, EAGLEConfig, ExaoneConfig, - H2OVLChatConfig, + Grok1Config, H2OVLChatConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MllamaConfig, MLPSpeculatorConfig, MPTConfig, @@ -75,6 +75,7 @@ "solar": SolarConfig, "telechat": Telechat2Config, "ultravox": UltravoxConfig, + "grok-1": Grok1Config, **_CONFIG_REGISTRY_OVERRIDE_HF } @@ -243,6 +244,10 @@ def get_config( raise RuntimeError(err_msg) from e else: raise e + if config.model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[config.model_type] + config = config_class.from_pretrained( + model, revision=revision, code_revision=code_revision) elif config_format == ConfigFormat.MISTRAL: config = load_params_config(model, revision, token=HF_TOKEN, **kwargs) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index c484a755ab4e..4521e8a40a2d 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -10,6 +10,7 @@ # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.grok1 import Grok1Config from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig @@ -45,4 +46,5 @@ "SolarConfig", "Telechat2Config", "UltravoxConfig", + "Grok1Config", ] \ No newline at end of file diff --git a/vllm/transformers_utils/configs/grok1.py b/vllm/transformers_utils/configs/grok1.py new file mode 100644 index 000000000000..5057698c33a5 --- /dev/null +++ b/vllm/transformers_utils/configs/grok1.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: Apache-2.0 +from transformers.configuration_utils import PretrainedConfig + + +class Grok1Config(PretrainedConfig): + model_type = "grok-1" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__(self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=32768, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=32, + attn_output_multiplier=1.0, + max_attn_value=1.0, + max_position_embeddings=4096, + embedding_multiplier_scale: float = 1.0, + output_multiplier_scale: float = 1.0, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=True, + num_experts_per_tok=2, + num_experts=8, + output_router_logits=False, + router_aux_loss_coef=0.001, + **kwargs): + self.vocab_size = vocab_size + self.attn_output_multiplier = attn_output_multiplier + self.max_attn_value = max_attn_value + self.max_position_embeddings = max_position_embeddings + self.embedding_multiplier_scale = embedding_multiplier_scale + self.output_multiplier_scale = output_multiplier_scale + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + + self.num_experts_per_tok = num_experts_per_tok + self.num_experts = num_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) diff --git a/vllm/utils.py b/vllm/utils.py index a2b53fcf252d..a3d0010557ce 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -177,6 +177,145 @@ class _Sentinel: ALL_PINNED_SENTINEL = _Sentinel() +class rpd_trace: + + def __init__(self, + filename=None, + name=None, + nvtx=False, + args=None, + skip=False): + self.skip = skip + if not self.skip: + self.name = name + self.args = args if args else "" + self.rpd = self.initialize_rpd_tracer(filename, nvtx) + + def _recreate_cm(self): + return self + + def __call__(self, func): + if not self.skip: + if self.name: + self.name += f"{func.__name__}" + else: + self.name = f"{func.__qualname__}" + + @wraps(func) + def inner(*args, **kwds): + with self._recreate_cm(): + return func(*args, **kwds) + + return inner + return func + + def __enter__(self): + if not self.skip: + self.rpd.__enter__() + self.rpd.rangePush("python", f"{self.name}", f"{self.args}") + return self + + def __exit__(self, *exc): + if not self.skip: + self.rpd.rangePop() + self.rpd.__exit__(None, None, None) + return False + + @staticmethod + def setup_environment_variables(filename): + os.environ['RPDT_AUTOSTART'] = '0' + os.environ['RPDT_FILENAME'] = filename + + def initialize_rpd_tracer(self, filename, nvtx): + try: + from rpdTracerControl import rpdTracerControl + rpd_trace.setup_environment_variables(filename) + rpdTracerControl.setFilename(name=filename, append=True) + return rpdTracerControl(nvtx=nvtx) + except Exception as e: + print(f"Error initializing rpdTracerControl: {e}") + raise + + @staticmethod + def create_file(filename): + import sqlite3 + + from rocpd.schema import RocpdSchema + try: + print("Creating empty rpd schema file ...") + filename = str(filename) + with sqlite3.connect(filename) as connection: + schema = RocpdSchema() + schema.writeSchema(connection) + connection.commit() + except sqlite3.OperationalError as e: + print(f"SQLite operational error: {e}") + except Exception as e: + print(f"An error occurred while creating the filename: {e}") + + +@cache +def is_hipScopedMarker_available(): + try: + from hipScopedMarker import hipScopedMarker + except ImportError: + hipScopedMarker = None + return hipScopedMarker is not None + + +class rpd_mark: + + def __init__(self, name=None): + self.name = name + + def __call__(self, func): + + if is_hipScopedMarker_available(): + from hipScopedMarker import hipScopedMarker + + @wraps(func) + def inner(*args, **kwds): + marker_name = self.name if self.name else f"{func.__name__}" + with hipScopedMarker(f"{marker_name}"): + return func(*args, **kwds) + + return inner + + else: + return func + + +class rpd_user_marker: + + def __init__(self, name=None): + self.name = name + self.marker = None + + def __enter__(self): + if is_hipScopedMarker_available(): + from hipScopedMarker import hipScopedMarker + marker_name = self.name if self.name else "UserMarker Undefined" + self.marker = hipScopedMarker(f"{marker_name}") + self.marker.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if is_hipScopedMarker_available() and self.marker: + self.marker.__exit__(exc_type, exc_val, exc_tb) + + def start(self): + if is_hipScopedMarker_available(): + from hipScopedMarker import hipScopedMarker + marker_name = self.name if self.name else "UserMarker Undefined" + self.marker = hipScopedMarker(f"{marker_name}") + self.marker.__enter__() + return self + + def end(self, exc_type=0, exc_val=0, exc_tb=0): + if is_hipScopedMarker_available() and self.marker: + self.marker.__exit__(exc_type, exc_val, exc_tb) + + class Device(enum.Enum): GPU = enum.auto() CPU = enum.auto() @@ -354,6 +493,16 @@ def reset(self): self._index = 0 +@cache +def is_mi250() -> bool: + from vllm.platforms import current_platform + if not current_platform.is_rocm() or not torch.cuda.is_available(): + return False + archName = torch.cuda.get_device_properties('cuda').gcnArchName + return (archName is not None) and \ + ("gfx90a" in archName) + + @cache def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" @@ -1566,6 +1715,28 @@ def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: return torch.ops._C.weak_ref_tensor(tensor) +@cache +def is_navi() -> bool: + from vllm.platforms import current_platform + if not current_platform.is_rocm() or not torch.cuda.is_available(): + return False + # All (visible) GPUs must be of the same type, + # otherwise FP8 results can't be guaranteed. + archName = torch.cuda.get_device_properties('cuda').gcnArchName + return archName is not None and "gfx1" in archName + + +@cache +def is_navi3() -> bool: + from vllm.platforms import current_platform + if not current_platform.is_rocm() or not torch.cuda.is_available(): + return False + # All (visible) GPUs must be of the same type, + # otherwise FP8 results can't be guaranteed. + archName = torch.cuda.get_device_properties('cuda').gcnArchName + return archName is not None and "gfx11" in archName + + def weak_ref_tensors( tensors: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]] ) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]: diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index e2d338f75761..a0098f095702 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -105,7 +105,7 @@ def __init__( def _maybe_force_supported_attention_backend(self): ''' - Force vLLM to use the XFormers attention backend, + Force vLLM to use the XFormers or ROCM attention backend, which is currently the only supported option. ''' @@ -122,13 +122,13 @@ def raise_backend_err(): # Backend override enforced by global variable takes # precedence over vLLM backend environment variable. if maybe_global_forced_backend not in\ - [_Backend.XFORMERS, _Backend.FLASH_ATTN]: + [_Backend.XFORMERS, _Backend.FLASH_ATTN, _Backend.ROCM_FLASH]: raise_backend_err() elif is_forced_by_env_var: # noqa: SIM102 # Backend override enforced by vLLM backend # environment variable if maybe_env_var_forced_backend not in\ - [_Backend.XFORMERS, _Backend.FLASH_ATTN]: + [_Backend.XFORMERS, _Backend.FLASH_ATTN, _Backend.ROCM_FLASH]: raise_backend_err() def _list_to_int32_tensor( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 90f08b1dfde8..6edb50dbb59b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -50,8 +50,8 @@ from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache, async_tensor_h2d, flatten_2d_lists, - is_pin_memory_available, supports_dynamo, - weak_ref_tensor) + is_pin_memory_available, rpd_mark, rpd_user_marker, + supports_dynamo, weak_ref_tensor) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -1646,6 +1646,7 @@ def prepare_model_input( is_prompt=is_prompt, virtual_engine=virtual_engine) + @rpd_mark() @torch.inference_mode() @dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"]) def execute_model( @@ -1677,6 +1678,12 @@ def execute_model( assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata decode_meta = model_input.attn_metadata.decode_metadata + if prefill_meta: + marker_instance = rpd_user_marker(name="Prefill") + else: + marker_instance = rpd_user_marker(name="Decode") + + marker_instance.start() # TODO(andoorve): We can remove this once all # virtual engines share the same kv cache. virtual_engine = model_input.virtual_engine @@ -1812,6 +1819,7 @@ def execute_model( output.hidden_states = hidden_states + marker_instance.end() return [output] def need_recv_kv(self, model_input, kv_caches) -> bool: diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 582aa460eb4f..54be8ea54179 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -2,6 +2,7 @@ """A GPU worker class.""" import gc import os +from pathlib import Path from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -109,18 +110,45 @@ def __init__( with_stack=True, on_trace_ready=torch.profiler.tensorboard_trace_handler( torch_profiler_trace_dir, use_gzip=True)) + elif envs.VLLM_RPD_PROFILER_DIR: + rpd_profiler_trace_dir = Path(envs.VLLM_RPD_PROFILER_DIR) + + if rpd_profiler_trace_dir.suffix != ".rpd": + rpd_profiler_trace_dir = rpd_profiler_trace_dir / "trace.rpd" + + rpd_profiler_trace_dir.parent.mkdir(parents=True, exist_ok=True) + + logger.info("Profiling enabled. Traces will be saved to: %s", + rpd_profiler_trace_dir) + + from vllm.utils import rpd_trace + + if self.rank == 0: + rpd_trace.create_file(filename=str(rpd_profiler_trace_dir)) + + self.profiler = rpd_trace(filename=str(rpd_profiler_trace_dir), + name='Worker RPD Enabled', + nvtx=True) else: self.profiler = None def start_profile(self): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") - self.profiler.start() + + if envs.VLLM_RPD_PROFILER_DIR: + self.profiler.__enter__() + else: + self.profiler.start() def stop_profile(self): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") - self.profiler.stop() + + if envs.VLLM_RPD_PROFILER_DIR: + self.profiler.__exit__() + else: + self.profiler.stop() def sleep(self, level: int = 1) -> None: free_bytes_before_sleep = torch.cuda.mem_get_info()[0]