diff --git a/examples/BMTrain/.dockerignore b/examples/BMTrain/.dockerignore new file mode 100644 index 00000000..c1591543 --- /dev/null +++ b/examples/BMTrain/.dockerignore @@ -0,0 +1,147 @@ +**/__pycache__/ +**/*.py[cod] +**/*$py.class + +# C extensions +**/*.so + +# Distribution / packaging +**/.Python +**/build/ +**/develop-eggs/ +**/dist/ +**/downloads/ +**/eggs/ +**/.eggs/ +**/lib/ +**/lib64/ +**/parts/ +**/sdist/ +**/var/ +**/wheels/ +**/share/python-wheels/ +**/*.egg-info/ +**/.installed.cfg +**/*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +**/*.manifest +**/*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +**/htmlcov/ +**/.tox/ +**/.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +**/*.cover +**/*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +**/.pyre/ + +# pytype static type analyzer +**/.pytype/ + +# Cython debug symbols +cython_debug/ + +**/*.pt + +**/*.npy + +**/.DS_Store + +**/log +**/*.qdrep +!bmtrain/dist \ No newline at end of file diff --git a/examples/BMTrain/.github/ISSUE_TEMPLATE/bug_report.yml b/examples/BMTrain/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 00000000..47c89fe6 --- /dev/null +++ b/examples/BMTrain/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,79 @@ + +name: 🐞 Bug Report +description: Report a bug/issue related to the PyTorch-based parallel model training toolkit +title: "[BUG] " +labels: ["bug"] +body: +- type: checkboxes + attributes: + label: Is there an existing issue for this? + description: Please search to see if an issue already exists for the bug you encountered. + options: + - label: I have searched the existing issues + required: true +- type: textarea + attributes: + label: Description of the Bug + description: Provide a clear and concise description of what the bug is. + validations: + required: true +- type: textarea + attributes: + label: Environment Information + description: | + Provide details about your environment. + Example: + - GCC version: 9.3.0 + - Torch version: 1.9.0 + - Linux system version: Ubuntu 20.04 + - CUDA version: 11.4 + - Torch's CUDA version (as per `torch.cuda.version()`): 11.3 + value: | + - GCC version: + - Torch version: + - Linux system version: + - CUDA version: + - Torch's CUDA version (as per `torch.cuda.version()`): + render: markdown + validations: + required: true +- type: textarea + attributes: + label: To Reproduce + description: Provide the steps and details to reproduce the behavior. + placeholder: | + 1. Describe your environment setup, including any specific version requirements. + 2. Clearly state the steps you took to trigger the error, including the specific code you executed. + 3. Identify the file and line number where the error occurred, along with the full traceback of the error. Make sure to have `NCCL_DEBUG=INFO` and `CUDA_LAUNCH_BLOCKING=True` set to get accurate debug information. + validations: + required: true +- type: textarea + attributes: + label: Expected Behavior + description: Describe what you expected to happen when you executed the code. + validations: + required: true +- type: textarea + attributes: + label: Screenshots + description: If applicable, please add screenshots to help explain your problem. + validations: + required: false +- type: textarea + attributes: + label: Additional Information + description: | + Provide any other relevant context or information about the problem here. + Links? References? Anything that will give us more context about the issue you are encountering! + Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in. + validations: + required: false +- type: checkboxes + attributes: + label: Confirmation + description: Please confirm that you have reviewed all of the above requirements and verified the information provided before submitting this issue. + options: + - label: I have reviewed and verified all the information provided in this report. + validations: + required: true + diff --git a/examples/BMTrain/.github/ISSUE_TEMPLATE/build_err.yml b/examples/BMTrain/.github/ISSUE_TEMPLATE/build_err.yml new file mode 100644 index 00000000..940fd7ed --- /dev/null +++ b/examples/BMTrain/.github/ISSUE_TEMPLATE/build_err.yml @@ -0,0 +1,94 @@ +name: 🛠️ Build Error +description: Report a build error for this project +title: "[BUILD ERROR] <title>" +labels: ["Build ERR"] +body: +- type: checkboxes + id: prev_issue + attributes: + label: Is there an existing issue for this? + description: Please search to see if an issue already exists for the build error you encountered. + options: + - label: I have searched the existing issues + required: true +- type: textarea + attributes: + label: Description of the Build Error + description: Provide a clear and concise description of what the build error is. + validations: + required: true +- type: textarea + attributes: + label: Expected Behavior + description: Provide a clear and concise description of what you expected to happen. + validations: + required: true +- type: textarea + attributes: + label: To Reproduce + description: Describe the steps you took to trigger the build error. Include any commands you executed or files you modified. + placeholder: | + 1. Go to '...' + 2. Click on '....' + 3. Scroll down to '....' + 4. See error + validations: + required: true +- type: textarea + attributes: + label: Environment Information + description: | + Provide details about your environment. + Example: + - Operating System version: Ubuntu 20.04 + - GCC version: 9.3.0 + - Pybind version: 2.8.1 + - CUDA version: 11.4 + - NVIDIA NCCL CU11 version: 2.14.3 + - CMake version: 3.21.2 + - Pip version: 22.0.0 + value: | + - Operating System version: + - GCC version: + - Pybind version: + - CUDA version: + - NVIDIA NCCL CU11 version: + - CMake version: + - Pip version: + render: markdown + validations: + required: true +- type: dropdown + attributes: + label: Installation Method + description: Please indicate if the error occurred during source code installation or when using the pip install .whl method. + options: + - Source Code Installation + - Pip Install .whl Method + validations: + required: true +- type: textarea + attributes: + label: Full Error Traceback + description: Provide the complete error traceback. + validations: + required: true +- type: textarea + attributes: + label: Additional Information + description: | + Provide any other relevant context or information about the problem here. + Links? References? Anything that will give us more context about the issue you are encountering! + Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in. + validations: + required: false +- type: checkboxes + id: confirm + attributes: + label: Confirmation + description: Please confirm that you have reviewed all of the above requirements and verified the information provided before submitting this report. + options: + - label: I have reviewed and verified all the information provided in this report. + validations: + required: true + diff --git a/examples/BMTrain/.github/ISSUE_TEMPLATE/features_request.yml b/examples/BMTrain/.github/ISSUE_TEMPLATE/features_request.yml new file mode 100644 index 00000000..769948c2 --- /dev/null +++ b/examples/BMTrain/.github/ISSUE_TEMPLATE/features_request.yml @@ -0,0 +1,30 @@ +name: 🚀Feature Request +description: Suggest an idea for this project +title: "[Feature] <title>" +labels: ["enhancement"] +assignees: [] +body: +- type: textarea + attributes: + label: Is your feature request related to a problem? Please describe. + description: "A clear and concise description of what the problem is. Example: I'm always frustrated when..." + validations: + required: true +- type: textarea + attributes: + label: Describe the solution you'd like + description: "A clear and concise description of what you want to happen." + validations: + required: true +- type: textarea + attributes: + label: Describe alternatives you've considered + description: "A clear and concise description of any alternative solutions or features you've considered." + validations: + required: false +- type: textarea + attributes: + label: Additional context + description: "Add any other context or screenshots about the feature request here." + validations: + required: false diff --git a/examples/BMTrain/.github/pull_request_template.md b/examples/BMTrain/.github/pull_request_template.md new file mode 100644 index 00000000..87a5f9b0 --- /dev/null +++ b/examples/BMTrain/.github/pull_request_template.md @@ -0,0 +1,29 @@ +## Pull Request Template + +### Issue Reference +Please mention the issue number if applicable, or write "N/A" if it's a new feature. + +Issue #... + +### Description +Please describe your changes in detail. If it resolves an issue, please state how it resolves it. + +### Type of Change +- [ ] Bug fix (non-breaking change which fixes an issue) +- [ ] New feature (non-breaking change which adds functionality) +- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) +- [ ] This change requires a documentation update + +### How Has This Been Tested? +Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce. + +### Checklist +- [ ] I have read the [CONTRIBUTING](../../CONTRIBUTING.md) document. +- [ ] My code follows the code style of this project. +- [ ] My change requires a change to the documentation. +- [ ] I have updated the documentation accordingly. +- [ ] I have added tests to cover my changes. +- [ ] All new and existing tests passed. + +### Additional Information +Any additional information, configuration, or data that might be necessary for the review. diff --git a/examples/BMTrain/.github/workflows/build.yml b/examples/BMTrain/.github/workflows/build.yml new file mode 100644 index 00000000..11aa61f6 --- /dev/null +++ b/examples/BMTrain/.github/workflows/build.yml @@ -0,0 +1,35 @@ +name: Build + +on: + pull_request_target: + types: [opened, reopened, synchronize] + branches: + - 'dev' + - 'main' + push: + branches: + - 'dev' + +jobs: + build-archive-wheel: + + uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@main + secrets: inherit + + fake-publish: + needs: build-archive-wheel + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set Up the Python + uses: actions/setup-python@v2 + with: + python-version: 3.9 + + - name: Download distribution files + uses: actions/download-artifact@v4 + with: + name: dist + path: dist diff --git a/examples/BMTrain/.github/workflows/build_whl.yml b/examples/BMTrain/.github/workflows/build_whl.yml new file mode 100644 index 00000000..9116b598 --- /dev/null +++ b/examples/BMTrain/.github/workflows/build_whl.yml @@ -0,0 +1,89 @@ +name: Build wheels in docker and archive + +on: + workflow_call: + secrets: + DOCKERHUB_TOKEN: + required: true + DOCKERHUB_USERNAME: + required: true + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['37', '38', '39', '310', '311'] + + + steps: + + - name: Check the disk space and clear unnecessary library + run: | + rm -rf /home/runner/work/BMTrain/BMTrain/dist + sudo rm -rf /usr/share/dotnet + sudo rm -rf /opt/ghc + sudo rm -rf "/usr/local/share/boost" + sudo rm -rf "$AGENT_TOOLSDIRECTORY" + df -hl + + - name: Checkout code + uses: actions/checkout@v3 + + - name: Login to DockerHub + uses: docker/login-action@v2 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Pull Docker image + run: docker pull pytorch/manylinux-cuda113:latest + + - name: Run Docker image and execute script + run: | + version=${{ matrix.python-version }} + docker run -e BUILD_DOCKER_ENV=1 -e CUDACXX=/usr/local/cuda-11.3/bin/nvcc -e PATH="/opt/rh/devtoolset-9/root/usr/bin:$PATH" -e LD_LIBRARY_PATH="/opt/rh/devtoolset-9/root/usr/lib64:/opt/rh/devtoolset-9/root/usr/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH" -v ${{ github.workspace }}:/workspace/BMTrain -i pytorch/manylinux-cuda113:latest /bin/bash -c "cd /workspace/BMTrain;/opt/python/cp${version}*/bin/pip install build; /opt/python/cp${version}*/bin/python -m build .;for file in dist/*-linux_x86_64.whl; do mv \"\$file\" \"\${file//-linux_x86_64/-manylinux2014_x86_64}\"; done" + + - name: Upload wheels as artifacts + uses: actions/upload-artifact@v4 + with: + name: wheels_py${{ matrix.python-version }} + path: dist/*.whl + + - name: Upload source distribution (only once) + if: matrix.python-version == '37' # Only upload source distribution once + uses: actions/upload-artifact@v4 + with: + name: source_dist + path: dist/*.tar.gz + + archive: + runs-on: ubuntu-latest + needs: build + steps: + - name: Download all wheels + uses: actions/download-artifact@v4 + with: + path: wheels + pattern: wheels_py* + + - name: Download source distribution + uses: actions/download-artifact@v4 + with: + path: source_dist + name: source_dist + + - name: Combine all wheels into a single directory + run: | + mkdir -p dist + find wheels -name '*.whl' -exec mv {} dist/ \; + find source_dist -name '*.tar.gz' -exec mv {} dist/ \; + + - name: Archive distribution files + uses: actions/upload-artifact@v4 + with: + name: dist + path: | + dist/*.tar.gz + dist/*.whl + overwrite: true \ No newline at end of file diff --git a/examples/BMTrain/.github/workflows/publish.yaml b/examples/BMTrain/.github/workflows/publish.yaml new file mode 100644 index 00000000..fd9b8c50 --- /dev/null +++ b/examples/BMTrain/.github/workflows/publish.yaml @@ -0,0 +1,41 @@ +name: Build and Publish to PyPI + +on: + push: + tags: + + - "v*.*.*" + +jobs: + + build-archive-wheel: + uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@main + secrets: + DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} + DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} + + publish: + needs: build-archive-wheel + runs-on: ubuntu-latest + steps: + - name: Set Up the Python + uses: actions/setup-python@v2 + with: + python-version: 3.9 + + - name: Install twine + run: python -m pip install twine + + - name: Download distribution files + uses: actions/download-artifact@v4 + with: + name: dist + path: dist + + - name: Publish to PyPI + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + cd dist + python -m twine upload *.tar.gz *.whl diff --git a/examples/BMTrain/.github/workflows/release.yml b/examples/BMTrain/.github/workflows/release.yml new file mode 100644 index 00000000..bafc2173 --- /dev/null +++ b/examples/BMTrain/.github/workflows/release.yml @@ -0,0 +1,44 @@ +name: Publish release in Github + +on: + push: + tags: + - "v*.*.*" + +jobs: + + build-archive-wheel: + + uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@main + secrets: + DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }} + DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }} + + publish: + needs: build-archive-wheel + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v3 + + - name: Set Up the Python + uses: actions/setup-python@v2 + with: + python-version: 3.9 + + - name: Download distribution files + uses: actions/download-artifact@v4 + with: + name: dist + path: dist + + - name: Upload Distribution Files to Existing Release + uses: softprops/action-gh-release@v1 + with: + files: | + dist/*.tar.gz + dist/*.whl + tag_name: ${{ github.ref_name }} # 使用当前触发工作流的 tag + token: ${{ secrets.RELEASE_TOKEN }} + env: + GITHUB_REPOSITORY: OpenBMB/BMTrain diff --git a/examples/BMTrain/.gitignore b/examples/BMTrain/.gitignore new file mode 100644 index 00000000..75138102 --- /dev/null +++ b/examples/BMTrain/.gitignore @@ -0,0 +1,155 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +*.pt + +*.npy + +bminference/version.py + +.DS_Store + +log +*.qdrep +.vscode + +!bmtrain/dist +tests/test_log.txt +tests/*.opt +tests/*.ckp \ No newline at end of file diff --git a/examples/BMTrain/CMakeLists.txt b/examples/BMTrain/CMakeLists.txt new file mode 100644 index 00000000..e027e7da --- /dev/null +++ b/examples/BMTrain/CMakeLists.txt @@ -0,0 +1,65 @@ +cmake_minimum_required(VERSION 3.18) +project(bmtrain) +enable_language(C) +enable_language(CXX) +set(CMAKE_CUDA_ARCHITECTURES "61;62;70;72;75;80") +enable_language(CUDA) +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_CXX_STANDARD_REQUIRED True) +set(CMAKE_CUDA_STANDARD 14) +set(CMAKE_CUDA_STANDARD_REQUIRED True) + +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_62,code=sm_62 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_72,code=sm_72 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80") + +if(NOT DEFINED ENV{BUILD_DOCKER_ENV} OR "$ENV{BUILD_DOCKER_ENV}" STREQUAL "0") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_86,code=sm_86") + set(AVX_FLAGS "${AVX_FLAGS} -march=native") +else() + message("Building in docker environment, skipping compute_86 and enable all avx flag") + set(AVX_FLAGS "${AVX_FLAGS} -mavx -mfma -mf16c -mavx512f") +endif() + +set(CMAKE_BUILD_RPATH $ORIGIN) +set(CMAKE_INSTALL_RPATH $ORIGIN) +set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/) + +find_package(NCCL REQUIRED) +find_package(Python ${PYTHON_VERSION} EXACT COMPONENTS Interpreter Development.Module REQUIRED) +message (STATUS "Python_EXECUTABLE: ${Python_EXECUTABLE}") +execute_process(COMMAND ${Python_EXECUTABLE} "-c" + "import pybind11; print(pybind11.get_cmake_dir())" + OUTPUT_VARIABLE PYBIND11_CMAKE_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE) +message (STATUS "PYBIND11_CMAKE_DIR: ${PYB +IND11_CMAKE_DIR}") +list(APPEND CMAKE_PREFIX_PATH ${PYBIND11_CMAKE_DIR}) +find_package(pybind11 REQUIRED) + +message (STATUS "CMAKE_INSTALL_RPATH: ${CMAKE_INSTALL_RPATH}") + +file(GLOB_RECURSE SOURCES "csrc/*.cpp") +file(GLOB_RECURSE CUDA_SOURCES "csrc/cuda/*.cu") + + +pybind11_add_module(C ${SOURCES} ${CUDA_SOURCES}) + +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${AVX_FLAGS}") + +target_link_libraries(C PRIVATE + "-Wl,-Bsymbolic" + "-Wl,-Bsymbolic-functions" + ${NCCL_LIBRARIES} +) +target_include_directories(C PRIVATE ${NCCL_INCLUDE_DIRS}) +target_compile_definitions(C + PRIVATE VERSION_INFO=${EXAMPLE_VERSION_INFO}) + +set_target_properties(C PROPERTIES CUDA_ARCHITECTURES "61;62;70;72;75;80") + +target_include_directories(C + PRIVATE "csrc/include" + PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} +) + + + diff --git a/examples/BMTrain/CONTRIBUTING.md b/examples/BMTrain/CONTRIBUTING.md new file mode 100644 index 00000000..e1c4458f --- /dev/null +++ b/examples/BMTrain/CONTRIBUTING.md @@ -0,0 +1,55 @@ +# Contributing to BMTrain + +We welcome everyone's effort to make the community and the package better. You are welcomed to propose an issue, make a pull request or help others in the community. All of the efforts are appreciated! + +There are many ways that you can contribute to BMTrain: + +- ✉️ Submitting an issue. +- ⌨️ Making a pull request. +- 🤝 Serving the community. + +## Submitting an issue +You can submit an issue if you find bugs or require new features and enhancements. Here are some principles: + +1. **Language.** It is better to write your issue in English so that more people can understand and help you more conveniently. +2. **Search.** It is a good habit to search existing issues using the search bar of GitHub. Make sure there are no duplicated or similar issues with yours and if yes, check their solutions first. +3. **Format.** It is also very helpful to write the issue with a good writing style. We provide templates of common types of issues and everyone is encouraged to use these templates. If the templates do not fit in your issue, feel free to open a blank one. +4. **Writing style.** Write your issues in clear and concise words. It is also important to provide enough details for others to help. For example in a bug report, it is better to provide your running environment and minimal lines of code to reproduce it. + +## Making a pull request (PR) +You can also write codes to contribute. The codes may include a bug fix, a new enhancement, or a new running example. Here we provide the steps to make a pull request: + +1. **Combine the PR with an issue.** Make us and others know what you are going to work on. If your codes try to solve an existing issue, you should comment on the issue and make sure there are no others working on it. If you are proposing a new enhancement, submit an issue first and we can discuss it with you before you work on it. + +2. **Fork the repository.** Fork the repository to your own GitHub space by clicking the "Fork" button. Then clone it on your disk and set the remote repo: +```git +$ git clone https://github.com/<your GitHub>/BMTrain.git +$ cd BMTrain +$ git remote add upstream https://github.com/OpenBMB/BMTrain.git +``` + +3. **Write your code.** Change to a new branch to work on your modifications. +```git +$ git checkout -b your-branch-name +``` +You are encouraged to think up a meaningful and descriptive name for your branch. + +4. **Make a pull request.** After you finish coding, you should first rebase your code and solve the conflicts with the remote codes: +```git +$ git fetch upstream +$ git rebase upstream/main +``` +Then you can push your codes to your own repo: +```git +$ git push -u origin your-branch-name +``` +Finally, you can make the pull request from your GitHub repo and merge it with ours. Your codes will be merged into the main repo after our code review. + + +## Serving the community + +Besides submitting issues and PRs, you can also join our community and help others. Efforts like writing the documents, answering questions as well as discussing new features are appreciated and welcomed. It will also be helpful if you can post your opinions and feelings about using our package on social media. + +We are now developing a reward system and all your contributions will be recorded and rewarded in the future. + + diff --git a/examples/BMTrain/Dockerfile b/examples/BMTrain/Dockerfile new file mode 100644 index 00000000..8e6cbddf --- /dev/null +++ b/examples/BMTrain/Dockerfile @@ -0,0 +1,20 @@ +FROM nvidia/cuda:10.2-devel +WORKDIR /build +RUN apt update && apt install -y --no-install-recommends \ + build-essential \ + python3-dev \ + python3-pip \ + python3-setuptools \ + python3-wheel +RUN pip3 install torch==1.10.0 -i https://pypi.tuna.tsinghua.edu.cn/simple +RUN pip3 install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple +RUN apt install iputils-ping opensm libopensm-dev libibverbs1 libibverbs-dev -y --no-install-recommends +ENV TORCH_CUDA_ARCH_LIST=6.1;7.0;7.5 +ENV BMT_AVX512=1 +ADD other_requirements.txt other_requirements.txt +RUN pip3 install --upgrade pip && pip3 install -r other_requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple +ADD . . +RUN python3 setup.py install + +WORKDIR /root +ADD example example \ No newline at end of file diff --git a/examples/BMTrain/LICENSE b/examples/BMTrain/LICENSE new file mode 100644 index 00000000..7ad7f39e --- /dev/null +++ b/examples/BMTrain/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2022 OpenBMB + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/examples/BMTrain/MANIFEST.in b/examples/BMTrain/MANIFEST.in new file mode 100644 index 00000000..a6f97fa4 --- /dev/null +++ b/examples/BMTrain/MANIFEST.in @@ -0,0 +1,4 @@ +graft csrc +include CMakeLists.txt +graft cmake + diff --git a/examples/BMTrain/README-ZH.md b/examples/BMTrain/README-ZH.md new file mode 100644 index 00000000..d36953a2 --- /dev/null +++ b/examples/BMTrain/README-ZH.md @@ -0,0 +1,369 @@ +<div align="center"> + +<h1><img src="docs/logo.png" height="28px" /> BMTrain</h1> + +**大模型高效训练工具包** + +<p align="center"> + <a href="#总览">总览</a> • <a href="#文档">文档</a> • <a href="#安装">安装</a> • <a href="#使用说明">使用说明</a> • <a href="#性能">性能</a> • <a href="./README.md" target="_blank">English</a> +<br> +</p> + +<p align="center"> + +<a href='https://bmtrain.readthedocs.io/en/latest/?badge=latest'> + <img src='https://readthedocs.org/projects/bmtrain/badge/?version=latest' alt='Documentation Status' /> +</a> + +<a href="https://github.com/OpenBMB/BMTrain/releases"> + <img alt="GitHub release (latest by date including pre-releases)" src="https://img.shields.io/github/v/release/OpenBMB/BMTrain?include_prereleases"> +</a> + +<a href="https://github.com/OpenBMB/BMTrain/blob/main/LICENSE"> + <img alt="GitHub" src="https://img.shields.io/github/license/OpenBMB/BMTrain"> +</a> + +</p> + +</div> + +## 最新动态 +- 2022/06/14 **BMTrain** [0.1.7](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.7) 发布。支持了ZeRO-2优化! +- 2022/03/30 **BMTrain** [0.1.2](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.2) 发布。适配了[OpenPrompt](https://github.com/thunlp/OpenPrompt)和 [OpenDelta](https://github.com/thunlp/OpenDelta)工具包。 +- 2022/03/16 **BMTrain** [0.1.1](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.1) 公开发布了第一个稳定版本,修复了 beta 版本中的一些问题。 +- 2022/02/11 **BMTrain** [0.0.15](https://github.com/OpenBMB/BMTrain/releases/tag/0.0.15) 公开发布了第一个 beta 版本。 + +<div id="总览"></div> + +## 总览 + +BMTrain 是一个高效的大模型训练工具包,可以用于训练数百亿参数的大模型。BMTrain 可以在分布式训练模型的同时,能够保持代码的简洁性。 + +<div id="文档"></div> + +## 文档 +我们的[文档](https://bmtrain.readthedocs.io/en/latest/index.html)提供了关于工具包的更多信息。 + +<div id="安装"></div> + +## 安装 + +- 用 pip 安装(推荐): ``pip install bmtrain`` + +- 从源代码安装: 下载工具包,然后运行 ``pip install .`` (setup.py的安装方式将会在未来被setuptools弃用) + +安装 BMTrain 可能需要花费数分钟的时间,因为在安装时需要编译 c/cuda 源代码。 +我们推荐直接在训练环境中编译 BMTrain,以避免不同环境带来的潜在问题。 + + +<div id="使用说明"></div> + +## 使用说明 + +### 步骤 1: 启用 BMTrain + +首先,你需要在代码开头初始化 BMTrain。正如在使用 PyTorch 的分布式训练模块需要在代码开头使用 **init_process_group** 一样,使用 BMTrain 需要在代码开头使用 **init_distributed**。 + +```python +import bmtrain as bmt +bmt.init_distributed( + seed=0, + zero_level=3, # 目前支持2和3 + # ... +) +``` + +**注意:** 使用 BMTrain 时请不要使用 PyTorch 自带的 `distributed` 模块,包括 `torch.distributed.init_process_group` 以及相关通信函数。 + +### 步骤 2: 使用 ZeRO 优化 + +使用ZeRO优化需要对模型代码进行简单替换: + +* `torch.nn.Module` -> `bmtrain.DistributedModule` +* `torch.nn.Parameter` -> `bmtrain.DistributedParameter` + +并在 transformer 模块上使用 `bmtrain.CheckpointBlock`。 + +下面是一个例子: + +**原始代码** + +```python +import torch +class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.empty(1024)) + self.module_list = torch.nn.ModuleList([ + SomeTransformerBlock(), + SomeTransformerBlock(), + SomeTransformerBlock() + ]) + + def forward(self): + x = self.param + for module in self.module_list: + x = module(x, 1, 2, 3) + return x + +``` + +**替换后代码** + +```python +import torch +import bmtrain as bmt +class MyModule(bmt.DistributedModule): # 修改这里 + def __init__(self): + super().__init__() + self.param = bmt.DistributedParameter(torch.empty(1024)) # 修改这里 + self.module_list = torch.nn.ModuleList([ + bmt.CheckpointBlock(SomeTransformerBlock()), # 修改这里 + bmt.CheckpointBlock(SomeTransformerBlock()), # 修改这里 + bmt.CheckpointBlock(SomeTransformerBlock()) # 修改这里 + ]) + + def forward(self): + x = self.param + for module in self.module_list: + x = module(x, 1, 2, 3) + return x + +``` + +### 步骤 3: 通信优化 + +为了进一步缩短通信额外开销,将通信与运算时间重叠,可以使用 `TransformerBlockList` 来进一步优化。 + +在使用时需要对代码进行简单替换: + +* `torch.nn.ModuleList` -> `bmtrain.TransformerBlockList` +* `for module in self.module_list: x = module(x, ...)` -> `x = self.module_list(x, ...)` + +**原始代码** + +```python +import torch +import bmtrain as bmt +class MyModule(bmt.DistributedModule): + def __init__(self): + super().__init__() + self.param = bmt.DistributedParameter(torch.empty(1024)) + self.module_list = torch.nn.ModuleList([ + bmt.CheckpointBlock(SomeTransformerBlock()), + bmt.CheckpointBlock(SomeTransformerBlock()), + bmt.CheckpointBlock(SomeTransformerBlock()) + ]) + + def forward(self): + x = self.param + for module in self.module_list: + x = module(x, 1, 2, 3) + return x + +``` + +**替换后代码** + +```python +import torch +import bmtrain as bmt +class MyModule(bmt.DistributedModule): + def __init__(self): + super().__init__() + self.param = bmt.DistributedParameter(torch.empty(1024)) + self.module_list = bmt.TransformerBlockList([ # 修改这里 + bmt.CheckpointBlock(SomeTransformerBlock()), + bmt.CheckpointBlock(SomeTransformerBlock()), + bmt.CheckpointBlock(SomeTransformerBlock()) + ]) + + def forward(self): + x = self.param + x = self.module_list(x, 1, 2, 3) # 修改这里 + return x + +``` + +### 步骤 4: 运行分布式训练代码 + +BMTrain 使用 PyTorch 原生的分布式训练启动器,你可以根据 PyTorch 版本选择下列命令中的一个。 + +* `${MASTER_ADDR}` 为主节点的 IP 地址 +* `${MASTER_PORT}` 为主节点的端口 +* `${NNODES}` 为节点数量 +* `${GPU_PER_NODE}` 为每个节点的 GPU 数量 +* `${NODE_RANK}` 为本节点的 rank + +#### torch.distributed.launch +```shell +$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.py +``` + +#### torchrun + +```shell +$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.py +``` + +更多信息请参考 PyTorch [官方文档](https://pytorch.org/docs/stable/distributed.html#launch-utility)。 + +## 样例 + +我们提供了一个使用 BMTrain 训练 GPT-2 的[样例](https://github.com/OpenBMB/BMTrain/tree/main/example)。 +代码主要包含以下几个部分。 + +### 第 1 部分: 模型定义 + +``` +├── layers +│ ├── attention.py +│ ├── embedding.py +│ ├── feedforward.py +│ ├── __init__.py +│ ├── layernorm.py +│ └── linear.py +└── models + ├── gpt.py + └── __init__.py +``` + +上面是代码的目录结构。 + +我们定义了 GPT-2 需要的所有模型层,并使用 BMTrain 的 `DistributedModule` 和 `DistributedParameter` 来启用 ZeRO 优化。 + +### 第 2 部分: 初始化 BMTrain + +```python +bmtrain.init_distributed(seed=0) + +model = GPT( + num_layers=8, + vocab_size=10240, + dim_model=2560, + dim_head=80, + num_heads=32, + dim_ff=8192, + max_distance=1024, + bias=True, + dtype=torch.half +) + +bmtrain.init_parameters(model) # 或者使用`bmtrain.load`加载checkpoint + +# ... 其他初始化(例如数据集) ... +``` + +`bmtrain.init_distributed(seed=0)` 用于初始化分布式训练环境,并设置随机数种子便于复现。 + +`bmtrain.init_parameters(model)` 用于初始化模型的分布式参数。 + +### 第 3 部分: 初始化优化器和学习率调整策略 + +```python +loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) +optimizer = bmtrain.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) +lr_scheduler = bmtrain.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) +``` + +BMTrain 支持**所有** PyTorch 原生的优化器和损失函数,同时你也可以使用 BMTrain 提供的融合(fused)优化器用于混合精度训练。 + +此外,在 `bmtrain.lr_scheduler` 中 BMTrain 也提供了常见的学习率调整策略。 + +### 第 4 部分: 训练 + +```python +# 新建优化器管理器实例 +optim_manager = bmtrain.optim.OptimManager(loss_scale=1024) +# 将所有的 optimzer 及(可选)其对应的 lr_scheduler 收入优化器管理器管理。 +optim_manager.add_optimizer(optimizer, lr_scheduler) +# 可以再次调用 add_optimizer 加入其他优化器 + +for iteration in range(1000): + # ... 为每个rank加载数据 ... + + # 前向传播并计算梯度 + pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) + logits = model( + enc_input, + pos, + pos < enc_length[:, None] + ) + batch, seq_len, vocab_out_size = logits.size() + + loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) + + global_loss = bmtrain.sum_loss(loss).item() # 聚合所有rank上的损失, 仅用于输出训练日志 + + # 梯度清零 + optim_manager.zero_grad() # 为每个 optimizer 调用 zero_grad + + # 损失缩放和反向传播 + optim_manager.backward(loss) + + # 梯度裁剪 + grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=1.0) + + # 更新参数 + optim_manager.step() + + # ... 保存checkpoint、打印日志 ... +``` + +这部分代码略有些长,但写起来就像常见的训练代码一样,你不需要为分布式训练调整太多的代码。 + +你可以根据代码中的注释来了解各部分代码的作用。 + +唯一需要说明的是 `optim_manager`。在使用 BMTrain 后,优化器的部分相关操作需要有一些细节上的调整。我们在 `optim_manager` 帮你实现了这些细节, 你只需要通过 `add_optimizer` 将优化器和学习率调整策略收入 `optim_manager` 管理,并由 `optim_manger` 代为执行 `zero_grad()`, `backward()`, `clip_grad_norm()` 和 `step()` 等操作。 + +如果你没有使用混合精度训练,你可以不用损失缩放,只需要将 `OptimManger(loss_scale=None)` 构造函数中 `loss_scale` 置为 None 即可, 这也是 `OptimManager` 的默认构造参数。 + +如果你使用了混合精度训练,**损失缩放**是混合精度训练中的一项常用技术,我们在 `optim_manager.backward(loss)` 帮你对 `loss` 进行了放缩,用于避免梯度下溢。只需要将 `OptimManger` 构造函数中 `loss_scale` 置为一个浮点数即可。 `loss_scale` 会在训练过程中根据梯度进行自适应的调整。 + +<div id="性能"></div> + +## 性能 + +我们训练了一个有130亿参数的 GPT-2 模型,使用了4台服务器,每台服务器有8张V100显卡。我们测试了训练过程中每个GPU的吞吐量(每个GPU每秒处理的样本数),结果见下表。 + +模型结构: +* 40层 +* 128个注意力头 +* 5120的隐藏层维数 +* 512的序列长度 + + +| batch size | 8 | 16 | 24 | 32 | +|-------------|-------|-------|:------|:------| +| BMTrain | 24.15 | 26.94 | 29.42 | 28.28 | +| ZeRO3(mp=1) | 14.88 | 21.69 | 24.38 | - | +| ZeRO3(mp=4) | 15.51 | - | - | - | +| ZeRO3(mp=8) | 15.51 | - | - | - | +| ZeRO2(mp=1) | - | - | - | - | +| ZeRO2(mp=4) | 22.85 | - | - | - | +| ZeRO2(mp=8) | 21.33 | - | - | - | + +**ZeROa(mp=b)** 表示 DeepSpeed + Megatron ZeRO stage a 和 model parallelism = b。 + +表格中的 **-** 表示超出显存。 + +## 模型支持 + +我们已经将大多数常见的 NLP 模型移植到了 BMTrain 中。你可以在 [ModelCenter](https://github.com/OpenBMB/ModelCenter) 项目中找到支持模型的列表。 + +## 开源社区 +欢迎贡献者参照我们的[贡献指南](https://github.com/OpenBMB/BMTrain/blob/master/CONTRIBUTING.md)贡献相关代码。 + +您也可以在其他平台与我们沟通交流: +- QQ群: 735930538 +- 官方网站: https://www.openbmb.org +- 微博: http://weibo.cn/OpenBMB +- Twitter: https://twitter.com/OpenBMB + +## 开源许可 + +该工具包使用[Apache 2.0](https://github.com/OpenBMB/BMTrain/blob/main/LICENSE)开源许可证。 + +## 其他说明 + +`BMTrain` 工具包对 PyTorch 进行了底层修改,如果你的程序输出了意料之外的结果,可以在 issue 中提交相关信息。 diff --git a/examples/BMTrain/README.md b/examples/BMTrain/README.md new file mode 100644 index 00000000..134929f5 --- /dev/null +++ b/examples/BMTrain/README.md @@ -0,0 +1,375 @@ +<div align="center"> + +<h1><img src="docs/logo.png" height="28px" /> BMTrain</h1> + +**Efficient Training for Big Models** + +<p align="center"> + <a href="#overview">Overview</a> • <a href="#documentation">Documentation</a> • <a href="#install">Installation</a> • <a href="#usage">Usage</a> • <a href="#performance">Performance</a> • <a href="./README-ZH.md" target="_blank">简体中文</a> +<br> +</p> + +<p align="center"> + +<a href='https://bmtrain.readthedocs.io/en/latest/?badge=latest'> + <img src='https://readthedocs.org/projects/bmtrain/badge/?version=latest' alt='Documentation Status' /> +</a> + +<a href="https://github.com/OpenBMB/BMTrain/releases"> + <img alt="GitHub release (latest by date including pre-releases)" src="https://img.shields.io/github/v/release/OpenBMB/BMTrain?include_prereleases"> +</a> + +<a href="https://github.com/OpenBMB/BMTrain/blob/main/LICENSE"> + <img alt="GitHub" src="https://img.shields.io/github/license/OpenBMB/BMTrain"> +</a> + +</p> + +</div> + +## What's New +- 2024/02/26 **BMTrain** [1.0.0](https://github.com/OpenBMB/BMTrain/releases/tag/v1.0.0) released. Code refactoring and Tensor parallel support. See the detail in [update log](docs/UPDATE_1.0.0.md) +- 2023/08/17 **BMTrain** [0.2.3](https://github.com/OpenBMB/BMTrain/releases/tag/v0.2.3) released. See the [update log](docs/UPDATE_0.2.3.md). +- 2022/12/15 **BMTrain** [0.2.0](https://github.com/OpenBMB/BMTrain/releases/tag/0.2.0) released. See the [update log](docs/UPDATE_0.2.0.md). +- 2022/06/14 **BMTrain** [0.1.7](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.7) released. ZeRO-2 optimization is supported! +- 2022/03/30 **BMTrain** [0.1.2](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.2) released. Adapted to [OpenPrompt](https://github.com/thunlp/OpenPrompt)and [OpenDelta](https://github.com/thunlp/OpenDelta). +- 2022/03/16 **BMTrain** [0.1.1](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.1) has publicly released the first stable version, which fixes many bugs that were in the beta version. +- 2022/02/11 **BMTrain** [0.0.15](https://github.com/OpenBMB/BMTrain/releases/tag/0.0.15) has publicly released the first beta version. + +<div id="overview"></div> + +## Overview + +BMTrain is an efficient large model training toolkit that can be used to train large models with tens of billions of parameters. It can train models in a distributed manner while keeping the code as simple as stand-alone training. + +<div id="documentation"></div> + +## Documentation +Our [documentation](https://bmtrain.readthedocs.io/en/latest/index.html) provides more information about the package. + +<div id="install"></div> + +## Installation + +- From pip (recommend) : ``pip install bmtrain`` + +- From source code: download the package and run ``pip install .`` + +Installing BMTrain may take a few to ten minutes, as it requires compiling the c/cuda source code at the time of installation. +We recommend compiling BMTrain directly in the training environment to avoid potential problems caused by the different environments. + +<div id="usage"></div> + +## Usage + +### Step 1: Initialize BMTrain + +Before you can use BMTrain, you need to initialize it at the beginning of your code. Just like using the distributed module of PyTorch requires the use of **init_process_group** at the beginning of the code, using BMTrain requires the use of **init_distributed** at the beginning of the code. + +```python +import bmtrain as bmt +bmt.init_distributed( + seed=0, + # ... +) +``` + +**NOTE:** Do not use PyTorch's distributed module and its associated communication functions when using BMTrain. + +### Step 2: Enable ZeRO Optimization + +To enable ZeRO optimization, you need to make some simple replacements to the original model's code. + +* `torch.nn.Module` -> `bmtrain.DistributedModule` +* `torch.nn.Parameter` -> `bmtrain.DistributedParameter` + +And wrap the transformer blocks with `bmtrain.Block`. + +Here is an example. + +**Original** + +```python +import torch +class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.empty(1024)) + self.module_list = torch.nn.ModuleList([ + SomeTransformerBlock(), + SomeTransformerBlock(), + SomeTransformerBlock() + ]) + + def forward(self): + x = self.param + for module in self.module_list: + x = module(x, 1, 2, 3) + return x + +``` + +**Replaced** + +```python +import torch +import bmtrain as bmt +class MyModule(bmt.DistributedModule): # changed here + def __init__(self): + super().__init__() + self.param = bmt.DistributedParameter(torch.empty(1024)) # changed here + self.module_list = torch.nn.ModuleList([ + bmt.Block(SomeTransformerBlock(), zero_level=3), # changed here, support 2 and 3 now + bmt.Block(SomeTransformerBlock(), zero_level=3), # changed here, support 2 and 3 now + bmt.Block(SomeTransformerBlock(), zero_level=3) # changed here, support 2 and 3 now + ]) + + def forward(self): + x = self.param + for module in self.module_list: + x = module(x, 1, 2, 3) + return x + +``` + +### Step 3: Enable Communication Optimization + + +To further reduce the extra overhead of communication and overlap communication with computing time, `TransformerBlockList` can be used for optimization. + +You can enable them by making the following substitutions to the code: + +* `torch.nn.ModuleList` -> `bmtrain.TransformerBlockList` +* `for module in self.module_list: x = module(x, ...)` -> `x = self.module_list(x, ...)` + +**Original** + +```python +import torch +import bmtrain as bmt +class MyModule(bmt.DistributedModule): + def __init__(self): + super().__init__() + self.param = bmt.DistributedParameter(torch.empty(1024)) + self.module_list = torch.nn.ModuleList([ + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()) + ]) + + def forward(self): + x = self.param + for module in self.module_list: + x = module(x, 1, 2, 3) + return x + +``` + +**Replaced** + +```python +import torch +import bmtrain as bmt +class MyModule(bmt.DistributedModule): + def __init__(self): + super().__init__() + self.param = bmt.DistributedParameter(torch.empty(1024)) + self.module_list = bmt.TransformerBlockList([ # changed here + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()) + ]) + + def forward(self): + x = self.param + for module in self.module_list: + x = module(x, 1, 2, 3) + return x + +``` + +### Step 4: Launch Distributed Training + +BMTrain uses the same launch command as the distributed module of PyTorch. + +You can choose one of them depending on your version of PyTorch. + +* `${MASTER_ADDR}` means the IP address of the master node. +* `${MASTER_PORT}` means the port of the master node. +* `${NNODES}` means the total number of nodes. +* `${GPU_PER_NODE}` means the number of GPUs per node. +* `${NODE_RANK}` means the rank of this node. + +#### torch.distributed.launch +```shell +$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.py +``` + +#### torchrun + +```shell +$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.py +``` + + +For more information, please refer to the [documentation](https://pytorch.org/docs/stable/distributed.html#launch-utility). + +## Example + +We provide an [example](https://github.com/OpenBMB/BMTrain/tree/main/example) of training GPT-2 based on BMTrain. +The code mainly consists of the following parts. + +### Part 1: Model Definition + +``` +├── layers +│ ├── attention.py +│ ├── embedding.py +│ ├── feedforward.py +│ ├── __init__.py +│ ├── layernorm.py +│ └── linear.py +└── models + ├── gpt.py + └── __init__.py +``` + +Above is the directory structure of the code in the part of Model Definition. + +We defined all the layers needed in GPT-2 and used BMTrain's `DistributedModule` and `DistributedParameter` to enable ZeRO optimization. + +### Part 2: BMTrain Initialization + +```python +bmtrain.init_distributed(seed=0) + +model = GPT( + num_layers=8, + vocab_size=10240, + dim_model=2560, + dim_head=80, + num_heads=32, + dim_ff=8192, + max_distance=1024, + bias=True, + dtype=torch.half +) + +bmtrain.init_parameters(model) # or loading checkpoint use `bmtrain.load` + +# ... other initialization (dataset) ... +``` + +`bmtrain.init_distributed(seed=0)` is used to initialize the distributed training environment and set the random seed for reproducibility. + +`bmtrain.init_parameters(model)` is used to initialize the distributed parameters of the model. + +### Part 3: Intialization of the Optimizer and LR Scheduler + +```python +loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) +optimizer = bmtrain.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) +lr_scheduler = bmtrain.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) +``` + +BMTrain supports *all* the PyTorch native optimizers and loss functions, and you can also use the fused optimizer provided by BMTrain for mixed-precision training. + +In addition, BMTrain also provides the common LRScheduler in the `bmtrain.lr_scheduler` module. + +### Part 4: Training Loop + +```python +# create a new instance of optimizer manager +optim_manager = bmtrain.optim.OptimManager(loss_scale=1024) +# let optim_manager handle all the optimizer and (optional) their corresponding lr_scheduler +optim_manager.add_optimizer(optimizer, lr_scheduler) +# add_optimizer can be called multiple times to add other optimizers. + +for iteration in range(1000): + # ... load data for each rank ... + + # forward pass and calculate loss + pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) + logits = model( + enc_input, + pos, + pos < enc_length[:, None] + ) + batch, seq_len, vocab_out_size = logits.size() + + loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) + + global_loss = bmtrain.sum_loss(loss).item() # sum the loss across all ranks. This is only used for the training log + + # zero grad + optim_manager.zero_grad() # calling zero_grad for each optimizer + + # loss scale and backward + optim_manager.backward(loss) + + # clip grad norm + grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=1.0) + + # optimizer step + optim_manager.step() + + # ... save checkpoint or print logs ... +``` + +The training loop part will be slightly longer, but just like a normal training loop, you don't need to adapt much to distributed training. + +You can follow the comments in the code to get an idea of what each section of code is doing. + +The only additional note is `optimizer`. After using BMTrain, some details in optimizers should be adjusted. We have implemented all those details needed in `optim_manager`. What you need is just letting `optim_manager` to handle all the optimizers by `add_optimizer`, and letting `optim_manager` do `zero_grad()`, `backward()`, `clip_grad_norm()` and `step()` instead. + +If you are not using the mixed-precision training, you can train without `loss_scale`. Just set `loss_scale` to None in the `__init__` function of `OptimManager(loss_scale=None)`, which is also the default. + +If you are using mixed-precision training, *loss scale* is the technique widely used in mixed precision training to prevent gradient underflow. By using `optim_manager.backward(loss)` to scale the `loss` before backward and set `loss_scale` to some floating number in the `__init__` function of `OptimManager`。The `loss_scale` would be adjusted adaptively based on the gradient during training. + +<div id="performance"></div> + +## Performance + +We trained a GPT-2 model with 13B parameters using 4 servers with 8 V100s on each server, and measured the throughput of each GPU during the training process (samples per GPU per second). + +Model structure: +* 40 layers +* 128 attention heads +* 5120 hidden dimension +* 512 sequence length + + +| batch size | 8 | 16 | 24 | 32 | +|-------------|-------|-------|:------|:------| +| BMTrain | 24.15 | 26.94 | 29.42 | 28.28 | +| ZeRO3(mp=1) | 14.88 | 21.69 | 24.38 | - | +| ZeRO3(mp=4) | 15.51 | - | - | - | +| ZeRO3(mp=8) | 15.51 | - | - | - | +| ZeRO2(mp=1) | - | - | - | - | +| ZeRO2(mp=4) | 22.85 | - | - | - | +| ZeRO2(mp=8) | 21.33 | - | - | - | + +**ZeROa(mp=b)** means DeepSpeed + Megatron ZeRO stage a and model parallelism = b. + +**-** in the table means out of memory. + +## Supported Models + +We have migrated most of the common models in NLP to the BMTrain. You can find the list of supported models in the repo [ModelCenter](https://github.com/OpenBMB/ModelCenter). + +## Community +We welcome everyone to contribute codes following our [contributing guidelines](https://github.com/OpenBMB/BMTrain/blob/master/CONTRIBUTING.md). + +You can also find us on other platforms: +- QQ Group: 735930538 +- Website: https://www.openbmb.org +- Weibo: http://weibo.cn/OpenBMB +- Twitter: https://twitter.com/OpenBMB + +## License +The package is released under the [Apache 2.0](https://github.com/OpenBMB/BMTrain/blob/master/LICENSE) License. + +## Other Notes + +`BMTrain` makes underlying changes to PyTorch, so if your program outputs unexpected results, you can submit information about it in an issue. + diff --git a/examples/BMTrain/Release.txt b/examples/BMTrain/Release.txt new file mode 100644 index 00000000..7c8a41be --- /dev/null +++ b/examples/BMTrain/Release.txt @@ -0,0 +1,9 @@ +## What's Changed +* Using pytorch's hook mechanism to refactor ZeRO, checkpoint, pipeline, communication implementation by @zkh2016 in #128 #159 +* Add Bf16 support by @Achazwl in #136 +* Tensor parallel implementation by @Achazwl @zkh2016 @MayDomine in #153 +* Async save state_dict by @zkh2016 in #171 +* `AdamOffloadOptimizer` can save whole gathered state by @MayDomine in #184 +* New test for new version's bmtrain by @Achazwl @JerryYin777 @MayDomine +**Full Changelog**: https://github.com/OpenBMB/BMTrain/compare/0.2.3...1.0.0 + diff --git a/examples/BMTrain/bmtrain/__init__.py b/examples/BMTrain/bmtrain/__init__.py new file mode 100644 index 00000000..f4ac3642 --- /dev/null +++ b/examples/BMTrain/bmtrain/__init__.py @@ -0,0 +1,26 @@ +from .utils import print_block, print_dict, print_rank, see_memory, load_nccl_pypi +try: + from . import nccl +except: + load_nccl_pypi() +from .global_var import config, world_size, rank +from .init import init_distributed + +from .parameter import DistributedParameter, ParameterInitializer +from .layer import DistributedModule +from .param_init import init_parameters, grouped_parameters +from .synchronize import synchronize, sum_loss, wait_loader, gather_result +from .block_layer import Block, TransformerBlockList +from .wrapper import BMTrainModelWrapper +from .pipe_layer import PipelineTransformerBlockList +from . import debug +from .store import save, load + +from . import loss +from . import distributed +from . import nn +from . import optim +from . import inspect +from . import lr_scheduler + +CheckpointBlock = Block diff --git a/examples/BMTrain/bmtrain/benchmark/__init__.py b/examples/BMTrain/bmtrain/benchmark/__init__.py new file mode 100644 index 00000000..571d621f --- /dev/null +++ b/examples/BMTrain/bmtrain/benchmark/__init__.py @@ -0,0 +1,3 @@ +from .all_gather import all_gather +from .reduce_scatter import reduce_scatter +from .send_recv import send_recv \ No newline at end of file diff --git a/examples/BMTrain/bmtrain/benchmark/all_gather.py b/examples/BMTrain/bmtrain/benchmark/all_gather.py new file mode 100644 index 00000000..b2f2ee7c --- /dev/null +++ b/examples/BMTrain/bmtrain/benchmark/all_gather.py @@ -0,0 +1,28 @@ +from .. import nccl +from .shape import SHAPES +from ..global_var import config +from ..utils import round_up, print_rank +from .utils import format_size +import torch + +def all_gather(): + current_stream = torch.cuda.current_stream() + for shape in SHAPES: + global_size = round_up(shape, config['world_size'] * 2) + partition_size = global_size // config['world_size'] + + partition_tensor = torch.empty( partition_size // 2, dtype=torch.half, device="cuda" ) + global_tensor = torch.empty( global_size // 2, dtype=torch.half, device="cuda" ) + + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + current_stream.record_event(start_evt) + nccl.allGather(partition_tensor.storage(), global_tensor.storage(), config['comm']) + current_stream.record_event(end_evt) + current_stream.synchronize() + time_usage = start_evt.elapsed_time(end_evt) + + bw = global_size / 1024 / 1024 / 1024 * 1000 / time_usage + print_rank("All gather:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(global_size), time_usage, bw)) + diff --git a/examples/BMTrain/bmtrain/benchmark/reduce_scatter.py b/examples/BMTrain/bmtrain/benchmark/reduce_scatter.py new file mode 100644 index 00000000..75733556 --- /dev/null +++ b/examples/BMTrain/bmtrain/benchmark/reduce_scatter.py @@ -0,0 +1,28 @@ +from .. import nccl +from .shape import SHAPES +from ..global_var import config +from ..utils import round_up, print_rank +from .utils import format_size +import torch + +def reduce_scatter(): + current_stream = torch.cuda.current_stream() + for shape in SHAPES: + global_size = round_up(shape, config['world_size']) + partition_size = global_size // config['world_size'] + + partition_tensor = torch.empty( partition_size // 2, dtype=torch.half, device="cuda" ) + global_tensor = torch.empty( global_size // 2, dtype=torch.half, device="cuda" ) + + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + current_stream.record_event(start_evt) + nccl.reduceScatter(global_tensor.storage(), partition_tensor.storage(), 'avg', config['comm']) + current_stream.record_event(end_evt) + current_stream.synchronize() + time_usage = start_evt.elapsed_time(end_evt) + + bw = global_size / 1024 / 1024 / 1024 * 1000 / time_usage + print_rank("Reduce Scatter:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(global_size), time_usage, bw)) + diff --git a/examples/BMTrain/bmtrain/benchmark/send_recv.py b/examples/BMTrain/bmtrain/benchmark/send_recv.py new file mode 100644 index 00000000..e3c971e4 --- /dev/null +++ b/examples/BMTrain/bmtrain/benchmark/send_recv.py @@ -0,0 +1,31 @@ +from .. import nccl +from .shape import SHAPES +from ..global_var import config +from ..utils import print_rank +from .utils import format_size +import torch +def send_recv(): + current_stream = torch.cuda.current_stream() + for shape in SHAPES: + send_size = shape + + send_buffer = torch.empty( send_size // 2, dtype=torch.half, device="cuda" ) + recv_buffer = torch.empty( send_size // 2, dtype=torch.half, device="cuda" ) + + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + current_stream.record_event(start_evt) + nccl.groupStart() + if config['rank'] in [0,2,4,6]: + nccl.send(send_buffer.storage(), config['rank']+1, config['comm']) + else: + nccl.recv(recv_buffer.storage(), config['rank']-1, config['comm']) + nccl.groupEnd() + current_stream.record_event(end_evt) + current_stream.synchronize() + time_usage = start_evt.elapsed_time(end_evt) + + bw = shape / 1024 / 1024 / 1024 * 1000 / time_usage + print_rank("Send Recv:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(send_size), time_usage, bw)) + diff --git a/examples/BMTrain/bmtrain/benchmark/shape.py b/examples/BMTrain/bmtrain/benchmark/shape.py new file mode 100644 index 00000000..0699e8cd --- /dev/null +++ b/examples/BMTrain/bmtrain/benchmark/shape.py @@ -0,0 +1,3 @@ +SHAPES = [ + (2**i) for i in range(10, 33) +] \ No newline at end of file diff --git a/examples/BMTrain/bmtrain/benchmark/utils.py b/examples/BMTrain/bmtrain/benchmark/utils.py new file mode 100644 index 00000000..dbc4a70c --- /dev/null +++ b/examples/BMTrain/bmtrain/benchmark/utils.py @@ -0,0 +1,11 @@ +def format_size_(x): + if x < 1024: + return "{:d}B".format(x) + if x < 1024 * 1024: + return "{:4.2f}KB".format(x / 1024) + if x < 1024 * 1024 * 1024: + return "{:4.2f}MB".format(x / 1024 / 1024) + return "{:4.2f}GB".format(x / 1024 / 1024 / 1024) + +def format_size(x): + return "{:.6s}".format(format_size_(x)) \ No newline at end of file diff --git a/examples/BMTrain/bmtrain/block_layer.py b/examples/BMTrain/bmtrain/block_layer.py new file mode 100644 index 00000000..216d77b2 --- /dev/null +++ b/examples/BMTrain/bmtrain/block_layer.py @@ -0,0 +1,726 @@ +from typing import Dict, Iterable, Iterator, Union, List + +from .utils import round_up, tp_split_tensor +from .global_var import config +import torch +from . import nccl +from .parameter import DistributedParameter, OpAllGather +from .zero_context import ZeroContext +from . import hook_func +import inspect +from torch.utils.checkpoint import checkpoint + + +def storage_type_cuda(storage_type): + """Convert storage_type to cuda storage_type.""" + STORAGE_MAP = { + torch.FloatStorage: torch.cuda.FloatStorage, + torch.DoubleStorage: torch.cuda.DoubleStorage, + torch.HalfStorage: torch.cuda.HalfStorage, + torch.BFloat16Storage: torch.cuda.BFloat16Storage, + torch.CharStorage: torch.cuda.CharStorage, + torch.ByteStorage: torch.cuda.ByteStorage, + torch.ShortStorage: torch.cuda.ShortStorage, + torch.IntStorage: torch.cuda.IntStorage, + torch.cuda.FloatStorage: torch.cuda.FloatStorage, + torch.cuda.DoubleStorage: torch.cuda.DoubleStorage, + torch.cuda.HalfStorage: torch.cuda.HalfStorage, + torch.cuda.BFloat16Storage: torch.cuda.BFloat16Storage, + torch.cuda.CharStorage: torch.cuda.CharStorage, + torch.cuda.ByteStorage: torch.cuda.ByteStorage, + torch.cuda.ShortStorage: torch.cuda.ShortStorage, + torch.cuda.IntStorage: torch.cuda.IntStorage, + } + if storage_type not in STORAGE_MAP: + raise ValueError("Unknown storage type: {}".format(storage_type)) + return STORAGE_MAP[storage_type] + + +def _get_param_kw(param: DistributedParameter): + """Get DistributedParameter kw name.""" + type_name = str(param.dtype).split(".")[-1] + grad_name = "_grad" if param.requires_grad else "_nograd" + group_name = "" + if param.group is not None: + group_name = "_g_" + param.group + return type_name + grad_name + group_name + + +class Block(torch.nn.Module): + """A block containing two memory-saving methods of ZeRO and checkpoint. + For details please refer to `ZeRO <https://arxiv.org/abs/1910.02054v3>`_ and + `Checkpointing <https://pytorch.org/docs/stable/checkpoint.html>`_ . + + Args: + inner_module (torch.nn.Module): The module to reduce memory usage. All kinds of modules are supported. + use_checkpoint (boolean): use checkpoint or not. Default True. + zero_level (int): 2 (ZeRO-2) indicates that optimizer states and gradients are partitioned across the process, + 3 (ZeRO-3) means that the parameters are partitioned one the basis of ZeRO-2. Default 3. + initialized (bool): initialized parameter storage. Default False. + mode (str): the mode shouled be "PIPE" when runing in pipeline mode, otherwise mode="BLOCK". Default "BLOCK" + + Examples: + >>> transformer_block = TransformerBlock(...) + >>> block = Block(transformer_block) + >>> y1, ... = block(x) + >>> y2, ... = transformer_block(x) + >>> assert torch.allclose(y1, y2) + """ + + def __init__( + self, + inner_module: torch.nn.Module, + use_checkpoint=True, + zero_level=3, + initialized=False, + mode="BLOCK", + ): + super().__init__() + self._module = inner_module + self._inputs = None + self._layer_dict = {} + self._forward_block_ctx = None + self._backward_block_ctx = None + + self._param_info = [] + self._storage_params: Dict[str, torch.nn.Parameter] = {} + self._storage_info = {} + self._ready = False + + self._use_checkpoint = use_checkpoint + self._is_first_layer = True + self._is_last_layer = True + self._need_release = True + self._next_module = None # save the next module of self + self._pre_module = None # save the pre module of self + self._mode = mode # BLOCK or PIPE + self.all_input_no_grad = False + self.all_param_no_grad = False + self._zero_level = zero_level + if not initialized: + self.init_param_storage() + + def reference(self, block): + """Make this block be a reference of the input Block.""" + self._param_info = block._param_info + self._storage_params = block._storage_params + self._storage_info = block._storage_info + self._layer_dict = block._layer_dict + self._initialized = True + self._need_release = False + + def init_param_storage(self): + """Init param storage.""" + # sort parameters by name + ordered_parameters = list(self._module.named_parameters()) + + # calc total number of parameters + for name, param in ordered_parameters: + if not isinstance(param, DistributedParameter): + raise ValueError( + "All parameters in checkpoint block must be DistributedParameter." + ) + + storage_type = storage_type_cuda(param.storage_type()) + kw_name = _get_param_kw(param) + + if kw_name not in self._storage_info: + if self._mode == "PIPE" and param._tp_mode: + zero_comm = config["pp_tp_zero_comm"] + elif self._mode != "PIPE" and param._tp_mode: + zero_comm = config["tp_zero_comm"] + elif self._mode == "PIPE" and not param._tp_mode: + zero_comm = config["pp_zero_comm"] + else: + zero_comm = config["zero_comm"] + + self._storage_info[kw_name] = { + "total": 0, + "storage_type": storage_type, + "requires_grad": param.requires_grad, + "group": param.group, + "zero_comm": zero_comm, + } + + param_shape = param._original_shape + + self._storage_info[kw_name]["total"] = round_up( + self._storage_info[kw_name]["total"] + param_shape.numel(), + 512 // param.element_size(), + # 512 bytes aligned + ) + + offsets = {} + # intialize storage buffers + for kw, val in self._storage_info.items(): + comm = val["zero_comm"] + world_size = nccl.commCount(comm) + rank = nccl.commRank(comm) + val["world_size"] = world_size + partition_size = ( + round_up(val["total"], val["world_size"]) // val["world_size"] + ) + val["partition_size"] = partition_size + val["begin"] = rank * partition_size + val["end"] = (rank + 1) * partition_size + offsets[kw] = 0 + + storage_type = val["storage_type"] + + storage_param_buffer = storage_type(partition_size) + + dtype = storage_param_buffer.dtype + device = storage_param_buffer.device + + # bind storage to buffer tensor + storage_param = torch.nn.Parameter( + torch.tensor([], dtype=dtype, device=device).set_(storage_param_buffer) + ) + if val["requires_grad"]: + storage_param.requires_grad_(True) + else: + storage_param.requires_grad_(False) + + self._storage_params[kw] = storage_param + + # initialize parameters in module + for name, param in ordered_parameters: + param_shape = param._original_shape + kw_name = _get_param_kw(param) + + param_st = offsets[kw_name] + offsets[kw_name] += param_shape.numel() + param_end = offsets[kw_name] + offsets[kw_name] = round_up(offsets[kw_name], 512 // param.element_size()) + + self._param_info.append( + { + "parameter": param, + "name": name, + "offset": param_st, + "size": param_shape.numel(), + "shape": param_shape, + "kw_name": kw_name, + } + ) + + # copy values to buffer for normal parameter + storage_st = self._storage_info[kw_name]["begin"] + storage_end = self._storage_info[kw_name]["end"] + + # make parameter contiguous in storage + with torch.no_grad(): + contiguous_param = OpAllGather.apply(param) + + if not (param_st >= storage_end or param_end <= storage_st): + # copy offset in parameter storage + offset_st = max(storage_st - param_st, 0) + offset_end = min(storage_end - param_st, contiguous_param.numel()) + assert offset_st < offset_end + + # copy to offset in buffer storage + to_offset_st = offset_st + param_st - storage_st + to_offset_end = offset_end + param_st - storage_st + + # copy to buffer + # PyTorch 1.11 changed the API of storage.__getitem__ + d_dtype = self._storage_params[kw_name].dtype + d_device = self._storage_params[kw_name].device + param.data = torch.tensor( + [], dtype=param.dtype, device=param.device + ).set_( + self._storage_params[kw_name].storage(), + to_offset_st, + (to_offset_end - to_offset_st,), + ) + self._param_info[-1]["begin"] = to_offset_st + self._param_info[-1]["end"] = (to_offset_end - to_offset_st,) + setattr(param, "_start_partition", offset_st) + setattr(param, "_end_partition", offset_end) + param.data[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_( + contiguous_param.storage(), offset_st, (offset_end - offset_st,) + )[:] + del contiguous_param + else: + param.data = torch.tensor([], dtype=param.dtype, device=param.device) + setattr(param, "_start_partition", None) + setattr(param, "_end_partition", 0) + # clear parameter data, but keep the dtype and device + setattr(param, "_in_block", True) + + for kw in offsets.keys(): + assert offsets[kw] == self._storage_info[kw]["total"] + + def set_pre_module(self, pre_module): + """Set pre module for current Block.""" + if pre_module is not None: + self._pre_module = pre_module + pre_module._next_module = self + + def pre_module(self): + """Return pre module of current Block.""" + return self._pre_module if not self._is_first_layer else None + + def next_module(self): + """Return next module of current Block.""" + return self._next_module if not self._is_last_layer else None + + def release_next_module(self, flag): + """Release next module of current Block.""" + if self.next_module() is not None: + self.next_module().release(flag) + + def release(self, flag): + """Release cuurent block ctx.""" + if self._need_release and self._backward_block_ctx is not None: + self._backward_block_ctx.exit(flag, True) + config["load_stream"].record_event(config["load_event"]) + + def pre_hook(self, *args): + """Hook function before forward.""" + grad_tensors = [] + grad_index = [] + arg_list = list(args) + for i, arg in enumerate(args): + if arg is not None and isinstance(arg, torch.Tensor) and arg.requires_grad: + grad_tensors.append(arg) + grad_index.append(i) + grad_tensors = tuple(grad_tensors) + + pre_out = hook_func.PreHookFunc.apply(self, *grad_tensors) + for i in range(len(grad_index)): + arg_list[grad_index[i]] = pre_out[i] + + if self._mode != "PIPE" and len(grad_tensors) == 0: + self.all_param_no_grad = True + for param in self._param_info: + if param["parameter"].requires_grad: + self.all_param_no_grad = False + break + self.all_input_no_grad = True + else: + self.all_input_no_grad = False + return arg_list + + def post_hook(self, out): + """Hook function after forward.""" + tuple_out = (out,) if isinstance(out, torch.Tensor) else out + post_out = hook_func.PostHookFunc.apply(self, *tuple_out) + if isinstance(out, torch.Tensor) and isinstance(post_out, tuple): + return post_out[0] + post_out = tuple(post_out) + return post_out + + def forward(self, *args, **kwargs): + signature = inspect.signature(self._module.forward) + bound_args = signature.bind(*args, **kwargs) + args = bound_args.args + arg_list = self.pre_hook(*args) + + + if self.all_input_no_grad and not self.all_param_no_grad: + placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled()) + return hook_func.OneStepNoGradFunc.apply(self, placeholder, *arg_list) + + if self._use_checkpoint: + out = checkpoint( + self._module, *arg_list, use_reentrant=not self.all_input_no_grad + ) + else: + out = self._module(*arg_list) + + return self.post_hook(out) + + def __getattr__(self, name: str): + if name == "_module": + return self._module + return getattr(self._module, name) + + def __setattr__(self, name, value): + object.__setattr__(self, name, value) + + def __getattribute__(self, name: str): + if name == "_parameters": + return self._module._parameters + return super().__getattribute__(name) + + def __delattr__(self, name): + object.__delattr__(self, name) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + raise RuntimeError("._save_to_state_dict() of Block should not be called") + + def state_dict(self, destination=None, prefix="", keep_vars=False): + # gather here + with torch.no_grad(): + with ZeroContext(self): + return self._module.state_dict( + destination=destination, prefix=prefix, keep_vars=keep_vars + ) + + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + all_keys = [] + for it in self._param_info: + key = prefix + it["name"] + all_keys.append(key) + if key in state_dict: + # load here + input_param = state_dict[key] + param = it["parameter"] + tp_mode = param._tp_mode + if input_param.__class__.__name__ == "DistributedTensorWrapper": + input_param = input_param.broadcast() + + verify_shape = torch.Size( + it["shape"] if not tp_mode else param._tp_original_shape + ) + if input_param.shape != verify_shape: + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format( + key, input_param.shape, verify_shape + ) + ) + continue + + param_st = it["offset"] + param_end = it["offset"] + it["size"] + kw_name = it["kw_name"] + + # not in this partition + storage_st = self._storage_info[kw_name]["begin"] + storage_end = self._storage_info[kw_name]["end"] + if param_st >= storage_end: + continue + if param_end <= storage_st: + continue + + # copy to buffer + verify_size = verify_shape.numel() + assert input_param.numel() == verify_size + + contiguous_param = ( + input_param.to(it["parameter"].dtype).cuda().contiguous() + ) + + tp_split_dim = param._tp_split_dim + if tp_mode and tp_split_dim >= 0: + contiguous_param = tp_split_tensor(contiguous_param, tp_split_dim) + + offset_st = max(storage_st - param_st, 0) + offset_end = min(storage_end - param_st, contiguous_param.numel()) + assert offset_st < offset_end + + to_offset_st = offset_st + param_st - storage_st + to_offset_end = offset_end + param_st - storage_st + + # copy to buffer + # PyTorch 1.11 changed the API of storage.__getitem__ + d_dtype = self._storage_params[kw_name].dtype + d_device = self._storage_params[kw_name].device + torch.tensor([], dtype=d_dtype, device=d_device).set_( + self._storage_params[kw_name].storage(), + to_offset_st, + (to_offset_end - to_offset_st,), + )[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_( + contiguous_param.storage(), offset_st, (offset_end - offset_st,) + )[ + : + ] + del contiguous_param + elif strict: + missing_keys.append(key) + + for name, param in self.named_parameters(): + if isinstance(param, DistributedParameter) and not param._in_block: + key = prefix + name + all_keys.append(key) + if key in state_dict: + input_param = state_dict[key] + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if ( + not is_param_lazy + and len(param.shape) == 0 + and len(input_param.shape) == 1 + ): + input_param = input_param[0] + + if ( + not is_param_lazy + and not isinstance(param, DistributedParameter) + and input_param.shape != param.shape + ): + # local shape should match the one in checkpoint + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format( + key, input_param.shape, param.shape + ) + ) + continue + if ( + not is_param_lazy + and isinstance(param, DistributedParameter) + and input_param.shape != param._original_shape + ): + error_msgs.append( + "size mismatch for {}: copying a param with shape {} from checkpoint, " + "the shape in current model is {}.".format( + key, input_param.shape, param.shape + ) + ) + try: + with torch.no_grad(): + param._copy_data(input_param) + except Exception as ex: + error_msgs.append( + 'While copying the parameter named "{}", ' + "whose dimensions in the model are {} and " + "whose dimensions in the checkpoint are {}, " + "an exception occurred : {}.".format( + key, param.size(), input_param.size(), ex.args + ) + ) + elif strict: + missing_keys.append(key) + + if strict: + all_keys = set(all_keys) + for key in state_dict.keys(): + if key.startswith(prefix) and key not in all_keys: + unexpected_keys.append(key) + + def grouped_parameters(self): + """ + Yield group params in storage params. + """ + ret = {} + for kw, val in self._storage_info.items(): + if val["group"] not in ret: + ret[val["group"]] = [] + ret[val["group"]].append(self._storage_params[kw]) + for kw, val in ret.items(): + yield kw, val + + def init_parameters(self): + """ + Initialize distributed parameters in this block. + """ + for it in self._param_info: + param = it["parameter"] + if ( + isinstance(param, DistributedParameter) + and param._init_method is not None + ): + # initialzie here + tmp_tensor = torch.empty( + param._tp_original_shape, device=param.device, dtype=param.dtype + ) + param._init_method(tmp_tensor) + param_st = it["offset"] + param_end = it["offset"] + it["size"] + kw_name = it["kw_name"] + + # not in this partition + storage_st = self._storage_info[kw_name]["begin"] + storage_end = self._storage_info[kw_name]["end"] + if param_st >= storage_end: + continue + if param_end <= storage_st: + continue + + if param._tp_mode and param._tp_split_dim >= 0: + tmp_tensor = tp_split_tensor(tmp_tensor, param._tp_split_dim) + # copy to buffer + assert tmp_tensor.is_contiguous() and it["size"] == tmp_tensor.numel() + + offset_st = max(storage_st - param_st, 0) + offset_end = min(storage_end - param_st, tmp_tensor.numel()) + assert offset_st < offset_end + + # copy to buffer + # PyTorch 1.11 changed the API of storage.__getitem__ + d_dtype = self._storage_params[kw_name].dtype + d_device = self._storage_params[kw_name].device + param.data[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_( + tmp_tensor.storage(), offset_st, (offset_end - offset_st,) + )[:] + del tmp_tensor + + def _named_members(self, get_members_fn, prefix="", recurse=True, **kwargs): + r"""Helper method for yielding various names + members of modules.""" + + # compitibity with torch 2.0 + if ( + "remove_duplicate" + in inspect.signature(torch.nn.Module._named_members).parameters + and "remove_duplicate" not in kwargs + ): + kwargs["remove_duplicate"] = True + return self._module._named_members(get_members_fn, prefix, recurse, **kwargs) + + def named_modules(self, memo=None, prefix: str = "", remove_duplicate: bool = True): + r"""Returns an iterator over all modules in the network, yielding + both the name of the module as well as the module itself. + + Args: + memo: a memo to store the set of modules already added to the result + prefix: a prefix that will be added to the name of the module + remove_duplicate: whether to remove the duplicated module instances in the result + or not + + Yields: + (string, Module): Tuple of name and module + + Note: + Duplicate modules are returned only once. In the following + example, ``l`` will be returned only once. + + Example:: + + >>> l = nn.Linear(2, 2) + >>> net = nn.Sequential(l, l) + >>> for idx, m in enumerate(net.named_modules()): + print(idx, '->', m) + + 0 -> ('', Sequential( + (0): Linear(in_features=2, out_features=2, bias=True) + (1): Linear(in_features=2, out_features=2, bias=True) + )) + 1 -> ('0', Linear(in_features=2, out_features=2, bias=True)) + + """ + + if memo is None: + memo = set() + if self not in memo: + if remove_duplicate: + memo.add(self) + yield prefix, self + for name, module in self._module._modules.items(): + if module is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + for m in module.named_modules(memo, submodule_prefix, remove_duplicate): + yield m + + def named_children(self): + return self._module.named_children() + + def train(self, mode: bool = True): + self._module.train(mode) + + def eval(self): + self._module.eval() + + def __repr__(self): + return self._module.__repr__() + + +def _block_wrapper(module, module_dict: dict, mode="BLOCK"): + if not isinstance(module, Block): + in_block = id(module) in module_dict + new_module = Block(module, initialized=in_block, mode=mode) + if in_block: + new_module.reference(module_dict[id(module)]) + else: + module_dict[id(module)] = new_module + else: + if mode == "PIPE" and module._mode != "PIPE": + assert ( + False + ), 'You must be set mode="PIPE" in bmt.Block when use PipelineTransformerBlockList!' + if id(module._module) in module_dict: + assert False, "Duplicate bmt.Block not supported in same block list!" + else: + new_module = module + module_dict[id(module._module)] = new_module + return new_module + + +class TransformerBlockList(torch.nn.Module): + r""" + TransformerBlockList is a list of bmt.Block. + + This is designed to reduce the communication overhead by overlapping the computation and reduce_scatter operation during backward pass. + + It is similar to `torch.nn.ModuleList` but with the difference when calling .forward() and .backward(). + + Example: + >>> module_list = [ ... ] + >>> normal_module_list = torch.nn.ModuleList(module_list) + >>> transformer_module_list = TransformerBlockList(module_list) + >>> # Calling normal module list + >>> for layer in normal_module_list: + >>> hidden_state = layer.forward(hidden_state, ...) + >>> # Calling transformer module list + >>> hidden_state = transformer_module_list(hidden_state, ...) + + """ + + _modules: Dict[str, Block] + + def __init__(self, modules: Iterable[Block], num_hidden=1) -> None: + super().__init__() + + self._modules = {} + pre_module = None + module_dict = {} + module_dict = {} + for i, module in enumerate(modules): + module = _block_wrapper(module, module_dict) + module.set_pre_module(pre_module) + pre_module = module + module._is_first_layer = False + module._is_last_layer = False + self._modules[str(i)] = module + self.add_module(str(i), module) + + self._modules[str(0)]._is_first_layer = True + self._modules[str(len(modules) - 1)]._is_last_layer = True + + self.num_hidden = num_hidden + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[Block]: + return iter(self._modules.values()) + + def __getitem__(self, index: Union[int, str]) -> Block: + return self._modules[str(index)] + + def forward(self, *args, return_hidden_states=False): + self.return_hidden_states = return_hidden_states + hidden_states = [] + for i in range(len(self)): + if return_hidden_states: + for hidden_state in args[: self.num_hidden]: + hidden_states.append(hidden_state) + outputs = self._modules[str(i)]._call_impl(*args) + if not isinstance(outputs, tuple): + outputs = (outputs,) + args = outputs + args[self.num_hidden :] + + if return_hidden_states: + hidden_states = [ + torch.stack(hidden_states[i :: self.num_hidden], dim=0) + for i in range(self.num_hidden) + ] + + if return_hidden_states: + return outputs + tuple(hidden_states) + else: + return ( + tuple(outputs[: self.num_hidden]) if self.num_hidden > 1 else outputs[0] + ) diff --git a/examples/BMTrain/bmtrain/debug.py b/examples/BMTrain/bmtrain/debug.py new file mode 100644 index 00000000..de392623 --- /dev/null +++ b/examples/BMTrain/bmtrain/debug.py @@ -0,0 +1,34 @@ +import torch + +DEBUG_VARS = {} + +def clear(key=None): + global DEBUG_VARS + if key is None: + DEBUG_VARS = {} + else: + DEBUG_VARS.pop(key, None) + +def set(key, value): + global DEBUG_VARS + if torch.is_tensor(value): + value = value.detach().cpu() + DEBUG_VARS[key] = value + +def get(key, default=None): + global DEBUG_VARS + if key in DEBUG_VARS: + return DEBUG_VARS[key] + return default + +def append(key, value): + global DEBUG_VARS + if key not in DEBUG_VARS: + DEBUG_VARS[key] = [] + DEBUG_VARS[key].append(value) + +def extend(key, value): + global DEBUG_VARS + if key not in DEBUG_VARS: + DEBUG_VARS[key] = [] + DEBUG_VARS[key].extend(value) \ No newline at end of file diff --git a/examples/BMTrain/bmtrain/distributed/__init__.py b/examples/BMTrain/bmtrain/distributed/__init__.py new file mode 100644 index 00000000..84a4adf8 --- /dev/null +++ b/examples/BMTrain/bmtrain/distributed/__init__.py @@ -0,0 +1 @@ +from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations, reduce_scatter diff --git a/examples/BMTrain/bmtrain/distributed/ops.py b/examples/BMTrain/bmtrain/distributed/ops.py new file mode 100644 index 00000000..d1b489e2 --- /dev/null +++ b/examples/BMTrain/bmtrain/distributed/ops.py @@ -0,0 +1,223 @@ +import torch +from ..global_var import config +from ..nccl import allGather as ncclAllGather, recv +from ..nccl import allReduce as ncclAllReduce +from ..nccl import broadcast as ncclBroadcast +from ..nccl import reduceScatter as ncclReduceScatter +from ..nccl import send as ncclSend +from ..nccl import recv as ncclRecv +from ..nccl import commCount,commRank,NCCLCommunicator +DTYPE_LIST = [ + torch.float64, + torch.float32, + torch.float16, + torch.int64, + torch.int32, + torch.int16, + torch.int8, + torch.bfloat16, + torch.bool +] +def send_activations(hidden_state, next_rank, comm): + send_meta(hidden_state, next_rank, comm) + ncclSend(hidden_state.storage(), next_rank, comm) + +def recv_activations(prev_rank, comm): + dtype, shape = recv_meta(prev_rank, comm) + hidden_state = torch.empty(shape, dtype=dtype, device="cuda") + ncclRecv(hidden_state.storage(), prev_rank, comm) + return hidden_state + +def send_meta(x, next_rank, comm): + meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int) + meta_data[0] = len(x.size()) + meta_data[1] = DTYPE_LIST.index(x.dtype) + meta_data[2:len(x.size())+2] = torch.tensor(x.size(), device="cuda", dtype=torch.int) + meta_data = meta_data.contiguous() + ncclSend(meta_data.storage(), next_rank, comm) + +def recv_meta(prev_rank, comm): + meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int) + ncclRecv(meta_data.storage(), prev_rank, comm) + n_dims = meta_data[0].item() + dtype = DTYPE_LIST[meta_data[1].item()] + shape = meta_data[2:n_dims+2].tolist() + return dtype,shape + +class OpBroadcast(torch.autograd.Function): + + @staticmethod + def forward(ctx, src, root, comm = None): + if comm is None: + comm = config["comm"] + ctx.comm = comm + outputs = torch.empty_like(src, dtype = src.dtype, device = src.device) + ncclBroadcast(src.storage(), outputs.storage(), root, comm) + return outputs + + @staticmethod + def backward(ctx, grad_output): + res = all_reduce(grad_output, "sum", ctx.comm) + return res, None, None + +def broadcast(src, root, comm=None): + if not config["initialized"]: + raise RuntimeError("BMTrain is not initialized") + return OpBroadcast.apply(src, root, comm) + +class OpAllGather(torch.autograd.Function): + + @staticmethod + def forward(ctx, input : torch.Tensor, comm = None): + if comm is None: + comm = config["comm"] + world_size = commCount(comm) + if not input.is_contiguous(): + input = input.contiguous() + if input.storage_offset() != 0 or input.storage().size() != input.numel(): + input = input.clone() + output = torch.empty( (world_size,) + input.size(), dtype=input.dtype, device=input.device) + ctx.comm = comm + ncclAllGather( + input.storage(), + output.storage(), + comm + ) + return output + + @staticmethod + def backward(ctx, grad_output): + return grad_output[commRank(ctx.comm)], None + +def all_gather(x : torch.Tensor, comm = None): + """Gathers the input tensor from all processes. + + Args: + x (torch.Tensor): The input tensor of shape (...). + + Returns: + torch.Tensor: The gathered tensor of shape (world_size, ...). + """ + if not config["initialized"]: + raise RuntimeError("BMTrain is not initialized") + + assert x.is_cuda + return OpAllGather.apply(x, comm) + +class OpReduceScatter(torch.autograd.Function): + + @staticmethod + def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None): + if comm is None: + comm = config["comm"] + ctx.comm = comm + rank = commRank(comm) + assert input.shape[0] % commCount(comm) == 0, "The dimension 0 must be divisible by the number of communication processes" + if not input.is_contiguous(): + input = input.contiguous() + if input.storage_offset() != 0 or input.storage().size() != input.numel(): + input = input.clone() + output_shape = (input.shape[0] // commCount(comm), *input.shape[1:]) + output = torch.empty( output_shape, dtype=input.dtype, device=input.device ) + ncclReduceScatter( + input.storage(), + output.storage(), + op, + comm + ) + ctx.op = op + if op in ["sum", "avg"]: + pass + elif op in ["max", "min"]: + ctx.save_for_backward( output != input[rank * input.shape[0]:(rank + 1) * input.shape[0]] ) + else: + ctx.save_for_backward( output / input[rank * input.shape[0]:(rank + 1) * input.shape[0]] ) + return output + + @staticmethod + def backward(ctx, grad_output): + with torch.no_grad(): + grad_output = OpAllGather.apply(grad_output, ctx.comm).flatten(0,1) + if ctx.op in ["max", "min", "prod"]: + raise NotImplementedError("max min operation now do not support backward") + else: + if ctx.op == "avg": + grad_output /= commCount(ctx.comm) + return grad_output, None, None + + +def reduce_scatter(x : torch.Tensor, op : str = "sum", comm = None): + """Reduces the input tensor from all processes. + + Args: + x (torch.Tensor): The input tensor of shape (world_size, ...). + op (str): The reduction operation, one of "sum", "avg", "max", "min", "prod". Default: "sum". + + Returns: + torch.Tensor: The reduced tensor of shape (...). + + """ + if not config["initialized"]: + raise RuntimeError("BMTrain is not initialized") + + assert x.is_cuda + return OpReduceScatter.apply(x, op, comm) + +class OpAllReduce(torch.autograd.Function): + @staticmethod + def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None): + if comm is None: + comm = config["comm"] + ctx.comm = comm + if not input.is_contiguous(): + input = input.contiguous() + if input.storage_offset() != 0 or input.storage().size() != input.numel(): + input = input.clone() + output = torch.empty( input.size(), dtype=input.dtype, device=input.device) + + ncclAllReduce( + input.storage(), + output.storage(), + op, + comm + ) + ctx.op = op + + if op in ["sum", "avg"]: + pass + elif op in ["max", "min"]: + ctx.save_for_backward( input != output ) + else: + ctx.save_for_backward( output / input ) + return output + + @staticmethod + def backward(ctx, grad_output): + if ctx.op == "sum": + return grad_output, None, None + elif ctx.op == "avg": + return grad_output / commCount(ctx.comm), None, None + elif ctx.op in ["max", "min"]: + return torch.masked_fill(grad_output, ctx.saved_tensors[0], 0), None, None + else: + return grad_output * ctx.saved_tensors[0], None, None + +def all_reduce(x : torch.Tensor, op : str = "sum", comm = None): + """Reduces the input tensor from all processes. + + Args: + x (torch.Tensor): The input tensor of shape (...). + op (str): The reduction operation, one of "sum", "avg", "max", "min", "prod". Default: "sum". + + Returns: + torch.Tensor: The reduced tensor of shape (...). + + """ + if not config["initialized"]: + raise RuntimeError("BMTrain is not initialized") + + assert x.is_cuda + return OpAllReduce.apply(x, op, comm) + + + diff --git a/examples/BMTrain/bmtrain/global_var.py b/examples/BMTrain/bmtrain/global_var.py new file mode 100644 index 00000000..137fa9cd --- /dev/null +++ b/examples/BMTrain/bmtrain/global_var.py @@ -0,0 +1,35 @@ +import torch +from typing_extensions import TypedDict +class ConfigMap(TypedDict): + rank : int + local_rank : int + world_size : int + local_size : int + zero_level : int + pipe_size : int + num_micro_batches : int + calc_stream : torch.cuda.Stream + load_stream : torch.cuda.Stream + load_event : torch.cuda.Event + barrier_stream : torch.cuda.Stream + loss_scale_factor : float + loss_scale_steps : int + topology : 'topology' + gradient_inspect : bool + initialized : bool + + comm : 'NCCLCommunicator' + +config = ConfigMap(rank=0, local_rank=0, world_size=1, initialized=False) + +def rank(): + """ + Returns the global rank of the current process. (0 ~ world_size-1) + """ + return config['rank'] + +def world_size(): + """ + Returns the total number of workers across all nodes. + """ + return config['world_size'] diff --git a/examples/BMTrain/bmtrain/hook_func.py b/examples/BMTrain/bmtrain/hook_func.py new file mode 100644 index 00000000..577331a2 --- /dev/null +++ b/examples/BMTrain/bmtrain/hook_func.py @@ -0,0 +1,121 @@ +import torch +from .global_var import config +from .zero_context import ZeroContext + + +def zero_pre_forward(module, inputs): + """Helper function for using ZeroContext to gather parmas before forward.""" + enter = True + pipe = False + if module._mode == "PIPE": + enter = module._micro_idx == 0 + pipe = True + if enter: + zero_level = module._zero_level + forward_flag = 1 if zero_level == 2 else 0 + if zero_level == 2 and not module._need_release: + forward_flag = 2 # repeating forward in same layer + if module.all_param_no_grad: # only forward + forward_flag = 0 + module._forward_block_ctx = ZeroContext(module, module._layer_dict, pipe=pipe) + module._forward_block_ctx.enter(forward_flag) + + +def zero_post_forward(module, inputs, outputs): + """Helper function for module _forwar_block_ctx weather exits after forward.""" + forward_flag = 1 if module._zero_level == 2 else 0 + if module.all_param_no_grad: + forward_flag = 0 + exit = True + if module._mode == "PIPE": + exit = module._micro_idx == config["micros"] - 1 + + if exit: + module._forward_block_ctx.exit(forward_flag) + + +def zero_pre_backward(module, grad_outputs): + """Helper function for using ZeroContext to init grad buffer before backward.""" + backward_flag = 2 if module._zero_level == 2 else 0 + if module._mode != "PIPE": + module._backward_block_ctx = ZeroContext(module, module._layer_dict) + module._backward_block_ctx.enter(backward_flag, True) + module.release_next_module(backward_flag) + else: + if module._micro_idx == config["micros"] - 1: + module._backward_block_ctx = ZeroContext( + module, module._layer_dict, pipe=True + ) + module._backward_block_ctx.enter(backward_flag, True) + + +def zero_post_backward(module, grad_inputs, grad_outputs): + """Helper function for module weather release after backward.""" + backward_flag = 2 if module._zero_level == 2 else 0 + if module._mode != "PIPE": + if module._is_first_layer: + module.release(backward_flag) + else: + if module._micro_idx == 0: + module.release(backward_flag) + module._micro_idx -= 1 + + +class OneStepNoGradFunc(torch.autograd.Function): + """ + Requires_grad = False for all inputs. + """ + + @staticmethod + def forward(ctx, module, placeholder, *x): + ctx.x = x + ctx.module = module + ctx.rng_state = torch.cuda.get_rng_state() + + with torch.no_grad(): + out = module._module(*x) + zero_post_forward(module, None, out) + if not isinstance(out, torch.Tensor): + return tuple(out) + return out + + @staticmethod + def backward(ctx, grads): + zero_pre_backward(ctx.module, grads) + with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True): + torch.cuda.set_rng_state(ctx.rng_state) + x = ctx.x + with torch.enable_grad(): + out = ctx.module._module(*x) + torch.autograd.backward(out, grads) + zero_post_backward(ctx.module, grads, None) + grads = [] + for _ in x: + grads.append(None) + return None, None, *grads + + +class PreHookFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, module, *x): + ctx.module = module + zero_pre_forward(module, x) + return x + + @staticmethod + def backward(ctx, *grads): + zero_post_backward(ctx.module, grads, None) + return None, *grads + + +class PostHookFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, module, *out): + ctx.module = module + zero_post_forward(module, None, out) + return out + + @staticmethod + def backward(ctx, *grads): + zero_pre_backward(ctx.module, grads) + return None, *grads diff --git a/examples/BMTrain/bmtrain/init.py b/examples/BMTrain/bmtrain/init.py new file mode 100644 index 00000000..601d617e --- /dev/null +++ b/examples/BMTrain/bmtrain/init.py @@ -0,0 +1,258 @@ +import datetime +import torch +import random +import torch.distributed as dist +import os +from .utils import print_dict +import ctypes +from .global_var import config + +from . import nccl +from .synchronize import synchronize + + +def init_distributed( + init_method: str = "env://", + seed: int = 0, + pipe_size: int = -1, + num_micro_batches: int = None, + tp_size: int = 1, +): + """Initialize distributed training. + This function will initialize the distributed training, set the random seed and global configurations. + It must be called before any other distributed functions. + + Args: + seed (int): The random seed. + pipe_size (int) : pipe_size means that all processes will be divided into pipe_size groups + num_micro_batches (int) : means that the input batchs will be divided into num_micro_batches small batches. used in pipeline mode. + tp_size (int) : tp_size means the size of each of tensor parallel group + + **init_distributed** reads the following environment variables: + + * `WORLD_SIZE`: The total number gpus in the distributed training. + * `RANK`: The global rank of the current gpu. From 0 to `WORLD_SIZE - 1`. + * `MASTER_ADDR`: The address of the master node. + * `MASTER_PORT`: The port of the master node. + * `LOCAL_RANK`: The local rank of the current gpu. + + Normally, all the environments variables above are setted by the pytorch distributed launcher. + + **Note**: Do not use any functions in torch.distributed package including `torch.distributed.init_process_group` . + + **Note**: If your training script is stuck here , it means some of your distributed workers are not connected to the master node. + + """ + torch.backends.cudnn.enabled = False + + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + rank = int(os.environ.get("RANK", "0")) + world_size = int(os.environ.get("WORLD_SIZE", "1")) + local_size = int(os.environ.get("LOCAL_WORLD_SIZE", "1")) + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = "10010" + addr = os.environ["MASTER_ADDR"] + port = os.environ["MASTER_PORT"] + master = addr + ":" + port + timeout = datetime.timedelta(seconds=1800) + rendezvous_iterator = dist.rendezvous( + init_method, rank, world_size, timeout=timeout + ) + + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + store = dist.PrefixStore("bmtrain", store) + torch.cuda.set_device(local_rank) + config["initialized"] = True + config["pipe_size"] = pipe_size if pipe_size > 0 else 1 + config["pipe_enabled"] = pipe_size > 0 + config["local_rank"] = local_rank + config["local_size"] = local_size + config["rank"] = rank + config["world_size"] = world_size + config["calc_stream"] = torch.cuda.current_stream() + config["load_stream"] = torch.cuda.Stream(priority=-1) + config["tp_comm_stream"] = torch.cuda.Stream(priority=-1) + config["pp_comm_stream"] = torch.cuda.Stream(priority=-1) + config["barrier_stream"] = torch.cuda.Stream() + config["load_event"] = torch.cuda.Event() + config["tp_size"] = tp_size if tp_size > 0 else 1 + config["topology"] = topology(config) + config["zero_rank"] = config["topology"].get_group_rank("zero") + config["tp_rank"] = config["topology"].get_group_rank("tp") + config["tp_zero_rank"] = config["topology"].get_group_rank("tp_zero") + config["save_param_to_cpu"] = True + cpus_this_worker = None + + all_available_cpus = sorted(list(os.sched_getaffinity(0))) + + cpus_per_worker = len(all_available_cpus) // local_size + + if cpus_per_worker < 1: + cpus_this_worker = all_available_cpus + torch.set_num_threads(1) + else: + cpus_this_worker = all_available_cpus[ + local_rank * cpus_per_worker : (local_rank + 1) * cpus_per_worker + ] + os.sched_setaffinity(0, cpus_this_worker) + torch.set_num_threads(len(cpus_this_worker)) + + torch.manual_seed(seed) + random.seed(seed) + try: + import numpy as np + + np.random.seed(seed) + except ModuleNotFoundError: + pass + + if rank == 0: + unique_id: bytes = nccl.getUniqueId() + store.set("BMTRAIN_UNIQUE_ID", unique_id.hex()) + + unique_id = bytes.fromhex(store.get("BMTRAIN_UNIQUE_ID").decode()) + config["comm"] = nccl.commInitRank(unique_id, world_size, rank) + topo = config["topology"] + + if config["pipe_enabled"]: + config["micros"] = ( + num_micro_batches if num_micro_batches else config["pipe_size"] + ) + if topo.stage_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex()) + unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode()) + config["pipe_comm"] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id) + + if topo.pp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}", unique_id.hex()) + unique_id = bytes.fromhex( + store.get(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}").decode() + ) + config["pp_zero_comm"] = nccl.commInitRank( + unique_id, world_size // config["pipe_size"], topo.pp_zero_id + ) + + if config["tp_size"] > 1: + if topo.tp_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex()) + unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode()) + config["tp_comm"] = nccl.commInitRank(unique_id, tp_size, topo.tp_id) + + if topo.tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex()) + unique_id = bytes.fromhex( + store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode() + ) + config["tp_zero_comm"] = nccl.commInitRank( + unique_id, world_size // config["tp_size"], topo.tp_zero_id + ) + + if config["pipe_size"] > 1 and config["tp_size"] > 1: + if topo.pp_tp_zero_id == 0: + unique_id = nccl.getUniqueId() + store.set(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}", unique_id.hex()) + unique_id = bytes.fromhex( + store.get(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}").decode() + ) + config["pp_tp_zero_comm"] = nccl.commInitRank( + unique_id, + world_size // (config["pipe_size"] * config["tp_size"]), + topo.pp_tp_zero_id, + ) + + config["zero_comm"] = config["comm"] + + for i in range(world_size): + if i == rank: + print_dict( + "Initialization", + { + "rank": rank, + "local_rank": local_rank, + "world_size": world_size, + "local_size": local_size, + "master": master, + "device": torch.cuda.current_device(), + "cpus": cpus_this_worker, + }, + ) + synchronize() + + +class topology: + """A helper class to keep parallel information when using different parallel methods together.""" + + def __init__(self, config): + # pipe_idx is the idx of the pipeline in the group + self.rank = config["rank"] + pp_size = config["pipe_size"] + tp_size = config["tp_size"] + world_size = config["world_size"] + assert ( + world_size % (pp_size * tp_size) == 0 + ), "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size" + + dp_size = world_size // (pp_size * tp_size) + config["tp_zero_size"] = dp_size + config["zero_size"] = world_size // pp_size + self.stages = config["pipe_size"] + + stage_size = world_size // pp_size + for i in range(world_size): + self.pipe_idx = self.rank % stage_size + self.stage_id = self.rank // stage_size + self.tp_id = self.rank % tp_size + self.tp_idx = self.rank // tp_size + # pp->zero + self.pp_zero_idx = self.stage_id + self.pp_zero_id = self.pipe_idx + # tp->zero + self.tp_zero_idx = self.tp_id + self.tp_zero_id = self.tp_idx + # pp->tp->zero + self.pp_tp_zero_idx = self.stage_id * tp_size + self.tp_id + self.pp_tp_zero_id = self.pipe_idx // tp_size + # only zero + self.zero_idx = 0 + self.zero_id = self.rank + + def get_group_id(self, group_name): + """Get group id of different parallel group. + + Args: + group_name (str): must be one of "pipe", "zero", "tp_zero" or "tp". + """ + if group_name == "pipe": + return self.pipe_idx + elif group_name == "zero": + return self.zero_idx + elif group_name == "tp_zero": + return self.tp_zero_idx + elif group_name == "tp": + return self.tp_idx + + def get_group_rank(self, group_name): + """Get group rank of different parallel group. + + Args: + group_name (str): must be one of "pipe", "zero", "tp_zero" or "tp". + """ + if group_name == "pipe": + return self.stage_id + elif group_name == "zero": + return self.zero_id + elif group_name == "tp_zero": + return self.tp_zero_id + elif group_name == "tp": + return self.tp_id + + +def is_initialized() -> bool: + return config["initialized"] diff --git a/examples/BMTrain/bmtrain/inspect/__init__.py b/examples/BMTrain/bmtrain/inspect/__init__.py new file mode 100644 index 00000000..2b6d2d26 --- /dev/null +++ b/examples/BMTrain/bmtrain/inspect/__init__.py @@ -0,0 +1,3 @@ +from .format import format_summary +from .model import inspect_model +from .tensor import inspect_tensor, record_tensor \ No newline at end of file diff --git a/examples/BMTrain/bmtrain/inspect/format.py b/examples/BMTrain/bmtrain/inspect/format.py new file mode 100644 index 00000000..79b3a1e5 --- /dev/null +++ b/examples/BMTrain/bmtrain/inspect/format.py @@ -0,0 +1,64 @@ +from typing import Any, Dict, List + +def align_str(s : str, align : int, left : bool) -> str: + if left: + return s + " " * (align - len(s)) + else: + return " " * (align - len(s)) + s + +def format_line(strs : List[str], length : List[int]): + ret = "" + for v, l in zip(strs, length): + if len(v) + 1 > l: + v = " " + v[:l - 1] + else: + v = " " + v + ret += align_str(v, l, True) + return ret + +def item_formater(x) -> str: + if isinstance(x, float): + return "{:.4f}".format(x) + else: + return str(x) + +def format_summary(summary : List[Dict[str, Any]]) -> str: + """Format summary to string. + + Args: + summary (List[Dict[str, Any]]): The summary to format. + + Returns: + str: The formatted summary. + + """ + ret = [] + + max_name_len = max([len("name")] + [len(item["name"]) for item in summary]) + 4 + headers = [ + "name", + "shape", + "max", + "min", + "std", + "mean", + "grad_std", + "grad_mean", + ] + headers_length = [ + max_name_len, + 20, + 10, + 10, + 10, + 10, + 10, + 10 + ] + ret.append( format_line(headers, headers_length) ) + for item in summary: + values = [ item_formater(item[name]) for name in headers ] + ret.append( format_line(values, headers_length) ) + return "\n".join(ret) + + \ No newline at end of file diff --git a/examples/BMTrain/bmtrain/inspect/model.py b/examples/BMTrain/bmtrain/inspect/model.py new file mode 100644 index 00000000..fc54f0d6 --- /dev/null +++ b/examples/BMTrain/bmtrain/inspect/model.py @@ -0,0 +1,246 @@ +import torch +from ..store import broadcast_object +from ..pipe_layer import PipelineTransformerBlockList +from ..block_layer import Block +from ..parameter import DistributedParameter +from .. import nccl +from ..global_var import config +import fnmatch + +def _gather_value(value : torch.Tensor, partition_size, origin_size): + global_size = partition_size * config['world_size'] + + storage = value.storage_type()(global_size) + + if value.storage().size() != partition_size: + tmp_buf = torch.zeros(partition_size, dtype=value.dtype, device=value.device) + tmp_buf[:value.numel()] = value[:] + nccl.allGather( + tmp_buf.storage(), + storage, + config['comm'] + ) + else: + nccl.allGather( + value.storage(), + storage, + config['comm'] + ) + + output_tensor = torch.tensor([], dtype=value.dtype, device="cuda") + output_tensor.set_(storage, 0, origin_size) + + return output_tensor + +def inspect_pipeline_transformer_block_list(pipe_model: PipelineTransformerBlockList, param_name : str, _prefix : str = ''): + ret = [] + for name, model in pipe_model._modules.items(): + idx = int(name) + prefix = _prefix + name + '.' + + # fast check + pass_fast_check = False + for param in model._param_info: + abs_name = prefix + param["name"] + if fnmatch.fnmatch(abs_name, param_name): + pass_fast_check = True + break + if not pass_fast_check: + continue + + if idx in pipe_model.layer_ids: + _param_buffer = {} + _grad_buffer = {} + for kw, val in model._storage_info.items(): + storage_type = model._storage_params[kw].storage_type() + + _param_buffer[kw] = storage_type(val["partition_size"] * val['world_size']) + if model._storage_params[kw].grad is not None: + _grad_buffer[kw] = storage_type(val["partition_size"] * val['world_size']) + + nccl.groupStart() + for kw, val in model._storage_info.items(): + nccl.allGather( + model._storage_params[kw].storage(), + _param_buffer[kw], + val["zero_comm"] + ) + if model._storage_params[kw].grad is not None: + nccl.allGather( + model._storage_params[kw].grad.storage(), + _grad_buffer[kw], + val["zero_comm"] + ) + + nccl.groupEnd() + for param in model._param_info: + abs_name = prefix + param["name"] + if fnmatch.fnmatch(abs_name, param_name): + kw_name = param["kw_name"] + dtype = _param_buffer[kw_name].dtype + device = _param_buffer[kw_name].device + offset = param["offset"] + shape = param["shape"] + p = torch.tensor([], dtype=dtype, device=device).set_(_param_buffer[kw_name], offset, shape) + if kw_name in _grad_buffer: + g = torch.tensor([], dtype=dtype, device=device).set_(_grad_buffer[kw_name], offset, shape) + info = { + "name": abs_name, + "shape": tuple(shape), + "std": p.std().cpu().item(), + "mean": p.mean().cpu().item(), + "grad_std": g.std().cpu().item(), + "grad_mean": g.mean().cpu().item(), + "max": p.max().cpu().item(), + "min": p.min().cpu().item(), + } + else: + info = { + "name": abs_name, + "shape": tuple(shape), + "std": p.std().cpu().item(), + "mean": p.mean().cpu().item(), + "grad_std": 0., + "grad_mean": 0., + "max": p.max().cpu().item(), + "min": p.min().cpu().item(), + } + broadcast_object(info, config["pipe_comm"], pipe_model.get_stage_by_layer_id(idx)) + ret.append(info) + else: + for param in model._param_info: + abs_name = prefix + param["name"] + if fnmatch.fnmatch(abs_name, param_name): + info = broadcast_object({}, config["pipe_comm"], pipe_model.get_stage_by_layer_id(idx)) + ret.append(info) + + return ret + + +def inspect_block(model : Block, param_name : str, prefix : str = ''): + # fast check + pass_fast_check = False + for param in model._param_info: + abs_name = prefix + param["name"] + if fnmatch.fnmatch(abs_name, param_name): + pass_fast_check = True + break + if not pass_fast_check: + return [] + + _param_buffer = {} + _grad_buffer = {} + for kw, val in model._storage_info.items(): + storage_type = model._storage_params[kw].storage_type() + + _param_buffer[kw] = storage_type(val["partition_size"] * config['world_size']) + if model._storage_params[kw].grad is not None: + _grad_buffer[kw] = storage_type(val["partition_size"] * config['world_size']) + + nccl.groupStart() + for kw, val in model._storage_info.items(): + nccl.allGather( + model._storage_params[kw].storage(), + _param_buffer[kw], + config["comm"] + ) + if model._storage_params[kw].grad is not None: + nccl.allGather( + model._storage_params[kw].grad.storage(), + _grad_buffer[kw], + config["comm"] + ) + + nccl.groupEnd() + ret = [] + for param in model._param_info: + abs_name = prefix + param["name"] + if fnmatch.fnmatch(abs_name, param_name): + kw_name = param["kw_name"] + dtype = _param_buffer[kw_name].dtype + device = _param_buffer[kw_name].device + offset = param["offset"] + shape = param["shape"] + p = torch.tensor([], dtype=dtype, device=device).set_(_param_buffer[kw_name], offset, shape) + if kw_name in _grad_buffer: + g = torch.tensor([], dtype=dtype, device=device).set_(_grad_buffer[kw_name], offset, shape) + ret.append({ + "name": abs_name, + "shape": tuple(shape), + "std": p.std().cpu().item(), + "mean": p.mean().cpu().item(), + "grad_std": g.std().cpu().item(), + "grad_mean": g.mean().cpu().item(), + "max": p.max().cpu().item(), + "min": p.min().cpu().item(), + }) + else: + ret.append({ + "name": abs_name, + "shape": tuple(shape), + "std": p.std().cpu().item(), + "mean": p.mean().cpu().item(), + "grad_std": 0., + "grad_mean": 0., + "max": p.max().cpu().item(), + "min": p.min().cpu().item(), + }) + return ret + +@torch.no_grad() +def inspect_model(model : torch.nn.Module, param_name : str, prefix : str = ''): + """Inspect the model and return the summary of the parameters. + + Args: + model (torch.nn.Module): The model to be inspected. + param_name (str): The name of the parameter to be inspected. The wildcard '*' can be used to match multiple parameters. + prefix (str): The prefix of the parameter name. + + Returns: + list: The summary of the parameters. + + Example: + >>> result_linear = bmt.inspect.inspect_model(model, "*.linear*") + >>> result_layernorm = bmt.inspect.inspect_model(model, "*.layernorm*") + >>> text_summray = bmt.inspect.format_summary(result_linear + result_layernorm) + >>> bmt.print_rank(text_summary) + name shape max min std mean grad_std grad_mean + ... + + """ + if isinstance(model, PipelineTransformerBlockList): + return inspect_pipeline_transformer_block_list(model, param_name, prefix) + elif isinstance(model, Block): + return inspect_block(model, param_name, prefix) + else: + ret = [] + for name, param in model._parameters.items(): + if fnmatch.fnmatch(prefix + name, param_name): + if isinstance(param, DistributedParameter): + p = _gather_value(param.data, param.storage().size(), param._original_shape) + else: + p = param + if p is None: + continue + stats = { + 'name': prefix + name, + 'shape': tuple(p.size()), + "std": p.std().cpu().item(), + "mean": p.mean().cpu().item(), + "max": p.max().cpu().item(), + "min": p.min().cpu().item(), + } + if param.grad is not None: + if isinstance(param, DistributedParameter): + g = _gather_value(param.grad.data, param.storage().size(), param._original_shape) + else: + g = param.grad + stats["grad_std"] = g.std().cpu().item() + stats["grad_mean"] = g.mean().cpu().item() + else: + stats["grad_std"] = 0. + stats["grad_mean"] = 0. + ret.append(stats) + for name, module in model._modules.items(): + ret.extend(inspect_model(module, param_name, prefix + name + '.')) + return ret diff --git a/examples/BMTrain/bmtrain/inspect/tensor.py b/examples/BMTrain/bmtrain/inspect/tensor.py new file mode 100644 index 00000000..9d003f82 --- /dev/null +++ b/examples/BMTrain/bmtrain/inspect/tensor.py @@ -0,0 +1,383 @@ +from typing import Optional +import torch +from .. import debug +from .. import nccl +from ..global_var import config +from ..store import broadcast_object +from ..distributed import broadcast +import math + + +class InspectTensor: + """This object is returned by `InspectTensorManager`. + + You can get the tensors recorded by `record_tensor`. + + """ + + def __init__(self): + self.summary = [] + + def _set_summary(self, summary): + self._summary = summary + for item in summary: + item["prefix"] = "" if item["group"] is None else f'{item["group"]}.' + + self.summary = [] + + kw_cnt = {} + i = 0 + while i < len(summary): + item = summary[i] + if item["inside_pipe"] is not None: + before_len = len(self.summary) + + assert item["inside_pipe"]["st"] + pipe_cnt = {} + j = i + while j < len(summary): + item = summary[j] + kw = f'{item["prefix"]}{item["name"]}' + + assert item["inside_pipe"] is not None + stage_id = item["inside_pipe"]["stage_id"] + stages = item["inside_pipe"]["stages"] + st = item["inside_pipe"]["st"] + ed = item["inside_pipe"]["ed"] + + if kw not in pipe_cnt: + pipe_cnt[kw] = 0 + pipe_cnt[kw] += 1 + + j += 1 + if ed: + break + + for stage in range(stages): + if stage_id == stage: + broadcast_object(pipe_cnt, config["pipe_comm"], src=stage) + for k in range(i, j): + item = summary[k] + kw = f'{item["prefix"]}{item["name"]}' + if kw not in kw_cnt: + kw_cnt[kw] = 0 + tensor = torch.cat( + [ + summary[k + m * (j - i)]["tensor"] + for m in range(config["micros"]) + ], + dim=0, + ) + grad = ( + torch.cat( + [ + summary[k + m * (j - i)]["tensor"].grad + for m in range(config["micros"]) + ], + dim=0, + ) + if item["requires_grad"] + and item["tensor"].grad is not None + else None + ) + self.summary.append( + { + "name": item["name"], + "summary_name": f'{item["prefix"]}{kw_cnt[kw]}.{item["name"]}', + "group": item["group"], + "min": None, + "max": None, + "mean": None, + "std": None, + "shape": (item["shape"][0] * config["micros"],) + + item["shape"][1:], + "grad_mean": None, + "grad_std": None, + "tensor": tensor, + "grad": grad, + "requires_grad": item["requires_grad"], + "inside_pipe": {"stage_id": stage}, + } + ) + kw_cnt[kw] += 1 + else: + cnt = broadcast_object({}, config["pipe_comm"], src=stage) + for kw, val in cnt.items(): + if kw not in kw_cnt: + kw_cnt[kw] = 0 + for _ in range(val): + self.summary.append( + { + "name": item["name"], + "summary_name": f'{item["prefix"]}{kw_cnt[kw]}.{item["name"]}', + "group": None, + "min": None, + "max": None, + "mean": None, + "std": None, + "shape": None, + "grad_mean": None, + "grad_std": None, + "tensor": None, + "grad": None, + "requires_grad": None, + "inside_pipe": {"stage_id": stage}, + } + ) + kw_cnt[kw] += 1 + + after_len = len(self.summary) + with torch.enable_grad(): + for it in self.summary[before_len:after_len]: + if it["tensor"] is not None: + has_grad = it["grad"] is not None + info = { + "group": it["group"], + "shape": it["shape"], + "requires_grad": it["requires_grad"], + "has_grad": has_grad, + } + broadcast_object( + info, + config["pipe_comm"], + src=it["inside_pipe"]["stage_id"], + ) + tensor = it["tensor"] + tensor = broadcast( + tensor, + it["inside_pipe"]["stage_id"], + config["pipe_comm"], + ) + grad = it["grad"] + else: + info = broadcast_object( + {}, + config["pipe_comm"], + src=it["inside_pipe"]["stage_id"], + ) + has_grad = info.pop("has_grad") + it.update(info) + tensor = torch.empty(it["shape"]).cuda().requires_grad_() + tensor = broadcast( + tensor, + it["inside_pipe"]["stage_id"], + config["pipe_comm"], + ) + if has_grad: + grad = torch.empty(it["shape"]).cuda() + tensor = tensor.chunk(stages, dim=0)[stage_id].clone() + it["tensor"] = tensor + if has_grad: + grad = broadcast( + grad, it["inside_pipe"]["stage_id"], config["pipe_comm"] + ) + grad = grad.chunk(stages, dim=0)[stage_id].clone() + tensor.grad = grad + it["shape"] = (it["shape"][0] // config["pipe_size"],) + it[ + "shape" + ][1:] + + i = i + config["micros"] * (j - i) + else: + kw = f'{item["prefix"]}{item["name"]}' + if kw not in kw_cnt: + kw_cnt[kw] = 0 + self.summary.append( + { + "name": item["name"], + "summary_name": f'{item["prefix"]}{kw_cnt[kw]}.{item["name"]}', + "group": item["group"], + "min": None, + "max": None, + "mean": None, + "std": None, + "shape": item["shape"], + "grad_mean": None, + "grad_std": None, + "tensor": item["tensor"], + "requires_grad": item["requires_grad"], + "inside_pipe": None, + } + ) + kw_cnt[kw] += 1 + i = i + 1 + + def get_summary(self): + r"""Get the summary of the tensors recorded by `record_tensor`. + + Returns: + A list of dicts. Each dict contains the following keys: + - name: The name of the tensor. + - min: The minimum value of the tensor. + - max: The maximum value of the tensor. + - mean: The mean value of the tensor. + - std: The standard deviation of the tensor. + - shape: The shape of the tensor. + - grad_mean: The mean value of the gradient of the tensor. + - grad_std: The standard deviation of the gradient of the tensor. + + **Note:** This method must be called outside of the `with` block. + + """ + self._set_summary(self._summary) + ret = [] + for item in self.summary: + comm = config["comm"] + + if not item["requires_grad"] or item["tensor"].grad is None: + x = item["tensor"] + info = torch.empty(2, dtype=x.dtype, device=x.device) + info[0] = x.mean() + info[1] = x.var() + nccl.allReduce(info.storage(), info.storage(), "sum", comm) + info = info / nccl.commCount(comm) + x_mean = info[0].cpu().item() + x_std = math.sqrt(info[1].cpu().item()) + grad_mean = None + grad_std = None + else: + x = item["tensor"] + info = torch.empty(4, dtype=x.dtype, device=x.device) + info[0] = x.mean() + info[1] = x.var() + info[2] = x.grad.mean() + info[3] = x.grad.var() + nccl.allReduce(info.storage(), info.storage(), "sum", comm) + info = info / nccl.commCount(comm) + x_mean = info[0].cpu().item() + x_std = math.sqrt(info[1].cpu().item()) + grad_mean = info[2].cpu().item() + grad_std = math.sqrt(info[3].cpu().item()) + + info[0] = x.max() + info[1] = -x.min() + nccl.allReduce(info.storage(), info.storage(), "max", comm) + x_max = info[0].cpu().item() + x_min = -info[1].cpu().item() + + summary = { + "name": item["summary_name"], + "min": x_min, + "max": x_max, + "mean": x_mean, + "std": x_std, + "shape": tuple( + (item["shape"][0] * config["world_size"],) + item["shape"][1:] + ), + "grad_mean": grad_mean, + "grad_std": grad_std, + } + + ret.append(summary) + return ret + + def get_tensor( + self, name: str, group: Optional[str] = None, index: Optional[int] = None + ) -> torch.Tensor: + """Get the tensor recorded by `record_tensor` by name, group and index. + + Args: + name (str): The name of the tensor. + group (Optional[str]): The group of the tensor. + index (Optional[int]): The index of the tensor. + + Returns: + The tensor if found, otherwise None. + + """ + group_name_prefix = f"{group}." if group is not None else "" + + all_names = [] + if index is None: + all_names.append(f"{group_name_prefix}{name}") + all_names.append(f"{group_name_prefix}0.{name}") + else: + all_names.append(f"{group_name_prefix}{index}.{name}") + + for item in self.summary: + if item["name"] in all_names: + return item["tensor"] + return None + + +class InspectTensorManager: + def __init__(self) -> None: + self._inspector = None + + def __enter__(self) -> InspectTensor: + self.prev_val = debug.get("_inspect_tensor", False) + if not self.prev_val: + debug.set("_inspect_tensor", True) + self._inspector = InspectTensor() + return self._inspector + else: + raise RuntimeError("InspectTensorManager is already in use") + + def __exit__(self, *args): + if not self.prev_val: + debug.set("_inspect_tensor", self.prev_val) + summary = debug.get("_inspect_hidden_states", []) + self._inspector._set_summary(summary) + self._inspector = None + debug.set("_inspect_hidden_states", []) + + +def inspect_tensor() -> InspectTensorManager: + """**inspect_tensor** returns a context manager that can be used to get the intermediate results of the model computations and their gradients. + + Example: + >>> with bmt.inspect.inspect_tensor() as inspector: + >>> loss = model(inputs) + >>> loss.backward() + >>> summary = inspector.get_summary() + >>> text_summary = bmt.inspect.format_summary(summary) + >>> bmt.print_rank(text_summary) + name shape max min std mean grad_std grad_mean + ... + + **Note:** loss.backward() must be called inside the context manager, otherwise the gradients will not be recorded. + **Note:** Calling get_summary() has significant overhead. + + """ + + return InspectTensorManager() + + +def record_tensor(x: torch.Tensor, name: str, group=None): + """Record the tensor for inspection. + + Args: + x (torch.Tensor): The tensor to be recorded. + name (str): The name of the tensor. + group (str): The group name of the tensor. + + **Note:** This function is only available in inspect_tensor context. + **Note:** Recording too many tensors may cause memory issues. + + """ + if isinstance(x, torch.nn.Parameter): + raise RuntimeError("Cannot inspect Parameter") + + if not debug.get("_inspect_tensor", False): + # do nothing + return + + if x.requires_grad: + x.retain_grad() + debug.append( + "_inspect_hidden_states", + { + "name": name, + "group": group, + "min": None, + "max": None, + "mean": None, + "std": None, + "shape": x.shape, + "grad_mean": None, + "grad_std": None, + "tensor": x, + "requires_grad": x.requires_grad, + "inside_pipe": None, + }, + ) diff --git a/examples/BMTrain/bmtrain/layer.py b/examples/BMTrain/bmtrain/layer.py new file mode 100644 index 00000000..e071e01b --- /dev/null +++ b/examples/BMTrain/bmtrain/layer.py @@ -0,0 +1,143 @@ +import torch +from .parameter import DistributedParameter +from .global_var import config +import itertools +from .utils import tp_split_tensor + +class DistributedModule(torch.nn.Module): + """ + DistributedModule is a subclass of torch.nn.Module that overrides the `__getattr__` method to gather distributed parameters automatically. + + """ + + def __getattr__(self, name: str): + ret = super().__getattr__(name) + # gather distributed parameters if not in bmt.Block + if isinstance(ret, DistributedParameter) and not ret._in_block: + return ret.gather() + return ret + + def _save_to_state_dict(self, destination, prefix, keep_vars): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + for name, param in self._parameters.items(): + if param is not None: + if isinstance(param, DistributedParameter):#and not param._in_block: + if param._in_block: + destination[prefix + name] = param.tp_gather().detach() # sync operation + else: + destination[prefix + name] = param.gather_all().detach() # sync operation + if config['save_param_to_cpu']: + destination[prefix + name] = destination[prefix + name].cpu() + else: + if config['save_param_to_cpu']: + destination[prefix + name] = param if keep_vars else param.detach().cpu() + else: + destination[prefix + name] = param if keep_vars else param.detach() + + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + if key in state_dict: + tp_mode = param._tp_mode + input_param = state_dict[key] + if input_param.__class__.__name__ == "DistributedTensorWrapper": + input_param = input_param.broadcast() + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and not isinstance(param, DistributedParameter) and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.' + .format(key, input_param.shape, param.shape)) + continue + verify_shape = torch.Size(param._original_shape if not tp_mode else param._tp_original_shape) + if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != verify_shape: + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.' + .format(key, input_param.shape, verify_shape)) + try: + with torch.no_grad(): + if isinstance(param, DistributedParameter): + tp_split_dim = param._tp_split_dim + if tp_mode and tp_split_dim >= 0: + input_param = tp_split_tensor(input_param, tp_split_dim) + param._copy_data(input_param) + else: + param.copy_(input_param) + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.' + .format(key, param.size(), input_param.size(), ex.args)) + elif strict: + missing_keys.append(key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix): + input_name = key[len(prefix):] + input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) + diff --git a/examples/BMTrain/bmtrain/loss/__init__.py b/examples/BMTrain/bmtrain/loss/__init__.py new file mode 100644 index 00000000..daa731fd --- /dev/null +++ b/examples/BMTrain/bmtrain/loss/__init__.py @@ -0,0 +1 @@ +from .cross_entropy import FusedCrossEntropy \ No newline at end of file diff --git a/examples/BMTrain/bmtrain/loss/_function.py b/examples/BMTrain/bmtrain/loss/_function.py new file mode 100644 index 00000000..6ff3c471 --- /dev/null +++ b/examples/BMTrain/bmtrain/loss/_function.py @@ -0,0 +1,182 @@ +from .. import C +import torch + +CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda + + +def has_inf_nan(g_half: torch.Tensor, out: torch.Tensor) -> None: + assert out.dtype == torch.uint8, "out must be a uint8 tensor" + assert CHECK_INPUT(g_half), "g_fp16 must be contiguous and on cuda" + assert CHECK_INPUT(out), "out must be contiguous and on cuda" + mid = torch.zeros(1024, device=out.device, dtype=out.dtype) + stream = torch.cuda.current_stream().cuda_stream + if g_half.dtype == torch.float16: + C.has_nan_inf_fp16_launcher( + g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream + ) + elif g_half.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.has_nan_inf_bf16_launcher( + g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream + ) + else: + raise ValueError(f"has_inf_nan not supported for dtype {g_half.dtype}") + + +def cross_entropy_forward( + m: int, + n: int, + input: torch.Tensor, + target: torch.Tensor, + softmax: torch.Tensor, + output: torch.Tensor, + ignore_index: int, +) -> None: + CHECK_INPUT(input) + CHECK_INPUT(target) + CHECK_INPUT(softmax) + CHECK_INPUT(output) + assert target.dtype == torch.int32, "target must be an int tensor" + assert output.dtype == torch.float32, "output must be a float tensor" + assert ( + input.numel() == softmax.numel() + ), "input and softmax must have the same number of elements" + assert ( + target.numel() == output.numel() + ), "target and output must have the same number of elements" + input_ptr = input.data_ptr() + target_ptr = target.data_ptr() + softmax_ptr = softmax.data_ptr() + output_ptr = output.data_ptr() + cuda_stream = torch.cuda.current_stream().cuda_stream + if input.dtype == torch.float16: + C.cross_entropy_forward_fp16_launcher( + m, + n, + input_ptr, + target_ptr, + softmax_ptr, + output_ptr, + ignore_index, + cuda_stream, + ) + elif input.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.cross_entropy_forward_bf16_launcher( + m, + n, + input_ptr, + target_ptr, + softmax_ptr, + output_ptr, + ignore_index, + cuda_stream, + ) + else: + raise ValueError(f"cross_entropy_forward not supported for dtype {input.dtype}") + + +def cross_entropy_backward_inplace( + m: int, + n: int, + grad_output: torch.Tensor, + target: torch.Tensor, + x: torch.Tensor, + ignore_index: int, +) -> None: + CHECK_INPUT(grad_output) + CHECK_INPUT(target) + CHECK_INPUT(x) + assert grad_output.dtype == torch.float32, "grad_output must be a float tensor" + assert target.dtype == torch.int32, "target must be an int tensor" + assert ( + target.numel() == grad_output.numel() + ), "target and grad_output must have the same number of elements" + cuda_stream = torch.cuda.current_stream().cuda_stream + grad_output_ptr = grad_output.data_ptr() + target_ptr = target.data_ptr() + x_ptr = x.data_ptr() + + if x.dtype == torch.float16: + C.cross_entropy_backward_inplace_fp16_launcher( + m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream + ) + elif x.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.cross_entropy_backward_inplace_bf16_launcher( + m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream + ) + else: + raise ValueError( + f"cross_entropy_backward not supported for dtype {input.dtype}" + ) + + +def fused_sumexp(logits: torch.Tensor, max_logits: torch.Tensor) -> torch.Tensor: + CHECK_INPUT(logits) + CHECK_INPUT(max_logits) + assert max_logits.dtype == torch.float32, "max_logits must be float tensor" + assert max_logits.size(0) == logits.size( + 0 + ), "max_logits must have same size(0) as logits" + sum_exp_logits = torch.empty( + logits.size(0), dtype=torch.float32, device=logits.device + ) + m = logits.size(0) + n = logits.size(1) + cuda_stream = torch.cuda.current_stream().cuda_stream + logits_ptr = logits.data_ptr() + max_logits_ptr = max_logits.data_ptr() + sum_exp_logits_ptr = sum_exp_logits.data_ptr() + if logits.dtype == torch.float16: + C.fused_sumexp_fp16_launcher( + m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream + ) + elif logits.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.fused_sumexp_bf16_launcher( + m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream + ) + else: + raise ValueError(f"fused_sumexp not supported for dtype {logits.dtype}") + return sum_exp_logits + + +def fused_softmax_inplace( + logits: torch.Tensor, max_logits: torch.Tensor, sum_exp_logits: torch.Tensor +) -> None: + CHECK_INPUT(logits) + CHECK_INPUT(max_logits) + CHECK_INPUT(sum_exp_logits) + assert max_logits.dtype == torch.float32, "max_logits must be float tensor" + assert sum_exp_logits.dtype == torch.float32, "sum_exp_logits must be float tensor" + assert max_logits.size(0) == logits.size( + 0 + ), "max_logits must have same size(0) as logits" + assert sum_exp_logits.size(0) == logits.size( + 0 + ), "sum_exp_logits must have same size(0) as logits" + m = logits.size(0) + n = logits.size(1) + cuda_stream = torch.cuda.current_stream().cuda_stream + logits_ptr = logits.data_ptr() + max_logits_ptr = max_logits.data_ptr() + sum_exp_logits_ptr = sum_exp_logits.data_ptr() + if logits.dtype == torch.float16: + C.fused_softmax_inplace_fp16_launcher( + m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream + ) + elif logits.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.fused_softmax_inplace_bf16_launcher( + m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream + ) + else: + raise ValueError( + f"fused_softmax_inplace not supported for dtype {logits.dtype}" + ) diff --git a/examples/BMTrain/bmtrain/loss/cross_entropy.py b/examples/BMTrain/bmtrain/loss/cross_entropy.py new file mode 100644 index 00000000..5be07665 --- /dev/null +++ b/examples/BMTrain/bmtrain/loss/cross_entropy.py @@ -0,0 +1,260 @@ +from typing import Optional +import torch +from . import _function as F +from bmtrain.global_var import config +from bmtrain.distributed import all_gather, all_reduce + +class OpFusedCrossEntropy(torch.autograd.Function): + """ + CrossEntropy dim = 1 + """ + @staticmethod + def forward(ctx, x : torch.Tensor, target : torch.Tensor, ignore_index: int): + assert x.ndim == 2 + softmax = torch.empty(x.size(), device=x.device, dtype=x.dtype) + out = torch.empty(x.size(0), device=x.device, dtype=torch.float) + F.cross_entropy_forward( + x.size(0), x.size(1), + x, target, + softmax, out, + ignore_index, + ) + ctx.ignore_index = ignore_index + ctx.save_for_backward(softmax, target) + return out # float tensor + + @staticmethod + def backward(ctx, grad_output : torch.Tensor): + grad_output = grad_output.contiguous() + softmax, target = ctx.saved_tensors + F.cross_entropy_backward_inplace( + softmax.size(0), softmax.size(1), + grad_output, target, + softmax, + ctx.ignore_index, + ) + return (softmax, None, None) + +class VPFusedCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, logits : torch.Tensor, target : torch.Tensor): + comm = config['tp_comm'] + rank = config['tp_rank'] + world_size = config['tp_size'] + + max_logits = torch.max(logits, dim=-1)[0].float() + max_logits = all_reduce(max_logits, op="max", comm=comm) + + partition_vocab_size = logits.size()[-1] + vocab_start_index = rank * partition_vocab_size + vocab_end_index = (rank + 1) * partition_vocab_size + + # Create a mask of valid vocab ids (1 means it needs to be masked). + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target.clone() - vocab_start_index + masked_target[target_mask] = 0 + + logits_2d = logits.view(-1, partition_vocab_size) + masked_target_1d = masked_target.view(-1) + arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device) + predicted_logits_1d = logits_2d[arange_1d, masked_target_1d].contiguous() # (-1,) + predicted_logits = predicted_logits_1d.view_as(target) + predicted_logits[target_mask] = 0.0 # if target=-100, it will also be 0 + + # All reduce is needed to get the chunks from other GPUs. + predicted_logits = all_reduce(predicted_logits.float(), op="sum", comm=comm) + predicted_logits = predicted_logits - max_logits + # Sum of exponential of logits along vocab dimension across all GPUs. + + sum_exp_logits = torch.empty(logits.size(0), device=logits.device, dtype=torch.float) + sum_exp_logits = F.fused_sumexp(logits, max_logits) # float + sum_exp_logits = all_reduce(sum_exp_logits, op="sum", comm=comm) + 1e-10 # avoid nan + + softmax = logits.clone() + F.fused_softmax_inplace(softmax, max_logits, sum_exp_logits) # logits -> softmax + # logits = logits.float() - max_logits.unsqueeze(dim=-1).float() + # exp_logits = logits + # torch.exp(logits, out=exp_logits) + # sum_exp_logits = exp_logits.sum(dim=-1) + # exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) + + loss = torch.log(sum_exp_logits.view(predicted_logits.shape)) - predicted_logits + + # Normalize + ctx.save_for_backward(softmax, target_mask, masked_target_1d) + + return loss + + @staticmethod + def backward(ctx, grad_output): + # Retreive tensors from the forward path. + softmax, target_mask, masked_target_1d = ctx.saved_tensors + # All the inputs have softmax as thier gradient. + grad_input = softmax + # For simplicity, work with the 2D gradient. + partition_vocab_size = softmax.size()[-1] + grad_2d = grad_input.view(-1, partition_vocab_size) + + # Add the gradient from matching classes. + arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device) + + softmax_update = 1.0 - target_mask.view(-1).float() + + grad_2d[arange_1d, masked_target_1d] -= softmax_update + grad_input.mul_(grad_output.view(*grad_input.shape[:-1]).unsqueeze(dim=-1)) + + return grad_input, None + +class FusedCrossEntropy(torch.nn.Module): + r"""This criterion computes the cross entropy loss between input and target. + + It is useful when training a classification problem with `C` classes. + If provided, the optional argument :attr:`weight` should be a 1D `Tensor` + assigning weight to each of the classes. + This is particularly useful when you have an unbalanced training set. + + The `input` is expected to contain raw, unnormalized scores for each class. + `input` has to be a Tensor of size :math:`(minibatch, C)`. + + The `target` that this criterion expects should contain either: + + - Class indices in the range :math:`[0, C-1]` where :math:`C` is the number of classes; if + `ignore_index` is specified, this loss also accepts this class index (this index + may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction` + set to ``'none'``) loss for this case can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})} + \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\} + + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, + :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension. If + :attr:`reduction` is not ``'none'`` (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + Note that this case is equivalent to the combination of :class:`~torch.nn.LogSoftmax` and + :class:`~torch.nn.NLLLoss`. + + - Probabilities for each class; useful when labels beyond a single class per minibatch item + are required, such as for blended labels, label smoothing, etc. The unreduced (i.e. with + :attr:`reduction` set to ``'none'``) loss for this case can be described as: + + .. math:: + \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad + l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\exp(\sum_{i=1}^C x_{n,i})} y_{n,c} + + where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight, + :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension. If + :attr:`reduction` is not ``'none'`` (default ``'mean'``), then + + .. math:: + \ell(x, y) = \begin{cases} + \frac{\sum_{n=1}^N l_n}{N}, & + \text{if reduction} = \text{`mean';}\\ + \sum_{n=1}^N l_n, & + \text{if reduction} = \text{`sum'.} + \end{cases} + + .. note:: + The performance of this criterion is generally better when `target` contains class + indices, as this allows for optimized computation. Consider providing `target` as + class probabilities only when a single class label per minibatch item is too restrictive. + + Args: + weight (Tensor, optional): a manual rescaling weight given to each class. + If given, has to be a Tensor of size `C` + size_average (bool, optional): Deprecated (see :attr:`reduction`). By default, + the losses are averaged over each loss element in the batch. Note that for + some losses, there are multiple elements per sample. If the field :attr:`size_average` + is set to ``False``, the losses are instead summed for each minibatch. Ignored + when :attr:`reduce` is ``False``. Default: ``True`` + ignore_index (int, optional): Specifies a target value that is ignored + and does not contribute to the input gradient. When :attr:`size_average` is + ``True``, the loss is averaged over non-ignored targets. Note that + :attr:`ignore_index` is only applicable when the target contains class indices. + reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the + losses are averaged or summed over observations for each minibatch depending + on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per + batch element instead and ignores :attr:`size_average`. Default: ``True`` + reduction (string, optional): Specifies the reduction to apply to the output: + ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will + be applied, ``'mean'``: the weighted mean of the output is taken, + ``'sum'``: the output will be summed. Note: :attr:`size_average` + and :attr:`reduce` are in the process of being deprecated, and in + the meantime, specifying either of those two args will override + :attr:`reduction`. Default: ``'mean'`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision <https://arxiv.org/abs/1512.00567>`__. Default: :math:`0.0`. + + Shape: + - Input: :math:`(N, C)` where `C = number of classes`. + - Target: If containing class indices, shape :math:`(N)` where each value is + :math:`0 \leq \text{targets}[i] \leq C-1`. If containing class probabilities, + same shape as the input. + - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)`. + Otherwise, scalar. + + Examples:: + + >>> # Example of target with class indices + >>> loss_func = bmt.loss.FusedCrossEntropy() + >>> input = torch.randn(32, 100).half() + >>> target = torch.randint(0, 100, (32,)).long() + >>> loss = loss_func(input, target) + >>> loss.backward() + """ + def __init__(self, + weight: Optional[torch.Tensor] = None, + ignore_index: int = -100, + reduction: str = 'mean', + label_smoothing: float = 0.0, # TODO not supported yet + parallel: bool = False, + ) -> None: + super().__init__() + self.weight = weight + self.ignore_index = ignore_index + self.reduction = reduction + self.label_smoothing = label_smoothing + self.parallel = parallel + + def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + if self.parallel: + ret = VPFusedCrossEntropy.apply(input, target.long()) + else: + if input.dtype == torch.float32: + return torch.nn.functional.cross_entropy( + input, + target.long(), + weight=self.weight, + ignore_index=self.ignore_index, + reduction=self.reduction, + label_smoothing=self.label_smoothing) + + ret = OpFusedCrossEntropy.apply(input, target.int(), self.ignore_index) # return float tensor + + if self.weight is not None: + if self.weight.dim() != 1 or self.weight.size(0) != input.size(1): + raise ValueError("weight should be a 1D tensor of size C"); + w = self.weight[torch.where(target==self.ignore_index, 0, target)].float() + w[target==self.ignore_index] = 0 + else: + w = (target != self.ignore_index).int() + + ret = w * ret + + if self.reduction == "none": + return ret + elif self.reduction == "sum": + return ret.sum() + elif self.reduction == "mean": + return ret.sum() / w.sum().float() diff --git a/examples/BMTrain/bmtrain/lr_scheduler/__init__.py b/examples/BMTrain/bmtrain/lr_scheduler/__init__.py new file mode 100644 index 00000000..0d9a0596 --- /dev/null +++ b/examples/BMTrain/bmtrain/lr_scheduler/__init__.py @@ -0,0 +1,6 @@ +from .warmup import WarmupLRScheduler +from .no_decay import NoDecay +from .noam import Noam +from .linear import Linear +from .cosine import Cosine +from .exponential import Exponential \ No newline at end of file diff --git a/examples/BMTrain/bmtrain/lr_scheduler/cosine.py b/examples/BMTrain/bmtrain/lr_scheduler/cosine.py new file mode 100644 index 00000000..3aed034d --- /dev/null +++ b/examples/BMTrain/bmtrain/lr_scheduler/cosine.py @@ -0,0 +1,18 @@ +import math +from .warmup import WarmupLRScheduler + + +class Cosine(WarmupLRScheduler): + r""" + After a warmup period during which learning rate increases linearly between 0 and the start_lr, + The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}` + """ + + def get_lr_warmup(self, num_iter) -> float: + return self.start_lr * num_iter / self.warmup_iter + + def get_lr_decay(self, num_iter) -> float: + progress = (num_iter - self.warmup_iter) / max( + 1, (self.end_iter - self.warmup_iter) + ) + return max(0.0, self.start_lr * 0.5 * (1.0 + math.cos(progress * math.pi))) diff --git a/examples/BMTrain/bmtrain/lr_scheduler/exponential.py b/examples/BMTrain/bmtrain/lr_scheduler/exponential.py new file mode 100644 index 00000000..6cf3240e --- /dev/null +++ b/examples/BMTrain/bmtrain/lr_scheduler/exponential.py @@ -0,0 +1,20 @@ +from .warmup import WarmupLRScheduler + + +class Exponential(WarmupLRScheduler): + r""" + After a warmup period during which learning rate increases linearly between 0 and the start_lr, + The decay period performs :math:`\text{lr}=\text{start_lr}\times \gamma ^ {\left(\text{num_iter}-\text{warmup_iter}\right)}` + """ + + def __init__( + self, optimizer, start_lr, warmup_iter, end_iter, num_iter, gamma=0.95 + ) -> None: + super().__init__(optimizer, start_lr, warmup_iter, end_iter, num_iter) + self.gamma = gamma + + def get_lr_warmup(self, num_iter) -> float: + return self.start_lr * num_iter / self.warmup_iter + + def get_lr_decay(self, num_iter) -> float: + return max(0.0, self.start_lr * self.gamma ** (num_iter - self.warmup_iter)) diff --git a/examples/BMTrain/bmtrain/lr_scheduler/linear.py b/examples/BMTrain/bmtrain/lr_scheduler/linear.py new file mode 100644 index 00000000..af193dd8 --- /dev/null +++ b/examples/BMTrain/bmtrain/lr_scheduler/linear.py @@ -0,0 +1,19 @@ +from .warmup import WarmupLRScheduler + + +class Linear(WarmupLRScheduler): + r""" + After a warmup period during which learning rate increases linearly between 0 and the start_lr, + The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{end_iter}-\text{num_iter}}{\text{end_iter}-\text{warmup_iter}}` + """ + + def get_lr_warmup(self, num_iter) -> float: + return self.start_lr * num_iter / self.warmup_iter + + def get_lr_decay(self, num_iter) -> float: + return max( + 0.0, + self.start_lr + * (self.end_iter - num_iter) + / (self.end_iter - self.warmup_iter), + ) diff --git a/examples/BMTrain/bmtrain/lr_scheduler/no_decay.py b/examples/BMTrain/bmtrain/lr_scheduler/no_decay.py new file mode 100644 index 00000000..6f85bf0a --- /dev/null +++ b/examples/BMTrain/bmtrain/lr_scheduler/no_decay.py @@ -0,0 +1,14 @@ +from .warmup import WarmupLRScheduler + + +class NoDecay(WarmupLRScheduler): + r""" + After a warmup period during which learning rate increases linearly between 0 and the start_lr, + The decay period performs :math:`\text{lr}=\text{start_lr}` + """ + + def get_lr_warmup(self, num_iter) -> float: + return self.start_lr * num_iter / self.warmup_iter + + def get_lr_decay(self, num_iter) -> float: + return self.start_lr diff --git a/examples/BMTrain/bmtrain/lr_scheduler/noam.py b/examples/BMTrain/bmtrain/lr_scheduler/noam.py new file mode 100644 index 00000000..8954a64d --- /dev/null +++ b/examples/BMTrain/bmtrain/lr_scheduler/noam.py @@ -0,0 +1,15 @@ +import math +from .warmup import WarmupLRScheduler + + +class Noam(WarmupLRScheduler): + r""" + After a warmup period during which performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{num_iter}}{\text{warmup_iter}^{3/2}}`, + The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{1}}{\sqrt{\text{num_iter}}}` + """ + + def get_lr_warmup(self, num_iter) -> float: + return self.start_lr / math.sqrt(self.warmup_iter) * num_iter / self.warmup_iter + + def get_lr_decay(self, num_iter) -> float: + return self.start_lr / math.sqrt(num_iter) diff --git a/examples/BMTrain/bmtrain/lr_scheduler/warmup.py b/examples/BMTrain/bmtrain/lr_scheduler/warmup.py new file mode 100644 index 00000000..1f9ccc8e --- /dev/null +++ b/examples/BMTrain/bmtrain/lr_scheduler/warmup.py @@ -0,0 +1,72 @@ +import torch + + +class WarmupLRScheduler: + r"""Base class for learning rate schedulers with warmup. + + Args: + optimizer (torch.optim.Optimizer): optimizer used for training + start_lr (float): starting learning rate + warmup_iter (int): number of iterations to linearly increase learning rate + end_iter (int): number of iterations to stop training + num_iter (int): current iteration number + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + start_lr, + warmup_iter, + end_iter, + num_iter=0, + ) -> None: + self.start_lr = start_lr + self.warmup_iter = warmup_iter + self.end_iter = end_iter + self.optimizer = optimizer + self.num_iter = num_iter + self._current_lr = None + + self.step(self.num_iter) + + def get_lr_warmup(self, num_iter) -> float: ... + + def get_lr_decay(self, num_iter) -> float: ... + + def get_lr(self): + assert self.num_iter >= 0 + + if self.num_iter < self.warmup_iter: + return self.get_lr_warmup(self.num_iter) + else: + return self.get_lr_decay(self.num_iter) + + @property + def current_lr(self): + return self._current_lr + + def step(self, num_iter=None) -> None: + if num_iter is None: + num_iter = self.num_iter + 1 + self.num_iter = num_iter + + lr = self.get_lr() + self._current_lr = lr + for group in self.optimizer.param_groups: + group["lr"] = lr + + def state_dict(self): + return { + "start_lr": self.start_lr, + "warmup_iter": self.warmup_iter, + "end_iter": self.end_iter, + "num_iter": self.num_iter, + } + + def load_state_dict(self, state_dict): + self.start_lr = state_dict["start_lr"] + self.warmup_iter = state_dict["warmup_iter"] + self.end_iter = state_dict["end_iter"] + self.num_iter = state_dict["num_iter"] + + self.step(self.num_iter) diff --git a/examples/BMTrain/bmtrain/nccl/__init__.py b/examples/BMTrain/bmtrain/nccl/__init__.py new file mode 100644 index 00000000..0f4129d5 --- /dev/null +++ b/examples/BMTrain/bmtrain/nccl/__init__.py @@ -0,0 +1,336 @@ + +from typing_extensions import Literal +import torch +from .. import C +from .enums import * + +class NCCLCommunicator: + """ + NCCL communicator stores the communicator handle. + """ + + def __init__(self, ptr) -> None: + self.__ptr = ptr + + @property + def ptr(self): + """ + Returns the communicator handle. + """ + if self.__ptr == -1: + raise RuntimeError("NCCL Communicator is already destroyed") + return self.__ptr + + def _destroy_ptr(self): + self.__ptr = -1 + +# utils + +def dtype2nccl(dtype : torch.dtype) -> int: + MAP = { + torch.int8: ncclInt8, + torch.uint8 : ncclUint8, + torch.int32 : ncclInt32, + torch.int : ncclInt32, + torch.int64 : ncclInt64, + torch.float16 : ncclFloat16, + torch.half : ncclHalf, + torch.bfloat16 : ncclBFloat16, + torch.float32 : ncclFloat32, + torch.float : ncclFloat, + torch.float64 : ncclFloat64, + torch.double : ncclDouble, + torch.bool : ncclBool + } + if dtype not in MAP: + raise TypeError("Unsupport dtype %s" % dtype) + return MAP[dtype] + +def op2nccl( + op : Literal["sum", "prod", "max", "min", "avg"] +): + if op == "sum": + return ncclSum + if op == "prod": + return ncclProd + if op == "max": + return ncclMax + if op == "min": + return ncclMin + if op == "avg": + return ncclAvg + raise ValueError("Unknown gather op %s") + +# wrappers + +def getUniqueId() -> bytes: + """ + NCCL API: `ncclGetUniqueId <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclgetuniqueid>`_ + + """ + return C.ncclGetUniqueId() + +def commInitRank(unique_id : bytes, world_size : int, rank : int) -> NCCLCommunicator: + """ + NCCL API: `ncclCommInitRank <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcomminitrank>`_ + + """ + assert rank >= 0 and rank < world_size, "rank must be between 0 and world_size-1" + return NCCLCommunicator(C.ncclCommInitRank(unique_id, world_size, rank)) + +def commDestroy(comm : NCCLCommunicator): + """ + NCCL API: `ncclCommDestroy <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommdestroy>`_ + + """ + C.ncclCommDestroy(comm.ptr) + comm._destroy_ptr() +def commCount(comm : NCCLCommunicator): + """NCCL API: `ncclCommCount <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommcount>`_ + + Args: + comm (NCCLCommunicator): NCCL communicator. + """ + return C.ncclCommCount(comm.ptr) +### collective +def commRank(comm : NCCLCommunicator): + """NCCL API: `ncclCommUserRank <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclCommUserRank>`_ + + Args: + comm (NCCLCommunicator): NCCL communicator. + """ + return C.ncclCommUserRank(comm.ptr) +def allReduce( + src : torch.storage._StorageBase, + dst : torch.storage._StorageBase, + op : Literal["sum", "prod", "max", "min", "avg"], + comm : NCCLCommunicator + ): + """NCCL API: `ncclAllReduce <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#ncclallreduce>`_ + + Args: + src (torch.storage._StorageBase): Source buffer. + dst (torch.storage._StorageBase): Destination buffer. + op (Literal["sum", "prod", "max", "min", "avg"]): Reduction operation. + comm (NCCLCommunicator): NCCL communicator. + + The src and dst buffers must be the same size, type and on the same device. + + If src == dst, the operation is performed in-place. + + """ + assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.is_cuda and dst.is_cuda + + sendbuff = src.data_ptr() + recvbuff = dst.data_ptr() + count = src.size() + datatype = dtype2nccl(src.dtype) + operator = op2nccl(op) + + assert src.size() == dst.size(), "Buffer size not aligned" + C.ncclAllReduce( + sendbuff, + recvbuff, + count, + datatype, + operator, + comm.ptr, + torch.cuda.current_stream().cuda_stream + ) +def send(src : torch.storage._StorageBase, + peer : int, + comm : NCCLCommunicator + ): + """NCCL API: `ncclsend <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend>`_ + + Args: + src (torch.storage._StorageBase): Source buffer. + peer (int): rank peer needs to call ncclRecv + comm (NCCLCommunicator): NCCL communicator. + """ + + sendbuff = src.data_ptr() + count = src.size() + datatype = dtype2nccl(src.dtype) + C.ncclSend( + sendbuff, + count, + datatype, + peer, + comm.ptr, + torch.cuda.current_stream().cuda_stream + ) +def recv(dst : torch.storage._StorageBase, + peer : int, + comm : NCCLCommunicator + ): + recvbuff = dst.data_ptr() + count = dst.size() + datatype = dtype2nccl(dst.dtype) + C.ncclRecv( + recvbuff, + count, + datatype, + peer, + comm.ptr, + torch.cuda.current_stream().cuda_stream + ) + +def broadcast( + src : torch.storage._StorageBase, + dst : torch.storage._StorageBase, + root : int, + comm : NCCLCommunicator + ): + """NCCL API: `ncclBroadcast <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#ncclbroadcast>`_ + + Args: + src (torch.storage._StorageBase): Source buffer. + dst (torch.storage._StorageBase): Destination buffer. + root (int): Rank of the root. + comm (NCCLCommunicator): NCCL communicator. + + The src and dst buffers must be the same size, type and on the same device. + + If src == dst, the operation is performed in-place. + + """ + + assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.is_cuda and dst.is_cuda + + sendbuff = src.data_ptr() + recvbuff = dst.data_ptr() + count = src.size() + datatype = dtype2nccl(src.dtype) + + assert dst.size() == src.size(), "Buffer size not aligned" + C.ncclBroadcast( + sendbuff, + recvbuff, + count, + datatype, + root, + comm.ptr, + torch.cuda.current_stream().cuda_stream + ) + +def reduce( + src : torch.storage._StorageBase, + dst : torch.storage._StorageBase, + op : Literal["sum", "prod", "max", "min", "avg"], + root : int, + comm : NCCLCommunicator + ): + """NCCL API: `ncclReduce <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#ncclreduce>`_ + + Args: + src (torch.storage._StorageBase): Source buffer. + dst (torch.storage._StorageBase): Destination buffer. + op (Literal["sum", "prod", "max", "min", "avg"]): Reduction operation. + root (int): Rank of the root. + comm (NCCLCommunicator): NCCL communicator. + + The src and dst buffers must be the same size, type and on the same device. + + If src == dst, the operation is performed in-place. + + """ + assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.is_cuda and dst.is_cuda + + sendbuff = src.data_ptr() + recvbuff = dst.data_ptr() + count = src.size() + datatype = dtype2nccl(src.dtype) + operator = op2nccl(op) + + assert dst.size() == src.size(), "Buffer size not aligned" + C.ncclReduce(sendbuff, recvbuff, count, datatype, operator, root, comm.ptr, torch.cuda.current_stream().cuda_stream) + +def allGather( + src : torch.storage._StorageBase, + dst : torch.storage._StorageBase, + comm : NCCLCommunicator + ): + """NCCL API: `ncclAllGather <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#ncclallgather>`_ + + Args: + src (torch.storage._StorageBase): Source buffer. + dst (torch.storage._StorageBase): Destination buffer. + comm (NCCLCommunicator): NCCL communicator. + + The size of the dst buffer must be equal to the size of src buffer * world_size. + + The dst buffer is only used on rank root. + + """ + assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.is_cuda and dst.is_cuda + + sendbuff = src.data_ptr() + recvbuff = dst.data_ptr() + sendcount = src.size() + datatype = dtype2nccl(src.dtype) + assert dst.size() % sendcount == 0, "Buffer size not aligned" + C.ncclAllGather( + sendbuff, + recvbuff, + sendcount, + datatype, + comm.ptr, + torch.cuda.current_stream().cuda_stream + ) + + +def reduceScatter( + src : torch.storage._StorageBase, + dst : torch.storage._StorageBase, + op : Literal["sum", "prod", "max", "min", "avg"], + comm : NCCLCommunicator + ): + """NCCL API: `ncclReduceScatter <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/colls.html#ncclreducescatter>`_ + + Args: + src (torch.storage._StorageBase): Source buffer. + dst (torch.storage._StorageBase): Destination buffer. + op (Literal["sum", "prod", "max", "min", "avg"]): Reduction operation. + comm (NCCLCommunicator): NCCL communicator. + + The size of the dst buffer must be equal to the size of src buffer / world_size. + + The dst buffer on rank `i` will contail the i-th block of the reduced result. + + """ + assert src.dtype == dst.dtype, "send and recv buffers must be the same time" + assert src.is_cuda and dst.is_cuda + + sendbuff = src.data_ptr() + recvbuff = dst.data_ptr() + recvcount = dst.size() + datatype = dtype2nccl(src.dtype) + operator = op2nccl(op) + + assert src.size() % recvcount == 0, "Buffer size not aligned" + C.ncclReduceScatter( + sendbuff, + recvbuff, + recvcount, + datatype, + operator, + comm.ptr, + torch.cuda.current_stream().cuda_stream + ) + +def groupStart(): + """ + NCCL API: `ncclGroupStart <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/group.html#ncclgroupstart>`_ + """ + C.ncclGroupStart() + +def groupEnd(): + """ + NCCL API: `ncclGroupEnd <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/group.html#ncclgroupend>`_ + """ + C.ncclGroupEnd() diff --git a/examples/BMTrain/bmtrain/nccl/enums.py b/examples/BMTrain/bmtrain/nccl/enums.py new file mode 100644 index 00000000..67411f0e --- /dev/null +++ b/examples/BMTrain/bmtrain/nccl/enums.py @@ -0,0 +1,27 @@ + +### ncclDataType_t + +ncclInt8 = 0 +ncclChar = 0 +ncclBool = 0 +ncclUint8 = 1 +ncclInt32 = 2 +ncclInt = 2 +ncclUint32 = 3 +ncclInt64 = 4 +ncclUint64 = 5 +ncclFloat16 = 6 +ncclHalf = 6 +ncclFloat32 = 7 +ncclFloat = 7 +ncclFloat64 = 8 +ncclDouble = 8 +ncclBFloat16 = 9 + +### ncclRedOp_t + +ncclSum = 0 +ncclProd = 1 +ncclMax = 2 +ncclMin = 3 +ncclAvg = 4 \ No newline at end of file diff --git a/examples/BMTrain/bmtrain/nn/__init__.py b/examples/BMTrain/bmtrain/nn/__init__.py new file mode 100644 index 00000000..60fed663 --- /dev/null +++ b/examples/BMTrain/bmtrain/nn/__init__.py @@ -0,0 +1,5 @@ +from .linear import Linear, OpLinear +from .column_parallel_linear import ColumnParallelLinear +from .row_parallel_linear import RowParallelLinear +from .parallel_embedding import VPEmbedding +from .parallel_linear_func import OpParallelLinear diff --git a/examples/BMTrain/bmtrain/nn/column_parallel_linear.py b/examples/BMTrain/bmtrain/nn/column_parallel_linear.py new file mode 100644 index 00000000..e1ede115 --- /dev/null +++ b/examples/BMTrain/bmtrain/nn/column_parallel_linear.py @@ -0,0 +1,80 @@ +import torch +from torch.nn.parameter import Parameter + +import bmtrain as bmt +from bmtrain.global_var import config +from .parallel_linear_func import OpParallelLinear, ReduceType + + +class ColumnParallelLinear(bmt.DistributedModule): + """Tensor Parallel use cloumn partition for Linear. + + Args: + in_features (int): in_features size. + out_features (int): out_features size. + bias (bool): whether use bias. + dtype : data type. + gather_ouput (bool): whether gather output after compute. + gather_input (bool): whether gather input before compute. + async_gather_chunks (int): chunk size for async gathering data. + + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype=None, + gather_output=False, + gather_input=True, + async_gather_chunks=2, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + self.gather_input = gather_input + self.async_gather_chunks = async_gather_chunks + tp_size = config["tp_size"] + assert out_features % tp_size == 0 + self.out_features_per_partition = out_features // tp_size + self.weight = bmt.DistributedParameter( + torch.empty( + self.out_features_per_partition, in_features, dtype=dtype, device="cuda" + ), + init_method=torch.nn.init.xavier_normal_, + tp_split_dim=0, + tp_mode=True, + ) + if bias: + self.bias = bmt.DistributedParameter( + torch.empty( + self.out_features_per_partition, dtype=dtype, device="cuda" + ), + init_method=torch.nn.init.zeros_, + tp_split_dim=0, + tp_mode=True, + ) + else: + self.register_parameter("bias", None) + + def forward(self, input): + gather_input = self.gather_input + split_input = False + reduce_output_type = None + return OpParallelLinear.apply( + input, + self.weight, + self.bias, + gather_input, + self.gather_output, + split_input, + reduce_output_type, + self.async_gather_chunks, + ) + + def extra_repr(self) -> str: + return "in_features={}, out_features={}, bias={}".format( + self.in_features, self.out_features_per_partitions, self.bias is not None + ) diff --git a/examples/BMTrain/bmtrain/nn/linear.py b/examples/BMTrain/bmtrain/nn/linear.py new file mode 100644 index 00000000..8afb1d89 --- /dev/null +++ b/examples/BMTrain/bmtrain/nn/linear.py @@ -0,0 +1,56 @@ +import torch +import torch.nn.functional as F +import bmtrain as bmt + + +class OpLinear(torch.autograd.Function): + @staticmethod + def forward(ctx, x, weight, bias=None): + ctx.save_for_backward(x, weight, bias) + return F.linear(x, weight, bias) + + @staticmethod + def backward(ctx, grad_output): + x, weight, bias = ctx.saved_tensors + grad_x = grad_weight = grad_bias = None + if x.requires_grad: + grad_x = grad_output.matmul(weight) + if weight.requires_grad: + dim = grad_output.dim() + grad_weight = ( + grad_output.reshape(-1, grad_output.shape[-1]) + .t() + .matmul(x.reshape(-1, x.shape[-1])) + ) + if bias is not None and bias.requires_grad: + grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) + return grad_x, grad_weight, grad_bias + + +class Linear(bmt.DistributedModule): + def __init__( + self, in_features: int, out_features: int, bias: bool = True, dtype=None + ) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.weight = bmt.DistributedParameter( + torch.empty(out_features, in_features, dtype=dtype, device="cuda"), + init_method=torch.nn.init.xavier_normal_, + ) + if bias: + self.bias = bmt.DistributedParameter( + torch.empty(out_features, dtype=dtype, device="cuda"), + init_method=torch.nn.init.zeros_, + ) + else: + self.register_parameter("bias", None) + + def forward(self, input): + return OpLinear.apply(input, self.weight, self.bias) + + def extra_repr(self) -> str: + return "in_features={}, out_features={}, bias={}".format( + self.in_features, self.out_features, self.bias is not None + ) diff --git a/examples/BMTrain/bmtrain/nn/parallel_embedding.py b/examples/BMTrain/bmtrain/nn/parallel_embedding.py new file mode 100644 index 00000000..3bdc4e56 --- /dev/null +++ b/examples/BMTrain/bmtrain/nn/parallel_embedding.py @@ -0,0 +1,59 @@ +import torch +from torch.nn.parameter import Parameter +import torch.nn.functional as F +import math + +import bmtrain as bmt +from bmtrain.global_var import config +from bmtrain.distributed import all_reduce, all_gather +from .parallel_linear_func import OpParallelLinear + + +class VPEmbedding(bmt.DistributedModule): + """Vocab Parallel Embedding. + + Args: + vocab_size (int required): vocab size. + embedding_size (int required): embedding size. + dtype (torch.dtype): data type. + init_mean (float optional): mean for weight init. + init_std (float optional): std for weight init. + + """ + + def __init__( + self, + vocab_size: int, + embedding_size: int, + dtype: torch.dtype = torch.half, + init_mean: float = 0.0, + init_std: float = 1, + ): + super().__init__() + + self.dim_model = embedding_size + assert vocab_size % bmt.config["tp_size"] == 0 + self.vocab_size_per_partition = vocab_size // bmt.config["tp_size"] + self.start_index = bmt.config["tp_rank"] * self.vocab_size_per_partition + self.end_index = (bmt.config["tp_rank"] + 1) * self.vocab_size_per_partition + self.weight = bmt.DistributedParameter( + torch.empty(self.vocab_size_per_partition, embedding_size, dtype=dtype), + init_method=bmt.ParameterInitializer( + torch.nn.init.normal_, mean=init_mean, std=init_std + ), + tp_split_dim=0, + tp_mode=True, + ) + + def forward(self, x: torch.Tensor, projection=False): + if not projection: + weight = all_gather(self.weight, comm=config["tp_comm"]).flatten(0, 1) + out = F.embedding(x, weight) + return out + else: + x = bmt.distributed.all_gather(x, comm=bmt.config["tp_comm"]).view( + x.shape[0], -1, x.shape[-1] + ) + return bmt.nn.OpParallelLinear.apply( + x, self.weight, None, False, False, False, None, 1 + ) diff --git a/examples/BMTrain/bmtrain/nn/parallel_linear_func.py b/examples/BMTrain/bmtrain/nn/parallel_linear_func.py new file mode 100644 index 00000000..e389cde6 --- /dev/null +++ b/examples/BMTrain/bmtrain/nn/parallel_linear_func.py @@ -0,0 +1,352 @@ +import torch +import torch.nn.functional as F +from bmtrain.global_var import config +from ..distributed import all_gather, all_reduce +from .. import nccl +import bmtrain as bmt +from enum import Enum + + +class ReduceType(Enum): + ALL_REDUCE = 1 + REDUCE_SCATTER = 2 + + +def preprocess_input(input, gather_input, split_input): + if gather_input: + input = all_gather(input, config["tp_comm"]) + input = input.flatten(0, 1) + + if split_input: + all_input_list = input.chunk(config["tp_size"], dim=-1) + input = all_input_list[config["topology"].tp_id] + return input + + +def async_all_gather_linear_func(input, weight, bias, async_chunks=2): + dim = input.dim() + shape = list(input.shape) + if dim > 2: + input = input.view(-1, input.shape[-1]) + tp_size = config["tp_size"] + current_stream = torch.cuda.current_stream() + comm_stream = config["tp_comm_stream"] + + rounds = async_chunks + inputs = input.chunk(rounds, dim=0) + comm_stream.wait_stream(current_stream) + outputs = [None] * tp_size * rounds + + input = all_gather(inputs[0], config["tp_comm"]) + input = input.flatten(0, 1) + out = F.linear(input, weight, bias) + outs = out.chunk(tp_size, dim=0) + for i in range(tp_size): + outputs[i * rounds] = outs[i] + + # async all_gather and overalap with linear + for i in range(rounds - 1): + with torch.cuda.stream(comm_stream): + inputs[i + 1].record_stream(comm_stream) + input = all_gather(inputs[i + 1], config["tp_comm"]) + input = input.flatten(0, 1) + + current_stream.wait_stream(comm_stream) + out = F.linear(input, weight, bias) + outs = out.chunk(tp_size, dim=0) + for j in range(tp_size): + outputs[(i + 1) + j * rounds] = outs[j] + + out = torch.cat(outputs, dim=0) + if dim > 2: + out_shape = list(out.shape) + shape[-1] = out_shape[-1] + shape[0] = shape[0] * tp_size + out = out.view(shape) + return out + + +def async_reduce_scatter_linear_func(input, weight, bias, async_chunks=2): + tp_size = config["tp_size"] + comm_stream = config["tp_comm_stream"] + rounds = async_chunks + input_shape = list(input.shape) + dim = input.dim() + if dim > 2: + input = input.view(-1, input.shape[-1]) + inputs = input.chunk(rounds * tp_size, dim=0) + current_stream = torch.cuda.current_stream() + + outputs = [None] * rounds + for i in range(rounds): + input = [None] * tp_size + for j in range(tp_size): + input[j] = inputs[j * rounds + i] + input = torch.cat(input, dim=0) + out = F.linear(input, weight, bias) + with torch.cuda.stream(comm_stream): + comm_stream.wait_stream(current_stream) + out.record_stream(comm_stream) + shape = list(out.shape) + shape[0] = shape[0] // config["tp_size"] + outputs[i] = torch.empty(shape, dtype=out.dtype, device=out.device) + nccl.reduceScatter( + out.storage(), outputs[i].storage(), "sum", config["tp_comm"] + ) + + current_stream.wait_stream(comm_stream) + out = torch.cat(outputs, dim=0) + if dim > 2: + out_shape = list(out.shape) + input_shape[-1] = out_shape[-1] + input_shape[0] = input_shape[0] // tp_size + out = out.view(input_shape) + + return out + + +def async_all_gather_linear_backward_func( + grad_out, input, weight, bias, async_chunks=2 +): + tp_size = config["tp_size"] + current_stream = torch.cuda.current_stream() + comm_stream = config["tp_comm_stream"] + input_require_grad = input.requires_grad + dim = input.dim() + input_shape = input.shape + if dim > 2: + input = input.view(-1, input_shape[-1]) + grad_out = grad_out.view(-1, grad_out.shape[-1]) + + rounds = async_chunks + grad_inputs = [None] * tp_size * rounds + grad_weights = [None] * tp_size * rounds + grad_outs = [None] * tp_size * rounds + local_grad_outs = grad_out.chunk(rounds, dim=0) + + inputs = [None] * rounds + comm_stream.wait_stream(current_stream) + if weight.requires_grad: + with torch.cuda.stream(comm_stream): + input.record_stream(comm_stream) + input_list = [None] * tp_size * rounds + tp_inputs = input.chunk(tp_size, dim=0) + for i in range(tp_size): + chunk_inputs = tp_inputs[i].chunk(rounds, dim=0) + for j in range(rounds): + input_list[j * tp_size + i] = chunk_inputs[j] + start = 0 + end = tp_size + for i in range(rounds): + inputs[i] = torch.cat(input_list[start:end], dim=0) + start = end + end += tp_size + + grad_input = grad_weight = grad_bias = None + + grad_out = all_gather(local_grad_outs[0], config["tp_comm"]) + for j in range(tp_size): + grad_outs[j * rounds] = grad_out[j] + grad_out = grad_out.flatten(0, 1) # (tp_size * (m/rounds), n) + if input_require_grad: + grad_input = grad_out.matmul( + weight + ) # (tp_size * (m/rounds), n) * (n, k/tp_size) + tmp_grad_inputs = grad_input.chunk(tp_size, dim=0) + for j in range(tp_size): + grad_inputs[j * rounds] = tmp_grad_inputs[j] + + if weight.requires_grad: + grad_weight = ( + grad_out.reshape(-1, grad_out.shape[-1]) + .t() + .matmul(inputs[0].reshape(-1, inputs[0].shape[-1])) + ) + + # async all_gather and overalap with matmul + for i in range(rounds - 1): + with torch.cuda.stream(comm_stream): + local_grad_outs[i + 1].record_stream(comm_stream) + grad_out = all_gather(local_grad_outs[i + 1], config["tp_comm"]) + for j in range(tp_size): + grad_outs[j * rounds + i + 1] = grad_out[j] + grad_out = grad_out.flatten(0, 1) # (tp_size * (m/rounds), n) + + current_stream.wait_stream(comm_stream) + if input_require_grad: + grad_input = grad_out.matmul( + weight + ) # (tp_size * (m/rounds), n) * (n, k/tp_size) + tmp_grad_inputs = grad_input.chunk(tp_size, dim=0) + for j in range(tp_size): + grad_inputs[j * rounds + i + 1] = tmp_grad_inputs[j] + + if weight.requires_grad: + dim = grad_out.dim() + grad_weight += ( + grad_out.reshape(-1, grad_out.shape[-1]) + .t() + .matmul(inputs[i + 1].reshape(-1, inputs[i + 1].shape[-1])) + ) + + if input_require_grad: + grad_input = torch.cat(grad_inputs, dim=0) + grad_input = grad_input.view(input_shape) + + if bias is not None and bias.requires_grad: + grad_out = torch.cat(grad_outs, dim=0) + grad_bias = grad_out.reshape(-1, grad_out.shape[-1]).sum(0) + + return grad_input, grad_weight, grad_bias + + +class OpParallelLinear(torch.autograd.Function): + """OpParallelLinear is a subclass of torch.autograd.Function. + It gathers the input tensor when needed, and all reduce or reduece scatter the output when needed. + + """ + + @staticmethod + def forward( + ctx, + input, + weight, + bias=None, + gather_input=False, + gather_output=False, + split_input=False, + reduce_output_type=None, + async_gather_chunks=2, + ): + if reduce_output_type is not None: + reduce_output_type = ReduceType(reduce_output_type) + + ctx.save_for_backward(input, weight, bias) + ctx.gather_output = gather_output + ctx.split_input = split_input + ctx.gather_input = gather_input + ctx.reduce_output_type = reduce_output_type + ctx.async_gather_chunks = async_gather_chunks + + if ( + gather_input + and config["tp_size"] > 1 + and async_gather_chunks > 1 + and split_input == False + ): + out = async_all_gather_linear_func(input, weight, bias, async_gather_chunks) + elif reduce_output_type == ReduceType.REDUCE_SCATTER: + return async_reduce_scatter_linear_func( + input, weight, bias, async_gather_chunks + ) + else: + all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) + out = F.linear(all_input, weight, bias) + + if gather_output: + all_output_list = all_gather(out, config["tp_comm"]) + all_output_list = all_output_list.chunk(config["tp_size"], dim=0) + out = torch.cat(all_output_list, dim=all_output_list[0].dim() - 1).flatten( + 0, 1 + ) + + if reduce_output_type is None: + return out + + if reduce_output_type == ReduceType.ALL_REDUCE: + nccl.allReduce(out.storage(), out.storage(), "sum", config["tp_comm"]) + return out + else: + assert False, "no support reduce type{}".format(reduce_output_type) + + @staticmethod + def backward(ctx, grad_output): + input, weight, bias = ctx.saved_tensors + gather_output = ctx.gather_output + + if ctx.reduce_output_type == ReduceType.REDUCE_SCATTER: + if input.requires_grad or weight.requires_grad: + grad_input, grad_weight, grad_bias = ( + async_all_gather_linear_backward_func( + grad_output, input, weight, bias, ctx.async_gather_chunks + ) + ) + return grad_input, grad_weight, grad_bias, None, None, None, None, None + else: + grad_output = all_gather(grad_output, config["tp_comm"]) + grad_output = grad_output.flatten(0, 1) + + if gather_output: + tp_size = config["tp_size"] + tp_id = config["topology"].tp_id + grad_output_list = grad_output.chunk(tp_size, dim=-1) + grad_output = grad_output_list[tp_id] + + grad_input = grad_weight = grad_bias = None + + current_stream = torch.cuda.current_stream() + if input.requires_grad or weight.requires_grad: + if ctx.gather_input: + # async the all_gather + with torch.cuda.stream(config["tp_comm_stream"]): + input.record_stream(config["tp_comm_stream"]) + config["tp_comm_stream"].wait_stream(current_stream) + all_input = preprocess_input( + input, ctx.gather_input, ctx.split_input + ) + # use event to solve two streams waiting for each other + gather_event = config["tp_comm_stream"].record_event() + else: + all_input = preprocess_input(input, ctx.gather_input, ctx.split_input) + + if input.requires_grad: + grad_all_input = grad_output.matmul(weight) + grad_input = torch.zeros_like(input) + if ctx.gather_input: + # async the reduce_scatter + with torch.cuda.stream(config["tp_comm_stream"]): + config["tp_comm_stream"].wait_stream(current_stream) + grad_input.record_stream(config["tp_comm_stream"]) + grad_all_input.record_stream(config["tp_comm_stream"]) + nccl.reduceScatter( + grad_all_input.storage(), + grad_input.storage(), + "sum", + config["tp_comm"], + ) + elif ctx.reduce_output_type is None: + with torch.cuda.stream(config["tp_comm_stream"]): + config["tp_comm_stream"].wait_stream(current_stream) + grad_input.record_stream(config["tp_comm_stream"]) + nccl.allReduce( + grad_all_input.storage(), + grad_all_input.storage(), + "sum", + config["tp_comm"], + ) + grad_input = grad_all_input + else: + grad_input = grad_all_input + + if ctx.split_input: + with torch.cuda.stream(config["tp_comm_stream"]): + config["tp_comm_stream"].wait_stream(current_stream) + grad_input.record_stream(config["tp_comm_stream"]) + grad_input = all_gather(grad_input, config["tp_comm"]) + + # wait all_gather + if ctx.gather_input: + current_stream.wait_event(gather_event) + if weight.requires_grad: + grad_weight = ( + grad_output.reshape(-1, grad_output.shape[-1]) + .t() + .matmul(all_input.reshape(-1, all_input.shape[-1])) + ) + + if bias is not None and bias.requires_grad: + grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0) + + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config["tp_comm_stream"]) + return grad_input, grad_weight, grad_bias, None, None, None, None, None diff --git a/examples/BMTrain/bmtrain/nn/row_parallel_linear.py b/examples/BMTrain/bmtrain/nn/row_parallel_linear.py new file mode 100644 index 00000000..ee4610cc --- /dev/null +++ b/examples/BMTrain/bmtrain/nn/row_parallel_linear.py @@ -0,0 +1,88 @@ +import torch +from torch.nn.parameter import Parameter + +import bmtrain as bmt +from bmtrain.global_var import config +from .parallel_linear_func import OpParallelLinear, ReduceType + + +class RowParallelLinear(bmt.DistributedModule): + """Tensor Parallel use row partition for Linear. + + Args: + in_features (int): in_features size. + out_features (int): out_features size. + bias (bool): whether use bias. + dtype : data type. + split_input (bool): whether split input before compute. + all_reduce_output (bool): if true use all_reduce data after compute, or use reduce_scatter. + async_chunks (int): chunk size for async. + + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + dtype=None, + split_input=False, + all_reduce_output=False, + async_chunks=2, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.split_input = split_input + self.all_reduce_output = all_reduce_output + self.async_chunks = async_chunks + tp_size = config["tp_size"] + assert in_features % tp_size == 0 + self.in_features_per_partition = in_features // tp_size + self.weight = bmt.DistributedParameter( + torch.empty( + self.out_features, + self.in_features_per_partition, + dtype=dtype, + device="cuda", + ), + init_method=torch.nn.init.xavier_normal_, + tp_split_dim=1, + tp_mode=True, + ) + if bias: + self.bias = bmt.DistributedParameter( + torch.empty(self.out_features, dtype=dtype, device="cuda"), + init_method=torch.nn.init.zeros_, + tp_split_dim=-1, + tp_mode=True, + ) + else: + self.register_parameter("bias", None) + + def forward(self, input): + gather_input = self.split_input + gather_output = False + reduce_output_type = ( + ReduceType.ALL_REDUCE + if self.all_reduce_output + else ReduceType.REDUCE_SCATTER + ) + out = OpParallelLinear.apply( + input, + self.weight, + None, + gather_input, + gather_output, + self.split_input, + reduce_output_type, + self.async_chunks, + ) + if self.bias is not None: + out = out + self.bias + return out + + def extra_repr(self) -> str: + return "in_features={}, out_features={}, bias={}".format( + self.in_features_per_partition, self.out_features, self.bias is not None + ) diff --git a/examples/BMTrain/bmtrain/optim/__init__.py b/examples/BMTrain/bmtrain/optim/__init__.py new file mode 100644 index 00000000..15206328 --- /dev/null +++ b/examples/BMTrain/bmtrain/optim/__init__.py @@ -0,0 +1,3 @@ +from .adam import AdamOptimizer +from .adam_offload import AdamOffloadOptimizer +from .optim_manager import OptimManager \ No newline at end of file diff --git a/examples/BMTrain/bmtrain/optim/_distributed.py b/examples/BMTrain/bmtrain/optim/_distributed.py new file mode 100644 index 00000000..df8f2f3e --- /dev/null +++ b/examples/BMTrain/bmtrain/optim/_distributed.py @@ -0,0 +1,40 @@ +import torch +from ..distributed import all_reduce, all_gather + + +def state_dict_gather(state_dict): + param_key = [ + p for param_group in state_dict["param_groups"] for p in param_group["params"] + ] + for k, v in state_dict["state"].items(): + if "step" in v: + step = v["step"] + + for k in param_key: + if k not in state_dict["state"]: + state_dict["state"][k] = { + "exp_avg": torch.tensor([], device="cuda", dtype=torch.float32), + "exp_avg_sq": torch.tensor([], device="cuda", dtype=torch.float32), + "_param_fp32": torch.tensor([], device="cuda", dtype=torch.float32), + "step": step, + } + v = state_dict["state"][k] + for name, dtype in [ + ("exp_avg", torch.float32), + ("exp_avg_sq", torch.float32), + ("_param_fp32", torch.float32), + ]: + if name in v: + with torch.no_grad(): + numel = torch.tensor( + v[name].numel(), device="cuda", dtype=torch.long + ) + max_numel = all_reduce(numel, op="max") + v_p = torch.nn.functional.pad( + v[name], (0, max_numel - numel), value=-1e15 + ) + if max_numel > 0: + whole_state = all_gather(v_p.cuda()).flatten() + whole_state = whole_state[whole_state != -1e15] + v[name] = whole_state.contiguous().cpu() + return state_dict diff --git a/examples/BMTrain/bmtrain/optim/_function.py b/examples/BMTrain/bmtrain/optim/_function.py new file mode 100644 index 00000000..f9e0ce9d --- /dev/null +++ b/examples/BMTrain/bmtrain/optim/_function.py @@ -0,0 +1,218 @@ +from .. import C +import torch + +CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda + + +def bf16_from_fp32(param_fp32): + param_bf16 = torch.empty_like(param_fp32, dtype=torch.bfloat16) + C.to_bf16_from_fp32( + param_fp32.numel(), param_fp32.data_ptr(), param_bf16.data_ptr() + ) + return param_bf16 + + +def fp16_from_fp32(param_fp32): + param_fp16 = torch.empty_like(param_fp32, dtype=torch.float16) + C.to_fp16_from_fp32( + param_fp32.numel(), param_fp32.data_ptr(), param_fp16.data_ptr() + ) + return param_fp16 + + +def adam_cpu( + param_fp32: torch.Tensor, + param_fp16: torch.Tensor, + delta_info: torch.Tensor, + g_fp16: torch.Tensor, + m_fp32: torch.Tensor, + v_fp32: torch.Tensor, + beta1: float, + beta2: float, + eps: float, + lr: float, + scale: float, + weight_decay: float, + step: int, +) -> None: + assert param_fp32.is_contiguous(), "param_fp32 must be contiguous" + assert param_fp16.is_contiguous(), "param_fp16 must be contiguous" + assert g_fp16.is_contiguous(), "g_fp16 must be contiguous" + assert m_fp32.is_contiguous(), "m_fp32 must be contiguous" + assert v_fp32.is_contiguous(), "v_fp32 must be contiguous" + assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor" + assert ( + param_fp16.dtype == torch.float16 or param_fp16.dtype == torch.bfloat16 + ), "param_fp16 must be float16/bfloat16 tensor" + assert ( + g_fp16.dtype == torch.float16 or g_fp16.dtype == torch.bfloat16 + ), "g_fp16 must be float16/bfloat16 tensor" + assert m_fp32.dtype == torch.float32, "m_fp32 must be float32 tensor" + assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor" + assert param_fp32.device == torch.device("cpu"), "param_fp32 must be a cpu tensor" + assert param_fp16.device == torch.device("cpu"), "param_fp16 must be a cpu tensor" + assert g_fp16.device == torch.device("cpu"), "g_fp16 must be a cpu tensor" + assert m_fp32.device == torch.device("cpu"), "m_fp32 must be a cpu tensor" + assert v_fp32.device == torch.device("cpu"), "v_fp32 must be a cpu tensor" + assert ( + param_fp32.numel() == param_fp16.numel() + ), "param_fp32 and param_fp16 must have the same number of elements" + assert ( + param_fp32.numel() == g_fp16.numel() + ), "param_fp32 and g_fp16 must have the same number of elements" + assert ( + param_fp32.numel() == m_fp32.numel() + ), "param_fp32 and m_fp32 must have the same number of elements" + assert ( + param_fp32.numel() == v_fp32.numel() + ), "param_fp32 and v_fp32 must have the same number of elements" + if delta_info is not None: + assert delta_info.is_contiguous(), "delta_info must be contiguous" + assert delta_info.dtype == torch.float32, "delta_info must be float32 tensor" + assert delta_info.device == torch.device( + "cpu" + ), "delta_info must be a cpu tensor" + assert delta_info.numel() == 4, "delta_info have a length of 4" + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + if g_fp16.dtype == torch.float16: + launcher = C.adam_cpu_fp16_launcher + elif g_fp16.dtype == torch.bfloat16: + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + launcher = C.adam_cpu_bf16_launcher + launcher( + param_fp32.numel(), + param_fp32.data_ptr(), + param_fp16.data_ptr(), + delta_info.data_ptr() if delta_info is not None else 0, + g_fp16.data_ptr(), + m_fp32.data_ptr(), + v_fp32.data_ptr(), + beta1, + beta2, + eps, + lr, + scale, + weight_decay, + bias_correction1, + bias_correction2, + ) + + +def adam_fp16( + param_fp32: torch.Tensor, + param_fp16: torch.Tensor, + g_fp16: torch.Tensor, + m_fp16: torch.Tensor, + v_fp32: torch.Tensor, + beta1: float, + beta2: float, + eps: float, + lr: float, + scale: float, + weight_decay: float, + step: int, +) -> None: + assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda" + assert CHECK_INPUT(param_fp16), "param_fp16 must be contiguous and on cuda" + assert CHECK_INPUT(g_fp16), "g_fp16 must be contiguous and on cuda" + assert CHECK_INPUT(m_fp16), "m_fp32 must be contiguous and on cuda" + assert CHECK_INPUT(v_fp32), "v_fp32 must be contiguous and on cuda" + assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor" + assert param_fp16.dtype == torch.float16, "param_fp16 must be float16 tensor" + assert g_fp16.dtype == torch.float16, "g_fp16 must be float16 tensor" + assert m_fp16.dtype == torch.float16, "m_fp16 must be float16 tensor" + assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor" + assert ( + param_fp32.numel() == param_fp16.numel() + ), "param_fp32 and param_fp16 must have the same number of elements" + assert ( + param_fp32.numel() == g_fp16.numel() + ), "param_fp32 and g_fp16 must have the same number of elements" + assert ( + param_fp32.numel() == m_fp16.numel() + ), "param_fp32 and m_fp32 must have the same number of elements" + assert ( + param_fp32.numel() == v_fp32.numel() + ), "param_fp32 and v_fp32 must have the same number of elements" + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + stream = torch.cuda.current_stream().cuda_stream + C.adam_fp16_launcher( + param_fp32.numel(), + param_fp32.data_ptr(), + param_fp16.data_ptr(), + g_fp16.data_ptr(), + m_fp16.data_ptr(), + v_fp32.data_ptr(), + beta1, + beta2, + eps, + lr, + scale, + weight_decay, + bias_correction1, + bias_correction2, + stream, + ) + + +def adam_bf16( + param_fp32: torch.Tensor, + param_bf16: torch.Tensor, + g_bf16: torch.Tensor, + m_fp32: torch.Tensor, + v_fp32: torch.Tensor, + beta1: float, + beta2: float, + eps: float, + lr: float, + scale: float, + weight_decay: float, + step: int, +) -> None: + assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda" + assert CHECK_INPUT(param_bf16), "param_bf16 must be contiguous and on cuda" + assert CHECK_INPUT(g_bf16), "g_bf16 must be contiguous and on cuda" + assert CHECK_INPUT(m_fp32), "m_fp32 must be contiguous and on cuda" + assert CHECK_INPUT(v_fp32), "v_fp32 must be contiguous and on cuda" + assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor" + assert param_bf16.dtype == torch.bfloat16, "param_fp16 must be float16 tensor" + assert g_bf16.dtype == torch.bfloat16, "g_bf16 must be bfloat16 tensor" + assert m_fp32.dtype == torch.float32, "m_fp32 must be bfloat16 tensor" + assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor" + assert ( + param_fp32.numel() == param_bf16.numel() + ), "param_fp32 and param_bf16 must have the same number of elements" + assert ( + param_fp32.numel() == g_bf16.numel() + ), "param_fp32 and g_fp16 must have the same number of elements" + assert ( + param_fp32.numel() == m_fp32.numel() + ), "param_fp32 and m_m_fp32 must have the same number of elements" + assert ( + param_fp32.numel() == v_fp32.numel() + ), "param_fp32 and v_fp32 must have the same number of elements" + bias_correction1 = 1 - beta1**step + bias_correction2 = 1 - beta2**step + stream = torch.cuda.current_stream().cuda_stream + if not C.is_bf16_supported(): + raise NotImplementedError(f"bfloat16 is not supported on current GPU") + C.adam_bf16_launcher( + param_fp32.numel(), + param_fp32.data_ptr(), + param_bf16.data_ptr(), + g_bf16.data_ptr(), + m_fp32.data_ptr(), + v_fp32.data_ptr(), + beta1, + beta2, + eps, + lr, + scale, + weight_decay, + bias_correction1, + bias_correction2, + stream, + ) diff --git a/examples/BMTrain/bmtrain/optim/adam.py b/examples/BMTrain/bmtrain/optim/adam.py new file mode 100644 index 00000000..f99c483c --- /dev/null +++ b/examples/BMTrain/bmtrain/optim/adam.py @@ -0,0 +1,252 @@ +import torch +from ..global_var import config +from . import _function as F +import torch.optim._functional +from .. import C +from .. import nccl +import inspect +from ..utils import check_torch_version +from copy import deepcopy +from itertools import chain +from collections import defaultdict + + +class AdamOptimizer(torch.optim.Optimizer): + """ + Adam optimizer support fp16 and bf16. + """ + + _bmtrain_optimizer = True + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + hold_steps=0, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super().__init__(params, defaults) + + self._hold_steps = hold_steps + + def _on_justify_scale(self, old_scale, new_scale): + delta = new_scale / old_scale + for group in self.param_groups: + for p in group["params"]: + if p in self.state: + state = self.state[p] + if len(state) > 0: + if p.dtype == torch.float16: + state["exp_avg"] *= delta + state["exp_avg_sq"] *= delta + + @torch.no_grad() + def step(self, closure=None, scale=1): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes. + """ + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # update parameters + for group in self.param_groups: + for p in group["params"]: + if p.grad is not None and p.requires_grad: + if p.grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + if p.dtype not in [torch.float32, torch.half, torch.bfloat16]: + raise RuntimeError( + "Adam only supports fp32, fp16 and bf16 gradients" + ) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + if p.dtype == torch.float16: + state["exp_avg"] = torch.zeros( + p.size(), dtype=torch.float16, device=p.device + ) # on device + else: + state["exp_avg"] = torch.zeros( + p.size(), dtype=torch.float32, device=p.device + ) # on device + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros( + p.size(), dtype=torch.float32, device=p.device + ) # on device + + if p.dtype != torch.float32: + state["_param_fp32"] = torch.empty( + p.size(), dtype=torch.float32, device=p.device + ) # on device + state["_param_fp32"].copy_(p) + + # update the steps for each param group update + if ("maximize" in group) and (group["maximize"] is True): + grad = -p.grad + else: + grad = p.grad + + if p.dtype == torch.float32: + other_kwargs = {} + if ( + "maximize" + in inspect.signature( + torch.optim._functional.adam + ).parameters + ): + other_kwargs["maximize"] = False + torch.optim._functional.adam( + [p], + [grad / scale], + [state["exp_avg"]], + [state["exp_avg_sq"]], + [], + ( + [state["step"]] + if check_torch_version("1.12.0") < 0 + else [torch.tensor(state["step"])] + ), + amsgrad=False, + beta1=group["betas"][0], + beta2=group["betas"][1], + lr=0.0 if state["step"] < self._hold_steps else group["lr"], + weight_decay=group["weight_decay"], + eps=group["eps"], + **other_kwargs + ) + state["step"] += 1 + else: + f = F.adam_fp16 if p.dtype == torch.float16 else F.adam_bf16 + state["step"] += 1 + f( + state["_param_fp32"], # fp32 + p, # fp16 + grad, # fp16 + state["exp_avg"], # fp16: m + state["exp_avg_sq"], # fp32: v + group["betas"][0], + group["betas"][1], + group["eps"], + 0.0 if state["step"] < self._hold_steps else group["lr"], + scale, + group["weight_decay"], + state["step"], + ) + + return loss + + def get_avg_delta(): + + raise NotImplementedError( + "get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer" + ) + + def get_var_delta(): + + raise NotImplementedError( + "get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer" + ) + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = deepcopy(state_dict) + # Validate the state_dict + groups = self.param_groups + saved_groups = state_dict["param_groups"] + + if len(groups) != len(saved_groups): + raise ValueError( + "loaded state dict has a different number of " "parameter groups" + ) + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError( + "loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group" + ) + + # Update the state + id_map = { + old_id: p + for old_id, p in zip( + chain.from_iterable((g["params"] for g in saved_groups)), + chain.from_iterable((g["params"] for g in groups)), + ) + } + + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state = defaultdict(dict) + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + + if param.dtype != torch.float32 and "_param_fp32" not in v: + v["_param_fp32"] = torch.empty( + param.size(), dtype=torch.float32, device=param.device + ) + v["_param_fp32"].copy_(param) + + for name, dtype in [ + ( + "exp_avg", + ( + torch.float16 + if param.dtype == torch.float16 + else torch.float32 + ), + ), + ("exp_avg_sq", torch.float32), + ("_param_fp32", torch.float32), + ]: + if name in v: + v[name] = v[name].to(param.device).to(dtype) + + state[param] = v + else: + state[k] = v + + # Update parameter groups, setting their 'params' value + def update_group(group, new_group): + new_group["params"] = group["params"] + return new_group + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) + + # TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu + def zero_grad(self, set_to_none: bool = False): + super().zero_grad(set_to_none=set_to_none) diff --git a/examples/BMTrain/bmtrain/optim/adam_offload.py b/examples/BMTrain/bmtrain/optim/adam_offload.py new file mode 100644 index 00000000..f6ea97ba --- /dev/null +++ b/examples/BMTrain/bmtrain/optim/adam_offload.py @@ -0,0 +1,386 @@ +import torch +from ..global_var import config +from . import _function as F +from .. import nccl +import inspect +from ..utils import check_torch_version +from copy import deepcopy +from itertools import chain +from collections import defaultdict +from ._distributed import state_dict_gather + + +class AdamOffloadOptimizer(torch.optim.Optimizer): + """ + Adam optimizer using optimizer offload. + """ + + _bmtrain_optimizer = True + + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + hold_steps=0, + record_delta=False, + ): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= weight_decay: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + self.avg_delta = 0 + self.var_delta = 0 + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + super().__init__(params, defaults) + self._hold_steps = hold_steps + self._events = {} + self.record_delta = record_delta + if self.record_delta: + for group in self.param_groups: + for p in group["params"]: + setattr( + p, + "_delta_info", + ( + torch.tensor( + [0 for i in range(4)], dtype=torch.float32, device="cpu" + ) + ), + ) + + @torch.no_grad() + def step(self, closure=None, scale=1): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes. + """ + + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + # parameters to be updated + update_params = [] + + for group in self.param_groups: + for p in group["params"]: + if p.grad is not None and p.requires_grad: + if p.grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + if p.dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise RuntimeError( + "Adam only supports fp32, fp16 and bf16 gradients" + ) + + state = self.state[p] + # Lazy state initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros( + p.size(), dtype=torch.float32, device="cpu" + ) # on host + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros( + p.size(), dtype=torch.float32, device="cpu" + ) # on host + + if p.dtype == torch.float32: + state["_param_fp32"] = torch.empty( + p.size(), dtype=torch.float32, pin_memory=True + ) # on host + state["_param_fp32"].copy_(p) + + # placeholder + state["_grad_fp32"] = torch.empty( + p.size(), dtype=torch.float32, pin_memory=True + ) # on host + else: + state["_param_fp32"] = torch.empty( + p.size(), dtype=torch.float32, device="cpu" + ) # on host + state["_param_fp32"].copy_(p) + + # placeholder + state["_param_fp16"] = torch.empty( + p.size(), dtype=p.dtype, pin_memory=True + ) # on host + state["_grad_fp16"] = torch.empty( + p.size(), dtype=p.dtype, pin_memory=True + ) # on host + + if p not in self._events: + self._events[p] = torch.cuda.Event() + + update_params.append( + ( + p, + state, + self._events[p], + group["betas"][0], + group["betas"][1], + group["eps"], + group["lr"], + group["weight_decay"], + ) + ) + + # transfer parameters to host asynchronously + for param, state, event, _, _, _, _, _ in update_params: + if param.dtype == torch.float32: + state["_grad_fp32"].copy_(param.grad, non_blocking=True) + else: + state["_grad_fp16"].copy_(param.grad, non_blocking=True) + torch.cuda.current_stream().record_event(event) + sum_delta = 0 + sum_sq_delta = 0 + total_numel = 0 + for param, state, event, beta1, beta2, eps, lr, weight_decay in update_params: + # wait for transfer to host + event.synchronize() + + # update parameters + if param.dtype == torch.float32: + state["_grad_fp32"].mul_(1.0 / scale) + if ("maximize" in group) and (group["maximize"] is True): + grad = -state["_grad_fp32"] + else: + grad = state["_grad_fp32"] + other_kwargs = {} + if ( + "maximize" + in inspect.signature(torch.optim._functional.adam).parameters + ): + other_kwargs["maximize"] = False + torch.optim._functional.adam( + [state["_param_fp32"]], + [grad], + [state["exp_avg"]], + [state["exp_avg_sq"]], + [], + ( + [state["step"]] + if check_torch_version("1.12.0") < 0 + else [torch.tensor(state["step"])] + ), + amsgrad=False, + beta1=beta1, + beta2=beta2, + lr=0.0 if state["step"] < self._hold_steps else lr, + weight_decay=weight_decay, + eps=eps, + **other_kwargs + ) + # transfer parameters back to device asynchronously + param.copy_(state["_param_fp32"], non_blocking=True) + state["step"] += 1 + else: + state["step"] += 1 + if ("maximize" in group) and (group["maximize"] is True): + grad = -state["_grad_fp16"] + else: + grad = state["_grad_fp16"] + F.adam_cpu( + state["_param_fp32"].view(-1), + state["_param_fp16"].view(-1), + param._delta_info if self.record_delta else None, + grad.view(-1), + state["exp_avg"].view(-1), + state["exp_avg_sq"].view(-1), + beta1, + beta2, + eps, + 0.0 if state["step"] < self._hold_steps else lr, + scale, + weight_decay, + state["step"], + ) + total_numel += state["_param_fp16"].numel() + if self.record_delta: + sum_delta += param._delta_info[2].item() + sum_sq_delta += param._delta_info[3].item() + # transfer parameters back to device asynchronously + param.copy_(state["_param_fp16"], non_blocking=True) + if self.record_delta: + self.avg_delta = sum_delta / total_numel + self.var_delta = sum_sq_delta / total_numel - self.avg_delta**2 + + return loss + + def get_avg_delta(self) -> None: + return self.avg_delta if self.record_delta else 0 + + def get_var_delta(self) -> None: + return self.var_delta if self.record_delta else 0 + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the optimizer state. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + + state_dict = deepcopy(state_dict) + # Validate the state_dict + groups = self.param_groups + saved_groups = state_dict["param_groups"] + + if len(groups) != len(saved_groups): + raise ValueError( + "loaded state dict has a different number of " "parameter groups" + ) + param_lens = (len(g["params"]) for g in groups) + saved_lens = (len(g["params"]) for g in saved_groups) + if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): + raise ValueError( + "loaded state dict contains a parameter group " + "that doesn't match the size of optimizer's group" + ) + + # Update the state + id_map = { + old_id: p + for old_id, p in zip( + chain.from_iterable((g["params"] for g in saved_groups)), + chain.from_iterable((g["params"] for g in groups)), + ) + } + + # _param_start_end = chain.from_iterable((g["params_start_end"] for g in saved_groups)) + # Copy state assigned to params (and cast tensors to appropriate types). + # State that is not assigned to params is copied as is (needed for + # backward compatibility). + state = defaultdict(dict) + is_whole = False if "is_whole" not in state_dict else state_dict["is_whole"] + pop_key = [] + for k, v in state_dict["state"].items(): + if k in id_map: + param = id_map[k] + if is_whole and param._start_partition is not None: + for key in ["_param_fp32", "exp_avg_sq", "exp_avg"]: + if key in v: + v[key] = v[key][ + param._start_partition : param._end_partition + ] + elif is_whole and param._start_partition is None: + pop_key.append(param) + + if "_param_fp32" not in v: + with torch.no_grad(): + v["_param_fp32"] = torch.empty( + param.size(), dtype=torch.float32, device="cpu" + ) + v["_param_fp32"].copy_(param) + + for name, dtype in [ + ("exp_avg", torch.float32), + ("exp_avg_sq", torch.float32), + ("_param_fp32", torch.float32), + ]: + if name in v: + v[name] = v[name].to("cpu").to(dtype) + + state[param] = v + if param.dtype == torch.float32: + state[param]["_param_fp32"] = state[param][ + "_param_fp32" + ].pin_memory() # on host + # initialize placeholders + state[param]["_grad_fp32"] = torch.empty( + param.size(), dtype=torch.float32, pin_memory=True + ) # on host + else: + # initialize placeholders + state[param]["_param_fp16"] = torch.empty( + param.size(), dtype=param.dtype, pin_memory=True + ) # on host + state[param]["_grad_fp16"] = torch.empty( + param.size(), dtype=param.dtype, pin_memory=True + ) # on host + else: + state[k] = v + for k in pop_key: + state.pop(k) + + # Update parameter groups, setting their 'params' value + def update_group(group, new_group): + new_group["params"] = group["params"] + return new_group + + param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] + self.__setstate__({"state": state, "param_groups": param_groups}) + + def state_dict(self, gather=False) -> dict: + r"""Returns the state of the optimizer as a :class:`dict`. + + It contains two entries: + + * state - a dict holding current optimization state. Its content + differs between optimizer classes. + * param_groups - a list containing all parameter groups where each + parameter group is a dict + """ + + # Save order indices instead of Tensors + param_mappings = {} + start_index = 0 + + def pack_group(group): + nonlocal start_index + packed = {k: v for k, v in group.items() if k != "params"} + param_mappings.update( + { + id(p): i + for i, p in enumerate(group["params"], start_index) + if id(p) not in param_mappings + } + ) + packed["params"] = [param_mappings[id(p)] for p in group["params"]] + start_index += len(packed["params"]) + return packed + + def cut_states(state): + return { + "step": state["step"], + "exp_avg": state["exp_avg"], + "exp_avg_sq": state["exp_avg_sq"], + "_param_fp32": state["_param_fp32"], + } + + param_groups = [pack_group(g) for g in self.param_groups] + # Remap state to use order indices as keys + packed_state = { + (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): cut_states(v) + for k, v in self.state.items() + } + states = { + "state": packed_state, + "param_groups": param_groups, + } + if gather: + states = state_dict_gather(states) + states["is_whole"] = True + else: + states["is_whole"] = False + + return states + + # TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu + def zero_grad(self, set_to_none: bool = False): + super().zero_grad(set_to_none=set_to_none) diff --git a/examples/BMTrain/bmtrain/optim/optim_manager.py b/examples/BMTrain/bmtrain/optim/optim_manager.py new file mode 100644 index 00000000..1a98ed92 --- /dev/null +++ b/examples/BMTrain/bmtrain/optim/optim_manager.py @@ -0,0 +1,226 @@ +from typing import Optional, Union, List, Dict, Tuple +import torch +from ..loss._function import has_inf_nan +from ..utils import print_rank +from ..lr_scheduler.warmup import WarmupLRScheduler +from .. import nccl +from ..global_var import config + +def check_overflow(param_groups): + # check overflow + has_inf_or_nan = torch.zeros(1, dtype=torch.uint8, device="cuda")[0] + for group in param_groups: + for p in group['params']: + if p.grad is not None: + if p.dtype != torch.float: + has_inf_nan(p.grad, has_inf_or_nan) + if "comm" in config: + nccl.allReduce(has_inf_or_nan.storage(), has_inf_or_nan.storage(), "max", config["comm"]) + + if has_inf_or_nan > 0: + raise OverflowError("Gradient overflow") + +def grad_rescale(param_groups, scale): + for group in param_groups: + for p in group['params']: + if p.grad is not None and p.requires_grad: + p.grad /= scale + +class OptimManager: + """wait cuda stream. Optional: add loss scaler for mix-precision training + + Args: + loss_scale (float): The initial loss scale. Default to None for not using loss scaling. + loss_scale_factor (float): The loss scale factor. + loss_scale_steps (int): The loss scale steps. + + Examples: + >>> optim_manager = bmt.optim.OptimManager(loss_scale=1024) + >>> optim_manager.add_optimizer(optimizer1) + >>> optim_manager.add_optimizer(optimizer2, lr_scheduler2) + >>> for data in dataset: + >>> # forward pass and calculate loss + >>> optim_manager.zero_grad() + >>> optim_manager.backward(loss) + >>> optim_manager.clip_grad_norm(optimizer1.param_groups, max_norm=1.0, norm_type=2) + >>> optim_manager.clip_grad_norm(optimizer2.param_groups, max_norm=2.0, norm_type=2) + >>> optim_manager.step() + """ + def __init__(self, + loss_scale : Optional[float] = None, + loss_scale_factor : float = 2, + loss_scale_steps : int = 1024, + min_loss_scale = 1, + max_loss_scale = float("inf"), + grad_scale : Optional[int] = None, + ): + if loss_scale is not None: + self.loss_scale = loss_scale + self.loss_scale_enabled = True + else: + self.loss_scale = 1 + self.loss_scale_enabled = False + self.steps_since_last_scale = 0 + self.loss_scale_factor = loss_scale_factor if loss_scale_factor > 1 else 1 / loss_scale_factor + self.loss_scale_steps = loss_scale_steps + self.min_loss_scale = min_loss_scale + self.max_loss_scale = max_loss_scale + if grad_scale is None: + grad_scale = config['zero_size'] + self.grad_scale = grad_scale + + self.optimizers = [] + self.lr_schedulers = [] + + def add_optimizer( + self, + optimizer: torch.optim.Optimizer, + lr_scheduler: Optional[WarmupLRScheduler] = None, + ): + """Add optimizer and (optional) its corresponding lr_scheduler into optim_manager. + All optimizers in the same optim_manager share the same loss scale. + + Args: + optim (torch.optim.Optimizer): A pytorch optimizer, e.g. torch.optim.Adam, torch.optim.SGD or bmtrain.optim.AdamOffloadOptimizer + lr_scheduler (Optional[WarmupLRScheduler]): A warmup lr scheduler, e.g. bmt.lr_scheduler.Noam + """ + self.optimizers.append(optimizer) + self.lr_schedulers.append(lr_scheduler) + + def scale_loss(self, loss : torch.Tensor) -> torch.Tensor: + + return loss * ( self.loss_scale / self.grad_scale ) # loss scale + + def backward(self, loss : torch.Tensor): + """ + Backward with loss scale. + + Args: + loss (torch.Tensor): loss + """ + loss = self.scale_loss(loss) + loss.backward() + # some reduce ops of distributed parameter were launched on load stream + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config['load_stream']) + + def zero_grad(self): + """ + This is a helper function to call optimizer.zero_grad() + """ + for optimizer in self.optimizers: + optimizer.zero_grad(set_to_none=False) + + def step(self): + """ + Backward with loss scale. + Synchronize streams before optimizer steps. + + This is a helper function to call optimizer.step() and lr_scheduler.step() and synchronize streams. + + This function can also handle gradient overflow by reducing the loss scale when it occurs. + """ + if self.loss_scale_enabled: + has_overflow = False + for optimizer in self.optimizers: + try: + check_overflow(optimizer.param_groups) + except OverflowError: + has_overflow = True + break + if has_overflow: + print_rank("Gradient overflow, change scale from %lf to %lf" % (self.loss_scale, self.loss_scale / self.loss_scale_factor)) + with torch.no_grad(): + if self.loss_scale > self.min_loss_scale: + self._justify_scale(self.loss_scale / self.loss_scale_factor) + self.zero_grad() + return + for optimizer, lr_scheduler in zip(self.optimizers, self.lr_schedulers): + if hasattr(optimizer, "_bmtrain_optimizer") and optimizer._bmtrain_optimizer: + optimizer.step(scale=self.loss_scale) + else: + if self.loss_scale_enabled: + grad_rescale(optimizer.param_groups, self.loss_scale) + optimizer.step() + + if lr_scheduler is not None: + lr_scheduler.step() + + if self.loss_scale_enabled: + self.steps_since_last_scale += 1 + + if self.steps_since_last_scale >= self.loss_scale_steps and self.loss_scale < self.max_loss_scale: + self._justify_scale(self.loss_scale * self.loss_scale_factor) + + current_stream = torch.cuda.current_stream() + config['load_stream'].wait_stream(current_stream) + + def clip_grad_norm(self, param_groups, max_norm, norm_type=2, eps=1e-6): + """Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized. + max_norm (float or int): max norm of the gradients. + norm_type (float or int): type of the used p-norm. Can be 'inf' for infinity norm. + eps (float): epsilon used to avoid zero division. + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + scale = self.loss_scale + grads = [] + parameters = [p for group in param_groups for p in group['params']] + for p in parameters: + if p.grad is not None: + grads.append(p.grad.data) + else: + grads.append(torch.zeros_like(p.data)) + + if norm_type == 'inf': + total_norm_cuda = max(g.data.abs().max() for g in grads).detach() + nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "max", config["comm"]) + total_norm = total_norm_cuda + else: + norm_type = float(norm_type) + total_norm_cuda = torch.cuda.FloatTensor([0]) + for index, g in enumerate(grads): + param_norm = g.data.float().norm(norm_type) + total_norm_cuda += param_norm ** norm_type + nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "sum", config["comm"]) + total_norm = total_norm_cuda[0] ** (1. / norm_type) + # total_norm = total_norm / scale + # clip_coef = float(max_norm) / (total_norm + eps) + clip_coef = float(max_norm * scale) / (total_norm + eps) + if clip_coef < 1: + for p in parameters: + if p.grad is not None: + p.grad.data.mul_(clip_coef) + return total_norm / scale + + @torch.no_grad() + def _justify_scale(self, scale): + for optimizer in self.optimizers: + if hasattr(optimizer, "_on_justify_scale"): + optimizer._on_justify_scale(self.loss_scale, scale) + self.loss_scale = scale + self.steps_since_last_scale = 0 + + def state_dict(self, gather_opt=False) -> dict: + return { + "optimizers": [opt.state_dict(gather_opt) for opt in self.optimizers], + "lr_schedulers": [lrs.state_dict() if lrs else None for lrs in self.lr_schedulers], + "loss_scale": self.loss_scale, + "loss_scale_enabled": self.loss_scale_enabled, + } + + def load_state_dict(self, state_dict: dict) -> None: + assert len(self.optimizers) == len(state_dict["optimizers"]) + assert len(self.lr_schedulers) == len(state_dict["lr_schedulers"]) + for opt, opt_st in zip(self.optimizers, state_dict["optimizers"]): + opt.load_state_dict(opt_st) + for lrs, lrs_st in zip(self.lr_schedulers, state_dict["lr_schedulers"]): + lrs.load_state_dict(lrs_st) + self.loss_scale = state_dict["loss_scale"] + self.loss_scale_enabled = state_dict["loss_scale_enabled"] diff --git a/examples/BMTrain/bmtrain/param_init.py b/examples/BMTrain/bmtrain/param_init.py new file mode 100644 index 00000000..21f95f25 --- /dev/null +++ b/examples/BMTrain/bmtrain/param_init.py @@ -0,0 +1,105 @@ +from typing import Generator, Iterable, List, Tuple +import torch +from .block_layer import Block +from .parameter import DistributedParameter +from .global_var import config + + +def init_distributed_parameter(params: Iterable[torch.nn.Parameter]): + """Init param of params which is instance of DistributedParameter using param._init_method. + + Args: + params (Iterable[torch.nn.Parameter]): parameter tensors. + + """ + for param in params: + if not isinstance(param, DistributedParameter): + continue + if param._init_method is None: + continue + with torch.no_grad(): + partition_size = param.storage().size() + global_size = partition_size * config["tp_zero_size"] * config["tp_size"] + tmp_storage = param.storage_type()(global_size) + tmp_tensor = torch.tensor([], dtype=param.dtype, device="cuda") + tmp_tensor.set_(tmp_storage, 0, param._tp_original_shape) + + param._init_method(tmp_tensor) + if param._tp_mode and param._tp_split_dim >= 0: + tensor_list = tmp_tensor.chunk( + config["tp_size"], dim=param._tp_split_dim + ) + sub_tensor = tensor_list[config["topology"].tp_id].contiguous() + tmp_tensor = torch.empty( + sub_tensor.shape, device=param.device, dtype=sub_tensor.dtype + ) + tmp_tensor.copy_(sub_tensor) + + if param._tp_mode: + begin = config["tp_zero_rank"] + else: + begin = config["zero_rank"] + end = begin + 1 + + # Pytorch 1.11 changed the API of storage.__getitem__ + torch.tensor([], dtype=param.dtype, device=param.device).set_( + param.storage() + )[:] = torch.tensor([], dtype=param.dtype, device=param.device).set_( + tmp_tensor.storage() + )[ + partition_size * begin : partition_size * end + ] + # param.storage().copy_(tmp_storage[partition_size * config['rank'] : partition_size * (config['rank'] + 1)]) + + +def iterate_parameters(model: torch.nn.Module): + """ + Itterate over the parameters of the model. + """ + for kw, val in model._parameters.items(): + if hasattr(val, "_in_block") and val._in_block: + return [] + yield val + + +def init_parameters(model: torch.nn.Module): + """ + Initialize the parameters of the model by calling the init_method of the distributed parameters. + """ + + modules = model.named_modules() + for module_prefix, module in modules: + if isinstance(module, Block): + module.init_parameters() + else: + init_distributed_parameter(iterate_parameters(module)) + + current_stream = torch.cuda.current_stream() + config["load_stream"].wait_stream(current_stream) + + +def grouped_parameters( + model: torch.nn.Module, +) -> Generator[Tuple[str, List[torch.nn.Parameter]], None, None]: + """ + Iterate over the parameters of the model grouped by the group name. + This is similar to `torch.nn.Module.named_parameters()` . + """ + + ret: List[torch.nn.Parameter] = {} + for module in model.modules(): + if isinstance(module, Block): + for kw, params in module.grouped_parameters(): + if kw not in ret: + ret[kw] = [] + ret[kw].extend(params) + else: + for param in module._parameters.values(): + group = None + if isinstance(param, DistributedParameter): + group = param.group + if group not in ret: + ret[group] = [] + ret[group].append(param) + for kw, val in ret.items(): + yield kw, val diff --git a/examples/BMTrain/bmtrain/parameter.py b/examples/BMTrain/bmtrain/parameter.py new file mode 100644 index 00000000..2dad4a3d --- /dev/null +++ b/examples/BMTrain/bmtrain/parameter.py @@ -0,0 +1,206 @@ +from typing import Callable, Iterable, Optional +import torch +from .utils import round_up +from .global_var import config +from . import nccl +from .distributed import all_gather + + +class DistributedParameter(torch.nn.Parameter): + r""" + DistributedParameter is a subclass of torch.nn.Parameter. + + It scatters the tensor to all the nodes and gathers them when needed. + + Args: + data (Tensor): parameter tensor. + requires_grad (bool, optional): if the parameter requires gradient. + init_method (Callable[['DistributedParameter'], None], optional): the method to initialize the parameter. + group (str, optional): the group name of the parameter. + + **Note**: DistributedParameter must be on the CUDA device. It will transfer the data to device automatically when `__init__` called. + + """ + + _original_shape: torch.Size + _start_partition: int + _end_partition: int + _init_method: Optional[Callable[["DistributedParameter"], None]] + _in_block: bool + _group: Optional[str] + + def __new__( + cls, + data: torch.Tensor, + requires_grad: bool = True, + init_method: Optional[Callable[["DistributedParameter"], None]] = None, + group: Optional[str] = None, + tp_mode: bool = False, + tp_split_dim: int = -1, + ): + if not config["initialized"]: + raise RuntimeError("BMTrain is not initialized") + + num_of_elements = data.numel() + + cuda_tensor = torch.tensor([], dtype=data.dtype, device="cuda") + if tp_mode: + comm = config["tp_zero_comm"] + else: + comm = config["zero_comm"] + world_size = nccl.commCount(comm) + rank = nccl.commRank(comm) + cuda_storage_size = round_up(num_of_elements, world_size) // world_size + + original_shape = data.size() + tp_original_shape = original_shape + if tp_mode and tp_split_dim >= 0: + tp_original_shape = list(original_shape) + tp_original_shape[tp_split_dim] *= config["tp_size"] + + cuda_storage = cuda_tensor.storage_type()(cuda_storage_size) + + start_of_partition = cuda_storage_size * rank + end_of_partition = min(num_of_elements, cuda_storage_size * (rank + 1)) + + # FX: cuda_tensor_size < 0 if num_of_elements is too small + cuda_tensor_size = max(end_of_partition - start_of_partition, 0) + + cuda_tensor.set_(cuda_storage, 0, (cuda_tensor_size,)) + cuda_tensor.copy_(data.view(-1)[start_of_partition:end_of_partition]) + ret = torch.Tensor._make_subclass(cls, cuda_tensor, requires_grad) + + setattr(ret, "_original_shape", original_shape) + setattr(ret, "_start_partition", start_of_partition) + setattr(ret, "_end_partition", end_of_partition) + setattr(ret, "_init_method", init_method) + setattr(ret, "_in_block", False) + setattr(ret, "_group", group if not tp_mode else "tp") + + setattr(ret, "_tp_mode", tp_mode) + setattr(ret, "_zero_comm", comm) + setattr(ret, "_tp_split_dim", tp_split_dim) + setattr(ret, "_tp_original_shape", tp_original_shape) + return ret + + @property + def group(self): + """The group name of the distributed parameter.""" + + return self._group + + def gather(self) -> torch.Tensor: + """Gather the data from ZeRO distributed nodes. + + Return: + torch.Tensor: The gathered data. + + """ + with torch.cuda.stream(config["load_stream"]): + output_tensor = OpAllGather.apply(self) + current_stream = torch.cuda.current_stream() + output_tensor.record_stream(current_stream) + current_stream.wait_stream(config["load_stream"]) + return output_tensor + + def gather_all(self) -> torch.tensor: + """Gather the data from ZeRO and Tensor Parallel distributed nodes. + + Return: + torch.Tensor: The gathered data. + + """ + zero_param = self.gather() + if config["tp_size"] > 1 and self._tp_split_dim >= 0: + output_tensor = all_gather(zero_param, config["tp_comm"]) + if self._tp_split_dim == 1: + output_list = output_tensor.chunk(config["tp_size"], dim=0) + output = torch.cat(output_list, dim=output_list[0].dim() - 1).flatten( + 0, 1 + ) + return output + else: + return output_tensor.flatten(0, 1) + else: + return zero_param + + def tp_gather(self) -> torch.tensor: + """Gather the data from Tensor Parallel distributed nodes. + + Return: + torch.Tensor: The gathered data. + + """ + if config["tp_size"] > 1 and self._tp_split_dim >= 0: + output_tensor = all_gather(self, config["tp_comm"]) + if self._tp_split_dim == 1: + output_list = output_tensor.chunk(config["tp_size"], dim=0) + output = torch.cat(output_list, dim=output_list[0].dim() - 1).flatten( + 0, 1 + ) + return output + else: + return output_tensor.flatten(0, 1) + else: + return self + + def _copy_data(self, data: torch.Tensor): + """Copy data to self.data.""" + self.data.copy_(data.view(-1)[self._start_partition : self._end_partition]) + + +class OpAllGather(torch.autograd.Function): + @staticmethod + def forward(ctx, value: DistributedParameter): + assert isinstance(value, DistributedParameter) + comm = value._zero_comm # config['zero_comm'] + world_size = nccl.commCount(comm) + ctx.comm = comm + ctx.world_size = world_size + + partition_size = value.storage().size() + global_size = partition_size * world_size + + storage = value.storage_type()(global_size) + + nccl.allGather(value.storage(), storage, comm) + + output_tensor = torch.tensor([], dtype=value.dtype, device="cuda") + output_tensor.set_(storage, 0, value._original_shape) + + ctx.partition_size = partition_size + ctx.tensor_size = value.size(0) + return output_tensor + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + if not grad_output.is_contiguous(): + grad_output = grad_output.contiguous() + + grad_storage = grad_output.storage_type()(ctx.partition_size) + grad_output_storage = grad_output.storage() + if grad_output_storage.size() == ctx.partition_size * ctx.world_size: + pass + else: + grad_output_storage.resize_(ctx.partition_size * ctx.world_size) + nccl.reduceScatter(grad_output_storage, grad_storage, "sum", ctx.comm) + grad_tensor = torch.tensor([], dtype=grad_output.dtype, device="cuda") + grad_tensor.set_(grad_storage, 0, (ctx.tensor_size,)) + return grad_tensor + + +class ParameterInitializer: + """ + ParameterInitializer is a helper class that is used to initialize the distributed parameters. + + Similar to functools.partial . + + """ + + def __init__(self, func: Callable, *args, **kwargs) -> None: + self.func = func + self._args = args + self._kwargs = kwargs + + def __call__(self, param: DistributedParameter): + self.func(param, *self._args, **self._kwargs) diff --git a/examples/BMTrain/bmtrain/pipe_layer.py b/examples/BMTrain/bmtrain/pipe_layer.py new file mode 100644 index 00000000..4d3b17ad --- /dev/null +++ b/examples/BMTrain/bmtrain/pipe_layer.py @@ -0,0 +1,314 @@ +from collections import OrderedDict +import copy +import torch +import copy +from typing import Dict, Iterable, Iterator, Tuple, Union, List +import torch + +from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations +from .global_var import config +from . import nccl +from .zero_context import ( + ZeroContext +) +from . import debug +from .block_layer import Block, round_up, _get_param_kw, _block_wrapper + +class PipePreFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, hidden_state, *args): + hidden_state_list = all_gather(hidden_state.clone(), config["pipe_comm"]) + hidden_state_list.requires_grad_() + + batch_related = args[-1] + batch_related_origin = [True if i in args[-1] else False for i in range(len(args[:-1]))] + batch_related_rule = [] + args = args[:-1] + + batch_size = hidden_state.shape[0] + num_micros = config["micros"] + args_list = [[] for _ in range(num_micros)] + input_requires_grad = [] + for arg in args: + if torch.is_tensor(arg): + arg_all = all_gather(arg, config['pipe_comm']) + if arg.dim() == hidden_state.dim() and arg.shape[0] == batch_size: + batch_related_rule.append(True) + arg_all = arg_all.flatten(0, 1).chunk(num_micros, dim=0) + arg_all = [tensor.requires_grad_(arg.requires_grad) for tensor in arg_all] + else: + batch_related_rule.append(False) + arg_all = [arg_all[0].requires_grad_(arg.requires_grad) for i in range(num_micros)] + input_requires_grad.append(arg.requires_grad) + else: + batch_related_rule.append(False) + arg_all = [arg for _ in range(num_micros)] + input_requires_grad.append(False) + for i in range(num_micros): + args_list[i].append(arg_all[i]) + ctx.input_requires_grad = input_requires_grad + ctx.args_list = args_list + if len(batch_related) == 0: + ctx.batch_related = batch_related_rule + else: + ctx.batch_related = batch_related_origin + return hidden_state_list, args_list + + @staticmethod + def backward(ctx, grads, arg_grads): + grads = broadcast(grads, 0, config['pipe_comm']) + topo = config['topology'] + arg_grads = [] + num_micros = config['micros'] + for idx,requires_grad in enumerate(ctx.input_requires_grad): + if requires_grad: + grad = torch.cat([ctx.args_list[m][idx].grad for m in range(num_micros)], dim=0) + grad = all_reduce(grad, "sum", config["pipe_comm"]) + split_size = topo.stages if ctx.batch_related[idx] else num_micros + grad = grad.chunk(split_size) + if ctx.batch_related[idx]: + arg_grads.append(grad[topo.stage_id]) + else: + arg_grads.append(grad[0]) + else: + arg_grads.append(None) + arg_grads.append(None) #for append(batch_related) + return grads.chunk(topo.stages, dim=0)[topo.stage_id], *arg_grads + +class PipePostFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, last_hidden, hidden_states=None, forward_stage_ranges=None, backward_stage_ranges=None, last_hidden_shape=None, return_hidden_states=False): + topo = config['topology'] + ctx.return_hidden_states = return_hidden_states + last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"]) + last_hidden = last_hidden.chunk(topo.stages, dim=0) + output = last_hidden[topo.stage_id] + output.requires_grad_() + + if return_hidden_states: + ctx.stage_id = topo.stage_id + ctx.stages = topo.stages + ctx.backward_stage_ranges = backward_stage_ranges + middle_hiddens = [] + for stage_id in range(ctx.stages): + if ctx.stage_id == stage_id: + middle_hidden = hidden_states + else: + middle_shape = (forward_stage_ranges[stage_id],) + last_hidden_shape + middle_hidden = torch.zeros(middle_shape, device=hidden_states.device, dtype=hidden_states.dtype) + middle_hidden = broadcast(middle_hidden, stage_id, config["pipe_comm"]) + middle_hidden = middle_hidden.chunk(ctx.stages, dim=1) + middle_hidden = middle_hidden[ctx.stage_id].clone() + middle_hiddens.append(middle_hidden) + middle_hiddens = torch.cat(middle_hiddens, dim=0) + middle_hiddens.requires_grad_() + return output, middle_hiddens + else: + return output + + @staticmethod + def backward(ctx, grads, grad_middle=None): + grad_list = all_gather(grads, config["pipe_comm"]) + grad_list = grad_list.flatten(start_dim=0, end_dim=1) + + if ctx.return_hidden_states: + for stage_id in range(ctx.stages): + layer_range = ctx.backward_stage_ranges[stage_id] + grad_middle_state = grad_middle[layer_range] + grad_middle_state = all_gather(grad_middle_state.transpose(0,1), config["pipe_comm"]) + grad_middle_state = grad_middle_state.flatten(start_dim=0, end_dim=1).transpose(0, 1) + if ctx.stage_id == stage_id: + grad_hidden_state_list = grad_middle_state + return grad_list, grad_hidden_state_list, None, None, None, None + else: + return grad_list + +class StagePreFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, stage_id): + ctx.stage_id = stage_id + ctx.is_first_stage = stage_id == 0 + ctx.is_last_stage = stage_id == config['pipe_size'] - 1 + if not ctx.is_first_stage: + input = recv_activations(stage_id - 1, config['pipe_comm']) + input.requires_grad_() + return input + return input + + @staticmethod + def backward(ctx, grad_outputs): + if not ctx.is_first_stage: + send_data = grad_outputs[0] if isinstance(grad_outputs, tuple) else grad_outputs + current_stream = torch.cuda.current_stream() + with torch.cuda.stream(config['pp_comm_stream']): + config['pp_comm_stream'].wait_stream(current_stream) + send_data.record_stream(config['pp_comm_stream']) + send_activations(send_data, ctx.stage_id - 1, config['pipe_comm']) + return grad_outputs, None + +class StagePostFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, outputs, stage_id): + ctx.stage_id = stage_id + ctx.is_first_stage = stage_id == 0 + ctx.is_last_stage = stage_id == config['pipe_size'] - 1 + if not ctx.is_last_stage: + send_data = outputs[0] if isinstance(outputs, tuple) else outputs + current_stream = torch.cuda.current_stream() + with torch.cuda.stream(config['pp_comm_stream']): + config['pp_comm_stream'].wait_stream(current_stream) + send_data.record_stream(config['pp_comm_stream']) + send_activations(send_data.detach(), stage_id + 1, config['pipe_comm']) + return outputs + + @staticmethod + def backward(ctx, grad_outputs): + if not ctx.is_last_stage: + pre_grad_inputs = recv_activations(ctx.stage_id + 1, config['pipe_comm']) + return pre_grad_inputs, None + return grad_outputs, None + + +class PipelineTransformerBlockList(torch.nn.Module): + r""" + TransformerBlockList is a list of Blocks. + + This is designed to reduce the communication overhead by overlapping the computation and reduce_scatter operation during backward pass. + + It is similar to `torch.nn.ModuleList` but with the difference when calling .forward() and .backward(). + + Example: + >>> module_list = [ ... ] + >>> normal_module_list = torch.nn.ModuleList(module_list) + >>> transformer_module_list = PipelineTransformerBlockList(module_list) + >>> # Calling normal module list + >>> for layer in normal_module_list: + >>> hidden_state = layer.forward(hidden_state, ...) + >>> # Calling transformer module list + >>> hidden_state = transformer_module_list(hidden_state, ...) + + """ + _modules: Dict[str, Block] + + def __init__(self, modules: Iterable[torch.nn.Module], num_hidden=1) -> None: + super().__init__() + self.num_hidden = num_hidden + self._modules = {} + self.layer_ids = [] + topo = config["topology"] + self.stages = topo.stages + self.stage_id = topo.stage_id + self.pipe_idx = topo.pipe_idx + module_dict = {} + for idx, module in enumerate(modules): + module = _block_wrapper(module, module_dict, "PIPE") + module._zero_level = 2 #currently, only support ZeRO-2 in pipeline mode + self._modules[str(idx)] = module + + self.layer_ids = self.get_range_by_stage_id(self.stage_id) + + pre_module = None + for i,layer_id in enumerate(self.layer_ids): + module = self._modules[str(layer_id)] + module.set_pre_module(pre_module) + pre_module = module + module._is_first_layer = False + module._is_last_layer = False + + self._modules[str(self.layer_ids[0])]._is_first_layer = True + self._modules[str(self.layer_ids[-1])]._is_last_layer = True + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self) -> Iterator[Block]: + return iter(self._modules.values()) + + def __getitem__(self, index: Union[int, str]) -> Block: + return self._modules[str(index)] + + def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=False): + self.return_hidden_states = return_hidden_states + batch_size = hidden_state.shape[0] + num_micros = config["micros"] + args = args + (batch_related, ) + hidden_state.requires_grad_() + hidden_state_list, args_list = PipePreFunction.apply(hidden_state, *args) + + hidden_state_list = hidden_state_list.flatten(0, 1).chunk(num_micros, dim=0) + outputs = [] + hidden_states = [] + + for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)): + micro_hidden_states = [] + + hidden_state = StagePreFunction.apply(hidden_state, self.stage_id) + + for idx,layer_id in enumerate(self.layer_ids): + self._modules[str(layer_id)]._micro_idx = micro_idx + if return_hidden_states: + micro_hidden_states.append(hidden_state) + hidden_state = self._modules[str(layer_id)](hidden_state, *arg) + hidden_state = StagePostFunction.apply(hidden_state, self.stage_id) + + outputs.append(hidden_state) + if return_hidden_states: + hidden_states.append(torch.stack(micro_hidden_states, dim=0)) + + last_hidden = torch.cat(outputs, dim=0) + last_hidden_shape = last_hidden.shape + + if return_hidden_states: + hidden_states = torch.cat(hidden_states, dim=1) + forward_stage_ranges = [] + backward_stage_ranges = [] + for stage_id in range(self.stages): + forward_stage_ranges.append(self.get_part_len_by_stage_id(stage_id)) + backward_stage_ranges.append(self.get_range_by_stage_id(stage_id)) + outputs, hidden_states = PipePostFunction.apply(last_hidden, hidden_states, forward_stage_ranges, backward_stage_ranges, last_hidden_shape, return_hidden_states) + return outputs, hidden_states + else: + outputs = PipePostFunction.apply(last_hidden) + return outputs + + def get_range_by_stage_id(self, stage_id : int) -> List[int]: + part_lens = [0]+[self.get_part_len_by_stage_id(i) for i in range(stage_id+1)] + start = sum(part_lens[:stage_id+1]) + end = start + part_lens[stage_id+1] + return range(start, end) + + def get_part_len_by_stage_id(self, stage_id : int) -> int: + return len(self) // self.stages + (stage_id < (len(self) % self.stages)) + + def get_stage_by_layer_id(self, layer_id : int) -> int: + part_len = len(self) // self.stages + rest = len(self) % self.stages + if layer_id // (part_len + 1) < rest: + return layer_id // (part_len + 1) + else: + return rest + (layer_id - rest * (part_len+1)) // part_len + + def _save_to_state_dict(self, destination, prefix, keep_vars): + for name, module in self._modules.items(): + idx = int(name) + name = prefix + name + '.' + + dst = OrderedDict() # creates an temporary ordered dict + dst._metadata = OrderedDict() + + if idx in self.layer_ids: + with torch.no_grad(): + with ZeroContext(module, pipe=True): + module._module.state_dict(destination=dst, prefix=name, keep_vars=False) + + if config["topology"].pp_zero_id == 0: + if config["rank"] == 0: + destination.update(dst) + else: + assert list(dst.keys()) == [name+n for n, parameter in module._module.named_parameters()] + for key, tensor in dst.items(): + send_activations(tensor.cuda(), 0, config['pipe_comm']) + if config['rank'] == 0 and idx not in self.layer_ids: + for n, parameter in module._module.named_parameters(): + destination[name+n] = recv_activations(self.get_stage_by_layer_id(idx), config['pipe_comm']).cpu() + diff --git a/examples/BMTrain/bmtrain/store.py b/examples/BMTrain/bmtrain/store.py new file mode 100644 index 00000000..2a3ee02c --- /dev/null +++ b/examples/BMTrain/bmtrain/store.py @@ -0,0 +1,325 @@ +from collections import OrderedDict +from typing import Dict +import torch + +from .pipe_layer import PipelineTransformerBlockList +from .block_layer import TransformerBlockList +from .global_var import config +from .block_layer import Block +from . import nccl +import io, pickle +from typing import Mapping +import threading +import bmtrain as bmt + +def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix): + if isinstance(model, Block): + if rank != 0: + destination = OrderedDict() # creates an temporary ordered dict + destination._metadata = OrderedDict() + model.state_dict(destination=destination, prefix=prefix, keep_vars=False) + else: + if rank != 0: + destination = OrderedDict() # creates an temporary ordered dict + destination._metadata = OrderedDict() + model._save_to_state_dict(destination, prefix, False) + +def _save_to_local_rank0(model : torch.nn.Module, destination=None, prefix=''): + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version) + _save_to_state_dict(model, config['local_rank'], destination, prefix) + for name, module in model._modules.items(): + if module is not None: + _save_to_local_rank0(module, destination, prefix + name + '.') + for hook in model._state_dict_hooks.values(): + hook_result = hook(model, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + +def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''): + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version) + if not isinstance(model, PipelineTransformerBlockList): + _save_to_state_dict(model, config['rank'], destination, prefix) + for name, module in model._modules.items(): + if module is not None: + _save_to_rank0(module, destination, prefix + name + '.') + for hook in model._state_dict_hooks.values(): + hook_result = hook(model, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + else: + model._save_to_state_dict(destination, prefix, False) + return destination + +def _save_to_infer_model(model : torch.nn.Module, infer_model, destination=None, prefix=''): + config['save_param_to_cpu'] = False + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version) + _save_to_state_dict(model, config['local_rank'], destination, prefix) + for name, module in model._modules.items(): + if module is not None: + if isinstance(module, TransformerBlockList): + for local_name, local_module in module._modules.items(): + local_state_dict = _save_to_local_rank0(local_module, None, prefix + name + "." + local_name + '.') + if config['local_rank'] == 0: + infer_model.load_layer_state_dict(local_state_dict) + else: + _save_to_infer_model(module, infer_model, destination, prefix + name + '.') + for hook in model._state_dict_hooks.values(): + hook_result = hook(model, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + + if config['local_rank'] == 0: + infer_model.load_layer_state_dict(destination) + + +def async_save_to_file(state_dict, file_path): + torch.save(state_dict, file_path) + config['finish_save'] = True + print("finish save state_dict to ", file_path) + +def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False): + """Saves the model to the file. + + Similar to torch.save, but it used for distributed modules. + + Args: + model (torch.nn.Module): The model to be saved. + file_name (str): The file name of the checkpoint. + non_blocking (bool): Whether to asynchronously save state_dict to file + + + Examples: + >>> bmtrain.save(model, "model.pt") + """ + torch.cuda.synchronize() + state_dict = _save_to_rank0(model) + if config["rank"] == 0: + if non_blocking is False: + torch.save(state_dict, file_name) + else: + if 'finish_save' not in config: + config['finish_save'] = True + + if config['finish_save'] is False: + config['save_thread'].join() + + config['finish_save'] = False + config['save_thread'] = threading.Thread(target=async_save_to_file, args=(state_dict, file_name)) + config['save_thread'].start() + bmt.synchronize() + +DTYPE_LIST = [ + torch.float64, + torch.float32, + torch.float16, + torch.int64, + torch.int32, + torch.int16, + torch.int8, + torch.bfloat16, + torch.bool +] + +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + +def allgather_objects(obj): + if bmt.world_size() == 1: + return [obj] + + with torch.no_grad(): + data_bytes: bytes = pickle.dumps(obj) + data_length: int = len(data_bytes) + + gpu_data_length = torch.tensor([data_length], device="cuda", dtype=torch.long) + gathered_length = bmt.distributed.all_gather(gpu_data_length).view(-1).cpu() + max_data_length = gathered_length.max().item() + + gpu_data_bytes = torch.zeros(max_data_length, dtype=torch.uint8, device="cuda") + byte_storage = torch.ByteStorage.from_buffer(data_bytes) + gpu_data_bytes[:data_length] = torch.ByteTensor(byte_storage) + + gathered_data = bmt.distributed.all_gather(gpu_data_bytes).cpu() + + ret = [] + for i in range(gathered_data.size(0)): + data_bytes = gathered_data[i, : gathered_length[i].item()].numpy().tobytes() + ret.append(pickle.loads(data_bytes)) + return ret + +def broadcast_object(obj, comm, src = 0): + if nccl.commRank(comm) == src: + f = io.BytesIO() + _pickler(f).dump(obj) + byte_storage = torch.ByteStorage.from_buffer(f.getvalue()) + # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype. + # Otherwise, it will casue 100X slowdown. + # See: https://github.com/pytorch/pytorch/issues/65696 + byte_tensor = torch.ByteTensor(byte_storage).cuda() + local_size = torch.LongTensor([byte_tensor.numel()]).cuda() + + nccl.broadcast( + local_size.storage(), + local_size.storage(), + src, + comm + ) + nccl.broadcast( + byte_tensor.storage(), + byte_tensor.storage(), + src, + comm + ) + else: + local_size = torch.LongTensor([0]).cuda() + nccl.broadcast( + local_size.storage(), + local_size.storage(), + src, + comm + ) + byte_tensor_size = local_size[0].item() + byte_tensor = torch.empty(int(byte_tensor_size), dtype=torch.uint8, device="cuda") + nccl.broadcast( + byte_tensor.storage(), + byte_tensor.storage(), + src, + comm + ) + buf = byte_tensor.cpu().numpy().tobytes() + obj = _unpickler(io.BytesIO(buf)).load() + return obj + +# Must be a Mapping after pytorch 1.12.0 +class DistributedTensorWrapper: + def __init__(self, tensor, shape=None): + self._dtype = tensor.dtype + self._device = tensor.device + self.shape = shape + self.tensor = tensor + + def broadcast(self): + output_param = torch.empty(self.shape, dtype=self._dtype, device="cuda") + if config['rank'] == 0: + input_param = self.tensor + if input_param.is_cuda: + input_param = input_param.clone().contiguous() + else: + input_param = input_param.cuda().contiguous() + + nccl.broadcast( + input_param.storage(), + output_param.storage(), + 0, + config['comm'] + ) + else: + nccl.broadcast( + output_param.storage(), + output_param.storage(), + 0, + config['comm'] + ) + return output_param + + def copy(self): + return self.tensor + + def __getattribute__(self, name): + if name == "tensor" or name == "shape": + return object.__getattribute__(self, name) + else: + try: + return object.__getattribute__(self, name) + except AttributeError: + pass + + return getattr(self.tensor, name) + +class DistributedStateDictWrapper(Mapping): + def __init__(self, state_dict : Dict) -> None: + self._state_dict = state_dict + self._metadata = broadcast_object(getattr(state_dict, "_metadata", None), config["comm"]) + + def __getitem__(self, key : str): + tmp_shape = torch.zeros(32, device="cuda", dtype=torch.int32) + if config['rank'] == 0: + input_param : torch.Tensor = self._state_dict[key] + shape_list = torch.tensor(list(input_param.size()), device="cuda", dtype=torch.int32) + dtype_idx = DTYPE_LIST.index(input_param.dtype) + + assert dtype_idx != -1, "Unknown data type %s" % input_param.dtype + + tmp_shape[0] = shape_list.size(0) + tmp_shape[1] = dtype_idx + tmp_shape[2:2 + shape_list.size(0)] = shape_list + + nccl.broadcast( + tmp_shape.storage(), + tmp_shape.storage(), + 0, + config['comm'] + ) + + shape_list_size = tmp_shape[0].item() + dtype_idx = tmp_shape[1].item() + shape_list = torch.Size(tmp_shape[2: 2 + shape_list_size].tolist()) + + if config['rank'] != 0: + return DistributedTensorWrapper(torch.tensor([], dtype=DTYPE_LIST[dtype_idx], device="cuda"), shape=shape_list) + else: + return DistributedTensorWrapper(self._state_dict[key], shape=shape_list) + + + + def copy(self): + return self + + def __len__(self): + return broadcast_object(len(self._state_dict), config["comm"]) + + def __contains__(self, key : str): + return broadcast_object(key in self._state_dict, config["comm"]) + + def keys(self): + return broadcast_object(list(self._state_dict.keys()),config["comm"]) + + def __iter__(self): + # pytorch 1.12.0 updated the load_state_dict method, which needs the state_dict to be a `Mapping`. + return iter(self.keys()) + +def load(model : torch.nn.Module, file_name : str, strict : bool = True): + """Loads the model from the file. + + Similar to torch.load, but it uses less memory when loading large models. + + Args: + model (torch.nn.Module): The model to be loaded. + file_name (str): The file name of the checkpoint. + strict (bool): Strict option of `load_state_dict`. + + Example: + >>> bmtrain.load(model, "model.pt", strict=True) + """ + if config['rank'] == 0: + state_dict = DistributedStateDictWrapper(torch.load(file_name)) + else: + state_dict = DistributedStateDictWrapper({}) + + ret = model.load_state_dict( + state_dict, + strict = strict + ) + torch.cuda.synchronize() + return ret diff --git a/examples/BMTrain/bmtrain/synchronize.py b/examples/BMTrain/bmtrain/synchronize.py new file mode 100644 index 00000000..87619159 --- /dev/null +++ b/examples/BMTrain/bmtrain/synchronize.py @@ -0,0 +1,73 @@ +import torch +from . import distributed, nccl +from .global_var import config +import warnings +from typing import Optional + + +def synchronize(): + """ + Synchronize all the workers across all nodes. (both CPU and GPU are synchronized) + """ + if not config["initialized"]: + raise RuntimeError("BMTrain is not initialized") + + with torch.cuda.stream(config["barrier_stream"]): + barrier = torch.cuda.FloatTensor([1]) + nccl.allReduce(barrier.storage(), barrier.storage(), "sum", config["comm"]) + config["barrier_stream"].synchronize() + + +def wait_loader(): + """ + Clac_stream (normally current stream) wait latest loader event, and set a new one. + """ + if not config["initialized"]: + raise RuntimeError("BMTrain is not initialized") + + config["load_event"].synchronize() + config["calc_stream"].record_event(config["load_event"]) + + +def sum_loss(loss: torch.Tensor, comm: Optional[nccl.NCCLCommunicator] = None): + """ + Sum the loss across all workers. + + This is a helper function to reduce the loss across all workers. + """ + if comm is None: + comm = config["comm"] + warnings.warn( + "bmtrain.sum_loss is deprecated and will be removed in later version. Use bmtrain.distributed.all_reduce instead.", + DeprecationWarning, + ) + + return distributed.all_reduce(loss, "avg", comm) + + +def gather_result(result: torch.Tensor): + """ + Gather result across all workers. + """ + warnings.warn( + "bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.", + DeprecationWarning, + ) + if result.storage_offset() != 0 or result.storage().size() != result.numel(): + # Create a clone of the original tensor if it's a slice + result = result.clone() + + output_cuda = True + if not result.is_cuda: + result = result.cuda() + output_cuda = False + ret = torch.empty( + (result.shape[0] * config["world_size"], *list(result.shape[1:])), + device=result.device, + dtype=result.dtype, + ) + nccl.allGather(result.storage(), ret.storage(), config["comm"]) + if output_cuda: + return ret + else: + return ret.cpu() diff --git a/examples/BMTrain/bmtrain/utils.py b/examples/BMTrain/bmtrain/utils.py new file mode 100644 index 00000000..daa4c595 --- /dev/null +++ b/examples/BMTrain/bmtrain/utils.py @@ -0,0 +1,184 @@ +import torch +import sys +from typing import Any, Dict, Iterable, Optional +from .global_var import config +import os +import ctypes + +ALIGN = 4 +ROW_WIDTH = 60 + + +def check_torch_version(version_str): + """ + Checks if the current torch version is greater than or equal to the given version. + version_str (str): The version to compare with, in the format of "x.y.z" ,and the func will convert it into a int value of x*100+y*10+z. + """ + version_int_arr = [int(v) for v in version_str.split(".")] + + version_int = ( + version_int_arr[0] * 10000 + version_int_arr[1] * 100 + version_int_arr[2] + ) + torch_version = torch.__version__.split("+")[0] + current_version_int_arr = [int(v) for v in torch_version.split(".")] + current_version_int = ( + current_version_int_arr[0] * 10000 + + current_version_int_arr[1] * 100 + + current_version_int_arr[2] + ) + return current_version_int - version_int + + +def load_nccl_pypi(): + """ + Check if current nccl is avaliable. + """ + try: + import nvidia.nccl + except: + raise ImportError("Run pip install nvidia-nccl-cu11 >=2.14.3 first") + + path = os.path.join(os.path.dirname(nvidia.nccl.__file__), "lib") + for file_so in os.listdir(path): + file_split = file_so.split(".") + if file_split[-1] == "so" or (len(file_split) > 1 and file_split[-2] == "so"): + ctypes.CDLL(os.path.join(path, file_so)) + + +def round_up(x, d): + """ + Return (x + d - 1) // d * d + """ + return (x + d - 1) // d * d + + +def print_dict(title: str, content: Dict[str, Any], file=sys.stdout): + """ + Print Dict to file. + """ + max_kw_len = max([len(kw) for kw in content.keys()]) + max_kw_len = round_up(max_kw_len + 3, 4) + + raw_content = "" + + for kw, val in content.items(): + raw_content += kw + " :" + " " * (max_kw_len - len(kw) - 2) + raw_val = "%s" % val + + len_val_row = ROW_WIDTH - max_kw_len + st = 0 + if len(raw_val) == 0: + raw_val = " " + while st < len(raw_val): + if st > 0: + raw_content += " " * max_kw_len + raw_content += raw_val[st : st + len_val_row] + "\n" + st += len_val_row + + print_block(title, raw_content, file) + + +def print_block(title: str, content: Optional[str] = None, file=sys.stdout): + """ + Print content to file. + """ + left_title = (ROW_WIDTH - len(title) - 2) // 2 + right_title = ROW_WIDTH - len(title) - 2 - left_title + + print("=" * left_title + " " + title + " " + "=" * right_title, file=file) + if content is not None: + print(content, file=file) + + +def print_rank(*args, rank=0, **kwargs): + """ + Prints the message only on the `rank` of the process. + + Args: + *args: The arguments to be printed. + rank (int): The rank id of the process to print. + **kwargs: The keyword arguments to be printed. + + """ + if config["rank"] == rank: + print(*args, **kwargs) + + +def see_memory(message, detail=False): + """ + Outputs a message followed by GPU memory status summary on rank 0. + At the end of the function, the starting point in tracking maximum GPU memory will be reset. + + Args: + message (str): The message to be printed. It can be used to distinguish between other outputs. + detail (bool): Whether to print memory status in a detailed way or in a concise way. Default to false. + + Example: + >>> bmt.see_memory("before forward") + >>> # forward_step() + >>> bmt.see_memory("after forward") + + """ + print_rank(message) + if detail: + print_rank(torch.cuda.memory_summary()) + else: + print_rank( + f""" + ======================================================================================= + memory_allocated {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB + max_memory_allocated {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB + ======================================================================================= + """ + ) + torch.cuda.reset_peak_memory_stats() + + +def tp_split_tensor(tensor, split_dim): + """ + Outpus the tensor with config["toplogy"].tp_id split at split dim. + + Args: + tensor (torch.tensor): The tensor to be splited. + split_dim (int): The dim to split the input tensor. + + """ + tensor_list = tensor.chunk(config["tp_size"], dim=split_dim) + sub_tensor = tensor_list[config["topology"].tp_id].contiguous() + tmp_tensor = torch.empty( + sub_tensor.shape, device=sub_tensor.device, dtype=sub_tensor.dtype + ) + tmp_tensor.copy_(sub_tensor) + return tmp_tensor + + +class AverageRecorder: + """A utility class to record the average value of a quantity over time. + + Args: + alpha (float): The decay factor of the average. + start_value (float): The initial value of the average. + + Use `.value` to get the current average value. + It is calculated as `alpha * old_value + (1 - alpha) * new_value`. + + """ + + def __init__(self, alpha=0.9, start_value=0): + self._value = start_value + self.alpha = alpha + self._steps = 0 + + def record(self, v): + """Records a new value. + Args: + v (float): The new value. + """ + self._value = self._value * self.alpha + v * (1 - self.alpha) + self._steps += 1 + + @property + def value(self): + if self._steps <= 0: + return self._value + return self._value / (1 - pow(self.alpha, self._steps)) diff --git a/examples/BMTrain/bmtrain/wrapper.py b/examples/BMTrain/bmtrain/wrapper.py new file mode 100644 index 00000000..e64fd5ba --- /dev/null +++ b/examples/BMTrain/bmtrain/wrapper.py @@ -0,0 +1,54 @@ +import torch +from .block_layer import Block, TransformerBlockList +from .layer import DistributedModule, DistributedParameter + + +def make_distributed(model: torch.nn.Module): + for kw in list(model._parameters.keys()): + if model._parameters[kw] is not None: + if not isinstance(model._parameters[kw], DistributedParameter): + model._parameters[kw] = DistributedParameter( + model._parameters[kw], + requires_grad=model._parameters[kw].requires_grad, + ) + + for kw in list(model._buffers.keys()): + if model._buffers[kw] is not None: + model._buffers[kw] = model._buffers[kw].cuda() + is_module_list = isinstance(model, torch.nn.ModuleList) + pre_module = None + for kw in list(model._modules.keys()): + if is_module_list: + if not isinstance(model._modules[kw], Block): + model._modules[kw] = Block(model_wrapper_dispatch(model._modules[kw])) + if pre_module is not None: + model._modules[kw].set_pre_module(pre_module) + pre_module = model._modules[kw] + else: + model._modules[kw] = model_wrapper_dispatch(model._modules[kw]) + + model.__class__ = type( + "bmtrain.Distributed" + model.__class__.__name__, + (model.__class__, DistributedModule), + {}, + ) + return model + + +def model_wrapper_dispatch(model: torch.nn.Module): + if isinstance(model, TransformerBlockList): + return model + elif isinstance(model, DistributedModule): + return model + elif isinstance(model, Block): + return model + else: + return make_distributed(model) + + +def BMTrainModelWrapper(model: torch.nn.Module) -> torch.nn.Module: + """ + Automatically wrap a model in a BMTrain model. + Replaces all parameters with DistributedParameter, all modules with DistributedModule, and modules in ModuleList with Block. + """ + return model_wrapper_dispatch(model) diff --git a/examples/BMTrain/bmtrain/zero_context.py b/examples/BMTrain/bmtrain/zero_context.py new file mode 100644 index 00000000..8a74b3f8 --- /dev/null +++ b/examples/BMTrain/bmtrain/zero_context.py @@ -0,0 +1,203 @@ +import torch +from . import nccl +from .global_var import config +from .synchronize import wait_loader + + +class ZeroContext: + """ZeroContext is a helper class to Gather parameters before module forward and reduce scatter + gradients after module backward. + + Args: + block (BLock): Input Block. + ctx_dict (dict): block._layer_dict. + pipe (bool): True if use pipe parallel. + + """ + + def __init__(self, block: "Block", ctx_dict: dict = None, pipe=False) -> None: + self.block = block + self.ctx_dict = ctx_dict + self._param_buffer = {} + self._grad_buffer = {} + self._param_tensor = {} + self._grad_tensor = {} + self._need_release = False + + def enter(self, flag=0, requires_grad=False): + """ + Gather parameters before module forward and init grad buffer before backward. + """ + if self.block._ready: + return + self.block._ready = True + self._need_release = True + + wait_loader() + with torch.cuda.stream(config["load_stream"]): + for kw, val in self.block._storage_info.items(): + assert self.block._storage_params[kw].is_cuda + assert kw not in self._grad_buffer + assert kw not in self._param_buffer + local_param = self.block._storage_params[kw] + + storage_type = local_param.storage_type() + if flag != 2: + self._param_buffer[kw] = storage_type( + val["partition_size"] * val["world_size"] + ) + self._param_tensor[kw] = torch.tensor( + [], + dtype=self._param_buffer[kw].dtype, + device=self._param_buffer[kw].device, + ).set_(self._param_buffer[kw]) + + if requires_grad and local_param.requires_grad: + self._grad_buffer[kw] = storage_type( + val["partition_size"] * val["world_size"] + ) + self._grad_tensor[kw] = ( + torch.tensor( + [], + dtype=self._grad_buffer[kw].dtype, + device=self._grad_buffer[kw].device, + ) + .set_(self._grad_buffer[kw]) + .zero_() + ) + if flag != 2: + nccl.groupStart() + for kw, val in self.block._storage_info.items(): + nccl.allGather( + self.block._storage_params[kw].storage(), + self._param_buffer[kw], + val["zero_comm"], + ) + nccl.groupEnd() + + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(config["load_stream"]) + + # set wait stream for each storage + for kw in self.block._storage_info.keys(): + if flag != 2: + self._param_tensor[kw].record_stream(current_stream) + if requires_grad and kw in self._grad_tensor: + self._grad_tensor[kw].record_stream(current_stream) + + # update parameters in block + for param in self.block._param_info: + kw_name = param["kw_name"] + offset = param["offset"] + shape = param["shape"] + + if flag != 2: + dtype = self._param_buffer[kw_name].dtype + device = self._param_buffer[kw_name].device + param["parameter"].data = torch.tensor( + [], dtype=dtype, device=device + ).set_(self._param_buffer[kw_name], offset, shape) + else: + dtype = param["parameter"].data.dtype + device = param["parameter"].data.device + param["parameter"].data = torch.tensor( + [], dtype=dtype, device=device + ).set_(self.ctx_dict[kw_name], offset, shape) + + if ( + requires_grad + and kw_name in self._grad_buffer + and param["parameter"].requires_grad + ): + param["parameter"].grad = torch.tensor( + [], dtype=dtype, device=device + ).set_(self._grad_buffer[kw_name], offset, shape) + + def __enter__(self): + self.enter() + + def exit(self, flag=0, backward=False): + """ + Reduce scatter gradients when backward and release all parameters from buffer to block_storge when forward is done. + """ + if not self._need_release: + return + self._need_release = False + self.block._ready = False + if backward: + for kw, val in self.block._storage_info.items(): + local_param = self.block._storage_params[kw] + + # accumulate previous gradient + if local_param.requires_grad: + if local_param.grad is None: + grad_storage = val["storage_type"]( + val["partition_size"] + ) # initialize gradient if not exist + local_param.grad = ( + torch.tensor( + [], dtype=grad_storage.dtype, device=grad_storage.device + ) + .set_(grad_storage) + .zero_() + ) + else: + self._grad_tensor[kw][ + val["begin"] : val["end"] + ] += local_param.grad + + current_stream = torch.cuda.current_stream() + config["load_stream"].wait_stream(current_stream) # wait for backward + + with torch.cuda.stream(config["load_stream"]): + nccl.groupStart() + for kw, val in self.block._storage_info.items(): + local_param = self.block._storage_params[kw] + + # scatter gradient + if local_param.requires_grad: + nccl.reduceScatter( + self._grad_buffer[kw], + local_param.grad.storage(), + "sum", + val["zero_comm"], + ) + nccl.groupEnd() + + # set wait stream for each storage + for kw in self._grad_tensor.keys(): + # grads can not be freed until reduce ops finish + self._grad_tensor[kw].record_stream(config["load_stream"]) + + # Release all parameters from buffer to block_storge + for param in self.block._param_info: + kw_name = param["kw_name"] + dtype = self.block._storage_params[kw_name].dtype + device = self.block._storage_params[kw_name].device + if "begin" not in param: + param["parameter"].data = torch.tensor([], dtype=dtype, device=device) + param["parameter"].grad = None + continue + begin = param["begin"] + end = param["end"] + param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_( + self.block._storage_params[kw_name].storage(), begin, end + ) + if ( + param["parameter"].requires_grad + and self.block._storage_params[kw_name].grad is not None + ): + param["parameter"].grad = torch.tensor( + [], dtype=dtype, device=device + ).set_(self.block._storage_params[kw_name].grad.storage(), begin, end) + if flag == 1: + for i in self._param_buffer: + self.ctx_dict[i] = self._param_buffer[i] + self._grad_tensor = {} + self._param_tensor = {} + self._grad_buffer = {} + self._param_buffer = {} + + def __exit__(self, exc_type, exc_val, exc_tb): + # reduce scatter gradients + self.exit() diff --git a/examples/BMTrain/cmake/FindNCCL.cmake b/examples/BMTrain/cmake/FindNCCL.cmake new file mode 100644 index 00000000..2af8e3b9 --- /dev/null +++ b/examples/BMTrain/cmake/FindNCCL.cmake @@ -0,0 +1,100 @@ +list(APPEND NCCL_ROOT $ENV{NCCL_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR}) +if(DEFINED ENV{NCCL_ROOT_DIR}) + set(NCCL_ROOT_DIR $ENV{NCCL_ROOT_DIR}) + set(NCCL_INCLUDE_DIR "${NCCL_ROOT_DIR}/include" CACHE PATH "Folder contains NVIDIA NCCL headers") + set(NCCL_LIB_DIR "${NCCL_ROOT_DIR}/lib" CACHE PATH "Folder contains NVIDIA NCCL libraries") +else() + set(NCCL_INCLUDE_DIR $ENV{NCCL_INCLUDE_DIR} CACHE PATH "Folder contains NVIDIA NCCL headers") + set(NCCL_LIB_DIR $ENV{NCCL_LIB_DIR} CACHE PATH "Folder contains NVIDIA NCCL libraries") +endif() + +# Compatible layer for CMake <3.12. NCCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12. +if(NOT NCCL_INCLUDE_DIR OR NOT NCCL_LIB_DIR) + execute_process( + COMMAND python -c "import nvidia.nccl;import os; print(os.path.dirname(nvidia.nccl.__file__))" + OUTPUT_VARIABLE NCCL_PIP_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + ) + list(APPEND NCCL_ROOT $ENV{NCCL_PIP_DIR}) + if(NOT NCCL_INCLUDE_DIR) + set(NCCL_INCLUDE_DIR "${NCCL_PIP_DIR}/include") + endif() + if(NOT NCCL_LIB_DIR) + set(NCCL_LIB_DIR "${NCCL_PIP_DIR}/lib") + endif() + find_library(NCCL_LIBRARIES + NAMES ${NCCL_LIBNAME} + HINTS ${NCCL_LIB_DIR}) +endif() + +list(APPEND CMAKE_PREFIX_PATH ${NCCL_ROOT}) +find_path(NCCL_INCLUDE_DIRS + NAMES nccl.h + HINTS ${NCCL_INCLUDE_DIR}) + + + +if (USE_STATIC_NCCL) + MESSAGE(STATUS "USE_STATIC_NCCL is set. Linking with static NCCL library.") + SET(NCCL_LIBNAME "nccl_static") + if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified + set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() +else() + SET(NCCL_LIBNAME "nccl") + + if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified + message(STATUS "NCCL version: ${NCCL_VERSION}") + set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + else() + set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.2" ${CMAKE_FIND_LIBRARY_SUFFIXES}) + endif() + +endif() + +find_library(NCCL_LIBRARIES + NAMES ${NCCL_LIBNAME} + HINTS ${NCCL_LIB_DIR}) + + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES) + +if(NCCL_FOUND) # obtaining NCCL version and some sanity checks + set (NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h") + message (STATUS "Determining NCCL version from ${NCCL_HEADER_FILE}...") + set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES}) + list (APPEND CMAKE_REQUIRED_INCLUDES ${NCCL_INCLUDE_DIRS}) + include(CheckCXXSymbolExists) + check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED) + + if (NCCL_VERSION_DEFINED) + set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc") + file(WRITE ${file} " + #include <iostream> + #include <nccl.h> + int main() + { + std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl; + int x; + ncclGetVersion(&x); + return x == NCCL_VERSION_CODE; + } +") + try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file} + RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER + CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${NCCL_INCLUDE_DIRS}" + LINK_LIBRARIES ${NCCL_LIBRARIES}) + if (NOT NCCL_VERSION_MATCHED) + message(FATAL_ERROR "Found NCCL header version and library version do not match! \ +(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.") + endif() + message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}") + else() + # message(STATUS "NCCL version < 2.3.5-5") + endif () + set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES}) + + message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})") + mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES) +endif() diff --git a/examples/BMTrain/csrc/bind.cpp b/examples/BMTrain/csrc/bind.cpp new file mode 100644 index 00000000..b8f6fa85 --- /dev/null +++ b/examples/BMTrain/csrc/bind.cpp @@ -0,0 +1,35 @@ +#include "include/bind.hpp" + +PYBIND11_MODULE(C, m) { + m.def("to_fp16_from_fp32", &fp16_from_fp32_value_launcher, "convert"); + m.def("to_bf16_from_fp32", &bf16_from_fp32_value_launcher, "convert"); + m.def("is_bf16_supported", &is_bf16_supported, "whether bf16 supported"); + m.def("has_nan_inf_fp16_launcher", &has_nan_inf_fp16_launcher, "has nan inf"); + m.def("has_nan_inf_bf16_launcher", &has_nan_inf_bf16_launcher, "has nan inf bf16"); + m.def("adam_fp16_launcher", &adam_fp16_launcher, "adam function cpu"); + m.def("adam_bf16_launcher", &adam_bf16_launcher, "adam function cpu"); + m.def("adam_cpu_fp16_launcher", &adam_cpu_fp16_launcher, "adam function cpu"); + m.def("adam_cpu_bf16_launcher", &adam_cpu_bf16_launcher, "adam function cpu"); + m.def("cross_entropy_forward_fp16_launcher", &cross_entropy_forward_fp16_launcher, "cross entropy forward"); + m.def("cross_entropy_forward_bf16_launcher", &cross_entropy_forward_bf16_launcher, "cross entropy forward"); + m.def("cross_entropy_backward_inplace_fp16_launcher", &cross_entropy_backward_inplace_fp16_launcher, "cross entropy backward inplace"); + m.def("cross_entropy_backward_inplace_bf16_launcher", &cross_entropy_backward_inplace_bf16_launcher, "cross entropy backward inplace"); + m.def("fused_sumexp_fp16_launcher", &fused_sumexp_fp16_launcher, "sum exp"); + m.def("fused_sumexp_bf16_launcher", &fused_sumexp_bf16_launcher, "sum exp"); + m.def("fused_softmax_inplace_fp16_launcher", &fused_softmax_inplace_fp16_launcher, "softmax inplace"); + m.def("fused_softmax_inplace_bf16_launcher", &fused_softmax_inplace_bf16_launcher, "softmax inplace"); + m.def("ncclGetUniqueId", &pyNCCLGetUniqueID, "nccl get unique ID"); + m.def("ncclCommInitRank", &pyNCCLCommInitRank, "nccl init rank"); + m.def("ncclCommDestroy", &pyNCCLCommDestroy, "nccl delete rank"); + m.def("ncclAllGather", &pyNCCLAllGather, "nccl all gather"); + m.def("ncclAllReduce", &pyNCCLAllReduce, "nccl all reduce"); + m.def("ncclBroadcast", &pyNCCLBroadcast, "nccl broadcast"); + m.def("ncclReduce", &pyNCCLReduce, "nccl reduce"); + m.def("ncclReduceScatter", &pyNCCLReduceScatter, "nccl reduce scatter"); + m.def("ncclGroupStart", &pyNCCLGroupStart, "nccl group start"); + m.def("ncclGroupEnd", &pyNCCLGroupEnd, "nccl group end"); + m.def("ncclSend", &pyNCCLSend, "nccl send"); + m.def("ncclRecv", &pyNCCLRecv, "nccl recv"); + m.def("ncclCommCount", &pyNCCLCommCount, "nccl comm count"); + m.def("ncclCommUserRank", &pyNCCLCommUserRank, "nccl comm user rank"); +} diff --git a/examples/BMTrain/csrc/cuda/adam_cuda.cu b/examples/BMTrain/csrc/cuda/adam_cuda.cu new file mode 100644 index 00000000..0510ac12 --- /dev/null +++ b/examples/BMTrain/csrc/cuda/adam_cuda.cu @@ -0,0 +1,126 @@ +#include <cstdint> +#include <cuda.h> +#include <cuda_fp16.h> +#include "bfloat16.cuh" + +namespace { +// blocks <n // 1024>, threads<min(n, 1024)> +__global__ void adam_fp32_accum( + int32_t n, + const half *g, // (n) + half *m, // (n) + float *v, // (n) + float *param, // (n) + half *param_h, // (n) + float beta1, + float beta2, + float eps, + float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +) { + int32_t col = blockIdx.x * blockDim.x + threadIdx.x; + if (col < n) { + float local_g = __half2float(g[col]); // real_g * scale + float local_m = beta1 * __half2float(m[col]) + (1 - beta1) * local_g; // real_m * scale + float local_v = beta2 * v[col] + (1 - beta2) * local_g * local_g / scale; // real_v * scale + float local_p = param[col]; + local_p = local_p - lr * local_m / bias_correction1 / (sqrtf(local_v * scale / bias_correction2) + eps * scale) - lr * weight_decay * local_p; + + param_h[col] = __float2half(local_p); + param[col] = local_p; + v[col] = local_v; + m[col] = __float2half(local_m); + } +} + +__global__ void adam_fp32_accum_bf16( + int32_t n, + const std::uintptr_t g_ptr, // (n) + float *m, // (n) + float *v, // (n) + float *param, // (n) + std::uintptr_t param_h_ptr, // (n) + float beta1, + float beta2, + float eps, + float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +) { +#ifdef BF16_SUPPORT + const __nv_bfloat16* g = reinterpret_cast<const __nv_bfloat16*>(g_ptr); + __nv_bfloat16* param_h = reinterpret_cast<__nv_bfloat16*>(param_h_ptr); + int32_t col = blockIdx.x * blockDim.x + threadIdx.x; + if (col < n) { + float local_g = __bfloat162float(g[col]) / scale; // real_g + float local_m = beta1 * m[col] + (1 - beta1) * local_g; // real_m + float local_v = beta2 * v[col] + (1 - beta2) * local_g * local_g; // real_v + float local_p = param[col]; + local_p = local_p - lr * local_m / bias_correction1 / (sqrtf(local_v / bias_correction2) + eps) - lr * weight_decay * local_p; + + param_h[col] = __float2bfloat16(local_p); + param[col] = local_p; + v[col] = local_v; + m[col] = local_m; + } +#endif +} + +} + +void adam_fp16_launcher( + int n, + std::uintptr_t param_fp32, + std::uintptr_t param_fp16, + std::uintptr_t g_fp16, + std::uintptr_t m_fp16, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2, + uintptr_t stream +) { + if (n <= 0) return; + auto g_ptr = reinterpret_cast<half*>(g_fp16); + auto m_ptr = reinterpret_cast<half*>(m_fp16); + auto param_h_ptr = reinterpret_cast<half*>(param_fp16); + auto param_fp32_ptr = reinterpret_cast<float*>(param_fp32); + auto v_fp32_ptr = reinterpret_cast<float*>(v_fp32); + int32_t threads = 1024; + dim3 block_size = dim3(threads, 1, 1); + dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); + adam_fp32_accum<<<grid_size, block_size, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, g_ptr, m_ptr, v_fp32_ptr, param_fp32_ptr, param_h_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); +} + +void adam_bf16_launcher( + int n, + std::uintptr_t param_fp32, + std::uintptr_t param_bf16, + std::uintptr_t g_bf16, + std::uintptr_t m_fp32, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2, + uintptr_t stream +) { + if (n <= 0) return; + auto m_ptr = reinterpret_cast<float*>(m_fp32); + auto param_fp32_ptr = reinterpret_cast<float*>(param_fp32); + auto v_fp32_ptr = reinterpret_cast<float*>(v_fp32); + int32_t threads = 1024; + dim3 block_size = dim3(threads, 1, 1); + dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); + adam_fp32_accum_bf16<<<grid_size, block_size, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, g_bf16, m_ptr, v_fp32_ptr, param_fp32_ptr, param_bf16, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); +} diff --git a/examples/BMTrain/csrc/cuda/bfloat16.cuh b/examples/BMTrain/csrc/cuda/bfloat16.cuh new file mode 100644 index 00000000..564d8bec --- /dev/null +++ b/examples/BMTrain/csrc/cuda/bfloat16.cuh @@ -0,0 +1,5 @@ +#include <cuda.h> +#if defined(__CUDACC__) && CUDA_VERSION >= 11000 +#include <cuda_bf16.h> +#define BF16_SUPPORT +#endif \ No newline at end of file diff --git a/examples/BMTrain/csrc/cuda/cross_entropy.cu b/examples/BMTrain/csrc/cuda/cross_entropy.cu new file mode 100644 index 00000000..177c3b77 --- /dev/null +++ b/examples/BMTrain/csrc/cuda/cross_entropy.cu @@ -0,0 +1,315 @@ +#include "reduce.cuh" +#include <cstdint> +#include <cuda.h> +#include <cuda_fp16.h> +#include "bfloat16.cuh" + +namespace { +// blocks <m>, threads<1024> +__global__ void cross_entropy_forward_fp16( + int64_t n, + const half *input, // (m, n) + const int32_t *target, // (m) + half *softmax, // (m, n) + float *output, // (m) + int32_t ignore_index +) { + int64_t base_idx = blockIdx.x * n; + + float local_max = -INFINITY; + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + local_max = fmaxf(__half2float(input[base_idx + i]), local_max); + } + + local_max = fmaxf(block_allreduce_max(local_max), -1e6); + + float local_sum = 0; + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + local_sum += expf(__half2float(input[base_idx + i]) - local_max); + } + local_sum = block_allreduce_sum(local_sum) + 1e-10; // avoid nan + + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + softmax[base_idx + i] = __float2half( expf(__half2float(input[base_idx + i]) - local_max) / local_sum ); + } + + if (threadIdx.x == 0) { + if (target[blockIdx.x] != ignore_index) { + output[blockIdx.x] = -__half2float(input[base_idx + target[blockIdx.x]]) + local_max + logf(local_sum); + } else { + output[blockIdx.x] = 0; + } + } +} + +// blocks <m>, threads<1024> +__global__ void cross_entropy_backward_inplace_fp16( + int64_t n, + const float *grad_output, // (m) + const int32_t *target, // (m) + half *x, // (m, n) + int32_t ignore_index +) { + int64_t base_idx = blockIdx.x * n; + + int32_t t = target[blockIdx.x]; + float v = grad_output[blockIdx.x]; + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + x[base_idx + i] = __float2half(i==t ? (__half2float(x[base_idx + i])-1)*v : __half2float(x[base_idx + i])*v); + } +} + +// blocks <m>, threads<1024> +__global__ void cross_entropy_forward_bf16( + int64_t n, + const std::uintptr_t input_ptr, // (m, n) + const int32_t *target, // (m) + std::uintptr_t softmax_ptr, // (m, n) + float *output, // (m) + int32_t ignore_index +) { +#ifdef BF16_SUPPORT + const __nv_bfloat16* input = reinterpret_cast<const __nv_bfloat16*>(input_ptr); + __nv_bfloat16* softmax = reinterpret_cast<__nv_bfloat16*>(softmax_ptr); + int64_t base_idx = blockIdx.x * n; + + float local_max = -INFINITY; + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + local_max = fmaxf(__bfloat162float(input[base_idx + i]), local_max); + } + + local_max = fmaxf(block_allreduce_max(local_max), -1e6); + + float local_sum = 0; + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + local_sum += expf(__bfloat162float(input[base_idx + i]) - local_max); + } + local_sum = block_allreduce_sum(local_sum) + 1e-10; // avoid nan + + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + softmax[base_idx + i] = __float2bfloat16( expf(__bfloat162float(input[base_idx + i]) - local_max) / local_sum ); + } + + if (threadIdx.x == 0) { + if (target[blockIdx.x] != ignore_index) { + output[blockIdx.x] = -__bfloat162float(input[base_idx + target[blockIdx.x]]) + local_max + logf(local_sum); + } else { + output[blockIdx.x] = 0; + } + } +#endif +} + +// blocks <m>, threads<1024> +__global__ void cross_entropy_backward_inplace_bf16( + int64_t n, + const float *grad_output, // (m) + const int32_t *target, // (m) + std::uintptr_t x_ptr, // (m, n) + int32_t ignore_index +) { +#ifdef BF16_SUPPORT + __nv_bfloat16* x = reinterpret_cast<__nv_bfloat16*>(x_ptr); + int64_t base_idx = blockIdx.x * n; + + int32_t t = target[blockIdx.x]; + float v = grad_output[blockIdx.x]; + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + x[base_idx + i] = __float2bfloat16(i==t ? (__bfloat162float(x[base_idx + i])-1)*v : __bfloat162float(x[base_idx + i])*v); + } +#endif +} + +// blocks <m>, threads<1024> +__global__ void fused_sumexp_fp16( + int64_t n, + const half *input, // (m, n) + const float *global_max, // (m) + float *global_sum // (m) +) { + int64_t base_idx = blockIdx.x * n; + float local_max = global_max[blockIdx.x]; + + float local_sum = 0; + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + local_sum += expf(__half2float(input[base_idx + i]) - local_max); + } + local_sum = block_allreduce_sum(local_sum); + if (threadIdx.x == 0) { + global_sum[blockIdx.x] = local_sum; + } +} + +// blocks <m>, threads<1024> +__global__ void fused_sumexp_bf16( + int64_t n, + const std::uintptr_t input_ptr, // (m, n) + const float *global_max, // (m) + float *global_sum // (m) +) { +#ifdef BF16_SUPPORT + const __nv_bfloat16* input = reinterpret_cast<const __nv_bfloat16*>(input_ptr); + int64_t base_idx = blockIdx.x * n; + float local_max = global_max[blockIdx.x]; + + float local_sum = 0; + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + local_sum += expf(__bfloat162float(input[base_idx + i]) - local_max); + } + local_sum = block_allreduce_sum(local_sum); + if (threadIdx.x == 0) { + global_sum[blockIdx.x] = local_sum; + } +#endif +} + +// blocks <m>, threads<1024> +__global__ void fused_softmax_inplace_fp16( + int64_t n, + half *softmax, // (m, n) + const float *global_max, // (m) + const float *global_sum // (m) +) { + int64_t base_idx = blockIdx.x * n; + float local_max = global_max[blockIdx.x]; + float local_sum = global_sum[blockIdx.x]; + + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + softmax[base_idx + i] = __float2half( expf(__half2float(softmax[base_idx + i]) - local_max) / local_sum ); + } +} + +// blocks <m>, threads<1024> +__global__ void fused_softmax_inplace_bf16( + int64_t n, + std::uintptr_t softmax_ptr, // (m, n) + const float *global_max, // (m) + const float *global_sum // (m) +) { +#ifdef BF16_SUPPORT + __nv_bfloat16* softmax = reinterpret_cast<__nv_bfloat16*>(softmax_ptr); + int64_t base_idx = blockIdx.x * n; + float local_max = global_max[blockIdx.x]; + float local_sum = global_sum[blockIdx.x]; + + for (int64_t i = threadIdx.x; i < n; i += blockDim.x) { + softmax[base_idx + i] = __float2bfloat16( expf(__bfloat162float(softmax[base_idx + i]) - local_max) / local_sum ); + } +#endif +} +} + +void cross_entropy_forward_fp16_launcher( + int32_t m, int32_t n, + std::uintptr_t input, + std::uintptr_t target, + std::uintptr_t softmax, + std::uintptr_t output, + int32_t ignore_index, + std::uintptr_t stream +) { + auto input_ptr = reinterpret_cast<half*>(input); + auto target_ptr = reinterpret_cast<int32_t*>(target); + auto softmax_ptr = reinterpret_cast<half*>(softmax); + auto output_ptr = reinterpret_cast<float*>(output); + int32_t threads = 1024; + cross_entropy_forward_fp16<<<m, threads, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index); +} + +void cross_entropy_backward_inplace_fp16_launcher( + int32_t m, int32_t n, + std::uintptr_t grad_output, + std::uintptr_t target, + std::uintptr_t x, + int32_t ignore_index, + std::uintptr_t stream +) { + auto output_ptr = reinterpret_cast<float*>(grad_output); + auto target_ptr = reinterpret_cast<int32_t*>(target); + auto x_ptr = reinterpret_cast<half*>(x); + int32_t threads = 1024; + cross_entropy_backward_inplace_fp16<<<m, threads, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, output_ptr, target_ptr, x_ptr, ignore_index); +} + +void cross_entropy_forward_bf16_launcher( + int32_t m, int32_t n, + std::uintptr_t input, + std::uintptr_t target, + std::uintptr_t softmax, + std::uintptr_t output, + int32_t ignore_index, + std::uintptr_t stream +) { + auto target_ptr = reinterpret_cast<int32_t*>(target); + auto output_ptr = reinterpret_cast<float*>(output); + int32_t threads = 1024; + cross_entropy_forward_bf16<<<m, threads, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, input, target_ptr, softmax, output_ptr, ignore_index); +} + +void cross_entropy_backward_inplace_bf16_launcher( + int32_t m, int32_t n, + std::uintptr_t grad_output, + std::uintptr_t target, + std::uintptr_t x, + int32_t ignore_index, + std::uintptr_t stream +) { + auto output_ptr = reinterpret_cast<float*>(grad_output); + auto target_ptr = reinterpret_cast<int32_t*>(target); + int32_t threads = 1024; + cross_entropy_backward_inplace_bf16<<<m, threads, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, output_ptr, target_ptr, x, ignore_index); +} + +void fused_sumexp_fp16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +) { + auto logits_ptr = reinterpret_cast<half*>(logits); + auto max_logits_ptr = reinterpret_cast<float*>(max_logits); + auto sum_exp_logits_ptr = reinterpret_cast<float*>(sum_exp_logits); + int32_t threads = 1024; + fused_sumexp_fp16<<<m, threads, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr); +} + +void fused_sumexp_bf16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +) { + auto max_logits_ptr = reinterpret_cast<float*>(max_logits); + auto sum_exp_logits_ptr = reinterpret_cast<float*>(sum_exp_logits); + int32_t threads = 1024; + fused_sumexp_bf16<<<m, threads, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, logits, max_logits_ptr, sum_exp_logits_ptr); +} + +void fused_softmax_inplace_fp16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +) { + auto logits_ptr = reinterpret_cast<half*>(logits); + auto max_logits_ptr = reinterpret_cast<float*>(max_logits); + auto sum_exp_logits_ptr = reinterpret_cast<float*>(sum_exp_logits); + int32_t threads = 1024; + fused_softmax_inplace_fp16<<<m, threads, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr); +} + +void fused_softmax_inplace_bf16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +) { + auto max_logits_ptr = reinterpret_cast<float*>(max_logits); + auto sum_exp_logits_ptr = reinterpret_cast<float*>(sum_exp_logits); + int32_t threads = 1024; + fused_softmax_inplace_bf16<<<m, threads, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, logits, max_logits_ptr, sum_exp_logits_ptr); +} \ No newline at end of file diff --git a/examples/BMTrain/csrc/cuda/has_inf_nan.cu b/examples/BMTrain/csrc/cuda/has_inf_nan.cu new file mode 100644 index 00000000..32bc5a5f --- /dev/null +++ b/examples/BMTrain/csrc/cuda/has_inf_nan.cu @@ -0,0 +1,145 @@ +#include <cstdint> +#include <cstdio> +#include <cuda.h> +#include <cuda_fp16.h> +#include "bfloat16.cuh" + +namespace{ +__inline__ __device__ bool isnan_(half v) { + #if __CUDA_ARCH__ >= 700 || __CUDA_ARCH__ == 600 + return __hisnan(v); + #else + return !__heq(v, v); + #endif +} + +__inline__ __device__ int8_t warpReduceAny(int8_t x) { + for (int offset = warpSize/2; offset > 0; offset /= 2) + x |= __shfl_down_sync(0xFFFFFFFF, x, offset); + return x; +} + +__inline__ __device__ float blockReduceAny(int8_t x) { + static __shared__ float shared[32]; + int lane = threadIdx.x % warpSize; + int wid = threadIdx.x / warpSize; + x = warpReduceAny(x); + if (lane == 0) shared[wid] = x; + __syncthreads(); + x = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0; + if (wid == 0) x = warpReduceAny(x); + return x; +} + +// grid <min(ceil(n/1024), 1024)>, thread<1024> +__global__ void bmt_has_nan_inf_fp16( + int32_t n, + const half* inp, // (n,) + uint8_t* mid // (1024,) +) { + int32_t gid = blockIdx.x * blockDim.x + threadIdx.x; + int32_t span = blockDim.x * gridDim.x; + + int8_t r = 0; + for (int i = gid; i < n; i += span) { + half v = inp[i]; + if (__hisinf(v) || isnan_(v)) { + r = 1; + break; + } + } + r = blockReduceAny(r); + if (threadIdx.x == 0) { + mid[blockIdx.x] = r; + } +} + +// grid <1>, thread<1024> +__global__ void bmt_has_nan_inf_reduce( + const uint8_t* mid, // (1024,) + uint8_t* out +) { + int tid = threadIdx.x; + int8_t r = blockReduceAny(mid[tid]); + if (tid == 0 && r > 0) { + out[0] = 1; + } +} + +// grid <min(ceil(n/1024), 1024)>, thread<1024> +__global__ void bmt_has_nan_inf_bf16( + int32_t n, + const uintptr_t inp, // (n,) + uint8_t* mid // (1024,) +) { +#ifdef BF16_SUPPORT + const __nv_bfloat16* bf_inp = reinterpret_cast<const __nv_bfloat16*>(inp); + int32_t gid = blockIdx.x * blockDim.x + threadIdx.x; + int32_t span = blockDim.x * gridDim.x; + + int8_t r = 0; + for (int i = gid; i < n; i += span) { + __nv_bfloat16 v = bf_inp[i]; + #if __CUDA_ARCH__ >= 800 + if (__hisinf(v) || __hisnan(v)) { + #else + if (isinf(__bfloat162float(v)) || isnan(__bfloat162float(v))) { + #endif + r = 1; + break; + } + } + r = blockReduceAny(r); + if (threadIdx.x == 0) { + mid[blockIdx.x] = r; + } +#endif +} + +} + +void has_nan_inf_fp16_launcher( + int32_t n, + std::uintptr_t g_fp16, + std::uintptr_t mid, + std::uintptr_t out, + std::uintptr_t stream +) { + if (n <= 0) return; + auto g_ptr = reinterpret_cast<half*>(g_fp16); + auto mid_ptr = reinterpret_cast<uint8_t*>(mid); + auto out_ptr = reinterpret_cast<uint8_t*>(out); + int32_t threads = 1024; + dim3 block_size = dim3(threads, 1, 1); + dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); + dim3 clamp_grid_size = dim3(min((n + threads - 1) / threads, 1024), 1, 1); + + bmt_has_nan_inf_fp16<<<clamp_grid_size, block_size, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, g_ptr, mid_ptr); + bmt_has_nan_inf_reduce<<<1, block_size, 0, reinterpret_cast<cudaStream_t>(stream)>>>(mid_ptr, out_ptr); +} + +void has_nan_inf_bf16_launcher( + int32_t n, + std::uintptr_t g_bf16, + std::uintptr_t mid, + std::uintptr_t out, + std::uintptr_t stream +) { + if (n <= 0) return; + auto mid_ptr = reinterpret_cast<uint8_t*>(mid); + auto out_ptr = reinterpret_cast<uint8_t*>(out); + int32_t threads = 1024; + dim3 block_size = dim3(threads, 1, 1); + dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1); + dim3 clamp_grid_size = dim3(min((n + threads - 1) / threads, 1024), 1, 1); + + bmt_has_nan_inf_bf16<<<clamp_grid_size, block_size, 0, reinterpret_cast<cudaStream_t>(stream)>>>(n, g_bf16, mid_ptr); + bmt_has_nan_inf_reduce<<<1, block_size, 0, reinterpret_cast<cudaStream_t>(stream)>>>(mid_ptr, out_ptr); +} + +int is_bf16_supported() { +#ifdef BF16_SUPPORT + return 1; +#endif + return 0; +} \ No newline at end of file diff --git a/examples/BMTrain/csrc/cuda/reduce.cuh b/examples/BMTrain/csrc/cuda/reduce.cuh new file mode 100644 index 00000000..a9c4c15b --- /dev/null +++ b/examples/BMTrain/csrc/cuda/reduce.cuh @@ -0,0 +1,114 @@ +namespace { +const int WARP_SZ = 32; + +// blocks <block_size>, threads<1024> +__device__ float block_reduce_sum(float val) { + static __shared__ float s_x[WARP_SZ]; + // int gid = threadIdx.x + blockIdx.x * blockDim.x; + int tid = threadIdx.x; + int lid = threadIdx.x % WARP_SZ; + int wid = threadIdx.x / WARP_SZ; + + // reduce intra warp + + for (int offset = WARP_SZ/2; offset > 0; offset >>= 1) + val += __shfl_down_sync(0xFFFFFFFF, val, offset); + + if (lid == 0) s_x[wid] = val; + __syncthreads(); + + // reduce inter warp + val = (tid < WARP_SZ) ? s_x[lid] : 0; + if (wid == 0) { + for (int offset = WARP_SZ/2; offset > 0; offset >>= 1) + val += __shfl_down_sync(0xFFFFFFFF, val, offset); + } + return val; +} + +// blocks <block_size>, threads<1024> +__device__ float block_reduce_max(float val) { + static __shared__ float s_x[WARP_SZ]; + // int gid = threadIdx.x + blockIdx.x * blockDim.x; + int tid = threadIdx.x; + int lid = threadIdx.x % WARP_SZ; + int wid = threadIdx.x / WARP_SZ; + + // reduce intra warp + + for (int offset = WARP_SZ/2; offset > 0; offset >>= 1) + val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset)); + + if (lid == 0) s_x[wid] = val; + __syncthreads(); + + // reduce inter warp + val = (tid < WARP_SZ) ? s_x[lid] : -INFINITY; + if (wid == 0) { + for (int offset = WARP_SZ/2; offset > 0; offset >>= 1) + val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset)); + } + return val; +} + +// blocks <block_size>, threads<1024> +__device__ float block_allreduce_sum(float val) { + static __shared__ float s_x[WARP_SZ]; + // int gid = threadIdx.x + blockIdx.x * blockDim.x; + int tid = threadIdx.x; + int lid = threadIdx.x % WARP_SZ; + int wid = threadIdx.x / WARP_SZ; + + // reduce intra warp + + for (int offset = WARP_SZ/2; offset > 0; offset >>= 1) + val += __shfl_down_sync(0xFFFFFFFF, val, offset); + + if (lid == 0) s_x[wid] = val; + __syncthreads(); + + // reduce inter warp + val = (tid < WARP_SZ) ? s_x[lid] : 0; + if (wid == 0) { + for (int offset = WARP_SZ/2; offset > 0; offset >>= 1) + val += __shfl_down_sync(0xFFFFFFFF, val, offset); + } + + if (tid == 0) { + s_x[0] = val; + } + __syncthreads(); + return s_x[0]; +} + +// blocks <block_size>, threads<1024> +__device__ float block_allreduce_max(float val) { + static __shared__ float s_x[WARP_SZ]; + // int gid = threadIdx.x + blockIdx.x * blockDim.x; + int tid = threadIdx.x; + int lid = threadIdx.x % WARP_SZ; + int wid = threadIdx.x / WARP_SZ; + + // reduce intra warp + + for (int offset = WARP_SZ/2; offset > 0; offset >>= 1) + val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset)); + + if (lid == 0) s_x[wid] = val; + __syncthreads(); + + // reduce inter warp + val = (tid < WARP_SZ) ? s_x[lid] : -INFINITY; + if (wid == 0) { + for (int offset = WARP_SZ/2; offset > 0; offset >>= 1) + val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset)); + } + + if (tid == 0) { + s_x[0] = val; + } + __syncthreads(); + return s_x[0]; +} + +} \ No newline at end of file diff --git a/examples/BMTrain/csrc/include/adam_cpu.hpp b/examples/BMTrain/csrc/include/adam_cpu.hpp new file mode 100644 index 00000000..52575d69 --- /dev/null +++ b/examples/BMTrain/csrc/include/adam_cpu.hpp @@ -0,0 +1,557 @@ +#include <emmintrin.h> +#include <immintrin.h> +#include <cmath> +#include <cstdint> +#include <sched.h> +#include <pybind11/pybind11.h> +#include <iostream> +#include <mutex> +#include <vector> +#include <thread> +#include <algorithm> +#include "cpu_info.h" +#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") + +static inline float _mm256_reduce_add_ps(__m256 x) { + /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */ + const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x)); + /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */ + const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */ + const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + /* Conversion to float is a no-op on x86-64 */ + return _mm_cvtss_f32(x32); +} + +inline float fp32_from_bits(uint32_t w) { + union { + uint32_t as_bits; + float as_value; + } fp32 = {w}; + return fp32.as_value; +} + +inline uint32_t fp32_to_bits(float f) { + union { + float as_value; + uint32_t as_bits; + } fp32 = {f}; + return fp32.as_bits; +} + +template <class F> +inline void parallel_for(int64_t begin, int64_t end, int64_t grain_size, const F& f) { + // Number of iterations + int64_t numiter = end - begin; + + // Number of threads to use + int64_t num_threads = 1; // Default to serial execution + + if (grain_size > 0) { + num_threads = std::max((numiter+grain_size-1) / grain_size, static_cast<int64_t>(1)); + } + else{ + cpu_set_t cpu_set; + CPU_ZERO(&cpu_set); + sched_getaffinity(0, sizeof(cpu_set), &cpu_set); + num_threads = CPU_COUNT(&cpu_set); + grain_size = std::max((numiter+num_threads-1) / num_threads, static_cast<int64_t>(1)); + + } + + // Check if parallel execution is feasible + if (num_threads > 1) { + py::gil_scoped_release release; // Release the GIL + std::vector<std::thread> threads(num_threads); + for (int64_t t = 0; t < num_threads; ++t) { + threads[t] = std::thread([&, t]() { + int64_t left = std::min(begin + t * grain_size, end); + int64_t right = std::min(begin + (t + 1) * grain_size, end); + f(left, right); + }); + } + for (auto& thread : threads) { + thread.join(); + } + } else { + // If not feasible or grain_size is 0, perform the operation serially + f(begin, end); + } +} + +// fp32 -> fp16 +inline uint16_t fp16_ieee_from_fp32_value(float f) { + // const float scale_to_inf = 0x1.0p+112f; + // const float scale_to_zero = 0x1.0p-110f; + uint32_t scale_to_inf_bits = (uint32_t) 239 << 23; + uint32_t scale_to_zero_bits = (uint32_t) 17 << 23; + float scale_to_inf_val, scale_to_zero_val; + std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val)); + std::memcpy(&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val)); + const float scale_to_inf = scale_to_inf_val; + const float scale_to_zero = scale_to_zero_val; + + float base = (fabsf(f) * scale_to_inf) * scale_to_zero; + + const uint32_t w = (uint32_t)fp32_to_bits(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } + + base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = (uint32_t)fp32_to_bits(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return static_cast<uint16_t>( + (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign) + ); +} + +// fp16 -> fp32 +inline float fp16_ieee_to_fp32_value(uint16_t h) { + const uint32_t w = (uint32_t)h << 16; + const uint32_t sign = w & UINT32_C(0x80000000); + const uint32_t two_w = w + w; + + const uint32_t exp_offset = UINT32_C(0xE0) << 23; + const float exp_scale = 0x1.0p-112f; + const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale; + + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias; + + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = + sign | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) + : fp32_to_bits(normalized_value)); + return fp32_from_bits(result); +} + +inline uint16_t bf16_from_fp32_value(float f){ + return *reinterpret_cast<uint32_t*>(&f) >> 16; +} +// fp32 -> bf16 +void bf16_from_fp32_value_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_bf16 +){ + int span = 1; + auto param_fp32_ptr = reinterpret_cast<float*>(param_fp32); + auto param_bf16_ptr = reinterpret_cast<uint16_t*>(param_bf16); + parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + for (int64_t j = start; j < end; j += span) { + for (int64_t i = j; i < end; i++) { + float p = param_fp32_ptr[i]; + param_bf16_ptr[i] = bf16_from_fp32_value(p); + } + break; // must break here + } + }); +} + +void fp16_from_fp32_value_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_fp16 +){ + int span = 1; + auto param_fp32_ptr = reinterpret_cast<float*>(param_fp32); + auto param_fp16_ptr = reinterpret_cast<uint16_t*>(param_fp16); + parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + for (int64_t j = start; j < end; j += span) { + for (int64_t i = j; i < end; i++) { + float p = param_fp32_ptr[i]; + param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); + } + break; // must break here + } + }); +} +// bf16 -> fp32 +inline float bf16_to_fp32_value(uint16_t h){ + uint32_t src = h; + src <<= 16; + return *reinterpret_cast<float*>(&src); +} + +void adam_cpu_0( + int64_t n, + float* param_fp32_ptr, + uint16_t* param_fp16_ptr, + float* delta_info_ptr, + uint16_t* g_fp16_ptr, + float* m_fp32_ptr, + float* v_fp32_ptr, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +){ + int64_t span = 1; + float sum_sq_delta = 0; + float sum_delta = 0; + std::mutex delta_mutex; + parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + float sum_sq_delta_i = 0; + float sum_delta_i = 0; + for (int64_t j = start; j < end; j += span) { + for (int64_t i = j; i < end; i++) { + float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale; + float m = m_fp32_ptr[i]; + float v = v_fp32_ptr[i]; + float p = param_fp32_ptr[i]; + m = beta1 * m + (1 - beta1) * g; + v = beta2 * v + (1 - beta2) * g * g; + if (delta_info_ptr != NULL){ + float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; + sum_delta_i += delta; + sum_sq_delta_i += delta * delta; + } + p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; + param_fp32_ptr[i] = p; + param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); + m_fp32_ptr[i] = m; + v_fp32_ptr[i] = v; + } + break; // must break here + } + if (delta_info_ptr != NULL){ + delta_mutex.lock(); + sum_delta += sum_delta_i; + sum_sq_delta += sum_sq_delta_i; + delta_mutex.unlock(); + } + }); + if (delta_info_ptr != NULL){ + delta_info_ptr[0] = sum_delta / n; + delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 + delta_info_ptr[2] = sum_delta; + delta_info_ptr[3] = sum_sq_delta; + } +} + +void adam_cpu_bf16_0( + int64_t n, + float* param_fp32_ptr, + uint16_t* param_bf16_ptr, + float* delta_info_ptr, + uint16_t* g_bf16_ptr, + float* m_fp32_ptr, + float* v_fp32_ptr, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +){ + int64_t span = 1; + float sum_sq_delta = 0; + float sum_delta = 0; + std::mutex delta_mutex; + parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + float sum_sq_delta_i = 0; + float sum_delta_i = 0; + for (int64_t j = start; j < end; j += span) { + for (int64_t i = j; i < end; i++) { + float g = bf16_to_fp32_value(g_bf16_ptr[i]) / scale; + float m = m_fp32_ptr[i]; + float v = v_fp32_ptr[i]; + float p = param_fp32_ptr[i]; + m = beta1 * m + (1 - beta1) * g; + v = beta2 * v + (1 - beta2) * g * g; + if (delta_info_ptr != NULL){ + float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; + sum_delta_i += delta; + sum_sq_delta_i += delta * delta; + } + p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; + param_fp32_ptr[i] = p; + param_bf16_ptr[i] = bf16_from_fp32_value(p); + m_fp32_ptr[i] = m; + v_fp32_ptr[i] = v; + } + break; // must break here + } + if (delta_info_ptr != NULL){ + delta_mutex.lock(); + sum_delta += sum_delta_i; + sum_sq_delta += sum_sq_delta_i; + delta_mutex.unlock(); + } + }); + if (delta_info_ptr != NULL){ + delta_info_ptr[0] = sum_delta / n; + delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 + delta_info_ptr[2] = sum_delta; + delta_info_ptr[3] = sum_sq_delta; + } +} + +static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1( + int64_t n, + float* param_fp32_ptr, + uint16_t* param_fp16_ptr, + float* delta_info_ptr, + uint16_t* g_fp16_ptr, + float* m_fp32_ptr, + float* v_fp32_ptr, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +){ + float sum_sq_delta = 0; + float sum_delta = 0; + std::mutex delta_mutex; + auto avx_beta1 = _mm256_set1_ps(beta1); + auto avx_beta2 = _mm256_set1_ps(beta2); + auto avx_beta1_1 = _mm256_set1_ps(1 - beta1); + auto avx_beta2_1 = _mm256_set1_ps(1 - beta2); + auto avx_eps = _mm256_set1_ps(eps); + auto avx_neg_lr = _mm256_set1_ps(-lr); + auto avx_scale = _mm256_set1_ps(scale); + auto avx_weight_decay = _mm256_set1_ps(weight_decay); + auto avx_bias_correction1 = _mm256_set1_ps(bias_correction1); + auto avx_bias_correction2 = _mm256_set1_ps(bias_correction2); + int64_t span = 8; + parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + float sum_sq_delta_i = 0; + float sum_delta_i = 0; + for (int64_t j = start; j < end; j += span) { + if (j + span > end) { + for (int64_t i = j; i < end; i++) { + float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale; + float m = m_fp32_ptr[i]; + float v = v_fp32_ptr[i]; + float p = param_fp32_ptr[i]; + m = beta1 * m + (1 - beta1) * g; + v = beta2 * v + (1 - beta2) * g * g; + if (delta_info_ptr != NULL){ + float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; + sum_delta_i += delta; + sum_sq_delta_i += delta * delta; + } + p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; + param_fp32_ptr[i] = p; + param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); + m_fp32_ptr[i] = m; + v_fp32_ptr[i] = v; + } + break; // must break here + } else { + auto g = _mm256_div_ps(_mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)&g_fp16_ptr[j])), avx_scale); + auto m = _mm256_loadu_ps(&m_fp32_ptr[j]); + auto v = _mm256_loadu_ps(&v_fp32_ptr[j]); + auto p = _mm256_loadu_ps(¶m_fp32_ptr[j]); + m = _mm256_fmadd_ps(avx_beta1, m, _mm256_mul_ps(avx_beta1_1, g)); + v = _mm256_fmadd_ps(avx_beta2, v, _mm256_mul_ps(avx_beta2_1, _mm256_mul_ps(g, g))); + if (delta_info_ptr != NULL){ + auto delta_256 = _mm256_add_ps( + _mm256_div_ps( + _mm256_div_ps(m, avx_bias_correction1), // m / bias_correction1 + _mm256_add_ps(_mm256_sqrt_ps(_mm256_div_ps(v, avx_bias_correction2)), avx_eps) // sqrt(v / bias_correction2) + eps + ), // m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + _mm256_mul_ps(avx_weight_decay, p) // weight_decay * p + ); // delta = m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + weight_decay * p + sum_delta_i += _mm256_reduce_add_ps(delta_256); + sum_sq_delta_i += _mm256_reduce_add_ps(_mm256_mul_ps(delta_256, delta_256)); + } + p = _mm256_fmadd_ps(avx_neg_lr, _mm256_mul_ps(avx_weight_decay, p), p); // p = p - lr * weight_decay * p + p = _mm256_fmadd_ps( + avx_neg_lr, + _mm256_div_ps( + _mm256_div_ps(m, avx_bias_correction1), // m / bias_correction1 + _mm256_add_ps(_mm256_sqrt_ps(_mm256_div_ps(v, avx_bias_correction2)), avx_eps) // sqrt(v / bias_correction2) + eps + ), // m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + p + ); // p = p - lr * m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + _mm256_storeu_ps(¶m_fp32_ptr[j], p); + _mm_storeu_si128((__m128i*)¶m_fp16_ptr[j], _mm256_cvtps_ph(p, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + _mm256_storeu_ps(&m_fp32_ptr[j], m); + _mm256_storeu_ps(&v_fp32_ptr[j], v); + } + } + if (delta_info_ptr != NULL){ + delta_mutex.lock(); + sum_delta += sum_delta_i; + sum_sq_delta += sum_sq_delta_i; + delta_mutex.unlock(); + } + }); + if (delta_info_ptr != NULL){ + delta_info_ptr[0] = sum_delta / n; + delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 + delta_info_ptr[2] = sum_delta; + delta_info_ptr[3] = sum_sq_delta; + } +} + +static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2( + int64_t n, + float* param_fp32_ptr, + uint16_t* param_fp16_ptr, + float* delta_info_ptr, + uint16_t* g_fp16_ptr, + float* m_fp32_ptr, + float* v_fp32_ptr, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +){ + float sum_sq_delta = 0; + float sum_delta = 0; + std::mutex delta_mutex; + auto avx_beta1 = _mm512_set1_ps(beta1); + auto avx_beta2 = _mm512_set1_ps(beta2); + auto avx_beta1_1 = _mm512_set1_ps(1 - beta1); + auto avx_beta2_1 = _mm512_set1_ps(1 - beta2); + auto avx_eps = _mm512_set1_ps(eps); + auto avx_neg_lr = _mm512_set1_ps(-lr); + auto avx_scale = _mm512_set1_ps(scale); + auto avx_weight_decay = _mm512_set1_ps(weight_decay); + auto avx_bias_correction1 = _mm512_set1_ps(bias_correction1); + auto avx_bias_correction2 = _mm512_set1_ps(bias_correction2); + int64_t span = 16; + parallel_for(0, n, 0, [&](int64_t start, int64_t end) { + float sum_sq_delta_i = 0; + float sum_delta_i = 0; + for (int64_t j = start; j < end; j += span) { + if (j + span > end) { + for (int64_t i = j; i < end; i++) { + float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale; + float m = m_fp32_ptr[i]; + float v = v_fp32_ptr[i]; + float p = param_fp32_ptr[i]; + m = beta1 * m + (1 - beta1) * g; + v = beta2 * v + (1 - beta2) * g * g; + if (delta_info_ptr != NULL){ + float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p; + sum_delta_i += delta; + sum_sq_delta_i += delta * delta; + } + p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p; + param_fp32_ptr[i] = p; + param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p); + m_fp32_ptr[i] = m; + v_fp32_ptr[i] = v; + } + break; // must break here + }else{ + auto g = _mm512_div_ps(_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)&g_fp16_ptr[j])), avx_scale); + auto m = _mm512_loadu_ps(&m_fp32_ptr[j]); + auto v = _mm512_loadu_ps(&v_fp32_ptr[j]); + auto p = _mm512_loadu_ps(¶m_fp32_ptr[j]); + m = _mm512_fmadd_ps(avx_beta1, m, _mm512_mul_ps(avx_beta1_1, g)); + v = _mm512_fmadd_ps(avx_beta2, v, _mm512_mul_ps(avx_beta2_1, _mm512_mul_ps(g, g))); + if (delta_info_ptr != NULL){ + auto delta_512 = _mm512_add_ps( + _mm512_div_ps( + _mm512_div_ps(m, avx_bias_correction1), // m / bias_correction1 + _mm512_add_ps(_mm512_sqrt_ps(_mm512_div_ps(v, avx_bias_correction2)), avx_eps) // sqrt(v / bias_correction2) + eps + ), // m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + _mm512_mul_ps(avx_weight_decay, p) // weight_decay * p + ); // delta = m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + weight_decay * p + sum_delta_i += _mm512_reduce_add_ps(delta_512); + sum_sq_delta_i += _mm512_reduce_add_ps(_mm512_mul_ps(delta_512, delta_512)); + } + p = _mm512_fmadd_ps(avx_neg_lr, _mm512_mul_ps(avx_weight_decay, p), p); // p = p - lr * weight_decay * p + p = _mm512_fmadd_ps( + avx_neg_lr, + _mm512_div_ps( + _mm512_div_ps(m, avx_bias_correction1), // m / bias_correction1 + _mm512_add_ps( + _mm512_sqrt_ps(_mm512_div_ps(v, avx_bias_correction2)), + avx_eps + ) // sqrt(v / bias_correction2) + eps + ), + p + ); // p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + _mm512_storeu_ps(¶m_fp32_ptr[j], p); + _mm256_storeu_si256((__m256i*)¶m_fp16_ptr[j], _mm512_cvtps_ph(p, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + _mm512_storeu_ps(&m_fp32_ptr[j], m); + _mm512_storeu_ps(&v_fp32_ptr[j], v); + } + } + if (delta_info_ptr != NULL){ + delta_mutex.lock(); + sum_delta += sum_delta_i; + sum_sq_delta += sum_sq_delta_i; + delta_mutex.unlock(); + } + }); + if (delta_info_ptr != NULL){ + delta_info_ptr[0] = sum_delta / n; + delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2 + delta_info_ptr[2] = sum_delta; + delta_info_ptr[3] = sum_sq_delta; + } +} + +void adam_cpu_fp16_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_fp16, + std::uintptr_t delta_info, + std::uintptr_t g_fp16, + std::uintptr_t m_fp32, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +) { + auto delta_info_ptr = reinterpret_cast<float*>(delta_info); + auto param_fp32_ptr = reinterpret_cast<float*>(param_fp32); + auto m_fp32_ptr = reinterpret_cast<float*>(m_fp32); + auto v_fp32_ptr = reinterpret_cast<float*>(v_fp32); + auto param_fp16_ptr = reinterpret_cast<uint16_t*>(param_fp16); + auto g_fp16_ptr = reinterpret_cast<uint16_t*>(g_fp16); + int cpu_level = get_cpu_level(); + if (cpu_level == 0 ){ + adam_cpu_0(n, param_fp32_ptr, param_fp16_ptr, delta_info_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); + }else if(cpu_level == 1){ + adam_cpu_1(n, param_fp32_ptr, param_fp16_ptr, delta_info_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); + }else{ + adam_cpu_2(n, param_fp32_ptr, param_fp16_ptr, delta_info_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); + } +} + +void adam_cpu_bf16_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_bf16, + std::uintptr_t delta_info, + std::uintptr_t g_bf16, + std::uintptr_t m_fp32, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2 +) { + auto delta_info_ptr = reinterpret_cast<float*>(delta_info); + auto m_fp32_ptr = reinterpret_cast<float*>(m_fp32); + auto v_fp32_ptr = reinterpret_cast<float*>(v_fp32); + auto param_fp32_ptr = reinterpret_cast<float*>(param_fp32); + auto param_bf16_ptr = reinterpret_cast<uint16_t*>(param_bf16); + auto g_bf16_ptr = reinterpret_cast<uint16_t*>(g_bf16); + adam_cpu_bf16_0(n, param_fp32_ptr, param_bf16_ptr, delta_info_ptr, g_bf16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2); +} diff --git a/examples/BMTrain/csrc/include/bind.hpp b/examples/BMTrain/csrc/include/bind.hpp new file mode 100644 index 00000000..3ff967fd --- /dev/null +++ b/examples/BMTrain/csrc/include/bind.hpp @@ -0,0 +1,111 @@ +#include <pybind11/pybind11.h> +#include "nccl.hpp" +#include "adam_cpu.hpp" + +int is_bf16_supported(); + +void has_nan_inf_fp16_launcher(int32_t n, std::uintptr_t g_fp16, std::uintptr_t mid, std::uintptr_t out, std::uintptr_t stream); +void has_nan_inf_bf16_launcher(int32_t n, std::uintptr_t g_bf16, std::uintptr_t mid, std::uintptr_t out, std::uintptr_t stream); + +void fp16_from_fp32_value_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_fp16 +); +void bf16_from_fp32_value_launcher( + int64_t n, + std::uintptr_t param_fp32, + std::uintptr_t param_bf16 +); +void cross_entropy_forward_fp16_launcher( + int32_t m, int32_t n, + std::uintptr_t input, + std::uintptr_t target, + std::uintptr_t softmax, + std::uintptr_t output, + int32_t ignore_index, + std::uintptr_t stream +); +void cross_entropy_backward_inplace_fp16_launcher( + int32_t m, int32_t n, + std::uintptr_t grad_output, + std::uintptr_t target, + std::uintptr_t x, + int32_t ignore_index, + std::uintptr_t stream +); +void cross_entropy_forward_bf16_launcher( + int32_t m, int32_t n, + std::uintptr_t input, + std::uintptr_t target, + std::uintptr_t softmax, + std::uintptr_t output, + int32_t ignore_index, + std::uintptr_t stream +); +void cross_entropy_backward_inplace_bf16_launcher( + int32_t m, int32_t n, + std::uintptr_t grad_output, + std::uintptr_t target, + std::uintptr_t x, + int32_t ignore_index, + std::uintptr_t stream +); +void fused_sumexp_fp16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +); +void fused_sumexp_bf16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +); +void fused_softmax_inplace_fp16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +); +void fused_softmax_inplace_bf16_launcher( + int32_t m, int32_t n, + std::uintptr_t logits, + std::uintptr_t max_logits, + std::uintptr_t sum_exp_logits, + std::uintptr_t stream +); +void adam_fp16_launcher( + int n, + std::uintptr_t param_fp32, + std::uintptr_t param_fp16, + std::uintptr_t g_fp16, + std::uintptr_t m_fp16, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2, + uintptr_t stream +); +void adam_bf16_launcher( + int n, + std::uintptr_t param_fp32, + std::uintptr_t param_bf16, + std::uintptr_t g_bf16, + std::uintptr_t m_fp32, + std::uintptr_t v_fp32, + float beta1, float beta2, + float eps, float lr, + float scale, + float weight_decay, + float bias_correction1, + float bias_correction2, + uintptr_t stream +); diff --git a/examples/BMTrain/csrc/include/cpu_info.h b/examples/BMTrain/csrc/include/cpu_info.h new file mode 100644 index 00000000..53ed48f8 --- /dev/null +++ b/examples/BMTrain/csrc/include/cpu_info.h @@ -0,0 +1,38 @@ +#include <cpuid.h> + +static void cpuid(int info[4], int InfoType){ + __cpuid_count(InfoType, 0, info[0], info[1], info[2], info[3]); +} + +int get_cpu_level() { + // SIMD: 128-bit + bool HW_F16C; + + // SIMD: 256-bit + bool HW_AVX; + bool HW_FMA; + + // SIMD: 512-bit + bool HW_AVX512F; // AVX512 Foundation + + int info[4]; + cpuid(info, 0); + int nIds = info[0]; + + // Detect Features + if (nIds >= 0x00000001){ + cpuid(info,0x00000001); + HW_AVX = (info[2] & ((int)1 << 28)) != 0; + HW_FMA = (info[2] & ((int)1 << 12)) != 0; + HW_F16C = (info[2] & ((int)1 << 29)) != 0; + } + if (nIds >= 0x00000007){ + cpuid(info,0x00000007); + HW_AVX512F = (info[1] & ((int)1 << 16)) != 0; + } + + int ret = 0; + if (HW_AVX && HW_FMA && HW_F16C) ret = 1; + if (HW_AVX512F) ret = 2; + return ret; +} diff --git a/examples/BMTrain/csrc/include/nccl.hpp b/examples/BMTrain/csrc/include/nccl.hpp new file mode 100644 index 00000000..bba0278b --- /dev/null +++ b/examples/BMTrain/csrc/include/nccl.hpp @@ -0,0 +1,188 @@ +#include <cstdint> +#include <string> +#include <pybind11/pybind11.h> + +namespace py = pybind11; +#include <nccl.h> + +void checkNCCLStatus(ncclResult_t result) { + if (result == ncclSuccess) return; + throw std::logic_error( + std::string("NCCL Error: ") + + ncclGetErrorString(result) + ); +} + +py::bytes pyNCCLGetUniqueID() { + ncclUniqueId uniqueID; + checkNCCLStatus(ncclGetUniqueId(&uniqueID)); + return py::bytes(uniqueID.internal, NCCL_UNIQUE_ID_BYTES); +} + +std::uintptr_t pyNCCLCommInitRank(py::bytes byteUniqueID, int world_size, int rank) { + ncclUniqueId uniqueID; + std::memcpy(uniqueID.internal, std::string(byteUniqueID).c_str(), NCCL_UNIQUE_ID_BYTES); + ncclComm_t comm; + checkNCCLStatus(ncclCommInitRank(&comm, world_size, uniqueID, rank)); + return reinterpret_cast<std::uintptr_t>(comm); +} + +void pyNCCLCommDestroy(std::uintptr_t ptrcomm) { + ncclComm_t comm = reinterpret_cast<ncclComm_t>(ptrcomm); + checkNCCLStatus(ncclCommDestroy(comm)); +} + +void pyNCCLAllGather( + std::uintptr_t sendbuff, + std::uintptr_t recvbuff, + size_t sendcount, + int datatype, + std::uintptr_t comm, + std::uintptr_t stream +) { + checkNCCLStatus(ncclAllGather( + reinterpret_cast<void*>(sendbuff), + reinterpret_cast<void*>(recvbuff), + sendcount, + static_cast<ncclDataType_t>(datatype), + reinterpret_cast<ncclComm_t>(comm), + reinterpret_cast<cudaStream_t>(stream) + )); +} + +void pyNCCLAllReduce( + std::uintptr_t sendbuff, + std::uintptr_t recvbuff, + size_t count, + int data_type, + int op, + std::uintptr_t comm, + std::uintptr_t stream +) { + checkNCCLStatus(ncclAllReduce( + reinterpret_cast<void*>(sendbuff), + reinterpret_cast<void*>(recvbuff), + count, + static_cast<ncclDataType_t>(data_type), + static_cast<ncclRedOp_t>(op), + reinterpret_cast<ncclComm_t>(comm), + reinterpret_cast<cudaStream_t>(stream) + )); +} + +void pyNCCLBroadcast( + std::uintptr_t sendbuff, + std::uintptr_t recvbuff, + size_t count, + int datatype, + int root, + std::uintptr_t comm, + std::uintptr_t stream +) { + checkNCCLStatus(ncclBroadcast( + reinterpret_cast<void*>(sendbuff), + reinterpret_cast<void*>(recvbuff), + count, + static_cast<ncclDataType_t>(datatype), + root, + reinterpret_cast<ncclComm_t>(comm), + reinterpret_cast<cudaStream_t>(stream) + )); +} + +void pyNCCLReduce( + std::uintptr_t sendbuff, + std::uintptr_t recvbuff, + size_t count, + int datatype, + int op, + int root, + std::uintptr_t comm, + std::uintptr_t stream +) { + checkNCCLStatus(ncclReduce( + reinterpret_cast<void*>(sendbuff), + reinterpret_cast<void*>(recvbuff), + count, + static_cast<ncclDataType_t>(datatype), + static_cast<ncclRedOp_t>(op), + root, + reinterpret_cast<ncclComm_t>(comm), + reinterpret_cast<cudaStream_t>(stream) + )); +} + +void pyNCCLReduceScatter( + std::uintptr_t sendbuff, + std::uintptr_t recvbuff, + size_t recvcount, + int datatype, + int op, + std::uintptr_t comm, + std::uintptr_t stream +) { + checkNCCLStatus(ncclReduceScatter( + reinterpret_cast<void*>(sendbuff), + reinterpret_cast<void*>(recvbuff), + recvcount, + static_cast<ncclDataType_t>(datatype), + static_cast<ncclRedOp_t>(op), + reinterpret_cast<ncclComm_t>(comm), + reinterpret_cast<cudaStream_t>(stream) + )); +} +void pyNCCLSend( + std::uintptr_t sendbuff, + size_t sendcount, + int data_type, + int peer, + std::uintptr_t comm, + std::uintptr_t stream +) { + checkNCCLStatus(ncclSend( + reinterpret_cast<void*>(sendbuff), + sendcount, + static_cast<ncclDataType_t>(data_type), + peer, + reinterpret_cast<ncclComm_t>(comm), + reinterpret_cast<cudaStream_t>(stream) + )); +} +void pyNCCLRecv( + std::uintptr_t recvbuff, + size_t recvcount, + int data_type, + int peer, + std::uintptr_t comm, + std::uintptr_t stream +) { + checkNCCLStatus(ncclRecv( + reinterpret_cast<void*>(recvbuff), + recvcount, + static_cast<ncclDataType_t>(data_type), + peer, + reinterpret_cast<ncclComm_t>(comm), + reinterpret_cast<cudaStream_t>(stream) + )); +} +void pyNCCLGroupStart() { + checkNCCLStatus(ncclGroupStart()); +} + +void pyNCCLGroupEnd() { + checkNCCLStatus(ncclGroupEnd()); +} +int pyNCCLCommCount( + std::uintptr_t comm +){ + int res; + checkNCCLStatus(ncclCommCount(reinterpret_cast<ncclComm_t>(comm),&res)); + return res; +} +int pyNCCLCommUserRank( + std::uintptr_t comm +){ + int rank; + checkNCCLStatus(ncclCommUserRank(reinterpret_cast<ncclComm_t>(comm),&rank)); + return rank; +} diff --git a/examples/BMTrain/doc_requirements.txt b/examples/BMTrain/doc_requirements.txt new file mode 100644 index 00000000..79d22ca0 --- /dev/null +++ b/examples/BMTrain/doc_requirements.txt @@ -0,0 +1,5 @@ +sphinx>=4.0.0 +recommonmark +sphinx_markdown_tables +sphinx_rtd_theme>=0.3.0 +torch \ No newline at end of file diff --git a/examples/BMTrain/docs/Makefile b/examples/BMTrain/docs/Makefile new file mode 100644 index 00000000..4f2fbe66 --- /dev/null +++ b/examples/BMTrain/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source-en +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/examples/BMTrain/docs/UPDATE_0.2.0.md b/examples/BMTrain/docs/UPDATE_0.2.0.md new file mode 100644 index 00000000..92819afd --- /dev/null +++ b/examples/BMTrain/docs/UPDATE_0.2.0.md @@ -0,0 +1,79 @@ +# Update Log 0.2.0 + +## What's New + +### 1. Added an `Optimizer Manager` to support various optimizer algorithms. + +Before 0.2.0, the `optimizer` was strongly coupled to the "loss scaler". This results in users cannot use multiple optimizers at the same time when training model in fp16. + +**======= Before 0.2.0 =======** + +```python +for iteration in range(1000): + # zero grad + optimizer.zero_grad() + + # ... + # loss scale and backward + loss = optimizer.loss_scale(loss) + loss.backward() + + # optimizer step + bmtrain.optim_step(optimizer, lr_scheduler) +``` + +The `bmtrain.optim_step` allows only one `optimizer` and at most one `lr_schduler`, which cannot handle some more complex scenarios. + + +**======= After 0.2.0 =======** + +```python +# create a new instance of optimizer manager +optim_manager = bmtrain.optim.OptimManager(loss_scale=1024) +# let optim_manager handle all the optimizer and (optional) their corresponding lr_scheduler +optim_manager.add_optimizer(optimizer, lr_scheduler) +# add_optimizer can be called multiple times to add other optimizers. + +for iteration in range(1000): + # zero grad + optim_manager.zero_grad() # calling zero_grad for each optimizer + + # ... + # loss scale and backward + optim_manager.backward(loss) + + # optimizer step + optim_manager.step() +``` + +Starting from BMTrain 0.2.0, we provide "OptimManager" to manage optimizers and loss scales. +`OptimManager` supports managing multiple optimizers and lr_schedulers at the same time, and allows setting the loss scale independently. +`OptimManager` can also manage pytorch native optimizers, such as SGD, AdamW, etc. + +### 2. Pipeline Parallelism + +In this version, BMTrain has added a new kind of parallel algorithm: pipeline parallelism. +To enable pipeline parallelism, one line of code needs to be modified. + +**======= ZeRO =======** +```python +layers = bmt.TransformerBlockList([ + # ... +]) +``` + +**======= Pipeline =======** +```python +layers = bmt.PipelineTransformerBlockList([ + # ... +]) +``` + +Replacing TransformerBlockList with PipelineTransformerBlockList allows the parallel algorithm to switch from ZeRO to pipeline parallelism. +The number of stages in the pipeline can be set by passing the `pipe_size` parameter to bmtrain.init_distributed. + +### 3. Others + +* Supports BF16. +* Tensors recorded in inspector supports backward propagation. +* Adds new tests. diff --git a/examples/BMTrain/docs/UPDATE_0.2.3.md b/examples/BMTrain/docs/UPDATE_0.2.3.md new file mode 100644 index 00000000..e95c6867 --- /dev/null +++ b/examples/BMTrain/docs/UPDATE_0.2.3.md @@ -0,0 +1,26 @@ +# Update Log 0.2.3 + +**Full Changelog**: https://github.com/OpenBMB/BMTrain/compare/0.2.0...0.2.3 + + +## What's New + +### 1. Get rid of torch cpp extension when compiling + +Before 0.2.3, the installation of BMTrain requires the torch cpp extension, which is not friendly to some users (it requires CUDA Runtime fits with torch). Now we get rid of the torch cpp extension when compiling BMTrain, which makes the source-code way installation of BMTrain more convenient. +Just run `pip install .` to install BMTrain using source code. + +### 2. CICD + +In 0.2.3, we bring the Github action CICD to BMTrain. Now we can run the CI/CD pipeline on Github to ensure the quality of the code. CICD will run the test cases and compile the source code into wheel packages. + +### 3. Loss scale management + +In 0.2.3, we add the min and max loss scale to the loss scale manager. The loss scale manager can adjust the loss scale dynamically according to the loss scale's min and max value. This feature can help users to avoid the loss scale being too large or too small. + + +### 3. Others + +* Fix `bmt.load(model)` OOM when meets torch >= 1.12 +* `AdamOffloadOptimizer` can choose avx flag automatically in runtime +* Now BMTrain is fully compatible with torch 2.0 diff --git a/examples/BMTrain/docs/UPDATE_1.0.0.md b/examples/BMTrain/docs/UPDATE_1.0.0.md new file mode 100644 index 00000000..da9fe86e --- /dev/null +++ b/examples/BMTrain/docs/UPDATE_1.0.0.md @@ -0,0 +1,72 @@ +# Update Log 1.0.0 + +**Full Changelog**: https://github.com/OpenBMB/BMTrain/compare/0.2.3...1.0.0 + +## What's New + +### 1. Using pytorch's hook mechanism to refactor ZeRO, checkpoint, pipeline, communication implementation + +Now user can specify zero level of each `bmt.CheckpointBlock`. + +**======= Before 1.0.0 =======** + +```python +import bmtrain as bmt +bmt.init_distributed(zero_level=3) + +``` + +The zero level setting can only set globally and computation checkpointing can not be disabled. +For `bmt.TransformerBlockList`, it has to call a blocklist forward instead of a loop way + +**======= After 1.0.0 =======** + +```python +import bmtrain as bmt +bmt.init_distributed() +# construct block +class Transformer(bmt.DistributedModule): + def __init__(self, + num_layers : int) -> None: + super().__init__() + + self.transformers = bmt.TransformerBlockList([ + bmt.Block( + TransformerEncoder( + dim_model, dim_head, num_heads, dim_ff, bias, dtype + ), use_checkpoint=True, zero_level=3 + ) + for _ in range(num_layers) + ]) + + def forward(self): + # return self.transformers(x) v0.2.3 can only forward in this way + for block in self.transformers: + x = block(x) + return x + +``` + +You can specify the zero level of each `bmt.CheckpointBlock` (alias of `bmt.Block`) and computation checkpointing can be disabled by setting `use_checkpoint=False` . For `bmt.TransformerBlockList`, it can be called in a loop way. + + +### 2. Add Bf16 support + +Now BMTrain supports Bf16 training. You can simply use `dtype=torch.bfloat16' in your model construction method and BMTrain will handle the rest. + +### 3. Tensor parallel implementation + +For this part, BMTrain only provides a series of parallel ops for Tensor parallel implementation, including `bmt.nn.OpParallelLinear` and `bmt.nn.VPEmbedding` . We also provide a Tensor Parallel training example in our training example. You can simply use `bmt.init_distributed(tp_size=4)` to enable a 4-way tensor parallel training. + +### 4. `AdamOffloadOptimizer` can save whole gathered state + +Now `AdamOffloadOptimizer` can save whole gathered state. This feature can help users to save the whole gathered state of the optimizer, which can be used to resume training from the saved state. For better performance, we provide async-way save state_dict to overlap I/O and computation. +```python +import bmtrain as bmt +# you can enbale this feature in two ways: Optimmanager's or optimizer's interface +global_ckpt = bmt.optim.Optimmanager.state_dict(gather_opt=True) +global_ckpt = optimizer.state_dict(gather=True) +``` +### Others + +* New test for new version BMTrain \ No newline at end of file diff --git a/examples/BMTrain/docs/logo.png b/examples/BMTrain/docs/logo.png new file mode 100644 index 00000000..2dd2f1cb Binary files /dev/null and b/examples/BMTrain/docs/logo.png differ diff --git a/examples/BMTrain/docs/make.bat b/examples/BMTrain/docs/make.bat new file mode 100644 index 00000000..6fcf05b4 --- /dev/null +++ b/examples/BMTrain/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/examples/BMTrain/docs/source-en/_static/css/custom.css b/examples/BMTrain/docs/source-en/_static/css/custom.css new file mode 100644 index 00000000..1e3a643c --- /dev/null +++ b/examples/BMTrain/docs/source-en/_static/css/custom.css @@ -0,0 +1,124 @@ +a, +.wy-menu-vertical header, +.wy-menu-vertical p.caption, +.wy-nav-top .fa-bars, +.wy-menu-vertical a:hover, + +.rst-content code.literal, .rst-content tt.literal + +{ + color: #315EFE !important; +} + +/* inspired by sphinx press theme */ +.wy-menu.wy-menu-vertical li.toctree-l1.current > a { + border-left: solid 8px #315EFE !important; + border-top: none; + border-bottom: none; +} + +.wy-menu.wy-menu-vertical li.toctree-l1.current > ul { + border-left: solid 8px #315EFE !important; +} +/* inspired by sphinx press theme */ + +.wy-nav-side { + color: unset !important; + background: unset !important; + border-right: solid 1px #ccc !important; +} + +.wy-side-nav-search, +.wy-nav-top, +.wy-menu-vertical li, +.wy-menu-vertical li a:hover, +.wy-menu-vertical li a +{ + background: unset !important; +} + +.wy-menu-vertical li.current a { + border-right: unset !important; +} + +.wy-side-nav-search div, +.wy-menu-vertical a { + color: #404040 !important; +} + +.wy-menu-vertical button.toctree-expand { + color: #333 !important; +} + +.wy-nav-content { + max-width: unset; +} + +.rst-content { + max-width: 900px; +} + +.wy-nav-content .icon-home:before { + content: "Docs"; +} + +.wy-side-nav-search .icon-home:before { + content: ""; +} + +dl.field-list { + display: block !important; +} + +dl.field-list > dt:after { + content: "" !important; +} + +:root { + --dark-blue: #3260F7; + --light-blue: rgba(194, 233, 248, 0.1) ; +} + +dl.field-list > dt { + display: table; + padding-left: 6px !important; + padding-right: 6px !important; + margin-bottom: 4px !important; + padding-bottom: 1px !important; + background: var(--light-blue); + border-left: solid 2px var(--dark-blue); +} + + +dl.py.class>dt +{ + color: rgba(17, 16, 17, 0.822) !important; + background: var(--light-blue) !important; + border-top: solid 2px var(--dark-blue) !important; +} + +dl.py.method>dt +{ + background: var(--light-blue) !important; + border-left: solid 2px var(--dark-blue) !important; +} + +dl.py.attribute>dt, +dl.py.property>dt +{ + background: var(--light-blue) !important; + border-left: solid 2px var(--dark-blue) !important; +} + +.fa-plus-square-o::before, .wy-menu-vertical li button.toctree-expand::before, +.fa-minus-square-o::before, .wy-menu-vertical li.current > a button.toctree-expand::before, .wy-menu-vertical li.on a button.toctree-expand::before +{ + content: ""; +} + +.rst-content .viewcode-back, +.rst-content .viewcode-link +{ + color:#58b5cc; + font-size: 120%; +} \ No newline at end of file diff --git a/examples/BMTrain/docs/source-en/_static/js/custom.js b/examples/BMTrain/docs/source-en/_static/js/custom.js new file mode 100644 index 00000000..489b7d5c --- /dev/null +++ b/examples/BMTrain/docs/source-en/_static/js/custom.js @@ -0,0 +1,7 @@ +document.addEventListener("DOMContentLoaded", function(event) { + document.querySelectorAll(".wy-menu.wy-menu-vertical > ul.current > li > a").forEach(a => a.addEventListener("click", e=>{ + f = document.querySelector(".wy-menu.wy-menu-vertical > ul.current > li > ul") + if (f.style.display=='none') { f.style.display='block'; } else f.style.display = 'none' + })); + document.querySelectorAll(".headerlink").forEach(a => a.text="\u{1F517}"); +}); \ No newline at end of file diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.benchmark.rst_bk b/examples/BMTrain/docs/source-en/api/bmtrain.benchmark.rst_bk new file mode 100644 index 00000000..f8b2902d --- /dev/null +++ b/examples/BMTrain/docs/source-en/api/bmtrain.benchmark.rst_bk @@ -0,0 +1,53 @@ +bmtrain.benchmark package +========================= + +Submodules +---------- + +bmtrain.benchmark.all\_gather module +------------------------------------ + +.. automodule:: bmtrain.benchmark.all_gather + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.reduce\_scatter module +---------------------------------------- + +.. automodule:: bmtrain.benchmark.reduce_scatter + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.send\_recv module +----------------------------------- + +.. automodule:: bmtrain.benchmark.send_recv + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.shape module +------------------------------ + +.. automodule:: bmtrain.benchmark.shape + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.utils module +------------------------------ + +.. automodule:: bmtrain.benchmark.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.benchmark + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.distributed.rst_bk b/examples/BMTrain/docs/source-en/api/bmtrain.distributed.rst_bk new file mode 100644 index 00000000..ef41db07 --- /dev/null +++ b/examples/BMTrain/docs/source-en/api/bmtrain.distributed.rst_bk @@ -0,0 +1,21 @@ +bmtrain.distributed package +=========================== + +Submodules +---------- + +bmtrain.distributed.ops module +------------------------------ + +.. automodule:: bmtrain.distributed.ops + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.distributed + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.inspect.rst b/examples/BMTrain/docs/source-en/api/bmtrain.inspect.rst new file mode 100644 index 00000000..c57195ad --- /dev/null +++ b/examples/BMTrain/docs/source-en/api/bmtrain.inspect.rst @@ -0,0 +1,37 @@ +bmtrain.inspect package +======================= + +Submodules +---------- + +bmtrain.inspect.format module +----------------------------- + +.. automodule:: bmtrain.inspect.format + :members: format_summary + :undoc-members: + :show-inheritance: + +bmtrain.inspect.model module +---------------------------- + +.. automodule:: bmtrain.inspect.model + :members: inspect_model + :undoc-members: + :show-inheritance: + +bmtrain.inspect.tensor module +----------------------------- + +.. automodule:: bmtrain.inspect.tensor + :members: inspect_tensor, InspectTensor + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.inspect + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.loss.rst b/examples/BMTrain/docs/source-en/api/bmtrain.loss.rst new file mode 100644 index 00000000..03b65646 --- /dev/null +++ b/examples/BMTrain/docs/source-en/api/bmtrain.loss.rst @@ -0,0 +1,21 @@ +bmtrain.loss package +==================== + +Submodules +---------- + +bmtrain.loss.cross\_entropy module +---------------------------------- + +.. automodule:: bmtrain.loss.cross_entropy + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.loss + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.lr_scheduler.rst b/examples/BMTrain/docs/source-en/api/bmtrain.lr_scheduler.rst new file mode 100644 index 00000000..0ba033af --- /dev/null +++ b/examples/BMTrain/docs/source-en/api/bmtrain.lr_scheduler.rst @@ -0,0 +1,61 @@ +bmtrain.lr\_scheduler package +============================= + +Submodules +---------- + +bmtrain.lr\_scheduler.cosine module +----------------------------------- + +.. automodule:: bmtrain.lr_scheduler.cosine + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.exponential module +---------------------------------------- + +.. automodule:: bmtrain.lr_scheduler.exponential + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.linear module +----------------------------------- + +.. automodule:: bmtrain.lr_scheduler.linear + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.no\_decay module +-------------------------------------- + +.. automodule:: bmtrain.lr_scheduler.no_decay + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.noam module +--------------------------------- + +.. automodule:: bmtrain.lr_scheduler.noam + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.warmup module +----------------------------------- + +.. automodule:: bmtrain.lr_scheduler.warmup + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.lr_scheduler + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.nccl.rst_bk b/examples/BMTrain/docs/source-en/api/bmtrain.nccl.rst_bk new file mode 100644 index 00000000..3755d9ef --- /dev/null +++ b/examples/BMTrain/docs/source-en/api/bmtrain.nccl.rst_bk @@ -0,0 +1,21 @@ +bmtrain.nccl package +==================== + +Submodules +---------- + +bmtrain.nccl.enums module +------------------------- + +.. automodule:: bmtrain.nccl.enums + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.nccl + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.nn.rst b/examples/BMTrain/docs/source-en/api/bmtrain.nn.rst new file mode 100644 index 00000000..8e2a531f --- /dev/null +++ b/examples/BMTrain/docs/source-en/api/bmtrain.nn.rst @@ -0,0 +1,53 @@ +bmtrain.nn package +================== + +Submodules +---------- + +bmtrain.nn.column\_parallel\_linear module +------------------------------------------ + +.. automodule:: bmtrain.nn.column_parallel_linear + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.linear module +------------------------ + +.. automodule:: bmtrain.nn.linear + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.parallel\_embedding module +------------------------------------- + +.. automodule:: bmtrain.nn.parallel_embedding + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.parallel\_linear\_func module +---------------------------------------- + +.. automodule:: bmtrain.nn.parallel_linear_func + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.row\_parallel\_linear module +--------------------------------------- + +.. automodule:: bmtrain.nn.row_parallel_linear + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.nn + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.optim.rst b/examples/BMTrain/docs/source-en/api/bmtrain.optim.rst new file mode 100644 index 00000000..2d47a3dd --- /dev/null +++ b/examples/BMTrain/docs/source-en/api/bmtrain.optim.rst @@ -0,0 +1,37 @@ +bmtrain.optim package +===================== + +Submodules +---------- + +bmtrain.optim.adam module +------------------------- + +.. automodule:: bmtrain.optim.adam + :members: + :undoc-members: + :show-inheritance: + +bmtrain.optim.adam\_offload module +---------------------------------- + +.. automodule:: bmtrain.optim.adam_offload + :members: + :undoc-members: + :show-inheritance: + +bmtrain.optim.optim\_manager module +----------------------------------- + +.. automodule:: bmtrain.optim.optim_manager + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.optim + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.rst b/examples/BMTrain/docs/source-en/api/bmtrain.rst new file mode 100644 index 00000000..8445e5f0 --- /dev/null +++ b/examples/BMTrain/docs/source-en/api/bmtrain.rst @@ -0,0 +1,140 @@ +bmtrain package +=============== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + bmtrain.benchmark + bmtrain.distributed + bmtrain.inspect + bmtrain.loss + bmtrain.lr_scheduler + bmtrain.nccl + bmtrain.nn + bmtrain.optim + +Submodules +---------- + +bmtrain.block\_layer module +--------------------------- + +.. automodule:: bmtrain.block_layer + :members: + :undoc-members: + :show-inheritance: + +.. bmtrain.debug module +.. -------------------- + +.. .. automodule:: bmtrain.debug +.. :members: +.. :undoc-members: +.. :show-inheritance: + +bmtrain.global\_var module +-------------------------- + +.. automodule:: bmtrain.global_var + :members: + :undoc-members: + :show-inheritance: + +.. bmtrain.hook\_func module +.. ------------------------- + +.. .. automodule:: bmtrain.hook_func +.. :members: +.. :undoc-members: +.. :show-inheritance: + +bmtrain.init module +------------------- + +.. automodule:: bmtrain.init + :members: + :undoc-members: + :show-inheritance: + +bmtrain.layer module +-------------------- + +.. automodule:: bmtrain.layer + :members: + :undoc-members: + :show-inheritance: + +bmtrain.param\_init module +-------------------------- + +.. automodule:: bmtrain.param_init + :members: + :undoc-members: + :show-inheritance: + +bmtrain.parameter module +------------------------ + +.. automodule:: bmtrain.parameter + :members: DistributedParameter, ParameterInitializer + :undoc-members: + :show-inheritance: + +bmtrain.pipe\_layer module +-------------------------- + +.. automodule:: bmtrain.pipe_layer + :members: PipelineTransformerBlockList + :undoc-members: + :show-inheritance: + +bmtrain.store module +-------------------- + +.. automodule:: bmtrain.store + :members: save, load + :undoc-members: + :show-inheritance: + +bmtrain.synchronize module +-------------------------- + +.. automodule:: bmtrain.synchronize + :members: + :undoc-members: + :show-inheritance: + +bmtrain.utils module +-------------------- + +.. automodule:: bmtrain.utils + :members: + :undoc-members: + :show-inheritance: + +bmtrain.wrapper module +---------------------- + +.. automodule:: bmtrain.wrapper + :members: BMTrainModelWrapper + :undoc-members: + :show-inheritance: + +bmtrain.zero\_context module +---------------------------- + +.. automodule:: bmtrain.zero_context + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source-en/api/modules.rst b/examples/BMTrain/docs/source-en/api/modules.rst new file mode 100644 index 00000000..4350b5d7 --- /dev/null +++ b/examples/BMTrain/docs/source-en/api/modules.rst @@ -0,0 +1,7 @@ +bmtrain +======= + +.. toctree:: + :maxdepth: 4 + + bmtrain diff --git a/examples/BMTrain/docs/source-en/conf.py b/examples/BMTrain/docs/source-en/conf.py new file mode 100644 index 00000000..6351767a --- /dev/null +++ b/examples/BMTrain/docs/source-en/conf.py @@ -0,0 +1,79 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +sys.path.insert(0, os.path.abspath("../../..")) + +import recommonmark +from recommonmark.transform import AutoStructify + + +# -- Project information ----------------------------------------------------- + +project = "BMTrain" +copyright = "2022, OpenBMB" +author = "BMTrain Team" +autodoc_mock_imports = [ + "numpy", + "tensorboard", + "bmtrain.nccl._C", + "bmtrain.optim._cpu", + "bmtrain.optim._cuda", + "bmtrain.loss._cuda", +] +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.napoleon", + "sphinx.ext.mathjax", + "recommonmark", + "sphinx_markdown_tables", +] + +source_suffix = [".rst", ".md"] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = "en" + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "sphinx_rtd_theme" + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ["_static"] +# html_stype="css/custom.css" +html_css_files = ["css/custom.css"] +html_js_files = ["js/custom.js"] diff --git a/examples/BMTrain/docs/source-en/index.rst b/examples/BMTrain/docs/source-en/index.rst new file mode 100644 index 00000000..a2ec24f9 --- /dev/null +++ b/examples/BMTrain/docs/source-en/index.rst @@ -0,0 +1,39 @@ +.. bmtrain-doc documentation master file, created by + sphinx-quickstart on Sat Mar 5 17:05:02 2022. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +BMTrain 文档 +=============================== + +**BMTrain** 是一个高效的大模型训练工具包,可以用于训练数百亿参数的大模型。BMTrain 可以在分布式训练模型的同时,能够保持代码的简洁性。 + +======================================= + +.. toctree:: + :maxdepth: 2 + :caption: Getting Started + + notes/installation.md + notes/quickstart.md + notes/tech.md + +.. toctree:: + :maxdepth: 2 + :caption: Package Reference + + api/bmtrain.rst + api/bmtrain.benchmark.rst + api/bmtrain.distributed.rst + api/bmtrain.inspect.rst + api/bmtrain.loss.rst + api/bmtrain.lr_scheduler.rst + api/bmtrain.nccl.rst + api/bmtrain.optim.rst + +API +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/examples/BMTrain/docs/source-en/notes/image/ZeRO3.png b/examples/BMTrain/docs/source-en/notes/image/ZeRO3.png new file mode 100644 index 00000000..38d52ed5 Binary files /dev/null and b/examples/BMTrain/docs/source-en/notes/image/ZeRO3.png differ diff --git a/examples/BMTrain/docs/source-en/notes/image/communication_example.png b/examples/BMTrain/docs/source-en/notes/image/communication_example.png new file mode 100644 index 00000000..0e549b08 Binary files /dev/null and b/examples/BMTrain/docs/source-en/notes/image/communication_example.png differ diff --git a/examples/BMTrain/docs/source-en/notes/image/communication_fig.png b/examples/BMTrain/docs/source-en/notes/image/communication_fig.png new file mode 100644 index 00000000..b7636240 Binary files /dev/null and b/examples/BMTrain/docs/source-en/notes/image/communication_fig.png differ diff --git a/examples/BMTrain/docs/source-en/notes/image/cpu.png b/examples/BMTrain/docs/source-en/notes/image/cpu.png new file mode 100644 index 00000000..451f397c Binary files /dev/null and b/examples/BMTrain/docs/source-en/notes/image/cpu.png differ diff --git a/examples/BMTrain/docs/source-en/notes/image/zero3_example.png b/examples/BMTrain/docs/source-en/notes/image/zero3_example.png new file mode 100644 index 00000000..40f0e8cc Binary files /dev/null and b/examples/BMTrain/docs/source-en/notes/image/zero3_example.png differ diff --git a/examples/BMTrain/docs/source-en/notes/installation.md b/examples/BMTrain/docs/source-en/notes/installation.md new file mode 100644 index 00000000..330bfb21 --- /dev/null +++ b/examples/BMTrain/docs/source-en/notes/installation.md @@ -0,0 +1,45 @@ +# Installation + +## Install BMTrain + +### 1. From PyPI (Recommend) + +```shell +$ pip install bmtrain +``` + +### 2. From Source + +```shell +$ git clone https://github.com/OpenBMB/BMTrain.git +$ cd BMTrain +$ python3 setup.py install +``` + +## Compilation Options + +By setting environment variables, you can configure the compilation options of BMTrain (by default, the compilation environment will be automatically adapted). + +### AVX Instructions + +* Force the use of AVX instructions: `BMT_AVX256=ON` +* Force the use of AVX512 instructions: `BMT_AVX512=ON` + +### CUDA Compute Capability + +`TORCH_CUDA_ARCH_LIST=6.0 6.1 7.0 7.5 8.0+PTX` + +## Recommended Configuration + +* Network:Infiniband 100Gbps / RoCE 100Gbps +* GPU:NVIDIA Tesla V100 / NVIDIA Tesla A100 / RTX 3090 +* CPU:CPU that supports AVX512 instructions, 32 cores or above +* RAM:256GB or above + +## FAQ + +If the following error message is reported during compilation, try using a newer version of the gcc compiler. +``` +error: invalid static_cast from type `const torch::OrderdDict<...>` +``` + diff --git a/examples/BMTrain/docs/source-en/notes/quickstart.md b/examples/BMTrain/docs/source-en/notes/quickstart.md new file mode 100644 index 00000000..3ede3d13 --- /dev/null +++ b/examples/BMTrain/docs/source-en/notes/quickstart.md @@ -0,0 +1,159 @@ +# Quick Start + +## Step 1: Initialize BMTrain + +Before you can use BMTrain, you need to initialize it at the beginning of your code. Just like using the distributed module of PyTorch requires the use of **init_process_group** at the beginning of the code, using BMTrain requires the use of **init_distributed** at the beginning of the code. + +```python +import bmtrain as bmt +bmt.init_distributed( + seed=0, + # ... +) +``` + +**NOTE:** Do not use PyTorch's distributed module and its associated communication functions when using BMTrain. + +## Step 2: Enable ZeRO-3 Optimization + +To enable ZeRO-3 optimization, you need to make some simple replacements to the original model's code. + +* `torch.nn.Module` -> `bmtrain.DistributedModule` +* `torch.nn.Parameter` -> `bmtrain.DistributedParameter` + +And wrap the transformer blocks with `bmtrain.CheckpointBlock`. + +Here is an example. + +**Original** + +```python +import torch +class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.empty(1024)) + self.module_list = torch.nn.ModuleList([ + SomeTransformerBlock(), + SomeTransformerBlock(), + SomeTransformerBlock() + ]) + + def forward(self): + x = self.param + for module in self.module_list: + x = module(x, 1, 2, 3) + return x + +``` + +**Replaced** + +```python +import torch +import bmtrain as bmt +class MyModule(bmt.DistributedModule): # changed here + def __init__(self): + super().__init__() + self.param = bmt.DistributedParameter(torch.empty(1024)) # changed here + self.module_list = torch.nn.ModuleList([ + bmt.CheckpointBlock(SomeTransformerBlock()), # changed here + bmt.CheckpointBlock(SomeTransformerBlock()), # changed here + bmt.CheckpointBlock(SomeTransformerBlock()) # changed here + ]) + + def forward(self): + x = self.param + for module in self.module_list: + x = module(x, 1, 2, 3) + return x + +``` + +## Step 3: Enable Communication Optimization + + +To further reduce the extra overhead of communication and overlap communication with computing time, `TransformerBlockList` can be used for optimization. + +You can enable them by making the following substitutions to the code: + +* `torch.nn.ModuleList` -> `bmtrain.TransformerBlockList` +* `for module in self.module_list: x = module(x, ...)` -> `x = self.module_list(x, ...)` + +**Original** + +```python +import torch +import bmtrain as bmt +class MyModule(bmt.DistributedModule): + def __init__(self): + super().__init__() + self.param = bmt.DistributedParameter(torch.empty(1024)) + self.module_list = torch.nn.ModuleList([ + bmt.CheckpointBlock(SomeTransformerBlock()), + bmt.CheckpointBlock(SomeTransformerBlock()), + bmt.CheckpointBlock(SomeTransformerBlock()) + ]) + + def forward(self): + x = self.param + for module in self.module_list: + x = module(x, 1, 2, 3) + return x + +``` + +**Replaced** + +```python +import torch +import bmtrain as bmt +class MyModule(bmt.DistributedModule): + def __init__(self): + super().__init__() + self.param = bmt.DistributedParameter(torch.empty(1024)) + self.module_list = bmt.TransformerBlockList([ # changed here + bmt.CheckpointBlock(SomeTransformerBlock()), + bmt.CheckpointBlock(SomeTransformerBlock()), + bmt.CheckpointBlock(SomeTransformerBlock()) + ]) + + def forward(self): + x = self.param + x = self.module_list(x, 1, 2, 3) # changed here + return x + +``` + +## Step 4: Launch Distributed Training + +BMTrain uses the same launch command as the distributed module of PyTorch. + +You can choose one of them depending on your version of PyTorch. + +* `${MASTER_ADDR}` means the IP address of the master node. +* `${MASTER_PORT}` means the port of the master node. +* `${NNODES}` means the total number of nodes. +* `${GPU_PER_NODE}` means the number of GPUs per node. +* `${NODE_RANK}` means the rank of this node. + +### torch.distributed.launch +```shell +$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.py +``` + +### torchrun + +```shell +$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.py +``` + + +For more information, please refer to the [documentation](https://pytorch.org/docs/stable/distributed.html#launch-utility). + +## Other Notes + +`BMTrain` makes underlying changes to PyTorch, so if your program outputs unexpected results, you can submit information about it in an issue. + +For more examples, please refer to the *examples* folder. + diff --git a/examples/BMTrain/docs/source-en/notes/tech.md b/examples/BMTrain/docs/source-en/notes/tech.md new file mode 100644 index 00000000..fc6c704a --- /dev/null +++ b/examples/BMTrain/docs/source-en/notes/tech.md @@ -0,0 +1,11 @@ +# Introduction to Core Technology + +## ZeRO-3 Optimization +![](image/ZeRO3.png) + +## Overlap Communication and Computation +![](image/communication_fig.png) + +## CPU Offload +![](image/cpu.png) + diff --git a/examples/BMTrain/docs/source/_static/css/custom.css b/examples/BMTrain/docs/source/_static/css/custom.css new file mode 100644 index 00000000..1e3a643c --- /dev/null +++ b/examples/BMTrain/docs/source/_static/css/custom.css @@ -0,0 +1,124 @@ +a, +.wy-menu-vertical header, +.wy-menu-vertical p.caption, +.wy-nav-top .fa-bars, +.wy-menu-vertical a:hover, + +.rst-content code.literal, .rst-content tt.literal + +{ + color: #315EFE !important; +} + +/* inspired by sphinx press theme */ +.wy-menu.wy-menu-vertical li.toctree-l1.current > a { + border-left: solid 8px #315EFE !important; + border-top: none; + border-bottom: none; +} + +.wy-menu.wy-menu-vertical li.toctree-l1.current > ul { + border-left: solid 8px #315EFE !important; +} +/* inspired by sphinx press theme */ + +.wy-nav-side { + color: unset !important; + background: unset !important; + border-right: solid 1px #ccc !important; +} + +.wy-side-nav-search, +.wy-nav-top, +.wy-menu-vertical li, +.wy-menu-vertical li a:hover, +.wy-menu-vertical li a +{ + background: unset !important; +} + +.wy-menu-vertical li.current a { + border-right: unset !important; +} + +.wy-side-nav-search div, +.wy-menu-vertical a { + color: #404040 !important; +} + +.wy-menu-vertical button.toctree-expand { + color: #333 !important; +} + +.wy-nav-content { + max-width: unset; +} + +.rst-content { + max-width: 900px; +} + +.wy-nav-content .icon-home:before { + content: "Docs"; +} + +.wy-side-nav-search .icon-home:before { + content: ""; +} + +dl.field-list { + display: block !important; +} + +dl.field-list > dt:after { + content: "" !important; +} + +:root { + --dark-blue: #3260F7; + --light-blue: rgba(194, 233, 248, 0.1) ; +} + +dl.field-list > dt { + display: table; + padding-left: 6px !important; + padding-right: 6px !important; + margin-bottom: 4px !important; + padding-bottom: 1px !important; + background: var(--light-blue); + border-left: solid 2px var(--dark-blue); +} + + +dl.py.class>dt +{ + color: rgba(17, 16, 17, 0.822) !important; + background: var(--light-blue) !important; + border-top: solid 2px var(--dark-blue) !important; +} + +dl.py.method>dt +{ + background: var(--light-blue) !important; + border-left: solid 2px var(--dark-blue) !important; +} + +dl.py.attribute>dt, +dl.py.property>dt +{ + background: var(--light-blue) !important; + border-left: solid 2px var(--dark-blue) !important; +} + +.fa-plus-square-o::before, .wy-menu-vertical li button.toctree-expand::before, +.fa-minus-square-o::before, .wy-menu-vertical li.current > a button.toctree-expand::before, .wy-menu-vertical li.on a button.toctree-expand::before +{ + content: ""; +} + +.rst-content .viewcode-back, +.rst-content .viewcode-link +{ + color:#58b5cc; + font-size: 120%; +} \ No newline at end of file diff --git a/examples/BMTrain/docs/source/_static/js/custom.js b/examples/BMTrain/docs/source/_static/js/custom.js new file mode 100644 index 00000000..489b7d5c --- /dev/null +++ b/examples/BMTrain/docs/source/_static/js/custom.js @@ -0,0 +1,7 @@ +document.addEventListener("DOMContentLoaded", function(event) { + document.querySelectorAll(".wy-menu.wy-menu-vertical > ul.current > li > a").forEach(a => a.addEventListener("click", e=>{ + f = document.querySelector(".wy-menu.wy-menu-vertical > ul.current > li > ul") + if (f.style.display=='none') { f.style.display='block'; } else f.style.display = 'none' + })); + document.querySelectorAll(".headerlink").forEach(a => a.text="\u{1F517}"); +}); \ No newline at end of file diff --git a/examples/BMTrain/docs/source/api/bmtrain.benchmark.rst_bk b/examples/BMTrain/docs/source/api/bmtrain.benchmark.rst_bk new file mode 100644 index 00000000..f8b2902d --- /dev/null +++ b/examples/BMTrain/docs/source/api/bmtrain.benchmark.rst_bk @@ -0,0 +1,53 @@ +bmtrain.benchmark package +========================= + +Submodules +---------- + +bmtrain.benchmark.all\_gather module +------------------------------------ + +.. automodule:: bmtrain.benchmark.all_gather + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.reduce\_scatter module +---------------------------------------- + +.. automodule:: bmtrain.benchmark.reduce_scatter + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.send\_recv module +----------------------------------- + +.. automodule:: bmtrain.benchmark.send_recv + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.shape module +------------------------------ + +.. automodule:: bmtrain.benchmark.shape + :members: + :undoc-members: + :show-inheritance: + +bmtrain.benchmark.utils module +------------------------------ + +.. automodule:: bmtrain.benchmark.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.benchmark + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source/api/bmtrain.distributed.rst_bk b/examples/BMTrain/docs/source/api/bmtrain.distributed.rst_bk new file mode 100644 index 00000000..ef41db07 --- /dev/null +++ b/examples/BMTrain/docs/source/api/bmtrain.distributed.rst_bk @@ -0,0 +1,21 @@ +bmtrain.distributed package +=========================== + +Submodules +---------- + +bmtrain.distributed.ops module +------------------------------ + +.. automodule:: bmtrain.distributed.ops + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.distributed + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source/api/bmtrain.inspect.rst b/examples/BMTrain/docs/source/api/bmtrain.inspect.rst new file mode 100644 index 00000000..c57195ad --- /dev/null +++ b/examples/BMTrain/docs/source/api/bmtrain.inspect.rst @@ -0,0 +1,37 @@ +bmtrain.inspect package +======================= + +Submodules +---------- + +bmtrain.inspect.format module +----------------------------- + +.. automodule:: bmtrain.inspect.format + :members: format_summary + :undoc-members: + :show-inheritance: + +bmtrain.inspect.model module +---------------------------- + +.. automodule:: bmtrain.inspect.model + :members: inspect_model + :undoc-members: + :show-inheritance: + +bmtrain.inspect.tensor module +----------------------------- + +.. automodule:: bmtrain.inspect.tensor + :members: inspect_tensor, InspectTensor + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.inspect + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source/api/bmtrain.loss.rst b/examples/BMTrain/docs/source/api/bmtrain.loss.rst new file mode 100644 index 00000000..03b65646 --- /dev/null +++ b/examples/BMTrain/docs/source/api/bmtrain.loss.rst @@ -0,0 +1,21 @@ +bmtrain.loss package +==================== + +Submodules +---------- + +bmtrain.loss.cross\_entropy module +---------------------------------- + +.. automodule:: bmtrain.loss.cross_entropy + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.loss + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source/api/bmtrain.lr_scheduler.rst b/examples/BMTrain/docs/source/api/bmtrain.lr_scheduler.rst new file mode 100644 index 00000000..0ba033af --- /dev/null +++ b/examples/BMTrain/docs/source/api/bmtrain.lr_scheduler.rst @@ -0,0 +1,61 @@ +bmtrain.lr\_scheduler package +============================= + +Submodules +---------- + +bmtrain.lr\_scheduler.cosine module +----------------------------------- + +.. automodule:: bmtrain.lr_scheduler.cosine + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.exponential module +---------------------------------------- + +.. automodule:: bmtrain.lr_scheduler.exponential + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.linear module +----------------------------------- + +.. automodule:: bmtrain.lr_scheduler.linear + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.no\_decay module +-------------------------------------- + +.. automodule:: bmtrain.lr_scheduler.no_decay + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.noam module +--------------------------------- + +.. automodule:: bmtrain.lr_scheduler.noam + :members: + :undoc-members: + :show-inheritance: + +bmtrain.lr\_scheduler.warmup module +----------------------------------- + +.. automodule:: bmtrain.lr_scheduler.warmup + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.lr_scheduler + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source/api/bmtrain.nccl.rst_bk b/examples/BMTrain/docs/source/api/bmtrain.nccl.rst_bk new file mode 100644 index 00000000..3755d9ef --- /dev/null +++ b/examples/BMTrain/docs/source/api/bmtrain.nccl.rst_bk @@ -0,0 +1,21 @@ +bmtrain.nccl package +==================== + +Submodules +---------- + +bmtrain.nccl.enums module +------------------------- + +.. automodule:: bmtrain.nccl.enums + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.nccl + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source/api/bmtrain.nn.rst b/examples/BMTrain/docs/source/api/bmtrain.nn.rst new file mode 100644 index 00000000..8e2a531f --- /dev/null +++ b/examples/BMTrain/docs/source/api/bmtrain.nn.rst @@ -0,0 +1,53 @@ +bmtrain.nn package +================== + +Submodules +---------- + +bmtrain.nn.column\_parallel\_linear module +------------------------------------------ + +.. automodule:: bmtrain.nn.column_parallel_linear + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.linear module +------------------------ + +.. automodule:: bmtrain.nn.linear + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.parallel\_embedding module +------------------------------------- + +.. automodule:: bmtrain.nn.parallel_embedding + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.parallel\_linear\_func module +---------------------------------------- + +.. automodule:: bmtrain.nn.parallel_linear_func + :members: + :undoc-members: + :show-inheritance: + +bmtrain.nn.row\_parallel\_linear module +--------------------------------------- + +.. automodule:: bmtrain.nn.row_parallel_linear + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.nn + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source/api/bmtrain.optim.rst b/examples/BMTrain/docs/source/api/bmtrain.optim.rst new file mode 100644 index 00000000..2d47a3dd --- /dev/null +++ b/examples/BMTrain/docs/source/api/bmtrain.optim.rst @@ -0,0 +1,37 @@ +bmtrain.optim package +===================== + +Submodules +---------- + +bmtrain.optim.adam module +------------------------- + +.. automodule:: bmtrain.optim.adam + :members: + :undoc-members: + :show-inheritance: + +bmtrain.optim.adam\_offload module +---------------------------------- + +.. automodule:: bmtrain.optim.adam_offload + :members: + :undoc-members: + :show-inheritance: + +bmtrain.optim.optim\_manager module +----------------------------------- + +.. automodule:: bmtrain.optim.optim_manager + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain.optim + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source/api/bmtrain.rst b/examples/BMTrain/docs/source/api/bmtrain.rst new file mode 100644 index 00000000..8445e5f0 --- /dev/null +++ b/examples/BMTrain/docs/source/api/bmtrain.rst @@ -0,0 +1,140 @@ +bmtrain package +=============== + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + bmtrain.benchmark + bmtrain.distributed + bmtrain.inspect + bmtrain.loss + bmtrain.lr_scheduler + bmtrain.nccl + bmtrain.nn + bmtrain.optim + +Submodules +---------- + +bmtrain.block\_layer module +--------------------------- + +.. automodule:: bmtrain.block_layer + :members: + :undoc-members: + :show-inheritance: + +.. bmtrain.debug module +.. -------------------- + +.. .. automodule:: bmtrain.debug +.. :members: +.. :undoc-members: +.. :show-inheritance: + +bmtrain.global\_var module +-------------------------- + +.. automodule:: bmtrain.global_var + :members: + :undoc-members: + :show-inheritance: + +.. bmtrain.hook\_func module +.. ------------------------- + +.. .. automodule:: bmtrain.hook_func +.. :members: +.. :undoc-members: +.. :show-inheritance: + +bmtrain.init module +------------------- + +.. automodule:: bmtrain.init + :members: + :undoc-members: + :show-inheritance: + +bmtrain.layer module +-------------------- + +.. automodule:: bmtrain.layer + :members: + :undoc-members: + :show-inheritance: + +bmtrain.param\_init module +-------------------------- + +.. automodule:: bmtrain.param_init + :members: + :undoc-members: + :show-inheritance: + +bmtrain.parameter module +------------------------ + +.. automodule:: bmtrain.parameter + :members: DistributedParameter, ParameterInitializer + :undoc-members: + :show-inheritance: + +bmtrain.pipe\_layer module +-------------------------- + +.. automodule:: bmtrain.pipe_layer + :members: PipelineTransformerBlockList + :undoc-members: + :show-inheritance: + +bmtrain.store module +-------------------- + +.. automodule:: bmtrain.store + :members: save, load + :undoc-members: + :show-inheritance: + +bmtrain.synchronize module +-------------------------- + +.. automodule:: bmtrain.synchronize + :members: + :undoc-members: + :show-inheritance: + +bmtrain.utils module +-------------------- + +.. automodule:: bmtrain.utils + :members: + :undoc-members: + :show-inheritance: + +bmtrain.wrapper module +---------------------- + +.. automodule:: bmtrain.wrapper + :members: BMTrainModelWrapper + :undoc-members: + :show-inheritance: + +bmtrain.zero\_context module +---------------------------- + +.. automodule:: bmtrain.zero_context + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- + +.. automodule:: bmtrain + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/BMTrain/docs/source/api/modules.rst b/examples/BMTrain/docs/source/api/modules.rst new file mode 100644 index 00000000..4350b5d7 --- /dev/null +++ b/examples/BMTrain/docs/source/api/modules.rst @@ -0,0 +1,7 @@ +bmtrain +======= + +.. toctree:: + :maxdepth: 4 + + bmtrain diff --git a/examples/BMTrain/docs/source/conf.py b/examples/BMTrain/docs/source/conf.py new file mode 100644 index 00000000..066680f7 --- /dev/null +++ b/examples/BMTrain/docs/source/conf.py @@ -0,0 +1,72 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys +sys.path.insert(0, os.path.abspath('../../..')) + +import recommonmark +from recommonmark.transform import AutoStructify + + + +# -- Project information ----------------------------------------------------- + +project = 'BMTrain' +copyright = '2022, OpenBMB' +author = 'BMTrain Team' +autodoc_mock_imports = ["numpy", "tensorboard", "bmtrain.nccl._C", "bmtrain.optim._cpu", "bmtrain.optim._cuda", "bmtrain.loss._cuda"] +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.napoleon', + 'sphinx.ext.mathjax', + 'recommonmark', + 'sphinx_markdown_tables', +] + +source_suffix = ['.rst', '.md'] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = 'zh_CN' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] +#html_stype="css/custom.css" +html_css_files=['css/custom.css' ] +html_js_files= ['js/custom.js' ] diff --git a/examples/BMTrain/docs/source/index.rst b/examples/BMTrain/docs/source/index.rst new file mode 100644 index 00000000..a2ec24f9 --- /dev/null +++ b/examples/BMTrain/docs/source/index.rst @@ -0,0 +1,39 @@ +.. bmtrain-doc documentation master file, created by + sphinx-quickstart on Sat Mar 5 17:05:02 2022. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +BMTrain 文档 +=============================== + +**BMTrain** 是一个高效的大模型训练工具包,可以用于训练数百亿参数的大模型。BMTrain 可以在分布式训练模型的同时,能够保持代码的简洁性。 + +======================================= + +.. toctree:: + :maxdepth: 2 + :caption: Getting Started + + notes/installation.md + notes/quickstart.md + notes/tech.md + +.. toctree:: + :maxdepth: 2 + :caption: Package Reference + + api/bmtrain.rst + api/bmtrain.benchmark.rst + api/bmtrain.distributed.rst + api/bmtrain.inspect.rst + api/bmtrain.loss.rst + api/bmtrain.lr_scheduler.rst + api/bmtrain.nccl.rst + api/bmtrain.optim.rst + +API +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/examples/BMTrain/docs/source/notes/image/ZeRO3.png b/examples/BMTrain/docs/source/notes/image/ZeRO3.png new file mode 100644 index 00000000..38d52ed5 Binary files /dev/null and b/examples/BMTrain/docs/source/notes/image/ZeRO3.png differ diff --git a/examples/BMTrain/docs/source/notes/image/communication_example.png b/examples/BMTrain/docs/source/notes/image/communication_example.png new file mode 100644 index 00000000..0e549b08 Binary files /dev/null and b/examples/BMTrain/docs/source/notes/image/communication_example.png differ diff --git a/examples/BMTrain/docs/source/notes/image/communication_fig.png b/examples/BMTrain/docs/source/notes/image/communication_fig.png new file mode 100644 index 00000000..b7636240 Binary files /dev/null and b/examples/BMTrain/docs/source/notes/image/communication_fig.png differ diff --git a/examples/BMTrain/docs/source/notes/image/cpu.png b/examples/BMTrain/docs/source/notes/image/cpu.png new file mode 100644 index 00000000..451f397c Binary files /dev/null and b/examples/BMTrain/docs/source/notes/image/cpu.png differ diff --git a/examples/BMTrain/docs/source/notes/image/zero3_example.png b/examples/BMTrain/docs/source/notes/image/zero3_example.png new file mode 100644 index 00000000..40f0e8cc Binary files /dev/null and b/examples/BMTrain/docs/source/notes/image/zero3_example.png differ diff --git a/examples/BMTrain/docs/source/notes/installation.md b/examples/BMTrain/docs/source/notes/installation.md new file mode 100644 index 00000000..3fff4eb6 --- /dev/null +++ b/examples/BMTrain/docs/source/notes/installation.md @@ -0,0 +1,45 @@ +# 安装 + +## 安装方法 + +### 1. 用 pip 安装 (推荐) + +```shell +$ pip install bmtrain +``` + +### 2. 从源代码安装 + +```shell +$ git clone https://github.com/OpenBMB/BMTrain.git +$ cd BMTrain +$ python3 setup.py install +``` + +## 编译选项 + +通过设置环境变量,你可以控制BMTrain的编译选项(默认会自动适配编译环境): + +### AVX指令集 + +* 强制使用AVX指令集: `BMT_AVX256=ON` +* 强制使用AVX512指令集: `BMT_AVX512=ON` + +### CUDA计算兼容性 + +`TORCH_CUDA_ARCH_LIST=6.0 6.1 7.0 7.5 8.0+PTX` + +## 推荐配置 + +* 网络:Infiniband 100Gbps / RoCE 100Gbps +* GPU:NVIDIA Tesla V100 / NVIDIA Tesla A100 / RTX 3090 +* CPU:支持AVX512指令集的CPU,32核心以上 +* RAM:256GB以上 + +## 常见问题 + +如果在编译过程中如下的报错信息,请尝试使用更新版本的gcc编译器。 +``` +error: invalid static_cast from type `const torch::OrderdDict<...>` +``` + diff --git a/examples/BMTrain/docs/source/notes/quickstart.md b/examples/BMTrain/docs/source/notes/quickstart.md new file mode 100644 index 00000000..f139fd33 --- /dev/null +++ b/examples/BMTrain/docs/source/notes/quickstart.md @@ -0,0 +1,146 @@ +# 快速入门 + +## Step 1: 启用 BMTrain + +要使用BMTrain需要在代码中引入`bmtrain`工具包,并在代码的开头使用`bmtrain.init_distributed` + +```python +import bmtrain as bmt +bmt.init_distributed( + seed=0, + # ... +) +``` + +**注意:** 使用BMTrain时请不要使用PyTorch自带的`distributed`模块,包括`torch.distributed.init_process_group`以及相关通信函数。 + +## Step 2: 使用 ZeRO-3 优化 + +使用ZeRO-3优化需要对模型代码进行简单替换: + +* `torch.nn.Module` -> `bmtrain.DistributedModule` +* `torch.nn.Parameter` -> `bmtrain.DistributedParameter` + +并在合适的模块上使用`Checkpointing`。 + +**原始代码:** + +```python +import torch +class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.empty(1024)) + self.module_list = torch.nn.ModuleList([ + SomeTransformerBlock(), + SomeTransformerBlock(), + SomeTransformerBlock() + ]) + + def forward(self): + x = self.param + for module in self.module_list: + x = module(x, 1, 2, 3) + return x + +``` + +**替换后代码:** + +```python +import torch +import bmtrain as bmt +class MyModule(bmt.DistributedModule): + def __init__(self): + super().__init__() + self.param = bmt.DistributedParameter(torch.empty(1024)) + self.module_list = torch.nn.ModuleList([ + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()) + ]) + + def forward(self): + x = self.param + for module in self.module_list: + x = module(x, 1, 2, 3) + return x + +``` + +## Step 3: 通信优化 + +为了进一步缩短通信额外开销,将通信与运算时间重叠,可以使用`TransformerBlockList`来进一步优化。 +在使用时需要对代码进行简单替换: + +* `torch.nn.ModuleList` -> `bmtrain.TransformerBlockList` +* `for module in self.module_list: x = module(x, ...)` -> `x = self.module_list(x, ...)` + +**原始代码:** + +```python +import torch +import bmtrain as bmt +class MyModule(bmt.DistributedModule): + def __init__(self): + super().__init__() + self.param = bmt.DistributedParameter(torch.empty(1024)) + self.module_list = torch.nn.ModuleList([ + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()) + ]) + + def forward(self): + x = self.param + for module in self.module_list: + x = module(x, 1, 2, 3) + return x + +``` + +**替换后代码:** + +```python +import torch +import bmtrain as bmt +class MyModule(bmt.DistributedModule): + def __init__(self): + super().__init__() + self.param = bmt.DistributedParameter(torch.empty(1024)) + self.module_list = bmt.TransformerBlockList([ + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()), + bmt.Block(SomeTransformerBlock()) + ]) + + def forward(self): + x = self.param + x = self.module_list(x, 1, 2, 3) + return x + +``` + +## Step 4: 运行分布式训练代码 + +BMTrain支持PyTorch原生的分布式训练启动器,不需要额外的参数: + +### torch.distributed.launch +```shell +$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.py +``` + +### torchrun + +```shell +$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.py +``` + +更多信息请参考PyTorch官方文档:[Launch utility](https://pytorch.org/docs/stable/distributed.html#launch-utility) + +## 其它说明 + +`BMTrain`工具包对PyTorch进行了底层修改,如果你的程序输出了意料之外的结果,可以在issue中提交相关信息。 + +更多例子请参考 *examples* 文件夹。 + diff --git a/examples/BMTrain/docs/source/notes/tech.md b/examples/BMTrain/docs/source/notes/tech.md new file mode 100644 index 00000000..9104fde3 --- /dev/null +++ b/examples/BMTrain/docs/source/notes/tech.md @@ -0,0 +1,11 @@ +# 核心技术简介 + +## ZeRO-3 优化 +![](image/ZeRO3.png) + +## 通信运算重叠 +![](image/communication_fig.png) + +## CPU Offload +![](image/cpu.png) + diff --git a/examples/BMTrain/example/README.md b/examples/BMTrain/example/README.md new file mode 100644 index 00000000..395b5e64 --- /dev/null +++ b/examples/BMTrain/example/README.md @@ -0,0 +1,5 @@ +# Example + +This is an example of BMTrain's implementation of GPT-2. + +For more model implementations, please refer to [Model Center](https://github.com/OpenBMB/ModelCenter). \ No newline at end of file diff --git a/examples/BMTrain/example/benchmark.py b/examples/BMTrain/example/benchmark.py new file mode 100644 index 00000000..8a7092d9 --- /dev/null +++ b/examples/BMTrain/example/benchmark.py @@ -0,0 +1,12 @@ +import bmtrain as bmt + +def main(): + bmt.init_distributed() + bmt.print_rank("======= All Gather =======") + bmt.benchmark.all_gather() + bmt.print_rank("===== Reduce Scatter =====") + bmt.benchmark.reduce_scatter() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/examples/BMTrain/example/layers/__init__.py b/examples/BMTrain/example/layers/__init__.py new file mode 100644 index 00000000..ef4617c0 --- /dev/null +++ b/examples/BMTrain/example/layers/__init__.py @@ -0,0 +1,5 @@ +from .embedding import Embedding +from .feedforward import Feedforward +from .layernorm import Layernorm +from .attention import Attention +from .transformer import TransformerEncoder diff --git a/examples/BMTrain/example/layers/attention.py b/examples/BMTrain/example/layers/attention.py new file mode 100644 index 00000000..0f5155d4 --- /dev/null +++ b/examples/BMTrain/example/layers/attention.py @@ -0,0 +1,118 @@ +from typing import Optional +import torch +import bmtrain as bmt +from bmtrain.nn import ( + Linear, + ColumnParallelLinear, + RowParallelLinear, +) +import math +from bmtrain.global_var import config +from bmtrain.distributed import all_gather + +class Attention(bmt.DistributedModule): + def __init__(self, + dim_model : int, dim_head : int, + num_heads : int, bias : bool = True, + dtype = None + ) -> None: + super().__init__() + + if config['tp_size'] > 1: + self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False) + self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + else: + self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + + + self.softmax = torch.nn.Softmax(dim=-1) + self.num_heads = num_heads + self.dim_head = dim_head + self.dim_model = dim_model + + def forward(self, + hidden_q : torch.Tensor, # (batch_size, seq_q, dim_model) + hidden_kv : torch.Tensor, # (batch_size, seq_kv, dim_model) + mask : torch.BoolTensor, # (batch_size, seq_q, seq_kv) + position_bias : Optional[torch.Tensor] = None, # (batch, num_heads, seq_q, seq_kv) + ) -> torch.Tensor: + batch_size = hidden_q.size()[0] + + assert hidden_q.data_ptr() == hidden_kv.data_ptr() + + if config['tp_size'] > 1: + hidden_q = bmt.nn.OpParallelLinear.apply( + hidden_q, + torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0), + torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0), + True, False, + False, None + ) + hidden_q = hidden_q.view(batch_size, -1, hidden_q.shape[-1]) + h_q, h_k, h_v = hidden_q.chunk(3, dim=-1) + else: + h_q : torch.Tensor = self.project_q(hidden_q) + h_k : torch.Tensor = self.project_k(hidden_kv) + h_v : torch.Tensor = self.project_v(hidden_kv) + + seq_q = h_q.size()[1] + seq_kv = h_k.size(1) + + h_q = h_q.view(batch_size, seq_q, -1, self.dim_head) + h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head) + h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head) + + h_q = h_q.permute(0, 2, 1, 3).contiguous() + h_k = h_k.permute(0, 2, 1, 3).contiguous() + h_v = h_v.permute(0, 2, 1, 3).contiguous() + + h_q = h_q.view(-1, seq_q, self.dim_head) + h_k = h_k.view(-1, seq_kv, self.dim_head) + h_v = h_v.view(-1, seq_kv, self.dim_head) + + score = torch.bmm( + h_q, h_k.transpose(1, 2) + ) + score = score / math.sqrt(self.dim_head) + + score = score.view(batch_size, -1, seq_q, seq_kv) + + if position_bias is not None: + score = score + position_bias.view(batch_size, -1, seq_q, seq_kv) + + score = torch.where( + mask.view(batch_size, 1, seq_q, seq_kv), + score, + torch.scalar_tensor(float('-inf'), device=score.device, dtype=score.dtype) + ) + + score = torch.where( + mask.view(batch_size, 1, seq_q, seq_kv), + self.softmax(score), + torch.scalar_tensor(0, device=score.device, dtype=score.dtype) + ) + + score = score.view(-1, seq_q, seq_kv) + + h_out = torch.bmm( + score, h_v + ) + h_out = h_out.view(batch_size, -1, seq_q, self.dim_head) + h_out = h_out.permute(0, 2, 1, 3).contiguous() + h_out = h_out.view(batch_size, seq_q, -1) + if config['tp_size'] > 1: + h_out = h_out.view(h_out.shape[0] * bmt.config["tp_size"], -1, h_out.shape[-1]) + + attn_out = self.project_out(h_out) + + return attn_out + + + + + diff --git a/examples/BMTrain/example/layers/embedding.py b/examples/BMTrain/example/layers/embedding.py new file mode 100644 index 00000000..f62151c4 --- /dev/null +++ b/examples/BMTrain/example/layers/embedding.py @@ -0,0 +1,102 @@ +import math +from typing import Optional +import torch +import torch.nn.functional as F +import bmtrain as bmt + + +class Embedding(bmt.DistributedModule): + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False, + sparse: bool = False, _weight: Optional[torch.Tensor] = None, + dtype=None): + super().__init__() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if padding_idx is not None: + if padding_idx > 0: + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' + elif padding_idx < 0: + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' + padding_idx = self.num_embeddings + padding_idx + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + if _weight is None: + self.weight = bmt.DistributedParameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device="cuda"), init_method=torch.nn.init.normal_) + else: + self.weight = bmt.DistributedParameter(_weight) + + self.sparse = sparse + + @classmethod + def from_pretrained(cls, embeddings, freeze=True, padding_idx=None, + max_norm=None, norm_type=2., scale_grad_by_freq=False, + sparse=False): + r"""Creates Embedding instance from given 2-dimensional FloatTensor. + + Args: + embeddings (Tensor): FloatTensor containing weights for the Embedding. + First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``. + freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process. + Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True`` + padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient; + therefore, the embedding vector at :attr:`padding_idx` is not updated during training, + i.e. it remains as a fixed "pad". + max_norm (float, optional): See module initialization documentation. + norm_type (float, optional): See module initialization documentation. Default ``2``. + scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``. + sparse (bool, optional): See module initialization documentation. + + Examples:: + + >>> # FloatTensor containing pretrained weights + >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]]) + >>> embedding = nn.Embedding.from_pretrained(weight) + >>> # Get embeddings for index 1 + >>> input = torch.LongTensor([1]) + >>> embedding(input) + tensor([[ 4.0000, 5.1000, 6.3000]]) + """ + assert embeddings.dim() == 2, \ + 'Embeddings parameter is expected to be 2-dimensional' + rows, cols = embeddings.shape + embedding = cls( + num_embeddings=rows, + embedding_dim=cols, + _weight=embeddings, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse) + embedding.weight.requires_grad = not freeze + return embedding + + def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor: + if not projection: + out = F.embedding( + input, self.weight, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + return out + else: + out = F.linear(input, self.weight) + return out + + def extra_repr(self) -> str: + s = '{num_embeddings}, {embedding_dim}' + if self.padding_idx is not None: + s += ', padding_idx={padding_idx}' + if self.max_norm is not None: + s += ', max_norm={max_norm}' + if self.norm_type != 2: + s += ', norm_type={norm_type}' + if self.scale_grad_by_freq is not False: + s += ', scale_grad_by_freq={scale_grad_by_freq}' + if self.sparse is not False: + s += ', sparse=True' + return s.format(**self.__dict__) + + diff --git a/examples/BMTrain/example/layers/feedforward.py b/examples/BMTrain/example/layers/feedforward.py new file mode 100644 index 00000000..e88d2495 --- /dev/null +++ b/examples/BMTrain/example/layers/feedforward.py @@ -0,0 +1,23 @@ +import torch +import bmtrain as bmt +from bmtrain.nn import ( + Linear, + ColumnParallelLinear, + RowParallelLinear) +from bmtrain.global_var import config + +class Feedforward(bmt.DistributedModule): + def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None: + super().__init__() + + if config['tp_size'] > 1: + self.w_in = ColumnParallelLinear(dim_model, dim_ff, bias = bias, dtype=dtype) + self.w_out = RowParallelLinear(dim_ff, dim_model, bias = bias, dtype=dtype) + else: + self.w_in = Linear(dim_model, dim_ff, bias=bias, dtype=dtype) + self.w_out = Linear(dim_ff, dim_model, bias=bias, dtype=dtype) + + self.relu = torch.nn.ReLU() + + def forward(self, input : torch.Tensor) -> torch.Tensor: + return self.w_out(self.relu(self.w_in(input))) diff --git a/examples/BMTrain/example/layers/layernorm.py b/examples/BMTrain/example/layers/layernorm.py new file mode 100644 index 00000000..9f3e3bc2 --- /dev/null +++ b/examples/BMTrain/example/layers/layernorm.py @@ -0,0 +1,34 @@ +from typing import Tuple +import torch +import torch.nn.functional as F +import bmtrain as bmt + +class Layernorm(bmt.DistributedModule): + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__(self, normalized_shape, eps: float = 1e-5, elementwise_affine: bool = True, + dtype=None) -> None: + super().__init__() + if isinstance(normalized_shape, int): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = bmt.DistributedParameter(torch.empty(self.normalized_shape, dtype=dtype, device="cuda"), init_method=torch.nn.init.ones_) + self.bias = bmt.DistributedParameter(torch.empty(self.normalized_shape, dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.layer_norm( + input, self.normalized_shape, self.weight, self.bias, self.eps) + + def extra_repr(self) -> str: + return '{normalized_shape}, eps={eps}, ' \ + 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) diff --git a/examples/BMTrain/example/layers/transformer.py b/examples/BMTrain/example/layers/transformer.py new file mode 100644 index 00000000..4cbff59b --- /dev/null +++ b/examples/BMTrain/example/layers/transformer.py @@ -0,0 +1,34 @@ +from typing import Optional +import torch +import bmtrain as bmt +from layers import Layernorm, Feedforward, Attention + +class TransformerEncoder(bmt.DistributedModule): + def __init__(self, + dim_model : int, dim_head : int, num_heads : int, dim_ff : int, + bias : bool = True, dtype = None + ) -> None: + super().__init__() + + self.ln_attn = Layernorm(dim_model, dtype=dtype) + self.attn = Attention(dim_model, dim_head, num_heads, bias=bias, dtype=dtype) + + self.ln_ff = Layernorm(dim_model, dtype=dtype) + self.ff = Feedforward(dim_model, dim_ff, bias=bias, dtype=dtype) + + def forward(self, + hidden : torch.Tensor, # (batch, seq_len, dim_model) + mask : torch.BoolTensor, # (batch, seq_len, dim_model) + position_bias : Optional[torch.Tensor] = None, # (batch, num_head, seq_len, seq_len) + ): + bmt.inspect.record_tensor(hidden, "hidden") + x = self.ln_attn(hidden) + x = self.attn(x, x, mask, position_bias) + hidden = hidden + x + + x = self.ln_ff(hidden) + x = self.ff(x) + hidden = hidden + x.view_as(hidden) + + return hidden + diff --git a/examples/BMTrain/example/models/__init__.py b/examples/BMTrain/example/models/__init__.py new file mode 100644 index 00000000..e7d1dcc9 --- /dev/null +++ b/examples/BMTrain/example/models/__init__.py @@ -0,0 +1 @@ +from .gpt import GPT \ No newline at end of file diff --git a/examples/BMTrain/example/models/gpt.py b/examples/BMTrain/example/models/gpt.py new file mode 100644 index 00000000..ed604382 --- /dev/null +++ b/examples/BMTrain/example/models/gpt.py @@ -0,0 +1,64 @@ +import torch +import bmtrain as bmt +from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder +from bmtrain.global_var import config + +class GPT(bmt.DistributedModule): + def __init__(self, + num_layers : int, vocab_size : int, + dim_model : int, dim_head : int, num_heads : int, dim_ff : int, + max_distance : int, + bias : bool = True, dtype = None + ) -> None: + super().__init__() + + self.max_distance = max_distance + + if config["tp_size"] > 1: + self.word_emb = bmt.nn.VPEmbedding(vocab_size, dim_model, dtype=dtype) + else: + self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype) + self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype) + + if config['pipe_size'] > 1: + self.transformers = bmt.PipelineTransformerBlockList([ + bmt.Block( + TransformerEncoder( + dim_model, dim_head, num_heads, dim_ff, bias, dtype + ) + , mode="PIPE" + ) + for _ in range(num_layers) + ]) + else: + self.transformers = bmt.TransformerBlockList([ + bmt.Block( + TransformerEncoder( + dim_model, dim_head, num_heads, dim_ff, bias, dtype + ) + ) + for _ in range(num_layers) + ]) + + self.layernorm = Layernorm(dim_model, dtype=dtype) + + def forward(self, + input : torch.LongTensor, # (batch, seq_len) + pos : torch.LongTensor, # (batch, seq_len) + mask : torch.BoolTensor, # (batch, seq_len) + ) -> torch.Tensor: + + mask_2d = mask[:, None, :] & mask[:, :, None] # (batch, seq_len, seq_len) + mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None]) + if config["tp_size"] > 1: + input = input.chunk(config["tp_size"], dim=1)[config["tp_rank"]] + pos = pos.chunk(config["tp_size"], dim=1)[config["tp_rank"]] + out = self.pos_emb(pos) + self.word_emb(input) + + # for layer in self.transformers: + out = self.transformers(out, mask_2d, None) + out = self.layernorm(out) + logits = self.word_emb(out, projection=True) + bmt.inspect.record_tensor(logits, "logits") + + return logits diff --git a/examples/BMTrain/example/run.sh b/examples/BMTrain/example/run.sh new file mode 100644 index 00000000..542e5252 --- /dev/null +++ b/examples/BMTrain/example/run.sh @@ -0,0 +1,3 @@ +export NCCL_P2P_DISABLE=1 +export CUDA_LAUNCH_BLOCKING=1 +torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost train.py diff --git a/examples/BMTrain/example/sbatch.sh b/examples/BMTrain/example/sbatch.sh new file mode 100644 index 00000000..3b93d99b --- /dev/null +++ b/examples/BMTrain/example/sbatch.sh @@ -0,0 +1,20 @@ +#!/bin/bash +#SBATCH --job-name=cpm2-test +#SBATCH --partition=rtx2080 + +#SBATCH --nodes=1 +#SBATCH --gpus-per-node=8 + + +MASTER_PORT=30123 +MASTER_HOST=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) + +# load python virtualenv if you have +# source /path/to/python/virtualenv/bin/activate + +# uncomment to print nccl debug info +# export NCCL_DEBUG=info + +srun torchrun --nnodes=$SLURM_JOB_NUM_NODES --nproc_per_node=$SLURM_GPUS_PER_NODE --rdzv_id=$SLURM_JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$MASTER_HOST:$MASTER_PORT train.py + + diff --git a/examples/BMTrain/example/train.py b/examples/BMTrain/example/train.py new file mode 100644 index 00000000..d5906a06 --- /dev/null +++ b/examples/BMTrain/example/train.py @@ -0,0 +1,138 @@ +import torch +import bmtrain as bmt +from models import GPT +import time +from bmtrain import optim +from bmtrain.global_var import config +from bmtrain import inspect + +def main(): + bmt.init_distributed( + seed=0, + tp_size=2, + ) + + model = GPT( + num_layers=8, + vocab_size=10240, + dim_model=2560, + dim_head=80, + num_heads=32, + dim_ff=8192, + max_distance=1024, + bias=True, + dtype=torch.half + ) + + bmt.init_parameters(model) + + bmt.print_rank("Model memory") + bmt.print_rank(torch.cuda.memory_summary()) + bmt.synchronize() + + # data + # generate dummy data for each rank + torch.manual_seed(1234) + + batch_size = 2 + seq_len = 512 + world_size = bmt.config["world_size"] if bmt.config["tp_size"] == 1 else bmt.config["tp_zero_size"] + r = bmt.config["rank"] if bmt.config["tp_size"] == 1 else bmt.config["tp_zero_rank"] + + for i in range(world_size): + sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) + enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() + enc_input = sent[:, :-1].long().cuda() + targets = sent[:, 1:].long().cuda() + mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None] + targets = torch.where( + mask, + targets, + torch.full_like(targets, -100, dtype=torch.long) + ) + + if i == r: + break + + if config['tp_size'] > 1: + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) + else: + loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100) + + optimizer = optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2) + lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) + + optim_manager = optim.OptimManager(loss_scale=2**20) + optim_manager.add_optimizer(optimizer, lr_scheduler) + + bmt.synchronize() + + avg_time_recorder = bmt.utils.AverageRecorder() + avg_loss_recorder = bmt.utils.AverageRecorder() + + for iteration in range(1000): + # load data + st = time.time() + + with inspect.inspect_tensor() as inspector: + pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) + logits = model( + enc_input, + pos, + pos < enc_length[:, None] + ) + batch, seq_len, vocab_out_size = logits.size() + + if config['tp_size'] > 1: + loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) + else: + loss = loss_func(logits.float().view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) + + global_loss = bmt.sum_loss(loss).item() + + optim_manager.zero_grad() + + optim_manager.backward(loss) + + # print inspected tensors in the forward & backward pass + # print parameters of the model + if iteration % 100 == 0: + bmt.print_rank( + inspect.format_summary( + inspector.get_summary() + ) + ) + bmt.print_rank( + inspect.format_summary( + inspect.inspect_model(model, "*") + ) + ) + + optim_manager.step() + + # record time and loss + iteration_time = time.time() - st + + avg_time_recorder.record(iteration_time) + avg_loss_recorder.record(global_loss) + + # print time and loss + bmt.print_rank( + "| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format( + iteration, + global_loss, + avg_loss_recorder.value, + lr_scheduler.current_lr, + optim_manager.loss_scale, + avg_time_recorder.value + ) + ) + + # save model + if iteration % 1000 == 0: + bmt.save(model, "ckpt-%d.pt" % iteration) + + bmt.save(model, "checkpoint.pt") + +if __name__ == '__main__': + main() diff --git a/examples/BMTrain/other_requirements.txt b/examples/BMTrain/other_requirements.txt new file mode 100644 index 00000000..6654b1ac --- /dev/null +++ b/examples/BMTrain/other_requirements.txt @@ -0,0 +1,6 @@ +tqdm +cpm_kernels>=1.0.11 +jieba +tensorboard +setuptools_rust +transformers \ No newline at end of file diff --git a/examples/BMTrain/pyproject.toml b/examples/BMTrain/pyproject.toml new file mode 100644 index 00000000..b563eb32 --- /dev/null +++ b/examples/BMTrain/pyproject.toml @@ -0,0 +1,8 @@ +[build-system] +requires = [ + "setuptools", + "pybind11", + "nvidia-nccl-cu11 >= 2.14.3", + "cmake > 3.27.0" +] +build-backend = "setuptools.build_meta" diff --git a/examples/BMTrain/setup.py b/examples/BMTrain/setup.py new file mode 100644 index 00000000..70752ff6 --- /dev/null +++ b/examples/BMTrain/setup.py @@ -0,0 +1,113 @@ +import os +import shutil +from setuptools.command.build_ext import build_ext +from setuptools import setup, find_packages, Extension +import setuptools +import warnings +import sys +import subprocess +class CMakeExtension(Extension): + def __init__(self, name, sourcedir=""): + Extension.__init__(self, name, sources=[]) + self.sourcedir = os.path.abspath(sourcedir) + + +def is_ninja_available(): + r''' + Returns ``True`` if the `ninja <https://ninja-build.org/>`_ build system is + available on the system, ``False`` otherwise. + ''' + with open(os.devnull, 'wb') as devnull: + try: + subprocess.check_call('ninja --version'.split(), stdout=devnull) + except OSError: + return False + else: + return True + + +class CMakeBuild(build_ext): + + def build_extension(self, ext): + extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) + + # required for auto-detection & inclusion of auxiliary "native" libs + if not extdir.endswith(os.path.sep): + extdir += os.path.sep + + debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug + cfg = "Debug" if debug else "Release" + + # CMake lets you override the generator - we need to check this. + # Can be set with Conda-Build, for example. + cmake_generator = os.environ.get("CMAKE_GENERATOR", "") + + # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON + # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code + # from Python. + cmake_args = [ + f"-DCMAKE_CXX_STANDARD=14", + f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}", + f"-DPYTHON_EXECUTABLE={sys.executable}", + f"-DPYTHON_VERSION={sys.version_info.major}.{sys.version_info.minor}", + f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm + ] + + build_args = [] + if "CMAKE_ARGS" in os.environ: + cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item] + + cmake_args += [f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"] + + + if not cmake_generator or cmake_generator == "Ninja": + try: + import ninja # noqa: F401 + + ninja_executable_path = os.path.join(ninja.BIN_DIR, "ninja") + cmake_args += [ + "-GNinja", + f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}", + ] + except ImportError: + pass + + if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ: + # self.parallel is a Python 3 only way to set parallel jobs by hand + # using -j in the build_ext call, not supported by pip or PyPA-build. + if hasattr(self, "parallel") and self.parallel: + # CMake 3.12+ only. + build_args += [f"-j{self.parallel}"] + + build_temp = os.path.join(self.build_temp, ext.name) + if os.path.exists(build_temp): + shutil.rmtree(build_temp) + os.makedirs(build_temp) + + cmake_args += ["-DPython_ROOT_DIR=" + os.path.dirname(os.path.dirname(sys.executable))] + subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp) + subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=build_temp) + +ext_modules = [ + CMakeExtension("bmtrain.C"), +] +setup( + name='bmtrain', + version='1.0.1', + author="Guoyang Zeng", + author_email="qbjooo@qq.com", + description="A toolkit for training big models", + packages=find_packages(), + install_requires=[ + "numpy", + "nvidia-nccl-cu11>=2.14.3" + ], + setup_requires=[ + "pybind11", + "nvidia-nccl-cu11>=2.14.3" + ], + ext_modules=ext_modules, + cmdclass={ + 'build_ext': CMakeBuild + }) + diff --git a/examples/BMTrain/tests/test_all.py b/examples/BMTrain/tests/test_all.py new file mode 100644 index 00000000..db5d2dd4 --- /dev/null +++ b/examples/BMTrain/tests/test_all.py @@ -0,0 +1,43 @@ +import subprocess +from tqdm import tqdm + + +tq = tqdm([ + ("different_output_shape", 1), + ("load_ckpt", 1), + ("init_parameters", 1), + ("synchronize", 4), + ("init_parameters_multi_gpu", 4), + ("optim_state", 4), + + ("requires_grad", 1), + ("requires_grad_multi_gpu", 2), + ("has_inf_nan", 1), + ("dropout", 1), + ("loss_func", 1), + + ("optim", 1), + + ("multi_return", 2), + ("middle_hidden", 4), + ("other_hidden", 4), + + ("model_wrapper", 4), + + ("send_recv", 4), + ("nccl_backward", 4), + ("no_grad", 1), + ("column_parallel_linear", 2), + ("row_parallel_linear", 2), + ("parallel_projection", 4), + + ("training", 4), +]) + +for t, num_gpu in tq: + PREFIX = f"python3 -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node={num_gpu} --master_addr=localhost --master_port=32123" + SUFFIX = f"> test_log.txt 2>&1" + command = f"{PREFIX} test_{t}.py {SUFFIX}" + completedProc = subprocess.run(command, shell=True) + assert completedProc.returncode == 0, f"test {t} failed, see test_log.txt for more details." + print(f"finish testing {t}") diff --git a/examples/BMTrain/tests/test_column_parallel_linear.py b/examples/BMTrain/tests/test_column_parallel_linear.py new file mode 100644 index 00000000..5f2fdad1 --- /dev/null +++ b/examples/BMTrain/tests/test_column_parallel_linear.py @@ -0,0 +1,74 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +import numpy as np + +def run_bmt(x, gather_input, gather_output, ckp_path, tp_size=2): + linear = bmt.nn.ColumnParallelLinear(8,8, gather_input=gather_input, gather_output=gather_output) + linear = bmt.Block(linear) + bmt.init_parameters(linear) + y = linear(x) + y.sum().backward() + bmt.save(linear, ckp_path) + bmt.synchronize() + return y, linear._parameters['weight'].grad, linear._parameters['bias'].grad + +def run_torch(x, ckp_path): + linear = torch.nn.Linear(8, 8) + linear_dict = torch.load(ckp_path) + linear.load_state_dict(linear_dict) + linear = linear.cuda() + linear.weight.requires_grad_() + y = linear(x) + y.sum().backward() + return y, linear.weight.grad, linear.bias.grad + +def run(gather_input, gather_output, ckp_path): + torch.cuda.manual_seed(100) + tp_size = config["tp_size"] + tp_rank = config['topology'].tp_id + x = torch.randn(8, 8, 8, device='cuda') + bmt_x = x.clone() + if gather_input: + rank_x = bmt_x.chunk(tp_size, dim=0)[tp_rank] + else: + rank_x = bmt_x + rank_x.requires_grad_() + x.requires_grad_() + y1, weight_grad1, bias_grad1 = run_bmt(rank_x, gather_input, gather_output, ckp_path) + y2, weight_grad2, bias_grad2 = run_torch(x, ckp_path) + tp_rank = config['topology'].tp_id + if gather_output: + assert np.allclose(y1.detach().cpu().numpy(), y2.detach().cpu().numpy()) + else: + torch_out_list = torch.split(y2, y2.size()[-1] // tp_size, dim=-1) + assert np.allclose(y1.detach().cpu().numpy(), torch_out_list[tp_rank].detach().cpu().numpy()) + + weight_grad_list = weight_grad2.chunk(tp_size, dim=0) + assert np.allclose(weight_grad1.reshape(weight_grad_list[tp_rank].shape).cpu().numpy(), weight_grad_list[tp_rank].cpu().numpy()) + + bias_grad_list = bias_grad2.chunk(tp_size, dim=0) + assert np.allclose(bias_grad1.reshape(bias_grad_list[tp_rank].shape).cpu().numpy(), bias_grad_list[tp_rank].cpu().numpy()) + + if gather_input: + x_grad_list = x.grad.chunk(tp_size, dim=0) + np.testing.assert_allclose(rank_x.grad.cpu().numpy(), x_grad_list[tp_rank].cpu().numpy(), atol=1e-4, rtol=1e-4) + else: + np.testing.assert_allclose(rank_x.grad.cpu().numpy(), x.grad.cpu().numpy(), atol=1e-4, rtol=1e-4) + +def test_gather_output(): + run(True, True, 'linear.ckp') + +def test_no_gather_output(): + run(True, False, 'linear_no_gather.ckp') + +def test_no_gather_input(): + run(False, True, 'linear.ckp') + + +if __name__ == "__main__": + bmt.init_distributed(tp_size=2) + test_gather_output() + test_no_gather_output() + test_no_gather_input() + diff --git a/examples/BMTrain/tests/test_different_output_shape.py b/examples/BMTrain/tests/test_different_output_shape.py new file mode 100644 index 00000000..bb8ab7fa --- /dev/null +++ b/examples/BMTrain/tests/test_different_output_shape.py @@ -0,0 +1,50 @@ +import torch +import bmtrain as bmt + +class Block0(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + +class Block1(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return (x,) + +class Block2(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return (x, x) + +class Block10(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return [x, x, x, x, x, x, x, x, x, x] + +if __name__ == "__main__": + bmt.init_distributed() + x = torch.tensor([1,2,3.]) + + b = bmt.Block(Block0()) + y = b(x) + assert isinstance(y, torch.Tensor) + + b = bmt.Block(Block1()) + y = b(x) + assert isinstance(y, tuple) and len(y)==1 + + b = bmt.Block(Block2()) + y = b(x) + assert isinstance(y, tuple) and len(y)==2 + + b = bmt.Block(Block10()) + y = b(x) + assert isinstance(y, tuple) and len(y)==10 diff --git a/examples/BMTrain/tests/test_dropout.py b/examples/BMTrain/tests/test_dropout.py new file mode 100644 index 00000000..d29240a1 --- /dev/null +++ b/examples/BMTrain/tests/test_dropout.py @@ -0,0 +1,45 @@ +from utils import * + +import torch +import bmtrain as bmt + +class InnerModule(bmt.DistributedModule): + def __init__(self): + super().__init__() + + self.drop = torch.nn.Dropout(p=0.5) + + def forward(self, x): + return self.drop(x) + +class OutterModule(bmt.DistributedModule): + def __init__(self) -> None: + super().__init__() + + self.blk = bmt.TransformerBlockList([ + bmt.Block(InnerModule()) + for _ in range(5) + ]) + + def forward(self, x): + return self.blk(x) + +def test_main(): + model = OutterModule() + + for _ in range(5): + model.train() + x = torch.ones(32, device="cuda") + y = model(x) + print(y) + assert_neq(x.numel()-y.nonzero().size(0), 0) + + model.eval() + x = torch.ones(32, device="cuda") + y = model(x) + print(y) + assert_eq(x.numel()-y.nonzero().size(0), 0) + +if __name__ == "__main__": + bmt.init_distributed() + test_main() diff --git a/examples/BMTrain/tests/test_grad_accu.py b/examples/BMTrain/tests/test_grad_accu.py new file mode 100644 index 00000000..48dbe8ac --- /dev/null +++ b/examples/BMTrain/tests/test_grad_accu.py @@ -0,0 +1,80 @@ +import bmtrain as bmt +import torch +from bmtrain import config +from bmtrain.block_layer import CheckpointBlockContext, CheckpointBlock, TransformerBlockList +from bmtrain.pipe_layer import PipelineTransformerBlockList +from typing import List +import torch.nn.functional as F +def print_rank0(s): + if bmt.rank() == 0: + print(s) +class Linear(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.out = {} + if init_weight: + self.weight = bmt.DistributedParameter(torch.tensor(init_weight, dtype=torch.float, device="cuda").reshape(out_features, in_features)) + else: + self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.xavier_normal_) + + if init_bias: + self.bias = bmt.DistributedParameter(torch.tensor(init_bias, dtype=torch.float, device="cuda").reshape(out_features,)) + else: + self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_) + + def forward(self, input): + ret = F.linear(input, self.weight, self.bias) + return ret + +def test_grad_accu(): + # normal distribute module + m = Linear(256, 256) + inp = torch.randn((1, 10, 256), device="cuda") + logits = m(inp) + loss = logits.sum() + loss.backward() + grad1 = m._parameters["weight"].grad.clone() + logits = m(inp) + loss = logits.sum() + loss.backward() + grad2 = m._parameters["weight"].grad + assert torch.allclose(grad1*2, grad2) + print_rank0("grad accumulation for distribute module passed") + # checkpoint block + m = CheckpointBlock(Linear(256, 256)) + inp = torch.randn((1, 10, 256), device="cuda") + logits = m(inp) + loss = logits.sum() + loss.backward() + bmt.synchronize() + grad1 = m.weight.grad.clone() + logits = m(inp) + loss = logits.sum() + loss.backward() + bmt.synchronize() + grad2 = m.weight.grad.clone() + assert torch.allclose(grad1*2, grad2) + print_rank0("grad accumulation for checkpointblock passed") + # transformer block list + m = TransformerBlockList([CheckpointBlock(Linear(256, 256))]) + inp = torch.randn((1, 10, 256), device="cuda") + logits = m(inp) + loss = logits.sum() + loss.backward() + bmt.synchronize() + grad1 = m[0].weight.grad.clone() + logits = m(inp) + loss = logits.sum() + loss.backward() + bmt.synchronize() + grad2 = m[0].weight.grad + assert torch.allclose(grad1*2, grad2) + print_rank0("grad accumulation for TransformerBlockList passed") + + +if __name__ == "__main__": + bmt.init_distributed() + test_grad_accu() \ No newline at end of file diff --git a/examples/BMTrain/tests/test_has_inf_nan.py b/examples/BMTrain/tests/test_has_inf_nan.py new file mode 100644 index 00000000..93ac8118 --- /dev/null +++ b/examples/BMTrain/tests/test_has_inf_nan.py @@ -0,0 +1,37 @@ +from utils import * +import torch +import bmtrain.loss._function as F +import random + +def check(x, v): + out = torch.zeros(1, dtype=torch.uint8, device="cuda")[0] + F.has_inf_nan(x, out) + assert_eq(out.item(), v) + +def test_main(dtype): + for i in list(range(1, 100)) + [1000]*10 + [10000]*10 + [100000]*10 + [1000000]*10: + x = torch.rand((i,)).to(dtype).cuda() + check(x, 0) + p = random.randint(0, i-1) + x[p] = x[p] / 0 + check(x, 1) + x[p] = 2 + check(x, 0) + p = random.randint(0, i-1) + x[p] = 0 + x[p] = x[p] / 0 + check(x, 1) + p = random.randint(0, i-1) + x[p] = x[p] / 0 + p = random.randint(0, i-1) + x[p] = x[p] / 0 + check(x, 1) + print("That's right") + +if __name__ == "__main__": + test_main(torch.float16) + print("==============================================================================") + try: + test_main(torch.bfloat16) + except NotImplementedError: + pass diff --git a/examples/BMTrain/tests/test_init_parameters.py b/examples/BMTrain/tests/test_init_parameters.py new file mode 100644 index 00000000..b67431f2 --- /dev/null +++ b/examples/BMTrain/tests/test_init_parameters.py @@ -0,0 +1,223 @@ +from utils import * + +import torch +import torch.nn.functional as F +import bmtrain as bmt + +def manual_seed(seed=33): + torch.manual_seed(seed) + import random as random + random.seed(seed) + try: + import numpy as np + np.random.seed(seed) + except ModuleNotFoundError: + pass + +class Linear_NormalInitBefore(torch.nn.Module): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + w = torch.empty(out_features, in_features, dtype=dtype, device="cuda") # use cuda to match random algorithm + torch.nn.init.xavier_normal_(w) + self.weight = torch.nn.Parameter(w) + if bias: + b = torch.empty(out_features, dtype=dtype, device="cuda") # use cuda to match random algorithm + torch.nn.init.zeros_(b) + self.bias = torch.nn.Parameter(b) + else: + self.register_parameter('bias', None) + + def forward(self, input): + return F.linear(input, self.weight, self.bias) + +class Linear_NormalInitAfter(torch.nn.Module): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype, device="cuda")) # use cuda to match random algorithm + torch.nn.init.xavier_normal_(self.weight) + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, dtype=dtype, device="cuda")) # use cuda to match random algorithm + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + def forward(self, input): + return F.linear(input, self.weight, self.bias) + +class Linear_BMTInitializer(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=dtype), init_method=torch.nn.init.xavier_normal_) + if bias: + self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=dtype), init_method=torch.nn.init.zeros_) + else: + self.register_parameter('bias', None) + + def forward(self, input): + return F.linear(input, self.weight, self.bias) + +class Linear_ManualInitBefore(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + w = torch.empty(out_features, in_features, dtype=dtype, device="cuda") # use cuda to match random algorithm + torch.nn.init.xavier_normal_(w) + self.weight = bmt.DistributedParameter(w) + if bias: + b = torch.empty(out_features, dtype=dtype, device="cuda") # use cuda to match random algorithm + torch.nn.init.zeros_(b) + self.bias = bmt.DistributedParameter(b) + else: + self.register_parameter('bias', None) + + def forward(self, input): + return F.linear(input, self.weight, self.bias) + +class Linear_ManualInitAfter(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=dtype, device="cuda")) # use cuda to match random algorithm + # torch.nn.init.xavier_normal_(self.weight) + if bias: + self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=dtype, device="cuda")) # use cuda to match random algorithm + # torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + def forward(self, input): + return F.linear(input, self.weight, self.bias) + + def extra_repr(self) -> str: + return 'in_features={}, out_features={}, bias={}'.format( + self.in_features, self.out_features, self.bias is not None + ) + +class Linear_Pipeline(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.l = bmt.PipelineTransformerBlockList([ + Linear_BMTInitializer(in_features, out_features, bias, dtype), + Linear_BMTInitializer(in_features, out_features, bias, dtype), + ]) + + def forward(self, input): + return self.l(input) + +class Linear_BlockList(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.l = bmt.TransformerBlockList([ + Linear_BMTInitializer(in_features, out_features, bias, dtype), + Linear_BMTInitializer(in_features, out_features, bias, dtype), + ]) + + def forward(self, input): + return self.l(input) + +class Linear_CheckpointList(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.l = torch.nn.ModuleList([ + bmt.CheckpointBlock(Linear_BMTInitializer(in_features, out_features, bias, dtype)), + bmt.CheckpointBlock(Linear_BMTInitializer(in_features, out_features, bias, dtype)), + ]) + + def forward(self, input): + return self.l(input) + +class Linear_Checkpoint(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.l = bmt.CheckpointBlock(Linear_BMTInitializer(in_features, out_features, bias, dtype)) + + def forward(self, input): + return self.l(input) + +def test_main(): + shape = [3, 5] + # torch + m = [None] * 10 + ret = [None] * 10 + manual_seed(33) + m[0] = Linear_NormalInitBefore(*shape) + ret[0] = (m[0].weight.data, m[0].bias.data) + + manual_seed(33) + m[1] = Linear_NormalInitAfter(*shape) + ret[1] = (m[1].weight.data, m[1].bias.data) + + # bmtrain + manual_seed(33) + m[2] = Linear_BMTInitializer(*shape) + bmt.init_parameters(m[2]) + bmt.synchronize() + ret[2] = (m[2].weight.data, m[2].bias.data) + + manual_seed(33) + m[3] = Linear_ManualInitBefore(*shape) + bmt.synchronize() + ret[3] = (m[3].weight.data, m[3].bias.data) + + # manual_seed(33) + # mw = Linear_ManualInitAfter(*shape) # not supported + # print(mw.weight.data, mw.bias.data) + + manual_seed(33) + m[4] = bmt.BMTrainModelWrapper(m[0]) + ret[4] = (m[4].weight.data, m[4].bias.data) + + manual_seed(33) + m[5] = bmt.BMTrainModelWrapper(m[1]) + ret[5] = (m[5].weight.data, m[5].bias.data) + + manual_seed(33) + m[6] = Linear_Pipeline(*shape) + bmt.init_parameters(m[6]) + ret[6] = (m[6].l[0].weight.data, m[6].l[0].bias.data) + + manual_seed(33) + m[7] = Linear_BlockList(*shape) + bmt.init_parameters(m[7]) + ret[7] = (m[7].l[0].weight.data, m[7].l[0].bias.data) + + manual_seed(33) + m[8] = Linear_CheckpointList(*shape) + bmt.init_parameters(m[8]) + ret[8] = (m[8].l[0].weight.data, m[8].l[0].bias.data) + + manual_seed(33) + m[9] = Linear_Checkpoint(*shape) + bmt.init_parameters(m[9]) + ret[9] = (m[9].l.weight.data, m[9].l.bias.data) + + for i in range(10): + ret[i] = ( ret[i][0].view(-1), ret[i][1].view(-1) ) + print(ret[i]) + for i in range(10): + for j in range(10): + print(i, j) + assert_all_eq(ret[i][0], ret[j][0]) + assert_all_eq(ret[i][1], ret[j][1]) + +if __name__ == "__main__": + bmt.init_distributed(pipe_size=1) + + test_main() \ No newline at end of file diff --git a/examples/BMTrain/tests/test_init_parameters_multi_gpu.py b/examples/BMTrain/tests/test_init_parameters_multi_gpu.py new file mode 100644 index 00000000..1e61568c --- /dev/null +++ b/examples/BMTrain/tests/test_init_parameters_multi_gpu.py @@ -0,0 +1,146 @@ +from utils import * + +import os +import torch +import torch.nn.functional as F +import bmtrain as bmt + +def manual_seed(seed=33): + torch.manual_seed(seed) + import random as random + random.seed(seed) + try: + import numpy as np + np.random.seed(seed) + except ModuleNotFoundError: + pass + +class Linear_Normal(torch.nn.Module): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype, device="cuda")) # use cuda to match random algorithm + torch.nn.init.xavier_normal_(self.weight) + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, dtype=dtype, device="cuda")) # use cuda to match random algorithm + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + def forward(self, input): + return F.linear(input, self.weight, self.bias) + +class Linear_BMTInitializer(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=dtype), init_method=torch.nn.init.xavier_normal_) + if bias: + self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=dtype), init_method=torch.nn.init.zeros_) + else: + self.register_parameter('bias', None) + + def forward(self, input): + return F.linear(input, self.weight, self.bias) + +class Linear_NormalList(torch.nn.Module): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.l = torch.nn.ModuleList([ + Linear_Normal(in_features, out_features, bias, dtype), + Linear_Normal(in_features, out_features, bias, dtype), + ]) + + def forward(self, input): + return self.l(input) + +class Linear_Pipeline(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.l = bmt.PipelineTransformerBlockList([ + Linear_BMTInitializer(in_features, out_features, bias, dtype), + Linear_BMTInitializer(in_features, out_features, bias, dtype), + ]) + + def forward(self, input): + return self.l(input) + +class Linear_BlockList(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.l = bmt.TransformerBlockList([ + Linear_BMTInitializer(in_features, out_features, bias, dtype), + Linear_BMTInitializer(in_features, out_features, bias, dtype), + ]) + + def forward(self, input): + return self.l(input) + +class Linear_CheckpointList(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.l = torch.nn.ModuleList([ + bmt.CheckpointBlock(Linear_BMTInitializer(in_features, out_features, bias, dtype)), + bmt.CheckpointBlock(Linear_BMTInitializer(in_features, out_features, bias, dtype)), + ]) + + def forward(self, input): + return self.l(input) + +def check(ckpt_path, ckpt_path_ref): + if bmt.rank() == 0: + ckpt1 = torch.load(ckpt_path) + ckpt2 = torch.load(ckpt_path_ref) + for (k1, v1), (k2, v2) in zip(ckpt1.items(), ckpt2.items()): + assert_eq(k1, k2) + print(v1, v2) + assert_all_eq(v1.cuda(), v2.cuda()) + +def test_main(): + ckpt_path_ref = "test_ckpt_ref.pt" + ckpt_path = "test_ckpt.pt" + shape = [3, 5] + # torch + m = [None] * 4 + ret = [None] * 4 + + manual_seed(33) + m[0] = Linear_NormalList(*shape) + if bmt.rank() == 0: + torch.save(m[0].state_dict(), ckpt_path_ref) + + manual_seed(33) + m[1] = Linear_Pipeline(*shape) + bmt.init_parameters(m[1]) + bmt.save(m[1], ckpt_path) + check(ckpt_path, ckpt_path_ref) + + # bmtrain + manual_seed(33) + m[2] = Linear_BlockList(*shape) + bmt.init_parameters(m[2]) + bmt.save(m[2], ckpt_path) + check(ckpt_path, ckpt_path_ref) + + manual_seed(33) + m[3] = Linear_CheckpointList(*shape) + bmt.init_parameters(m[3]) + bmt.save(m[3], ckpt_path) + check(ckpt_path, ckpt_path_ref) + + if bmt.rank() == 0: + os.remove(ckpt_path) + os.remove(ckpt_path_ref) + +if __name__ == "__main__": + bmt.init_distributed(pipe_size=2) + + test_main() \ No newline at end of file diff --git a/examples/BMTrain/tests/test_inspector_hidden.py b/examples/BMTrain/tests/test_inspector_hidden.py new file mode 100644 index 00000000..62e03656 --- /dev/null +++ b/examples/BMTrain/tests/test_inspector_hidden.py @@ -0,0 +1,241 @@ +from utils import * + +import bmtrain as bmt +import random +import torch +from bmtrain import config +from bmtrain.block_layer import Block, TransformerBlockList +from bmtrain.pipe_layer import PipelineTransformerBlockList +import torch.nn.functional as F +from bmtrain import inspect + +class Linear(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.out = {} + if init_weight: + self.weight = bmt.DistributedParameter(torch.tensor(init_weight, dtype=torch.float, device="cuda").reshape(out_features, in_features)) + else: + self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.xavier_normal_) + + if init_bias: + self.bias = bmt.DistributedParameter(torch.tensor(init_bias, dtype=torch.float, device="cuda").reshape(out_features,)) + else: + self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_) + + def forward(self, input): + ret = F.linear(input, self.weight, self.bias) + return ret + +class L2(bmt.DistributedModule): + def __init__(self, dim : int): + super().__init__() + self.m1 = Linear(dim, dim) + self.m2 = Linear(dim, dim) + + def forward(self, x): + x = self.m1(x) + x = self.m2(x) + return x + +class L2_record(bmt.DistributedModule): + def __init__(self, dim : int): + super().__init__() + self.m1 = Linear(dim, dim) + self.m2 = Linear(dim, dim) + + def forward(self, x): + x = self.m1(x) + inspect.record_tensor(x, "hidden") + x = self.m2(x) + return x + +class Model_ZERO(torch.nn.Module): + def __init__(self, pre, ms, post) -> None: + super().__init__() + self.pre = pre + self.ms = TransformerBlockList([ + Block(m) + for m in ms + ]) + self.post = post + + def forward(self, x, return_hidden_states=False): + x = self.pre(x) + if return_hidden_states: + x, o = self.ms(x, return_hidden_states=return_hidden_states) + return self.post(x), o + else: + x = self.ms(x, return_hidden_states=return_hidden_states) + return self.post(x) + +class Model_PIPE(torch.nn.Module): + def __init__(self, pre, ms, post) -> None: + super().__init__() + self.pre = pre + self.ms = PipelineTransformerBlockList([ + Block(m) + for m in ms + ]) + self.post = post + + def forward(self, x, return_hidden_states=False): + x = self.pre(x) + if return_hidden_states: + x, o = self.ms(x, return_hidden_states=return_hidden_states) + return self.post(x), o + else: + x = self.ms(x, return_hidden_states=return_hidden_states) + return self.post(x) + +class Model_BLOCK(torch.nn.Module): + def __init__(self, pre, ms, post) -> None: + super().__init__() + self.pre = pre + self.ms = torch.nn.ModuleList([ + Block(m) + for m in ms + ]) + self.post = post + + def forward(self, x, return_hidden_states=False): + x = self.pre(x) + o = [] + y = x + for m in self.ms: + o.append(y) + y = m(y) + if return_hidden_states: + return self.post(y), o + else: + return self.post(y) + +class Model_NORMAL(torch.nn.Module): + def __init__(self, pre, ms, post) -> None: + super().__init__() + self.pre = pre + self.ms = torch.nn.ModuleList(ms) + self.post = post + + def forward(self, x, return_hidden_states=False): + x = self.pre(x) + o = [] + y = x + for m in self.ms: + o.append(y) + y = m(y) + if return_hidden_states: + return self.post(y), o + else: + return self.post(y) + +def manual_seed(seed=33): + torch.manual_seed(seed) + random.seed(seed) + try: + import numpy as np + np.random.seed(seed) + except ModuleNotFoundError: + pass + +def sub_run(name, cls, num_layer, dim, batch, seq_len): + manual_seed() + + pre = Linear(dim, dim) + post = Linear(dim, dim) + ms = [L2_record(dim) if i%2==0 else L2(dim) for i in range(num_layer)] + + inp = torch.randn((batch, seq_len, dim)).cuda() + last_weight = torch.randn((batch, seq_len, dim)).cuda() + middle_weight = [ + torch.randn((batch, seq_len, dim)).cuda() + for i in range(len(ms)//2) + ] + + bmt.init_parameters(pre) + bmt.init_parameters(post) + for m in ms: + bmt.init_parameters(m) + m = cls(pre, [m for m in ms], post) + ret = "" + with inspect.inspect_tensor() as inspector: + logits = m(inp) + + ret += inspect.format_summary( + inspector.get_summary() + ) + "\n" + + loss = 0 + for i in range(len(ms)//2): + loss = loss + (inspector.summary[i]['tensor'] * middle_weight[i]).sum() + + with inspect.inspect_tensor(): + loss.backward() + + ret += inspect.format_summary( + inspect.inspect_model(m, '*') + ) + "\n" + ret += inspect.format_summary( + inspector.get_summary() + ) + "\n" + + with inspect.inspect_tensor() as inspector: + logits = m(inp) + + ret += inspect.format_summary( + inspector.get_summary() + ) + "\n" + + loss = (logits * last_weight).sum() + + with inspect.inspect_tensor(): + loss.backward() + + ret += inspect.format_summary( + inspect.inspect_model(m, '*') + ) + "\n" + ret += inspect.format_summary( + inspector.get_summary() + ) + "\n" + + return ret + "\n" # replace for matching None grad with zero_grad + +def run(name, cls, num_layer=4, dim=4096, batch=32, seq_len=256): + ret = "" + ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len) + bmt.synchronize() + return ret + +def test_main(): + ret = {} + ret["normal"] = run("normal", Model_NORMAL) + ret["block"] = run("block", Model_BLOCK) + ret["zero"] = run("zero", Model_ZERO) + ret["pipe"] = run("pipe", Model_PIPE) + for k, r in ret.items(): + bmt.print_rank(f"============={k}============") + bmt.print_rank(r) + for r in ret.values(): + for r2 in ret.values(): + lines, lines2 = r.split('\n'), r2.split('\n') + assert len(lines) == len(lines2) + for line, line2 in zip(lines, lines2): + words, words2 = line.split(), line2.split() + assert len(words) == len(words2) + for w, w2 in zip(words, words2): + try: + is_float = isinstance(eval(w), float) + except: + is_float = False + if is_float: + assert_lt(abs(float(w)-float(w2)), 0.00011) + else: + assert_eq(w, w2) + +if __name__ == "__main__": + bmt.init_distributed(pipe_size=2) + + test_main() diff --git a/examples/BMTrain/tests/test_load_ckpt.py b/examples/BMTrain/tests/test_load_ckpt.py new file mode 100644 index 00000000..0eb4f95f --- /dev/null +++ b/examples/BMTrain/tests/test_load_ckpt.py @@ -0,0 +1,78 @@ +from utils import * +import torch +import torch.nn.functional as F +import bmtrain as bmt +import os + +class Linear_Normal(torch.nn.Module): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype, device="cuda")) + torch.nn.init.xavier_normal_(self.weight) + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, dtype=dtype, device="cuda")) + torch.nn.init.zeros_(self.bias) + else: + self.register_parameter('bias', None) + + def forward(self, input): + return F.linear(input, self.weight, self.bias) + +class Linear_BMT(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=dtype), init_method=torch.nn.init.xavier_normal_) + if bias: + self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=dtype), init_method=torch.nn.init.zeros_) + else: + self.register_parameter('bias', None) + + def forward(self, input): + return F.linear(input, self.weight, self.bias) + + +def test_main(): + ckpt_path = "test_ckpt.pt" + # Transformer BlockList + m = Linear_Normal(256, 256).cuda() + m2 = bmt.TransformerBlockList([bmt.Block(Linear_BMT(256, 256))]) + if bmt.rank() == 0: + torch.save(m.state_dict(), ckpt_path) + dic2 = m.state_dict() + dic2["0.weight"] = dic2.pop("weight") + dic2["0.bias"] = dic2.pop("bias") + m2.load_state_dict(dic2) + for key in m.state_dict(): + bmt_key = f"0.{key}" + assert bmt_key in m2.state_dict(), "wrong key in bmtrain model" + assert (m2.state_dict()[bmt_key].cuda() == m.state_dict()[key]).all() , "wrong param in bmtrain model" + if bmt.rank() == 0: + os.remove(ckpt_path) + print("Transformer Blocklist load_state_dict and state_dict test passed") + + # Block + m3 = bmt.Block(Linear_BMT(256, 256)) + m3.load_state_dict(m.state_dict()) + for key in m.state_dict(): + assert key in m3.state_dict(), "wrong key in bmtrain model" + assert (m.state_dict()[key] == m3.state_dict()[key].cuda()).all(), "wrong param in bmtrain model" + print("Block load_state_dict and state_dict test passed") + + # normal Distributed module + m4 = Linear_BMT(256, 256) + m4.load_state_dict(m.state_dict()) + for key in m.state_dict(): + assert key in m4.state_dict(), "wrong key in bmtrain model" + assert (m.state_dict()[key] == m4.state_dict()[key].cuda()).all(), "wrong param in bmtrain model" + print("bmt.distributedmodule load_state_dict and state_dict test passed") + +if __name__ == "__main__": + bmt.init_distributed() + + test_main() diff --git a/examples/BMTrain/tests/test_loss_func.py b/examples/BMTrain/tests/test_loss_func.py new file mode 100644 index 00000000..a448b6d1 --- /dev/null +++ b/examples/BMTrain/tests/test_loss_func.py @@ -0,0 +1,79 @@ +from utils import * + +import torch +import bmtrain as bmt +import torch +import random +import copy + +def run(x, tgt, loss_func, bigmodel=None, scale=32768, use_float=False): + x = x.clone().detach() + bigmodel = copy.deepcopy(bigmodel) + if use_float: + x = x.float() + if bigmodel is not None: + bigmodel = bigmodel.float() + x = x.requires_grad_() + if bigmodel is None: + loss = loss_func(x, tgt) + else: + t = bigmodel(x) + loss = loss_func(t, tgt) + (loss * scale).backward() + return loss, x.grad + +def check(x, tgt, loss_func1, loss_func2, bigmodel=None): + loss_1, grad_1 = run(x, tgt, loss_func1, bigmodel=bigmodel) + loss_2, grad_2 = run(x, tgt, loss_func2, bigmodel=bigmodel, use_float=True) + assert_eq(grad_1.isnan().sum(), 0) + assert_eq(grad_2.isnan().sum(), 0) + print(f"{(loss_1 - loss_2).abs().item():.6f} {(grad_1 - grad_2).abs().max().item():.6f}") + assert_lt((loss_1 - loss_2).abs().item(), 1e-5) + assert_lt((grad_1 - grad_2).abs().max().item(), 1e-1) + +def test_simple(dtype): + loss_func1 = bmt.loss.FusedCrossEntropy() + loss_func2 = torch.nn.CrossEntropyLoss() + + N = 32 * 512 + for i in range(1, 10): + C = i * 10 + x = torch.randn(N, C).cuda().to(dtype) + tgt = torch.randint(0, C, (N,)).cuda().long() + check(x, tgt, loss_func1, loss_func2) + for i in range(1, 10): + C = i * 100 + x = torch.randn(N, C).cuda().to(dtype) + tgt = torch.randint(0, C, (N,)).cuda().long() + check(x, tgt, loss_func1, loss_func2) + for i in range(1, 31): + C = i * 1000 + x = torch.randn(N, C).cuda().to(dtype) + tgt = torch.randint(0, C, (N,)).cuda().long() + check(x, tgt, loss_func1, loss_func2) + +def test_other(dtype): + N = 32 * 512 + for i in range(1, 11): + C = i * 10 + weight = [i+1 for i in range(C)] + random.shuffle(weight) + weight = torch.tensor(weight, device="cuda") + loss_func1 = bmt.loss.FusedCrossEntropy(weight=weight.clone().to(dtype)) + loss_func2 = torch.nn.CrossEntropyLoss(weight=weight.clone().float()) + + x = torch.randn(N, C).cuda().to(dtype) + tgt = torch.randint(0, C, (N,)).cuda().long() + mask = torch.randint(0, 2, (N,)).cuda().bool() + tgt[mask] = -100 + check(x, tgt, loss_func1, loss_func2) + +if __name__ == "__main__": + test_other(torch.float16) + test_simple(torch.float16) + print("==============================================================================") + try: + test_other(torch.bfloat16) + test_simple(torch.bfloat16) + except NotImplementedError: + pass \ No newline at end of file diff --git a/examples/BMTrain/tests/test_middle_hidden.py b/examples/BMTrain/tests/test_middle_hidden.py new file mode 100644 index 00000000..2a93efe0 --- /dev/null +++ b/examples/BMTrain/tests/test_middle_hidden.py @@ -0,0 +1,212 @@ +from utils import * + +import bmtrain as bmt +import random +import torch +from bmtrain.block_layer import Block, TransformerBlockList +from bmtrain.pipe_layer import PipelineTransformerBlockList +import torch.nn.functional as F +from bmtrain import inspect + +class Linear(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.out = {} + if init_weight: + self.weight = bmt.DistributedParameter(torch.tensor(init_weight, dtype=torch.float, device="cuda").reshape(out_features, in_features)) + else: + self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.xavier_normal_) + + if init_bias: + self.bias = bmt.DistributedParameter(torch.tensor(init_bias, dtype=torch.float, device="cuda").reshape(out_features,)) + else: + self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_) + + def forward(self, input): + ret = F.linear(input, self.weight, self.bias) + return ret + +class Model_ZERO(torch.nn.Module): + def __init__(self, pre, ms, post) -> None: + super().__init__() + self.pre = pre + self.ms = TransformerBlockList([ + Block(m) + for m in ms + ]) + self.post = post + + def forward(self, x, return_hidden_states=False): + x = self.pre(x) + if return_hidden_states: + x, o = self.ms(x, return_hidden_states=return_hidden_states) + return self.post(x), o + else: + x = self.ms(x, return_hidden_states=return_hidden_states) + return self.post(x) + +class Model_PIPE(torch.nn.Module): + def __init__(self, pre, ms, post) -> None: + super().__init__() + self.pre = pre + self.ms = PipelineTransformerBlockList([ + Block(m) + for m in ms + ]) + self.post = post + + def forward(self, x, return_hidden_states=False): + x = self.pre(x) + if return_hidden_states: + x, o = self.ms(x, return_hidden_states=return_hidden_states) + return self.post(x), o + else: + x = self.ms(x, return_hidden_states=return_hidden_states) + return self.post(x) + +class Model_BLOCK(torch.nn.Module): + def __init__(self, pre, ms, post) -> None: + super().__init__() + self.pre = pre + self.ms = torch.nn.ModuleList([ + Block(m) + for m in ms + ]) + self.post = post + + def forward(self, x, return_hidden_states=False): + x = self.pre(x) + o = [] + y = x + for m in self.ms: + o.append(y) + y = m(y) + if return_hidden_states: + return self.post(y), o + else: + return self.post(y) + +class Model_NORMAL(torch.nn.Module): + def __init__(self, pre, ms, post) -> None: + super().__init__() + self.pre = pre + self.ms = torch.nn.ModuleList(ms) + self.post = post + + def forward(self, x, return_hidden_states=False): + x = self.pre(x) + o = [] + y = x + for m in self.ms: + o.append(y) + y = m(y) + if return_hidden_states: + return self.post(y), o + else: + return self.post(y) + +def manual_seed(seed=33): + torch.manual_seed(seed) + random.seed(seed) + try: + import numpy as np + np.random.seed(seed) + except ModuleNotFoundError: + pass + +def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_middle=False, mix_test=False): + manual_seed() + + pre = Linear(dim, dim) + post = Linear(dim, dim) + ms = [Linear(dim, dim) for i in range(num_layer)] + + inp = torch.randn((batch, seq_len, dim)).cuda() + last_weight = torch.randn((batch, seq_len, dim)).cuda() + middle_weight = [ + torch.randn((batch, seq_len, dim)).cuda() + for i in range(len(ms)) + ] + + bmt.init_parameters(pre) + bmt.init_parameters(post) + for m in ms: + bmt.init_parameters(m) + m = cls(pre, [m for m in ms], post) + + ret = "" + if only_last: + logits = m(inp) + loss = (logits * last_weight).sum() + loss.backward() + ret += f"========================only last========================\n" + ret += inspect.format_summary( + inspect.inspect_model(m, '*') + ) + if only_middle: + logits, hidden_states = m(inp, return_hidden_states=True) + loss = sum([ + (hidden_state * middle_weight[i]).sum() + for i, hidden_state in enumerate(hidden_states) + ]) + loss.backward() + ret += f"========================only middle========================\n" + ret += inspect.format_summary( + inspect.inspect_model(m, '*') + ) + if mix_test: + logits, hidden_states = m(inp, return_hidden_states=True) + loss = sum([ + (hidden_state * middle_weight[i]).sum() + for i, hidden_state in enumerate(hidden_states) + ]) + (logits * last_weight).sum() + loss.backward() + ret += f"========================mix========================\n" + ret += inspect.format_summary( + inspect.inspect_model(m, '*') + ) + return ret + "\n" # replace for matching None grad with zero_grad + +def run(name, cls, num_layer=4, dim=4096, batch=32, seq_len=256): + ret = "" + ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len, only_last=True) + bmt.synchronize() + ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len, only_middle=True) + bmt.synchronize() + ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len, mix_test=True) + bmt.synchronize() + return ret + +def test_main(): + ret = {} + ret["normal"] = run("normal", Model_NORMAL) + ret["block"] = run("block", Model_BLOCK) + ret["zero"] = run("zero", Model_ZERO) + # ret["pipe"] = run("pipe", Model_PIPE) + for k, r in ret.items(): + bmt.print_rank(f"============={k}============") + bmt.print_rank(r) + for r in ret.values(): + for r2 in ret.values(): + lines, lines2 = r.split('\n'), r2.split('\n') + assert len(lines) == len(lines2) + for line, line2 in zip(lines, lines2): + words, words2 = line.split(), line2.split() + assert len(words) == len(words2) + for w, w2 in zip(words, words2): + try: + is_float = isinstance(eval(w), float) + except: + is_float = False + if is_float: + assert_lt(abs(float(w)-float(w2)), 2.) + else: + assert_eq(w, w2) + +if __name__ == "__main__": + bmt.init_distributed(pipe_size=1) + + test_main() diff --git a/examples/BMTrain/tests/test_model_wrapper.py b/examples/BMTrain/tests/test_model_wrapper.py new file mode 100644 index 00000000..6f913d3c --- /dev/null +++ b/examples/BMTrain/tests/test_model_wrapper.py @@ -0,0 +1,221 @@ +from utils import * + +from typing import Optional +import torch +import math +import torch.nn.functional as F +import bmtrain as bmt +import time + +class Attention(torch.nn.Module): + def __init__(self, + dim_model : int, dim_head : int, + num_heads : int, bias : bool = True, + dtype = None + ) -> None: + super().__init__() + + self.project_q = torch.nn.Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_k = torch.nn.Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_v = torch.nn.Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + + self.project_out = torch.nn.Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + + self.softmax = torch.nn.Softmax(dim=-1) + self.num_heads = num_heads + self.dim_head = dim_head + self.dim_model = dim_model + + def forward(self, + hidden_q : torch.Tensor, # (batch_size, seq_q, dim_model) + hidden_kv : torch.Tensor, # (batch_size, seq_kv, dim_model) + mask : torch.BoolTensor, # (batch_size, seq_q, seq_kv) + position_bias : Optional[torch.Tensor] = None, # (batch, num_heads, seq_q, seq_kv) + ) -> torch.Tensor: + batch_size, seq_q, dim_model = hidden_q.size() + seq_kv = hidden_kv.size(1) + + h_q : torch.Tensor = self.project_q(hidden_q) + h_k : torch.Tensor = self.project_k(hidden_kv) + h_v : torch.Tensor = self.project_v(hidden_kv) + + h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head) + h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head) + h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head) + + h_q = h_q.permute(0, 2, 1, 3).contiguous() + h_k = h_k.permute(0, 2, 1, 3).contiguous() + h_v = h_v.permute(0, 2, 1, 3).contiguous() + + h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head) + h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head) + h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head) + + score = torch.bmm( + h_q, h_k.transpose(1, 2) + ) + score = score / math.sqrt(self.dim_head) + + score = score.view(batch_size, self.num_heads, seq_q, seq_kv) + + if position_bias is not None: + score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv) + + score = torch.where( + mask.view(batch_size, 1, seq_q, seq_kv), + score, + torch.scalar_tensor(float('-inf'), device=score.device, dtype=score.dtype) + ) + + score = torch.where( + mask.view(batch_size, 1, seq_q, seq_kv), + self.softmax(score), + torch.scalar_tensor(0, device=score.device, dtype=score.dtype) + ) + + score = score.view(batch_size * self.num_heads, seq_q, seq_kv) + + h_out = torch.bmm( + score, h_v + ) + h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head) + h_out = h_out.permute(0, 2, 1, 3).contiguous() + h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head) + + attn_out = self.project_out(h_out) + return attn_out + +class Feedforward(torch.nn.Module): + def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None: + super().__init__() + + self.w_in = torch.nn.Linear(dim_model, dim_ff, bias = bias, dtype=dtype) + self.w_out = torch.nn.Linear(dim_ff, dim_model, bias = bias, dtype=dtype) + + self.relu = torch.nn.ReLU() + + def forward(self, input : torch.Tensor) -> torch.Tensor: + return self.w_out(self.relu(self.w_in(input))) + + +class TransformerEncoder(torch.nn.Module): + def __init__(self, + dim_model : int, dim_head : int, num_heads : int, dim_ff : int, + bias : bool = True, dtype = None + ) -> None: + super().__init__() + + self.ln_attn = torch.nn.LayerNorm(dim_model, dtype=dtype) + self.attn = Attention(dim_model, dim_head, num_heads, bias=bias, dtype=dtype) + + self.ln_ff = torch.nn.LayerNorm(dim_model, dtype=dtype) + self.ff = Feedforward(dim_model, dim_ff, bias=bias, dtype=dtype) + + def forward(self, + hidden : torch.Tensor, # (batch, seq_len, dim_model) + mask : torch.BoolTensor, # (batch, seq_len, dim_model) + position_bias : Optional[torch.Tensor] = None, # (batch, num_head, seq_len, seq_len) + ): + x = self.ln_attn(hidden) + x = self.attn(x, x, mask, position_bias) + hidden = hidden + x + + x = self.ln_ff(hidden) + x = self.ff(x) + hidden = hidden + x + + return hidden + + +class GPT(torch.nn.Module): + def __init__(self, + num_layers : int, vocab_size : int, + dim_model : int, dim_head : int, num_heads : int, dim_ff : int, + max_distance : int, + bias : bool = True, dtype = None + ) -> None: + super().__init__() + + self.max_distance = max_distance + + self.word_emb = torch.nn.Embedding(vocab_size, dim_model, dtype=dtype) + self.pos_emb = torch.nn.Embedding(max_distance, dim_model, dtype=dtype) + self.dim_model = dim_model + + self.transformers = torch.nn.ModuleList([ + TransformerEncoder( + dim_model, dim_head, num_heads, dim_ff, bias, dtype + ) + for _ in range(num_layers) + ]) + + self.layernorm = torch.nn.LayerNorm(dim_model, dtype=dtype) + + def forward(self, + input : torch.LongTensor, # (batch, seq_len) + pos : torch.LongTensor, # (batch, seq_len) + mask : torch.BoolTensor, # (batch, seq_len) + ) -> torch.Tensor: + + mask_2d = mask[:, None, :] & mask[:, :, None] # (batch, seq_len, seq_len) + mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None]) + + input_emb = self.pos_emb(pos) + self.word_emb(input) + + out = input_emb + for layer in self.transformers: + out = layer(out, mask_2d) + out = self.layernorm(out) + + logits = F.linear(out, self.word_emb.weight) / math.sqrt(self.dim_model) + + return logits + +def test_main(): + model = GPT( + num_layers=8, + vocab_size=10240, + dim_model=2560, + dim_head=80, + num_heads=32, + dim_ff=8192, + max_distance=1024, + bias=True, + dtype=torch.half + ) + + bmt_model = bmt.BMTrainModelWrapper(model) + model = model.cuda() + + # use the break if i == bmt.rank() to generate different data on different rank + torch.manual_seed(1234) + batch_size = 2 + seq_len = 512 + + for i in range(bmt.world_size()): + sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) + enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() + enc_input = sent[:, :-1].long().cuda() + targets = sent[:, 1:].long().cuda() + mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None] + targets = torch.where( + mask, + targets, + torch.full_like(targets, -100, dtype=torch.long) + ) + + if i == bmt.rank(): + break + + pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) + logits = model(enc_input, pos, pos < enc_length[:, None]) + bmt_logits = bmt_model(enc_input, pos, pos < enc_length[:, None]) + print(logits) + bmt.synchronize() + print(bmt_logits) + assert_all_eq(logits, bmt_logits) + +if __name__ == '__main__': + bmt.init_distributed(seed=0) + + test_main() diff --git a/examples/BMTrain/tests/test_multi_return.py b/examples/BMTrain/tests/test_multi_return.py new file mode 100644 index 00000000..f4a5d79f --- /dev/null +++ b/examples/BMTrain/tests/test_multi_return.py @@ -0,0 +1,126 @@ +from utils import * + +import bmtrain as bmt +import torch +import random +from bmtrain import config +from bmtrain.block_layer import Block, TransformerBlockList +from bmtrain.pipe_layer import PipelineTransformerBlockList +import torch.nn.functional as F + +class MultiInputReturn(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, b, c, d, e): + return a*2, b+d, c*4+e*5 + +class Model_ZERO(torch.nn.Module): + def __init__(self, ms) -> None: + super().__init__() + self.ms = TransformerBlockList([ + Block(m) + for m in ms + ], num_hidden=3) + + def forward(self, x): + y = self.ms(*x) + return y + +class Model_PIPE(torch.nn.Module): + def __init__(self, ms) -> None: + super().__init__() + self.ms = PipelineTransformerBlockList([ + Block(m) + for m in ms + ], num_hidden=3) + + def forward(self, x): + y = self.ms(*x) + return y + +class Model_BLOCK(torch.nn.Module): + def __init__(self, ms) -> None: + super().__init__() + self.ms = torch.nn.ModuleList([ + Block(m) + for m in ms + ]) + + def forward(self, x): + y = x[:3] + other = x[3:] + for m in self.ms: + y = m(*y, *other) + return y + +class Model_NORMAL(torch.nn.Module): + def __init__(self, ms) -> None: + super().__init__() + self.ms = torch.nn.ModuleList(ms) + + def forward(self, x): + y = x[:3] + other = x[3:] + for m in self.ms: + y = m(*y, *other) + return y + +def manual_seed(seed=33): + torch.manual_seed(seed) + random.seed(seed) + try: + import numpy as np + np.random.seed(seed) + except ModuleNotFoundError: + pass + +def run(name, cls, num_layer=4, dim=4096): + manual_seed() + + ms = [MultiInputReturn() for i in range(num_layer)] + + inps = ( + torch.randn((dim,)).cuda(), + torch.randn((dim,)).cuda(), + torch.randn((dim,)).cuda(), + torch.randn((dim,)).cuda(), + torch.randn((dim,)).cuda(), + ) + last_weights = ( + torch.randn((dim,)).cuda(), + torch.randn((dim,)).cuda(), + torch.randn((dim,)).cuda(), + ) + + for inp in inps: + inp.requires_grad_(True) + m = cls(ms) + + ret = "" + logits = m(inps) + loss = (logits[0]*last_weights[0] + logits[1]*last_weights[1] + logits[2]*last_weights[2]).sum() + loss.backward() + return list(logits) + [ + inp.grad + for inp in inps + ] + +def test_main(): + ret = {} + ret["normal"] = run("normal", Model_NORMAL) + ret["block"] = run("block", Model_BLOCK) + ret["zero"] = run("zero", Model_ZERO) + # ret["pipe"] = run("pipe", Model_PIPE) # TODO pipeline not support multiple input-output yet + for k, r in ret.items(): + bmt.print_rank(f"============={k}============") + bmt.print_rank(r) + for r in ret.values(): + for r2 in ret.values(): + for i in range(len(r)): + assert_lt((r[i]-r2[i]).abs().max(), 1e-5) + +if __name__ == "__main__": + bmt.init_distributed(pipe_size=1) + + test_main() diff --git a/examples/BMTrain/tests/test_nccl_backward.py b/examples/BMTrain/tests/test_nccl_backward.py new file mode 100644 index 00000000..5950fb5c --- /dev/null +++ b/examples/BMTrain/tests/test_nccl_backward.py @@ -0,0 +1,43 @@ +from utils import * + +import bmtrain as bmt +import torch + +def test_main(dtype): + x = torch.full((1,), bmt.rank() + 1, dtype=dtype, device="cuda").requires_grad_(True) + y = bmt.distributed.all_reduce(x, "prod").view(-1) + loss = (y * y).sum() / 2 + loss.backward() + ref = y + for i in range(bmt.world_size()): + if i != bmt.rank(): ref *= i+1 + assert_eq(x.grad, ref) + +def test_reducescatter(): + world_size = bmt.world_size() + for shape in [(128,), (128,128)]: + tensors = torch.randn(world_size, *shape, dtype=torch.half, device="cuda").requires_grad_(True) + local_tensor = tensors[bmt.rank()] + x = local_tensor.detach().clone().requires_grad_(True) + y = bmt.distributed.reduce_scatter(x, "sum") + ref = tensors.sum(0) + partition = x.shape[0] // bmt.world_size() + ref_p = ref[bmt.rank() * partition:(bmt.rank() + 1) * partition] + if bmt.rank() == 0: + print(ref_p) + print(y) + assert torch.allclose(ref_p, y, atol=1e-2, rtol=1e-3) + g = torch.randn_like(y) + grad = torch.autograd.grad(y, x, g)[0] + pgrad = grad[bmt.rank() * y.shape[0]: (bmt.rank() + 1) * y.shape[0]] + ref_g = g + if bmt.rank() == 0: + print(ref_g) + print(pgrad) + assert torch.allclose(ref_g, pgrad, atol=1e-3, rtol=1e-3) + +if __name__ == "__main__": + bmt.init_distributed() + test_reducescatter() + test_main(torch.half) + test_main(torch.bfloat16) diff --git a/examples/BMTrain/tests/test_no_grad.py b/examples/BMTrain/tests/test_no_grad.py new file mode 100644 index 00000000..7851c670 --- /dev/null +++ b/examples/BMTrain/tests/test_no_grad.py @@ -0,0 +1,90 @@ +import torch +import bmtrain as bmt + +class Layer(torch.nn.Module): + def __init__(self): + super(Layer, self).__init__() + self.linear = bmt.nn.Linear(32, 32) + self.count = 0 + def forward(self, x): + self.count += 1 + return self.linear(x) + +def test_no_grad(): + x = torch.randn(32, 32, device='cuda') + + layer1 = bmt.Block(Layer()) + layer2 = bmt.Block(Layer()) + layer1.linear.weight.requires_grad_(False) + layer1.linear.bias.requires_grad_(False) + y = layer1(x) + assert y.requires_grad == False + y = layer2(y) + y.sum().backward() + assert layer1.count == 1 + assert layer2.count == 2 + +def test_multi_layer_no_grad(): + x = torch.randn(32, 32, device='cuda') + + layers = [] + for i in range(10): + layer = bmt.Block(Layer()) + layer.linear.weight.requires_grad_(i > 4) + layer.linear.bias.requires_grad_(i > 4) + layers.append(layer) + + y = x + for layer in layers: + y = layer(y) + y.sum().backward() + for i in range(len(layers)): + assert layers[i].count == (1 if i <=4 else 2) + +def test_all_input_no_grad(): + linear1 = bmt.nn.Linear(32, 32) + linear2 = bmt.nn.Linear(32, 32) + + x = torch.randn(32,32, device='cuda') + + linear1 = bmt.Block(linear1) + linear2 = bmt.Block(linear2) + y = linear1(x) + y = linear2(y) + y.sum().backward() + assert linear1.weight.grad is not None + assert linear1.bias.grad is not None + assert x.grad is None + +def test_same_layer(): + layer = Layer() + block_list = bmt.TransformerBlockList([layer, layer]) + assert id(block_list[0]) != id(block_list[1]) + +def test_no_grad_error(): + layer = bmt.Block(Layer()) + + try: + block_list = bmt.TransformerBlockList([layer, layer]) + raise ValueError("test failed") + except AssertionError as e: + return + +def test_no_grad_error2(): + layer = bmt.Block(Layer()) + + try: + block_list = bmt.PipelineTransformerBlockList([layer]) + raise ValueError("test failed") + except AssertionError as e: + return + +if __name__ == '__main__': + bmt.init_distributed() + + test_no_grad() + test_multi_layer_no_grad() + test_all_input_no_grad() + test_same_layer() + test_no_grad_error() + test_no_grad_error2() diff --git a/examples/BMTrain/tests/test_optim.py b/examples/BMTrain/tests/test_optim.py new file mode 100644 index 00000000..0aca8c31 --- /dev/null +++ b/examples/BMTrain/tests/test_optim.py @@ -0,0 +1,94 @@ +from utils import * +import torch +import bmtrain as bmt +from bmtrain import optim + +class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + self.fc1 = torch.nn.Linear(128, 128, bias=False) + self.fc2 = torch.nn.Linear(128, 128) + self.fc3 = torch.nn.Linear(128, 128) + self.fc4 = torch.nn.Linear(128, 128) + self.fc5 = torch.nn.Linear(128, 128) + self.param = torch.nn.Parameter(torch.empty(1237)) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + x = self.fc4(x) + x = self.fc5(x) + return x + +def main(dtype): + model1 = TestModule() + model2 = TestModule() + model3 = TestModule() + model4 = TestModule() + model5 = TestModule() + + state_dict = model1.state_dict() + for kw in state_dict.keys(): + state_dict[kw] = torch.randn_like(state_dict[kw]) + + model1.load_state_dict(state_dict) + model2.load_state_dict(state_dict) + model3.load_state_dict(state_dict) + model4.load_state_dict(state_dict) + model5.load_state_dict(state_dict) + + model1 = model1.cuda().to(dtype) + model2 = model2.cuda().to(dtype) + model3 = model3.cuda() + model4 = model4.cuda() + model5 = model5.cuda() + + opt1 = bmt.optim.AdamOptimizer(model1.parameters(), lr=1) + opt2 = bmt.optim.AdamOffloadOptimizer(model2.parameters(), lr=1) + opt3 = torch.optim.Adam(model3.parameters(), lr=1) + opt4 = bmt.optim.AdamOptimizer(model4.parameters(), lr=1) + opt5 = bmt.optim.AdamOffloadOptimizer(model5.parameters(), lr=1) + + optim_manager = bmt.optim.OptimManager(loss_scale=4) + optim_manager.add_optimizer(opt1) + optim_manager.add_optimizer(opt2) + optim_manager.add_optimizer(opt3) + optim_manager.add_optimizer(opt4) + optim_manager.add_optimizer(opt5) + + for _ in range(100): + optim_manager.zero_grad() + + for p1, p2, p3, p4, p5 in zip(model1.parameters(), model2.parameters(), model3.parameters(), model4.parameters(), model5.parameters()): + grad = torch.randn_like(p1) + p1.grad = grad.to(dtype) + p2.grad = grad.to(dtype) + p3.grad = grad.float() + p4.grad = grad.float() + p5.grad = grad.float() + + optim_manager.step() + torch.cuda.synchronize() + + for p1, p2, p3, p4, p5 in zip(model1.parameters(), model2.parameters(), model3.parameters(), model4.parameters(), model5.parameters()): + diff1 = torch.abs(p1 - p2).max().item() + diff2 = torch.abs(p1 - p3).max().item() + diff3 = torch.abs(p2 - p3).max().item() + diff4 = torch.abs(p3 - p4).max().item() + diff5 = torch.abs(p3 - p5).max().item() + print(f"{diff1:.6f}, {diff2:.6f}, {diff3:.6f}, {diff4:.6f}, {diff5:.6f}") + assert_lt(diff1, 1) + assert_lt(diff2, 1) + assert_lt(diff3, 1) + assert_eq(diff4, 0) + assert_lt(diff5, 0.00001) + +if __name__ == "__main__": + bmt.init_distributed() + main(torch.float16) + print("==============================================================================") + try: + main(torch.bfloat16) + except NotImplementedError: + pass diff --git a/examples/BMTrain/tests/test_optim_state.py b/examples/BMTrain/tests/test_optim_state.py new file mode 100644 index 00000000..57d5d0e3 --- /dev/null +++ b/examples/BMTrain/tests/test_optim_state.py @@ -0,0 +1,135 @@ +import torch +import bmtrain as bmt +import os +from copy import deepcopy +from bmtrain import optim, lr_scheduler + +class TestSubModule(bmt.DistributedModule): + def __init__(self): + super(TestSubModule, self).__init__() + self.fc1 = bmt.BMTrainModelWrapper(torch.nn.Linear(768, 3072)) + self.fc2 = bmt.BMTrainModelWrapper(torch.nn.Linear(3072, 1024)) + self.fc3 = bmt.BMTrainModelWrapper(torch.nn.Linear(1024, 768)) + self.param = bmt.DistributedParameter(torch.zeros(1237)) + self.fc4 = bmt.BMTrainModelWrapper(torch.nn.Linear(768, 300)) + self.fc5 = bmt.BMTrainModelWrapper(torch.nn.Linear(300, 768)) + self.dropout = torch.nn.Dropout(0.0) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + x = self.fc3(x) + x = self.dropout(x) + x = self.fc4(x) + x = self.fc5(x) + return x + +class TestModule(torch.nn.Module): + def __init__(self): + super(TestModule, self).__init__() + self.layer1 = TestSubModule() + self.layer2 = bmt.CheckpointBlock(TestSubModule()) + + def forward(self, x): + x = self.layer1(x) + x = self.layer2(x) + return x + +def train(model1, model2, model3, optim_manager): + x = torch.randn((4, 768)).cuda() + for _ in range(10): + optim_manager.zero_grad() + + y1, y2, y3 = model1(x), model2(x), model3(x) + w = torch.randn_like(y1) + loss = (y1*w).sum() + (y2*w).sum() + (y3*w).sum() + optim_manager.backward(loss) + + optim_manager.step() + +def manual_seed(seed=33): + torch.manual_seed(seed) + import random + random.seed(seed) + try: + import numpy as np + np.random.seed(seed) + except ModuleNotFoundError: + pass + +def main(): + model1 = TestModule() + model2 = TestModule() + model3 = TestModule() + + bmt.save(model1, "test_optim_state_model1.pt") + + bmt.load(model1, f"test_optim_state_model1.pt") + bmt.load(model2, f"test_optim_state_model1.pt") + bmt.load(model3, f"test_optim_state_model1.pt") + + opt1 = optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3) + opt2 = optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3) + opt3 = torch.optim.Adam(model3.parameters(), weight_decay=1e-3) + lrs1 = lr_scheduler.Noam(opt1, start_lr=20, warmup_iter=0, end_iter=300, num_iter=1) + lrs2 = lr_scheduler.Noam(opt2, start_lr=20, warmup_iter=0, end_iter=300, num_iter=1) + lrs3 = lr_scheduler.Noam(opt3, start_lr=20, warmup_iter=0, end_iter=300, num_iter=1) + optim_manager = optim.OptimManager(loss_scale=256) + optim_manager.add_optimizer(opt1, lrs1) + optim_manager.add_optimizer(opt2, lrs2) + optim_manager.add_optimizer(opt3, lrs3) + + train(model1, model2, model3, optim_manager) + + bmt.save(model1, f"test_optim_state_model1.pt") + bmt.save(model2, f"test_optim_state_model2.pt") + bmt.save(model3, f"test_optim_state_model3.pt") + + torch.save(optim_manager.state_dict(), f"test_optim_manager_{bmt.rank()}.opt") + + manual_seed() + train(model1, model2, model3, optim_manager) + state_2 = deepcopy([list(model1.parameters()), list(model2.parameters()), list(model3.parameters())]) + + model1 = TestModule() + model2 = TestModule() + model3 = TestModule() + bmt.load(model1, f"test_optim_state_model1.pt") + bmt.load(model2, f"test_optim_state_model2.pt") + bmt.load(model3, f"test_optim_state_model3.pt") + + opt1 = optim.AdamOptimizer(model1.parameters(), weight_decay=1e-8, betas=(0.3, 0.333), eps=1e-3) + opt2 = optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-7, betas=(0.4, 0.456), eps=1e-1) + opt3 = torch.optim.Adam(model3.parameters(), weight_decay=1e-6, betas=(0.9, 0.777), eps=1e-2) + lrs1 = lr_scheduler.Noam(opt1, start_lr=200, warmup_iter=30, end_iter=500, num_iter=3) + lrs2 = lr_scheduler.Noam(opt2, start_lr=20, warmup_iter=40, end_iter=600, num_iter=1) + lrs3 = lr_scheduler.Noam(opt3, start_lr=10, warmup_iter=50, end_iter=700, num_iter=2) + optim_manager = optim.OptimManager(loss_scale=10485760) + optim_manager.add_optimizer(opt1, lrs1) + optim_manager.add_optimizer(opt2, lrs2) + optim_manager.add_optimizer(opt3, lrs3) + optim_manager.load_state_dict(torch.load(f"test_optim_manager_{bmt.rank()}.opt")) + + manual_seed() + train(model1, model2, model3, optim_manager) + state_1_plus_1 = deepcopy([list(model1.parameters()), list(model2.parameters()), list(model3.parameters())]) + + for i, kind in [ + (0, "BMTAdam"), + (1, "BMTAdamOffload"), + (2, "TorchAdam") + ]: + ref = state_2[i] + chk = state_1_plus_1[i] + for rp, p in zip(ref, chk): + assert (rp==p).all(), f"{kind} state load error" + + if bmt.rank() == 0: + os.remove(f"test_optim_state_model1.pt") + os.remove(f"test_optim_state_model2.pt") + os.remove(f"test_optim_state_model3.pt") + os.remove(f"test_optim_manager_{bmt.rank()}.opt") + +if __name__ == "__main__": + bmt.init_distributed() + main() diff --git a/examples/BMTrain/tests/test_other_hidden.py b/examples/BMTrain/tests/test_other_hidden.py new file mode 100644 index 00000000..27736aa7 --- /dev/null +++ b/examples/BMTrain/tests/test_other_hidden.py @@ -0,0 +1,189 @@ +from utils import * + +import bmtrain as bmt +import random +import torch +from bmtrain import config +from bmtrain.block_layer import Block, TransformerBlockList +from bmtrain.pipe_layer import PipelineTransformerBlockList +import torch.nn.functional as F +from bmtrain import inspect + +class Linear(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.out = {} + if init_weight: + self.weight = bmt.DistributedParameter(torch.tensor(init_weight, dtype=torch.float, device="cuda").reshape(out_features, in_features)) + else: + self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.xavier_normal_) + + if init_bias: + self.bias = bmt.DistributedParameter(torch.tensor(init_bias, dtype=torch.float, device="cuda").reshape(out_features,)) + else: + self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_) + + def forward(self, input): + ret = F.linear(input, self.weight, self.bias) + return ret + +class Model_ZERO(torch.nn.Module): + def __init__(self, pre, ms, post) -> None: + super().__init__() + self.pre = pre + self.ms = TransformerBlockList([ + Block(m) + for m in ms + ]) + self.post = post + + def forward(self, x, return_hidden_states=False): + x = self.pre(x) + if return_hidden_states: + x, o = self.ms(x, return_hidden_states=return_hidden_states) + return self.post(x), o + else: + x = self.ms(x, return_hidden_states=return_hidden_states) + return self.post(x) + +class Model_PIPE(torch.nn.Module): + def __init__(self, pre, ms, post) -> None: + super().__init__() + self.pre = pre + self.ms = PipelineTransformerBlockList([ + Block(m) + for m in ms + ]) + self.post = post + + def forward(self, x, return_hidden_states=False): + x = self.pre(x) + if return_hidden_states: + x, o = self.ms(x, return_hidden_states=return_hidden_states) + return self.post(x), o + else: + x = self.ms(x, return_hidden_states=return_hidden_states) + return self.post(x) + +class Model_BLOCK(torch.nn.Module): + def __init__(self, pre, ms, post) -> None: + super().__init__() + self.pre = pre + self.ms = torch.nn.ModuleList([ + Block(m) + for m in ms + ]) + self.post = post + + def forward(self, x, return_hidden_states=False): + x = self.pre(x) + o = [] + y = x + for m in self.ms: + o.append(y) + y = m(y) + if return_hidden_states: + return self.post(y), o + else: + return self.post(y) + +class Model_NORMAL(torch.nn.Module): + def __init__(self, pre, ms, post) -> None: + super().__init__() + self.pre = pre + self.ms = torch.nn.ModuleList(ms) + self.post = post + + def forward(self, x, return_hidden_states=False): + x = self.pre(x) + o = [] + y = x + for m in self.ms: + o.append(y) + y = m(y) + if return_hidden_states: + return self.post(y), o + else: + return self.post(y) + +def manual_seed(seed=33): + torch.manual_seed(seed) + random.seed(seed) + try: + import numpy as np + np.random.seed(seed) + except ModuleNotFoundError: + pass + +def sub_run(name, cls, num_layer, dim, batch, seq_len, only_pre=False, only_post=False, mix_test=False): + manual_seed() + + pre = Linear(dim, dim) + post = Linear(dim, dim) + ms = [Linear(dim, dim) for i in range(num_layer)] + + inp = torch.randn((batch, seq_len, dim)).cuda() + last_weight = torch.randn(pre.weight.shape).cuda()*10 + middle_weight = [ + torch.randn((batch, seq_len, dim)).cuda() + for i in range(len(ms)) + ] + + bmt.init_parameters(pre) + bmt.init_parameters(post) + for m in ms: + bmt.init_parameters(m) + m = cls(pre, [m for m in ms], post) + + ret = "" + if only_pre: + loss = (pre.weight * last_weight).sum() + loss.backward() + ret += f"========================only last========================\n" + ret += inspect.format_summary( + inspect.inspect_model(m, '*') + ) + if only_post: + loss = (post.weight * last_weight).sum() + loss.backward() + ret += f"========================only middle========================\n" + ret += inspect.format_summary( + inspect.inspect_model(m, '*') + ) + if mix_test: + loss = (pre.weight * last_weight).sum() + (post.weight * last_weight).sum() + loss.backward() + ret += f"========================mix========================\n" + ret += inspect.format_summary( + inspect.inspect_model(m, '*') + ) + return ret + "\n" # replace for matching None grad with zero_grad + +def run(name, cls, num_layer=4, dim=4096, batch=32, seq_len=256): + ret = "" + ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len, only_pre=True) + bmt.synchronize() + ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len, only_post=True) + bmt.synchronize() + ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len, mix_test=True) + bmt.synchronize() + return ret + +def test_main(): + ret = [] + ret.append( run("normal", Model_NORMAL) ) + ret.append( run("block", Model_BLOCK) ) + ret.append( run("zero", Model_ZERO) ) + # ret.append( run("pipe", Model_PIPE) ) + for r in ret: + bmt.print_rank(r) + for r in ret: + for r2 in ret: + assert_eq(r, r2) + +if __name__ == "__main__": + bmt.init_distributed(pipe_size=1) + test_main() diff --git a/examples/BMTrain/tests/test_parallel_projection.py b/examples/BMTrain/tests/test_parallel_projection.py new file mode 100644 index 00000000..dc1e874d --- /dev/null +++ b/examples/BMTrain/tests/test_parallel_projection.py @@ -0,0 +1,55 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +import numpy as np +import os + +def run_normal(x, t, ckp_path, dtype): + proj = bmt.nn.Projection(100, 64, dtype=dtype) + bmt.init_parameters(proj) + bmt.save(proj, ckp_path) + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=False) + y = proj(x) + y = y.detach().requires_grad_() + loss = loss_func(y, t) + loss.backward() + return y, loss, y.grad + +def run_vp(x, t, ckp_path, dtype): + proj = bmt.nn.VPProjection(100, 64, dtype=dtype) + bmt.load(proj, ckp_path) + loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) + y = proj(x) + y = y.detach().requires_grad_() + loss = loss_func(y, t) + loss.backward() + return y, loss, y.grad + +def run(dtype): + ckp_path = 'embedding.pt' + torch.cuda.manual_seed(100) + tp_size = config["tp_size"] + tp_rank = config['tp_rank'] + x = torch.randn(110, 64, device='cuda', dtype=dtype) + t = torch.cat([torch.arange(100).view(10, 10), torch.ones((10, 1))*-100], dim=-1).view(110).int().cuda() + y1, loss1, grad1 = run_normal(x, t, ckp_path, dtype) + y2, loss2, grad2 = run_vp(x, t, ckp_path, dtype) + y1 = y1.chunk(tp_size, dim=-1)[tp_rank] + grad1 = grad1.chunk(tp_size, dim=-1)[tp_rank] + for r in range(tp_size): + if bmt.rank() == r: + print((y1-y2).abs().max()) + print((loss1-loss2).abs().max()) + print((grad1-grad2).abs().max()) + assert (y1-y2).abs().max() < 1e-4 + assert (loss1-loss2).abs().max() < 1e-4 + assert (grad1-grad2).abs().max() < 1e-4 + bmt.synchronize() + if bmt.rank() == 0: + os.remove(f"embedding.pt") + +if __name__ == "__main__": + bmt.init_distributed(tp_size=4) + run(torch.half) + run(torch.bfloat16) + diff --git a/examples/BMTrain/tests/test_requires_grad.py b/examples/BMTrain/tests/test_requires_grad.py new file mode 100644 index 00000000..9a443bd3 --- /dev/null +++ b/examples/BMTrain/tests/test_requires_grad.py @@ -0,0 +1,107 @@ +from utils import * + +import bmtrain as bmt +import torch +from bmtrain import config +from bmtrain.block_layer import Block, TransformerBlockList +from bmtrain.pipe_layer import PipelineTransformerBlockList +from typing import List +import torch.nn.functional as F +from bmtrain import inspect + +class Linear(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.out = {} + if init_weight: + self.weight = bmt.DistributedParameter(torch.tensor(init_weight, dtype=torch.float, device="cuda").reshape(out_features, in_features)) + else: + self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.xavier_normal_) + + if init_bias: + self.bias = bmt.DistributedParameter(torch.tensor(init_bias, dtype=torch.float, device="cuda").reshape(out_features,)) + else: + self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_) + + def forward(self, input, other_bias): + ret = F.linear(input, self.weight, self.bias) + ret += other_bias + return ret + +def run(m, a, b): + inp = torch.rand((1, 10, 256)).cuda()*100 + bias = torch.rand(256).cuda()*100 + logits = m(inp, bias) + loss = logits.sum() + loss.backward() + bmt.synchronize() + sm = inspect.format_summary( + inspect.inspect_model(m, '*') + ) + assert_eq(bias.requires_grad, False) + return a.weight.grad is None, a.bias.grad is None, sm + +def test_main(): + a = Linear(256, 256) + b = Linear(256, 256) + m = TransformerBlockList([Block(a), Block(b)]) + bmt.init_parameters(m) + + a.bias.requires_grad_(False) + awg, abg, sm1 = run(m, a, b) + print(awg, abg, sm1) + assert_eq((awg, abg), (False, True)) + assert_eq(sm1.split('\n')[2].split()[-2:], ["0.0000", "0.0000"]) + + a.weight.requires_grad_(False) + a.bias.requires_grad_(True) + awg, abg, sm2 = run(m, a, b) + print(awg, abg, sm2) + assert_eq((awg, abg), (False, False)) + assert_eq(sm1.split('\n')[1], sm2.split('\n')[1]) + assert_neq(sm1.split('\n')[2], sm2.split('\n')[2]) + + a.weight.requires_grad_(True) + a.bias.requires_grad_(False) + awg, abg, sm3 = run(m, a, b) + print(awg, abg, sm3) + assert_eq((awg, abg), (False, False)) + assert_neq(sm2.split('\n')[1], sm3.split('\n')[1]) + assert_eq(sm2.split('\n')[2], sm3.split('\n')[2]) + +def test_main_pipe(): + a = Linear(256, 256) + b = Linear(256, 256) + m = PipelineTransformerBlockList([Block(a), Block(b)]) + bmt.init_parameters(m) + + a.bias.requires_grad_(False) + awg, abg, sm1 = run(m, a, b) + print(awg, abg, sm1) + assert_eq((awg, abg), (False, True)) + assert_eq(sm1.split('\n')[2].split()[-2:], ["0.0000", "0.0000"]) + + a.weight.requires_grad_(False) + a.bias.requires_grad_(True) + awg, abg, sm2 = run(m, a, b) + print(awg, abg, sm2) + assert_eq((awg, abg), (False, False)) + assert_eq(sm1.split('\n')[1], sm2.split('\n')[1]) + assert_neq(sm1.split('\n')[2], sm2.split('\n')[2]) + + a.weight.requires_grad_(True) + a.bias.requires_grad_(False) + awg, abg, sm3 = run(m, a, b) + print(awg, abg, sm3) + assert_eq((awg, abg), (False, False)) + assert_neq(sm2.split('\n')[1], sm3.split('\n')[1]) + assert_eq(sm2.split('\n')[2], sm3.split('\n')[2]) + +if __name__ == "__main__": + bmt.init_distributed(pipe_size=1) + + test_main() + # test_main_pipe() diff --git a/examples/BMTrain/tests/test_requires_grad_multi_gpu.py b/examples/BMTrain/tests/test_requires_grad_multi_gpu.py new file mode 100644 index 00000000..2eedf7b6 --- /dev/null +++ b/examples/BMTrain/tests/test_requires_grad_multi_gpu.py @@ -0,0 +1,96 @@ +from utils import * + +import bmtrain as bmt +import torch +from bmtrain.block_layer import Block, TransformerBlockList +from bmtrain.pipe_layer import PipelineTransformerBlockList +from typing import List +import torch.nn.functional as F +from bmtrain import inspect + +class Linear(bmt.DistributedModule): + def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None: + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.out = {} + if init_weight: + self.weight = bmt.DistributedParameter(torch.tensor(init_weight, dtype=torch.float, device="cuda").reshape(out_features, in_features)) + else: + self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.xavier_normal_) + + if init_bias: + self.bias = bmt.DistributedParameter(torch.tensor(init_bias, dtype=torch.float, device="cuda").reshape(out_features,)) + else: + self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_) + + def forward(self, input): + ret = F.linear(input, self.weight, self.bias) + return ret + +def run(m, a, b): + inp = torch.rand((1, 10, 256)).cuda()*100 + logits = m(inp) + loss = logits.sum() + loss.backward() + + sm = inspect.format_summary( + inspect.inspect_model(m, '*') + ) + return sm + +def test_main(): + a = Linear(256, 256) + b = Linear(256, 256) + m = TransformerBlockList([Block(a), Block(b)]) + bmt.init_parameters(m) + + a.bias.requires_grad_(False) + sm1 = run(m, a, b) + print(sm1) + assert_eq(sm1.split('\n')[2].split()[-2:], ["0.0000", "0.0000"]) + + a.weight.requires_grad_(False) + a.bias.requires_grad_(True) + sm2 = run(m, a, b) + print(sm2) + assert_eq(sm1.split('\n')[1], sm2.split('\n')[1]) + assert_neq(sm1.split('\n')[2], sm2.split('\n')[2]) + + a.weight.requires_grad_(True) + a.bias.requires_grad_(False) + sm3 = run(m, a, b) + assert_neq(sm2.split('\n')[1], sm3.split('\n')[1]) + assert_eq(sm2.split('\n')[2], sm3.split('\n')[2]) + +def test_main_pipe(): + a = Linear(256, 256) + b = Linear(256, 256) + m = PipelineTransformerBlockList([Block(a), Block(b)]) + bmt.init_parameters(m) + + a.bias.requires_grad_(False) + sm1 = run(m, a, b) + print(sm1) + assert_eq(sm1.split('\n')[2].split()[-2:], ["0.0000", "0.0000"]) + + a.weight.requires_grad_(False) + a.bias.requires_grad_(True) + sm2 = run(m, a, b) + print(sm2) + assert_eq(sm1.split('\n')[1], sm2.split('\n')[1]) + assert_neq(sm1.split('\n')[2], sm2.split('\n')[2]) + + a.weight.requires_grad_(True) + a.bias.requires_grad_(False) + sm3 = run(m, a, b) + print(sm3) + assert_neq(sm2.split('\n')[1], sm3.split('\n')[1]) + assert_eq(sm2.split('\n')[2], sm3.split('\n')[2]) + +if __name__ == "__main__": + bmt.init_distributed(pipe_size=1) + + test_main() + # test_main_pipe() diff --git a/examples/BMTrain/tests/test_row_parallel_linear.py b/examples/BMTrain/tests/test_row_parallel_linear.py new file mode 100644 index 00000000..23dce8b2 --- /dev/null +++ b/examples/BMTrain/tests/test_row_parallel_linear.py @@ -0,0 +1,54 @@ +import torch +import bmtrain as bmt +from bmtrain.global_var import config +import numpy as np + +def run_bmt(x, ckp_path, split_input=True, use_checkpoint_block=True): + linear = bmt.nn.RowParallelLinear(8,8, split_input=split_input, all_reduce_output=True) + if use_checkpoint_block: + linear = bmt.Block(linear) + bmt.init_parameters(linear) + y = linear(x) + y.sum().backward() + bmt.save(linear, ckp_path) + bmt.synchronize() + return y, linear._parameters['weight'].grad, linear._parameters['bias'].grad + +def run_torch(x, ckp_path): + linear = torch.nn.Linear(8, 8) + linear_dict = torch.load(ckp_path) + linear.load_state_dict(linear_dict) + linear = linear.cuda() + linear.weight.requires_grad_() + y = linear(x) + y.sum().backward() + return y, linear.weight.grad, linear.bias.grad + +def run(split_input, use_checkpoint_block, ckp_path): + tp_size = bmt.config['tp_size'] + torch.cuda.manual_seed(100) + tp_rank = config['topology'].tp_id + x = torch.randn(8,8, device='cuda').requires_grad_() + rank_x = x.chunk(tp_size, dim=0 if split_input else 1)[tp_rank] + y1, weight_grad1, bias_grad1 = run_bmt(rank_x, ckp_path, split_input, use_checkpoint_block) + y2, weight_grad2, bias_grad2 = run_torch(x, ckp_path) + np.testing.assert_allclose(y1.detach().cpu().numpy(), y2.detach().cpu().numpy(), atol=1e-5) + + weight_grad_list = weight_grad2.chunk(tp_size, dim=1) + assert np.allclose(weight_grad1.reshape(weight_grad_list[tp_rank].shape).cpu().numpy(), weight_grad_list[tp_rank].cpu().numpy()) + + assert np.allclose(bias_grad1.cpu().numpy(), bias_grad2.cpu().numpy()) + +def test_split_input(): + run(True, False, 'row_parallel_linear.ckp') + run(True, True, 'row_parallel_linear.ckp') + +def test_no_split_input(): + run(False, False, 'row_parallel_linear_no_split.ckp') + run(False, True, 'row_parallel_linear_no_split.ckp') + +if __name__ == "__main__": + bmt.init_distributed(tp_size=2) + test_no_split_input() + test_split_input() + diff --git a/examples/BMTrain/tests/test_send_recv.py b/examples/BMTrain/tests/test_send_recv.py new file mode 100644 index 00000000..f933b0c2 --- /dev/null +++ b/examples/BMTrain/tests/test_send_recv.py @@ -0,0 +1,22 @@ +from utils import * + +import torch +import bmtrain as bmt +from bmtrain.global_var import config + +def test_send_recv(): + if config["topology"].stage_id == 0: + a = torch.ones((2,1)) * (config["topology"].pp_zero_id+1) + a = a.cuda() + print(f"send {a}") + bmt.distributed.send_activations(a, 1, config["pipe_comm"]) + else: + ref = torch.ones((2,1)) * (config["topology"].pp_zero_id+1) + a = bmt.distributed.recv_activations(0, config["pipe_comm"]) + print(f"recv {a}") + assert_all_eq(a, ref.cuda()) + +if __name__ == '__main__': + bmt.init_distributed(pipe_size=2) + + test_send_recv() diff --git a/examples/BMTrain/tests/test_store.py b/examples/BMTrain/tests/test_store.py new file mode 100644 index 00000000..cb427d5b --- /dev/null +++ b/examples/BMTrain/tests/test_store.py @@ -0,0 +1,13 @@ +import bmtrain as bmt +from bmtrain.store import allgather_object + +def test_allgather_object(): + + res = allgather_object(bmt.rank(), bmt.config["comm"]) + ref = [i for i in range(bmt.world_size())] + assert res == ref + +if __name__ == "__main__": + bmt.init_distributed() + test_allgather_object() + diff --git a/examples/BMTrain/tests/test_synchronize.py b/examples/BMTrain/tests/test_synchronize.py new file mode 100644 index 00000000..bea48b04 --- /dev/null +++ b/examples/BMTrain/tests/test_synchronize.py @@ -0,0 +1,26 @@ +import torch +import bmtrain as bmt + +from bmtrain.global_var import config +from bmtrain import nccl, distributed +from bmtrain.synchronize import gather_result + +def test_main(): + + ref_result = torch.ones(5 * bmt.world_size(), 5) + tensor = ref_result.chunk(bmt.world_size(), dim=0)[bmt.rank()] + real_result = bmt.gather_result(tensor) + assert torch.allclose(ref_result, real_result, atol=1e-6), "Assertion failed for real gather result error" + + for i in range(4): + size = i + 1 + tensor_slice = tensor[:size, :size] + result_slice = bmt.gather_result(tensor_slice) + test_slice = torch.chunk(result_slice, bmt.world_size(), dim=0)[i] + assert torch.allclose(tensor_slice, test_slice), f"Assertion failed for tensor_slice_{i}" + +print("All test passed") + +if __name__ == '__main__': + bmt.init_distributed(pipe_size=1) + test_main() diff --git a/examples/BMTrain/tests/test_training.py b/examples/BMTrain/tests/test_training.py new file mode 100644 index 00000000..46389802 --- /dev/null +++ b/examples/BMTrain/tests/test_training.py @@ -0,0 +1,516 @@ +from bmtrain.optim import optim_manager +from utils import * + +from typing import Optional +import torch +import math +import torch.nn.functional as F +import bmtrain as bmt +import os +from bmtrain import inspect + +def clip_grad_norm(loss_scale, param_groups, max_norm, norm_type=2, eps=1e-6, is_torch=False): + """Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized. + max_norm (float or int): max norm of the gradients. + norm_type (float or int): type of the used p-norm. Can be 'inf' for infinity norm. + eps (float): epsilon used to avoid zero division. + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + scale = loss_scale + grads = [] + parameters = [p for group in param_groups for p in group['params']] + for p in parameters: + if p.grad is not None: + grads.append(p.grad.data) + else: + grads.append(torch.zeros_like(p.data)) + + if norm_type == 'inf': + total_norm_cuda = max(g.data.abs().max() for g in grads).detach() + if not is_torch: + bmt.nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "max", bmt.config["comm"]) + total_norm = total_norm_cuda + else: + norm_type = float(norm_type) + total_norm_cuda = torch.cuda.FloatTensor([0]) + for index, g in enumerate(grads): + param_norm = g.data.float().norm(norm_type) + total_norm_cuda += param_norm ** norm_type + if not is_torch: + bmt.nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "sum", bmt.config["comm"]) + total_norm = total_norm_cuda[0] ** (1. / norm_type) + clip_coef = float(max_norm * scale) / (total_norm + eps) + if clip_coef < 1: + for p in parameters: + if p.grad is not None: + p.grad.data.mul_(clip_coef) + return total_norm / scale + +class Attention(torch.nn.Module): + def __init__(self, + dim_model : int, dim_head : int, + num_heads : int, bias : bool = True, + dtype = None + ) -> None: + super().__init__() + + self.project_q = torch.nn.Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_k = torch.nn.Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + self.project_v = torch.nn.Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype) + + self.project_out = torch.nn.Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype) + + self.softmax = torch.nn.Softmax(dim=-1) + self.num_heads = num_heads + self.dim_head = dim_head + self.dim_model = dim_model + + def forward(self, + hidden_q : torch.Tensor, # (batch_size, seq_q, dim_model) + hidden_kv : torch.Tensor, # (batch_size, seq_kv, dim_model) + mask : torch.BoolTensor, # (batch_size, seq_q, seq_kv) + position_bias : Optional[torch.Tensor] = None, # (batch, num_heads, seq_q, seq_kv) + ) -> torch.Tensor: + batch_size, seq_q, dim_model = hidden_q.size() + seq_kv = hidden_kv.size(1) + + h_q : torch.Tensor = self.project_q(hidden_q) + h_k : torch.Tensor = self.project_k(hidden_kv) + h_v : torch.Tensor = self.project_v(hidden_kv) + + h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head) + h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head) + h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head) + + h_q = h_q.permute(0, 2, 1, 3).contiguous() + h_k = h_k.permute(0, 2, 1, 3).contiguous() + h_v = h_v.permute(0, 2, 1, 3).contiguous() + + h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head) + h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head) + h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head) + + score = torch.bmm( + h_q, h_k.transpose(1, 2) + ) + score = score / math.sqrt(self.dim_head) + + score = score.view(batch_size, self.num_heads, seq_q, seq_kv) + + if position_bias is not None: + score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv) + + score = torch.where( + mask.view(batch_size, 1, seq_q, seq_kv), + score, + torch.scalar_tensor(float('-inf'), device=score.device, dtype=score.dtype) + ) + + score = torch.where( + mask.view(batch_size, 1, seq_q, seq_kv), + self.softmax(score), + torch.scalar_tensor(0, device=score.device, dtype=score.dtype) + ) + + score = score.view(batch_size * self.num_heads, seq_q, seq_kv) + + h_out = torch.bmm( + score, h_v + ) + h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head) + h_out = h_out.permute(0, 2, 1, 3).contiguous() + h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head) + + attn_out = self.project_out(h_out) + return attn_out + +class Feedforward(torch.nn.Module): + def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None: + super().__init__() + + self.w_in = torch.nn.Linear(dim_model, dim_ff, bias = bias, dtype=dtype) + self.w_out = torch.nn.Linear(dim_ff, dim_model, bias = bias, dtype=dtype) + + self.relu = torch.nn.ReLU() + + def forward(self, input : torch.Tensor) -> torch.Tensor: + return self.w_out(self.relu(self.w_in(input))) + + +class TransformerEncoder(torch.nn.Module): + def __init__(self, + dim_model : int, dim_head : int, num_heads : int, dim_ff : int, + bias : bool = True, dtype = None + ) -> None: + super().__init__() + + self.ln_attn = torch.nn.LayerNorm(dim_model, dtype=dtype) + self.attn = Attention(dim_model, dim_head, num_heads, bias=bias, dtype=dtype) + + self.ln_ff = torch.nn.LayerNorm(dim_model, dtype=dtype) + self.ff = Feedforward(dim_model, dim_ff, bias=bias, dtype=dtype) + + def forward(self, + hidden : torch.Tensor, # (batch, seq_len, dim_model) + mask : torch.BoolTensor, # (batch, seq_len, dim_model) + position_bias : Optional[torch.Tensor] = None, # (batch, num_head, seq_len, seq_len) + ): + x = self.ln_attn(hidden) + x = self.attn(x, x, mask, position_bias) + hidden = hidden + x + + x = self.ln_ff(hidden) + x = self.ff(x) + hidden = hidden + x + + return hidden + + +class GPT(torch.nn.Module): + def __init__(self, + num_layers : int, vocab_size : int, + dim_model : int, dim_head : int, num_heads : int, dim_ff : int, + max_distance : int, + bias : bool = True, dtype = None + ) -> None: + super().__init__() + + self.dtype = dtype + self.max_distance = max_distance + + self.word_emb = torch.nn.Embedding(vocab_size, dim_model, dtype=dtype) + self.pos_emb = torch.nn.Embedding(max_distance, dim_model, dtype=dtype) + self.dim_model = dim_model + + self.transformers = torch.nn.ModuleList([ + TransformerEncoder( + dim_model, dim_head, num_heads, dim_ff, bias, dtype + ) + for _ in range(num_layers) + ]) + self.run_unroll = False + + self.layernorm = torch.nn.LayerNorm(dim_model, dtype=dtype) + + def forward(self, + input : torch.LongTensor, # (batch, seq_len) + pos : torch.LongTensor, # (batch, seq_len) + mask : torch.BoolTensor, # (batch, seq_len) + ) -> torch.Tensor: + + mask_2d = mask[:, None, :] & mask[:, :, None] # (batch, seq_len, seq_len) + mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None]) + + input_emb = self.pos_emb(pos) + self.word_emb(input) + + out = input_emb + if isinstance(self.transformers, torch.nn.ModuleList) or self.run_unroll: + for layer in self.transformers: + out = layer(out, mask_2d, None) + else: + out = self.transformers(out, mask_2d, None) + out = self.layernorm(out) + + logits = F.linear(out, self.word_emb.weight) / math.sqrt(self.dim_model) + + return logits + +def sub_train_torch(model, loss_func_cls, optimizer_cls): + loss_func = loss_func_cls(ignore_index=-100) + optimizer = optimizer_cls(model.parameters(), weight_decay=1e-2) + lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) + + optim_manager = bmt.optim.OptimManager(loss_scale=2**20 if model.dtype == torch.half else None) + optim_manager.add_optimizer(optimizer, lr_scheduler) + + # use the break if i == bmt.rank() to generate different data on different rank + torch.manual_seed(2333) + batch_size = 2 + seq_len = 512 + + sents = [] + enc_lengths = [] + enc_inputs = [] + targetss = [] + masks = [] + inps = [] + for i in range(bmt.world_size()): + sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) + enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() + enc_input = sent[:, :-1].long().cuda() + targets = sent[:, 1:].long().cuda() + mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None] + targets = torch.where( + mask, + targets, + torch.full_like(targets, -100, dtype=torch.long) + ) + + sents.append(sent) + enc_lengths.append(enc_length) + enc_inputs.append(enc_input) + targetss.append(targets) + masks.append(mask) + inps.append((sent,enc_length,enc_input,targets,mask)) + + logs = [] + for iter in range(100): + + optim_manager.zero_grad() + global_loss = 0 + for inp in inps: + sent, enc_length, enc_input, targets, mask = inp + pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) + logits = model(enc_input, pos, pos < enc_length[:, None]) + + batch, seq_len, vocab_out_size = logits.size() + + loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) / len(inps) + + global_loss += loss.item() + + loss = optim_manager.loss_scale * loss + loss.backward() + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(bmt.config['load_stream']) + grad_norm = clip_grad_norm(optim_manager.loss_scale, optimizer.param_groups, max_norm=10.0, is_torch = True) + + optim_manager.step() + + bmt.print_rank("| Iter: {:6d} | loss: {:.4f} {:.4f} | lr: {:.4e} scale: {:10.4f} | grad_norm: {:.4f} |".format( + iter, + global_loss, + loss, + lr_scheduler.current_lr, + optim_manager.loss_scale, + grad_norm, + )) + logs.append(global_loss) + + summary = inspect.inspect_model(model, "*") + return logs, summary + +def sub_train(model, loss_func_cls, optimizer_cls): + loss_func = loss_func_cls(ignore_index=-100) + optimizer = optimizer_cls(model.parameters(), weight_decay=1e-2) + lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0) + + optim_manager = bmt.optim.OptimManager(loss_scale=2**20 if model.dtype == torch.half else None) + optim_manager.add_optimizer(optimizer, lr_scheduler) + + # use the break if i == bmt.rank() to generate different data on different rank + torch.manual_seed(2333) + batch_size = 2 + seq_len = 512 + sents = [] + enc_lengths = [] + enc_inputs = [] + targetss = [] + masks = [] + for i in range(bmt.world_size()): + sent = torch.randint(0, 10240, (batch_size, seq_len + 1)) + enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda() + enc_input = sent[:, :-1].long().cuda() + targets = sent[:, 1:].long().cuda() + mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None] + targets = torch.where( + mask, + targets, + torch.full_like(targets, -100, dtype=torch.long) + ) + + sents.append(sent) + enc_lengths.append(enc_length) + enc_inputs.append(enc_input) + targetss.append(targets) + masks.append(mask) + # sent = torch.cat(sents, dim=0) + # enc_length = torch.cat(enc_lengths, dim=0) + # enc_input = torch.cat(enc_inputs, dim=0) + # targets = torch.cat(targetss, dim=0) + # mask = torch.cat(masks, dim=0) + sent = sents[bmt.rank()] + enc_length = enc_lengths[bmt.rank()] + enc_input = enc_inputs[bmt.rank()] + targets = targetss[bmt.rank()] + mask = masks[bmt.rank()] + + + logs = [] + pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) + for iter in range(100): + optim_manager.zero_grad() + logits = model(enc_input, pos, pos < enc_length[:, None]) + + batch, seq_len, vocab_out_size = logits.size() + + loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) + + global_loss = bmt.sum_loss(loss).item() + + optim_manager.backward(loss) + grad_norm = clip_grad_norm(optim_manager.loss_scale, optimizer.param_groups, max_norm=10.0) + + optim_manager.step() + bmt.print_rank("| Iter: {:6d} | loss: {:.4f} {:.4f} | lr: {:.4e} scale: {:10.4f} | grad_norm: {:.4f} |".format( + iter, + global_loss, + loss, + lr_scheduler.current_lr, + optim_manager.loss_scale, + grad_norm, + )) + logs.append(global_loss) + + summary = inspect.inspect_model(model, "*") + return logs, summary + +def train(model, loss_func, optimizer): + key = f"{model[0]}*{loss_func[0]}*{optimizer[0]}" + model = model[1]() + if key.startswith("torch"): + ret = sub_train_torch(model, loss_func[1], optimizer[1]) + else: + ret = sub_train(model, loss_func[1], optimizer[1]) + del model + bmt.print_rank(f"finished {key}") + return key, ret + +def test_main(test_fp16=True, test_fp32=True): + ckpt_path = "test_ckpt.pt" + + kwargs = { + "num_layers": 8, + "vocab_size": 10240, + "dim_model": 2560, + "dim_head": 80, + "num_heads": 32, + "dim_ff": 8192, + "max_distance": 1024, + "bias": True, + "dtype": None, + } + + def make_ref_ckpt(): + model = GPT(**kwargs) + if bmt.rank() == 0: + torch.save(model.state_dict(), ckpt_path) + bmt.synchronize() + del model + + ret = {} + def torch_model(): + model = GPT(**kwargs) + model.load_state_dict(torch.load(ckpt_path)) + model = model.cuda() + return model + + def wrap_model(): + model = GPT(**kwargs) + wrap_model = bmt.BMTrainModelWrapper(model) + bmt.load(wrap_model, ckpt_path) + return model + + def list_model(): + model = GPT(**kwargs) + list_model = bmt.BMTrainModelWrapper(model) + list_model.transformers = bmt.TransformerBlockList([m for m in list_model.transformers]) + bmt.load(list_model, ckpt_path) + return model + + def pipe_model(): + model = GPT(**kwargs) + pipe_model = bmt.BMTrainModelWrapper(model) + for m in pipe_model.transformers: + assert isinstance(m, bmt.Block) + pipe_model.transformers = bmt.PipelineTransformerBlockList([m for m in pipe_model.transformers]) + bmt.load(pipe_model, ckpt_path) + return model + + def unroll_list_model(): + model = GPT(**kwargs) + list_model = bmt.BMTrainModelWrapper(model) + list_model.transformers = bmt.TransformerBlockList([m for m in list_model.transformers]) + bmt.load(list_model, ckpt_path) + model.run_unroll = True + return model + + models = { + "torch": torch_model, + "wrapper": wrap_model, + "blocklist": list_model, + # "pipelist": pipe_model, + "unroll_blocklist": unroll_list_model, + } + loss_funcs = { + "bmt_entropy": bmt.loss.FusedCrossEntropy, + "torch_entropy": torch.nn.CrossEntropyLoss, + } + optimizers = { + "bmt_adam": bmt.optim.AdamOptimizer, + "bmt_adam_offload": bmt.optim.AdamOffloadOptimizer, + "torch_adam": torch.optim.Adam, + } + + ret = {} + def add_to_check_list(m, l, o): + key, value = train((m, models[m]), (l, loss_funcs[l]), (o, optimizers[o])) + ret[key] = value + + if test_fp16: + kwargs["dtype"] = torch.half + make_ref_ckpt() + add_to_check_list("torch", "bmt_entropy", "bmt_adam") + add_to_check_list("wrapper", "bmt_entropy", "bmt_adam") + add_to_check_list("blocklist", "bmt_entropy", "bmt_adam") + # add_to_check_list("pipelist", "bmt_entropy", "bmt_adam") + add_to_check_list("blocklist", "torch_entropy", "bmt_adam") + add_to_check_list("blocklist", "bmt_entropy", "bmt_adam_offload") + add_to_check_list("unroll_blocklist", "bmt_entropy", "bmt_adam") + if bmt.rank() == 0: + os.remove(ckpt_path) + check(ret) + + if test_fp32: + kwargs["dtype"] = torch.float + make_ref_ckpt() + add_to_check_list("torch", "torch_entropy", "bmt_adam") + add_to_check_list("wrapper", "torch_entropy", "bmt_adam") + add_to_check_list("blocklist", "torch_entropy", "bmt_adam") + # add_to_check_list("pipelist", "torch_entropy", "bmt_adam") + add_to_check_list("blocklist", "torch_entropy", "bmt_adam_offload") + add_to_check_list("blocklist", "torch_entropy", "torch_adam") + add_to_check_list("unroll_blocklist", "bmt_entropy", "bmt_adam") + if bmt.rank() == 0: + os.remove(ckpt_path) + check(ret) + +def check(ret): + if bmt.rank() == 0: + for k1, v1 in ret.items(): + for k2, v2 in ret.items(): + if k1 != k2: + print(f"checking {k1} vs. {k2}") + check_param(v1[1], v2[1]) + bmt.synchronize() + ret.clear() + +def check_param(info1, info2): + for l1, l2 in zip(info1, info2): + for key in ['std', 'mean', 'max', 'min']: + v1 = l1[key] + v2 = l2[key] + assert_lt(abs(v1-v2), 1e-2) + +if __name__ == '__main__': + bmt.init_distributed(pipe_size=1) + + + test_main(test_fp16=True, test_fp32=True) diff --git a/examples/BMTrain/tests/utils.py b/examples/BMTrain/tests/utils.py new file mode 100644 index 00000000..62831fea --- /dev/null +++ b/examples/BMTrain/tests/utils.py @@ -0,0 +1,14 @@ +def assert_eq(a, b): + assert a == b, f"{a} != {b}" + +def assert_neq(a, b): + assert a != b, f"{a} == {b}" + +def assert_lt(a, b): + assert a < b, f"{a} >= {b}" + +def assert_gt(a, b): + assert a > b, f"{a} <= {b}" + +def assert_all_eq(a, b): + assert_eq((a==b).all(), True) \ No newline at end of file diff --git a/examples/CPM.cu/.arsync b/examples/CPM.cu/.arsync new file mode 100644 index 00000000..f9635cc6 --- /dev/null +++ b/examples/CPM.cu/.arsync @@ -0,0 +1,10 @@ +auto_sync_up 0 +local_options -var +local_path /Users/tachicoma/projects/CPM.cu +remote_host sa.km +remote_options -var +remote_or_local remote +remote_path /home/sunao/cpm.cu +remote_port 0 +rsync_flags ["--max-size=100m"] +backend rsync diff --git a/examples/CPM.cu/.gitignore b/examples/CPM.cu/.gitignore new file mode 100644 index 00000000..acdb2c98 --- /dev/null +++ b/examples/CPM.cu/.gitignore @@ -0,0 +1,222 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + + +build/ +# Cuda +*.i +*.ii +*.gpu +*.ptx +*.cubin +*.fatbin + +# Prerequisites +*.d + +# Compiled Object files +*.slo +*.lo +*.o +*.obj + +# Precompiled Headers +*.gch +*.pch + +# Compiled Dynamic libraries +*.so +*.dylib +*.dll + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai +*.la +*.a +*.lib + +# Executables +*.exe +*.out +*.app + +.DS_Store +.vscode + +# Checkpoints +checkpoints/* +!checkpoints/.gitkeep + +# Cursor Rules +.mdc + +# FR-Spec Index +fr_index + +# Prompt Files +tests/*.txt +prompt.txt \ No newline at end of file diff --git a/examples/CPM.cu/.gitmodules b/examples/CPM.cu/.gitmodules new file mode 100644 index 00000000..a00c8278 --- /dev/null +++ b/examples/CPM.cu/.gitmodules @@ -0,0 +1,3 @@ +[submodule "src/cutlass"] + path = src/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/examples/CPM.cu/LICENSE b/examples/CPM.cu/LICENSE new file mode 100644 index 00000000..99908f2a --- /dev/null +++ b/examples/CPM.cu/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2024 OpenBMB + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/examples/CPM.cu/README.md b/examples/CPM.cu/README.md new file mode 100644 index 00000000..c4fb84f5 --- /dev/null +++ b/examples/CPM.cu/README.md @@ -0,0 +1,167 @@ +# CPM.cu + +<strong>[中文版本](./README_ZH.md) | English</strong> + +CPM.cu is a lightweight, high-performance CUDA implementation for LLMs, optimized for end-device inference and featuring cutting-edge techniques in **sparse architecture**, **speculative sampling** and **quantization**. + +<div id="news"></div> + +## 🔥 Project Updates + +- [2025.06.06] Optimized for [MiniCPM4](https://github.com/openbmb/minicpm). + - Support InfLLM-v2 attention kernel + - Support sliding-window for the MTP layer, optimized for long context + - Support quantization for the MTP layer +- [2025.05.29] Support Quantization at [SpecMQuant](https://github.com/AI9Stars/SpecMQuant). + - Support Marlin GPTQ kernel for the LLM + - Support Speculative Sampling for quantized LLM +- [2025.03.01] Release the first version at [FR-Spec](https://github.com/thunlp/FR-Spec). + - SOTA Speculative Sampling Implementation + - Support FR-Spec: Frequency-Ranked Speculative Sampling + - Support Tree-based verification of Speculative Sampling in Flash-Attention + - Support Static memory management and memory reuse + - Support Fused kernels + - Support Chunked prefill + - Support CUDA Graph + +<div id="demo"></div> + +## Demo + +https://github.com/user-attachments/assets/ab36fd7a-485b-4707-b72f-b80b5c43d024 + +<div id="getstart"></div> + +## Getting Started + +- [Installation](#install) +- [Model Weights](#modelweights) +- [Quick Start](#example) + +<div id="install"></div> + +## Installation + +### Install from source + +```bash +git clone https://github.com/OpenBMB/CPM.cu.git --recursive +cd CPM.cu +python3 setup.py install +``` + +<div id="modelweights"></div> + +## Prepare Model + +Please follow [MiniCPM4's README](https://github.com/openbmb/minicpm) to download the model weights. + +<div id="example"></div> + +## Quick Start + +We provide a simple example to show how to use CPM.cu to generate text. + +```bash +python3 tests/test_generate.py --prompt-file <your prompt file> +``` + +If you don't ​​specify​​ the model path, the scripts will load the model from ​​OpenBMB's Hugging Face repository​​. +If you want to use local paths, we recommend keeping all model filenames unchanged and placing them in the same directory. This way, you can run the model by specifying the directory with the -p parameter. Otherwise, we suggest modifying the paths in the code accordingly. + +If you don't ​​specify​​ the prompt file, a default ​​Haystack task​​ with ​​15K context length​​ will be provided. +You can use --help to learn more ​​about the script's features​​. + +We also provide a script, `tests/long_prompt_gen.py`, to generate ​​long code summarization. +This script ​​automatically collects code from this repository​​ and prompts ​​the model to "Summarize the code."​ + +```bash +python3 tests/long_prompt_gen.py # generate prompt.txt (for more details, use --help) +python3 tests/test_generate.py --prompt-file prompt.txt +``` + +The output should be of the following format: + +```bash +Generated text (streaming output): +-------------------------------------------------- +Prefilling: 100.0% (106850/106850 tokens) @ 6565.3 tokens/s - Complete! + +<Generated Output HERE> +================================================== +Stream Generation Summary: +================================================== +Prefill length: 106850 +Prefill time: 16.36 s +Prefill tokens/s: 6530.77 +Mean accept length: 2.50 +Decode length: 118 +Decode time: 0.76 s +Decode tokens/s: 154.59 +``` + +Where: + +- the `Prefill` and `Decode` speed are output by (length, time and token/s). +- the `Mean accept length` is the average length of the accepted tokens when using Speculative Sampling. + +## Code Structure + +```bash +CPM.cu/ +├── src/ +│ ├── flash_attn/ # attention kernels: sparse, tree-verification, etc. +│ ├── model/ +│ │ ├── minicpm4/ # minicpm4 model +│ │ ├── w4a16_gptq_marlin/ # marlin kernel +│ │ └── ... # common layers +│ ├── entry.cu # pybind: bind cuda and python +│ └── ... +├── cpmcu/ # python interface +└── ... +``` + +## More +We provide a word frequency generation script for FR-Spec, located at "scripts/fr_spec/gen_fr_index.py". You can run it as follows: +``` +python scripts/fr_spec/gen_fr_index.py --model_path <your_model_path> +``` +You can modify the code to use your own dataset. If your task is in a specific vertical domain, constructing word frequencies tailored to that domain can significantly improve processing speed. + + +## Acknowledgments + +Our `src/flash_attn` folder modified based on [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/v2.6.3/csrc/flash_attn). + +We have drawn inspiration from the following repositories: + +- [EAGLE](https://github.com/SafeAILab/EAGLE) +- [Block-Sparse-Attention](https://github.com/mit-han-lab/Block-Sparse-Attention) +- [vLLM](https://github.com/vllm-project/vllm) +- [SGLang](https://github.com/sgl-project/sglang) + +## Citation + +Please cite our paper if you find our work valuable. + +``` +@article{zhao2025fr, + title={FR-Spec: Accelerating Large-Vocabulary Language Models via Frequency-Ranked Speculative Sampling}, + author={Zhao, Weilin and Pan, Tengyu and Han, Xu and Zhang, Yudi and Sun, Ao and Huang, Yuxiang and Zhang, Kaihuo and Zhao, Weilun and Li, Yuxuan and Wang, Jianyong and others}, + journal={arXiv preprint arXiv:2502.14856}, + year={2025} +} + +@article{zhang2025specmqaunt, + title={Speculative Decoding Meets Quantization: Compatibility Evaluation and Hierarchical Framework Design}, + author={Zhang, Yudi and Zhao, Weilin and Han, Xu and Zhao, Tiejun and Xu, Wang and Cao, Hailong and Zhu, Conghui}, + journal={arXiv preprint arXiv:2505.22179}, + year={2025} +} + +@article{minicpm4, + title={MiniCPM4: Ultra-Efficient LLMs on End Devices}, + author={MiniCPM}, + year={2025} +} +``` \ No newline at end of file diff --git a/examples/CPM.cu/README_ZH.md b/examples/CPM.cu/README_ZH.md new file mode 100644 index 00000000..7f364ad3 --- /dev/null +++ b/examples/CPM.cu/README_ZH.md @@ -0,0 +1,165 @@ +# CPM.cu + +<strong>中文 | [English Version](./README.md)</strong> + +CPM.cu 是一个针对端侧大模型推理设计的轻量、高效的 CUDA 推理框架,核心支持 **稀疏架构**、**投机采样** 和 **低位宽量化** 等前沿技术创新。 + +<div id="news"></div> + +## 🔥 项目进展 + +- [2025.06.06] 为 [MiniCPM4](https://github.com/openbmb/minicpm) 优化。 + - 支持 InfLLM-v2 注意力内核 + - 支持 MTP 层的滑动窗口,优化长上下文处理 + - 支持 MTP 层的量化 +- [2025.05.29] 支持 [SpecMQuant](https://github.com/AI9Stars/SpecMQuant) 的量化。 + - 支持 LLM 的 Marlin GPTQ 内核 + - 支持量化 LLM 的投机采样 +- [2025.03.01] 在 [FR-Spec](https://github.com/thunlp/FR-Spec) 发布首个版本。 + - 速度最快的投机采样实现 + - 支持 FR-Spec, 基于词频优化的投机采样 + - 支持 Flash-Attention 中的树状投机采样验证 + - 支持静态内存管理和内存复用 + - 支持计算融合内核 + - 支持分块预填充 + - 支持 CUDA Graph + +<div id="demo"></div> + +## 效果演示 + +https://github.com/user-attachments/assets/ab36fd7a-485b-4707-b72f-b80b5c43d024 + +<div id="getstart"></div> + +## 快速开始 + +- [框架安装](#install) +- [模型权重](#modelweights) +- [运行示例](#example) + +<div id="install"></div> + +## 框架安装 + +### 从源码安装 + +```bash +git clone https://github.com/OpenBMB/cpm.cu.git --recursive +cd cpm.cu +python3 setup.py install +``` + +<div id="modelweights"></div> + +## 准备模型 + +请按照 [MiniCPM4 的 README](https://github.com/openbmb/minicpm) 的说明下载模型权重。 + +<div id="example"></div> + +## 运行示例 + +我们提供了一个简单的示例来展示如何使用 CPM.cu。 + +```bash +python3 tests/test_generate.py --prompt-file <输入文件路径> +``` + +如果您不指定模型路径,脚本将从 OpenBMB 的 Hugging Face 仓库加载模型。 +如果你想使用本地路径,我们推荐不修改所有模型文件名并放在同一目录下,这样可以通过-p指定该目录运行模型。否则建议修改代码中的路径。 + +如果您不指定输入文件,将提供一个默认的 Haystack 任务,上下文长度为 15K。 +您可以使用 --help 了解更多关于脚本的功能。 + +我们还有一个脚本,`tests/long_prompt_gen.py`,用于生成长代码总结。 +这个脚本会自动从本仓库中收集代码,并提示模型“总结代码”。 + +```bash +python3 tests/long_prompt_gen.py # 生成 prompt.txt (更多细节请见 --help) +python3 tests/test_generate.py --prompt-file prompt.txt +``` + +输出应为如下格式: + +```bash +Generated text (streaming output): +-------------------------------------------------- +Prefilling: 100.0% (106850/106850 tokens) @ 6565.3 tokens/s - Complete! + +<Generated Output HERE> +================================================== +Stream Generation Summary: +================================================== +Prefill length: 106850 +Prefill time: 16.36 s +Prefill tokens/s: 6530.77 +Mean accept length: 2.50 +Decode length: 118 +Decode time: 0.76 s +Decode tokens/s: 154.59 +``` + +其中: + +- `Prefill` (输入) 和 `Decode` (输出) 速度通过(长度、时间和 token/s)输出。 +- `Mean accept length` (平均接受长度) 是使用投机采样时接受 token 的平均长度。 + +## 代码结构 + +```bash +cpm.cu/ +├── src/ +│ ├── flash_attn/ # attention: 稀疏, 投机树状验证等 +│ ├── model/ +│ │ ├── minicpm4/ # minicpm4 模型 +│ │ ├── w4a16_gptq_marlin/ # Marlin GPTQ 计算内核 +│ │ └── ... # 通用层 +│ ├── entry.cu # pybind: 绑定 CUDA 和 Python +│ └── ... +├── cpmcu/ # Python 接口 +└── ... +``` +## 更多 +我们提供了FR-Spec的词频生成脚本,位于"scripts/fr_spec/gen_fr_index.py",运行方式如下: +``` +python scripts/fr_spec/gen_fr_index.py --model_path <your modelpath> +``` +你可以修改代码使用自己的数据集。如果你的任务是特定垂直领域,根据领域构造词频对速度提升有显著收益。 + +## 致谢 + +我们的 `src/flash_attn` 文件夹基于 [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/v2.6.3/csrc/flash_attn) 并进行了修改。 + +我们从以下仓库中获取了实现灵感: + +- [EAGLE](https://github.com/SafeAILab/EAGLE) +- [Block-Sparse-Attention](https://github.com/mit-han-lab/Block-Sparse-Attention) +- [vLLM](https://github.com/vllm-project/vllm) +- [SGLang](https://github.com/sgl-project/sglang) + +## 引用 + +如果您觉得我们的工作有价值,请引用我们的论文。 + +``` +@article{zhao2025fr, + title={FR-Spec: Accelerating Large-Vocabulary Language Models via Frequency-Ranked Speculative Sampling}, + author={Zhao, Weilin and Pan, Tengyu and Han, Xu and Zhang, Yudi and Sun, Ao and Huang, Yuxiang and Zhang, Kaihuo and Zhao, Weilun and Li, Yuxuan and Wang, Jianyong and others}, + journal={arXiv preprint arXiv:2502.14856}, + year={2025} +} + +@article{zhang2025specmqaunt, + title={Speculative Decoding Meets Quantization: Compatibility Evaluation and Hierarchical Framework Design}, + author={Zhang, Yudi and Zhao, Weilin and Han, Xu and Zhao, Tiejun and Xu, Wang and Cao, Hailong and Zhu, Conghui}, + journal={arXiv preprint arXiv:2505.22179}, + year={2025} +} + +@article{minicpm4, + title={MiniCPM4: Ultra-Efficient LLMs on End Devices}, + author={MiniCPM}, + year={2025} +} +``` \ No newline at end of file diff --git a/examples/CPM.cu/cpmcu/__init__.py b/examples/CPM.cu/cpmcu/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/CPM.cu/cpmcu/llm.py b/examples/CPM.cu/cpmcu/llm.py new file mode 100644 index 00000000..fb594c6b --- /dev/null +++ b/examples/CPM.cu/cpmcu/llm.py @@ -0,0 +1,422 @@ +from . import C + +import os, json, glob +import torch +from transformers import AutoTokenizer, AutoConfig +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from safetensors.torch import load_file +import time, math +import torch.nn.functional as F + +dtype_map = { + torch.float16: 0, + torch.bfloat16: 1, +} + +def dtype_to_int(dtype): + ret = dtype_map.get(dtype, -1) + if ret == -1: + raise ValueError(f"Unsupported dtype: {dtype}") + return ret + +class LLM(torch.nn.Module): + def __init__(self, + path: str, # hf model path + memory_limit: float = 0.8, + chunk_length: int = 1024, + dtype: torch.dtype = None, + cuda_graph: bool = False, + apply_sparse: bool = False, + sink_window_size: int = 1, + block_window_size: int = 32, + sparse_topk_k: int = 32, + sparse_switch: int = 8192, + apply_compress_lse: bool = False, + use_enter: bool = False, + use_decode_enter: bool = False, + temperature: float = 0.0, + random_seed: int = None, + ): + super().__init__() + + self.path = path + self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + self.config = AutoConfig.from_pretrained(path, trust_remote_code=True) + self.dtype = dtype if dtype is not None else self.config.torch_dtype + self.dtype_int = dtype_to_int(self.dtype) + self.cuda_graph = cuda_graph + self.use_enter = use_enter + self.use_decode_enter = use_decode_enter + self.temperature = temperature + + self.chunk_length = chunk_length + # Flag for showing prefill progress (used in stream mode) + self._show_prefill_progress = False + + # Initialize random generator if random_seed is provided + if random_seed is not None: + self.generator = torch.Generator(device="cuda") + self.generator.manual_seed(random_seed) + else: + self.generator = None + + if not hasattr(self.config, "head_dim"): + self.config.head_dim = self.config.hidden_size // self.config.num_attention_heads + scale_embed = self.config.scale_emb if hasattr(self.config, "scale_emb") else 1.0 + scale_lmhead = (self.config.dim_model_base / self.config.hidden_size) if hasattr(self.config, "dim_model_base") else 1.0 + scale_residual = self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) if hasattr(self.config, "scale_depth") else 1.0 + + if apply_sparse: + C.init_minicpm4_model( + memory_limit, + self.config.vocab_size, + self.config.num_hidden_layers, + self.config.hidden_size, + self.config.intermediate_size, + self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.head_dim, + self.config.rms_norm_eps, + self.dtype_int, + self.chunk_length, + scale_embed, + scale_lmhead, + scale_residual, + sink_window_size, + block_window_size, + sparse_topk_k, + sparse_switch, + apply_compress_lse, + ) + else: + C.init_base_model( + memory_limit, + self.config.vocab_size, + self.config.num_hidden_layers, + self.config.hidden_size, + self.config.intermediate_size, + self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.head_dim, + self.config.rms_norm_eps, + self.dtype_int, + self.chunk_length, + scale_embed, + scale_lmhead, + scale_residual, + ) + + self.logits = torch.empty((64, self.config.vocab_size), dtype=self.dtype, device="cuda") + + def init_storage(self): + self.max_total_length = C.init_storage() + print("max supported length under current memory limit: ", self.max_total_length) + + def _load(self, name, param, dtype=None, cls=None): + if dtype is None: + if 'rotary_emb' in name: + dtype = torch.float32 + else: + dtype = self.dtype + + if 'gate_up_proj' in name: + self._load(name.replace("gate_up_proj", "gate_proj"), param[:param.shape[0]//2], dtype) + self._load(name.replace("gate_up_proj", "up_proj"), param[param.shape[0]//2:]) + elif 'qkv_proj' in name: + self._load(name.replace("qkv_proj", "q_proj"), param[:self.config.num_attention_heads * self.config.head_dim]) + self._load(name.replace("qkv_proj", "k_proj"), param[self.config.num_attention_heads * self.config.head_dim:(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim]) + self._load(name.replace("qkv_proj", "v_proj"), param[(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim:]) + else: + param = param.contiguous().to(dtype) + C.load_model(name, param.data_ptr()) + + if "embed_tokens" in name and hasattr(self.config, "tie_word_embeddings") and self.config.tie_word_embeddings: + self._load("lm_head.weight", param) + + def _load_from_ckpt(self, path, cls=None): + supported_suffix_1 = ["bin.index.json", "safetensors.index.json"] + supported_suffix_2 = ["bin", "safetensors", "pt"] + file = None + for suffix in supported_suffix_1: + files = glob.glob(os.path.join(path, f"*.{suffix}")) + if len(files) > 1: + raise ValueError(f"Multiple files with suffix {suffix} found in {path}") + elif len(files) == 1: + file = files[0] + break + else: + for suffix in supported_suffix_2: + files = glob.glob(os.path.join(path, f"*.{suffix}")) + if len(files) > 1: + raise ValueError(f"Multiple files with suffix {suffix} found in {path}") + elif len(files) == 1: + file = files[0] + break + else: + raise ValueError(f"No supported checkpoint file found in {path}, supported suffixes: {supported_suffix_1 + supported_suffix_2}") + + if file.endswith(".index.json"): + with open(file, "r") as f: + file_list = set(json.load(f)["weight_map"].values()) + file_list = [os.path.join(path, file) for file in file_list] + else: + file_list = [file] + + for file in file_list: + print(f"load from {file}") + if file.endswith(".bin") or file.endswith(".pt"): + ckpt = torch.load(file, map_location="cpu") + elif file.endswith(".safetensors"): + ckpt = load_file(file) + for name, param in ckpt.items(): + self._load(name, param, cls=cls) + + def load_from_hf(self): + with torch.no_grad(): + self._load_from_ckpt(self.path) + + # rope + if hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + rope_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type")) + if rope_type == "longrope" and not hasattr(self.config.rope_scaling, "factor"): + self.config.rope_scaling["factor"] = 1.0 + else: + rope_type = "default" + # TODO only support "default", "llama3" or "longrope" with long_factor=short_factor + inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](self.config, "cpu", seq_len=self.max_total_length) + # attention_scaling = torch.tensor([attention_scaling], dtype=torch.float32, device="cpu") + self._load("model.rotary_emb.inv_freq", inv_freq, dtype=torch.float32) + # self._load("model.rotary_emb.attention_scaling", attention_scaling, dtype=torch.float32) + + def prefill(self, input_ids, position_ids): + assert input_ids.dtype == torch.int32 + # Check if input length exceeds maximum supported length + if input_ids.numel() > self.max_total_length: + raise ValueError(f"Input token count ({input_ids.numel()}) exceeds maximum supported length ({self.max_total_length}) under current memory limit") + + total_length = input_ids.numel() + num_chunks = (total_length + self.chunk_length - 1) // self.chunk_length + + prefill_start_time = None + actual_prefill_start = None + + # User interaction logic only when use_enter is True + if self._show_prefill_progress and self.use_enter: + # Clear screen and move cursor to top, then show prompt and wait for user input + print("\033[2J\033[H", end="", flush=True) # Clear screen and move to top + print("Please Press Enter to Start Prefilling...", end="", flush=True) + input() # Wait for Enter key + + # Replace the prompt with [Prefilling] - clear entire line first + print("\r" + " " * 50 + "\r[Prefilling]", flush=True) + # Start timing after user presses Enter + prefill_start_time = time.time() + actual_prefill_start = prefill_start_time + + # Initialize progress display for stream mode (always when _show_prefill_progress is True) + if self._show_prefill_progress: + if prefill_start_time is None: # Only set start time if not already set above + prefill_start_time = time.time() + if not self.use_enter: + print("Prefilling: 0.0% (0/{} tokens) @ 0.0 tokens/s".format(total_length), end="", flush=True) + else: + print("Prefilling: 0.0% (0/{} tokens) @ 0.0 tokens/s".format(total_length), end="", flush=True) + + # Record actual computation start time if not set yet + if actual_prefill_start is None: + actual_prefill_start = time.time() + + for chunk_idx, i in enumerate(range(0, input_ids.numel(), self.chunk_length)): + # torch.cuda.nvtx.range_push(f"chunk from {i}") + C.prefill( + min(input_ids.numel() - i, self.chunk_length), i, + input_ids.view(-1)[i:].data_ptr(), position_ids.view(-1)[i:].data_ptr(), + self.logits.data_ptr() + ) + # torch.cuda.nvtx.range_pop() + + # Show progress for stream mode - always when _show_prefill_progress is True + if self._show_prefill_progress and prefill_start_time is not None: + current_tokens = min(i + self.chunk_length, total_length) + elapsed_time = time.time() - prefill_start_time + progress = (current_tokens * 100.0) / total_length + tokens_per_sec = current_tokens / elapsed_time if elapsed_time > 0 else 0.0 + print(f"\rPrefilling: {progress:.1f}% ({current_tokens}/{total_length} tokens) @ {tokens_per_sec:.1f} tokens/s", end="", flush=True) + + # Calculate actual prefill time + actual_prefill_time = time.time() - actual_prefill_start + + # Final completion status for stream mode + if self._show_prefill_progress: + if prefill_start_time is not None: + final_elapsed_time = time.time() - prefill_start_time + final_tokens_per_sec = total_length / final_elapsed_time if final_elapsed_time > 0 else 0.0 + print(f"\rPrefilling: 100.0% ({total_length}/{total_length} tokens) @ {final_tokens_per_sec:.1f} tokens/s - Complete!") + if self.use_enter: + print("\n[Decoding]") # Show decoding status and move to next line only with use_enter + else: + print() # Just a newline for normal mode + + # Store the actual prefill time for use in generate method + self._last_prefill_time = actual_prefill_time + + return self.logits[:1].clone() + + def decode(self, input_ids, position_ids, cache_length, mask_2d = None): + assert input_ids.dtype == torch.int32 + assert position_ids.dtype == torch.int32 + assert cache_length.dtype == torch.int32 + if mask_2d is not None: + # assert mask_2d.dtype == torch.uint64 + assert input_ids.numel() == mask_2d.shape[0] + + # torch.cuda.nvtx.range_push(f"decode") + cache_length += input_ids.numel() # temparary add for convinience in flash_attn + padded_length = (cache_length[0].item() + 128 - 1) // 128 * 128 + C.decode( + input_ids.numel(), padded_length, + input_ids.data_ptr(), position_ids.data_ptr(), cache_length.data_ptr(), + mask_2d.data_ptr() if mask_2d is not None else 0, + self.logits.data_ptr(), + self.cuda_graph + ) + cache_length -= input_ids.numel() + # torch.cuda.nvtx.range_pop() + return self.logits[:input_ids.numel()].clone() + + def generate(self, input_ids, generation_length=100, teminators=[], use_stream=False): + """ + Generate text with optional streaming output. + Returns (tokens, decode_time, prefill_time) if use_stream=False, or generator yielding {'token', 'text', 'is_finished', 'prefill_time', 'decode_time'} if use_stream=True. + """ + assert input_ids.dtype == torch.int32 + + prefix_length = input_ids.numel() + position_ids = torch.arange(prefix_length, dtype=torch.int32, device="cuda") + + # Set progress flag before prefill for stream mode + if use_stream: + self._show_prefill_progress = True + + # Measure prefill time + if self.use_enter and use_stream: + # In use_enter mode, timing will be handled inside prefill method + logits = self.prefill(input_ids, position_ids) + prefill_time = getattr(self, '_last_prefill_time', 0.0) # Get actual prefill time + else: + torch.cuda.synchronize() + prefill_start = time.time() + logits = self.prefill(input_ids, position_ids) + torch.cuda.synchronize() + prefill_time = time.time() - prefill_start + + if self.temperature > 0.0: + token = torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator)[0].item() + else: + token = logits[0].argmax(dim=-1).item() + + # Wait for user input before decode phase if use_decode_enter is enabled + if self.use_decode_enter: + if use_stream and self.use_enter: + # In stream mode with use_enter, we already showed [Decoding], just wait for input + print("Please Press Enter to Start Decoding...", end="", flush=True) + input() # Wait for Enter key + print("\r" + " " * 50 + "\r", end="", flush=True) # Clear the prompt without showing [Decoding] again + else: + # In other modes, show prompt and wait + print("Please Press Enter to Start Decoding...", end="", flush=True) + input() # Wait for Enter key + print("\r" + " " * 50 + "\r[Decoding]", flush=True) # Show [Decoding] only when use_enter is not enabled + + if not hasattr(self, "input_ids"): + self.input_ids = torch.tensor([0], dtype=torch.int32, device="cuda") + self.position_ids = torch.tensor([0], dtype=torch.int32, device="cuda") + self.cache_length = torch.tensor([0], dtype=torch.int32, device="cuda") + + if use_stream: + # Stream generation (optimized) + def _stream_generator(): + nonlocal token + # Keep minimal context for correct spacing + prev_token = token + + # yield first token + text = self.tokenizer.decode([token], skip_special_tokens=False) + + yield { + 'token': token, + 'text': text, + 'is_finished': token in teminators, + 'prefill_time': prefill_time, + 'decode_time': 0.0 # First token comes from prefill + } + + if token in teminators: + return + + decode_start_time = time.time() + + for i in range(generation_length-1): + self.input_ids[0] = token + self.position_ids[0] = prefix_length + i + self.cache_length[0] = prefix_length + i + + logits = self.decode(self.input_ids, self.position_ids, self.cache_length) + if self.temperature > 0.0: + token = torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator)[0].item() + else: + token = logits[0].argmax(dim=-1).item() + + # For correct spacing, decode with previous token context + if prev_token is not None: + context_tokens = [prev_token, token] + context_text = self.tokenizer.decode(context_tokens, skip_special_tokens=False) + prev_text = self.tokenizer.decode([prev_token], skip_special_tokens=False) + text = context_text[len(prev_text):] + else: + text = self.tokenizer.decode([token], skip_special_tokens=False) + + is_finished = token in teminators or i == generation_length - 2 + + # Calculate time only when needed to reduce overhead + decode_time = time.time() - decode_start_time + + yield { + 'token': token, + 'text': text, + 'is_finished': is_finished, + 'prefill_time': 0.0, # Only report prefill_time for first token + 'decode_time': decode_time + } + + if token in teminators: + break + + # Update prev_token + prev_token = token + + return _stream_generator() + else: + # Original batch generation + tokens = [token] + torch.cuda.synchronize() + decode_start = time.time() + for i in range(generation_length-1): + self.input_ids[0] = token + self.position_ids[0] = prefix_length + i + self.cache_length[0] = prefix_length + i + + logits = self.decode(self.input_ids, self.position_ids, self.cache_length) + if self.temperature > 0.0: + token = torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator)[0].item() + else: + token = logits[0].argmax(dim=-1).item() + tokens.append(token) + if token in teminators: + break + torch.cuda.synchronize() + decode_time = time.time() - decode_start + return tokens, decode_time, prefill_time + + def print_perf_summary(self): + C.print_perf_summary() \ No newline at end of file diff --git a/examples/CPM.cu/cpmcu/llm_w4a16_gptq_marlin.py b/examples/CPM.cu/cpmcu/llm_w4a16_gptq_marlin.py new file mode 100644 index 00000000..d80a869f --- /dev/null +++ b/examples/CPM.cu/cpmcu/llm_w4a16_gptq_marlin.py @@ -0,0 +1,434 @@ +from . import C + +import os, json, glob +import torch +from transformers import AutoTokenizer, AutoConfig +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from safetensors.torch import load_file +import time, math +import torch.nn.functional as F + +dtype_map = { + torch.float16: 0, + torch.bfloat16: 1, +} + +def dtype_to_int(dtype): + ret = dtype_map.get(dtype, -1) + if ret == -1: + raise ValueError(f"Unsupported dtype: {dtype}") + return ret + +class W4A16GPTQMarlinLLM(torch.nn.Module): + def __init__(self, + path: str, # hf model path + memory_limit: float = 0.8, + chunk_length: int = 1024, + dtype: torch.dtype = None, + cuda_graph: bool = False, + apply_sparse: bool = False, + sink_window_size: int = 1, + block_window_size: int = 32, + sparse_topk_k: int = 32, + sparse_switch: int = 8192, + apply_compress_lse: bool = False, + use_enter: bool = False, + use_decode_enter: bool = False, + temperature: float = 0.0, + random_seed: int = None, + ): + super().__init__() + + self.path = path + self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + self.config = AutoConfig.from_pretrained(path, trust_remote_code=True) + self.dtype = dtype if dtype is not None else self.config.torch_dtype + self.dtype_int = dtype_to_int(self.dtype) + self.cuda_graph = cuda_graph + self.use_enter = use_enter + self.use_decode_enter = use_decode_enter + self.temperature = temperature + self.chunk_length = chunk_length + # Flag for showing prefill progress (used in stream mode) + self._show_prefill_progress = False + + # Initialize random generator if random_seed is provided + if random_seed is not None: + self.generator = torch.Generator(device="cuda") + self.generator.manual_seed(random_seed) + else: + self.generator = None + + if not hasattr(self.config, "head_dim"): + self.config.head_dim = self.config.hidden_size // self.config.num_attention_heads + + self.group_size = self.config.quantization_config['group_size'] + scale_embed = self.config.scale_emb if hasattr(self.config, "scale_emb") else 1.0 + scale_lmhead = (self.config.dim_model_base / self.config.hidden_size) if hasattr(self.config, "dim_model_base") else 1.0 + scale_residual = self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) if hasattr(self.config, "scale_depth") else 1.0 + + if apply_sparse: + C.init_w4a16_gptq_marlin_minicpm4_model( + memory_limit, + self.config.vocab_size, + self.config.num_hidden_layers, + self.config.hidden_size, + self.config.intermediate_size, + self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.head_dim, + self.config.rms_norm_eps, + self.group_size, + self.dtype_int, + self.chunk_length, + scale_embed, + scale_lmhead, + scale_residual, + sink_window_size, + block_window_size, + sparse_topk_k, + sparse_switch, + apply_compress_lse, + ) + else: + C.init_w4a16_gptq_marlin_base_model( + memory_limit, + self.config.vocab_size, + self.config.num_hidden_layers, + self.config.hidden_size, + self.config.intermediate_size, + self.config.num_attention_heads, + self.config.num_key_value_heads, + self.config.head_dim, + self.config.rms_norm_eps, + self.group_size, + self.dtype_int, + self.chunk_length, + scale_embed, + scale_lmhead, + scale_residual, + ) + + self.logits = torch.empty((64, self.config.vocab_size), dtype=self.dtype, device="cuda") + + def init_storage(self): + self.max_total_length = C.init_storage() + print("max supported length under current memory limit: ", self.max_total_length) + + def _load(self, name, param, dtype=None, cls=None): + # if ".q_proj." in name or ".k_proj." in name or ".v_proj." in name or ".gate_proj." in name or ".up_proj." in name: + # return + if dtype is None: + if 'rotary_emb' in name: + dtype = torch.float32 + else: + dtype = self.dtype + + # if 'gate_up_proj' in name: + # self._load(name.replace("gate_up_proj", "gate_proj"), param[:param.shape[0]//2], dtype) + # self._load(name.replace("gate_up_proj", "up_proj"), param[param.shape[0]//2:]) + # elif 'qkv_proj' in name: + # self._load(name.replace("qkv_proj", "q_proj"), param[:self.config.num_attention_heads * self.config.head_dim]) + # self._load(name.replace("qkv_proj", "k_proj"), param[self.config.num_attention_heads * self.config.head_dim:(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim]) + # self._load(name.replace("qkv_proj", "v_proj"), param[(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim:]) + # else: + param = param.contiguous() + if param.dtype not in [torch.int8, torch.int16, torch.int32]: + param = param.to(dtype) + C.load_model(name, param.data_ptr()) + + if "embed_tokens" in name and hasattr(self.config, "tie_word_embeddings") and self.config.tie_word_embeddings: + self._load("lm_head.weight", param) + + def _load_from_ckpt(self, path, cls=None): + supported_suffix_1 = ["bin.index.json", "safetensors.index.json"] + supported_suffix_2 = ["bin", "safetensors", "pt"] + file = None + for suffix in supported_suffix_1: + files = glob.glob(os.path.join(path, f"*.{suffix}")) + if len(files) > 1: + raise ValueError(f"Multiple files with suffix {suffix} found in {path}") + elif len(files) == 1: + file = files[0] + break + else: + for suffix in supported_suffix_2: + files = glob.glob(os.path.join(path, f"*.{suffix}")) + if len(files) > 1: + print(files) + if path + "/model_gptq_marlin.safetensors" in files: + file = path + "/model_gptq_marlin.safetensors" + else: + raise ValueError(f"Autogptq models not found in {path}") + break + elif len(files) == 1: + file = files[0] + break + else: + raise ValueError(f"No supported checkpoint file found in {path}, supported suffixes: {supported_suffix_1 + supported_suffix_2}") + + if file.endswith(".index.json"): + with open(file, "r") as f: + file_list = set(json.load(f)["weight_map"].values()) + file_list = [os.path.join(path, file) for file in file_list] + else: + file_list = [file] + + for file in file_list: + print(f"load from {file}") + if file.endswith(".bin") or file.endswith(".pt"): + ckpt = torch.load(file, map_location="cpu") + elif file.endswith(".safetensors"): + ckpt = load_file(file) + for name, param in ckpt.items(): + self._load(name, param, cls=cls) + + def load_from_hf(self): + with torch.no_grad(): + self._load_from_ckpt(self.path) + + # rope + if hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + rope_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type")) + if rope_type == "longrope" and not hasattr(self.config.rope_scaling, "factor"): + self.config.rope_scaling["factor"] = 1.0 + else: + rope_type = "default" + # TODO only support "default", "llama3" or "longrope" with long_factor=short_factor + inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](self.config, "cpu", seq_len=self.max_total_length) + # attention_scaling = torch.tensor([attention_scaling], dtype=torch.float32, device="cpu") + self._load("model.rotary_emb.inv_freq", inv_freq, dtype=torch.float32) + # self._load("model.rotary_emb.attention_scaling", attention_scaling, dtype=torch.float32) + + def prefill(self, input_ids, position_ids): + assert input_ids.dtype == torch.int32 + # Check if input length exceeds maximum supported length + if input_ids.numel() > self.max_total_length: + raise ValueError(f"Input token count ({input_ids.numel()}) exceeds maximum supported length ({self.max_total_length}) under current memory limit") + + total_length = input_ids.numel() + num_chunks = (total_length + self.chunk_length - 1) // self.chunk_length + + prefill_start_time = None + actual_prefill_start = None + + # User interaction logic only when use_enter is True + if self._show_prefill_progress and self.use_enter: + # Clear screen and move cursor to top, then show prompt and wait for user input + print("\033[2J\033[H", end="", flush=True) # Clear screen and move to top + print("Please Press Enter to Start Prefilling...", end="", flush=True) + input() # Wait for Enter key + + # Replace the prompt with [Prefilling] - clear entire line first + print("\r" + " " * 50 + "\r[Prefilling]", flush=True) + # Start timing after user presses Enter + prefill_start_time = time.time() + actual_prefill_start = prefill_start_time + + # Initialize progress display for stream mode (always when _show_prefill_progress is True) + if self._show_prefill_progress: + if prefill_start_time is None: # Only set start time if not already set above + prefill_start_time = time.time() + if not self.use_enter: + print("Prefilling: 0.0% (0/{} tokens) @ 0.0 tokens/s".format(total_length), end="", flush=True) + else: + print("Prefilling: 0.0% (0/{} tokens) @ 0.0 tokens/s".format(total_length), end="", flush=True) + + # Record actual computation start time if not set yet + if actual_prefill_start is None: + actual_prefill_start = time.time() + + for chunk_idx, i in enumerate(range(0, input_ids.numel(), self.chunk_length)): + # torch.cuda.nvtx.range_push(f"chunk from {i}") + C.prefill( + min(input_ids.numel() - i, self.chunk_length), i, + input_ids.view(-1)[i:].data_ptr(), position_ids.view(-1)[i:].data_ptr(), + self.logits.data_ptr() + ) + # torch.cuda.nvtx.range_pop() + + # Show progress for stream mode - always when _show_prefill_progress is True + if self._show_prefill_progress and prefill_start_time is not None: + current_tokens = min(i + self.chunk_length, total_length) + elapsed_time = time.time() - prefill_start_time + progress = (current_tokens * 100.0) / total_length + tokens_per_sec = current_tokens / elapsed_time if elapsed_time > 0 else 0.0 + print(f"\rPrefilling: {progress:.1f}% ({current_tokens}/{total_length} tokens) @ {tokens_per_sec:.1f} tokens/s", end="", flush=True) + + # Calculate actual prefill time + actual_prefill_time = time.time() - actual_prefill_start + + # Final completion status for stream mode + if self._show_prefill_progress: + if prefill_start_time is not None: + final_elapsed_time = time.time() - prefill_start_time + final_tokens_per_sec = total_length / final_elapsed_time if final_elapsed_time > 0 else 0.0 + print(f"\rPrefilling: 100.0% ({total_length}/{total_length} tokens) @ {final_tokens_per_sec:.1f} tokens/s - Complete!") + if self.use_enter: + print("\n[Decoding]") # Show decoding status and move to next line only with use_enter + else: + print() # Just a newline for normal mode + + # Store the actual prefill time for use in generate method + self._last_prefill_time = actual_prefill_time + + return self.logits[:1].clone() + + def decode(self, input_ids, position_ids, cache_length, mask_2d = None): + assert input_ids.dtype == torch.int32 + assert position_ids.dtype == torch.int32 + assert cache_length.dtype == torch.int32 + if mask_2d is not None: + # assert mask_2d.dtype == torch.uint64 + assert input_ids.numel() == mask_2d.shape[0] + + # torch.cuda.nvtx.range_push(f"decode") + cache_length += input_ids.numel() # temparary add for convinience in flash_attn + padded_length = (cache_length[0].item() + 128 - 1) // 128 * 128 + C.decode( + input_ids.numel(), padded_length, + input_ids.data_ptr(), position_ids.data_ptr(), cache_length.data_ptr(), + mask_2d.data_ptr() if mask_2d is not None else 0, + self.logits.data_ptr(), + self.cuda_graph + ) + cache_length -= input_ids.numel() + # torch.cuda.nvtx.range_pop() + return self.logits[:input_ids.numel()].clone() + + def generate(self, input_ids, generation_length=100, teminators=[], use_stream=False): + """ + Generate text with optional streaming output. + Returns (tokens, decode_time, prefill_time) if use_stream=False, or generator yielding {'token', 'text', 'is_finished', 'prefill_time', 'decode_time'} if use_stream=True. + """ + assert input_ids.dtype == torch.int32 + + prefix_length = input_ids.numel() + position_ids = torch.arange(prefix_length, dtype=torch.int32, device="cuda") + + # Set progress flag before prefill for stream mode + if use_stream: + self._show_prefill_progress = True + + # Measure prefill time + if self.use_enter and use_stream: + # In use_enter mode, timing will be handled inside prefill method + logits = self.prefill(input_ids, position_ids) + prefill_time = getattr(self, '_last_prefill_time', 0.0) # Get actual prefill time + else: + torch.cuda.synchronize() + prefill_start = time.time() + logits = self.prefill(input_ids, position_ids) + torch.cuda.synchronize() + prefill_time = time.time() - prefill_start + + if self.temperature > 0: + token = torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator)[0] + else: + token = logits[0].argmax(dim=-1).item() + + # Wait for user input before decode phase if use_decode_enter is enabled + if self.use_decode_enter: + if use_stream and self.use_enter: + # In stream mode with use_enter, we already showed [Decoding], just wait for input + print("Please Press Enter to Start Decoding...", end="", flush=True) + input() # Wait for Enter key + print("\r" + " " * 50 + "\r", end="", flush=True) # Clear the prompt without showing [Decoding] again + else: + # In other modes, show prompt and wait + print("Please Press Enter to Start Decoding...", end="", flush=True) + input() # Wait for Enter key + print("\r" + " " * 50 + "\r[Decoding]", flush=True) # Show [Decoding] only when use_enter is not enabled + + if not hasattr(self, "input_ids"): + self.input_ids = torch.tensor([0], dtype=torch.int32, device="cuda") + self.position_ids = torch.tensor([0], dtype=torch.int32, device="cuda") + self.cache_length = torch.tensor([0], dtype=torch.int32, device="cuda") + + if use_stream: + # Stream generation (optimized) + def _stream_generator(): + nonlocal token + # Keep minimal context for correct spacing + prev_token = token + + # yield first token + text = self.tokenizer.decode([token], skip_special_tokens=False) + + yield { + 'token': token, + 'text': text, + 'is_finished': token in teminators, + 'prefill_time': prefill_time, + 'decode_time': 0.0 # First token comes from prefill + } + + if token in teminators: + return + + decode_start_time = time.time() + + for i in range(generation_length-1): + self.input_ids[0] = token + self.position_ids[0] = prefix_length + i + self.cache_length[0] = prefix_length + i + + logits = self.decode(self.input_ids, self.position_ids, self.cache_length) + if self.temperature > 0: + token = torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator)[0] + else: + token = logits[0].argmax(dim=-1).item() + + # For correct spacing, decode with previous token context + if prev_token is not None: + context_tokens = [prev_token, token] + context_text = self.tokenizer.decode(context_tokens, skip_special_tokens=False) + prev_text = self.tokenizer.decode([prev_token], skip_special_tokens=False) + text = context_text[len(prev_text):] + else: + text = self.tokenizer.decode([token], skip_special_tokens=False) + + is_finished = token in teminators or i == generation_length - 2 + + # Calculate time only when needed to reduce overhead + decode_time = time.time() - decode_start_time + + yield { + 'token': token, + 'text': text, + 'is_finished': is_finished, + 'prefill_time': 0.0, # Only report prefill_time for first token + 'decode_time': decode_time + } + + if token in teminators: + break + + # Update prev_token + prev_token = token + + return _stream_generator() + else: + # Original batch generation + tokens = [token] + torch.cuda.synchronize() + decode_start = time.time() + for i in range(generation_length-1): + self.input_ids[0] = token + self.position_ids[0] = prefix_length + i + self.cache_length[0] = prefix_length + i + + logits = self.decode(self.input_ids, self.position_ids, self.cache_length) + if self.temperature > 0: + token = torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator)[0].item() + else: + token = logits[0].argmax(dim=-1).item() + tokens.append(token) + if token in teminators: + break + torch.cuda.synchronize() + decode_time = time.time() - decode_start + return tokens, decode_time, prefill_time + + def print_perf_summary(self): + C.print_perf_summary() \ No newline at end of file diff --git a/examples/CPM.cu/cpmcu/speculative/__init__.py b/examples/CPM.cu/cpmcu/speculative/__init__.py new file mode 100644 index 00000000..e1dcad20 --- /dev/null +++ b/examples/CPM.cu/cpmcu/speculative/__init__.py @@ -0,0 +1 @@ +from .eagle import LLM_with_eagle \ No newline at end of file diff --git a/examples/CPM.cu/cpmcu/speculative/eagle.py b/examples/CPM.cu/cpmcu/speculative/eagle.py new file mode 100644 index 00000000..83ef7f7e --- /dev/null +++ b/examples/CPM.cu/cpmcu/speculative/eagle.py @@ -0,0 +1,99 @@ +from .. import C +from .tree_drafter import LLM_with_tree_drafter +import math, torch +from transformers import PretrainedConfig + +class EagleConfig(PretrainedConfig): + def __init__( + self, + num_hidden_layers=1, + **kwargs, + ): + super().__init__(**kwargs) + self.eagle_num_layers = num_hidden_layers + +class LLM_with_eagle(LLM_with_tree_drafter): + def __init__(self, + eagle_path, + base_path, + num_iter=6, + topk_per_iter=10, + tree_size=60, + eagle_window_size=0, + frspec_vocab_size=0, + apply_eagle_quant: bool=False, + use_rope: bool=False, + use_input_norm: bool=False, + use_attn_norm: bool=False, + **kwargs): + super().__init__( + "eagle", eagle_path, base_path, + tree_size = tree_size, + **kwargs + ) + + self.eagle_path = eagle_path + self.eagle_config = EagleConfig.from_pretrained(eagle_path) + # Ensure presence consistency and equality for scale_depth, dim_model_base, and scale_emb + for attr in ("scale_depth", "dim_model_base", "scale_emb"): + base_has = hasattr(self.config, attr) + eagle_has = hasattr(self.eagle_config, attr) + assert base_has == eagle_has, f"{attr} presence mismatch between base and eagle config" + if base_has: + assert getattr(self.config, attr) == getattr(self.eagle_config, attr), f"{attr} in base config and eagle config should be the same" + scale_residual = self.config.scale_depth / math.sqrt(self.config.num_hidden_layers + 1) if hasattr(self.config, "scale_depth") else 1.0 + self.apply_eagle_quant = apply_eagle_quant + if apply_eagle_quant and hasattr(self.eagle_config, "quantization_config"): + self.group_size = self.eagle_config.quantization_config.get('group_size', 0) + else: + self.group_size = 0 + assert self.group_size == 128 or self.group_size == 0, "only group_size 128 is supported in quantization mode" + + if not use_rope and not use_input_norm and not use_attn_norm and not apply_eagle_quant: + C.init_eagle_model( + self.eagle_config.eagle_num_layers, + num_iter, + topk_per_iter, + self.tree_size, + self.dtype_int + ) + else: + C.init_minicpm4_eagle_model( + self.eagle_config.eagle_num_layers, + num_iter, + topk_per_iter, + self.tree_size, + self.dtype_int, + apply_eagle_quant, + self.group_size, + eagle_window_size, + frspec_vocab_size, + scale_residual, + use_input_norm, + use_attn_norm + ) + + def _load(self, name, param, dtype=None, cls=None): + if cls == self.drafter_type: + if name == "token_id_remap": + C.load_model(f"{cls}.{name}", param.data_ptr()) + return + if dtype is None: + dtype = self.dtype + param = param.contiguous() + if not self.apply_eagle_quant: + param = param.to(dtype) + if 'embed_tokens' in name: + return + if 'fc' in name: + if 'weight' in name: + param1 = param[..., :param.shape[-1] // 2].contiguous() + param2 = param[..., param.shape[-1] // 2:].contiguous() + C.load_model(f"{cls}.{name.replace('fc', 'fc1')}", param1.data_ptr()) + C.load_model(f"{cls}.{name.replace('fc', 'fc2')}", param2.data_ptr()) + else: # bias + C.load_model(f"{cls}.{name.replace('fc', 'fc1')}", param.data_ptr()) + else: + C.load_model(f"{cls}.{name}", param.data_ptr()) + else: + super()._load(name, param, dtype) diff --git a/examples/CPM.cu/cpmcu/speculative/eagle_base_quant/__init__.py b/examples/CPM.cu/cpmcu/speculative/eagle_base_quant/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/CPM.cu/cpmcu/speculative/eagle_base_quant/eagle_base_w4a16_marlin_gptq.py b/examples/CPM.cu/cpmcu/speculative/eagle_base_quant/eagle_base_w4a16_marlin_gptq.py new file mode 100644 index 00000000..3df755b8 --- /dev/null +++ b/examples/CPM.cu/cpmcu/speculative/eagle_base_quant/eagle_base_w4a16_marlin_gptq.py @@ -0,0 +1,103 @@ +from ... import C +from ..eagle import EagleConfig +from ..tree_drafter_base_quant.tree_drafter_w4a16_gptq_marlin import W4A16GPTQMarlinLLM_with_tree_drafter + +import math, torch + + +class W4A16GPTQMarlinLLM_with_eagle(W4A16GPTQMarlinLLM_with_tree_drafter): + def __init__(self, + eagle_path, + base_path, + num_iter=6, + topk_per_iter=10, + tree_size=60, + eagle_window_size=0, + frspec_vocab_size=0, + apply_eagle_quant: bool=False, + use_rope: bool=False, + use_input_norm: bool=False, + use_attn_norm: bool=False, + use_rotation: bool=False, + **kwargs): + super().__init__( + "eagle", eagle_path, base_path, + tree_size = tree_size, + **kwargs + ) + + self.eagle_path = eagle_path + self.eagle_config = EagleConfig.from_pretrained(eagle_path) + # Ensure presence consistency and equality for scale_depth, dim_model_base, and scale_emb + for attr in ("scale_depth", "dim_model_base", "scale_emb"): + base_has = hasattr(self.config, attr) + eagle_has = hasattr(self.eagle_config, attr) + assert base_has == eagle_has, f"{attr} presence mismatch between base and eagle config" + if base_has: + assert getattr(self.config, attr) == getattr(self.eagle_config, attr), f"{attr} in base config and eagle config should be the same" + scale_residual = self.config.scale_depth / math.sqrt(self.config.num_hidden_layers + 1) if hasattr(self.config, "scale_depth") else 1.0 + self.use_rotation = use_rotation + self.apply_eagle_quant = apply_eagle_quant + if apply_eagle_quant and hasattr(self.eagle_config, "quantization_config"): + self.group_size = self.eagle_config.quantization_config.get('group_size', 0) + else: + self.group_size = 0 + assert self.group_size == 128 or self.group_size == 0, "only group_size 128 is supported in quantization mode" + + if not use_rope and not use_input_norm and not use_attn_norm and not apply_eagle_quant: + if not use_rotation: + C.init_eagle_model( + self.eagle_config.eagle_num_layers, + num_iter, + topk_per_iter, + self.tree_size, + self.dtype_int + ) + else: + C.init_eagle_w4a16_gptq_marlin_rot_model( + self.eagle_config.eagle_num_layers, + num_iter, + topk_per_iter, + self.tree_size, + self.dtype_int + ) + else: + C.init_minicpm4_eagle_model( + self.eagle_config.eagle_num_layers, + num_iter, + topk_per_iter, + self.tree_size, + self.dtype_int, + apply_eagle_quant, + self.group_size, + eagle_window_size, + frspec_vocab_size, + scale_residual, + use_input_norm, + use_attn_norm + ) + + def _load(self, name, param, dtype=None, cls=None): + if cls == self.drafter_type: + if name == "token_id_remap": + C.load_model(f"{cls}.{name}", param.data_ptr()) + return + if dtype is None: + dtype = self.dtype + param = param.contiguous() + if not self.apply_eagle_quant: + param = param.to(dtype) + if (not self.use_rotation) and 'embed_tokens' in name: + return + if 'fc' in name: + if 'weight' in name or "scales" in name: + param1 = param[..., :param.shape[-1] // 2].contiguous() + param2 = param[..., param.shape[-1] // 2:].contiguous() + C.load_model(f"{cls}.{name.replace('fc', 'fc1')}", param1.data_ptr()) + C.load_model(f"{cls}.{name.replace('fc', 'fc2')}", param2.data_ptr()) + else: # bias + C.load_model(f"{cls}.{name.replace('fc', 'fc1')}", param.data_ptr()) + else: + C.load_model(f"{cls}.{name}", param.data_ptr()) + else: + super()._load(name, param, dtype) diff --git a/examples/CPM.cu/cpmcu/speculative/hier_spec_quant/__init__.py b/examples/CPM.cu/cpmcu/speculative/hier_spec_quant/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/CPM.cu/cpmcu/speculative/hier_spec_quant/hier_eagle_w4a16_gm_spec_w4a16_gm.py b/examples/CPM.cu/cpmcu/speculative/hier_spec_quant/hier_eagle_w4a16_gm_spec_w4a16_gm.py new file mode 100644 index 00000000..e4f9fc5c --- /dev/null +++ b/examples/CPM.cu/cpmcu/speculative/hier_spec_quant/hier_eagle_w4a16_gm_spec_w4a16_gm.py @@ -0,0 +1,268 @@ +from ... import C +from ...llm_w4a16_gptq_marlin import W4A16GPTQMarlinLLM + +import numpy as np +import torch +from ..tree_drafter import * +import time +from transformers import PretrainedConfig, AutoTokenizer, AutoConfig +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS + +def pack_draft_mask(mask_2d): + ''' + for static masks, pack them into a uint64 per row + ''' + mask_2d_packed = torch.zeros((mask_2d.shape[0]), dtype=torch.uint16, device="cuda") + for i in range(mask_2d.shape[0]): + mask_1 = 0 + for j in range(i + 1): + mask_1 |= (mask_2d[i][j].item() << j ) + mask_2d_packed[i] = mask_1 + mask_2d_packed = mask_2d_packed.view(torch.uint16).view(-1) + return mask_2d_packed + +class EagleConfig(PretrainedConfig): + def __init__( + self, + num_hidden_layers=1, + **kwargs, + ): + super().__init__(**kwargs) + self.eagle_num_layers = num_hidden_layers + +class HierEagleW4A16GMSpecW4A16GM(W4A16GPTQMarlinLLM): + def __init__(self, + drafter_path: str, + base_path: str, + min_draft_length: int, + draft_cuda_graph: bool, + tree_path: str, + ea_num_iter=6, + ea_topk_per_iter=10, + tree_size=60, + draft_model_start=False, + rotation=False, + **kwargs): + + super().__init__(base_path, **kwargs) + + # eagle config + self.tree_drafter_type = 'eagle' + self.eagle_path = tree_path + self.ea_num_iter = ea_num_iter + self.ea_topk_per_iter = ea_topk_per_iter + self.tree_size = tree_size + + self.tree_draft_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda") + self.tree_position_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda") + self.tree_gt_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda") + self.tree_attn_mask = torch.empty((tree_size), dtype=torch.uint64, device="cuda") + self.tree_parent = torch.empty((tree_size), dtype=torch.int32, device="cuda") + + + self.eagle_config = EagleConfig.from_pretrained(self.eagle_path) + self.rotation = rotation + + + # draft config + self.drafter_type = 'draft' + self.drafter_path = drafter_path + self.drafter_tokenizer = AutoTokenizer.from_pretrained(drafter_path) + self.drafter_config = AutoConfig.from_pretrained(drafter_path) + + self.min_draft_length = min_draft_length + self.max_draft_length = min_draft_length + ea_num_iter + self.draft_ids = torch.empty((self.max_draft_length+1), dtype=torch.int32, device="cuda") + self.draft_position_ids = torch.empty((self.max_draft_length+1), dtype=torch.int32, device="cuda") + self.draft_gt_ids = torch.empty((self.max_draft_length+1), dtype=torch.int32, device="cuda") + self.draft_attn_mask = pack_draft_mask( + torch.tril(torch.ones(self.max_draft_length+1, self.max_draft_length+1, dtype=torch.bool)).to("cuda") + ) + + # eagle accept list + self.draft_ea_accept_list = torch.empty((1024,), dtype=torch.int32, device="cuda") + + self.draft_logits = torch.empty((64, self.config.vocab_size), dtype=self.dtype, device="cuda") + self.draft_cache_length = torch.tensor([0], dtype=torch.int32, device="cuda") + self.cache_length = torch.tensor([0], dtype=torch.int32, device="cuda") + self.draft_cuda_graph = draft_cuda_graph + + self.draft_model_start = draft_model_start + + self.draft_group_size = self.drafter_config.quantization_config['group_size'] + + if self.rotation: + C.init_hier_eagle_w4a16_gm_rot_spec_w4a16_gm_model( + self.drafter_config.vocab_size, + self.drafter_config.num_hidden_layers, + self.drafter_config.hidden_size, + self.drafter_config.intermediate_size, + self.drafter_config.num_attention_heads, + self.drafter_config.num_key_value_heads, + self.drafter_config.head_dim, + self.drafter_config.rms_norm_eps, + self.draft_group_size, + self.min_draft_length, + self.draft_cuda_graph, + self.eagle_config.eagle_num_layers, + self.ea_num_iter, + self.ea_topk_per_iter, + self.tree_size, + self.draft_model_start, + 0, + ) + else: + C.init_hier_eagle_w4a16_gm_spec_w4a16_gm_model( + self.drafter_config.vocab_size, + self.drafter_config.num_hidden_layers, + self.drafter_config.hidden_size, + self.drafter_config.intermediate_size, + self.drafter_config.num_attention_heads, + self.drafter_config.num_key_value_heads, + self.drafter_config.head_dim, + self.drafter_config.rms_norm_eps, + self.draft_group_size, + self.min_draft_length, + self.draft_cuda_graph, + self.eagle_config.eagle_num_layers, + self.ea_num_iter, + self.ea_topk_per_iter, + self.tree_size, + self.draft_model_start, + 0, + ) + + # def load_from_hf(self): + # self._load_from_ckpt(self.eagle_path, cls=self.tree_drafter_type) + # self._load_from_ckpt(self.drafter_path, cls=self.drafter_type) + # super().load_from_hf() + + def _load(self, name, param, dtype=None, cls=None): + if cls == self.tree_drafter_type: + if dtype is None: + dtype = self.dtype + param = param.contiguous().to(dtype) + if (not self.rotation) and 'embed_tokens' in name: + return + if 'fc' in name: + if 'weight' in name: + param1 = param[..., :param.shape[-1] // 2].contiguous() + param2 = param[..., param.shape[-1] // 2:].contiguous() + C.load_model(f"{cls}.{name.replace('fc', 'fc1')}", param1.data_ptr()) + C.load_model(f"{cls}.{name.replace('fc', 'fc2')}", param2.data_ptr()) + else: # bias + C.load_model(f"{cls}.{name.replace('fc', 'fc1')}", param.data_ptr()) + else: + C.load_model(f"{cls}.{name}", param.data_ptr()) + elif cls == self.drafter_type: + if dtype is None: + if 'rotary_emb' in name: + dtype = torch.float32 + else: + dtype = self.dtype + + # if 'gate_up_proj' in name: + # self._load(name.replace("gate_up_proj", "gate_proj"), param[:param.shape[0]//2], dtype, cls=cls) + # self._load(name.replace("gate_up_proj", "up_proj"), param[param.shape[0]//2:], cls=cls) + # elif 'qkv_proj' in name: + # self._load(name.replace("qkv_proj", "q_proj"), param[:self.config.num_attention_heads * self.config.head_dim], cls=cls) + # self._load(name.replace("qkv_proj", "k_proj"), param[self.config.num_attention_heads * self.config.head_dim:(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim], cls=cls) + # self._load(name.replace("qkv_proj", "v_proj"), param[(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim:], cls=cls) + # else: + param = param.contiguous() + if param.dtype not in [torch.int8, torch.int16, torch.int32]: + param = param.to(dtype) + C.load_model(f"{cls}.{name}", param.data_ptr()) + + if "embed_tokens" in name and hasattr(self.config, "tie_word_embeddings") and self.config.tie_word_embeddings: + self._load("lm_head.weight", param, cls) + else: + super()._load(name, param, dtype) + + def load_from_hf(self): + with torch.no_grad(): + # ealge load + self._load_from_ckpt(self.eagle_path, cls=self.tree_drafter_type) + + self._load_from_ckpt(self.drafter_path, cls=self.drafter_type) + # rope + if hasattr(self.drafter_config, "rope_scaling") and self.drafter_config.rope_scaling is not None: + draft_rope_type = self.drafter_config.rope_scaling.get("rope_type", self.drafter_config.rope_scaling.get("type")) + else: + draft_rope_type = "default" + # TODO only support "default", "llama3" or "longrope" with long_factor=short_factor + draft_inv_freq, draft_attention_scaling = ROPE_INIT_FUNCTIONS[draft_rope_type](self.drafter_config, "cpu", seq_len=self.max_total_length) + # attention_scaling = torch.tensor([attention_scaling], dtype=torch.float32, device="cpu") + self._load(f"{self.drafter_type}.model.rotary_emb.inv_freq", draft_inv_freq, dtype=torch.float32, cls=self.drafter_type) + # self._load("model.rotary_emb.attention_scaling", attention_scaling, dtype=torch.float32) + + super().load_from_hf() + + + def generate(self, input_ids, generation_length=100, teminators=[]): + assert input_ids.dtype == torch.int32 + + prefix_length = input_ids.shape[1] + + position_ids = torch.arange(prefix_length, dtype=torch.int32, device="cuda") + logits = self.prefill(input_ids, position_ids) + self.draft_ids[:1].copy_(logits[0].argmax(dim=-1)) + + tokens = torch.empty((generation_length), dtype=torch.int32, device="cuda") + tokens[0].copy_(self.draft_ids[0]) + accept_lengths = [] + i = 0 + model_step = 0 + terminal = False + torch.cuda.synchronize() + start_time = time.time() + + + while i < generation_length-1 and not terminal: + self.cache_length[0] = prefix_length + i + self.draft_position_ids[0] = prefix_length + i + + + # step 1: draft model prefill and eagle input prepare + C.draft( + self.draft_ids.data_ptr(), + self.draft_position_ids.data_ptr(), + self.cache_length.data_ptr(), + self.draft_attn_mask.data_ptr(), + self.draft_ea_accept_list.data_ptr(), + ) + + + # step 2: target model decode (length fixed for cuda graph) + logits = self.decode(self.draft_ids, self.draft_position_ids, self.cache_length, mask_2d=self.draft_attn_mask) + self.draft_gt_ids.copy_(logits.argmax(dim=-1)) + + # step 6: verify and fix target model and eagle input + accept_length = C.verify_and_fix( + self.draft_ids.numel(), + self.draft_ids.data_ptr(), + self.draft_gt_ids.data_ptr(), + self.draft_position_ids.data_ptr(), + self.cache_length.data_ptr(), + self.draft_attn_mask.data_ptr(), + self.draft_ea_accept_list.data_ptr(), + ) + + model_step += 1 + accept_lengths.append(accept_length) + for temin in teminators: + if temin in self.draft_gt_ids[:accept_length]: + terminal = True + append_length = min(accept_length, generation_length - 1 - i) + tokens[1+i:1+i+append_length].copy_(self.draft_gt_ids[:append_length]) + self.draft_ids[0] = self.draft_gt_ids[accept_length - 1] + i += accept_length + + + # print(f"ea accept avg:", np.mean(self.draft_ea_accept_list[1:ea_acc_nums+1].cpu().numpy())) + + torch.cuda.synchronize() + decode_time = time.time() - start_time + ea_acc_nums = self.draft_ea_accept_list[0].item() + tokens = tokens[:i+1].tolist() + return tokens, accept_lengths, model_step, decode_time, self.draft_ea_accept_list[1:1+ea_acc_nums].clone() \ No newline at end of file diff --git a/examples/CPM.cu/cpmcu/speculative/spec_quant/__init__.py b/examples/CPM.cu/cpmcu/speculative/spec_quant/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/CPM.cu/cpmcu/speculative/spec_quant/spec_w4a16_gm_for_w4a16_gm_model.py b/examples/CPM.cu/cpmcu/speculative/spec_quant/spec_w4a16_gm_for_w4a16_gm_model.py new file mode 100644 index 00000000..15b22414 --- /dev/null +++ b/examples/CPM.cu/cpmcu/speculative/spec_quant/spec_w4a16_gm_for_w4a16_gm_model.py @@ -0,0 +1,160 @@ +from ... import C +from transformers import AutoTokenizer, AutoConfig +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +from safetensors.torch import load_file + +from ...llm_w4a16_gptq_marlin import W4A16GPTQMarlinLLM + +import torch +import time + +def pack_draft_mask(mask_2d): + ''' + for static masks, pack them into a uint64 per row + ''' + mask_2d_packed = torch.zeros((mask_2d.shape[0]), dtype=torch.uint16, device="cuda") + for i in range(mask_2d.shape[0]): + mask_1 = 0 + for j in range(i + 1): + mask_1 |= (mask_2d[i][j].item() << j ) + mask_2d_packed[i] = mask_1 + mask_2d_packed = mask_2d_packed.view(torch.uint16).view(-1) + return mask_2d_packed + + +class W4A16GMSpecW4A16GM(W4A16GPTQMarlinLLM): + def __init__(self, + drafter_path: str, + base_path: str, + draft_num: int, + draft_cuda_graph: bool, + **kwargs): + super().__init__(base_path, **kwargs) + + self.drafter_type = 'draft' + self.drafter_path = drafter_path + self.drafter_tokenizer = AutoTokenizer.from_pretrained(drafter_path) + self.drafter_config = AutoConfig.from_pretrained(drafter_path) + + self.draft_num = draft_num + self.draft_ids = torch.empty((self.draft_num+1), dtype=torch.int32, device="cuda") + self.draft_position_ids = torch.empty((self.draft_num+1), dtype=torch.int32, device="cuda") + self.draft_gt_ids = torch.empty((self.draft_num+1), dtype=torch.int32, device="cuda") + self.draft_attn_mask = pack_draft_mask( + torch.tril(torch.ones(draft_num+1, draft_num+1, dtype=torch.bool)).to("cuda") + ) + self.draft_parent = torch.tensor([], dtype=torch.int32, device="cuda") + self.cache_length = torch.tensor([0], dtype=torch.int32, device="cuda") + + self.draft_cuda_graph = draft_cuda_graph + + self.draft_group_size = self.drafter_config.quantization_config['group_size'] + + C.init_w4a16_gm_spec_w4a16_gm_model( + self.drafter_config.vocab_size, + self.drafter_config.num_hidden_layers, + self.drafter_config.hidden_size, + self.drafter_config.intermediate_size, + self.drafter_config.num_attention_heads, + self.drafter_config.num_key_value_heads, + self.drafter_config.head_dim, + self.drafter_config.rms_norm_eps, + self.draft_group_size, + self.draft_num, + self.draft_cuda_graph, + self.dtype_int, + ) + + def _load(self, name, param, dtype=None, cls=None): + if cls == self.drafter_type: + if dtype is None: + if 'rotary_emb' in name: + dtype = torch.float32 + else: + dtype = self.dtype + + # if 'gate_up_proj' in name: + # self._load(name.replace("gate_up_proj", "gate_proj"), param[:param.shape[0]//2], dtype, cls=cls) + # self._load(name.replace("gate_up_proj", "up_proj"), param[param.shape[0]//2:], cls=cls) + # elif 'qkv_proj' in name: + # self._load(name.replace("qkv_proj", "q_proj"), param[:self.config.num_attention_heads * self.config.head_dim], cls=cls) + # self._load(name.replace("qkv_proj", "k_proj"), param[self.config.num_attention_heads * self.config.head_dim:(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim], cls=cls) + # self._load(name.replace("qkv_proj", "v_proj"), param[(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim:], cls=cls) + # else: + param = param.contiguous() + if param.dtype not in [torch.int8, torch.int16, torch.int32]: + param = param.to(dtype) + C.load_model(f"{cls}.{name}", param.data_ptr()) + + if "embed_tokens" in name and hasattr(self.config, "tie_word_embeddings") and self.config.tie_word_embeddings: + self._load("lm_head.weight", param, cls) + else: + super()._load(name, param, dtype) + + + def load_from_hf(self): + with torch.no_grad(): + self._load_from_ckpt(self.drafter_path, cls=self.drafter_type) + # rope + if hasattr(self.drafter_config, "rope_scaling") and self.drafter_config.rope_scaling is not None: + draft_rope_type = self.drafter_config.rope_scaling.get("rope_type", self.drafter_config.rope_scaling.get("type")) + else: + draft_rope_type = "default" + # TODO only support "default", "llama3" or "longrope" with long_factor=short_factor + draft_inv_freq, draft_attention_scaling = ROPE_INIT_FUNCTIONS[draft_rope_type](self.drafter_config, "cpu", seq_len=self.max_total_length) + # attention_scaling = torch.tensor([attention_scaling], dtype=torch.float32, device="cpu") + self._load(f"{self.drafter_type}.model.rotary_emb.inv_freq", draft_inv_freq, dtype=torch.float32, cls=self.drafter_type) + # self._load("model.rotary_emb.attention_scaling", attention_scaling, dtype=torch.float32) + super().load_from_hf() + + + def generate(self, input_ids, generation_length=100, teminators=[]): + assert input_ids.dtype == torch.int32 + + prefix_length = input_ids.numel() + position_ids = torch.arange(prefix_length, dtype=torch.int32, device="cuda") + logits = self.prefill(input_ids, position_ids) + self.draft_ids[:1].copy_(logits[0].argmax(dim=-1)) + + tokens = torch.empty((generation_length), dtype=torch.int32, device="cuda") + tokens[0].copy_(self.draft_ids[0]) + accept_lengths = [] + i = 0 + model_step = 0 + terminal = False + torch.cuda.synchronize() + start_time = time.time() + while i < generation_length-1 and not terminal: + self.cache_length[0] = prefix_length + i + self.draft_position_ids[0] = prefix_length + i + + # torch.cuda.nvtx.range_push(f"draft") + C.draft(self.draft_ids.data_ptr(), self.draft_position_ids.data_ptr(), self.cache_length.data_ptr(), self.draft_attn_mask.data_ptr(), self.draft_parent.data_ptr()) + # torch.cuda.nvtx.range_pop() + + + logits = self.decode(self.draft_ids, self.draft_position_ids, self.cache_length, mask_2d=self.draft_attn_mask) + self.draft_gt_ids.copy_(logits.argmax(dim=-1)) + + # torch.cuda.nvtx.range_push(f"verify") + accept_length = C.verify_and_fix( + self.draft_ids.numel(), self.draft_ids.data_ptr(), self.draft_gt_ids.data_ptr(), + self.draft_position_ids.data_ptr(), self.cache_length.data_ptr(), + self.draft_attn_mask.data_ptr(), self.draft_parent.data_ptr() + ) + # torch.cuda.nvtx.range_pop() + + model_step += 1 + accept_lengths.append(accept_length) + for temin in teminators: + if temin in self.draft_gt_ids[:accept_length]: + terminal = True + append_length = min(accept_length, generation_length - 1 - i) + tokens[1+i:1+i+append_length].copy_(self.draft_gt_ids[:append_length]) + self.draft_ids[0] = self.draft_gt_ids[accept_length - 1] + i += accept_length + torch.cuda.synchronize() + decode_time = time.time() - start_time + tokens = tokens[:1+i].tolist() + + return tokens, accept_lengths, model_step, decode_time \ No newline at end of file diff --git a/examples/CPM.cu/cpmcu/speculative/tree_drafter.py b/examples/CPM.cu/cpmcu/speculative/tree_drafter.py new file mode 100644 index 00000000..428c9023 --- /dev/null +++ b/examples/CPM.cu/cpmcu/speculative/tree_drafter.py @@ -0,0 +1,255 @@ +from .. import C +from ..llm import LLM + +import torch +import torch.nn.functional as F +import time +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS + +def pack_mask(mask_2d): + ''' + for static masks, pack them into a uint64 per row + ''' + mask_2d_packed = torch.zeros((mask_2d.shape[0], 2), dtype=torch.uint32, device="cuda") + for i in range(mask_2d.shape[0]): + mask_1 = 0 + mask_2 = 0 + for j in range(i + 1): + if j < 32: + mask_1 |= (mask_2d[i][j].item() << j) + else: + mask_2 |= (mask_2d[i][j].item() << (j - 32)) + mask_2d_packed[i][0] = mask_1 + mask_2d_packed[i][1] = mask_2 + mask_2d_packed = mask_2d_packed.view(torch.uint64).view(-1) + return mask_2d_packed + +class LLM_with_tree_drafter(LLM): + def __init__(self, + drafter_type, drafter_path, base_path, + tree_size, + use_rope: bool=False, + **kwargs): + super().__init__(base_path, **kwargs) + + self.drafter_type = drafter_type + self.drafter_path = drafter_path + self.base_path = base_path + self.use_rope = use_rope + + self.tree_size = tree_size + self.tree_draft_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda") + self.tree_position_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda") + self.tree_gt_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda") + self.tree_attn_mask = torch.empty((tree_size), dtype=torch.uint64, device="cuda") + self.tree_parent = torch.empty((tree_size), dtype=torch.int32, device="cuda") + self.tree_position_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda") + + self.cache_length = torch.tensor([0], dtype=torch.int32, device="cuda") + + def load_from_hf(self): + with torch.no_grad(): + self._load_from_ckpt(self.drafter_path, cls=self.drafter_type) + + if self.use_rope: + if hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + rope_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type")) + else: + rope_type = "default" + # TODO only support "default", "llama3" or "longrope" with long_factor=short_factor + inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](self.config, "cpu", seq_len=self.max_total_length) + # attention_scaling = torch.tensor([attention_scaling], dtype=torch.float32, device="cpu") + self._load(f"{self.drafter_type}.rotary_emb.inv_freq", inv_freq, dtype=torch.float32) + # self._load("model.rotary_emb.attention_scaling", attention_scaling, dtype=torch.float32) + + super().load_from_hf() + + def generate(self, input_ids, generation_length=100, teminators=[], use_stream=False): + """ + Generate text with optional streaming output for tree drafter. + Returns (tokens, accept_lengths, decode_time, prefill_time) if use_stream=False, or generator yielding {'token', 'text', 'is_finished', 'accept_length', 'prefill_time', 'decode_time'} if use_stream=True. + """ + assert input_ids.dtype == torch.int32 + + prefix_length = input_ids.numel() + # Check if input length exceeds maximum supported length + if prefix_length > self.max_total_length: + raise ValueError(f"Input token count ({prefix_length}) exceeds maximum supported length ({self.max_total_length}) under current memory limit") + + position_ids = torch.arange(prefix_length, dtype=torch.int32, device="cuda") + + # Set progress flag before prefill for stream mode + if use_stream: + self._show_prefill_progress = True + else: + self._show_prefill_progress = False + + # Measure prefill time + if self.use_enter and use_stream: + # In use_enter mode, timing will be handled inside prefill method + logits = self.prefill(input_ids, position_ids) + prefill_time = getattr(self, '_last_prefill_time', 0.0) # Get actual prefill time + else: + torch.cuda.synchronize() + prefill_start = time.time() + logits = self.prefill(input_ids, position_ids) + torch.cuda.synchronize() + prefill_time = time.time() - prefill_start + + if self.temperature > 0.0: + self.tree_draft_ids[:1].copy_(torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator)) + else: + self.tree_draft_ids[:1].copy_(logits[0].argmax(dim=-1)) + + # Wait for user input before decode phase if use_decode_enter is enabled + if self.use_decode_enter: + if use_stream and self.use_enter: + # In stream mode with use_enter, we already showed [Decoding], just wait for input + print("Please Press Enter to Start Decoding...", end="", flush=True) + input() # Wait for Enter key + print("\r" + " " * 50 + "\r", end="", flush=True) # Clear the prompt without showing [Decoding] again + else: + # In other modes, show prompt and wait + print("Please Press Enter to Start Decoding...", end="", flush=True) + input() # Wait for Enter key + print("\r" + " " * 50 + "\r[Decoding]", flush=True) # Show [Decoding] only when use_enter is not enabled + + if use_stream: + # Stream generation for tree drafter (optimized) + def _stream_generator(): + # Keep minimal context for correct spacing + prev_token = None + + # yield first token + token = self.tree_draft_ids[0].item() + text = self.tokenizer.decode([token], skip_special_tokens=False) + prev_token = token + + yield { + 'token': token, + 'text': text, + 'is_finished': token in teminators, + 'accept_length': 1, + 'prefill_time': prefill_time, + 'decode_time': 0.0 # First token comes from prefill + } + + if token in teminators: + return + + decode_start_time = time.time() + + i = 0 + while i < generation_length-1: + self.cache_length[0] = prefix_length + i + + # draft step + C.draft(self.tree_draft_ids.data_ptr(), self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(), self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr()) + + logits = self.decode(self.tree_draft_ids, self.tree_position_ids, self.cache_length, mask_2d=self.tree_attn_mask) + if self.temperature > 0.0: + self.tree_gt_ids.copy_(torch.multinomial(F.softmax(logits/self.temperature, dim=-1), num_samples=1, generator=self.generator).squeeze(-1)) + else: + self.tree_gt_ids.copy_(logits.argmax(dim=-1)) + + # verify step + accept_length = C.verify_and_fix( + self.tree_draft_ids.numel(), self.tree_draft_ids.data_ptr(), self.tree_gt_ids.data_ptr(), + self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(), + self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr() + ) + + # yield accepted tokens (optimized with minimal context) + if accept_length > 0: + accepted_tokens = self.tree_draft_ids[:accept_length].tolist() + + # For correct spacing, decode with previous token context + if prev_token is not None: + context_tokens = [prev_token] + accepted_tokens + context_text = self.tokenizer.decode(context_tokens, skip_special_tokens=False) + prev_text = self.tokenizer.decode([prev_token], skip_special_tokens=False) + new_text = context_text[len(prev_text):] + else: + new_text = self.tokenizer.decode(accepted_tokens, skip_special_tokens=False) + + # Yield tokens with batch text for first token, empty for others + for j in range(accept_length): + if i + j >= generation_length - 1: + break + + token = accepted_tokens[j] + + # Give all new text to first token, empty to others + if j == 0: + text = new_text + else: + text = "" + + terminal = token in teminators + is_finished = terminal or (i + j == generation_length - 2) + + # Only calculate time for the last token in the batch to reduce overhead + decode_time = time.time() - decode_start_time if j == accept_length - 1 else 0.0 + + yield { + 'token': token, + 'text': text, + 'is_finished': is_finished, + 'accept_length': accept_length if j == 0 else 0, # only report accept_length for first token in batch + 'prefill_time': 0.0, # Only report prefill_time for first token + 'decode_time': decode_time + } + + if terminal: + return + + # Update prev_token to the last accepted token + prev_token = accepted_tokens[-1] + + self.tree_draft_ids[0] = self.tree_draft_ids[accept_length - 1] + i += accept_length + + return _stream_generator() + else: + # Original batch generation + tokens = torch.empty((generation_length), dtype=torch.int32, device="cuda") + tokens[0].copy_(self.tree_draft_ids[0]) + accept_lengths = [] + i = 0 + terminal = False + torch.cuda.synchronize() + decode_start = time.time() + while i < generation_length-1 and not terminal: + self.cache_length[0] = prefix_length + i + + # torch.cuda.nvtx.range_push(f"draft") + C.draft(self.tree_draft_ids.data_ptr(), self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(), self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr()) + # torch.cuda.nvtx.range_pop() + + logits = self.decode(self.tree_draft_ids, self.tree_position_ids, self.cache_length, mask_2d=self.tree_attn_mask) + if self.temperature > 0.0: + self.tree_gt_ids.copy_(torch.multinomial(F.softmax(logits/self.temperature, dim=-1), num_samples=1, generator=self.generator).squeeze(-1)) + else: + self.tree_gt_ids.copy_(logits.argmax(dim=-1)) + + # torch.cuda.nvtx.range_push(f"verify") + accept_length = C.verify_and_fix( + self.tree_draft_ids.numel(), self.tree_draft_ids.data_ptr(), self.tree_gt_ids.data_ptr(), + self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(), + self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr() + ) + # torch.cuda.nvtx.range_pop() + + accept_lengths.append(accept_length) + for temin in teminators: + if temin in self.tree_draft_ids[:accept_length]: + terminal = True + append_length = min(accept_length, generation_length - 1 - i) + tokens[1+i:1+i+append_length].copy_(self.tree_draft_ids[:append_length]) + self.tree_draft_ids[0] = self.tree_draft_ids[accept_length - 1] + i += accept_length + torch.cuda.synchronize() + decode_time = time.time() - decode_start + tokens = tokens[:1+i].tolist() + + return tokens, accept_lengths, decode_time, prefill_time \ No newline at end of file diff --git a/examples/CPM.cu/cpmcu/speculative/tree_drafter_base_quant/__init__.py b/examples/CPM.cu/cpmcu/speculative/tree_drafter_base_quant/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/CPM.cu/cpmcu/speculative/tree_drafter_base_quant/tree_drafter_w4a16_gptq_marlin.py b/examples/CPM.cu/cpmcu/speculative/tree_drafter_base_quant/tree_drafter_w4a16_gptq_marlin.py new file mode 100644 index 00000000..63ccf306 --- /dev/null +++ b/examples/CPM.cu/cpmcu/speculative/tree_drafter_base_quant/tree_drafter_w4a16_gptq_marlin.py @@ -0,0 +1,240 @@ +from ... import C +from ...llm_w4a16_gptq_marlin import W4A16GPTQMarlinLLM + +import torch +from ..tree_drafter import * +import time +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS +import torch.nn.functional as F + +class W4A16GPTQMarlinLLM_with_tree_drafter(W4A16GPTQMarlinLLM): + def __init__(self, + drafter_type, drafter_path, base_path, + tree_size, + use_rope: bool=False, + temperature: float=0.0, + **kwargs): + super().__init__(base_path, **kwargs) + + self.drafter_type = drafter_type + self.drafter_path = drafter_path + self.base_path = base_path + self.use_rope = use_rope + + self.tree_size = tree_size + self.tree_draft_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda") + self.tree_position_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda") + self.tree_gt_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda") + self.tree_attn_mask = torch.empty((tree_size), dtype=torch.uint64, device="cuda") + self.tree_parent = torch.empty((tree_size), dtype=torch.int32, device="cuda") + self.tree_position_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda") + self.temperature = temperature + + self.cache_length = torch.tensor([0], dtype=torch.int32, device="cuda") + + def load_from_hf(self): + with torch.no_grad(): + self._load_from_ckpt(self.drafter_path, cls=self.drafter_type) + + if self.use_rope: + if hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None: + rope_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type")) + else: + rope_type = "default" + # TODO only support "default", "llama3" or "longrope" with long_factor=short_factor + inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](self.config, "cpu", seq_len=self.max_total_length) + # attention_scaling = torch.tensor([attention_scaling], dtype=torch.float32, device="cpu") + self._load(f"{self.drafter_type}.rotary_emb.inv_freq", inv_freq, dtype=torch.float32) + # self._load("model.rotary_emb.attention_scaling", attention_scaling, dtype=torch.float32) + + super().load_from_hf() + + def generate(self, input_ids, generation_length=100, teminators=[], use_stream=False): + """ + Generate text with optional streaming output for quantized tree drafter. + Returns (tokens, accept_lengths, decode_time, prefill_time) if use_stream=False, or generator yielding {'token', 'text', 'is_finished', 'accept_length', 'prefill_time', 'decode_time'} if use_stream=True. + """ + assert input_ids.dtype == torch.int32 + + prefix_length = input_ids.numel() + # Check if input length exceeds maximum supported length + if prefix_length > self.max_total_length: + raise ValueError(f"Input token count ({prefix_length}) exceeds maximum supported length ({self.max_total_length}) under current memory limit") + + position_ids = torch.arange(prefix_length, dtype=torch.int32, device="cuda") + + # Set progress flag before prefill for stream mode + if use_stream: + self._show_prefill_progress = True + else: + self._show_prefill_progress = False + + # Measure prefill time + if self.use_enter and use_stream: + # In use_enter mode, timing will be handled inside prefill method + logits = self.prefill(input_ids, position_ids) + prefill_time = getattr(self, '_last_prefill_time', 0.0) # Get actual prefill time + else: + torch.cuda.synchronize() + prefill_start = time.time() + logits = self.prefill(input_ids, position_ids) + torch.cuda.synchronize() + prefill_time = time.time() - prefill_start + + if self.temperature > 0.0: + self.tree_draft_ids[:1].copy_(torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator)) + else: + self.tree_draft_ids[:1].copy_(logits[0].argmax(dim=-1)) + + # Wait for user input before decode phase if use_decode_enter is enabled + if self.use_decode_enter: + if use_stream and self.use_enter: + # In stream mode with use_enter, we already showed [Decoding], just wait for input + print("Please Press Enter to Start Decoding...", end="", flush=True) + input() # Wait for Enter key + print("\r" + " " * 50 + "\r", end="", flush=True) # Clear the prompt without showing [Decoding] again + else: + # In other modes, show prompt and wait + print("Please Press Enter to Start Decoding...", end="", flush=True) + input() # Wait for Enter key + print("\r" + " " * 50 + "\r[Decoding]", flush=True) # Show [Decoding] only when use_enter is not enabled + + if use_stream: + # Stream generation for quantized tree drafter (optimized) + def _stream_generator(): + # Keep minimal context for correct spacing + prev_token = None + + # yield first token + token = self.tree_draft_ids[0].item() + text = self.tokenizer.decode([token], skip_special_tokens=False) + prev_token = token + + yield { + 'token': token, + 'text': text, + 'is_finished': token in teminators, + 'accept_length': 1, + 'prefill_time': prefill_time, + 'decode_time': 0.0 # First token comes from prefill + } + + if token in teminators: + return + + decode_start_time = time.time() + + i = 0 + while i < generation_length-1: + self.cache_length[0] = prefix_length + i + + # draft step + C.draft(self.tree_draft_ids.data_ptr(), self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(), self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr()) + + logits = self.decode(self.tree_draft_ids, self.tree_position_ids, self.cache_length, mask_2d=self.tree_attn_mask) + if self.temperature > 0.0: + self.tree_gt_ids.copy_(torch.multinomial(F.softmax(logits/self.temperature, dim=-1), num_samples=1, generator=self.generator).squeeze(-1)) + else: + self.tree_gt_ids.copy_(logits.argmax(dim=-1)) + + # verify step + accept_length = C.verify_and_fix( + self.tree_draft_ids.numel(), self.tree_draft_ids.data_ptr(), self.tree_gt_ids.data_ptr(), + self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(), + self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr() + ) + + # yield accepted tokens (optimized with minimal context) + if accept_length > 0: + accepted_tokens = self.tree_draft_ids[:accept_length].tolist() + + # For correct spacing, decode with previous token context + if prev_token is not None: + context_tokens = [prev_token] + accepted_tokens + context_text = self.tokenizer.decode(context_tokens, skip_special_tokens=False) + prev_text = self.tokenizer.decode([prev_token], skip_special_tokens=False) + new_text = context_text[len(prev_text):] + else: + new_text = self.tokenizer.decode(accepted_tokens, skip_special_tokens=False) + + # Yield tokens with batch text for first token, empty for others + for j in range(accept_length): + if i + j >= generation_length - 1: + break + + token = accepted_tokens[j] + + # Give all new text to first token, empty to others + if j == 0: + text = new_text + else: + text = "" + + terminal = token in teminators + is_finished = terminal or (i + j == generation_length - 2) + + # Only calculate time for the last token in the batch to reduce overhead + decode_time = time.time() - decode_start_time if j == accept_length - 1 else 0.0 + + yield { + 'token': token, + 'text': text, + 'is_finished': is_finished, + 'accept_length': accept_length if j == 0 else 0, # only report accept_length for first token in batch + 'prefill_time': 0.0, # Only report prefill_time for first token + 'decode_time': decode_time + } + + if terminal: + return + + # Update prev_token to the last accepted token + prev_token = accepted_tokens[-1] + + self.tree_draft_ids[0] = self.tree_draft_ids[accept_length - 1] + i += accept_length + + return _stream_generator() + else: + # Original batch generation + tokens = torch.empty((generation_length), dtype=torch.int32, device="cuda") + tokens[0].copy_(self.tree_draft_ids[0]) + accept_lengths = [] + i = 0 + terminal = False + torch.cuda.synchronize() + decode_start = time.time() + while i < generation_length-1 and not terminal: + self.cache_length[0] = prefix_length + i + + # torch.cuda.nvtx.range_push(f"draft") + C.draft(self.tree_draft_ids.data_ptr(), self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(), self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr()) + # torch.cuda.nvtx.range_pop() + + logits = self.decode(self.tree_draft_ids, self.tree_position_ids, self.cache_length, mask_2d=self.tree_attn_mask) + if self.temperature > 0.0: + self.tree_gt_ids.copy_(torch.multinomial(F.softmax(logits/self.temperature, dim=-1), num_samples=1, generator=self.generator).squeeze(-1)) + else: + self.tree_gt_ids.copy_(logits.argmax(dim=-1)) + + # torch.cuda.nvtx.range_push(f"verify") + accept_length = C.verify_and_fix( + self.tree_draft_ids.numel(), self.tree_draft_ids.data_ptr(), self.tree_gt_ids.data_ptr(), + self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(), + self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr() + ) + # torch.cuda.nvtx.range_pop() + + accept_lengths.append(accept_length) + for temin in teminators: + if temin in self.tree_draft_ids[:accept_length]: + terminal = True + append_length = min(accept_length, generation_length - 1 - i) + tokens[1+i:1+i+append_length].copy_(self.tree_draft_ids[:append_length]) + self.tree_draft_ids[0] = self.tree_draft_ids[accept_length - 1] + i += accept_length + torch.cuda.synchronize() + decode_time = time.time() - decode_start + tokens = tokens[:1+i].tolist() + + return tokens, accept_lengths, decode_time, prefill_time \ No newline at end of file diff --git a/examples/CPM.cu/model_convert/convert_llama_format.py b/examples/CPM.cu/model_convert/convert_llama_format.py new file mode 100644 index 00000000..ececfcb2 --- /dev/null +++ b/examples/CPM.cu/model_convert/convert_llama_format.py @@ -0,0 +1,44 @@ +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch +import math + +torch.manual_seed(0) + +def llm_load(path): + tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True) + + return model, tokenizer + +def convert_llm(): + # model.embed_tokens.weight * scale_emb + state_dict["model.embed_tokens.weight"] = state_dict["model.embed_tokens.weight"] * scale_emb + + # lm_head.weight / (hidden_size / dim_model_base) + state_dict["lm_head.weight"] = state_dict["lm_head.weight"] / (hidden_size / dim_model_base) + + for i in range(num_layers): + attn_out_name = f"model.layers.{i}.self_attn.o_proj.weight" + state_dict[attn_out_name] = state_dict[attn_out_name] * (scale_depth / math.sqrt(num_layers)) + + ffn_down_proj_name = f"model.layers.{i}.mlp.down_proj.weight" + state_dict[ffn_down_proj_name] = state_dict[ffn_down_proj_name] * (scale_depth / math.sqrt(num_layers)) + + torch.save(state_dict, "./pytorch_model.bin") + +if __name__ == "__main__": + model, tokenizer = llm_load("/DATA/disk0/zhaoweilun/minicpm4/models/stable_7T_decay_700B_decay2_300B_longdecay_1sw1fa_sft_50B_release") + + scale_emb = model.config.scale_emb + dim_model_base = model.config.dim_model_base + scale_depth = model.config.scale_depth + num_layers = model.config.num_hidden_layers + hidden_size = model.config.hidden_size + print(f"scale_emb = {scale_emb}") + print(f"dim_model_base = {dim_model_base}") + print(f"scale_depth = {scale_depth}") + print(f"num_layers = {num_layers}") + print(f"hidden_size = {hidden_size}") + + state_dict = model.state_dict() + convert_llm() \ No newline at end of file diff --git a/examples/CPM.cu/model_convert/convert_w4a16.py b/examples/CPM.cu/model_convert/convert_w4a16.py new file mode 100644 index 00000000..d2ed26c0 --- /dev/null +++ b/examples/CPM.cu/model_convert/convert_w4a16.py @@ -0,0 +1,287 @@ +import torch +from safetensors.torch import load_file, save_file +import os, glob, shutil +import re +from typing import List +import argparse +import numpy as np +from transformers import AutoConfig + +parser = argparse.ArgumentParser() + +parser.add_argument("--model-path", type=str, required=True, help="Path to the original model") +parser.add_argument("--quant-path", type=str, required=True, help="Path to the AutoGPTQ model") +parser.add_argument("--output-path", type=str, required=True, help="Path to save the converted model") + +# Copied from https://github.com/AutoGPTQ/AutoGPTQ/blob/9f7d37072917ab3a7545835f23e808294a542153/auto_gptq/nn_modules/qlinear/qlinear_marlin.py +def get_perms(): + perm = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm) + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + perm = perm.reshape((-1, 8))[:, interleave].ravel() + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm, scale_perm, scale_perm_single + +PERM, SCALE_PERM, SCALE_PERM_SINGLE = get_perms() + +def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, group_size: int) -> torch.Tensor: + + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(SCALE_PERM)))[:, SCALE_PERM] + else: + s = s.reshape((-1, len(SCALE_PERM_SINGLE)))[:, SCALE_PERM_SINGLE] + s = s.reshape((-1, size_n)).contiguous() + + return s + +def marlin_repack_qweight(qweight: torch.Tensor, bits: int, size_k: int, size_n: int, tile: int = 16) -> torch.Tensor: + + # unpack + compress_factor = 32 // bits + mask = (1 << bits) - 1 + qweight = qweight.cpu().numpy().astype(np.uint32) + unpacked_qweight = np.zeros((size_k, size_n), dtype=np.uint32) + unpacked_offset = np.arange(size_k) // compress_factor + unpacked_shift = (np.arange(size_k) % compress_factor) * bits + unpacked_qweight = (qweight[unpacked_offset, :] >> unpacked_shift[:, None]) & mask + + # permute + unpacked_qweight = torch.from_numpy(unpacked_qweight.astype(np.int32)) + unpacked_qweight = unpacked_qweight.reshape((size_k // tile, tile, size_n // tile, tile)) + unpacked_qweight = unpacked_qweight.permute(0, 2, 1, 3) + unpacked_qweight = unpacked_qweight.reshape((size_k // tile, size_n * tile)).contiguous() + unpacked_qweight = unpacked_qweight.reshape((-1, PERM.numel()))[:, PERM].reshape(unpacked_qweight.shape) + + # repack + repacked_qweight = np.zeros((unpacked_qweight.shape[0], unpacked_qweight.shape[1] // compress_factor), dtype=np.uint32) + unpacked_qweight = unpacked_qweight.cpu().numpy().astype(np.uint32) + for i in range(compress_factor): + repacked_qweight |= unpacked_qweight[:, i::compress_factor] << (bits * i) + repacked_qweight = torch.from_numpy(repacked_qweight.astype(np.int32)) + + return repacked_qweight + +def convert_w4a16_checkpoint(orig_model_path, quant_path, output_path): + + config = AutoConfig.from_pretrained(quant_path) + + group_size = config.quantization_config['group_size'] + assert group_size in [-1, 128], "Only group_size -1 and 128 are supported for marlin" + + bits = config.quantization_config['bits'] + assert bits == 4, "Only 4-bit quantization is supported for marlin" + + model_path = glob.glob(os.path.join(quant_path, "*.safetensors"))[0] + + autogptq_weigths = load_file(model_path) + + gptq_convert_dict = { + "model.layers.{}.self_attn.q_proj.qweight": ["model.layers.{}.self_attn.q_proj.scales", "model.layers.{}.self_attn.q_proj.g_idx", "model.layers.{}.self_attn.q_proj.qzeros"], + "model.layers.{}.self_attn.k_proj.qweight":["model.layers.{}.self_attn.k_proj.scales", "model.layers.{}.self_attn.k_proj.g_idx", "model.layers.{}.self_attn.k_proj.qzeros"], + "model.layers.{}.self_attn.v_proj.qweight":["model.layers.{}.self_attn.v_proj.scales", "model.layers.{}.self_attn.v_proj.g_idx", "model.layers.{}.self_attn.v_proj.qzeros"], + "model.layers.{}.self_attn.o_proj.qweight":["model.layers.{}.self_attn.o_proj.scales", "model.layers.{}.self_attn.o_proj.g_idx", "model.layers.{}.self_attn.o_proj.qzeros"], + "model.layers.{}.mlp.gate_proj.qweight":["model.layers.{}.mlp.gate_proj.scales", "model.layers.{}.mlp.gate_proj.g_idx", "model.layers.{}.mlp.gate_proj.qzeros"], + "model.layers.{}.mlp.up_proj.qweight": ["model.layers.{}.mlp.up_proj.scales", "model.layers.{}.mlp.up_proj.g_idx", "model.layers.{}.mlp.up_proj.qzeros"], + "model.layers.{}.mlp.down_proj.qweight": ["model.layers.{}.mlp.down_proj.scales", "model.layers.{}.mlp.down_proj.g_idx", "model.layers.{}.mlp.down_proj.qzeros"], + "fc.qweight": ["fc.scales", "fc.g_idx", "fc.qzeros"], + } + + convert_checkpoint = {} + processed_keys = set() + + for gptq_key in autogptq_weigths: + if gptq_key in processed_keys: + continue + elif "layers" in gptq_key: + abstract_key = re.sub(r'(\d+)', '{}', gptq_key) + layer_num = re.search(r'\d+', gptq_key).group(0) + if "q_proj" in abstract_key: + if abstract_key.endswith('qweight'): + k_key = gptq_key.replace('q_proj', 'k_proj') + v_key = gptq_key.replace('q_proj', 'v_proj') + + q_weight = autogptq_weigths[gptq_key].clone().cuda() + k_weight = autogptq_weigths[k_key].clone().cuda() + v_weight = autogptq_weigths[v_key].clone().cuda() + x = torch.cat([q_weight, k_weight, v_weight], dim=-1) + + shape_0 = x.shape[0] * 8 + shape_1 = x.shape[1] + x = marlin_repack_qweight(x, bits, shape_0, shape_1) + + convert_checkpoint[gptq_key.replace("q_proj", "qkv_proj")] = x.cpu() + + processed_keys.add(gptq_key) + processed_keys.add(k_key) + processed_keys.add(v_key) + + for q_keys in gptq_convert_dict[abstract_key]: + if q_keys.endswith("scales"): + k_q_keys = q_keys.replace("q_proj", "k_proj") + v_q_keys = q_keys.replace("q_proj", "v_proj") + + scales_x_q = autogptq_weigths[q_keys.format(layer_num)].clone().cuda() + scales_x_k = autogptq_weigths[k_q_keys.format(layer_num)].clone().cuda() + scales_x_v = autogptq_weigths[v_q_keys.format(layer_num)].clone().cuda() + scales_x = torch.cat([scales_x_q, scales_x_k, scales_x_v], dim=-1) + scales_x.data = marlin_permute_scales(scales_x.data.contiguous(), + size_k=shape_0, + size_n=shape_1, + group_size=group_size) + convert_checkpoint[q_keys.format(layer_num).replace("q_proj", "qkv_proj")] = scales_x.cpu() + + processed_keys.add(q_keys.format(layer_num)) + processed_keys.add(q_keys.replace("q_proj", "k_proj").format(layer_num)) + processed_keys.add(q_keys.replace("q_proj", "v_proj").format(layer_num)) + elif "gate_proj" in abstract_key: + if abstract_key.endswith('qweight'): + up_key = gptq_key.replace('gate_proj', 'up_proj') + + gate_weight = autogptq_weigths[gptq_key].clone().cuda() + up_weight = autogptq_weigths[up_key].clone().cuda() + + x = torch.cat([gate_weight, up_weight], dim=-1) + + shape_0 = x.shape[0] * 8 + shape_1 = x.shape[1] + x = marlin_repack_qweight(x, bits, shape_0, shape_1) + + convert_checkpoint[gptq_key.replace("gate_proj", "gate_up_proj")] = x.cpu() + + processed_keys.add(gptq_key) + processed_keys.add(up_key) + + for q_keys in gptq_convert_dict[abstract_key]: + if q_keys.endswith("scales"): + up_q_keys = q_keys.replace("gate_proj", "up_proj") + + scales_x_gate = autogptq_weigths[q_keys.format(layer_num)].clone().cuda() + scales_x_up = autogptq_weigths[up_q_keys.format(layer_num)].clone().cuda() + scales_x = torch.cat([scales_x_gate, scales_x_up], dim=-1) + scales_x.data = marlin_permute_scales(scales_x.data.contiguous(), + size_k=shape_0, + size_n=shape_1, + group_size=group_size) + convert_checkpoint[q_keys.format(layer_num).replace("gate_proj", "gate_up_proj")] = scales_x.cpu() + + processed_keys.add(q_keys.format(layer_num)) + processed_keys.add(q_keys.replace("gate_proj", "up_proj").format(layer_num)) + + elif "down_proj" in abstract_key or "o_proj" in abstract_key: + if abstract_key.endswith('qweight'): + x = autogptq_weigths[gptq_key].clone().cuda() + + shape_0 = x.shape[0] * 8 + shape_1 = x.shape[1] + x = marlin_repack_qweight(x, bits, shape_0, shape_1) + + convert_checkpoint[gptq_key] = x.cpu() + + processed_keys.add(gptq_key) + + for q_keys in gptq_convert_dict[abstract_key]: + if q_keys.endswith("scales"): + + scales_x = autogptq_weigths[q_keys.format(layer_num)].clone().cuda() + scales_x.data = marlin_permute_scales(scales_x.data.contiguous(), + size_k=shape_0, + size_n=shape_1, + group_size=group_size) + convert_checkpoint[q_keys.format(layer_num)] = scales_x.cpu() + + processed_keys.add(q_keys.format(layer_num)) + + elif "post_attention_layernorm" in gptq_key or "input_layernorm" in gptq_key: + convert_checkpoint[gptq_key] = autogptq_weigths[gptq_key].clone() + elif "fc" in gptq_key and autogptq_weigths[gptq_key].dtype == torch.int32: + if gptq_key.endswith('qweight'): + fc_qweight = autogptq_weigths[gptq_key].clone().cuda() + packed_in_features_x_2, out_features = fc_qweight.shape + packed_in_features = packed_in_features_x_2 // 2 + in_features = packed_in_features * 32 // bits + fc1_weight = fc_qweight[:packed_in_features, :].contiguous() + fc2_weight = fc_qweight[packed_in_features:, :].contiguous() + + fc1_weight = marlin_repack_qweight(fc1_weight, bits, in_features, out_features) + fc2_weight = marlin_repack_qweight(fc2_weight, bits, in_features, out_features) + + convert_checkpoint[gptq_key] = torch.cat([fc1_weight, fc2_weight], dim=-1).cpu() + processed_keys.add(gptq_key) + + for fc_key in gptq_convert_dict[gptq_key]: + if fc_key.endswith("scales"): + fc_scales = autogptq_weigths[gptq_key.replace("qweight", "scales")].clone().cuda() + fc_scales_1 = fc_scales[:in_features // group_size, :].contiguous() + fc_scales_2 = fc_scales[in_features // group_size:, :].contiguous() + + fc_scales_1 = marlin_permute_scales( + fc_scales_1.data.contiguous(), + size_k=in_features, + size_n=out_features, + group_size=group_size + ) + fc_scales_2 = marlin_permute_scales( + fc_scales_2.data.contiguous(), + size_k=in_features, + size_n=out_features, + group_size=group_size + ) + # convert_checkpoint[q_keys.format(layer_num).replace("gate_proj", "gate_up_proj")] = scales_x.cpu() + convert_checkpoint[gptq_key.replace("qweight", "scales")] = torch.cat([fc_scales_1, fc_scales_2], dim=-1).cpu() + processed_keys.add(gptq_key.replace("qweight", "scales")) + else: + convert_checkpoint[gptq_key] = autogptq_weigths[gptq_key].clone() + + save_file(convert_checkpoint, os.path.join(output_path, f"model_gptq.safetensors")) + # copy quantization config + config_list = glob.glob(os.path.join(quant_path, "*config.json")) + for config_file in config_list: + # copy config to output path + config_filename = os.path.basename(config_file) + dst_path = os.path.join(output_path, config_filename) + shutil.copy2(config_file, dst_path) + + # copy tokenizer + tokenizer_list = glob.glob(os.path.join(orig_model_path, "tokenizer*")) + for tokenizer_file in tokenizer_list: + # copy config to output path + tokenizer_filename = os.path.basename(tokenizer_file) + dst_path = os.path.join(output_path, tokenizer_filename) + shutil.copy2(tokenizer_file, dst_path) + + # copy "special_tokens_map.json" + special_tokens_map_file = glob.glob(os.path.join(orig_model_path, "special_tokens_map.json"))[0] + special_tokens_map_basename = os.path.basename(special_tokens_map_file) + dst_path = os.path.join(output_path, special_tokens_map_basename) + shutil.copy2(special_tokens_map_file, dst_path) + +if __name__=="__main__": + + args = parser.parse_args() + orig_model_path = args.model_path + quant_path = args.quant_path + output_path = args.output_path + + os.makedirs(output_path, exist_ok=True) + + convert_w4a16_checkpoint(orig_model_path, quant_path, output_path) \ No newline at end of file diff --git a/examples/CPM.cu/model_convert/post_process_w4a16_eagle.py b/examples/CPM.cu/model_convert/post_process_w4a16_eagle.py new file mode 100644 index 00000000..f9a940d3 --- /dev/null +++ b/examples/CPM.cu/model_convert/post_process_w4a16_eagle.py @@ -0,0 +1,48 @@ +import os +import torch +import argparse +from safetensors.torch import load_file, save_file + +parser = argparse.ArgumentParser() +parser.add_argument("--fp-model-path", type=str, required=True, help="Path to the fp model") +parser.add_argument("--quant-model-path", type=str, required=True, help="Path to the quantized model") +parser.add_argument("--output-path", type=str, required=True, help="Path to save the converted model") + +def post_process_eagle_w4a16_ckpt(fp_model_path, quant_model_path, output_path): + fp_model = torch.load(os.path.join(fp_model_path, "pytorch_model.bin")) + quant_model = load_file(os.path.join(quant_model_path, "model_gptq.safetensors")) + + new_state_dict = {} + + assert (fp_model["embed_tokens.weight"].to(torch.float16) == quant_model["model.embed_tokens.weight"].cuda().to(torch.float16)).all(), "embed_tokens.weight mismatch" + new_state_dict["embed_tokens.weight"] = fp_model["embed_tokens.weight"].to(torch.float16) + + if "fc.weight" in quant_model.keys(): + assert (fp_model["fc.weight"].to(torch.float16) == quant_model["fc.weight"].cuda().to(torch.float16)).all(), "fc.weight mismatch" + new_state_dict["fc.weight"] = fp_model["fc.weight"].to(torch.float16) + elif "fc.qweight" in quant_model.keys(): + new_state_dict["fc.qweight"] = quant_model["fc.qweight"] + new_state_dict["fc.scales"] = quant_model["fc.scales"] + + new_state_dict["input_norm1.weight"] = fp_model["input_norm1.weight"].to(torch.float16) + new_state_dict["input_norm2.weight"] = fp_model["input_norm2.weight"].to(torch.float16) + + for key, value in quant_model.items(): + if "model.layers." in key: + new_key = key.replace("model.", "") + new_state_dict[new_key] = value + + save_file(new_state_dict, os.path.join(output_path, f"model_gptq.safetensors")) + + os.system(f"cp {quant_model_path}/*.json {output_path}") + +if __name__ == "__main__": + args = parser.parse_args() + fp_model_path = args.fp_model_path + quant_model_path = args.quant_model_path + output_path = args.output_path + + os.makedirs(output_path, exist_ok=True) + + post_process_eagle_w4a16_ckpt(fp_model_path, quant_model_path, output_path) + \ No newline at end of file diff --git a/examples/CPM.cu/scripts/fr_spec/gen_fr_index.py b/examples/CPM.cu/scripts/fr_spec/gen_fr_index.py new file mode 100644 index 00000000..b556ca95 --- /dev/null +++ b/examples/CPM.cu/scripts/fr_spec/gen_fr_index.py @@ -0,0 +1,89 @@ +from datasets import load_dataset +from transformers import AutoTokenizer +from collections import Counter +from tqdm import tqdm +import torch +import argparse +import os + +def main(args): + print(f"Generating FR index for {args.model_name} with vocab size {args.vocab_size}") + print(f"This may take about 5-10 minutes with 1M lines on good network connection") + print("Loading dataset...") + ds = load_dataset('Salesforce/wikitext', 'wikitext-103-raw-v1', streaming=True)['train'] + # Only take the number of samples we need to process + ds = ds.take(args.num_lines + 1) # +1 to account for 0-indexing + print(f"Dataset limited to {args.num_lines + 1} samples") + + print("Loading tokenizer...") + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + print("Tokenizer loaded successfully") + + token_counter = Counter() + num_lines = args.num_lines + num_tokens = 0 + print("Starting to process data...") + for i, d in tqdm(enumerate(ds)): + tokens = tokenizer.encode(d['text']) + token_counter.update(tokens) + num_tokens += len(tokens) + if i == num_lines: + break + + sort_by_freq = sorted(token_counter.items(), key=lambda x: x[1], reverse=True) + ids, frequencies = zip(*sort_by_freq) + ids = list(ids) + + print(f"processed {num_lines} items") + print(f"processed {num_tokens} tokens") + print(f"unique tokens: {len(ids)}") + + os.makedirs(f'fr_index/{args.model_name}', exist_ok=True) + + for r in args.vocab_size: + eos_id = tokenizer.encode(tokenizer.special_tokens_map['eos_token']) + if eos_id not in ids[:r]: + not_in_ids = len(set(eos_id) - set(ids[:r])) + freq_ids = ids[:r - not_in_ids] + eos_id + else: + freq_ids = ids[:r] + if (r != len(freq_ids)): + print(f"Warning: requested vocab_size {r} but actual size: {len(freq_ids)}, file not saved") + else: + pt_path = f'fr_index/{args.model_name}/freq_{r}.pt' + print(f'save {pt_path}, actual size: {len(freq_ids)}') + with open(pt_path, 'wb') as f: + torch.save(freq_ids, f) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument( + '--model_name', + type=str, + default='MiniCPM4-8B', + help='The name of the model.' + ) + parser.add_argument( + '--model_path', + type=str, + default='openbmb/MiniCPM4-8B', + help='The path to the model.' + ) + parser.add_argument( + '--num_lines', + type=int, + default=1000000, + help='The number of lines to process.' + ) + parser.add_argument( + '--vocab_size', + nargs='+', + type=int, + default=[8192, 16384, 32768], + help='The vocab sizes to process.' + ) + + args = parser.parse_args() + print(args) + main(args) diff --git a/examples/CPM.cu/scripts/model_convert/convert_w4a16.sh b/examples/CPM.cu/scripts/model_convert/convert_w4a16.sh new file mode 100644 index 00000000..15320d1f --- /dev/null +++ b/examples/CPM.cu/scripts/model_convert/convert_w4a16.sh @@ -0,0 +1,9 @@ +Model_Path=/yourpath/minicpm4_mupformat +Quant_Path=/yourpath/minicpm4_autogptq +Output_Path=/yourpath/minicpm4_marlin + + +python model_convert/convert_w4a16.py \ + --model-path $Model_Path \ + --quant-path $Quant_Path \ + --output-path $Output_Path \ No newline at end of file diff --git a/examples/CPM.cu/setup.py b/examples/CPM.cu/setup.py new file mode 100644 index 00000000..e19f9d7b --- /dev/null +++ b/examples/CPM.cu/setup.py @@ -0,0 +1,317 @@ +import os, glob +from setuptools import setup, find_packages + +this_dir = os.path.dirname(os.path.abspath(__file__)) + +def detect_cuda_arch(): + """Automatically detect current CUDA architecture""" + # 1. First check if environment variable specifies architecture + env_arch = os.getenv("CPMCU_CUDA_ARCH") + if env_arch: + # Only support simple comma-separated format, e.g., "80,86" + arch_list = [] + tokens = env_arch.split(',') + + for token in tokens: + token = token.strip() + if not token: + continue + + # Check format: must be pure digits + if not token.isdigit(): + raise ValueError( + f"Invalid CUDA architecture format: '{token}'. " + f"CPMCU_CUDA_ARCH should only contain comma-separated numbers like '80,86'" + ) + + arch_list.append(token) + + if arch_list: + print(f"Using CUDA architectures from environment variable: {arch_list}") + return arch_list + + # 2. Check if torch library is available, if so, auto-detect + try: + import torch + except ImportError: + # 3. If no environment variable and no torch, throw error + raise RuntimeError( + "CUDA architecture detection failed. Please either:\n" + "1. Set environment variable CPMCU_CUDA_ARCH (e.g., export CPMCU_CUDA_ARCH=90), or\n" + "2. Install PyTorch (pip install torch) for automatic detection.\n" + "Common CUDA architectures: 70 (V100), 75 (T4), 80 (A100), 86 (RTX 30xx), 89 (RTX 40xx), 90 (H100)" + ) + + # Use torch to auto-detect all GPU architectures + try: + if torch.cuda.is_available(): + arch_set = set() + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + arch = f"{major}{minor}" + arch_set.add(arch) + + arch_list = sorted(list(arch_set)) # Sort for consistency + print(f"Detected CUDA architectures: {arch_list} (from {device_count} GPU devices)") + return arch_list + else: + raise RuntimeError( + "No CUDA devices detected. Please either:\n" + "1. Set environment variable CPMCU_CUDA_ARCH (e.g., export CPMCU_CUDA_ARCH=90), or\n" + "2. Ensure CUDA devices are available and properly configured.\n" + "Common CUDA architectures: 70 (V100), 75 (T4), 80 (A100), 86 (RTX 30xx), 89 (RTX 40xx), 90 (H100)" + ) + except Exception as e: + raise RuntimeError( + f"CUDA architecture detection failed: {e}\n" + "Please set environment variable CPMCU_CUDA_ARCH (e.g., export CPMCU_CUDA_ARCH=90).\n" + "Common CUDA architectures: 70 (V100), 75 (T4), 80 (A100), 86 (RTX 30xx), 89 (RTX 40xx), 90 (H100)" + ) + +def append_nvcc_threads(nvcc_extra_args): + nvcc_threads = os.getenv("NVCC_THREADS") or "16" + return nvcc_extra_args + ["--threads", nvcc_threads] + +def get_compile_args(): + """Return different compilation arguments based on debug mode""" + debug_mode = os.getenv("CPMCU_DEBUG", "0").lower() in ("1", "true", "yes") + perf_mode = os.getenv("CPMCU_PERF", "0").lower() in ("1", "true", "yes") + + # Check precision type environment variable + dtype_env = os.getenv("CPMCU_DTYPE", "fp16").lower() + + # Parse precision type list + dtype_list = [dtype.strip() for dtype in dtype_env.split(',')] + + # Validate precision types + valid_dtypes = {"fp16", "bf16"} + invalid_dtypes = [dtype for dtype in dtype_list if dtype not in valid_dtypes] + if invalid_dtypes: + raise ValueError( + f"Invalid CPMCU_DTYPE values: {invalid_dtypes}. " + f"Supported values: 'fp16', 'bf16', 'fp16,bf16'" + ) + + # Deduplicate and generate compilation definitions + dtype_set = set(dtype_list) + dtype_defines = [] + if "fp16" in dtype_set: + dtype_defines.append("-DENABLE_DTYPE_FP16") + if "bf16" in dtype_set: + dtype_defines.append("-DENABLE_DTYPE_BF16") + + # Print compilation information + if len(dtype_set) == 1: + dtype_name = list(dtype_set)[0].upper() + print(f"Compiling with {dtype_name} support only") + else: + dtype_names = [dtype.upper() for dtype in sorted(dtype_set)] + print(f"Compiling with {' and '.join(dtype_names)} support") + + # Common compilation arguments + common_cxx_args = ["-std=c++17"] + dtype_defines + common_nvcc_args = [ + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + ] + dtype_defines + + if debug_mode: + print("Debug mode enabled (CPMCU_DEBUG=1)") + cxx_args = common_cxx_args + [ + "-g3", # Most detailed debug information + "-O0", # Disable optimization + "-DDISABLE_MEMPOOL", + "-DDEBUG", + "-fno-inline", # Disable inlining + "-fno-omit-frame-pointer", # Keep frame pointer + ] + nvcc_base_args = common_nvcc_args + [ + "-O0", + "-g", # Host-side debug information + "-lineinfo", # Generate line number information + "-DDISABLE_MEMPOOL", + "-DDEBUG", + "-DCUDA_DEBUG", + "-Xcompiler", "-g3", # Pass to host compiler + "-Xcompiler", "-fno-inline", # Disable inlining + "-Xcompiler", "-fno-omit-frame-pointer", # Keep frame pointer + ] + # Debug mode linking arguments + link_args = ["-g", "-rdynamic"] + else: + print("Release mode enabled") + cxx_args = common_cxx_args + ["-O3"] + nvcc_base_args = common_nvcc_args + [ + "-O3", + "--use_fast_math", + ] + # Release mode linking arguments + link_args = [] + + # Add performance testing control + if perf_mode: + print("Performance monitoring enabled (CPMCU_PERF=1)") + cxx_args.append("-DENABLE_PERF") + nvcc_base_args.append("-DENABLE_PERF") + else: + print("Performance monitoring disabled (CPMCU_PERF=0)") + + return cxx_args, nvcc_base_args, link_args, dtype_set + +def get_all_headers(): + """Get all header files for dependency tracking""" + header_patterns = [ + "src/**/*.h", + "src/**/*.hpp", + "src/**/*.cuh", + "src/cutlass/include/**/*.h", + "src/cutlass/include/**/*.hpp", + "src/flash_attn/**/*.h", + "src/flash_attn/**/*.hpp", + "src/flash_attn/**/*.cuh", + ] + + headers = [] + for pattern in header_patterns: + abs_headers = glob.glob(os.path.join(this_dir, pattern), recursive=True) + # Convert to relative paths + rel_headers = [os.path.relpath(h, this_dir) for h in abs_headers] + headers.extend(rel_headers) + + # Filter out non-existent files (check absolute path but return relative path) + headers = [h for h in headers if os.path.exists(os.path.join(this_dir, h))] + + return headers + +def get_flash_attn_sources(enabled_dtypes): + """Get flash attention source files based on enabled data types""" + sources = [] + + for dtype in enabled_dtypes: + if dtype == "fp16": + # sources.extend(glob.glob("src/flash_attn/src/*hdim64_fp16*.cu")) + sources.extend(glob.glob("src/flash_attn/src/*hdim128_fp16*.cu")) + elif dtype == "bf16": + # sources.extend(glob.glob("src/flash_attn/src/*hdim64_bf16*.cu")) + sources.extend(glob.glob("src/flash_attn/src/*hdim128_bf16*.cu")) + + return sources + +# Try to build extension modules +ext_modules = [] +cmdclass = {} + +try: + # Try to import torch-related modules + from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + # Get CUDA architecture + arch_list = detect_cuda_arch() + + # Get compilation arguments + cxx_args, nvcc_base_args, link_args, dtype_set = get_compile_args() + + # Get header files + all_headers = get_all_headers() + + # Generate gencode arguments for each architecture + gencode_args = [] + arch_defines = [] + for arch in arch_list: + gencode_args.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"]) + arch_defines.append(f"-D_ARCH{arch}") + + print(f"Using CUDA architecture compile flags: {arch_list}") + + flash_attn_sources = get_flash_attn_sources(dtype_set) + + # Create Ninja build extension class + class NinjaBuildExtension(BuildExtension): + def __init__(self, *args, **kwargs) -> None: + # do not override env MAX_JOBS if already exists + if not os.environ.get("MAX_JOBS"): + import psutil + # calculate the maximum allowed NUM_JOBS based on cores + max_num_jobs_cores = max(1, os.cpu_count() // 2) + # calculate the maximum allowed NUM_JOBS based on free memory + free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB + max_num_jobs_memory = int(free_memory_gb / 9) # each JOB peak memory cost is ~8-9GB when threads = 4 + # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation + max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory)) + os.environ["MAX_JOBS"] = str(max_jobs) + super().__init__(*args, **kwargs) + + # Configure extension module + ext_modules = [ + CUDAExtension( + name='cpmcu.C', + sources = [ + "src/entry.cu", + "src/utils.cu", + "src/signal_handler.cu", + "src/perf.cu", + *glob.glob("src/qgemm/gptq_marlin/*cu"), + *flash_attn_sources, + ], + libraries=["cublas", "dl"], + depends=all_headers, + extra_compile_args={ + "cxx": cxx_args, + "nvcc": append_nvcc_threads( + nvcc_base_args + + gencode_args + + arch_defines + [ + # Add dependency file generation options + "-MMD", "-MP", + ] + ), + }, + extra_link_args=link_args, + include_dirs=[ + f"{this_dir}/src/flash_attn", + f"{this_dir}/src/flash_attn/src", + f"{this_dir}/src/cutlass/include", + f"{this_dir}/src/", + ], + ) + ] + + cmdclass = {'build_ext': NinjaBuildExtension} + +except Exception as e: + print(f"Warning: Unable to configure CUDA extension module: {e}") + print("Skipping extension module build, installing Python package only...") + +setup( + name='cpmcu', + version='1.0.0', + author_email="acha131441373@gmail.com", + description="cpm cuda implementation", + packages=find_packages(), + setup_requires=[ + "pybind11", + "psutil", + "ninja", + "torch", + ], + install_requires=[ + "transformers==4.46.2", + "accelerate==0.26.0", + "datasets", + "fschat", + "openai", + "anthropic", + "human_eval", + "zstandard", + "tree_sitter", + "tree-sitter-python" + ], + ext_modules=ext_modules, + cmdclass=cmdclass, +) \ No newline at end of file diff --git a/examples/CPM.cu/src/entry.cu b/examples/CPM.cu/src/entry.cu new file mode 100644 index 00000000..41b3acc3 --- /dev/null +++ b/examples/CPM.cu/src/entry.cu @@ -0,0 +1,534 @@ +#include <pybind11/pybind11.h> +#include <cuda_runtime.h> +#include <type_traits> +#include <stdexcept> + +#include "utils.cuh" +#include "trait.cuh" +#include "perf.cuh" +// base model +#include "model/model.cuh" +#include "model/w4a16_gptq_marlin/w4a16_gptq_marlin_model.cuh" +#include "model/minicpm4/minicpm4_model.cuh" +#include "model/minicpm4/minicpm4_w4a16_gptq_marlin_model.cuh" + +// eagle +#include "model/eagle.cuh" +#include "model/minicpm4/minicpm4_eagle.cuh" +#include "model/eagle_base_quant/eagle_base_w4a16_gptq_marlin.cuh" + +// spec model +#include "model/spec_quant/w4a16_gm_spec_w4a16_gm.cuh" + +// hier +#include "model/hier_spec_quant/hier_ea_w4a16_gm_spec_w4a16_gm.cuh" +#include "model/hier_spec_quant/hier_ea_w4a16_gm_rot_spec_w4a16_gm.cuh" + + +#if defined(ENABLE_DTYPE_FP16) && defined(ENABLE_DTYPE_BF16) +#define DTYPE_SWITCH(COND, ...) \ + [&] { \ + if (COND == 0) { \ + using elem_type = __half; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = __nv_bfloat16; \ + return __VA_ARGS__(); \ + } \ + }() +#elif defined(ENABLE_DTYPE_FP16) +#define DTYPE_SWITCH(COND, ...) \ + [&] { \ + if (COND != 0) { \ + throw std::runtime_error("BF16 support not compiled. Please recompile with CPMCU_DTYPE=bf16 or CPMCU_DTYPE=fp16,bf16"); \ + } \ + using elem_type = __half; \ + return __VA_ARGS__(); \ + }() +#elif defined(ENABLE_DTYPE_BF16) +#define DTYPE_SWITCH(COND, ...) \ + [&] { \ + if (COND == 0) { \ + throw std::runtime_error("FP16 support not compiled. Please recompile with CPMCU_DTYPE=fp16 or CPMCU_DTYPE=fp16,bf16"); \ + } \ + using elem_type = __nv_bfloat16; \ + return __VA_ARGS__(); \ + }() +#else +#error "At least one of ENABLE_DTYPE_FP16 or ENABLE_DTYPE_BF16 must be defined" +#endif + +#define MODEL_TYPE_SWITCH(MODEL_PTR, T, ...) \ + [&] { \ + if (dynamic_cast<MiniCPM4Impl<T>*>(MODEL_PTR)) { \ + using ModelType = MiniCPM4Impl<T>; \ + auto* typed_model = static_cast<MiniCPM4Impl<T>*>(MODEL_PTR); \ + return __VA_ARGS__(); \ + } else if (dynamic_cast<ModelImpl<T>*>(MODEL_PTR)) { \ + using ModelType = ModelImpl<T>; \ + auto* typed_model = static_cast<ModelImpl<T>*>(MODEL_PTR); \ + return __VA_ARGS__(); \ + } \ + else if (dynamic_cast<W4A16GPTQMarlinModelImpl<T>*>(MODEL_PTR)) { \ + using ModelType = W4A16GPTQMarlinModelImpl<T>; \ + auto* typed_model = static_cast<W4A16GPTQMarlinModelImpl<T>*>(MODEL_PTR); \ + return __VA_ARGS__(); \ + } else if (dynamic_cast<MiniCPM4W4A16GPTQMarlinModelImpl<T>*>(MODEL_PTR)) { \ + using ModelType = MiniCPM4W4A16GPTQMarlinModelImpl<T>; \ + auto* typed_model = static_cast<MiniCPM4W4A16GPTQMarlinModelImpl<T>*>(MODEL_PTR); \ + return __VA_ARGS__(); \ + } \ + }() + +#define EAGLE_QUANT_SWITCH(COND, T, ...) \ + [&] { \ + if (COND == true) { \ + using LayerType = W4A16GPTQMarlinLayer<T>; \ + using Fc1Type = W4A16GPTQMarlinLinear<T, true, true>; \ + using Fc2Type = W4A16GPTQMarlinLinear<T>; \ + return __VA_ARGS__(); \ + } else { \ + using LayerType = Layer<T>; \ + using Fc1Type = Linear<T, true, true>; \ + using Fc2Type = Linear<T>; \ + return __VA_ARGS__(); \ + } \ + }() + +Model* model; + +void init_base_model( + float memory_limit, + int vocab_size, + int num_hidden_layers, + int hidden_size, + int intermediate_size, + int num_attention_heads, + int num_key_value_heads, + int head_dim, + float rms_norm_eps, + int torch_dtype, + int chunk_length, + float scale_embed, + float scale_lmhead, + float scale_residual +) { + init_resources(); + + DTYPE_SWITCH(torch_dtype, [&] { + model = new ModelImpl<elem_type>( + memory_limit, + vocab_size, + num_hidden_layers, + hidden_size, + intermediate_size, + num_attention_heads, + num_key_value_heads, + head_dim, + rms_norm_eps, + chunk_length, + scale_embed, + scale_lmhead, + scale_residual + ); + }); + +} + +void init_minicpm4_model( + float memory_limit, + int vocab_size, + int num_hidden_layers, + int hidden_size, + int intermediate_size, + int num_attention_heads, + int num_key_value_heads, + int head_dim, + float rms_norm_eps, + int torch_dtype, + int chunk_length, + float scale_embed, + float scale_lmhead, + float scale_residual, + int sink_window_size, + int block_window_size, + int sparse_topk_k, + int sparse_switch, + bool apply_compress_lse +) { + init_resources(); + + DTYPE_SWITCH(torch_dtype, [&] { + model = new MiniCPM4Impl<elem_type>( + memory_limit, + vocab_size, + num_hidden_layers, + hidden_size, + intermediate_size, + num_attention_heads, + num_key_value_heads, + head_dim, + rms_norm_eps, + chunk_length, + scale_embed, + scale_lmhead, + scale_residual, + sink_window_size, + block_window_size, + sparse_topk_k, + sparse_switch, + apply_compress_lse + ); + }); + +} + +void init_w4a16_gptq_marlin_base_model( + float memory_limit, + int vocab_size, + int num_hidden_layers, + int hidden_size, + int intermediate_size, + int num_attention_heads, + int num_key_value_heads, + int head_dim, + float rms_norm_eps, + int group_size, + int torch_dtype, + int chunk_length, + float scale_embed, + float scale_lmhead, + float scale_residual +) { + init_resources(); + + DTYPE_SWITCH(torch_dtype, [&] { + model = new W4A16GPTQMarlinModelImpl<elem_type>( + memory_limit, + vocab_size, + num_hidden_layers, + hidden_size, + intermediate_size, + num_attention_heads, + num_key_value_heads, + head_dim, + rms_norm_eps, + group_size, + chunk_length, + scale_embed, + scale_lmhead, + scale_residual + ); + }); + +} + +void init_w4a16_gptq_marlin_minicpm4_model( + float memory_limit, + int vocab_size, + int num_hidden_layers, + int hidden_size, + int intermediate_size, + int num_attention_heads, + int num_key_value_heads, + int head_dim, + float rms_norm_eps, + int group_size, + int torch_dtype, + int chunk_length, + float scale_embed, + float scale_lmhead, + float scale_residual, + int sink_window_size, + int block_window_size, + int sparse_topk_k, + int sparse_switch, + bool apply_compress_lse +) { + init_resources(); + + DTYPE_SWITCH(torch_dtype, [&] { + model = new MiniCPM4W4A16GPTQMarlinModelImpl<elem_type>( + memory_limit, + vocab_size, + num_hidden_layers, + hidden_size, + intermediate_size, + num_attention_heads, + num_key_value_heads, + head_dim, + rms_norm_eps, + group_size, + chunk_length, + scale_embed, + scale_lmhead, + scale_residual, + sink_window_size, + block_window_size, + sparse_topk_k, + sparse_switch, + apply_compress_lse + ); + }); + +} + +// eagle model +void init_eagle_model( + int num_layers, + int num_iter, + int topk_per_iter, + int tree_size, + int torch_dtype +) { + bool dispatch_model = false; + DTYPE_SWITCH(torch_dtype, [&] { + MODEL_TYPE_SWITCH(model, elem_type, [&] { + dispatch_model = true; + model = new EagleImpl<elem_type, ModelType>( + typed_model, + num_layers, + num_iter, + topk_per_iter, + tree_size + ); + }); + }); + if (!dispatch_model) { + printf("Model type failed to dispatch: %s\n", typeid(*model).name()); + } +} + +void init_minicpm4_eagle_model( + int num_layers, + int num_iter, + int topk_per_iter, + int tree_size, + int torch_dtype, + bool apply_eagle_quant, + int group_size, + int eagle_window_size, + int frspec_vocab_size, + float residual_scale, + bool use_input_norm, + bool use_attn_norm +) { + bool dispatch_model = false; + DTYPE_SWITCH(torch_dtype, [&] { + MODEL_TYPE_SWITCH(model, elem_type, [&] { + dispatch_model = true; + EAGLE_QUANT_SWITCH(apply_eagle_quant, elem_type, [&] { + model = new MiniCPM4EagleImpl<elem_type, ModelType, LayerType, Fc1Type, Fc2Type>( + typed_model, + num_layers, + num_iter, + topk_per_iter, + tree_size, + group_size, + eagle_window_size, + frspec_vocab_size, + residual_scale, + use_input_norm, + use_attn_norm + ); + }); + }); + }); + if (!dispatch_model) { + printf("Model type failed to dispatch: %s\n", typeid(*model).name()); + } +} + +// spec model +void init_w4a16_gm_spec_w4a16_gm_model( + int draft_vocab_size, + int draft_num_hidden_layers, + int draft_hidden_size, + int draft_intermediate_size, + int draft_num_attention_heads, + int draft_num_key_value_heads, + int draft_head_dim, + float draft_rms_norm_eps, + int draft_group_size, + int num_iter, + bool draft_cuda_graph, + int torch_dtype +) { + DTYPE_SWITCH(torch_dtype, [&] { + model = new W4A16GMSpecW4A16GMImpl<elem_type>( + (W4A16GPTQMarlinModelImpl<elem_type>*)model, + draft_vocab_size, + draft_num_hidden_layers, + draft_hidden_size, + draft_intermediate_size, + draft_num_attention_heads, + draft_num_key_value_heads, + draft_head_dim, + draft_rms_norm_eps, + draft_group_size, + num_iter, + draft_cuda_graph + ); + }); +} + +// hier spec model +void init_hier_eagle_w4a16_gm_spec_w4a16_gm_model( + int draft_vocab_size, + int draft_num_hidden_layers, + int draft_hidden_size, + int draft_intermediate_size, + int draft_num_attention_heads, + int draft_num_key_value_heads, + int draft_head_dim, + float draft_rms_norm_eps, + int draft_group_size, + int min_draft_length, + bool draft_cuda_graph, + int ea_num_layers, + int ea_num_iter, + int ea_topk_per_iter, + int ea_tree_size, + bool draft_model_start, + int torch_dtype +) { + DTYPE_SWITCH(torch_dtype, [&] { + model = new HierEagleW4A16GMSpecW4A16GMImpl<elem_type>( + (W4A16GPTQMarlinModelImpl<elem_type>*)model, + draft_vocab_size, + draft_num_hidden_layers, + draft_hidden_size, + draft_intermediate_size, + draft_num_attention_heads, + draft_num_key_value_heads, + draft_head_dim, + draft_rms_norm_eps, + draft_group_size, + min_draft_length, + draft_cuda_graph, + ea_num_layers, + ea_num_iter, + ea_topk_per_iter, + ea_tree_size, + draft_model_start + ); + }); +} + +void init_hier_eagle_w4a16_gm_rot_spec_w4a16_gm_model( + int draft_vocab_size, + int draft_num_hidden_layers, + int draft_hidden_size, + int draft_intermediate_size, + int draft_num_attention_heads, + int draft_num_key_value_heads, + int draft_head_dim, + float draft_rms_norm_eps, + int draft_group_size, + int min_draft_length, + bool draft_cuda_graph, + int ea_num_layers, + int ea_num_iter, + int ea_topk_per_iter, + int ea_tree_size, + bool draft_model_start, + int torch_dtype +) { + DTYPE_SWITCH(torch_dtype, [&] { + model = new HierEagleW4A16GMRotSpecW4A16GMImpl<elem_type>( + (W4A16GPTQMarlinModelImpl<elem_type>*)model, + draft_vocab_size, + draft_num_hidden_layers, + draft_hidden_size, + draft_intermediate_size, + draft_num_attention_heads, + draft_num_key_value_heads, + draft_head_dim, + draft_rms_norm_eps, + draft_group_size, + min_draft_length, + draft_cuda_graph, + ea_num_layers, + ea_num_iter, + ea_topk_per_iter, + ea_tree_size, + draft_model_start + ); + }); +} + + +int init_storage() { + return model->init_storage(); +} + +void load_model(std::string name, std::uintptr_t param) { + model->load_to_storage(name, reinterpret_cast<void*>(param)); +} + +void prefill(int input_length, int history_length, std::uintptr_t input, std::uintptr_t position_ids, std::uintptr_t output) { + model->prefill(input_length, history_length, reinterpret_cast<int32_t*>(input), reinterpret_cast<int32_t*>(position_ids), (void*)(output)); +} + +void decode(int input_length, int padded_length, std::uintptr_t input, std::uintptr_t position_ids, std::uintptr_t cache_length, std::uintptr_t mask_2d, std::uintptr_t output, bool cuda_graph) { + if (cuda_graph) { + if (graphCreated_padding_length != padded_length || graphCreated_input_length != input_length) { + if (graphExec != nullptr) { + cudaGraphExecDestroy(graphExec); + graphExec = nullptr; + } + if (graph != nullptr) { + cudaGraphDestroy(graph); + graph = nullptr; + } + cudaStreamBeginCapture(calc_stream.stream, cudaStreamCaptureModeGlobal); + model->decode(input_length, padded_length, reinterpret_cast<int32_t*>(input), reinterpret_cast<int32_t*>(position_ids), reinterpret_cast<int32_t*>(cache_length), reinterpret_cast<uint64_t*>(mask_2d), reinterpret_cast<void*>(output)); + cudaStreamEndCapture(calc_stream.stream, &graph); + cudaGraphInstantiate(&graphExec, graph, nullptr, nullptr, 0); + graphCreated_padding_length = padded_length; + graphCreated_input_length = input_length; + } + cudaGraphLaunch(graphExec, calc_stream.stream); + } else { + model->decode(input_length, padded_length, reinterpret_cast<int32_t*>(input), reinterpret_cast<int32_t*>(position_ids), reinterpret_cast<int32_t*>(cache_length), reinterpret_cast<uint64_t*>(mask_2d), reinterpret_cast<void*>(output)); + } +} + +void draft(std::uintptr_t tree_draft_ids, std::uintptr_t tree_position_ids, std::uintptr_t cache_length, std::uintptr_t attn_mask, std::uintptr_t tree_parent) { + model->draft(reinterpret_cast<int32_t*>(tree_draft_ids), reinterpret_cast<int32_t*>(tree_position_ids), reinterpret_cast<int32_t*>(cache_length), reinterpret_cast<uint64_t*>(attn_mask), reinterpret_cast<int32_t*>(tree_parent)); +} + +int verify_and_fix(int num_tokens, std::uintptr_t pred, std::uintptr_t gt, std::uintptr_t position_ids, std::uintptr_t cache_length, std::uintptr_t attn_mask, std::uintptr_t tree_parent) { + return model->verify(num_tokens, reinterpret_cast<int32_t*>(pred), reinterpret_cast<int32_t*>(gt), reinterpret_cast<int32_t*>(position_ids), reinterpret_cast<int32_t*>(cache_length), reinterpret_cast<uint64_t*>(attn_mask), reinterpret_cast<int32_t*>(tree_parent)); +} + +void print_perf_summary() { + perf_summary(); +} + +PYBIND11_MODULE(C, m) { + // base bind + m.def("init_base_model", &init_base_model, "Init base model"); + m.def("init_minicpm4_model", &init_minicpm4_model, "Init minicpm4 model"); + m.def("init_w4a16_gptq_marlin_base_model", &init_w4a16_gptq_marlin_base_model, "Init W4A16 base model"); + m.def("init_w4a16_gptq_marlin_minicpm4_model", &init_w4a16_gptq_marlin_minicpm4_model, "Init W4A16 base model"); + + // eagle bind + m.def("init_eagle_model", &init_eagle_model, "Init eagle model"); + m.def("init_minicpm4_eagle_model", &init_minicpm4_eagle_model, "Init minicpm4 eagle model"); + // spec bind + m.def("init_w4a16_gm_spec_w4a16_gm_model", &init_w4a16_gm_spec_w4a16_gm_model, "Init w4a16 spec v1 model"); + + // hier spec bind + m.def("init_hier_eagle_w4a16_gm_spec_w4a16_gm_model", &init_hier_eagle_w4a16_gm_spec_w4a16_gm_model, "init hier eagle gm spec gm model"); + m.def("init_hier_eagle_w4a16_gm_rot_spec_w4a16_gm_model", &init_hier_eagle_w4a16_gm_rot_spec_w4a16_gm_model, "init hier eagle rot gm spec gm model"); + + // interface + m.def("init_storage", &init_storage, "Init storage"); + m.def("load_model", &load_model, "Load model"); + m.def("prefill", &prefill, "Prefill"); + m.def("decode", &decode, "Decode"); + m.def("draft", &draft, "Draft"); + m.def("verify_and_fix", &verify_and_fix, "Verify and fix"); + m.def("print_perf_summary", &print_perf_summary, "Print perf summary"); +} \ No newline at end of file diff --git a/examples/CPM.cu/src/flash_attn/flash_api.hpp b/examples/CPM.cu/src/flash_attn/flash_api.hpp new file mode 100644 index 00000000..f0ee07e6 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/flash_api.hpp @@ -0,0 +1,392 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#include <cutlass/numeric_types.h> + +#include "flash.h" +#include "static_switch.h" +#include "../model/mask.cuh" + +void set_params_fprop(Flash_fwd_params ¶ms, + bool is_bf16, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + void* q, + void* k, + void* v, + void* out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_k, + void *p_d, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + const float softcap, + bool seqlenq_ngroups_swapped=false, + const bool unpadded_lse=false) { + + // Reset the parameters + params = {}; + + params.is_bf16 = is_bf16; + + // Set the pointers and strides. + params.q_ptr = q; + params.k_ptr = k; + params.v_ptr = v; + // All stride are in elements, not bytes. + params.q_row_stride = h * d; + params.k_row_stride = h_k * d; + params.v_row_stride = h_k * d; + params.q_head_stride = d; + params.k_head_stride = d; + params.v_head_stride = d; + params.o_ptr = out; + params.o_row_stride = h * d; + params.o_head_stride = d; + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = seqlen_q * h * d; + params.k_batch_stride = seqlen_k * h_k * d; + params.v_batch_stride = seqlen_k * h_k * d; + params.o_batch_stride = seqlen_q * h * d; + if (seqlenq_ngroups_swapped) { + params.q_batch_stride *= seqlen_q; + params.o_batch_stride *= seqlen_q; + } + } + + params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d); + params.seqused_k = static_cast<int *>(seqused_k); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + // Set the different scale values. + if (softcap > 0.0) { + params.softcap = softmax_scale / softcap; + params.scale_softmax = softcap; + params.scale_softmax_log2 = softcap * M_LOG2E; + } else{ + // Remove potential NaN + params.softcap = 0.0; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + } + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0; + + if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; } + if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + params.is_seqlens_k_cumulative = true; + + params.unpadded_lse = unpadded_lse; + params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; +} + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { + FP16_SWITCH(!params.is_bf16, [&] { + HEADDIM_SWITCH(params.d, [&] { + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream); + } else { + run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream); + } + }); + }); + }); +} + +// Find the number of splits that maximizes the occupancy. For example, if we have +// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is +// better than having 3 splits (efficiency = 0.67). However, we also don't want too many +// splits as that would incur more HBM reads/writes. +// So we find the best efficiency, then find the smallest number of splits that gets 85% +// of the best efficiency. +inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { + // If we have enough to almost fill the SMs, then just use 1 split + if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; } + max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + float max_efficiency = 0.f; + std::vector<float> efficiency; + efficiency.reserve(max_splits); + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, + // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks + // (i.e. it's 11 splits anyway). + // So we check if the number of blocks per split is the same as the previous num_splits. + auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { + return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); + }; + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { + efficiency.push_back(0.f); + } else { + float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if (eff > max_efficiency) { max_efficiency = eff; } + efficiency.push_back(eff); + } + } + for (int num_splits = 1; num_splits <= max_splits; num_splits++) { + if (!is_split_eligible(num_splits)) { continue; } + if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) { + // printf("num_splits chosen = %d\n", num_splits); + return num_splits; + } + } + return 1; +} + +int set_params_splitkv(const int batch_size, const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q) { + cudaDeviceProp dprops; + cudaGetDeviceProperties(&dprops, 0); + + // This needs to match with run_mha_fwd_splitkv_dispatch + const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64); + const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n; + // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel. + // In any case we don't expect seqlen_q to be larger than 64 for inference. + const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64; + + // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block. + int num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops.multiProcessorCount * 2, num_n_blocks, 128); + + return num_splits; +} + +void mha_fwd_stage1( + bool is_bf16, + int batch_size, + int seqlen_q, + int seqlen_k, + int seqlen_c, + int num_heads, + int num_heads_k, + int head_size, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size + int* seqlens_k, // batch_size + void* p, // batch_size x seqlen_q x seqlen_k + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + cudaStream_t stream, + int& q_round, + int& k_round +) { + // causal=true is the same as causal=false in this case + if (seqlen_q == 1) { is_causal = false; } + if (is_causal) { window_size_right = 0; } + + seqlen_q *= 16; + num_heads /= 16; + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + q_round = seqlen_q_rounded / 16; + k_round = seqlen_k_rounded; + + Flash_fwd_params params; + set_params_fprop(params, + is_bf16, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, kcache, vcache, nullptr, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + /*p_ptr=*/p, + /*softmax_lse=*/nullptr, + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right, + softcap + ); + + params.blockmask = nullptr; + params.block_window_size = 0; + params.m_block_dim = 16; + params.n_block_dim = 1; + + params.seqlen_c = seqlen_c; + + params.mask_2d = nullptr; + params.mask_q_range = 0; + params.mask_k_range = 0; + + params.rotary_dim = 0; + + params.page_block_size = 1; + + params.alibi_slopes_ptr = nullptr; + + if (seqlens_k != nullptr) { + params.cu_seqlens_k = seqlens_k; + params.is_seqlens_k_cumulative = false; + } + params.num_splits = 1; + + run_mha_fwd(params, stream, true); +} + +void mha_fwd_kvcache( + bool is_bf16, + int batch_size, + int seqlen_q, + int seqlen_k, + int num_heads, + int num_heads_k, + int head_size, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* kcache, // batch_size x seqlen_k x num_heads_k x head_size + void* vcache, // batch_size x seqlen_k x num_heads_k x head_size + int* seqlens_k, // batch_size + const Mask& mask, // batch_size x seqlen_q x seqlen_k_range + void* out, // batch_size x seqlen_q x num_heads x head_size + float* softmax_lse, // batch_size x num_heads x seqlen_q + float* softmax_lse_accum, // num_splits x batch_size x num_heads x seqlen_q + float* oaccum, // num_splits x batch_size x num_heads x seqlen_q x head_size + const float softmax_scale, + bool is_causal, + int window_size_left, + int window_size_right, + const float softcap, + cudaStream_t stream, + uint64_t* blockmask = nullptr, + int block_window_size = 0 +) { + // causal=true is the same as causal=false in this case + if (seqlen_q == 1) { is_causal = false; } + if (is_causal) { window_size_right = 0; } + + if (blockmask != nullptr) { // TODO improve this + // if (blockmask != nullptr || block_window_size > 0) { + seqlen_q *= 16; + num_heads /= 16; + } + + if (window_size_left >= seqlen_k) { window_size_left = -1; } + if (window_size_right >= seqlen_k) { window_size_right = -1; } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256; + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + Flash_fwd_params params; + set_params_fprop(params, + is_bf16, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, kcache, vcache, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + /*p_ptr=*/nullptr, + (void*)softmax_lse, + /*p_dropout=*/0.f, + softmax_scale, + window_size_left, + window_size_right, + softcap + ); + + params.blockmask = blockmask; + if (blockmask != nullptr) { // TODO improve this + // if (blockmask != nullptr || block_window_size > 0) { + params.m_block_dim = 16; + params.n_block_dim = 64; + params.num_blocks_m = (seqlen_q + 16 - 1) / 16; + params.num_blocks_n = (seqlen_k + 64 - 1) / 64; + } else { + params.m_block_dim = 1; + params.n_block_dim = 1; + } + params.block_window_size = block_window_size; + + params.mask_2d = mask.ptr; + params.mask_q_range = mask.mask_q_range; + params.mask_k_range = mask.mask_k_range; + + params.rotary_dim = 0; + + params.softmax_lseaccum_ptr = (void*)softmax_lse_accum; + params.oaccum_ptr = (void*)oaccum; + + params.page_block_size = 1; + + params.alibi_slopes_ptr = nullptr; + + if (seqlens_k != nullptr) { + params.cu_seqlens_k = seqlens_k; + params.is_seqlens_k_cumulative = false; + params.num_splits = 16; + } else { + params.num_splits = 1; + } + + run_mha_fwd(params, stream, true); +} \ No newline at end of file diff --git a/examples/CPM.cu/src/flash_attn/src/alibi.h b/examples/CPM.cu/src/flash_attn/src/alibi.h new file mode 100644 index 00000000..e714233e --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/alibi.h @@ -0,0 +1,74 @@ +#include <cmath> + +#include <cute/tensor.hpp> + +#include <cutlass/cutlass.h> +#include <cutlass/array.h> + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <bool Is_causal> +struct Alibi { + + const float alibi_slope; + const int max_seqlen_k, max_seqlen_q; + + __forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q) + : alibi_slope(alibi_slope) + , max_seqlen_k(max_seqlen_k) + , max_seqlen_q(max_seqlen_q) { + }; + + + template <typename Engine, typename Layout> + __forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + } + } + } else { // Bias depends on both row_idx and col_idx + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx); + } + } + } + } + } + } + +}; + +} // namespace flash diff --git a/examples/CPM.cu/src/flash_attn/src/block_info.h b/examples/CPM.cu/src/flash_attn/src/block_info.h new file mode 100644 index 00000000..a33811a0 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/block_info.h @@ -0,0 +1,54 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<bool Varlen=true> +struct BlockInfo { + + template<typename Params> + __device__ BlockInfo(const Params ¶ms, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) + , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) + , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) + { + } + + template <typename index_t> + __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template <typename index_t> + __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride; + } + + + template <typename index_t> + inline __device__ index_t blockmask_q_offset(const index_t m_block_dim, const int bidb) const { + return sum_s_q == -1 ? bidb * (actual_seqlen_q / m_block_dim) : uint32_t(sum_s_q) / m_block_dim; + } + + const int sum_s_q; + const int sum_s_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int leftpad_k; + const int seqlen_k_cache; + const int actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/examples/CPM.cu/src/flash_attn/src/dropout.h b/examples/CPM.cu/src/flash_attn/src/dropout.h new file mode 100644 index 00000000..4882f97d --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/dropout.h @@ -0,0 +1,94 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "philox.cuh" +#include "utils.h" + +namespace flash { + +struct Dropout { + + const unsigned long long seed, offset; + const uint8_t p_dropout_in_uint8_t; + + __forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset, + const uint8_t p_dropout_in_uint8_t, + const int bid, const int hid, const int tid, const int nheads) + : seed(seed) + , offset(offset + (bid * nheads + hid) * 32 + tid % 32) + , p_dropout_in_uint8_t(p_dropout_in_uint8_t) { + } + + template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout> + __forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_, + int block_row_start, int block_col_start, int block_row_stride) { + // convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2) + Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_dropout(tensor_.layout())); + using T = typename Engine::value_type; + auto encode_dropout = [](bool keep, T val) { + return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0)); + }; + static_assert(decltype(size<2>(tensor))::value % 2 == 0); + const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t); + const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t); + // if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); } + #pragma unroll + for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) { + uint2 rowcol = make_uint2(block_row_start, block_col_start); + #pragma unroll + for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) { + // if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));} + uint4 random_uint4 = flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset); + // if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);} + uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4); + // Special implementation for 16-bit types: we duplicate the threshold to the + // low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction + // to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000, + // and the high 16 bits will be either 0xffff or 0x0000, depending on whether + // the random value is less than the threshold. + // We then do a bit-wise AND between the mask and the original value (in 32-bit). + // We're exploiting the fact that floating point comparison is equivalent to integer + // comparison, since we're comparing unsigned integers whose top 8-bits are zero. + if (!encode_dropout_in_sign_bit + && (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) { + uint16_t rnd_16[16]; + #pragma unroll + for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); } + uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16); + #pragma unroll + for (int j = 0; j < 2; j++) { + Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + #pragma unroll + for (int i = 0; i < 4; i++) { + uint32_t mask; + asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t)); + tensor_uint32(i) &= mask; + } + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } else { + #pragma unroll + for (int j = 0; j < 2; j++) { + #pragma unroll + for (int i = 0; i < 8; i++) { + tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j)); + } + Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j)); + // if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); } + } + } + // // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { + // // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w); + // // } + } + } + } + +}; + +} // namespace flash diff --git a/examples/CPM.cu/src/flash_attn/src/flash.h b/examples/CPM.cu/src/flash_attn/src/flash.h new file mode 100644 index 00000000..3943b6a9 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash.h @@ -0,0 +1,193 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include <cuda.h> +#include <vector> + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = int64_t; + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + + // The O matrix (output). + void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the P matrix. + void * __restrict__ p_ptr; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; + int seqlen_c; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + int * __restrict__ leftpad_k; + + // If provided, the actual length of each k sequence. + int * __restrict__ seqused_k; + + uint64_t *__restrict__ blockmask; + int m_block_dim, n_block_dim; + int num_blocks_m, num_blocks_n; + int block_window_size; + + // The mask_2d matrix. + uint64_t *__restrict__ mask_2d; + int mask_q_range; + int mask_k_range; + + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int * __restrict__ cache_batch_idx; + + // Paged KV cache + int * __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + // Local window size + int window_size_left, window_size_right; + float softcap; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; + + bool is_bf16; + bool is_causal; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + + void * __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; + + bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. + bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_bwd_params : public Flash_fwd_params { + + // The dO and dQKV matrices. + void *__restrict__ do_ptr; + void *__restrict__ dq_ptr; + void *__restrict__ dk_ptr; + void *__restrict__ dv_ptr; + + // To accumulate dQ + void *__restrict__ dq_accum_ptr; + void *__restrict__ dk_accum_ptr; + void *__restrict__ dv_accum_ptr; + + // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q + // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ + // dv_accum_ptr; + + // The stride between rows of the dO, dQ, dK and dV matrices. + // TD [2022-04-16]: We're using 32-bit indexing to save registers. + // The code probably won't work for arrays larger than 2GB. + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + + // The pointer to the softmax d sum. + void *__restrict__ dsoftmax_sum; + + bool deterministic; + index_t dq_accum_split_stride; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template<typename T, int Headdim, bool Is_causal> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); + +template<typename T, int Headdim, bool Is_causal> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_blockmask.h b/examples/CPM.cu/src/flash_attn/src/flash_blockmask.h new file mode 100644 index 00000000..fe1782c3 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_blockmask.h @@ -0,0 +1,108 @@ +#pragma once + +namespace flash { + +class fwdIterator{ + public: + template<typename Params, typename BlockInfo> + __device__ fwdIterator(const Params ¶ms, const BlockInfo &binfo, const int kBlockM, const int kBlockN, const int batch_idx, const int head_idx, const int loop_step_idx, int n_block_min, int n_block_max) {//row first + this->cache_seqlen_k = binfo.actual_seqlen_k - binfo.actual_seqlen_q / params.m_block_dim; + this->max_block_idx = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + this->n_block_min = n_block_min; + this->n_block_max = n_block_max; + this->batch_idx = batch_idx; // Store batch_idx for debugging + this->head_idx = head_idx; + + const int q_block_idx = loop_step_idx + cache_seqlen_k; + if (params.blockmask != nullptr) { + // Calculate the offset for the uint64 blockmask + const int num_blocks_m = params.num_blocks_m; + const int num_blocks_n = params.num_blocks_n; + const int uint64_per_row = (num_blocks_n + 64 - 1) / 64; + const int row_offset = params.cu_seqlens_q != nullptr ? binfo.blockmask_q_offset(params.m_block_dim, batch_idx) : batch_idx * params.h_k * params.num_blocks_m; + + blockmask_ptr = params.blockmask + + head_idx * params.num_blocks_m * uint64_per_row + + row_offset * uint64_per_row + + loop_step_idx * uint64_per_row; + + this->k_window_left = params.block_window_size > 0 ? (q_block_idx + kBlockN - 1) / kBlockN - params.block_window_size : 2147483647; + } else { + blockmask_ptr = nullptr; + this->k_window_left = params.block_window_size > 0 ? (q_block_idx + kBlockN - 1) / kBlockN - params.block_window_size / kBlockN : -1; + } + } + + __device__ int _max_no_larger(int target) const { + if(max_block_idx == 0){ + return -1; + }; + if (target < 0) return -1; + + if (blockmask_ptr == nullptr) { + if (k_window_left <= target) return target; // sliding window + return -1; + } + + if (k_window_left <= target) { + return target; + } + + // 目标值不能超过最大块索引 + target = min(target, max_block_idx - 1); + + // 计算相对于当前q_bit_position的实际位置 + int target_bit_pos = target; + + // 确定此块在哪个uint64中 + int uint64_offset = target_bit_pos / 64; + + // 确定此块在uint64中的哪一位 + int bit_pos = target_bit_pos % 64; + + // 创建一个掩码,保留target及更低位的所有位 + uint64_t mask = bit_pos != 63 ? (1ULL << (bit_pos + 1)) - 1 : 0xFFFFFFFFFFFFFFFFULL; + + // 检查当前uint64中target及以下的位 + uint64_t value = blockmask_ptr[uint64_offset] & mask; + + // 如果当前uint64中有设置的位 + if (value != 0) { + // 找到最高位的1(即不大于target的最大设置位) + int highest_bit = 63 - __clzll(value); // __clzll计算前导0的数量 + int result = highest_bit + (uint64_offset * 64); + return result; + } + + // 如果当前uint64中没有找到,检查更低的uint64块 + for (int i = uint64_offset - 1; i >= 0; i--) { + value = blockmask_ptr[i]; + if (value != 0) { + // 找到最高位的1 + int highest_bit = 63 - __clzll(value); + // 计算相对于q_bit_position的偏移 + int result = highest_bit + (i * 64); + return result; + } + } + + // 没有找到设置位 + return -1; + } + + __device__ int max_no_larger(int target) const { + int res = _max_no_larger(target); + return res < this->n_block_min ? -1 : res; + } + + uint64_t *blockmask_ptr; + int row_offset; // 行偏移量 + int uint64_per_row; // 每行使用的uint64数量 + int cache_seqlen_k; + int max_block_idx; + int n_block_min, n_block_max; + int batch_idx, head_idx; + int k_window_left; +}; + +} // namespace flash \ No newline at end of file diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu new file mode 100644 index 00000000..9383c102 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 128, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128<cutlass::bfloat16_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu new file mode 100644 index 00000000..f03abda4 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 128, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128<cutlass::bfloat16_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu new file mode 100644 index 00000000..c616628c --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 128, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128<cutlass::half_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu new file mode 100644 index 00000000..4ff6b9fb --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 128, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128<cutlass::half_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu new file mode 100644 index 00000000..d6d4371b --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 160, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160<cutlass::bfloat16_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu new file mode 100644 index 00000000..5af68ac3 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 160, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160<cutlass::bfloat16_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu new file mode 100644 index 00000000..1ef511a6 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 160, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160<cutlass::half_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu new file mode 100644 index 00000000..96abfbd8 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 160, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160<cutlass::half_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu new file mode 100644 index 00000000..077d25d0 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 192, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192<cutlass::bfloat16_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu new file mode 100644 index 00000000..ea5f265f --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 192, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192<cutlass::bfloat16_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu new file mode 100644 index 00000000..a4a7bc24 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 192, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192<cutlass::half_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu new file mode 100644 index 00000000..c30c4a14 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 192, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192<cutlass::half_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu new file mode 100644 index 00000000..a12a5f4a --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 224, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224<cutlass::bfloat16_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu new file mode 100644 index 00000000..8690bdb1 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 224, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224<cutlass::bfloat16_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu new file mode 100644 index 00000000..f01dad09 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 224, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224<cutlass::half_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu new file mode 100644 index 00000000..7ec1e16b --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 224, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224<cutlass::half_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu new file mode 100644 index 00000000..f84e978c --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 256, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256<cutlass::bfloat16_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu new file mode 100644 index 00000000..c52f0417 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 256, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256<cutlass::bfloat16_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu new file mode 100644 index 00000000..f96f7edc --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 256, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256<cutlass::half_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu new file mode 100644 index 00000000..9c7c6b93 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 256, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256<cutlass::half_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu new file mode 100644 index 00000000..e21d0408 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 32, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32<cutlass::bfloat16_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu new file mode 100644 index 00000000..f377a5b8 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 32, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32<cutlass::bfloat16_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu new file mode 100644 index 00000000..74e4d66a --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 32, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32<cutlass::half_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu new file mode 100644 index 00000000..e85db18e --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 32, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32<cutlass::half_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu new file mode 100644 index 00000000..9297e8bb --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 64, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64<cutlass::bfloat16_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu new file mode 100644 index 00000000..8364b1e7 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 64, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64<cutlass::bfloat16_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu new file mode 100644 index 00000000..1c6ed7ef --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 64, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64<cutlass::half_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu new file mode 100644 index 00000000..3c87573b --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 64, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64<cutlass::half_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu new file mode 100644 index 00000000..49fae856 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 96, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96<cutlass::bfloat16_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu new file mode 100644 index 00000000..c5af1cf6 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::bfloat16_t, 96, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96<cutlass::bfloat16_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu new file mode 100644 index 00000000..b0d6c992 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 96, true>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96<cutlass::half_t, true>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu new file mode 100644 index 00000000..c97aa33f --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<cutlass::half_t, 96, false>(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96<cutlass::half_t, false>(params, stream); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_kernel.h b/examples/CPM.cu/src/flash_attn/src/flash_fwd_kernel.h new file mode 100644 index 00000000..a1ad2adc --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_kernel.h @@ -0,0 +1,2503 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include <cute/tensor.hpp> + +#include <cutlass/cutlass.h> +#include <cutlass/array.h> +#include <cutlass/numeric_types.h> + +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" +#include "softmax.h" +#include "mask.h" +#include "dropout.h" +#include "rotary.h" +#include "flash_blockmask.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +__forceinline__ __device__ float minusinf_to_zero(float &input) { + if (input == -INFINITY) { + return 0; + } else { + return input; + } +} + +template<int thr_offset, typename Engine0, typename Layout0, typename Operator> +__forceinline__ __device__ void thread_element_wise_reduce_(Tensor<Engine0, Layout0> &tensor, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ni++) { + tensor(mi, ni) = op( + minusinf_to_zero(tensor(mi, ni)), + __shfl_xor_sync(uint32_t(-1), minusinf_to_zero(tensor(mi, ni)), thr_offset) + ); + } + } +} + +template <typename Element, typename E1, typename L1, typename E2, typename L2> +__forceinline__ __device__ void hdim16_reduce( + Tensor<E1, L1> &acc_S, + Tensor<E2, L2> &g_Sh, + const int col_idx_offset_, + const int row_idx_offset_, + const int warp_row_stride) { + + // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) (or (_2,_2),_2,_16) for D=32) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + auto tensor = make_tensor(acc_S.data(), flash::convert_layout_acc_rowcol(acc_S.layout())); + const int warp_id = threadIdx.x / 32; + const int lane_id = threadIdx.x % 32; + // const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + // const int row_idx_offset = row_idx_offset_ + warp_id * 16 + lane_id / 4; + const int col_idx_offset = 0 + (lane_id % 4) * 2; + const int row_idx_offset = 0 + warp_id * 16 + lane_id / 4; + + // step 1: 线程内部求和 (v0 + v2) + using TensorT = decltype(make_tensor<float>(Shape< Int<size<0, 1>(tensor)>, Int<size<1>(tensor)> >{})); + TensorT v02; + clear(v02); + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx_v02 = j * size<1, 1>(tensor) + nj; + const float v0 = minusinf_to_zero(tensor(make_coord(0, mi), make_coord(j, nj))); + const float v2 = minusinf_to_zero(tensor(make_coord(1, mi), make_coord(j, nj))); + v02(mi, col_idx_v02) = v0 + v2; + } + } + } + + // step 2: warp 内部蝶形求和 + SumOp<float> sum_op; + thread_element_wise_reduce_<16>(v02, sum_op); + thread_element_wise_reduce_<8> (v02, sum_op); + thread_element_wise_reduce_<4> (v02, sum_op); + + // step 3: copy 到 global mem + cutlass::NumericConverter<Element, float> converter; + if (lane_id < 4) { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + const int col_idx_v02 = j * size<1, 1>(tensor) + nj; + g_Sh(row_idx_base/16, col_idx) = converter(v02(mi, col_idx_v02)); // ignore /16 since it's too slow + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN> +__forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bidb, const int bidh, const int m_block, const BlockInfo</*Varlen=*/!Is_even_MN> &binfo) { + // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path. + // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick. + // Otherwise, it's written as (h, b, seqlen_q). + const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped; + auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0; + auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast<ElementAccum*>(params.softmax_lse_ptr) + lse_offset); + + auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q); + auto lse_stride = params.seqlenq_ngroups_swapped ? make_stride(1, params.seqlen_q * params.b, params.b) : ( + params.unpadded_lse ? make_stride(params.h * params.total_q, params.total_q, 1) : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1) + ); + + auto lse_layout = make_layout(lse_shape, lse_stride); + Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout); + auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _); + return local_tile(mLSE_slice, Shape<Int<kBlockM>>{}, make_coord(m_block)); +} + + +template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params> +inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + // do not support dropout + flash::Dropout dropout(0, 0, params.p_dropout_in_uint8_t, + bidb, bidh, tidx, params.h); + + // Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might + // exit early and no one saves the rng states. + if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) { + params.rng_state[0] = 0; + params.rng_state[1] = 0; + } + + const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_block_min = !Is_local ? 0 : std::max(0, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right, kBlockN)); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { + // printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max); + // } + } + // We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0. + // Otherwise we might read OOB elements from gK and gV. + if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) { + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr) + + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + + Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_tensor<Element>(shape(tOgO)); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gO), size<1>(gO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgO); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSE(row) = INFINITY; } + } + return; + } + // if (tidx == 0) { printf("m_block = %d, n_block_min = %d, n_block_max = %d\n", m_block, n_block_min, n_block_max); } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.k_ptr) + + binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.k_row_stride, params.k_head_stride, _1{})); + Tensor gK = local_tile(mK(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) + Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.v_ptr) + + binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)), + make_shape(binfo.actual_seqlen_k, params.h_k, params.d), + make_stride(params.v_row_stride, params.v_head_stride, _1{})); + Tensor gV = local_tile(mV(_, bidh / params.h_h_k_ratio, _), Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_coord(_, 0)); // (kBlockN, kHeadDim, nblocksN) + Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p), + Shape<Int<kBlockM>, Int<kBlockN>>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K, nblocksN) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K, nblocksN) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor tSgS = thr_mma.partition_C(gP); + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + // if (cute::thread0()) {smem_thr_copy_Q.print_all();} + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + // if (cute::thread0()) {print(tSsQ.layout()); printf("\n");} + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + // Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K) + // if (cute::thread0()) { + // print(tScQ.layout()); printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<0>(tScQ(i))); + // } + // printf("\n"); + // for (int i = 0; i < size(tScQ); ++i) { + // printf("%d ", get<1>(tScQ(i))); + // } + // printf("\n"); + // } + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); } + + // // if (cute::thread(1, 0)) { print(tQsQ); } + // // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{}); + // // if (cute::thread0()) { print(sQNoSwizzle); } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block), tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); } + // __syncthreads(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + + clear(acc_o); + + flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask<Is_causal, Is_local, Has_alibi, /*Mask_2d=*/false> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>( + gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } + + mask.template apply_mask<Is_causal, Is_even_MN>( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2); + + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type<Element>(acc_s); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor rP_drop = make_fragment_like(rP); + cute::copy(rP, rP_drop); + dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>( + rP_drop, block_row_idx, block_col_idx, kNWarps + ); + cute::copy(rP_drop, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); + } + if (Is_dropout) { + dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); + } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout())); + // if (cute::thread0()) { print(tOrP); } + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + // if (cute::thread0()) { print(scores); } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= n_block_min; --n_block) { + Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV(_, _, _, n_block), tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > n_block_min) { + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK(_, _, _, n_block - 1), tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + mask.template apply_mask</*Causal_mask=*/false>( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type<Element>(acc_s); + int block_row_idx = m_block * (kBlockM / 16) + tidx / 32; + int block_col_idx = n_block * (kBlockN / 32); + if (Return_softmax) { + Tensor rP_drop = make_fragment_like(rP); + cute::copy(rP, rP_drop); + dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>( + rP_drop, block_row_idx, block_col_idx, kNWarps + ); + cute::copy(rP_drop, tSgS); + tSgS.data() = tSgS.data() + (-kBlockN); + } + if (Is_dropout) { + dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps); + } + + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout())); + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(acc_o, params.scale_softmax, params.rp_dropout); + + // Convert acc_o from fp32 to fp16/bf16 + Tensor rO = flash::convert_type<Element>(acc_o); + Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); + Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + Tensor mO = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.o_ptr) + + binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.o_row_stride, params.o_head_stride, _1{})); + Tensor gO = local_tile(mO(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor gLSE = get_lse_tile<ElementAccum, Params, kBlockM, Is_even_MN>(params, bidb, bidh, m_block, binfo); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + Tensor tOrO = make_tensor<Element>(shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, typename Params> +inline __device__ void compute_attn_1rowblock_splitkv_headlevel(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::GmemTiledCopyO, + typename Kernel_traits::GmemTiledCopyOaccum + >; + using ElementO = std::conditional_t<!Split, Element, ElementAccum>; + + const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } + // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM / params.m_block_dim + binfo.actual_seqlen_k - (binfo.actual_seqlen_q / params.m_block_dim) + params.window_size_right, kBlockN)); + } + if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 + // We exit early and write 0 to gOaccum and -inf to gLSEaccum. + // Otherwise we might read OOB elements from gK and gV, + // or get wrong results when we combine gOaccum from different blocks. + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride * kBlockM; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_stride(Split ? kHeadDim * params.seqlen_q / kBlockM : params.o_head_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape<Int<kBlockM>>{}, make_stride(params.seqlen_q / kBlockM)); + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum)); + clear(tOrOaccum); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgOaccum); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; } + } + return; + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + // We move K and V to the last block. + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; + const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; + const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; + const index_t row_offset_k = block_table == nullptr + ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride + : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = block_table == nullptr + ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride + : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q / kBlockM, params.h, kBlockM, params.d), + make_stride(params.q_row_stride * kBlockM, params.q_head_stride * kBlockM, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(m_block, bidh, _, _), Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_coord(0, 0)); // (kBlockM, kHeadDim) + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k), + Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_stride(params.k_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v), + Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_stride(params.v_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), + Shape<Int<kBlockN>, Int<kHeadDim / 2>>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin), + Shape<Int<kBlockN>, Int<kHeadDim / 2>>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), + Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin), + Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + const index_t row_offset_knew = bidb * params.knew_batch_stride + + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + const index_t row_offset_vnew = bidb * params.vnew_batch_stride + + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr) + + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr) + + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + auto tKgK_data = tKgK.data(); + auto tVgV_data = tVgV.data(); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx<Is_even_K>( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx<Is_even_K>( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + + } + } + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + if (n_block > n_block_copy_min) { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; + const int offset_diff = block_table_offset_next - block_table_offset_cur; + tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; + tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; + } + } + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + tKgK.data() = tKgK_data; + tVgV.data() = tVgV_data; + } + + // Read Q from gmem to smem, optionally apply rotary embedding. + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), + Shape<Int<kBlockM>, Int<kHeadDim / 2>>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin), + Shape<Int<kBlockM>, Int<kHeadDim / 2>>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), + Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin), + Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved<Is_even_K>( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } else { + flash::copy_rotary_contiguous<Is_even_K>( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + // flash::cp_async_wait<0>(); + // __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } + // __syncthreads(); + + clear(acc_o); + + flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask<Is_causal, Is_local, Has_alibi, /*Mask_2d=*/true> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope, params.mask_2d, params.mask_q_range, params.mask_k_range, params.m_block_dim); + fwdIterator blockmask(params, binfo, kBlockM, kBlockN, bidb, bidh, m_block, n_block_min, n_block_max); + int next_block_idx = blockmask.max_no_larger(n_block_max-1); + int leap = 0; + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + const bool skip = (n_block != next_block_idx); + Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + if (!skip) { + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } + + mask.template apply_mask<Is_causal, Is_even_MN>( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + next_block_idx = blockmask.max_no_larger(n_block-1); + } else { + mask.all_mask(acc_s); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // __syncthreads(); + + leap = (masking_step + 1 == n_masking_steps) ? n_block - next_block_idx : 1; + + if (n_block > n_block_min && next_block_idx != -1) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * leap * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - leap) * kBlockN / params.page_block_size; + const int block_table_offset_next =(n_block - leap) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2); + // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } + + if (!skip) { + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type<Element>(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout())); + + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + leap = n_block - next_block_idx + 1; + + // These are the iterations where we don't need masking on S + for (n_block = next_block_idx; n_block != -1 && n_block >= n_block_min; n_block = next_block_idx) { + next_block_idx = blockmask.max_no_larger(n_block - 1); + + Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * leap * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + leap) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + leap) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + + leap = n_block - next_block_idx; + if (next_block_idx != -1) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * leap * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - leap) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - leap) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + mask.template apply_mask</*Causal_mask=*/false>( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type<Element>(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout())); + + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, params.scale_softmax); + // if (cute::thread0()) { print(lse); } + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum + >; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type<ElementO>(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + if constexpr (Split) { __syncthreads(); } + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride * kBlockM; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block) * params.d_rounded; + const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ? + ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb) + ) + m_block; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_stride(Split ? kHeadDim * params.seqlen_q / kBlockM : params.o_head_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape<Int<kBlockM>>{}, make_stride(params.seqlen_q / kBlockM)); + // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, typename Params> +inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + using GmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::GmemTiledCopyO, + typename Kernel_traits::GmemTiledCopyOaccum + >; + using ElementO = std::conditional_t<!Split, Element, ElementAccum>; + + const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } + // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + if (Is_causal || Is_local) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM / params.m_block_dim + binfo.actual_seqlen_k - (binfo.actual_seqlen_q / params.m_block_dim) + params.window_size_right, kBlockN)); + } + if (n_block_min >= n_block_max) { // This also covers the case where n_block_max <= 0 + // We exit early and write 0 to gOaccum and -inf to gLSEaccum. + // Otherwise we might read OOB elements from gK and gV, + // or get wrong results when we combine gOaccum from different blocks. + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape<Int<kBlockM>>{}, Stride<_1>{}); + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum)); + clear(tOrOaccum); + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(gOaccum), size<1>(gOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); + Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOgOaccum); ++m) { + const int row = get<0>(tOcO(0, m, 0)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM && get<1>(tOcO(0, m, 0)) == 0) { gLSEaccum(row) = Split ? -INFINITY : INFINITY; } + } + return; + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + // We move K and V to the last block. + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; + const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; + const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; + const index_t row_offset_k = block_table == nullptr + ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride + : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = block_table == nullptr + ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride + : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q, params.h, params.d), + make_stride(params.q_row_stride, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(_, bidh, _), Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_coord(m_block, 0)); // (kBlockM, kHeadDim) + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k), + Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_stride(params.k_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } + Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v), + Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_stride(params.v_row_stride, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); + Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + if constexpr (Append_KV) { + // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to + // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. + // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2); + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), + Shape<Int<kBlockN>, Int<kHeadDim / 2>>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin), + Shape<Int<kBlockN>, Int<kHeadDim / 2>>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), + Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin), + Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_stride(params.rotary_dim / 2, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } + // if (cute::thread(8, 0)) { print_tensor(gCos); } + // if (cute::thread(0, 0)) { print_tensor(tRgCos); } + + // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + const index_t row_offset_knew = bidb * params.knew_batch_stride + + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; + // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + const index_t row_offset_vnew = bidb * params.vnew_batch_stride + + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; + // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, + // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. + // This maps to accessing the first 64 rows of knew_ptr. + Tensor gKnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.knew_ptr) + + row_offset_knew - binfo.seqlen_k_cache * params.knew_row_stride), + Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_stride(params.knew_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("knew_ptr = %p, row_offset_knew = %d, gKnew_ptr = %p\n", params.knew_ptr, row_offset_knew, gKnew.data()); } + Tensor gVnew = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.vnew_ptr) + + row_offset_vnew - binfo.seqlen_k_cache * params.vnew_row_stride), + Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_stride(params.vnew_row_stride, _1{})); + Tensor tKgKnew = gmem_thr_copy_QKV.partition_S(gKnew); // (KCPY, KCPY_N, KCPY_K) + Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K) + + const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN); + auto tKgK_data = tKgK.data(); + auto tVgV_data = tVgV.data(); + for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) { + flash::copy_w_min_idx<Is_even_K>( + tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride)); + if (params.rotary_dim == 0) { + flash::copy_w_min_idx<Is_even_K>( + tKgKnew, tKgK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN + ); + } else { + if (params.is_rotary_interleaved) { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_interleaved<Is_even_K, /*Clear_OOB_K=*/false>( + tKgKnew, tKgK, tRgCos, tRgSin, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCos.data() = tRgCos.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSin.data() = tRgSin.data() + (-int(kBlockN * params.rotary_dim / 2)); + } else { + // Don't clear OOB_K because we're writing to global memory + flash::copy_rotary_contiguous<Is_even_K, /*Clear_OOB_K=*/false>( + tKgKnew, tKgK, tRgCosCont, tRgSinCont, tKVcKV, binfo.actual_seqlen_k - n_block * kBlockN, + binfo.seqlen_k_cache - n_block * kBlockN, params.d, params.rotary_dim + ); + tRgCosCont.data() = tRgCosCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + tRgSinCont.data() = tRgSinCont.data() + (-int(kBlockN * params.rotary_dim / 2)); + + } + } + tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride)); + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + } else { + if (n_block > n_block_copy_min) { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size; + const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur]; + const int offset_diff = block_table_offset_next - block_table_offset_cur; + tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride; + tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride; + } + } + } + // Need this before we can read in K again, so that we'll see the updated K values. + __syncthreads(); + tKgK.data() = tKgK_data; + tVgV.data() = tVgV_data; + } + + // Read Q from gmem to smem, optionally apply rotary embedding. + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), + Shape<Int<kBlockM>, Int<kHeadDim / 2>>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin), + Shape<Int<kBlockM>, Int<kHeadDim / 2>>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), + Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin), + Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved<Is_even_K>( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } else { + flash::copy_rotary_contiguous<Is_even_K>( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + // flash::cp_async_wait<0>(); + // __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tKsK); } + // __syncthreads(); + + clear(acc_o); + + flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask<Is_causal, Is_local, Has_alibi, /*Mask_2d=*/true> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope, params.mask_2d, params.mask_q_range, params.mask_k_range, params.m_block_dim); + fwdIterator blockmask(params, binfo, kBlockM, kBlockN, bidb, bidh, m_block * kBlockM, n_block_min, n_block_max); + int next_block_idx = blockmask.max_no_larger(n_block_max-1); + int leap = 0; + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + const bool skip = (n_block != next_block_idx); + Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + } + cute::cp_async_fence(); + + if (!skip) { + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } + + mask.template apply_mask<Is_causal, Is_even_MN>( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + next_block_idx = blockmask.max_no_larger(n_block-1); + } else { + mask.all_mask(acc_s); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // __syncthreads(); + + leap = (masking_step + 1 == n_masking_steps) ? n_block - next_block_idx : 1; + + if (n_block > n_block_min && next_block_idx != -1) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * leap * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - leap) * kBlockN / params.page_block_size; + const int block_table_offset_next =(n_block - leap) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2) + : softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2); + // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } + + if (!skip) { + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type<Element>(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout())); + + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + leap = n_block - next_block_idx + 1; + + // These are the iterations where we don't need masking on S + for (n_block = next_block_idx; n_block != -1 && n_block >= n_block_min; n_block = next_block_idx) { + next_block_idx = blockmask.max_no_larger(n_block - 1); + + Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + if (block_table == nullptr) { + tVgV.data() = tVgV.data() + (-int(kBlockN * leap * params.v_row_stride)); + } else { + const int block_table_idx_cur = (n_block + leap) * kBlockN / params.page_block_size; + const int block_table_offset_cur = (n_block + leap) * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = n_block * kBlockN / params.page_block_size; + const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size; + tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride; + } + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + + leap = n_block - next_block_idx; + if (next_block_idx != -1) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * leap * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - leap) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - leap) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + mask.template apply_mask</*Causal_mask=*/false>( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2); + + Tensor rP = flash::convert_type<Element>(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout())); + + flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, params.scale_softmax); + // if (cute::thread0()) { print(lse); } + + Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + using SmemTiledCopyO = std::conditional_t< + !Split, + typename Kernel_traits::SmemCopyAtomO, + typename Kernel_traits::SmemCopyAtomOaccum + >; + auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma); + auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor rO = flash::convert_type<ElementO>(acc_o); + Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sOaccum is larger than sQ, so we need to syncthreads here + // TODO: allocate enough smem for sOaccum + if constexpr (Split) { __syncthreads(); } + + cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + + m_block * kBlockM) * params.d_rounded; + const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ? + ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb) + ) + m_block * kBlockM; + + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), + Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_stride(Split ? kHeadDim : params.o_row_stride, _1{})); + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + row_offset_lseaccum), + Shape<Int<kBlockM>>{}, Stride<_1>{}); + // if (tidx == 0) { printf("row_offset_o = %d, bidh = %d, gOaccum = %p\n", row_offset_o, bidh, gOaccum.data()); } + + GmemTiledCopyO gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum); + + __syncthreads(); + + Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum)); + cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum); + + Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); } + } + } + + // Construct identity layout for sO + Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>( + gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, typename Params> +inline __device__ void compute_attn_1rowblock_splitkv_stage1(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + + const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); } + // if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); } + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const int n_blocks_per_split = ((params.seqlen_k + kBlockN - 1) / kBlockN + num_n_splits - 1) / num_n_splits; + const int n_block_min = !Is_local + ? n_split_idx * n_blocks_per_split + : std::max(n_split_idx * n_blocks_per_split, (m_block * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q - params.window_size_left) / kBlockN); + int n_block_max = std::min(cute::ceil_div(binfo.actual_seqlen_k, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + int n_block_max_c = std::min(cute::ceil_div(params.seqlen_c, kBlockN), (n_split_idx + 1) * n_blocks_per_split); + if (Is_causal || Is_local) { + const int max_q = (m_block + 1) * kBlockM / params.m_block_dim - 1; + const int max_k = (max_q - 16 + 1) / 16; + n_block_max = std::min(n_block_max, cute::ceil_div(max_k, kBlockN)); + const int max_c = (max_q - 64 + 1) / 64; + n_block_max_c = std::min(n_block_max_c, cute::ceil_div(max_c, kBlockN)); + } + if (n_block_min >= n_block_max_c) { // This also covers the case where n_block_max <= 0 + // We exit early and write 0 to gOaccum and -inf to gLSEaccum. + // Otherwise we might read OOB elements from gK and gV, + // or get wrong results when we combine gOaccum from different blocks. + return; + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + // We move K and V to the last block. + const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb]; + const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride; + const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size; + const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size; + const index_t row_offset_k = block_table == nullptr + ? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache) + + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride + : block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = block_table == nullptr + ? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache) + + (n_block_max_c - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride + : block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded/16 // TODO 16 is m_block_dim + + m_block * kBlockM/16) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element*>(params.q_ptr) + binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)), + make_shape(binfo.actual_seqlen_q / kBlockM, params.h, kBlockM, params.d), + make_stride(params.q_row_stride * kBlockM, params.q_head_stride * kBlockM, params.q_head_stride, _1{})); + Tensor gQ = local_tile(mQ(m_block, bidh, _, _), Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_coord(0, 0)); // (kBlockM, kHeadDim) + Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k), + Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_stride(params.k_row_stride, _1{})); + Tensor gC = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v), + Shape<Int<kBlockN>, Int<kHeadDim>>{}, + make_stride(params.v_row_stride, _1{})); + // if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); } + Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p), + Shape<Int<kBlockM/16>, Int<kBlockN>>{}, // TODO 16 is m_block_dim + make_stride(params.seqlen_k_rounded, _1{})); + + Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + + Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + Tensor tKgK = gmem_thr_copy_QKV.partition_S(gC); // (KCPY, KCPY_N, KCPY_K) + Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + + Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + // PREDICATES + // + + // // Allocate predicate tensors for m and n + // Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{}); + // Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{}); + + // Construct identity layout for sQ and sK + Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ))); + Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; } + #pragma unroll + for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; } + } + + // Prologue + + // Copy from Knew to K, optionally apply rotary embedding. + typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary; + auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont; + auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx); + + // Read Q from gmem to smem, optionally apply rotary embedding. + if (!Append_KV || params.rotary_dim == 0) { + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + } else { + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. + // We do this by setting the row stride of gCos / gSin to 0. + Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), + Shape<Int<kBlockM>, Int<kHeadDim / 2>>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSin = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin), + Shape<Int<kBlockM>, Int<kHeadDim / 2>>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gCosCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_cos_ptr) + row_offset_cossin), + Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin), + Shape<Int<kBlockM>, Int<kHeadDim>>{}, + make_stride(Is_causal || Is_local ? params.rotary_dim / 2 : 0, _1{})); + Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos); + Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); + Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); + Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); + if (params.is_rotary_interleaved) { + flash::copy_rotary_interleaved<Is_even_K>( + tQgQ, tQsQ, tRgCos, tRgSin, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } else { + flash::copy_rotary_contiguous<Is_even_K>( + tQgQ, tQsQ, tRgCosCont, tRgSinCont, tQcQ, binfo.actual_seqlen_q - m_block * kBlockM, + 0, params.d, params.rotary_dim + ); + } + } + + int n_block = n_block_max_c - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + params.seqlen_c - n_block * kBlockN); + cute::cp_async_fence(); + + flash::Softmax<2 * size<1>(acc_o)> softmax; + + const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax; + flash::Mask<Is_causal, Is_local, Has_alibi> mask(params.seqlen_c, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope, nullptr, 0, 0, params.m_block_dim); + + fwdIterator blockmask(params, binfo, kBlockM, kBlockN, bidb, bidh, m_block, n_block_min, n_block_max_c); + int next_block_idx = blockmask.max_no_larger(n_block_max_c-1); + int leap = 0; + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = (!Is_causal && !Is_local) + ? 1 + : ((Is_even_MN && Is_causal) ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + const bool skip = (n_block != next_block_idx); + Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + if (!skip) { + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } + + mask.template apply_mask_stage1<Is_causal, Is_even_MN>( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + next_block_idx = blockmask.max_no_larger(n_block-1); + } else { + mask.all_mask(acc_s); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // __syncthreads(); + + leap = (masking_step + 1 == n_masking_steps) ? n_block - next_block_idx : 1; + + if (n_block > n_block_min && next_block_idx != -1) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * leap * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - leap) * kBlockN / params.page_block_size; + const int block_table_offset_next =(n_block - leap) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + masking_step == 0 + ? softmax.template softmax_rescale_simple</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, params.scale_softmax_log2) + : softmax.template softmax_rescale_simple</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, params.scale_softmax_log2); + // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } + + if (!skip) { + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type<Element>(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + leap = n_block - next_block_idx + 1; + + // These are the iterations where we don't need masking on S + for (n_block = next_block_idx; n_block != -1 && n_block >= n_block_min; n_block = next_block_idx) { + next_block_idx = blockmask.max_no_larger(n_block - 1); + + Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + + leap = n_block - next_block_idx; + if (next_block_idx != -1) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * leap * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - leap) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - leap) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + mask.template apply_mask_stage1</*Causal_mask=*/false>( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + softmax.template softmax_rescale_simple</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, params.scale_softmax_log2); + } + + // Epilogue + + softmax.get_row_sum(); + + { // second time + tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + + n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope, nullptr, 0, 0, params.m_block_dim); + fwdIterator blockmask(params, binfo, kBlockM, kBlockN, bidb, bidh, m_block, n_block_min, n_block_max); + + next_block_idx = blockmask.max_no_larger(n_block_max-1); + leap = 0; + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + const bool skip = (n_block != next_block_idx); + Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + if (!skip) { + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } + + mask.template apply_mask_stage1<Is_causal, Is_even_MN>( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + + next_block_idx = blockmask.max_no_larger(n_block-1); + } else { + mask.all_mask(acc_s); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + // if (tidx == 0 && blockIdx.y == 0 && blockIdx.z == 0) { print(tVsV); } + // __syncthreads(); + + leap = (masking_step + 1 == n_masking_steps) ? n_block - next_block_idx : 1; + + if (n_block > n_block_min && next_block_idx != -1) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * leap * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - leap) * kBlockN / params.page_block_size; + const int block_table_offset_next =(n_block - leap) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // We have key_padding_mask so we'll need to Check_inf + softmax.template softmax_rescale_gt(acc_s, params.scale_softmax_log2); + // if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); } + + if (!skip) { + // Convert acc_s from fp32 to fp16/bf16 + Tensor rP = flash::convert_type<Element>(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + + if (params.p_ptr != nullptr) { + hdim16_reduce<Element>( + acc_s, + gP, + n_block * kBlockN, + m_block * kBlockM, + kNWarps * 16 + ); + gP.data() = gP.data() + (-kBlockN); + } + } + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= n_block_min) { + --n_block; + break; + } + } + + leap = n_block - next_block_idx + 1; + + // These are the iterations where we don't need masking on S + for (n_block = next_block_idx; n_block != -1 && n_block >= n_block_min; n_block = next_block_idx) { + next_block_idx = blockmask.max_no_larger(n_block - 1); + + Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K + ); + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + + leap = n_block - next_block_idx; + if (next_block_idx != -1) { + // Advance gK + if (block_table == nullptr) { + tKgK.data() = tKgK.data() + (-int(kBlockN * leap * params.k_row_stride)); + } else { + const int block_table_idx_cur = n_block * kBlockN / params.page_block_size; + const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size; + const int block_table_idx_next = (n_block - leap) * kBlockN / params.page_block_size; + const int block_table_offset_next = (n_block - leap) * kBlockN - block_table_idx_next * params.page_block_size; + tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride; + } + flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + mask.template apply_mask_stage1</*Causal_mask=*/false>( + acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 + ); + softmax.template softmax_rescale_gt(acc_s, params.scale_softmax_log2); + + Tensor rP = flash::convert_type<Element>(acc_s); + // Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8. + + if (params.p_ptr != nullptr) { + hdim16_reduce<Element>( + acc_s, + gP, + n_block * kBlockN, + m_block * kBlockM, + kNWarps * 16 + ); + gP.data() = gP.data() + (-kBlockN); + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax, typename Params> +inline __device__ void compute_attn(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. + + flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params, bidb, bidh, m_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, typename Params> +inline __device__ void compute_attn_splitkv(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = Split ? blockIdx.z / params.h : blockIdx.y; + // The block index for the head. + const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; + const int n_split_idx = Split ? blockIdx.y : 0; + const int num_n_splits = Split ? gridDim.y : 1; + if (params.m_block_dim == 1) { + flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + } else { + flash::compute_attn_1rowblock_splitkv_headlevel<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + } +} + +template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV, typename Params> +inline __device__ void compute_attn_splitkv_stage1(const Params ¶ms) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = Split ? blockIdx.z / params.h : blockIdx.y; + // The block index for the head. + const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; + const int n_split_idx = Split ? blockIdx.y : 0; + const int num_n_splits = Split ? gridDim.y : 1; + flash::compute_attn_1rowblock_splitkv_stage1<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K, typename Params> +inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + constexpr int kMaxSplits = 1 << Log_max_splits; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNThreads = Kernel_traits::kNThreads; + + static_assert(kMaxSplits <= 128, "kMaxSplits must be <= 128"); + static_assert(kBlockM == 4 || kBlockM == 8 || kBlockM == 16 || kBlockM == 32, "kBlockM must be 4, 8, 16 or 32"); + + // Shared memory. + // kBlockM + 1 instead of kBlockM to reduce bank conflicts. + __shared__ ElementAccum sLSE[kMaxSplits][kBlockM + 1]; + + // The thread and block index. + const int tidx = threadIdx.x; + const int bidx = blockIdx.x; + + const index_t lse_size = params.b * params.h * params.seqlen_q; + + const index_t row_offset_lse = bidx * kBlockM; + Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lse), + Shape<Int<kMaxSplits>, Int<kBlockM>>{}, + make_stride(lse_size, _1{})); + + // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile. + // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}. + Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse), + Shape<Int<kBlockM>>{}, Stride<_1>{}); + + // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. + Layout flat_layout = make_layout(lse_size); + Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); + auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q); + Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); + Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); + + Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr)), final_layout); + + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; + + // Read the LSE values from gmem and store them in shared memory, then transpose them. + constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadLSE + tidx / kBlockM; + const int col = tidx % kBlockM; + ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + if (row < kMaxSplits) { sLSE[row][col] = lse; } + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } + } + // if (bidx == 1 && tidx < 32) { printf("tidx = %d, row_offset_lse = %d, lse = %f\n", tidx, row_offset_lse, lse_accum(0)); } + __syncthreads(); + Tensor lse_accum = make_tensor<ElementAccum>(Shape<Int<kNLsePerThread>>{}); + constexpr int kRowsPerLoadTranspose = std::min(kRowsPerLoadLSE, kMaxSplits); + // To make sure that kMaxSplits is within 1 warp: we decide how many elements within kMaxSplits + // each thread should hold. If kMaxSplits = 16, then each thread holds 2 elements (128 threads, + // kBlockM rows, so each time we load we can load 128 / kBlockM rows). + // constexpr int kThreadsPerSplit = kMaxSplits / kRowsPerLoadTranspose; + // static_assert(kThreadsPerSplit <= 32); + static_assert(kRowsPerLoadTranspose <= 32); + static_assert(kNLsePerThread * kRowsPerLoadTranspose <= kMaxSplits); + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + lse_accum(l) = (row < kMaxSplits && col < kBlockM) ? sLSE[row][col] : -INFINITY; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse_accum(l)); } + } + + // Compute the logsumexp of the LSE along the split dimension. + ElementAccum lse_max = lse_accum(0); + #pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { lse_max = max(lse_max, lse_accum(l)); } + MaxOp<float> max_op; + lse_max = Allreduce<kRowsPerLoadTranspose>::run(lse_max, max_op); + lse_max = lse_max == -INFINITY ? 0.0f : lse_max; // In case all local LSEs are -inf + float lse_sum = expf(lse_accum(0) - lse_max); + #pragma unroll + for (int l = 1; l < kNLsePerThread; ++l) { lse_sum += expf(lse_accum(l) - lse_max); } + SumOp<float> sum_op; + lse_sum = Allreduce<kRowsPerLoadTranspose>::run(lse_sum, sum_op); + // For the case where all local lse == -INFINITY, we want to set lse_logsum to INFINITY. Otherwise + // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. + ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; + // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } + if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { + if (params.unpadded_lse) { + const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; + if (lse_offset < lse_size) { + gLSE_unpadded(lse_offset) = lse_logsum; + } + } else { + gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + } + } + // Store the scales exp(lse - lse_logsum) in shared memory. + #pragma unroll + for (int l = 0; l < kNLsePerThread; ++l) { + const int row = l * kRowsPerLoadTranspose + tidx % kRowsPerLoadTranspose; + const int col = tidx / kRowsPerLoadTranspose; + if (row < params.num_splits && col < kBlockM) { sLSE[row][col] = expf(lse_accum(l) - lse_logsum); } + } + __syncthreads(); + + const index_t row_offset_oaccum = bidx * kBlockM * params.d_rounded; + Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum), + Shape<Int<kBlockM>, Int<kHeadDim>>{}, + Stride<Int<kHeadDim>, _1>{}); + constexpr int kBlockN = kNThreads / kBlockM; + using GmemLayoutAtomOaccum = Layout<Shape<Int<kBlockM>, Int<kBlockN>>, Stride<Int<kBlockN>, _1>>; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, + GmemLayoutAtomOaccum{}, + Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store + GmemTiledCopyOaccum gmem_tiled_copy_Oaccum; + auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx); + Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum); + Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum)); + Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum)); + clear(tOrO); + + // Predicates + Tensor cOaccum = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); + // Repeat the partitioning with identity layouts + Tensor tOcOaccum = gmem_thr_copy_Oaccum.partition_S(cOaccum); + Tensor tOpOaccum = make_tensor<bool>(make_shape(size<2>(tOgOaccum))); + if (!Is_even_K) { + #pragma unroll + for (int k = 0; k < size(tOpOaccum); ++k) { tOpOaccum(k) = get<1>(tOcOaccum(0, 0, k)) < params.d; } + } + // Load Oaccum in then scale and accumulate to O + for (int split = 0; split < params.num_splits; ++split) { + flash::copy</*Is_even_MN=*/false, Is_even_K>( + gmem_tiled_copy_Oaccum, tOgOaccum, tOrOaccum, tOcOaccum, tOpOaccum, params.b * params.h * params.seqlen_q - bidx * kBlockM + ); + #pragma unroll + for (int m = 0; m < size<1>(tOrOaccum); ++m) { + int row = get<0>(tOcOaccum(0, m, 0)); + ElementAccum lse_scale = sLSE[split][row]; + #pragma unroll + for (int k = 0; k < size<2>(tOrOaccum); ++k) { + #pragma unroll + for (int i = 0; i < size<0>(tOrOaccum); ++i) { + tOrO(i, m, k) += lse_scale * tOrOaccum(i, m, k); + } + } + // if (cute::thread0()) { printf("lse_scale = %f, %f\n", sLSE[split][0], sLSE[split][1]); print(tOrOaccum); } + } + tOgOaccum.data() = tOgOaccum.data() + params.b * params.h * params.seqlen_q * params.d_rounded; + } + // if (cute::thread0()) { print_tensor(tOrO); } + + Tensor rO = flash::convert_type<Element>(tOrO); + // Write to gO + #pragma unroll + for (int m = 0; m < size<1>(rO); ++m) { + const int idx = bidx * kBlockM + get<0>(tOcOaccum(0, m, 0)); + if (idx < params.b * params.h * params.seqlen_q) { + const int batch_idx = idx / (params.h * params.seqlen_q); + const int head_idx = (idx - batch_idx * (params.h * params.seqlen_q)) / params.seqlen_q; + // The index to the rows of Q + const int row = idx - batch_idx * (params.h * params.seqlen_q) - head_idx * params.seqlen_q; + auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride + + head_idx * params.o_head_stride + row * params.o_row_stride; + #pragma unroll + for (int k = 0; k < size<2>(rO); ++k) { + if (Is_even_K || tOpOaccum(k)) { + const int col = get<1>(tOcOaccum(0, m, k)); + Tensor gO = make_tensor(make_gmem_ptr(o_ptr + col), + Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{}); + // TODO: Should check if this is using vectorized store, but it seems pretty fast + copy(rO(_, m, k), gO); + // if (bidx == 0 && tidx == 0) { printf("tidx = %d, idx = %d, batch_idx = %d, head_idx = %d, row = %d, col = %d\n", tidx, idx, batch_idx, head_idx, row, col); print(rO(_, m, k)); print(gO); } + // reinterpret_cast<uint64_t *>(o_ptr)[col / 4] = recast<uint64_t>(rO)(0, m, k); + } + } + } + } +} + +} // namespace flash diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_launch_template.h b/examples/CPM.cu/src/flash_attn/src/flash_fwd_launch_template.h new file mode 100644 index 00000000..cdb50ba7 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_launch_template.h @@ -0,0 +1,382 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "static_switch.h" +#include "flash.h" +#include "flash_fwd_kernel.h" + +// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#define ARCH_SUPPORTS_FLASH +#define KERNEL_PARAM_MODIFIER __grid_constant__ +#else +#define KERNEL_PARAM_MODIFIER +#endif + +// Define a macro for unsupported architecture handling to centralize the error message +#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!"); + +// Use a macro to clean up kernel definitions +#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \ +template<typename Kernel_traits, __VA_ARGS__> \ +__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { + #if defined(ARCH_SUPPORTS_FLASH) + static_assert(!(Is_causal && Is_local)); // Enforce constraints + flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Return_softmax>(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { + #if defined(ARCH_SUPPORTS_FLASH) + flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_stage1_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { + #if defined(ARCH_SUPPORTS_FLASH) + flash::compute_attn_splitkv_stage1<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, Split, Append_KV>(params); + #else + FLASH_UNSUPPORTED_ARCH + #endif +} + +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) { + static_assert(Log_max_splits >= 1); + flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params); +} + +template<typename Kernel_traits, bool Is_dropout, bool Is_causal> +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + // printf("smem_size = %d\n", smem_size); + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + const bool return_softmax = params.p_ptr != nullptr; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout && !Is_softcap, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, ReturnSoftmaxConst && Is_dropout && !Is_softcap>; + // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>; + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); + }); + }); + }); + }); + }); + }); +} + +template<typename Kernel_traits, bool Is_causal> +void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + constexpr static bool Is_local = false; { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + constexpr static bool Append_KV = false; { + constexpr static bool Has_alibi = false; { + constexpr static bool Is_softcap = false; { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>; + // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>; + // auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); + } + } + } + }); + } + }); + }); + if (params.num_splits > 1) { + // We want kBlockM to be as small as possible for more parallelism. + // With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4. + // If headdim is divisible by 64, then we set kBlockM = 8, etc. + constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16); + dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM); + params.seqlen_q /= params.m_block_dim; + params.h *= params.m_block_dim; + params.o_row_stride *= params.m_block_dim; + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + if (params.num_splits <= 2) { + flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params); + } else if (params.num_splits <= 4) { + flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params); + } else if (params.num_splits <= 8) { + flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params); + } else if (params.num_splits <= 16) { + flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params); + } else if (params.num_splits <= 32) { + flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params); + } else if (params.num_splits <= 64) { + flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params); + } else if (params.num_splits <= 128) { + flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params); + } + }); + } +} + +template<typename Kernel_traits, bool Is_causal> +void run_flash_splitkv_fwd_stage1(Flash_fwd_params ¶ms, cudaStream_t stream) { + static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); + static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); + constexpr size_t smem_size = Kernel_traits::kSmemSize; + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + constexpr static bool Is_local = false; { + constexpr static bool Split = false; { + constexpr static bool Append_KV = false; { + constexpr static bool Has_alibi = false; { + constexpr static bool Is_softcap = false; { + // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_splitkv_stage1_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap, Split, Append_KV>; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params); + } + } + } + } + } + }); + }); +} + +template<typename T, int Headdim, bool Is_causal> +void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { + if (params.blockmask == nullptr) { // TODO improve this + // if (params.blockmask == nullptr && params.block_window_size == 0) { + constexpr static int kBlockM = 64; // Fixed for all head dimensions + // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, + // and for headdim 192 with block size 64 x 128. + // Also for headdim 160 with block size 64 x 128 after the rotary addition. + constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); + if (params.m_block_dim == 1) { + run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream); + } else { + run_flash_splitkv_fwd_stage1<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>, Is_causal>(params, stream); + } + } else { + constexpr static int kBlockM = 16; + constexpr static int kBlockN = 64; + run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 1, false, false, T>, Is_causal>(params, stream); + } +} + +template<typename T, bool Is_causal> +void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 32; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + }); +} + +template<typename T, bool Is_causal> +void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 64; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + } + }); +} + +template<typename T, bool Is_causal> +void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 96; + cudaDeviceProp props; + cudaGetDeviceProperties(&props, 0); + bool is_sm8x = props.major == 8 && props.minor > 0; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream); + }); +} + +template<typename T, bool Is_causal> +void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + cudaDeviceProp props; + cudaGetDeviceProperties(&props, 0); + bool is_sm8x = props.major == 8 && props.minor > 0; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy + } else { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream); + } + }); +} + +template<typename T, bool Is_causal> +void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 160; + cudaDeviceProp props; + cudaGetDeviceProperties(&props, 0); + bool is_sm8x = props.major == 8 && props.minor > 0; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // For A100, H100, 128 x 32 is the fastest. + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 64 with 8 warps is the fastest for non-causal. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + } + } else { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream); + }); +} + +template<typename T, bool Is_causal> +void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 192; + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + if constexpr(!Is_dropout) { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream); + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream); + }); +} + +template<typename T, bool Is_causal> +void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 256; + int device; + cudaGetDevice(&device); + int max_smem_per_sm, max_smem_per_block; + cudaError status_ = cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device); + status_ = cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + if (status_ != cudaSuccess) { + } + // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); + DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + } + // 64 KB + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream); + // 96 KB + // run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream); + }); +} diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu new file mode 100644 index 00000000..a959c9ce --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu new file mode 100644 index 00000000..e608e308 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu new file mode 100644 index 00000000..3dd74e27 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu new file mode 100644 index 00000000..addacedf --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu new file mode 100644 index 00000000..8ace7bda --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu new file mode 100644 index 00000000..1e133ec1 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu new file mode 100644 index 00000000..1723c69e --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu new file mode 100644 index 00000000..892d2352 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu new file mode 100644 index 00000000..d07ee0af --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu new file mode 100644 index 00000000..23cfa59d --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu new file mode 100644 index 00000000..273a2844 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu new file mode 100644 index 00000000..0f588d1f --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu new file mode 100644 index 00000000..ea024d9a --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 224, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu new file mode 100644 index 00000000..b06ae5ac --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 224, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu new file mode 100644 index 00000000..b217f378 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 224, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu new file mode 100644 index 00000000..8cf2eabe --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 224, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu new file mode 100644 index 00000000..370fe9ca --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu new file mode 100644 index 00000000..508f07f7 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu new file mode 100644 index 00000000..019ded67 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu new file mode 100644 index 00000000..708f5542 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu new file mode 100644 index 00000000..5a205b7e --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu new file mode 100644 index 00000000..2c576f11 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu new file mode 100644 index 00000000..484a15e9 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu new file mode 100644 index 00000000..5474ae89 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu new file mode 100644 index 00000000..8c7da41d --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu new file mode 100644 index 00000000..93f29dea --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu new file mode 100644 index 00000000..1e2e12b8 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu new file mode 100644 index 00000000..16c34ed3 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu new file mode 100644 index 00000000..50080c47 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu new file mode 100644 index 00000000..ae56ddd4 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu new file mode 100644 index 00000000..ed305767 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, true>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu new file mode 100644 index 00000000..02206465 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96, false>(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/examples/CPM.cu/src/flash_attn/src/generate_kernels.py b/examples/CPM.cu/src/flash_attn/src/generate_kernels.py new file mode 100644 index 00000000..119e3495 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/generate_kernels.py @@ -0,0 +1,108 @@ +# Copied from Driss Guessous's PR in PyTorch: https://github.com/pytorch/pytorch/pull/105602 + +# This file is run to generate the kernel instantiations for the flash_attn kernels +# They are written to several files in order to speed up compilation + +import argparse +import itertools +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional + +DTYPE_MAP = { + "fp16": "cutlass::half_t", + "bf16": "cutlass::bfloat16_t", +} + +SM = [80] # Sm80 kernels support up to +HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 256] +IS_CAUSAL = ["false", "true"] +KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ + run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); +}} +""" + +KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream); +""" + +KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h" + +template<> +void run_mha_bwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{ + run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); +}} +""" + + +@dataclass +class Kernel: + sm: int + dtype: str + head_dim: int + is_causal: bool + direction: str + + @property + def template(self) -> str: + if self.direction == "fwd": + return KERNEL_IMPL_TEMPLATE_FWD.format( + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal + ) + elif self.direction == "bwd": + return KERNEL_IMPL_TEMPLATE_BWD.format( + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal + ) + else: + return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format( + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal + ) + + @property + def filename(self) -> str: + return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu" + + +def get_all_kernels() -> List[Kernel]: + for direction in ["fwd", "fwd_split", "bwd"]: + for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM): + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction) + + +def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: + prelude = """// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py"\n +""" + (autogen_dir / kernel.filename).write_text(prelude + kernel.template) + + +def main(output_dir: Optional[str]) -> None: + if output_dir is None: + output_dir = Path(__file__).parent + else: + output_dir = Path(output_dir) + + for kernel in get_all_kernels(): + write_kernel(kernel, output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="generate_kernels", + description="Generate the flash_attention kernels template instantiations", + ) + # Set an optional output directory + parser.add_argument( + "-o", + "--output_dir", + required=False, + help="Where to generate the kernels " + " will default to the current directory ", + ) + args = parser.parse_args() + main(args.output_dir) diff --git a/examples/CPM.cu/src/flash_attn/src/kernel_traits.h b/examples/CPM.cu/src/flash_attn/src/kernel_traits.h new file mode 100644 index 00000000..91178dc3 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/kernel_traits.h @@ -0,0 +1,349 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include <cutlass/numeric_types.h> + +using namespace cute; + +template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t> +struct Flash_kernel_traits { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = int64_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v<elem_type, cutlass::half_t>, + MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>, + MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN> + >; +#else + using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>; + using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>; +#else + using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>; + using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t, + typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group + Tile<Int<16 * kNWarps>, _16, _16>>; + + using SmemLayoutAtomQ = decltype( + composition(Swizzle<kSwizzle, 3, 3>{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout<Shape<_8, Int<kBlockKSmem>>, + Stride<Int<kBlockKSmem>, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape<Int<kBlockM>, Int<kHeadDim>>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape<Int<kBlockN>, Int<kHeadDim>>{})); + + // https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434 + using SmemLayoutVtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{}))); + using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{})); + + using SmemLayoutAtomO = decltype( + composition(Swizzle<kSwizzle, 3, 3>{}, + Layout<Shape<Int<8>, Int<kBlockKSmem>>, + Stride<Int<kBlockKSmem>, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape<Int<kBlockM>, Int<kHeadDim>>{})); + using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>; + using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>; + + static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, + Stride<Int<kGmemThreadsPerRow>, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{}, + GmemLayoutAtom{}, + Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, + GmemLayoutAtom{}, + Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store + + using GmemLayoutAtomOaccum = std::conditional_t< + kNThreads == 32, // Special case for 1 warp (splitkv scenario) + Layout<Shape <_4, _8>, // 4 rows x 8 cols = 32 threads, matches combine kernel + Stride< _8, _1>>, + std::conditional_t< + kBlockKSmem == 32, + Layout<Shape <_16, _8>, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout<Shape <_8, _16>, // Thread layout, 16 threads per row + Stride< _16, _1>> + > + >; + using GmemTiledCopyOaccum = decltype( + make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, + GmemLayoutAtomOaccum{}, + Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store + using GmemLayoutAtomRotcossin = GmemLayoutAtom; + using GmemTiledCopyRotcossin = decltype( + make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load + using GmemTiledCopyRotcossinCont = decltype( + make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, + GmemLayoutAtomRotcossin{}, + Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load +}; + +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. +// No_double_buffer is another option to reduce smem usage, but will slow things down. +template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, + int AtomLayoutMSdP_=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=2, + bool Is_V_in_regs_=false, bool No_double_buffer_=false, typename elem_type=cutlass::half_t, + typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>, + Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>; + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>, + Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>; + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group + Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>; + + using SmemLayoutAtomQdO = decltype( + composition(Swizzle<kSwizzle, 3, 3>{}, + Layout<Shape<_8, Int<kBlockKSmem>>, + Stride<Int<kBlockKSmem>, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + make_shape(Int<kBlockM>{}, Int<kHeadDim>{}))); + + using SmemLayoutAtomKV = decltype( + composition(Swizzle<kSwizzle, 3, 3>{}, + Layout<Shape<Int<kBlockM / kNWarps>, Int<kBlockKSmem>>, + Stride<Int<kBlockKSmem>, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomKV{}, + make_shape(Int<kBlockN>{}, Int<kHeadDim>{}))); + + using SmemLayoutKtransposed = decltype( + composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{}))); + using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{})); + + // TODO: generalize to other values of kBlockN + // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 + // static constexpr int kPBlockN = kBlockN; + // Temporarily disabling this for hdim 256 on sm86 and sm89 + // static_assert(kBlockN >= 64); + static_assert(kBlockN >= 32); + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype( + composition(Swizzle<kSwizzlePdS, 3, 3>{}, + Layout<Shape<Int<kBlockM>, Int<kPBlockN>>, + Stride<Int<kPBlockN>, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + make_shape(Int<kBlockM>{}, Int<kBlockN>{}))); + using SmemLayoutPdStransposed = decltype( + composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{})); + + using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>; + + using SmemLayoutQdOtransposed = decltype( + composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{})); + + using SmemLayoutAtomdKV = decltype( + composition(Swizzle<kSwizzle, 3, 3>{}, + Layout<Shape<_8, Int<kBlockKSmem>>, + Stride<Int<kBlockKSmem>, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int<kBlockN>{}, Int<kHeadDim>{}))); + using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>; + + using SmemLayoutAtomdQ = decltype( + composition(Swizzle<kSwizzle, 3, 3>{}, + Layout<Shape<_8, Int<kBlockKSmem>>, + Stride<Int<kBlockKSmem>, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int<kBlockM>{}, Int<kHeadDim>{}))); + using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>; + + // Double buffer for sQ + static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element); + static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element); + static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element); + static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); + static constexpr int kSmemSize = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>, + Stride<Int<kGmemThreadsPerRow>, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>, + DefaultCopy + >; + using GmemTiledCopyQKV = decltype( + make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{}, + GmemLayoutAtom{}, + Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype( + make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, + GmemLayoutAtom{}, + Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype( + make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, + GmemLayoutAtom{}, + Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype( + make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{}, + GmemLayoutAtom{}, + Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + Layout<Shape <_32, _8>, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout<Shape <_16, _16>, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopydQaccum = decltype( + make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, + GmemLayoutAtomdQaccum{}, + Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype( + make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, + Layout<Shape <_8, _32>, // Thread layout, 8 threads per row + Stride<_32, _1>>{}, + Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/CPM.cu/src/flash_attn/src/mask.h b/examples/CPM.cu/src/flash_attn/src/mask.h new file mode 100644 index 00000000..026f2578 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/mask.h @@ -0,0 +1,328 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include <cute/tensor.hpp> + +namespace flash { + +using namespace cute; + +template <typename Engine, typename Layout> +__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { + // Without the "make_coord" we get wrong results + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template <bool HasWSLeft=true, typename Engine, typename Layout> +__forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride, + const int window_size_left, const int window_size_right) { + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template <typename Engine, typename Layout> +__forceinline__ __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset, + const int max_seqlen_q, const int warp_row_stride) { + // Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0 + apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset, + max_seqlen_q, warp_row_stride, -1, 0); +} + +template <typename Engine0, typename Layout0, typename Engine1, typename Layout1> +__forceinline__ __device__ void apply_mask_causal_w_idx( + Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol, + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) +{ + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0))); + #pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +template <bool Is_causal, bool Is_local, bool Has_alibi, bool Mask_2d=false> +struct Mask { + + const int max_seqlen_k, max_seqlen_q; + const int window_size_left, window_size_right; + const uint64_t *mask_2d; + const int mask_q_range, mask_k_begin; + const float alibi_slope; + const int m_block_dim; + + __forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q, + const int window_size_left, const int window_size_right, + const float alibi_slope=0.f, + const uint64_t *mask_2d = nullptr, + const int mask_q_range = 0, + const int mask_k_range = 0, + const int m_block_dim = 1) + : max_seqlen_k(max_seqlen_k) + , max_seqlen_q(max_seqlen_q) + , window_size_left(window_size_left) + , window_size_right(window_size_right) + , mask_2d(mask_2d) + , mask_q_range(mask_q_range) + , mask_k_begin(max_seqlen_k - mask_k_range) + , alibi_slope(!Has_alibi ? 0.0 : alibi_slope) + , m_block_dim(m_block_dim) { + }; + + // Causal_mask: whether this particular iteration needs causal masking + template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout> + __forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor_, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); + static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; + // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } + if constexpr (Need_masking) { + // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); + // Do we need both row and column indices, or just column incides? + static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Col_idx_only) { + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // No causal, no local + if constexpr (Has_alibi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + if constexpr (!Is_even_MN) { + if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; } + } + } + } + } + } else { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int orig_row_idx = row_idx / this->m_block_dim; + const int orig_max_seqlen_q = max_seqlen_q / this->m_block_dim; + + const int col_idx_limit_left = std::max(0, orig_row_idx + max_seqlen_k - orig_max_seqlen_q - window_size_left); + const int col_idx_limit_right = std::min(max_seqlen_k, orig_row_idx + 1 + max_seqlen_k - orig_max_seqlen_q + window_size_right); + const uint64_t mask = (Mask_2d && orig_row_idx < mask_q_range) ? mask_2d[orig_row_idx] : 0; + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if constexpr (Mask_2d) { + if (col_idx >= mask_k_begin && (mask >> (col_idx - mask_k_begin) & 1) == 0) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (Has_alibi) { + if constexpr (Is_causal) { + tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx; + } else { + tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(orig_row_idx + max_seqlen_k - orig_max_seqlen_q - col_idx); + + } + } + if constexpr (Causal_mask) { + if (col_idx >= col_idx_limit_right) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (Is_local) { + if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { + // Causal and Local already handles MN masking + if (col_idx >= max_seqlen_k) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } + } + } + }; + + // Causal_mask: whether this particular iteration needs causal masking + template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout> + __forceinline__ __device__ void apply_mask_stage1(Tensor<Engine, Layout> &tensor_, + const int col_idx_offset_, + const int row_idx_offset, + const int warp_row_stride) { + static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local"); + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); + static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN; + // if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); } + if constexpr (Need_masking) { + // Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); + // Do we need both row and column indices, or just column incides? + static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask; + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; + if constexpr (Col_idx_only) { + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // No causal, no local + if constexpr (Has_alibi) { + tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx; + } + if constexpr (!Is_even_MN) { + if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; } + } + } + } + } + } else { + #pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; + #pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + + const int orig_row_idx = row_idx / this->m_block_dim; + + const int col_idx_limit_left = 0; + const int col_idx_limit_right = std::min(max_seqlen_k, (orig_row_idx - /*TODO stride=*/16 + 1) / /*TODO stride=*/16 + window_size_right); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if constexpr (Causal_mask) { + if (col_idx >= col_idx_limit_right) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (Is_local) { + if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + if constexpr (!Causal_mask && !Is_local && !Is_even_MN) { + // Causal and Local already handles MN masking + if (col_idx >= max_seqlen_k) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + } + } + } + } + } + }; + + template <typename Engine, typename Layout> + __forceinline__ __device__ void all_mask(Tensor<Engine, Layout> &tensor_) { + static_assert(Layout::rank == 3, "Only support 3D Tensor"); + static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4"); + Tensor tensor = make_tensor(tensor_.data(), flash::convert_layout_acc_rowcol(tensor_.layout())); + #pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + #pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + }; +}; + +} // namespace flash diff --git a/examples/CPM.cu/src/flash_attn/src/philox.cuh b/examples/CPM.cu/src/flash_attn/src/philox.cuh new file mode 100644 index 00000000..cd7e4d2f --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/philox.cuh @@ -0,0 +1,51 @@ +// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h +#pragma once +// Philox CUDA. + +namespace flash { + +struct ull2 { + unsigned long long x; + unsigned long long y; +}; + +__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) { + uint2 *res; + unsigned long long tmp; + asm ("mul.wide.u32 %0, %1, %2;\n\t" + : "=l"(tmp) + : "r"(a), "r"(b)); + res = (uint2*)(&tmp); + return *res; +} + +__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) { + constexpr unsigned long kPhiloxSA = 0xD2511F53; + constexpr unsigned long kPhiloxSB = 0xCD9E8D57; + uint2 res0 = mulhilo32(kPhiloxSA, ctr.x); + uint2 res1 = mulhilo32(kPhiloxSB, ctr.z); + uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x}; + return ret; +} + +__forceinline__ __device__ uint4 philox(unsigned long long seed, + unsigned long long subsequence, + unsigned long long offset) { + constexpr unsigned long kPhilox10A = 0x9E3779B9; + constexpr unsigned long kPhilox10B = 0xBB67AE85; + uint2 key = reinterpret_cast<uint2&>(seed); + uint4 counter; + ull2 *tmp = reinterpret_cast<ull2*>(&counter); + tmp->x = offset; + tmp->y = subsequence; + #pragma unroll + for (int i = 0; i < 6; i++) { + counter = philox_single_round(counter, key); + key.x += (kPhilox10A); + key.y += (kPhilox10B); + } + uint4 output = philox_single_round(counter, key); + return output; +} + +} // namespace flash diff --git a/examples/CPM.cu/src/flash_attn/src/rotary.h b/examples/CPM.cu/src/flash_attn/src/rotary.h new file mode 100644 index 00000000..7f1614ad --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/rotary.h @@ -0,0 +1,152 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include <cute/tensor.hpp> + +#include "utils.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <bool Is_even_K=true, bool Clear_OOB_K=true, + typename Engine0, typename Layout0, typename Engine1, typename Layout1, + typename Engine2, typename Layout2, typename Engine3, typename Layout3> +__forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S, + Tensor<Engine1, Layout1> &D, + Tensor<Engine2, Layout2> const &Cos, + Tensor<Engine2, Layout2> const &Sin, + Tensor<Engine3, Layout3> const &identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K + static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + cute::copy(Cos(_, m, k), rCos(_, m, k)); + cute::copy(Sin(_, m, k), rSin(_, m, k)); + Tensor S_fp32 = convert_type<float>(rS(_, m, k)); + Tensor cos_fp32 = convert_type<float>(rCos(_, m, k)); + Tensor sin_fp32 = convert_type<float>(rSin(_, m, k)); + #pragma unroll + for (int i = 0; i < size<0>(rS) / 2; ++i) { + float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i); + float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i); + S_fp32(2 * i) = real; + S_fp32(2 * i + 1) = imag; + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type<T>(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <bool Is_even_K=true, bool Clear_OOB_K=true, + typename Engine0, typename Layout0, typename Engine1, typename Layout1, + typename Engine2, typename Layout2, typename Engine3, typename Layout3> +__forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S, + Tensor<Engine1, Layout1> &D, + Tensor<Engine2, Layout2> const &Cos, + Tensor<Engine2, Layout2> const &Sin, + Tensor<Engine3, Layout3> const &identity_MN, + const int max_MN, const int min_MN, + const int dim, const int rotary_dim) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA + CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); + static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32 + Tensor rCos = make_fragment_like(Cos); + Tensor rSin = make_fragment_like(Sin); + Tensor rS = make_fragment_like(S); + Tensor rS_other = make_fragment_like(rS(_, 0, 0)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) { + cute::copy(S(_, m, k), rS(_, m, k)); + if (get<1>(identity_MN(0, 0, k)) < rotary_dim) { + const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2; + Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout()); + cute::copy(gS_other, rS_other); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); } + Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout()); + Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout()); + cute::copy(gCos, rCos(_, m, k)); + cute::copy(gSin, rSin(_, m, k)); + // if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); } + Tensor S_fp32 = convert_type<float>(rS(_, m, k)); + Tensor S_other_fp32 = convert_type<float>(rS_other); + Tensor cos_fp32 = convert_type<float>(rCos(_, m, k)); + Tensor sin_fp32 = convert_type<float>(rSin(_, m, k)); + #pragma unroll + for (int i = 0; i < size<0>(rS); ++i) { + S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i)); + } + // Idk but I need to copy for the convert_type to work + Tensor S_fp32_copy = make_fragment_like(S_fp32); + cute::copy(S_fp32, S_fp32_copy); + using T = typename Engine0::value_type; + Tensor S_og_type = convert_type<T>(S_fp32_copy); + cute::copy(S_og_type, rS(_, m, k)); + // if (cute::thread0()) { print_tensor(rS(_, m, k)); } + } + cute::copy(rS(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/examples/CPM.cu/src/flash_attn/src/softmax.h b/examples/CPM.cu/src/flash_attn/src/softmax.h new file mode 100644 index 00000000..d420f752 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/softmax.h @@ -0,0 +1,259 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include <cmath> + +#include <cute/tensor.hpp> + +#include <cutlass/numeric_types.h> + +#include "philox.cuh" +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator> +__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator> +__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator> +__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) { + thread_reduce_<zero_init>(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1> +__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){ + MaxOp<float> max_op; + reduce_<zero_init>(tensor, max, max_op); +} + +template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1> +__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){ + SumOp<float> sum_op; + thread_reduce_<zero_init>(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1> +__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + // The following macro will disable the use of fma. + // See: https://github.com/pytorch/pytorch/issues/121558 for more details + // This macro is set in PyTorch and not FlashAttention + #ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled); + #else + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + #endif + } + } +} + +// Apply the exp to all the elements. +template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1> +__forceinline__ __device__ void get_softmax(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, Tensor<Engine1, Layout1> const &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + const float sum_scaled = 1. / sum(mi); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + // The following macro will disable the use of fma. + // See: https://github.com/pytorch/pytorch/issues/121558 for more details + // This macro is set in PyTorch and not FlashAttention + #ifdef UNFUSE_FMA + tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled) * sum_scaled; + #else + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled) * sum_scaled; + #endif + } + } +} + +// Apply the exp to all the elements. +template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1> +__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp<float> max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp<float> sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <int kNRows> +struct Softmax { + + using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{})); + TensorT row_max, row_sum; + + __forceinline__ __device__ Softmax() {}; + + template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1> + __forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) { + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + if (Is_first) { + flash::template reduce_max</*zero_init=*/true>(scores, row_max); + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum</*zero_init=*/true>(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max</*zero_init=*/false>(scores, row_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; } + } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum</*zero_init=*/false>(scores, row_sum); + } + }; + + template<bool Is_first, bool Check_inf=false, typename Tensor0> + __forceinline__ __device__ void softmax_rescale_simple(Tensor0 &acc_s, float softmax_scale_log2) { + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + if (Is_first) { + flash::template reduce_max</*zero_init=*/true>(scores, row_max); + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum</*zero_init=*/true>(scores, row_sum); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max</*zero_init=*/false>(scores, row_max); + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale; + } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum</*zero_init=*/false>(scores, row_sum); + } + }; + + template<typename Tensor0> + __forceinline__ __device__ void softmax_rescale_gt(Tensor0 &acc_s, float softmax_scale_log2) { + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + flash::get_softmax(scores, row_max, row_sum, softmax_scale_log2); + }; + + __forceinline__ __device__ void get_row_sum() { + SumOp<float> sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + } + + template<bool Is_dropout=false, bool Split=false, typename Tensor0> + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) { + SumOp<float> sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT lse = make_fragment_like(row_sum); + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); + float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; } + } + return lse; + }; +}; + +} // namespace flash diff --git a/examples/CPM.cu/src/flash_attn/src/static_switch.h b/examples/CPM.cu/src/flash_attn/src/static_switch.h new file mode 100644 index 00000000..55b265fa --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/static_switch.h @@ -0,0 +1,142 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +#include <stdexcept> + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function<BoolConst>(...); +/// }); +/// ``` + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#ifdef FLASHATTENTION_DISABLE_DROPOUT + #define DROPOUT_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define DROPOUT_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_ALIBI + #define ALIBI_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define ALIBI_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_UNEVEN_K + #define EVENK_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + }() +#else + #define EVENK_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_SOFTCAP + #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define SOFTCAP_SWITCH BOOL_SWITCH +#endif + +#ifdef FLASHATTENTION_DISABLE_LOCAL + #define LOCAL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define LOCAL_SWITCH BOOL_SWITCH +#endif + +#if defined(ENABLE_DTYPE_FP16) && defined(ENABLE_DTYPE_BF16) +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() +#elif defined(ENABLE_DTYPE_FP16) +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (!(COND)) { \ + throw std::runtime_error("BF16 support not compiled. Please recompile with CPMCU_DTYPE=bf16 or CPMCU_DTYPE=fp16,bf16"); \ + } \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + }() +#elif defined(ENABLE_DTYPE_BF16) +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + throw std::runtime_error("FP16 support not compiled. Please recompile with CPMCU_DTYPE=fp16 or CPMCU_DTYPE=fp16,bf16"); \ + } \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + }() +#else +#error "At least one of ENABLE_DTYPE_FP16 or ENABLE_DTYPE_BF16 must be defined" +#endif + +// TODO only compile 64 for debug +#define HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + }() + // if (HEADDIM <= 32) { \ + // constexpr static int kHeadDim = 32; \ + // return __VA_ARGS__(); \ + // } else if (HEADDIM <= 64) { \ + // constexpr static int kHeadDim = 64; \ + // return __VA_ARGS__(); \ + // } else if (HEADDIM <= 96) { \ + // constexpr static int kHeadDim = 96; \ + // return __VA_ARGS__(); \ + // } else if (HEADDIM <= 128) { \ + // constexpr static int kHeadDim = 128; \ + // return __VA_ARGS__(); \ + // } else if (HEADDIM <= 160) { \ + // constexpr static int kHeadDim = 160; \ + // return __VA_ARGS__(); \ + // } else if (HEADDIM <= 192) { \ + // constexpr static int kHeadDim = 192; \ + // return __VA_ARGS__(); \ + // } else if (HEADDIM <= 256) { \ + // constexpr static int kHeadDim = 256; \ + // return __VA_ARGS__(); \ + // } \ + // }() diff --git a/examples/CPM.cu/src/flash_attn/src/utils.h b/examples/CPM.cu/src/flash_attn/src/utils.h new file mode 100644 index 00000000..b7408ec4 --- /dev/null +++ b/examples/CPM.cu/src/flash_attn/src/utils.h @@ -0,0 +1,411 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include <assert.h> +#include <stdint.h> +#include <stdlib.h> + +#include <cuda_fp16.h> + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include <cuda_bf16.h> +#endif + +#include <cute/tensor.hpp> + +#include <cutlass/array.h> +#include <cutlass/cutlass.h> +#include <cutlass/numeric_conversion.h> +#include <cutlass/numeric_types.h> + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<typename T> +__forceinline__ __device__ uint32_t relu2(const uint32_t x); + +template<> +__forceinline__ __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); +#else + asm volatile( \ + "{\n" \ + "\t .reg .f16x2 sela;\n" \ + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" \ + "\t and.b32 %0, sela, %1;\n" + "}\n" : "=r"(res) : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template<> +__forceinline__ __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template<typename T> +__forceinline__ __device__ uint32_t convert_relu2(const float2 x); + +template<> +__forceinline__ __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast<const uint32_t&>(x.x); + const uint32_t b = reinterpret_cast<const uint32_t&>(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +template<> +__forceinline__ __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast<const uint32_t&>(x.x); + const uint32_t b = reinterpret_cast<const uint32_t&>(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<typename T> +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp<float> { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<typename T> +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<int THREADS> +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template<typename T, typename Operator> + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce<OFFSET>::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template<typename T, typename Operator> +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1, + typename Tensor2, typename Tensor3, typename Tensor4, + typename TiledMma, typename TiledCopyA, typename TiledCopyB, + typename ThrCopyA, typename ThrCopyB> +__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); } + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); } + if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, + typename TiledMma, typename TiledCopy, typename ThrCopy> +__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + #pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template<typename Layout> +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +template<typename MMA_traits, typename Layout> +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +template<typename Layout> +__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) { + using X = Underscore; + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <typename To_type, typename Engine, typename Layout> +__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data())); + return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <typename Engine, typename Layout> +__forceinline__ __device__ void relu_(Tensor<Engine, Layout> &tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast<uint32_t>(tensor); + #pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2<value_t>(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template <typename To_type, typename Engine, typename Layout> +__forceinline__ __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>); + static_assert(std::is_same_v<float, From_type>); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast<float2>(tensor); + Tensor out_uint32 = make_tensor<uint32_t>(tensor_float2.layout()); + #pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2<To_type>(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr<To_type>(out_uint32.data()), tensor.layout()); +#else + Tensor out = flash::convert_type<To_type>(tensor); + flash::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template <int N> +CUTE_HOST_DEVICE +void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true, + typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1, + typename Engine2, typename Layout2, typename Engine3, typename Layout3> +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S, + Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN, + Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } + // TD [2023-04-13]: Strange that the code below can cause race condition. + // I think it's because the copies are under an if statement. + // if (Is_even_K) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, _), D(_, m, _)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, _)); + // } + // } + // } else { // It's slightly faster in this case if iterate over K first + // #pragma unroll + // for (int k = 0; k < size<2>(S); ++k) { + // if (predicate_K(k)) { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + // copy(tiled_copy, S(_, m, k), D(_, m, k)); + // } else if (Clear_OOB_MN) { + // clear(D(_, m, k)); + // } + // } + // } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN + // if (Clear_OOB_MN || Is_even_MN) { + // clear(D(_, _, k)); + // } else { + // #pragma unroll + // for (int m = 0; m < size<1>(S); ++m) { + // if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) { + // clear(D(_, m, k)); + // } + // } + // } + // } + // } + // } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <bool Is_even_K=true, + typename Engine0, typename Layout0, typename Engine1, typename Layout1, + typename Engine2, typename Layout2, typename Engine3, typename Layout3> +__forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S, + Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN, + Tensor<Engine3, Layout3> const &predicate_K, + const int max_MN=0, const int min_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); } + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) { + // if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); } + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(S(_, m, k), D(_, m, k)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <typename Engine, typename Layout> +__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); + } +} + +template <typename Engine0, typename Layout0, typename Engine1, typename Layout1> +__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(src_tensor); ++i) { + dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/examples/CPM.cu/src/model/activation.cuh b/examples/CPM.cu/src/model/activation.cuh new file mode 100644 index 00000000..f9c9d8ce --- /dev/null +++ b/examples/CPM.cu/src/model/activation.cuh @@ -0,0 +1,83 @@ +#pragma once +#include "../trait.cuh" +#include <cuda_runtime.h> + +namespace { +template <typename T> +__global__ void gated_silu_interleaved_kernel(int dim, const T* src, T* tgt) { + int row_offset = blockIdx.x * dim; + int row_offset_2 = row_offset * 2; + int col = blockIdx.y * blockDim.x + threadIdx.x; + int col2 = col + dim; + if (col < dim) { + float g = float(src[row_offset_2 + col]); + float u = float(src[row_offset_2 + col2]); + float s = 1.0f / (1.0f + expf(-g)); + tgt[row_offset + col] = T(g * s * u); + } +} + +template<typename T> +__global__ void gated_silu_kernel(int dim, const T* src, T* tgt) { + int row_offset = blockIdx.x * dim; + int col = blockIdx.y * blockDim.x + threadIdx.x; + if (col < dim) { + float g = float(src[row_offset + col]); + float u = float(tgt[row_offset + col]); + float s = 1.0f / (1.0f + expf(-g)); + tgt[row_offset + col] = T(g * s * u); + } +} + +template<typename T> +__global__ void silu_kernel(int dim, const T* src, T* tgt) { + int row_offset = blockIdx.x * dim; + int col = blockIdx.y * blockDim.x + threadIdx.x; + if (col < dim) { + float g = float(src[row_offset + col]); + float s = 1.0f / (1.0f + expf(-g)); + tgt[row_offset + col] = T(g * s); + } +} + +template<typename T> +__global__ void relu_kernel(int dim, const T* src, T* tgt) { + int row_offset = blockIdx.x * dim; + int col = blockIdx.y * blockDim.x + threadIdx.x; + if (col < dim) { + T v = src[row_offset + col]; + tgt[row_offset + col] = v > T(0) ? v : T(0); + } +} +} + +template <typename T> +void gated_silu_interleaved(const Stream& stream, int num_tokens, int dim, const T* src, T* tgt) { + gated_silu_interleaved_kernel<T><<<dim3(num_tokens, CEIL_DIV(dim, 256)), 256, 0, stream.stream>>>(dim, src, tgt); +} + +template <typename T> +void gated_silu(const Stream& stream, int num_tokens, int dim, const T* src, T* tgt) { + gated_silu_kernel<T><<<dim3(num_tokens, CEIL_DIV(dim, 256)), 256, 0, stream.stream>>>(dim, src, tgt); +} + +template<typename T> +void silu(const Stream& stream, int num_tokens, int dim, const T* src, T* tgt) { + silu_kernel<T><<<dim3(num_tokens, CEIL_DIV(dim, 256)), 256, 0, stream.stream>>>(dim, src, tgt); +} + +template <typename T> +void silu_inplace(const Stream& stream, int num_tokens, int dim, T* x) { + silu(stream, num_tokens, dim, x, x); +} + + +template<typename T> +void relu(const Stream& stream, int num_tokens, int dim, const T* src, T* tgt) { + relu_kernel<T><<<dim3(num_tokens, CEIL_DIV(dim, 256)), 256, 0, stream.stream>>>(dim, src, tgt); +} + +template <typename T> +void relu_inplace(const Stream& stream, int num_tokens, int dim, T* x) { + relu(stream, num_tokens, dim, x, x); +} \ No newline at end of file diff --git a/examples/CPM.cu/src/model/attn.cuh b/examples/CPM.cu/src/model/attn.cuh new file mode 100644 index 00000000..9988742f --- /dev/null +++ b/examples/CPM.cu/src/model/attn.cuh @@ -0,0 +1,235 @@ +#pragma once +#include "../trait.cuh" +#include "../utils.cuh" +#include "../flash_attn/flash_api.hpp" +#include "perf.cuh" +#include "norm.cuh" +#include "linear.cuh" +#include "rotary.cuh" +#include "kvcache.cuh" +#include "mask.cuh" +#include <cuda_runtime.h> + +namespace { +__global__ void copy_to_kvcache_kernel(int num_tokens, int dim, int total, float4* k, float4* v, float4* k_cache, float4* v_cache, int* cache_length) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int offset = idx + (cache_length[0] - num_tokens) * dim; + if (idx < total) { + k_cache[offset] = k[idx]; + v_cache[offset] = v[idx]; + } +} + +__global__ void permute_kernel(int num_tokens, int a, int b, float4* input, float4* output) { + int row = blockIdx.x; + int input_offset = row * (a + b + b); + int output_offset = row * a; + for (int i = threadIdx.x; i < a; i += blockDim.x) { + output[output_offset + i] = input[input_offset + i]; + } + input_offset += a; + output_offset = num_tokens * a + row * b; + for (int i = threadIdx.x; i < b; i += blockDim.x) { + output[output_offset + i] = input[input_offset + i]; + } + input_offset += b; + output_offset = num_tokens * (a + b) + row * b; + for (int i = threadIdx.x; i < b; i += blockDim.x) { + output[output_offset + i] = input[input_offset + i]; + } +} + +template <typename T> +void copy_to_kvcache(const Stream& stream, int num_tokens, T* k, T* v, KVCache<T>* kv_cache, int* cache_length) { + int dim = kv_cache->dim / (16/sizeof(T)); + int total = num_tokens * dim; + copy_to_kvcache_kernel<<<CEIL_DIV(total, 256), 256, 0, stream.stream>>>(num_tokens, dim, total, (float4*)k, (float4*)v, (float4*)kv_cache->k_cache, (float4*)kv_cache->v_cache, cache_length); +} + +template <typename T> +void permute(const Stream& stream, int num_tokens, int a, int b, T* input, T* output) { + a = a / (16/sizeof(T)); + b = b / (16/sizeof(T)); + permute_kernel<<<num_tokens, 512, 0, stream.stream>>>(num_tokens, a, b, (float4*)input, (float4*)output); +} + +} + +template <typename T> +struct Attention { + int hidden_size; + int num_attention_heads; + int num_key_value_heads; + int head_dim; + float rms_norm_eps; + + Norm<T> *attn_norm; + Linear<T> *q_proj, *k_proj, *v_proj; + Linear<T> *o_proj; + T* output; + + T* attn_output; + float *softmax_lse, *softmax_lse_accum, *oaccum; + + int window_size; + + Attention(int hidden_size, int num_attention_heads, int num_key_value_heads, int head_dim, float rms_norm_eps, int window_size = 0) { + this->hidden_size = hidden_size; + this->num_attention_heads = num_attention_heads; + this->num_key_value_heads = num_key_value_heads; + this->head_dim = head_dim; + this->rms_norm_eps = rms_norm_eps; + + this->attn_norm = new RMSNorm<T>(hidden_size, rms_norm_eps); + this->q_proj = new Linear<T>(hidden_size, num_attention_heads * head_dim); + this->k_proj = new Linear<T>(hidden_size, num_key_value_heads * head_dim); + this->v_proj = new Linear<T>(hidden_size, num_key_value_heads * head_dim); + this->o_proj = new Linear<T>(hidden_size, num_attention_heads * head_dim); + + this->window_size = window_size; + } + + void init_weight_ptr(Memory* memory) { + this->attn_norm->init_weight_ptr(memory); + this->q_proj->init_weight_ptr(memory); + this->k_proj->init_weight_ptr(memory); + this->v_proj->init_weight_ptr(memory); + this->o_proj->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t attn_norm_end = this->attn_norm->init_output_ptr(memory, num_tokens, offset); + int64_t q_proj_end = this->q_proj->init_output_ptr(memory, num_tokens, attn_norm_end); + int64_t k_proj_end = this->k_proj->init_output_ptr(memory, num_tokens, q_proj_end); + int64_t v_proj_end = this->v_proj->init_output_ptr(memory, num_tokens, k_proj_end); + + int64_t attn_output_end = memory->allocate((void**)&this->attn_output, offset, num_tokens * this->num_attention_heads * this->head_dim * sizeof(T)); + int64_t softmax_lse_end = memory->allocate((void**)&this->softmax_lse, v_proj_end, num_tokens * this->num_attention_heads * sizeof(float)); + const int max_num_splits = 128; // Maximum number of splits for attention computation + const int max_spec_tree_size = 64; // Maximum size of speculative decoding tree + int64_t softmax_lse_accum_end = memory->allocate((void**)&this->softmax_lse_accum, softmax_lse_end, max(max_num_splits * max_spec_tree_size, num_tokens) * this->num_attention_heads * sizeof(float)); + int64_t oaccum_end = memory->allocate((void**)&this->oaccum, softmax_lse_accum_end, max(max_num_splits * max_spec_tree_size, num_tokens) * this->num_attention_heads * this->head_dim * sizeof(float)); + + int64_t o_proj_end = this->o_proj->init_output_ptr(memory, num_tokens, v_proj_end); + this->output = this->o_proj->output; + + return std::max(oaccum_end, o_proj_end); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.find("q_proj") != std::string::npos) { + this->q_proj->load_to_storage(name, ptr); + } else if (name.find("k_proj") != std::string::npos) { + this->k_proj->load_to_storage(name, ptr); + } else if (name.find("v_proj") != std::string::npos) { + this->v_proj->load_to_storage(name, ptr); + } else if (name.find("o_proj") != std::string::npos) { + this->o_proj->load_to_storage(name, ptr); + } else if (name.find("input_layernorm") != std::string::npos) { + this->attn_norm->load_to_storage(name, ptr); + } else { + throw std::invalid_argument("Unsupported name " + name); + } + } + + void prefill(const Stream& stream, int32_t num_tokens, int32_t num_history_tokens, T* input, T* prev_output, int32_t* position_ids, KVCache<T>* kv_cache) { + T* k_cache = kv_cache->offset_k(num_history_tokens); + T* v_cache = kv_cache->offset_v(num_history_tokens); + + this->attn_norm->prefill(stream, num_tokens, input, prev_output); + this->q_proj->prefill(stream, num_tokens, this->attn_norm->output); + this->k_proj->prefill(stream, num_tokens, this->attn_norm->output, k_cache); + this->v_proj->prefill(stream, num_tokens, this->attn_norm->output, v_cache); + kv_cache->rotary_embedding->prefill(stream, num_tokens, this->num_attention_heads, this->num_key_value_heads, this->q_proj->output, k_cache, position_ids); + + cuda_perf_start_on_stream_f(PREFILL_ATTN_CORE, stream.stream); + mha_fwd_kvcache( + TypeTraits<T>::type_code()==1, + 1, + num_tokens, + num_history_tokens+num_tokens, + this->num_attention_heads, + this->num_key_value_heads, + this->head_dim, + this->q_proj->output, + kv_cache->k_cache, + kv_cache->v_cache, + nullptr, + Mask(nullptr), + this->attn_output, + this->softmax_lse, + this->softmax_lse_accum, + this->oaccum, + rsqrtf(float(this->head_dim)), + true, + -1, + -1, + 0, + stream.stream, + nullptr, + this->window_size + ); + cuda_perf_stop_on_stream_f(PREFILL_ATTN_CORE, stream.stream); + + // flash attention and put output to attn_norm->output + this->o_proj->prefill(stream, num_tokens, this->attn_output); + } + + void decode(const Stream& stream, int32_t num_tokens, int32_t padded_length, T* input, T* prev_output, int32_t* position_ids, int32_t* cache_length, const Mask& mask, KVCache<T>* kv_cache) { + this->attn_norm->prefill(stream, num_tokens, input, prev_output); + T *q = nullptr; +#ifdef DISABLE_MEMPOOL + this->q_proj->prefill(stream, num_tokens, this->attn_norm->output); + this->k_proj->prefill(stream, num_tokens, this->attn_norm->output); + this->v_proj->prefill(stream, num_tokens, this->attn_norm->output); + q = this->q_proj->output; + kv_cache->rotary_embedding->prefill(stream, num_tokens, this->num_attention_heads, this->num_key_value_heads, this->q_proj->output, this->k_proj->output, position_ids); + copy_to_kvcache(stream, num_tokens, this->k_proj->output, this->v_proj->output, kv_cache, cache_length); +#else + int merge_dim_out = (this->num_attention_heads + 2 * this->num_key_value_heads) * this->head_dim; + if (num_tokens > 1) { + linear<T>(stream, num_tokens, this->hidden_size, merge_dim_out, this->attn_norm->output, this->q_proj->weight, this->v_proj->output); + permute(stream, num_tokens, this->num_attention_heads * this->head_dim, this->num_key_value_heads * this->head_dim, this->v_proj->output, this->q_proj->output); + } else { + linear<T>(stream, num_tokens, this->hidden_size, merge_dim_out, this->attn_norm->output, this->q_proj->weight, this->q_proj->output); + } + q = this->q_proj->output; + T* k = q + num_tokens * this->num_attention_heads * this->head_dim; + T* v = k + num_tokens * this->num_key_value_heads * this->head_dim; + kv_cache->rotary_embedding->prefill(stream, num_tokens, this->num_attention_heads, this->num_key_value_heads, q, k, position_ids); + copy_to_kvcache(stream, num_tokens, k, v, kv_cache, cache_length); +#endif + + cuda_perf_start_on_stream_f(DECODE_ATTN_CORE, stream.stream); + mha_fwd_kvcache( + TypeTraits<T>::type_code()==1, + 1, + num_tokens, + padded_length, + this->num_attention_heads, + this->num_key_value_heads, + this->head_dim, + q, + kv_cache->k_cache, + kv_cache->v_cache, + cache_length, + mask, + this->attn_output, + this->softmax_lse, + this->softmax_lse_accum, + this->oaccum, + rsqrtf(float(this->head_dim)), + true, + -1, + -1, + 0, + stream.stream, + nullptr, + this->window_size + ); + cuda_perf_stop_on_stream_f(DECODE_ATTN_CORE, stream.stream); + + // flash attention and put output to attn_norm->output + this->o_proj->prefill(stream, num_tokens, this->attn_output); + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/drafter.cuh b/examples/CPM.cu/src/model/drafter.cuh new file mode 100644 index 00000000..6bf1ff4d --- /dev/null +++ b/examples/CPM.cu/src/model/drafter.cuh @@ -0,0 +1,48 @@ +#pragma once +#include <cuda_runtime.h> + +namespace { + +__global__ void seq_verify_kernel(int num_tokens, int32_t* pred, const int32_t* gt, const uint16_t* attn_mask, int32_t* d_best) { + int i = threadIdx.x; + + __shared__ uint16_t s_correct_mask; + uint16_t correct_mask = 1; + if (0 < i && i < num_tokens && pred[i] == gt[i-1]) correct_mask |= 1ULL << i; + // only 16 threads + correct_mask |= __shfl_down_sync(0x0000ffff, correct_mask, 8); + correct_mask |= __shfl_down_sync(0x0000ffff, correct_mask, 4); + correct_mask |= __shfl_down_sync(0x0000ffff, correct_mask, 2); + correct_mask |= __shfl_down_sync(0x0000ffff, correct_mask, 1); + if (i == 0) s_correct_mask = correct_mask; + __syncthreads(); + correct_mask = s_correct_mask; + + __shared__ int32_t mx[16]; + // int prefix_length = cache_length[0]; + if (i < num_tokens && ((correct_mask & attn_mask[i]) == attn_mask[i])) { + mx[i] = i + 1; + } else { + mx[i] = 1; + } + __syncthreads(); + for (int offset = 8; offset > 0; offset >>= 1) { + if (i < offset && mx[i + offset] > mx[i]) { + mx[i] = mx[i + offset]; + } + __syncthreads(); + } + if (i == 0) { + d_best[0] = mx[0]; + } + __syncthreads(); + +} + +} + + +void verify_seq_draft(const Stream& stream, int num_tokens, int32_t* pred, const int32_t* gt, const uint16_t* attn_mask, int32_t* best) { + // each SM has 32 threads + seq_verify_kernel<<<1, 16, 0, stream.stream>>>(num_tokens, pred, gt, attn_mask, best); +} \ No newline at end of file diff --git a/examples/CPM.cu/src/model/eagle.cuh b/examples/CPM.cu/src/model/eagle.cuh new file mode 100644 index 00000000..f426e25b --- /dev/null +++ b/examples/CPM.cu/src/model/eagle.cuh @@ -0,0 +1,517 @@ +#pragma once +#include "tree_drafter.cuh" +#include "model.cuh" +#include "topk.cuh" +#include "layer.cuh" +#include "kvcache.cuh" +#include "norm.cuh" +#include "elementwise.cuh" + +namespace { +__global__ void add_kernel(int32_t* ptr, int32_t value) { + ptr[threadIdx.x] += value; +} + +__global__ void repeat_kernel(int32_t dim, int32_t pos, const float4* input, float4* output) { + int row = blockIdx.x; + int col = threadIdx.x; + for (int i = col; i < dim; i += blockDim.x) { + output[row * dim + i] = input[pos * dim + i]; + } +} + +template<typename T> +__global__ void repeat_kernel_2(int32_t pos, const T* input, T* output) { + int col = threadIdx.x; + output[col] = input[pos]; +} + +template<typename T> +__global__ void log_softmax_kernel(int32_t dim, T* input) { + int base_idx = blockIdx.x * dim; + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + + __shared__ float s_val[32]; + float mx = -TypeTraits<T>::inf(); + for (int i = threadIdx.x; i < dim; i += blockDim.x) { + mx = fmaxf(float(input[base_idx + i]), mx); + } + mx = fmaxf(__shfl_down_sync(0xffffffff, mx, 16), mx); + mx = fmaxf(__shfl_down_sync(0xffffffff, mx, 8), mx); + mx = fmaxf(__shfl_down_sync(0xffffffff, mx, 4), mx); + mx = fmaxf(__shfl_down_sync(0xffffffff, mx, 2), mx); + mx = fmaxf(__shfl_down_sync(0xffffffff, mx, 1), mx); + if (lane_id == 0) s_val[warp_id] = mx; + __syncthreads(); + if (threadIdx.x < 32) { + mx = s_val[threadIdx.x]; + mx = fmaxf(__shfl_down_sync(0xffffffff, mx, 16), mx); + mx = fmaxf(__shfl_down_sync(0xffffffff, mx, 8), mx); + mx = fmaxf(__shfl_down_sync(0xffffffff, mx, 4), mx); + mx = fmaxf(__shfl_down_sync(0xffffffff, mx, 2), mx); + mx = fmaxf(__shfl_down_sync(0xffffffff, mx, 1), mx); + } + if (threadIdx.x == 0) { + s_val[0] = mx; + } + __syncthreads(); + mx = s_val[0]; + + float sum = 0; + for (int i = threadIdx.x; i < dim; i += blockDim.x) { + sum += expf(float(input[base_idx + i]) - mx); + } + sum += __shfl_down_sync(0xffffffff, sum, 16); + sum += __shfl_down_sync(0xffffffff, sum, 8); + sum += __shfl_down_sync(0xffffffff, sum, 4); + sum += __shfl_down_sync(0xffffffff, sum, 2); + sum += __shfl_down_sync(0xffffffff, sum, 1); + if (lane_id == 0) s_val[warp_id] = sum; + __syncthreads(); + if (threadIdx.x < 32) { + sum = s_val[threadIdx.x]; + sum += __shfl_down_sync(0xffffffff, sum, 16); + sum += __shfl_down_sync(0xffffffff, sum, 8); + sum += __shfl_down_sync(0xffffffff, sum, 4); + sum += __shfl_down_sync(0xffffffff, sum, 2); + sum += __shfl_down_sync(0xffffffff, sum, 1); + } + if (threadIdx.x == 0) { + s_val[0] = sum; + } + __syncthreads(); + sum = s_val[0]; + + for (int i = threadIdx.x; i < dim; i += blockDim.x) { + input[base_idx + i] = T( float(input[base_idx + i]) - mx - logf(sum) ); + } +} + +__global__ void init_tree_kernel(uint64_t* mask_2d) { + mask_2d[threadIdx.x] = 1ULL << threadIdx.x; +} + +__global__ void set_parent_kernel(int32_t num_tokens, int32_t* parent, const int32_t* pos, int32_t offset) { + parent[threadIdx.x] = pos[threadIdx.x] + offset; +} + +__global__ void update_tree_kernel(int32_t num_tokens, int32_t offset, uint64_t* mask_2d, const uint64_t* tmp_mask_2d, const int32_t* topk_pos) { + mask_2d[threadIdx.x] = tmp_mask_2d[topk_pos[threadIdx.x] / num_tokens] | (1ULL << (offset + threadIdx.x)); +} + +template<typename T> +__global__ void cumsum_kernel(int32_t dim, T* input, const T* weight) { + input[blockIdx.x * dim + threadIdx.x] += weight[blockIdx.x]; +} + +__global__ void remap_hidden_kernel(int32_t scale, int32_t dim, const int32_t* id_map, const float4* real_hidden, float4* output) { + int row = blockIdx.x; + int col = threadIdx.x; + int real_row = id_map[row] / scale; + for (int i = col; i < dim; i += blockDim.x) { + output[row * dim + i] = real_hidden[real_row * dim + i]; + } +} + +__global__ void remap_id_kernel(const int32_t* id_map, const int32_t* real_id, int32_t* output) { + output[threadIdx.x] = real_id[id_map[threadIdx.x]]; +} + +__global__ void remap_id_kernel(const int32_t* id_map, const int32_t* real_id, const int32_t* token_id_remap, int32_t* output) { + output[threadIdx.x] = token_id_remap[real_id[id_map[threadIdx.x]]]; +} + +__global__ void make_arange_kernel(int32_t* offset, int32_t* output) { + output[threadIdx.x] = threadIdx.x + offset[0]; +} + +} // namespace + +void add(const Stream& stream, int32_t num_tokens, int32_t* ptr, int32_t value) { + add_kernel<<<1, num_tokens, 0, stream.stream>>>(ptr, value); +} + +template<typename T> +void repeat(const Stream& stream, int32_t num_tokens, int32_t dim, int32_t pos, T* input, T* output=nullptr) { + if (output == nullptr) output = input; + if (dim > 1) { + dim = dim / (16 / sizeof(T)); + repeat_kernel<<<num_tokens, 512, 0, stream.stream>>>(dim, pos, (float4*)input, (float4*)output); + } else { + repeat_kernel_2<<<1, num_tokens, 0, stream.stream>>>(pos, input, output); + } +} + +template<typename T> +void log_softmax(const Stream& stream, int32_t num_tokens, int32_t dim, T* input) { + log_softmax_kernel<<<num_tokens, 1024, 0, stream.stream>>>(dim, input); +} + +void init_tree(const Stream& stream, int32_t num_tokens, uint64_t* mask_2d) { + init_tree_kernel<<<1, num_tokens, 0, stream.stream>>>(mask_2d); +} + +void set_parent(const Stream& stream, int32_t num_tokens, int32_t* parent, const int32_t* pos, int32_t offset) { + set_parent_kernel<<<1, num_tokens, 0, stream.stream>>>(num_tokens, parent, pos, offset); +} + +void update_tree(const Stream& stream, int32_t num_tokens, int32_t offset, uint64_t* mask_2d, const uint64_t* tmp_mask_2d, const int32_t* topk_pos) { + update_tree_kernel<<<1, num_tokens, 0, stream.stream>>>(num_tokens, offset, mask_2d, tmp_mask_2d, topk_pos); +} + +template<typename T> +void cumsum(const Stream& stream, int32_t num_tokens, int32_t dim, T* input, const T* weight) { + cumsum_kernel<<<num_tokens, dim, 0, stream.stream>>>(dim, input, weight); +} + +template<typename T> +void remap_hidden(const Stream& stream, int32_t num_tokens, int32_t dim, const int32_t* id_map, const T* real_hidden, T* output, int32_t scale=1) { + dim = dim / (16 / sizeof(T)); + remap_hidden_kernel<<<num_tokens, 512, 0, stream.stream>>>(scale, dim, id_map, (float4*)real_hidden, (float4*)output); +} + +void remap_id(const Stream& stream, int32_t num_tokens, int32_t* id_map, const int32_t* real_id, int32_t* output=nullptr) { + if (output == nullptr) output = id_map; + remap_id_kernel<<<1, num_tokens, 0, stream.stream>>>(id_map, real_id, output); +} + +void remap_id_fr(const Stream& stream, int32_t num_tokens, int32_t* id_map, const int32_t* real_id, const int32_t* token_id_remap, int32_t* output=nullptr) { + if (output == nullptr) output = id_map; + remap_id_kernel<<<1, num_tokens, 0, stream.stream>>>(id_map, real_id, token_id_remap, output); +} + +void make_arange(const Stream& stream, int32_t range, int32_t* offset, int32_t* output) { + make_arange_kernel<<<1, range, 0, stream.stream>>>(offset, output); +} + +__global__ void build_dynamic_tree_kernel(int32_t tree_size, int32_t pos_offset, int32_t topk_per_iter, const int32_t* tried_history_parent, const int32_t* topk_pos, int32_t* tree_pos, uint64_t* tree_mask, int32_t* tree_parent) { + __shared__ int32_t reverse_tree_id[1024]; + int tid = threadIdx.x; + if (tid != 0) { + reverse_tree_id[topk_pos[tid-1]] = tid; + } + __syncthreads(); + if (tid == 0) { + tree_mask[0] = 1; + tree_pos[0] = pos_offset; + for (int i = 1; i < tree_size; i++) { + int p = topk_pos[i-1]; + tree_pos[i] = pos_offset + ((p < topk_per_iter) ? 1 : (p - topk_per_iter) / (topk_per_iter * topk_per_iter) + 2); + tree_mask[i] = 1ULL << reverse_tree_id[p];; + + if (p < topk_per_iter) { + p = -1; + } else { + p = p - topk_per_iter; + if (p < topk_per_iter * topk_per_iter) { + p = p / topk_per_iter; + } else { + p = tried_history_parent[(p - topk_per_iter * topk_per_iter) / topk_per_iter]; + } + } + int parent = p == -1 ? 0 : reverse_tree_id[p]; + tree_parent[i] = parent; + tree_mask[i] |= tree_mask[parent]; + } + } +} + +void build_dynamic_tree(const Stream& stream, int32_t tree_size, int32_t pos_offset, int32_t topk_per_iter, const int32_t* tried_history_parent, const int32_t* topk_pos, int32_t* tree_pos, uint64_t* tree_mask, int32_t* tree_parent) { + +// TODO: remove this after fixing the bug +// #ifdef DEBUG + // Check if all elements in topk_pos are the same + int32_t h_topk_pos[tree_size]; + cudaMemcpy(h_topk_pos, topk_pos, tree_size * sizeof(int32_t), cudaMemcpyDeviceToHost); + for (int i = 1; i < tree_size; i++) { + if (h_topk_pos[i] != h_topk_pos[0]) break; + if (i == tree_size - 1) throw std::runtime_error("All topk_pos elements are identical"); + } +// #endif + build_dynamic_tree_kernel<<<1, tree_size, 0, stream.stream>>>(tree_size, pos_offset, topk_per_iter, tried_history_parent, topk_pos, tree_pos, tree_mask, tree_parent); +} + +template<typename T> +struct Skip : Norm<T> { + int dim; + + Skip(int dim) { + this->dim = dim; + } + + void init_weight_ptr(Memory* memory) {} + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + return memory->allocate((void**)&this->output, offset, num_tokens * dim * sizeof(T)); + } + + void load_to_storage(std::string name, void* ptr) {} + + void prefill(const Stream& stream, int32_t num_tokens, T* input, T* prev_output, T* tgt=nullptr) { + if (tgt == nullptr) tgt = this->output; + if (prev_output == nullptr) { + cudaMemcpy(tgt, input, sizeof(T) * this->dim * num_tokens, cudaMemcpyDeviceToDevice); + } else { + elementwise_add(stream, num_tokens, this->dim, input, prev_output, tgt); + } + } +}; + +template<typename T, class ModelType> +struct EagleImpl : Model { + int num_layers; + int num_iter; + int topk_per_iter; + int tree_size; + int total_tried; + + ModelType* model; + KVCacheManager<T>* kv_caches; + std::vector<Layer<T>*> layers; + Linear<T, true, true> *fc1; + Linear<T> *fc2; + functions::TopK<T>* topk_func; + functions::TopK<T>* topk_func_2; + + T *prev_hidden_state, *prev_embed; + int num_prev, num_history_tokens; + int32_t *eagle_position_ids, *eagle_cache_length; + int *eagle_original_length, eagle_padded_length; + uint64_t *eagle_mask_2d, *tmp_mask_2d; + T* eagle_logits; + T* tried_history_val; int32_t* tried_history_pos; + int32_t* tried_history_parent; + bool is_first_draft; + + int32_t *h_best, *d_best; + + T* tmp_kvcache; + + EagleImpl( + ModelType* model, + int num_layers, + int num_iter, + int topk_per_iter, + int tree_size + ) { + this->model = model; + this->num_layers = num_layers; + this->num_iter = num_iter; + this->topk_per_iter = topk_per_iter; + this->tree_size = tree_size; + assert(this->tree_size <= 64); // tree_size must be <= 64 + this->total_tried = topk_per_iter * topk_per_iter * (num_iter - 1) + topk_per_iter; + + kv_caches = new KVCacheManager<T>(num_layers, this->model->num_key_value_heads, this->model->head_dim); + fc1 = new Linear<T, true, true>(this->model->hidden_size, this->model->hidden_size); + fc2 = new Linear<T>(this->model->hidden_size, this->model->hidden_size); + for (int i = 0; i < num_layers; i++) { + layers.push_back(new Layer<T>(this->model->hidden_size, this->model->intermediate_size, this->model->num_attention_heads, this->model->num_key_value_heads, this->model->head_dim, this->model->rms_norm_eps)); + } + + assert(topk_per_iter <= this->tree_size-1); + + topk_func = new functions::TopK<T>(model->vocab_size, topk_per_iter); + topk_func_2 = new functions::TopK<T>(total_tried, this->tree_size-1); + } + + void init_weight_ptr(Memory* memory) { + fc1->init_weight_ptr(memory); + fc2->init_weight_ptr(memory); + for (int i = 0; i < num_layers; i++) { + layers[i]->init_weight_ptr(memory); + } + layers[0]->attn->attn_norm = new Skip<T>(this->model->hidden_size); + kv_caches->rotary_embedding = this->model->kv_caches->rotary_embedding; + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + offset = fc1->init_output_ptr(memory, num_tokens, offset); + offset = fc2->init_output_ptr(memory, num_tokens, offset); + int64_t layer_end = 0; + for (int i = 0; i < num_layers; i++) { + layer_end = layers[i]->init_output_ptr(memory, num_tokens, offset); + } + offset = layer_end; + offset = memory->allocate((void**)&eagle_logits, offset, this->topk_per_iter * this->model->vocab_size * sizeof(T)); + offset = memory->allocate((void**)&eagle_mask_2d, offset, this->topk_per_iter * sizeof(uint64_t)); + offset = memory->allocate((void**)&tmp_mask_2d, offset, this->topk_per_iter * sizeof(uint64_t)); + offset = memory->allocate((void**)&tried_history_val, offset, this->total_tried * sizeof(T)); + offset = memory->allocate((void**)&tried_history_pos, offset, this->total_tried * sizeof(int32_t)); + offset = memory->allocate((void**)&tried_history_parent, offset, this->topk_per_iter * (this->num_iter - 1) * sizeof(int32_t)); + cudaMallocHost(&eagle_original_length, sizeof(int32_t)); + + offset = topk_func->init_output_ptr(memory, this->topk_per_iter, offset); + offset = topk_func_2->init_output_ptr(memory, 1, offset); + + offset = memory->allocate((void**)&prev_hidden_state, offset, num_tokens * this->model->hidden_size * sizeof(T)); + offset = memory->allocate((void**)&prev_embed, offset, num_tokens * this->model->hidden_size * sizeof(T)); + offset = memory->allocate((void**)&eagle_position_ids, offset, num_tokens * sizeof(int32_t)); + offset = memory->allocate((void**)&eagle_cache_length, offset, sizeof(int32_t)); + + offset = memory->allocate((void**)&d_best, offset, 2 * sizeof(int32_t)); + cudaMallocHost(&h_best, 2 * sizeof(int32_t)); + offset = memory->allocate((void**)&tmp_kvcache, offset, 64 * this->model->kv_caches->num_hidden_layers * 2 * this->model->kv_caches->dim * sizeof(T)); + return offset; + } + + int init_storage() { + this->model->init_weight_ptr(this->model->memory); + this->init_weight_ptr(this->model->memory); + int64_t offset = this->model->init_output_ptr(this->model->memory, this->model->chunk_length, this->model->memory->model_offset); + int64_t kv_cache_offset = this->init_output_ptr(this->model->memory, this->model->chunk_length, offset); + float ratio = float(this->model->num_hidden_layers) / (this->model->num_hidden_layers + this->num_layers); + kv_cache_offset = this->model->kv_caches->init_output_ptr(this->model->memory, kv_cache_offset, ratio); + kv_caches->init_output_ptr(this->model->memory, kv_cache_offset); + return min(kv_caches->budget + 1, this->model->kv_caches->budget); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.substr(0, 5) == "eagle") { + if (name.substr(0, 9) == "eagle.fc1") { + fc1->load_to_storage(name, ptr); + } else if (name.substr(0, 9) == "eagle.fc2") { + fc2->load_to_storage(name, ptr); + } else { + std::regex layer_regex("eagle\\.layers\\.(\\d+)\\.(.*)"); + std::smatch matches; + if (std::regex_search(name, matches, layer_regex)) { + int layer_idx = std::stoi(matches[1]); + layers[layer_idx]->load_to_storage(matches[2], ptr); + } else { + throw std::invalid_argument("Unsupported name (layer_idx not found): " + name); + } + } + } else { + this->model->load_to_storage(name, ptr); + } + } + + void eagle_prefill(int num_history_tokens) { + cudaMemcpy(this->prev_embed + (num_prev - 1) * this->model->hidden_size, this->model->embedding->output, this->model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->fc1->prefill(calc_stream, num_prev, this->prev_embed); + this->fc2->prefill(calc_stream, num_prev, this->prev_hidden_state); + elementwise_add(calc_stream, num_prev, this->model->hidden_size, this->fc1->output, this->fc2->output, this->fc2->output); + T* layer_output = nullptr; + for (int i = 0; i < num_layers; i++) { + this->layers[i]->prefill(num_prev, num_history_tokens, this->fc2->output, layer_output, this->eagle_position_ids, this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_add(calc_stream, num_prev, this->model->hidden_size, this->fc2->output, layer_output, this->fc2->output); + } + + void eagle_decode(int32_t* cache_length) { + this->fc1->prefill(calc_stream, num_prev, this->prev_embed); + this->fc2->prefill(calc_stream, num_prev, this->prev_hidden_state); + elementwise_add(calc_stream, num_prev, this->model->hidden_size, this->fc1->output, this->fc2->output, this->fc2->output); + T* layer_output = nullptr; + for (int i = 0; i < num_layers; i++) { + this->layers[i]->decode(num_prev, this->eagle_padded_length, this->fc2->output, layer_output, this->eagle_position_ids, cache_length, Mask(nullptr), this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_add(calc_stream, num_prev, this->model->hidden_size, this->fc2->output, layer_output, this->fc2->output); + } + + void prefill(int32_t num_tokens, int32_t num_history_tokens, int32_t* input, int32_t* position_ids, void* output) { + this->model->embedding->prefill(calc_stream, num_tokens, input); + if (num_history_tokens > 0) { + this->eagle_prefill(this->num_history_tokens); + } + + cudaMemcpy(this->prev_embed, this->model->embedding->output + this->model->hidden_size, (num_tokens - 1) * this->model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->model->prefill_embed(num_tokens, num_history_tokens, this->model->embedding->output, position_ids, output); + this->prev_hidden_state = this->model->norm->output; + cudaMemcpy(this->eagle_position_ids, position_ids, num_tokens * sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->num_prev = num_tokens; + + this->num_history_tokens = num_history_tokens; + this->is_first_draft = true; + } + + void decode(int32_t num_tokens, int32_t padded_length, int32_t* input, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + this->model->decode(num_tokens, padded_length, input, position_ids, cache_length, mask_2d, output); + } + + void draft(int32_t* tree_draft_ids, int32_t* tree_position_ids, int32_t* cache_length, uint64_t* tree_attn_mask, int32_t* tree_parent) { + cudaMemcpy(this->eagle_original_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + this->eagle_padded_length = (this->eagle_original_length[0] + 256 - 1) / 128 * 128; + + + if (this->is_first_draft) { + this->model->embedding->prefill(calc_stream, 1, tree_draft_ids); + this->eagle_prefill(this->num_history_tokens); + } else { + this->eagle_decode(cache_length); + } + cudaMemcpy(this->eagle_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->eagle_position_ids, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + repeat(calc_stream, topk_per_iter, 1, 0, this->eagle_position_ids); + + { // d = 0 + this->model->lm_head->prefill(calc_stream, 1, this->fc2->output + (num_prev - 1) * this->model->hidden_size, this->eagle_logits); + log_softmax(calc_stream, 1, this->model->vocab_size, this->eagle_logits); + this->topk_func->prefill(calc_stream, 1, this->eagle_logits); + cudaMemcpy(this->topk_func_2->topk_pos, this->topk_func->topk_pos, topk_per_iter * sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->topk_func_2->topk_val, this->topk_func->topk_val, topk_per_iter * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->tried_history_val, this->topk_func->topk_val, topk_per_iter * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->tried_history_pos, this->topk_func->topk_pos, topk_per_iter * sizeof(int32_t), cudaMemcpyDeviceToDevice); + repeat(calc_stream, topk_per_iter, this->model->hidden_size, num_prev-1, this->fc2->output, this->fc1->output); + init_tree(calc_stream, topk_per_iter, this->eagle_mask_2d); + } + for (int d = 1; d < this->num_iter; ++d) { + add(calc_stream, 1, this->eagle_cache_length, topk_per_iter); + this->model->embedding->prefill(calc_stream, topk_per_iter, this->topk_func_2->topk_pos); + this->fc2->prefill(calc_stream, topk_per_iter, this->fc1->output); + this->fc1->prefill(calc_stream, topk_per_iter, this->model->embedding->output); + elementwise_add(calc_stream, topk_per_iter, this->model->hidden_size, this->fc1->output, this->fc2->output, this->fc2->output); + T* layer_output = nullptr; + for (int i = 0; i < num_layers; i++) { + this->layers[i]->decode(topk_per_iter, this->eagle_padded_length, this->fc2->output, layer_output, this->eagle_position_ids, this->eagle_cache_length, Mask(eagle_mask_2d, topk_per_iter, topk_per_iter * d), this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_add(calc_stream, topk_per_iter, this->model->hidden_size, this->fc2->output, layer_output, this->fc2->output); + add(calc_stream, topk_per_iter, this->eagle_position_ids, 1); + + this->model->lm_head->prefill(calc_stream, topk_per_iter, this->fc2->output, this->eagle_logits); + log_softmax(calc_stream, topk_per_iter, this->model->vocab_size, this->eagle_logits); + this->topk_func->prefill(calc_stream, topk_per_iter, this->eagle_logits); + cumsum(calc_stream, topk_per_iter, topk_per_iter, this->topk_func->topk_val, this->topk_func_2->topk_val); + cudaMemcpy(this->tried_history_val + topk_per_iter + (d - 1) * topk_per_iter * topk_per_iter, this->topk_func->topk_val, topk_per_iter * topk_per_iter * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->tried_history_pos + topk_per_iter + (d - 1) * topk_per_iter * topk_per_iter, this->topk_func->topk_pos, topk_per_iter * topk_per_iter * sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->topk_func_2->prefill(calc_stream, 1, this->topk_func->topk_val, topk_per_iter * topk_per_iter, topk_per_iter); + + cudaMemcpy(this->tmp_mask_2d, this->eagle_mask_2d, topk_per_iter * sizeof(uint64_t), cudaMemcpyDeviceToDevice); + set_parent(calc_stream, topk_per_iter, this->tried_history_parent + (d - 1) * topk_per_iter, this->topk_func_2->topk_pos, topk_per_iter + (d - 1) * topk_per_iter * topk_per_iter); + update_tree(calc_stream, topk_per_iter, topk_per_iter * d, this->eagle_mask_2d, this->tmp_mask_2d, this->topk_func_2->topk_pos); + remap_hidden(calc_stream, topk_per_iter, this->model->hidden_size, this->topk_func_2->topk_pos, this->fc2->output, this->fc1->output, topk_per_iter); + remap_id(calc_stream, topk_per_iter, this->topk_func_2->topk_pos, this->topk_func->topk_pos); + } + + this->topk_func_2->prefill(calc_stream, 1, this->tried_history_val); + + // build tree + build_dynamic_tree(calc_stream, this->tree_size, this->eagle_original_length[0], this->topk_per_iter, this->tried_history_parent, this->topk_func_2->topk_pos, tree_position_ids, tree_attn_mask, tree_parent); + remap_id(calc_stream, this->tree_size-1, this->topk_func_2->topk_pos, this->tried_history_pos, tree_draft_ids + 1); + + this->is_first_draft = false; + } + + int verify(int32_t num_tokens, int32_t* pred, int32_t* gt, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, int32_t* tree_parent) { + verify_draft(calc_stream, num_tokens, pred, gt, position_ids, cache_length, mask_2d, tree_parent, this->d_best); + cudaMemcpyAsync(this->h_best, this->d_best, 2 * sizeof(int32_t), cudaMemcpyDeviceToHost, calc_stream.stream); + cudaStreamSynchronize(calc_stream.stream); + + this->num_prev = h_best[0]; + remap_hidden(calc_stream, this->num_prev, this->model->hidden_size, pred, this->model->norm->output, this->prev_hidden_state); + + fix_kv_cache(calc_stream, h_best[0], this->model->kv_caches->num_hidden_layers * 2, this->model->kv_caches->dim, pred, gt, cache_length, this->model->kv_caches->d_flat_caches, this->tmp_kvcache); + + this->model->embedding->prefill(calc_stream, this->num_prev, pred); + cudaMemcpy(this->prev_embed, this->model->embedding->output, this->num_prev * this->model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + + make_arange(calc_stream, this->num_prev, cache_length, this->eagle_position_ids); + + return h_best[0]; + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/eagle_base_quant/eagle_base_w4a16_gptq_marlin.cuh b/examples/CPM.cu/src/model/eagle_base_quant/eagle_base_w4a16_gptq_marlin.cuh new file mode 100644 index 00000000..bb6a9c82 --- /dev/null +++ b/examples/CPM.cu/src/model/eagle_base_quant/eagle_base_w4a16_gptq_marlin.cuh @@ -0,0 +1,260 @@ +#pragma once +#include "../w4a16_gptq_marlin/w4a16_gptq_marlin_model.cuh" +#include "../eagle.cuh" + +template<typename T> +struct EagleImplBaseW4A16GPTQMarlin : Model { + int num_layers; + int num_iter; + int topk_per_iter; + int tree_size; + int total_tried; + + W4A16GPTQMarlinModelImpl<T>* model; + KVCacheManager<T>* kv_caches; + std::vector<Layer<T>*> layers; + Linear<T, true, true> *fc1; + Linear<T> *fc2; + functions::TopK<T>* topk_func; + functions::TopK<T>* topk_func_2; + + T *prev_hidden_state, *prev_embed; + int num_prev, num_history_tokens; + int32_t *eagle_position_ids, *eagle_cache_length; + int *eagle_original_length, eagle_padded_length; + uint64_t *eagle_mask_2d, *tmp_mask_2d; + T* eagle_logits; + T* tried_history_val; int32_t* tried_history_pos; + int32_t* tried_history_parent; + bool is_first_draft; + + int32_t *h_best, *d_best; + + T* tmp_kvcache; + + EagleImplBaseW4A16GPTQMarlin( + W4A16GPTQMarlinModelImpl<T>* model, + int num_layers, + int num_iter, + int topk_per_iter, + int tree_size + ) { + this->model = model; + this->num_layers = num_layers; + this->num_iter = num_iter; + this->topk_per_iter = topk_per_iter; + this->tree_size = tree_size; + this->total_tried = topk_per_iter * topk_per_iter * (num_iter - 1) + topk_per_iter; + + kv_caches = new KVCacheManager<T>(num_layers, this->model->num_key_value_heads, this->model->head_dim); + fc1 = new Linear<T, true, true>(this->model->hidden_size, this->model->hidden_size); + fc2 = new Linear<T>(this->model->hidden_size, this->model->hidden_size); + for (int i = 0; i < num_layers; i++) { + layers.push_back(new Layer<T>(this->model->hidden_size, this->model->intermediate_size, this->model->num_attention_heads, this->model->num_key_value_heads, this->model->head_dim, this->model->rms_norm_eps)); + } + + assert(topk_per_iter <= this->tree_size-1); + + topk_func = new functions::TopK<T>(model->vocab_size, topk_per_iter); + topk_func_2 = new functions::TopK<T>(total_tried, this->tree_size-1); + } + + void init_weight_ptr(Memory* memory) { + fc1->init_weight_ptr(memory); + fc2->init_weight_ptr(memory); + for (int i = 0; i < num_layers; i++) { + layers[i]->init_weight_ptr(memory); + } + layers[0]->attn->attn_norm = new Skip<T>(this->model->hidden_size); + kv_caches->rotary_embedding = this->model->kv_caches->rotary_embedding; + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + offset = fc1->init_output_ptr(memory, num_tokens, offset); + offset = fc2->init_output_ptr(memory, num_tokens, offset); + int64_t layer_end = 0; + for (int i = 0; i < num_layers; i++) { + layer_end = layers[i]->init_output_ptr(memory, num_tokens, offset); + } + offset = layer_end; + offset = memory->allocate((void**)&eagle_logits, offset, this->topk_per_iter * this->model->vocab_size * sizeof(T)); + offset = memory->allocate((void**)&eagle_mask_2d, offset, this->topk_per_iter * sizeof(uint64_t)); + offset = memory->allocate((void**)&tmp_mask_2d, offset, this->topk_per_iter * sizeof(uint64_t)); + offset = memory->allocate((void**)&tried_history_val, offset, this->total_tried * sizeof(T)); + offset = memory->allocate((void**)&tried_history_pos, offset, this->total_tried * sizeof(int32_t)); + offset = memory->allocate((void**)&tried_history_parent, offset, this->topk_per_iter * (this->num_iter - 1) * sizeof(int32_t)); + cudaMallocHost(&eagle_original_length, sizeof(int32_t)); + + offset = topk_func->init_output_ptr(memory, this->topk_per_iter, offset); + offset = topk_func_2->init_output_ptr(memory, 1, offset); + + offset = memory->allocate((void**)&prev_hidden_state, offset, num_tokens * this->model->hidden_size * sizeof(T)); + offset = memory->allocate((void**)&prev_embed, offset, num_tokens * this->model->hidden_size * sizeof(T)); + offset = memory->allocate((void**)&eagle_position_ids, offset, num_tokens * sizeof(int32_t)); + offset = memory->allocate((void**)&eagle_cache_length, offset, sizeof(int32_t)); + + offset = memory->allocate((void**)&d_best, offset, 2 * sizeof(int32_t)); + cudaMallocHost(&h_best, 2 * sizeof(int32_t)); + offset = memory->allocate((void**)&tmp_kvcache, offset, 64 * this->model->kv_caches->num_hidden_layers * 2 * this->model->kv_caches->dim * sizeof(T)); + return offset; + } + + int init_storage() { + this->model->init_weight_ptr(this->model->memory); + this->init_weight_ptr(this->model->memory); + int64_t offset = this->model->init_output_ptr(this->model->memory, this->model->chunk_length, this->model->memory->model_offset); + int64_t kv_cache_offset = this->init_output_ptr(this->model->memory, this->model->chunk_length, offset); + float ratio = float(this->model->num_hidden_layers) / (this->model->num_hidden_layers + this->num_layers); + kv_cache_offset = this->model->kv_caches->init_output_ptr(this->model->memory, kv_cache_offset, ratio); + kv_caches->init_output_ptr(this->model->memory, kv_cache_offset); + return min(kv_caches->budget + 1, this->model->kv_caches->budget); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.substr(0, 5) == "eagle") { + if (name.substr(0, 9) == "eagle.fc1") { + fc1->load_to_storage(name, ptr); + } else if (name.substr(0, 9) == "eagle.fc2") { + fc2->load_to_storage(name, ptr); + } else { + std::regex layer_regex("eagle\\.layers\\.(\\d+)\\.(.*)"); + std::smatch matches; + if (std::regex_search(name, matches, layer_regex)) { + int layer_idx = std::stoi(matches[1]); + layers[layer_idx]->load_to_storage(matches[2], ptr); + } else { + throw std::invalid_argument("Unsupported name (layer_idx not found): " + name); + } + } + } else { + this->model->load_to_storage(name, ptr); + } + } + + void eagle_prefill(int num_history_tokens) { + cudaMemcpy(this->prev_embed + (num_prev - 1) * this->model->hidden_size, this->model->embedding->output, this->model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->fc1->prefill(calc_stream, num_prev, this->prev_embed); + this->fc2->prefill(calc_stream, num_prev, this->prev_hidden_state); + elementwise_add(calc_stream, num_prev, this->model->hidden_size, this->fc1->output, this->fc2->output, this->fc2->output); + T* layer_output = nullptr; + for (int i = 0; i < num_layers; i++) { + this->layers[i]->prefill(num_prev, num_history_tokens, this->fc2->output, layer_output, this->eagle_position_ids, this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_add(calc_stream, num_prev, this->model->hidden_size, this->fc2->output, layer_output, this->fc2->output); + } + + void eagle_decode(int32_t* cache_length) { + this->fc1->prefill(calc_stream, num_prev, this->prev_embed); + this->fc2->prefill(calc_stream, num_prev, this->prev_hidden_state); + elementwise_add(calc_stream, num_prev, this->model->hidden_size, this->fc1->output, this->fc2->output, this->fc2->output); + T* layer_output = nullptr; + for (int i = 0; i < num_layers; i++) { + this->layers[i]->decode(num_prev, this->eagle_padded_length, this->fc2->output, layer_output, this->eagle_position_ids, cache_length, Mask(nullptr), this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_add(calc_stream, num_prev, this->model->hidden_size, this->fc2->output, layer_output, this->fc2->output); + } + + void prefill(int32_t num_tokens, int32_t num_history_tokens, int32_t* input, int32_t* position_ids, void* output) { + this->model->embedding->prefill(calc_stream, num_tokens, input); + if (num_history_tokens > 0) { + this->eagle_prefill(this->num_history_tokens); + } + + cudaMemcpy(this->prev_embed, this->model->embedding->output + this->model->hidden_size, (num_tokens - 1) * this->model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->model->prefill_embed(num_tokens, num_history_tokens, this->model->embedding->output, position_ids, output); + this->prev_hidden_state = this->model->norm->output; + cudaMemcpy(this->eagle_position_ids, position_ids, num_tokens * sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->num_prev = num_tokens; + + this->num_history_tokens = num_history_tokens; + this->is_first_draft = true; + } + + void decode(int32_t num_tokens, int32_t padded_length, int32_t* input, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + this->model->decode(num_tokens, padded_length, input, position_ids, cache_length, mask_2d, output); + } + + void draft(int32_t* tree_draft_ids, int32_t* tree_position_ids, int32_t* cache_length, uint64_t* tree_attn_mask, int32_t* tree_parent) { + cudaMemcpy(this->eagle_original_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + this->eagle_padded_length = (this->eagle_original_length[0] + 256 - 1) / 128 * 128; + + + if (this->is_first_draft) { + this->model->embedding->prefill(calc_stream, 1, tree_draft_ids); + this->eagle_prefill(this->num_history_tokens); + } else { + this->eagle_decode(cache_length); + } + cudaMemcpy(this->eagle_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->eagle_position_ids, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + repeat(calc_stream, topk_per_iter, 1, 0, this->eagle_position_ids); + + { // d = 0 + this->model->lm_head->prefill(calc_stream, 1, this->fc2->output + (num_prev - 1) * this->model->hidden_size, this->eagle_logits); + log_softmax(calc_stream, 1, this->model->vocab_size, this->eagle_logits); + this->topk_func->prefill(calc_stream, 1, this->eagle_logits); + cudaMemcpy(this->topk_func_2->topk_pos, this->topk_func->topk_pos, topk_per_iter * sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->topk_func_2->topk_val, this->topk_func->topk_val, topk_per_iter * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->tried_history_val, this->topk_func->topk_val, topk_per_iter * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->tried_history_pos, this->topk_func->topk_pos, topk_per_iter * sizeof(int32_t), cudaMemcpyDeviceToDevice); + repeat(calc_stream, topk_per_iter, this->model->hidden_size, num_prev-1, this->fc2->output, this->fc1->output); + init_tree(calc_stream, topk_per_iter, this->eagle_mask_2d); + } + for (int d = 1; d < this->num_iter; ++d) { + add(calc_stream, 1, this->eagle_cache_length, topk_per_iter); + this->model->embedding->prefill(calc_stream, topk_per_iter, this->topk_func_2->topk_pos); + this->fc2->prefill(calc_stream, topk_per_iter, this->fc1->output); + this->fc1->prefill(calc_stream, topk_per_iter, this->model->embedding->output); + elementwise_add(calc_stream, topk_per_iter, this->model->hidden_size, this->fc1->output, this->fc2->output, this->fc2->output); + T* layer_output = nullptr; + for (int i = 0; i < num_layers; i++) { + this->layers[i]->decode(topk_per_iter, this->eagle_padded_length, this->fc2->output, layer_output, this->eagle_position_ids, this->eagle_cache_length, Mask(eagle_mask_2d, topk_per_iter, topk_per_iter * d), this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_add(calc_stream, topk_per_iter, this->model->hidden_size, this->fc2->output, layer_output, this->fc2->output); + add(calc_stream, topk_per_iter, this->eagle_position_ids, 1); + + this->model->lm_head->prefill(calc_stream, topk_per_iter, this->fc2->output, this->eagle_logits); + log_softmax(calc_stream, topk_per_iter, this->model->vocab_size, this->eagle_logits); + this->topk_func->prefill(calc_stream, topk_per_iter, this->eagle_logits); + cumsum(calc_stream, topk_per_iter, topk_per_iter, this->topk_func->topk_val, this->topk_func_2->topk_val); + cudaMemcpy(this->tried_history_val + topk_per_iter + (d - 1) * topk_per_iter * topk_per_iter, this->topk_func->topk_val, topk_per_iter * topk_per_iter * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->tried_history_pos + topk_per_iter + (d - 1) * topk_per_iter * topk_per_iter, this->topk_func->topk_pos, topk_per_iter * topk_per_iter * sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->topk_func_2->prefill(calc_stream, 1, this->topk_func->topk_val, topk_per_iter * topk_per_iter, topk_per_iter); + + cudaMemcpy(this->tmp_mask_2d, this->eagle_mask_2d, topk_per_iter * sizeof(uint64_t), cudaMemcpyDeviceToDevice); + set_parent(calc_stream, topk_per_iter, this->tried_history_parent + (d - 1) * topk_per_iter, this->topk_func_2->topk_pos, topk_per_iter + (d - 1) * topk_per_iter * topk_per_iter); + update_tree(calc_stream, topk_per_iter, topk_per_iter * d, this->eagle_mask_2d, this->tmp_mask_2d, this->topk_func_2->topk_pos); + remap_hidden(calc_stream, topk_per_iter, this->model->hidden_size, this->topk_func_2->topk_pos, this->fc2->output, this->fc1->output, topk_per_iter); + remap_id(calc_stream, topk_per_iter, this->topk_func_2->topk_pos, this->topk_func->topk_pos); + } + + this->topk_func_2->prefill(calc_stream, 1, this->tried_history_val); + + // build tree + build_dynamic_tree(calc_stream, this->tree_size, this->eagle_original_length[0], this->topk_per_iter, this->tried_history_parent, this->topk_func_2->topk_pos, tree_position_ids, tree_attn_mask, tree_parent); + remap_id(calc_stream, this->tree_size-1, this->topk_func_2->topk_pos, this->tried_history_pos, tree_draft_ids + 1); + + this->is_first_draft = false; + } + + int verify(int32_t num_tokens, int32_t* pred, int32_t* gt, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, int32_t* tree_parent) { + verify_draft(calc_stream, num_tokens, pred, gt, position_ids, cache_length, mask_2d, tree_parent, this->d_best); + cudaMemcpyAsync(this->h_best, this->d_best, 2 * sizeof(int32_t), cudaMemcpyDeviceToHost, calc_stream.stream); + cudaStreamSynchronize(calc_stream.stream); + + this->num_prev = h_best[0]; + remap_hidden(calc_stream, this->num_prev, this->model->hidden_size, pred, this->model->norm->output, this->prev_hidden_state); + + fix_kv_cache(calc_stream, h_best[0], this->model->kv_caches->num_hidden_layers * 2, this->model->kv_caches->dim, pred, gt, cache_length, this->model->kv_caches->d_flat_caches, this->tmp_kvcache); + + this->model->embedding->prefill(calc_stream, this->num_prev, pred); + cudaMemcpy(this->prev_embed, this->model->embedding->output, this->num_prev * this->model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + + make_arange(calc_stream, this->num_prev, cache_length, this->eagle_position_ids); + + return h_best[0]; + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/elementwise.cuh b/examples/CPM.cu/src/model/elementwise.cuh new file mode 100644 index 00000000..1e3ff8f9 --- /dev/null +++ b/examples/CPM.cu/src/model/elementwise.cuh @@ -0,0 +1,87 @@ +#pragma once +#include <cuda_runtime.h> +#include "../trait.cuh" +#include "../utils.cuh" + +namespace { +template <typename T2> +__global__ void batched_add_kernel(int dim, const T2* a, const T2* b, T2* c) { + int row = blockIdx.x * dim; + int col = blockIdx.y * blockDim.x + threadIdx.x; + if (col < dim) { + c[row + col] = a[row + col] + b[col]; + } +} + +template <typename T2> +__global__ void elementwise_add_kernel(int dim, const T2* a, const T2* b, T2* c) { + int row = blockIdx.x * dim; + int col = blockIdx.y * blockDim.x + threadIdx.x; + if (col < dim) { + c[row + col] = a[row + col] + b[row + col]; + } +} + +template <typename T2> +__global__ void elementwise_add3_kernel(int dim, const T2* a, const T2* b, const T2*c, T2* d) { + int row = blockIdx.x * dim; + int col = blockIdx.y * blockDim.x + threadIdx.x; + if (col < dim) { + d[row + col] = a[row + col] + b[row + col] + c[row+col]; + } +} + +template <typename T, typename T2> +__global__ void elementwise_scale_kernel(int dim, const T2* a, float v, T2* b) { + int row = blockIdx.x * dim; + int col = blockIdx.y * blockDim.x + threadIdx.x; + if (col < dim) { + b[row + col] = a[row + col] * T2(T(v), T(v)); + } +} + +template <typename T> +__global__ void batched_mul_kernel(int dim, const T* a, const T* b, T* c) { + int row = blockIdx.x; + int col = threadIdx.x; + T bv = b[row]; + for (int i = col; i < dim; i += blockDim.x) { + c[row * dim + i] = a[row * dim + i] * bv; + } +} +} // namespace + +template <typename T> +void batched_add(const Stream& stream, int num_tokens, int dim, const T* a, const T* b, T* c) { + using T2 = typename TypeTraits<T>::half2; + dim = dim / 2; + batched_add_kernel<T2><<<dim3(num_tokens, CEIL_DIV(dim, 512)), 512, 0, stream.stream>>>(dim, (T2*)a, (T2*)b, (T2*)c); +} + +template <typename T> +void elementwise_add(const Stream& stream, int num_tokens, int dim, const T* a, const T* b, T* c) { + using T2 = typename TypeTraits<T>::half2; + dim = dim / 2; + elementwise_add_kernel<T2><<<dim3(num_tokens, CEIL_DIV(dim, 512)), 512, 0, stream.stream>>>(dim, (T2*)a, (T2*)b, (T2*)c); +} + +template <typename T> +void elementwise_add3(const Stream& stream, int num_tokens, int dim, const T* a, const T* b, const T* c, T* d) { + using T2 = typename TypeTraits<T>::half2; + dim = dim / 2; + elementwise_add3_kernel<T2><<<dim3(num_tokens, CEIL_DIV(dim, 512)), 512, 0, stream.stream>>>(dim, (T2*)a, (T2*)b, (T2*)c, (T2*)d); +} + +template <typename T> +void elementwise_scale(const Stream& stream, int num_tokens, int dim, T* a, float v, T* b = nullptr) { + if (v == 1.0 && b == nullptr) return; + if (b == nullptr) b = a; + using T2 = typename TypeTraits<T>::half2; + dim = dim / 2; + elementwise_scale_kernel<T, T2><<<dim3(num_tokens, CEIL_DIV(dim, 512)), 512, 0, stream.stream>>>(dim, (T2*)a, v, (T2*)b); +} + +template <typename T> +void batched_mul(const Stream& stream, int num_tokens, int dim, const T* a, const T* b, T* c) { + batched_mul_kernel<<<num_tokens, 128, 0, stream.stream>>>(dim, (T*)a, (T*)b, (T*)c); +} \ No newline at end of file diff --git a/examples/CPM.cu/src/model/embedding.cuh b/examples/CPM.cu/src/model/embedding.cuh new file mode 100644 index 00000000..12eb1184 --- /dev/null +++ b/examples/CPM.cu/src/model/embedding.cuh @@ -0,0 +1,53 @@ +#pragma once +#include <cuda_runtime.h> +#include "../utils.cuh" + +namespace { +template <typename T> +__global__ void embedding_kernel(int32_t num_cols, const int32_t* input, const float4* weight, float4* output) { + int row = blockIdx.x; + int col = threadIdx.x; + int offset_output = row * num_cols; + int offset_weight = input[row] * num_cols; + for (int i = col; i < num_cols; i += blockDim.x) { + output[offset_output + i] = weight[offset_weight + i]; + } +} + +template <typename T> +void embedding(const Stream& stream, int32_t num_tokens, int32_t hidden_size, const int32_t* input, const T* weight, T* output) { + embedding_kernel<T><<<num_tokens, 256, 0, stream.stream>>>(hidden_size/(16/sizeof(T)), input, (float4*)weight, (float4*)output); +} +} // namespace + +template <typename T> +struct Embedding { + int vocab_size; + int hidden_size; + T* weight; + T* output; + float embed_scale; + + Embedding(int vocab_size, int hidden_size, float embed_scale = 1.0f) { + this->vocab_size = vocab_size; + this->hidden_size = hidden_size; + this->embed_scale = embed_scale; + } + + void init_weight_ptr(Memory* memory) { + weight = (T*)memory->allocate_for_model(vocab_size * hidden_size * sizeof(T)); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + return memory->allocate((void**)&this->output, offset, num_tokens * hidden_size * sizeof(T)); + } + + void load_to_storage(std::string name, void* ptr) { + cudaMemcpy((void*)weight, ptr, vocab_size * hidden_size * sizeof(T), cudaMemcpyHostToDevice); + } + + void prefill(const Stream& stream, int32_t num_tokens, int32_t* input) { + embedding(stream, num_tokens, this->hidden_size, input, this->weight, this->output); + elementwise_scale(stream, num_tokens, this->hidden_size, this->output, this->embed_scale); + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/ffn.cuh b/examples/CPM.cu/src/model/ffn.cuh new file mode 100644 index 00000000..02512c68 --- /dev/null +++ b/examples/CPM.cu/src/model/ffn.cuh @@ -0,0 +1,91 @@ +#pragma once +#include "../trait.cuh" +#include "norm.cuh" +#include "linear.cuh" +#include "activation.cuh" +#include <cuda_runtime.h> + +template <typename T> +struct FFN { + T* output; + virtual void init_weight_ptr(Memory* memory) = 0; + virtual int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) = 0; + virtual void load_to_storage(std::string name, void* ptr) = 0; + virtual void prefill(const Stream& stream, int32_t num_tokens, T* input, T* prev_output) = 0; + virtual void decode(const Stream& stream, int32_t num_tokens, T* input, T* prev_output) = 0; +}; + +template <typename T> +struct GatedFFN : FFN<T> { + int hidden_size; + int intermediate_size; + float rms_norm_eps; + + Norm<T> *ffn_norm; + Linear<T> *gate_proj, *up_proj; + Linear<T> *down_proj; + + T* gated_up; + + GatedFFN(int hidden_size, int intermediate_size, float rms_norm_eps) { + this->hidden_size = hidden_size; + this->intermediate_size = intermediate_size; + this->rms_norm_eps = rms_norm_eps; + + this->ffn_norm = new RMSNorm<T>(hidden_size, rms_norm_eps); + this->gate_proj = new Linear<T>(hidden_size, intermediate_size); + this->up_proj = new Linear<T>(hidden_size, intermediate_size); + this->down_proj = new Linear<T>(intermediate_size, hidden_size); + } + + void init_weight_ptr(Memory* memory) { + this->ffn_norm->init_weight_ptr(memory); + this->gate_proj->init_weight_ptr(memory); + this->up_proj->init_weight_ptr(memory); + this->down_proj->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t ffn_norm_end = this->ffn_norm->init_output_ptr(memory, num_tokens, offset); + int64_t gate_proj_end = this->gate_proj->init_output_ptr(memory, num_tokens, ffn_norm_end); + int64_t up_proj_end = this->up_proj->init_output_ptr(memory, num_tokens, gate_proj_end); + int64_t gated_up_end = memory->allocate((void**)&this->gated_up, up_proj_end, num_tokens * intermediate_size * sizeof(T)); + int64_t down_proj_end = this->down_proj->init_output_ptr(memory, num_tokens, gated_up_end); + this->output = this->down_proj->output; + return down_proj_end; + } + + void load_to_storage(std::string name, void* ptr) { + if (name.find("gate_proj") != std::string::npos) { + this->gate_proj->load_to_storage(name, ptr); + } else if (name.find("up_proj") != std::string::npos) { + this->up_proj->load_to_storage(name, ptr); + } else if (name.find("down_proj") != std::string::npos) { + this->down_proj->load_to_storage(name, ptr); + } else if (name.find("post_attention_layernorm") != std::string::npos) { + this->ffn_norm->load_to_storage(name, ptr); + } else { + throw std::invalid_argument("Unsupported name " + name); + } + } + + void prefill(const Stream& stream, int32_t num_tokens, T* input, T* prev_output) { + this->ffn_norm->prefill(stream, num_tokens, input, prev_output); + +#ifdef DISABLE_MEMPOOL + this->gate_proj->prefill(stream, num_tokens, this->ffn_norm->output); + this->up_proj->prefill(stream, num_tokens, this->ffn_norm->output); + cudaMemcpy(this->gated_up, this->up_proj->output, num_tokens * this->intermediate_size * sizeof(T), cudaMemcpyDeviceToDevice); + gated_silu<T>(stream, num_tokens, this->intermediate_size, this->gate_proj->output, this->gated_up); +#else + linear<T>(stream, num_tokens, this->hidden_size, this->intermediate_size*2, this->ffn_norm->output, this->gate_proj->weight, this->gate_proj->output); + gated_silu_interleaved<T>(stream, num_tokens, this->intermediate_size, this->gate_proj->output, this->gated_up); +#endif + + this->down_proj->prefill(stream, num_tokens, this->gated_up); + } + + void decode(const Stream& stream, int32_t num_tokens, T* input, T* prev_output) { + prefill(stream, num_tokens, input, prev_output); + } +}; diff --git a/examples/CPM.cu/src/model/hier_spec_quant/hier_ea_w4a16_gm_rot_spec_w4a16_gm.cuh b/examples/CPM.cu/src/model/hier_spec_quant/hier_ea_w4a16_gm_rot_spec_w4a16_gm.cuh new file mode 100644 index 00000000..dcebb582 --- /dev/null +++ b/examples/CPM.cu/src/model/hier_spec_quant/hier_ea_w4a16_gm_rot_spec_w4a16_gm.cuh @@ -0,0 +1,706 @@ +#pragma once +#include "../w4a16_gptq_marlin/w4a16_gptq_marlin_model.cuh" +#include "../eagle.cuh" +#include "../drafter.cuh" +#include "../w4a16_gptq_marlin/w4a16_gptq_marlin_layer.cuh" + + +template <typename T> +struct HierEagleW4A16GMRotSpecW4A16GMImpl: Model { + + // eagle + int ea_num_layers; + int ea_num_iter; + int ea_topk_per_iter; + int ea_tree_size; + int ea_total_tried; + + KVCacheManager<T>* ea_kv_caches; + // new embedding + Embedding<T>* ea_embedding; + std::vector<Layer<T>*> ea_layers; + Linear<T> * ea_rms_norm_rotation; + Linear<T, true, true> *ea_fc1; + Linear<T> *ea_fc2; + Linear<T> *ea_lm_head; + functions::TopK<T>* ea_topk_func; + functions::TopK<T>* ea_topk_func_2; + + T *ea_prev_hidden_state, *ea_prev_embed; + int ea_num_prev, ea_num_history_tokens; + int32_t *eagle_position_ids, *eagle_cache_length; + int *eagle_original_length, eagle_padded_length; + uint64_t *eagle_mask_2d, *ea_tmp_mask_2d; + T* eagle_logits; + T* ea_tried_history_val; int32_t* ea_tried_history_pos; + int32_t* ea_tried_history_parent; + bool ea_is_first_draft; + + + int32_t *ea_h_best, *ea_d_best; + + T* ea_tmp_kvcache; + + int32_t* ea_tree_draft_ids, *ea_tree_position_ids, *ea_tree_cache_length, *ea_tree_parent; + uint64_t* ea_tree_attn_mask; + + // draft & target + + W4A16GPTQMarlinModelImpl<T>* draft_model; + W4A16GPTQMarlinModelImpl<T>* model; + + // draft args + int32_t *draft_input; + int32_t *draft_position_ids, *draft_cache_length; + int * host_draft_cache_length; + int draft_padded_length; + T* draft_logits; + bool is_first_draft; + functions::TopK<T>* topk_func; + int32_t *draft_tmp; + int32_t *h_best, *d_best; + int num_prev, num_history_tokens; + + // draft mask always nullptr + uint64_t* draft_mask_2d; + + // graph + bool draft_cuda_graph; + int draft_graphCreated_padding_length; + int draft_graphCreated_input_length; + cudaGraph_t draft_graph; + cudaGraphExec_t draft_graphExec; + + // cascade vars + int cur_draft_length; + int min_draft_length; + T * draft_tmp_hidden_state; + bool draft_model_start; // start from draft model for num_prev == 1 + + int32_t* ea_accept_nums; + int ea_accept_nums_size; + int cur_ea_accept_nums_size; + + HierEagleW4A16GMRotSpecW4A16GMImpl( + W4A16GPTQMarlinModelImpl<T>* model, + int draft_vocab_size, + int draft_num_hidden_layers, + int draft_hidden_size, + int draft_intermediate_size, + int draft_num_attention_heads, + int draft_num_key_value_heads, + int draft_head_dim, + float draft_rms_norm_eps, + int draft_group_size, + // int num_iter, + int min_draft_length, + bool draft_cuda_graph, + // eagle args + int ea_num_layers, + int ea_num_iter, + int ea_topk_per_iter, + int ea_tree_size, + bool draft_model_start + ) { + this->model = model; + this->draft_model = new W4A16GPTQMarlinModelImpl<T>( + 0, + draft_vocab_size, + draft_num_hidden_layers, + draft_hidden_size, + draft_intermediate_size, + draft_num_attention_heads, + draft_num_key_value_heads, + draft_head_dim, + draft_rms_norm_eps, + draft_group_size, + this->model->chunk_length + ); + + // draft config + this->draft_mask_2d = 0; + topk_func = new functions::TopK<T>(model->vocab_size, 1); // greedy sample + + this->draft_cuda_graph = draft_cuda_graph; + this->draft_graphCreated_padding_length = -1; + this->draft_graphCreated_input_length = -1; + this->draft_graph = nullptr; + this->draft_graphExec = nullptr; + + this->min_draft_length = min_draft_length; + this->draft_model_start = draft_model_start; + + // eagle config + this->ea_num_layers = ea_num_layers; + this->ea_num_iter = ea_num_iter; + this->ea_topk_per_iter = ea_topk_per_iter; + this->ea_tree_size = ea_tree_size; + this->ea_total_tried = ea_topk_per_iter * ea_topk_per_iter * (ea_num_iter-1) + ea_topk_per_iter; + + // ea model + ea_embedding = new Embedding<T>(this->draft_model->vocab_size, this->draft_model->hidden_size); + + ea_kv_caches = new KVCacheManager<T>(ea_num_layers, this->draft_model->num_key_value_heads, this->draft_model->head_dim); + ea_rms_norm_rotation = new Linear<T>(this->draft_model->hidden_size, this->draft_model->hidden_size); + ea_fc1 = new Linear<T, true, true>(this->draft_model->hidden_size, this->draft_model->hidden_size); + ea_fc2 = new Linear<T>(this->draft_model->hidden_size, this->draft_model->hidden_size); + for (int i = 0; i < ea_num_layers; i++) { + ea_layers.push_back(new Layer<T>(this->draft_model->hidden_size, this->draft_model->intermediate_size, this->draft_model->num_attention_heads, this->draft_model->num_key_value_heads, this->draft_model->head_dim, this->draft_model->rms_norm_eps)); + } + ea_lm_head = new Linear<T>(this->draft_model->hidden_size, this->draft_model->vocab_size); + + ea_topk_func = new functions::TopK<T>(this->draft_model->vocab_size, ea_topk_per_iter); + ea_topk_func_2 = new functions::TopK<T>(ea_total_tried, this->ea_tree_size-1); + + this->ea_accept_nums_size = 0; + this->cur_ea_accept_nums_size = 0; + + } + + void init_weight_ptr(Memory* memory) { + ea_embedding->init_weight_ptr(memory); + ea_rms_norm_rotation->init_weight_ptr(memory); + ea_fc1->init_weight_ptr(memory); + ea_fc2->init_weight_ptr(memory); + for (int i = 0; i < ea_num_layers; i++) { + ea_layers[i]->init_weight_ptr(memory); + } + ea_lm_head->init_weight_ptr(memory); + ea_layers[0]->attn->attn_norm = new Skip<T>(this->draft_model->hidden_size); + ea_kv_caches->rotary_embedding = this->draft_model->kv_caches->rotary_embedding; + } + + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + // init eagle output + offset = ea_embedding->init_output_ptr(memory, num_tokens, offset); + offset = ea_rms_norm_rotation->init_output_ptr(memory, num_tokens, offset); + offset = ea_fc1->init_output_ptr(memory, num_tokens, offset); + offset = ea_fc2->init_output_ptr(memory, num_tokens, offset); + int64_t layer_end = 0; + for (int i = 0; i < ea_num_layers; i++) { + layer_end = ea_layers[i]->init_output_ptr(memory, num_tokens, offset); + } + offset = ea_lm_head->init_output_ptr(memory, 64, layer_end); + offset = memory->allocate((void**)&eagle_logits, offset, this->ea_topk_per_iter * this->draft_model->vocab_size * sizeof(T)); + offset = memory->allocate((void**)&eagle_mask_2d, offset, this->ea_topk_per_iter * sizeof(uint64_t)); + offset = memory->allocate((void**)&ea_tmp_mask_2d, offset, this->ea_topk_per_iter * sizeof(uint64_t)); + offset = memory->allocate((void**)&ea_tried_history_val, offset, this->ea_total_tried * sizeof(T)); + offset = memory->allocate((void**)&ea_tried_history_pos, offset, this->ea_total_tried * sizeof(int32_t)); + offset = memory->allocate((void**)&ea_tried_history_parent, offset, this->ea_topk_per_iter * (this->ea_num_iter - 1) * sizeof(int32_t)); + cudaMallocHost(&eagle_original_length, sizeof(int32_t)); + + offset = ea_topk_func->init_output_ptr(memory, this->ea_topk_per_iter, offset); + offset = ea_topk_func_2->init_output_ptr(memory, 1, offset); + + offset = memory->allocate((void**)&ea_prev_hidden_state, offset, num_tokens * this->draft_model->hidden_size * sizeof(T)); + offset = memory->allocate((void**)&ea_prev_embed, offset, num_tokens * this->draft_model->hidden_size * sizeof(T)); + offset = memory->allocate((void**)&eagle_position_ids, offset, num_tokens * sizeof(int32_t)); + offset = memory->allocate((void**)&eagle_cache_length, offset, sizeof(int32_t)); + + offset = memory->allocate((void**)&ea_d_best, offset, 2 * sizeof(int32_t)); + cudaMallocHost(&ea_h_best, 2 * sizeof(int32_t)); + offset = memory->allocate((void**)&ea_tmp_kvcache, offset, 64 * this->draft_model->kv_caches->num_hidden_layers * 2 * this->draft_model->kv_caches->dim * sizeof(T)); + + // to allocate ealge draft some states + offset = memory->allocate((void**)&ea_tree_draft_ids, offset, this->ea_tree_size * sizeof(int32_t)); + offset = memory->allocate((void**)&ea_tree_position_ids, offset, this->ea_tree_size * sizeof(int32_t)); + offset = memory->allocate((void**)&ea_tree_cache_length, offset, sizeof(int32_t)); + offset = memory->allocate((void**)&ea_tree_parent, offset, this->ea_tree_size * sizeof(int32_t)); + offset = memory->allocate((void**)&ea_tree_attn_mask, offset, this->ea_tree_size * sizeof(uint64_t)); + + + + // init draft output + int64_t lm_head_end = this->draft_model->init_output_ptr(memory, num_tokens, offset); + offset = lm_head_end; + + offset = memory->allocate((void**)&draft_input, offset, num_tokens * sizeof(int32_t)); + offset = memory->allocate((void**)&draft_position_ids, offset, num_tokens * sizeof(int32_t)); + offset = memory->allocate((void**)&draft_cache_length, offset, sizeof(int32_t)); + cudaMallocHost(&host_draft_cache_length, sizeof(int32_t)); + + + offset = memory->allocate((void**)&draft_logits, offset, 64 * this->draft_model->vocab_size * sizeof(T)); + offset = topk_func->init_output_ptr(memory, 64, offset); + + offset = memory->allocate((void**)&draft_tmp, offset, (this->min_draft_length + ea_num_iter + 1)*sizeof(int32_t)); + offset = memory->allocate((void**)&d_best, offset, sizeof(int32_t)); + cudaMallocHost(&h_best, sizeof(int32_t)); + + // cascade vars + offset = memory->allocate((void**)&draft_tmp_hidden_state, offset, (this->min_draft_length + ea_num_iter + 1) * this->draft_model->hidden_size * sizeof(T)); + // cudaMallocHost(&host_ea_accept_nums, 1024 * sizeof(int)); + offset = memory->allocate((void**)&ea_accept_nums, offset, 1024 * sizeof(int32_t)); + return offset; + } + + int init_storage() { + + this->model->init_weight_ptr(this->model->memory); + // this->init_weight_ptr(this->model->memory); + this->draft_model->init_weight_ptr(this->model->memory); + this->init_weight_ptr(this->model->memory); + + int64_t offset = this->model->init_output_ptr(this->model->memory, this->model->chunk_length, this->model->memory->model_offset); + int64_t kv_cache_offset = init_output_ptr(this->model->memory, this->model->chunk_length, offset); + + int model_kv_size = (this->model->num_hidden_layers*this->model->num_key_value_heads*this->model->head_dim); + int draft_kv_size = (this->draft_model->num_hidden_layers*this->draft_model->num_key_value_heads*this->draft_model->head_dim); + int ea_kv_size = this->ea_num_layers * this->draft_model->num_key_value_heads * this->draft_model->head_dim; + float ratio = float(model_kv_size)/float(model_kv_size + draft_kv_size + ea_kv_size); + kv_cache_offset = this->model->kv_caches->init_output_ptr(this->model->memory, kv_cache_offset, ratio); + ratio = float(draft_kv_size)/float(draft_kv_size + ea_kv_size); + kv_cache_offset = this->draft_model->kv_caches->init_output_ptr(this->model->memory, kv_cache_offset, ratio); + this->ea_kv_caches->init_output_ptr(this->model->memory, kv_cache_offset, 1.0); + return min(min(this->draft_model->kv_caches->budget, this->model->kv_caches->budget), this->ea_kv_caches->budget + 1); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.substr(0, 5) == "eagle") { + if (name.substr(0, 23) == "eagle.rms_norm_rotation") { + ea_rms_norm_rotation->load_to_storage(name, ptr); + } else if (name.substr(0, 18) == "eagle.embed_tokens") { + ea_embedding->load_to_storage(name, ptr); + } else if (name.substr(0, 9) == "eagle.fc1") { + ea_fc1->load_to_storage(name, ptr); + } else if (name.substr(0, 9) == "eagle.fc2") { + ea_fc2->load_to_storage(name, ptr); + } else if (name.substr(0, 13) == "eagle.lm_head") { + ea_lm_head->load_to_storage(name, ptr); + } else { + std::regex layer_regex("eagle\\.layers\\.(\\d+)\\.(.*)"); + std::smatch matches; + if (std::regex_search(name, matches, layer_regex)) { + int layer_idx = std::stoi(matches[1]); + ea_layers[layer_idx]->load_to_storage(matches[2], ptr); + } else { + throw std::invalid_argument("Unsupported name (layer_idx not found): " + name); + } + } + } else if (name.substr(0, 5) == "draft"){ + std::string draft_name = name.substr(6); + this->draft_model->load_to_storage(draft_name, ptr); + } else { + this->model->load_to_storage(name, ptr); + } + } + + + + void eagle_prefill(int num_history_tokens) { + cudaMemcpy(this->ea_prev_embed + (ea_num_prev - 1) * this->draft_model->hidden_size, this->ea_embedding->output, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->ea_fc1->prefill(calc_stream, ea_num_prev, this->ea_prev_embed); + this->ea_rms_norm_rotation->prefill(calc_stream, ea_num_prev, this->ea_prev_hidden_state); + this->ea_fc2->prefill(calc_stream, ea_num_prev, this->ea_rms_norm_rotation->output); + elementwise_add(calc_stream, ea_num_prev, this->draft_model->hidden_size, this->ea_fc1->output, this->ea_fc2->output, this->ea_fc2->output); + T* layer_output = nullptr; + + for (int i = 0; i < ea_num_layers; i++) { + this->ea_layers[i]->prefill(num_prev, num_history_tokens, this->ea_fc2->output, layer_output, this->eagle_position_ids, this->ea_kv_caches->caches[i]); + layer_output = this->ea_layers[i]->output; + } + elementwise_add(calc_stream, ea_num_prev, this->draft_model->hidden_size, this->ea_fc2->output, layer_output, this->ea_fc2->output); + } + + void eagle_decode(int32_t* cache_length) { + + this->ea_fc1->prefill(calc_stream, ea_num_prev, this->ea_prev_embed); + this->ea_rms_norm_rotation->prefill(calc_stream, ea_num_prev, this->ea_prev_hidden_state); + this->ea_fc2->prefill(calc_stream, ea_num_prev, this->ea_rms_norm_rotation->output); + elementwise_add(calc_stream, ea_num_prev, this->draft_model->hidden_size, this->ea_fc1->output, this->ea_fc2->output, this->ea_fc2->output); + T* layer_output = nullptr; + for (int i = 0; i < ea_num_layers; i++) { + this->ea_layers[i]->decode(ea_num_prev, this->eagle_padded_length, this->ea_fc2->output, layer_output, this->eagle_position_ids, cache_length, Mask(nullptr), this->ea_kv_caches->caches[i]); + layer_output = this->ea_layers[i]->output; + } + elementwise_add(calc_stream, ea_num_prev, this->draft_model->hidden_size, this->ea_fc2->output, layer_output, this->ea_fc2->output); + } + + void prefill(int32_t num_tokens, int32_t num_history_tokens, int32_t* input, int32_t* position_ids, void* output) { + this->model->prefill(num_tokens, num_history_tokens, input, position_ids, output); + if (num_history_tokens > 0) { + this->draft_model->embedding->prefill(calc_stream, this->num_prev, this->draft_input); + + // new embedding + this->ea_embedding->prefill(calc_stream, this->num_prev, this->draft_input); + cudaMemcpy(this->ea_prev_embed, this->ea_embedding->output+this->draft_model->hidden_size, (this->num_prev-1) * this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + + this->draft_model->prefill_embed(this->num_prev, this->num_history_tokens, this->draft_model->embedding->output, this->draft_position_ids, (void*)this->draft_logits); + + cudaMemcpy(this->ea_prev_hidden_state, this->draft_model->norm->output, this->num_prev * this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + + this->ea_embedding->prefill(calc_stream, 1, input); + this->eagle_prefill(this->ea_num_history_tokens); + + // this->draft_model-> + } + + cudaMemcpy(this->draft_input, input, num_tokens * sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_position_ids, position_ids, num_tokens * sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->num_prev = num_tokens; + this->num_history_tokens = num_history_tokens; + this->is_first_draft = true; + + // eagle + cudaMemcpy(this->eagle_position_ids, this->draft_position_ids, num_tokens * sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->ea_num_prev = num_tokens; + this->ea_num_history_tokens = num_history_tokens; + this->ea_is_first_draft = true; + + + this->ea_accept_nums_size = 0; + cudaMemcpy(this->ea_accept_nums, &this->ea_accept_nums_size, sizeof(int), cudaMemcpyHostToDevice); + } + + + void draft_decode(int32_t num_tokens, int32_t padded_length, int32_t* input, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + throw std::runtime_error("Draft decode is not supported"); + } + + void draft_decode_with_graph_control(int32_t d_num_tokens, int32_t d_padded_length, int32_t* d_input, int32_t* d_position_ids, int32_t* d_cache_length, uint64_t* d_mask_2d, void* d_output) { + if (this->draft_cuda_graph) { + if (this->draft_graphCreated_padding_length != d_padded_length || this->draft_graphCreated_input_length != d_num_tokens) { + if (this->draft_graphExec != nullptr) { + cudaGraphExecDestroy(this->draft_graphExec); + this->draft_graphExec = nullptr; + } + if (this->draft_graph != nullptr) { + cudaGraphDestroy(this->draft_graph); + this->draft_graph = nullptr; + } + cudaStreamBeginCapture(calc_stream.stream, cudaStreamCaptureModeGlobal); + this->draft_model->decode(d_num_tokens, d_padded_length, d_input, d_position_ids, d_cache_length, d_mask_2d, d_output); + cudaStreamEndCapture(calc_stream.stream, &(this->draft_graph)); + cudaGraphInstantiate(&(this->draft_graphExec), this->draft_graph, nullptr, nullptr, 0); + this->draft_graphCreated_padding_length = d_padded_length; + this->draft_graphCreated_input_length = d_num_tokens; + } + cudaGraphLaunch(this->draft_graphExec, calc_stream.stream); + } else { + this->draft_model->decode(d_num_tokens, d_padded_length, d_input, d_position_ids, d_cache_length, d_mask_2d, d_output); + } + } + + void decode(int32_t num_tokens, int32_t padded_length, int32_t* input, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + this->model->decode(num_tokens, padded_length, input, position_ids, cache_length, nullptr, output); + } + + + void draft_with_eagle(int32_t* ea_tree_draft_ids, int32_t* ea_tree_position_ids, int32_t* ea_cache_length, uint64_t* ea_tree_attn_mask, int32_t* ea_tree_parent) { + cudaMemcpy(this->eagle_original_length, ea_cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + this->eagle_padded_length = (this->eagle_original_length[0] + 256 - 1) / 128 * 128; + if (this->ea_is_first_draft) { + // prefill hidden states and embedding have been cpy + this->ea_embedding->prefill(calc_stream, 1, ea_tree_draft_ids); + this->eagle_prefill(this->ea_num_history_tokens); + } else { + this->eagle_decode(ea_cache_length); + } + + cudaMemcpy(this->eagle_cache_length, ea_cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->eagle_position_ids, ea_cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + repeat(calc_stream, ea_topk_per_iter, 1, 0, this->eagle_position_ids); + + { // d = 0 + this->ea_lm_head->prefill(calc_stream, 1, this->ea_fc2->output + (ea_num_prev - 1) * this->draft_model->hidden_size, this->eagle_logits); + log_softmax(calc_stream, 1, this->draft_model->vocab_size, this->eagle_logits); + this->ea_topk_func->prefill(calc_stream, 1, this->eagle_logits); + cudaMemcpy(this->ea_topk_func_2->topk_pos, this->ea_topk_func->topk_pos, ea_topk_per_iter * sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_topk_func_2->topk_val, this->ea_topk_func->topk_val, ea_topk_per_iter * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_tried_history_val, this->ea_topk_func->topk_val, ea_topk_per_iter * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_tried_history_pos, this->ea_topk_func->topk_pos, ea_topk_per_iter * sizeof(int32_t), cudaMemcpyDeviceToDevice); + repeat(calc_stream, ea_topk_per_iter, this->draft_model->hidden_size, ea_num_prev-1, this->ea_fc2->output, this->ea_fc1->output); + init_tree(calc_stream, ea_topk_per_iter, this->eagle_mask_2d); + } + + for (int d = 1; d < this->ea_num_iter; ++d) { + add(calc_stream, 1, this->eagle_cache_length, ea_topk_per_iter); + this->ea_embedding->prefill(calc_stream, ea_topk_per_iter, this->ea_topk_func_2->topk_pos); + this->ea_fc2->prefill(calc_stream, ea_topk_per_iter, this->ea_fc1->output); + this->ea_fc1->prefill(calc_stream, ea_topk_per_iter, this->ea_embedding->output); + elementwise_add(calc_stream, ea_topk_per_iter, this->draft_model->hidden_size, this->ea_fc1->output, this->ea_fc2->output, this->ea_fc2->output); + T* layer_output = nullptr; + for (int i = 0; i < ea_num_layers; i++) { + this->ea_layers[i]->decode(ea_topk_per_iter, this->eagle_padded_length, this->ea_fc2->output, layer_output, this->eagle_position_ids, this->eagle_cache_length, Mask(eagle_mask_2d, ea_topk_per_iter, ea_topk_per_iter * d), this->ea_kv_caches->caches[i]); + layer_output = this->ea_layers[i]->output; + } + elementwise_add(calc_stream, ea_topk_per_iter, this->draft_model->hidden_size, this->ea_fc2->output, layer_output, this->ea_fc2->output); + add(calc_stream, ea_topk_per_iter, this->eagle_position_ids, 1); + + this->ea_lm_head->prefill(calc_stream, ea_topk_per_iter, this->ea_fc2->output, this->eagle_logits); + log_softmax(calc_stream, ea_topk_per_iter, this->draft_model->vocab_size, this->eagle_logits); + this->ea_topk_func->prefill(calc_stream, ea_topk_per_iter, this->eagle_logits); + cumsum(calc_stream, ea_topk_per_iter, ea_topk_per_iter, this->ea_topk_func->topk_val, this->ea_topk_func_2->topk_val); + cudaMemcpy(this->ea_tried_history_val + ea_topk_per_iter + (d - 1) * ea_topk_per_iter * ea_topk_per_iter, this->ea_topk_func->topk_val, ea_topk_per_iter * ea_topk_per_iter * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_tried_history_pos + ea_topk_per_iter + (d - 1) * ea_topk_per_iter * ea_topk_per_iter, this->ea_topk_func->topk_pos, ea_topk_per_iter * ea_topk_per_iter * sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->ea_topk_func_2->prefill(calc_stream, 1, this->ea_topk_func->topk_val, ea_topk_per_iter * ea_topk_per_iter, ea_topk_per_iter); + + cudaMemcpy(this->ea_tmp_mask_2d, this->eagle_mask_2d, ea_topk_per_iter * sizeof(uint64_t), cudaMemcpyDeviceToDevice); + set_parent(calc_stream, ea_topk_per_iter, this->ea_tried_history_parent + (d - 1) * ea_topk_per_iter, this->ea_topk_func_2->topk_pos, 10 + (d - 1) * ea_topk_per_iter * ea_topk_per_iter); + update_tree(calc_stream, ea_topk_per_iter, ea_topk_per_iter * d, this->eagle_mask_2d, this->ea_tmp_mask_2d, this->ea_topk_func_2->topk_pos); + remap_hidden(calc_stream, ea_topk_per_iter, this->draft_model->hidden_size, this->ea_topk_func_2->topk_pos, this->ea_fc2->output, this->ea_fc1->output, ea_topk_per_iter); + remap_id(calc_stream, ea_topk_per_iter, this->ea_topk_func_2->topk_pos, this->ea_topk_func->topk_pos); + } + + this->ea_topk_func_2->prefill(calc_stream, 1, this->ea_tried_history_val); + + // build tree + build_dynamic_tree(calc_stream, this->ea_tree_size, this->eagle_original_length[0], this->ea_topk_per_iter, this->ea_tried_history_parent, this->ea_topk_func_2->topk_pos, ea_tree_position_ids, ea_tree_attn_mask, ea_tree_parent); + remap_id(calc_stream, this->ea_tree_size-1, this->ea_topk_func_2->topk_pos, this->ea_tried_history_pos, ea_tree_draft_ids + 1); + + this->ea_is_first_draft = false; + } + + void draft(int32_t *tree_draft_ids, int32_t *tree_position_ids, int32_t *cache_length, uint64_t*, int32_t*) { + // reset cur draft length + this->cur_draft_length = 0; + this->cur_ea_accept_nums_size = 0; + if (this->is_first_draft) { + // append tree draft ids to draft input + cudaMemcpy(this->draft_input+this->num_prev, tree_draft_ids, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_position_ids+this->num_prev, tree_position_ids, sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->num_prev += 1; + + this->draft_model->embedding->prefill(calc_stream, this->num_prev, this->draft_input); + + // new embedding + this->ea_embedding->prefill(calc_stream, this->num_prev, this->draft_input); + cudaMemcpy(this->ea_prev_embed, this->ea_embedding->output+ this->draft_model->hidden_size, (this->num_prev-1) * this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + + this->draft_model->prefill_embed(this->num_prev, this->num_history_tokens, this->draft_model->embedding->output, this->draft_position_ids, (void*)this->draft_logits); + + // eagle prepare for draft_with_eagle function + // ea_is_first_draft is True + cudaMemcpy(this->eagle_position_ids + (this->ea_num_prev), tree_position_ids, sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->ea_num_prev = this->num_prev; + + cudaMemcpy(this->ea_prev_hidden_state, this->draft_model->norm->output, this->num_prev * this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->topk_func->prefill(calc_stream, 1, this->draft_logits); + + + // prepare for draft with eagle + cudaMemcpy(this->ea_tree_draft_ids, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_tree_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->ea_tree_cache_length, 1); + + cudaMemcpy(this->draft_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->draft_cache_length, 1); + cudaMemcpy(this->host_draft_cache_length, this->draft_cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + + // draft model has forward one time + cudaMemcpy(this->draft_tmp, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_tmp_hidden_state, this->draft_model->norm->output + (this->num_prev-1) * this->draft_model->hidden_size, this->draft_model->hidden_size*sizeof(T), cudaMemcpyDeviceToDevice); + this->cur_draft_length += 1; + + } else if (this->num_prev == 2){ + this->draft_model->embedding->prefill(calc_stream, this->num_prev, this->draft_input); + + // new embedding + this->ea_embedding->prefill(calc_stream, this->num_prev, this->draft_input); + cudaMemcpy(this->ea_prev_embed + (ea_num_prev)* this->draft_model->hidden_size, this->ea_embedding->output + this->draft_model->hidden_size, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + + this->draft_model->decode_embed(this->num_prev, this->draft_padded_length, this->draft_model->embedding->output, this->draft_position_ids, this->draft_cache_length, nullptr, (void*)this->draft_logits); + + + this->topk_func->prefill(calc_stream, 1, this->draft_logits+(this->draft_model->vocab_size)); + + + // prepare for the eagle input + cudaMemcpy(this->ea_prev_hidden_state + (ea_num_prev)* this->draft_model->hidden_size, this->draft_model->norm->output, this->num_prev * this->draft_model->hidden_size*sizeof(T), cudaMemcpyDeviceToDevice); + + // new embeddding + // this->draft_model->embedding->prefill(calc_stream, 1, this->topk_func->topk_pos); + // cudaMemcpy(this->ea_prev_embed + (ea_num_prev+1)* this->draft_model->hidden_size, this->draft_model->embedding->output, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->ea_embedding->prefill(calc_stream, 1, this->topk_func->topk_pos); + cudaMemcpy(this->ea_prev_embed + (ea_num_prev+1)* this->draft_model->hidden_size, this->ea_embedding->output, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->ea_num_prev += this->num_prev; + + + // prepare for draft with eagle + cudaMemcpy(this->ea_tree_draft_ids, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_tree_cache_length, this->draft_cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + + // draft model has forward one time + cudaMemcpy(this->draft_tmp, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_tmp_hidden_state, this->draft_model->norm->output + (this->num_prev-1) * this->draft_model->hidden_size, this->draft_model->hidden_size*sizeof(T), cudaMemcpyDeviceToDevice); + this->cur_draft_length += 1; + } else if (this->draft_model_start) { + // num_prev == 1 + this->draft_model->decode(this->num_prev, this->draft_padded_length, this->draft_input, this->draft_position_ids, this->draft_cache_length, nullptr, (void*)this->draft_logits); + this->topk_func->prefill(calc_stream, 1, this->draft_logits); + + // prepare for the eagle input + cudaMemcpy(this->ea_prev_hidden_state + (ea_num_prev)*this->draft_model->hidden_size, this->draft_model->norm->output, this->num_prev * this->draft_model->hidden_size*sizeof(T), cudaMemcpyDeviceToDevice); + + // new embeddding + // this->draft_model->embedding->prefill(calc_stream, 1, this->topk_func->topk_pos); + // cudaMemcpy(this->ea_prev_embed + (ea_num_prev)* this->draft_model->hidden_size, this->draft_model->embedding->output, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->ea_embedding->prefill(calc_stream, 1, this->topk_func->topk_pos); + cudaMemcpy(this->ea_prev_embed + (ea_num_prev)* this->draft_model->hidden_size, this->ea_embedding->output, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->ea_num_prev += this->num_prev; + + // prepare for draft with eagle + cudaMemcpy(this->ea_tree_draft_ids, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_tree_cache_length, this->draft_cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + + + // draft model has forward one time + cudaMemcpy(this->draft_tmp, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_tmp_hidden_state, this->draft_model->norm->output + (this->num_prev-1) * this->draft_model->hidden_size, this->draft_model->hidden_size*sizeof(T), cudaMemcpyDeviceToDevice); + this->cur_draft_length += 1; + + } else { + cudaMemcpy(this->ea_tree_draft_ids, tree_draft_ids, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_tree_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + } + + + + while (this->cur_draft_length < min_draft_length){ + + // eagle draft + this->draft_with_eagle( + this->ea_tree_draft_ids, + this->ea_tree_position_ids, + this->ea_tree_cache_length, + this->ea_tree_attn_mask, + this->ea_tree_parent + ); + + + cudaMemcpy(this->draft_cache_length, this->ea_tree_cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->draft_cache_length, this->ea_tree_size); + cudaMemcpy(this->host_draft_cache_length, this->draft_cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + this->draft_padded_length = (this->host_draft_cache_length[0]+ 128 -1) / 128*128;; + + this->draft_decode_with_graph_control( + this->ea_tree_size, + this->draft_padded_length, + this->ea_tree_draft_ids, + this->ea_tree_position_ids, + this->draft_cache_length, + this->ea_tree_attn_mask, + (void*) this->draft_logits + ); + this->topk_func->prefill(calc_stream, this->ea_tree_size, this->draft_logits); + + + this->draft_verify( + this->ea_tree_size, + this->ea_tree_draft_ids, + this->topk_func->topk_pos, + this->ea_tree_position_ids, + this->ea_tree_cache_length, + this->ea_tree_attn_mask, + this->ea_tree_parent + ); + + cudaMemcpy(this->ea_accept_nums + (this->ea_accept_nums_size+this->cur_ea_accept_nums_size+1), this->ea_d_best, sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->cur_ea_accept_nums_size += 1; + + + // accept return to ea_h_best[0] + cudaMemcpy(this->ea_tree_draft_ids, this->ea_tree_draft_ids + (this->ea_h_best[0]-1), sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->ea_tree_cache_length, this->ea_h_best[0]); + + } + + + + cudaMemcpy(tree_draft_ids + 1, this->draft_tmp, this->cur_draft_length*sizeof(int32_t), cudaMemcpyDeviceToDevice); + make_arange(calc_stream, this->cur_draft_length+1, cache_length, tree_position_ids); + this->is_first_draft = false; + } + + int draft_verify(int32_t ea_num_tokens, int32_t* ea_pred, int32_t* ea_gt, int32_t* ea_position_ids, int32_t* ea_cache_length, uint64_t* ea_mask_2d, int32_t* ea_tree_parent) { + verify_draft(calc_stream, ea_num_tokens, ea_pred, ea_gt, ea_position_ids, ea_cache_length, ea_mask_2d, ea_tree_parent, this->ea_d_best); + + cudaMemcpyAsync(this->ea_h_best, this->ea_d_best, 2 * sizeof(int32_t), cudaMemcpyDeviceToHost, calc_stream.stream); + cudaStreamSynchronize(calc_stream.stream); + + this->ea_num_prev = ea_h_best[0]; + remap_hidden(calc_stream, this->ea_num_prev, this->draft_model->hidden_size, ea_pred, this->draft_model->norm->output, this->ea_prev_hidden_state); + + fix_kv_cache(calc_stream, ea_h_best[0], this->draft_model->kv_caches->num_hidden_layers * 2, this->draft_model->kv_caches->dim, ea_pred, ea_gt, ea_cache_length, this->draft_model->kv_caches->d_flat_caches, this->ea_tmp_kvcache); + + this->ea_embedding->prefill(calc_stream, this->ea_num_prev, ea_pred); + cudaMemcpy(this->ea_prev_embed, this->ea_embedding->output, this->ea_num_prev * this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + + make_arange(calc_stream, this->ea_num_prev, ea_cache_length, this->eagle_position_ids); + + cudaMemcpy(this->draft_tmp_hidden_state + (this->cur_draft_length*this->draft_model->hidden_size), this->ea_prev_hidden_state, this->ea_num_prev * this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + + cudaMemcpy(this->draft_tmp + this->cur_draft_length, ea_pred, this->ea_num_prev * sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->cur_draft_length += this->ea_num_prev; + + return ea_h_best[0]; + } + + int verify(int32_t num_tokens, int32_t* pred, int32_t* gt, int32_t* position_ids, int32_t* cache_length, uint64_t* attn_mask, int32_t* tree_parent) { + verify_seq_draft(calc_stream, num_tokens, pred, gt, (uint16_t*)attn_mask, this->d_best); + cudaMemcpyAsync(this->h_best, this->d_best, 1 * sizeof(int32_t), cudaMemcpyDeviceToHost, calc_stream.stream); + cudaStreamSynchronize(calc_stream.stream); + + if (h_best[0]>(this->cur_draft_length+1)) { + h_best[0] = this->cur_draft_length+1; + } + + if (h_best[0]==(this->cur_draft_length+1)) { + // full accept + this->num_prev = 2; + cudaMemcpy(this->draft_input, gt + (this->cur_draft_length-1), 2*sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->draft_cache_length, this->h_best[0]+1); + make_arange(calc_stream, 2, cache_length, this->draft_position_ids); + add(calc_stream, 2, this->draft_position_ids, this->cur_draft_length); + cudaMemcpy(this->host_draft_cache_length, this->draft_cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + this->draft_padded_length = (this->host_draft_cache_length[0]+ 128 -1) / 128*128; + } else { + this->num_prev = 1; + cudaMemcpy(this->draft_input, gt + (this->h_best[0]-1), sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->draft_cache_length, this->h_best[0]+1); + cudaMemcpy(this->draft_position_ids, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->draft_position_ids, this->h_best[0]); + cudaMemcpy(this->host_draft_cache_length, this->draft_cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + this->draft_padded_length = (this->host_draft_cache_length[0]+ 128 -1) / 128*128; + + // adapt eagle draft ptr + // conidtion 1: eagle last start postion is larger than accept position + if (host_draft_cache_length[0] > this->eagle_original_length[0] + 1) { + this->ea_num_prev = host_draft_cache_length[0] - (this->eagle_original_length[0] + 1); + // keep ea_prev_hidden_state and update ea_prev_embed + // new embedding + // this->draft_model->embedding->prefill(calc_stream, 1, this->draft_input); + // cudaMemcpy(this->ea_prev_embed+(this->ea_num_prev-1), this->draft_model->embedding->output, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->ea_embedding->prefill(calc_stream, 1, this->draft_input); + cudaMemcpy(this->ea_prev_embed+(this->ea_num_prev-1), this->ea_embedding->output, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + } else { + // condition 2: eagle last start position is less than accept position + // index from the kepted draft model hidden state + cudaMemcpy(this->ea_prev_hidden_state, this->draft_tmp_hidden_state + (h_best[0]-1) *this->draft_model->hidden_size, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->ea_num_prev = 1; + // new embedding + // this->draft_model->embedding->prefill(calc_stream, 1, this->draft_input); + // cudaMemcpy(this->ea_prev_embed, this->draft_model->embedding->output, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->ea_embedding->prefill(calc_stream, 1, this->draft_input); + cudaMemcpy(this->ea_prev_embed, this->ea_embedding->output, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->eagle_position_ids, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->eagle_position_ids, this->h_best[0]-1); + } + + + } + + + this->ea_accept_nums_size += this->cur_ea_accept_nums_size; + add(calc_stream, 1, this->ea_accept_nums, this->cur_ea_accept_nums_size); + cudaMemcpy(tree_parent, this->ea_accept_nums, (this->ea_accept_nums_size+1) * sizeof(int32_t), cudaMemcpyDeviceToDevice); + + + + return h_best[0]; + + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/hier_spec_quant/hier_ea_w4a16_gm_spec_w4a16_gm.cuh b/examples/CPM.cu/src/model/hier_spec_quant/hier_ea_w4a16_gm_spec_w4a16_gm.cuh new file mode 100644 index 00000000..6a1ed7b8 --- /dev/null +++ b/examples/CPM.cu/src/model/hier_spec_quant/hier_ea_w4a16_gm_spec_w4a16_gm.cuh @@ -0,0 +1,662 @@ +#pragma once +#include "../w4a16_gptq_marlin/w4a16_gptq_marlin_model.cuh" +#include "../eagle.cuh" +#include "../drafter.cuh" +#include "../w4a16_gptq_marlin/w4a16_gptq_marlin_layer.cuh" + + +template <typename T> +struct HierEagleW4A16GMSpecW4A16GMImpl: Model { + + // eagle + int ea_num_layers; + int ea_num_iter; + int ea_topk_per_iter; + int ea_tree_size; + int ea_total_tried; + + KVCacheManager<T>* ea_kv_caches; + std::vector<Layer<T>*> ea_layers; + Linear<T, true, true> *ea_fc1; + Linear<T> *ea_fc2; + functions::TopK<T>* ea_topk_func; + functions::TopK<T>* ea_topk_func_2; + + T *ea_prev_hidden_state, *ea_prev_embed; + int ea_num_prev, ea_num_history_tokens; + int32_t *eagle_position_ids, *eagle_cache_length; + int *eagle_original_length, eagle_padded_length; + uint64_t *eagle_mask_2d, *ea_tmp_mask_2d; + T* eagle_logits; + T* ea_tried_history_val; int32_t* ea_tried_history_pos; + int32_t* ea_tried_history_parent; + bool ea_is_first_draft; + + + int32_t *ea_h_best, *ea_d_best; + + T* ea_tmp_kvcache; + + int32_t* ea_tree_draft_ids, *ea_tree_position_ids, *ea_tree_cache_length, *ea_tree_parent; + uint64_t* ea_tree_attn_mask; + + // draft & target + + W4A16GPTQMarlinModelImpl<T>* draft_model; + W4A16GPTQMarlinModelImpl<T>* model; + + // draft args + int32_t *draft_input; + int32_t *draft_position_ids, *draft_cache_length; + int * host_draft_cache_length; + int draft_padded_length; + T* draft_logits; + bool is_first_draft; + functions::TopK<T>* topk_func; + int32_t *draft_tmp; + int32_t *h_best, *d_best; + int num_prev, num_history_tokens; + + // draft mask always nullptr + uint64_t* draft_mask_2d; + + // graph + bool draft_cuda_graph; + int draft_graphCreated_padding_length; + int draft_graphCreated_input_length; + cudaGraph_t draft_graph; + cudaGraphExec_t draft_graphExec; + + // cascade vars + int cur_draft_length; + int min_draft_length; + T * draft_tmp_hidden_state; + bool draft_model_start; // start from draft model for num_prev == 1 + + int32_t* ea_accept_nums; + int ea_accept_nums_size; + int cur_ea_accept_nums_size; + + HierEagleW4A16GMSpecW4A16GMImpl( + W4A16GPTQMarlinModelImpl<T>* model, + int draft_vocab_size, + int draft_num_hidden_layers, + int draft_hidden_size, + int draft_intermediate_size, + int draft_num_attention_heads, + int draft_num_key_value_heads, + int draft_head_dim, + float draft_rms_norm_eps, + int draft_group_size, + // int num_iter, + int min_draft_length, + bool draft_cuda_graph, + // eagle args + int ea_num_layers, + int ea_num_iter, + int ea_topk_per_iter, + int ea_tree_size, + bool draft_model_start + ) { + this->model = model; + this->draft_model = new W4A16GPTQMarlinModelImpl<T>( + 0, + draft_vocab_size, + draft_num_hidden_layers, + draft_hidden_size, + draft_intermediate_size, + draft_num_attention_heads, + draft_num_key_value_heads, + draft_head_dim, + draft_rms_norm_eps, + draft_group_size, + this->model->chunk_length + ); + + // draft config + this->draft_mask_2d = 0; + topk_func = new functions::TopK<T>(model->vocab_size, 1); // greedy sample + + this->draft_cuda_graph = draft_cuda_graph; + this->draft_graphCreated_padding_length = -1; + this->draft_graphCreated_input_length = -1; + this->draft_graph = nullptr; + this->draft_graphExec = nullptr; + + this->min_draft_length = min_draft_length; + this->draft_model_start = draft_model_start; + + // eagle config + this->ea_num_layers = ea_num_layers; + this->ea_num_iter = ea_num_iter; + this->ea_topk_per_iter = ea_topk_per_iter; + this->ea_tree_size = ea_tree_size; + this->ea_total_tried = ea_topk_per_iter * ea_topk_per_iter * (ea_num_iter-1) + ea_topk_per_iter; + + // ea model + ea_kv_caches = new KVCacheManager<T>(ea_num_layers, this->draft_model->num_key_value_heads, this->draft_model->head_dim); + ea_fc1 = new Linear<T, true, true>(this->draft_model->hidden_size, this->draft_model->hidden_size); + ea_fc2 = new Linear<T>(this->draft_model->hidden_size, this->draft_model->hidden_size); + for (int i = 0; i < ea_num_layers; i++) { + ea_layers.push_back(new Layer<T>(this->draft_model->hidden_size, this->draft_model->intermediate_size, this->draft_model->num_attention_heads, this->draft_model->num_key_value_heads, this->draft_model->head_dim, this->draft_model->rms_norm_eps)); + } + + ea_topk_func = new functions::TopK<T>(this->draft_model->vocab_size, ea_topk_per_iter); + ea_topk_func_2 = new functions::TopK<T>(ea_total_tried, this->ea_tree_size-1); + + this->ea_accept_nums_size = 0; + this->cur_ea_accept_nums_size = 0; + + } + + void init_weight_ptr(Memory* memory) { + ea_fc1->init_weight_ptr(memory); + ea_fc2->init_weight_ptr(memory); + for (int i = 0; i < ea_num_layers; i++) { + ea_layers[i]->init_weight_ptr(memory); + } + ea_layers[0]->attn->attn_norm = new Skip<T>(this->draft_model->hidden_size); + ea_kv_caches->rotary_embedding = this->draft_model->kv_caches->rotary_embedding; + } + + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + // init eagle output + offset = ea_fc1->init_output_ptr(memory, num_tokens, offset); + offset = ea_fc2->init_output_ptr(memory, num_tokens, offset); + int64_t layer_end = 0; + for (int i = 0; i < ea_num_layers; i++) { + layer_end = ea_layers[i]->init_output_ptr(memory, num_tokens, offset); + } + offset = layer_end; + offset = memory->allocate((void**)&eagle_logits, offset, this->ea_topk_per_iter * this->draft_model->vocab_size * sizeof(T)); + offset = memory->allocate((void**)&eagle_mask_2d, offset, this->ea_topk_per_iter * sizeof(uint64_t)); + offset = memory->allocate((void**)&ea_tmp_mask_2d, offset, this->ea_topk_per_iter * sizeof(uint64_t)); + offset = memory->allocate((void**)&ea_tried_history_val, offset, this->ea_total_tried * sizeof(T)); + offset = memory->allocate((void**)&ea_tried_history_pos, offset, this->ea_total_tried * sizeof(int32_t)); + offset = memory->allocate((void**)&ea_tried_history_parent, offset, this->ea_topk_per_iter * (this->ea_num_iter - 1) * sizeof(int32_t)); + cudaMallocHost(&eagle_original_length, sizeof(int32_t)); + + offset = ea_topk_func->init_output_ptr(memory, this->ea_topk_per_iter, offset); + offset = ea_topk_func_2->init_output_ptr(memory, 1, offset); + + offset = memory->allocate((void**)&ea_prev_hidden_state, offset, num_tokens * this->draft_model->hidden_size * sizeof(T)); + offset = memory->allocate((void**)&ea_prev_embed, offset, num_tokens * this->draft_model->hidden_size * sizeof(T)); + offset = memory->allocate((void**)&eagle_position_ids, offset, num_tokens * sizeof(int32_t)); + offset = memory->allocate((void**)&eagle_cache_length, offset, sizeof(int32_t)); + + offset = memory->allocate((void**)&ea_d_best, offset, 2 * sizeof(int32_t)); + cudaMallocHost(&ea_h_best, 2 * sizeof(int32_t)); + offset = memory->allocate((void**)&ea_tmp_kvcache, offset, 64 * this->draft_model->kv_caches->num_hidden_layers * 2 * this->draft_model->kv_caches->dim * sizeof(T)); + + // to allocate ealge draft some states + offset = memory->allocate((void**)&ea_tree_draft_ids, offset, this->ea_tree_size * sizeof(int32_t)); + offset = memory->allocate((void**)&ea_tree_position_ids, offset, this->ea_tree_size * sizeof(int32_t)); + offset = memory->allocate((void**)&ea_tree_cache_length, offset, sizeof(int32_t)); + offset = memory->allocate((void**)&ea_tree_parent, offset, this->ea_tree_size * sizeof(int32_t)); + offset = memory->allocate((void**)&ea_tree_attn_mask, offset, this->ea_tree_size * sizeof(uint64_t)); + + + + // init draft output + int64_t lm_head_end = this->draft_model->init_output_ptr(memory, num_tokens, offset); + offset = lm_head_end; + + offset = memory->allocate((void**)&draft_input, offset, num_tokens * sizeof(int32_t)); + offset = memory->allocate((void**)&draft_position_ids, offset, num_tokens * sizeof(int32_t)); + offset = memory->allocate((void**)&draft_cache_length, offset, sizeof(int32_t)); + cudaMallocHost(&host_draft_cache_length, sizeof(int32_t)); + + + offset = memory->allocate((void**)&draft_logits, offset, 64 * this->draft_model->vocab_size * sizeof(T)); + offset = topk_func->init_output_ptr(memory, 64, offset); + + offset = memory->allocate((void**)&draft_tmp, offset, (this->min_draft_length + ea_num_iter + 1)*sizeof(int32_t)); + offset = memory->allocate((void**)&d_best, offset, sizeof(int32_t)); + cudaMallocHost(&h_best, sizeof(int32_t)); + + // cascade vars + offset = memory->allocate((void**)&draft_tmp_hidden_state, offset, (this->min_draft_length + ea_num_iter + 1) * this->draft_model->hidden_size * sizeof(T)); + // cudaMallocHost(&host_ea_accept_nums, 1024 * sizeof(int)); + offset = memory->allocate((void**)&ea_accept_nums, offset, 1024 * sizeof(int32_t)); + return offset; + } + + int init_storage() { + + this->model->init_weight_ptr(this->model->memory); + // this->init_weight_ptr(this->model->memory); + this->draft_model->init_weight_ptr(this->model->memory); + this->init_weight_ptr(this->model->memory); + + int64_t offset = this->model->init_output_ptr(this->model->memory, this->model->chunk_length, this->model->memory->model_offset); + int64_t kv_cache_offset = init_output_ptr(this->model->memory, this->model->chunk_length, offset); + + int model_kv_size = (this->model->num_hidden_layers*this->model->num_key_value_heads*this->model->head_dim); + int draft_kv_size = (this->draft_model->num_hidden_layers*this->draft_model->num_key_value_heads*this->draft_model->head_dim); + int ea_kv_size = this->ea_num_layers * this->draft_model->num_key_value_heads * this->draft_model->head_dim; + float ratio = float(model_kv_size)/float(model_kv_size + draft_kv_size + ea_kv_size); + kv_cache_offset = this->model->kv_caches->init_output_ptr(this->model->memory, kv_cache_offset, ratio); + ratio = float(draft_kv_size)/float(draft_kv_size + ea_kv_size); + kv_cache_offset = this->draft_model->kv_caches->init_output_ptr(this->model->memory, kv_cache_offset, ratio); + this->ea_kv_caches->init_output_ptr(this->model->memory, kv_cache_offset, 1.0); + return min(min(this->draft_model->kv_caches->budget, this->model->kv_caches->budget), this->ea_kv_caches->budget + 1); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.substr(0, 5) == "eagle") { + if (name.substr(0, 9) == "eagle.fc1") { + ea_fc1->load_to_storage(name, ptr); + } else if (name.substr(0, 9) == "eagle.fc2") { + ea_fc2->load_to_storage(name, ptr); + } else { + std::regex layer_regex("eagle\\.layers\\.(\\d+)\\.(.*)"); + std::smatch matches; + if (std::regex_search(name, matches, layer_regex)) { + int layer_idx = std::stoi(matches[1]); + ea_layers[layer_idx]->load_to_storage(matches[2], ptr); + } else { + throw std::invalid_argument("Unsupported name (layer_idx not found): " + name); + } + } + } else if (name.substr(0, 5) == "draft"){ + std::string draft_name = name.substr(6); + this->draft_model->load_to_storage(draft_name, ptr); + } else { + this->model->load_to_storage(name, ptr); + } + } + + + + void eagle_prefill(int num_history_tokens) { + cudaMemcpy(this->ea_prev_embed + (ea_num_prev - 1) * this->draft_model->hidden_size, this->draft_model->embedding->output, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->ea_fc1->prefill(calc_stream, ea_num_prev, this->ea_prev_embed); + this->ea_fc2->prefill(calc_stream, ea_num_prev, this->ea_prev_hidden_state); + elementwise_add(calc_stream, ea_num_prev, this->draft_model->hidden_size, this->ea_fc1->output, this->ea_fc2->output, this->ea_fc2->output); + T* layer_output = nullptr; + + for (int i = 0; i < ea_num_layers; i++) { + this->ea_layers[i]->prefill(num_prev, num_history_tokens, this->ea_fc2->output, layer_output, this->eagle_position_ids, this->ea_kv_caches->caches[i]); + layer_output = this->ea_layers[i]->output; + } + elementwise_add(calc_stream, ea_num_prev, this->draft_model->hidden_size, this->ea_fc2->output, layer_output, this->ea_fc2->output); + } + + void eagle_decode(int32_t* cache_length) { + + this->ea_fc1->prefill(calc_stream, ea_num_prev, this->ea_prev_embed); + this->ea_fc2->prefill(calc_stream, ea_num_prev, this->ea_prev_hidden_state); + elementwise_add(calc_stream, ea_num_prev, this->draft_model->hidden_size, this->ea_fc1->output, this->ea_fc2->output, this->ea_fc2->output); + T* layer_output = nullptr; + for (int i = 0; i < ea_num_layers; i++) { + this->ea_layers[i]->decode(ea_num_prev, this->eagle_padded_length, this->ea_fc2->output, layer_output, this->eagle_position_ids, cache_length, Mask(nullptr), this->ea_kv_caches->caches[i]); + layer_output = this->ea_layers[i]->output; + } + elementwise_add(calc_stream, ea_num_prev, this->draft_model->hidden_size, this->ea_fc2->output, layer_output, this->ea_fc2->output); + } + + void prefill(int32_t num_tokens, int32_t num_history_tokens, int32_t* input, int32_t* position_ids, void* output) { + this->model->prefill(num_tokens, num_history_tokens, input, position_ids, output); + if (num_history_tokens > 0) { + this->draft_model->embedding->prefill(calc_stream, this->num_prev, this->draft_input); + cudaMemcpy(this->ea_prev_embed, this->draft_model->embedding->output+this->draft_model->hidden_size, (this->num_prev-1) * this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->draft_model->prefill_embed(this->num_prev, this->num_history_tokens, this->draft_model->embedding->output, this->draft_position_ids, (void*)this->draft_logits); + + cudaMemcpy(this->ea_prev_hidden_state, this->draft_model->norm->output, this->num_prev * this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + + this->draft_model->embedding->prefill(calc_stream, 1, input); + this->eagle_prefill(this->ea_num_history_tokens); + + // this->draft_model-> + } + + cudaMemcpy(this->draft_input, input, num_tokens * sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_position_ids, position_ids, num_tokens * sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->num_prev = num_tokens; + this->num_history_tokens = num_history_tokens; + this->is_first_draft = true; + + // eagle + cudaMemcpy(this->eagle_position_ids, this->draft_position_ids, num_tokens * sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->ea_num_prev = num_tokens; + this->ea_num_history_tokens = num_history_tokens; + this->ea_is_first_draft = true; + + + this->ea_accept_nums_size = 0; + cudaMemcpy(this->ea_accept_nums, &this->ea_accept_nums_size, sizeof(int), cudaMemcpyHostToDevice); + } + + + void draft_decode(int32_t num_tokens, int32_t padded_length, int32_t* input, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + throw std::runtime_error("Draft decode is not supported"); + } + + void draft_decode_with_graph_control(int32_t d_num_tokens, int32_t d_padded_length, int32_t* d_input, int32_t* d_position_ids, int32_t* d_cache_length, uint64_t* d_mask_2d, void* d_output) { + if (this->draft_cuda_graph) { + if (this->draft_graphCreated_padding_length != d_padded_length || this->draft_graphCreated_input_length != d_num_tokens) { + if (this->draft_graphExec != nullptr) { + cudaGraphExecDestroy(this->draft_graphExec); + this->draft_graphExec = nullptr; + } + if (this->draft_graph != nullptr) { + cudaGraphDestroy(this->draft_graph); + this->draft_graph = nullptr; + } + cudaStreamBeginCapture(calc_stream.stream, cudaStreamCaptureModeGlobal); + this->draft_model->decode(d_num_tokens, d_padded_length, d_input, d_position_ids, d_cache_length, d_mask_2d, d_output); + cudaStreamEndCapture(calc_stream.stream, &(this->draft_graph)); + cudaGraphInstantiate(&(this->draft_graphExec), this->draft_graph, nullptr, nullptr, 0); + this->draft_graphCreated_padding_length = d_padded_length; + this->draft_graphCreated_input_length = d_num_tokens; + } + cudaGraphLaunch(this->draft_graphExec, calc_stream.stream); + } else { + this->draft_model->decode(d_num_tokens, d_padded_length, d_input, d_position_ids, d_cache_length, d_mask_2d, d_output); + } + } + + void decode(int32_t num_tokens, int32_t padded_length, int32_t* input, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + this->model->decode(num_tokens, padded_length, input, position_ids, cache_length, nullptr, output); + } + + + void draft_with_eagle(int32_t* ea_tree_draft_ids, int32_t* ea_tree_position_ids, int32_t* ea_cache_length, uint64_t* ea_tree_attn_mask, int32_t* ea_tree_parent) { + cudaMemcpy(this->eagle_original_length, ea_cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + this->eagle_padded_length = (this->eagle_original_length[0] + 256 - 1) / 128 * 128; + if (this->ea_is_first_draft) { + // prefill hidden states and embedding have been cpy + this->draft_model->embedding->prefill(calc_stream, 1, ea_tree_draft_ids); + this->eagle_prefill(this->ea_num_history_tokens); + } else { + this->eagle_decode(ea_cache_length); + } + + cudaMemcpy(this->eagle_cache_length, ea_cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->eagle_position_ids, ea_cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + repeat(calc_stream, ea_topk_per_iter, 1, 0, this->eagle_position_ids); + + { // d = 0 + this->draft_model->lm_head->prefill(calc_stream, 1, this->ea_fc2->output + (ea_num_prev - 1) * this->draft_model->hidden_size, this->eagle_logits); + log_softmax(calc_stream, 1, this->draft_model->vocab_size, this->eagle_logits); + this->ea_topk_func->prefill(calc_stream, 1, this->eagle_logits); + cudaMemcpy(this->ea_topk_func_2->topk_pos, this->ea_topk_func->topk_pos, ea_topk_per_iter * sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_topk_func_2->topk_val, this->ea_topk_func->topk_val, ea_topk_per_iter * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_tried_history_val, this->ea_topk_func->topk_val, ea_topk_per_iter * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_tried_history_pos, this->ea_topk_func->topk_pos, ea_topk_per_iter * sizeof(int32_t), cudaMemcpyDeviceToDevice); + repeat(calc_stream, ea_topk_per_iter, this->draft_model->hidden_size, ea_num_prev-1, this->ea_fc2->output, this->ea_fc1->output); + init_tree(calc_stream, ea_topk_per_iter, this->eagle_mask_2d); + } + + for (int d = 1; d < this->ea_num_iter; ++d) { + add(calc_stream, 1, this->eagle_cache_length, ea_topk_per_iter); + this->draft_model->embedding->prefill(calc_stream, ea_topk_per_iter, this->ea_topk_func_2->topk_pos); + this->ea_fc2->prefill(calc_stream, ea_topk_per_iter, this->ea_fc1->output); + this->ea_fc1->prefill(calc_stream, ea_topk_per_iter, this->draft_model->embedding->output); + elementwise_add(calc_stream, ea_topk_per_iter, this->draft_model->hidden_size, this->ea_fc1->output, this->ea_fc2->output, this->ea_fc2->output); + T* layer_output = nullptr; + for (int i = 0; i < ea_num_layers; i++) { + this->ea_layers[i]->decode(ea_topk_per_iter, this->eagle_padded_length, this->ea_fc2->output, layer_output, this->eagle_position_ids, this->eagle_cache_length, Mask(eagle_mask_2d, ea_topk_per_iter, ea_topk_per_iter * d), this->ea_kv_caches->caches[i]); + layer_output = this->ea_layers[i]->output; + } + elementwise_add(calc_stream, ea_topk_per_iter, this->draft_model->hidden_size, this->ea_fc2->output, layer_output, this->ea_fc2->output); + add(calc_stream, ea_topk_per_iter, this->eagle_position_ids, 1); + + this->draft_model->lm_head->prefill(calc_stream, ea_topk_per_iter, this->ea_fc2->output, this->eagle_logits); + log_softmax(calc_stream, ea_topk_per_iter, this->draft_model->vocab_size, this->eagle_logits); + this->ea_topk_func->prefill(calc_stream, ea_topk_per_iter, this->eagle_logits); + cumsum(calc_stream, ea_topk_per_iter, ea_topk_per_iter, this->ea_topk_func->topk_val, this->ea_topk_func_2->topk_val); + cudaMemcpy(this->ea_tried_history_val + ea_topk_per_iter + (d - 1) * ea_topk_per_iter * ea_topk_per_iter, this->ea_topk_func->topk_val, ea_topk_per_iter * ea_topk_per_iter * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_tried_history_pos + ea_topk_per_iter + (d - 1) * ea_topk_per_iter * ea_topk_per_iter, this->ea_topk_func->topk_pos, ea_topk_per_iter * ea_topk_per_iter * sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->ea_topk_func_2->prefill(calc_stream, 1, this->ea_topk_func->topk_val, ea_topk_per_iter * ea_topk_per_iter, ea_topk_per_iter); + + cudaMemcpy(this->ea_tmp_mask_2d, this->eagle_mask_2d, ea_topk_per_iter * sizeof(uint64_t), cudaMemcpyDeviceToDevice); + set_parent(calc_stream, ea_topk_per_iter, this->ea_tried_history_parent + (d - 1) * ea_topk_per_iter, this->ea_topk_func_2->topk_pos, 10 + (d - 1) * ea_topk_per_iter * ea_topk_per_iter); + update_tree(calc_stream, ea_topk_per_iter, ea_topk_per_iter * d, this->eagle_mask_2d, this->ea_tmp_mask_2d, this->ea_topk_func_2->topk_pos); + remap_hidden(calc_stream, ea_topk_per_iter, this->draft_model->hidden_size, this->ea_topk_func_2->topk_pos, this->ea_fc2->output, this->ea_fc1->output, ea_topk_per_iter); + remap_id(calc_stream, ea_topk_per_iter, this->ea_topk_func_2->topk_pos, this->ea_topk_func->topk_pos); + } + + this->ea_topk_func_2->prefill(calc_stream, 1, this->ea_tried_history_val); + + // build tree + build_dynamic_tree(calc_stream, this->ea_tree_size, this->eagle_original_length[0], this->ea_topk_per_iter, this->ea_tried_history_parent, this->ea_topk_func_2->topk_pos, ea_tree_position_ids, ea_tree_attn_mask, ea_tree_parent); + remap_id(calc_stream, this->ea_tree_size-1, this->ea_topk_func_2->topk_pos, this->ea_tried_history_pos, ea_tree_draft_ids + 1); + + this->ea_is_first_draft = false; + } + + void draft(int32_t *tree_draft_ids, int32_t *tree_position_ids, int32_t *cache_length, uint64_t*, int32_t*) { + // reset cur draft length + this->cur_draft_length = 0; + this->cur_ea_accept_nums_size = 0; + if (this->is_first_draft) { + // append tree draft ids to draft input + cudaMemcpy(this->draft_input+this->num_prev, tree_draft_ids, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_position_ids+this->num_prev, tree_position_ids, sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->num_prev += 1; + + this->draft_model->embedding->prefill(calc_stream, this->num_prev, this->draft_input); + cudaMemcpy(this->ea_prev_embed, this->draft_model->embedding->output+ this->draft_model->hidden_size, (this->num_prev-1) * this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + + this->draft_model->prefill_embed(this->num_prev, this->num_history_tokens, this->draft_model->embedding->output, this->draft_position_ids, (void*)this->draft_logits); + + // eagle prepare for draft_with_eagle function + // ea_is_first_draft is True + cudaMemcpy(this->eagle_position_ids + (this->ea_num_prev), tree_position_ids, sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->ea_num_prev = this->num_prev; + + cudaMemcpy(this->ea_prev_hidden_state, this->draft_model->norm->output, this->num_prev * this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->topk_func->prefill(calc_stream, 1, this->draft_logits); + + + // prepare for draft with eagle + cudaMemcpy(this->ea_tree_draft_ids, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_tree_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->ea_tree_cache_length, 1); + + cudaMemcpy(this->draft_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->draft_cache_length, 1); + cudaMemcpy(this->host_draft_cache_length, this->draft_cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + + // draft model has forward one time + cudaMemcpy(this->draft_tmp, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_tmp_hidden_state, this->draft_model->norm->output + (this->num_prev-1) * this->draft_model->hidden_size, this->draft_model->hidden_size*sizeof(T), cudaMemcpyDeviceToDevice); + this->cur_draft_length += 1; + + } else if (this->num_prev == 2){ + this->draft_model->embedding->prefill(calc_stream, this->num_prev, this->draft_input); + cudaMemcpy(this->ea_prev_embed + (ea_num_prev)* this->draft_model->hidden_size, this->draft_model->embedding->output + this->draft_model->hidden_size, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + + this->draft_model->decode_embed(this->num_prev, this->draft_padded_length, this->draft_model->embedding->output, this->draft_position_ids, this->draft_cache_length, nullptr, (void*)this->draft_logits); + + + this->topk_func->prefill(calc_stream, 1, this->draft_logits+(this->draft_model->vocab_size)); + + + // prepare for the eagle input + cudaMemcpy(this->ea_prev_hidden_state + (ea_num_prev)* this->draft_model->hidden_size, this->draft_model->norm->output, this->num_prev * this->draft_model->hidden_size*sizeof(T), cudaMemcpyDeviceToDevice); + this->draft_model->embedding->prefill(calc_stream, 1, this->topk_func->topk_pos); + cudaMemcpy(this->ea_prev_embed + (ea_num_prev+1)* this->draft_model->hidden_size, this->draft_model->embedding->output, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->ea_num_prev += this->num_prev; + + + // prepare for draft with eagle + cudaMemcpy(this->ea_tree_draft_ids, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_tree_cache_length, this->draft_cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + + // draft model has forward one time + cudaMemcpy(this->draft_tmp, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_tmp_hidden_state, this->draft_model->norm->output + (this->num_prev-1) * this->draft_model->hidden_size, this->draft_model->hidden_size*sizeof(T), cudaMemcpyDeviceToDevice); + this->cur_draft_length += 1; + } else if (this->draft_model_start) { + // num_prev == 1 + this->draft_model->decode(this->num_prev, this->draft_padded_length, this->draft_input, this->draft_position_ids, this->draft_cache_length, nullptr, (void*)this->draft_logits); + this->topk_func->prefill(calc_stream, 1, this->draft_logits); + + // prepare for the eagle input + cudaMemcpy(this->ea_prev_hidden_state + (ea_num_prev)*this->draft_model->hidden_size, this->draft_model->norm->output, this->num_prev * this->draft_model->hidden_size*sizeof(T), cudaMemcpyDeviceToDevice); + + this->draft_model->embedding->prefill(calc_stream, 1, this->topk_func->topk_pos); + cudaMemcpy(this->ea_prev_embed + (ea_num_prev)* this->draft_model->hidden_size, this->draft_model->embedding->output, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->ea_num_prev += this->num_prev; + + // prepare for draft with eagle + cudaMemcpy(this->ea_tree_draft_ids, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_tree_cache_length, this->draft_cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + + + // draft model has forward one time + cudaMemcpy(this->draft_tmp, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_tmp_hidden_state, this->draft_model->norm->output + (this->num_prev-1) * this->draft_model->hidden_size, this->draft_model->hidden_size*sizeof(T), cudaMemcpyDeviceToDevice); + this->cur_draft_length += 1; + + } else { + cudaMemcpy(this->ea_tree_draft_ids, tree_draft_ids, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->ea_tree_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + } + + + + while (this->cur_draft_length < min_draft_length){ + + // eagle draft + this->draft_with_eagle( + this->ea_tree_draft_ids, + this->ea_tree_position_ids, + this->ea_tree_cache_length, + this->ea_tree_attn_mask, + this->ea_tree_parent + ); + + + cudaMemcpy(this->draft_cache_length, this->ea_tree_cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->draft_cache_length, this->ea_tree_size); + cudaMemcpy(this->host_draft_cache_length, this->draft_cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + this->draft_padded_length = (this->host_draft_cache_length[0]+ 128 -1) / 128*128;; + + this->draft_decode_with_graph_control( + this->ea_tree_size, + this->draft_padded_length, + this->ea_tree_draft_ids, + this->ea_tree_position_ids, + this->draft_cache_length, + this->ea_tree_attn_mask, + (void*) this->draft_logits + ); + this->topk_func->prefill(calc_stream, this->ea_tree_size, this->draft_logits); + + + this->draft_verify( + this->ea_tree_size, + this->ea_tree_draft_ids, + this->topk_func->topk_pos, + this->ea_tree_position_ids, + this->ea_tree_cache_length, + this->ea_tree_attn_mask, + this->ea_tree_parent + ); + + cudaMemcpy(this->ea_accept_nums + (this->ea_accept_nums_size+this->cur_ea_accept_nums_size+1), this->ea_d_best, sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->cur_ea_accept_nums_size += 1; + + + // accept return to ea_h_best[0] + cudaMemcpy(this->ea_tree_draft_ids, this->ea_tree_draft_ids + (this->ea_h_best[0]-1), sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->ea_tree_cache_length, this->ea_h_best[0]); + + } + + + + cudaMemcpy(tree_draft_ids + 1, this->draft_tmp, this->cur_draft_length*sizeof(int32_t), cudaMemcpyDeviceToDevice); + make_arange(calc_stream, this->cur_draft_length+1, cache_length, tree_position_ids); + this->is_first_draft = false; + } + + int draft_verify(int32_t ea_num_tokens, int32_t* ea_pred, int32_t* ea_gt, int32_t* ea_position_ids, int32_t* ea_cache_length, uint64_t* ea_mask_2d, int32_t* ea_tree_parent) { + verify_draft(calc_stream, ea_num_tokens, ea_pred, ea_gt, ea_position_ids, ea_cache_length, ea_mask_2d, ea_tree_parent, this->ea_d_best); + + cudaMemcpyAsync(this->ea_h_best, this->ea_d_best, 2 * sizeof(int32_t), cudaMemcpyDeviceToHost, calc_stream.stream); + cudaStreamSynchronize(calc_stream.stream); + + this->ea_num_prev = ea_h_best[0]; + remap_hidden(calc_stream, this->ea_num_prev, this->draft_model->hidden_size, ea_pred, this->draft_model->norm->output, this->ea_prev_hidden_state); + + fix_kv_cache(calc_stream, ea_h_best[0], this->draft_model->kv_caches->num_hidden_layers * 2, this->draft_model->kv_caches->dim, ea_pred, ea_gt, ea_cache_length, this->draft_model->kv_caches->d_flat_caches, this->ea_tmp_kvcache); + + this->draft_model->embedding->prefill(calc_stream, this->ea_num_prev, ea_pred); + cudaMemcpy(this->ea_prev_embed, this->draft_model->embedding->output, this->ea_num_prev * this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + + make_arange(calc_stream, this->ea_num_prev, ea_cache_length, this->eagle_position_ids); + + cudaMemcpy(this->draft_tmp_hidden_state + (this->cur_draft_length*this->draft_model->hidden_size), this->ea_prev_hidden_state, this->ea_num_prev * this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + + cudaMemcpy(this->draft_tmp + this->cur_draft_length, ea_pred, this->ea_num_prev * sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->cur_draft_length += this->ea_num_prev; + + return ea_h_best[0]; + } + + int verify(int32_t num_tokens, int32_t* pred, int32_t* gt, int32_t* position_ids, int32_t* cache_length, uint64_t* attn_mask, int32_t* tree_parent) { + verify_seq_draft(calc_stream, num_tokens, pred, gt, (uint16_t*)attn_mask, this->d_best); + cudaMemcpyAsync(this->h_best, this->d_best, 1 * sizeof(int32_t), cudaMemcpyDeviceToHost, calc_stream.stream); + cudaStreamSynchronize(calc_stream.stream); + + if (h_best[0]>(this->cur_draft_length+1)) { + h_best[0] = this->cur_draft_length+1; + } + + if (h_best[0]==(this->cur_draft_length+1)) { + // full accept + this->num_prev = 2; + cudaMemcpy(this->draft_input, gt + (this->cur_draft_length-1), 2*sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->draft_cache_length, this->h_best[0]+1); + make_arange(calc_stream, 2, cache_length, this->draft_position_ids); + add(calc_stream, 2, this->draft_position_ids, this->cur_draft_length); + cudaMemcpy(this->host_draft_cache_length, this->draft_cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + this->draft_padded_length = (this->host_draft_cache_length[0]+ 128 -1) / 128*128; + } else { + this->num_prev = 1; + cudaMemcpy(this->draft_input, gt + (this->h_best[0]-1), sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->draft_cache_length, this->h_best[0]+1); + cudaMemcpy(this->draft_position_ids, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->draft_position_ids, this->h_best[0]); + cudaMemcpy(this->host_draft_cache_length, this->draft_cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + this->draft_padded_length = (this->host_draft_cache_length[0]+ 128 -1) / 128*128; + + // adapt eagle draft ptr + // conidtion 1: eagle last start postion is larger than accept position + if (host_draft_cache_length[0] > this->eagle_original_length[0] + 1) { + this->ea_num_prev = host_draft_cache_length[0] - (this->eagle_original_length[0] + 1); + // keep ea_prev_hidden_state and update ea_prev_embed + this->draft_model->embedding->prefill(calc_stream, 1, this->draft_input); + cudaMemcpy(this->ea_prev_embed+(this->ea_num_prev-1), this->draft_model->embedding->output, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + } else { + // condition 2: eagle last start position is less than accept position + // index from the kepted draft model hidden state + cudaMemcpy(this->ea_prev_hidden_state, this->draft_tmp_hidden_state + (h_best[0]-1) *this->draft_model->hidden_size, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->ea_num_prev = 1; + this->draft_model->embedding->prefill(calc_stream, 1, this->draft_input); + cudaMemcpy(this->ea_prev_embed, this->draft_model->embedding->output, this->draft_model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->eagle_position_ids, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->eagle_position_ids, this->h_best[0]-1); + } + + + } + + + this->ea_accept_nums_size += this->cur_ea_accept_nums_size; + add(calc_stream, 1, this->ea_accept_nums, this->cur_ea_accept_nums_size); + cudaMemcpy(tree_parent, this->ea_accept_nums, (this->ea_accept_nums_size+1) * sizeof(int32_t), cudaMemcpyDeviceToDevice); + + + + return h_best[0]; + + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/kvcache.cuh b/examples/CPM.cu/src/model/kvcache.cuh new file mode 100644 index 00000000..5ba14353 --- /dev/null +++ b/examples/CPM.cu/src/model/kvcache.cuh @@ -0,0 +1,65 @@ +#pragma once +#include "../trait.cuh" +#include "rotary.cuh" +#include <vector> +#include <cuda_runtime.h> + +template <typename T> +struct KVCache { + int dim; + T *k_cache, *v_cache; + RotaryEmbedding<T> *rotary_embedding; + + KVCache(int dim, RotaryEmbedding<T> *rotary_embedding) { + this->dim = dim; + this->rotary_embedding = rotary_embedding; + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + offset = memory->allocate((void**)&this->k_cache, offset, num_tokens * dim * sizeof(T)); + offset = memory->allocate((void**)&this->v_cache, offset, num_tokens * dim * sizeof(T)); + return offset; + } + + T* offset_k(int offset) { return k_cache + offset * dim; } + T* offset_v(int offset) { return v_cache + offset * dim; } +}; + +template <typename T> +struct KVCacheManager { + int num_hidden_layers; + int dim; + int budget; + std::vector<KVCache<T>*> caches; + T **h_flat_caches, **d_flat_caches; + RotaryEmbedding<T> *rotary_embedding; + + KVCacheManager(int num_hidden_layers, int num_key_value_heads, int head_dim) { + this->num_hidden_layers = num_hidden_layers; + this->dim = num_key_value_heads * head_dim; + this->rotary_embedding = new RotaryEmbedding<T>(head_dim); + } + + void init_weight_ptr(Memory* memory) { + this->rotary_embedding->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int64_t offset, float ratio=1.0) { + offset = memory->allocate((void**)&this->d_flat_caches, offset, num_hidden_layers * 2 * sizeof(T*)); + + budget = int64_t(memory->get_remaining_memory(offset) * ratio * 0.999) / (this->num_hidden_layers * 2 * this->dim * sizeof(T)) - 1; + for (int i = 0; i < this->num_hidden_layers; i++) { + caches.push_back(new KVCache<T>(this->dim, this->rotary_embedding)); + } + for (int i = 0; i < this->num_hidden_layers; i++) { + offset = caches[i]->init_output_ptr(memory, budget, offset); + } + this->h_flat_caches = new T*[num_hidden_layers * 2]; + for (int i = 0; i < num_hidden_layers; i++) { + this->h_flat_caches[i * 2] = caches[i]->k_cache; + this->h_flat_caches[i * 2 + 1] = caches[i]->v_cache; + } + cudaMemcpy(this->d_flat_caches, this->h_flat_caches, num_hidden_layers * 2 * sizeof(T*), cudaMemcpyHostToDevice); + return offset; + } +}; diff --git a/examples/CPM.cu/src/model/layer.cuh b/examples/CPM.cu/src/model/layer.cuh new file mode 100644 index 00000000..56b6f818 --- /dev/null +++ b/examples/CPM.cu/src/model/layer.cuh @@ -0,0 +1,90 @@ +#pragma once +#include "perf.cuh" +#include "norm.cuh" +#include "attn.cuh" +#include "ffn.cuh" +#include "kvcache.cuh" +#include "mask.cuh" +#include <cuda_runtime.h> + +template <typename T> +struct Layer { + Attention<T> *attn; + FFN<T> *ffn; + T* output; + int hidden_size; + float residual_scale; + + Layer(int hidden_size, int intermediate_size, int num_attention_heads, int num_key_value_heads, int head_dim, float rms_norm_eps, float residual_scale = 1.0, int window_size = 0) { + this->attn = new Attention<T>(hidden_size, num_attention_heads, num_key_value_heads, head_dim, rms_norm_eps, window_size); + this->ffn = new GatedFFN<T>(hidden_size, intermediate_size, rms_norm_eps); + this->hidden_size = hidden_size; + this->residual_scale = residual_scale; + } + + void init_weight_ptr(Memory* memory) { + this->attn->init_weight_ptr(memory); + this->ffn->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t attn_end = this->attn->init_output_ptr(memory, num_tokens, offset); + int64_t ffn_end = this->ffn->init_output_ptr(memory, num_tokens, offset); + this->output = this->ffn->output; + return std::max(attn_end, ffn_end); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.find("attn") != std::string::npos || name.find("input_layernorm") != std::string::npos) { + this->attn->load_to_storage(name, ptr); + } else if (name.find("mlp") != std::string::npos || name.find("post_attention_layernorm") != std::string::npos) { + this->ffn->load_to_storage(name, ptr); + } else { + throw std::invalid_argument("Unsupported name " + name); + } + } + + void prefill(int32_t num_tokens, int32_t num_history_tokens, T* input, T* prev_output, int32_t* position_ids, KVCache<T>* kv_cache, T* prev_layer_states=nullptr) { + if (prev_output != nullptr) { + elementwise_scale(calc_stream, num_tokens, this->hidden_size, prev_output, this->residual_scale); + } + cuda_perf_start_on_stream_f(PREFILL_ATTN, calc_stream.stream); + this->attn->prefill(calc_stream, num_tokens, num_history_tokens, input, prev_output, position_ids, kv_cache); + cuda_perf_stop_on_stream_f(PREFILL_ATTN, calc_stream.stream); + if (prev_layer_states != nullptr) { + cudaMemcpyAsync( + prev_layer_states, // dst + input, // src + num_tokens * this->attn->hidden_size * sizeof(T), + cudaMemcpyDeviceToDevice, + calc_stream.stream + ); + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, this->attn->output, this->residual_scale); + cuda_perf_start_on_stream_f(PREFILL_FFN, calc_stream.stream); + this->ffn->prefill(calc_stream, num_tokens, input, this->attn->output); + cuda_perf_stop_on_stream_f(PREFILL_FFN, calc_stream.stream); + } + + void decode(int32_t num_tokens, int32_t padded_length, T* input, T* prev_output, int32_t* position_ids, int32_t* cache_length, const Mask& mask, KVCache<T>* kv_cache, T* prev_layer_states=nullptr) { + if (prev_output != nullptr) { + elementwise_scale(calc_stream, num_tokens, this->hidden_size, prev_output, this->residual_scale); + } + cuda_perf_start_on_stream_f(DECODE_ATTN, calc_stream.stream); + this->attn->decode(calc_stream, num_tokens, padded_length, input, prev_output, position_ids, cache_length, mask, kv_cache); + cuda_perf_stop_on_stream_f(DECODE_ATTN, calc_stream.stream); + if (prev_layer_states != nullptr) { + cudaMemcpyAsync( + prev_layer_states, // dst + input, // src + num_tokens * this->attn->hidden_size * sizeof(T), + cudaMemcpyDeviceToDevice, + calc_stream.stream + ); + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, this->attn->output, this->residual_scale); + cuda_perf_start_on_stream_f(DECODE_FFN, calc_stream.stream); + this->ffn->decode(calc_stream, num_tokens, input, this->attn->output); + cuda_perf_stop_on_stream_f(DECODE_FFN, calc_stream.stream); + } +}; diff --git a/examples/CPM.cu/src/model/linear.cuh b/examples/CPM.cu/src/model/linear.cuh new file mode 100644 index 00000000..5f838314 --- /dev/null +++ b/examples/CPM.cu/src/model/linear.cuh @@ -0,0 +1,101 @@ +#pragma once +#include <cuda_runtime.h> +#include <cublas_v2.h> +#include "../trait.cuh" +#include "../utils.cuh" +#include "elementwise.cuh" + +template <typename T, bool transposed=true> +void linear(const Stream& stream, int num_tokens, int dim_in, int dim_out, const T* input, const T* weight, T* output, bool inplace=false) { + float alpha = 1.0f; + float beta = inplace ? 1.0f : 0.0f; + if constexpr (transposed) { + cublasCheck(cublasGemmEx(stream.cublas_handle, + CUBLAS_OP_T, CUBLAS_OP_N, + dim_out, num_tokens, dim_in, + &alpha, + weight, TypeTraits<T>::cublas_type(), dim_in, + input, TypeTraits<T>::cublas_type(), dim_in, + &beta, + output, TypeTraits<T>::cublas_type(), dim_out, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT + )); + } else { + cublasCheck(cublasGemmEx(stream.cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + dim_out, num_tokens, dim_in, + &alpha, + weight, TypeTraits<T>::cublas_type(), dim_out, + input, TypeTraits<T>::cublas_type(), dim_in, + &beta, + output, TypeTraits<T>::cublas_type(), dim_out, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT + )); + } +} + +template <typename T, bool transposed=true, bool has_bias=false> +struct Linear { + int dim_in; + int dim_out; + T* output; + T* weight; + T* bias; + + Linear(int dim_in, int dim_out) { + this->dim_in = dim_in; + this->dim_out = dim_out; + } + + void init_weight_ptr(Memory* memory) { + weight = (T*)memory->allocate_for_model(dim_in * dim_out * sizeof(T)); + if constexpr (has_bias) { + bias = (T*)memory->allocate_for_model(dim_out * sizeof(T)); + } + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + return memory->allocate((void**)&this->output, offset, num_tokens * dim_out * sizeof(T)); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.find("weight") != std::string::npos) { + cudaMemcpy((void*)weight, ptr, dim_in * dim_out * sizeof(T), cudaMemcpyHostToDevice); + } else if (name.find("bias") != std::string::npos) { + cudaMemcpy((void*)bias, ptr, dim_out * sizeof(T), cudaMemcpyHostToDevice); + } else { + throw std::invalid_argument("Unsupported name " + name); + } + } + + void prefill(const Stream& stream, int32_t num_tokens, T* input, T* tgt=nullptr, bool inplace=false) { + if (tgt == nullptr) tgt = this->output; + linear<T, transposed>(stream, num_tokens, dim_in, dim_out, input, weight, tgt, inplace); + if constexpr (has_bias) { + batched_add<T>(stream, num_tokens, dim_out, tgt, bias, tgt); + } + } +}; + +template <typename T> +struct LMHead : Linear<T> { + T* tmp_hidden_size; + float head_scale; + + LMHead(int dim_in, int dim_out, float head_scale = 1.0) : Linear<T>(dim_in, dim_out) { + this->head_scale = head_scale; + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + offset = Linear<T>::init_output_ptr(memory, num_tokens, offset); + offset = memory->allocate((void**)&this->tmp_hidden_size, offset, num_tokens * this->dim_in * sizeof(T)); + return offset; + } + + void prefill(const Stream& stream, int32_t num_tokens, T* input, T* tgt=nullptr, bool inplace=false) { + elementwise_scale(stream, num_tokens, this->dim_in, input, head_scale, tmp_hidden_size); + Linear<T>::prefill(stream, num_tokens, tmp_hidden_size, tgt, inplace); + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/mask.cuh b/examples/CPM.cu/src/model/mask.cuh new file mode 100644 index 00000000..73d3c59c --- /dev/null +++ b/examples/CPM.cu/src/model/mask.cuh @@ -0,0 +1,18 @@ +#pragma once + +#include <cuda_runtime.h> + +struct Mask { + uint64_t* ptr; + int mask_q_range; + int mask_k_range; + + Mask(uint64_t* ptr = nullptr, int mask_q_range = 0, int mask_k_range = 0) : ptr(ptr) { + if (ptr == nullptr) { + mask_q_range = 0; + mask_k_range = 0; + } + this->mask_q_range = mask_q_range; + this->mask_k_range = mask_k_range; + } +}; diff --git a/examples/CPM.cu/src/model/memory.cuh b/examples/CPM.cu/src/model/memory.cuh new file mode 100644 index 00000000..b2f1b857 --- /dev/null +++ b/examples/CPM.cu/src/model/memory.cuh @@ -0,0 +1,183 @@ +#pragma once +#include "../utils.cuh" +#include <cuda_runtime.h> +#include "../signal_handler.cuh" +#ifdef DISABLE_MEMPOOL +#include <vector> +#endif + +#define ALIGN_SIZE 256 + +// TODO: refactor this for better encapsulation +struct Memory { + int64_t memory_limit; + int64_t model_offset; +#ifndef DISABLE_MEMPOOL + uint8_t* memory_pool; +#else + int64_t allocated_size; + std::vector<void*> allocated_ptrs; +#endif + + Memory(float memory_limit) { + // Get GPU total memory size + size_t free_memory, total_memory; + cudaError_t err = cudaMemGetInfo(&free_memory, &total_memory); + if (err != cudaSuccess) { + fprintf(stderr, "Error: cudaMemGetInfo failed: %s\n\n", cudaGetErrorString(err)); + this->memory_limit = 0; + this->model_offset = 0; +#ifndef DISABLE_MEMPOOL + this->memory_pool = nullptr; +#else + this->allocated_size = 0; +#endif + return; + } + + // Calculate actual memory size + this->memory_limit = (int64_t)(total_memory * memory_limit); + +#ifndef DISABLE_MEMPOOL + printf("Use Pre-allocated Memory Pool\n"); +#else + printf("Use Dynamic Memory Allocation, this is for debug\n"); +#endif + printf("GPU Total Memory: %ld bytes (%.2f GB), ", total_memory, (double)total_memory / (1024*1024*1024)); + printf("Set Allocatable Memory Limit: %ld bytes (%.2f GB), ratio: %.1f%%\n", + this->memory_limit, (double)this->memory_limit / (1024*1024*1024), memory_limit * 100); + + this->model_offset = 0; +#ifndef DISABLE_MEMPOOL + err = cudaMalloc(reinterpret_cast<void**>(&this->memory_pool), this->memory_limit); + if (err != cudaSuccess) { + print_stack_trace(5); + fprintf(stderr, "Error: cudaMalloc failed in Memory constructor: %s, size: %ld\n\n", cudaGetErrorString(err), this->memory_limit); + this->memory_pool = nullptr; + } +#else + // In DISABLE_MEMPOOL mode, don't pre-allocate memory + this->allocated_size = 0; +#endif + } + + // Add destructor to prevent memory leak + ~Memory() { +#ifndef DISABLE_MEMPOOL + if (memory_pool != nullptr) { + cudaError_t err = cudaFree(memory_pool); + if (err != cudaSuccess) { + fprintf(stderr, "Warning: cudaFree failed in Memory destructor: %s\n\n", cudaGetErrorString(err)); + } + } +#else + // In DISABLE_MEMPOOL mode, free all individually allocated memory + for (void* ptr : allocated_ptrs) { + if (ptr != nullptr) { + cudaError_t err = cudaFree(ptr); + if (err != cudaSuccess) { + fprintf(stderr, "Warning: cudaFree failed in Memory destructor: %s\n\n", cudaGetErrorString(err)); + } + } + } + allocated_ptrs.clear(); +#endif + } + + // Get remaining available memory from a specific offset + int64_t get_remaining_memory(int64_t offset) const { +#ifndef DISABLE_MEMPOOL + return this->memory_limit - offset; +#else + return this->memory_limit - this->allocated_size; +#endif + } + +#ifndef DISABLE_MEMPOOL + void* allocate_for_model(size_t size) { + uint8_t* ret = memory_pool + model_offset; + model_offset += size; + model_offset = ROUND_UP(model_offset, ALIGN_SIZE); + if (model_offset > this->memory_limit) { + print_stack_trace(5); + fprintf(stderr, "Error: memory limit exceeded, offset: %ld, size: %ld, memory_limit: %ld\n\n", model_offset, size, this->memory_limit); + return nullptr; + } + return (void*)ret; + } + int64_t allocate(void** ptr, int64_t offset, size_t size = 0) { + if (size == 0) { + print_stack_trace(5); + fprintf(stderr, "Error: size is 0\n\n"); + return -1; + } + *ptr = memory_pool + offset; + offset += size; + offset = ROUND_UP(offset, ALIGN_SIZE); + if (offset > this->memory_limit) { + print_stack_trace(5); + fprintf(stderr, "Error: memory limit exceeded, offset: %ld, size: %ld, memory_limit: %ld\n\n", offset, size, this->memory_limit); + *ptr = nullptr; + return -1; + } + return offset; + } +#else + void* allocate_for_model(size_t size) { + void* ptr; + size_t aligned_size = ROUND_UP(size, ALIGN_SIZE); + + // Check if allocation would exceed memory limit + if (allocated_size + aligned_size > this->memory_limit) { + print_stack_trace(5); + fprintf(stderr, "Error: memory limit exceeded, allocated_size: %ld, new_size: %ld, memory_limit: %ld\n\n", + allocated_size, aligned_size, this->memory_limit); + return nullptr; + } + + cudaError_t err = cudaMalloc(&ptr, aligned_size); + if (err != cudaSuccess) { + print_stack_trace(5); + fprintf(stderr, "Error: cudaMalloc failed: %s, size: %ld\n\n", cudaGetErrorString(err), size); + return nullptr; + } + + allocated_ptrs.push_back(ptr); + allocated_size += aligned_size; + return ptr; + } + int64_t allocate(void** ptr, int64_t offset, size_t size = 0) { // 0 for reuse previous allocated memory, just need start offset, return value is useless + if (size == 0) { + print_stack_trace(5); + fprintf(stderr, "Error: size is 0\n\n"); + return -1; + } + + size_t aligned_size = ROUND_UP(size, ALIGN_SIZE); + + // Check if allocation would exceed memory limit + if (allocated_size + aligned_size > this->memory_limit) { + print_stack_trace(5); + fprintf(stderr, "Error: memory limit exceeded, allocated_size: %ld, new_size: %ld, memory_limit: %ld\n\n", + allocated_size, aligned_size, this->memory_limit); + *ptr = nullptr; + return -1; + } + + cudaError_t err = cudaMalloc(ptr, aligned_size); + if (err != cudaSuccess) { + print_stack_trace(5); + fprintf(stderr, "Error: cudaMalloc failed: %s, size: %ld\n\n", cudaGetErrorString(err), size); + *ptr = nullptr; + return -1; + } + + allocated_ptrs.push_back(*ptr); + allocated_size += aligned_size; + + // Update max_output_offset for tracking purposes + offset += aligned_size; + return offset; + } +#endif +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/minicpm4/minicpm4_attn.cuh b/examples/CPM.cu/src/model/minicpm4/minicpm4_attn.cuh new file mode 100644 index 00000000..a6227c6a --- /dev/null +++ b/examples/CPM.cu/src/model/minicpm4/minicpm4_attn.cuh @@ -0,0 +1,324 @@ +#pragma once +#include "../attn.cuh" +#include "minicpm4_kvcache.cuh" + +template <typename T> +struct MiniCPM4Attention { + int hidden_size; + int num_attention_heads; + int num_key_value_heads; + int head_dim; + float rms_norm_eps; + + Norm<T> *attn_norm; + Linear<T> *q_proj, *k_proj, *v_proj; + Linear<T> *o_proj; + T* output; + + T* attn_output; + float *softmax_lse, *softmax_lse_accum, *oaccum; + + int sink_window_size; + int block_window_size; + int sparse_switch; + bool apply_compress_lse; + + MiniCPM4Attention(int hidden_size, int num_attention_heads, int num_key_value_heads, int head_dim, float rms_norm_eps, int sink_window_size, int block_window_size, int sparse_switch, bool apply_compress_lse) { + this->hidden_size = hidden_size; + this->num_attention_heads = num_attention_heads; + this->num_key_value_heads = num_key_value_heads; + this->head_dim = head_dim; + this->rms_norm_eps = rms_norm_eps; + + this->attn_norm = new RMSNorm<T>(hidden_size, rms_norm_eps); + this->q_proj = new Linear<T>(hidden_size, num_attention_heads * head_dim); + this->k_proj = new Linear<T>(hidden_size, num_key_value_heads * head_dim); + this->v_proj = new Linear<T>(hidden_size, num_key_value_heads * head_dim); + this->o_proj = new Linear<T>(hidden_size, num_attention_heads * head_dim); + + this->sink_window_size = sink_window_size; + this->block_window_size = block_window_size; + this->sparse_switch = sparse_switch; + this->apply_compress_lse = apply_compress_lse; + } + + void init_weight_ptr(Memory* memory) { + this->attn_norm->init_weight_ptr(memory); + this->q_proj->init_weight_ptr(memory); + this->k_proj->init_weight_ptr(memory); + this->v_proj->init_weight_ptr(memory); + this->o_proj->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t attn_norm_end = this->attn_norm->init_output_ptr(memory, num_tokens, offset); + int64_t q_proj_end = this->q_proj->init_output_ptr(memory, num_tokens, attn_norm_end); + int64_t k_proj_end = this->k_proj->init_output_ptr(memory, num_tokens, q_proj_end); + int64_t v_proj_end = this->v_proj->init_output_ptr(memory, num_tokens, k_proj_end); + + int64_t attn_output_end = memory->allocate((void**)&this->attn_output, offset, num_tokens * this->num_attention_heads * this->head_dim * sizeof(T)); + int64_t softmax_lse_end = memory->allocate((void**)&this->softmax_lse, v_proj_end, num_tokens * this->num_attention_heads * sizeof(float)); // TODO minicpm4 support larger num_splits + const int max_num_splits = 128; // Maximum number of splits for attention computation + const int max_spec_tree_size = 64; // Maximum size of speculative decoding tree + int64_t softmax_lse_accum_end = memory->allocate((void**)&this->softmax_lse_accum, softmax_lse_end, max(max_num_splits * max_spec_tree_size, num_tokens) * this->num_attention_heads * sizeof(float)); + int64_t oaccum_end = memory->allocate((void**)&this->oaccum, softmax_lse_accum_end, max(max_num_splits * max_spec_tree_size, num_tokens) * this->num_attention_heads * this->head_dim * sizeof(float)); + + int64_t o_proj_end = this->o_proj->init_output_ptr(memory, num_tokens, v_proj_end); + this->output = this->o_proj->output; + + return std::max(oaccum_end, o_proj_end); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.find("q_proj") != std::string::npos) { + this->q_proj->load_to_storage(name, ptr); + } else if (name.find("k_proj") != std::string::npos) { + this->k_proj->load_to_storage(name, ptr); + } else if (name.find("v_proj") != std::string::npos) { + this->v_proj->load_to_storage(name, ptr); + } else if (name.find("o_proj") != std::string::npos) { + this->o_proj->load_to_storage(name, ptr); + } else if (name.find("input_layernorm") != std::string::npos) { + this->attn_norm->load_to_storage(name, ptr); + } else { + throw std::invalid_argument("Unsupported name " + name); + } + } + + void prefill(const Stream& stream, int32_t num_tokens, int32_t num_history_tokens, T* input, T* prev_output, int32_t* position_ids, MiniCPM4KVCache<T>* kv_cache) { + T* k_cache = kv_cache->offset_k(num_history_tokens); + T* v_cache = kv_cache->offset_v(num_history_tokens); + + this->attn_norm->prefill(stream, num_tokens, input, prev_output); + this->q_proj->prefill(stream, num_tokens, this->attn_norm->output); + this->k_proj->prefill(stream, num_tokens, this->attn_norm->output, k_cache); + this->v_proj->prefill(stream, num_tokens, this->attn_norm->output, v_cache); + kv_cache->rotary_embedding->prefill(stream, num_tokens, this->num_attention_heads, this->num_key_value_heads, this->q_proj->output, k_cache, position_ids); + + cuda_perf_start_on_stream_f(M4_PREFILL_ATTN_CORE, stream.stream); + cuda_perf_start_on_stream_f(M4_PREFILL_ATTN_STAGE1, stream.stream); + if (num_history_tokens == 0) { + kv_cache->init(); + } else { + kv_cache->compress(stream); + } + + uint64_t *blockmask = nullptr; + if ((!apply_compress_lse && kv_cache->c1_len * kv_cache->c1_stride >= this->sparse_switch) || (apply_compress_lse && kv_cache->c2_len * kv_cache->c2_stride >= this->sparse_switch)) { + int q_round, k_round, out_len; + cuda_perf_start_on_stream_f(M4_PREFILL_ATTN_STAGE1_CORE, stream.stream); + mha_fwd_stage1( + TypeTraits<T>::type_code()==1, + 1, + num_tokens, + kv_cache->c1_len, + apply_compress_lse ? kv_cache->c2_len : kv_cache->c1_len, + this->num_attention_heads, + this->num_key_value_heads, + this->head_dim, + this->q_proj->output, + kv_cache->c1_cache, + apply_compress_lse ? kv_cache->c2_cache : kv_cache->c1_cache, + nullptr, + kv_cache->stage1_score, + rsqrtf(float(this->head_dim)), + false, + -1, + -1, + 0, + stream.stream, + q_round, + k_round + ); + cuda_perf_stop_on_stream_f(M4_PREFILL_ATTN_STAGE1_CORE, stream.stream); + maxpooling_func( + stream.stream, + kv_cache->stage1_score, + kv_cache->pool_score, + this->num_key_value_heads, + num_tokens, + q_round, + k_round, + kv_cache->next_kv_length, + this->sink_window_size, + this->block_window_size, + out_len + ); + kv_cache->topk_func->prefill( + stream, + this->num_key_value_heads*num_tokens, + kv_cache->pool_score, + out_len + ); + topk_to_uint64_func( + stream.stream, + kv_cache->topk_func->topk_pos, + kv_cache->blockmask, + this->num_key_value_heads*num_tokens, + kv_cache->topk_func->top, + num_history_tokens+num_tokens + ); + blockmask = kv_cache->blockmask; + } + cuda_perf_stop_on_stream_f(M4_PREFILL_ATTN_STAGE1, stream.stream); + + cuda_perf_start_on_stream_f(M4_PREFILL_ATTN_STAGE2, stream.stream); + mha_fwd_kvcache( + TypeTraits<T>::type_code()==1, + 1, + num_tokens, + num_history_tokens+num_tokens, + this->num_attention_heads, + this->num_key_value_heads, + this->head_dim, + this->q_proj->output, + kv_cache->k_cache, + kv_cache->v_cache, + nullptr, + Mask(nullptr), + this->attn_output, + this->softmax_lse, + this->softmax_lse_accum, + this->oaccum, + rsqrtf(float(this->head_dim)), + true, + -1, + -1, + 0, + stream.stream, + blockmask, + blockmask ? this->block_window_size : 0 // TODO fix this condition + ); + cuda_perf_stop_on_stream_f(M4_PREFILL_ATTN_STAGE2, stream.stream); + cuda_perf_stop_on_stream_f(M4_PREFILL_ATTN_CORE, stream.stream); + + // flash attention and put output to attn_norm->output + this->o_proj->prefill(stream, num_tokens, this->attn_output); + + kv_cache->next_kv_length = kv_cache->next_kv_length + num_tokens; + } + + void decode(const Stream& stream, int32_t num_tokens, int32_t padded_length, T* input, T* prev_output, int32_t* position_ids, int32_t* cache_length, const Mask& mask, MiniCPM4KVCache<T>* kv_cache) { + this->attn_norm->prefill(stream, num_tokens, input, prev_output); + T *q = nullptr; +#ifdef DISABLE_MEMPOOL + this->q_proj->prefill(stream, num_tokens, this->attn_norm->output); + this->k_proj->prefill(stream, num_tokens, this->attn_norm->output); + this->v_proj->prefill(stream, num_tokens, this->attn_norm->output); + q = this->q_proj->output; + kv_cache->rotary_embedding->prefill(stream, num_tokens, this->num_attention_heads, this->num_key_value_heads, this->q_proj->output, this->k_proj->output, position_ids); + copy_to_kvcache(stream, num_tokens, this->k_proj->output, this->v_proj->output, kv_cache, cache_length); +#else + int merge_dim_out = (this->num_attention_heads + 2 * this->num_key_value_heads) * this->head_dim; + if (num_tokens > 1) { + linear<T>(stream, num_tokens, this->hidden_size, merge_dim_out, this->attn_norm->output, this->q_proj->weight, this->v_proj->output); + permute(stream, num_tokens, this->num_attention_heads * this->head_dim, this->num_key_value_heads * this->head_dim, this->v_proj->output, this->q_proj->output); + } else { + linear<T>(stream, num_tokens, this->hidden_size, merge_dim_out, this->attn_norm->output, this->q_proj->weight, this->q_proj->output); + } + q = this->q_proj->output; + T* k = q + num_tokens * this->num_attention_heads * this->head_dim; + T* v = k + num_tokens * this->num_key_value_heads * this->head_dim; + kv_cache->rotary_embedding->prefill(stream, num_tokens, this->num_attention_heads, this->num_key_value_heads, q, k, position_ids); + copy_to_kvcache(stream, num_tokens, k, v, kv_cache, cache_length); +#endif + + cuda_perf_start_on_stream_f(M4_DECODE_ATTN_CORE, stream.stream); + cuda_perf_start_on_stream_f(M4_DECODE_ATTN_STAGE1, stream.stream); + kv_cache->compress(stream); + + uint64_t *blockmask = nullptr; + if ((!apply_compress_lse && kv_cache->c1_len * kv_cache->c1_stride >= this->sparse_switch) || (apply_compress_lse && kv_cache->c2_len * kv_cache->c2_stride >= this->sparse_switch)) { + int q_round, k_round, out_len; + cuda_perf_start_on_stream_f(M4_DECODE_ATTN_STAGE1_CORE, stream.stream); + mha_fwd_stage1( + TypeTraits<T>::type_code()==1, + 1, + num_tokens, + kv_cache->c1_len, + apply_compress_lse ? kv_cache->c2_len : kv_cache->c1_len, + this->num_attention_heads, + this->num_key_value_heads, + this->head_dim, + this->q_proj->output, + kv_cache->c1_cache, + apply_compress_lse ? kv_cache->c2_cache : kv_cache->c1_cache, + nullptr, + kv_cache->stage1_score, + rsqrtf(float(this->head_dim)), + false, + -1, + -1, + 0, + stream.stream, + q_round, + k_round + ); + cuda_perf_stop_on_stream_f(M4_DECODE_ATTN_STAGE1_CORE, stream.stream); + maxpooling_func( + stream.stream, + kv_cache->stage1_score, + kv_cache->pool_score, + this->num_key_value_heads, + num_tokens, + q_round, + k_round, + kv_cache->next_kv_length, + this->sink_window_size, + this->block_window_size, + out_len + ); + kv_cache->topk_func->prefill( + stream, + this->num_key_value_heads*num_tokens, + kv_cache->pool_score, + out_len + ); + topk_to_uint64_func( + stream.stream, + kv_cache->topk_func->topk_pos, + kv_cache->blockmask, + this->num_key_value_heads*num_tokens, + kv_cache->topk_func->top, + padded_length + ); + blockmask = kv_cache->blockmask; + } + cuda_perf_stop_on_stream_f(M4_DECODE_ATTN_STAGE1, stream.stream); + + cuda_perf_start_on_stream_f(M4_DECODE_ATTN_STAGE2, stream.stream); + mha_fwd_kvcache( + TypeTraits<T>::type_code()==1, + 1, + num_tokens, + padded_length, + this->num_attention_heads, + this->num_key_value_heads, + this->head_dim, + q, + kv_cache->k_cache, + kv_cache->v_cache, + cache_length, + mask, + this->attn_output, + this->softmax_lse, + this->softmax_lse_accum, + this->oaccum, + rsqrtf(float(this->head_dim)), + true, + -1, + -1, + 0, + stream.stream, + blockmask, + blockmask ? this->block_window_size : 0 // TODO fix this condition + ); + cuda_perf_stop_on_stream_f(M4_DECODE_ATTN_STAGE2, stream.stream); + cuda_perf_stop_on_stream_f(M4_DECODE_ATTN_CORE, stream.stream); + + // flash attention and put output to attn_norm->output + this->o_proj->prefill(stream, num_tokens, this->attn_output); + + kv_cache->next_kv_length = kv_cache->next_kv_length + 1; + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/minicpm4/minicpm4_eagle.cuh b/examples/CPM.cu/src/model/minicpm4/minicpm4_eagle.cuh new file mode 100644 index 00000000..1b3bf5a6 --- /dev/null +++ b/examples/CPM.cu/src/model/minicpm4/minicpm4_eagle.cuh @@ -0,0 +1,419 @@ +#pragma once +#include <type_traits> +#include "../tree_drafter.cuh" +#include "../eagle.cuh" +#include "minicpm4_model.cuh" +#include "minicpm4_w4a16_gptq_marlin_model.cuh" +#include "../w4a16_gptq_marlin/w4a16_gptq_marlin_layer.cuh" +#include "../w4a16_gptq_marlin/w4a16_gptq_marlin_linear.cuh" + +template<typename T, class ModelType, class LayerType, class Fc1Type, class Fc2Type> +struct MiniCPM4EagleImpl : Model { + int num_layers; + int num_iter; + int topk_per_iter; + int tree_size; + int total_tried; + float residual_scale; + bool use_input_norm; + bool use_attn_norm; + + ModelType* model = nullptr; + KVCacheManager<T>* kv_caches = nullptr; + std::vector<LayerType*> layers; + Fc1Type *fc1 = nullptr; + Fc2Type *fc2 = nullptr; + Linear<T>* lm_head = nullptr; + int32_t* token_id_remap = nullptr; + RMSNorm<T> *input_norm1 = nullptr; + RMSNorm<T> *input_norm2 = nullptr; + functions::TopK<T>* topk_func = nullptr; + functions::TopK<T>* topk_func_2 = nullptr; + + T *prev_hidden_state, *prev_embed; + int num_prev, num_history_tokens; + int32_t *eagle_position_ids, *eagle_cache_length; + int *eagle_original_length, eagle_padded_length; + uint64_t *eagle_mask_2d, *tmp_mask_2d; + T* eagle_logits; + T* tried_history_val; int32_t* tried_history_pos; + int32_t* tried_history_parent = nullptr; + bool is_first_draft; + int frspec_vocab_size; + + int32_t *h_best, *d_best; + + T* tmp_kvcache; + + T* a_tmp = nullptr; + float* c_tmp = nullptr; + + MiniCPM4EagleImpl( + ModelType* model, + int num_layers, + int num_iter, + int topk_per_iter, + int tree_size, + int group_size = 128, + int eagle_window_size = 0, + int frspec_vocab_size = 0, + float residual_scale = 1.0f, + bool use_input_norm = true, + bool use_attn_norm = true + ) { + this->model = model; + this->num_layers = num_layers; + this->num_iter = num_iter; + this->topk_per_iter = topk_per_iter; + this->tree_size = tree_size; + assert(this->tree_size <= 64); + this->total_tried = topk_per_iter * topk_per_iter * (num_iter - 1) + topk_per_iter; + this->frspec_vocab_size = frspec_vocab_size > 0 ? frspec_vocab_size : this->model->vocab_size; + this->residual_scale = residual_scale; + this->use_input_norm = use_input_norm; + this->use_attn_norm = use_attn_norm; + + kv_caches = new KVCacheManager<T>(num_layers, this->model->num_key_value_heads, this->model->head_dim); + if constexpr (std::is_same_v<Fc1Type, W4A16GPTQMarlinLinear<T, true, true>>) { + fc1 = new W4A16GPTQMarlinLinear<T, true, true>(this->model->hidden_size, this->model->hidden_size, group_size); + fc2 = new W4A16GPTQMarlinLinear<T>(this->model->hidden_size, this->model->hidden_size, group_size); + } else { + fc1 = new Linear<T, true, true>(this->model->hidden_size, this->model->hidden_size); + fc2 = new Linear<T>(this->model->hidden_size, this->model->hidden_size); + } + if (use_input_norm) { + input_norm1 = new RMSNorm<T>(this->model->hidden_size, this->model->rms_norm_eps); + input_norm2 = new RMSNorm<T>(this->model->hidden_size, this->model->rms_norm_eps); + } + for (int i = 0; i < num_layers; i++) { + if constexpr (std::is_same_v<LayerType, W4A16GPTQMarlinLayer<T>>) { + layers.push_back(new W4A16GPTQMarlinLayer<T>(this->model->hidden_size, this->model->intermediate_size, this->model->num_attention_heads, this->model->num_key_value_heads, this->model->head_dim, this->model->rms_norm_eps, group_size, this->residual_scale, eagle_window_size)); + } else { + layers.push_back(new Layer<T>(this->model->hidden_size, this->model->intermediate_size, this->model->num_attention_heads, this->model->num_key_value_heads, this->model->head_dim, this->model->rms_norm_eps, this->residual_scale, eagle_window_size)); + } + } + if (this->frspec_vocab_size != this->model->vocab_size) { + lm_head = new Linear<T>(this->model->hidden_size, this->frspec_vocab_size); + } else { + lm_head = this->model->lm_head; + } + + assert(this->topk_per_iter <= this->tree_size-1); + + topk_func = new functions::TopK<T>(this->frspec_vocab_size, this->topk_per_iter); + topk_func_2 = new functions::TopK<T>(this->total_tried, this->tree_size-1); + } + + void init_weight_ptr(Memory* memory) { + fc1->init_weight_ptr(memory); + fc2->init_weight_ptr(memory); + if (use_input_norm) { + input_norm1->init_weight_ptr(memory); + input_norm2->init_weight_ptr(memory); + } + for (int i = 0; i < num_layers; i++) { + layers[i]->init_weight_ptr(memory); + } + if (this->frspec_vocab_size != this->model->vocab_size) { + lm_head->init_weight_ptr(memory); + } + if (!use_attn_norm) { + layers[0]->attn->attn_norm = new Skip<T>(this->model->hidden_size); + } + kv_caches->rotary_embedding = this->model->kv_caches->rotary_embedding; + token_id_remap = (int32_t*)memory->allocate_for_model(this->frspec_vocab_size * sizeof(int32_t)); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + if constexpr (std::is_same_v<Fc1Type, W4A16GPTQMarlinLinear<T, true, true>>) { + offset = memory->allocate((void**)&this->a_tmp, offset, 2 * num_tokens * this->model->hidden_size * sizeof(T)); + int reduce_max_m = marlin::determine_reduce_max_m(num_tokens, marlin::max_par); + int reduce_n = 2 * this->model->hidden_size; + offset = memory->allocate((void**)&this->c_tmp, offset, reduce_max_m * reduce_n * sizeof(float)); + } + offset = fc1->init_output_ptr(memory, num_tokens, offset); + offset = fc2->init_output_ptr(memory, num_tokens, offset); + if (use_input_norm) { + offset = input_norm1->init_output_ptr(memory, num_tokens, offset); + offset = input_norm2->init_output_ptr(memory, num_tokens, offset); + } + int64_t layer_end = 0; + for (int i = 0; i < num_layers; i++) { + layer_end = layers[i]->init_output_ptr(memory, num_tokens, offset); + } + offset = layer_end; + if (this->frspec_vocab_size != this->model->vocab_size) { + offset = lm_head->init_output_ptr(memory, 64, offset); + } + offset = memory->allocate((void**)&eagle_logits, offset, this->topk_per_iter * this->frspec_vocab_size * sizeof(T)); + offset = memory->allocate((void**)&eagle_mask_2d, offset, this->topk_per_iter * sizeof(uint64_t)); + offset = memory->allocate((void**)&tmp_mask_2d, offset, this->topk_per_iter * sizeof(uint64_t)); + offset = memory->allocate((void**)&tried_history_val, offset, this->total_tried * sizeof(T)); + offset = memory->allocate((void**)&tried_history_pos, offset, this->total_tried * sizeof(int32_t)); + if (this->num_iter > 1) { + offset = memory->allocate((void**)&tried_history_parent, offset, this->topk_per_iter * (this->num_iter - 1) * sizeof(int32_t)); + } + cudaMallocHost(&eagle_original_length, sizeof(int32_t)); + + offset = topk_func->init_output_ptr(memory, this->topk_per_iter, offset); + offset = topk_func_2->init_output_ptr(memory, 1*16, offset); + + offset = memory->allocate((void**)&prev_hidden_state, offset, num_tokens * this->model->hidden_size * sizeof(T)); + offset = memory->allocate((void**)&prev_embed, offset, num_tokens * this->model->hidden_size * sizeof(T)); + offset = memory->allocate((void**)&eagle_position_ids, offset, num_tokens * sizeof(int32_t)); + offset = memory->allocate((void**)&eagle_cache_length, offset, sizeof(int32_t)); + + offset = memory->allocate((void**)&d_best, offset, 2 * sizeof(int32_t)); + cudaMallocHost(&h_best, 2 * sizeof(int32_t)); + offset = memory->allocate((void**)&tmp_kvcache, offset, 64 * this->model->kv_caches->num_hidden_layers * 2 * this->model->kv_caches->dim * sizeof(T)); + return offset; + } + + int init_storage() { + this->model->init_weight_ptr(this->model->memory); + this->init_weight_ptr(this->model->memory); + int64_t offset = this->model->init_output_ptr(this->model->memory, this->model->chunk_length, this->model->memory->model_offset); + int64_t kv_cache_offset = this->init_output_ptr(this->model->memory, this->model->chunk_length, offset); + float ratio = float(this->model->num_hidden_layers) / (this->model->num_hidden_layers + this->num_layers); + if constexpr (std::is_same_v<ModelType, MiniCPM4Impl<T>> || std::is_same_v<ModelType, MiniCPM4W4A16GPTQMarlinModelImpl<T>>) { + kv_cache_offset = this->model->kv_caches->init_output_ptr(this->model->memory, this->model->chunk_length, kv_cache_offset, ratio); + } else { + kv_cache_offset = this->model->kv_caches->init_output_ptr(this->model->memory, kv_cache_offset, ratio); + } + kv_caches->init_output_ptr(this->model->memory, kv_cache_offset); + return min(kv_caches->budget, this->model->kv_caches->budget); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.substr(0, 5) == "eagle") { + if (name.substr(0, 9) == "eagle.fc1") { + fc1->load_to_storage(name, ptr); + } else if (name.substr(0, 9) == "eagle.fc2") { + fc2->load_to_storage(name, ptr); + } else if (name.substr(0, 20) == "eagle.token_id_remap") { + cudaMemcpy((void*)token_id_remap, ptr, this->frspec_vocab_size * sizeof(int32_t), cudaMemcpyHostToDevice); + } else if (name.find("eagle.input_norm1") != std::string::npos) { + if (!use_input_norm) throw std::invalid_argument("norm is not used, but input_norm1 is found"); + input_norm1->load_to_storage(name, ptr); + } else if (name.find("eagle.input_norm2") != std::string::npos) { + if (!use_input_norm) throw std::invalid_argument("norm is not used, but input_norm2 is found"); + input_norm2->load_to_storage(name, ptr); + } else if (name.find("eagle.rotary_emb") != std::string::npos) { + kv_caches->rotary_embedding->load_to_storage(name, ptr); + } else { + std::regex layer_regex("eagle\\.layers\\.(\\d+)\\.(.*)"); + std::smatch matches; + if (std::regex_search(name, matches, layer_regex)) { + int layer_idx = std::stoi(matches[1]); + layers[layer_idx]->load_to_storage(matches[2], ptr); + } else { + throw std::invalid_argument("Unsupported name (layer_idx not found): " + name); + } + } + } else { + this->model->load_to_storage(name, ptr); + if (name.substr(0, 7) == "lm_head") { + if (this->frspec_vocab_size != this->model->vocab_size) { + remap_copy(calc_stream, this->model->lm_head->weight, this->lm_head->weight, this->model->hidden_size, this->frspec_vocab_size, this->token_id_remap); + } + } + } + } + + void eagle_prefill(int num_history_tokens) { + cudaMemcpy(this->prev_embed + (num_prev - 1) * this->model->hidden_size, this->model->embedding->output, this->model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + if (use_input_norm) { + this->input_norm1->prefill(calc_stream, num_prev, this->prev_embed, nullptr); + this->input_norm2->prefill(calc_stream, num_prev, this->prev_hidden_state, nullptr); + if constexpr (std::is_same_v<Fc1Type, W4A16GPTQMarlinLinear<T, true, true>>) { + this->fc1->prefill(calc_stream, num_prev, this->input_norm1->output, this->a_tmp, this->c_tmp); + this->fc2->prefill(calc_stream, num_prev, this->input_norm2->output, this->a_tmp, this->c_tmp); + } else { + this->fc1->prefill(calc_stream, num_prev, this->input_norm1->output); + this->fc2->prefill(calc_stream, num_prev, this->input_norm2->output); + } + } else { + if constexpr (std::is_same_v<Fc1Type, W4A16GPTQMarlinLinear<T, true, true>>) { + this->fc1->prefill(calc_stream, num_prev, this->prev_embed, this->a_tmp, this->c_tmp); + this->fc2->prefill(calc_stream, num_prev, this->prev_hidden_state, this->a_tmp, this->c_tmp); + } else { + this->fc1->prefill(calc_stream, num_prev, this->prev_embed); + this->fc2->prefill(calc_stream, num_prev, this->prev_hidden_state); + } + } + elementwise_add(calc_stream, num_prev, this->model->hidden_size, this->fc1->output, this->fc2->output, this->fc2->output); + T* layer_output = nullptr; + for (int i = 0; i < num_layers; i++) { + this->layers[i]->prefill(num_prev, num_history_tokens, this->fc2->output, layer_output, this->eagle_position_ids, this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_scale(calc_stream, num_prev, this->model->hidden_size, layer_output, this->residual_scale); + elementwise_add(calc_stream, num_prev, this->model->hidden_size, this->fc2->output, layer_output, this->fc2->output); + } + + void eagle_decode(int32_t* cache_length) { + if (use_input_norm) { + this->input_norm1->prefill(calc_stream, num_prev, this->prev_embed, nullptr); + this->input_norm2->prefill(calc_stream, num_prev, this->prev_hidden_state, nullptr); + if constexpr (std::is_same_v<Fc1Type, W4A16GPTQMarlinLinear<T, true, true>>) { + this->fc1->prefill(calc_stream, num_prev, this->input_norm1->output, this->a_tmp, this->c_tmp); + this->fc2->prefill(calc_stream, num_prev, this->input_norm2->output, this->a_tmp, this->c_tmp); + } else { + this->fc1->prefill(calc_stream, num_prev, this->input_norm1->output); + this->fc2->prefill(calc_stream, num_prev, this->input_norm2->output); + } + } else { + if constexpr (std::is_same_v<Fc1Type, W4A16GPTQMarlinLinear<T, true, true>>) { + this->fc1->prefill(calc_stream, num_prev, this->prev_embed, this->a_tmp, this->c_tmp); + this->fc2->prefill(calc_stream, num_prev, this->prev_hidden_state, this->a_tmp, this->c_tmp); + } else { + this->fc1->prefill(calc_stream, num_prev, this->prev_embed); + this->fc2->prefill(calc_stream, num_prev, this->prev_hidden_state); + } + } + elementwise_add(calc_stream, num_prev, this->model->hidden_size, this->fc1->output, this->fc2->output, this->fc2->output); + T* layer_output = nullptr; + for (int i = 0; i < num_layers; i++) { + this->layers[i]->decode(num_prev, this->eagle_padded_length, this->fc2->output, layer_output, this->eagle_position_ids, cache_length, Mask(nullptr), this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_scale(calc_stream, num_prev, this->model->hidden_size, layer_output, this->residual_scale); + elementwise_add(calc_stream, num_prev, this->model->hidden_size, this->fc2->output, layer_output, this->fc2->output); + } + + void prefill(int32_t num_tokens, int32_t num_history_tokens, int32_t* input, int32_t* position_ids, void* output) { + this->model->embedding->prefill(calc_stream, num_tokens, input); + if (num_history_tokens > 0) { + this->eagle_prefill(num_history_tokens); + } + + cudaMemcpy(this->prev_embed, this->model->embedding->output + this->model->hidden_size, (num_tokens - 1) * this->model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + this->model->prefill_embed(num_tokens, num_history_tokens, this->model->embedding->output, position_ids, output); + this->prev_hidden_state = this->model->norm->output; + cudaMemcpy(this->eagle_position_ids, position_ids, num_tokens * sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->num_prev = num_tokens; + + this->num_history_tokens = num_history_tokens; + this->is_first_draft = true; + } + + void decode(int32_t num_tokens, int32_t padded_length, int32_t* input, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + this->model->decode(num_tokens, padded_length, input, position_ids, cache_length, mask_2d, output); + } + + void draft(int32_t* tree_draft_ids, int32_t* tree_position_ids, int32_t* cache_length, uint64_t* tree_attn_mask, int32_t* tree_parent) { + cudaMemcpy(this->eagle_original_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + this->eagle_padded_length = (this->eagle_original_length[0] + 256 - 1) / 128 * 128; + + + if (this->is_first_draft) { + this->model->embedding->prefill(calc_stream, 1, tree_draft_ids); + this->eagle_prefill(this->num_history_tokens); + } else { + this->eagle_decode(cache_length); + } + cudaMemcpy(this->eagle_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->eagle_position_ids, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + repeat(calc_stream, topk_per_iter, 1, 0, this->eagle_position_ids); + + { // d = 0 + lm_head->prefill(calc_stream, 1, this->fc2->output + (num_prev - 1) * this->model->hidden_size, this->eagle_logits); + log_softmax(calc_stream, 1, this->frspec_vocab_size, this->eagle_logits); + this->topk_func->prefill(calc_stream, 1, this->eagle_logits); + cudaMemcpy(this->tried_history_val, this->topk_func->topk_val, topk_per_iter * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->tried_history_pos, this->topk_func->topk_pos, topk_per_iter * sizeof(int32_t), cudaMemcpyDeviceToDevice); + if (this->frspec_vocab_size != this->model->vocab_size) { + remap(calc_stream, topk_per_iter, this->topk_func->topk_pos, this->topk_func_2->topk_pos, this->token_id_remap); + } else { + cudaMemcpy(this->topk_func_2->topk_pos, this->topk_func->topk_pos, topk_per_iter * sizeof(int32_t), cudaMemcpyDeviceToDevice); + } + cudaMemcpy(this->topk_func_2->topk_val, this->topk_func->topk_val, topk_per_iter * sizeof(T), cudaMemcpyDeviceToDevice); + repeat(calc_stream, topk_per_iter, this->model->hidden_size, num_prev-1, this->fc2->output, this->fc1->output); + init_tree(calc_stream, topk_per_iter, this->eagle_mask_2d); + } + for (int d = 1; d < this->num_iter; ++d) { + add(calc_stream, 1, this->eagle_cache_length, topk_per_iter); + this->model->embedding->prefill(calc_stream, topk_per_iter, this->topk_func_2->topk_pos); + if (use_input_norm) { + this->input_norm1->prefill(calc_stream, topk_per_iter, this->model->embedding->output, nullptr); + this->input_norm2->prefill(calc_stream, topk_per_iter, this->fc1->output, nullptr); + if constexpr (std::is_same_v<Fc1Type, W4A16GPTQMarlinLinear<T, true, true>>) { + this->fc1->prefill(calc_stream, topk_per_iter, this->input_norm1->output, this->a_tmp, this->c_tmp); + this->fc2->prefill(calc_stream, topk_per_iter, this->input_norm2->output, this->a_tmp, this->c_tmp); + } else { + this->fc1->prefill(calc_stream, topk_per_iter, this->input_norm1->output); + this->fc2->prefill(calc_stream, topk_per_iter, this->input_norm2->output); + } + } else { + if constexpr (std::is_same_v<Fc1Type, W4A16GPTQMarlinLinear<T, true, true>>) { + this->fc2->prefill(calc_stream, topk_per_iter, this->fc1->output, this->a_tmp, this->c_tmp); + this->fc1->prefill(calc_stream, topk_per_iter, this->model->embedding->output, this->a_tmp, this->c_tmp); + } else { + this->fc2->prefill(calc_stream, topk_per_iter, this->fc1->output); + this->fc1->prefill(calc_stream, topk_per_iter, this->model->embedding->output); + } + } + elementwise_add(calc_stream, topk_per_iter, this->model->hidden_size, this->fc1->output, this->fc2->output, this->fc2->output); + T* layer_output = nullptr; + for (int i = 0; i < num_layers; i++) { + this->layers[i]->decode(topk_per_iter, this->eagle_padded_length, this->fc2->output, layer_output, this->eagle_position_ids, this->eagle_cache_length, Mask(eagle_mask_2d, topk_per_iter, topk_per_iter * d), this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_scale(calc_stream, topk_per_iter, this->model->hidden_size, layer_output, this->residual_scale); + elementwise_add(calc_stream, topk_per_iter, this->model->hidden_size, this->fc2->output, layer_output, this->fc2->output); + add(calc_stream, topk_per_iter, this->eagle_position_ids, 1); + + lm_head->prefill(calc_stream, topk_per_iter, this->fc2->output, this->eagle_logits); + log_softmax(calc_stream, topk_per_iter, this->frspec_vocab_size, this->eagle_logits); + this->topk_func->prefill(calc_stream, topk_per_iter, this->eagle_logits); + cumsum(calc_stream, topk_per_iter, topk_per_iter, this->topk_func->topk_val, this->topk_func_2->topk_val); + cudaMemcpy(this->tried_history_val + topk_per_iter + (d - 1) * topk_per_iter * topk_per_iter, this->topk_func->topk_val, topk_per_iter * topk_per_iter * sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->tried_history_pos + topk_per_iter + (d - 1) * topk_per_iter * topk_per_iter, this->topk_func->topk_pos, topk_per_iter * topk_per_iter * sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->topk_func_2->prefill(calc_stream, 1, this->topk_func->topk_val, topk_per_iter * topk_per_iter, topk_per_iter); + + cudaMemcpy(this->tmp_mask_2d, this->eagle_mask_2d, topk_per_iter * sizeof(uint64_t), cudaMemcpyDeviceToDevice); + set_parent(calc_stream, topk_per_iter, this->tried_history_parent + (d - 1) * topk_per_iter, this->topk_func_2->topk_pos, topk_per_iter + (d - 1) * topk_per_iter * topk_per_iter); + update_tree(calc_stream, topk_per_iter, topk_per_iter * d, this->eagle_mask_2d, this->tmp_mask_2d, this->topk_func_2->topk_pos); + remap_hidden(calc_stream, topk_per_iter, this->model->hidden_size, this->topk_func_2->topk_pos, this->fc2->output, this->fc1->output, topk_per_iter); + if (this->frspec_vocab_size != this->model->vocab_size) { + remap_id_fr(calc_stream, topk_per_iter, this->topk_func_2->topk_pos, this->topk_func->topk_pos, this->token_id_remap); + } else { + remap_id(calc_stream, topk_per_iter, this->topk_func_2->topk_pos, this->topk_func->topk_pos); + } + } + + this->topk_func_2->prefill(calc_stream, 1, this->tried_history_val); + + // build tree + build_dynamic_tree(calc_stream, this->tree_size, this->eagle_original_length[0], this->topk_per_iter, this->tried_history_parent, this->topk_func_2->topk_pos, tree_position_ids, tree_attn_mask, tree_parent); + if (this->frspec_vocab_size != this->model->vocab_size) { + remap_id_fr(calc_stream, this->tree_size-1, this->topk_func_2->topk_pos, this->tried_history_pos, this->token_id_remap, tree_draft_ids + 1); + } else { + remap_id(calc_stream, this->tree_size-1, this->topk_func_2->topk_pos, this->tried_history_pos, tree_draft_ids + 1); + } + + this->is_first_draft = false; + } + + int verify(int32_t num_tokens, int32_t* pred, int32_t* gt, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, int32_t* tree_parent) { + verify_draft(calc_stream, num_tokens, pred, gt, position_ids, cache_length, mask_2d, tree_parent, this->d_best); + cudaMemcpyAsync(this->h_best, this->d_best, 2 * sizeof(int32_t), cudaMemcpyDeviceToHost, calc_stream.stream); + cudaStreamSynchronize(calc_stream.stream); + + this->num_prev = h_best[0]; + remap_hidden(calc_stream, this->num_prev, this->model->hidden_size, pred, this->model->norm->output, this->prev_hidden_state); + + fix_kv_cache(calc_stream, h_best[0], this->model->kv_caches->num_hidden_layers * 2, this->model->kv_caches->dim, pred, gt, cache_length, this->model->kv_caches->d_flat_caches, this->tmp_kvcache); + + this->model->embedding->prefill(calc_stream, this->num_prev, pred); + cudaMemcpy(this->prev_embed, this->model->embedding->output, this->num_prev * this->model->hidden_size * sizeof(T), cudaMemcpyDeviceToDevice); + + make_arange(calc_stream, this->num_prev, cache_length, this->eagle_position_ids); + + if constexpr (std::is_same_v<ModelType, MiniCPM4Impl<T>> || std::is_same_v<ModelType, MiniCPM4W4A16GPTQMarlinModelImpl<T>>) { + this->model->kv_caches->add_length(h_best[0] - 1); + } + + return h_best[0]; + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/minicpm4/minicpm4_kvcache.cuh b/examples/CPM.cu/src/model/minicpm4/minicpm4_kvcache.cuh new file mode 100644 index 00000000..686d1a47 --- /dev/null +++ b/examples/CPM.cu/src/model/minicpm4/minicpm4_kvcache.cuh @@ -0,0 +1,316 @@ +#pragma once +#include "../kvcache.cuh" +#include "../topk.cuh" + +namespace { +template <typename T> +__global__ void meanpooling_16_kernel(int left, int dim, T* compressed, const T* k_cache) { + __shared__ T s[32][33]; + + int idx = blockIdx.x + left; + int orig_left = idx * 16; + T* c = compressed + idx * dim; + const T* k = k_cache + orig_left * dim; + int i = threadIdx.x / 32; + int j = threadIdx.x % 32; + + for (int offset = 0; offset < dim; offset += 32) { + s[i][j] = k[i * dim + offset + j]; + __syncthreads(); + float v = s[j][i]; + v += __shfl_down_sync(0xffffffff, v, 16); + v += __shfl_down_sync(0xffffffff, v, 8); + v += __shfl_down_sync(0xffffffff, v, 4); + v += __shfl_down_sync(0xffffffff, v, 2); + v += __shfl_down_sync(0xffffffff, v, 1); + if (j == 0) { + c[offset + i] = T(v / 32.0f); + } + } +} + +template <typename T> +__global__ void meanpooling_64_kernel(int left, int dim, T* compressed, const T* k_cache) { + __shared__ T s[32][33]; + + int idx = blockIdx.x + left; + int orig_left = idx * 64; + T* c = compressed + idx * dim; + const T* k = k_cache + orig_left * dim; + int i = threadIdx.x / 32; + int j = threadIdx.x % 32; + + for (int offset = 0; offset < dim; offset += 32) { + float v_sum[32] = {0}; + for (int offset_row = 0; offset_row < 128; offset_row += 32) { + s[i][j] = k[(i + offset_row) * dim + offset + j]; + __syncthreads(); + float v = s[j][i]; + v += __shfl_down_sync(0xffffffff, v, 16); + v += __shfl_down_sync(0xffffffff, v, 8); + v += __shfl_down_sync(0xffffffff, v, 4); + v += __shfl_down_sync(0xffffffff, v, 2); + v += __shfl_down_sync(0xffffffff, v, 1); + if (j == 0) { + v_sum[i] += v; + } + } + if (j == 0) { + c[offset + i] = T(v_sum[i] / 128.0f); + } + } +} + +template <typename T> +__global__ void maxpooling_kernel( + const T* input, + T* output, + int num_heads, + int q_len, + int q_round, + int k_len, + int out_len, + int cache_len, + int init_blocks, + int local_blocks, + int kernel_size, + int stride, + int padding, + int block_size +) { + int bidh = blockIdx.y; + int bidq = blockIdx.x; + const T* in = input + bidh * q_round * k_len + bidq * k_len; + T* out = output + bidh * q_len * out_len + bidq * out_len; + int q_block = (bidq + cache_len) / block_size; + + for (int k = threadIdx.x; k < out_len; k += blockDim.x) { + int start = k * stride - padding; + int end = start + kernel_size; + start = max(start, 0); + end = min(end, k_len); + + T max_val; + if (k < init_blocks) { + max_val = TypeTraits<T>::inf(); + } else if (q_block - local_blocks < k) { + max_val = -TypeTraits<T>::inf(); + } else { + max_val = in[start]; + for (int i = start + 1; i < end; i++) { + if (in[i] > max_val) { + max_val = in[i]; + } + } + } + out[k] = max_val; + } +} + +__global__ void kernel_topk_to_uint64( + const int* topk_idx, + uint64_t* result, + int batch_size, + int k, + int k_blocks, + int n_uint64_per_row +) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + int col = blockIdx.y; + + if (row >= batch_size || col >= n_uint64_per_row) return; + + int out_idx = row * n_uint64_per_row + col; + + int bit_start = col * 64; + + uint64_t packed_value = 0; + + for (int i = 0; i < k; i++) { + int idx_offset = row * k + i; + int idx = topk_idx[idx_offset]; + + if (idx == -1) continue; + + if (idx >= bit_start && idx < bit_start + 64) { + int local_bit = idx - bit_start; + packed_value |= (1ULL << local_bit); + } + } + + result[out_idx] = packed_value; +} + +template <typename T> +void meanpooling(const Stream& stream, int left, int right, int dim, T* compressed, const T* k_cache, int stride) { + if (left == right) return; + if (stride == 16) { + meanpooling_16_kernel<<<right-left, 1024, 0, stream.stream>>>(left, dim, compressed, k_cache); + } else if (stride == 64) { + meanpooling_64_kernel<<<right-left, 1024, 0, stream.stream>>>(left, dim, compressed, k_cache); + } else { + throw std::runtime_error("Unsupported meanpooling stride: " + std::to_string(stride)); + } +} + +template <typename T> +void maxpooling_func( + cudaStream_t stream, + const T* input, + T* output, + int num_heads, + int q_len, + int q_round, + int k_len, + int cache_len, + int init_blocks, + int local_blocks, + int &out_len, + int kernel_size=5, + int stride=4, + int padding=1, + int block_size=64 +) { + out_len = (cache_len + block_size - 1) / block_size; + maxpooling_kernel<<<dim3(q_len, num_heads), 256, 0, stream>>>( + input, output, num_heads, q_len, q_round, k_len, out_len, cache_len, init_blocks, local_blocks, kernel_size, stride, padding, block_size + ); +} + +void topk_to_uint64_func( // TODO not necessary now, since topk is small + cudaStream_t stream, + const int* topk_idx, // Input topk indices + uint64_t* result, // Output uint64 array + int batch_size, // num_heads x q_len + int topk, // Number of topk values per row + int k_len, // k_len + int block_size = 64 +) { + int k_blocks = (k_len + block_size - 1) / block_size; + int n_uint64_per_row = (k_blocks + block_size - 1) / block_size; + + const int threads_per_block = 256; + const int blocks_per_row = (batch_size + threads_per_block - 1) / threads_per_block; + + dim3 grid(blocks_per_row, n_uint64_per_row); + dim3 block(threads_per_block, 1); + + kernel_topk_to_uint64<<<grid, block, 0, stream>>>( + topk_idx, result, batch_size, topk, k_blocks, n_uint64_per_row + ); +} +} + +template <typename T> +struct MiniCPM4KVCache : KVCache<T> { + uint64_t *blockmask; + T *stage1_score, *pool_score; + functions::TopK<T> *topk_func; + T *c1_cache, *c2_cache; + int c1_stride, c2_stride; + int c1_len, c2_len; + int prev_kv_length; + int next_kv_length; + bool apply_compress_lse; + + MiniCPM4KVCache(int dim, RotaryEmbedding<T> *rotary_embedding, uint64_t *blockmask, T* stage1_score, T* pool_score, functions::TopK<T> *topk_func, bool apply_compress_lse) : KVCache<T>(dim, rotary_embedding) { + this->blockmask = blockmask; + this->stage1_score = stage1_score; + this->pool_score = pool_score; + this->topk_func = topk_func; + c1_stride = 16; + c2_stride = 64; + assert(this->dim % 32 == 0); + this->apply_compress_lse = apply_compress_lse; + } + + void init() { + this->prev_kv_length = 0; + this->next_kv_length = 0; + this->c1_len = 0; + this->c2_len = 0; + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int32_t num_c1, int32_t num_c2, int64_t offset) { + offset = KVCache<T>::init_output_ptr(memory, num_tokens, offset); + offset = memory->allocate((void**)&this->c1_cache, offset, num_c1 * this->dim * sizeof(T)); + if (apply_compress_lse) { + offset = memory->allocate((void**)&this->c2_cache, offset, num_c2 * this->dim * sizeof(T)); + } + return offset; + } + + void compress(const Stream& stream) { + int prev_pos; + prev_pos = c1_len; + c1_len = max((this->next_kv_length - c1_stride) / c1_stride, 0); + meanpooling(stream, prev_pos, c1_len, this->dim, this->c1_cache, this->k_cache, c1_stride); + if (apply_compress_lse) { + prev_pos = c2_len; + c2_len = max((this->next_kv_length - c2_stride) / c2_stride, 0); + meanpooling(stream, prev_pos, c2_len, this->dim, this->c2_cache, this->k_cache, c2_stride); + } + this->prev_kv_length = this->next_kv_length; + } +}; + +template <typename T> +struct MiniCPM4KVCacheManager { + int num_hidden_layers; + int dim; + int budget; + int budget_c1, budget_c2; + std::vector<MiniCPM4KVCache<T>*> caches; + T **h_flat_caches, **d_flat_caches; + RotaryEmbedding<T> *rotary_embedding; + uint64_t *blockmask; + T* stage1_score, *pool_score; + functions::TopK<T> *topk_func; + bool apply_compress_lse; + + MiniCPM4KVCacheManager(int num_hidden_layers, int num_key_value_heads, int head_dim, int sparse_topk_k, bool apply_compress_lse) { + this->num_hidden_layers = num_hidden_layers; + this->dim = num_key_value_heads * head_dim; + this->rotary_embedding = new RotaryEmbedding<T>(head_dim); + this->topk_func = new functions::TopK<T>(4096, sparse_topk_k); // 256k/64 + this->apply_compress_lse = apply_compress_lse; + } + + void init_weight_ptr(Memory* memory) { + this->rotary_embedding->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int num_tokens, int64_t offset, float ratio=1.0) { + // 2 = num_heads + offset = memory->allocate((void**)&this->blockmask, offset, 2 * num_tokens * 64 * sizeof(uint64_t)); // 256k/64/64 + offset = memory->allocate((void**)&this->stage1_score, offset, 2 * num_tokens * 16384 * sizeof(T)); // 256k/16 + offset = memory->allocate((void**)&this->pool_score, offset, 2 * num_tokens * 4096 * sizeof(T)); // 256k/64 + offset = topk_func->init_output_ptr(memory, 2 * num_tokens, offset); + + offset = memory->allocate((void**)&this->d_flat_caches, offset, num_hidden_layers * 2 * sizeof(T*)); + + budget = int64_t(memory->get_remaining_memory(offset) * ratio * 0.999) / (this->num_hidden_layers * 2 * this->dim * sizeof(T)) - 1; + for (int i = 0; i < this->num_hidden_layers; i++) { + caches.push_back(new MiniCPM4KVCache<T>(this->dim, this->rotary_embedding, this->blockmask, stage1_score, pool_score, topk_func, apply_compress_lse)); + } + budget_c2 = (int)(budget / 69.0); // 1 + 4 + 64 + budget_c1 = budget_c2 * 4; + budget = budget_c1 * 16; + for (int i = 0; i < this->num_hidden_layers; i++) { + offset = caches[i]->init_output_ptr(memory, budget, budget_c1, budget_c2, offset); + } + this->h_flat_caches = new T*[num_hidden_layers * 2]; + for (int i = 0; i < num_hidden_layers; i++) { + this->h_flat_caches[i * 2] = caches[i]->k_cache; + this->h_flat_caches[i * 2 + 1] = caches[i]->v_cache; + } + cudaMemcpy(this->d_flat_caches, this->h_flat_caches, num_hidden_layers * 2 * sizeof(T*), cudaMemcpyHostToDevice); + return offset; + } + + void add_length(int length) { + for (int i = 0; i < this->num_hidden_layers; i++) { + caches[i]->next_kv_length += length; + } + } +}; diff --git a/examples/CPM.cu/src/model/minicpm4/minicpm4_layer.cuh b/examples/CPM.cu/src/model/minicpm4/minicpm4_layer.cuh new file mode 100644 index 00000000..9e6bea74 --- /dev/null +++ b/examples/CPM.cu/src/model/minicpm4/minicpm4_layer.cuh @@ -0,0 +1,85 @@ +#pragma once +#include "../layer.cuh" +#include "minicpm4_attn.cuh" + +template <typename T> +struct MiniCPM4Layer { + MiniCPM4Attention<T> *attn; + FFN<T> *ffn; + T* output; + int hidden_size; + float residual_scale; + + MiniCPM4Layer(int hidden_size, int intermediate_size, int num_attention_heads, int num_key_value_heads, int head_dim, float rms_norm_eps, float residual_scale = 1.0, int sink_window_size = 1, int block_window_size = 32, int sparse_switch = 8192, bool apply_compress_lse = false) { + this->attn = new MiniCPM4Attention<T>(hidden_size, num_attention_heads, num_key_value_heads, head_dim, rms_norm_eps, sink_window_size, block_window_size, sparse_switch, apply_compress_lse); + this->ffn = new GatedFFN<T>(hidden_size, intermediate_size, rms_norm_eps); + this->hidden_size = hidden_size; + this->residual_scale = residual_scale; + } + + void init_weight_ptr(Memory* memory) { + this->attn->init_weight_ptr(memory); + this->ffn->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t attn_end = this->attn->init_output_ptr(memory, num_tokens, offset); + int64_t ffn_end = this->ffn->init_output_ptr(memory, num_tokens, offset); + this->output = this->ffn->output; + return std::max(attn_end, ffn_end); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.find("attn") != std::string::npos || name.find("input_layernorm") != std::string::npos) { + this->attn->load_to_storage(name, ptr); + } else if (name.find("mlp") != std::string::npos || name.find("post_attention_layernorm") != std::string::npos) { + this->ffn->load_to_storage(name, ptr); + } else { + throw std::invalid_argument("Unsupported name " + name); + } + } + + void prefill(int32_t num_tokens, int32_t num_history_tokens, T* input, T* prev_output, int32_t* position_ids, KVCache<T>* kv_cache, T* prev_layer_states=nullptr) { + if (prev_output != nullptr) { + elementwise_scale(calc_stream, num_tokens, this->hidden_size, prev_output, this->residual_scale); + } + cuda_perf_start_on_stream_f(M4_PREFILL_ATTN, calc_stream.stream); + this->attn->prefill(calc_stream, num_tokens, num_history_tokens, input, prev_output, position_ids, static_cast<MiniCPM4KVCache<T>*>(kv_cache)); + cuda_perf_stop_on_stream_f(M4_PREFILL_ATTN, calc_stream.stream); + if (prev_layer_states != nullptr) { + cudaMemcpyAsync( + prev_layer_states, // dst + input, // src + num_tokens * this->attn->hidden_size * sizeof(T), + cudaMemcpyDeviceToDevice, + calc_stream.stream + ); + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, this->attn->output, this->residual_scale); + cuda_perf_start_on_stream_f(M4_PREFILL_FFN, calc_stream.stream); + this->ffn->prefill(calc_stream, num_tokens, input, this->attn->output); + cuda_perf_stop_on_stream_f(M4_PREFILL_FFN, calc_stream.stream); + } + + void decode(int32_t num_tokens, int32_t padded_length, T* input, T* prev_output, int32_t* position_ids, int32_t* cache_length, const Mask& mask, KVCache<T>* kv_cache, T* prev_layer_states=nullptr) { + if (prev_output != nullptr) { + elementwise_scale(calc_stream, num_tokens, this->hidden_size, prev_output, this->residual_scale); + } + cuda_perf_start_on_stream_f(M4_DECODE_ATTN, calc_stream.stream); + this->attn->decode(calc_stream, num_tokens, padded_length, input, prev_output, position_ids, cache_length, mask, static_cast<MiniCPM4KVCache<T>*>(kv_cache)); + cuda_perf_stop_on_stream_f(M4_DECODE_ATTN, calc_stream.stream); + if (prev_layer_states != nullptr) { + cudaMemcpyAsync( + prev_layer_states, // dst + input, // src + num_tokens * this->attn->hidden_size * sizeof(T), + cudaMemcpyDeviceToDevice, + calc_stream.stream + ); + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, this->attn->output, this->residual_scale); + cuda_perf_start_on_stream_f(M4_DECODE_FFN, calc_stream.stream); + this->ffn->decode(calc_stream, num_tokens, input, this->attn->output); + cuda_perf_stop_on_stream_f(M4_DECODE_FFN, calc_stream.stream); + } +}; diff --git a/examples/CPM.cu/src/model/minicpm4/minicpm4_model.cuh b/examples/CPM.cu/src/model/minicpm4/minicpm4_model.cuh new file mode 100644 index 00000000..b30c5a9b --- /dev/null +++ b/examples/CPM.cu/src/model/minicpm4/minicpm4_model.cuh @@ -0,0 +1,160 @@ +#pragma once +#include "../model.cuh" +#include "minicpm4_layer.cuh" +#include "minicpm4_kvcache.cuh" + +template <typename T> +struct MiniCPM4Impl : Model { + Memory* memory; + + int vocab_size; + int num_hidden_layers; + int hidden_size; + int intermediate_size; + int num_attention_heads; + int num_key_value_heads; + int head_dim; + float rms_norm_eps; + + int chunk_length; + + MiniCPM4KVCacheManager<T>* kv_caches; + + Embedding<T>* embedding; + std::vector<MiniCPM4Layer<T>*> layers; + RMSNorm<T>* norm; + LMHead<T>* lm_head; + float residual_scale; + + MiniCPM4Impl( + float memory_limit, + int vocab_size, + int num_hidden_layers, + int hidden_size, + int intermediate_size, + int num_attention_heads, + int num_key_value_heads, + int head_dim, + float rms_norm_eps, + int chunk_length, + float scale_embed = 1.0f, + float scale_lmhead = 1.0f, + float scale_residual = 1.0f, + int sink_window_size = 1, + int block_window_size = 32, + int sparse_topk_k = 32, + int sparse_switch = 8192, + bool apply_compress_lse = false + ) { + this->vocab_size = vocab_size; + this->num_hidden_layers = num_hidden_layers; + this->hidden_size = hidden_size; + this->intermediate_size = intermediate_size; + this->num_attention_heads = num_attention_heads; + this->num_key_value_heads = num_key_value_heads; + this->head_dim = head_dim; + this->rms_norm_eps = rms_norm_eps; + + this->chunk_length = chunk_length; + this->residual_scale = scale_residual; + + memory = new Memory(memory_limit); + + kv_caches = new MiniCPM4KVCacheManager<T>(num_hidden_layers, num_key_value_heads, head_dim, sparse_topk_k, apply_compress_lse); + + embedding = new Embedding<T>(vocab_size, hidden_size, scale_embed); + for (int i = 0; i < num_hidden_layers; i++) { + layers.push_back(new MiniCPM4Layer<T>(hidden_size, intermediate_size, num_attention_heads, num_key_value_heads, head_dim, rms_norm_eps, residual_scale, sink_window_size, block_window_size, sparse_switch, apply_compress_lse)); + } + norm = new RMSNorm<T>(hidden_size, rms_norm_eps); + lm_head = new LMHead<T>(hidden_size, vocab_size, scale_lmhead); + } + + void init_weight_ptr(Memory* memory) { + embedding->init_weight_ptr(memory); + for (int i = 0; i < num_hidden_layers; i++) { + layers[i]->init_weight_ptr(memory); + } + norm->init_weight_ptr(memory); + lm_head->init_weight_ptr(memory); + kv_caches->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t embedding_end = embedding->init_output_ptr(memory, num_tokens, offset); + int64_t layer_end = 0; + for (int i = 0; i < num_hidden_layers; i++) { + layer_end = layers[i]->init_output_ptr(memory, num_tokens, embedding_end); + } + // norm and lm_head are not used in prefill + int64_t norm_end = norm->init_output_ptr(memory, num_tokens, layer_end); + int64_t lm_head_end = lm_head->init_output_ptr(memory, 64, norm_end); + return lm_head_end; + } + + int init_storage() { + init_weight_ptr(memory); + int64_t kv_cache_offset = init_output_ptr(memory, chunk_length, memory->model_offset); + kv_cache_offset = kv_caches->init_output_ptr(memory, chunk_length, kv_cache_offset); + return this->kv_caches->budget; + } + + void load_to_storage(std::string name, void* ptr) { + if (name.substr(0, 18) == "model.embed_tokens") { + embedding->load_to_storage(name, ptr); + } else if (name.substr(0, 10) == "model.norm") { + norm->load_to_storage(name, ptr); + } else if (name.substr(0, 7) == "lm_head") { + lm_head->load_to_storage(name, ptr); + } else if (name.find("rotary_emb") != std::string::npos) { + kv_caches->rotary_embedding->load_to_storage(name, ptr); + } else if (name.substr(0, 12) == "model.layers") { // e.g. model.layers.20.attn.q_proj.weight + std::regex layer_regex("model\\.layers\\.(\\d+)\\.(.*)"); + std::smatch matches; + if (std::regex_search(name, matches, layer_regex)) { + int layer_idx = std::stoi(matches[1]); + layers[layer_idx]->load_to_storage(matches[2], ptr); + } else { + throw std::invalid_argument("Unsupported name (layer_idx not found): " + name); + } + } else { + throw std::invalid_argument("Unsupported name " + name); + } + } + + void prefill_embed(int32_t num_tokens, int32_t num_history_tokens, T* embed, int32_t* position_ids, void* output) { + T* layer_output = nullptr; + for (int i = 0; i < num_hidden_layers; i++) { + this->layers[i]->prefill(num_tokens, num_history_tokens, embed, layer_output, position_ids, this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, layer_output, this->residual_scale); + this->norm->prefill(calc_stream, num_tokens, embed, layer_output); + this->lm_head->prefill(calc_stream, 1, this->norm->output + (num_tokens - 1) * hidden_size, (T*)output); + } + + void prefill(int32_t num_tokens, int32_t num_history_tokens, int32_t* input, int32_t* position_ids, void* output) { + this->embedding->prefill(calc_stream, num_tokens, input); + prefill_embed(num_tokens, num_history_tokens, this->embedding->output, position_ids, output); + } + + void decode_embed(int32_t num_tokens, int32_t padded_length, T* embed, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + Mask mask(mask_2d, num_tokens, num_tokens); + T* layer_output = nullptr; + for (int i = 0; i < num_hidden_layers; i++) { + this->layers[i]->decode(num_tokens, padded_length, this->embedding->output, layer_output, position_ids, cache_length, mask, this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, layer_output, this->residual_scale); + this->norm->prefill(calc_stream, num_tokens, this->embedding->output, layer_output); + this->lm_head->prefill(calc_stream, num_tokens, this->norm->output, (T*)output); + } + + void decode(int32_t num_tokens, int32_t padded_length, int32_t* input, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + this->embedding->prefill(calc_stream, num_tokens, input); + decode_embed(num_tokens, padded_length, this->embedding->output, position_ids, cache_length, mask_2d, output); + } + + void draft(int32_t *tree_draft_ids, int32_t *tree_position_ids, int32_t *cache_length, uint64_t* attn_mask, int32_t* tree_parent) { throw std::runtime_error("Draft is not supported"); } + int verify(int32_t num_tokens, int32_t* pred, int32_t* gt, int32_t* position_ids, int32_t* cache_length, uint64_t* attn_mask, int32_t* tree_parent) { throw std::runtime_error("Verify is not supported"); } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/minicpm4/minicpm4_w4a16_gptq_marlin_attn.cuh b/examples/CPM.cu/src/model/minicpm4/minicpm4_w4a16_gptq_marlin_attn.cuh new file mode 100644 index 00000000..579e2e83 --- /dev/null +++ b/examples/CPM.cu/src/model/minicpm4/minicpm4_w4a16_gptq_marlin_attn.cuh @@ -0,0 +1,319 @@ +#pragma once +#include "../attn.cuh" +#include "../w4a16_gptq_marlin/w4a16_gptq_marlin_linear.cuh" +#include "minicpm4_kvcache.cuh" + +template <typename T> +struct MiniCPM4W4A16GPTQMarlinAttention { + int hidden_size; + int num_attention_heads; + int num_key_value_heads; + int head_dim; + float rms_norm_eps; + + Norm<T> *attn_norm; + W4A16GPTQMarlinLinear<T> *qkv_proj; + W4A16GPTQMarlinLinear<T> *o_proj; + T* output; + + T* attn_output; + float *softmax_lse, *softmax_lse_accum, *oaccum; + + T* q_proj_output, *v_proj_output, *k_proj_output; + T* permute_qkv_output; + + int sink_window_size; + int block_window_size; + int sparse_switch; + bool apply_compress_lse; + + MiniCPM4W4A16GPTQMarlinAttention(int hidden_size, int num_attention_heads, int num_key_value_heads, int head_dim, float rms_norm_eps, int group_size, int sink_window_size, int block_window_size, int sparse_switch, bool apply_compress_lse) { + this->hidden_size = hidden_size; + this->num_attention_heads = num_attention_heads; + this->num_key_value_heads = num_key_value_heads; + this->head_dim = head_dim; + this->rms_norm_eps = rms_norm_eps; + + this->attn_norm = new RMSNorm<T>(hidden_size, rms_norm_eps); + + this->qkv_proj = new W4A16GPTQMarlinLinear<T>(hidden_size, (num_attention_heads + 2*num_key_value_heads) * head_dim, group_size); + this->o_proj = new W4A16GPTQMarlinLinear<T>(hidden_size, num_attention_heads * head_dim, group_size); + + this->sink_window_size = sink_window_size; + this->block_window_size = block_window_size; + this->sparse_switch = sparse_switch; + this->apply_compress_lse = apply_compress_lse; + } + + void init_weight_ptr(Memory* memory) { + this->attn_norm->init_weight_ptr(memory); + this->qkv_proj->init_weight_ptr(memory); + this->o_proj->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t attn_norm_end = this->attn_norm->init_output_ptr(memory, num_tokens, offset); + int64_t qkv_proj_end = this->qkv_proj->init_output_ptr(memory, num_tokens, attn_norm_end); + + this->q_proj_output = this->qkv_proj->output; + this->k_proj_output = this->qkv_proj->output + num_tokens * this->num_attention_heads * this->head_dim; + this->v_proj_output = this->qkv_proj->output + num_tokens * (this->num_attention_heads+this->num_key_value_heads) * this->head_dim; + int64_t qkv_permute_end = memory->allocate((void**)&this->permute_qkv_output, qkv_proj_end, num_tokens * (this->num_attention_heads + 2*this->num_key_value_heads) * this->head_dim * sizeof(T)); + + int64_t attn_output_end = memory->allocate((void**)&this->attn_output, offset, num_tokens * this->num_attention_heads * this->head_dim * sizeof(T)); + int64_t softmax_lse_end = memory->allocate((void**)&this->softmax_lse, qkv_permute_end, num_tokens * this->num_attention_heads * sizeof(float)); + const int max_num_splits = 128; // Maximum number of splits for attention computation + const int max_spec_tree_size = 64; // Maximum size of speculative decoding tree + int64_t softmax_lse_accum_end = memory->allocate((void**)&this->softmax_lse_accum, softmax_lse_end, max(max_num_splits * max_spec_tree_size, num_tokens) * this->num_attention_heads * sizeof(float)); + int64_t oaccum_end = memory->allocate((void**)&this->oaccum, softmax_lse_accum_end, max(max_num_splits * max_spec_tree_size, num_tokens) * this->num_attention_heads * this->head_dim * sizeof(float)); + + int64_t o_proj_end = this->o_proj->init_output_ptr(memory, num_tokens, qkv_permute_end); + this->output = this->o_proj->output; + + return std::max(oaccum_end, o_proj_end); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.find("qkv_proj") != std::string::npos) { + this->qkv_proj->load_to_storage(name, ptr); + } else if (name.find("o_proj") != std::string::npos) { + this->o_proj->load_to_storage(name, ptr); + } else if (name.find("input_layernorm") != std::string::npos) { + this->attn_norm->load_to_storage(name, ptr); + } else { + throw std::invalid_argument("Unsupported name " + name); + } + } + + void prefill(const Stream& stream, int32_t num_tokens, int32_t num_history_tokens, T* input, T* prev_output, int32_t* position_ids, MiniCPM4KVCache<T>* kv_cache, T* a_tmp, float* c_tmp) { + T* k_cache = kv_cache->offset_k(num_history_tokens); + T* v_cache = kv_cache->offset_v(num_history_tokens); + + this->attn_norm->prefill(stream, num_tokens, input, prev_output); + this->qkv_proj->prefill(stream, num_tokens, this->attn_norm->output, a_tmp, c_tmp); + permute(stream, num_tokens, this->num_attention_heads * this->head_dim, this->num_key_value_heads * this->head_dim, this->qkv_proj->output, this->permute_qkv_output); + cudaMemcpy(k_cache, this->permute_qkv_output + num_tokens*this->num_attention_heads*this->head_dim, num_tokens*this->num_key_value_heads*this->head_dim*sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(v_cache, this->permute_qkv_output + num_tokens*( this->num_attention_heads + this->num_key_value_heads)*this->head_dim, num_tokens*this->num_key_value_heads*this->head_dim*sizeof(T), cudaMemcpyDeviceToDevice); + kv_cache->rotary_embedding->prefill(stream, num_tokens, this->num_attention_heads, this->num_key_value_heads, this->permute_qkv_output, k_cache, position_ids); + + cuda_perf_start_on_stream_f(M4Q_PREFILL_ATTN_CORE, stream.stream); + cuda_perf_start_on_stream_f(M4Q_PREFILL_ATTN_STAGE1, stream.stream); + if (num_history_tokens == 0) { + kv_cache->init(); + } else { + kv_cache->compress(stream); + } + + uint64_t *blockmask = nullptr; + if ((!apply_compress_lse && kv_cache->c1_len * kv_cache->c1_stride >= this->sparse_switch) || (apply_compress_lse && kv_cache->c2_len * kv_cache->c2_stride >= this->sparse_switch)) { + int q_round, k_round, out_len; + cuda_perf_start_on_stream_f(M4Q_PREFILL_ATTN_STAGE1_CORE, stream.stream); + mha_fwd_stage1( + TypeTraits<T>::type_code()==1, + 1, + num_tokens, + kv_cache->c1_len, + apply_compress_lse ? kv_cache->c2_len : kv_cache->c1_len, + this->num_attention_heads, + this->num_key_value_heads, + this->head_dim, + this->permute_qkv_output, + kv_cache->c1_cache, + apply_compress_lse ? kv_cache->c2_cache : kv_cache->c1_cache, + nullptr, + kv_cache->stage1_score, + rsqrtf(float(this->head_dim)), + false, + -1, + -1, + 0, + stream.stream, + q_round, + k_round + ); + cuda_perf_stop_on_stream_f(M4Q_PREFILL_ATTN_STAGE1_CORE, stream.stream); + maxpooling_func( + stream.stream, + kv_cache->stage1_score, + kv_cache->pool_score, + this->num_key_value_heads, + num_tokens, + q_round, + k_round, + kv_cache->next_kv_length, + this->sink_window_size, + this->block_window_size, + out_len + ); + kv_cache->topk_func->prefill( + stream, + this->num_key_value_heads*num_tokens, + kv_cache->pool_score, + out_len + ); + topk_to_uint64_func( + stream.stream, + kv_cache->topk_func->topk_pos, + kv_cache->blockmask, + this->num_key_value_heads*num_tokens, + kv_cache->topk_func->top, + num_history_tokens+num_tokens + ); + blockmask = kv_cache->blockmask; + } + cuda_perf_stop_on_stream_f(M4Q_PREFILL_ATTN_STAGE1, stream.stream); + + cuda_perf_start_on_stream_f(M4Q_PREFILL_ATTN_STAGE2, stream.stream); + mha_fwd_kvcache( + TypeTraits<T>::type_code()==1, + 1, + num_tokens, + num_history_tokens+num_tokens, + this->num_attention_heads, + this->num_key_value_heads, + this->head_dim, + this->permute_qkv_output, + kv_cache->k_cache, + kv_cache->v_cache, + nullptr, + Mask(nullptr), + this->attn_output, + this->softmax_lse, + this->softmax_lse_accum, + this->oaccum, + rsqrtf(float(this->head_dim)), + true, + -1, + -1, + 0, + stream.stream, + blockmask, + blockmask ? this->block_window_size : 0 // TODO fix this condition + ); + cuda_perf_stop_on_stream_f(M4Q_PREFILL_ATTN_STAGE2, stream.stream); + cuda_perf_stop_on_stream_f(M4Q_PREFILL_ATTN_CORE, stream.stream); + + // flash attention and put output to attn_norm->output + this->o_proj->prefill(stream, num_tokens, this->attn_output, a_tmp, c_tmp); + + kv_cache->next_kv_length = kv_cache->next_kv_length + num_tokens; + } + + void decode(const Stream& stream, int32_t num_tokens, int32_t padded_length, T* input, T* prev_output, int32_t* position_ids, int32_t* cache_length, const Mask& mask, MiniCPM4KVCache<T>* kv_cache, T* a_tmp, float* c_tmp) { + this->attn_norm->prefill(stream, num_tokens, input, prev_output); + T *q, *k, *v; + + if (num_tokens > 1) { + this->qkv_proj->prefill(stream, num_tokens, this->attn_norm->output, a_tmp, c_tmp); + permute(stream, num_tokens, this->num_attention_heads * this->head_dim, this->num_key_value_heads * this->head_dim, this->qkv_proj->output, this->permute_qkv_output); // TODO: Double check + q = this->permute_qkv_output; + } else { + this->qkv_proj->prefill(stream, num_tokens, this->attn_norm->output, a_tmp, c_tmp); + q = this->qkv_proj->output; + } + k = q + num_tokens * this->num_attention_heads * this->head_dim; + v = k + num_tokens * this->num_key_value_heads * this->head_dim; + kv_cache->rotary_embedding->prefill(stream, num_tokens, this->num_attention_heads, this->num_key_value_heads, q, k, position_ids); + + cuda_perf_start_on_stream_f(M4Q_DECODE_ATTN_CORE, stream.stream); + cuda_perf_start_on_stream_f(M4Q_DECODE_ATTN_STAGE1, stream.stream); + + copy_to_kvcache(stream, num_tokens, k, v, kv_cache, cache_length); + + kv_cache->compress(stream); + + uint64_t *blockmask = nullptr; + if ((!apply_compress_lse && kv_cache->c1_len * kv_cache->c1_stride >= this->sparse_switch) || (apply_compress_lse && kv_cache->c2_len * kv_cache->c2_stride >= this->sparse_switch)) { + int q_round, k_round, out_len; + cuda_perf_start_on_stream_f(M4Q_DECODE_ATTN_STAGE1_CORE, stream.stream); + mha_fwd_stage1( + TypeTraits<T>::type_code()==1, + 1, + num_tokens, + kv_cache->c1_len, + apply_compress_lse ? kv_cache->c2_len : kv_cache->c1_len, + this->num_attention_heads, + this->num_key_value_heads, + this->head_dim, + q, + kv_cache->c1_cache, + apply_compress_lse ? kv_cache->c2_cache : kv_cache->c1_cache, + nullptr, + kv_cache->stage1_score, + rsqrtf(float(this->head_dim)), + false, + -1, + -1, + 0, + stream.stream, + q_round, + k_round + ); + cuda_perf_stop_on_stream_f(M4Q_DECODE_ATTN_STAGE1_CORE, stream.stream); + maxpooling_func( + stream.stream, + kv_cache->stage1_score, + kv_cache->pool_score, + this->num_key_value_heads, + num_tokens, + q_round, + k_round, + kv_cache->next_kv_length, + this->sink_window_size, + this->block_window_size, + out_len + ); + kv_cache->topk_func->prefill( + stream, + this->num_key_value_heads*num_tokens, + kv_cache->pool_score, + out_len + ); + topk_to_uint64_func( + stream.stream, + kv_cache->topk_func->topk_pos, + kv_cache->blockmask, + this->num_key_value_heads*num_tokens, + kv_cache->topk_func->top, + padded_length + ); + blockmask = kv_cache->blockmask; + } + cuda_perf_stop_on_stream_f(M4Q_DECODE_ATTN_STAGE1, stream.stream); + + cuda_perf_start_on_stream_f(M4Q_DECODE_ATTN_STAGE2, stream.stream); + mha_fwd_kvcache( + TypeTraits<T>::type_code()==1, + 1, + num_tokens, + padded_length, + this->num_attention_heads, + this->num_key_value_heads, + this->head_dim, + q, + kv_cache->k_cache, + kv_cache->v_cache, + cache_length, + mask, + this->attn_output, + this->softmax_lse, + this->softmax_lse_accum, + this->oaccum, + rsqrtf(float(this->head_dim)), + true, + -1, + -1, + 0, + stream.stream, + blockmask, + blockmask ? this->block_window_size : 0 // TODO fix this condition + ); + cuda_perf_stop_on_stream_f(M4Q_DECODE_ATTN_STAGE2, stream.stream); + cuda_perf_stop_on_stream_f(M4Q_DECODE_ATTN_CORE, stream.stream); + + // flash attention and put output to attn_norm->output + this->o_proj->prefill(stream, num_tokens, this->attn_output, a_tmp, c_tmp); + + kv_cache->next_kv_length = kv_cache->next_kv_length + 1; + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/minicpm4/minicpm4_w4a16_gptq_marlin_layer.cuh b/examples/CPM.cu/src/model/minicpm4/minicpm4_w4a16_gptq_marlin_layer.cuh new file mode 100644 index 00000000..fb5e0e11 --- /dev/null +++ b/examples/CPM.cu/src/model/minicpm4/minicpm4_w4a16_gptq_marlin_layer.cuh @@ -0,0 +1,98 @@ +#pragma once +#include "../layer.cuh" +#include "../w4a16_gptq_marlin/w4a16_gptq_marlin_ffn.cuh" +#include "../../qgemm/gptq_marlin/gptq_marlin.cuh" +#include "minicpm4_w4a16_gptq_marlin_attn.cuh" + +template <typename T> +struct MiniCPM4W4A16GPTQMarlinLayer { + MiniCPM4W4A16GPTQMarlinAttention<T> *attn; + W4A16GPTQMarlinGatedFFN<T> *ffn; + T* output; + + // marlin for gptq marlin + int intermediate_size; + T* a_tmp; + float* c_tmp; + float residual_scale; + int hidden_size; + + MiniCPM4W4A16GPTQMarlinLayer(int hidden_size, int intermediate_size, int num_attention_heads, int num_key_value_heads, int head_dim, float rms_norm_eps, int group_size, float residual_scale = 1.0, int sink_window_size = 1, int block_window_size = 32, int sparse_switch = 8192, bool apply_compress_lse = false) { + this->attn = new MiniCPM4W4A16GPTQMarlinAttention<T>(hidden_size, num_attention_heads, num_key_value_heads, head_dim, rms_norm_eps, group_size, sink_window_size, block_window_size, sparse_switch, apply_compress_lse); + this->ffn = new W4A16GPTQMarlinGatedFFN<T>(hidden_size, intermediate_size, rms_norm_eps, group_size); + this->hidden_size = hidden_size; + this->intermediate_size = intermediate_size; + this->residual_scale = residual_scale; + } + + void init_weight_ptr(Memory* memory) { + this->attn->init_weight_ptr(memory); + this->ffn->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t a_tmp_offset = memory->allocate((void**)&this->a_tmp, offset, 2* num_tokens * intermediate_size * sizeof(T)); + int reduce_max_m = marlin::determine_reduce_max_m(num_tokens, marlin::max_par); + int reduce_n = 2*intermediate_size; + int64_t c_tmp_offset = memory->allocate((void**)&this->c_tmp, a_tmp_offset, reduce_max_m * reduce_n * sizeof(float)); + + int64_t attn_end = this->attn->init_output_ptr(memory, num_tokens, c_tmp_offset); + int64_t ffn_end = this->ffn->init_output_ptr(memory, num_tokens, c_tmp_offset); + this->output = this->ffn->output; + return std::max(attn_end, ffn_end); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.find("attn") != std::string::npos || name.find("input_layernorm") != std::string::npos) { + this->attn->load_to_storage(name, ptr); + } else if (name.find("mlp") != std::string::npos || name.find("post_attention_layernorm") != std::string::npos) { + this->ffn->load_to_storage(name, ptr); + } else { + throw std::invalid_argument("Unsupported name " + name); + } + } + + void prefill(int32_t num_tokens, int32_t num_history_tokens, T* input, T* prev_output, int32_t* position_ids, KVCache<T>* kv_cache, T* prev_layer_states=nullptr) { + if (prev_output != nullptr) { + elementwise_scale(calc_stream, num_tokens, this->hidden_size, prev_output, this->residual_scale); + } + cuda_perf_start_on_stream_f(M4Q_PREFILL_ATTN, calc_stream.stream); + this->attn->prefill(calc_stream, num_tokens, num_history_tokens, input, prev_output, position_ids, static_cast<MiniCPM4KVCache<T>*>(kv_cache), a_tmp, c_tmp); + cuda_perf_stop_on_stream_f(M4Q_PREFILL_ATTN, calc_stream.stream); + if (prev_layer_states != nullptr) { + cudaMemcpyAsync( + prev_layer_states, // dst + input, // src + num_tokens * this->attn->hidden_size * sizeof(T), + cudaMemcpyDeviceToDevice, + calc_stream.stream + ); + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, this->attn->output, this->residual_scale); + cuda_perf_start_on_stream_f(M4Q_PREFILL_FFN, calc_stream.stream); + this->ffn->prefill(calc_stream, num_tokens, input, this->attn->output, a_tmp, c_tmp); + cuda_perf_stop_on_stream_f(M4Q_PREFILL_FFN, calc_stream.stream); + } + + void decode(int32_t num_tokens, int32_t padded_length, T* input, T* prev_output, int32_t* position_ids, int32_t* cache_length, const Mask& mask, KVCache<T>* kv_cache, T* prev_layer_states=nullptr) { + if (prev_output != nullptr) { + elementwise_scale(calc_stream, num_tokens, this->hidden_size, prev_output, this->residual_scale); + } + cuda_perf_start_on_stream_f(M4Q_DECODE_ATTN, calc_stream.stream); + this->attn->decode(calc_stream, num_tokens, padded_length, input, prev_output, position_ids, cache_length, mask, static_cast<MiniCPM4KVCache<T>*>(kv_cache), a_tmp, c_tmp); + cuda_perf_stop_on_stream_f(M4Q_DECODE_ATTN, calc_stream.stream); + if (prev_layer_states != nullptr) { + cudaMemcpyAsync( + prev_layer_states, // dst + input, // src + num_tokens * this->attn->hidden_size * sizeof(T), + cudaMemcpyDeviceToDevice, + calc_stream.stream + ); + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, this->attn->output, this->residual_scale); + cuda_perf_start_on_stream_f(M4Q_DECODE_FFN, calc_stream.stream); + this->ffn->decode(calc_stream, num_tokens, input, this->attn->output, a_tmp, c_tmp); + cuda_perf_stop_on_stream_f(M4Q_DECODE_FFN, calc_stream.stream); + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/minicpm4/minicpm4_w4a16_gptq_marlin_model.cuh b/examples/CPM.cu/src/model/minicpm4/minicpm4_w4a16_gptq_marlin_model.cuh new file mode 100644 index 00000000..528d94c9 --- /dev/null +++ b/examples/CPM.cu/src/model/minicpm4/minicpm4_w4a16_gptq_marlin_model.cuh @@ -0,0 +1,161 @@ +#pragma once +#include "../model.cuh" +#include "minicpm4_w4a16_gptq_marlin_layer.cuh" +#include "minicpm4_kvcache.cuh" + +template <typename T> +struct MiniCPM4W4A16GPTQMarlinModelImpl : Model { + Memory* memory; + + int vocab_size; + int num_hidden_layers; + int hidden_size; + int intermediate_size; + int num_attention_heads; + int num_key_value_heads; + int head_dim; + float rms_norm_eps; + + int chunk_length; + + MiniCPM4KVCacheManager<T>* kv_caches; + + Embedding<T>* embedding; + std::vector<MiniCPM4W4A16GPTQMarlinLayer<T>*> layers; + RMSNorm<T>* norm; + LMHead<T>* lm_head; + float residual_scale; + + MiniCPM4W4A16GPTQMarlinModelImpl( + float memory_limit, + int vocab_size, + int num_hidden_layers, + int hidden_size, + int intermediate_size, + int num_attention_heads, + int num_key_value_heads, + int head_dim, + float rms_norm_eps, + int group_size, + int chunk_length, + float scale_embed = 1.0f, + float scale_lmhead = 1.0f, + float scale_residual = 1.0f, + int sink_window_size = 1, + int block_window_size = 32, + int sparse_topk_k = 32, + int sparse_switch = 8192, + bool apply_compress_lse = false + ) { + this->vocab_size = vocab_size; + this->num_hidden_layers = num_hidden_layers; + this->hidden_size = hidden_size; + this->intermediate_size = intermediate_size; + this->num_attention_heads = num_attention_heads; + this->num_key_value_heads = num_key_value_heads; + this->head_dim = head_dim; + this->rms_norm_eps = rms_norm_eps; + + this->chunk_length = chunk_length; + this->residual_scale = scale_residual; + + memory = new Memory(memory_limit); + + kv_caches = new MiniCPM4KVCacheManager<T>(num_hidden_layers, num_key_value_heads, head_dim, sparse_topk_k, apply_compress_lse); + + embedding = new Embedding<T>(vocab_size, hidden_size, scale_embed); + for (int i = 0; i < num_hidden_layers; i++) { + layers.push_back(new MiniCPM4W4A16GPTQMarlinLayer<T>(hidden_size, intermediate_size, num_attention_heads, num_key_value_heads, head_dim, rms_norm_eps, group_size, residual_scale, sink_window_size, block_window_size, sparse_switch, apply_compress_lse)); + } + norm = new RMSNorm<T>(hidden_size, rms_norm_eps); + lm_head = new LMHead<T>(hidden_size, vocab_size, scale_lmhead); + } + + void init_weight_ptr(Memory* memory) { + embedding->init_weight_ptr(memory); + for (int i = 0; i < num_hidden_layers; i++) { + layers[i]->init_weight_ptr(memory); + } + norm->init_weight_ptr(memory); + lm_head->init_weight_ptr(memory); + kv_caches->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t embedding_end = embedding->init_output_ptr(memory, num_tokens, offset); + int64_t layer_end = 0; + for (int i = 0; i < num_hidden_layers; i++) { + layer_end = layers[i]->init_output_ptr(memory, num_tokens, embedding_end); + } + // norm and lm_head are not used in prefill + int64_t norm_end = norm->init_output_ptr(memory, num_tokens, layer_end); + int64_t lm_head_end = lm_head->init_output_ptr(memory, 64, norm_end); + return lm_head_end; + } + + int init_storage() { + init_weight_ptr(memory); + int64_t kv_cache_offset = init_output_ptr(memory, chunk_length, memory->model_offset); + kv_cache_offset = kv_caches->init_output_ptr(memory, chunk_length, kv_cache_offset); + return this->kv_caches->budget; + } + + void load_to_storage(std::string name, void* ptr) { + if (name.substr(0, 18) == "model.embed_tokens") { + embedding->load_to_storage(name, ptr); + } else if (name.substr(0, 10) == "model.norm") { + norm->load_to_storage(name, ptr); + } else if (name.substr(0, 7) == "lm_head") { + lm_head->load_to_storage(name, ptr); + } else if (name.find("rotary_emb") != std::string::npos) { + kv_caches->rotary_embedding->load_to_storage(name, ptr); + } else if (name.substr(0, 12) == "model.layers") { // e.g. model.layers.20.attn.q_proj.weight + std::regex layer_regex("model\\.layers\\.(\\d+)\\.(.*)"); + std::smatch matches; + if (std::regex_search(name, matches, layer_regex)) { + int layer_idx = std::stoi(matches[1]); + layers[layer_idx]->load_to_storage(matches[2], ptr); + } else { + throw std::invalid_argument("Unsupported name (layer_idx not found): " + name); + } + } else { + throw std::invalid_argument("Unsupported name " + name); + } + } + + void prefill_embed(int32_t num_tokens, int32_t num_history_tokens, T* embed, int32_t* position_ids, void* output) { + T* layer_output = nullptr; + for (int i = 0; i < num_hidden_layers; i++) { + this->layers[i]->prefill(num_tokens, num_history_tokens, embed, layer_output, position_ids, this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, layer_output, this->residual_scale); + this->norm->prefill(calc_stream, num_tokens, embed, layer_output); + this->lm_head->prefill(calc_stream, 1, this->norm->output + (num_tokens - 1) * hidden_size, (T*)output); + } + + void prefill(int32_t num_tokens, int32_t num_history_tokens, int32_t* input, int32_t* position_ids, void* output) { + this->embedding->prefill(calc_stream, num_tokens, input); + prefill_embed(num_tokens, num_history_tokens, this->embedding->output, position_ids, output); + } + + void decode_embed(int32_t num_tokens, int32_t padded_length, T* embed, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + Mask mask(mask_2d, num_tokens, num_tokens); + T* layer_output = nullptr; + for (int i = 0; i < num_hidden_layers; i++) { + this->layers[i]->decode(num_tokens, padded_length, this->embedding->output, layer_output, position_ids, cache_length, mask, this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, layer_output, this->residual_scale); + this->norm->prefill(calc_stream, num_tokens, this->embedding->output, layer_output); + this->lm_head->prefill(calc_stream, num_tokens, this->norm->output, (T*)output); + } + + void decode(int32_t num_tokens, int32_t padded_length, int32_t* input, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + this->embedding->prefill(calc_stream, num_tokens, input); + decode_embed(num_tokens, padded_length, this->embedding->output, position_ids, cache_length, mask_2d, output); + } + + void draft(int32_t *tree_draft_ids, int32_t *tree_position_ids, int32_t *cache_length, uint64_t* attn_mask, int32_t* tree_parent) { throw std::runtime_error("Draft is not supported"); } + int verify(int32_t num_tokens, int32_t* pred, int32_t* gt, int32_t* position_ids, int32_t* cache_length, uint64_t* attn_mask, int32_t* tree_parent) { throw std::runtime_error("Verify is not supported"); } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/model.cuh b/examples/CPM.cu/src/model/model.cuh new file mode 100644 index 00000000..68f2ac51 --- /dev/null +++ b/examples/CPM.cu/src/model/model.cuh @@ -0,0 +1,174 @@ +#pragma once +#include "memory.cuh" +#include "embedding.cuh" +#include "norm.cuh" +#include "linear.cuh" +#include "layer.cuh" +#include "kvcache.cuh" +#include "mask.cuh" +#include <algorithm> +#include <cuda_runtime.h> +#include <vector> +#include <regex> + +struct Model { + virtual int init_storage() = 0; + virtual void load_to_storage(std::string name, void* ptr) = 0; + virtual void prefill(int32_t num_tokens, int32_t num_history_tokens, int32_t* input, int32_t* position_ids, void* output) = 0; + virtual void decode(int32_t num_tokens, int32_t padded_length, int32_t* input, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) = 0; + + virtual void draft(int32_t *tree_draft_ids, int32_t *tree_position_ids, int32_t *cache_length, uint64_t* attn_mask, int32_t* tree_parent) = 0; + virtual int verify(int32_t num_tokens, int32_t* pred, int32_t* gt, int32_t* position_ids, int32_t* cache_length, uint64_t* attn_mask, int32_t* tree_parent) = 0; + /* verify should find max accept length (based on tree_parent and position_ids) and return, fix kvcache (based on position_ids), and make pred[:accept_length] the accept path (based on attn_mask and position_ids) */ +}; + +template <typename T> +struct ModelImpl : Model { + Memory* memory; + + int vocab_size; + int num_hidden_layers; + int hidden_size; + int intermediate_size; + int num_attention_heads; + int num_key_value_heads; + int head_dim; + float rms_norm_eps; + + int chunk_length; + + KVCacheManager<T>* kv_caches; + + Embedding<T>* embedding; + std::vector<Layer<T>*> layers; + RMSNorm<T>* norm; + LMHead<T>* lm_head; + float residual_scale; + + ModelImpl( + float memory_limit, + int vocab_size, + int num_hidden_layers, + int hidden_size, + int intermediate_size, + int num_attention_heads, + int num_key_value_heads, + int head_dim, + float rms_norm_eps, + int chunk_length, + float scale_embed = 1.0f, + float scale_lmhead = 1.0f, + float scale_residual = 1.0f + ) { + this->vocab_size = vocab_size; + this->num_hidden_layers = num_hidden_layers; + this->hidden_size = hidden_size; + this->intermediate_size = intermediate_size; + this->num_attention_heads = num_attention_heads; + this->num_key_value_heads = num_key_value_heads; + this->head_dim = head_dim; + this->rms_norm_eps = rms_norm_eps; + + this->chunk_length = chunk_length; + this->residual_scale = scale_residual; + + memory = new Memory(memory_limit); + + kv_caches = new KVCacheManager<T>(num_hidden_layers, num_key_value_heads, head_dim); + + embedding = new Embedding<T>(vocab_size, hidden_size, scale_embed); + for (int i = 0; i < num_hidden_layers; i++) { + layers.push_back(new Layer<T>(hidden_size, intermediate_size, num_attention_heads, num_key_value_heads, head_dim, rms_norm_eps, residual_scale)); + } + norm = new RMSNorm<T>(hidden_size, rms_norm_eps); + lm_head = new LMHead<T>(hidden_size, vocab_size, scale_lmhead); + } + + void init_weight_ptr(Memory* memory) { + embedding->init_weight_ptr(memory); + for (int i = 0; i < num_hidden_layers; i++) { + layers[i]->init_weight_ptr(memory); + } + norm->init_weight_ptr(memory); + lm_head->init_weight_ptr(memory); + kv_caches->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t embedding_end = embedding->init_output_ptr(memory, num_tokens, offset); + int64_t layer_end = 0; + for (int i = 0; i < num_hidden_layers; i++) { + layer_end = layers[i]->init_output_ptr(memory, num_tokens, embedding_end); + } + // norm and lm_head are not used in prefill + int64_t norm_end = norm->init_output_ptr(memory, num_tokens, layer_end); + int64_t lm_head_end = lm_head->init_output_ptr(memory, 64, norm_end); + return lm_head_end; + } + + int init_storage() { + init_weight_ptr(memory); + int64_t kv_cache_offset = init_output_ptr(memory, chunk_length, memory->model_offset); + kv_cache_offset = kv_caches->init_output_ptr(memory, kv_cache_offset); + return this->kv_caches->budget; + } + + void load_to_storage(std::string name, void* ptr) { + if (name.substr(0, 18) == "model.embed_tokens") { + embedding->load_to_storage(name, ptr); + } else if (name.substr(0, 10) == "model.norm") { + norm->load_to_storage(name, ptr); + } else if (name.substr(0, 7) == "lm_head") { + lm_head->load_to_storage(name, ptr); + } else if (name.find("rotary_emb") != std::string::npos) { + kv_caches->rotary_embedding->load_to_storage(name, ptr); + } else if (name.substr(0, 12) == "model.layers") { // e.g. model.layers.20.attn.q_proj.weight + std::regex layer_regex("model\\.layers\\.(\\d+)\\.(.*)"); + std::smatch matches; + if (std::regex_search(name, matches, layer_regex)) { + int layer_idx = std::stoi(matches[1]); + layers[layer_idx]->load_to_storage(matches[2], ptr); + } else { + throw std::invalid_argument("Unsupported name (layer_idx not found): " + name); + } + } else { + throw std::invalid_argument("Unsupported name " + name); + } + } + + void prefill_embed(int32_t num_tokens, int32_t num_history_tokens, T* embed, int32_t* position_ids, void* output) { + T* layer_output = nullptr; + for (int i = 0; i < num_hidden_layers; i++) { + this->layers[i]->prefill(num_tokens, num_history_tokens, embed, layer_output, position_ids, this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, layer_output, this->residual_scale); + this->norm->prefill(calc_stream, num_tokens, embed, layer_output); + this->lm_head->prefill(calc_stream, 1, this->norm->output + (num_tokens - 1) * hidden_size, (T*)output); + } + + void prefill(int32_t num_tokens, int32_t num_history_tokens, int32_t* input, int32_t* position_ids, void* output) { + this->embedding->prefill(calc_stream, num_tokens, input); + prefill_embed(num_tokens, num_history_tokens, this->embedding->output, position_ids, output); + } + + void decode_embed(int32_t num_tokens, int32_t padded_length, T* embed, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + Mask mask(mask_2d, num_tokens, num_tokens); + T* layer_output = nullptr; + for (int i = 0; i < num_hidden_layers; i++) { + this->layers[i]->decode(num_tokens, padded_length, this->embedding->output, layer_output, position_ids, cache_length, mask, this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, layer_output, this->residual_scale); + this->norm->prefill(calc_stream, num_tokens, this->embedding->output, layer_output); + this->lm_head->prefill(calc_stream, num_tokens, this->norm->output, (T*)output); + } + + void decode(int32_t num_tokens, int32_t padded_length, int32_t* input, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + this->embedding->prefill(calc_stream, num_tokens, input); + decode_embed(num_tokens, padded_length, this->embedding->output, position_ids, cache_length, mask_2d, output); + } + + void draft(int32_t *tree_draft_ids, int32_t *tree_position_ids, int32_t *cache_length, uint64_t* attn_mask, int32_t* tree_parent) { throw std::runtime_error("Draft is not supported"); } + int verify(int32_t num_tokens, int32_t* pred, int32_t* gt, int32_t* position_ids, int32_t* cache_length, uint64_t* attn_mask, int32_t* tree_parent) { throw std::runtime_error("Verify is not supported"); } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/norm.cuh b/examples/CPM.cu/src/model/norm.cuh new file mode 100644 index 00000000..d953521f --- /dev/null +++ b/examples/CPM.cu/src/model/norm.cuh @@ -0,0 +1,154 @@ +#pragma once +#include <cuda_runtime.h> +#include "../trait.cuh" +#include "../utils.cuh" + +namespace { +template <typename T, typename T2> +__global__ void rms_norm_kernel(int dim, const T2* input, const T2* weight, T2* output, float eps) { + // __shared__ T2 s_input[2048]; + __shared__ float shared_sum; + __shared__ float warp_sum[16]; + int row = blockIdx.x; + int col = threadIdx.x; + float sum1 = 0.0f, sum2 = 0.0f; + for (int i = col; i < dim; i += blockDim.x) { + T2 val = input[row * dim + i]; + // s_input[i] = val; + float val1 = float(val.x); + float val2 = float(val.y); + sum1 += val1 * val1; + sum2 += val2 * val2; + } + float sum = sum1 + sum2; + sum += __shfl_down_sync(0xffffffff, sum, 16); + sum += __shfl_down_sync(0xffffffff, sum, 8); + sum += __shfl_down_sync(0xffffffff, sum, 4); + sum += __shfl_down_sync(0xffffffff, sum, 2); + sum += __shfl_down_sync(0xffffffff, sum, 1); + if (col % 32 == 0) warp_sum[col / 32] = sum; + __syncthreads(); + if (col < 16) { + sum = warp_sum[col]; + sum += __shfl_down_sync(0x0000ffff, sum, 8); + sum += __shfl_down_sync(0x0000ffff, sum, 4); + sum += __shfl_down_sync(0x0000ffff, sum, 2); + sum += __shfl_down_sync(0x0000ffff, sum, 1); + } + if (col == 0) { + shared_sum = rsqrtf(sum / (2*dim) + eps); + } + __syncthreads(); + sum = shared_sum; + for (int i = col; i < dim; i += blockDim.x) { + T2 inp = input[row * dim + i]; + T2 w = weight[i]; + output[row * dim + i] = T2( + T(sum * float(inp.x) * float(w.x)), + T(sum * float(inp.y) * float(w.y)) + ); + } +} + +template <typename T, typename T2> +__global__ void add_and_rms_norm_kernel(int dim, T2* input, const T2* prev_output, const T2* weight, T2* output, float eps) { + // __shared__ T2 s_input[2048]; + __shared__ float shared_sum; + __shared__ float warp_sum[16]; + int row = blockIdx.x; + int col = threadIdx.x; + float sum1 = 0.0f, sum2 = 0.0f; + for (int i = col; i < dim; i += blockDim.x) { + T2 val = input[row * dim + i]; + T2 prev = prev_output[row * dim + i]; + val = val + prev; + input[row * dim + i] = val; + float val1 = float(val.x); + float val2 = float(val.y); + sum1 += val1 * val1; + sum2 += val2 * val2; + } + float sum = sum1 + sum2; + sum += __shfl_down_sync(0xffffffff, sum, 16); + sum += __shfl_down_sync(0xffffffff, sum, 8); + sum += __shfl_down_sync(0xffffffff, sum, 4); + sum += __shfl_down_sync(0xffffffff, sum, 2); + sum += __shfl_down_sync(0xffffffff, sum, 1); + if (col % 32 == 0) warp_sum[col / 32] = sum; + __syncthreads(); + if (col < 16) { + sum = warp_sum[col]; + sum += __shfl_down_sync(0x0000ffff, sum, 8); + sum += __shfl_down_sync(0x0000ffff, sum, 4); + sum += __shfl_down_sync(0x0000ffff, sum, 2); + sum += __shfl_down_sync(0x0000ffff, sum, 1); + } + if (col == 0) { + shared_sum = rsqrtf(sum / (2*dim) + eps); + } + __syncthreads(); + sum = shared_sum; + for (int i = col; i < dim; i += blockDim.x) { + T2 inp = input[row * dim + i]; + T2 w = weight[i]; + output[row * dim + i] = T2( + T(sum * float(inp.x) * float(w.x)), + T(sum * float(inp.y) * float(w.y)) + ); + } +} + +template <typename T> +void rms_norm(const Stream& stream, int num_tokens, int dim, const T* input, const T* weight, T* output, float eps) { + using T2 = typename TypeTraits<T>::half2; + rms_norm_kernel<T, T2><<<num_tokens, 512, 0, stream.stream>>>(dim/2, (T2*)input, (T2*)weight, (T2*)output, eps); +} + +template <typename T> +void add_and_rms_norm(const Stream& stream, int num_tokens, int dim, T* input, const T* prev_output, const T* weight, T* output, float eps) { + using T2 = typename TypeTraits<T>::half2; + add_and_rms_norm_kernel<T, T2><<<num_tokens, 512, 0, stream.stream>>>(dim/2, (T2*)input, (T2*)prev_output, (T2*)weight, (T2*)output, eps); +} +} + +template <typename T> +struct Norm { + T* output; + virtual void init_weight_ptr(Memory* memory) = 0; + virtual int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) = 0; + virtual void load_to_storage(std::string name, void* ptr) = 0; + virtual void prefill(const Stream& stream, int32_t num_tokens, T* input, T* prev_output, T* tgt=nullptr) = 0; +}; + +template <typename T> +struct RMSNorm : Norm<T> { + int dim; + float eps; + T* weight; + + RMSNorm(int dim, float eps) { + this->dim = dim; + this->eps = eps; + } + + void init_weight_ptr(Memory* memory) { + weight = (T*)memory->allocate_for_model(dim * sizeof(T)); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + return memory->allocate((void**)&this->output, offset, num_tokens * dim * sizeof(T)); + } + + void load_to_storage(std::string name, void* ptr) { + cudaMemcpy((void*)weight, ptr, dim * sizeof(T), cudaMemcpyHostToDevice); + } + + void prefill(const Stream& stream, int32_t num_tokens, T* input, T* prev_output, T* tgt=nullptr) { + if (tgt == nullptr) tgt = this->output; + if (prev_output == nullptr) { + rms_norm(stream, num_tokens, this->dim, input, this->weight, tgt, this->eps); + } else { + add_and_rms_norm(stream, num_tokens, this->dim, input, prev_output, this->weight, tgt, this->eps); + } + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/rotary.cuh b/examples/CPM.cu/src/model/rotary.cuh new file mode 100644 index 00000000..431f190d --- /dev/null +++ b/examples/CPM.cu/src/model/rotary.cuh @@ -0,0 +1,68 @@ +#pragma once +#include <cuda_runtime.h> +#include "../utils.cuh" + +namespace { +template<typename T> +__global__ void rotary_embedding_kernel(int num_heads, int num_heads_kv, int half_dim, const float *inv_freq, const int* pos, T* q, T* k) { + int tid = threadIdx.x; + + int p = pos[blockIdx.x]; + + for (int i = tid; i < num_heads * half_dim; i += blockDim.x) { + int row = i / half_dim; + int col = i % half_dim; + int offset = blockIdx.x * num_heads * half_dim * 2 + row * half_dim * 2; + float freq = p * inv_freq[col]; + float cos_freq = cos(freq), sin_freq = sin(freq); + float a = float(q[offset + col]); + float b = float(q[offset + col + half_dim]); + q[offset + col] = T(a * cos_freq - b * sin_freq); + q[offset + col + half_dim] = T(a * sin_freq + b * cos_freq); + } + for (int i = tid; i < num_heads_kv * half_dim; i += blockDim.x) { + int row = i / half_dim; + int col = i % half_dim; + int offset = blockIdx.x * num_heads_kv * half_dim * 2 + row * half_dim * 2; + float freq = p * inv_freq[col]; + float cos_freq = cos(freq), sin_freq = sin(freq); + float a = float(k[offset + col]); + float b = float(k[offset + col + half_dim]); + k[offset + col] = T(a * cos_freq - b * sin_freq); + k[offset + col + half_dim] = T(a * sin_freq + b * cos_freq); + } +} +} + +template<typename T> +void rotary_embedding(const Stream& stream, int num_tokens, int num_heads, int num_heads_kv, int half_dim, const float *inv_freq, const int* pos, T* q, T* k) { + rotary_embedding_kernel<T><<<num_tokens, 512, 0, stream.stream>>>(num_heads, num_heads_kv, half_dim, inv_freq, pos, q, k); +} + +template <typename T> +struct RotaryEmbedding { + int half_dim; + + float *inv_freq; + // float attention_scaling; + + RotaryEmbedding(int head_dim) { + this->half_dim = head_dim / 2; + } + + void init_weight_ptr(Memory* memory) { + this->inv_freq = (float*)memory->allocate_for_model(half_dim * sizeof(float)); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.find("inv_freq") != std::string::npos) { + cudaMemcpy((void*)inv_freq, ptr, half_dim * sizeof(float), cudaMemcpyHostToDevice); + } else { + throw std::runtime_error("Unsupported rotary embedding weight name: " + name); + } + } + + void prefill(const Stream& stream, int32_t num_tokens, int num_heads, int num_heads_kv, T* q, T* k, int32_t* position_ids) { + rotary_embedding(stream, num_tokens, num_heads, num_heads_kv, this->half_dim, this->inv_freq, position_ids, q, k); + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/spec_quant/w4a16_gm_spec_w4a16_gm.cuh b/examples/CPM.cu/src/model/spec_quant/w4a16_gm_spec_w4a16_gm.cuh new file mode 100644 index 00000000..47c39821 --- /dev/null +++ b/examples/CPM.cu/src/model/spec_quant/w4a16_gm_spec_w4a16_gm.cuh @@ -0,0 +1,247 @@ +#pragma once +#include "../w4a16_gptq_marlin/w4a16_gptq_marlin_model.cuh" +#include "../eagle.cuh" +#include "../drafter.cuh" + + +template <typename T> +struct W4A16GMSpecW4A16GMImpl: Model { + + + W4A16GPTQMarlinModelImpl<T>* draft_model; + W4A16GPTQMarlinModelImpl<T>* model; + + // draft args + int32_t *draft_input; + int32_t *draft_position_ids, *draft_cache_length; + int * host_draft_cache_length; + int draft_padded_length; + T* draft_logits; + bool is_first_draft; + functions::TopK<T>* topk_func; + int32_t *draft_tmp; + int32_t *h_best, *d_best; + int num_iter; + int num_prev, num_history_tokens; + + // draft mask always nullptr + uint64_t* draft_mask_2d; + + // graph + bool draft_cuda_graph; + int draft_graphCreated_padding_length; + int draft_graphCreated_input_length; + cudaGraph_t draft_graph; + cudaGraphExec_t draft_graphExec; + + W4A16GMSpecW4A16GMImpl( + W4A16GPTQMarlinModelImpl<T>* model, + int draft_vocab_size, + int draft_num_hidden_layers, + int draft_hidden_size, + int draft_intermediate_size, + int draft_num_attention_heads, + int draft_num_key_value_heads, + int draft_head_dim, + float draft_rms_norm_eps, + int draft_group_size, + int num_iter, + bool draft_cuda_graph + ) { + this->model = model; + this->draft_model = new W4A16GPTQMarlinModelImpl<T>( + 0, + draft_vocab_size, + draft_num_hidden_layers, + draft_hidden_size, + draft_intermediate_size, + draft_num_attention_heads, + draft_num_key_value_heads, + draft_head_dim, + draft_rms_norm_eps, + draft_group_size, + this->model->chunk_length + ); + + this->num_iter = num_iter; + + this->draft_mask_2d = 0; + + + topk_func = new functions::TopK<T>(model->vocab_size, 1); // greedy sample + + this->draft_cuda_graph = draft_cuda_graph; + this->draft_graphCreated_padding_length = -1; + this->draft_graphCreated_input_length = -1; + this->draft_graph = nullptr; + this->draft_graphExec = nullptr; + } + + + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t lm_head_end = this->draft_model->init_output_ptr(memory, num_tokens, offset); + offset = lm_head_end; + + offset = memory->allocate((void**)&draft_input, offset, num_tokens * sizeof(int32_t)); + offset = memory->allocate((void**)&draft_position_ids, offset, num_tokens * sizeof(int32_t)); + offset = memory->allocate((void**)&draft_cache_length, offset, sizeof(int32_t)); + cudaMallocHost(&host_draft_cache_length, sizeof(int32_t)); + + + offset = memory->allocate((void**)&draft_logits, offset, 64 * this->draft_model->vocab_size * sizeof(T)); + offset = topk_func->init_output_ptr(memory, 1, offset); + + offset = memory->allocate((void**)&draft_tmp, offset, 16*sizeof(int32_t)); + offset = memory->allocate((void**)&d_best, offset, sizeof(int32_t)); + cudaMallocHost(&h_best, sizeof(int32_t)); + return offset; + } + + int init_storage() { + this->model->init_weight_ptr(this->model->memory); + // this->init_weight_ptr(this->model->memory); + this->draft_model->init_weight_ptr(this->model->memory); + + int64_t offset = this->model->init_output_ptr(this->model->memory, this->model->chunk_length, this->model->memory->model_offset); + int64_t kv_cache_offset = init_output_ptr(this->model->memory, this->model->chunk_length, offset); + + int model_kv_size = (this->model->num_hidden_layers*this->model->num_key_value_heads*this->model->head_dim); + int draft_kv_size = (this->draft_model->num_hidden_layers*this->draft_model->num_key_value_heads*this->draft_model->head_dim); + float ratio = float(model_kv_size)/float(model_kv_size + draft_kv_size); + kv_cache_offset = this->model->kv_caches->init_output_ptr(this->model->memory, kv_cache_offset, ratio); + this->draft_model->kv_caches->init_output_ptr(this->model->memory, kv_cache_offset); + return min(this->draft_model->kv_caches->budget, this->model->kv_caches->budget); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.substr(0, 5) == "draft"){ + std::string draft_name = name.substr(6); + this->draft_model->load_to_storage(draft_name, ptr); + } else { + this->model->load_to_storage(name, ptr); + } + } + + + + void draft_decode_with_graph_control(int32_t num_tokens, int32_t padded_length, int32_t* input, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + if (this->draft_cuda_graph) { + if (this->draft_graphCreated_padding_length != padded_length || this->draft_graphCreated_input_length != num_tokens) { + if (this->draft_graphExec != nullptr) { + cudaGraphExecDestroy(this->draft_graphExec); + this->draft_graphExec = nullptr; + } + if (this->draft_graph != nullptr) { + cudaGraphDestroy(this->draft_graph); + this->draft_graph = nullptr; + } + cudaStreamBeginCapture(calc_stream.stream, cudaStreamCaptureModeGlobal); + // this->draft_decode(num_tokens, padded_length, output); + this->draft_model->decode(num_tokens, padded_length, input, position_ids, cache_length, mask_2d, output); + cudaStreamEndCapture(calc_stream.stream, &(this->draft_graph)); + cudaGraphInstantiate(&(this->draft_graphExec), this->draft_graph, nullptr, nullptr, 0); + this->draft_graphCreated_padding_length = padded_length; + this->draft_graphCreated_input_length = num_tokens; + } + cudaGraphLaunch(this->draft_graphExec, calc_stream.stream); + } else { + // this->draft_decode(num_tokens, padded_length, output); + this->draft_model->decode(num_tokens, padded_length, input, position_ids, cache_length, mask_2d, output); + } + } + + void prefill(int32_t num_tokens, int32_t num_history_tokens, int32_t* input, int32_t* position_ids, void* output) { + this->model->prefill(num_tokens, num_history_tokens, input, position_ids, output); + if (num_history_tokens > 0) { + this->draft_model->prefill(this->num_prev, this->num_history_tokens, this->draft_input, this->draft_position_ids, this->draft_logits); + } + + cudaMemcpy(this->draft_input, input, num_tokens * sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_position_ids, position_ids, num_tokens * sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->num_prev = num_tokens; + this->num_history_tokens = num_history_tokens; + this->is_first_draft = true; + } + + void decode(int32_t num_tokens, int32_t padded_length, int32_t* input, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + this->model->decode(num_tokens, padded_length, input, position_ids, cache_length, nullptr, output); + } + + void draft(int32_t *tree_draft_ids, int32_t *tree_position_ids, int32_t *cache_length, uint64_t*, int32_t*) { + if (this->is_first_draft) { + // append tree draft ids to draft input + cudaMemcpy(this->draft_input+this->num_prev, tree_draft_ids, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_position_ids+this->num_prev, tree_position_ids, sizeof(int32_t), cudaMemcpyDeviceToDevice); + this->num_prev += 1; + this->draft_model->prefill(this->num_prev, this->num_history_tokens, this->draft_input, this->draft_position_ids, (void*)this->draft_logits); + + + cudaMemcpy(this->draft_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->draft_cache_length, 1); + cudaMemcpy(this->draft_position_ids, tree_position_ids, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->host_draft_cache_length, this->draft_cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + // this->draft_padded_length = (this->host_draft_cache_length[0]+ 128 -1) / 128*128; + this->topk_func->prefill(calc_stream, 1, this->draft_logits); + } else if (this->num_prev == 2){ + // this->draft_decode(this->num_prev, this->draft_padded_length, this->draft_logits); + this->draft_model->decode(this->num_prev, this->draft_padded_length, this->draft_input, this->draft_position_ids, this->draft_cache_length, nullptr, (void*)this->draft_logits); + this->topk_func->prefill(calc_stream, 1, this->draft_logits+(this->draft_model->vocab_size)); + add(calc_stream, 1, this->draft_position_ids, 1); + } else { + // num_prev == 1 + this->draft_decode_with_graph_control(this->num_prev, this->draft_padded_length, this->draft_input, this->draft_position_ids, this->draft_cache_length, nullptr, (void*)this->draft_logits); + this->topk_func->prefill(calc_stream, 1, this->draft_logits); + } + + cudaMemcpy(this->draft_input, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_tmp, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + + + for (int d = 1; d < this->num_iter; ++d){ + add(calc_stream, 1, this->draft_cache_length, 1); + add(calc_stream, 1, this->draft_position_ids, 1); + + this->host_draft_cache_length[0] += 1; + this->draft_padded_length = (this->host_draft_cache_length[0]+ 128 -1) / 128*128;; + this->draft_decode_with_graph_control(1, this->draft_padded_length, this->draft_input, this->draft_position_ids, this->draft_cache_length, nullptr, (void*)this->draft_logits); + this->topk_func->prefill(calc_stream, 1, this->draft_logits); + cudaMemcpy(this->draft_input, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_tmp + d, this->topk_func->topk_pos, sizeof(int32_t), cudaMemcpyDeviceToDevice); + } + + cudaMemcpy(tree_draft_ids + 1, this->draft_tmp, num_iter*sizeof(int32_t), cudaMemcpyDeviceToDevice); + make_arange(calc_stream, this->num_iter+1, cache_length, tree_position_ids); + this->is_first_draft = false; + } + + int verify(int32_t num_tokens, int32_t* pred, int32_t* gt, int32_t* position_ids, int32_t* cache_length, uint64_t* attn_mask, int32_t* tree_parent) { + verify_seq_draft(calc_stream, num_tokens, pred, gt, (uint16_t*)attn_mask, this->d_best); + cudaMemcpyAsync(this->h_best, this->d_best, 1 * sizeof(int32_t), cudaMemcpyDeviceToHost, calc_stream.stream); + cudaStreamSynchronize(calc_stream.stream); + + if (h_best[0]==(num_iter+1)) { + // full accept + this->num_prev = 2; + cudaMemcpy(this->draft_input, gt + (num_iter-1), 2*sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->draft_cache_length, this->h_best[0]+1); + make_arange(calc_stream, 2, cache_length, this->draft_position_ids); + add(calc_stream, 2, this->draft_position_ids, num_iter); + cudaMemcpy(this->host_draft_cache_length, this->draft_cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + this->draft_padded_length = (this->host_draft_cache_length[0]+ 128 -1) / 128*128; + } else { + this->num_prev = 1; + cudaMemcpy(this->draft_input, gt + (this->h_best[0]-1), sizeof(int32_t), cudaMemcpyDeviceToDevice); + cudaMemcpy(this->draft_cache_length, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->draft_cache_length, this->h_best[0]+1); + cudaMemcpy(this->draft_position_ids, cache_length, sizeof(int32_t), cudaMemcpyDeviceToDevice); + add(calc_stream, 1, this->draft_position_ids, this->h_best[0]); + cudaMemcpy(this->host_draft_cache_length, this->draft_cache_length, sizeof(int32_t), cudaMemcpyDeviceToHost); + this->draft_padded_length = (this->host_draft_cache_length[0]+ 128 -1) / 128*128; + } + + return h_best[0]; + + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/topk.cuh b/examples/CPM.cu/src/model/topk.cuh new file mode 100644 index 00000000..b623380a --- /dev/null +++ b/examples/CPM.cu/src/model/topk.cuh @@ -0,0 +1,292 @@ +#pragma once +#include "../utils.cuh" +#include "../trait.cuh" +namespace functions { +namespace { +template<typename T, int N> +static __device__ inline void warpBitonicSort(T& v1, int& pos, bool asc) { + int lane_id = threadIdx.x & (N - 1); + #pragma unroll + for (int k = 2; k <= N; k *= 2) { + bool desc = ((lane_id & k) == 0) ^ asc; + #pragma unroll + for (int j = k / 2; j > 0; j /= 2) { + T v2 = __shfl_xor_sync(0xFFFFFFFF, v1, j); + int pos2 = __shfl_xor_sync(0xFFFFFFFF, pos, j); + bool upper = (lane_id & j) != 0; + if (desc ^ (v1 > v2 || (v1 == v2 && pos < pos2)) ^ upper) { + v1 = v2; + pos = pos2; + } + } + } +} +template<typename T, int N> +static __device__ inline void warpBitonicMerge(T& v1, int& pos1, T& v2, int& pos2) { + if (v1 < v2 || (v1 == v2 && pos1 > pos2)) { + v1 = v2; + pos1 = pos2; + } + int lane_id = threadIdx.x & (N - 1); + // resort + #pragma unroll + for (int j = N / 2; j > 0; j /= 2) { + v2 = __shfl_xor_sync(0xFFFFFFFF, v1, j); + int pos2 = __shfl_xor_sync(0xFFFFFFFF, pos1, j); + bool upper = (lane_id & j) != 0; + if ((v1 < v2 || (v1 == v2 && pos1 > pos2)) ^ upper) { + v1 = v2; + pos1 = pos2; + } + } +} +template<typename T, int N> +static __device__ inline void blockBitonicReduce(T& v, int& pos) { + __shared__ T shared_val[1024]; + __shared__ int shared_pos[1024]; + // block reduce + shared_val[threadIdx.x] = v; + shared_pos[threadIdx.x] = pos; + // inter warp reduce + #pragma unroll + for (int i = 512; i >= 32; i >>= 1) { + if (blockDim.x > i) { + __syncthreads(); + if (threadIdx.x < i) { + int idx_next = (i << 1) - threadIdx.x - 1; + T nw_v = (idx_next < blockDim.x) ? shared_val[idx_next] : T(-TypeTraits<T>::inf()); + int nw_pos = (idx_next < blockDim.x) ? shared_pos[idx_next] : -1; + warpBitonicMerge<T, N>(v, pos, nw_v, nw_pos); // merge and rebuild in desc order + shared_val[threadIdx.x] = v; + shared_pos[threadIdx.x] = pos; + } + } + } + // intra warp reduce + if (threadIdx.x < 32) { + warpBitonicSort<T, 32>(v, pos, false); + } +} +template<typename T, int N> +static __global__ void kernel_bitonic_topk( + int n, int top, + T *inp, // (batch, n) + float *out, // (batch, top) + int *idx // (batch, top) +) { + int offset_inp = blockIdx.x * n; + int offset_out = blockIdx.x * top; + T local_v = threadIdx.x < n ? inp[offset_inp + threadIdx.x] : -TypeTraits<T>::inf(); + int local_pos = threadIdx.x; + warpBitonicSort<T, N>(local_v, local_pos, false); // local sort in desc order + for (int i = blockDim.x; i < n; i += blockDim.x) { + T nw_v = (i + threadIdx.x) < n ? inp[offset_inp + i + threadIdx.x] : -TypeTraits<T>::inf(); + int nw_pos = i + threadIdx.x; + // step.1: local sort + warpBitonicSort<T, N>(nw_v, nw_pos, true); // local sort in asc order + // step.2&3: merge and rebuild + warpBitonicMerge<T, N>(local_v, local_pos, nw_v, nw_pos); // merge and rebuild in desc order + } + blockBitonicReduce<T, N>(local_v, local_pos); + if (threadIdx.x < top) { + out[offset_out + threadIdx.x] = local_v; + idx[offset_out + threadIdx.x] = local_pos; + } +} +// intra-block topk +// gridDim(batch, n / 1024, 1), threadDim(1024, 1, 1) +template<typename T, int N, bool ordered> +static __global__ void kernel_bitonic_topk_multiblock( + int n, + const T *inp, // (batch, n) + const int *idx_inp, // (batch, n) + T *out, // (batch, n / 1024 * N) + int *idx // (batch, n / 1024 * N) +) { + int offset_col = blockIdx.y * blockDim.x + threadIdx.x; + int offset_inp = blockIdx.x * n + offset_col; + int offset_out = blockIdx.x * (gridDim.y * N) + blockIdx.y * N + threadIdx.x; + T local_v = (offset_col < n) ? inp[offset_inp] : T(-TypeTraits<T>::inf()); + int local_pos = (idx_inp == nullptr) ? offset_col : ((offset_col < n) ? idx_inp[offset_inp] : -1); + if (!ordered) warpBitonicSort<T, N>(local_v, local_pos, false); // local sort in desc order + blockBitonicReduce<T, N>(local_v, local_pos); + if (threadIdx.x < N) { + out[offset_out] = local_v; + idx[offset_out] = local_pos; + } +} +// copy kernel +// gridDim(batch, 1, 1), blockDim(top, 1, 1) +template<typename T> +static __global__ void kernel_bitonic_topk_multiblock_copy ( + int n, int top, + const T *inp, // (batch, n) + const int *idx_inp, // (batch, n) + T *out, // (batch, top) + int *idx // (batch, top) +) { + int offset_inp = blockIdx.x * n + threadIdx.x; + int offset_out = blockIdx.x * top + threadIdx.x; + if (threadIdx.x < top) { + out[offset_out] = inp[offset_inp]; + idx[offset_out] = idx_inp[offset_inp]; + } +} +#define TOPK_SIZE_DISPATCH(top, ...) \ + do { \ + const int &top_v = top; \ + if (top_v > 16) { \ + const int top_size = 32; \ + __VA_ARGS__ \ + } else if (top_v > 8) { \ + const int top_size = 16; \ + __VA_ARGS__ \ + } else if (top_v > 4) { \ + const int top_size = 8; \ + __VA_ARGS__ \ + } else if (top_v > 2) { \ + const int top_size = 4; \ + __VA_ARGS__ \ + } else if (top_v > 1) { \ + const int top_size = 2; \ + __VA_ARGS__ \ + } else { \ + const int top_size = 1; \ + __VA_ARGS__ \ + } \ + } while(0) +template <typename T> +void bitonic_topk( + const Stream& stream, + const int batch, + const int n, + const int top, + const T* x, + T* out, + int* pos, + T* buf_val, + int* buf_pos, + T* nw_buf_val, + int* nw_buf_pos +) { + TOPK_SIZE_DISPATCH(top, { + bool first = true; + dim3 blockDim(1024, 1, 1); + unsigned int tmp_n = n; + do { + dim3 gridDim(batch, CEIL_DIV(tmp_n, 1024), 1); + if (first) { + first = false; + kernel_bitonic_topk_multiblock<T, top_size, false><<<gridDim, blockDim, 0, stream.stream>>>( + tmp_n, + x, + nullptr, + buf_val, + buf_pos + ); + } else { + kernel_bitonic_topk_multiblock<T, top_size, false><<<gridDim, blockDim, 0, stream.stream>>>( + tmp_n, + buf_val, + buf_pos, + nw_buf_val, + nw_buf_pos + ); + buf_val = nw_buf_val; + buf_pos = nw_buf_pos; + } + tmp_n = CEIL_DIV(tmp_n, 1024) * top_size; + } while (tmp_n > top_size); + // copy to output tensor + { + dim3 gridDim(batch, 1, 1); + blockDim = dim3(top_size, 1, 1); + kernel_bitonic_topk_multiblock_copy<T><<<gridDim, blockDim, 0, stream.stream>>>( + top_size, top, + buf_val, + buf_pos, + out, + pos + ); + } + }); +} + +template<typename T> +static __global__ void set_topk_to_neg_inf_kernel(int dim, int top, T* x, const int* topk_pos) { + x[blockIdx.x * dim + topk_pos[blockIdx.x * top + threadIdx.x]] = -TypeTraits<T>::inf(); +} +} // namespace + +template<typename T> +void set_topk_to_neg_inf(const Stream& stream, int num_tokens, int dim, int top, int num, T* x, const int* topk_pos) { + set_topk_to_neg_inf_kernel<<<num_tokens, num, 0, stream.stream>>>(dim, top, x, topk_pos); +} + +template <typename T> +struct TopK { +private: + T *buf_val, *nw_buf_val; + int *buf_pos, *nw_buf_pos; +public: + int dim, top; + T* topk_val; + int* topk_pos; + T* tmp_x; + + TopK(const int dim, const int top) { + this->dim = dim; + this->top = top; + } + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + TOPK_SIZE_DISPATCH(top, { + offset = memory->allocate((void**)&buf_val, offset, num_tokens * CEIL_DIV(dim, 1024) * top_size * sizeof(T)); + offset = memory->allocate((void**)&buf_pos, offset, num_tokens * CEIL_DIV(dim, 1024) * top_size * sizeof(int)); + offset = memory->allocate((void**)&nw_buf_val, offset, num_tokens * CEIL_DIV(dim, 1024) * top_size * sizeof(T)); + offset = memory->allocate((void**)&nw_buf_pos, offset, num_tokens * CEIL_DIV(dim, 1024) * top_size * sizeof(int)); + }); + if (top > 32) { // tmp fix + offset = memory->allocate((void**)&tmp_x, offset, num_tokens * dim * sizeof(T)); + } + offset = memory->allocate((void**)&topk_val, offset, num_tokens * top * sizeof(T)); + offset = memory->allocate((void**)&topk_pos, offset, num_tokens * top * sizeof(int)); + return offset; + } + void prefill( + const Stream& stream, + int num_tokens, + const T* input, + int dim = -1, + int top = -1 + ) { + assert(dim == -1 || dim <= this->dim); + assert(top == -1 || top <= this->top); + if (dim == -1) dim = this->dim; + if (top == -1) top = this->top; + bitonic_topk<T>( + stream, + num_tokens, + dim, top, + input, + this->topk_val, this->topk_pos, + this->buf_val, this->buf_pos, + this->nw_buf_val, this->nw_buf_pos + ); + if (top > 32) { + for (int i = 32; i < top; i += 32) { + cudaMemcpyAsync(this->tmp_x, input, num_tokens * dim * sizeof(T), cudaMemcpyDeviceToDevice, stream.stream); + set_topk_to_neg_inf(stream, num_tokens, dim, top, 32, this->tmp_x, this->topk_pos + (i - 32)); + bitonic_topk<T>( + stream, + num_tokens, + dim, top, + this->tmp_x, + this->topk_val + i, this->topk_pos + i, + this->buf_val, this->buf_pos, + this->nw_buf_val, this->nw_buf_pos + ); + } + } + } +}; +} // namespace functions \ No newline at end of file diff --git a/examples/CPM.cu/src/model/tree_drafter.cuh b/examples/CPM.cu/src/model/tree_drafter.cuh new file mode 100644 index 00000000..4e5e6379 --- /dev/null +++ b/examples/CPM.cu/src/model/tree_drafter.cuh @@ -0,0 +1,111 @@ +#pragma once +#include <cuda_runtime.h> + +namespace { +__global__ void verify_kernel(int num_tokens, int32_t* pred, const int32_t* gt, const int32_t* position_ids, const int32_t* cache_length, const uint64_t* attn_mask, const int32_t* tree_parent, int32_t* d_best) { + int i = threadIdx.x; + + __shared__ uint64_t s_correct_mask[2]; + uint64_t correct_mask = 1; + if (0 < i && i < num_tokens && pred[i] == gt[tree_parent[i]]) correct_mask |= 1ULL << i; + correct_mask |= __shfl_down_sync(0xffffffff, correct_mask, 16); + correct_mask |= __shfl_down_sync(0xffffffff, correct_mask, 8); + correct_mask |= __shfl_down_sync(0xffffffff, correct_mask, 4); + correct_mask |= __shfl_down_sync(0xffffffff, correct_mask, 2); + correct_mask |= __shfl_down_sync(0xffffffff, correct_mask, 1); + if (i % 32 == 0) s_correct_mask[i / 32] = correct_mask; + __syncthreads(); + if (i == 0) s_correct_mask[0] |= s_correct_mask[1]; + __syncthreads(); + correct_mask = s_correct_mask[0]; + + __shared__ int32_t mx[64], mx_idx[64]; + int prefix_length = cache_length[0]; + if (i < num_tokens && ((correct_mask & attn_mask[i]) == attn_mask[i])) { + mx[i] = position_ids[i] - prefix_length + 1; mx_idx[i] = i; + } else { + mx[i] = 1; mx_idx[i] = 0; + } + __syncthreads(); + for (int offset = 32; offset > 0; offset >>= 1) { + if (i < offset && mx[i + offset] > mx[i]) { + mx[i] = mx[i + offset]; + mx_idx[i] = mx_idx[i + offset]; + } + __syncthreads(); + } + if (i == 0) { + d_best[0] = mx[0]; d_best[1] = mx_idx[0]; + } + __syncthreads(); + + int p = mx_idx[0]; + if (i < num_tokens && (attn_mask[p] >> i & 1)) { + pred[position_ids[i] - prefix_length] = i; + } +} + +template<typename T> +__global__ void fix_kvcache_kernel_1(int num_caches, int dim, int32_t* pred, const int32_t* gt, const int32_t* cache_length, const T* const* flat_caches, float4* tmp_kvcache) { + int i = blockIdx.x; + int j = threadIdx.x; + int k = blockIdx.y; + int prefix_length = cache_length[0]; + int real_i = pred[i] + prefix_length; + float4* tmp = tmp_kvcache + i * num_caches * dim; + const float4* flat = (const float4*)flat_caches[k]; + for (int d = j; d < dim; d += blockDim.x) { + tmp[k * dim + d] = flat[real_i * dim + d]; + } +} + +template<typename T> +__global__ void fix_kvcache_kernel_2(int num_caches, int dim, int32_t* pred, const int32_t* gt, const int32_t* cache_length, T** flat_caches, const float4* tmp_kvcache) { + int i = blockIdx.x; + int j = threadIdx.x; + int k = blockIdx.y; + int prefix_length = cache_length[0]; + int real_i = i + prefix_length; + const float4* tmp = tmp_kvcache + i * num_caches * dim; + float4* flat = (float4*)flat_caches[k]; + for (int d = j; d < dim; d += blockDim.x) { + flat[real_i * dim + d] = tmp[k * dim + d]; + } + if (j == 0 && k == 0) { + pred[i] = gt[pred[i]]; + } +} + +template<typename T> +__global__ void remap_copy_kernel(int32_t dim, const T* src, T* dst, const int32_t* token_id_remap) { + int row = blockIdx.x; + int real_row = token_id_remap[row]; + for (int i = threadIdx.x; i < dim; i += blockDim.x) { + dst[row * dim + i] = src[real_row * dim + i]; + } +} + +__global__ void remap_kernel(int32_t num_tokens, const int32_t* input, int32_t* output, const int32_t* token_id_remap) { + output[threadIdx.x] = token_id_remap[input[threadIdx.x]]; +} +} + +void verify_draft(const Stream& stream, int num_tokens, int32_t* pred, const int32_t* gt, const int32_t* position_ids, const int32_t* cache_length, const uint64_t* attn_mask, const int32_t* tree_parent, int32_t* best) { + verify_kernel<<<1, 64, 0, stream.stream>>>(num_tokens, pred, gt, position_ids, cache_length, attn_mask, tree_parent, best); +} + +template<typename T> +void fix_kv_cache(const Stream& stream, int accept_length, int num_caches, int dim, int32_t* pred, const int32_t* gt, const int32_t* cache_length, T** flat_caches, T* tmp_kvcache) { + fix_kvcache_kernel_1<T><<<dim3(accept_length, num_caches, 1), 256, 0, stream.stream>>>(num_caches, dim/(16/sizeof(T)), pred, gt, cache_length, flat_caches, (float4*)tmp_kvcache); + fix_kvcache_kernel_2<T><<<dim3(accept_length, num_caches, 1), 256, 0, stream.stream>>>(num_caches, dim/(16/sizeof(T)), pred, gt, cache_length, flat_caches, (float4*)tmp_kvcache); +} + +template<typename T> +void remap_copy(const Stream& stream, const T* src, T* dst, int32_t dim, int32_t num_tokens, const int32_t* token_id_remap) { + dim = dim / (16 / sizeof(T)); + remap_copy_kernel<<<num_tokens, 512, 0, stream.stream>>>(dim, (float4*)src, (float4*)dst, token_id_remap); +} + +void remap(const Stream& stream, int32_t num_tokens, const int32_t* input, int32_t* output, const int32_t* token_id_remap) { + remap_kernel<<<1, num_tokens, 0, stream.stream>>>(num_tokens, input, output, token_id_remap); +} \ No newline at end of file diff --git a/examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_attn.cuh b/examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_attn.cuh new file mode 100644 index 00000000..6b8a78bd --- /dev/null +++ b/examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_attn.cuh @@ -0,0 +1,178 @@ +#pragma once +#include "../attn.cuh" +#include "w4a16_gptq_marlin_linear.cuh" + + +template <typename T> +struct W4A16GPTQMarlinAttention { + int hidden_size; + int num_attention_heads; + int num_key_value_heads; + int head_dim; + float rms_norm_eps; + + Norm<T> *attn_norm; + W4A16GPTQMarlinLinear<T> *qkv_proj; + W4A16GPTQMarlinLinear<T> *o_proj; + T* output; + + T* attn_output; + float *softmax_lse, *softmax_lse_accum, *oaccum; + + int window_size; + + T* q_proj_output, *v_proj_output, *k_proj_output; + T* permute_qkv_output; + + W4A16GPTQMarlinAttention(int hidden_size, int num_attention_heads, int num_key_value_heads, int head_dim, float rms_norm_eps, int group_size, int window_size = 0) { + this->hidden_size = hidden_size; + this->num_attention_heads = num_attention_heads; + this->num_key_value_heads = num_key_value_heads; + this->head_dim = head_dim; + this->rms_norm_eps = rms_norm_eps; + + this->attn_norm = new RMSNorm<T>(hidden_size, rms_norm_eps); + + this->qkv_proj = new W4A16GPTQMarlinLinear<T>(hidden_size, (num_attention_heads + 2*num_key_value_heads) * head_dim, group_size); + + this->o_proj = new W4A16GPTQMarlinLinear<T>(hidden_size, num_attention_heads * head_dim, group_size); + + this->window_size = window_size; + } + + void init_weight_ptr(Memory* memory) { + this->attn_norm->init_weight_ptr(memory); + this->qkv_proj->init_weight_ptr(memory); + this->o_proj->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t attn_norm_end = this->attn_norm->init_output_ptr(memory, num_tokens, offset); + int64_t qkv_proj_end = this->qkv_proj->init_output_ptr(memory, num_tokens, attn_norm_end); + + this->q_proj_output = this->qkv_proj->output; + this->k_proj_output = this->qkv_proj->output + num_tokens * this->num_attention_heads * this->head_dim; + this->v_proj_output = this->qkv_proj->output + num_tokens * (this->num_attention_heads+this->num_key_value_heads) * this->head_dim; + int64_t qkv_permute_end = memory->allocate((void**)&this->permute_qkv_output, qkv_proj_end, num_tokens * (this->num_attention_heads + 2*this->num_key_value_heads) * this->head_dim * sizeof(T)); + + int64_t attn_output_end = memory->allocate((void**)&this->attn_output, offset, num_tokens * this->num_attention_heads * this->head_dim * sizeof(T)); + int64_t softmax_lse_end = memory->allocate((void**)&this->softmax_lse, qkv_permute_end, num_tokens * this->num_attention_heads * sizeof(float)); + const int max_num_splits = 128; // Maximum number of splits for attention computation + const int max_spec_tree_size = 64; // Maximum size of speculative decoding tree + int64_t softmax_lse_accum_end = memory->allocate((void**)&this->softmax_lse_accum, softmax_lse_end, max(max_num_splits * max_spec_tree_size, num_tokens) * this->num_attention_heads * sizeof(float)); + int64_t oaccum_end = memory->allocate((void**)&this->oaccum, softmax_lse_accum_end, max(max_num_splits * max_spec_tree_size, num_tokens) * this->num_attention_heads * this->head_dim * sizeof(float)); + + int64_t o_proj_end = this->o_proj->init_output_ptr(memory, num_tokens, qkv_permute_end); + this->output = this->o_proj->output; + + return std::max(oaccum_end, o_proj_end); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.find("qkv_proj") != std::string::npos) { + this->qkv_proj->load_to_storage(name, ptr); + } else if (name.find("o_proj") != std::string::npos) { + this->o_proj->load_to_storage(name, ptr); + } else if (name.find("input_layernorm") != std::string::npos) { + this->attn_norm->load_to_storage(name, ptr); + } else { + throw std::invalid_argument("Attn Unsupported name " + name); + } + } + + void prefill(const Stream& stream, int32_t num_tokens, int32_t num_history_tokens, T* input, T* prev_output, int32_t* position_ids, KVCache<T>* kv_cache, T* a_tmp, float* c_tmp) { + T* k_cache = kv_cache->offset_k(num_history_tokens); + T* v_cache = kv_cache->offset_v(num_history_tokens); + + this->attn_norm->prefill(stream, num_tokens, input, prev_output); + this->qkv_proj->prefill(stream, num_tokens, this->attn_norm->output, a_tmp, c_tmp); + permute(stream, num_tokens, this->num_attention_heads * this->head_dim, this->num_key_value_heads * this->head_dim, this->qkv_proj->output, this->permute_qkv_output); + cudaMemcpy(k_cache, this->permute_qkv_output + num_tokens*this->num_attention_heads*this->head_dim, num_tokens*this->num_key_value_heads*this->head_dim*sizeof(T), cudaMemcpyDeviceToDevice); + cudaMemcpy(v_cache, this->permute_qkv_output + num_tokens*( this->num_attention_heads + this->num_key_value_heads)*this->head_dim, num_tokens*this->num_key_value_heads*this->head_dim*sizeof(T), cudaMemcpyDeviceToDevice); + + + kv_cache->rotary_embedding->prefill(stream, num_tokens, this->num_attention_heads, this->num_key_value_heads, this->permute_qkv_output, k_cache, position_ids); + + cuda_perf_start_on_stream_f(Q_PREFILL_ATTN_CORE, stream.stream); + mha_fwd_kvcache( + TypeTraits<T>::type_code()==1, + 1, + num_tokens, + num_history_tokens+num_tokens, + this->num_attention_heads, + this->num_key_value_heads, + this->head_dim, + this->permute_qkv_output, + kv_cache->k_cache, + kv_cache->v_cache, + nullptr, + Mask(nullptr), + this->attn_output, + this->softmax_lse, + this->softmax_lse_accum, + this->oaccum, + rsqrtf(float(this->head_dim)), + true, + -1, + -1, + 0, + stream.stream, + nullptr, + this->window_size + ); + cuda_perf_stop_on_stream_f(Q_PREFILL_ATTN_CORE, stream.stream); + + // flash attention and put output to attn_norm->output + this->o_proj->prefill(stream, num_tokens, this->attn_output, a_tmp, c_tmp); + } + + void decode(const Stream& stream, int32_t num_tokens, int32_t padded_length, T* input, T* prev_output, int32_t* position_ids, int32_t* cache_length, const Mask& mask, KVCache<T>* kv_cache, T* a_tmp, float* c_tmp) { + this->attn_norm->prefill(stream, num_tokens, input, prev_output); + T *q, *k, *v; + + if (num_tokens > 1) { + this->qkv_proj->prefill(stream, num_tokens, this->attn_norm->output, a_tmp, c_tmp); + permute(stream, num_tokens, this->num_attention_heads * this->head_dim, this->num_key_value_heads * this->head_dim, this->qkv_proj->output, this->permute_qkv_output); // TODO: Double check + q = this->permute_qkv_output; + } else { + this->qkv_proj->prefill(stream, num_tokens, this->attn_norm->output, a_tmp, c_tmp); + q = this->qkv_proj->output; + } + k = q + num_tokens * this->num_attention_heads * this->head_dim; + v = k + num_tokens * this->num_key_value_heads * this->head_dim; + kv_cache->rotary_embedding->prefill(stream, num_tokens, this->num_attention_heads, this->num_key_value_heads, q, k, position_ids); + + copy_to_kvcache(stream, num_tokens, k, v, kv_cache, cache_length); + + cuda_perf_start_on_stream_f(Q_DECODE_ATTN_CORE, stream.stream); + mha_fwd_kvcache( + TypeTraits<T>::type_code()==1, + 1, + num_tokens, + padded_length, + this->num_attention_heads, + this->num_key_value_heads, + this->head_dim, + q, + kv_cache->k_cache, + kv_cache->v_cache, + cache_length, + mask, + this->attn_output, + this->softmax_lse, + this->softmax_lse_accum, + this->oaccum, + rsqrtf(float(this->head_dim)), + true, + -1, + -1, + 0, + stream.stream, + nullptr, + this->window_size + ); + cuda_perf_stop_on_stream_f(Q_DECODE_ATTN_CORE, stream.stream); + // flash attention and put output to attn_norm->output + this->o_proj->prefill(stream, num_tokens, this->attn_output, a_tmp, c_tmp); + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_ffn.cuh b/examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_ffn.cuh new file mode 100644 index 00000000..f5e0d06c --- /dev/null +++ b/examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_ffn.cuh @@ -0,0 +1,70 @@ +#pragma once +#include "../ffn.cuh" +#include "w4a16_gptq_marlin_linear.cuh" +#include "../activation.cuh" +#include <cuda_runtime.h> + + +template <typename T> +struct W4A16GPTQMarlinGatedFFN { + int hidden_size; + int intermediate_size; + float rms_norm_eps; + + Norm<T> *ffn_norm; + W4A16GPTQMarlinLinear<T> *gate_up_proj; + W4A16GPTQMarlinLinear<T> *down_proj; + + T* output; + T* gated_up; + + W4A16GPTQMarlinGatedFFN(int hidden_size, int intermediate_size, float rms_norm_eps, int group_size) { + this->hidden_size = hidden_size; + this->intermediate_size = intermediate_size; + this->rms_norm_eps = rms_norm_eps; + + this->ffn_norm = new RMSNorm<T>(hidden_size, rms_norm_eps); + this->gate_up_proj = new W4A16GPTQMarlinLinear<T>(hidden_size, intermediate_size*2, group_size); + this->down_proj = new W4A16GPTQMarlinLinear<T>(intermediate_size, hidden_size, group_size); + } + + void init_weight_ptr(Memory* memory) { + this->ffn_norm->init_weight_ptr(memory); + this->gate_up_proj->init_weight_ptr(memory); + this->down_proj->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t ffn_norm_end = this->ffn_norm->init_output_ptr(memory, num_tokens, offset); + int64_t gate_up_proj_end = this->gate_up_proj->init_output_ptr(memory, num_tokens, ffn_norm_end); + int64_t gated_up_end = memory->allocate((void**)&this->gated_up, gate_up_proj_end, num_tokens * intermediate_size * sizeof(T)); + int64_t down_proj_end = this->down_proj->init_output_ptr(memory, num_tokens, gated_up_end); + this->output = this->down_proj->output; + return down_proj_end; + } + + void load_to_storage(std::string name, void* ptr) { + if (name.find("gate_up_proj") != std::string::npos) { + this->gate_up_proj->load_to_storage(name, ptr); + } else if (name.find("down_proj") != std::string::npos) { + this->down_proj->load_to_storage(name, ptr); + } else if (name.find("post_attention_layernorm") != std::string::npos) { + this->ffn_norm->load_to_storage(name, ptr); + } else { + throw std::invalid_argument("FFN Unsupported name " + name); + } + } + + void prefill(const Stream& stream, int32_t num_tokens, T* input, T* prev_output, T* a_tmp, float* c_tmp) { + this->ffn_norm->prefill(stream, num_tokens, input, prev_output); + + this->gate_up_proj->prefill(stream, num_tokens, this->ffn_norm->output, a_tmp, c_tmp); + gated_silu_interleaved<T>(stream, num_tokens, this->intermediate_size, this->gate_up_proj->output, this->gated_up); + this->down_proj->prefill(stream, num_tokens, this->gated_up, a_tmp, c_tmp); + + } + + void decode(const Stream& stream, int32_t num_tokens, T* input, T* prev_output, T* a_tmp, float* c_tmp) { + prefill(stream, num_tokens, input, prev_output, a_tmp, c_tmp); + } +}; diff --git a/examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_layer.cuh b/examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_layer.cuh new file mode 100644 index 00000000..5ac8c4c8 --- /dev/null +++ b/examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_layer.cuh @@ -0,0 +1,98 @@ +#pragma once +#include "w4a16_gptq_marlin_attn.cuh" +#include "w4a16_gptq_marlin_ffn.cuh" +#include "../layer.cuh" +#include "../../qgemm/gptq_marlin/gptq_marlin.cuh" + +template <typename T> +struct W4A16GPTQMarlinLayer { + W4A16GPTQMarlinAttention<T> *attn; + W4A16GPTQMarlinGatedFFN<T> *ffn; + T* output; + + // marlin for gptq marlin + int intermediate_size; + T* a_tmp; + float* c_tmp; + float residual_scale; + int hidden_size; + + W4A16GPTQMarlinLayer(int hidden_size, int intermediate_size, int num_attention_heads, int num_key_value_heads, int head_dim, float rms_norm_eps, int group_size, float residual_scale = 1.0f, int window_size = 0) { + this->intermediate_size = intermediate_size; + this->attn = new W4A16GPTQMarlinAttention<T>(hidden_size, num_attention_heads, num_key_value_heads, head_dim, rms_norm_eps, group_size, window_size); + this->ffn = new W4A16GPTQMarlinGatedFFN<T>(hidden_size, intermediate_size, rms_norm_eps, group_size); + this->residual_scale = residual_scale; + this->hidden_size = hidden_size; + } + + void init_weight_ptr(Memory* memory) { + this->attn->init_weight_ptr(memory); + this->ffn->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t a_tmp_offset = memory->allocate((void**)&this->a_tmp, offset, 2* num_tokens * intermediate_size * sizeof(T)); + int reduce_max_m = marlin::determine_reduce_max_m(num_tokens, marlin::max_par); + int reduce_n = 2*intermediate_size; + int64_t c_tmp_offset = memory->allocate((void**)&this->c_tmp, a_tmp_offset, reduce_max_m * reduce_n * sizeof(float)); + + int64_t attn_end = this->attn->init_output_ptr(memory, num_tokens, c_tmp_offset); + int64_t ffn_end = this->ffn->init_output_ptr(memory, num_tokens, c_tmp_offset); + this->output = this->ffn->output; + return std::max(attn_end, ffn_end); + } + + void load_to_storage(std::string name, void* ptr) { + if (name.find("attn") != std::string::npos || name.find("input_layernorm") != std::string::npos) { + this->attn->load_to_storage(name, ptr); + } else if (name.find("mlp") != std::string::npos || name.find("post_attention_layernorm") != std::string::npos) { + this->ffn->load_to_storage(name, ptr); + } else { + throw std::invalid_argument("Layer Unsupported name " + name); + } + } + + void prefill(int32_t num_tokens, int32_t num_history_tokens, T* input, T* prev_output, int32_t* position_ids, KVCache<T>* kv_cache, T* prev_layer_states=nullptr) { + if (prev_output != nullptr) { + elementwise_scale(calc_stream, num_tokens, this->hidden_size, prev_output, this->residual_scale); + } + cuda_perf_start_on_stream_f(Q_PREFILL_ATTN, calc_stream.stream); + this->attn->prefill(calc_stream, num_tokens, num_history_tokens, input, prev_output, position_ids, kv_cache, a_tmp, c_tmp); + cuda_perf_stop_on_stream_f(Q_PREFILL_ATTN, calc_stream.stream); + if (prev_layer_states != nullptr) { + cudaMemcpyAsync( + prev_layer_states, // dst + input, // src + num_tokens * this->attn->hidden_size * sizeof(T), + cudaMemcpyDeviceToDevice, + calc_stream.stream + ); + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, this->attn->output, this->residual_scale); + cuda_perf_start_on_stream_f(Q_PREFILL_FFN, calc_stream.stream); + this->ffn->prefill(calc_stream, num_tokens, input, this->attn->output, a_tmp, c_tmp); + cuda_perf_stop_on_stream_f(Q_PREFILL_FFN, calc_stream.stream); + } + + void decode(int32_t num_tokens, int32_t padded_length, T* input, T* prev_output, int32_t* position_ids, int32_t* cache_length, const Mask& mask, KVCache<T>* kv_cache, T* prev_layer_states=nullptr) { + if (prev_output != nullptr) { + elementwise_scale(calc_stream, num_tokens, this->hidden_size, prev_output, this->residual_scale); + } + cuda_perf_start_on_stream_f(Q_DECODE_ATTN, calc_stream.stream); + this->attn->decode(calc_stream, num_tokens, padded_length, input, prev_output, position_ids, cache_length, mask, kv_cache, a_tmp, c_tmp); + cuda_perf_stop_on_stream_f(Q_DECODE_ATTN, calc_stream.stream); + if (prev_layer_states != nullptr) { + cudaMemcpyAsync( + prev_layer_states, // dst + input, // src + num_tokens * this->attn->hidden_size * sizeof(T), + cudaMemcpyDeviceToDevice, + calc_stream.stream + ); + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, this->attn->output, this->residual_scale); + cuda_perf_start_on_stream_f(Q_DECODE_FFN, calc_stream.stream); + this->ffn->decode(calc_stream, num_tokens, input, this->attn->output, a_tmp, c_tmp); + cuda_perf_stop_on_stream_f(Q_DECODE_FFN, calc_stream.stream); + } +}; diff --git a/examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_linear.cuh b/examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_linear.cuh new file mode 100644 index 00000000..99a698d8 --- /dev/null +++ b/examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_linear.cuh @@ -0,0 +1,143 @@ +#pragma once +#include <cuda_runtime.h> +#include <cublas_v2.h> +#include <cuda_fp16.h> +#include "../linear.cuh" +#include "../../qgemm/gptq_marlin/marlin.cuh" +#include "../../qgemm/gptq_marlin/gptq_marlin.cuh" +#include "../../qgemm/gptq_marlin/core/scalar_type.hpp" + +template <typename T, bool transposed=true, bool has_bias=false> +struct W4A16GPTQMarlinLinear { + int dim_in; + int dim_out; + + T* output; + int32_t* weight; + T* bias; + T* scales; + + // just placeholder + int32_t* qzeros; + int32_t* g_idx; + int32_t* perm; + int32_t* workspace; + + // new added + const vllm::ScalarType weight_scalar_dtype; + bool is_k_full; + bool use_fp32_reduce; // be true is better + + int num_groups; + int group_size; + bool has_act_order; + + + W4A16GPTQMarlinLinear(int dim_in, int dim_out, int group_size) + :weight_scalar_dtype(static_cast<uint8_t>(0), + static_cast<uint8_t>(4), + false, + static_cast<int32_t>(8), + false) // Initialize weight_scalar_dtype in the constructor + { + this->dim_in = dim_in; + this->dim_out = dim_out; + + // place holder + this->qzeros = 0; + this->g_idx = 0; + this->perm = 0; + + this->is_k_full = true; + this->use_fp32_reduce = true; + this->group_size = group_size; + if (this->group_size == 128){ + this->num_groups = (dim_in) / group_size; + } else if (this->group_size == -1){ + this->num_groups = 1; + } else { + throw std::invalid_argument("Unsupported group size"); + } + + this->has_act_order = false; + + } + + void init_weight_ptr(Memory* memory) { + const int w_size = this->dim_in * this->dim_out / 8; + weight = (int32_t*)memory->allocate_for_model(w_size*sizeof(int32_t)); + const int s_size = this->num_groups * this->dim_out ; + scales = (T*)memory->allocate_for_model(s_size * sizeof(T)); + + const int workspace_size = (this->dim_out / 64)*16; + workspace = (int32_t*)memory->allocate_for_model(workspace_size * sizeof(int32_t)); + + if constexpr (has_bias) { + bias = (T*)memory->allocate_for_model(dim_out * sizeof(T)); + } + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t output_offset = memory->allocate((void**)&this->output, offset, num_tokens * dim_out * sizeof(T)); + return output_offset; + } + + + void load_to_storage(std::string name, void* ptr) { + if (name.find("scales") != std::string::npos) { + const int s_size = this->num_groups * this->dim_out; + cudaMemcpy((void*)scales, ptr, s_size*sizeof(T), cudaMemcpyHostToDevice); + } else if (name.find("qweight") != std::string::npos) { + const int w_size = this->dim_in * this->dim_out / 8; + cudaMemcpy((void*)weight, ptr, w_size*sizeof(int32_t), cudaMemcpyHostToDevice); + } else if (name.find("bias") != std::string::npos) { + cudaMemcpy((void*)bias, ptr, dim_out * sizeof(T), cudaMemcpyHostToDevice); + } else { + throw std::invalid_argument("Linear Unsupported name " + name); + } + } + + void prefill(const Stream& stream, int32_t num_tokens, T* input, T* a_tmp, float* c_tmp, T* tgt=nullptr, bool inplace=false) { + T* tgt_temp; + if (tgt == nullptr) { + tgt_temp = this->output; + tgt = tgt_temp; + } else if (inplace && tgt) { + tgt_temp = this->output; + } + else if (!inplace && tgt) { + tgt_temp = tgt; + } + gptq_marlin_gemm<T>( + input, + weight, + scales, + qzeros, + g_idx, + perm, + workspace, + weight_scalar_dtype, + num_tokens, + dim_out, + dim_in, + is_k_full, + false, + use_fp32_reduce, + tgt_temp, + num_groups, + group_size, + 2*dim_out, + has_act_order, + stream.stream, + a_tmp, + c_tmp + ); + + if (inplace) { + elementwise_add<T>(stream, num_tokens, this->dim_out, tgt, tgt_temp, tgt); + } + if constexpr (has_bias) { + batched_add<T>(stream, num_tokens, this->dim_out, tgt, this->bias, tgt); + } + } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_model.cuh b/examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_model.cuh new file mode 100644 index 00000000..5454de0d --- /dev/null +++ b/examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_model.cuh @@ -0,0 +1,157 @@ +#pragma once +#include "../model.cuh" +#include "w4a16_gptq_marlin_layer.cuh" + + +template <typename T> +struct W4A16GPTQMarlinModelImpl: Model { + Memory* memory; + + int vocab_size; + int num_hidden_layers; + int hidden_size; + int intermediate_size; + int num_attention_heads; + int num_key_value_heads; + int head_dim; + float rms_norm_eps; + + int chunk_length; + + KVCacheManager<T>* kv_caches; + + Embedding<T>* embedding; + std::vector<W4A16GPTQMarlinLayer<T>*> layers; + RMSNorm<T>* norm; + LMHead<T>* lm_head; + float residual_scale; + + + W4A16GPTQMarlinModelImpl( + float memory_limit, + int vocab_size, + int num_hidden_layers, + int hidden_size, + int intermediate_size, + int num_attention_heads, + int num_key_value_heads, + int head_dim, + float rms_norm_eps, + int group_size, + int chunk_length, + float scale_embed = 1.0f, + float scale_lmhead = 1.0f, + float scale_residual = 1.0f + ) { + this->vocab_size = vocab_size; + this->num_hidden_layers = num_hidden_layers; + this->hidden_size = hidden_size; + this->intermediate_size = intermediate_size; + this->num_attention_heads = num_attention_heads; + this->num_key_value_heads = num_key_value_heads; + this->head_dim = head_dim; + this->rms_norm_eps = rms_norm_eps; + + this->chunk_length = chunk_length; + this->residual_scale = scale_residual; + + memory = new Memory(memory_limit); + + kv_caches = new KVCacheManager<T>(num_hidden_layers, num_key_value_heads, head_dim); + + embedding = new Embedding<T>(vocab_size, hidden_size, scale_embed); + for (int i = 0; i < num_hidden_layers; i++) { + layers.push_back(new W4A16GPTQMarlinLayer<T>(hidden_size, intermediate_size, num_attention_heads, num_key_value_heads, head_dim, rms_norm_eps, group_size, residual_scale)); + } + norm = new RMSNorm<T>(hidden_size, rms_norm_eps); + lm_head = new LMHead<T>(hidden_size, vocab_size, scale_lmhead); + } + + void init_weight_ptr(Memory* memory) { + embedding->init_weight_ptr(memory); + for (int i = 0; i < num_hidden_layers; i++) { + layers[i]->init_weight_ptr(memory); + } + norm->init_weight_ptr(memory); + lm_head->init_weight_ptr(memory); + kv_caches->init_weight_ptr(memory); + } + + int64_t init_output_ptr(Memory* memory, int32_t num_tokens, int64_t offset) { + int64_t embedding_end = embedding->init_output_ptr(memory, num_tokens, offset); + int64_t layer_end = 0; + for (int i = 0; i < num_hidden_layers; i++) { + layer_end = layers[i]->init_output_ptr(memory, num_tokens, embedding_end); + } + // norm and lm_head are not used in prefill + int64_t norm_end = norm->init_output_ptr(memory, num_tokens, layer_end); + int64_t lm_head_end = lm_head->init_output_ptr(memory, 64, norm_end); + return lm_head_end; + } + + int init_storage() { + init_weight_ptr(memory); + int64_t kv_cache_offset = init_output_ptr(memory, chunk_length, memory->model_offset); + kv_cache_offset = kv_caches->init_output_ptr(memory, kv_cache_offset); + return this->kv_caches->budget; + } + + void load_to_storage(std::string name, void* ptr) { + if (name.substr(0, 18) == "model.embed_tokens") { + embedding->load_to_storage(name, ptr); + } else if (name.substr(0, 10) == "model.norm") { + norm->load_to_storage(name, ptr); + } else if (name.substr(0, 7) == "lm_head") { + lm_head->load_to_storage(name, ptr); + } else if (name.find("rotary_emb") != std::string::npos) { + kv_caches->rotary_embedding->load_to_storage(name, ptr); + } else if (name.substr(0, 12) == "model.layers") { // e.g. model.layers.20.attn.q_proj.weight + std::regex layer_regex("model\\.layers\\.(\\d+)\\.(.*)"); + std::smatch matches; + if (std::regex_search(name, matches, layer_regex)) { + int layer_idx = std::stoi(matches[1]); + layers[layer_idx]->load_to_storage(matches[2], ptr); + } else { + throw std::invalid_argument("Model Layer Unsupported name (layer_idx not found): " + name); + } + } else { + throw std::invalid_argument("Model Unsupported name " + name); + } + } + + void prefill_embed(int32_t num_tokens, int32_t num_history_tokens, T* embed, int32_t* position_ids, void* output) { + T* layer_output = nullptr; + for (int i = 0; i < num_hidden_layers; i++) { + this->layers[i]->prefill(num_tokens, num_history_tokens, embed, layer_output, position_ids, this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, layer_output, this->residual_scale); + this->norm->prefill(calc_stream, num_tokens, embed, layer_output); + this->lm_head->prefill(calc_stream, 1, this->norm->output + (num_tokens - 1) * hidden_size, (T*)output); + } + + void prefill(int32_t num_tokens, int32_t num_history_tokens, int32_t* input, int32_t* position_ids, void* output) { + this->embedding->prefill(calc_stream, num_tokens, input); + prefill_embed(num_tokens, num_history_tokens, this->embedding->output, position_ids, output); + } + + void decode_embed(int32_t num_tokens, int32_t padded_length, T* embed, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + Mask mask(mask_2d, num_tokens, num_tokens); + T* layer_output = nullptr; + for (int i = 0; i < num_hidden_layers; i++) { + this->layers[i]->decode(num_tokens, padded_length, this->embedding->output, layer_output, position_ids, cache_length, mask, this->kv_caches->caches[i]); + layer_output = this->layers[i]->output; + } + elementwise_scale(calc_stream, num_tokens, this->hidden_size, layer_output, this->residual_scale); + this->norm->prefill(calc_stream, num_tokens, this->embedding->output, layer_output); + this->lm_head->prefill(calc_stream, num_tokens, this->norm->output, (T*)output); + } + + void decode(int32_t num_tokens, int32_t padded_length, int32_t* input, int32_t* position_ids, int32_t* cache_length, uint64_t* mask_2d, void* output) { + this->embedding->prefill(calc_stream, num_tokens, input); + decode_embed(num_tokens, padded_length, this->embedding->output, position_ids, cache_length, mask_2d, output); + } + + void draft(int32_t *tree_draft_ids, int32_t *tree_position_ids, int32_t *cache_length, uint64_t* attn_mask, int32_t* tree_parent) { throw std::runtime_error("Draft is not supported"); } + int verify(int32_t num_tokens, int32_t* pred, int32_t* gt, int32_t* position_ids, int32_t* cache_length, uint64_t* attn_mask, int32_t* tree_parent) { throw std::runtime_error("Verify is not supported"); } +}; \ No newline at end of file diff --git a/examples/CPM.cu/src/perf.cu b/examples/CPM.cu/src/perf.cu new file mode 100644 index 00000000..7b70f5e9 --- /dev/null +++ b/examples/CPM.cu/src/perf.cu @@ -0,0 +1,9 @@ +#include <string> +#include <unordered_map> +#include "perf.cuh" + +// 全局统一性能数据存储的实现 +std::unordered_map<std::string, PerfData>& get_perf_data() { + static std::unordered_map<std::string, PerfData> g_perf_data; + return g_perf_data; +} diff --git a/examples/CPM.cu/src/perf.cuh b/examples/CPM.cu/src/perf.cuh new file mode 100644 index 00000000..86a59ff0 --- /dev/null +++ b/examples/CPM.cu/src/perf.cuh @@ -0,0 +1,291 @@ +#ifndef PERF_H +#define PERF_H + +#include <string> +#include <unordered_map> +#include <chrono> +#include <iostream> +#include <iomanip> +#include <cmath> // For std::isnan + +#ifdef __CUDACC__ +#include <cuda_runtime.h> +#include <cuda_profiler_api.h> +#endif + +// #define ENABLE_PERF +// 性能测量开关,可以通过编译时定义ENABLE_PERF来启用 +#ifdef ENABLE_PERF +#define PERF_ENABLED 1 +#else +#define PERF_ENABLED 0 +#endif + +struct PerfData { + std::chrono::high_resolution_clock::time_point start_time; + double total_time = 0.0; + int count = 0; + bool is_running = false; + std::string type = "CPU"; // "CPU" 或 "CUDA" + +#ifdef __CUDACC__ + cudaEvent_t start_event; + cudaEvent_t stop_event; + bool events_created = false; + + ~PerfData() { + if (events_created) { + cudaEventDestroy(start_event); + cudaEventDestroy(stop_event); + } + } +#endif +}; + +#if PERF_ENABLED + +// 前向声明 - 在perf.cpp中实现 +std::unordered_map<std::string, PerfData>& get_perf_data(); + +// 统一初始化性能测量系统 +#define perf_init() \ + do { \ + auto& perf_data = get_perf_data(); \ + for (auto& pair : perf_data) { \ + auto& data = pair.second; \ + if (data.type == "CUDA") { \ + if (data.events_created) { \ + cudaEventDestroy(data.start_event); \ + cudaEventDestroy(data.stop_event); \ + } \ + } \ + } \ + perf_data.clear(); \ + } while(0) + +// CPU性能测量开始 +#define perf_startf(label) \ + do { \ + auto& data = get_perf_data()[#label]; \ + data.type = "CPU"; \ + if (!data.is_running) { \ + data.start_time = std::chrono::high_resolution_clock::now(); \ + data.is_running = true; \ + } \ + } while(0) + +// CPU性能测量停止 +#define perf_stopf(label) \ + do { \ + auto& data = get_perf_data()[#label]; \ + if (data.is_running && data.type == "CPU") { \ + auto end_time = std::chrono::high_resolution_clock::now(); \ + auto duration = std::chrono::duration<double, std::milli>(end_time - data.start_time).count(); \ + data.total_time += duration; \ + data.count++; \ + data.is_running = false; \ + } \ + } while(0) + +#ifdef __CUDACC__ +// CUDA性能测量开始 +#define cuda_perf_startf(label) \ + do { \ + auto& data = get_perf_data()[#label]; \ + data.type = "CUDA"; \ + if (!data.events_created) { \ + cudaEventCreate(&data.start_event); \ + cudaEventCreate(&data.stop_event); \ + data.events_created = true; \ + } \ + if (!data.is_running) { \ + cudaEventRecord(data.start_event); \ + data.is_running = true; \ + } \ + } while(0) + +// CUDA性能测量停止 +#define cuda_perf_stopf(label) \ + do { \ + auto& data = get_perf_data()[#label]; \ + if (data.is_running && data.type == "CUDA" && data.events_created) { \ + cudaEventRecord(data.stop_event); \ + cudaEventSynchronize(data.stop_event); \ + float elapsed_time; \ + cudaEventElapsedTime(&elapsed_time, data.start_event, data.stop_event); \ + data.total_time += elapsed_time; \ + data.count++; \ + data.is_running = false; \ + } \ + } while(0) + +// 获取GPU内存使用情况 +#define cuda_get_memory_usage(free_mem, total_mem) \ + do { \ + size_t free_bytes, total_bytes; \ + cudaMemGetInfo(&free_bytes, &total_bytes); \ + free_mem = free_bytes; \ + total_mem = total_bytes; \ + } while(0) + +// CUDA作用域自动计时器 +#define cuda_perf_scope(label) \ + struct CudaPerfScope_##label { \ + CudaPerfScope_##label() { cuda_perf_startf(label); } \ + ~CudaPerfScope_##label() { cuda_perf_stopf(label); } \ + } cuda_perf_scope_##label + +// 新增: 用于在指定流上进行CUDA性能测量的宏 +#define cuda_perf_start_on_stream_f(label, stream_val) \ + do { \ + auto& data = get_perf_data()[#label]; \ + data.type = "CUDA"; \ + if (!data.events_created) { \ + cudaEventCreate(&data.start_event); \ + cudaEventCreate(&data.stop_event); \ + data.events_created = true; \ + } \ + if (!data.is_running) { \ + cudaEventRecord(data.start_event, stream_val); \ + data.is_running = true; \ + } \ + } while(0) + +#define cuda_perf_stop_on_stream_f(label, stream_val) \ + do { \ + auto& data = get_perf_data()[#label]; \ + if (data.is_running && data.type == "CUDA" && data.events_created) { \ + cudaEventRecord(data.stop_event, stream_val); \ + cudaEventSynchronize(data.stop_event); \ + float elapsed_time; \ + cudaEventElapsedTime(&elapsed_time, data.start_event, data.stop_event); \ + data.total_time += elapsed_time; \ + data.count++; \ + data.is_running = false; \ + } \ + } while(0) + +#define cuda_perf_scope_on_stream(label, stream_val) \ + struct CudaPerfScopeOnStream_##label { \ + cudaStream_t s_val; \ + CudaPerfScopeOnStream_##label(cudaStream_t stream_arg) : s_val(stream_arg) { cuda_perf_start_on_stream_f(label, s_val); } \ + ~CudaPerfScopeOnStream_##label() { cuda_perf_stop_on_stream_f(label, s_val); } \ + } cuda_perf_scope_on_stream_##label(stream_val) + +#else // __CUDACC__ (即没有CUDA支持时) +// 当没有CUDA支持时,CUDA宏变成空操作 +#define cuda_perf_startf(label) do {} while(0) +#define cuda_perf_stopf(label) do {} while(0) +#define cuda_get_memory_usage(free_mem, total_mem) do {} while(0) +#define cuda_perf_scope(label) do {} while(0) +#define cuda_perf_start_on_stream_f(label, stream_val) do {} while(0) +#define cuda_perf_stop_on_stream_f(label, stream_val) do {} while(0) +#define cuda_perf_scope_on_stream(label, stream_val) do {} while(0) + +#endif // __CUDACC__ + +// 统一输出性能统计摘要 +#define perf_summary() \ + do { \ + cudaDeviceSynchronize(); /* 在总结前同步所有CUDA操作 */ \ + auto& perf_data = get_perf_data(); \ + std::cout << "\n=== Performance Summary ===" << std::endl; \ + std::cout << std::left << std::setw(30) << "Label" \ + << std::setw(8) << "Type" \ + << std::setw(8) << "Count" \ + << std::setw(15) << "Total(ms)" \ + << std::setw(15) << "Average(ms)" << std::endl; \ + std::cout << std::string(76, '-') << std::endl; \ + \ + for (auto& pair : perf_data) { \ + const auto& name = pair.first; \ + auto& data = pair.second; \ + if (data.count > 0) { \ + double avg = (data.count > 0) ? (data.total_time / data.count) : 0.0; \ + std::cout << std::left << std::setw(30) << name \ + << std::setw(8) << data.type \ + << std::setw(8) << data.count \ + << std::setw(15) << std::fixed << std::setprecision(3) << data.total_time \ + << std::setw(15) << std::fixed << std::setprecision(3) << avg << std::endl; \ + } \ + } \ + \ + bool has_cuda = false; \ + for (const auto& pair : perf_data) { \ + if (pair.second.type == "CUDA") { \ + has_cuda = true; \ + break; \ + } \ + } \ + \ + if (has_cuda) { \ + size_t free_mem, total_mem; \ + cuda_get_memory_usage(free_mem, total_mem); \ + std::cout << std::string(76, '-') << std::endl; \ + std::cout << "GPU Memory: " << (total_mem - free_mem) / (1024*1024) << "MB used / " \ + << total_mem / (1024*1024) << "MB total" << std::endl; \ + } \ + std::cout << "============================" << std::endl; \ + } while(0) + +// 获取指定标签的总时间(毫秒) +#define perf_get_total_time(label) \ + (get_perf_data().count(#label) ? get_perf_data()[#label].total_time : 0.0) + +// 获取指定标签的平均时间(毫秒) +#define perf_get_avg_time(label) \ + (get_perf_data().count(#label) && get_perf_data()[#label].count > 0 ? \ + get_perf_data()[#label].total_time / get_perf_data()[#label].count : 0.0) + +// 获取指定标签的调用次数 +#define perf_get_count(label) \ + (get_perf_data().count(#label) ? get_perf_data()[#label].count : 0) + +// 重置指定标签的性能数据 +#define perf_reset(label) \ + do { \ + auto& perf_data = get_perf_data(); \ + if (perf_data.count(#label)) { \ + auto& data = perf_data[#label]; \ + if (data.type == "CUDA" && data.events_created) { \ + cudaEventDestroy(data.start_event); \ + cudaEventDestroy(data.stop_event); \ + } \ + perf_data[#label] = PerfData{}; \ + } \ + } while(0) + +// CPU作用域自动计时器(RAII风格) +#define perf_scope(label) \ + struct PerfScope_##label { \ + PerfScope_##label() { perf_startf(label); } \ + ~PerfScope_##label() { perf_stopf(label); } \ + } perf_scope_##label + +#else // PERF_ENABLED (即性能测量被禁用时) + +// 当性能测量被禁用时,所有宏都变成空操作 +#define perf_init() do {} while(0) +#define perf_startf(label) do {} while(0) +#define perf_stopf(label) do {} while(0) +#define perf_summary() do {} while(0) +#define perf_get_total_time(label) 0.0 +#define perf_get_avg_time(label) 0.0 +#define perf_get_count(label) 0 +#define perf_reset(label) do {} while(0) +#define perf_scope(label) do {} while(0) + +// CUDA相关的空操作宏 (因为PERF_ENABLED为false,所有CUDA相关的宏也是空操作) +#define cuda_perf_startf(label) do {} while(0) +#define cuda_perf_stopf(label) do {} while(0) +#define cuda_get_memory_usage(free_mem, total_mem) do {} while(0) +#define cuda_perf_scope(label) do {} while(0) +#define cuda_perf_start_on_stream_f(label, stream_val) do {} while(0) +#define cuda_perf_stop_on_stream_f(label, stream_val) do {} while(0) +#define cuda_perf_scope_on_stream(label, stream_val) do {} while(0) + +#endif // PERF_ENABLED + +#endif // PERF_H + + diff --git a/examples/CPM.cu/src/qgemm/gptq_marlin/core/scalar_type.hpp b/examples/CPM.cu/src/qgemm/gptq_marlin/core/scalar_type.hpp new file mode 100644 index 00000000..5ad89899 --- /dev/null +++ b/examples/CPM.cu/src/qgemm/gptq_marlin/core/scalar_type.hpp @@ -0,0 +1,348 @@ +#pragma once + +// For TORCH_CHECK +#include <torch/library.h> +#include <variant> + +namespace vllm { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +// The type definitions on the Python side can be found in: vllm/scalar_type.py +// these type definitions should be kept up to date with any Python API changes +// here. +// +class ScalarType { + public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_, + int32_t bias, bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr){}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); + } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); + } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, + uint8_t mantissa) { + TORCH_CHECK(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, + bool finite_values_only, + NanRepr nan_repr) { + TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); + TORCH_CHECK(mantissa > 0 && exponent > 0); + TORCH_CHECK(nan_repr != NAN_IEEE_754, + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions"); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, + nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + + private: + // Field size in id + template <typename T_> + static constexpr size_t member_id_field_width() { + using T = std::decay_t<T_>; + return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8; + } + + template <typename Fn, typename Init, typename Member, typename... Rest> + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, + Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template <typename Fn, typename Init> + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, + finite_values_only, nan_repr); + }; + + template <typename Fn, typename Init> + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { + return acc + member_id_field_width<decltype(member)>(); + }, + 0); + } + + public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, + "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair<Id, uint32_t> result, + auto member) -> std::pair<Id, uint32_t> { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width<decltype(member)>(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) + << bit_offset, + bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first; + } + + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width<T>(); + auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) & + ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, + std::pair<std::tuple<>, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, + tuple_args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { return signed_; } + constexpr bool is_integer() const { return exponent == 0; } + constexpr bool is_floating_point() const { return exponent > 0; } + constexpr bool is_ieee_754() const { + return is_floating_point() && finite_values_only == false && + nan_repr == NAN_IEEE_754; + } + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { + return is_floating_point() && finite_values_only == false; + } + constexpr bool has_bias() const { return bias != 0; } + + private: + double _floating_point_max() const { + TORCH_CHECK(mantissa <= 52 && exponent <= 11, + "Cannot represent max/min as a double for type ", str()); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { + max_mantissa -= 1; + } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + TORCH_CHECK(exponent < 11, + "Cannot represent max/min as a double for type ", str()); + max_exponent += 1; + } + + // adjust the exponent to match that of a double + // for now we assume the exponent bias is the standard 2^(e-1) -1, (where e + // is the exponent bits), there is some precedent for non-standard biases, + // example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes + // but to avoid premature over complication we are just assuming the + // standard exponent bias until there is a need to support non-standard + // biases + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = + max_exponent - exponent_bias + exponent_bias_double; + + // shift the mantissa into the position for a double and + // the exponent + uint64_t double_raw = + (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast<double*>(&double_raw); + } + + constexpr std::variant<int64_t, double> _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(), + "Cannot represent max as a int64_t"); + return {(int64_t(1) << mantissa) - 1}; + } + } + + constexpr std::variant<int64_t, double> _raw_min() const { + if (is_floating_point()) { + TORCH_CHECK(is_signed(), + "We currently assume all floating point types are signed"); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast<double*>(&min_raw)}; + } else { + TORCH_CHECK(!is_signed() || size_bits() <= 64, + "Cannot represent min as a int64_t"); + if (is_signed()) { + // set the top bit to 1 (i.e. INT64_MIN) and the rest to 0 + // then perform an arithmetic shift right to set all the bits above + // (size_bits() - 1) to 1 + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + // Max representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant<int64_t, double> max() const { + return std::visit( + [this](auto x) -> std::variant<int64_t, double> { return {x - bias}; }, + _raw_max()); + } + + // Min representable value for this scalar type. + // (accounting for bias if there is one) + constexpr std::variant<int64_t, double> min() const { + return std::visit( + [this](auto x) -> std::variant<int64_t, double> { return {x - bias}; }, + _raw_min()); + } + + std::string str() const { + /* naming generally follows: https://github.com/jax-ml/ml_dtypes + * for floating point types (leading f) the scheme is: + * `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]` + * flags: + * - no-flags: means it follows IEEE 754 conventions + * - f: means finite values only (no infinities) + * - n: means nans are supported (non-standard encoding) + * for integer types the scheme is: + * `[u]int<size_bits>[b<bias>]` + * - if bias is not present it means its zero + */ + if (is_floating_point()) { + auto ret = "float" + std::to_string(size_bits()) + "_e" + + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { + ret += "f"; + } + if (nan_repr != NAN_NONE) { + ret += "n"; + } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { + ret += "b" + std::to_string(bias); + } + return ret; + } + } + + constexpr bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && + bias == other.bias && signed_ == other.signed_ && + finite_values_only == other.finite_values_only && + nan_repr == other.nan_repr; + } +}; + +using ScalarTypeId = ScalarType::Id; + +// "rust style" names generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70 +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE3M2f = + ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = + ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names, generally following: +// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57 +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +static inline constexpr auto kFloat16Id = kFloat16.id(); +}; // namespace vllm diff --git a/examples/CPM.cu/src/qgemm/gptq_marlin/gptq_marlin.cu b/examples/CPM.cu/src/qgemm/gptq_marlin/gptq_marlin.cu new file mode 100644 index 00000000..9eab8072 --- /dev/null +++ b/examples/CPM.cu/src/qgemm/gptq_marlin/gptq_marlin.cu @@ -0,0 +1,2175 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#include <cuda.h> +#include "marlin.cuh" +#include "marlin_dtypes.cuh" +#include "core/scalar_type.hpp" +#include "gptq_marlin.cuh" + + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same<scalar_t, half>::value || \ + std::is_same<scalar_t, nv_bfloat16>::value, \ + "only float16 and bfloat16 is supported"); + +template <typename T> +inline std::string str(T x) { + return std::to_string(x); +} + +namespace marlin { + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template <typename scalar_t> +__device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag, + const typename ScalarType<scalar_t>::FragB& frag_b, + typename ScalarType<scalar_t>::FragC& frag_c) { + const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag); + const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b); + float* c = reinterpret_cast<float*>(&frag_c); + if constexpr (std::is_same<scalar_t, half>::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template <typename scalar_t> +__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a); + uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template <int lut> +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template <int start_byte, int mask> +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template <typename scalar_t, vllm::ScalarTypeId w_type_id> +__device__ inline typename ScalarType<scalar_t>::FragB dequant(int q); + +// +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// +template <> +__device__ inline typename ScalarType<half>::FragB +dequant<half, vllm::kU4B8.id()>(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + typename ScalarType<half>::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), + *reinterpret_cast<const half2*>(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi), + *reinterpret_cast<const half2*>(&MUL), + *reinterpret_cast<const half2*>(&ADD)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType<nv_bfloat16>::FragB +dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + typename ScalarType<nv_bfloat16>::FragB frag_b; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC308C308; + + frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo), + *reinterpret_cast<const nv_bfloat162*>(&MUL), + *reinterpret_cast<const nv_bfloat162*>(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi), + *reinterpret_cast<const nv_bfloat162*>(&MUL), + *reinterpret_cast<const nv_bfloat162*>(&ADD)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType<half>::FragB +dequant<half, vllm::kU4.id()>(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + typename ScalarType<half>::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), + *reinterpret_cast<const half2*>(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi), + *reinterpret_cast<const half2*>(&MUL), + *reinterpret_cast<const half2*>(&ADD)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType<nv_bfloat16>::FragB +dequant<nv_bfloat16, vllm::kU4.id()>(int q) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + typename ScalarType<nv_bfloat16>::FragB frag_b; + static constexpr uint32_t MUL = 0x3F803F80; + static constexpr uint32_t ADD = 0xC300C300; + + frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo), + *reinterpret_cast<const nv_bfloat162*>(&MUL), + *reinterpret_cast<const nv_bfloat162*>(&ADD)); + frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi), + *reinterpret_cast<const nv_bfloat162*>(&MUL), + *reinterpret_cast<const nv_bfloat162*>(&ADD)); + return frag_b; +} + +// +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +// +template <> +__device__ inline typename ScalarType<half>::FragB +dequant<half, vllm::kU8B128.id()>(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q); + uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + typename ScalarType<half>::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), + *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi), + *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType<nv_bfloat16>::FragB +dequant<nv_bfloat16, vllm::kU8B128.id()>(int q) { + typename ScalarType<nv_bfloat16>::FragB frag_b; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast<uint32_t*>(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); + + return frag_b; +} + +template <> +__device__ inline typename ScalarType<half>::FragB +dequant<half, vllm::kU8.id()>(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q); + uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + + typename ScalarType<half>::FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo), + *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi), + *reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline typename ScalarType<nv_bfloat16>::FragB +dequant<nv_bfloat16, vllm::kU8.id()>(int q) { + typename ScalarType<nv_bfloat16>::FragB frag_b; + + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast<uint32_t*>(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); + + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template <typename scalar_t> +__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b, + typename ScalarType<scalar_t>::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; + scalar_t2 s = + ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +template <typename scalar_t> +__device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b, + typename ScalarType<scalar_t>::scalar_t2& frag_zp, + int i) { + using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; + scalar_t2 zp = + ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Same as above, but for act_order (each K is multiplied individually) +template <typename scalar_t> +__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b, + typename ScalarType<scalar_t>::FragS& frag_s_1, + typename ScalarType<scalar_t>::FragS& frag_s_2, + typename ScalarType<scalar_t>::FragS& frag_s_3, + typename ScalarType<scalar_t>::FragS& frag_s_4, + int i) { + using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast<scalar_t*>(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast<scalar_t*>(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast<scalar_t*>(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast<scalar_t*>(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template <typename scalar_t> +__device__ inline void scale_float(float* c, + typename ScalarType<scalar_t>::FragS& s) { + scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s); + c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int offset = row * row_stride; + + half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset); + half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +template <typename scalar_t, // compute dtype, half or nv_float16 + const vllm::ScalarTypeId w_type_id, // weight ScalarType id + const int threads, // number of threads in a threadblock + const int thread_m_blocks, // number of 16x16 blocks in the m + // dimension (batchsize) of the + // threadblock + const int thread_n_blocks, // same for n dimension (output) + const int thread_k_blocks, // same for k dimension (reduction) + const int stages, // number of stages for the async global->shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_fp32_reduce // whether to use fp32 global reduce +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType<scalar_t>; + using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2; + using FragA = typename ScalarType<scalar_t>::FragA; + using FragB = typename ScalarType<scalar_t>::FragB; + using FragC = typename ScalarType<scalar_t>::FragC; + using FragS = typename ScalarType<scalar_t>::FragS; + using FragZP = typename ScalarType<scalar_t>::FragZP; + + static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + + constexpr int pack_factor = 32 / w_type.size_bits(); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + int par_id = 0; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + par_id++; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + int zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + int4* sh_red = sh_s + (stages * s_sh_stage); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast<float*>(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast<int4 const*>(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + // Only fetch scales if this tile starts a new group + if ((pipe + 1) % (group_blocks / thread_k_blocks) == 0) { + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait<stages - 2>(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4<scalar_t>(frag_a[k % 2][i], + &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + reinterpret_cast<int4*>(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast<int4*>(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast<int4*>(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast<int*>(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast<int4*>(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i]; + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + if constexpr (has_zp) { + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = frag_qzp[k % 2][1]; + } + + frag_zp_0 = dequant<scalar_t, w_type_id>(zp_quant_0); + frag_zp_1 = dequant<scalar_t, w_type_id>(zp_quant_1); + + frag_zp[0] = frag_zp_0[0]; + frag_zp[1] = frag_zp_0[1]; + frag_zp[2] = frag_zp_1[0]; + frag_zp[3] = frag_zp_1[1]; + } + + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + frag_b0 = dequant<scalar_t, w_type_id>(b_quant_0); + frag_b1 = dequant<scalar_t, w_type_id>(b_quant_1); + + // Apply zero-point to frag_b0 + if constexpr (has_zp) { + sub_zp<scalar_t>(frag_b0, frag_zp[j], 0); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4<scalar_t>(frag_b0, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale<scalar_t>(frag_b0, frag_s[k % 2][j], 0); + } + } + + // Apply zero-point to frag_b1 + if constexpr (has_zp) { + sub_zp<scalar_t>(frag_b1, frag_zp[j], 1); + } + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j], + act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], + act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale<scalar_t>(frag_b1, frag_s[k % 2][j], 1); + } + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast<float*>( + &sh_red[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = + reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<float*>( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<scalar_t*>(&c)[j] = + Dtype::float2num(reinterpret_cast<float*>( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + int par_offset = c_size * n_tiles * par_id; + int slice_offset = c_size * slice_col; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = par_offset + slice_offset; + + if (!is_th_active) { + return; + } + + if (!first) { + float* frag_c_ptr = reinterpret_cast<float*>(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k++) { + sh_red[threadIdx.x] = + C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float* sh_c_ptr = reinterpret_cast<float*>(&sh_red[threadIdx.x]); + #pragma unroll + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) { + int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k++) { + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { + res = __hmul2(res, s[0]); + } + + ((scalar_t2*)sh_red)[idx] = res; + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh_red[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + + if constexpr (has_zp && group_blocks == -1) { + if (i == 0) { + fetch_zp_to_shared(); + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } else { + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float<scalar_t>( + reinterpret_cast<float*>(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float<scalar_t>( + reinterpret_cast<float*>(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float<scalar_t>( + reinterpret_cast<float*>(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float<scalar_t>( + reinterpret_cast<float*>(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } +} + + #define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \ + HAS_ZP, GROUP_BLOCKS>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \ + HAS_ZP, GROUP_BLOCKS> \ + <<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \ + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \ + num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}, +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}, + +}; + +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = div_ceil(prob_m, 16); + int tb_max_m = 16; + + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * pipe_stages; + + float reduce_size = max(th_config.num_threads * 32 * 4, + (tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2); + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size + reduce_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Determine cache for scales + int scales_cache_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + + return true; +} + +int determine_reduce_max_m(int prob_m, int max_par) { + constexpr int tile_m_size = 16; + + if (prob_m <= tile_m_size) { + return tile_m_size; + + } else if (prob_m <= tile_m_size * 2) { + return tile_m_size * 2; + + } else if (prob_m <= tile_m_size * 3) { + return tile_m_size * 3; + + } else if (prob_m <= tile_m_size * 4) { + return tile_m_size * 4; + + } else { + int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par); + return tile_m_size * 4 * cur_par; + } +} + +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + bool has_act_order, bool is_k_full, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage + } + + return exec_config_t{0, {-1, -1, -1}}; +} + + #define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \ + \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + + #define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \ + \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) + +template <typename scalar_t> +void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, + void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m, + int prob_n, int prob_k, void* workspace, + vllm::ScalarType const& q_type, bool has_act_order, + bool is_k_full, bool has_zp, int num_groups, int group_size, + int dev, cudaStream_t stream, int thread_k, int thread_n, + int sms, int max_par, bool use_fp32_reduce) { + if (has_zp) { + TORCH_CHECK( + q_type == vllm::kU4 || q_type == vllm::kU8, + "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } else { + TORCH_CHECK( + q_type == vllm::kU4B8 || q_type == vllm::kU8B128, + "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + q_type.str()); + } + + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + // TODO: remove alias when we start supporting other 8bit types + int num_bits = q_type.size_bits(); + int tot_m = prob_m; + int tot_m_blocks = div_ceil(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + exec_config_t exec_cfg; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; + } else { + // Auto config + exec_cfg = + determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem); + } + + TORCH_CHECK(exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + int4* C_tmp_ptr = (int4*)C_tmp; + const int4* s_ptr = (const int4*)s; + const int4* zp_ptr = (const int4*)zp; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; + + int* locks = (int*)workspace; + + if (has_act_order) { + // Permute A columns + int block_rows = div_ceil(prob_m, blocks); + permute_cols_kernel<<<blocks, default_threads, 0, stream>>>( + A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows); + A_ptr = a_tmp_ptr; + } + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by having + // a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + + // Main loop + for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > exec_cfg.max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * exec_cfg.max_m_blocks) * par; + i += exec_cfg.max_m_blocks * (par - 1); + thread_m_blocks = exec_cfg.max_m_blocks; + } + + if (false) { + } + GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256) + GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256) + GPTQ_CALL_IF(vllm::kU4B8, 8, 4, 128) + GPTQ_CALL_IF(vllm::kU4B8, 4, 8, 128) + GPTQ_CALL_IF(vllm::kU8B128, 16, 4, 256) + GPTQ_CALL_IF(vllm::kU8B128, 8, 8, 256) + GPTQ_CALL_IF(vllm::kU8B128, 8, 4, 128) + GPTQ_CALL_IF(vllm::kU8B128, 4, 8, 128) + + AWQ_CALL_IF(vllm::kU4, 16, 4, 256) + AWQ_CALL_IF(vllm::kU4, 8, 8, 256) + AWQ_CALL_IF(vllm::kU4, 8, 4, 128) + AWQ_CALL_IF(vllm::kU4, 4, 8, 128) + AWQ_CALL_IF(vllm::kU8, 16, 4, 256) + AWQ_CALL_IF(vllm::kU8, 8, 8, 256) + AWQ_CALL_IF(vllm::kU8, 8, 4, 128) + AWQ_CALL_IF(vllm::kU8, 4, 8, 128) + else { + TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, + ", ", prob_k, "]", ", has_act_order = ", has_act_order, + ", num_groups = ", num_groups, ", group_size = ", group_size, + ", thread_m_blocks = ", thread_m_blocks, + ", thread_n_blocks = ", thread_n_blocks, + ", thread_k_blocks = ", thread_k_blocks, + ", num_bits = ", num_bits); + } + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } +} + +} // namespace marlin + +// torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, +// torch::Tensor& b_scales, torch::Tensor& b_zeros, +// torch::Tensor& g_idx, torch::Tensor& perm, +// torch::Tensor& workspace, +// vllm::ScalarType const& b_q_type, // init in linear +// int64_t size_m, int64_t size_n, int64_t size_k, +// bool is_k_full, bool has_zp, +// bool use_fp32_reduce, +// // new input +// void* c +// ) { + +template <typename T> +void gptq_marlin_gemm(T* a, int32_t* b_q_weight, + T* b_scales, int32_t* b_zeros, + int32_t* g_idx, int32_t* perm, + int32_t* workspace, + vllm::ScalarType const& b_q_type, // init in linear + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, bool has_zp, + bool use_fp32_reduce, + // TODO: new args + T* c, + int num_groups, int group_size, + int b_q_weight_size1, + bool has_act_order, + cudaStream_t stream, + T* a_tmp, + float* c_tmp + ) { + + int pack_factor = 32 / b_q_type.size_bits(); + + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + // Verify workspace size + TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, + ", is not divisible by min_thread_n = ", marlin::min_thread_n); + + // int dev = a.get_device(); + int dev = 0; // 选择第一个 GPU(设备 0) + cudaSetDevice(dev); // 设置当前 CUDA 设备 + marlin::marlin_mm<T>( + a, b_q_weight, c, + c_tmp, b_scales, + b_zeros, g_idx, perm, + a_tmp, size_m, size_n, size_k, + workspace, b_q_type, has_act_order, is_k_full, has_zp, + num_groups, group_size, dev, stream, + thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce); + + return; +} + +// Explicit template instantiations +#ifdef ENABLE_DTYPE_FP16 +template void gptq_marlin_gemm<__half>(__half* a, int32_t* b_q_weight, + __half* b_scales, int32_t* b_zeros, + int32_t* g_idx, int32_t* perm, + int32_t* workspace, + vllm::ScalarType const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, bool has_zp, + bool use_fp32_reduce, + __half* c, + int num_groups, int group_size, + int b_q_weight_size1, + bool has_act_order, + cudaStream_t stream, + __half* a_tmp, + float* c_tmp); +#endif + +#ifdef ENABLE_DTYPE_BF16 +template void gptq_marlin_gemm<__nv_bfloat16>(__nv_bfloat16* a, int32_t* b_q_weight, + __nv_bfloat16* b_scales, int32_t* b_zeros, + int32_t* g_idx, int32_t* perm, + int32_t* workspace, + vllm::ScalarType const& b_q_type, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, bool has_zp, + bool use_fp32_reduce, + __nv_bfloat16* c, + int num_groups, int group_size, + int b_q_weight_size1, + bool has_act_order, + cudaStream_t stream, + __nv_bfloat16* a_tmp, + float* c_tmp); +#endif diff --git a/examples/CPM.cu/src/qgemm/gptq_marlin/gptq_marlin.cuh b/examples/CPM.cu/src/qgemm/gptq_marlin/gptq_marlin.cuh new file mode 100644 index 00000000..c85e6c03 --- /dev/null +++ b/examples/CPM.cu/src/qgemm/gptq_marlin/gptq_marlin.cuh @@ -0,0 +1,27 @@ +#pragma once +#include <cuda_fp16.h> +#include "core/scalar_type.hpp" + +namespace marlin { + +int determine_reduce_max_m(int prob_m, int max_par); + +} + +template <typename T> +void gptq_marlin_gemm(T* a, int32_t* b_q_weight, + T* b_scales, int32_t* b_zeros, + int32_t* g_idx, int32_t* perm, + int32_t* workspace, + vllm::ScalarType const& b_q_type, // init in linear + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, bool has_zp, + bool use_fp32_reduce, + T* c, + int num_groups, int group_size, + int b_q_weight_size1, + bool has_act_order, + cudaStream_t stream, + T* a_tmp, + float* c_tmp + ); \ No newline at end of file diff --git a/examples/CPM.cu/src/qgemm/gptq_marlin/marlin.cuh b/examples/CPM.cu/src/qgemm/gptq_marlin/marlin.cuh new file mode 100644 index 00000000..74ccbac5 --- /dev/null +++ b/examples/CPM.cu/src/qgemm/gptq_marlin/marlin.cuh @@ -0,0 +1,87 @@ +#pragma once + +#include <torch/all.h> + +#include <ATen/cuda/CUDAContext.h> +#include <c10/cuda/CUDAGuard.h> +#include <cuda.h> +#include <cuda_fp16.h> +#include <cuda_runtime.h> +#include <iostream> + +namespace marlin { + +// Marlin params + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = + 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +// Repack params +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +// Helpers +template <typename T, int n> +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec<int, 4>; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +// No support for async +#else + +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template <int n> +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +#endif + +} // namespace marlin diff --git a/examples/CPM.cu/src/qgemm/gptq_marlin/marlin_dtypes.cuh b/examples/CPM.cu/src/qgemm/gptq_marlin/marlin_dtypes.cuh new file mode 100644 index 00000000..4f4522b2 --- /dev/null +++ b/examples/CPM.cu/src/qgemm/gptq_marlin/marlin_dtypes.cuh @@ -0,0 +1,75 @@ + +#define _data_types_cuh +#include "marlin.cuh" +#include <cuda_fp16.h> +#include <cuda_bf16.h> + +namespace marlin { + +template <typename scalar_t> +class ScalarType {}; + +template <> +class ScalarType<half> { + public: + using scalar_t = half; + using scalar_t2 = half2; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec<half2, 4>; + using FragB = Vec<half2, 2>; + using FragC = Vec<float, 4>; + using FragS = Vec<half2, 1>; + using FragZP = Vec<half2, 4>; + + static __device__ float inline num2float(const half x) { + return __half2float(x); + } + + static __device__ half2 inline num2num2(const half x) { + return __half2half2(x); + } + + static __device__ half2 inline nums2num2(const half x1, const half x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ half inline float2num(const float x) { + return __float2half(x); + } +}; + +template <> +class ScalarType<nv_bfloat16> { + public: + using scalar_t = nv_bfloat16; + using scalar_t2 = nv_bfloat162; + + using FragA = Vec<nv_bfloat162, 4>; + using FragB = Vec<nv_bfloat162, 2>; + using FragC = Vec<float, 4>; + using FragS = Vec<nv_bfloat162, 1>; + using FragZP = Vec<nv_bfloat162, 4>; + + static __device__ float inline num2float(const nv_bfloat16 x) { + return __bfloat162float(x); + } + + static __device__ nv_bfloat162 inline num2num2(const nv_bfloat16 x) { + return __bfloat162bfloat162(x); + } + + static __device__ nv_bfloat162 inline nums2num2(const nv_bfloat16 x1, + const nv_bfloat16 x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ nv_bfloat16 inline float2num(const float x) { + return __float2bfloat16(x); + } +}; + +} // namespace marlin + diff --git a/examples/CPM.cu/src/signal_handler.cu b/examples/CPM.cu/src/signal_handler.cu new file mode 100644 index 00000000..6d26a661 --- /dev/null +++ b/examples/CPM.cu/src/signal_handler.cu @@ -0,0 +1,126 @@ +#include "signal_handler.cuh" +#include <dlfcn.h> + +// 保存原有信号处理器的全局变量 +std::map<int, void(*)(int)> original_handlers; + +void init_signal_handlers() { + // 保存并设置信号处理器 + original_handlers[SIGSEGV] = signal(SIGSEGV, signal_handler); // 段错误 + original_handlers[SIGABRT] = signal(SIGABRT, signal_handler); // 异常终止 + original_handlers[SIGFPE] = signal(SIGFPE, signal_handler); // 浮点异常 + original_handlers[SIGILL] = signal(SIGILL, signal_handler); // 非法指令 +#ifdef SIGBUS + original_handlers[SIGBUS] = signal(SIGBUS, signal_handler); // 总线错误 (某些系统可能没有) +#endif + original_handlers[SIGTERM] = signal(SIGTERM, signal_handler); // 终止信号 + original_handlers[SIGINT] = signal(SIGINT, signal_handler); // 中断信号 (Ctrl+C) + + std::cout << "Signal handlers initialized for common exceptions" << std::endl; +} + +// TODO 修复和python traceback的协作 +void signal_handler(int sig) { + const char* signal_name = "Unknown"; + + switch (sig) { + case SIGSEGV: signal_name = "SIGSEGV (Segmentation fault)"; break; + case SIGABRT: signal_name = "SIGABRT (Abort)"; break; + case SIGFPE: signal_name = "SIGFPE (Floating point exception)"; break; + case SIGILL: signal_name = "SIGILL (Illegal instruction)"; break; +#ifdef SIGBUS + case SIGBUS: signal_name = "SIGBUS (Bus error)"; break; +#endif + case SIGTERM: signal_name = "SIGTERM (Termination)"; break; + case SIGINT: signal_name = "SIGINT (Interrupt)"; break; + } + + std::cerr << "\n=== SIGNAL CAUGHT ===" << std::endl; + std::cerr << "Signal: " << signal_name << " (" << sig << ")" << std::endl; + std::cerr << "Process ID: " << getpid() << std::endl; + std::cerr << "====================" << std::endl; + + // 打印栈帧信息 + print_stack_trace(50); + + std::cerr << "\nProgram terminated due to signal " << sig << std::endl; + std::cerr.flush(); + std::cout.flush(); + + // 查找并调用原有的信号处理器 + auto it = original_handlers.find(sig); + if (it != original_handlers.end() && it->second != SIG_DFL && it->second != SIG_IGN) { + std::cerr << "Calling original signal handler..." << std::endl; + it->second(sig); + } + + // 恢复默认信号处理并重新发送信号 + std::cerr << "Restoring default handler..." << std::endl; + signal(sig, SIG_DFL); + raise(sig); +} + +void print_stack_trace(int max_frames) { + void **array = new void*[max_frames]; + + // 获取调用栈 + int size = backtrace(array, max_frames); + char **strings = backtrace_symbols(array, size); + + if (strings == nullptr) { + std::cerr << "Failed to get backtrace symbols (backtrace may not be available on this system)" << std::endl; + delete[] array; + return; + } + + // 添加基地址信息 + Dl_info dl_info; + if (dladdr((void*)print_stack_trace, &dl_info)) { + std::cerr << "\n=== MODULE INFO ===" << std::endl; + std::cerr << "Base address: " << dl_info.dli_fbase << std::endl; + std::cerr << "Module path: " << dl_info.dli_fname << std::endl; + } + + std::cerr << "=== STACK TRACE ===" << std::endl; + std::cerr << "Call stack (" << size << " frames):" << std::endl; + + for (int i = 0; i < size; i++) { + std::string symbol_info = get_symbol_name(strings[i]); + std::cerr << "[" << i << "] " << symbol_info << std::endl; + } + + std::cerr << "==================" << std::endl; + + free(strings); + delete[] array; +} + +std::string get_symbol_name(const char* symbol) { + std::string result(symbol); + + // 查找函数名的开始和结束位置 + char *start = strstr((char*)symbol, "("); + char *end = strstr((char*)symbol, "+"); + + if (start && end && start < end) { + *end = '\0'; + char *function_name = start + 1; + + // 尝试demangle C++符号名 + int status; + char *demangled = abi::__cxa_demangle(function_name, 0, 0, &status); + + if (status == 0 && demangled) { + // 成功demangle + std::string prefix(symbol, start - symbol + 1); + std::string suffix = end + 1; + result = prefix + demangled + "+" + suffix; + free(demangled); + } else { + // demangle失败,恢复原始字符串 + *end = '+'; + } + } + + return result; +} \ No newline at end of file diff --git a/examples/CPM.cu/src/signal_handler.cuh b/examples/CPM.cu/src/signal_handler.cuh new file mode 100644 index 00000000..226f9849 --- /dev/null +++ b/examples/CPM.cu/src/signal_handler.cuh @@ -0,0 +1,26 @@ +#pragma once + +#include <signal.h> +#include <cstdio> +#include <cstdlib> +#include <unistd.h> +#include <execinfo.h> +#include <cxxabi.h> +#include <string> +#include <iostream> +#include <map> + +// 保存原有信号处理器 +extern std::map<int, void(*)(int)> original_handlers; + +// 初始化signal处理器 +void init_signal_handlers(); + +// signal处理函数 +void signal_handler(int sig); + +// 打印栈帧信息 +void print_stack_trace(int max_frames = 50); + +// 获取符号名称(带demangling) +std::string get_symbol_name(const char* symbol); \ No newline at end of file diff --git a/examples/CPM.cu/src/trait.cuh b/examples/CPM.cu/src/trait.cuh new file mode 100644 index 00000000..188a4986 --- /dev/null +++ b/examples/CPM.cu/src/trait.cuh @@ -0,0 +1,37 @@ +#pragma once +#include <cuda_runtime.h> +#include <cuda_fp16.h> +#include <cuda_bf16.h> + +template <typename T> +struct TypeTraits; + +template <> +struct TypeTraits<__half> { + using half2 = __half2; + + static __inline__ cudaDataType_t cublas_type() { + return CUDA_R_16F; + } + + static __inline__ int type_code() { + return 0; + } + + static __host__ __device__ __inline__ constexpr __half inf() { const short v = 0x7c00; return *(reinterpret_cast<const __half *>(&(v))); } +}; + +template <> +struct TypeTraits<__nv_bfloat16> { + using half2 = __nv_bfloat162; + + static __inline__ cudaDataType_t cublas_type() { + return CUDA_R_16BF; + } + + static __inline__ int type_code() { + return 1; + } + + static __host__ __device__ __inline__ constexpr __nv_bfloat16 inf() { const short v = 0x7f80; return *(reinterpret_cast<const __nv_bfloat16 *>(&(v))); } +}; diff --git a/examples/CPM.cu/src/utils.cu b/examples/CPM.cu/src/utils.cu new file mode 100644 index 00000000..fb19554e --- /dev/null +++ b/examples/CPM.cu/src/utils.cu @@ -0,0 +1,25 @@ +#include "perf.cuh" +#include "utils.cuh" +#include "signal_handler.cuh" + +bool initialized = false; + +Stream calc_stream; + +int graphCreated_padding_length = -1; +int graphCreated_input_length = -1; +cudaGraph_t graph; +cudaGraphExec_t graphExec; + +void init_resources() { + if (initialized) return; + + // 初始化signal处理器 + init_signal_handlers(); + perf_init(); + + cudaCheck(cudaStreamCreate(&calc_stream.stream)); + cublasCheck(cublasCreate(&calc_stream.cublas_handle)); + cublasCheck(cublasSetStream(calc_stream.cublas_handle, calc_stream.stream)); + initialized = true; +} diff --git a/examples/CPM.cu/src/utils.cuh b/examples/CPM.cu/src/utils.cuh new file mode 100644 index 00000000..2acf7291 --- /dev/null +++ b/examples/CPM.cu/src/utils.cuh @@ -0,0 +1,60 @@ +#pragma once +#include <cmath> +#include <cstdio> +#include <cstdlib> +#include <iostream> +#include <iomanip> +#include <cuda_runtime.h> +#include <cublas_v2.h> +#include "signal_handler.cuh" + +extern bool initialized; + +struct Stream { + cudaStream_t stream; + cublasHandle_t cublas_handle; +}; + +extern Stream calc_stream; + +extern int graphCreated_padding_length; +extern int graphCreated_input_length; +extern cudaGraph_t graph; +extern cudaGraphExec_t graphExec; + +#define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) +#define ROUND_UP(M, N) (((M) + (N) - 1) / (N) * (N)) + +inline const char* cublasGetErrorString(cublasStatus_t status) { + switch(status) { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + default: return "Unknown cuBLAS error"; + } +} + +#define cudaCheck(err) \ + if (err != cudaSuccess) { \ + std::cerr << "cuda error at " << __FILE__ << ":" << __LINE__ << std::endl; \ + std::cerr << cudaGetErrorString(err) << std::endl; \ + print_stack_trace(); \ + exit(EXIT_FAILURE); \ + } + +#define cublasCheck(err) \ + if (err != CUBLAS_STATUS_SUCCESS) { \ + std::cerr << "cuBLAS error at " << __FILE__ << ":" << __LINE__ << std::endl; \ + std::cerr << "Error code: " << err << " (" << cublasGetErrorString(err) << ")" << std::endl; \ + print_stack_trace(); \ + exit(EXIT_FAILURE); \ + } + +void init_resources(); diff --git a/examples/CPM.cu/src/utilsq.cuh b/examples/CPM.cu/src/utilsq.cuh new file mode 100644 index 00000000..137e4901 --- /dev/null +++ b/examples/CPM.cu/src/utilsq.cuh @@ -0,0 +1,421 @@ +#pragma once +#include <cuda_runtime.h> +#include <cuda_fp16.h> +#include <cuda_bf16.h> + +template<typename T> struct num_elems; +template <> struct num_elems<float> { static constexpr int value = 1; }; +template <> struct num_elems<float2> { static constexpr int value = 2; }; +template <> struct num_elems<float4> { static constexpr int value = 4; }; +template <> struct num_elems<half> { static constexpr int value = 1; }; +template <> struct num_elems<half2> { static constexpr int value = 2; }; +template <> struct num_elems<__nv_bfloat16> { static constexpr int value = 1; }; +template <> struct num_elems<__nv_bfloat162> { static constexpr int value = 2; }; + +template<typename T, int num> struct packed_as; +template<typename T> struct packed_as<T, 1> { using type = T; }; +template<> struct packed_as<half, 2> { using type = half2; }; +template<> struct packed_as<float, 2> { using type = float2; }; +template<> struct packed_as<int8_t, 2> { using type = int16_t; }; +template<> struct packed_as<int32_t, 2> { using type = int2; }; +template<> struct packed_as<half2, 1> { using type = half; }; +template<> struct packed_as<float2, 1> { using type = float; }; +// BF16 +template<> struct packed_as<__nv_bfloat16, 2> { using type = __nv_bfloat162; }; +template<> struct packed_as<__nv_bfloat162, 1> { using type = __nv_bfloat16; }; + +inline __device__ float2 operator*(float2 a, float2 b) { return make_float2(a.x * b.x, a.y * b.y); } +inline __device__ float2 operator+(float2 a, float2 b) { return make_float2(a.x + b.x, a.y + b.y); } +inline __device__ float2 operator-(float2 a, float2 b) { return make_float2(a.x - b.x, a.y - b.y); } + +inline __device__ float2 operator*(float2 a, float b) { return make_float2(a.x * b, a.y * b); } +inline __device__ float2 operator+(float2 a, float b) { return make_float2(a.x + b, a.y + b); } +inline __device__ float2 operator-(float2 a, float b) { return make_float2(a.x - b, a.y - b); } + + +static inline __device__ int8_t float_to_int8_rn(float x) +{ + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast<const int8_t&>(dst); +} + +template<typename T> +inline __device__ T ldg(const T* val) { + return __ldg(val); +} + + + +#define bf1622float2 __bfloat1622float2 +#define float22bf162 __float22bfloat162_rn +#define bf162bf162 __bfloat162bfloat162 +inline __device__ int16_t bf1622int16(__nv_bfloat162 val) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = max(min(__low2float(val), 127.f), -128.f); + f_val.y = max(min(__high2float(val), 127.f), -128.f); + + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = static_cast<int8_t>(static_cast<short>(f_val.x)); + int8[1] = static_cast<int8_t>(static_cast<short>(f_val.y)); + return int16; +#else + val = __hmin2(val, make_bfloat162(127., 127.)); + val = __hmax2(val, make_bfloat162(-128., -128.)); + + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = static_cast<int8_t>(static_cast<short>(val.x)); + int8[1] = static_cast<int8_t>(static_cast<short>(val.y)); + return int16; +} + +template<> +inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return val[0]; +#else + return __ldg(val); +#endif +} + +template<> +inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return val[0]; +#else + return __ldg(val); +#endif +} + +template <typename T_OUT, typename T_IN> +__device__ inline T_OUT cuda_cast(T_IN val) +{ + return val; +} + +template <> +__device__ inline float2 cuda_cast<float2, int2>(int2 val) +{ + return make_float2(val.x, val.y); +} + +template <> +__device__ inline float2 cuda_cast<float2, float>(float val) +{ + return make_float2(val, val); +} + +template <> +__device__ inline float2 cuda_cast<float2, half2>(half2 val) +{ + return __half22float2(val); +} + +template <> +__device__ inline half2 cuda_cast<half2, float2>(float2 val) +{ + return __float22half2_rn(val); +} + +template <> +__device__ inline half2 cuda_cast<half2, float>(float val) +{ + return __float2half2_rn(val); +} + +template <> +__device__ inline half2 cuda_cast<half2, half>(half val) +{ + return __half2half2(val); +} + +template <> +__device__ inline int8_t cuda_cast<int8_t, half>(half val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + union + { + half fp16; + int16_t int16_in; + }; + + fp16 = val; + asm volatile("cvt.rni.sat.s8.f16 %0, %1;" : "=h"(int16) : "h"(int16_in)); + return int8[0]; +} + +template <> +__device__ inline int16_t cuda_cast<int16_t, half2>(half2 val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = cuda_cast<int8_t>(val.x); + int8[1] = cuda_cast<int8_t>(val.y); + return int16; +} + +template <> +__device__ inline int8_t cuda_cast<int8_t, float>(float val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); + return int8[0]; +} + +template <> +__device__ inline int16_t cuda_cast<int16_t, float2>(float2 val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int8[0] = cuda_cast<int8_t>(val.x); + int8[1] = cuda_cast<int8_t>(val.y); + return int16; +} + +template <> +__device__ inline half2 cuda_cast<half2, int16_t>(int16_t val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int16 = val; + return make_half2(int8[0], int8[1]); +} + +template <> +__device__ inline float2 cuda_cast<float2, int16_t>(int16_t val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int16 = val; + return make_float2(int8[0], int8[1]); +} + +// #ifdef ENABLE_BF16 +template <> +__device__ inline __nv_bfloat16 cuda_cast(int32_t val) +{ + return static_cast<float>(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast(int8_t val) +{ + return static_cast<float>(val); +} + +template <> +__device__ inline int8_t cuda_cast(__nv_bfloat16 val) +{ + return static_cast<float>(val); +} + +template <> +__device__ inline float cuda_cast<float, __nv_bfloat16>(__nv_bfloat16 val) +{ + return __bfloat162float(val); +} + +template <> +__device__ inline float2 cuda_cast<float2, __nv_bfloat162>(__nv_bfloat162 val) +{ + return bf1622float2(val); +} + +template <> +__device__ inline half cuda_cast<half, __nv_bfloat16>(__nv_bfloat16 val) +{ + return __float2half(__bfloat162float(val)); +} + +template <> +__device__ inline int16_t cuda_cast<int16_t, __nv_bfloat162>(__nv_bfloat162 val) +{ + return bf1622int16(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, float>(float val) +{ + return __float2bfloat16(val); +} + +template <> +__device__ inline __nv_bfloat16 cuda_cast<__nv_bfloat16, half>(half val) +{ + return __float2bfloat16(__half2float(val)); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, __nv_bfloat16>(__nv_bfloat16 val) +{ + return bf162bf162(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float>(float val) +{ + return __float2bfloat162_rn(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, float2>(float2 val) +{ + return float22bf162(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, int16_t>(int16_t val) +{ + union + { + int8_t int8[2]; + int16_t int16; + }; + + int16 = val; + __nv_bfloat162 res; + res.x = cuda_cast<__nv_bfloat16>(int8[0]); + res.y = cuda_cast<__nv_bfloat16>(int8[1]); + return res; +} + +template <> +__device__ inline __nv_bfloat162 cuda_cast<__nv_bfloat162, half2>(half2 val) +{ + return float22bf162(__half22float2(val)); +} + +// #endif // ENABLE BF16 + +template <typename To, typename Ti> +__device__ inline To cuda_sum(Ti val) +{ + return cuda_cast<To>(val); +}; + +template <typename To> +__device__ inline To cuda_sum(float2 val) +{ + return cuda_cast<To>(val.x + val.y); +}; + +// Unary maximum: compute the max of a vector type +template <typename To, typename Ti> +__device__ inline To cuda_max(Ti val) +{ + return cuda_cast<To>(val); +}; + +template <> +__device__ inline float cuda_max(float2 val) +{ + return fmaxf(val.x, val.y); +} + +template <> +__device__ inline half cuda_max(half2 val) +{ + return __hmax(val.x, val.y); +} + +// #ifdef ENABLE_BF16 +template <> +__device__ inline __nv_bfloat16 cuda_max(__nv_bfloat162 val) +{ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + return __hmax(val.x, val.y); +#endif +} +#endif + +// Binary maximum: compute the max of two scalar types +template <typename T> +__device__ inline T cuda_max(T val1, T val2) +{ + return (val1 > val2) ? val1 : val2; +} + +template <typename T> +__device__ inline T cuda_abs(T val) +{ + assert(false); + return {}; +} + +template <> +__device__ inline float cuda_abs(float val) +{ + return fabs(val); +} + +template <> +__device__ inline float2 cuda_abs(float2 val) +{ + return make_float2(fabs(val.x), fabs(val.y)); +} + +template <> +__device__ inline half cuda_abs(half val) +{ + return __habs(val); +} + +template <> +__device__ inline half2 cuda_abs(half2 val) +{ + return __habs2(val); +} + +// #ifdef ENABLE_BF16 + +#if __CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__) +template <> +__device__ inline __nv_bfloat16 cuda_abs(__nv_bfloat16 val) +{ + return __habs(val); +} + +template <> +__device__ inline __nv_bfloat162 cuda_abs(__nv_bfloat162 val) +{ + return __habs2(val); +} +#endif \ No newline at end of file diff --git a/examples/CPM.cu/tests/long_prompt_gen.py b/examples/CPM.cu/tests/long_prompt_gen.py new file mode 100644 index 00000000..5eaa6ddf --- /dev/null +++ b/examples/CPM.cu/tests/long_prompt_gen.py @@ -0,0 +1,110 @@ +import random +import string +import os +import glob +import argparse + +def collect_code_files(extensions=('.cpp', '.c', '.h', '.hpp'), + dirs=('src', 'include', 'ggml/src', 'examples', 'tools'), + max_length=500000, + verbose=True): + """ + Collect code file contents from repository + :param extensions: File extensions to collect + :param dirs: Directories to search + :param max_length: Maximum character count + :param verbose: Whether to print detailed information + :return: Concatenated code text and total character count + """ + all_content = [] + total_chars = 0 + file_count = 0 + + # Find all matching file paths + all_files = [] + for directory in dirs: + if not os.path.exists(directory): + if verbose: + print(f"Directory does not exist: {directory}") + continue + + for ext in extensions: + pattern = os.path.join(directory, f'**/*{ext}') + for file_path in glob.glob(pattern, recursive=True): + all_files.append(file_path) + + # Randomly shuffle file order + random.shuffle(all_files) + + # Read file contents + for file_path in all_files: + try: + with open(file_path, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read() + file_header = f"\n\n----- FILE: {file_path} -----\n\n" + + # Check if adding this file would exceed maximum length limit + remaining_chars = max_length - total_chars if max_length > 0 else float('inf') + + if len(file_header) + len(content) > remaining_chars: + # Need to truncate file + if remaining_chars > len(file_header) + 100: # Ensure at least 100 characters of content + truncated_content = content[:remaining_chars - len(file_header)] + truncated_content += "\n\n... (file content truncated) ..." + + all_content.append(file_header) + all_content.append(truncated_content) + total_chars += len(file_header) + len(truncated_content) + file_count += 1 + + if verbose: + print(f"Added file (truncated): {file_path}, current total chars: {total_chars}, file count: {file_count}") + break # Reached maximum length, exit loop + else: + # Remaining space too small to add meaningful content + if verbose: + print(f"Skipped file: {file_path}, insufficient remaining space") + continue + else: + # Add file completely + all_content.append(file_header) + all_content.append(content) + total_chars += len(file_header) + len(content) + file_count += 1 + + if verbose: + print(f"Added file: {file_path}, current total chars: {total_chars}, file count: {file_count}") + + if max_length > 0 and total_chars >= max_length: + break # Reached maximum length, exit loop + except Exception as e: + if verbose: + print(f"Error reading file {file_path}: {e}") + + return ''.join(all_content), total_chars, file_count + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Generate long code prompt file') + parser.add_argument('--output', '-o', type=str, default="prompt.txt", help='Output file name') + parser.add_argument('--length', '-l', type=int, default=350000, help='Maximum character count') + parser.add_argument('--question', "-q", type=str, default="\nSummarize the code", help='Question to ask') + + args = parser.parse_args() + + # Use default values directly + dirs = ['src', 'cpmcu', 'scripts'] + extensions = ['.cpp', '.c', '.h', '.hpp', '.py'] + + # Collect code files + code_content, content_length, file_count = collect_code_files( + extensions=extensions, + dirs=dirs, + max_length=args.length, + ) + + # Add question + final_content = code_content + f"\n\n{args.question}" + + with open(args.output, "w", encoding="utf-8") as f: + f.write(final_content) + print(f"Generated {args.output}, total length: {len(final_content)} characters, contains {file_count} files") \ No newline at end of file diff --git a/examples/CPM.cu/tests/test_generate.py b/examples/CPM.cu/tests/test_generate.py new file mode 100644 index 00000000..d4012f5c --- /dev/null +++ b/examples/CPM.cu/tests/test_generate.py @@ -0,0 +1,544 @@ +import torch +from cpmcu.llm import LLM +from cpmcu.llm_w4a16_gptq_marlin import W4A16GPTQMarlinLLM +from cpmcu.speculative import LLM_with_eagle +from cpmcu.speculative.eagle_base_quant.eagle_base_w4a16_marlin_gptq import W4A16GPTQMarlinLLM_with_eagle +from transformers import AutoTokenizer +import time +import numpy as np +import argparse +import sys +import os +from huggingface_hub import snapshot_download + +# Default Configuration +default_config = { + 'test_minicpm4': True, + 'use_stream': True, + 'apply_eagle': True, + 'apply_quant': True, + 'apply_sparse': True, + 'apply_eagle_quant': True, + "minicpm4_yarn": True, # TODO default is True for long context test, better implementation + 'frspec_vocab_size': 32768, + 'eagle_window_size': 8 * 128, + 'eagle_num_iter': 2, + 'eagle_topk_per_iter': 10, + 'eagle_tree_size': 12, + 'apply_compress_lse': True, + 'sink_window_size': 1, + 'block_window_size': 8, + 'sparse_topk_k': 64, + "sparse_switch": 1, + 'num_generate': 256, + 'chunk_length': 2048, + 'memory_limit': 0.9, + 'cuda_graph': True, + 'dtype': torch.float16, + 'use_terminators': True, + "temperature": 0.0, + "random_seed": None, +} + +# Demo Configuration: Only for MiniCPM4 demo, will be deleted after release +demo_config = { + 'use_enter': False, + 'use_decode_enter': False, +} + +# Combined Default Configurations +default_config = {**default_config, **demo_config} + +def create_argument_parser(): + """Create and configure argument parser""" + parser = argparse.ArgumentParser(description='Generate text using LLM models') + + # Basic arguments + parser.add_argument('--path-prefix', '--path_prefix', '-p', type=str, default='openbmb', + help='Path prefix for model directories, you can use openbmb to download models, or your own path (default: openbmb)') + + # Prompt arguments + parser.add_argument('--prompt-file', '--prompt_file', type=str, default=None, + help='Path to prompt file (default: None)') + parser.add_argument('--prompt-text', '--prompt_text', type=str, default=None, + help='Direct prompt text (default: None)') + parser.add_argument('--prompt-haystack', '--prompt_haystack', type=int, default=15, + help='Generate haystack prompt with specified length in thousands (e.g., 120 for 120k tokens)') + + # Model configuration boolean arguments + parser.add_argument('--test-minicpm4', '--test_minicpm4', action='store_true', + help='Use MiniCPM4 model') + parser.add_argument('--no-test-minicpm4', '--no_test_minicpm4', action='store_false', dest='test_minicpm4', + help='Do not use MiniCPM4 model') + parser.add_argument('--use-stream', '--use_stream', action='store_true', + help='Use stream generation') + parser.add_argument('--no-use-stream', '--no_use_stream', action='store_false', dest='use_stream', + help='Do not use stream generation') + parser.add_argument('--apply-eagle', '--apply_eagle', action='store_true', + help='Use Eagle speculative decoding') + parser.add_argument('--no-apply-eagle', '--no_apply_eagle', action='store_false', dest='apply_eagle', + help='Do not use Eagle speculative decoding') + parser.add_argument('--apply-quant', '--apply_quant', action='store_true', + help='Use quantized model') + parser.add_argument('--no-apply-quant', '--no_apply_quant', action='store_false', dest='apply_quant', + help='Do not use quantized model') + parser.add_argument('--apply-sparse', '--apply_sparse', action='store_true', + help='Use sparse attention') + parser.add_argument('--no-apply-sparse', '--no_apply_sparse', action='store_false', dest='apply_sparse', + help='Do not use sparse attention') + parser.add_argument('--apply-eagle-quant', '--apply_eagle_quant', action='store_true', + help='Use quantized Eagle model') + parser.add_argument('--no-apply-eagle-quant', '--no_apply_eagle_quant', action='store_false', dest='apply_eagle_quant', + help='Do not use quantized Eagle model') + parser.add_argument('--apply-compress-lse', '--apply_compress_lse', action='store_true', + help='Apply LSE compression, only support on sparse attention, this will compress the stage 1 kv twice for LSE pre-computing') + parser.add_argument('--no-apply-compress-lse', '--no_apply_compress_lse', action='store_false', dest='apply_compress_lse', + help='Do not apply LSE compression') + parser.add_argument('--cuda-graph', '--cuda_graph', action='store_true', + help='Use CUDA graph optimization') + parser.add_argument('--no-cuda-graph', '--no_cuda_graph', action='store_false', dest='cuda_graph', + help='Do not use CUDA graph optimization') + parser.add_argument('--use-teminators', '--use_terminators', action='store_true', + help='Use teminators, if not specified, the generation will not be interrupted') + parser.add_argument('--no-use-teminators', '--no_use_terminators', action='store_false', dest='use_terminators', + help='Do not use teminators') + parser.add_argument('--minicpm4-yarn', '--minicpm4_yarn', action='store_true', + help='Use MiniCPM4 YARN, this is for very long context, such as > 32/64k tokens') + parser.add_argument('--no-minicpm4-yarn', '--no_minicpm4_yarn', action='store_false', dest='minicpm4_yarn', + help='Do not use MiniCPM4 YARN') + + # Model configuration numeric arguments + parser.add_argument('--frspec-vocab-size', '--frspec_vocab_size', type=int, default=None, + help='Frequent speculation vocab size (default: from default_config)') + parser.add_argument('--eagle-window-size', '--eagle_window_size', type=int, default=None, + help='Eagle window size (default: from default_config)') + parser.add_argument('--eagle-num-iter', '--eagle_num_iter', type=int, default=None, + help='Eagle number of iterations (default: from default_config)') + parser.add_argument('--eagle-topk-per-iter', '--eagle_topk_per_iter', type=int, default=None, + help='Eagle top-k per iteration (default: from default_config)') + parser.add_argument('--eagle-tree-size', '--eagle_tree_size', type=int, default=None, + help='Eagle tree size (default: from default_config)') + parser.add_argument('--sink-window-size', '--sink_window_size', type=int, default=None, + help='Sink window size of sparse attention (default: from default_config)') + parser.add_argument('--block-window-size', '--block_window_size', type=int, default=None, + help='Block window size of sparse attention (default: from default_config)') + parser.add_argument('--sparse-topk-k', '--sparse_topk_k', type=int, default=None, + help='Sparse attention top-k (default: from default_config)') + parser.add_argument('--sparse-switch', '--sparse_switch', type=int, default=None, + help='Context length of dense and sparse attention switch (default: from default_config)') + parser.add_argument('--num-generate', '--num_generate', type=int, default=None, + help='Number of tokens to generate (default: from default_config)') + parser.add_argument('--chunk-length', '--chunk_length', type=int, default=None, + help='Chunk length for prefilling (default: from default_config)') + parser.add_argument('--memory-limit', '--memory_limit', type=float, default=None, + help='Memory limit for use (default: from default_config)') + parser.add_argument('--temperature', '--temperature', type=float, default=None, + help='Temperature for processing (default: from default_config)') + parser.add_argument('--dtype', type=str, default=None, choices=['float16', 'bfloat16'], + help='Model dtype (default: from default_config)') + parser.add_argument('--random-seed', '--random_seed', type=int, default=None, + help='Random seed for processing (default: from default_config)') + # Demo arguments + parser.add_argument('--use-enter', '--use_enter', action='store_true', + help='Use enter to generate') + parser.add_argument('--no-use-enter', '--no_use_enter', action='store_false', dest='use_enter', + help='Do not use enter to generate') + parser.add_argument('--use-decode-enter', '--use_decode_enter', action='store_true', + help='Use enter before decode phase') + parser.add_argument('--no-use-decode-enter', '--no_use_decode_enter', action='store_false', dest='use_decode_enter', + help='Do not use enter before decode phase') + + return parser + +def parse_and_merge_config(default_config): + """Parse arguments and merge with default configuration""" + parser = create_argument_parser() + args = parser.parse_args() + + # Set default values to None for boolean arguments that weren't specified + bool_args = [key for key, value in default_config.items() if isinstance(value, bool)] + for arg in bool_args: + # Convert underscores to hyphens for command line argument names + arg_hyphen = arg.replace('_', '-') + # Check for both formats (hyphen and underscore) + arg_specified = (f'--{arg_hyphen}' in sys.argv or f'--no-{arg_hyphen}' in sys.argv or + f'--{arg}' in sys.argv or f'--no-{arg}' in sys.argv) + if not arg_specified: + setattr(args, arg, None) + + # Override default config with command line arguments + config = default_config.copy() + + # Define parameter mappings for automatic override (exclude dtype which needs special handling) + auto_override_params = [key for key in default_config.keys() if key != 'dtype'] + + # Override config values if arguments are provided + for param in auto_override_params: + arg_value = getattr(args, param) + if arg_value is not None: + config[param] = arg_value + + # Handle dtype separately due to type conversion + if args.dtype is not None: + config['dtype'] = torch.float16 if args.dtype == 'float16' else torch.bfloat16 + + return args, config + +def check_or_download_model(path): + if os.path.exists(path): + return path + else: + cache_dir = snapshot_download(path) + return cache_dir + +def get_model_paths(path_prefix, config): + """Get model paths based on configuration""" + if config['test_minicpm4']: + if config['apply_eagle_quant']: + eagle_repo_id = f"{path_prefix}/MiniCPM4-8B-Eagle-FRSpec-QAT-cpmcu" + else: + eagle_repo_id = f"{path_prefix}/MiniCPM4-8B-Eagle-FRSpec" + else: + eagle_repo_id = f"{path_prefix}/EAGLE-LLaMA3-Instruct-8B" + + if not config['apply_quant']: + if config['test_minicpm4']: + base_repo_id = f"{path_prefix}/MiniCPM4-8B" + else: + base_repo_id = f"{path_prefix}/Meta-Llama-3-8B-Instruct" + else: + base_repo_id = f"{path_prefix}/MiniCPM4-8B-marlin-cpmcu" + + eagle_path = check_or_download_model(eagle_repo_id) + base_path = check_or_download_model(base_repo_id) + + return eagle_path, base_path, eagle_repo_id, base_repo_id + +def apply_minicpm4_yarn_config(llm, config): + """Apply MiniCPM4 YARN configuration to model config""" + yarn_factors = [ + 0.9977997200264581, 1.014658295992452, 1.0349680404997148, 1.059429246056193, + 1.0888815016813513, 1.1243301355211495, 1.166977103606075, 1.2182568066927284, + 1.2798772354275727, 1.3538666751582975, 1.4426259039919596, 1.5489853358570191, + 1.6762658237220625, 1.8283407612492941, 2.0096956085876183, 2.225478927469756, + 2.481536379650452, 2.784415934557119, 3.1413289096347365, 3.560047844772632, + 4.048719380066383, 4.752651957515948, 5.590913044973868, 6.584005926629993, + 7.7532214876576155, 9.119754865903639, 10.704443927019176, 12.524994176518703, + 14.59739595363613, 16.93214476166354, 19.53823297353041, 22.417131025031697, + 25.568260840911098, 28.991144156566317, 32.68408069090375, 36.65174474170465, + 40.90396065611201, 45.4664008671033, 50.37147343433591, 55.6804490772103, + 61.470816952306556, 67.8622707390618, 75.00516023410414, 83.11898235973767, + 92.50044360202462, 103.57086856690864, 116.9492274587385, 118.16074567836519, + 119.18497548708795, 120.04810876261652, 120.77352815196981, 121.38182790207875, + 121.89094985353891, 122.31638758099915, 122.6714244963338, 122.9673822552567, + 123.21386397019609, 123.41898278254268, 123.58957065488238, 123.73136519024158, + 123.84917421274221, 123.94701903496814, 124.02825801299717, 124.09569231686116 + ] + + # Create or modify rope_scaling configuration + if not hasattr(llm.config, 'rope_scaling') or llm.config.rope_scaling is None: + llm.config.rope_scaling = {} + + llm.config.rope_scaling['rope_type'] = 'longrope' + llm.config.rope_scaling['long_factor'] = yarn_factors + llm.config.rope_scaling['short_factor'] = yarn_factors + print("Forcing MiniCPM4 YARN rope_scaling parameters") + +def create_model(eagle_path, base_path, config): + """Create model instance based on configuration""" + common_kwargs = { + 'dtype': config['dtype'], + 'chunk_length': config['chunk_length'], + 'cuda_graph': config['cuda_graph'], + 'apply_sparse': config['apply_sparse'], + 'sink_window_size': config['sink_window_size'], + 'block_window_size': config['block_window_size'], + 'sparse_topk_k': config['sparse_topk_k'], + 'sparse_switch': config['sparse_switch'], + 'apply_compress_lse': config['apply_compress_lse'], + 'memory_limit': config['memory_limit'], + 'use_enter': config['use_enter'], + 'use_decode_enter': config['use_decode_enter'], + 'temperature': config['temperature'], + 'random_seed': config['random_seed'], + } + + eagle_kwargs = { + 'num_iter': config['eagle_num_iter'], + 'topk_per_iter': config['eagle_topk_per_iter'], + 'tree_size': config['eagle_tree_size'], + 'eagle_window_size': config['eagle_window_size'], + 'frspec_vocab_size': config['frspec_vocab_size'], + 'apply_eagle_quant': config['apply_eagle_quant'], + 'use_rope': config['test_minicpm4'], + 'use_input_norm': config['test_minicpm4'], + 'use_attn_norm': config['test_minicpm4'] + } + + if config['apply_quant']: + if config['apply_eagle']: + return W4A16GPTQMarlinLLM_with_eagle(eagle_path, base_path, **common_kwargs, **eagle_kwargs) + else: + return W4A16GPTQMarlinLLM(base_path, **common_kwargs) + else: + if config['apply_eagle']: + return LLM_with_eagle(eagle_path, base_path, **common_kwargs, **eagle_kwargs) + else: + return LLM(base_path, **common_kwargs) + +def make_input(tokenizer, args, prompt_content=None): + """Prepare input tokens from prompt content or file""" + + def make_haystack_prompt(digits, target_length_k): + """Generate haystack prompt with pass key hidden in context""" + # Simple calculation based on target length + a = target_length_k * 16 # Scale factor for before text + b = target_length_k * 33 # Scale factor for after text + + head = "There is a pass key hidden in the context. Find it and remember it. I will quiz you about it later. " + before = "The sky is blue. The tree is green. The flower is red. The sun is yellow. " * a + needle = f"The pass key is {digits}. Remember it. The pass key is {digits}" + after = "The sky is blue. The tree is green. The flower is red. The sun is yellow. " * b + query = "Now, give me the exact number of the pass key. The pass key is " + return head + before + needle + after + query + + if prompt_content is None: + # Check if file or text was specified + file_specified = args.prompt_file is not None + text_specified = args.prompt_text is not None + + if not file_specified and not text_specified: + # Case 1: Neither file nor text specified, use haystack with default value + print(f"Using haystack prompt with {args.prompt_haystack}k tokens (default)") + prompt_content = make_haystack_prompt(681725493, args.prompt_haystack) + else: + # Case 2 & 3: At least one of file or text specified, ignore haystack + prompt_content = "" + + # Load from file if specified + if file_specified: + try: + with open(args.prompt_file, 'r', encoding='utf-8') as f: + file_content = f.read().strip() + prompt_content += file_content + print(f"Loaded prompt from file: {args.prompt_file}") + except FileNotFoundError: + print(f"Warning: {args.prompt_file} not found, skipping file content") + except Exception as e: + print(f"Error reading {args.prompt_file}: {e}, skipping file content") + + # Append text if specified + if text_specified: + if file_specified and prompt_content: + # Case 3: Both specified, append text to file content + prompt_content += "\n" + args.prompt_text + print(f"Appended prompt text to file content") + else: + # Case 2: Only text specified + prompt_content = args.prompt_text + print(f"Using direct prompt text input") + + # Fallback if no content was loaded + if not prompt_content: + print(f"No valid content found, using default Chinese prompt") + prompt_content = "北京有哪些好玩的地方" + + if config['test_minicpm4'] and not file_specified and not text_specified: # TODO: haystack need w/o chat template, may be a bug + prompt = prompt_content + else: + prompt = tokenizer.apply_chat_template([{"role": "user", "content": prompt_content}], tokenize=False, add_generation_prompt=True) + input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda().int() + + print(f"Input token count: {input_ids.shape[1]}") + if input_ids.shape[1] <= 100: # Only show input_ids for short prompts + print(f"Input_ids: {input_ids}") + + return input_ids + +def print_generation_summary(mode, prefill_stats, decode_stats, config): + """Print unified generation summary for both modes""" + print("\n" + "=" * 50) + print(f"{mode} Generation Summary:") + print("=" * 50) + + # Prefill statistics + print(f"Prefill length: {prefill_stats['length']}") + print(f"Prefill time: {prefill_stats['time']:.2f} s") + print(f"Prefill tokens/s: {prefill_stats['tokens_per_sec']:.2f}") + + # Eagle-specific statistics + if config['apply_eagle'] and 'mean_accept_length' in decode_stats: + print(f"Mean accept length: {decode_stats['mean_accept_length']:.2f}") + # print(f"Decode token/s when acc = 1: {decode_stats['tokens_per_sec'] / decode_stats['mean_accept_length']:.2f}") + + # Decode statistics + print(f"Decode length: {decode_stats['length']}") + print(f"Decode time: {decode_stats['time']:.2f} s") + print(f"Decode tokens/s: {decode_stats['tokens_per_sec']:.2f}") + +def run_stream_generation(llm, input_ids, config, teminators, tokenizer): + """Run streaming generation and display results""" + print("\nGenerated text (streaming output):") + print("-" * 50) + + # Statistics tracking + prefill_length = input_ids.shape[1] + prefill_time = 0.0 + total_decode_time = 0.0 + + generated_text = "" + total_decode_tokens = 0 + accept_lengths = [] + + try: + for result in llm.generate(input_ids, config['num_generate'], teminators=teminators, use_stream=True): + token = result['token'] + text = result['text'] + is_finished = result['is_finished'] + + # Track timing statistics + if 'prefill_time' in result and result['prefill_time'] > 0: + prefill_time = result['prefill_time'] + if 'decode_time' in result and result['decode_time'] > 0: + total_decode_time = result['decode_time'] + + generated_text += text + total_decode_tokens += 1 + + # Track accept lengths for eagle models + if 'accept_length' in result and result['accept_length'] > 0: + accept_lengths.append(result['accept_length']) + + print(text, end='', flush=True) + + if is_finished: + break + + except KeyboardInterrupt: + print("\n\nGeneration interrupted by user.") + + prefill_stats = { + 'length': prefill_length, + 'time': prefill_time, + 'tokens_per_sec': prefill_length / prefill_time if prefill_time > 0 else 0 + } + + decode_stats = { + 'length': total_decode_tokens, + 'time': total_decode_time, + 'tokens_per_sec': total_decode_tokens / total_decode_time if total_decode_time > 0 else 0 + } + + if config['apply_eagle'] and accept_lengths: + decode_stats['mean_accept_length'] = np.mean(accept_lengths) + + print_generation_summary("Stream", prefill_stats, decode_stats, config) + +def run_non_stream_generation(llm, input_ids, config, teminators, tokenizer): + """Run non-stream generation and display results""" + prefill_length = input_ids.shape[1] + + torch.cuda.synchronize() + start_time = time.time() + gen_result = llm.generate(input_ids, config['num_generate'], teminators=teminators, use_stream=False) + torch.cuda.synchronize() + end_time = time.time() + + # Handle different return formats based on model type + if config['apply_eagle']: + # Eagle models return: (tokens, accept_lengths, decode_time, prefill_time) + tokens, accept_lengths, decode_time, prefill_time = gen_result + decode_length = len(tokens) + mean_accept_length = np.mean(accept_lengths) + else: + # Base models return: (tokens, decode_time, prefill_time) + tokens, decode_time, prefill_time = gen_result + decode_length = len(tokens) + mean_accept_length = None + + print("\n[Generated Result]") + print(tokenizer.decode(tokens).strip()) + print("\n") + + prefill_stats = { + 'length': prefill_length, + 'time': prefill_time, + 'tokens_per_sec': prefill_length / prefill_time if prefill_time > 0 else 0 + } + + decode_stats = { + 'length': decode_length, + 'time': decode_time, + 'tokens_per_sec': decode_length / decode_time if decode_time > 0 else 0 + } + + if mean_accept_length is not None: + decode_stats['mean_accept_length'] = mean_accept_length + + print_generation_summary("Non-Stream", prefill_stats, decode_stats, config) + +def print_config(config, use_stream): + """Print all configuration parameters""" + print("=" * 50) + print("Configuration Parameters:") + print("=" * 50) + print(f"Features: eagle={config['apply_eagle']}, quant={config['apply_quant']}, sparse={config['apply_sparse']}") + print(f"Generation: num_generate={config['num_generate']}, chunk_length={config['chunk_length']}, use_terminators={config['use_terminators']}, use_stream={config['use_stream']}") + print(f"Sampling: temperature={config['temperature']}, random_seed={config['random_seed']}") + print(f"Demo: use_enter={config['use_enter']}, use_decode_enter={config['use_decode_enter']}") + print(f"Others: dtype={config['dtype']}, minicpm4_yarn={config['minicpm4_yarn']}, cuda_graph={config['cuda_graph']}, memory_limit={config['memory_limit']}") + if config['apply_sparse']: + print(f"Sparse Attention: sink_window={config['sink_window_size']}, block_window={config['block_window_size']}, sparse_topk_k={config['sparse_topk_k']}, sparse_switch={config['sparse_switch']}, compress_lse={config['apply_compress_lse']}") + if config['apply_eagle']: + print(f"Eagle: eagle_num_iter={config['eagle_num_iter']}, eagle_topk_per_iter={config['eagle_topk_per_iter']}, eagle_tree_size={config['eagle_tree_size']}, apply_eagle_quant={config['apply_eagle_quant']}, window_size={config['eagle_window_size']}, frspec_vocab_size={config['frspec_vocab_size']}") + print("=" * 50) + print() + +def main(args, config): + if not config['test_minicpm4']: + print(f"test_minicpm4 is False, set apply_sparse to False") + config['apply_sparse'] = False + + print_config(config, config['use_stream']) + + # Get model paths and create model + eagle_path, base_path, eagle_repo_id, base_repo_id = get_model_paths(args.path_prefix, config) + tokenizer = AutoTokenizer.from_pretrained(base_path, trust_remote_code=True) + llm = create_model(eagle_path, base_path, config) + + # Prepare input + input_ids = make_input(tokenizer, args) + teminators = [] if not config['use_terminators'] else [tokenizer.eos_token_id] + + # Initialize model + llm.init_storage() + + # Apply MiniCPM4 YARN configuration if enabled + if config['test_minicpm4'] and config['minicpm4_yarn']: + apply_minicpm4_yarn_config(llm, config) + + if config['apply_eagle'] and config['frspec_vocab_size'] > 0: + fr_path = f'{eagle_path}/freq_{config["frspec_vocab_size"]}.pt' + if not os.path.exists(fr_path): + cache_dir = snapshot_download( + eagle_repo_id, + ignore_patterns=["*.bin", "*.safetensors"], + ) + fr_path = os.path.join(cache_dir, f'freq_{config["frspec_vocab_size"]}.pt') + + with open(fr_path, 'rb') as f: + token_id_remap = torch.tensor(torch.load(f, weights_only=True), dtype=torch.int32, device="cpu") + llm._load("token_id_remap", token_id_remap, cls="eagle") + llm.load_from_hf() + + # Run generation + if config['use_stream']: + run_stream_generation(llm, input_ids, config, teminators, tokenizer) + else: + run_non_stream_generation(llm, input_ids, config, teminators, tokenizer) + + llm.print_perf_summary() + +if __name__ == "__main__": + args, config = parse_and_merge_config(default_config) + main(args, config)