diff --git a/examples/BMTrain/.dockerignore b/examples/BMTrain/.dockerignore
new file mode 100644
index 00000000..c1591543
--- /dev/null
+++ b/examples/BMTrain/.dockerignore
@@ -0,0 +1,147 @@
+**/__pycache__/
+**/*.py[cod]
+**/*$py.class
+
+# C extensions
+**/*.so
+
+# Distribution / packaging
+**/.Python
+**/build/
+**/develop-eggs/
+**/dist/
+**/downloads/
+**/eggs/
+**/.eggs/
+**/lib/
+**/lib64/
+**/parts/
+**/sdist/
+**/var/
+**/wheels/
+**/share/python-wheels/
+**/*.egg-info/
+**/.installed.cfg
+**/*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+**/*.manifest
+**/*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+**/htmlcov/
+**/.tox/
+**/.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+**/*.cover
+**/*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+**/.pyre/
+
+# pytype static type analyzer
+**/.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+**/*.pt
+
+**/*.npy
+
+**/.DS_Store
+
+**/log
+**/*.qdrep
+!bmtrain/dist
\ No newline at end of file
diff --git a/examples/BMTrain/.github/ISSUE_TEMPLATE/bug_report.yml b/examples/BMTrain/.github/ISSUE_TEMPLATE/bug_report.yml
new file mode 100644
index 00000000..47c89fe6
--- /dev/null
+++ b/examples/BMTrain/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -0,0 +1,79 @@
+
+name: 🐞 Bug Report
+description: Report a bug/issue related to the PyTorch-based parallel model training toolkit
+title: "[BUG]
"
+labels: ["bug"]
+body:
+- type: checkboxes
+ attributes:
+ label: Is there an existing issue for this?
+ description: Please search to see if an issue already exists for the bug you encountered.
+ options:
+ - label: I have searched the existing issues
+ required: true
+- type: textarea
+ attributes:
+ label: Description of the Bug
+ description: Provide a clear and concise description of what the bug is.
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Environment Information
+ description: |
+ Provide details about your environment.
+ Example:
+ - GCC version: 9.3.0
+ - Torch version: 1.9.0
+ - Linux system version: Ubuntu 20.04
+ - CUDA version: 11.4
+ - Torch's CUDA version (as per `torch.cuda.version()`): 11.3
+ value: |
+ - GCC version:
+ - Torch version:
+ - Linux system version:
+ - CUDA version:
+ - Torch's CUDA version (as per `torch.cuda.version()`):
+ render: markdown
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: To Reproduce
+ description: Provide the steps and details to reproduce the behavior.
+ placeholder: |
+ 1. Describe your environment setup, including any specific version requirements.
+ 2. Clearly state the steps you took to trigger the error, including the specific code you executed.
+ 3. Identify the file and line number where the error occurred, along with the full traceback of the error. Make sure to have `NCCL_DEBUG=INFO` and `CUDA_LAUNCH_BLOCKING=True` set to get accurate debug information.
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Expected Behavior
+ description: Describe what you expected to happen when you executed the code.
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Screenshots
+ description: If applicable, please add screenshots to help explain your problem.
+ validations:
+ required: false
+- type: textarea
+ attributes:
+ label: Additional Information
+ description: |
+ Provide any other relevant context or information about the problem here.
+ Links? References? Anything that will give us more context about the issue you are encountering!
+ Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in.
+ validations:
+ required: false
+- type: checkboxes
+ attributes:
+ label: Confirmation
+ description: Please confirm that you have reviewed all of the above requirements and verified the information provided before submitting this issue.
+ options:
+ - label: I have reviewed and verified all the information provided in this report.
+ validations:
+ required: true
+
diff --git a/examples/BMTrain/.github/ISSUE_TEMPLATE/build_err.yml b/examples/BMTrain/.github/ISSUE_TEMPLATE/build_err.yml
new file mode 100644
index 00000000..940fd7ed
--- /dev/null
+++ b/examples/BMTrain/.github/ISSUE_TEMPLATE/build_err.yml
@@ -0,0 +1,94 @@
+name: 🛠️ Build Error
+description: Report a build error for this project
+title: "[BUILD ERROR] "
+labels: ["Build ERR"]
+body:
+- type: checkboxes
+ id: prev_issue
+ attributes:
+ label: Is there an existing issue for this?
+ description: Please search to see if an issue already exists for the build error you encountered.
+ options:
+ - label: I have searched the existing issues
+ required: true
+- type: textarea
+ attributes:
+ label: Description of the Build Error
+ description: Provide a clear and concise description of what the build error is.
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Expected Behavior
+ description: Provide a clear and concise description of what you expected to happen.
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: To Reproduce
+ description: Describe the steps you took to trigger the build error. Include any commands you executed or files you modified.
+ placeholder: |
+ 1. Go to '...'
+ 2. Click on '....'
+ 3. Scroll down to '....'
+ 4. See error
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Environment Information
+ description: |
+ Provide details about your environment.
+ Example:
+ - Operating System version: Ubuntu 20.04
+ - GCC version: 9.3.0
+ - Pybind version: 2.8.1
+ - CUDA version: 11.4
+ - NVIDIA NCCL CU11 version: 2.14.3
+ - CMake version: 3.21.2
+ - Pip version: 22.0.0
+ value: |
+ - Operating System version:
+ - GCC version:
+ - Pybind version:
+ - CUDA version:
+ - NVIDIA NCCL CU11 version:
+ - CMake version:
+ - Pip version:
+ render: markdown
+ validations:
+ required: true
+- type: dropdown
+ attributes:
+ label: Installation Method
+ description: Please indicate if the error occurred during source code installation or when using the pip install .whl method.
+ options:
+ - Source Code Installation
+ - Pip Install .whl Method
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Full Error Traceback
+ description: Provide the complete error traceback.
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Additional Information
+ description: |
+ Provide any other relevant context or information about the problem here.
+ Links? References? Anything that will give us more context about the issue you are encountering!
+ Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in.
+ validations:
+ required: false
+- type: checkboxes
+ id: confirm
+ attributes:
+ label: Confirmation
+ description: Please confirm that you have reviewed all of the above requirements and verified the information provided before submitting this report.
+ options:
+ - label: I have reviewed and verified all the information provided in this report.
+ validations:
+ required: true
+
diff --git a/examples/BMTrain/.github/ISSUE_TEMPLATE/features_request.yml b/examples/BMTrain/.github/ISSUE_TEMPLATE/features_request.yml
new file mode 100644
index 00000000..769948c2
--- /dev/null
+++ b/examples/BMTrain/.github/ISSUE_TEMPLATE/features_request.yml
@@ -0,0 +1,30 @@
+name: 🚀Feature Request
+description: Suggest an idea for this project
+title: "[Feature] "
+labels: ["enhancement"]
+assignees: []
+body:
+- type: textarea
+ attributes:
+ label: Is your feature request related to a problem? Please describe.
+ description: "A clear and concise description of what the problem is. Example: I'm always frustrated when..."
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Describe the solution you'd like
+ description: "A clear and concise description of what you want to happen."
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Describe alternatives you've considered
+ description: "A clear and concise description of any alternative solutions or features you've considered."
+ validations:
+ required: false
+- type: textarea
+ attributes:
+ label: Additional context
+ description: "Add any other context or screenshots about the feature request here."
+ validations:
+ required: false
diff --git a/examples/BMTrain/.github/pull_request_template.md b/examples/BMTrain/.github/pull_request_template.md
new file mode 100644
index 00000000..87a5f9b0
--- /dev/null
+++ b/examples/BMTrain/.github/pull_request_template.md
@@ -0,0 +1,29 @@
+## Pull Request Template
+
+### Issue Reference
+Please mention the issue number if applicable, or write "N/A" if it's a new feature.
+
+Issue #...
+
+### Description
+Please describe your changes in detail. If it resolves an issue, please state how it resolves it.
+
+### Type of Change
+- [ ] Bug fix (non-breaking change which fixes an issue)
+- [ ] New feature (non-breaking change which adds functionality)
+- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
+- [ ] This change requires a documentation update
+
+### How Has This Been Tested?
+Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce.
+
+### Checklist
+- [ ] I have read the [CONTRIBUTING](../../CONTRIBUTING.md) document.
+- [ ] My code follows the code style of this project.
+- [ ] My change requires a change to the documentation.
+- [ ] I have updated the documentation accordingly.
+- [ ] I have added tests to cover my changes.
+- [ ] All new and existing tests passed.
+
+### Additional Information
+Any additional information, configuration, or data that might be necessary for the review.
diff --git a/examples/BMTrain/.github/workflows/build.yml b/examples/BMTrain/.github/workflows/build.yml
new file mode 100644
index 00000000..11aa61f6
--- /dev/null
+++ b/examples/BMTrain/.github/workflows/build.yml
@@ -0,0 +1,35 @@
+name: Build
+
+on:
+ pull_request_target:
+ types: [opened, reopened, synchronize]
+ branches:
+ - 'dev'
+ - 'main'
+ push:
+ branches:
+ - 'dev'
+
+jobs:
+ build-archive-wheel:
+
+ uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@main
+ secrets: inherit
+
+ fake-publish:
+ needs: build-archive-wheel
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Set Up the Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.9
+
+ - name: Download distribution files
+ uses: actions/download-artifact@v4
+ with:
+ name: dist
+ path: dist
diff --git a/examples/BMTrain/.github/workflows/build_whl.yml b/examples/BMTrain/.github/workflows/build_whl.yml
new file mode 100644
index 00000000..9116b598
--- /dev/null
+++ b/examples/BMTrain/.github/workflows/build_whl.yml
@@ -0,0 +1,89 @@
+name: Build wheels in docker and archive
+
+on:
+ workflow_call:
+ secrets:
+ DOCKERHUB_TOKEN:
+ required: true
+ DOCKERHUB_USERNAME:
+ required: true
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ['37', '38', '39', '310', '311']
+
+
+ steps:
+
+ - name: Check the disk space and clear unnecessary library
+ run: |
+ rm -rf /home/runner/work/BMTrain/BMTrain/dist
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf /opt/ghc
+ sudo rm -rf "/usr/local/share/boost"
+ sudo rm -rf "$AGENT_TOOLSDIRECTORY"
+ df -hl
+
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Login to DockerHub
+ uses: docker/login-action@v2
+ with:
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
+
+ - name: Pull Docker image
+ run: docker pull pytorch/manylinux-cuda113:latest
+
+ - name: Run Docker image and execute script
+ run: |
+ version=${{ matrix.python-version }}
+ docker run -e BUILD_DOCKER_ENV=1 -e CUDACXX=/usr/local/cuda-11.3/bin/nvcc -e PATH="/opt/rh/devtoolset-9/root/usr/bin:$PATH" -e LD_LIBRARY_PATH="/opt/rh/devtoolset-9/root/usr/lib64:/opt/rh/devtoolset-9/root/usr/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH" -v ${{ github.workspace }}:/workspace/BMTrain -i pytorch/manylinux-cuda113:latest /bin/bash -c "cd /workspace/BMTrain;/opt/python/cp${version}*/bin/pip install build; /opt/python/cp${version}*/bin/python -m build .;for file in dist/*-linux_x86_64.whl; do mv \"\$file\" \"\${file//-linux_x86_64/-manylinux2014_x86_64}\"; done"
+
+ - name: Upload wheels as artifacts
+ uses: actions/upload-artifact@v4
+ with:
+ name: wheels_py${{ matrix.python-version }}
+ path: dist/*.whl
+
+ - name: Upload source distribution (only once)
+ if: matrix.python-version == '37' # Only upload source distribution once
+ uses: actions/upload-artifact@v4
+ with:
+ name: source_dist
+ path: dist/*.tar.gz
+
+ archive:
+ runs-on: ubuntu-latest
+ needs: build
+ steps:
+ - name: Download all wheels
+ uses: actions/download-artifact@v4
+ with:
+ path: wheels
+ pattern: wheels_py*
+
+ - name: Download source distribution
+ uses: actions/download-artifact@v4
+ with:
+ path: source_dist
+ name: source_dist
+
+ - name: Combine all wheels into a single directory
+ run: |
+ mkdir -p dist
+ find wheels -name '*.whl' -exec mv {} dist/ \;
+ find source_dist -name '*.tar.gz' -exec mv {} dist/ \;
+
+ - name: Archive distribution files
+ uses: actions/upload-artifact@v4
+ with:
+ name: dist
+ path: |
+ dist/*.tar.gz
+ dist/*.whl
+ overwrite: true
\ No newline at end of file
diff --git a/examples/BMTrain/.github/workflows/publish.yaml b/examples/BMTrain/.github/workflows/publish.yaml
new file mode 100644
index 00000000..fd9b8c50
--- /dev/null
+++ b/examples/BMTrain/.github/workflows/publish.yaml
@@ -0,0 +1,41 @@
+name: Build and Publish to PyPI
+
+on:
+ push:
+ tags:
+
+ - "v*.*.*"
+
+jobs:
+
+ build-archive-wheel:
+ uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@main
+ secrets:
+ DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
+ DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
+
+ publish:
+ needs: build-archive-wheel
+ runs-on: ubuntu-latest
+ steps:
+ - name: Set Up the Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.9
+
+ - name: Install twine
+ run: python -m pip install twine
+
+ - name: Download distribution files
+ uses: actions/download-artifact@v4
+ with:
+ name: dist
+ path: dist
+
+ - name: Publish to PyPI
+ env:
+ TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
+ TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
+ run: |
+ cd dist
+ python -m twine upload *.tar.gz *.whl
diff --git a/examples/BMTrain/.github/workflows/release.yml b/examples/BMTrain/.github/workflows/release.yml
new file mode 100644
index 00000000..bafc2173
--- /dev/null
+++ b/examples/BMTrain/.github/workflows/release.yml
@@ -0,0 +1,44 @@
+name: Publish release in Github
+
+on:
+ push:
+ tags:
+ - "v*.*.*"
+
+jobs:
+
+ build-archive-wheel:
+
+ uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@main
+ secrets:
+ DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
+ DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
+
+ publish:
+ needs: build-archive-wheel
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Set Up the Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.9
+
+ - name: Download distribution files
+ uses: actions/download-artifact@v4
+ with:
+ name: dist
+ path: dist
+
+ - name: Upload Distribution Files to Existing Release
+ uses: softprops/action-gh-release@v1
+ with:
+ files: |
+ dist/*.tar.gz
+ dist/*.whl
+ tag_name: ${{ github.ref_name }} # 使用当前触发工作流的 tag
+ token: ${{ secrets.RELEASE_TOKEN }}
+ env:
+ GITHUB_REPOSITORY: OpenBMB/BMTrain
diff --git a/examples/BMTrain/.gitignore b/examples/BMTrain/.gitignore
new file mode 100644
index 00000000..75138102
--- /dev/null
+++ b/examples/BMTrain/.gitignore
@@ -0,0 +1,155 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+*.pt
+
+*.npy
+
+bminference/version.py
+
+.DS_Store
+
+log
+*.qdrep
+.vscode
+
+!bmtrain/dist
+tests/test_log.txt
+tests/*.opt
+tests/*.ckp
\ No newline at end of file
diff --git a/examples/BMTrain/CMakeLists.txt b/examples/BMTrain/CMakeLists.txt
new file mode 100644
index 00000000..e027e7da
--- /dev/null
+++ b/examples/BMTrain/CMakeLists.txt
@@ -0,0 +1,65 @@
+cmake_minimum_required(VERSION 3.18)
+project(bmtrain)
+enable_language(C)
+enable_language(CXX)
+set(CMAKE_CUDA_ARCHITECTURES "61;62;70;72;75;80")
+enable_language(CUDA)
+set(CMAKE_CXX_STANDARD 14)
+set(CMAKE_CXX_STANDARD_REQUIRED True)
+set(CMAKE_CUDA_STANDARD 14)
+set(CMAKE_CUDA_STANDARD_REQUIRED True)
+
+set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_62,code=sm_62 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_72,code=sm_72 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80")
+
+if(NOT DEFINED ENV{BUILD_DOCKER_ENV} OR "$ENV{BUILD_DOCKER_ENV}" STREQUAL "0")
+ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_86,code=sm_86")
+ set(AVX_FLAGS "${AVX_FLAGS} -march=native")
+else()
+ message("Building in docker environment, skipping compute_86 and enable all avx flag")
+ set(AVX_FLAGS "${AVX_FLAGS} -mavx -mfma -mf16c -mavx512f")
+endif()
+
+set(CMAKE_BUILD_RPATH $ORIGIN)
+set(CMAKE_INSTALL_RPATH $ORIGIN)
+set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/)
+
+find_package(NCCL REQUIRED)
+find_package(Python ${PYTHON_VERSION} EXACT COMPONENTS Interpreter Development.Module REQUIRED)
+message (STATUS "Python_EXECUTABLE: ${Python_EXECUTABLE}")
+execute_process(COMMAND ${Python_EXECUTABLE} "-c"
+ "import pybind11; print(pybind11.get_cmake_dir())"
+ OUTPUT_VARIABLE PYBIND11_CMAKE_DIR
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+message (STATUS "PYBIND11_CMAKE_DIR: ${PYB
+IND11_CMAKE_DIR}")
+list(APPEND CMAKE_PREFIX_PATH ${PYBIND11_CMAKE_DIR})
+find_package(pybind11 REQUIRED)
+
+message (STATUS "CMAKE_INSTALL_RPATH: ${CMAKE_INSTALL_RPATH}")
+
+file(GLOB_RECURSE SOURCES "csrc/*.cpp")
+file(GLOB_RECURSE CUDA_SOURCES "csrc/cuda/*.cu")
+
+
+pybind11_add_module(C ${SOURCES} ${CUDA_SOURCES})
+
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${AVX_FLAGS}")
+
+target_link_libraries(C PRIVATE
+ "-Wl,-Bsymbolic"
+ "-Wl,-Bsymbolic-functions"
+ ${NCCL_LIBRARIES}
+)
+target_include_directories(C PRIVATE ${NCCL_INCLUDE_DIRS})
+target_compile_definitions(C
+ PRIVATE VERSION_INFO=${EXAMPLE_VERSION_INFO})
+
+set_target_properties(C PROPERTIES CUDA_ARCHITECTURES "61;62;70;72;75;80")
+
+target_include_directories(C
+ PRIVATE "csrc/include"
+ PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
+)
+
+
+
diff --git a/examples/BMTrain/CONTRIBUTING.md b/examples/BMTrain/CONTRIBUTING.md
new file mode 100644
index 00000000..e1c4458f
--- /dev/null
+++ b/examples/BMTrain/CONTRIBUTING.md
@@ -0,0 +1,55 @@
+# Contributing to BMTrain
+
+We welcome everyone's effort to make the community and the package better. You are welcomed to propose an issue, make a pull request or help others in the community. All of the efforts are appreciated!
+
+There are many ways that you can contribute to BMTrain:
+
+- ✉️ Submitting an issue.
+- ⌨️ Making a pull request.
+- 🤝 Serving the community.
+
+## Submitting an issue
+You can submit an issue if you find bugs or require new features and enhancements. Here are some principles:
+
+1. **Language.** It is better to write your issue in English so that more people can understand and help you more conveniently.
+2. **Search.** It is a good habit to search existing issues using the search bar of GitHub. Make sure there are no duplicated or similar issues with yours and if yes, check their solutions first.
+3. **Format.** It is also very helpful to write the issue with a good writing style. We provide templates of common types of issues and everyone is encouraged to use these templates. If the templates do not fit in your issue, feel free to open a blank one.
+4. **Writing style.** Write your issues in clear and concise words. It is also important to provide enough details for others to help. For example in a bug report, it is better to provide your running environment and minimal lines of code to reproduce it.
+
+## Making a pull request (PR)
+You can also write codes to contribute. The codes may include a bug fix, a new enhancement, or a new running example. Here we provide the steps to make a pull request:
+
+1. **Combine the PR with an issue.** Make us and others know what you are going to work on. If your codes try to solve an existing issue, you should comment on the issue and make sure there are no others working on it. If you are proposing a new enhancement, submit an issue first and we can discuss it with you before you work on it.
+
+2. **Fork the repository.** Fork the repository to your own GitHub space by clicking the "Fork" button. Then clone it on your disk and set the remote repo:
+```git
+$ git clone https://github.com//BMTrain.git
+$ cd BMTrain
+$ git remote add upstream https://github.com/OpenBMB/BMTrain.git
+```
+
+3. **Write your code.** Change to a new branch to work on your modifications.
+```git
+$ git checkout -b your-branch-name
+```
+You are encouraged to think up a meaningful and descriptive name for your branch.
+
+4. **Make a pull request.** After you finish coding, you should first rebase your code and solve the conflicts with the remote codes:
+```git
+$ git fetch upstream
+$ git rebase upstream/main
+```
+Then you can push your codes to your own repo:
+```git
+$ git push -u origin your-branch-name
+```
+Finally, you can make the pull request from your GitHub repo and merge it with ours. Your codes will be merged into the main repo after our code review.
+
+
+## Serving the community
+
+Besides submitting issues and PRs, you can also join our community and help others. Efforts like writing the documents, answering questions as well as discussing new features are appreciated and welcomed. It will also be helpful if you can post your opinions and feelings about using our package on social media.
+
+We are now developing a reward system and all your contributions will be recorded and rewarded in the future.
+
+
diff --git a/examples/BMTrain/Dockerfile b/examples/BMTrain/Dockerfile
new file mode 100644
index 00000000..8e6cbddf
--- /dev/null
+++ b/examples/BMTrain/Dockerfile
@@ -0,0 +1,20 @@
+FROM nvidia/cuda:10.2-devel
+WORKDIR /build
+RUN apt update && apt install -y --no-install-recommends \
+ build-essential \
+ python3-dev \
+ python3-pip \
+ python3-setuptools \
+ python3-wheel
+RUN pip3 install torch==1.10.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
+RUN pip3 install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple
+RUN apt install iputils-ping opensm libopensm-dev libibverbs1 libibverbs-dev -y --no-install-recommends
+ENV TORCH_CUDA_ARCH_LIST=6.1;7.0;7.5
+ENV BMT_AVX512=1
+ADD other_requirements.txt other_requirements.txt
+RUN pip3 install --upgrade pip && pip3 install -r other_requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
+ADD . .
+RUN python3 setup.py install
+
+WORKDIR /root
+ADD example example
\ No newline at end of file
diff --git a/examples/BMTrain/LICENSE b/examples/BMTrain/LICENSE
new file mode 100644
index 00000000..7ad7f39e
--- /dev/null
+++ b/examples/BMTrain/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2022 OpenBMB
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/examples/BMTrain/MANIFEST.in b/examples/BMTrain/MANIFEST.in
new file mode 100644
index 00000000..a6f97fa4
--- /dev/null
+++ b/examples/BMTrain/MANIFEST.in
@@ -0,0 +1,4 @@
+graft csrc
+include CMakeLists.txt
+graft cmake
+
diff --git a/examples/BMTrain/README-ZH.md b/examples/BMTrain/README-ZH.md
new file mode 100644
index 00000000..d36953a2
--- /dev/null
+++ b/examples/BMTrain/README-ZH.md
@@ -0,0 +1,369 @@
+
+
+## 最新动态
+- 2022/06/14 **BMTrain** [0.1.7](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.7) 发布。支持了ZeRO-2优化!
+- 2022/03/30 **BMTrain** [0.1.2](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.2) 发布。适配了[OpenPrompt](https://github.com/thunlp/OpenPrompt)和 [OpenDelta](https://github.com/thunlp/OpenDelta)工具包。
+- 2022/03/16 **BMTrain** [0.1.1](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.1) 公开发布了第一个稳定版本,修复了 beta 版本中的一些问题。
+- 2022/02/11 **BMTrain** [0.0.15](https://github.com/OpenBMB/BMTrain/releases/tag/0.0.15) 公开发布了第一个 beta 版本。
+
+
+
+## 总览
+
+BMTrain 是一个高效的大模型训练工具包,可以用于训练数百亿参数的大模型。BMTrain 可以在分布式训练模型的同时,能够保持代码的简洁性。
+
+
+
+## 文档
+我们的[文档](https://bmtrain.readthedocs.io/en/latest/index.html)提供了关于工具包的更多信息。
+
+
+
+## 安装
+
+- 用 pip 安装(推荐): ``pip install bmtrain``
+
+- 从源代码安装: 下载工具包,然后运行 ``pip install .`` (setup.py的安装方式将会在未来被setuptools弃用)
+
+安装 BMTrain 可能需要花费数分钟的时间,因为在安装时需要编译 c/cuda 源代码。
+我们推荐直接在训练环境中编译 BMTrain,以避免不同环境带来的潜在问题。
+
+
+
+
+## 使用说明
+
+### 步骤 1: 启用 BMTrain
+
+首先,你需要在代码开头初始化 BMTrain。正如在使用 PyTorch 的分布式训练模块需要在代码开头使用 **init_process_group** 一样,使用 BMTrain 需要在代码开头使用 **init_distributed**。
+
+```python
+import bmtrain as bmt
+bmt.init_distributed(
+ seed=0,
+ zero_level=3, # 目前支持2和3
+ # ...
+)
+```
+
+**注意:** 使用 BMTrain 时请不要使用 PyTorch 自带的 `distributed` 模块,包括 `torch.distributed.init_process_group` 以及相关通信函数。
+
+### 步骤 2: 使用 ZeRO 优化
+
+使用ZeRO优化需要对模型代码进行简单替换:
+
+* `torch.nn.Module` -> `bmtrain.DistributedModule`
+* `torch.nn.Parameter` -> `bmtrain.DistributedParameter`
+
+并在 transformer 模块上使用 `bmtrain.CheckpointBlock`。
+
+下面是一个例子:
+
+**原始代码**
+
+```python
+import torch
+class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.param = torch.nn.Parameter(torch.empty(1024))
+ self.module_list = torch.nn.ModuleList([
+ SomeTransformerBlock(),
+ SomeTransformerBlock(),
+ SomeTransformerBlock()
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+**替换后代码**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule): # 修改这里
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024)) # 修改这里
+ self.module_list = torch.nn.ModuleList([
+ bmt.CheckpointBlock(SomeTransformerBlock()), # 修改这里
+ bmt.CheckpointBlock(SomeTransformerBlock()), # 修改这里
+ bmt.CheckpointBlock(SomeTransformerBlock()) # 修改这里
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+### 步骤 3: 通信优化
+
+为了进一步缩短通信额外开销,将通信与运算时间重叠,可以使用 `TransformerBlockList` 来进一步优化。
+
+在使用时需要对代码进行简单替换:
+
+* `torch.nn.ModuleList` -> `bmtrain.TransformerBlockList`
+* `for module in self.module_list: x = module(x, ...)` -> `x = self.module_list(x, ...)`
+
+**原始代码**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule):
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024))
+ self.module_list = torch.nn.ModuleList([
+ bmt.CheckpointBlock(SomeTransformerBlock()),
+ bmt.CheckpointBlock(SomeTransformerBlock()),
+ bmt.CheckpointBlock(SomeTransformerBlock())
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+**替换后代码**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule):
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024))
+ self.module_list = bmt.TransformerBlockList([ # 修改这里
+ bmt.CheckpointBlock(SomeTransformerBlock()),
+ bmt.CheckpointBlock(SomeTransformerBlock()),
+ bmt.CheckpointBlock(SomeTransformerBlock())
+ ])
+
+ def forward(self):
+ x = self.param
+ x = self.module_list(x, 1, 2, 3) # 修改这里
+ return x
+
+```
+
+### 步骤 4: 运行分布式训练代码
+
+BMTrain 使用 PyTorch 原生的分布式训练启动器,你可以根据 PyTorch 版本选择下列命令中的一个。
+
+* `${MASTER_ADDR}` 为主节点的 IP 地址
+* `${MASTER_PORT}` 为主节点的端口
+* `${NNODES}` 为节点数量
+* `${GPU_PER_NODE}` 为每个节点的 GPU 数量
+* `${NODE_RANK}` 为本节点的 rank
+
+#### torch.distributed.launch
+```shell
+$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.py
+```
+
+#### torchrun
+
+```shell
+$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.py
+```
+
+更多信息请参考 PyTorch [官方文档](https://pytorch.org/docs/stable/distributed.html#launch-utility)。
+
+## 样例
+
+我们提供了一个使用 BMTrain 训练 GPT-2 的[样例](https://github.com/OpenBMB/BMTrain/tree/main/example)。
+代码主要包含以下几个部分。
+
+### 第 1 部分: 模型定义
+
+```
+├── layers
+│ ├── attention.py
+│ ├── embedding.py
+│ ├── feedforward.py
+│ ├── __init__.py
+│ ├── layernorm.py
+│ └── linear.py
+└── models
+ ├── gpt.py
+ └── __init__.py
+```
+
+上面是代码的目录结构。
+
+我们定义了 GPT-2 需要的所有模型层,并使用 BMTrain 的 `DistributedModule` 和 `DistributedParameter` 来启用 ZeRO 优化。
+
+### 第 2 部分: 初始化 BMTrain
+
+```python
+bmtrain.init_distributed(seed=0)
+
+model = GPT(
+ num_layers=8,
+ vocab_size=10240,
+ dim_model=2560,
+ dim_head=80,
+ num_heads=32,
+ dim_ff=8192,
+ max_distance=1024,
+ bias=True,
+ dtype=torch.half
+)
+
+bmtrain.init_parameters(model) # 或者使用`bmtrain.load`加载checkpoint
+
+# ... 其他初始化(例如数据集) ...
+```
+
+`bmtrain.init_distributed(seed=0)` 用于初始化分布式训练环境,并设置随机数种子便于复现。
+
+`bmtrain.init_parameters(model)` 用于初始化模型的分布式参数。
+
+### 第 3 部分: 初始化优化器和学习率调整策略
+
+```python
+loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100)
+optimizer = bmtrain.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2)
+lr_scheduler = bmtrain.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0)
+```
+
+BMTrain 支持**所有** PyTorch 原生的优化器和损失函数,同时你也可以使用 BMTrain 提供的融合(fused)优化器用于混合精度训练。
+
+此外,在 `bmtrain.lr_scheduler` 中 BMTrain 也提供了常见的学习率调整策略。
+
+### 第 4 部分: 训练
+
+```python
+# 新建优化器管理器实例
+optim_manager = bmtrain.optim.OptimManager(loss_scale=1024)
+# 将所有的 optimzer 及(可选)其对应的 lr_scheduler 收入优化器管理器管理。
+optim_manager.add_optimizer(optimizer, lr_scheduler)
+# 可以再次调用 add_optimizer 加入其他优化器
+
+for iteration in range(1000):
+ # ... 为每个rank加载数据 ...
+
+ # 前向传播并计算梯度
+ pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1)
+ logits = model(
+ enc_input,
+ pos,
+ pos < enc_length[:, None]
+ )
+ batch, seq_len, vocab_out_size = logits.size()
+
+ loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len))
+
+ global_loss = bmtrain.sum_loss(loss).item() # 聚合所有rank上的损失, 仅用于输出训练日志
+
+ # 梯度清零
+ optim_manager.zero_grad() # 为每个 optimizer 调用 zero_grad
+
+ # 损失缩放和反向传播
+ optim_manager.backward(loss)
+
+ # 梯度裁剪
+ grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=1.0)
+
+ # 更新参数
+ optim_manager.step()
+
+ # ... 保存checkpoint、打印日志 ...
+```
+
+这部分代码略有些长,但写起来就像常见的训练代码一样,你不需要为分布式训练调整太多的代码。
+
+你可以根据代码中的注释来了解各部分代码的作用。
+
+唯一需要说明的是 `optim_manager`。在使用 BMTrain 后,优化器的部分相关操作需要有一些细节上的调整。我们在 `optim_manager` 帮你实现了这些细节, 你只需要通过 `add_optimizer` 将优化器和学习率调整策略收入 `optim_manager` 管理,并由 `optim_manger` 代为执行 `zero_grad()`, `backward()`, `clip_grad_norm()` 和 `step()` 等操作。
+
+如果你没有使用混合精度训练,你可以不用损失缩放,只需要将 `OptimManger(loss_scale=None)` 构造函数中 `loss_scale` 置为 None 即可, 这也是 `OptimManager` 的默认构造参数。
+
+如果你使用了混合精度训练,**损失缩放**是混合精度训练中的一项常用技术,我们在 `optim_manager.backward(loss)` 帮你对 `loss` 进行了放缩,用于避免梯度下溢。只需要将 `OptimManger` 构造函数中 `loss_scale` 置为一个浮点数即可。 `loss_scale` 会在训练过程中根据梯度进行自适应的调整。
+
+
+
+## 性能
+
+我们训练了一个有130亿参数的 GPT-2 模型,使用了4台服务器,每台服务器有8张V100显卡。我们测试了训练过程中每个GPU的吞吐量(每个GPU每秒处理的样本数),结果见下表。
+
+模型结构:
+* 40层
+* 128个注意力头
+* 5120的隐藏层维数
+* 512的序列长度
+
+
+| batch size | 8 | 16 | 24 | 32 |
+|-------------|-------|-------|:------|:------|
+| BMTrain | 24.15 | 26.94 | 29.42 | 28.28 |
+| ZeRO3(mp=1) | 14.88 | 21.69 | 24.38 | - |
+| ZeRO3(mp=4) | 15.51 | - | - | - |
+| ZeRO3(mp=8) | 15.51 | - | - | - |
+| ZeRO2(mp=1) | - | - | - | - |
+| ZeRO2(mp=4) | 22.85 | - | - | - |
+| ZeRO2(mp=8) | 21.33 | - | - | - |
+
+**ZeROa(mp=b)** 表示 DeepSpeed + Megatron ZeRO stage a 和 model parallelism = b。
+
+表格中的 **-** 表示超出显存。
+
+## 模型支持
+
+我们已经将大多数常见的 NLP 模型移植到了 BMTrain 中。你可以在 [ModelCenter](https://github.com/OpenBMB/ModelCenter) 项目中找到支持模型的列表。
+
+## 开源社区
+欢迎贡献者参照我们的[贡献指南](https://github.com/OpenBMB/BMTrain/blob/master/CONTRIBUTING.md)贡献相关代码。
+
+您也可以在其他平台与我们沟通交流:
+- QQ群: 735930538
+- 官方网站: https://www.openbmb.org
+- 微博: http://weibo.cn/OpenBMB
+- Twitter: https://twitter.com/OpenBMB
+
+## 开源许可
+
+该工具包使用[Apache 2.0](https://github.com/OpenBMB/BMTrain/blob/main/LICENSE)开源许可证。
+
+## 其他说明
+
+`BMTrain` 工具包对 PyTorch 进行了底层修改,如果你的程序输出了意料之外的结果,可以在 issue 中提交相关信息。
diff --git a/examples/BMTrain/README.md b/examples/BMTrain/README.md
new file mode 100644
index 00000000..134929f5
--- /dev/null
+++ b/examples/BMTrain/README.md
@@ -0,0 +1,375 @@
+
+
+## What's New
+- 2024/02/26 **BMTrain** [1.0.0](https://github.com/OpenBMB/BMTrain/releases/tag/v1.0.0) released. Code refactoring and Tensor parallel support. See the detail in [update log](docs/UPDATE_1.0.0.md)
+- 2023/08/17 **BMTrain** [0.2.3](https://github.com/OpenBMB/BMTrain/releases/tag/v0.2.3) released. See the [update log](docs/UPDATE_0.2.3.md).
+- 2022/12/15 **BMTrain** [0.2.0](https://github.com/OpenBMB/BMTrain/releases/tag/0.2.0) released. See the [update log](docs/UPDATE_0.2.0.md).
+- 2022/06/14 **BMTrain** [0.1.7](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.7) released. ZeRO-2 optimization is supported!
+- 2022/03/30 **BMTrain** [0.1.2](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.2) released. Adapted to [OpenPrompt](https://github.com/thunlp/OpenPrompt)and [OpenDelta](https://github.com/thunlp/OpenDelta).
+- 2022/03/16 **BMTrain** [0.1.1](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.1) has publicly released the first stable version, which fixes many bugs that were in the beta version.
+- 2022/02/11 **BMTrain** [0.0.15](https://github.com/OpenBMB/BMTrain/releases/tag/0.0.15) has publicly released the first beta version.
+
+
+
+## Overview
+
+BMTrain is an efficient large model training toolkit that can be used to train large models with tens of billions of parameters. It can train models in a distributed manner while keeping the code as simple as stand-alone training.
+
+
+
+## Documentation
+Our [documentation](https://bmtrain.readthedocs.io/en/latest/index.html) provides more information about the package.
+
+
+
+## Installation
+
+- From pip (recommend) : ``pip install bmtrain``
+
+- From source code: download the package and run ``pip install .``
+
+Installing BMTrain may take a few to ten minutes, as it requires compiling the c/cuda source code at the time of installation.
+We recommend compiling BMTrain directly in the training environment to avoid potential problems caused by the different environments.
+
+
+
+## Usage
+
+### Step 1: Initialize BMTrain
+
+Before you can use BMTrain, you need to initialize it at the beginning of your code. Just like using the distributed module of PyTorch requires the use of **init_process_group** at the beginning of the code, using BMTrain requires the use of **init_distributed** at the beginning of the code.
+
+```python
+import bmtrain as bmt
+bmt.init_distributed(
+ seed=0,
+ # ...
+)
+```
+
+**NOTE:** Do not use PyTorch's distributed module and its associated communication functions when using BMTrain.
+
+### Step 2: Enable ZeRO Optimization
+
+To enable ZeRO optimization, you need to make some simple replacements to the original model's code.
+
+* `torch.nn.Module` -> `bmtrain.DistributedModule`
+* `torch.nn.Parameter` -> `bmtrain.DistributedParameter`
+
+And wrap the transformer blocks with `bmtrain.Block`.
+
+Here is an example.
+
+**Original**
+
+```python
+import torch
+class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.param = torch.nn.Parameter(torch.empty(1024))
+ self.module_list = torch.nn.ModuleList([
+ SomeTransformerBlock(),
+ SomeTransformerBlock(),
+ SomeTransformerBlock()
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+**Replaced**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule): # changed here
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024)) # changed here
+ self.module_list = torch.nn.ModuleList([
+ bmt.Block(SomeTransformerBlock(), zero_level=3), # changed here, support 2 and 3 now
+ bmt.Block(SomeTransformerBlock(), zero_level=3), # changed here, support 2 and 3 now
+ bmt.Block(SomeTransformerBlock(), zero_level=3) # changed here, support 2 and 3 now
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+### Step 3: Enable Communication Optimization
+
+
+To further reduce the extra overhead of communication and overlap communication with computing time, `TransformerBlockList` can be used for optimization.
+
+You can enable them by making the following substitutions to the code:
+
+* `torch.nn.ModuleList` -> `bmtrain.TransformerBlockList`
+* `for module in self.module_list: x = module(x, ...)` -> `x = self.module_list(x, ...)`
+
+**Original**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule):
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024))
+ self.module_list = torch.nn.ModuleList([
+ bmt.Block(SomeTransformerBlock()),
+ bmt.Block(SomeTransformerBlock()),
+ bmt.Block(SomeTransformerBlock())
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+**Replaced**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule):
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024))
+ self.module_list = bmt.TransformerBlockList([ # changed here
+ bmt.Block(SomeTransformerBlock()),
+ bmt.Block(SomeTransformerBlock()),
+ bmt.Block(SomeTransformerBlock())
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+### Step 4: Launch Distributed Training
+
+BMTrain uses the same launch command as the distributed module of PyTorch.
+
+You can choose one of them depending on your version of PyTorch.
+
+* `${MASTER_ADDR}` means the IP address of the master node.
+* `${MASTER_PORT}` means the port of the master node.
+* `${NNODES}` means the total number of nodes.
+* `${GPU_PER_NODE}` means the number of GPUs per node.
+* `${NODE_RANK}` means the rank of this node.
+
+#### torch.distributed.launch
+```shell
+$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.py
+```
+
+#### torchrun
+
+```shell
+$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.py
+```
+
+
+For more information, please refer to the [documentation](https://pytorch.org/docs/stable/distributed.html#launch-utility).
+
+## Example
+
+We provide an [example](https://github.com/OpenBMB/BMTrain/tree/main/example) of training GPT-2 based on BMTrain.
+The code mainly consists of the following parts.
+
+### Part 1: Model Definition
+
+```
+├── layers
+│ ├── attention.py
+│ ├── embedding.py
+│ ├── feedforward.py
+│ ├── __init__.py
+│ ├── layernorm.py
+│ └── linear.py
+└── models
+ ├── gpt.py
+ └── __init__.py
+```
+
+Above is the directory structure of the code in the part of Model Definition.
+
+We defined all the layers needed in GPT-2 and used BMTrain's `DistributedModule` and `DistributedParameter` to enable ZeRO optimization.
+
+### Part 2: BMTrain Initialization
+
+```python
+bmtrain.init_distributed(seed=0)
+
+model = GPT(
+ num_layers=8,
+ vocab_size=10240,
+ dim_model=2560,
+ dim_head=80,
+ num_heads=32,
+ dim_ff=8192,
+ max_distance=1024,
+ bias=True,
+ dtype=torch.half
+)
+
+bmtrain.init_parameters(model) # or loading checkpoint use `bmtrain.load`
+
+# ... other initialization (dataset) ...
+```
+
+`bmtrain.init_distributed(seed=0)` is used to initialize the distributed training environment and set the random seed for reproducibility.
+
+`bmtrain.init_parameters(model)` is used to initialize the distributed parameters of the model.
+
+### Part 3: Intialization of the Optimizer and LR Scheduler
+
+```python
+loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100)
+optimizer = bmtrain.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2)
+lr_scheduler = bmtrain.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0)
+```
+
+BMTrain supports *all* the PyTorch native optimizers and loss functions, and you can also use the fused optimizer provided by BMTrain for mixed-precision training.
+
+In addition, BMTrain also provides the common LRScheduler in the `bmtrain.lr_scheduler` module.
+
+### Part 4: Training Loop
+
+```python
+# create a new instance of optimizer manager
+optim_manager = bmtrain.optim.OptimManager(loss_scale=1024)
+# let optim_manager handle all the optimizer and (optional) their corresponding lr_scheduler
+optim_manager.add_optimizer(optimizer, lr_scheduler)
+# add_optimizer can be called multiple times to add other optimizers.
+
+for iteration in range(1000):
+ # ... load data for each rank ...
+
+ # forward pass and calculate loss
+ pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1)
+ logits = model(
+ enc_input,
+ pos,
+ pos < enc_length[:, None]
+ )
+ batch, seq_len, vocab_out_size = logits.size()
+
+ loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len))
+
+ global_loss = bmtrain.sum_loss(loss).item() # sum the loss across all ranks. This is only used for the training log
+
+ # zero grad
+ optim_manager.zero_grad() # calling zero_grad for each optimizer
+
+ # loss scale and backward
+ optim_manager.backward(loss)
+
+ # clip grad norm
+ grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=1.0)
+
+ # optimizer step
+ optim_manager.step()
+
+ # ... save checkpoint or print logs ...
+```
+
+The training loop part will be slightly longer, but just like a normal training loop, you don't need to adapt much to distributed training.
+
+You can follow the comments in the code to get an idea of what each section of code is doing.
+
+The only additional note is `optimizer`. After using BMTrain, some details in optimizers should be adjusted. We have implemented all those details needed in `optim_manager`. What you need is just letting `optim_manager` to handle all the optimizers by `add_optimizer`, and letting `optim_manager` do `zero_grad()`, `backward()`, `clip_grad_norm()` and `step()` instead.
+
+If you are not using the mixed-precision training, you can train without `loss_scale`. Just set `loss_scale` to None in the `__init__` function of `OptimManager(loss_scale=None)`, which is also the default.
+
+If you are using mixed-precision training, *loss scale* is the technique widely used in mixed precision training to prevent gradient underflow. By using `optim_manager.backward(loss)` to scale the `loss` before backward and set `loss_scale` to some floating number in the `__init__` function of `OptimManager`。The `loss_scale` would be adjusted adaptively based on the gradient during training.
+
+
+
+## Performance
+
+We trained a GPT-2 model with 13B parameters using 4 servers with 8 V100s on each server, and measured the throughput of each GPU during the training process (samples per GPU per second).
+
+Model structure:
+* 40 layers
+* 128 attention heads
+* 5120 hidden dimension
+* 512 sequence length
+
+
+| batch size | 8 | 16 | 24 | 32 |
+|-------------|-------|-------|:------|:------|
+| BMTrain | 24.15 | 26.94 | 29.42 | 28.28 |
+| ZeRO3(mp=1) | 14.88 | 21.69 | 24.38 | - |
+| ZeRO3(mp=4) | 15.51 | - | - | - |
+| ZeRO3(mp=8) | 15.51 | - | - | - |
+| ZeRO2(mp=1) | - | - | - | - |
+| ZeRO2(mp=4) | 22.85 | - | - | - |
+| ZeRO2(mp=8) | 21.33 | - | - | - |
+
+**ZeROa(mp=b)** means DeepSpeed + Megatron ZeRO stage a and model parallelism = b.
+
+**-** in the table means out of memory.
+
+## Supported Models
+
+We have migrated most of the common models in NLP to the BMTrain. You can find the list of supported models in the repo [ModelCenter](https://github.com/OpenBMB/ModelCenter).
+
+## Community
+We welcome everyone to contribute codes following our [contributing guidelines](https://github.com/OpenBMB/BMTrain/blob/master/CONTRIBUTING.md).
+
+You can also find us on other platforms:
+- QQ Group: 735930538
+- Website: https://www.openbmb.org
+- Weibo: http://weibo.cn/OpenBMB
+- Twitter: https://twitter.com/OpenBMB
+
+## License
+The package is released under the [Apache 2.0](https://github.com/OpenBMB/BMTrain/blob/master/LICENSE) License.
+
+## Other Notes
+
+`BMTrain` makes underlying changes to PyTorch, so if your program outputs unexpected results, you can submit information about it in an issue.
+
diff --git a/examples/BMTrain/Release.txt b/examples/BMTrain/Release.txt
new file mode 100644
index 00000000..7c8a41be
--- /dev/null
+++ b/examples/BMTrain/Release.txt
@@ -0,0 +1,9 @@
+## What's Changed
+* Using pytorch's hook mechanism to refactor ZeRO, checkpoint, pipeline, communication implementation by @zkh2016 in #128 #159
+* Add Bf16 support by @Achazwl in #136
+* Tensor parallel implementation by @Achazwl @zkh2016 @MayDomine in #153
+* Async save state_dict by @zkh2016 in #171
+* `AdamOffloadOptimizer` can save whole gathered state by @MayDomine in #184
+* New test for new version's bmtrain by @Achazwl @JerryYin777 @MayDomine
+**Full Changelog**: https://github.com/OpenBMB/BMTrain/compare/0.2.3...1.0.0
+
diff --git a/examples/BMTrain/bmtrain/__init__.py b/examples/BMTrain/bmtrain/__init__.py
new file mode 100644
index 00000000..f4ac3642
--- /dev/null
+++ b/examples/BMTrain/bmtrain/__init__.py
@@ -0,0 +1,26 @@
+from .utils import print_block, print_dict, print_rank, see_memory, load_nccl_pypi
+try:
+ from . import nccl
+except:
+ load_nccl_pypi()
+from .global_var import config, world_size, rank
+from .init import init_distributed
+
+from .parameter import DistributedParameter, ParameterInitializer
+from .layer import DistributedModule
+from .param_init import init_parameters, grouped_parameters
+from .synchronize import synchronize, sum_loss, wait_loader, gather_result
+from .block_layer import Block, TransformerBlockList
+from .wrapper import BMTrainModelWrapper
+from .pipe_layer import PipelineTransformerBlockList
+from . import debug
+from .store import save, load
+
+from . import loss
+from . import distributed
+from . import nn
+from . import optim
+from . import inspect
+from . import lr_scheduler
+
+CheckpointBlock = Block
diff --git a/examples/BMTrain/bmtrain/benchmark/__init__.py b/examples/BMTrain/bmtrain/benchmark/__init__.py
new file mode 100644
index 00000000..571d621f
--- /dev/null
+++ b/examples/BMTrain/bmtrain/benchmark/__init__.py
@@ -0,0 +1,3 @@
+from .all_gather import all_gather
+from .reduce_scatter import reduce_scatter
+from .send_recv import send_recv
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/benchmark/all_gather.py b/examples/BMTrain/bmtrain/benchmark/all_gather.py
new file mode 100644
index 00000000..b2f2ee7c
--- /dev/null
+++ b/examples/BMTrain/bmtrain/benchmark/all_gather.py
@@ -0,0 +1,28 @@
+from .. import nccl
+from .shape import SHAPES
+from ..global_var import config
+from ..utils import round_up, print_rank
+from .utils import format_size
+import torch
+
+def all_gather():
+ current_stream = torch.cuda.current_stream()
+ for shape in SHAPES:
+ global_size = round_up(shape, config['world_size'] * 2)
+ partition_size = global_size // config['world_size']
+
+ partition_tensor = torch.empty( partition_size // 2, dtype=torch.half, device="cuda" )
+ global_tensor = torch.empty( global_size // 2, dtype=torch.half, device="cuda" )
+
+ start_evt = torch.cuda.Event(enable_timing=True)
+ end_evt = torch.cuda.Event(enable_timing=True)
+
+ current_stream.record_event(start_evt)
+ nccl.allGather(partition_tensor.storage(), global_tensor.storage(), config['comm'])
+ current_stream.record_event(end_evt)
+ current_stream.synchronize()
+ time_usage = start_evt.elapsed_time(end_evt)
+
+ bw = global_size / 1024 / 1024 / 1024 * 1000 / time_usage
+ print_rank("All gather:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(global_size), time_usage, bw))
+
diff --git a/examples/BMTrain/bmtrain/benchmark/reduce_scatter.py b/examples/BMTrain/bmtrain/benchmark/reduce_scatter.py
new file mode 100644
index 00000000..75733556
--- /dev/null
+++ b/examples/BMTrain/bmtrain/benchmark/reduce_scatter.py
@@ -0,0 +1,28 @@
+from .. import nccl
+from .shape import SHAPES
+from ..global_var import config
+from ..utils import round_up, print_rank
+from .utils import format_size
+import torch
+
+def reduce_scatter():
+ current_stream = torch.cuda.current_stream()
+ for shape in SHAPES:
+ global_size = round_up(shape, config['world_size'])
+ partition_size = global_size // config['world_size']
+
+ partition_tensor = torch.empty( partition_size // 2, dtype=torch.half, device="cuda" )
+ global_tensor = torch.empty( global_size // 2, dtype=torch.half, device="cuda" )
+
+ start_evt = torch.cuda.Event(enable_timing=True)
+ end_evt = torch.cuda.Event(enable_timing=True)
+
+ current_stream.record_event(start_evt)
+ nccl.reduceScatter(global_tensor.storage(), partition_tensor.storage(), 'avg', config['comm'])
+ current_stream.record_event(end_evt)
+ current_stream.synchronize()
+ time_usage = start_evt.elapsed_time(end_evt)
+
+ bw = global_size / 1024 / 1024 / 1024 * 1000 / time_usage
+ print_rank("Reduce Scatter:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(global_size), time_usage, bw))
+
diff --git a/examples/BMTrain/bmtrain/benchmark/send_recv.py b/examples/BMTrain/bmtrain/benchmark/send_recv.py
new file mode 100644
index 00000000..e3c971e4
--- /dev/null
+++ b/examples/BMTrain/bmtrain/benchmark/send_recv.py
@@ -0,0 +1,31 @@
+from .. import nccl
+from .shape import SHAPES
+from ..global_var import config
+from ..utils import print_rank
+from .utils import format_size
+import torch
+def send_recv():
+ current_stream = torch.cuda.current_stream()
+ for shape in SHAPES:
+ send_size = shape
+
+ send_buffer = torch.empty( send_size // 2, dtype=torch.half, device="cuda" )
+ recv_buffer = torch.empty( send_size // 2, dtype=torch.half, device="cuda" )
+
+ start_evt = torch.cuda.Event(enable_timing=True)
+ end_evt = torch.cuda.Event(enable_timing=True)
+
+ current_stream.record_event(start_evt)
+ nccl.groupStart()
+ if config['rank'] in [0,2,4,6]:
+ nccl.send(send_buffer.storage(), config['rank']+1, config['comm'])
+ else:
+ nccl.recv(recv_buffer.storage(), config['rank']-1, config['comm'])
+ nccl.groupEnd()
+ current_stream.record_event(end_evt)
+ current_stream.synchronize()
+ time_usage = start_evt.elapsed_time(end_evt)
+
+ bw = shape / 1024 / 1024 / 1024 * 1000 / time_usage
+ print_rank("Send Recv:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(send_size), time_usage, bw))
+
diff --git a/examples/BMTrain/bmtrain/benchmark/shape.py b/examples/BMTrain/bmtrain/benchmark/shape.py
new file mode 100644
index 00000000..0699e8cd
--- /dev/null
+++ b/examples/BMTrain/bmtrain/benchmark/shape.py
@@ -0,0 +1,3 @@
+SHAPES = [
+ (2**i) for i in range(10, 33)
+]
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/benchmark/utils.py b/examples/BMTrain/bmtrain/benchmark/utils.py
new file mode 100644
index 00000000..dbc4a70c
--- /dev/null
+++ b/examples/BMTrain/bmtrain/benchmark/utils.py
@@ -0,0 +1,11 @@
+def format_size_(x):
+ if x < 1024:
+ return "{:d}B".format(x)
+ if x < 1024 * 1024:
+ return "{:4.2f}KB".format(x / 1024)
+ if x < 1024 * 1024 * 1024:
+ return "{:4.2f}MB".format(x / 1024 / 1024)
+ return "{:4.2f}GB".format(x / 1024 / 1024 / 1024)
+
+def format_size(x):
+ return "{:.6s}".format(format_size_(x))
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/block_layer.py b/examples/BMTrain/bmtrain/block_layer.py
new file mode 100644
index 00000000..216d77b2
--- /dev/null
+++ b/examples/BMTrain/bmtrain/block_layer.py
@@ -0,0 +1,726 @@
+from typing import Dict, Iterable, Iterator, Union, List
+
+from .utils import round_up, tp_split_tensor
+from .global_var import config
+import torch
+from . import nccl
+from .parameter import DistributedParameter, OpAllGather
+from .zero_context import ZeroContext
+from . import hook_func
+import inspect
+from torch.utils.checkpoint import checkpoint
+
+
+def storage_type_cuda(storage_type):
+ """Convert storage_type to cuda storage_type."""
+ STORAGE_MAP = {
+ torch.FloatStorage: torch.cuda.FloatStorage,
+ torch.DoubleStorage: torch.cuda.DoubleStorage,
+ torch.HalfStorage: torch.cuda.HalfStorage,
+ torch.BFloat16Storage: torch.cuda.BFloat16Storage,
+ torch.CharStorage: torch.cuda.CharStorage,
+ torch.ByteStorage: torch.cuda.ByteStorage,
+ torch.ShortStorage: torch.cuda.ShortStorage,
+ torch.IntStorage: torch.cuda.IntStorage,
+ torch.cuda.FloatStorage: torch.cuda.FloatStorage,
+ torch.cuda.DoubleStorage: torch.cuda.DoubleStorage,
+ torch.cuda.HalfStorage: torch.cuda.HalfStorage,
+ torch.cuda.BFloat16Storage: torch.cuda.BFloat16Storage,
+ torch.cuda.CharStorage: torch.cuda.CharStorage,
+ torch.cuda.ByteStorage: torch.cuda.ByteStorage,
+ torch.cuda.ShortStorage: torch.cuda.ShortStorage,
+ torch.cuda.IntStorage: torch.cuda.IntStorage,
+ }
+ if storage_type not in STORAGE_MAP:
+ raise ValueError("Unknown storage type: {}".format(storage_type))
+ return STORAGE_MAP[storage_type]
+
+
+def _get_param_kw(param: DistributedParameter):
+ """Get DistributedParameter kw name."""
+ type_name = str(param.dtype).split(".")[-1]
+ grad_name = "_grad" if param.requires_grad else "_nograd"
+ group_name = ""
+ if param.group is not None:
+ group_name = "_g_" + param.group
+ return type_name + grad_name + group_name
+
+
+class Block(torch.nn.Module):
+ """A block containing two memory-saving methods of ZeRO and checkpoint.
+ For details please refer to `ZeRO `_ and
+ `Checkpointing `_ .
+
+ Args:
+ inner_module (torch.nn.Module): The module to reduce memory usage. All kinds of modules are supported.
+ use_checkpoint (boolean): use checkpoint or not. Default True.
+ zero_level (int): 2 (ZeRO-2) indicates that optimizer states and gradients are partitioned across the process,
+ 3 (ZeRO-3) means that the parameters are partitioned one the basis of ZeRO-2. Default 3.
+ initialized (bool): initialized parameter storage. Default False.
+ mode (str): the mode shouled be "PIPE" when runing in pipeline mode, otherwise mode="BLOCK". Default "BLOCK"
+
+ Examples:
+ >>> transformer_block = TransformerBlock(...)
+ >>> block = Block(transformer_block)
+ >>> y1, ... = block(x)
+ >>> y2, ... = transformer_block(x)
+ >>> assert torch.allclose(y1, y2)
+ """
+
+ def __init__(
+ self,
+ inner_module: torch.nn.Module,
+ use_checkpoint=True,
+ zero_level=3,
+ initialized=False,
+ mode="BLOCK",
+ ):
+ super().__init__()
+ self._module = inner_module
+ self._inputs = None
+ self._layer_dict = {}
+ self._forward_block_ctx = None
+ self._backward_block_ctx = None
+
+ self._param_info = []
+ self._storage_params: Dict[str, torch.nn.Parameter] = {}
+ self._storage_info = {}
+ self._ready = False
+
+ self._use_checkpoint = use_checkpoint
+ self._is_first_layer = True
+ self._is_last_layer = True
+ self._need_release = True
+ self._next_module = None # save the next module of self
+ self._pre_module = None # save the pre module of self
+ self._mode = mode # BLOCK or PIPE
+ self.all_input_no_grad = False
+ self.all_param_no_grad = False
+ self._zero_level = zero_level
+ if not initialized:
+ self.init_param_storage()
+
+ def reference(self, block):
+ """Make this block be a reference of the input Block."""
+ self._param_info = block._param_info
+ self._storage_params = block._storage_params
+ self._storage_info = block._storage_info
+ self._layer_dict = block._layer_dict
+ self._initialized = True
+ self._need_release = False
+
+ def init_param_storage(self):
+ """Init param storage."""
+ # sort parameters by name
+ ordered_parameters = list(self._module.named_parameters())
+
+ # calc total number of parameters
+ for name, param in ordered_parameters:
+ if not isinstance(param, DistributedParameter):
+ raise ValueError(
+ "All parameters in checkpoint block must be DistributedParameter."
+ )
+
+ storage_type = storage_type_cuda(param.storage_type())
+ kw_name = _get_param_kw(param)
+
+ if kw_name not in self._storage_info:
+ if self._mode == "PIPE" and param._tp_mode:
+ zero_comm = config["pp_tp_zero_comm"]
+ elif self._mode != "PIPE" and param._tp_mode:
+ zero_comm = config["tp_zero_comm"]
+ elif self._mode == "PIPE" and not param._tp_mode:
+ zero_comm = config["pp_zero_comm"]
+ else:
+ zero_comm = config["zero_comm"]
+
+ self._storage_info[kw_name] = {
+ "total": 0,
+ "storage_type": storage_type,
+ "requires_grad": param.requires_grad,
+ "group": param.group,
+ "zero_comm": zero_comm,
+ }
+
+ param_shape = param._original_shape
+
+ self._storage_info[kw_name]["total"] = round_up(
+ self._storage_info[kw_name]["total"] + param_shape.numel(),
+ 512 // param.element_size(),
+ # 512 bytes aligned
+ )
+
+ offsets = {}
+ # intialize storage buffers
+ for kw, val in self._storage_info.items():
+ comm = val["zero_comm"]
+ world_size = nccl.commCount(comm)
+ rank = nccl.commRank(comm)
+ val["world_size"] = world_size
+ partition_size = (
+ round_up(val["total"], val["world_size"]) // val["world_size"]
+ )
+ val["partition_size"] = partition_size
+ val["begin"] = rank * partition_size
+ val["end"] = (rank + 1) * partition_size
+ offsets[kw] = 0
+
+ storage_type = val["storage_type"]
+
+ storage_param_buffer = storage_type(partition_size)
+
+ dtype = storage_param_buffer.dtype
+ device = storage_param_buffer.device
+
+ # bind storage to buffer tensor
+ storage_param = torch.nn.Parameter(
+ torch.tensor([], dtype=dtype, device=device).set_(storage_param_buffer)
+ )
+ if val["requires_grad"]:
+ storage_param.requires_grad_(True)
+ else:
+ storage_param.requires_grad_(False)
+
+ self._storage_params[kw] = storage_param
+
+ # initialize parameters in module
+ for name, param in ordered_parameters:
+ param_shape = param._original_shape
+ kw_name = _get_param_kw(param)
+
+ param_st = offsets[kw_name]
+ offsets[kw_name] += param_shape.numel()
+ param_end = offsets[kw_name]
+ offsets[kw_name] = round_up(offsets[kw_name], 512 // param.element_size())
+
+ self._param_info.append(
+ {
+ "parameter": param,
+ "name": name,
+ "offset": param_st,
+ "size": param_shape.numel(),
+ "shape": param_shape,
+ "kw_name": kw_name,
+ }
+ )
+
+ # copy values to buffer for normal parameter
+ storage_st = self._storage_info[kw_name]["begin"]
+ storage_end = self._storage_info[kw_name]["end"]
+
+ # make parameter contiguous in storage
+ with torch.no_grad():
+ contiguous_param = OpAllGather.apply(param)
+
+ if not (param_st >= storage_end or param_end <= storage_st):
+ # copy offset in parameter storage
+ offset_st = max(storage_st - param_st, 0)
+ offset_end = min(storage_end - param_st, contiguous_param.numel())
+ assert offset_st < offset_end
+
+ # copy to offset in buffer storage
+ to_offset_st = offset_st + param_st - storage_st
+ to_offset_end = offset_end + param_st - storage_st
+
+ # copy to buffer
+ # PyTorch 1.11 changed the API of storage.__getitem__
+ d_dtype = self._storage_params[kw_name].dtype
+ d_device = self._storage_params[kw_name].device
+ param.data = torch.tensor(
+ [], dtype=param.dtype, device=param.device
+ ).set_(
+ self._storage_params[kw_name].storage(),
+ to_offset_st,
+ (to_offset_end - to_offset_st,),
+ )
+ self._param_info[-1]["begin"] = to_offset_st
+ self._param_info[-1]["end"] = (to_offset_end - to_offset_st,)
+ setattr(param, "_start_partition", offset_st)
+ setattr(param, "_end_partition", offset_end)
+ param.data[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_(
+ contiguous_param.storage(), offset_st, (offset_end - offset_st,)
+ )[:]
+ del contiguous_param
+ else:
+ param.data = torch.tensor([], dtype=param.dtype, device=param.device)
+ setattr(param, "_start_partition", None)
+ setattr(param, "_end_partition", 0)
+ # clear parameter data, but keep the dtype and device
+ setattr(param, "_in_block", True)
+
+ for kw in offsets.keys():
+ assert offsets[kw] == self._storage_info[kw]["total"]
+
+ def set_pre_module(self, pre_module):
+ """Set pre module for current Block."""
+ if pre_module is not None:
+ self._pre_module = pre_module
+ pre_module._next_module = self
+
+ def pre_module(self):
+ """Return pre module of current Block."""
+ return self._pre_module if not self._is_first_layer else None
+
+ def next_module(self):
+ """Return next module of current Block."""
+ return self._next_module if not self._is_last_layer else None
+
+ def release_next_module(self, flag):
+ """Release next module of current Block."""
+ if self.next_module() is not None:
+ self.next_module().release(flag)
+
+ def release(self, flag):
+ """Release cuurent block ctx."""
+ if self._need_release and self._backward_block_ctx is not None:
+ self._backward_block_ctx.exit(flag, True)
+ config["load_stream"].record_event(config["load_event"])
+
+ def pre_hook(self, *args):
+ """Hook function before forward."""
+ grad_tensors = []
+ grad_index = []
+ arg_list = list(args)
+ for i, arg in enumerate(args):
+ if arg is not None and isinstance(arg, torch.Tensor) and arg.requires_grad:
+ grad_tensors.append(arg)
+ grad_index.append(i)
+ grad_tensors = tuple(grad_tensors)
+
+ pre_out = hook_func.PreHookFunc.apply(self, *grad_tensors)
+ for i in range(len(grad_index)):
+ arg_list[grad_index[i]] = pre_out[i]
+
+ if self._mode != "PIPE" and len(grad_tensors) == 0:
+ self.all_param_no_grad = True
+ for param in self._param_info:
+ if param["parameter"].requires_grad:
+ self.all_param_no_grad = False
+ break
+ self.all_input_no_grad = True
+ else:
+ self.all_input_no_grad = False
+ return arg_list
+
+ def post_hook(self, out):
+ """Hook function after forward."""
+ tuple_out = (out,) if isinstance(out, torch.Tensor) else out
+ post_out = hook_func.PostHookFunc.apply(self, *tuple_out)
+ if isinstance(out, torch.Tensor) and isinstance(post_out, tuple):
+ return post_out[0]
+ post_out = tuple(post_out)
+ return post_out
+
+ def forward(self, *args, **kwargs):
+ signature = inspect.signature(self._module.forward)
+ bound_args = signature.bind(*args, **kwargs)
+ args = bound_args.args
+ arg_list = self.pre_hook(*args)
+
+
+ if self.all_input_no_grad and not self.all_param_no_grad:
+ placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled())
+ return hook_func.OneStepNoGradFunc.apply(self, placeholder, *arg_list)
+
+ if self._use_checkpoint:
+ out = checkpoint(
+ self._module, *arg_list, use_reentrant=not self.all_input_no_grad
+ )
+ else:
+ out = self._module(*arg_list)
+
+ return self.post_hook(out)
+
+ def __getattr__(self, name: str):
+ if name == "_module":
+ return self._module
+ return getattr(self._module, name)
+
+ def __setattr__(self, name, value):
+ object.__setattr__(self, name, value)
+
+ def __getattribute__(self, name: str):
+ if name == "_parameters":
+ return self._module._parameters
+ return super().__getattribute__(name)
+
+ def __delattr__(self, name):
+ object.__delattr__(self, name)
+
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
+ raise RuntimeError("._save_to_state_dict() of Block should not be called")
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ # gather here
+ with torch.no_grad():
+ with ZeroContext(self):
+ return self._module.state_dict(
+ destination=destination, prefix=prefix, keep_vars=keep_vars
+ )
+
+ def _load_from_state_dict(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ ):
+ all_keys = []
+ for it in self._param_info:
+ key = prefix + it["name"]
+ all_keys.append(key)
+ if key in state_dict:
+ # load here
+ input_param = state_dict[key]
+ param = it["parameter"]
+ tp_mode = param._tp_mode
+ if input_param.__class__.__name__ == "DistributedTensorWrapper":
+ input_param = input_param.broadcast()
+
+ verify_shape = torch.Size(
+ it["shape"] if not tp_mode else param._tp_original_shape
+ )
+ if input_param.shape != verify_shape:
+ error_msgs.append(
+ "size mismatch for {}: copying a param with shape {} from checkpoint, "
+ "the shape in current model is {}.".format(
+ key, input_param.shape, verify_shape
+ )
+ )
+ continue
+
+ param_st = it["offset"]
+ param_end = it["offset"] + it["size"]
+ kw_name = it["kw_name"]
+
+ # not in this partition
+ storage_st = self._storage_info[kw_name]["begin"]
+ storage_end = self._storage_info[kw_name]["end"]
+ if param_st >= storage_end:
+ continue
+ if param_end <= storage_st:
+ continue
+
+ # copy to buffer
+ verify_size = verify_shape.numel()
+ assert input_param.numel() == verify_size
+
+ contiguous_param = (
+ input_param.to(it["parameter"].dtype).cuda().contiguous()
+ )
+
+ tp_split_dim = param._tp_split_dim
+ if tp_mode and tp_split_dim >= 0:
+ contiguous_param = tp_split_tensor(contiguous_param, tp_split_dim)
+
+ offset_st = max(storage_st - param_st, 0)
+ offset_end = min(storage_end - param_st, contiguous_param.numel())
+ assert offset_st < offset_end
+
+ to_offset_st = offset_st + param_st - storage_st
+ to_offset_end = offset_end + param_st - storage_st
+
+ # copy to buffer
+ # PyTorch 1.11 changed the API of storage.__getitem__
+ d_dtype = self._storage_params[kw_name].dtype
+ d_device = self._storage_params[kw_name].device
+ torch.tensor([], dtype=d_dtype, device=d_device).set_(
+ self._storage_params[kw_name].storage(),
+ to_offset_st,
+ (to_offset_end - to_offset_st,),
+ )[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_(
+ contiguous_param.storage(), offset_st, (offset_end - offset_st,)
+ )[
+ :
+ ]
+ del contiguous_param
+ elif strict:
+ missing_keys.append(key)
+
+ for name, param in self.named_parameters():
+ if isinstance(param, DistributedParameter) and not param._in_block:
+ key = prefix + name
+ all_keys.append(key)
+ if key in state_dict:
+ input_param = state_dict[key]
+ is_param_lazy = torch.nn.parameter.is_lazy(param)
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
+ if (
+ not is_param_lazy
+ and len(param.shape) == 0
+ and len(input_param.shape) == 1
+ ):
+ input_param = input_param[0]
+
+ if (
+ not is_param_lazy
+ and not isinstance(param, DistributedParameter)
+ and input_param.shape != param.shape
+ ):
+ # local shape should match the one in checkpoint
+ error_msgs.append(
+ "size mismatch for {}: copying a param with shape {} from checkpoint, "
+ "the shape in current model is {}.".format(
+ key, input_param.shape, param.shape
+ )
+ )
+ continue
+ if (
+ not is_param_lazy
+ and isinstance(param, DistributedParameter)
+ and input_param.shape != param._original_shape
+ ):
+ error_msgs.append(
+ "size mismatch for {}: copying a param with shape {} from checkpoint, "
+ "the shape in current model is {}.".format(
+ key, input_param.shape, param.shape
+ )
+ )
+ try:
+ with torch.no_grad():
+ param._copy_data(input_param)
+ except Exception as ex:
+ error_msgs.append(
+ 'While copying the parameter named "{}", '
+ "whose dimensions in the model are {} and "
+ "whose dimensions in the checkpoint are {}, "
+ "an exception occurred : {}.".format(
+ key, param.size(), input_param.size(), ex.args
+ )
+ )
+ elif strict:
+ missing_keys.append(key)
+
+ if strict:
+ all_keys = set(all_keys)
+ for key in state_dict.keys():
+ if key.startswith(prefix) and key not in all_keys:
+ unexpected_keys.append(key)
+
+ def grouped_parameters(self):
+ """
+ Yield group params in storage params.
+ """
+ ret = {}
+ for kw, val in self._storage_info.items():
+ if val["group"] not in ret:
+ ret[val["group"]] = []
+ ret[val["group"]].append(self._storage_params[kw])
+ for kw, val in ret.items():
+ yield kw, val
+
+ def init_parameters(self):
+ """
+ Initialize distributed parameters in this block.
+ """
+ for it in self._param_info:
+ param = it["parameter"]
+ if (
+ isinstance(param, DistributedParameter)
+ and param._init_method is not None
+ ):
+ # initialzie here
+ tmp_tensor = torch.empty(
+ param._tp_original_shape, device=param.device, dtype=param.dtype
+ )
+ param._init_method(tmp_tensor)
+ param_st = it["offset"]
+ param_end = it["offset"] + it["size"]
+ kw_name = it["kw_name"]
+
+ # not in this partition
+ storage_st = self._storage_info[kw_name]["begin"]
+ storage_end = self._storage_info[kw_name]["end"]
+ if param_st >= storage_end:
+ continue
+ if param_end <= storage_st:
+ continue
+
+ if param._tp_mode and param._tp_split_dim >= 0:
+ tmp_tensor = tp_split_tensor(tmp_tensor, param._tp_split_dim)
+ # copy to buffer
+ assert tmp_tensor.is_contiguous() and it["size"] == tmp_tensor.numel()
+
+ offset_st = max(storage_st - param_st, 0)
+ offset_end = min(storage_end - param_st, tmp_tensor.numel())
+ assert offset_st < offset_end
+
+ # copy to buffer
+ # PyTorch 1.11 changed the API of storage.__getitem__
+ d_dtype = self._storage_params[kw_name].dtype
+ d_device = self._storage_params[kw_name].device
+ param.data[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_(
+ tmp_tensor.storage(), offset_st, (offset_end - offset_st,)
+ )[:]
+ del tmp_tensor
+
+ def _named_members(self, get_members_fn, prefix="", recurse=True, **kwargs):
+ r"""Helper method for yielding various names + members of modules."""
+
+ # compitibity with torch 2.0
+ if (
+ "remove_duplicate"
+ in inspect.signature(torch.nn.Module._named_members).parameters
+ and "remove_duplicate" not in kwargs
+ ):
+ kwargs["remove_duplicate"] = True
+ return self._module._named_members(get_members_fn, prefix, recurse, **kwargs)
+
+ def named_modules(self, memo=None, prefix: str = "", remove_duplicate: bool = True):
+ r"""Returns an iterator over all modules in the network, yielding
+ both the name of the module as well as the module itself.
+
+ Args:
+ memo: a memo to store the set of modules already added to the result
+ prefix: a prefix that will be added to the name of the module
+ remove_duplicate: whether to remove the duplicated module instances in the result
+ or not
+
+ Yields:
+ (string, Module): Tuple of name and module
+
+ Note:
+ Duplicate modules are returned only once. In the following
+ example, ``l`` will be returned only once.
+
+ Example::
+
+ >>> l = nn.Linear(2, 2)
+ >>> net = nn.Sequential(l, l)
+ >>> for idx, m in enumerate(net.named_modules()):
+ print(idx, '->', m)
+
+ 0 -> ('', Sequential(
+ (0): Linear(in_features=2, out_features=2, bias=True)
+ (1): Linear(in_features=2, out_features=2, bias=True)
+ ))
+ 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
+
+ """
+
+ if memo is None:
+ memo = set()
+ if self not in memo:
+ if remove_duplicate:
+ memo.add(self)
+ yield prefix, self
+ for name, module in self._module._modules.items():
+ if module is None:
+ continue
+ submodule_prefix = prefix + ("." if prefix else "") + name
+ for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
+ yield m
+
+ def named_children(self):
+ return self._module.named_children()
+
+ def train(self, mode: bool = True):
+ self._module.train(mode)
+
+ def eval(self):
+ self._module.eval()
+
+ def __repr__(self):
+ return self._module.__repr__()
+
+
+def _block_wrapper(module, module_dict: dict, mode="BLOCK"):
+ if not isinstance(module, Block):
+ in_block = id(module) in module_dict
+ new_module = Block(module, initialized=in_block, mode=mode)
+ if in_block:
+ new_module.reference(module_dict[id(module)])
+ else:
+ module_dict[id(module)] = new_module
+ else:
+ if mode == "PIPE" and module._mode != "PIPE":
+ assert (
+ False
+ ), 'You must be set mode="PIPE" in bmt.Block when use PipelineTransformerBlockList!'
+ if id(module._module) in module_dict:
+ assert False, "Duplicate bmt.Block not supported in same block list!"
+ else:
+ new_module = module
+ module_dict[id(module._module)] = new_module
+ return new_module
+
+
+class TransformerBlockList(torch.nn.Module):
+ r"""
+ TransformerBlockList is a list of bmt.Block.
+
+ This is designed to reduce the communication overhead by overlapping the computation and reduce_scatter operation during backward pass.
+
+ It is similar to `torch.nn.ModuleList` but with the difference when calling .forward() and .backward().
+
+ Example:
+ >>> module_list = [ ... ]
+ >>> normal_module_list = torch.nn.ModuleList(module_list)
+ >>> transformer_module_list = TransformerBlockList(module_list)
+ >>> # Calling normal module list
+ >>> for layer in normal_module_list:
+ >>> hidden_state = layer.forward(hidden_state, ...)
+ >>> # Calling transformer module list
+ >>> hidden_state = transformer_module_list(hidden_state, ...)
+
+ """
+
+ _modules: Dict[str, Block]
+
+ def __init__(self, modules: Iterable[Block], num_hidden=1) -> None:
+ super().__init__()
+
+ self._modules = {}
+ pre_module = None
+ module_dict = {}
+ module_dict = {}
+ for i, module in enumerate(modules):
+ module = _block_wrapper(module, module_dict)
+ module.set_pre_module(pre_module)
+ pre_module = module
+ module._is_first_layer = False
+ module._is_last_layer = False
+ self._modules[str(i)] = module
+ self.add_module(str(i), module)
+
+ self._modules[str(0)]._is_first_layer = True
+ self._modules[str(len(modules) - 1)]._is_last_layer = True
+
+ self.num_hidden = num_hidden
+
+ def __len__(self) -> int:
+ return len(self._modules)
+
+ def __iter__(self) -> Iterator[Block]:
+ return iter(self._modules.values())
+
+ def __getitem__(self, index: Union[int, str]) -> Block:
+ return self._modules[str(index)]
+
+ def forward(self, *args, return_hidden_states=False):
+ self.return_hidden_states = return_hidden_states
+ hidden_states = []
+ for i in range(len(self)):
+ if return_hidden_states:
+ for hidden_state in args[: self.num_hidden]:
+ hidden_states.append(hidden_state)
+ outputs = self._modules[str(i)]._call_impl(*args)
+ if not isinstance(outputs, tuple):
+ outputs = (outputs,)
+ args = outputs + args[self.num_hidden :]
+
+ if return_hidden_states:
+ hidden_states = [
+ torch.stack(hidden_states[i :: self.num_hidden], dim=0)
+ for i in range(self.num_hidden)
+ ]
+
+ if return_hidden_states:
+ return outputs + tuple(hidden_states)
+ else:
+ return (
+ tuple(outputs[: self.num_hidden]) if self.num_hidden > 1 else outputs[0]
+ )
diff --git a/examples/BMTrain/bmtrain/debug.py b/examples/BMTrain/bmtrain/debug.py
new file mode 100644
index 00000000..de392623
--- /dev/null
+++ b/examples/BMTrain/bmtrain/debug.py
@@ -0,0 +1,34 @@
+import torch
+
+DEBUG_VARS = {}
+
+def clear(key=None):
+ global DEBUG_VARS
+ if key is None:
+ DEBUG_VARS = {}
+ else:
+ DEBUG_VARS.pop(key, None)
+
+def set(key, value):
+ global DEBUG_VARS
+ if torch.is_tensor(value):
+ value = value.detach().cpu()
+ DEBUG_VARS[key] = value
+
+def get(key, default=None):
+ global DEBUG_VARS
+ if key in DEBUG_VARS:
+ return DEBUG_VARS[key]
+ return default
+
+def append(key, value):
+ global DEBUG_VARS
+ if key not in DEBUG_VARS:
+ DEBUG_VARS[key] = []
+ DEBUG_VARS[key].append(value)
+
+def extend(key, value):
+ global DEBUG_VARS
+ if key not in DEBUG_VARS:
+ DEBUG_VARS[key] = []
+ DEBUG_VARS[key].extend(value)
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/distributed/__init__.py b/examples/BMTrain/bmtrain/distributed/__init__.py
new file mode 100644
index 00000000..84a4adf8
--- /dev/null
+++ b/examples/BMTrain/bmtrain/distributed/__init__.py
@@ -0,0 +1 @@
+from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations, reduce_scatter
diff --git a/examples/BMTrain/bmtrain/distributed/ops.py b/examples/BMTrain/bmtrain/distributed/ops.py
new file mode 100644
index 00000000..d1b489e2
--- /dev/null
+++ b/examples/BMTrain/bmtrain/distributed/ops.py
@@ -0,0 +1,223 @@
+import torch
+from ..global_var import config
+from ..nccl import allGather as ncclAllGather, recv
+from ..nccl import allReduce as ncclAllReduce
+from ..nccl import broadcast as ncclBroadcast
+from ..nccl import reduceScatter as ncclReduceScatter
+from ..nccl import send as ncclSend
+from ..nccl import recv as ncclRecv
+from ..nccl import commCount,commRank,NCCLCommunicator
+DTYPE_LIST = [
+ torch.float64,
+ torch.float32,
+ torch.float16,
+ torch.int64,
+ torch.int32,
+ torch.int16,
+ torch.int8,
+ torch.bfloat16,
+ torch.bool
+]
+def send_activations(hidden_state, next_rank, comm):
+ send_meta(hidden_state, next_rank, comm)
+ ncclSend(hidden_state.storage(), next_rank, comm)
+
+def recv_activations(prev_rank, comm):
+ dtype, shape = recv_meta(prev_rank, comm)
+ hidden_state = torch.empty(shape, dtype=dtype, device="cuda")
+ ncclRecv(hidden_state.storage(), prev_rank, comm)
+ return hidden_state
+
+def send_meta(x, next_rank, comm):
+ meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int)
+ meta_data[0] = len(x.size())
+ meta_data[1] = DTYPE_LIST.index(x.dtype)
+ meta_data[2:len(x.size())+2] = torch.tensor(x.size(), device="cuda", dtype=torch.int)
+ meta_data = meta_data.contiguous()
+ ncclSend(meta_data.storage(), next_rank, comm)
+
+def recv_meta(prev_rank, comm):
+ meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int)
+ ncclRecv(meta_data.storage(), prev_rank, comm)
+ n_dims = meta_data[0].item()
+ dtype = DTYPE_LIST[meta_data[1].item()]
+ shape = meta_data[2:n_dims+2].tolist()
+ return dtype,shape
+
+class OpBroadcast(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, src, root, comm = None):
+ if comm is None:
+ comm = config["comm"]
+ ctx.comm = comm
+ outputs = torch.empty_like(src, dtype = src.dtype, device = src.device)
+ ncclBroadcast(src.storage(), outputs.storage(), root, comm)
+ return outputs
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ res = all_reduce(grad_output, "sum", ctx.comm)
+ return res, None, None
+
+def broadcast(src, root, comm=None):
+ if not config["initialized"]:
+ raise RuntimeError("BMTrain is not initialized")
+ return OpBroadcast.apply(src, root, comm)
+
+class OpAllGather(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input : torch.Tensor, comm = None):
+ if comm is None:
+ comm = config["comm"]
+ world_size = commCount(comm)
+ if not input.is_contiguous():
+ input = input.contiguous()
+ if input.storage_offset() != 0 or input.storage().size() != input.numel():
+ input = input.clone()
+ output = torch.empty( (world_size,) + input.size(), dtype=input.dtype, device=input.device)
+ ctx.comm = comm
+ ncclAllGather(
+ input.storage(),
+ output.storage(),
+ comm
+ )
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output[commRank(ctx.comm)], None
+
+def all_gather(x : torch.Tensor, comm = None):
+ """Gathers the input tensor from all processes.
+
+ Args:
+ x (torch.Tensor): The input tensor of shape (...).
+
+ Returns:
+ torch.Tensor: The gathered tensor of shape (world_size, ...).
+ """
+ if not config["initialized"]:
+ raise RuntimeError("BMTrain is not initialized")
+
+ assert x.is_cuda
+ return OpAllGather.apply(x, comm)
+
+class OpReduceScatter(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None):
+ if comm is None:
+ comm = config["comm"]
+ ctx.comm = comm
+ rank = commRank(comm)
+ assert input.shape[0] % commCount(comm) == 0, "The dimension 0 must be divisible by the number of communication processes"
+ if not input.is_contiguous():
+ input = input.contiguous()
+ if input.storage_offset() != 0 or input.storage().size() != input.numel():
+ input = input.clone()
+ output_shape = (input.shape[0] // commCount(comm), *input.shape[1:])
+ output = torch.empty( output_shape, dtype=input.dtype, device=input.device )
+ ncclReduceScatter(
+ input.storage(),
+ output.storage(),
+ op,
+ comm
+ )
+ ctx.op = op
+ if op in ["sum", "avg"]:
+ pass
+ elif op in ["max", "min"]:
+ ctx.save_for_backward( output != input[rank * input.shape[0]:(rank + 1) * input.shape[0]] )
+ else:
+ ctx.save_for_backward( output / input[rank * input.shape[0]:(rank + 1) * input.shape[0]] )
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ with torch.no_grad():
+ grad_output = OpAllGather.apply(grad_output, ctx.comm).flatten(0,1)
+ if ctx.op in ["max", "min", "prod"]:
+ raise NotImplementedError("max min operation now do not support backward")
+ else:
+ if ctx.op == "avg":
+ grad_output /= commCount(ctx.comm)
+ return grad_output, None, None
+
+
+def reduce_scatter(x : torch.Tensor, op : str = "sum", comm = None):
+ """Reduces the input tensor from all processes.
+
+ Args:
+ x (torch.Tensor): The input tensor of shape (world_size, ...).
+ op (str): The reduction operation, one of "sum", "avg", "max", "min", "prod". Default: "sum".
+
+ Returns:
+ torch.Tensor: The reduced tensor of shape (...).
+
+ """
+ if not config["initialized"]:
+ raise RuntimeError("BMTrain is not initialized")
+
+ assert x.is_cuda
+ return OpReduceScatter.apply(x, op, comm)
+
+class OpAllReduce(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None):
+ if comm is None:
+ comm = config["comm"]
+ ctx.comm = comm
+ if not input.is_contiguous():
+ input = input.contiguous()
+ if input.storage_offset() != 0 or input.storage().size() != input.numel():
+ input = input.clone()
+ output = torch.empty( input.size(), dtype=input.dtype, device=input.device)
+
+ ncclAllReduce(
+ input.storage(),
+ output.storage(),
+ op,
+ comm
+ )
+ ctx.op = op
+
+ if op in ["sum", "avg"]:
+ pass
+ elif op in ["max", "min"]:
+ ctx.save_for_backward( input != output )
+ else:
+ ctx.save_for_backward( output / input )
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if ctx.op == "sum":
+ return grad_output, None, None
+ elif ctx.op == "avg":
+ return grad_output / commCount(ctx.comm), None, None
+ elif ctx.op in ["max", "min"]:
+ return torch.masked_fill(grad_output, ctx.saved_tensors[0], 0), None, None
+ else:
+ return grad_output * ctx.saved_tensors[0], None, None
+
+def all_reduce(x : torch.Tensor, op : str = "sum", comm = None):
+ """Reduces the input tensor from all processes.
+
+ Args:
+ x (torch.Tensor): The input tensor of shape (...).
+ op (str): The reduction operation, one of "sum", "avg", "max", "min", "prod". Default: "sum".
+
+ Returns:
+ torch.Tensor: The reduced tensor of shape (...).
+
+ """
+ if not config["initialized"]:
+ raise RuntimeError("BMTrain is not initialized")
+
+ assert x.is_cuda
+ return OpAllReduce.apply(x, op, comm)
+
+
+
diff --git a/examples/BMTrain/bmtrain/global_var.py b/examples/BMTrain/bmtrain/global_var.py
new file mode 100644
index 00000000..137fa9cd
--- /dev/null
+++ b/examples/BMTrain/bmtrain/global_var.py
@@ -0,0 +1,35 @@
+import torch
+from typing_extensions import TypedDict
+class ConfigMap(TypedDict):
+ rank : int
+ local_rank : int
+ world_size : int
+ local_size : int
+ zero_level : int
+ pipe_size : int
+ num_micro_batches : int
+ calc_stream : torch.cuda.Stream
+ load_stream : torch.cuda.Stream
+ load_event : torch.cuda.Event
+ barrier_stream : torch.cuda.Stream
+ loss_scale_factor : float
+ loss_scale_steps : int
+ topology : 'topology'
+ gradient_inspect : bool
+ initialized : bool
+
+ comm : 'NCCLCommunicator'
+
+config = ConfigMap(rank=0, local_rank=0, world_size=1, initialized=False)
+
+def rank():
+ """
+ Returns the global rank of the current process. (0 ~ world_size-1)
+ """
+ return config['rank']
+
+def world_size():
+ """
+ Returns the total number of workers across all nodes.
+ """
+ return config['world_size']
diff --git a/examples/BMTrain/bmtrain/hook_func.py b/examples/BMTrain/bmtrain/hook_func.py
new file mode 100644
index 00000000..577331a2
--- /dev/null
+++ b/examples/BMTrain/bmtrain/hook_func.py
@@ -0,0 +1,121 @@
+import torch
+from .global_var import config
+from .zero_context import ZeroContext
+
+
+def zero_pre_forward(module, inputs):
+ """Helper function for using ZeroContext to gather parmas before forward."""
+ enter = True
+ pipe = False
+ if module._mode == "PIPE":
+ enter = module._micro_idx == 0
+ pipe = True
+ if enter:
+ zero_level = module._zero_level
+ forward_flag = 1 if zero_level == 2 else 0
+ if zero_level == 2 and not module._need_release:
+ forward_flag = 2 # repeating forward in same layer
+ if module.all_param_no_grad: # only forward
+ forward_flag = 0
+ module._forward_block_ctx = ZeroContext(module, module._layer_dict, pipe=pipe)
+ module._forward_block_ctx.enter(forward_flag)
+
+
+def zero_post_forward(module, inputs, outputs):
+ """Helper function for module _forwar_block_ctx weather exits after forward."""
+ forward_flag = 1 if module._zero_level == 2 else 0
+ if module.all_param_no_grad:
+ forward_flag = 0
+ exit = True
+ if module._mode == "PIPE":
+ exit = module._micro_idx == config["micros"] - 1
+
+ if exit:
+ module._forward_block_ctx.exit(forward_flag)
+
+
+def zero_pre_backward(module, grad_outputs):
+ """Helper function for using ZeroContext to init grad buffer before backward."""
+ backward_flag = 2 if module._zero_level == 2 else 0
+ if module._mode != "PIPE":
+ module._backward_block_ctx = ZeroContext(module, module._layer_dict)
+ module._backward_block_ctx.enter(backward_flag, True)
+ module.release_next_module(backward_flag)
+ else:
+ if module._micro_idx == config["micros"] - 1:
+ module._backward_block_ctx = ZeroContext(
+ module, module._layer_dict, pipe=True
+ )
+ module._backward_block_ctx.enter(backward_flag, True)
+
+
+def zero_post_backward(module, grad_inputs, grad_outputs):
+ """Helper function for module weather release after backward."""
+ backward_flag = 2 if module._zero_level == 2 else 0
+ if module._mode != "PIPE":
+ if module._is_first_layer:
+ module.release(backward_flag)
+ else:
+ if module._micro_idx == 0:
+ module.release(backward_flag)
+ module._micro_idx -= 1
+
+
+class OneStepNoGradFunc(torch.autograd.Function):
+ """
+ Requires_grad = False for all inputs.
+ """
+
+ @staticmethod
+ def forward(ctx, module, placeholder, *x):
+ ctx.x = x
+ ctx.module = module
+ ctx.rng_state = torch.cuda.get_rng_state()
+
+ with torch.no_grad():
+ out = module._module(*x)
+ zero_post_forward(module, None, out)
+ if not isinstance(out, torch.Tensor):
+ return tuple(out)
+ return out
+
+ @staticmethod
+ def backward(ctx, grads):
+ zero_pre_backward(ctx.module, grads)
+ with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True):
+ torch.cuda.set_rng_state(ctx.rng_state)
+ x = ctx.x
+ with torch.enable_grad():
+ out = ctx.module._module(*x)
+ torch.autograd.backward(out, grads)
+ zero_post_backward(ctx.module, grads, None)
+ grads = []
+ for _ in x:
+ grads.append(None)
+ return None, None, *grads
+
+
+class PreHookFunc(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, module, *x):
+ ctx.module = module
+ zero_pre_forward(module, x)
+ return x
+
+ @staticmethod
+ def backward(ctx, *grads):
+ zero_post_backward(ctx.module, grads, None)
+ return None, *grads
+
+
+class PostHookFunc(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, module, *out):
+ ctx.module = module
+ zero_post_forward(module, None, out)
+ return out
+
+ @staticmethod
+ def backward(ctx, *grads):
+ zero_pre_backward(ctx.module, grads)
+ return None, *grads
diff --git a/examples/BMTrain/bmtrain/init.py b/examples/BMTrain/bmtrain/init.py
new file mode 100644
index 00000000..601d617e
--- /dev/null
+++ b/examples/BMTrain/bmtrain/init.py
@@ -0,0 +1,258 @@
+import datetime
+import torch
+import random
+import torch.distributed as dist
+import os
+from .utils import print_dict
+import ctypes
+from .global_var import config
+
+from . import nccl
+from .synchronize import synchronize
+
+
+def init_distributed(
+ init_method: str = "env://",
+ seed: int = 0,
+ pipe_size: int = -1,
+ num_micro_batches: int = None,
+ tp_size: int = 1,
+):
+ """Initialize distributed training.
+ This function will initialize the distributed training, set the random seed and global configurations.
+ It must be called before any other distributed functions.
+
+ Args:
+ seed (int): The random seed.
+ pipe_size (int) : pipe_size means that all processes will be divided into pipe_size groups
+ num_micro_batches (int) : means that the input batchs will be divided into num_micro_batches small batches. used in pipeline mode.
+ tp_size (int) : tp_size means the size of each of tensor parallel group
+
+ **init_distributed** reads the following environment variables:
+
+ * `WORLD_SIZE`: The total number gpus in the distributed training.
+ * `RANK`: The global rank of the current gpu. From 0 to `WORLD_SIZE - 1`.
+ * `MASTER_ADDR`: The address of the master node.
+ * `MASTER_PORT`: The port of the master node.
+ * `LOCAL_RANK`: The local rank of the current gpu.
+
+ Normally, all the environments variables above are setted by the pytorch distributed launcher.
+
+ **Note**: Do not use any functions in torch.distributed package including `torch.distributed.init_process_group` .
+
+ **Note**: If your training script is stuck here , it means some of your distributed workers are not connected to the master node.
+
+ """
+ torch.backends.cudnn.enabled = False
+
+ local_rank = int(os.environ.get("LOCAL_RANK", "0"))
+ rank = int(os.environ.get("RANK", "0"))
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
+ local_size = int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
+ if "MASTER_ADDR" not in os.environ:
+ os.environ["MASTER_ADDR"] = "localhost"
+ if "MASTER_PORT" not in os.environ:
+ os.environ["MASTER_PORT"] = "10010"
+ addr = os.environ["MASTER_ADDR"]
+ port = os.environ["MASTER_PORT"]
+ master = addr + ":" + port
+ timeout = datetime.timedelta(seconds=1800)
+ rendezvous_iterator = dist.rendezvous(
+ init_method, rank, world_size, timeout=timeout
+ )
+
+ store, rank, world_size = next(rendezvous_iterator)
+ store.set_timeout(timeout)
+ store = dist.PrefixStore("bmtrain", store)
+ torch.cuda.set_device(local_rank)
+ config["initialized"] = True
+ config["pipe_size"] = pipe_size if pipe_size > 0 else 1
+ config["pipe_enabled"] = pipe_size > 0
+ config["local_rank"] = local_rank
+ config["local_size"] = local_size
+ config["rank"] = rank
+ config["world_size"] = world_size
+ config["calc_stream"] = torch.cuda.current_stream()
+ config["load_stream"] = torch.cuda.Stream(priority=-1)
+ config["tp_comm_stream"] = torch.cuda.Stream(priority=-1)
+ config["pp_comm_stream"] = torch.cuda.Stream(priority=-1)
+ config["barrier_stream"] = torch.cuda.Stream()
+ config["load_event"] = torch.cuda.Event()
+ config["tp_size"] = tp_size if tp_size > 0 else 1
+ config["topology"] = topology(config)
+ config["zero_rank"] = config["topology"].get_group_rank("zero")
+ config["tp_rank"] = config["topology"].get_group_rank("tp")
+ config["tp_zero_rank"] = config["topology"].get_group_rank("tp_zero")
+ config["save_param_to_cpu"] = True
+ cpus_this_worker = None
+
+ all_available_cpus = sorted(list(os.sched_getaffinity(0)))
+
+ cpus_per_worker = len(all_available_cpus) // local_size
+
+ if cpus_per_worker < 1:
+ cpus_this_worker = all_available_cpus
+ torch.set_num_threads(1)
+ else:
+ cpus_this_worker = all_available_cpus[
+ local_rank * cpus_per_worker : (local_rank + 1) * cpus_per_worker
+ ]
+ os.sched_setaffinity(0, cpus_this_worker)
+ torch.set_num_threads(len(cpus_this_worker))
+
+ torch.manual_seed(seed)
+ random.seed(seed)
+ try:
+ import numpy as np
+
+ np.random.seed(seed)
+ except ModuleNotFoundError:
+ pass
+
+ if rank == 0:
+ unique_id: bytes = nccl.getUniqueId()
+ store.set("BMTRAIN_UNIQUE_ID", unique_id.hex())
+
+ unique_id = bytes.fromhex(store.get("BMTRAIN_UNIQUE_ID").decode())
+ config["comm"] = nccl.commInitRank(unique_id, world_size, rank)
+ topo = config["topology"]
+
+ if config["pipe_enabled"]:
+ config["micros"] = (
+ num_micro_batches if num_micro_batches else config["pipe_size"]
+ )
+ if topo.stage_id == 0:
+ unique_id = nccl.getUniqueId()
+ store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex())
+ unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode())
+ config["pipe_comm"] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id)
+
+ if topo.pp_zero_id == 0:
+ unique_id = nccl.getUniqueId()
+ store.set(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}", unique_id.hex())
+ unique_id = bytes.fromhex(
+ store.get(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}").decode()
+ )
+ config["pp_zero_comm"] = nccl.commInitRank(
+ unique_id, world_size // config["pipe_size"], topo.pp_zero_id
+ )
+
+ if config["tp_size"] > 1:
+ if topo.tp_id == 0:
+ unique_id = nccl.getUniqueId()
+ store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex())
+ unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode())
+ config["tp_comm"] = nccl.commInitRank(unique_id, tp_size, topo.tp_id)
+
+ if topo.tp_zero_id == 0:
+ unique_id = nccl.getUniqueId()
+ store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex())
+ unique_id = bytes.fromhex(
+ store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()
+ )
+ config["tp_zero_comm"] = nccl.commInitRank(
+ unique_id, world_size // config["tp_size"], topo.tp_zero_id
+ )
+
+ if config["pipe_size"] > 1 and config["tp_size"] > 1:
+ if topo.pp_tp_zero_id == 0:
+ unique_id = nccl.getUniqueId()
+ store.set(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}", unique_id.hex())
+ unique_id = bytes.fromhex(
+ store.get(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}").decode()
+ )
+ config["pp_tp_zero_comm"] = nccl.commInitRank(
+ unique_id,
+ world_size // (config["pipe_size"] * config["tp_size"]),
+ topo.pp_tp_zero_id,
+ )
+
+ config["zero_comm"] = config["comm"]
+
+ for i in range(world_size):
+ if i == rank:
+ print_dict(
+ "Initialization",
+ {
+ "rank": rank,
+ "local_rank": local_rank,
+ "world_size": world_size,
+ "local_size": local_size,
+ "master": master,
+ "device": torch.cuda.current_device(),
+ "cpus": cpus_this_worker,
+ },
+ )
+ synchronize()
+
+
+class topology:
+ """A helper class to keep parallel information when using different parallel methods together."""
+
+ def __init__(self, config):
+ # pipe_idx is the idx of the pipeline in the group
+ self.rank = config["rank"]
+ pp_size = config["pipe_size"]
+ tp_size = config["tp_size"]
+ world_size = config["world_size"]
+ assert (
+ world_size % (pp_size * tp_size) == 0
+ ), "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size"
+
+ dp_size = world_size // (pp_size * tp_size)
+ config["tp_zero_size"] = dp_size
+ config["zero_size"] = world_size // pp_size
+ self.stages = config["pipe_size"]
+
+ stage_size = world_size // pp_size
+ for i in range(world_size):
+ self.pipe_idx = self.rank % stage_size
+ self.stage_id = self.rank // stage_size
+ self.tp_id = self.rank % tp_size
+ self.tp_idx = self.rank // tp_size
+ # pp->zero
+ self.pp_zero_idx = self.stage_id
+ self.pp_zero_id = self.pipe_idx
+ # tp->zero
+ self.tp_zero_idx = self.tp_id
+ self.tp_zero_id = self.tp_idx
+ # pp->tp->zero
+ self.pp_tp_zero_idx = self.stage_id * tp_size + self.tp_id
+ self.pp_tp_zero_id = self.pipe_idx // tp_size
+ # only zero
+ self.zero_idx = 0
+ self.zero_id = self.rank
+
+ def get_group_id(self, group_name):
+ """Get group id of different parallel group.
+
+ Args:
+ group_name (str): must be one of "pipe", "zero", "tp_zero" or "tp".
+ """
+ if group_name == "pipe":
+ return self.pipe_idx
+ elif group_name == "zero":
+ return self.zero_idx
+ elif group_name == "tp_zero":
+ return self.tp_zero_idx
+ elif group_name == "tp":
+ return self.tp_idx
+
+ def get_group_rank(self, group_name):
+ """Get group rank of different parallel group.
+
+ Args:
+ group_name (str): must be one of "pipe", "zero", "tp_zero" or "tp".
+ """
+ if group_name == "pipe":
+ return self.stage_id
+ elif group_name == "zero":
+ return self.zero_id
+ elif group_name == "tp_zero":
+ return self.tp_zero_id
+ elif group_name == "tp":
+ return self.tp_id
+
+
+def is_initialized() -> bool:
+ return config["initialized"]
diff --git a/examples/BMTrain/bmtrain/inspect/__init__.py b/examples/BMTrain/bmtrain/inspect/__init__.py
new file mode 100644
index 00000000..2b6d2d26
--- /dev/null
+++ b/examples/BMTrain/bmtrain/inspect/__init__.py
@@ -0,0 +1,3 @@
+from .format import format_summary
+from .model import inspect_model
+from .tensor import inspect_tensor, record_tensor
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/inspect/format.py b/examples/BMTrain/bmtrain/inspect/format.py
new file mode 100644
index 00000000..79b3a1e5
--- /dev/null
+++ b/examples/BMTrain/bmtrain/inspect/format.py
@@ -0,0 +1,64 @@
+from typing import Any, Dict, List
+
+def align_str(s : str, align : int, left : bool) -> str:
+ if left:
+ return s + " " * (align - len(s))
+ else:
+ return " " * (align - len(s)) + s
+
+def format_line(strs : List[str], length : List[int]):
+ ret = ""
+ for v, l in zip(strs, length):
+ if len(v) + 1 > l:
+ v = " " + v[:l - 1]
+ else:
+ v = " " + v
+ ret += align_str(v, l, True)
+ return ret
+
+def item_formater(x) -> str:
+ if isinstance(x, float):
+ return "{:.4f}".format(x)
+ else:
+ return str(x)
+
+def format_summary(summary : List[Dict[str, Any]]) -> str:
+ """Format summary to string.
+
+ Args:
+ summary (List[Dict[str, Any]]): The summary to format.
+
+ Returns:
+ str: The formatted summary.
+
+ """
+ ret = []
+
+ max_name_len = max([len("name")] + [len(item["name"]) for item in summary]) + 4
+ headers = [
+ "name",
+ "shape",
+ "max",
+ "min",
+ "std",
+ "mean",
+ "grad_std",
+ "grad_mean",
+ ]
+ headers_length = [
+ max_name_len,
+ 20,
+ 10,
+ 10,
+ 10,
+ 10,
+ 10,
+ 10
+ ]
+ ret.append( format_line(headers, headers_length) )
+ for item in summary:
+ values = [ item_formater(item[name]) for name in headers ]
+ ret.append( format_line(values, headers_length) )
+ return "\n".join(ret)
+
+
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/inspect/model.py b/examples/BMTrain/bmtrain/inspect/model.py
new file mode 100644
index 00000000..fc54f0d6
--- /dev/null
+++ b/examples/BMTrain/bmtrain/inspect/model.py
@@ -0,0 +1,246 @@
+import torch
+from ..store import broadcast_object
+from ..pipe_layer import PipelineTransformerBlockList
+from ..block_layer import Block
+from ..parameter import DistributedParameter
+from .. import nccl
+from ..global_var import config
+import fnmatch
+
+def _gather_value(value : torch.Tensor, partition_size, origin_size):
+ global_size = partition_size * config['world_size']
+
+ storage = value.storage_type()(global_size)
+
+ if value.storage().size() != partition_size:
+ tmp_buf = torch.zeros(partition_size, dtype=value.dtype, device=value.device)
+ tmp_buf[:value.numel()] = value[:]
+ nccl.allGather(
+ tmp_buf.storage(),
+ storage,
+ config['comm']
+ )
+ else:
+ nccl.allGather(
+ value.storage(),
+ storage,
+ config['comm']
+ )
+
+ output_tensor = torch.tensor([], dtype=value.dtype, device="cuda")
+ output_tensor.set_(storage, 0, origin_size)
+
+ return output_tensor
+
+def inspect_pipeline_transformer_block_list(pipe_model: PipelineTransformerBlockList, param_name : str, _prefix : str = ''):
+ ret = []
+ for name, model in pipe_model._modules.items():
+ idx = int(name)
+ prefix = _prefix + name + '.'
+
+ # fast check
+ pass_fast_check = False
+ for param in model._param_info:
+ abs_name = prefix + param["name"]
+ if fnmatch.fnmatch(abs_name, param_name):
+ pass_fast_check = True
+ break
+ if not pass_fast_check:
+ continue
+
+ if idx in pipe_model.layer_ids:
+ _param_buffer = {}
+ _grad_buffer = {}
+ for kw, val in model._storage_info.items():
+ storage_type = model._storage_params[kw].storage_type()
+
+ _param_buffer[kw] = storage_type(val["partition_size"] * val['world_size'])
+ if model._storage_params[kw].grad is not None:
+ _grad_buffer[kw] = storage_type(val["partition_size"] * val['world_size'])
+
+ nccl.groupStart()
+ for kw, val in model._storage_info.items():
+ nccl.allGather(
+ model._storage_params[kw].storage(),
+ _param_buffer[kw],
+ val["zero_comm"]
+ )
+ if model._storage_params[kw].grad is not None:
+ nccl.allGather(
+ model._storage_params[kw].grad.storage(),
+ _grad_buffer[kw],
+ val["zero_comm"]
+ )
+
+ nccl.groupEnd()
+ for param in model._param_info:
+ abs_name = prefix + param["name"]
+ if fnmatch.fnmatch(abs_name, param_name):
+ kw_name = param["kw_name"]
+ dtype = _param_buffer[kw_name].dtype
+ device = _param_buffer[kw_name].device
+ offset = param["offset"]
+ shape = param["shape"]
+ p = torch.tensor([], dtype=dtype, device=device).set_(_param_buffer[kw_name], offset, shape)
+ if kw_name in _grad_buffer:
+ g = torch.tensor([], dtype=dtype, device=device).set_(_grad_buffer[kw_name], offset, shape)
+ info = {
+ "name": abs_name,
+ "shape": tuple(shape),
+ "std": p.std().cpu().item(),
+ "mean": p.mean().cpu().item(),
+ "grad_std": g.std().cpu().item(),
+ "grad_mean": g.mean().cpu().item(),
+ "max": p.max().cpu().item(),
+ "min": p.min().cpu().item(),
+ }
+ else:
+ info = {
+ "name": abs_name,
+ "shape": tuple(shape),
+ "std": p.std().cpu().item(),
+ "mean": p.mean().cpu().item(),
+ "grad_std": 0.,
+ "grad_mean": 0.,
+ "max": p.max().cpu().item(),
+ "min": p.min().cpu().item(),
+ }
+ broadcast_object(info, config["pipe_comm"], pipe_model.get_stage_by_layer_id(idx))
+ ret.append(info)
+ else:
+ for param in model._param_info:
+ abs_name = prefix + param["name"]
+ if fnmatch.fnmatch(abs_name, param_name):
+ info = broadcast_object({}, config["pipe_comm"], pipe_model.get_stage_by_layer_id(idx))
+ ret.append(info)
+
+ return ret
+
+
+def inspect_block(model : Block, param_name : str, prefix : str = ''):
+ # fast check
+ pass_fast_check = False
+ for param in model._param_info:
+ abs_name = prefix + param["name"]
+ if fnmatch.fnmatch(abs_name, param_name):
+ pass_fast_check = True
+ break
+ if not pass_fast_check:
+ return []
+
+ _param_buffer = {}
+ _grad_buffer = {}
+ for kw, val in model._storage_info.items():
+ storage_type = model._storage_params[kw].storage_type()
+
+ _param_buffer[kw] = storage_type(val["partition_size"] * config['world_size'])
+ if model._storage_params[kw].grad is not None:
+ _grad_buffer[kw] = storage_type(val["partition_size"] * config['world_size'])
+
+ nccl.groupStart()
+ for kw, val in model._storage_info.items():
+ nccl.allGather(
+ model._storage_params[kw].storage(),
+ _param_buffer[kw],
+ config["comm"]
+ )
+ if model._storage_params[kw].grad is not None:
+ nccl.allGather(
+ model._storage_params[kw].grad.storage(),
+ _grad_buffer[kw],
+ config["comm"]
+ )
+
+ nccl.groupEnd()
+ ret = []
+ for param in model._param_info:
+ abs_name = prefix + param["name"]
+ if fnmatch.fnmatch(abs_name, param_name):
+ kw_name = param["kw_name"]
+ dtype = _param_buffer[kw_name].dtype
+ device = _param_buffer[kw_name].device
+ offset = param["offset"]
+ shape = param["shape"]
+ p = torch.tensor([], dtype=dtype, device=device).set_(_param_buffer[kw_name], offset, shape)
+ if kw_name in _grad_buffer:
+ g = torch.tensor([], dtype=dtype, device=device).set_(_grad_buffer[kw_name], offset, shape)
+ ret.append({
+ "name": abs_name,
+ "shape": tuple(shape),
+ "std": p.std().cpu().item(),
+ "mean": p.mean().cpu().item(),
+ "grad_std": g.std().cpu().item(),
+ "grad_mean": g.mean().cpu().item(),
+ "max": p.max().cpu().item(),
+ "min": p.min().cpu().item(),
+ })
+ else:
+ ret.append({
+ "name": abs_name,
+ "shape": tuple(shape),
+ "std": p.std().cpu().item(),
+ "mean": p.mean().cpu().item(),
+ "grad_std": 0.,
+ "grad_mean": 0.,
+ "max": p.max().cpu().item(),
+ "min": p.min().cpu().item(),
+ })
+ return ret
+
+@torch.no_grad()
+def inspect_model(model : torch.nn.Module, param_name : str, prefix : str = ''):
+ """Inspect the model and return the summary of the parameters.
+
+ Args:
+ model (torch.nn.Module): The model to be inspected.
+ param_name (str): The name of the parameter to be inspected. The wildcard '*' can be used to match multiple parameters.
+ prefix (str): The prefix of the parameter name.
+
+ Returns:
+ list: The summary of the parameters.
+
+ Example:
+ >>> result_linear = bmt.inspect.inspect_model(model, "*.linear*")
+ >>> result_layernorm = bmt.inspect.inspect_model(model, "*.layernorm*")
+ >>> text_summray = bmt.inspect.format_summary(result_linear + result_layernorm)
+ >>> bmt.print_rank(text_summary)
+ name shape max min std mean grad_std grad_mean
+ ...
+
+ """
+ if isinstance(model, PipelineTransformerBlockList):
+ return inspect_pipeline_transformer_block_list(model, param_name, prefix)
+ elif isinstance(model, Block):
+ return inspect_block(model, param_name, prefix)
+ else:
+ ret = []
+ for name, param in model._parameters.items():
+ if fnmatch.fnmatch(prefix + name, param_name):
+ if isinstance(param, DistributedParameter):
+ p = _gather_value(param.data, param.storage().size(), param._original_shape)
+ else:
+ p = param
+ if p is None:
+ continue
+ stats = {
+ 'name': prefix + name,
+ 'shape': tuple(p.size()),
+ "std": p.std().cpu().item(),
+ "mean": p.mean().cpu().item(),
+ "max": p.max().cpu().item(),
+ "min": p.min().cpu().item(),
+ }
+ if param.grad is not None:
+ if isinstance(param, DistributedParameter):
+ g = _gather_value(param.grad.data, param.storage().size(), param._original_shape)
+ else:
+ g = param.grad
+ stats["grad_std"] = g.std().cpu().item()
+ stats["grad_mean"] = g.mean().cpu().item()
+ else:
+ stats["grad_std"] = 0.
+ stats["grad_mean"] = 0.
+ ret.append(stats)
+ for name, module in model._modules.items():
+ ret.extend(inspect_model(module, param_name, prefix + name + '.'))
+ return ret
diff --git a/examples/BMTrain/bmtrain/inspect/tensor.py b/examples/BMTrain/bmtrain/inspect/tensor.py
new file mode 100644
index 00000000..9d003f82
--- /dev/null
+++ b/examples/BMTrain/bmtrain/inspect/tensor.py
@@ -0,0 +1,383 @@
+from typing import Optional
+import torch
+from .. import debug
+from .. import nccl
+from ..global_var import config
+from ..store import broadcast_object
+from ..distributed import broadcast
+import math
+
+
+class InspectTensor:
+ """This object is returned by `InspectTensorManager`.
+
+ You can get the tensors recorded by `record_tensor`.
+
+ """
+
+ def __init__(self):
+ self.summary = []
+
+ def _set_summary(self, summary):
+ self._summary = summary
+ for item in summary:
+ item["prefix"] = "" if item["group"] is None else f'{item["group"]}.'
+
+ self.summary = []
+
+ kw_cnt = {}
+ i = 0
+ while i < len(summary):
+ item = summary[i]
+ if item["inside_pipe"] is not None:
+ before_len = len(self.summary)
+
+ assert item["inside_pipe"]["st"]
+ pipe_cnt = {}
+ j = i
+ while j < len(summary):
+ item = summary[j]
+ kw = f'{item["prefix"]}{item["name"]}'
+
+ assert item["inside_pipe"] is not None
+ stage_id = item["inside_pipe"]["stage_id"]
+ stages = item["inside_pipe"]["stages"]
+ st = item["inside_pipe"]["st"]
+ ed = item["inside_pipe"]["ed"]
+
+ if kw not in pipe_cnt:
+ pipe_cnt[kw] = 0
+ pipe_cnt[kw] += 1
+
+ j += 1
+ if ed:
+ break
+
+ for stage in range(stages):
+ if stage_id == stage:
+ broadcast_object(pipe_cnt, config["pipe_comm"], src=stage)
+ for k in range(i, j):
+ item = summary[k]
+ kw = f'{item["prefix"]}{item["name"]}'
+ if kw not in kw_cnt:
+ kw_cnt[kw] = 0
+ tensor = torch.cat(
+ [
+ summary[k + m * (j - i)]["tensor"]
+ for m in range(config["micros"])
+ ],
+ dim=0,
+ )
+ grad = (
+ torch.cat(
+ [
+ summary[k + m * (j - i)]["tensor"].grad
+ for m in range(config["micros"])
+ ],
+ dim=0,
+ )
+ if item["requires_grad"]
+ and item["tensor"].grad is not None
+ else None
+ )
+ self.summary.append(
+ {
+ "name": item["name"],
+ "summary_name": f'{item["prefix"]}{kw_cnt[kw]}.{item["name"]}',
+ "group": item["group"],
+ "min": None,
+ "max": None,
+ "mean": None,
+ "std": None,
+ "shape": (item["shape"][0] * config["micros"],)
+ + item["shape"][1:],
+ "grad_mean": None,
+ "grad_std": None,
+ "tensor": tensor,
+ "grad": grad,
+ "requires_grad": item["requires_grad"],
+ "inside_pipe": {"stage_id": stage},
+ }
+ )
+ kw_cnt[kw] += 1
+ else:
+ cnt = broadcast_object({}, config["pipe_comm"], src=stage)
+ for kw, val in cnt.items():
+ if kw not in kw_cnt:
+ kw_cnt[kw] = 0
+ for _ in range(val):
+ self.summary.append(
+ {
+ "name": item["name"],
+ "summary_name": f'{item["prefix"]}{kw_cnt[kw]}.{item["name"]}',
+ "group": None,
+ "min": None,
+ "max": None,
+ "mean": None,
+ "std": None,
+ "shape": None,
+ "grad_mean": None,
+ "grad_std": None,
+ "tensor": None,
+ "grad": None,
+ "requires_grad": None,
+ "inside_pipe": {"stage_id": stage},
+ }
+ )
+ kw_cnt[kw] += 1
+
+ after_len = len(self.summary)
+ with torch.enable_grad():
+ for it in self.summary[before_len:after_len]:
+ if it["tensor"] is not None:
+ has_grad = it["grad"] is not None
+ info = {
+ "group": it["group"],
+ "shape": it["shape"],
+ "requires_grad": it["requires_grad"],
+ "has_grad": has_grad,
+ }
+ broadcast_object(
+ info,
+ config["pipe_comm"],
+ src=it["inside_pipe"]["stage_id"],
+ )
+ tensor = it["tensor"]
+ tensor = broadcast(
+ tensor,
+ it["inside_pipe"]["stage_id"],
+ config["pipe_comm"],
+ )
+ grad = it["grad"]
+ else:
+ info = broadcast_object(
+ {},
+ config["pipe_comm"],
+ src=it["inside_pipe"]["stage_id"],
+ )
+ has_grad = info.pop("has_grad")
+ it.update(info)
+ tensor = torch.empty(it["shape"]).cuda().requires_grad_()
+ tensor = broadcast(
+ tensor,
+ it["inside_pipe"]["stage_id"],
+ config["pipe_comm"],
+ )
+ if has_grad:
+ grad = torch.empty(it["shape"]).cuda()
+ tensor = tensor.chunk(stages, dim=0)[stage_id].clone()
+ it["tensor"] = tensor
+ if has_grad:
+ grad = broadcast(
+ grad, it["inside_pipe"]["stage_id"], config["pipe_comm"]
+ )
+ grad = grad.chunk(stages, dim=0)[stage_id].clone()
+ tensor.grad = grad
+ it["shape"] = (it["shape"][0] // config["pipe_size"],) + it[
+ "shape"
+ ][1:]
+
+ i = i + config["micros"] * (j - i)
+ else:
+ kw = f'{item["prefix"]}{item["name"]}'
+ if kw not in kw_cnt:
+ kw_cnt[kw] = 0
+ self.summary.append(
+ {
+ "name": item["name"],
+ "summary_name": f'{item["prefix"]}{kw_cnt[kw]}.{item["name"]}',
+ "group": item["group"],
+ "min": None,
+ "max": None,
+ "mean": None,
+ "std": None,
+ "shape": item["shape"],
+ "grad_mean": None,
+ "grad_std": None,
+ "tensor": item["tensor"],
+ "requires_grad": item["requires_grad"],
+ "inside_pipe": None,
+ }
+ )
+ kw_cnt[kw] += 1
+ i = i + 1
+
+ def get_summary(self):
+ r"""Get the summary of the tensors recorded by `record_tensor`.
+
+ Returns:
+ A list of dicts. Each dict contains the following keys:
+ - name: The name of the tensor.
+ - min: The minimum value of the tensor.
+ - max: The maximum value of the tensor.
+ - mean: The mean value of the tensor.
+ - std: The standard deviation of the tensor.
+ - shape: The shape of the tensor.
+ - grad_mean: The mean value of the gradient of the tensor.
+ - grad_std: The standard deviation of the gradient of the tensor.
+
+ **Note:** This method must be called outside of the `with` block.
+
+ """
+ self._set_summary(self._summary)
+ ret = []
+ for item in self.summary:
+ comm = config["comm"]
+
+ if not item["requires_grad"] or item["tensor"].grad is None:
+ x = item["tensor"]
+ info = torch.empty(2, dtype=x.dtype, device=x.device)
+ info[0] = x.mean()
+ info[1] = x.var()
+ nccl.allReduce(info.storage(), info.storage(), "sum", comm)
+ info = info / nccl.commCount(comm)
+ x_mean = info[0].cpu().item()
+ x_std = math.sqrt(info[1].cpu().item())
+ grad_mean = None
+ grad_std = None
+ else:
+ x = item["tensor"]
+ info = torch.empty(4, dtype=x.dtype, device=x.device)
+ info[0] = x.mean()
+ info[1] = x.var()
+ info[2] = x.grad.mean()
+ info[3] = x.grad.var()
+ nccl.allReduce(info.storage(), info.storage(), "sum", comm)
+ info = info / nccl.commCount(comm)
+ x_mean = info[0].cpu().item()
+ x_std = math.sqrt(info[1].cpu().item())
+ grad_mean = info[2].cpu().item()
+ grad_std = math.sqrt(info[3].cpu().item())
+
+ info[0] = x.max()
+ info[1] = -x.min()
+ nccl.allReduce(info.storage(), info.storage(), "max", comm)
+ x_max = info[0].cpu().item()
+ x_min = -info[1].cpu().item()
+
+ summary = {
+ "name": item["summary_name"],
+ "min": x_min,
+ "max": x_max,
+ "mean": x_mean,
+ "std": x_std,
+ "shape": tuple(
+ (item["shape"][0] * config["world_size"],) + item["shape"][1:]
+ ),
+ "grad_mean": grad_mean,
+ "grad_std": grad_std,
+ }
+
+ ret.append(summary)
+ return ret
+
+ def get_tensor(
+ self, name: str, group: Optional[str] = None, index: Optional[int] = None
+ ) -> torch.Tensor:
+ """Get the tensor recorded by `record_tensor` by name, group and index.
+
+ Args:
+ name (str): The name of the tensor.
+ group (Optional[str]): The group of the tensor.
+ index (Optional[int]): The index of the tensor.
+
+ Returns:
+ The tensor if found, otherwise None.
+
+ """
+ group_name_prefix = f"{group}." if group is not None else ""
+
+ all_names = []
+ if index is None:
+ all_names.append(f"{group_name_prefix}{name}")
+ all_names.append(f"{group_name_prefix}0.{name}")
+ else:
+ all_names.append(f"{group_name_prefix}{index}.{name}")
+
+ for item in self.summary:
+ if item["name"] in all_names:
+ return item["tensor"]
+ return None
+
+
+class InspectTensorManager:
+ def __init__(self) -> None:
+ self._inspector = None
+
+ def __enter__(self) -> InspectTensor:
+ self.prev_val = debug.get("_inspect_tensor", False)
+ if not self.prev_val:
+ debug.set("_inspect_tensor", True)
+ self._inspector = InspectTensor()
+ return self._inspector
+ else:
+ raise RuntimeError("InspectTensorManager is already in use")
+
+ def __exit__(self, *args):
+ if not self.prev_val:
+ debug.set("_inspect_tensor", self.prev_val)
+ summary = debug.get("_inspect_hidden_states", [])
+ self._inspector._set_summary(summary)
+ self._inspector = None
+ debug.set("_inspect_hidden_states", [])
+
+
+def inspect_tensor() -> InspectTensorManager:
+ """**inspect_tensor** returns a context manager that can be used to get the intermediate results of the model computations and their gradients.
+
+ Example:
+ >>> with bmt.inspect.inspect_tensor() as inspector:
+ >>> loss = model(inputs)
+ >>> loss.backward()
+ >>> summary = inspector.get_summary()
+ >>> text_summary = bmt.inspect.format_summary(summary)
+ >>> bmt.print_rank(text_summary)
+ name shape max min std mean grad_std grad_mean
+ ...
+
+ **Note:** loss.backward() must be called inside the context manager, otherwise the gradients will not be recorded.
+ **Note:** Calling get_summary() has significant overhead.
+
+ """
+
+ return InspectTensorManager()
+
+
+def record_tensor(x: torch.Tensor, name: str, group=None):
+ """Record the tensor for inspection.
+
+ Args:
+ x (torch.Tensor): The tensor to be recorded.
+ name (str): The name of the tensor.
+ group (str): The group name of the tensor.
+
+ **Note:** This function is only available in inspect_tensor context.
+ **Note:** Recording too many tensors may cause memory issues.
+
+ """
+ if isinstance(x, torch.nn.Parameter):
+ raise RuntimeError("Cannot inspect Parameter")
+
+ if not debug.get("_inspect_tensor", False):
+ # do nothing
+ return
+
+ if x.requires_grad:
+ x.retain_grad()
+ debug.append(
+ "_inspect_hidden_states",
+ {
+ "name": name,
+ "group": group,
+ "min": None,
+ "max": None,
+ "mean": None,
+ "std": None,
+ "shape": x.shape,
+ "grad_mean": None,
+ "grad_std": None,
+ "tensor": x,
+ "requires_grad": x.requires_grad,
+ "inside_pipe": None,
+ },
+ )
diff --git a/examples/BMTrain/bmtrain/layer.py b/examples/BMTrain/bmtrain/layer.py
new file mode 100644
index 00000000..e071e01b
--- /dev/null
+++ b/examples/BMTrain/bmtrain/layer.py
@@ -0,0 +1,143 @@
+import torch
+from .parameter import DistributedParameter
+from .global_var import config
+import itertools
+from .utils import tp_split_tensor
+
+class DistributedModule(torch.nn.Module):
+ """
+ DistributedModule is a subclass of torch.nn.Module that overrides the `__getattr__` method to gather distributed parameters automatically.
+
+ """
+
+ def __getattr__(self, name: str):
+ ret = super().__getattr__(name)
+ # gather distributed parameters if not in bmt.Block
+ if isinstance(ret, DistributedParameter) and not ret._in_block:
+ return ret.gather()
+ return ret
+
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
+ r"""Saves module state to `destination` dictionary, containing a state
+ of the module, but not its descendants. This is called on every
+ submodule in :meth:`~torch.nn.Module.state_dict`.
+
+ In rare cases, subclasses can achieve class-specific behavior by
+ overriding this method with custom logic.
+
+ Args:
+ destination (dict): a dict where state will be stored
+ prefix (str): the prefix for parameters and buffers used in this
+ module
+ """
+ for name, param in self._parameters.items():
+ if param is not None:
+ if isinstance(param, DistributedParameter):#and not param._in_block:
+ if param._in_block:
+ destination[prefix + name] = param.tp_gather().detach() # sync operation
+ else:
+ destination[prefix + name] = param.gather_all().detach() # sync operation
+ if config['save_param_to_cpu']:
+ destination[prefix + name] = destination[prefix + name].cpu()
+ else:
+ if config['save_param_to_cpu']:
+ destination[prefix + name] = param if keep_vars else param.detach().cpu()
+ else:
+ destination[prefix + name] = param if keep_vars else param.detach()
+
+ for name, buf in self._buffers.items():
+ if buf is not None and name not in self._non_persistent_buffers_set:
+ destination[prefix + name] = buf if keep_vars else buf.detach()
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ r"""Copies parameters and buffers from :attr:`state_dict` into only
+ this module, but not its descendants. This is called on every submodule
+ in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
+ module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
+ For state dicts without metadata, :attr:`local_metadata` is empty.
+ Subclasses can achieve class-specific backward compatible loading using
+ the version number at `local_metadata.get("version", None)`.
+
+ .. note::
+ :attr:`state_dict` is not the same object as the input
+ :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
+ it can be modified.
+
+ Args:
+ state_dict (dict): a dict containing parameters and
+ persistent buffers.
+ prefix (str): the prefix for parameters and buffers used in this
+ module
+ local_metadata (dict): a dict containing the metadata for this module.
+ See
+ strict (bool): whether to strictly enforce that the keys in
+ :attr:`state_dict` with :attr:`prefix` match the names of
+ parameters and buffers in this module
+ missing_keys (list of str): if ``strict=True``, add missing keys to
+ this list
+ unexpected_keys (list of str): if ``strict=True``, add unexpected
+ keys to this list
+ error_msgs (list of str): error messages should be added to this
+ list, and will be reported together in
+ :meth:`~torch.nn.Module.load_state_dict`
+ """
+ for hook in self._load_state_dict_pre_hooks.values():
+ hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+
+ persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
+ local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
+ local_state = {k: v for k, v in local_name_params if v is not None}
+
+ for name, param in local_state.items():
+ key = prefix + name
+ if key in state_dict:
+ tp_mode = param._tp_mode
+ input_param = state_dict[key]
+ if input_param.__class__.__name__ == "DistributedTensorWrapper":
+ input_param = input_param.broadcast()
+ # This is used to avoid copying uninitialized parameters into
+ # non-lazy modules, since they dont have the hook to do the checks
+ # in such case, it will error when accessing the .shape attribute.
+ is_param_lazy = torch.nn.parameter.is_lazy(param)
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
+ if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
+ input_param = input_param[0]
+
+ if not is_param_lazy and not isinstance(param, DistributedParameter) and input_param.shape != param.shape:
+ # local shape should match the one in checkpoint
+ error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
+ 'the shape in current model is {}.'
+ .format(key, input_param.shape, param.shape))
+ continue
+ verify_shape = torch.Size(param._original_shape if not tp_mode else param._tp_original_shape)
+ if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != verify_shape:
+ error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
+ 'the shape in current model is {}.'
+ .format(key, input_param.shape, verify_shape))
+ try:
+ with torch.no_grad():
+ if isinstance(param, DistributedParameter):
+ tp_split_dim = param._tp_split_dim
+ if tp_mode and tp_split_dim >= 0:
+ input_param = tp_split_tensor(input_param, tp_split_dim)
+ param._copy_data(input_param)
+ else:
+ param.copy_(input_param)
+ except Exception as ex:
+ error_msgs.append('While copying the parameter named "{}", '
+ 'whose dimensions in the model are {} and '
+ 'whose dimensions in the checkpoint are {}, '
+ 'an exception occurred : {}.'
+ .format(key, param.size(), input_param.size(), ex.args))
+ elif strict:
+ missing_keys.append(key)
+
+ if strict:
+ for key in state_dict.keys():
+ if key.startswith(prefix):
+ input_name = key[len(prefix):]
+ input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
+ if input_name not in self._modules and input_name not in local_state:
+ unexpected_keys.append(key)
+
diff --git a/examples/BMTrain/bmtrain/loss/__init__.py b/examples/BMTrain/bmtrain/loss/__init__.py
new file mode 100644
index 00000000..daa731fd
--- /dev/null
+++ b/examples/BMTrain/bmtrain/loss/__init__.py
@@ -0,0 +1 @@
+from .cross_entropy import FusedCrossEntropy
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/loss/_function.py b/examples/BMTrain/bmtrain/loss/_function.py
new file mode 100644
index 00000000..6ff3c471
--- /dev/null
+++ b/examples/BMTrain/bmtrain/loss/_function.py
@@ -0,0 +1,182 @@
+from .. import C
+import torch
+
+CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda
+
+
+def has_inf_nan(g_half: torch.Tensor, out: torch.Tensor) -> None:
+ assert out.dtype == torch.uint8, "out must be a uint8 tensor"
+ assert CHECK_INPUT(g_half), "g_fp16 must be contiguous and on cuda"
+ assert CHECK_INPUT(out), "out must be contiguous and on cuda"
+ mid = torch.zeros(1024, device=out.device, dtype=out.dtype)
+ stream = torch.cuda.current_stream().cuda_stream
+ if g_half.dtype == torch.float16:
+ C.has_nan_inf_fp16_launcher(
+ g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream
+ )
+ elif g_half.dtype == torch.bfloat16:
+ if not C.is_bf16_supported():
+ raise NotImplementedError(f"bfloat16 is not supported on current GPU")
+ C.has_nan_inf_bf16_launcher(
+ g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream
+ )
+ else:
+ raise ValueError(f"has_inf_nan not supported for dtype {g_half.dtype}")
+
+
+def cross_entropy_forward(
+ m: int,
+ n: int,
+ input: torch.Tensor,
+ target: torch.Tensor,
+ softmax: torch.Tensor,
+ output: torch.Tensor,
+ ignore_index: int,
+) -> None:
+ CHECK_INPUT(input)
+ CHECK_INPUT(target)
+ CHECK_INPUT(softmax)
+ CHECK_INPUT(output)
+ assert target.dtype == torch.int32, "target must be an int tensor"
+ assert output.dtype == torch.float32, "output must be a float tensor"
+ assert (
+ input.numel() == softmax.numel()
+ ), "input and softmax must have the same number of elements"
+ assert (
+ target.numel() == output.numel()
+ ), "target and output must have the same number of elements"
+ input_ptr = input.data_ptr()
+ target_ptr = target.data_ptr()
+ softmax_ptr = softmax.data_ptr()
+ output_ptr = output.data_ptr()
+ cuda_stream = torch.cuda.current_stream().cuda_stream
+ if input.dtype == torch.float16:
+ C.cross_entropy_forward_fp16_launcher(
+ m,
+ n,
+ input_ptr,
+ target_ptr,
+ softmax_ptr,
+ output_ptr,
+ ignore_index,
+ cuda_stream,
+ )
+ elif input.dtype == torch.bfloat16:
+ if not C.is_bf16_supported():
+ raise NotImplementedError(f"bfloat16 is not supported on current GPU")
+ C.cross_entropy_forward_bf16_launcher(
+ m,
+ n,
+ input_ptr,
+ target_ptr,
+ softmax_ptr,
+ output_ptr,
+ ignore_index,
+ cuda_stream,
+ )
+ else:
+ raise ValueError(f"cross_entropy_forward not supported for dtype {input.dtype}")
+
+
+def cross_entropy_backward_inplace(
+ m: int,
+ n: int,
+ grad_output: torch.Tensor,
+ target: torch.Tensor,
+ x: torch.Tensor,
+ ignore_index: int,
+) -> None:
+ CHECK_INPUT(grad_output)
+ CHECK_INPUT(target)
+ CHECK_INPUT(x)
+ assert grad_output.dtype == torch.float32, "grad_output must be a float tensor"
+ assert target.dtype == torch.int32, "target must be an int tensor"
+ assert (
+ target.numel() == grad_output.numel()
+ ), "target and grad_output must have the same number of elements"
+ cuda_stream = torch.cuda.current_stream().cuda_stream
+ grad_output_ptr = grad_output.data_ptr()
+ target_ptr = target.data_ptr()
+ x_ptr = x.data_ptr()
+
+ if x.dtype == torch.float16:
+ C.cross_entropy_backward_inplace_fp16_launcher(
+ m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream
+ )
+ elif x.dtype == torch.bfloat16:
+ if not C.is_bf16_supported():
+ raise NotImplementedError(f"bfloat16 is not supported on current GPU")
+ C.cross_entropy_backward_inplace_bf16_launcher(
+ m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream
+ )
+ else:
+ raise ValueError(
+ f"cross_entropy_backward not supported for dtype {input.dtype}"
+ )
+
+
+def fused_sumexp(logits: torch.Tensor, max_logits: torch.Tensor) -> torch.Tensor:
+ CHECK_INPUT(logits)
+ CHECK_INPUT(max_logits)
+ assert max_logits.dtype == torch.float32, "max_logits must be float tensor"
+ assert max_logits.size(0) == logits.size(
+ 0
+ ), "max_logits must have same size(0) as logits"
+ sum_exp_logits = torch.empty(
+ logits.size(0), dtype=torch.float32, device=logits.device
+ )
+ m = logits.size(0)
+ n = logits.size(1)
+ cuda_stream = torch.cuda.current_stream().cuda_stream
+ logits_ptr = logits.data_ptr()
+ max_logits_ptr = max_logits.data_ptr()
+ sum_exp_logits_ptr = sum_exp_logits.data_ptr()
+ if logits.dtype == torch.float16:
+ C.fused_sumexp_fp16_launcher(
+ m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream
+ )
+ elif logits.dtype == torch.bfloat16:
+ if not C.is_bf16_supported():
+ raise NotImplementedError(f"bfloat16 is not supported on current GPU")
+ C.fused_sumexp_bf16_launcher(
+ m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream
+ )
+ else:
+ raise ValueError(f"fused_sumexp not supported for dtype {logits.dtype}")
+ return sum_exp_logits
+
+
+def fused_softmax_inplace(
+ logits: torch.Tensor, max_logits: torch.Tensor, sum_exp_logits: torch.Tensor
+) -> None:
+ CHECK_INPUT(logits)
+ CHECK_INPUT(max_logits)
+ CHECK_INPUT(sum_exp_logits)
+ assert max_logits.dtype == torch.float32, "max_logits must be float tensor"
+ assert sum_exp_logits.dtype == torch.float32, "sum_exp_logits must be float tensor"
+ assert max_logits.size(0) == logits.size(
+ 0
+ ), "max_logits must have same size(0) as logits"
+ assert sum_exp_logits.size(0) == logits.size(
+ 0
+ ), "sum_exp_logits must have same size(0) as logits"
+ m = logits.size(0)
+ n = logits.size(1)
+ cuda_stream = torch.cuda.current_stream().cuda_stream
+ logits_ptr = logits.data_ptr()
+ max_logits_ptr = max_logits.data_ptr()
+ sum_exp_logits_ptr = sum_exp_logits.data_ptr()
+ if logits.dtype == torch.float16:
+ C.fused_softmax_inplace_fp16_launcher(
+ m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream
+ )
+ elif logits.dtype == torch.bfloat16:
+ if not C.is_bf16_supported():
+ raise NotImplementedError(f"bfloat16 is not supported on current GPU")
+ C.fused_softmax_inplace_bf16_launcher(
+ m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream
+ )
+ else:
+ raise ValueError(
+ f"fused_softmax_inplace not supported for dtype {logits.dtype}"
+ )
diff --git a/examples/BMTrain/bmtrain/loss/cross_entropy.py b/examples/BMTrain/bmtrain/loss/cross_entropy.py
new file mode 100644
index 00000000..5be07665
--- /dev/null
+++ b/examples/BMTrain/bmtrain/loss/cross_entropy.py
@@ -0,0 +1,260 @@
+from typing import Optional
+import torch
+from . import _function as F
+from bmtrain.global_var import config
+from bmtrain.distributed import all_gather, all_reduce
+
+class OpFusedCrossEntropy(torch.autograd.Function):
+ """
+ CrossEntropy dim = 1
+ """
+ @staticmethod
+ def forward(ctx, x : torch.Tensor, target : torch.Tensor, ignore_index: int):
+ assert x.ndim == 2
+ softmax = torch.empty(x.size(), device=x.device, dtype=x.dtype)
+ out = torch.empty(x.size(0), device=x.device, dtype=torch.float)
+ F.cross_entropy_forward(
+ x.size(0), x.size(1),
+ x, target,
+ softmax, out,
+ ignore_index,
+ )
+ ctx.ignore_index = ignore_index
+ ctx.save_for_backward(softmax, target)
+ return out # float tensor
+
+ @staticmethod
+ def backward(ctx, grad_output : torch.Tensor):
+ grad_output = grad_output.contiguous()
+ softmax, target = ctx.saved_tensors
+ F.cross_entropy_backward_inplace(
+ softmax.size(0), softmax.size(1),
+ grad_output, target,
+ softmax,
+ ctx.ignore_index,
+ )
+ return (softmax, None, None)
+
+class VPFusedCrossEntropy(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, logits : torch.Tensor, target : torch.Tensor):
+ comm = config['tp_comm']
+ rank = config['tp_rank']
+ world_size = config['tp_size']
+
+ max_logits = torch.max(logits, dim=-1)[0].float()
+ max_logits = all_reduce(max_logits, op="max", comm=comm)
+
+ partition_vocab_size = logits.size()[-1]
+ vocab_start_index = rank * partition_vocab_size
+ vocab_end_index = (rank + 1) * partition_vocab_size
+
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
+ target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
+ masked_target = target.clone() - vocab_start_index
+ masked_target[target_mask] = 0
+
+ logits_2d = logits.view(-1, partition_vocab_size)
+ masked_target_1d = masked_target.view(-1)
+ arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
+ predicted_logits_1d = logits_2d[arange_1d, masked_target_1d].contiguous() # (-1,)
+ predicted_logits = predicted_logits_1d.view_as(target)
+ predicted_logits[target_mask] = 0.0 # if target=-100, it will also be 0
+
+ # All reduce is needed to get the chunks from other GPUs.
+ predicted_logits = all_reduce(predicted_logits.float(), op="sum", comm=comm)
+ predicted_logits = predicted_logits - max_logits
+ # Sum of exponential of logits along vocab dimension across all GPUs.
+
+ sum_exp_logits = torch.empty(logits.size(0), device=logits.device, dtype=torch.float)
+ sum_exp_logits = F.fused_sumexp(logits, max_logits) # float
+ sum_exp_logits = all_reduce(sum_exp_logits, op="sum", comm=comm) + 1e-10 # avoid nan
+
+ softmax = logits.clone()
+ F.fused_softmax_inplace(softmax, max_logits, sum_exp_logits) # logits -> softmax
+ # logits = logits.float() - max_logits.unsqueeze(dim=-1).float()
+ # exp_logits = logits
+ # torch.exp(logits, out=exp_logits)
+ # sum_exp_logits = exp_logits.sum(dim=-1)
+ # exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
+
+ loss = torch.log(sum_exp_logits.view(predicted_logits.shape)) - predicted_logits
+
+ # Normalize
+ ctx.save_for_backward(softmax, target_mask, masked_target_1d)
+
+ return loss
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # Retreive tensors from the forward path.
+ softmax, target_mask, masked_target_1d = ctx.saved_tensors
+ # All the inputs have softmax as thier gradient.
+ grad_input = softmax
+ # For simplicity, work with the 2D gradient.
+ partition_vocab_size = softmax.size()[-1]
+ grad_2d = grad_input.view(-1, partition_vocab_size)
+
+ # Add the gradient from matching classes.
+ arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
+
+ softmax_update = 1.0 - target_mask.view(-1).float()
+
+ grad_2d[arange_1d, masked_target_1d] -= softmax_update
+ grad_input.mul_(grad_output.view(*grad_input.shape[:-1]).unsqueeze(dim=-1))
+
+ return grad_input, None
+
+class FusedCrossEntropy(torch.nn.Module):
+ r"""This criterion computes the cross entropy loss between input and target.
+
+ It is useful when training a classification problem with `C` classes.
+ If provided, the optional argument :attr:`weight` should be a 1D `Tensor`
+ assigning weight to each of the classes.
+ This is particularly useful when you have an unbalanced training set.
+
+ The `input` is expected to contain raw, unnormalized scores for each class.
+ `input` has to be a Tensor of size :math:`(minibatch, C)`.
+
+ The `target` that this criterion expects should contain either:
+
+ - Class indices in the range :math:`[0, C-1]` where :math:`C` is the number of classes; if
+ `ignore_index` is specified, this loss also accepts this class index (this index
+ may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction`
+ set to ``'none'``) loss for this case can be described as:
+
+ .. math::
+ \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
+ l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})}
+ \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}
+
+ where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight,
+ :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension. If
+ :attr:`reduction` is not ``'none'`` (default ``'mean'``), then
+
+ .. math::
+ \ell(x, y) = \begin{cases}
+ \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, &
+ \text{if reduction} = \text{`mean';}\\
+ \sum_{n=1}^N l_n, &
+ \text{if reduction} = \text{`sum'.}
+ \end{cases}
+
+ Note that this case is equivalent to the combination of :class:`~torch.nn.LogSoftmax` and
+ :class:`~torch.nn.NLLLoss`.
+
+ - Probabilities for each class; useful when labels beyond a single class per minibatch item
+ are required, such as for blended labels, label smoothing, etc. The unreduced (i.e. with
+ :attr:`reduction` set to ``'none'``) loss for this case can be described as:
+
+ .. math::
+ \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
+ l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\exp(\sum_{i=1}^C x_{n,i})} y_{n,c}
+
+ where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight,
+ :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension. If
+ :attr:`reduction` is not ``'none'`` (default ``'mean'``), then
+
+ .. math::
+ \ell(x, y) = \begin{cases}
+ \frac{\sum_{n=1}^N l_n}{N}, &
+ \text{if reduction} = \text{`mean';}\\
+ \sum_{n=1}^N l_n, &
+ \text{if reduction} = \text{`sum'.}
+ \end{cases}
+
+ .. note::
+ The performance of this criterion is generally better when `target` contains class
+ indices, as this allows for optimized computation. Consider providing `target` as
+ class probabilities only when a single class label per minibatch item is too restrictive.
+
+ Args:
+ weight (Tensor, optional): a manual rescaling weight given to each class.
+ If given, has to be a Tensor of size `C`
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when :attr:`reduce` is ``False``. Default: ``True``
+ ignore_index (int, optional): Specifies a target value that is ignored
+ and does not contribute to the input gradient. When :attr:`size_average` is
+ ``True``, the loss is averaged over non-ignored targets. Note that
+ :attr:`ignore_index` is only applicable when the target contains class indices.
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
+ be applied, ``'mean'``: the weighted mean of the output is taken,
+ ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in
+ the meantime, specifying either of those two args will override
+ :attr:`reduction`. Default: ``'mean'``
+ label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount
+ of smoothing when computing the loss, where 0.0 means no smoothing. The targets
+ become a mixture of the original ground truth and a uniform distribution as described in
+ `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`.
+
+ Shape:
+ - Input: :math:`(N, C)` where `C = number of classes`.
+ - Target: If containing class indices, shape :math:`(N)` where each value is
+ :math:`0 \leq \text{targets}[i] \leq C-1`. If containing class probabilities,
+ same shape as the input.
+ - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)`.
+ Otherwise, scalar.
+
+ Examples::
+
+ >>> # Example of target with class indices
+ >>> loss_func = bmt.loss.FusedCrossEntropy()
+ >>> input = torch.randn(32, 100).half()
+ >>> target = torch.randint(0, 100, (32,)).long()
+ >>> loss = loss_func(input, target)
+ >>> loss.backward()
+ """
+ def __init__(self,
+ weight: Optional[torch.Tensor] = None,
+ ignore_index: int = -100,
+ reduction: str = 'mean',
+ label_smoothing: float = 0.0, # TODO not supported yet
+ parallel: bool = False,
+ ) -> None:
+ super().__init__()
+ self.weight = weight
+ self.ignore_index = ignore_index
+ self.reduction = reduction
+ self.label_smoothing = label_smoothing
+ self.parallel = parallel
+
+ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ if self.parallel:
+ ret = VPFusedCrossEntropy.apply(input, target.long())
+ else:
+ if input.dtype == torch.float32:
+ return torch.nn.functional.cross_entropy(
+ input,
+ target.long(),
+ weight=self.weight,
+ ignore_index=self.ignore_index,
+ reduction=self.reduction,
+ label_smoothing=self.label_smoothing)
+
+ ret = OpFusedCrossEntropy.apply(input, target.int(), self.ignore_index) # return float tensor
+
+ if self.weight is not None:
+ if self.weight.dim() != 1 or self.weight.size(0) != input.size(1):
+ raise ValueError("weight should be a 1D tensor of size C");
+ w = self.weight[torch.where(target==self.ignore_index, 0, target)].float()
+ w[target==self.ignore_index] = 0
+ else:
+ w = (target != self.ignore_index).int()
+
+ ret = w * ret
+
+ if self.reduction == "none":
+ return ret
+ elif self.reduction == "sum":
+ return ret.sum()
+ elif self.reduction == "mean":
+ return ret.sum() / w.sum().float()
diff --git a/examples/BMTrain/bmtrain/lr_scheduler/__init__.py b/examples/BMTrain/bmtrain/lr_scheduler/__init__.py
new file mode 100644
index 00000000..0d9a0596
--- /dev/null
+++ b/examples/BMTrain/bmtrain/lr_scheduler/__init__.py
@@ -0,0 +1,6 @@
+from .warmup import WarmupLRScheduler
+from .no_decay import NoDecay
+from .noam import Noam
+from .linear import Linear
+from .cosine import Cosine
+from .exponential import Exponential
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/lr_scheduler/cosine.py b/examples/BMTrain/bmtrain/lr_scheduler/cosine.py
new file mode 100644
index 00000000..3aed034d
--- /dev/null
+++ b/examples/BMTrain/bmtrain/lr_scheduler/cosine.py
@@ -0,0 +1,18 @@
+import math
+from .warmup import WarmupLRScheduler
+
+
+class Cosine(WarmupLRScheduler):
+ r"""
+ After a warmup period during which learning rate increases linearly between 0 and the start_lr,
+ The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
+ """
+
+ def get_lr_warmup(self, num_iter) -> float:
+ return self.start_lr * num_iter / self.warmup_iter
+
+ def get_lr_decay(self, num_iter) -> float:
+ progress = (num_iter - self.warmup_iter) / max(
+ 1, (self.end_iter - self.warmup_iter)
+ )
+ return max(0.0, self.start_lr * 0.5 * (1.0 + math.cos(progress * math.pi)))
diff --git a/examples/BMTrain/bmtrain/lr_scheduler/exponential.py b/examples/BMTrain/bmtrain/lr_scheduler/exponential.py
new file mode 100644
index 00000000..6cf3240e
--- /dev/null
+++ b/examples/BMTrain/bmtrain/lr_scheduler/exponential.py
@@ -0,0 +1,20 @@
+from .warmup import WarmupLRScheduler
+
+
+class Exponential(WarmupLRScheduler):
+ r"""
+ After a warmup period during which learning rate increases linearly between 0 and the start_lr,
+ The decay period performs :math:`\text{lr}=\text{start_lr}\times \gamma ^ {\left(\text{num_iter}-\text{warmup_iter}\right)}`
+ """
+
+ def __init__(
+ self, optimizer, start_lr, warmup_iter, end_iter, num_iter, gamma=0.95
+ ) -> None:
+ super().__init__(optimizer, start_lr, warmup_iter, end_iter, num_iter)
+ self.gamma = gamma
+
+ def get_lr_warmup(self, num_iter) -> float:
+ return self.start_lr * num_iter / self.warmup_iter
+
+ def get_lr_decay(self, num_iter) -> float:
+ return max(0.0, self.start_lr * self.gamma ** (num_iter - self.warmup_iter))
diff --git a/examples/BMTrain/bmtrain/lr_scheduler/linear.py b/examples/BMTrain/bmtrain/lr_scheduler/linear.py
new file mode 100644
index 00000000..af193dd8
--- /dev/null
+++ b/examples/BMTrain/bmtrain/lr_scheduler/linear.py
@@ -0,0 +1,19 @@
+from .warmup import WarmupLRScheduler
+
+
+class Linear(WarmupLRScheduler):
+ r"""
+ After a warmup period during which learning rate increases linearly between 0 and the start_lr,
+ The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{end_iter}-\text{num_iter}}{\text{end_iter}-\text{warmup_iter}}`
+ """
+
+ def get_lr_warmup(self, num_iter) -> float:
+ return self.start_lr * num_iter / self.warmup_iter
+
+ def get_lr_decay(self, num_iter) -> float:
+ return max(
+ 0.0,
+ self.start_lr
+ * (self.end_iter - num_iter)
+ / (self.end_iter - self.warmup_iter),
+ )
diff --git a/examples/BMTrain/bmtrain/lr_scheduler/no_decay.py b/examples/BMTrain/bmtrain/lr_scheduler/no_decay.py
new file mode 100644
index 00000000..6f85bf0a
--- /dev/null
+++ b/examples/BMTrain/bmtrain/lr_scheduler/no_decay.py
@@ -0,0 +1,14 @@
+from .warmup import WarmupLRScheduler
+
+
+class NoDecay(WarmupLRScheduler):
+ r"""
+ After a warmup period during which learning rate increases linearly between 0 and the start_lr,
+ The decay period performs :math:`\text{lr}=\text{start_lr}`
+ """
+
+ def get_lr_warmup(self, num_iter) -> float:
+ return self.start_lr * num_iter / self.warmup_iter
+
+ def get_lr_decay(self, num_iter) -> float:
+ return self.start_lr
diff --git a/examples/BMTrain/bmtrain/lr_scheduler/noam.py b/examples/BMTrain/bmtrain/lr_scheduler/noam.py
new file mode 100644
index 00000000..8954a64d
--- /dev/null
+++ b/examples/BMTrain/bmtrain/lr_scheduler/noam.py
@@ -0,0 +1,15 @@
+import math
+from .warmup import WarmupLRScheduler
+
+
+class Noam(WarmupLRScheduler):
+ r"""
+ After a warmup period during which performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{num_iter}}{\text{warmup_iter}^{3/2}}`,
+ The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{1}}{\sqrt{\text{num_iter}}}`
+ """
+
+ def get_lr_warmup(self, num_iter) -> float:
+ return self.start_lr / math.sqrt(self.warmup_iter) * num_iter / self.warmup_iter
+
+ def get_lr_decay(self, num_iter) -> float:
+ return self.start_lr / math.sqrt(num_iter)
diff --git a/examples/BMTrain/bmtrain/lr_scheduler/warmup.py b/examples/BMTrain/bmtrain/lr_scheduler/warmup.py
new file mode 100644
index 00000000..1f9ccc8e
--- /dev/null
+++ b/examples/BMTrain/bmtrain/lr_scheduler/warmup.py
@@ -0,0 +1,72 @@
+import torch
+
+
+class WarmupLRScheduler:
+ r"""Base class for learning rate schedulers with warmup.
+
+ Args:
+ optimizer (torch.optim.Optimizer): optimizer used for training
+ start_lr (float): starting learning rate
+ warmup_iter (int): number of iterations to linearly increase learning rate
+ end_iter (int): number of iterations to stop training
+ num_iter (int): current iteration number
+ """
+
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ start_lr,
+ warmup_iter,
+ end_iter,
+ num_iter=0,
+ ) -> None:
+ self.start_lr = start_lr
+ self.warmup_iter = warmup_iter
+ self.end_iter = end_iter
+ self.optimizer = optimizer
+ self.num_iter = num_iter
+ self._current_lr = None
+
+ self.step(self.num_iter)
+
+ def get_lr_warmup(self, num_iter) -> float: ...
+
+ def get_lr_decay(self, num_iter) -> float: ...
+
+ def get_lr(self):
+ assert self.num_iter >= 0
+
+ if self.num_iter < self.warmup_iter:
+ return self.get_lr_warmup(self.num_iter)
+ else:
+ return self.get_lr_decay(self.num_iter)
+
+ @property
+ def current_lr(self):
+ return self._current_lr
+
+ def step(self, num_iter=None) -> None:
+ if num_iter is None:
+ num_iter = self.num_iter + 1
+ self.num_iter = num_iter
+
+ lr = self.get_lr()
+ self._current_lr = lr
+ for group in self.optimizer.param_groups:
+ group["lr"] = lr
+
+ def state_dict(self):
+ return {
+ "start_lr": self.start_lr,
+ "warmup_iter": self.warmup_iter,
+ "end_iter": self.end_iter,
+ "num_iter": self.num_iter,
+ }
+
+ def load_state_dict(self, state_dict):
+ self.start_lr = state_dict["start_lr"]
+ self.warmup_iter = state_dict["warmup_iter"]
+ self.end_iter = state_dict["end_iter"]
+ self.num_iter = state_dict["num_iter"]
+
+ self.step(self.num_iter)
diff --git a/examples/BMTrain/bmtrain/nccl/__init__.py b/examples/BMTrain/bmtrain/nccl/__init__.py
new file mode 100644
index 00000000..0f4129d5
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nccl/__init__.py
@@ -0,0 +1,336 @@
+
+from typing_extensions import Literal
+import torch
+from .. import C
+from .enums import *
+
+class NCCLCommunicator:
+ """
+ NCCL communicator stores the communicator handle.
+ """
+
+ def __init__(self, ptr) -> None:
+ self.__ptr = ptr
+
+ @property
+ def ptr(self):
+ """
+ Returns the communicator handle.
+ """
+ if self.__ptr == -1:
+ raise RuntimeError("NCCL Communicator is already destroyed")
+ return self.__ptr
+
+ def _destroy_ptr(self):
+ self.__ptr = -1
+
+# utils
+
+def dtype2nccl(dtype : torch.dtype) -> int:
+ MAP = {
+ torch.int8: ncclInt8,
+ torch.uint8 : ncclUint8,
+ torch.int32 : ncclInt32,
+ torch.int : ncclInt32,
+ torch.int64 : ncclInt64,
+ torch.float16 : ncclFloat16,
+ torch.half : ncclHalf,
+ torch.bfloat16 : ncclBFloat16,
+ torch.float32 : ncclFloat32,
+ torch.float : ncclFloat,
+ torch.float64 : ncclFloat64,
+ torch.double : ncclDouble,
+ torch.bool : ncclBool
+ }
+ if dtype not in MAP:
+ raise TypeError("Unsupport dtype %s" % dtype)
+ return MAP[dtype]
+
+def op2nccl(
+ op : Literal["sum", "prod", "max", "min", "avg"]
+):
+ if op == "sum":
+ return ncclSum
+ if op == "prod":
+ return ncclProd
+ if op == "max":
+ return ncclMax
+ if op == "min":
+ return ncclMin
+ if op == "avg":
+ return ncclAvg
+ raise ValueError("Unknown gather op %s")
+
+# wrappers
+
+def getUniqueId() -> bytes:
+ """
+ NCCL API: `ncclGetUniqueId `_
+
+ """
+ return C.ncclGetUniqueId()
+
+def commInitRank(unique_id : bytes, world_size : int, rank : int) -> NCCLCommunicator:
+ """
+ NCCL API: `ncclCommInitRank `_
+
+ """
+ assert rank >= 0 and rank < world_size, "rank must be between 0 and world_size-1"
+ return NCCLCommunicator(C.ncclCommInitRank(unique_id, world_size, rank))
+
+def commDestroy(comm : NCCLCommunicator):
+ """
+ NCCL API: `ncclCommDestroy `_
+
+ """
+ C.ncclCommDestroy(comm.ptr)
+ comm._destroy_ptr()
+def commCount(comm : NCCLCommunicator):
+ """NCCL API: `ncclCommCount `_
+
+ Args:
+ comm (NCCLCommunicator): NCCL communicator.
+ """
+ return C.ncclCommCount(comm.ptr)
+### collective
+def commRank(comm : NCCLCommunicator):
+ """NCCL API: `ncclCommUserRank `_
+
+ Args:
+ comm (NCCLCommunicator): NCCL communicator.
+ """
+ return C.ncclCommUserRank(comm.ptr)
+def allReduce(
+ src : torch.storage._StorageBase,
+ dst : torch.storage._StorageBase,
+ op : Literal["sum", "prod", "max", "min", "avg"],
+ comm : NCCLCommunicator
+ ):
+ """NCCL API: `ncclAllReduce `_
+
+ Args:
+ src (torch.storage._StorageBase): Source buffer.
+ dst (torch.storage._StorageBase): Destination buffer.
+ op (Literal["sum", "prod", "max", "min", "avg"]): Reduction operation.
+ comm (NCCLCommunicator): NCCL communicator.
+
+ The src and dst buffers must be the same size, type and on the same device.
+
+ If src == dst, the operation is performed in-place.
+
+ """
+ assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
+ assert src.is_cuda and dst.is_cuda
+
+ sendbuff = src.data_ptr()
+ recvbuff = dst.data_ptr()
+ count = src.size()
+ datatype = dtype2nccl(src.dtype)
+ operator = op2nccl(op)
+
+ assert src.size() == dst.size(), "Buffer size not aligned"
+ C.ncclAllReduce(
+ sendbuff,
+ recvbuff,
+ count,
+ datatype,
+ operator,
+ comm.ptr,
+ torch.cuda.current_stream().cuda_stream
+ )
+def send(src : torch.storage._StorageBase,
+ peer : int,
+ comm : NCCLCommunicator
+ ):
+ """NCCL API: `ncclsend `_
+
+ Args:
+ src (torch.storage._StorageBase): Source buffer.
+ peer (int): rank peer needs to call ncclRecv
+ comm (NCCLCommunicator): NCCL communicator.
+ """
+
+ sendbuff = src.data_ptr()
+ count = src.size()
+ datatype = dtype2nccl(src.dtype)
+ C.ncclSend(
+ sendbuff,
+ count,
+ datatype,
+ peer,
+ comm.ptr,
+ torch.cuda.current_stream().cuda_stream
+ )
+def recv(dst : torch.storage._StorageBase,
+ peer : int,
+ comm : NCCLCommunicator
+ ):
+ recvbuff = dst.data_ptr()
+ count = dst.size()
+ datatype = dtype2nccl(dst.dtype)
+ C.ncclRecv(
+ recvbuff,
+ count,
+ datatype,
+ peer,
+ comm.ptr,
+ torch.cuda.current_stream().cuda_stream
+ )
+
+def broadcast(
+ src : torch.storage._StorageBase,
+ dst : torch.storage._StorageBase,
+ root : int,
+ comm : NCCLCommunicator
+ ):
+ """NCCL API: `ncclBroadcast `_
+
+ Args:
+ src (torch.storage._StorageBase): Source buffer.
+ dst (torch.storage._StorageBase): Destination buffer.
+ root (int): Rank of the root.
+ comm (NCCLCommunicator): NCCL communicator.
+
+ The src and dst buffers must be the same size, type and on the same device.
+
+ If src == dst, the operation is performed in-place.
+
+ """
+
+ assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
+ assert src.is_cuda and dst.is_cuda
+
+ sendbuff = src.data_ptr()
+ recvbuff = dst.data_ptr()
+ count = src.size()
+ datatype = dtype2nccl(src.dtype)
+
+ assert dst.size() == src.size(), "Buffer size not aligned"
+ C.ncclBroadcast(
+ sendbuff,
+ recvbuff,
+ count,
+ datatype,
+ root,
+ comm.ptr,
+ torch.cuda.current_stream().cuda_stream
+ )
+
+def reduce(
+ src : torch.storage._StorageBase,
+ dst : torch.storage._StorageBase,
+ op : Literal["sum", "prod", "max", "min", "avg"],
+ root : int,
+ comm : NCCLCommunicator
+ ):
+ """NCCL API: `ncclReduce `_
+
+ Args:
+ src (torch.storage._StorageBase): Source buffer.
+ dst (torch.storage._StorageBase): Destination buffer.
+ op (Literal["sum", "prod", "max", "min", "avg"]): Reduction operation.
+ root (int): Rank of the root.
+ comm (NCCLCommunicator): NCCL communicator.
+
+ The src and dst buffers must be the same size, type and on the same device.
+
+ If src == dst, the operation is performed in-place.
+
+ """
+ assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
+ assert src.is_cuda and dst.is_cuda
+
+ sendbuff = src.data_ptr()
+ recvbuff = dst.data_ptr()
+ count = src.size()
+ datatype = dtype2nccl(src.dtype)
+ operator = op2nccl(op)
+
+ assert dst.size() == src.size(), "Buffer size not aligned"
+ C.ncclReduce(sendbuff, recvbuff, count, datatype, operator, root, comm.ptr, torch.cuda.current_stream().cuda_stream)
+
+def allGather(
+ src : torch.storage._StorageBase,
+ dst : torch.storage._StorageBase,
+ comm : NCCLCommunicator
+ ):
+ """NCCL API: `ncclAllGather `_
+
+ Args:
+ src (torch.storage._StorageBase): Source buffer.
+ dst (torch.storage._StorageBase): Destination buffer.
+ comm (NCCLCommunicator): NCCL communicator.
+
+ The size of the dst buffer must be equal to the size of src buffer * world_size.
+
+ The dst buffer is only used on rank root.
+
+ """
+ assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
+ assert src.is_cuda and dst.is_cuda
+
+ sendbuff = src.data_ptr()
+ recvbuff = dst.data_ptr()
+ sendcount = src.size()
+ datatype = dtype2nccl(src.dtype)
+ assert dst.size() % sendcount == 0, "Buffer size not aligned"
+ C.ncclAllGather(
+ sendbuff,
+ recvbuff,
+ sendcount,
+ datatype,
+ comm.ptr,
+ torch.cuda.current_stream().cuda_stream
+ )
+
+
+def reduceScatter(
+ src : torch.storage._StorageBase,
+ dst : torch.storage._StorageBase,
+ op : Literal["sum", "prod", "max", "min", "avg"],
+ comm : NCCLCommunicator
+ ):
+ """NCCL API: `ncclReduceScatter `_
+
+ Args:
+ src (torch.storage._StorageBase): Source buffer.
+ dst (torch.storage._StorageBase): Destination buffer.
+ op (Literal["sum", "prod", "max", "min", "avg"]): Reduction operation.
+ comm (NCCLCommunicator): NCCL communicator.
+
+ The size of the dst buffer must be equal to the size of src buffer / world_size.
+
+ The dst buffer on rank `i` will contail the i-th block of the reduced result.
+
+ """
+ assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
+ assert src.is_cuda and dst.is_cuda
+
+ sendbuff = src.data_ptr()
+ recvbuff = dst.data_ptr()
+ recvcount = dst.size()
+ datatype = dtype2nccl(src.dtype)
+ operator = op2nccl(op)
+
+ assert src.size() % recvcount == 0, "Buffer size not aligned"
+ C.ncclReduceScatter(
+ sendbuff,
+ recvbuff,
+ recvcount,
+ datatype,
+ operator,
+ comm.ptr,
+ torch.cuda.current_stream().cuda_stream
+ )
+
+def groupStart():
+ """
+ NCCL API: `ncclGroupStart `_
+ """
+ C.ncclGroupStart()
+
+def groupEnd():
+ """
+ NCCL API: `ncclGroupEnd `_
+ """
+ C.ncclGroupEnd()
diff --git a/examples/BMTrain/bmtrain/nccl/enums.py b/examples/BMTrain/bmtrain/nccl/enums.py
new file mode 100644
index 00000000..67411f0e
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nccl/enums.py
@@ -0,0 +1,27 @@
+
+### ncclDataType_t
+
+ncclInt8 = 0
+ncclChar = 0
+ncclBool = 0
+ncclUint8 = 1
+ncclInt32 = 2
+ncclInt = 2
+ncclUint32 = 3
+ncclInt64 = 4
+ncclUint64 = 5
+ncclFloat16 = 6
+ncclHalf = 6
+ncclFloat32 = 7
+ncclFloat = 7
+ncclFloat64 = 8
+ncclDouble = 8
+ncclBFloat16 = 9
+
+### ncclRedOp_t
+
+ncclSum = 0
+ncclProd = 1
+ncclMax = 2
+ncclMin = 3
+ncclAvg = 4
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/nn/__init__.py b/examples/BMTrain/bmtrain/nn/__init__.py
new file mode 100644
index 00000000..60fed663
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nn/__init__.py
@@ -0,0 +1,5 @@
+from .linear import Linear, OpLinear
+from .column_parallel_linear import ColumnParallelLinear
+from .row_parallel_linear import RowParallelLinear
+from .parallel_embedding import VPEmbedding
+from .parallel_linear_func import OpParallelLinear
diff --git a/examples/BMTrain/bmtrain/nn/column_parallel_linear.py b/examples/BMTrain/bmtrain/nn/column_parallel_linear.py
new file mode 100644
index 00000000..e1ede115
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nn/column_parallel_linear.py
@@ -0,0 +1,80 @@
+import torch
+from torch.nn.parameter import Parameter
+
+import bmtrain as bmt
+from bmtrain.global_var import config
+from .parallel_linear_func import OpParallelLinear, ReduceType
+
+
+class ColumnParallelLinear(bmt.DistributedModule):
+ """Tensor Parallel use cloumn partition for Linear.
+
+ Args:
+ in_features (int): in_features size.
+ out_features (int): out_features size.
+ bias (bool): whether use bias.
+ dtype : data type.
+ gather_ouput (bool): whether gather output after compute.
+ gather_input (bool): whether gather input before compute.
+ async_gather_chunks (int): chunk size for async gathering data.
+
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ dtype=None,
+ gather_output=False,
+ gather_input=True,
+ async_gather_chunks=2,
+ ) -> None:
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.gather_output = gather_output
+ self.gather_input = gather_input
+ self.async_gather_chunks = async_gather_chunks
+ tp_size = config["tp_size"]
+ assert out_features % tp_size == 0
+ self.out_features_per_partition = out_features // tp_size
+ self.weight = bmt.DistributedParameter(
+ torch.empty(
+ self.out_features_per_partition, in_features, dtype=dtype, device="cuda"
+ ),
+ init_method=torch.nn.init.xavier_normal_,
+ tp_split_dim=0,
+ tp_mode=True,
+ )
+ if bias:
+ self.bias = bmt.DistributedParameter(
+ torch.empty(
+ self.out_features_per_partition, dtype=dtype, device="cuda"
+ ),
+ init_method=torch.nn.init.zeros_,
+ tp_split_dim=0,
+ tp_mode=True,
+ )
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, input):
+ gather_input = self.gather_input
+ split_input = False
+ reduce_output_type = None
+ return OpParallelLinear.apply(
+ input,
+ self.weight,
+ self.bias,
+ gather_input,
+ self.gather_output,
+ split_input,
+ reduce_output_type,
+ self.async_gather_chunks,
+ )
+
+ def extra_repr(self) -> str:
+ return "in_features={}, out_features={}, bias={}".format(
+ self.in_features, self.out_features_per_partitions, self.bias is not None
+ )
diff --git a/examples/BMTrain/bmtrain/nn/linear.py b/examples/BMTrain/bmtrain/nn/linear.py
new file mode 100644
index 00000000..8afb1d89
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nn/linear.py
@@ -0,0 +1,56 @@
+import torch
+import torch.nn.functional as F
+import bmtrain as bmt
+
+
+class OpLinear(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, weight, bias=None):
+ ctx.save_for_backward(x, weight, bias)
+ return F.linear(x, weight, bias)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x, weight, bias = ctx.saved_tensors
+ grad_x = grad_weight = grad_bias = None
+ if x.requires_grad:
+ grad_x = grad_output.matmul(weight)
+ if weight.requires_grad:
+ dim = grad_output.dim()
+ grad_weight = (
+ grad_output.reshape(-1, grad_output.shape[-1])
+ .t()
+ .matmul(x.reshape(-1, x.shape[-1]))
+ )
+ if bias is not None and bias.requires_grad:
+ grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
+ return grad_x, grad_weight, grad_bias
+
+
+class Linear(bmt.DistributedModule):
+ def __init__(
+ self, in_features: int, out_features: int, bias: bool = True, dtype=None
+ ) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = bmt.DistributedParameter(
+ torch.empty(out_features, in_features, dtype=dtype, device="cuda"),
+ init_method=torch.nn.init.xavier_normal_,
+ )
+ if bias:
+ self.bias = bmt.DistributedParameter(
+ torch.empty(out_features, dtype=dtype, device="cuda"),
+ init_method=torch.nn.init.zeros_,
+ )
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, input):
+ return OpLinear.apply(input, self.weight, self.bias)
+
+ def extra_repr(self) -> str:
+ return "in_features={}, out_features={}, bias={}".format(
+ self.in_features, self.out_features, self.bias is not None
+ )
diff --git a/examples/BMTrain/bmtrain/nn/parallel_embedding.py b/examples/BMTrain/bmtrain/nn/parallel_embedding.py
new file mode 100644
index 00000000..3bdc4e56
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nn/parallel_embedding.py
@@ -0,0 +1,59 @@
+import torch
+from torch.nn.parameter import Parameter
+import torch.nn.functional as F
+import math
+
+import bmtrain as bmt
+from bmtrain.global_var import config
+from bmtrain.distributed import all_reduce, all_gather
+from .parallel_linear_func import OpParallelLinear
+
+
+class VPEmbedding(bmt.DistributedModule):
+ """Vocab Parallel Embedding.
+
+ Args:
+ vocab_size (int required): vocab size.
+ embedding_size (int required): embedding size.
+ dtype (torch.dtype): data type.
+ init_mean (float optional): mean for weight init.
+ init_std (float optional): std for weight init.
+
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ embedding_size: int,
+ dtype: torch.dtype = torch.half,
+ init_mean: float = 0.0,
+ init_std: float = 1,
+ ):
+ super().__init__()
+
+ self.dim_model = embedding_size
+ assert vocab_size % bmt.config["tp_size"] == 0
+ self.vocab_size_per_partition = vocab_size // bmt.config["tp_size"]
+ self.start_index = bmt.config["tp_rank"] * self.vocab_size_per_partition
+ self.end_index = (bmt.config["tp_rank"] + 1) * self.vocab_size_per_partition
+ self.weight = bmt.DistributedParameter(
+ torch.empty(self.vocab_size_per_partition, embedding_size, dtype=dtype),
+ init_method=bmt.ParameterInitializer(
+ torch.nn.init.normal_, mean=init_mean, std=init_std
+ ),
+ tp_split_dim=0,
+ tp_mode=True,
+ )
+
+ def forward(self, x: torch.Tensor, projection=False):
+ if not projection:
+ weight = all_gather(self.weight, comm=config["tp_comm"]).flatten(0, 1)
+ out = F.embedding(x, weight)
+ return out
+ else:
+ x = bmt.distributed.all_gather(x, comm=bmt.config["tp_comm"]).view(
+ x.shape[0], -1, x.shape[-1]
+ )
+ return bmt.nn.OpParallelLinear.apply(
+ x, self.weight, None, False, False, False, None, 1
+ )
diff --git a/examples/BMTrain/bmtrain/nn/parallel_linear_func.py b/examples/BMTrain/bmtrain/nn/parallel_linear_func.py
new file mode 100644
index 00000000..e389cde6
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nn/parallel_linear_func.py
@@ -0,0 +1,352 @@
+import torch
+import torch.nn.functional as F
+from bmtrain.global_var import config
+from ..distributed import all_gather, all_reduce
+from .. import nccl
+import bmtrain as bmt
+from enum import Enum
+
+
+class ReduceType(Enum):
+ ALL_REDUCE = 1
+ REDUCE_SCATTER = 2
+
+
+def preprocess_input(input, gather_input, split_input):
+ if gather_input:
+ input = all_gather(input, config["tp_comm"])
+ input = input.flatten(0, 1)
+
+ if split_input:
+ all_input_list = input.chunk(config["tp_size"], dim=-1)
+ input = all_input_list[config["topology"].tp_id]
+ return input
+
+
+def async_all_gather_linear_func(input, weight, bias, async_chunks=2):
+ dim = input.dim()
+ shape = list(input.shape)
+ if dim > 2:
+ input = input.view(-1, input.shape[-1])
+ tp_size = config["tp_size"]
+ current_stream = torch.cuda.current_stream()
+ comm_stream = config["tp_comm_stream"]
+
+ rounds = async_chunks
+ inputs = input.chunk(rounds, dim=0)
+ comm_stream.wait_stream(current_stream)
+ outputs = [None] * tp_size * rounds
+
+ input = all_gather(inputs[0], config["tp_comm"])
+ input = input.flatten(0, 1)
+ out = F.linear(input, weight, bias)
+ outs = out.chunk(tp_size, dim=0)
+ for i in range(tp_size):
+ outputs[i * rounds] = outs[i]
+
+ # async all_gather and overalap with linear
+ for i in range(rounds - 1):
+ with torch.cuda.stream(comm_stream):
+ inputs[i + 1].record_stream(comm_stream)
+ input = all_gather(inputs[i + 1], config["tp_comm"])
+ input = input.flatten(0, 1)
+
+ current_stream.wait_stream(comm_stream)
+ out = F.linear(input, weight, bias)
+ outs = out.chunk(tp_size, dim=0)
+ for j in range(tp_size):
+ outputs[(i + 1) + j * rounds] = outs[j]
+
+ out = torch.cat(outputs, dim=0)
+ if dim > 2:
+ out_shape = list(out.shape)
+ shape[-1] = out_shape[-1]
+ shape[0] = shape[0] * tp_size
+ out = out.view(shape)
+ return out
+
+
+def async_reduce_scatter_linear_func(input, weight, bias, async_chunks=2):
+ tp_size = config["tp_size"]
+ comm_stream = config["tp_comm_stream"]
+ rounds = async_chunks
+ input_shape = list(input.shape)
+ dim = input.dim()
+ if dim > 2:
+ input = input.view(-1, input.shape[-1])
+ inputs = input.chunk(rounds * tp_size, dim=0)
+ current_stream = torch.cuda.current_stream()
+
+ outputs = [None] * rounds
+ for i in range(rounds):
+ input = [None] * tp_size
+ for j in range(tp_size):
+ input[j] = inputs[j * rounds + i]
+ input = torch.cat(input, dim=0)
+ out = F.linear(input, weight, bias)
+ with torch.cuda.stream(comm_stream):
+ comm_stream.wait_stream(current_stream)
+ out.record_stream(comm_stream)
+ shape = list(out.shape)
+ shape[0] = shape[0] // config["tp_size"]
+ outputs[i] = torch.empty(shape, dtype=out.dtype, device=out.device)
+ nccl.reduceScatter(
+ out.storage(), outputs[i].storage(), "sum", config["tp_comm"]
+ )
+
+ current_stream.wait_stream(comm_stream)
+ out = torch.cat(outputs, dim=0)
+ if dim > 2:
+ out_shape = list(out.shape)
+ input_shape[-1] = out_shape[-1]
+ input_shape[0] = input_shape[0] // tp_size
+ out = out.view(input_shape)
+
+ return out
+
+
+def async_all_gather_linear_backward_func(
+ grad_out, input, weight, bias, async_chunks=2
+):
+ tp_size = config["tp_size"]
+ current_stream = torch.cuda.current_stream()
+ comm_stream = config["tp_comm_stream"]
+ input_require_grad = input.requires_grad
+ dim = input.dim()
+ input_shape = input.shape
+ if dim > 2:
+ input = input.view(-1, input_shape[-1])
+ grad_out = grad_out.view(-1, grad_out.shape[-1])
+
+ rounds = async_chunks
+ grad_inputs = [None] * tp_size * rounds
+ grad_weights = [None] * tp_size * rounds
+ grad_outs = [None] * tp_size * rounds
+ local_grad_outs = grad_out.chunk(rounds, dim=0)
+
+ inputs = [None] * rounds
+ comm_stream.wait_stream(current_stream)
+ if weight.requires_grad:
+ with torch.cuda.stream(comm_stream):
+ input.record_stream(comm_stream)
+ input_list = [None] * tp_size * rounds
+ tp_inputs = input.chunk(tp_size, dim=0)
+ for i in range(tp_size):
+ chunk_inputs = tp_inputs[i].chunk(rounds, dim=0)
+ for j in range(rounds):
+ input_list[j * tp_size + i] = chunk_inputs[j]
+ start = 0
+ end = tp_size
+ for i in range(rounds):
+ inputs[i] = torch.cat(input_list[start:end], dim=0)
+ start = end
+ end += tp_size
+
+ grad_input = grad_weight = grad_bias = None
+
+ grad_out = all_gather(local_grad_outs[0], config["tp_comm"])
+ for j in range(tp_size):
+ grad_outs[j * rounds] = grad_out[j]
+ grad_out = grad_out.flatten(0, 1) # (tp_size * (m/rounds), n)
+ if input_require_grad:
+ grad_input = grad_out.matmul(
+ weight
+ ) # (tp_size * (m/rounds), n) * (n, k/tp_size)
+ tmp_grad_inputs = grad_input.chunk(tp_size, dim=0)
+ for j in range(tp_size):
+ grad_inputs[j * rounds] = tmp_grad_inputs[j]
+
+ if weight.requires_grad:
+ grad_weight = (
+ grad_out.reshape(-1, grad_out.shape[-1])
+ .t()
+ .matmul(inputs[0].reshape(-1, inputs[0].shape[-1]))
+ )
+
+ # async all_gather and overalap with matmul
+ for i in range(rounds - 1):
+ with torch.cuda.stream(comm_stream):
+ local_grad_outs[i + 1].record_stream(comm_stream)
+ grad_out = all_gather(local_grad_outs[i + 1], config["tp_comm"])
+ for j in range(tp_size):
+ grad_outs[j * rounds + i + 1] = grad_out[j]
+ grad_out = grad_out.flatten(0, 1) # (tp_size * (m/rounds), n)
+
+ current_stream.wait_stream(comm_stream)
+ if input_require_grad:
+ grad_input = grad_out.matmul(
+ weight
+ ) # (tp_size * (m/rounds), n) * (n, k/tp_size)
+ tmp_grad_inputs = grad_input.chunk(tp_size, dim=0)
+ for j in range(tp_size):
+ grad_inputs[j * rounds + i + 1] = tmp_grad_inputs[j]
+
+ if weight.requires_grad:
+ dim = grad_out.dim()
+ grad_weight += (
+ grad_out.reshape(-1, grad_out.shape[-1])
+ .t()
+ .matmul(inputs[i + 1].reshape(-1, inputs[i + 1].shape[-1]))
+ )
+
+ if input_require_grad:
+ grad_input = torch.cat(grad_inputs, dim=0)
+ grad_input = grad_input.view(input_shape)
+
+ if bias is not None and bias.requires_grad:
+ grad_out = torch.cat(grad_outs, dim=0)
+ grad_bias = grad_out.reshape(-1, grad_out.shape[-1]).sum(0)
+
+ return grad_input, grad_weight, grad_bias
+
+
+class OpParallelLinear(torch.autograd.Function):
+ """OpParallelLinear is a subclass of torch.autograd.Function.
+ It gathers the input tensor when needed, and all reduce or reduece scatter the output when needed.
+
+ """
+
+ @staticmethod
+ def forward(
+ ctx,
+ input,
+ weight,
+ bias=None,
+ gather_input=False,
+ gather_output=False,
+ split_input=False,
+ reduce_output_type=None,
+ async_gather_chunks=2,
+ ):
+ if reduce_output_type is not None:
+ reduce_output_type = ReduceType(reduce_output_type)
+
+ ctx.save_for_backward(input, weight, bias)
+ ctx.gather_output = gather_output
+ ctx.split_input = split_input
+ ctx.gather_input = gather_input
+ ctx.reduce_output_type = reduce_output_type
+ ctx.async_gather_chunks = async_gather_chunks
+
+ if (
+ gather_input
+ and config["tp_size"] > 1
+ and async_gather_chunks > 1
+ and split_input == False
+ ):
+ out = async_all_gather_linear_func(input, weight, bias, async_gather_chunks)
+ elif reduce_output_type == ReduceType.REDUCE_SCATTER:
+ return async_reduce_scatter_linear_func(
+ input, weight, bias, async_gather_chunks
+ )
+ else:
+ all_input = preprocess_input(input, ctx.gather_input, ctx.split_input)
+ out = F.linear(all_input, weight, bias)
+
+ if gather_output:
+ all_output_list = all_gather(out, config["tp_comm"])
+ all_output_list = all_output_list.chunk(config["tp_size"], dim=0)
+ out = torch.cat(all_output_list, dim=all_output_list[0].dim() - 1).flatten(
+ 0, 1
+ )
+
+ if reduce_output_type is None:
+ return out
+
+ if reduce_output_type == ReduceType.ALL_REDUCE:
+ nccl.allReduce(out.storage(), out.storage(), "sum", config["tp_comm"])
+ return out
+ else:
+ assert False, "no support reduce type{}".format(reduce_output_type)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, weight, bias = ctx.saved_tensors
+ gather_output = ctx.gather_output
+
+ if ctx.reduce_output_type == ReduceType.REDUCE_SCATTER:
+ if input.requires_grad or weight.requires_grad:
+ grad_input, grad_weight, grad_bias = (
+ async_all_gather_linear_backward_func(
+ grad_output, input, weight, bias, ctx.async_gather_chunks
+ )
+ )
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None
+ else:
+ grad_output = all_gather(grad_output, config["tp_comm"])
+ grad_output = grad_output.flatten(0, 1)
+
+ if gather_output:
+ tp_size = config["tp_size"]
+ tp_id = config["topology"].tp_id
+ grad_output_list = grad_output.chunk(tp_size, dim=-1)
+ grad_output = grad_output_list[tp_id]
+
+ grad_input = grad_weight = grad_bias = None
+
+ current_stream = torch.cuda.current_stream()
+ if input.requires_grad or weight.requires_grad:
+ if ctx.gather_input:
+ # async the all_gather
+ with torch.cuda.stream(config["tp_comm_stream"]):
+ input.record_stream(config["tp_comm_stream"])
+ config["tp_comm_stream"].wait_stream(current_stream)
+ all_input = preprocess_input(
+ input, ctx.gather_input, ctx.split_input
+ )
+ # use event to solve two streams waiting for each other
+ gather_event = config["tp_comm_stream"].record_event()
+ else:
+ all_input = preprocess_input(input, ctx.gather_input, ctx.split_input)
+
+ if input.requires_grad:
+ grad_all_input = grad_output.matmul(weight)
+ grad_input = torch.zeros_like(input)
+ if ctx.gather_input:
+ # async the reduce_scatter
+ with torch.cuda.stream(config["tp_comm_stream"]):
+ config["tp_comm_stream"].wait_stream(current_stream)
+ grad_input.record_stream(config["tp_comm_stream"])
+ grad_all_input.record_stream(config["tp_comm_stream"])
+ nccl.reduceScatter(
+ grad_all_input.storage(),
+ grad_input.storage(),
+ "sum",
+ config["tp_comm"],
+ )
+ elif ctx.reduce_output_type is None:
+ with torch.cuda.stream(config["tp_comm_stream"]):
+ config["tp_comm_stream"].wait_stream(current_stream)
+ grad_input.record_stream(config["tp_comm_stream"])
+ nccl.allReduce(
+ grad_all_input.storage(),
+ grad_all_input.storage(),
+ "sum",
+ config["tp_comm"],
+ )
+ grad_input = grad_all_input
+ else:
+ grad_input = grad_all_input
+
+ if ctx.split_input:
+ with torch.cuda.stream(config["tp_comm_stream"]):
+ config["tp_comm_stream"].wait_stream(current_stream)
+ grad_input.record_stream(config["tp_comm_stream"])
+ grad_input = all_gather(grad_input, config["tp_comm"])
+
+ # wait all_gather
+ if ctx.gather_input:
+ current_stream.wait_event(gather_event)
+ if weight.requires_grad:
+ grad_weight = (
+ grad_output.reshape(-1, grad_output.shape[-1])
+ .t()
+ .matmul(all_input.reshape(-1, all_input.shape[-1]))
+ )
+
+ if bias is not None and bias.requires_grad:
+ grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
+
+ current_stream = torch.cuda.current_stream()
+ current_stream.wait_stream(config["tp_comm_stream"])
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None
diff --git a/examples/BMTrain/bmtrain/nn/row_parallel_linear.py b/examples/BMTrain/bmtrain/nn/row_parallel_linear.py
new file mode 100644
index 00000000..ee4610cc
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nn/row_parallel_linear.py
@@ -0,0 +1,88 @@
+import torch
+from torch.nn.parameter import Parameter
+
+import bmtrain as bmt
+from bmtrain.global_var import config
+from .parallel_linear_func import OpParallelLinear, ReduceType
+
+
+class RowParallelLinear(bmt.DistributedModule):
+ """Tensor Parallel use row partition for Linear.
+
+ Args:
+ in_features (int): in_features size.
+ out_features (int): out_features size.
+ bias (bool): whether use bias.
+ dtype : data type.
+ split_input (bool): whether split input before compute.
+ all_reduce_output (bool): if true use all_reduce data after compute, or use reduce_scatter.
+ async_chunks (int): chunk size for async.
+
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ dtype=None,
+ split_input=False,
+ all_reduce_output=False,
+ async_chunks=2,
+ ) -> None:
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.split_input = split_input
+ self.all_reduce_output = all_reduce_output
+ self.async_chunks = async_chunks
+ tp_size = config["tp_size"]
+ assert in_features % tp_size == 0
+ self.in_features_per_partition = in_features // tp_size
+ self.weight = bmt.DistributedParameter(
+ torch.empty(
+ self.out_features,
+ self.in_features_per_partition,
+ dtype=dtype,
+ device="cuda",
+ ),
+ init_method=torch.nn.init.xavier_normal_,
+ tp_split_dim=1,
+ tp_mode=True,
+ )
+ if bias:
+ self.bias = bmt.DistributedParameter(
+ torch.empty(self.out_features, dtype=dtype, device="cuda"),
+ init_method=torch.nn.init.zeros_,
+ tp_split_dim=-1,
+ tp_mode=True,
+ )
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, input):
+ gather_input = self.split_input
+ gather_output = False
+ reduce_output_type = (
+ ReduceType.ALL_REDUCE
+ if self.all_reduce_output
+ else ReduceType.REDUCE_SCATTER
+ )
+ out = OpParallelLinear.apply(
+ input,
+ self.weight,
+ None,
+ gather_input,
+ gather_output,
+ self.split_input,
+ reduce_output_type,
+ self.async_chunks,
+ )
+ if self.bias is not None:
+ out = out + self.bias
+ return out
+
+ def extra_repr(self) -> str:
+ return "in_features={}, out_features={}, bias={}".format(
+ self.in_features_per_partition, self.out_features, self.bias is not None
+ )
diff --git a/examples/BMTrain/bmtrain/optim/__init__.py b/examples/BMTrain/bmtrain/optim/__init__.py
new file mode 100644
index 00000000..15206328
--- /dev/null
+++ b/examples/BMTrain/bmtrain/optim/__init__.py
@@ -0,0 +1,3 @@
+from .adam import AdamOptimizer
+from .adam_offload import AdamOffloadOptimizer
+from .optim_manager import OptimManager
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/optim/_distributed.py b/examples/BMTrain/bmtrain/optim/_distributed.py
new file mode 100644
index 00000000..df8f2f3e
--- /dev/null
+++ b/examples/BMTrain/bmtrain/optim/_distributed.py
@@ -0,0 +1,40 @@
+import torch
+from ..distributed import all_reduce, all_gather
+
+
+def state_dict_gather(state_dict):
+ param_key = [
+ p for param_group in state_dict["param_groups"] for p in param_group["params"]
+ ]
+ for k, v in state_dict["state"].items():
+ if "step" in v:
+ step = v["step"]
+
+ for k in param_key:
+ if k not in state_dict["state"]:
+ state_dict["state"][k] = {
+ "exp_avg": torch.tensor([], device="cuda", dtype=torch.float32),
+ "exp_avg_sq": torch.tensor([], device="cuda", dtype=torch.float32),
+ "_param_fp32": torch.tensor([], device="cuda", dtype=torch.float32),
+ "step": step,
+ }
+ v = state_dict["state"][k]
+ for name, dtype in [
+ ("exp_avg", torch.float32),
+ ("exp_avg_sq", torch.float32),
+ ("_param_fp32", torch.float32),
+ ]:
+ if name in v:
+ with torch.no_grad():
+ numel = torch.tensor(
+ v[name].numel(), device="cuda", dtype=torch.long
+ )
+ max_numel = all_reduce(numel, op="max")
+ v_p = torch.nn.functional.pad(
+ v[name], (0, max_numel - numel), value=-1e15
+ )
+ if max_numel > 0:
+ whole_state = all_gather(v_p.cuda()).flatten()
+ whole_state = whole_state[whole_state != -1e15]
+ v[name] = whole_state.contiguous().cpu()
+ return state_dict
diff --git a/examples/BMTrain/bmtrain/optim/_function.py b/examples/BMTrain/bmtrain/optim/_function.py
new file mode 100644
index 00000000..f9e0ce9d
--- /dev/null
+++ b/examples/BMTrain/bmtrain/optim/_function.py
@@ -0,0 +1,218 @@
+from .. import C
+import torch
+
+CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda
+
+
+def bf16_from_fp32(param_fp32):
+ param_bf16 = torch.empty_like(param_fp32, dtype=torch.bfloat16)
+ C.to_bf16_from_fp32(
+ param_fp32.numel(), param_fp32.data_ptr(), param_bf16.data_ptr()
+ )
+ return param_bf16
+
+
+def fp16_from_fp32(param_fp32):
+ param_fp16 = torch.empty_like(param_fp32, dtype=torch.float16)
+ C.to_fp16_from_fp32(
+ param_fp32.numel(), param_fp32.data_ptr(), param_fp16.data_ptr()
+ )
+ return param_fp16
+
+
+def adam_cpu(
+ param_fp32: torch.Tensor,
+ param_fp16: torch.Tensor,
+ delta_info: torch.Tensor,
+ g_fp16: torch.Tensor,
+ m_fp32: torch.Tensor,
+ v_fp32: torch.Tensor,
+ beta1: float,
+ beta2: float,
+ eps: float,
+ lr: float,
+ scale: float,
+ weight_decay: float,
+ step: int,
+) -> None:
+ assert param_fp32.is_contiguous(), "param_fp32 must be contiguous"
+ assert param_fp16.is_contiguous(), "param_fp16 must be contiguous"
+ assert g_fp16.is_contiguous(), "g_fp16 must be contiguous"
+ assert m_fp32.is_contiguous(), "m_fp32 must be contiguous"
+ assert v_fp32.is_contiguous(), "v_fp32 must be contiguous"
+ assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor"
+ assert (
+ param_fp16.dtype == torch.float16 or param_fp16.dtype == torch.bfloat16
+ ), "param_fp16 must be float16/bfloat16 tensor"
+ assert (
+ g_fp16.dtype == torch.float16 or g_fp16.dtype == torch.bfloat16
+ ), "g_fp16 must be float16/bfloat16 tensor"
+ assert m_fp32.dtype == torch.float32, "m_fp32 must be float32 tensor"
+ assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor"
+ assert param_fp32.device == torch.device("cpu"), "param_fp32 must be a cpu tensor"
+ assert param_fp16.device == torch.device("cpu"), "param_fp16 must be a cpu tensor"
+ assert g_fp16.device == torch.device("cpu"), "g_fp16 must be a cpu tensor"
+ assert m_fp32.device == torch.device("cpu"), "m_fp32 must be a cpu tensor"
+ assert v_fp32.device == torch.device("cpu"), "v_fp32 must be a cpu tensor"
+ assert (
+ param_fp32.numel() == param_fp16.numel()
+ ), "param_fp32 and param_fp16 must have the same number of elements"
+ assert (
+ param_fp32.numel() == g_fp16.numel()
+ ), "param_fp32 and g_fp16 must have the same number of elements"
+ assert (
+ param_fp32.numel() == m_fp32.numel()
+ ), "param_fp32 and m_fp32 must have the same number of elements"
+ assert (
+ param_fp32.numel() == v_fp32.numel()
+ ), "param_fp32 and v_fp32 must have the same number of elements"
+ if delta_info is not None:
+ assert delta_info.is_contiguous(), "delta_info must be contiguous"
+ assert delta_info.dtype == torch.float32, "delta_info must be float32 tensor"
+ assert delta_info.device == torch.device(
+ "cpu"
+ ), "delta_info must be a cpu tensor"
+ assert delta_info.numel() == 4, "delta_info have a length of 4"
+ bias_correction1 = 1 - beta1**step
+ bias_correction2 = 1 - beta2**step
+ if g_fp16.dtype == torch.float16:
+ launcher = C.adam_cpu_fp16_launcher
+ elif g_fp16.dtype == torch.bfloat16:
+ if not C.is_bf16_supported():
+ raise NotImplementedError(f"bfloat16 is not supported on current GPU")
+ launcher = C.adam_cpu_bf16_launcher
+ launcher(
+ param_fp32.numel(),
+ param_fp32.data_ptr(),
+ param_fp16.data_ptr(),
+ delta_info.data_ptr() if delta_info is not None else 0,
+ g_fp16.data_ptr(),
+ m_fp32.data_ptr(),
+ v_fp32.data_ptr(),
+ beta1,
+ beta2,
+ eps,
+ lr,
+ scale,
+ weight_decay,
+ bias_correction1,
+ bias_correction2,
+ )
+
+
+def adam_fp16(
+ param_fp32: torch.Tensor,
+ param_fp16: torch.Tensor,
+ g_fp16: torch.Tensor,
+ m_fp16: torch.Tensor,
+ v_fp32: torch.Tensor,
+ beta1: float,
+ beta2: float,
+ eps: float,
+ lr: float,
+ scale: float,
+ weight_decay: float,
+ step: int,
+) -> None:
+ assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda"
+ assert CHECK_INPUT(param_fp16), "param_fp16 must be contiguous and on cuda"
+ assert CHECK_INPUT(g_fp16), "g_fp16 must be contiguous and on cuda"
+ assert CHECK_INPUT(m_fp16), "m_fp32 must be contiguous and on cuda"
+ assert CHECK_INPUT(v_fp32), "v_fp32 must be contiguous and on cuda"
+ assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor"
+ assert param_fp16.dtype == torch.float16, "param_fp16 must be float16 tensor"
+ assert g_fp16.dtype == torch.float16, "g_fp16 must be float16 tensor"
+ assert m_fp16.dtype == torch.float16, "m_fp16 must be float16 tensor"
+ assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor"
+ assert (
+ param_fp32.numel() == param_fp16.numel()
+ ), "param_fp32 and param_fp16 must have the same number of elements"
+ assert (
+ param_fp32.numel() == g_fp16.numel()
+ ), "param_fp32 and g_fp16 must have the same number of elements"
+ assert (
+ param_fp32.numel() == m_fp16.numel()
+ ), "param_fp32 and m_fp32 must have the same number of elements"
+ assert (
+ param_fp32.numel() == v_fp32.numel()
+ ), "param_fp32 and v_fp32 must have the same number of elements"
+ bias_correction1 = 1 - beta1**step
+ bias_correction2 = 1 - beta2**step
+ stream = torch.cuda.current_stream().cuda_stream
+ C.adam_fp16_launcher(
+ param_fp32.numel(),
+ param_fp32.data_ptr(),
+ param_fp16.data_ptr(),
+ g_fp16.data_ptr(),
+ m_fp16.data_ptr(),
+ v_fp32.data_ptr(),
+ beta1,
+ beta2,
+ eps,
+ lr,
+ scale,
+ weight_decay,
+ bias_correction1,
+ bias_correction2,
+ stream,
+ )
+
+
+def adam_bf16(
+ param_fp32: torch.Tensor,
+ param_bf16: torch.Tensor,
+ g_bf16: torch.Tensor,
+ m_fp32: torch.Tensor,
+ v_fp32: torch.Tensor,
+ beta1: float,
+ beta2: float,
+ eps: float,
+ lr: float,
+ scale: float,
+ weight_decay: float,
+ step: int,
+) -> None:
+ assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda"
+ assert CHECK_INPUT(param_bf16), "param_bf16 must be contiguous and on cuda"
+ assert CHECK_INPUT(g_bf16), "g_bf16 must be contiguous and on cuda"
+ assert CHECK_INPUT(m_fp32), "m_fp32 must be contiguous and on cuda"
+ assert CHECK_INPUT(v_fp32), "v_fp32 must be contiguous and on cuda"
+ assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor"
+ assert param_bf16.dtype == torch.bfloat16, "param_fp16 must be float16 tensor"
+ assert g_bf16.dtype == torch.bfloat16, "g_bf16 must be bfloat16 tensor"
+ assert m_fp32.dtype == torch.float32, "m_fp32 must be bfloat16 tensor"
+ assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor"
+ assert (
+ param_fp32.numel() == param_bf16.numel()
+ ), "param_fp32 and param_bf16 must have the same number of elements"
+ assert (
+ param_fp32.numel() == g_bf16.numel()
+ ), "param_fp32 and g_fp16 must have the same number of elements"
+ assert (
+ param_fp32.numel() == m_fp32.numel()
+ ), "param_fp32 and m_m_fp32 must have the same number of elements"
+ assert (
+ param_fp32.numel() == v_fp32.numel()
+ ), "param_fp32 and v_fp32 must have the same number of elements"
+ bias_correction1 = 1 - beta1**step
+ bias_correction2 = 1 - beta2**step
+ stream = torch.cuda.current_stream().cuda_stream
+ if not C.is_bf16_supported():
+ raise NotImplementedError(f"bfloat16 is not supported on current GPU")
+ C.adam_bf16_launcher(
+ param_fp32.numel(),
+ param_fp32.data_ptr(),
+ param_bf16.data_ptr(),
+ g_bf16.data_ptr(),
+ m_fp32.data_ptr(),
+ v_fp32.data_ptr(),
+ beta1,
+ beta2,
+ eps,
+ lr,
+ scale,
+ weight_decay,
+ bias_correction1,
+ bias_correction2,
+ stream,
+ )
diff --git a/examples/BMTrain/bmtrain/optim/adam.py b/examples/BMTrain/bmtrain/optim/adam.py
new file mode 100644
index 00000000..f99c483c
--- /dev/null
+++ b/examples/BMTrain/bmtrain/optim/adam.py
@@ -0,0 +1,252 @@
+import torch
+from ..global_var import config
+from . import _function as F
+import torch.optim._functional
+from .. import C
+from .. import nccl
+import inspect
+from ..utils import check_torch_version
+from copy import deepcopy
+from itertools import chain
+from collections import defaultdict
+
+
+class AdamOptimizer(torch.optim.Optimizer):
+ """
+ Adam optimizer support fp16 and bf16.
+ """
+
+ _bmtrain_optimizer = True
+
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ hold_steps=0,
+ ):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
+ super().__init__(params, defaults)
+
+ self._hold_steps = hold_steps
+
+ def _on_justify_scale(self, old_scale, new_scale):
+ delta = new_scale / old_scale
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p in self.state:
+ state = self.state[p]
+ if len(state) > 0:
+ if p.dtype == torch.float16:
+ state["exp_avg"] *= delta
+ state["exp_avg_sq"] *= delta
+
+ @torch.no_grad()
+ def step(self, closure=None, scale=1):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
+ """
+
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ # update parameters
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is not None and p.requires_grad:
+ if p.grad.is_sparse:
+ raise RuntimeError(
+ "Adam does not support sparse gradients, please consider SparseAdam instead"
+ )
+ if p.dtype not in [torch.float32, torch.half, torch.bfloat16]:
+ raise RuntimeError(
+ "Adam only supports fp32, fp16 and bf16 gradients"
+ )
+
+ state = self.state[p]
+ # Lazy state initialization
+ if len(state) == 0:
+ state["step"] = 0
+ # Exponential moving average of gradient values
+ if p.dtype == torch.float16:
+ state["exp_avg"] = torch.zeros(
+ p.size(), dtype=torch.float16, device=p.device
+ ) # on device
+ else:
+ state["exp_avg"] = torch.zeros(
+ p.size(), dtype=torch.float32, device=p.device
+ ) # on device
+ # Exponential moving average of squared gradient values
+ state["exp_avg_sq"] = torch.zeros(
+ p.size(), dtype=torch.float32, device=p.device
+ ) # on device
+
+ if p.dtype != torch.float32:
+ state["_param_fp32"] = torch.empty(
+ p.size(), dtype=torch.float32, device=p.device
+ ) # on device
+ state["_param_fp32"].copy_(p)
+
+ # update the steps for each param group update
+ if ("maximize" in group) and (group["maximize"] is True):
+ grad = -p.grad
+ else:
+ grad = p.grad
+
+ if p.dtype == torch.float32:
+ other_kwargs = {}
+ if (
+ "maximize"
+ in inspect.signature(
+ torch.optim._functional.adam
+ ).parameters
+ ):
+ other_kwargs["maximize"] = False
+ torch.optim._functional.adam(
+ [p],
+ [grad / scale],
+ [state["exp_avg"]],
+ [state["exp_avg_sq"]],
+ [],
+ (
+ [state["step"]]
+ if check_torch_version("1.12.0") < 0
+ else [torch.tensor(state["step"])]
+ ),
+ amsgrad=False,
+ beta1=group["betas"][0],
+ beta2=group["betas"][1],
+ lr=0.0 if state["step"] < self._hold_steps else group["lr"],
+ weight_decay=group["weight_decay"],
+ eps=group["eps"],
+ **other_kwargs
+ )
+ state["step"] += 1
+ else:
+ f = F.adam_fp16 if p.dtype == torch.float16 else F.adam_bf16
+ state["step"] += 1
+ f(
+ state["_param_fp32"], # fp32
+ p, # fp16
+ grad, # fp16
+ state["exp_avg"], # fp16: m
+ state["exp_avg_sq"], # fp32: v
+ group["betas"][0],
+ group["betas"][1],
+ group["eps"],
+ 0.0 if state["step"] < self._hold_steps else group["lr"],
+ scale,
+ group["weight_decay"],
+ state["step"],
+ )
+
+ return loss
+
+ def get_avg_delta():
+
+ raise NotImplementedError(
+ "get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer"
+ )
+
+ def get_var_delta():
+
+ raise NotImplementedError(
+ "get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer"
+ )
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ r"""Loads the optimizer state.
+
+ Args:
+ state_dict (dict): optimizer state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # deepcopy, to be consistent with module API
+ state_dict = deepcopy(state_dict)
+ # Validate the state_dict
+ groups = self.param_groups
+ saved_groups = state_dict["param_groups"]
+
+ if len(groups) != len(saved_groups):
+ raise ValueError(
+ "loaded state dict has a different number of " "parameter groups"
+ )
+ param_lens = (len(g["params"]) for g in groups)
+ saved_lens = (len(g["params"]) for g in saved_groups)
+ if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
+ raise ValueError(
+ "loaded state dict contains a parameter group "
+ "that doesn't match the size of optimizer's group"
+ )
+
+ # Update the state
+ id_map = {
+ old_id: p
+ for old_id, p in zip(
+ chain.from_iterable((g["params"] for g in saved_groups)),
+ chain.from_iterable((g["params"] for g in groups)),
+ )
+ }
+
+ # Copy state assigned to params (and cast tensors to appropriate types).
+ # State that is not assigned to params is copied as is (needed for
+ # backward compatibility).
+ state = defaultdict(dict)
+ for k, v in state_dict["state"].items():
+ if k in id_map:
+ param = id_map[k]
+
+ if param.dtype != torch.float32 and "_param_fp32" not in v:
+ v["_param_fp32"] = torch.empty(
+ param.size(), dtype=torch.float32, device=param.device
+ )
+ v["_param_fp32"].copy_(param)
+
+ for name, dtype in [
+ (
+ "exp_avg",
+ (
+ torch.float16
+ if param.dtype == torch.float16
+ else torch.float32
+ ),
+ ),
+ ("exp_avg_sq", torch.float32),
+ ("_param_fp32", torch.float32),
+ ]:
+ if name in v:
+ v[name] = v[name].to(param.device).to(dtype)
+
+ state[param] = v
+ else:
+ state[k] = v
+
+ # Update parameter groups, setting their 'params' value
+ def update_group(group, new_group):
+ new_group["params"] = group["params"]
+ return new_group
+
+ param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
+ self.__setstate__({"state": state, "param_groups": param_groups})
+
+ # TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu
+ def zero_grad(self, set_to_none: bool = False):
+ super().zero_grad(set_to_none=set_to_none)
diff --git a/examples/BMTrain/bmtrain/optim/adam_offload.py b/examples/BMTrain/bmtrain/optim/adam_offload.py
new file mode 100644
index 00000000..f6ea97ba
--- /dev/null
+++ b/examples/BMTrain/bmtrain/optim/adam_offload.py
@@ -0,0 +1,386 @@
+import torch
+from ..global_var import config
+from . import _function as F
+from .. import nccl
+import inspect
+from ..utils import check_torch_version
+from copy import deepcopy
+from itertools import chain
+from collections import defaultdict
+from ._distributed import state_dict_gather
+
+
+class AdamOffloadOptimizer(torch.optim.Optimizer):
+ """
+ Adam optimizer using optimizer offload.
+ """
+
+ _bmtrain_optimizer = True
+
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ hold_steps=0,
+ record_delta=False,
+ ):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ self.avg_delta = 0
+ self.var_delta = 0
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
+ super().__init__(params, defaults)
+ self._hold_steps = hold_steps
+ self._events = {}
+ self.record_delta = record_delta
+ if self.record_delta:
+ for group in self.param_groups:
+ for p in group["params"]:
+ setattr(
+ p,
+ "_delta_info",
+ (
+ torch.tensor(
+ [0 for i in range(4)], dtype=torch.float32, device="cpu"
+ )
+ ),
+ )
+
+ @torch.no_grad()
+ def step(self, closure=None, scale=1):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
+ """
+
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ # parameters to be updated
+ update_params = []
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is not None and p.requires_grad:
+ if p.grad.is_sparse:
+ raise RuntimeError(
+ "Adam does not support sparse gradients, please consider SparseAdam instead"
+ )
+ if p.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
+ raise RuntimeError(
+ "Adam only supports fp32, fp16 and bf16 gradients"
+ )
+
+ state = self.state[p]
+ # Lazy state initialization
+ if len(state) == 0:
+ state["step"] = 0
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros(
+ p.size(), dtype=torch.float32, device="cpu"
+ ) # on host
+ # Exponential moving average of squared gradient values
+ state["exp_avg_sq"] = torch.zeros(
+ p.size(), dtype=torch.float32, device="cpu"
+ ) # on host
+
+ if p.dtype == torch.float32:
+ state["_param_fp32"] = torch.empty(
+ p.size(), dtype=torch.float32, pin_memory=True
+ ) # on host
+ state["_param_fp32"].copy_(p)
+
+ # placeholder
+ state["_grad_fp32"] = torch.empty(
+ p.size(), dtype=torch.float32, pin_memory=True
+ ) # on host
+ else:
+ state["_param_fp32"] = torch.empty(
+ p.size(), dtype=torch.float32, device="cpu"
+ ) # on host
+ state["_param_fp32"].copy_(p)
+
+ # placeholder
+ state["_param_fp16"] = torch.empty(
+ p.size(), dtype=p.dtype, pin_memory=True
+ ) # on host
+ state["_grad_fp16"] = torch.empty(
+ p.size(), dtype=p.dtype, pin_memory=True
+ ) # on host
+
+ if p not in self._events:
+ self._events[p] = torch.cuda.Event()
+
+ update_params.append(
+ (
+ p,
+ state,
+ self._events[p],
+ group["betas"][0],
+ group["betas"][1],
+ group["eps"],
+ group["lr"],
+ group["weight_decay"],
+ )
+ )
+
+ # transfer parameters to host asynchronously
+ for param, state, event, _, _, _, _, _ in update_params:
+ if param.dtype == torch.float32:
+ state["_grad_fp32"].copy_(param.grad, non_blocking=True)
+ else:
+ state["_grad_fp16"].copy_(param.grad, non_blocking=True)
+ torch.cuda.current_stream().record_event(event)
+ sum_delta = 0
+ sum_sq_delta = 0
+ total_numel = 0
+ for param, state, event, beta1, beta2, eps, lr, weight_decay in update_params:
+ # wait for transfer to host
+ event.synchronize()
+
+ # update parameters
+ if param.dtype == torch.float32:
+ state["_grad_fp32"].mul_(1.0 / scale)
+ if ("maximize" in group) and (group["maximize"] is True):
+ grad = -state["_grad_fp32"]
+ else:
+ grad = state["_grad_fp32"]
+ other_kwargs = {}
+ if (
+ "maximize"
+ in inspect.signature(torch.optim._functional.adam).parameters
+ ):
+ other_kwargs["maximize"] = False
+ torch.optim._functional.adam(
+ [state["_param_fp32"]],
+ [grad],
+ [state["exp_avg"]],
+ [state["exp_avg_sq"]],
+ [],
+ (
+ [state["step"]]
+ if check_torch_version("1.12.0") < 0
+ else [torch.tensor(state["step"])]
+ ),
+ amsgrad=False,
+ beta1=beta1,
+ beta2=beta2,
+ lr=0.0 if state["step"] < self._hold_steps else lr,
+ weight_decay=weight_decay,
+ eps=eps,
+ **other_kwargs
+ )
+ # transfer parameters back to device asynchronously
+ param.copy_(state["_param_fp32"], non_blocking=True)
+ state["step"] += 1
+ else:
+ state["step"] += 1
+ if ("maximize" in group) and (group["maximize"] is True):
+ grad = -state["_grad_fp16"]
+ else:
+ grad = state["_grad_fp16"]
+ F.adam_cpu(
+ state["_param_fp32"].view(-1),
+ state["_param_fp16"].view(-1),
+ param._delta_info if self.record_delta else None,
+ grad.view(-1),
+ state["exp_avg"].view(-1),
+ state["exp_avg_sq"].view(-1),
+ beta1,
+ beta2,
+ eps,
+ 0.0 if state["step"] < self._hold_steps else lr,
+ scale,
+ weight_decay,
+ state["step"],
+ )
+ total_numel += state["_param_fp16"].numel()
+ if self.record_delta:
+ sum_delta += param._delta_info[2].item()
+ sum_sq_delta += param._delta_info[3].item()
+ # transfer parameters back to device asynchronously
+ param.copy_(state["_param_fp16"], non_blocking=True)
+ if self.record_delta:
+ self.avg_delta = sum_delta / total_numel
+ self.var_delta = sum_sq_delta / total_numel - self.avg_delta**2
+
+ return loss
+
+ def get_avg_delta(self) -> None:
+ return self.avg_delta if self.record_delta else 0
+
+ def get_var_delta(self) -> None:
+ return self.var_delta if self.record_delta else 0
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ r"""Loads the optimizer state.
+
+ Args:
+ state_dict (dict): optimizer state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # deepcopy, to be consistent with module API
+
+ state_dict = deepcopy(state_dict)
+ # Validate the state_dict
+ groups = self.param_groups
+ saved_groups = state_dict["param_groups"]
+
+ if len(groups) != len(saved_groups):
+ raise ValueError(
+ "loaded state dict has a different number of " "parameter groups"
+ )
+ param_lens = (len(g["params"]) for g in groups)
+ saved_lens = (len(g["params"]) for g in saved_groups)
+ if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
+ raise ValueError(
+ "loaded state dict contains a parameter group "
+ "that doesn't match the size of optimizer's group"
+ )
+
+ # Update the state
+ id_map = {
+ old_id: p
+ for old_id, p in zip(
+ chain.from_iterable((g["params"] for g in saved_groups)),
+ chain.from_iterable((g["params"] for g in groups)),
+ )
+ }
+
+ # _param_start_end = chain.from_iterable((g["params_start_end"] for g in saved_groups))
+ # Copy state assigned to params (and cast tensors to appropriate types).
+ # State that is not assigned to params is copied as is (needed for
+ # backward compatibility).
+ state = defaultdict(dict)
+ is_whole = False if "is_whole" not in state_dict else state_dict["is_whole"]
+ pop_key = []
+ for k, v in state_dict["state"].items():
+ if k in id_map:
+ param = id_map[k]
+ if is_whole and param._start_partition is not None:
+ for key in ["_param_fp32", "exp_avg_sq", "exp_avg"]:
+ if key in v:
+ v[key] = v[key][
+ param._start_partition : param._end_partition
+ ]
+ elif is_whole and param._start_partition is None:
+ pop_key.append(param)
+
+ if "_param_fp32" not in v:
+ with torch.no_grad():
+ v["_param_fp32"] = torch.empty(
+ param.size(), dtype=torch.float32, device="cpu"
+ )
+ v["_param_fp32"].copy_(param)
+
+ for name, dtype in [
+ ("exp_avg", torch.float32),
+ ("exp_avg_sq", torch.float32),
+ ("_param_fp32", torch.float32),
+ ]:
+ if name in v:
+ v[name] = v[name].to("cpu").to(dtype)
+
+ state[param] = v
+ if param.dtype == torch.float32:
+ state[param]["_param_fp32"] = state[param][
+ "_param_fp32"
+ ].pin_memory() # on host
+ # initialize placeholders
+ state[param]["_grad_fp32"] = torch.empty(
+ param.size(), dtype=torch.float32, pin_memory=True
+ ) # on host
+ else:
+ # initialize placeholders
+ state[param]["_param_fp16"] = torch.empty(
+ param.size(), dtype=param.dtype, pin_memory=True
+ ) # on host
+ state[param]["_grad_fp16"] = torch.empty(
+ param.size(), dtype=param.dtype, pin_memory=True
+ ) # on host
+ else:
+ state[k] = v
+ for k in pop_key:
+ state.pop(k)
+
+ # Update parameter groups, setting their 'params' value
+ def update_group(group, new_group):
+ new_group["params"] = group["params"]
+ return new_group
+
+ param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
+ self.__setstate__({"state": state, "param_groups": param_groups})
+
+ def state_dict(self, gather=False) -> dict:
+ r"""Returns the state of the optimizer as a :class:`dict`.
+
+ It contains two entries:
+
+ * state - a dict holding current optimization state. Its content
+ differs between optimizer classes.
+ * param_groups - a list containing all parameter groups where each
+ parameter group is a dict
+ """
+
+ # Save order indices instead of Tensors
+ param_mappings = {}
+ start_index = 0
+
+ def pack_group(group):
+ nonlocal start_index
+ packed = {k: v for k, v in group.items() if k != "params"}
+ param_mappings.update(
+ {
+ id(p): i
+ for i, p in enumerate(group["params"], start_index)
+ if id(p) not in param_mappings
+ }
+ )
+ packed["params"] = [param_mappings[id(p)] for p in group["params"]]
+ start_index += len(packed["params"])
+ return packed
+
+ def cut_states(state):
+ return {
+ "step": state["step"],
+ "exp_avg": state["exp_avg"],
+ "exp_avg_sq": state["exp_avg_sq"],
+ "_param_fp32": state["_param_fp32"],
+ }
+
+ param_groups = [pack_group(g) for g in self.param_groups]
+ # Remap state to use order indices as keys
+ packed_state = {
+ (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): cut_states(v)
+ for k, v in self.state.items()
+ }
+ states = {
+ "state": packed_state,
+ "param_groups": param_groups,
+ }
+ if gather:
+ states = state_dict_gather(states)
+ states["is_whole"] = True
+ else:
+ states["is_whole"] = False
+
+ return states
+
+ # TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu
+ def zero_grad(self, set_to_none: bool = False):
+ super().zero_grad(set_to_none=set_to_none)
diff --git a/examples/BMTrain/bmtrain/optim/optim_manager.py b/examples/BMTrain/bmtrain/optim/optim_manager.py
new file mode 100644
index 00000000..1a98ed92
--- /dev/null
+++ b/examples/BMTrain/bmtrain/optim/optim_manager.py
@@ -0,0 +1,226 @@
+from typing import Optional, Union, List, Dict, Tuple
+import torch
+from ..loss._function import has_inf_nan
+from ..utils import print_rank
+from ..lr_scheduler.warmup import WarmupLRScheduler
+from .. import nccl
+from ..global_var import config
+
+def check_overflow(param_groups):
+ # check overflow
+ has_inf_or_nan = torch.zeros(1, dtype=torch.uint8, device="cuda")[0]
+ for group in param_groups:
+ for p in group['params']:
+ if p.grad is not None:
+ if p.dtype != torch.float:
+ has_inf_nan(p.grad, has_inf_or_nan)
+ if "comm" in config:
+ nccl.allReduce(has_inf_or_nan.storage(), has_inf_or_nan.storage(), "max", config["comm"])
+
+ if has_inf_or_nan > 0:
+ raise OverflowError("Gradient overflow")
+
+def grad_rescale(param_groups, scale):
+ for group in param_groups:
+ for p in group['params']:
+ if p.grad is not None and p.requires_grad:
+ p.grad /= scale
+
+class OptimManager:
+ """wait cuda stream. Optional: add loss scaler for mix-precision training
+
+ Args:
+ loss_scale (float): The initial loss scale. Default to None for not using loss scaling.
+ loss_scale_factor (float): The loss scale factor.
+ loss_scale_steps (int): The loss scale steps.
+
+ Examples:
+ >>> optim_manager = bmt.optim.OptimManager(loss_scale=1024)
+ >>> optim_manager.add_optimizer(optimizer1)
+ >>> optim_manager.add_optimizer(optimizer2, lr_scheduler2)
+ >>> for data in dataset:
+ >>> # forward pass and calculate loss
+ >>> optim_manager.zero_grad()
+ >>> optim_manager.backward(loss)
+ >>> optim_manager.clip_grad_norm(optimizer1.param_groups, max_norm=1.0, norm_type=2)
+ >>> optim_manager.clip_grad_norm(optimizer2.param_groups, max_norm=2.0, norm_type=2)
+ >>> optim_manager.step()
+ """
+ def __init__(self,
+ loss_scale : Optional[float] = None,
+ loss_scale_factor : float = 2,
+ loss_scale_steps : int = 1024,
+ min_loss_scale = 1,
+ max_loss_scale = float("inf"),
+ grad_scale : Optional[int] = None,
+ ):
+ if loss_scale is not None:
+ self.loss_scale = loss_scale
+ self.loss_scale_enabled = True
+ else:
+ self.loss_scale = 1
+ self.loss_scale_enabled = False
+ self.steps_since_last_scale = 0
+ self.loss_scale_factor = loss_scale_factor if loss_scale_factor > 1 else 1 / loss_scale_factor
+ self.loss_scale_steps = loss_scale_steps
+ self.min_loss_scale = min_loss_scale
+ self.max_loss_scale = max_loss_scale
+ if grad_scale is None:
+ grad_scale = config['zero_size']
+ self.grad_scale = grad_scale
+
+ self.optimizers = []
+ self.lr_schedulers = []
+
+ def add_optimizer(
+ self,
+ optimizer: torch.optim.Optimizer,
+ lr_scheduler: Optional[WarmupLRScheduler] = None,
+ ):
+ """Add optimizer and (optional) its corresponding lr_scheduler into optim_manager.
+ All optimizers in the same optim_manager share the same loss scale.
+
+ Args:
+ optim (torch.optim.Optimizer): A pytorch optimizer, e.g. torch.optim.Adam, torch.optim.SGD or bmtrain.optim.AdamOffloadOptimizer
+ lr_scheduler (Optional[WarmupLRScheduler]): A warmup lr scheduler, e.g. bmt.lr_scheduler.Noam
+ """
+ self.optimizers.append(optimizer)
+ self.lr_schedulers.append(lr_scheduler)
+
+ def scale_loss(self, loss : torch.Tensor) -> torch.Tensor:
+
+ return loss * ( self.loss_scale / self.grad_scale ) # loss scale
+
+ def backward(self, loss : torch.Tensor):
+ """
+ Backward with loss scale.
+
+ Args:
+ loss (torch.Tensor): loss
+ """
+ loss = self.scale_loss(loss)
+ loss.backward()
+ # some reduce ops of distributed parameter were launched on load stream
+ current_stream = torch.cuda.current_stream()
+ current_stream.wait_stream(config['load_stream'])
+
+ def zero_grad(self):
+ """
+ This is a helper function to call optimizer.zero_grad()
+ """
+ for optimizer in self.optimizers:
+ optimizer.zero_grad(set_to_none=False)
+
+ def step(self):
+ """
+ Backward with loss scale.
+ Synchronize streams before optimizer steps.
+
+ This is a helper function to call optimizer.step() and lr_scheduler.step() and synchronize streams.
+
+ This function can also handle gradient overflow by reducing the loss scale when it occurs.
+ """
+ if self.loss_scale_enabled:
+ has_overflow = False
+ for optimizer in self.optimizers:
+ try:
+ check_overflow(optimizer.param_groups)
+ except OverflowError:
+ has_overflow = True
+ break
+ if has_overflow:
+ print_rank("Gradient overflow, change scale from %lf to %lf" % (self.loss_scale, self.loss_scale / self.loss_scale_factor))
+ with torch.no_grad():
+ if self.loss_scale > self.min_loss_scale:
+ self._justify_scale(self.loss_scale / self.loss_scale_factor)
+ self.zero_grad()
+ return
+ for optimizer, lr_scheduler in zip(self.optimizers, self.lr_schedulers):
+ if hasattr(optimizer, "_bmtrain_optimizer") and optimizer._bmtrain_optimizer:
+ optimizer.step(scale=self.loss_scale)
+ else:
+ if self.loss_scale_enabled:
+ grad_rescale(optimizer.param_groups, self.loss_scale)
+ optimizer.step()
+
+ if lr_scheduler is not None:
+ lr_scheduler.step()
+
+ if self.loss_scale_enabled:
+ self.steps_since_last_scale += 1
+
+ if self.steps_since_last_scale >= self.loss_scale_steps and self.loss_scale < self.max_loss_scale:
+ self._justify_scale(self.loss_scale * self.loss_scale_factor)
+
+ current_stream = torch.cuda.current_stream()
+ config['load_stream'].wait_stream(current_stream)
+
+ def clip_grad_norm(self, param_groups, max_norm, norm_type=2, eps=1e-6):
+ """Clips gradient norm of an iterable of parameters.
+
+ The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.
+
+ Args:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized.
+ max_norm (float or int): max norm of the gradients.
+ norm_type (float or int): type of the used p-norm. Can be 'inf' for infinity norm.
+ eps (float): epsilon used to avoid zero division.
+
+ Returns:
+ Total norm of the parameters (viewed as a single vector).
+ """
+ scale = self.loss_scale
+ grads = []
+ parameters = [p for group in param_groups for p in group['params']]
+ for p in parameters:
+ if p.grad is not None:
+ grads.append(p.grad.data)
+ else:
+ grads.append(torch.zeros_like(p.data))
+
+ if norm_type == 'inf':
+ total_norm_cuda = max(g.data.abs().max() for g in grads).detach()
+ nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "max", config["comm"])
+ total_norm = total_norm_cuda
+ else:
+ norm_type = float(norm_type)
+ total_norm_cuda = torch.cuda.FloatTensor([0])
+ for index, g in enumerate(grads):
+ param_norm = g.data.float().norm(norm_type)
+ total_norm_cuda += param_norm ** norm_type
+ nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "sum", config["comm"])
+ total_norm = total_norm_cuda[0] ** (1. / norm_type)
+ # total_norm = total_norm / scale
+ # clip_coef = float(max_norm) / (total_norm + eps)
+ clip_coef = float(max_norm * scale) / (total_norm + eps)
+ if clip_coef < 1:
+ for p in parameters:
+ if p.grad is not None:
+ p.grad.data.mul_(clip_coef)
+ return total_norm / scale
+
+ @torch.no_grad()
+ def _justify_scale(self, scale):
+ for optimizer in self.optimizers:
+ if hasattr(optimizer, "_on_justify_scale"):
+ optimizer._on_justify_scale(self.loss_scale, scale)
+ self.loss_scale = scale
+ self.steps_since_last_scale = 0
+
+ def state_dict(self, gather_opt=False) -> dict:
+ return {
+ "optimizers": [opt.state_dict(gather_opt) for opt in self.optimizers],
+ "lr_schedulers": [lrs.state_dict() if lrs else None for lrs in self.lr_schedulers],
+ "loss_scale": self.loss_scale,
+ "loss_scale_enabled": self.loss_scale_enabled,
+ }
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ assert len(self.optimizers) == len(state_dict["optimizers"])
+ assert len(self.lr_schedulers) == len(state_dict["lr_schedulers"])
+ for opt, opt_st in zip(self.optimizers, state_dict["optimizers"]):
+ opt.load_state_dict(opt_st)
+ for lrs, lrs_st in zip(self.lr_schedulers, state_dict["lr_schedulers"]):
+ lrs.load_state_dict(lrs_st)
+ self.loss_scale = state_dict["loss_scale"]
+ self.loss_scale_enabled = state_dict["loss_scale_enabled"]
diff --git a/examples/BMTrain/bmtrain/param_init.py b/examples/BMTrain/bmtrain/param_init.py
new file mode 100644
index 00000000..21f95f25
--- /dev/null
+++ b/examples/BMTrain/bmtrain/param_init.py
@@ -0,0 +1,105 @@
+from typing import Generator, Iterable, List, Tuple
+import torch
+from .block_layer import Block
+from .parameter import DistributedParameter
+from .global_var import config
+
+
+def init_distributed_parameter(params: Iterable[torch.nn.Parameter]):
+ """Init param of params which is instance of DistributedParameter using param._init_method.
+
+ Args:
+ params (Iterable[torch.nn.Parameter]): parameter tensors.
+
+ """
+ for param in params:
+ if not isinstance(param, DistributedParameter):
+ continue
+ if param._init_method is None:
+ continue
+ with torch.no_grad():
+ partition_size = param.storage().size()
+ global_size = partition_size * config["tp_zero_size"] * config["tp_size"]
+ tmp_storage = param.storage_type()(global_size)
+ tmp_tensor = torch.tensor([], dtype=param.dtype, device="cuda")
+ tmp_tensor.set_(tmp_storage, 0, param._tp_original_shape)
+
+ param._init_method(tmp_tensor)
+ if param._tp_mode and param._tp_split_dim >= 0:
+ tensor_list = tmp_tensor.chunk(
+ config["tp_size"], dim=param._tp_split_dim
+ )
+ sub_tensor = tensor_list[config["topology"].tp_id].contiguous()
+ tmp_tensor = torch.empty(
+ sub_tensor.shape, device=param.device, dtype=sub_tensor.dtype
+ )
+ tmp_tensor.copy_(sub_tensor)
+
+ if param._tp_mode:
+ begin = config["tp_zero_rank"]
+ else:
+ begin = config["zero_rank"]
+ end = begin + 1
+
+ # Pytorch 1.11 changed the API of storage.__getitem__
+ torch.tensor([], dtype=param.dtype, device=param.device).set_(
+ param.storage()
+ )[:] = torch.tensor([], dtype=param.dtype, device=param.device).set_(
+ tmp_tensor.storage()
+ )[
+ partition_size * begin : partition_size * end
+ ]
+ # param.storage().copy_(tmp_storage[partition_size * config['rank'] : partition_size * (config['rank'] + 1)])
+
+
+def iterate_parameters(model: torch.nn.Module):
+ """
+ Itterate over the parameters of the model.
+ """
+ for kw, val in model._parameters.items():
+ if hasattr(val, "_in_block") and val._in_block:
+ return []
+ yield val
+
+
+def init_parameters(model: torch.nn.Module):
+ """
+ Initialize the parameters of the model by calling the init_method of the distributed parameters.
+ """
+
+ modules = model.named_modules()
+ for module_prefix, module in modules:
+ if isinstance(module, Block):
+ module.init_parameters()
+ else:
+ init_distributed_parameter(iterate_parameters(module))
+
+ current_stream = torch.cuda.current_stream()
+ config["load_stream"].wait_stream(current_stream)
+
+
+def grouped_parameters(
+ model: torch.nn.Module,
+) -> Generator[Tuple[str, List[torch.nn.Parameter]], None, None]:
+ """
+ Iterate over the parameters of the model grouped by the group name.
+ This is similar to `torch.nn.Module.named_parameters()` .
+ """
+
+ ret: List[torch.nn.Parameter] = {}
+ for module in model.modules():
+ if isinstance(module, Block):
+ for kw, params in module.grouped_parameters():
+ if kw not in ret:
+ ret[kw] = []
+ ret[kw].extend(params)
+ else:
+ for param in module._parameters.values():
+ group = None
+ if isinstance(param, DistributedParameter):
+ group = param.group
+ if group not in ret:
+ ret[group] = []
+ ret[group].append(param)
+ for kw, val in ret.items():
+ yield kw, val
diff --git a/examples/BMTrain/bmtrain/parameter.py b/examples/BMTrain/bmtrain/parameter.py
new file mode 100644
index 00000000..2dad4a3d
--- /dev/null
+++ b/examples/BMTrain/bmtrain/parameter.py
@@ -0,0 +1,206 @@
+from typing import Callable, Iterable, Optional
+import torch
+from .utils import round_up
+from .global_var import config
+from . import nccl
+from .distributed import all_gather
+
+
+class DistributedParameter(torch.nn.Parameter):
+ r"""
+ DistributedParameter is a subclass of torch.nn.Parameter.
+
+ It scatters the tensor to all the nodes and gathers them when needed.
+
+ Args:
+ data (Tensor): parameter tensor.
+ requires_grad (bool, optional): if the parameter requires gradient.
+ init_method (Callable[['DistributedParameter'], None], optional): the method to initialize the parameter.
+ group (str, optional): the group name of the parameter.
+
+ **Note**: DistributedParameter must be on the CUDA device. It will transfer the data to device automatically when `__init__` called.
+
+ """
+
+ _original_shape: torch.Size
+ _start_partition: int
+ _end_partition: int
+ _init_method: Optional[Callable[["DistributedParameter"], None]]
+ _in_block: bool
+ _group: Optional[str]
+
+ def __new__(
+ cls,
+ data: torch.Tensor,
+ requires_grad: bool = True,
+ init_method: Optional[Callable[["DistributedParameter"], None]] = None,
+ group: Optional[str] = None,
+ tp_mode: bool = False,
+ tp_split_dim: int = -1,
+ ):
+ if not config["initialized"]:
+ raise RuntimeError("BMTrain is not initialized")
+
+ num_of_elements = data.numel()
+
+ cuda_tensor = torch.tensor([], dtype=data.dtype, device="cuda")
+ if tp_mode:
+ comm = config["tp_zero_comm"]
+ else:
+ comm = config["zero_comm"]
+ world_size = nccl.commCount(comm)
+ rank = nccl.commRank(comm)
+ cuda_storage_size = round_up(num_of_elements, world_size) // world_size
+
+ original_shape = data.size()
+ tp_original_shape = original_shape
+ if tp_mode and tp_split_dim >= 0:
+ tp_original_shape = list(original_shape)
+ tp_original_shape[tp_split_dim] *= config["tp_size"]
+
+ cuda_storage = cuda_tensor.storage_type()(cuda_storage_size)
+
+ start_of_partition = cuda_storage_size * rank
+ end_of_partition = min(num_of_elements, cuda_storage_size * (rank + 1))
+
+ # FX: cuda_tensor_size < 0 if num_of_elements is too small
+ cuda_tensor_size = max(end_of_partition - start_of_partition, 0)
+
+ cuda_tensor.set_(cuda_storage, 0, (cuda_tensor_size,))
+ cuda_tensor.copy_(data.view(-1)[start_of_partition:end_of_partition])
+ ret = torch.Tensor._make_subclass(cls, cuda_tensor, requires_grad)
+
+ setattr(ret, "_original_shape", original_shape)
+ setattr(ret, "_start_partition", start_of_partition)
+ setattr(ret, "_end_partition", end_of_partition)
+ setattr(ret, "_init_method", init_method)
+ setattr(ret, "_in_block", False)
+ setattr(ret, "_group", group if not tp_mode else "tp")
+
+ setattr(ret, "_tp_mode", tp_mode)
+ setattr(ret, "_zero_comm", comm)
+ setattr(ret, "_tp_split_dim", tp_split_dim)
+ setattr(ret, "_tp_original_shape", tp_original_shape)
+ return ret
+
+ @property
+ def group(self):
+ """The group name of the distributed parameter."""
+
+ return self._group
+
+ def gather(self) -> torch.Tensor:
+ """Gather the data from ZeRO distributed nodes.
+
+ Return:
+ torch.Tensor: The gathered data.
+
+ """
+ with torch.cuda.stream(config["load_stream"]):
+ output_tensor = OpAllGather.apply(self)
+ current_stream = torch.cuda.current_stream()
+ output_tensor.record_stream(current_stream)
+ current_stream.wait_stream(config["load_stream"])
+ return output_tensor
+
+ def gather_all(self) -> torch.tensor:
+ """Gather the data from ZeRO and Tensor Parallel distributed nodes.
+
+ Return:
+ torch.Tensor: The gathered data.
+
+ """
+ zero_param = self.gather()
+ if config["tp_size"] > 1 and self._tp_split_dim >= 0:
+ output_tensor = all_gather(zero_param, config["tp_comm"])
+ if self._tp_split_dim == 1:
+ output_list = output_tensor.chunk(config["tp_size"], dim=0)
+ output = torch.cat(output_list, dim=output_list[0].dim() - 1).flatten(
+ 0, 1
+ )
+ return output
+ else:
+ return output_tensor.flatten(0, 1)
+ else:
+ return zero_param
+
+ def tp_gather(self) -> torch.tensor:
+ """Gather the data from Tensor Parallel distributed nodes.
+
+ Return:
+ torch.Tensor: The gathered data.
+
+ """
+ if config["tp_size"] > 1 and self._tp_split_dim >= 0:
+ output_tensor = all_gather(self, config["tp_comm"])
+ if self._tp_split_dim == 1:
+ output_list = output_tensor.chunk(config["tp_size"], dim=0)
+ output = torch.cat(output_list, dim=output_list[0].dim() - 1).flatten(
+ 0, 1
+ )
+ return output
+ else:
+ return output_tensor.flatten(0, 1)
+ else:
+ return self
+
+ def _copy_data(self, data: torch.Tensor):
+ """Copy data to self.data."""
+ self.data.copy_(data.view(-1)[self._start_partition : self._end_partition])
+
+
+class OpAllGather(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, value: DistributedParameter):
+ assert isinstance(value, DistributedParameter)
+ comm = value._zero_comm # config['zero_comm']
+ world_size = nccl.commCount(comm)
+ ctx.comm = comm
+ ctx.world_size = world_size
+
+ partition_size = value.storage().size()
+ global_size = partition_size * world_size
+
+ storage = value.storage_type()(global_size)
+
+ nccl.allGather(value.storage(), storage, comm)
+
+ output_tensor = torch.tensor([], dtype=value.dtype, device="cuda")
+ output_tensor.set_(storage, 0, value._original_shape)
+
+ ctx.partition_size = partition_size
+ ctx.tensor_size = value.size(0)
+ return output_tensor
+
+ @staticmethod
+ def backward(ctx, grad_output: torch.Tensor):
+ if not grad_output.is_contiguous():
+ grad_output = grad_output.contiguous()
+
+ grad_storage = grad_output.storage_type()(ctx.partition_size)
+ grad_output_storage = grad_output.storage()
+ if grad_output_storage.size() == ctx.partition_size * ctx.world_size:
+ pass
+ else:
+ grad_output_storage.resize_(ctx.partition_size * ctx.world_size)
+ nccl.reduceScatter(grad_output_storage, grad_storage, "sum", ctx.comm)
+ grad_tensor = torch.tensor([], dtype=grad_output.dtype, device="cuda")
+ grad_tensor.set_(grad_storage, 0, (ctx.tensor_size,))
+ return grad_tensor
+
+
+class ParameterInitializer:
+ """
+ ParameterInitializer is a helper class that is used to initialize the distributed parameters.
+
+ Similar to functools.partial .
+
+ """
+
+ def __init__(self, func: Callable, *args, **kwargs) -> None:
+ self.func = func
+ self._args = args
+ self._kwargs = kwargs
+
+ def __call__(self, param: DistributedParameter):
+ self.func(param, *self._args, **self._kwargs)
diff --git a/examples/BMTrain/bmtrain/pipe_layer.py b/examples/BMTrain/bmtrain/pipe_layer.py
new file mode 100644
index 00000000..4d3b17ad
--- /dev/null
+++ b/examples/BMTrain/bmtrain/pipe_layer.py
@@ -0,0 +1,314 @@
+from collections import OrderedDict
+import copy
+import torch
+import copy
+from typing import Dict, Iterable, Iterator, Tuple, Union, List
+import torch
+
+from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations
+from .global_var import config
+from . import nccl
+from .zero_context import (
+ ZeroContext
+)
+from . import debug
+from .block_layer import Block, round_up, _get_param_kw, _block_wrapper
+
+class PipePreFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, hidden_state, *args):
+ hidden_state_list = all_gather(hidden_state.clone(), config["pipe_comm"])
+ hidden_state_list.requires_grad_()
+
+ batch_related = args[-1]
+ batch_related_origin = [True if i in args[-1] else False for i in range(len(args[:-1]))]
+ batch_related_rule = []
+ args = args[:-1]
+
+ batch_size = hidden_state.shape[0]
+ num_micros = config["micros"]
+ args_list = [[] for _ in range(num_micros)]
+ input_requires_grad = []
+ for arg in args:
+ if torch.is_tensor(arg):
+ arg_all = all_gather(arg, config['pipe_comm'])
+ if arg.dim() == hidden_state.dim() and arg.shape[0] == batch_size:
+ batch_related_rule.append(True)
+ arg_all = arg_all.flatten(0, 1).chunk(num_micros, dim=0)
+ arg_all = [tensor.requires_grad_(arg.requires_grad) for tensor in arg_all]
+ else:
+ batch_related_rule.append(False)
+ arg_all = [arg_all[0].requires_grad_(arg.requires_grad) for i in range(num_micros)]
+ input_requires_grad.append(arg.requires_grad)
+ else:
+ batch_related_rule.append(False)
+ arg_all = [arg for _ in range(num_micros)]
+ input_requires_grad.append(False)
+ for i in range(num_micros):
+ args_list[i].append(arg_all[i])
+ ctx.input_requires_grad = input_requires_grad
+ ctx.args_list = args_list
+ if len(batch_related) == 0:
+ ctx.batch_related = batch_related_rule
+ else:
+ ctx.batch_related = batch_related_origin
+ return hidden_state_list, args_list
+
+ @staticmethod
+ def backward(ctx, grads, arg_grads):
+ grads = broadcast(grads, 0, config['pipe_comm'])
+ topo = config['topology']
+ arg_grads = []
+ num_micros = config['micros']
+ for idx,requires_grad in enumerate(ctx.input_requires_grad):
+ if requires_grad:
+ grad = torch.cat([ctx.args_list[m][idx].grad for m in range(num_micros)], dim=0)
+ grad = all_reduce(grad, "sum", config["pipe_comm"])
+ split_size = topo.stages if ctx.batch_related[idx] else num_micros
+ grad = grad.chunk(split_size)
+ if ctx.batch_related[idx]:
+ arg_grads.append(grad[topo.stage_id])
+ else:
+ arg_grads.append(grad[0])
+ else:
+ arg_grads.append(None)
+ arg_grads.append(None) #for append(batch_related)
+ return grads.chunk(topo.stages, dim=0)[topo.stage_id], *arg_grads
+
+class PipePostFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, last_hidden, hidden_states=None, forward_stage_ranges=None, backward_stage_ranges=None, last_hidden_shape=None, return_hidden_states=False):
+ topo = config['topology']
+ ctx.return_hidden_states = return_hidden_states
+ last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"])
+ last_hidden = last_hidden.chunk(topo.stages, dim=0)
+ output = last_hidden[topo.stage_id]
+ output.requires_grad_()
+
+ if return_hidden_states:
+ ctx.stage_id = topo.stage_id
+ ctx.stages = topo.stages
+ ctx.backward_stage_ranges = backward_stage_ranges
+ middle_hiddens = []
+ for stage_id in range(ctx.stages):
+ if ctx.stage_id == stage_id:
+ middle_hidden = hidden_states
+ else:
+ middle_shape = (forward_stage_ranges[stage_id],) + last_hidden_shape
+ middle_hidden = torch.zeros(middle_shape, device=hidden_states.device, dtype=hidden_states.dtype)
+ middle_hidden = broadcast(middle_hidden, stage_id, config["pipe_comm"])
+ middle_hidden = middle_hidden.chunk(ctx.stages, dim=1)
+ middle_hidden = middle_hidden[ctx.stage_id].clone()
+ middle_hiddens.append(middle_hidden)
+ middle_hiddens = torch.cat(middle_hiddens, dim=0)
+ middle_hiddens.requires_grad_()
+ return output, middle_hiddens
+ else:
+ return output
+
+ @staticmethod
+ def backward(ctx, grads, grad_middle=None):
+ grad_list = all_gather(grads, config["pipe_comm"])
+ grad_list = grad_list.flatten(start_dim=0, end_dim=1)
+
+ if ctx.return_hidden_states:
+ for stage_id in range(ctx.stages):
+ layer_range = ctx.backward_stage_ranges[stage_id]
+ grad_middle_state = grad_middle[layer_range]
+ grad_middle_state = all_gather(grad_middle_state.transpose(0,1), config["pipe_comm"])
+ grad_middle_state = grad_middle_state.flatten(start_dim=0, end_dim=1).transpose(0, 1)
+ if ctx.stage_id == stage_id:
+ grad_hidden_state_list = grad_middle_state
+ return grad_list, grad_hidden_state_list, None, None, None, None
+ else:
+ return grad_list
+
+class StagePreFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, stage_id):
+ ctx.stage_id = stage_id
+ ctx.is_first_stage = stage_id == 0
+ ctx.is_last_stage = stage_id == config['pipe_size'] - 1
+ if not ctx.is_first_stage:
+ input = recv_activations(stage_id - 1, config['pipe_comm'])
+ input.requires_grad_()
+ return input
+ return input
+
+ @staticmethod
+ def backward(ctx, grad_outputs):
+ if not ctx.is_first_stage:
+ send_data = grad_outputs[0] if isinstance(grad_outputs, tuple) else grad_outputs
+ current_stream = torch.cuda.current_stream()
+ with torch.cuda.stream(config['pp_comm_stream']):
+ config['pp_comm_stream'].wait_stream(current_stream)
+ send_data.record_stream(config['pp_comm_stream'])
+ send_activations(send_data, ctx.stage_id - 1, config['pipe_comm'])
+ return grad_outputs, None
+
+class StagePostFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, outputs, stage_id):
+ ctx.stage_id = stage_id
+ ctx.is_first_stage = stage_id == 0
+ ctx.is_last_stage = stage_id == config['pipe_size'] - 1
+ if not ctx.is_last_stage:
+ send_data = outputs[0] if isinstance(outputs, tuple) else outputs
+ current_stream = torch.cuda.current_stream()
+ with torch.cuda.stream(config['pp_comm_stream']):
+ config['pp_comm_stream'].wait_stream(current_stream)
+ send_data.record_stream(config['pp_comm_stream'])
+ send_activations(send_data.detach(), stage_id + 1, config['pipe_comm'])
+ return outputs
+
+ @staticmethod
+ def backward(ctx, grad_outputs):
+ if not ctx.is_last_stage:
+ pre_grad_inputs = recv_activations(ctx.stage_id + 1, config['pipe_comm'])
+ return pre_grad_inputs, None
+ return grad_outputs, None
+
+
+class PipelineTransformerBlockList(torch.nn.Module):
+ r"""
+ TransformerBlockList is a list of Blocks.
+
+ This is designed to reduce the communication overhead by overlapping the computation and reduce_scatter operation during backward pass.
+
+ It is similar to `torch.nn.ModuleList` but with the difference when calling .forward() and .backward().
+
+ Example:
+ >>> module_list = [ ... ]
+ >>> normal_module_list = torch.nn.ModuleList(module_list)
+ >>> transformer_module_list = PipelineTransformerBlockList(module_list)
+ >>> # Calling normal module list
+ >>> for layer in normal_module_list:
+ >>> hidden_state = layer.forward(hidden_state, ...)
+ >>> # Calling transformer module list
+ >>> hidden_state = transformer_module_list(hidden_state, ...)
+
+ """
+ _modules: Dict[str, Block]
+
+ def __init__(self, modules: Iterable[torch.nn.Module], num_hidden=1) -> None:
+ super().__init__()
+ self.num_hidden = num_hidden
+ self._modules = {}
+ self.layer_ids = []
+ topo = config["topology"]
+ self.stages = topo.stages
+ self.stage_id = topo.stage_id
+ self.pipe_idx = topo.pipe_idx
+ module_dict = {}
+ for idx, module in enumerate(modules):
+ module = _block_wrapper(module, module_dict, "PIPE")
+ module._zero_level = 2 #currently, only support ZeRO-2 in pipeline mode
+ self._modules[str(idx)] = module
+
+ self.layer_ids = self.get_range_by_stage_id(self.stage_id)
+
+ pre_module = None
+ for i,layer_id in enumerate(self.layer_ids):
+ module = self._modules[str(layer_id)]
+ module.set_pre_module(pre_module)
+ pre_module = module
+ module._is_first_layer = False
+ module._is_last_layer = False
+
+ self._modules[str(self.layer_ids[0])]._is_first_layer = True
+ self._modules[str(self.layer_ids[-1])]._is_last_layer = True
+
+ def __len__(self) -> int:
+ return len(self._modules)
+
+ def __iter__(self) -> Iterator[Block]:
+ return iter(self._modules.values())
+
+ def __getitem__(self, index: Union[int, str]) -> Block:
+ return self._modules[str(index)]
+
+ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=False):
+ self.return_hidden_states = return_hidden_states
+ batch_size = hidden_state.shape[0]
+ num_micros = config["micros"]
+ args = args + (batch_related, )
+ hidden_state.requires_grad_()
+ hidden_state_list, args_list = PipePreFunction.apply(hidden_state, *args)
+
+ hidden_state_list = hidden_state_list.flatten(0, 1).chunk(num_micros, dim=0)
+ outputs = []
+ hidden_states = []
+
+ for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)):
+ micro_hidden_states = []
+
+ hidden_state = StagePreFunction.apply(hidden_state, self.stage_id)
+
+ for idx,layer_id in enumerate(self.layer_ids):
+ self._modules[str(layer_id)]._micro_idx = micro_idx
+ if return_hidden_states:
+ micro_hidden_states.append(hidden_state)
+ hidden_state = self._modules[str(layer_id)](hidden_state, *arg)
+ hidden_state = StagePostFunction.apply(hidden_state, self.stage_id)
+
+ outputs.append(hidden_state)
+ if return_hidden_states:
+ hidden_states.append(torch.stack(micro_hidden_states, dim=0))
+
+ last_hidden = torch.cat(outputs, dim=0)
+ last_hidden_shape = last_hidden.shape
+
+ if return_hidden_states:
+ hidden_states = torch.cat(hidden_states, dim=1)
+ forward_stage_ranges = []
+ backward_stage_ranges = []
+ for stage_id in range(self.stages):
+ forward_stage_ranges.append(self.get_part_len_by_stage_id(stage_id))
+ backward_stage_ranges.append(self.get_range_by_stage_id(stage_id))
+ outputs, hidden_states = PipePostFunction.apply(last_hidden, hidden_states, forward_stage_ranges, backward_stage_ranges, last_hidden_shape, return_hidden_states)
+ return outputs, hidden_states
+ else:
+ outputs = PipePostFunction.apply(last_hidden)
+ return outputs
+
+ def get_range_by_stage_id(self, stage_id : int) -> List[int]:
+ part_lens = [0]+[self.get_part_len_by_stage_id(i) for i in range(stage_id+1)]
+ start = sum(part_lens[:stage_id+1])
+ end = start + part_lens[stage_id+1]
+ return range(start, end)
+
+ def get_part_len_by_stage_id(self, stage_id : int) -> int:
+ return len(self) // self.stages + (stage_id < (len(self) % self.stages))
+
+ def get_stage_by_layer_id(self, layer_id : int) -> int:
+ part_len = len(self) // self.stages
+ rest = len(self) % self.stages
+ if layer_id // (part_len + 1) < rest:
+ return layer_id // (part_len + 1)
+ else:
+ return rest + (layer_id - rest * (part_len+1)) // part_len
+
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
+ for name, module in self._modules.items():
+ idx = int(name)
+ name = prefix + name + '.'
+
+ dst = OrderedDict() # creates an temporary ordered dict
+ dst._metadata = OrderedDict()
+
+ if idx in self.layer_ids:
+ with torch.no_grad():
+ with ZeroContext(module, pipe=True):
+ module._module.state_dict(destination=dst, prefix=name, keep_vars=False)
+
+ if config["topology"].pp_zero_id == 0:
+ if config["rank"] == 0:
+ destination.update(dst)
+ else:
+ assert list(dst.keys()) == [name+n for n, parameter in module._module.named_parameters()]
+ for key, tensor in dst.items():
+ send_activations(tensor.cuda(), 0, config['pipe_comm'])
+ if config['rank'] == 0 and idx not in self.layer_ids:
+ for n, parameter in module._module.named_parameters():
+ destination[name+n] = recv_activations(self.get_stage_by_layer_id(idx), config['pipe_comm']).cpu()
+
diff --git a/examples/BMTrain/bmtrain/store.py b/examples/BMTrain/bmtrain/store.py
new file mode 100644
index 00000000..2a3ee02c
--- /dev/null
+++ b/examples/BMTrain/bmtrain/store.py
@@ -0,0 +1,325 @@
+from collections import OrderedDict
+from typing import Dict
+import torch
+
+from .pipe_layer import PipelineTransformerBlockList
+from .block_layer import TransformerBlockList
+from .global_var import config
+from .block_layer import Block
+from . import nccl
+import io, pickle
+from typing import Mapping
+import threading
+import bmtrain as bmt
+
+def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix):
+ if isinstance(model, Block):
+ if rank != 0:
+ destination = OrderedDict() # creates an temporary ordered dict
+ destination._metadata = OrderedDict()
+ model.state_dict(destination=destination, prefix=prefix, keep_vars=False)
+ else:
+ if rank != 0:
+ destination = OrderedDict() # creates an temporary ordered dict
+ destination._metadata = OrderedDict()
+ model._save_to_state_dict(destination, prefix, False)
+
+def _save_to_local_rank0(model : torch.nn.Module, destination=None, prefix=''):
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version)
+ _save_to_state_dict(model, config['local_rank'], destination, prefix)
+ for name, module in model._modules.items():
+ if module is not None:
+ _save_to_local_rank0(module, destination, prefix + name + '.')
+ for hook in model._state_dict_hooks.values():
+ hook_result = hook(model, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+ return destination
+
+
+def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''):
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version)
+ if not isinstance(model, PipelineTransformerBlockList):
+ _save_to_state_dict(model, config['rank'], destination, prefix)
+ for name, module in model._modules.items():
+ if module is not None:
+ _save_to_rank0(module, destination, prefix + name + '.')
+ for hook in model._state_dict_hooks.values():
+ hook_result = hook(model, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+ else:
+ model._save_to_state_dict(destination, prefix, False)
+ return destination
+
+def _save_to_infer_model(model : torch.nn.Module, infer_model, destination=None, prefix=''):
+ config['save_param_to_cpu'] = False
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version)
+ _save_to_state_dict(model, config['local_rank'], destination, prefix)
+ for name, module in model._modules.items():
+ if module is not None:
+ if isinstance(module, TransformerBlockList):
+ for local_name, local_module in module._modules.items():
+ local_state_dict = _save_to_local_rank0(local_module, None, prefix + name + "." + local_name + '.')
+ if config['local_rank'] == 0:
+ infer_model.load_layer_state_dict(local_state_dict)
+ else:
+ _save_to_infer_model(module, infer_model, destination, prefix + name + '.')
+ for hook in model._state_dict_hooks.values():
+ hook_result = hook(model, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+
+ if config['local_rank'] == 0:
+ infer_model.load_layer_state_dict(destination)
+
+
+def async_save_to_file(state_dict, file_path):
+ torch.save(state_dict, file_path)
+ config['finish_save'] = True
+ print("finish save state_dict to ", file_path)
+
+def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False):
+ """Saves the model to the file.
+
+ Similar to torch.save, but it used for distributed modules.
+
+ Args:
+ model (torch.nn.Module): The model to be saved.
+ file_name (str): The file name of the checkpoint.
+ non_blocking (bool): Whether to asynchronously save state_dict to file
+
+
+ Examples:
+ >>> bmtrain.save(model, "model.pt")
+ """
+ torch.cuda.synchronize()
+ state_dict = _save_to_rank0(model)
+ if config["rank"] == 0:
+ if non_blocking is False:
+ torch.save(state_dict, file_name)
+ else:
+ if 'finish_save' not in config:
+ config['finish_save'] = True
+
+ if config['finish_save'] is False:
+ config['save_thread'].join()
+
+ config['finish_save'] = False
+ config['save_thread'] = threading.Thread(target=async_save_to_file, args=(state_dict, file_name))
+ config['save_thread'].start()
+ bmt.synchronize()
+
+DTYPE_LIST = [
+ torch.float64,
+ torch.float32,
+ torch.float16,
+ torch.int64,
+ torch.int32,
+ torch.int16,
+ torch.int8,
+ torch.bfloat16,
+ torch.bool
+]
+
+_pickler = pickle.Pickler
+_unpickler = pickle.Unpickler
+
+def allgather_objects(obj):
+ if bmt.world_size() == 1:
+ return [obj]
+
+ with torch.no_grad():
+ data_bytes: bytes = pickle.dumps(obj)
+ data_length: int = len(data_bytes)
+
+ gpu_data_length = torch.tensor([data_length], device="cuda", dtype=torch.long)
+ gathered_length = bmt.distributed.all_gather(gpu_data_length).view(-1).cpu()
+ max_data_length = gathered_length.max().item()
+
+ gpu_data_bytes = torch.zeros(max_data_length, dtype=torch.uint8, device="cuda")
+ byte_storage = torch.ByteStorage.from_buffer(data_bytes)
+ gpu_data_bytes[:data_length] = torch.ByteTensor(byte_storage)
+
+ gathered_data = bmt.distributed.all_gather(gpu_data_bytes).cpu()
+
+ ret = []
+ for i in range(gathered_data.size(0)):
+ data_bytes = gathered_data[i, : gathered_length[i].item()].numpy().tobytes()
+ ret.append(pickle.loads(data_bytes))
+ return ret
+
+def broadcast_object(obj, comm, src = 0):
+ if nccl.commRank(comm) == src:
+ f = io.BytesIO()
+ _pickler(f).dump(obj)
+ byte_storage = torch.ByteStorage.from_buffer(f.getvalue())
+ # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
+ # Otherwise, it will casue 100X slowdown.
+ # See: https://github.com/pytorch/pytorch/issues/65696
+ byte_tensor = torch.ByteTensor(byte_storage).cuda()
+ local_size = torch.LongTensor([byte_tensor.numel()]).cuda()
+
+ nccl.broadcast(
+ local_size.storage(),
+ local_size.storage(),
+ src,
+ comm
+ )
+ nccl.broadcast(
+ byte_tensor.storage(),
+ byte_tensor.storage(),
+ src,
+ comm
+ )
+ else:
+ local_size = torch.LongTensor([0]).cuda()
+ nccl.broadcast(
+ local_size.storage(),
+ local_size.storage(),
+ src,
+ comm
+ )
+ byte_tensor_size = local_size[0].item()
+ byte_tensor = torch.empty(int(byte_tensor_size), dtype=torch.uint8, device="cuda")
+ nccl.broadcast(
+ byte_tensor.storage(),
+ byte_tensor.storage(),
+ src,
+ comm
+ )
+ buf = byte_tensor.cpu().numpy().tobytes()
+ obj = _unpickler(io.BytesIO(buf)).load()
+ return obj
+
+# Must be a Mapping after pytorch 1.12.0
+class DistributedTensorWrapper:
+ def __init__(self, tensor, shape=None):
+ self._dtype = tensor.dtype
+ self._device = tensor.device
+ self.shape = shape
+ self.tensor = tensor
+
+ def broadcast(self):
+ output_param = torch.empty(self.shape, dtype=self._dtype, device="cuda")
+ if config['rank'] == 0:
+ input_param = self.tensor
+ if input_param.is_cuda:
+ input_param = input_param.clone().contiguous()
+ else:
+ input_param = input_param.cuda().contiguous()
+
+ nccl.broadcast(
+ input_param.storage(),
+ output_param.storage(),
+ 0,
+ config['comm']
+ )
+ else:
+ nccl.broadcast(
+ output_param.storage(),
+ output_param.storage(),
+ 0,
+ config['comm']
+ )
+ return output_param
+
+ def copy(self):
+ return self.tensor
+
+ def __getattribute__(self, name):
+ if name == "tensor" or name == "shape":
+ return object.__getattribute__(self, name)
+ else:
+ try:
+ return object.__getattribute__(self, name)
+ except AttributeError:
+ pass
+
+ return getattr(self.tensor, name)
+
+class DistributedStateDictWrapper(Mapping):
+ def __init__(self, state_dict : Dict) -> None:
+ self._state_dict = state_dict
+ self._metadata = broadcast_object(getattr(state_dict, "_metadata", None), config["comm"])
+
+ def __getitem__(self, key : str):
+ tmp_shape = torch.zeros(32, device="cuda", dtype=torch.int32)
+ if config['rank'] == 0:
+ input_param : torch.Tensor = self._state_dict[key]
+ shape_list = torch.tensor(list(input_param.size()), device="cuda", dtype=torch.int32)
+ dtype_idx = DTYPE_LIST.index(input_param.dtype)
+
+ assert dtype_idx != -1, "Unknown data type %s" % input_param.dtype
+
+ tmp_shape[0] = shape_list.size(0)
+ tmp_shape[1] = dtype_idx
+ tmp_shape[2:2 + shape_list.size(0)] = shape_list
+
+ nccl.broadcast(
+ tmp_shape.storage(),
+ tmp_shape.storage(),
+ 0,
+ config['comm']
+ )
+
+ shape_list_size = tmp_shape[0].item()
+ dtype_idx = tmp_shape[1].item()
+ shape_list = torch.Size(tmp_shape[2: 2 + shape_list_size].tolist())
+
+ if config['rank'] != 0:
+ return DistributedTensorWrapper(torch.tensor([], dtype=DTYPE_LIST[dtype_idx], device="cuda"), shape=shape_list)
+ else:
+ return DistributedTensorWrapper(self._state_dict[key], shape=shape_list)
+
+
+
+ def copy(self):
+ return self
+
+ def __len__(self):
+ return broadcast_object(len(self._state_dict), config["comm"])
+
+ def __contains__(self, key : str):
+ return broadcast_object(key in self._state_dict, config["comm"])
+
+ def keys(self):
+ return broadcast_object(list(self._state_dict.keys()),config["comm"])
+
+ def __iter__(self):
+ # pytorch 1.12.0 updated the load_state_dict method, which needs the state_dict to be a `Mapping`.
+ return iter(self.keys())
+
+def load(model : torch.nn.Module, file_name : str, strict : bool = True):
+ """Loads the model from the file.
+
+ Similar to torch.load, but it uses less memory when loading large models.
+
+ Args:
+ model (torch.nn.Module): The model to be loaded.
+ file_name (str): The file name of the checkpoint.
+ strict (bool): Strict option of `load_state_dict`.
+
+ Example:
+ >>> bmtrain.load(model, "model.pt", strict=True)
+ """
+ if config['rank'] == 0:
+ state_dict = DistributedStateDictWrapper(torch.load(file_name))
+ else:
+ state_dict = DistributedStateDictWrapper({})
+
+ ret = model.load_state_dict(
+ state_dict,
+ strict = strict
+ )
+ torch.cuda.synchronize()
+ return ret
diff --git a/examples/BMTrain/bmtrain/synchronize.py b/examples/BMTrain/bmtrain/synchronize.py
new file mode 100644
index 00000000..87619159
--- /dev/null
+++ b/examples/BMTrain/bmtrain/synchronize.py
@@ -0,0 +1,73 @@
+import torch
+from . import distributed, nccl
+from .global_var import config
+import warnings
+from typing import Optional
+
+
+def synchronize():
+ """
+ Synchronize all the workers across all nodes. (both CPU and GPU are synchronized)
+ """
+ if not config["initialized"]:
+ raise RuntimeError("BMTrain is not initialized")
+
+ with torch.cuda.stream(config["barrier_stream"]):
+ barrier = torch.cuda.FloatTensor([1])
+ nccl.allReduce(barrier.storage(), barrier.storage(), "sum", config["comm"])
+ config["barrier_stream"].synchronize()
+
+
+def wait_loader():
+ """
+ Clac_stream (normally current stream) wait latest loader event, and set a new one.
+ """
+ if not config["initialized"]:
+ raise RuntimeError("BMTrain is not initialized")
+
+ config["load_event"].synchronize()
+ config["calc_stream"].record_event(config["load_event"])
+
+
+def sum_loss(loss: torch.Tensor, comm: Optional[nccl.NCCLCommunicator] = None):
+ """
+ Sum the loss across all workers.
+
+ This is a helper function to reduce the loss across all workers.
+ """
+ if comm is None:
+ comm = config["comm"]
+ warnings.warn(
+ "bmtrain.sum_loss is deprecated and will be removed in later version. Use bmtrain.distributed.all_reduce instead.",
+ DeprecationWarning,
+ )
+
+ return distributed.all_reduce(loss, "avg", comm)
+
+
+def gather_result(result: torch.Tensor):
+ """
+ Gather result across all workers.
+ """
+ warnings.warn(
+ "bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.",
+ DeprecationWarning,
+ )
+ if result.storage_offset() != 0 or result.storage().size() != result.numel():
+ # Create a clone of the original tensor if it's a slice
+ result = result.clone()
+
+ output_cuda = True
+ if not result.is_cuda:
+ result = result.cuda()
+ output_cuda = False
+ ret = torch.empty(
+ (result.shape[0] * config["world_size"], *list(result.shape[1:])),
+ device=result.device,
+ dtype=result.dtype,
+ )
+ nccl.allGather(result.storage(), ret.storage(), config["comm"])
+ if output_cuda:
+ return ret
+ else:
+ return ret.cpu()
diff --git a/examples/BMTrain/bmtrain/utils.py b/examples/BMTrain/bmtrain/utils.py
new file mode 100644
index 00000000..daa4c595
--- /dev/null
+++ b/examples/BMTrain/bmtrain/utils.py
@@ -0,0 +1,184 @@
+import torch
+import sys
+from typing import Any, Dict, Iterable, Optional
+from .global_var import config
+import os
+import ctypes
+
+ALIGN = 4
+ROW_WIDTH = 60
+
+
+def check_torch_version(version_str):
+ """
+ Checks if the current torch version is greater than or equal to the given version.
+ version_str (str): The version to compare with, in the format of "x.y.z" ,and the func will convert it into a int value of x*100+y*10+z.
+ """
+ version_int_arr = [int(v) for v in version_str.split(".")]
+
+ version_int = (
+ version_int_arr[0] * 10000 + version_int_arr[1] * 100 + version_int_arr[2]
+ )
+ torch_version = torch.__version__.split("+")[0]
+ current_version_int_arr = [int(v) for v in torch_version.split(".")]
+ current_version_int = (
+ current_version_int_arr[0] * 10000
+ + current_version_int_arr[1] * 100
+ + current_version_int_arr[2]
+ )
+ return current_version_int - version_int
+
+
+def load_nccl_pypi():
+ """
+ Check if current nccl is avaliable.
+ """
+ try:
+ import nvidia.nccl
+ except:
+ raise ImportError("Run pip install nvidia-nccl-cu11 >=2.14.3 first")
+
+ path = os.path.join(os.path.dirname(nvidia.nccl.__file__), "lib")
+ for file_so in os.listdir(path):
+ file_split = file_so.split(".")
+ if file_split[-1] == "so" or (len(file_split) > 1 and file_split[-2] == "so"):
+ ctypes.CDLL(os.path.join(path, file_so))
+
+
+def round_up(x, d):
+ """
+ Return (x + d - 1) // d * d
+ """
+ return (x + d - 1) // d * d
+
+
+def print_dict(title: str, content: Dict[str, Any], file=sys.stdout):
+ """
+ Print Dict to file.
+ """
+ max_kw_len = max([len(kw) for kw in content.keys()])
+ max_kw_len = round_up(max_kw_len + 3, 4)
+
+ raw_content = ""
+
+ for kw, val in content.items():
+ raw_content += kw + " :" + " " * (max_kw_len - len(kw) - 2)
+ raw_val = "%s" % val
+
+ len_val_row = ROW_WIDTH - max_kw_len
+ st = 0
+ if len(raw_val) == 0:
+ raw_val = " "
+ while st < len(raw_val):
+ if st > 0:
+ raw_content += " " * max_kw_len
+ raw_content += raw_val[st : st + len_val_row] + "\n"
+ st += len_val_row
+
+ print_block(title, raw_content, file)
+
+
+def print_block(title: str, content: Optional[str] = None, file=sys.stdout):
+ """
+ Print content to file.
+ """
+ left_title = (ROW_WIDTH - len(title) - 2) // 2
+ right_title = ROW_WIDTH - len(title) - 2 - left_title
+
+ print("=" * left_title + " " + title + " " + "=" * right_title, file=file)
+ if content is not None:
+ print(content, file=file)
+
+
+def print_rank(*args, rank=0, **kwargs):
+ """
+ Prints the message only on the `rank` of the process.
+
+ Args:
+ *args: The arguments to be printed.
+ rank (int): The rank id of the process to print.
+ **kwargs: The keyword arguments to be printed.
+
+ """
+ if config["rank"] == rank:
+ print(*args, **kwargs)
+
+
+def see_memory(message, detail=False):
+ """
+ Outputs a message followed by GPU memory status summary on rank 0.
+ At the end of the function, the starting point in tracking maximum GPU memory will be reset.
+
+ Args:
+ message (str): The message to be printed. It can be used to distinguish between other outputs.
+ detail (bool): Whether to print memory status in a detailed way or in a concise way. Default to false.
+
+ Example:
+ >>> bmt.see_memory("before forward")
+ >>> # forward_step()
+ >>> bmt.see_memory("after forward")
+
+ """
+ print_rank(message)
+ if detail:
+ print_rank(torch.cuda.memory_summary())
+ else:
+ print_rank(
+ f"""
+ =======================================================================================
+ memory_allocated {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB
+ max_memory_allocated {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB
+ =======================================================================================
+ """
+ )
+ torch.cuda.reset_peak_memory_stats()
+
+
+def tp_split_tensor(tensor, split_dim):
+ """
+ Outpus the tensor with config["toplogy"].tp_id split at split dim.
+
+ Args:
+ tensor (torch.tensor): The tensor to be splited.
+ split_dim (int): The dim to split the input tensor.
+
+ """
+ tensor_list = tensor.chunk(config["tp_size"], dim=split_dim)
+ sub_tensor = tensor_list[config["topology"].tp_id].contiguous()
+ tmp_tensor = torch.empty(
+ sub_tensor.shape, device=sub_tensor.device, dtype=sub_tensor.dtype
+ )
+ tmp_tensor.copy_(sub_tensor)
+ return tmp_tensor
+
+
+class AverageRecorder:
+ """A utility class to record the average value of a quantity over time.
+
+ Args:
+ alpha (float): The decay factor of the average.
+ start_value (float): The initial value of the average.
+
+ Use `.value` to get the current average value.
+ It is calculated as `alpha * old_value + (1 - alpha) * new_value`.
+
+ """
+
+ def __init__(self, alpha=0.9, start_value=0):
+ self._value = start_value
+ self.alpha = alpha
+ self._steps = 0
+
+ def record(self, v):
+ """Records a new value.
+ Args:
+ v (float): The new value.
+ """
+ self._value = self._value * self.alpha + v * (1 - self.alpha)
+ self._steps += 1
+
+ @property
+ def value(self):
+ if self._steps <= 0:
+ return self._value
+ return self._value / (1 - pow(self.alpha, self._steps))
diff --git a/examples/BMTrain/bmtrain/wrapper.py b/examples/BMTrain/bmtrain/wrapper.py
new file mode 100644
index 00000000..e64fd5ba
--- /dev/null
+++ b/examples/BMTrain/bmtrain/wrapper.py
@@ -0,0 +1,54 @@
+import torch
+from .block_layer import Block, TransformerBlockList
+from .layer import DistributedModule, DistributedParameter
+
+
+def make_distributed(model: torch.nn.Module):
+ for kw in list(model._parameters.keys()):
+ if model._parameters[kw] is not None:
+ if not isinstance(model._parameters[kw], DistributedParameter):
+ model._parameters[kw] = DistributedParameter(
+ model._parameters[kw],
+ requires_grad=model._parameters[kw].requires_grad,
+ )
+
+ for kw in list(model._buffers.keys()):
+ if model._buffers[kw] is not None:
+ model._buffers[kw] = model._buffers[kw].cuda()
+ is_module_list = isinstance(model, torch.nn.ModuleList)
+ pre_module = None
+ for kw in list(model._modules.keys()):
+ if is_module_list:
+ if not isinstance(model._modules[kw], Block):
+ model._modules[kw] = Block(model_wrapper_dispatch(model._modules[kw]))
+ if pre_module is not None:
+ model._modules[kw].set_pre_module(pre_module)
+ pre_module = model._modules[kw]
+ else:
+ model._modules[kw] = model_wrapper_dispatch(model._modules[kw])
+
+ model.__class__ = type(
+ "bmtrain.Distributed" + model.__class__.__name__,
+ (model.__class__, DistributedModule),
+ {},
+ )
+ return model
+
+
+def model_wrapper_dispatch(model: torch.nn.Module):
+ if isinstance(model, TransformerBlockList):
+ return model
+ elif isinstance(model, DistributedModule):
+ return model
+ elif isinstance(model, Block):
+ return model
+ else:
+ return make_distributed(model)
+
+
+def BMTrainModelWrapper(model: torch.nn.Module) -> torch.nn.Module:
+ """
+ Automatically wrap a model in a BMTrain model.
+ Replaces all parameters with DistributedParameter, all modules with DistributedModule, and modules in ModuleList with Block.
+ """
+ return model_wrapper_dispatch(model)
diff --git a/examples/BMTrain/bmtrain/zero_context.py b/examples/BMTrain/bmtrain/zero_context.py
new file mode 100644
index 00000000..8a74b3f8
--- /dev/null
+++ b/examples/BMTrain/bmtrain/zero_context.py
@@ -0,0 +1,203 @@
+import torch
+from . import nccl
+from .global_var import config
+from .synchronize import wait_loader
+
+
+class ZeroContext:
+ """ZeroContext is a helper class to Gather parameters before module forward and reduce scatter
+ gradients after module backward.
+
+ Args:
+ block (BLock): Input Block.
+ ctx_dict (dict): block._layer_dict.
+ pipe (bool): True if use pipe parallel.
+
+ """
+
+ def __init__(self, block: "Block", ctx_dict: dict = None, pipe=False) -> None:
+ self.block = block
+ self.ctx_dict = ctx_dict
+ self._param_buffer = {}
+ self._grad_buffer = {}
+ self._param_tensor = {}
+ self._grad_tensor = {}
+ self._need_release = False
+
+ def enter(self, flag=0, requires_grad=False):
+ """
+ Gather parameters before module forward and init grad buffer before backward.
+ """
+ if self.block._ready:
+ return
+ self.block._ready = True
+ self._need_release = True
+
+ wait_loader()
+ with torch.cuda.stream(config["load_stream"]):
+ for kw, val in self.block._storage_info.items():
+ assert self.block._storage_params[kw].is_cuda
+ assert kw not in self._grad_buffer
+ assert kw not in self._param_buffer
+ local_param = self.block._storage_params[kw]
+
+ storage_type = local_param.storage_type()
+ if flag != 2:
+ self._param_buffer[kw] = storage_type(
+ val["partition_size"] * val["world_size"]
+ )
+ self._param_tensor[kw] = torch.tensor(
+ [],
+ dtype=self._param_buffer[kw].dtype,
+ device=self._param_buffer[kw].device,
+ ).set_(self._param_buffer[kw])
+
+ if requires_grad and local_param.requires_grad:
+ self._grad_buffer[kw] = storage_type(
+ val["partition_size"] * val["world_size"]
+ )
+ self._grad_tensor[kw] = (
+ torch.tensor(
+ [],
+ dtype=self._grad_buffer[kw].dtype,
+ device=self._grad_buffer[kw].device,
+ )
+ .set_(self._grad_buffer[kw])
+ .zero_()
+ )
+ if flag != 2:
+ nccl.groupStart()
+ for kw, val in self.block._storage_info.items():
+ nccl.allGather(
+ self.block._storage_params[kw].storage(),
+ self._param_buffer[kw],
+ val["zero_comm"],
+ )
+ nccl.groupEnd()
+
+ current_stream = torch.cuda.current_stream()
+ current_stream.wait_stream(config["load_stream"])
+
+ # set wait stream for each storage
+ for kw in self.block._storage_info.keys():
+ if flag != 2:
+ self._param_tensor[kw].record_stream(current_stream)
+ if requires_grad and kw in self._grad_tensor:
+ self._grad_tensor[kw].record_stream(current_stream)
+
+ # update parameters in block
+ for param in self.block._param_info:
+ kw_name = param["kw_name"]
+ offset = param["offset"]
+ shape = param["shape"]
+
+ if flag != 2:
+ dtype = self._param_buffer[kw_name].dtype
+ device = self._param_buffer[kw_name].device
+ param["parameter"].data = torch.tensor(
+ [], dtype=dtype, device=device
+ ).set_(self._param_buffer[kw_name], offset, shape)
+ else:
+ dtype = param["parameter"].data.dtype
+ device = param["parameter"].data.device
+ param["parameter"].data = torch.tensor(
+ [], dtype=dtype, device=device
+ ).set_(self.ctx_dict[kw_name], offset, shape)
+
+ if (
+ requires_grad
+ and kw_name in self._grad_buffer
+ and param["parameter"].requires_grad
+ ):
+ param["parameter"].grad = torch.tensor(
+ [], dtype=dtype, device=device
+ ).set_(self._grad_buffer[kw_name], offset, shape)
+
+ def __enter__(self):
+ self.enter()
+
+ def exit(self, flag=0, backward=False):
+ """
+ Reduce scatter gradients when backward and release all parameters from buffer to block_storge when forward is done.
+ """
+ if not self._need_release:
+ return
+ self._need_release = False
+ self.block._ready = False
+ if backward:
+ for kw, val in self.block._storage_info.items():
+ local_param = self.block._storage_params[kw]
+
+ # accumulate previous gradient
+ if local_param.requires_grad:
+ if local_param.grad is None:
+ grad_storage = val["storage_type"](
+ val["partition_size"]
+ ) # initialize gradient if not exist
+ local_param.grad = (
+ torch.tensor(
+ [], dtype=grad_storage.dtype, device=grad_storage.device
+ )
+ .set_(grad_storage)
+ .zero_()
+ )
+ else:
+ self._grad_tensor[kw][
+ val["begin"] : val["end"]
+ ] += local_param.grad
+
+ current_stream = torch.cuda.current_stream()
+ config["load_stream"].wait_stream(current_stream) # wait for backward
+
+ with torch.cuda.stream(config["load_stream"]):
+ nccl.groupStart()
+ for kw, val in self.block._storage_info.items():
+ local_param = self.block._storage_params[kw]
+
+ # scatter gradient
+ if local_param.requires_grad:
+ nccl.reduceScatter(
+ self._grad_buffer[kw],
+ local_param.grad.storage(),
+ "sum",
+ val["zero_comm"],
+ )
+ nccl.groupEnd()
+
+ # set wait stream for each storage
+ for kw in self._grad_tensor.keys():
+ # grads can not be freed until reduce ops finish
+ self._grad_tensor[kw].record_stream(config["load_stream"])
+
+ # Release all parameters from buffer to block_storge
+ for param in self.block._param_info:
+ kw_name = param["kw_name"]
+ dtype = self.block._storage_params[kw_name].dtype
+ device = self.block._storage_params[kw_name].device
+ if "begin" not in param:
+ param["parameter"].data = torch.tensor([], dtype=dtype, device=device)
+ param["parameter"].grad = None
+ continue
+ begin = param["begin"]
+ end = param["end"]
+ param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(
+ self.block._storage_params[kw_name].storage(), begin, end
+ )
+ if (
+ param["parameter"].requires_grad
+ and self.block._storage_params[kw_name].grad is not None
+ ):
+ param["parameter"].grad = torch.tensor(
+ [], dtype=dtype, device=device
+ ).set_(self.block._storage_params[kw_name].grad.storage(), begin, end)
+ if flag == 1:
+ for i in self._param_buffer:
+ self.ctx_dict[i] = self._param_buffer[i]
+ self._grad_tensor = {}
+ self._param_tensor = {}
+ self._grad_buffer = {}
+ self._param_buffer = {}
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ # reduce scatter gradients
+ self.exit()
diff --git a/examples/BMTrain/cmake/FindNCCL.cmake b/examples/BMTrain/cmake/FindNCCL.cmake
new file mode 100644
index 00000000..2af8e3b9
--- /dev/null
+++ b/examples/BMTrain/cmake/FindNCCL.cmake
@@ -0,0 +1,100 @@
+list(APPEND NCCL_ROOT $ENV{NCCL_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR})
+if(DEFINED ENV{NCCL_ROOT_DIR})
+ set(NCCL_ROOT_DIR $ENV{NCCL_ROOT_DIR})
+ set(NCCL_INCLUDE_DIR "${NCCL_ROOT_DIR}/include" CACHE PATH "Folder contains NVIDIA NCCL headers")
+ set(NCCL_LIB_DIR "${NCCL_ROOT_DIR}/lib" CACHE PATH "Folder contains NVIDIA NCCL libraries")
+else()
+ set(NCCL_INCLUDE_DIR $ENV{NCCL_INCLUDE_DIR} CACHE PATH "Folder contains NVIDIA NCCL headers")
+ set(NCCL_LIB_DIR $ENV{NCCL_LIB_DIR} CACHE PATH "Folder contains NVIDIA NCCL libraries")
+endif()
+
+# Compatible layer for CMake <3.12. NCCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
+if(NOT NCCL_INCLUDE_DIR OR NOT NCCL_LIB_DIR)
+ execute_process(
+ COMMAND python -c "import nvidia.nccl;import os; print(os.path.dirname(nvidia.nccl.__file__))"
+ OUTPUT_VARIABLE NCCL_PIP_DIR
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ )
+ list(APPEND NCCL_ROOT $ENV{NCCL_PIP_DIR})
+ if(NOT NCCL_INCLUDE_DIR)
+ set(NCCL_INCLUDE_DIR "${NCCL_PIP_DIR}/include")
+ endif()
+ if(NOT NCCL_LIB_DIR)
+ set(NCCL_LIB_DIR "${NCCL_PIP_DIR}/lib")
+ endif()
+ find_library(NCCL_LIBRARIES
+ NAMES ${NCCL_LIBNAME}
+ HINTS ${NCCL_LIB_DIR})
+endif()
+
+list(APPEND CMAKE_PREFIX_PATH ${NCCL_ROOT})
+find_path(NCCL_INCLUDE_DIRS
+ NAMES nccl.h
+ HINTS ${NCCL_INCLUDE_DIR})
+
+
+
+if (USE_STATIC_NCCL)
+ MESSAGE(STATUS "USE_STATIC_NCCL is set. Linking with static NCCL library.")
+ SET(NCCL_LIBNAME "nccl_static")
+ if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified
+ set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES})
+ endif()
+else()
+ SET(NCCL_LIBNAME "nccl")
+
+ if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified
+ message(STATUS "NCCL version: ${NCCL_VERSION}")
+ set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES})
+ else()
+ set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.2" ${CMAKE_FIND_LIBRARY_SUFFIXES})
+ endif()
+
+endif()
+
+find_library(NCCL_LIBRARIES
+ NAMES ${NCCL_LIBNAME}
+ HINTS ${NCCL_LIB_DIR})
+
+
+include(FindPackageHandleStandardArgs)
+find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
+
+if(NCCL_FOUND) # obtaining NCCL version and some sanity checks
+ set (NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
+ message (STATUS "Determining NCCL version from ${NCCL_HEADER_FILE}...")
+ set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES})
+ list (APPEND CMAKE_REQUIRED_INCLUDES ${NCCL_INCLUDE_DIRS})
+ include(CheckCXXSymbolExists)
+ check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED)
+
+ if (NCCL_VERSION_DEFINED)
+ set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc")
+ file(WRITE ${file} "
+ #include
+ #include
+ int main()
+ {
+ std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl;
+ int x;
+ ncclGetVersion(&x);
+ return x == NCCL_VERSION_CODE;
+ }
+")
+ try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file}
+ RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER
+ CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${NCCL_INCLUDE_DIRS}"
+ LINK_LIBRARIES ${NCCL_LIBRARIES})
+ if (NOT NCCL_VERSION_MATCHED)
+ message(FATAL_ERROR "Found NCCL header version and library version do not match! \
+(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.")
+ endif()
+ message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}")
+ else()
+ # message(STATUS "NCCL version < 2.3.5-5")
+ endif ()
+ set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES})
+
+ message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
+ mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
+endif()
diff --git a/examples/BMTrain/csrc/bind.cpp b/examples/BMTrain/csrc/bind.cpp
new file mode 100644
index 00000000..b8f6fa85
--- /dev/null
+++ b/examples/BMTrain/csrc/bind.cpp
@@ -0,0 +1,35 @@
+#include "include/bind.hpp"
+
+PYBIND11_MODULE(C, m) {
+ m.def("to_fp16_from_fp32", &fp16_from_fp32_value_launcher, "convert");
+ m.def("to_bf16_from_fp32", &bf16_from_fp32_value_launcher, "convert");
+ m.def("is_bf16_supported", &is_bf16_supported, "whether bf16 supported");
+ m.def("has_nan_inf_fp16_launcher", &has_nan_inf_fp16_launcher, "has nan inf");
+ m.def("has_nan_inf_bf16_launcher", &has_nan_inf_bf16_launcher, "has nan inf bf16");
+ m.def("adam_fp16_launcher", &adam_fp16_launcher, "adam function cpu");
+ m.def("adam_bf16_launcher", &adam_bf16_launcher, "adam function cpu");
+ m.def("adam_cpu_fp16_launcher", &adam_cpu_fp16_launcher, "adam function cpu");
+ m.def("adam_cpu_bf16_launcher", &adam_cpu_bf16_launcher, "adam function cpu");
+ m.def("cross_entropy_forward_fp16_launcher", &cross_entropy_forward_fp16_launcher, "cross entropy forward");
+ m.def("cross_entropy_forward_bf16_launcher", &cross_entropy_forward_bf16_launcher, "cross entropy forward");
+ m.def("cross_entropy_backward_inplace_fp16_launcher", &cross_entropy_backward_inplace_fp16_launcher, "cross entropy backward inplace");
+ m.def("cross_entropy_backward_inplace_bf16_launcher", &cross_entropy_backward_inplace_bf16_launcher, "cross entropy backward inplace");
+ m.def("fused_sumexp_fp16_launcher", &fused_sumexp_fp16_launcher, "sum exp");
+ m.def("fused_sumexp_bf16_launcher", &fused_sumexp_bf16_launcher, "sum exp");
+ m.def("fused_softmax_inplace_fp16_launcher", &fused_softmax_inplace_fp16_launcher, "softmax inplace");
+ m.def("fused_softmax_inplace_bf16_launcher", &fused_softmax_inplace_bf16_launcher, "softmax inplace");
+ m.def("ncclGetUniqueId", &pyNCCLGetUniqueID, "nccl get unique ID");
+ m.def("ncclCommInitRank", &pyNCCLCommInitRank, "nccl init rank");
+ m.def("ncclCommDestroy", &pyNCCLCommDestroy, "nccl delete rank");
+ m.def("ncclAllGather", &pyNCCLAllGather, "nccl all gather");
+ m.def("ncclAllReduce", &pyNCCLAllReduce, "nccl all reduce");
+ m.def("ncclBroadcast", &pyNCCLBroadcast, "nccl broadcast");
+ m.def("ncclReduce", &pyNCCLReduce, "nccl reduce");
+ m.def("ncclReduceScatter", &pyNCCLReduceScatter, "nccl reduce scatter");
+ m.def("ncclGroupStart", &pyNCCLGroupStart, "nccl group start");
+ m.def("ncclGroupEnd", &pyNCCLGroupEnd, "nccl group end");
+ m.def("ncclSend", &pyNCCLSend, "nccl send");
+ m.def("ncclRecv", &pyNCCLRecv, "nccl recv");
+ m.def("ncclCommCount", &pyNCCLCommCount, "nccl comm count");
+ m.def("ncclCommUserRank", &pyNCCLCommUserRank, "nccl comm user rank");
+}
diff --git a/examples/BMTrain/csrc/cuda/adam_cuda.cu b/examples/BMTrain/csrc/cuda/adam_cuda.cu
new file mode 100644
index 00000000..0510ac12
--- /dev/null
+++ b/examples/BMTrain/csrc/cuda/adam_cuda.cu
@@ -0,0 +1,126 @@
+#include
+#include
+#include
+#include "bfloat16.cuh"
+
+namespace {
+// blocks , threads
+__global__ void adam_fp32_accum(
+ int32_t n,
+ const half *g, // (n)
+ half *m, // (n)
+ float *v, // (n)
+ float *param, // (n)
+ half *param_h, // (n)
+ float beta1,
+ float beta2,
+ float eps,
+ float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+) {
+ int32_t col = blockIdx.x * blockDim.x + threadIdx.x;
+ if (col < n) {
+ float local_g = __half2float(g[col]); // real_g * scale
+ float local_m = beta1 * __half2float(m[col]) + (1 - beta1) * local_g; // real_m * scale
+ float local_v = beta2 * v[col] + (1 - beta2) * local_g * local_g / scale; // real_v * scale
+ float local_p = param[col];
+ local_p = local_p - lr * local_m / bias_correction1 / (sqrtf(local_v * scale / bias_correction2) + eps * scale) - lr * weight_decay * local_p;
+
+ param_h[col] = __float2half(local_p);
+ param[col] = local_p;
+ v[col] = local_v;
+ m[col] = __float2half(local_m);
+ }
+}
+
+__global__ void adam_fp32_accum_bf16(
+ int32_t n,
+ const std::uintptr_t g_ptr, // (n)
+ float *m, // (n)
+ float *v, // (n)
+ float *param, // (n)
+ std::uintptr_t param_h_ptr, // (n)
+ float beta1,
+ float beta2,
+ float eps,
+ float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+) {
+#ifdef BF16_SUPPORT
+ const __nv_bfloat16* g = reinterpret_cast(g_ptr);
+ __nv_bfloat16* param_h = reinterpret_cast<__nv_bfloat16*>(param_h_ptr);
+ int32_t col = blockIdx.x * blockDim.x + threadIdx.x;
+ if (col < n) {
+ float local_g = __bfloat162float(g[col]) / scale; // real_g
+ float local_m = beta1 * m[col] + (1 - beta1) * local_g; // real_m
+ float local_v = beta2 * v[col] + (1 - beta2) * local_g * local_g; // real_v
+ float local_p = param[col];
+ local_p = local_p - lr * local_m / bias_correction1 / (sqrtf(local_v / bias_correction2) + eps) - lr * weight_decay * local_p;
+
+ param_h[col] = __float2bfloat16(local_p);
+ param[col] = local_p;
+ v[col] = local_v;
+ m[col] = local_m;
+ }
+#endif
+}
+
+}
+
+void adam_fp16_launcher(
+ int n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_fp16,
+ std::uintptr_t g_fp16,
+ std::uintptr_t m_fp16,
+ std::uintptr_t v_fp32,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2,
+ uintptr_t stream
+) {
+ if (n <= 0) return;
+ auto g_ptr = reinterpret_cast(g_fp16);
+ auto m_ptr = reinterpret_cast(m_fp16);
+ auto param_h_ptr = reinterpret_cast(param_fp16);
+ auto param_fp32_ptr = reinterpret_cast(param_fp32);
+ auto v_fp32_ptr = reinterpret_cast(v_fp32);
+ int32_t threads = 1024;
+ dim3 block_size = dim3(threads, 1, 1);
+ dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1);
+ adam_fp32_accum<<(stream)>>>(n, g_ptr, m_ptr, v_fp32_ptr, param_fp32_ptr, param_h_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
+}
+
+void adam_bf16_launcher(
+ int n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_bf16,
+ std::uintptr_t g_bf16,
+ std::uintptr_t m_fp32,
+ std::uintptr_t v_fp32,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2,
+ uintptr_t stream
+) {
+ if (n <= 0) return;
+ auto m_ptr = reinterpret_cast(m_fp32);
+ auto param_fp32_ptr = reinterpret_cast(param_fp32);
+ auto v_fp32_ptr = reinterpret_cast(v_fp32);
+ int32_t threads = 1024;
+ dim3 block_size = dim3(threads, 1, 1);
+ dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1);
+ adam_fp32_accum_bf16<<(stream)>>>(n, g_bf16, m_ptr, v_fp32_ptr, param_fp32_ptr, param_bf16, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
+}
diff --git a/examples/BMTrain/csrc/cuda/bfloat16.cuh b/examples/BMTrain/csrc/cuda/bfloat16.cuh
new file mode 100644
index 00000000..564d8bec
--- /dev/null
+++ b/examples/BMTrain/csrc/cuda/bfloat16.cuh
@@ -0,0 +1,5 @@
+#include
+#if defined(__CUDACC__) && CUDA_VERSION >= 11000
+#include
+#define BF16_SUPPORT
+#endif
\ No newline at end of file
diff --git a/examples/BMTrain/csrc/cuda/cross_entropy.cu b/examples/BMTrain/csrc/cuda/cross_entropy.cu
new file mode 100644
index 00000000..177c3b77
--- /dev/null
+++ b/examples/BMTrain/csrc/cuda/cross_entropy.cu
@@ -0,0 +1,315 @@
+#include "reduce.cuh"
+#include
+#include
+#include
+#include "bfloat16.cuh"
+
+namespace {
+// blocks , threads<1024>
+__global__ void cross_entropy_forward_fp16(
+ int64_t n,
+ const half *input, // (m, n)
+ const int32_t *target, // (m)
+ half *softmax, // (m, n)
+ float *output, // (m)
+ int32_t ignore_index
+) {
+ int64_t base_idx = blockIdx.x * n;
+
+ float local_max = -INFINITY;
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ local_max = fmaxf(__half2float(input[base_idx + i]), local_max);
+ }
+
+ local_max = fmaxf(block_allreduce_max(local_max), -1e6);
+
+ float local_sum = 0;
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ local_sum += expf(__half2float(input[base_idx + i]) - local_max);
+ }
+ local_sum = block_allreduce_sum(local_sum) + 1e-10; // avoid nan
+
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ softmax[base_idx + i] = __float2half( expf(__half2float(input[base_idx + i]) - local_max) / local_sum );
+ }
+
+ if (threadIdx.x == 0) {
+ if (target[blockIdx.x] != ignore_index) {
+ output[blockIdx.x] = -__half2float(input[base_idx + target[blockIdx.x]]) + local_max + logf(local_sum);
+ } else {
+ output[blockIdx.x] = 0;
+ }
+ }
+}
+
+// blocks , threads<1024>
+__global__ void cross_entropy_backward_inplace_fp16(
+ int64_t n,
+ const float *grad_output, // (m)
+ const int32_t *target, // (m)
+ half *x, // (m, n)
+ int32_t ignore_index
+) {
+ int64_t base_idx = blockIdx.x * n;
+
+ int32_t t = target[blockIdx.x];
+ float v = grad_output[blockIdx.x];
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ x[base_idx + i] = __float2half(i==t ? (__half2float(x[base_idx + i])-1)*v : __half2float(x[base_idx + i])*v);
+ }
+}
+
+// blocks , threads<1024>
+__global__ void cross_entropy_forward_bf16(
+ int64_t n,
+ const std::uintptr_t input_ptr, // (m, n)
+ const int32_t *target, // (m)
+ std::uintptr_t softmax_ptr, // (m, n)
+ float *output, // (m)
+ int32_t ignore_index
+) {
+#ifdef BF16_SUPPORT
+ const __nv_bfloat16* input = reinterpret_cast(input_ptr);
+ __nv_bfloat16* softmax = reinterpret_cast<__nv_bfloat16*>(softmax_ptr);
+ int64_t base_idx = blockIdx.x * n;
+
+ float local_max = -INFINITY;
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ local_max = fmaxf(__bfloat162float(input[base_idx + i]), local_max);
+ }
+
+ local_max = fmaxf(block_allreduce_max(local_max), -1e6);
+
+ float local_sum = 0;
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ local_sum += expf(__bfloat162float(input[base_idx + i]) - local_max);
+ }
+ local_sum = block_allreduce_sum(local_sum) + 1e-10; // avoid nan
+
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ softmax[base_idx + i] = __float2bfloat16( expf(__bfloat162float(input[base_idx + i]) - local_max) / local_sum );
+ }
+
+ if (threadIdx.x == 0) {
+ if (target[blockIdx.x] != ignore_index) {
+ output[blockIdx.x] = -__bfloat162float(input[base_idx + target[blockIdx.x]]) + local_max + logf(local_sum);
+ } else {
+ output[blockIdx.x] = 0;
+ }
+ }
+#endif
+}
+
+// blocks , threads<1024>
+__global__ void cross_entropy_backward_inplace_bf16(
+ int64_t n,
+ const float *grad_output, // (m)
+ const int32_t *target, // (m)
+ std::uintptr_t x_ptr, // (m, n)
+ int32_t ignore_index
+) {
+#ifdef BF16_SUPPORT
+ __nv_bfloat16* x = reinterpret_cast<__nv_bfloat16*>(x_ptr);
+ int64_t base_idx = blockIdx.x * n;
+
+ int32_t t = target[blockIdx.x];
+ float v = grad_output[blockIdx.x];
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ x[base_idx + i] = __float2bfloat16(i==t ? (__bfloat162float(x[base_idx + i])-1)*v : __bfloat162float(x[base_idx + i])*v);
+ }
+#endif
+}
+
+// blocks , threads<1024>
+__global__ void fused_sumexp_fp16(
+ int64_t n,
+ const half *input, // (m, n)
+ const float *global_max, // (m)
+ float *global_sum // (m)
+) {
+ int64_t base_idx = blockIdx.x * n;
+ float local_max = global_max[blockIdx.x];
+
+ float local_sum = 0;
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ local_sum += expf(__half2float(input[base_idx + i]) - local_max);
+ }
+ local_sum = block_allreduce_sum(local_sum);
+ if (threadIdx.x == 0) {
+ global_sum[blockIdx.x] = local_sum;
+ }
+}
+
+// blocks , threads<1024>
+__global__ void fused_sumexp_bf16(
+ int64_t n,
+ const std::uintptr_t input_ptr, // (m, n)
+ const float *global_max, // (m)
+ float *global_sum // (m)
+) {
+#ifdef BF16_SUPPORT
+ const __nv_bfloat16* input = reinterpret_cast(input_ptr);
+ int64_t base_idx = blockIdx.x * n;
+ float local_max = global_max[blockIdx.x];
+
+ float local_sum = 0;
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ local_sum += expf(__bfloat162float(input[base_idx + i]) - local_max);
+ }
+ local_sum = block_allreduce_sum(local_sum);
+ if (threadIdx.x == 0) {
+ global_sum[blockIdx.x] = local_sum;
+ }
+#endif
+}
+
+// blocks , threads<1024>
+__global__ void fused_softmax_inplace_fp16(
+ int64_t n,
+ half *softmax, // (m, n)
+ const float *global_max, // (m)
+ const float *global_sum // (m)
+) {
+ int64_t base_idx = blockIdx.x * n;
+ float local_max = global_max[blockIdx.x];
+ float local_sum = global_sum[blockIdx.x];
+
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ softmax[base_idx + i] = __float2half( expf(__half2float(softmax[base_idx + i]) - local_max) / local_sum );
+ }
+}
+
+// blocks , threads<1024>
+__global__ void fused_softmax_inplace_bf16(
+ int64_t n,
+ std::uintptr_t softmax_ptr, // (m, n)
+ const float *global_max, // (m)
+ const float *global_sum // (m)
+) {
+#ifdef BF16_SUPPORT
+ __nv_bfloat16* softmax = reinterpret_cast<__nv_bfloat16*>(softmax_ptr);
+ int64_t base_idx = blockIdx.x * n;
+ float local_max = global_max[blockIdx.x];
+ float local_sum = global_sum[blockIdx.x];
+
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ softmax[base_idx + i] = __float2bfloat16( expf(__bfloat162float(softmax[base_idx + i]) - local_max) / local_sum );
+ }
+#endif
+}
+}
+
+void cross_entropy_forward_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t input,
+ std::uintptr_t target,
+ std::uintptr_t softmax,
+ std::uintptr_t output,
+ int32_t ignore_index,
+ std::uintptr_t stream
+) {
+ auto input_ptr = reinterpret_cast(input);
+ auto target_ptr = reinterpret_cast(target);
+ auto softmax_ptr = reinterpret_cast(softmax);
+ auto output_ptr = reinterpret_cast(output);
+ int32_t threads = 1024;
+ cross_entropy_forward_fp16<<(stream)>>>(n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index);
+}
+
+void cross_entropy_backward_inplace_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t grad_output,
+ std::uintptr_t target,
+ std::uintptr_t x,
+ int32_t ignore_index,
+ std::uintptr_t stream
+) {
+ auto output_ptr = reinterpret_cast(grad_output);
+ auto target_ptr = reinterpret_cast(target);
+ auto x_ptr = reinterpret_cast(x);
+ int32_t threads = 1024;
+ cross_entropy_backward_inplace_fp16<<(stream)>>>(n, output_ptr, target_ptr, x_ptr, ignore_index);
+}
+
+void cross_entropy_forward_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t input,
+ std::uintptr_t target,
+ std::uintptr_t softmax,
+ std::uintptr_t output,
+ int32_t ignore_index,
+ std::uintptr_t stream
+) {
+ auto target_ptr = reinterpret_cast(target);
+ auto output_ptr = reinterpret_cast(output);
+ int32_t threads = 1024;
+ cross_entropy_forward_bf16<<(stream)>>>(n, input, target_ptr, softmax, output_ptr, ignore_index);
+}
+
+void cross_entropy_backward_inplace_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t grad_output,
+ std::uintptr_t target,
+ std::uintptr_t x,
+ int32_t ignore_index,
+ std::uintptr_t stream
+) {
+ auto output_ptr = reinterpret_cast(grad_output);
+ auto target_ptr = reinterpret_cast(target);
+ int32_t threads = 1024;
+ cross_entropy_backward_inplace_bf16<<(stream)>>>(n, output_ptr, target_ptr, x, ignore_index);
+}
+
+void fused_sumexp_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+) {
+ auto logits_ptr = reinterpret_cast(logits);
+ auto max_logits_ptr = reinterpret_cast(max_logits);
+ auto sum_exp_logits_ptr = reinterpret_cast(sum_exp_logits);
+ int32_t threads = 1024;
+ fused_sumexp_fp16<<(stream)>>>(n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr);
+}
+
+void fused_sumexp_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+) {
+ auto max_logits_ptr = reinterpret_cast(max_logits);
+ auto sum_exp_logits_ptr = reinterpret_cast(sum_exp_logits);
+ int32_t threads = 1024;
+ fused_sumexp_bf16<<(stream)>>>(n, logits, max_logits_ptr, sum_exp_logits_ptr);
+}
+
+void fused_softmax_inplace_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+) {
+ auto logits_ptr = reinterpret_cast(logits);
+ auto max_logits_ptr = reinterpret_cast(max_logits);
+ auto sum_exp_logits_ptr = reinterpret_cast(sum_exp_logits);
+ int32_t threads = 1024;
+ fused_softmax_inplace_fp16<<(stream)>>>(n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr);
+}
+
+void fused_softmax_inplace_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+) {
+ auto max_logits_ptr = reinterpret_cast(max_logits);
+ auto sum_exp_logits_ptr = reinterpret_cast(sum_exp_logits);
+ int32_t threads = 1024;
+ fused_softmax_inplace_bf16<<(stream)>>>(n, logits, max_logits_ptr, sum_exp_logits_ptr);
+}
\ No newline at end of file
diff --git a/examples/BMTrain/csrc/cuda/has_inf_nan.cu b/examples/BMTrain/csrc/cuda/has_inf_nan.cu
new file mode 100644
index 00000000..32bc5a5f
--- /dev/null
+++ b/examples/BMTrain/csrc/cuda/has_inf_nan.cu
@@ -0,0 +1,145 @@
+#include
+#include
+#include
+#include
+#include "bfloat16.cuh"
+
+namespace{
+__inline__ __device__ bool isnan_(half v) {
+ #if __CUDA_ARCH__ >= 700 || __CUDA_ARCH__ == 600
+ return __hisnan(v);
+ #else
+ return !__heq(v, v);
+ #endif
+}
+
+__inline__ __device__ int8_t warpReduceAny(int8_t x) {
+ for (int offset = warpSize/2; offset > 0; offset /= 2)
+ x |= __shfl_down_sync(0xFFFFFFFF, x, offset);
+ return x;
+}
+
+__inline__ __device__ float blockReduceAny(int8_t x) {
+ static __shared__ float shared[32];
+ int lane = threadIdx.x % warpSize;
+ int wid = threadIdx.x / warpSize;
+ x = warpReduceAny(x);
+ if (lane == 0) shared[wid] = x;
+ __syncthreads();
+ x = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0;
+ if (wid == 0) x = warpReduceAny(x);
+ return x;
+}
+
+// grid , thread<1024>
+__global__ void bmt_has_nan_inf_fp16(
+ int32_t n,
+ const half* inp, // (n,)
+ uint8_t* mid // (1024,)
+) {
+ int32_t gid = blockIdx.x * blockDim.x + threadIdx.x;
+ int32_t span = blockDim.x * gridDim.x;
+
+ int8_t r = 0;
+ for (int i = gid; i < n; i += span) {
+ half v = inp[i];
+ if (__hisinf(v) || isnan_(v)) {
+ r = 1;
+ break;
+ }
+ }
+ r = blockReduceAny(r);
+ if (threadIdx.x == 0) {
+ mid[blockIdx.x] = r;
+ }
+}
+
+// grid <1>, thread<1024>
+__global__ void bmt_has_nan_inf_reduce(
+ const uint8_t* mid, // (1024,)
+ uint8_t* out
+) {
+ int tid = threadIdx.x;
+ int8_t r = blockReduceAny(mid[tid]);
+ if (tid == 0 && r > 0) {
+ out[0] = 1;
+ }
+}
+
+// grid , thread<1024>
+__global__ void bmt_has_nan_inf_bf16(
+ int32_t n,
+ const uintptr_t inp, // (n,)
+ uint8_t* mid // (1024,)
+) {
+#ifdef BF16_SUPPORT
+ const __nv_bfloat16* bf_inp = reinterpret_cast(inp);
+ int32_t gid = blockIdx.x * blockDim.x + threadIdx.x;
+ int32_t span = blockDim.x * gridDim.x;
+
+ int8_t r = 0;
+ for (int i = gid; i < n; i += span) {
+ __nv_bfloat16 v = bf_inp[i];
+ #if __CUDA_ARCH__ >= 800
+ if (__hisinf(v) || __hisnan(v)) {
+ #else
+ if (isinf(__bfloat162float(v)) || isnan(__bfloat162float(v))) {
+ #endif
+ r = 1;
+ break;
+ }
+ }
+ r = blockReduceAny(r);
+ if (threadIdx.x == 0) {
+ mid[blockIdx.x] = r;
+ }
+#endif
+}
+
+}
+
+void has_nan_inf_fp16_launcher(
+ int32_t n,
+ std::uintptr_t g_fp16,
+ std::uintptr_t mid,
+ std::uintptr_t out,
+ std::uintptr_t stream
+) {
+ if (n <= 0) return;
+ auto g_ptr = reinterpret_cast(g_fp16);
+ auto mid_ptr = reinterpret_cast(mid);
+ auto out_ptr = reinterpret_cast(out);
+ int32_t threads = 1024;
+ dim3 block_size = dim3(threads, 1, 1);
+ dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1);
+ dim3 clamp_grid_size = dim3(min((n + threads - 1) / threads, 1024), 1, 1);
+
+ bmt_has_nan_inf_fp16<<(stream)>>>(n, g_ptr, mid_ptr);
+ bmt_has_nan_inf_reduce<<<1, block_size, 0, reinterpret_cast(stream)>>>(mid_ptr, out_ptr);
+}
+
+void has_nan_inf_bf16_launcher(
+ int32_t n,
+ std::uintptr_t g_bf16,
+ std::uintptr_t mid,
+ std::uintptr_t out,
+ std::uintptr_t stream
+) {
+ if (n <= 0) return;
+ auto mid_ptr = reinterpret_cast(mid);
+ auto out_ptr = reinterpret_cast(out);
+ int32_t threads = 1024;
+ dim3 block_size = dim3(threads, 1, 1);
+ dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1);
+ dim3 clamp_grid_size = dim3(min((n + threads - 1) / threads, 1024), 1, 1);
+
+ bmt_has_nan_inf_bf16<<(stream)>>>(n, g_bf16, mid_ptr);
+ bmt_has_nan_inf_reduce<<<1, block_size, 0, reinterpret_cast(stream)>>>(mid_ptr, out_ptr);
+}
+
+int is_bf16_supported() {
+#ifdef BF16_SUPPORT
+ return 1;
+#endif
+ return 0;
+}
\ No newline at end of file
diff --git a/examples/BMTrain/csrc/cuda/reduce.cuh b/examples/BMTrain/csrc/cuda/reduce.cuh
new file mode 100644
index 00000000..a9c4c15b
--- /dev/null
+++ b/examples/BMTrain/csrc/cuda/reduce.cuh
@@ -0,0 +1,114 @@
+namespace {
+const int WARP_SZ = 32;
+
+// blocks , threads<1024>
+__device__ float block_reduce_sum(float val) {
+ static __shared__ float s_x[WARP_SZ];
+ // int gid = threadIdx.x + blockIdx.x * blockDim.x;
+ int tid = threadIdx.x;
+ int lid = threadIdx.x % WARP_SZ;
+ int wid = threadIdx.x / WARP_SZ;
+
+ // reduce intra warp
+
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val += __shfl_down_sync(0xFFFFFFFF, val, offset);
+
+ if (lid == 0) s_x[wid] = val;
+ __syncthreads();
+
+ // reduce inter warp
+ val = (tid < WARP_SZ) ? s_x[lid] : 0;
+ if (wid == 0) {
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val += __shfl_down_sync(0xFFFFFFFF, val, offset);
+ }
+ return val;
+}
+
+// blocks , threads<1024>
+__device__ float block_reduce_max(float val) {
+ static __shared__ float s_x[WARP_SZ];
+ // int gid = threadIdx.x + blockIdx.x * blockDim.x;
+ int tid = threadIdx.x;
+ int lid = threadIdx.x % WARP_SZ;
+ int wid = threadIdx.x / WARP_SZ;
+
+ // reduce intra warp
+
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
+
+ if (lid == 0) s_x[wid] = val;
+ __syncthreads();
+
+ // reduce inter warp
+ val = (tid < WARP_SZ) ? s_x[lid] : -INFINITY;
+ if (wid == 0) {
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
+ }
+ return val;
+}
+
+// blocks , threads<1024>
+__device__ float block_allreduce_sum(float val) {
+ static __shared__ float s_x[WARP_SZ];
+ // int gid = threadIdx.x + blockIdx.x * blockDim.x;
+ int tid = threadIdx.x;
+ int lid = threadIdx.x % WARP_SZ;
+ int wid = threadIdx.x / WARP_SZ;
+
+ // reduce intra warp
+
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val += __shfl_down_sync(0xFFFFFFFF, val, offset);
+
+ if (lid == 0) s_x[wid] = val;
+ __syncthreads();
+
+ // reduce inter warp
+ val = (tid < WARP_SZ) ? s_x[lid] : 0;
+ if (wid == 0) {
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val += __shfl_down_sync(0xFFFFFFFF, val, offset);
+ }
+
+ if (tid == 0) {
+ s_x[0] = val;
+ }
+ __syncthreads();
+ return s_x[0];
+}
+
+// blocks , threads<1024>
+__device__ float block_allreduce_max(float val) {
+ static __shared__ float s_x[WARP_SZ];
+ // int gid = threadIdx.x + blockIdx.x * blockDim.x;
+ int tid = threadIdx.x;
+ int lid = threadIdx.x % WARP_SZ;
+ int wid = threadIdx.x / WARP_SZ;
+
+ // reduce intra warp
+
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
+
+ if (lid == 0) s_x[wid] = val;
+ __syncthreads();
+
+ // reduce inter warp
+ val = (tid < WARP_SZ) ? s_x[lid] : -INFINITY;
+ if (wid == 0) {
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
+ }
+
+ if (tid == 0) {
+ s_x[0] = val;
+ }
+ __syncthreads();
+ return s_x[0];
+}
+
+}
\ No newline at end of file
diff --git a/examples/BMTrain/csrc/include/adam_cpu.hpp b/examples/BMTrain/csrc/include/adam_cpu.hpp
new file mode 100644
index 00000000..52575d69
--- /dev/null
+++ b/examples/BMTrain/csrc/include/adam_cpu.hpp
@@ -0,0 +1,557 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "cpu_info.h"
+#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
+
+static inline float _mm256_reduce_add_ps(__m256 x) {
+ /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */
+ const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
+ /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */
+ const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
+ /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */
+ const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
+ /* Conversion to float is a no-op on x86-64 */
+ return _mm_cvtss_f32(x32);
+}
+
+inline float fp32_from_bits(uint32_t w) {
+ union {
+ uint32_t as_bits;
+ float as_value;
+ } fp32 = {w};
+ return fp32.as_value;
+}
+
+inline uint32_t fp32_to_bits(float f) {
+ union {
+ float as_value;
+ uint32_t as_bits;
+ } fp32 = {f};
+ return fp32.as_bits;
+}
+
+template
+inline void parallel_for(int64_t begin, int64_t end, int64_t grain_size, const F& f) {
+ // Number of iterations
+ int64_t numiter = end - begin;
+
+ // Number of threads to use
+ int64_t num_threads = 1; // Default to serial execution
+
+ if (grain_size > 0) {
+ num_threads = std::max((numiter+grain_size-1) / grain_size, static_cast(1));
+ }
+ else{
+ cpu_set_t cpu_set;
+ CPU_ZERO(&cpu_set);
+ sched_getaffinity(0, sizeof(cpu_set), &cpu_set);
+ num_threads = CPU_COUNT(&cpu_set);
+ grain_size = std::max((numiter+num_threads-1) / num_threads, static_cast(1));
+
+ }
+
+ // Check if parallel execution is feasible
+ if (num_threads > 1) {
+ py::gil_scoped_release release; // Release the GIL
+ std::vector threads(num_threads);
+ for (int64_t t = 0; t < num_threads; ++t) {
+ threads[t] = std::thread([&, t]() {
+ int64_t left = std::min(begin + t * grain_size, end);
+ int64_t right = std::min(begin + (t + 1) * grain_size, end);
+ f(left, right);
+ });
+ }
+ for (auto& thread : threads) {
+ thread.join();
+ }
+ } else {
+ // If not feasible or grain_size is 0, perform the operation serially
+ f(begin, end);
+ }
+}
+
+// fp32 -> fp16
+inline uint16_t fp16_ieee_from_fp32_value(float f) {
+ // const float scale_to_inf = 0x1.0p+112f;
+ // const float scale_to_zero = 0x1.0p-110f;
+ uint32_t scale_to_inf_bits = (uint32_t) 239 << 23;
+ uint32_t scale_to_zero_bits = (uint32_t) 17 << 23;
+ float scale_to_inf_val, scale_to_zero_val;
+ std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
+ std::memcpy(&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
+ const float scale_to_inf = scale_to_inf_val;
+ const float scale_to_zero = scale_to_zero_val;
+
+ float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
+
+ const uint32_t w = (uint32_t)fp32_to_bits(f);
+ const uint32_t shl1_w = w + w;
+ const uint32_t sign = w & UINT32_C(0x80000000);
+ uint32_t bias = shl1_w & UINT32_C(0xFF000000);
+ if (bias < UINT32_C(0x71000000)) {
+ bias = UINT32_C(0x71000000);
+ }
+
+ base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
+ const uint32_t bits = (uint32_t)fp32_to_bits(base);
+ const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
+ const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
+ const uint32_t nonsign = exp_bits + mantissa_bits;
+ return static_cast(
+ (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign)
+ );
+}
+
+// fp16 -> fp32
+inline float fp16_ieee_to_fp32_value(uint16_t h) {
+ const uint32_t w = (uint32_t)h << 16;
+ const uint32_t sign = w & UINT32_C(0x80000000);
+ const uint32_t two_w = w + w;
+
+ const uint32_t exp_offset = UINT32_C(0xE0) << 23;
+ const float exp_scale = 0x1.0p-112f;
+ const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
+
+ const uint32_t magic_mask = UINT32_C(126) << 23;
+ const float magic_bias = 0.5f;
+ const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
+
+ const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
+ const uint32_t result =
+ sign | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
+ : fp32_to_bits(normalized_value));
+ return fp32_from_bits(result);
+}
+
+inline uint16_t bf16_from_fp32_value(float f){
+ return *reinterpret_cast(&f) >> 16;
+}
+// fp32 -> bf16
+void bf16_from_fp32_value_launcher(
+ int64_t n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_bf16
+){
+ int span = 1;
+ auto param_fp32_ptr = reinterpret_cast(param_fp32);
+ auto param_bf16_ptr = reinterpret_cast(param_bf16);
+ parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
+ for (int64_t j = start; j < end; j += span) {
+ for (int64_t i = j; i < end; i++) {
+ float p = param_fp32_ptr[i];
+ param_bf16_ptr[i] = bf16_from_fp32_value(p);
+ }
+ break; // must break here
+ }
+ });
+}
+
+void fp16_from_fp32_value_launcher(
+ int64_t n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_fp16
+){
+ int span = 1;
+ auto param_fp32_ptr = reinterpret_cast(param_fp32);
+ auto param_fp16_ptr = reinterpret_cast(param_fp16);
+ parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
+ for (int64_t j = start; j < end; j += span) {
+ for (int64_t i = j; i < end; i++) {
+ float p = param_fp32_ptr[i];
+ param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p);
+ }
+ break; // must break here
+ }
+ });
+}
+// bf16 -> fp32
+inline float bf16_to_fp32_value(uint16_t h){
+ uint32_t src = h;
+ src <<= 16;
+ return *reinterpret_cast(&src);
+}
+
+void adam_cpu_0(
+ int64_t n,
+ float* param_fp32_ptr,
+ uint16_t* param_fp16_ptr,
+ float* delta_info_ptr,
+ uint16_t* g_fp16_ptr,
+ float* m_fp32_ptr,
+ float* v_fp32_ptr,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+){
+ int64_t span = 1;
+ float sum_sq_delta = 0;
+ float sum_delta = 0;
+ std::mutex delta_mutex;
+ parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
+ float sum_sq_delta_i = 0;
+ float sum_delta_i = 0;
+ for (int64_t j = start; j < end; j += span) {
+ for (int64_t i = j; i < end; i++) {
+ float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale;
+ float m = m_fp32_ptr[i];
+ float v = v_fp32_ptr[i];
+ float p = param_fp32_ptr[i];
+ m = beta1 * m + (1 - beta1) * g;
+ v = beta2 * v + (1 - beta2) * g * g;
+ if (delta_info_ptr != NULL){
+ float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p;
+ sum_delta_i += delta;
+ sum_sq_delta_i += delta * delta;
+ }
+ p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p;
+ param_fp32_ptr[i] = p;
+ param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p);
+ m_fp32_ptr[i] = m;
+ v_fp32_ptr[i] = v;
+ }
+ break; // must break here
+ }
+ if (delta_info_ptr != NULL){
+ delta_mutex.lock();
+ sum_delta += sum_delta_i;
+ sum_sq_delta += sum_sq_delta_i;
+ delta_mutex.unlock();
+ }
+ });
+ if (delta_info_ptr != NULL){
+ delta_info_ptr[0] = sum_delta / n;
+ delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2
+ delta_info_ptr[2] = sum_delta;
+ delta_info_ptr[3] = sum_sq_delta;
+ }
+}
+
+void adam_cpu_bf16_0(
+ int64_t n,
+ float* param_fp32_ptr,
+ uint16_t* param_bf16_ptr,
+ float* delta_info_ptr,
+ uint16_t* g_bf16_ptr,
+ float* m_fp32_ptr,
+ float* v_fp32_ptr,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+){
+ int64_t span = 1;
+ float sum_sq_delta = 0;
+ float sum_delta = 0;
+ std::mutex delta_mutex;
+ parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
+ float sum_sq_delta_i = 0;
+ float sum_delta_i = 0;
+ for (int64_t j = start; j < end; j += span) {
+ for (int64_t i = j; i < end; i++) {
+ float g = bf16_to_fp32_value(g_bf16_ptr[i]) / scale;
+ float m = m_fp32_ptr[i];
+ float v = v_fp32_ptr[i];
+ float p = param_fp32_ptr[i];
+ m = beta1 * m + (1 - beta1) * g;
+ v = beta2 * v + (1 - beta2) * g * g;
+ if (delta_info_ptr != NULL){
+ float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p;
+ sum_delta_i += delta;
+ sum_sq_delta_i += delta * delta;
+ }
+ p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p;
+ param_fp32_ptr[i] = p;
+ param_bf16_ptr[i] = bf16_from_fp32_value(p);
+ m_fp32_ptr[i] = m;
+ v_fp32_ptr[i] = v;
+ }
+ break; // must break here
+ }
+ if (delta_info_ptr != NULL){
+ delta_mutex.lock();
+ sum_delta += sum_delta_i;
+ sum_sq_delta += sum_sq_delta_i;
+ delta_mutex.unlock();
+ }
+ });
+ if (delta_info_ptr != NULL){
+ delta_info_ptr[0] = sum_delta / n;
+ delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2
+ delta_info_ptr[2] = sum_delta;
+ delta_info_ptr[3] = sum_sq_delta;
+ }
+}
+
+static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1(
+ int64_t n,
+ float* param_fp32_ptr,
+ uint16_t* param_fp16_ptr,
+ float* delta_info_ptr,
+ uint16_t* g_fp16_ptr,
+ float* m_fp32_ptr,
+ float* v_fp32_ptr,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+){
+ float sum_sq_delta = 0;
+ float sum_delta = 0;
+ std::mutex delta_mutex;
+ auto avx_beta1 = _mm256_set1_ps(beta1);
+ auto avx_beta2 = _mm256_set1_ps(beta2);
+ auto avx_beta1_1 = _mm256_set1_ps(1 - beta1);
+ auto avx_beta2_1 = _mm256_set1_ps(1 - beta2);
+ auto avx_eps = _mm256_set1_ps(eps);
+ auto avx_neg_lr = _mm256_set1_ps(-lr);
+ auto avx_scale = _mm256_set1_ps(scale);
+ auto avx_weight_decay = _mm256_set1_ps(weight_decay);
+ auto avx_bias_correction1 = _mm256_set1_ps(bias_correction1);
+ auto avx_bias_correction2 = _mm256_set1_ps(bias_correction2);
+ int64_t span = 8;
+ parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
+ float sum_sq_delta_i = 0;
+ float sum_delta_i = 0;
+ for (int64_t j = start; j < end; j += span) {
+ if (j + span > end) {
+ for (int64_t i = j; i < end; i++) {
+ float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale;
+ float m = m_fp32_ptr[i];
+ float v = v_fp32_ptr[i];
+ float p = param_fp32_ptr[i];
+ m = beta1 * m + (1 - beta1) * g;
+ v = beta2 * v + (1 - beta2) * g * g;
+ if (delta_info_ptr != NULL){
+ float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p;
+ sum_delta_i += delta;
+ sum_sq_delta_i += delta * delta;
+ }
+ p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p;
+ param_fp32_ptr[i] = p;
+ param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p);
+ m_fp32_ptr[i] = m;
+ v_fp32_ptr[i] = v;
+ }
+ break; // must break here
+ } else {
+ auto g = _mm256_div_ps(_mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)&g_fp16_ptr[j])), avx_scale);
+ auto m = _mm256_loadu_ps(&m_fp32_ptr[j]);
+ auto v = _mm256_loadu_ps(&v_fp32_ptr[j]);
+ auto p = _mm256_loadu_ps(¶m_fp32_ptr[j]);
+ m = _mm256_fmadd_ps(avx_beta1, m, _mm256_mul_ps(avx_beta1_1, g));
+ v = _mm256_fmadd_ps(avx_beta2, v, _mm256_mul_ps(avx_beta2_1, _mm256_mul_ps(g, g)));
+ if (delta_info_ptr != NULL){
+ auto delta_256 = _mm256_add_ps(
+ _mm256_div_ps(
+ _mm256_div_ps(m, avx_bias_correction1), // m / bias_correction1
+ _mm256_add_ps(_mm256_sqrt_ps(_mm256_div_ps(v, avx_bias_correction2)), avx_eps) // sqrt(v / bias_correction2) + eps
+ ), // m / bias_correction1 / (sqrt(v / bias_correction2) + eps)
+ _mm256_mul_ps(avx_weight_decay, p) // weight_decay * p
+ ); // delta = m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + weight_decay * p
+ sum_delta_i += _mm256_reduce_add_ps(delta_256);
+ sum_sq_delta_i += _mm256_reduce_add_ps(_mm256_mul_ps(delta_256, delta_256));
+ }
+ p = _mm256_fmadd_ps(avx_neg_lr, _mm256_mul_ps(avx_weight_decay, p), p); // p = p - lr * weight_decay * p
+ p = _mm256_fmadd_ps(
+ avx_neg_lr,
+ _mm256_div_ps(
+ _mm256_div_ps(m, avx_bias_correction1), // m / bias_correction1
+ _mm256_add_ps(_mm256_sqrt_ps(_mm256_div_ps(v, avx_bias_correction2)), avx_eps) // sqrt(v / bias_correction2) + eps
+ ), // m / bias_correction1 / (sqrt(v / bias_correction2) + eps)
+ p
+ ); // p = p - lr * m / bias_correction1 / (sqrt(v / bias_correction2) + eps)
+ _mm256_storeu_ps(¶m_fp32_ptr[j], p);
+ _mm_storeu_si128((__m128i*)¶m_fp16_ptr[j], _mm256_cvtps_ph(p, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+ _mm256_storeu_ps(&m_fp32_ptr[j], m);
+ _mm256_storeu_ps(&v_fp32_ptr[j], v);
+ }
+ }
+ if (delta_info_ptr != NULL){
+ delta_mutex.lock();
+ sum_delta += sum_delta_i;
+ sum_sq_delta += sum_sq_delta_i;
+ delta_mutex.unlock();
+ }
+ });
+ if (delta_info_ptr != NULL){
+ delta_info_ptr[0] = sum_delta / n;
+ delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2
+ delta_info_ptr[2] = sum_delta;
+ delta_info_ptr[3] = sum_sq_delta;
+ }
+}
+
+static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2(
+ int64_t n,
+ float* param_fp32_ptr,
+ uint16_t* param_fp16_ptr,
+ float* delta_info_ptr,
+ uint16_t* g_fp16_ptr,
+ float* m_fp32_ptr,
+ float* v_fp32_ptr,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+){
+ float sum_sq_delta = 0;
+ float sum_delta = 0;
+ std::mutex delta_mutex;
+ auto avx_beta1 = _mm512_set1_ps(beta1);
+ auto avx_beta2 = _mm512_set1_ps(beta2);
+ auto avx_beta1_1 = _mm512_set1_ps(1 - beta1);
+ auto avx_beta2_1 = _mm512_set1_ps(1 - beta2);
+ auto avx_eps = _mm512_set1_ps(eps);
+ auto avx_neg_lr = _mm512_set1_ps(-lr);
+ auto avx_scale = _mm512_set1_ps(scale);
+ auto avx_weight_decay = _mm512_set1_ps(weight_decay);
+ auto avx_bias_correction1 = _mm512_set1_ps(bias_correction1);
+ auto avx_bias_correction2 = _mm512_set1_ps(bias_correction2);
+ int64_t span = 16;
+ parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
+ float sum_sq_delta_i = 0;
+ float sum_delta_i = 0;
+ for (int64_t j = start; j < end; j += span) {
+ if (j + span > end) {
+ for (int64_t i = j; i < end; i++) {
+ float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale;
+ float m = m_fp32_ptr[i];
+ float v = v_fp32_ptr[i];
+ float p = param_fp32_ptr[i];
+ m = beta1 * m + (1 - beta1) * g;
+ v = beta2 * v + (1 - beta2) * g * g;
+ if (delta_info_ptr != NULL){
+ float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p;
+ sum_delta_i += delta;
+ sum_sq_delta_i += delta * delta;
+ }
+ p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p;
+ param_fp32_ptr[i] = p;
+ param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p);
+ m_fp32_ptr[i] = m;
+ v_fp32_ptr[i] = v;
+ }
+ break; // must break here
+ }else{
+ auto g = _mm512_div_ps(_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)&g_fp16_ptr[j])), avx_scale);
+ auto m = _mm512_loadu_ps(&m_fp32_ptr[j]);
+ auto v = _mm512_loadu_ps(&v_fp32_ptr[j]);
+ auto p = _mm512_loadu_ps(¶m_fp32_ptr[j]);
+ m = _mm512_fmadd_ps(avx_beta1, m, _mm512_mul_ps(avx_beta1_1, g));
+ v = _mm512_fmadd_ps(avx_beta2, v, _mm512_mul_ps(avx_beta2_1, _mm512_mul_ps(g, g)));
+ if (delta_info_ptr != NULL){
+ auto delta_512 = _mm512_add_ps(
+ _mm512_div_ps(
+ _mm512_div_ps(m, avx_bias_correction1), // m / bias_correction1
+ _mm512_add_ps(_mm512_sqrt_ps(_mm512_div_ps(v, avx_bias_correction2)), avx_eps) // sqrt(v / bias_correction2) + eps
+ ), // m / bias_correction1 / (sqrt(v / bias_correction2) + eps)
+ _mm512_mul_ps(avx_weight_decay, p) // weight_decay * p
+ ); // delta = m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + weight_decay * p
+ sum_delta_i += _mm512_reduce_add_ps(delta_512);
+ sum_sq_delta_i += _mm512_reduce_add_ps(_mm512_mul_ps(delta_512, delta_512));
+ }
+ p = _mm512_fmadd_ps(avx_neg_lr, _mm512_mul_ps(avx_weight_decay, p), p); // p = p - lr * weight_decay * p
+ p = _mm512_fmadd_ps(
+ avx_neg_lr,
+ _mm512_div_ps(
+ _mm512_div_ps(m, avx_bias_correction1), // m / bias_correction1
+ _mm512_add_ps(
+ _mm512_sqrt_ps(_mm512_div_ps(v, avx_bias_correction2)),
+ avx_eps
+ ) // sqrt(v / bias_correction2) + eps
+ ),
+ p
+ ); // p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps)
+ _mm512_storeu_ps(¶m_fp32_ptr[j], p);
+ _mm256_storeu_si256((__m256i*)¶m_fp16_ptr[j], _mm512_cvtps_ph(p, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+ _mm512_storeu_ps(&m_fp32_ptr[j], m);
+ _mm512_storeu_ps(&v_fp32_ptr[j], v);
+ }
+ }
+ if (delta_info_ptr != NULL){
+ delta_mutex.lock();
+ sum_delta += sum_delta_i;
+ sum_sq_delta += sum_sq_delta_i;
+ delta_mutex.unlock();
+ }
+ });
+ if (delta_info_ptr != NULL){
+ delta_info_ptr[0] = sum_delta / n;
+ delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2
+ delta_info_ptr[2] = sum_delta;
+ delta_info_ptr[3] = sum_sq_delta;
+ }
+}
+
+void adam_cpu_fp16_launcher(
+ int64_t n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_fp16,
+ std::uintptr_t delta_info,
+ std::uintptr_t g_fp16,
+ std::uintptr_t m_fp32,
+ std::uintptr_t v_fp32,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+) {
+ auto delta_info_ptr = reinterpret_cast(delta_info);
+ auto param_fp32_ptr = reinterpret_cast(param_fp32);
+ auto m_fp32_ptr = reinterpret_cast(m_fp32);
+ auto v_fp32_ptr = reinterpret_cast(v_fp32);
+ auto param_fp16_ptr = reinterpret_cast(param_fp16);
+ auto g_fp16_ptr = reinterpret_cast(g_fp16);
+ int cpu_level = get_cpu_level();
+ if (cpu_level == 0 ){
+ adam_cpu_0(n, param_fp32_ptr, param_fp16_ptr, delta_info_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
+ }else if(cpu_level == 1){
+ adam_cpu_1(n, param_fp32_ptr, param_fp16_ptr, delta_info_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
+ }else{
+ adam_cpu_2(n, param_fp32_ptr, param_fp16_ptr, delta_info_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
+ }
+}
+
+void adam_cpu_bf16_launcher(
+ int64_t n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_bf16,
+ std::uintptr_t delta_info,
+ std::uintptr_t g_bf16,
+ std::uintptr_t m_fp32,
+ std::uintptr_t v_fp32,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+) {
+ auto delta_info_ptr = reinterpret_cast(delta_info);
+ auto m_fp32_ptr = reinterpret_cast(m_fp32);
+ auto v_fp32_ptr = reinterpret_cast(v_fp32);
+ auto param_fp32_ptr = reinterpret_cast(param_fp32);
+ auto param_bf16_ptr = reinterpret_cast(param_bf16);
+ auto g_bf16_ptr = reinterpret_cast(g_bf16);
+ adam_cpu_bf16_0(n, param_fp32_ptr, param_bf16_ptr, delta_info_ptr, g_bf16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
+}
diff --git a/examples/BMTrain/csrc/include/bind.hpp b/examples/BMTrain/csrc/include/bind.hpp
new file mode 100644
index 00000000..3ff967fd
--- /dev/null
+++ b/examples/BMTrain/csrc/include/bind.hpp
@@ -0,0 +1,111 @@
+#include
+#include "nccl.hpp"
+#include "adam_cpu.hpp"
+
+int is_bf16_supported();
+
+void has_nan_inf_fp16_launcher(int32_t n, std::uintptr_t g_fp16, std::uintptr_t mid, std::uintptr_t out, std::uintptr_t stream);
+void has_nan_inf_bf16_launcher(int32_t n, std::uintptr_t g_bf16, std::uintptr_t mid, std::uintptr_t out, std::uintptr_t stream);
+
+void fp16_from_fp32_value_launcher(
+ int64_t n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_fp16
+);
+void bf16_from_fp32_value_launcher(
+ int64_t n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_bf16
+);
+void cross_entropy_forward_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t input,
+ std::uintptr_t target,
+ std::uintptr_t softmax,
+ std::uintptr_t output,
+ int32_t ignore_index,
+ std::uintptr_t stream
+);
+void cross_entropy_backward_inplace_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t grad_output,
+ std::uintptr_t target,
+ std::uintptr_t x,
+ int32_t ignore_index,
+ std::uintptr_t stream
+);
+void cross_entropy_forward_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t input,
+ std::uintptr_t target,
+ std::uintptr_t softmax,
+ std::uintptr_t output,
+ int32_t ignore_index,
+ std::uintptr_t stream
+);
+void cross_entropy_backward_inplace_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t grad_output,
+ std::uintptr_t target,
+ std::uintptr_t x,
+ int32_t ignore_index,
+ std::uintptr_t stream
+);
+void fused_sumexp_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+);
+void fused_sumexp_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+);
+void fused_softmax_inplace_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+);
+void fused_softmax_inplace_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+);
+void adam_fp16_launcher(
+ int n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_fp16,
+ std::uintptr_t g_fp16,
+ std::uintptr_t m_fp16,
+ std::uintptr_t v_fp32,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2,
+ uintptr_t stream
+);
+void adam_bf16_launcher(
+ int n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_bf16,
+ std::uintptr_t g_bf16,
+ std::uintptr_t m_fp32,
+ std::uintptr_t v_fp32,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2,
+ uintptr_t stream
+);
diff --git a/examples/BMTrain/csrc/include/cpu_info.h b/examples/BMTrain/csrc/include/cpu_info.h
new file mode 100644
index 00000000..53ed48f8
--- /dev/null
+++ b/examples/BMTrain/csrc/include/cpu_info.h
@@ -0,0 +1,38 @@
+#include
+
+static void cpuid(int info[4], int InfoType){
+ __cpuid_count(InfoType, 0, info[0], info[1], info[2], info[3]);
+}
+
+int get_cpu_level() {
+ // SIMD: 128-bit
+ bool HW_F16C;
+
+ // SIMD: 256-bit
+ bool HW_AVX;
+ bool HW_FMA;
+
+ // SIMD: 512-bit
+ bool HW_AVX512F; // AVX512 Foundation
+
+ int info[4];
+ cpuid(info, 0);
+ int nIds = info[0];
+
+ // Detect Features
+ if (nIds >= 0x00000001){
+ cpuid(info,0x00000001);
+ HW_AVX = (info[2] & ((int)1 << 28)) != 0;
+ HW_FMA = (info[2] & ((int)1 << 12)) != 0;
+ HW_F16C = (info[2] & ((int)1 << 29)) != 0;
+ }
+ if (nIds >= 0x00000007){
+ cpuid(info,0x00000007);
+ HW_AVX512F = (info[1] & ((int)1 << 16)) != 0;
+ }
+
+ int ret = 0;
+ if (HW_AVX && HW_FMA && HW_F16C) ret = 1;
+ if (HW_AVX512F) ret = 2;
+ return ret;
+}
diff --git a/examples/BMTrain/csrc/include/nccl.hpp b/examples/BMTrain/csrc/include/nccl.hpp
new file mode 100644
index 00000000..bba0278b
--- /dev/null
+++ b/examples/BMTrain/csrc/include/nccl.hpp
@@ -0,0 +1,188 @@
+#include
+#include
+#include
+
+namespace py = pybind11;
+#include
+
+void checkNCCLStatus(ncclResult_t result) {
+ if (result == ncclSuccess) return;
+ throw std::logic_error(
+ std::string("NCCL Error: ") +
+ ncclGetErrorString(result)
+ );
+}
+
+py::bytes pyNCCLGetUniqueID() {
+ ncclUniqueId uniqueID;
+ checkNCCLStatus(ncclGetUniqueId(&uniqueID));
+ return py::bytes(uniqueID.internal, NCCL_UNIQUE_ID_BYTES);
+}
+
+std::uintptr_t pyNCCLCommInitRank(py::bytes byteUniqueID, int world_size, int rank) {
+ ncclUniqueId uniqueID;
+ std::memcpy(uniqueID.internal, std::string(byteUniqueID).c_str(), NCCL_UNIQUE_ID_BYTES);
+ ncclComm_t comm;
+ checkNCCLStatus(ncclCommInitRank(&comm, world_size, uniqueID, rank));
+ return reinterpret_cast(comm);
+}
+
+void pyNCCLCommDestroy(std::uintptr_t ptrcomm) {
+ ncclComm_t comm = reinterpret_cast(ptrcomm);
+ checkNCCLStatus(ncclCommDestroy(comm));
+}
+
+void pyNCCLAllGather(
+ std::uintptr_t sendbuff,
+ std::uintptr_t recvbuff,
+ size_t sendcount,
+ int datatype,
+ std::uintptr_t comm,
+ std::uintptr_t stream
+) {
+ checkNCCLStatus(ncclAllGather(
+ reinterpret_cast(sendbuff),
+ reinterpret_cast(recvbuff),
+ sendcount,
+ static_cast(datatype),
+ reinterpret_cast(comm),
+ reinterpret_cast(stream)
+ ));
+}
+
+void pyNCCLAllReduce(
+ std::uintptr_t sendbuff,
+ std::uintptr_t recvbuff,
+ size_t count,
+ int data_type,
+ int op,
+ std::uintptr_t comm,
+ std::uintptr_t stream
+) {
+ checkNCCLStatus(ncclAllReduce(
+ reinterpret_cast(sendbuff),
+ reinterpret_cast(recvbuff),
+ count,
+ static_cast(data_type),
+ static_cast(op),
+ reinterpret_cast(comm),
+ reinterpret_cast(stream)
+ ));
+}
+
+void pyNCCLBroadcast(
+ std::uintptr_t sendbuff,
+ std::uintptr_t recvbuff,
+ size_t count,
+ int datatype,
+ int root,
+ std::uintptr_t comm,
+ std::uintptr_t stream
+) {
+ checkNCCLStatus(ncclBroadcast(
+ reinterpret_cast(sendbuff),
+ reinterpret_cast(recvbuff),
+ count,
+ static_cast(datatype),
+ root,
+ reinterpret_cast(comm),
+ reinterpret_cast(stream)
+ ));
+}
+
+void pyNCCLReduce(
+ std::uintptr_t sendbuff,
+ std::uintptr_t recvbuff,
+ size_t count,
+ int datatype,
+ int op,
+ int root,
+ std::uintptr_t comm,
+ std::uintptr_t stream
+) {
+ checkNCCLStatus(ncclReduce(
+ reinterpret_cast(sendbuff),
+ reinterpret_cast(recvbuff),
+ count,
+ static_cast(datatype),
+ static_cast(op),
+ root,
+ reinterpret_cast(comm),
+ reinterpret_cast(stream)
+ ));
+}
+
+void pyNCCLReduceScatter(
+ std::uintptr_t sendbuff,
+ std::uintptr_t recvbuff,
+ size_t recvcount,
+ int datatype,
+ int op,
+ std::uintptr_t comm,
+ std::uintptr_t stream
+) {
+ checkNCCLStatus(ncclReduceScatter(
+ reinterpret_cast(sendbuff),
+ reinterpret_cast(recvbuff),
+ recvcount,
+ static_cast(datatype),
+ static_cast(op),
+ reinterpret_cast(comm),
+ reinterpret_cast(stream)
+ ));
+}
+void pyNCCLSend(
+ std::uintptr_t sendbuff,
+ size_t sendcount,
+ int data_type,
+ int peer,
+ std::uintptr_t comm,
+ std::uintptr_t stream
+) {
+ checkNCCLStatus(ncclSend(
+ reinterpret_cast(sendbuff),
+ sendcount,
+ static_cast(data_type),
+ peer,
+ reinterpret_cast(comm),
+ reinterpret_cast(stream)
+ ));
+}
+void pyNCCLRecv(
+ std::uintptr_t recvbuff,
+ size_t recvcount,
+ int data_type,
+ int peer,
+ std::uintptr_t comm,
+ std::uintptr_t stream
+) {
+ checkNCCLStatus(ncclRecv(
+ reinterpret_cast(recvbuff),
+ recvcount,
+ static_cast(data_type),
+ peer,
+ reinterpret_cast(comm),
+ reinterpret_cast(stream)
+ ));
+}
+void pyNCCLGroupStart() {
+ checkNCCLStatus(ncclGroupStart());
+}
+
+void pyNCCLGroupEnd() {
+ checkNCCLStatus(ncclGroupEnd());
+}
+int pyNCCLCommCount(
+ std::uintptr_t comm
+){
+ int res;
+ checkNCCLStatus(ncclCommCount(reinterpret_cast(comm),&res));
+ return res;
+}
+int pyNCCLCommUserRank(
+ std::uintptr_t comm
+){
+ int rank;
+ checkNCCLStatus(ncclCommUserRank(reinterpret_cast(comm),&rank));
+ return rank;
+}
diff --git a/examples/BMTrain/doc_requirements.txt b/examples/BMTrain/doc_requirements.txt
new file mode 100644
index 00000000..79d22ca0
--- /dev/null
+++ b/examples/BMTrain/doc_requirements.txt
@@ -0,0 +1,5 @@
+sphinx>=4.0.0
+recommonmark
+sphinx_markdown_tables
+sphinx_rtd_theme>=0.3.0
+torch
\ No newline at end of file
diff --git a/examples/BMTrain/docs/Makefile b/examples/BMTrain/docs/Makefile
new file mode 100644
index 00000000..4f2fbe66
--- /dev/null
+++ b/examples/BMTrain/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source-en
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/examples/BMTrain/docs/UPDATE_0.2.0.md b/examples/BMTrain/docs/UPDATE_0.2.0.md
new file mode 100644
index 00000000..92819afd
--- /dev/null
+++ b/examples/BMTrain/docs/UPDATE_0.2.0.md
@@ -0,0 +1,79 @@
+# Update Log 0.2.0
+
+## What's New
+
+### 1. Added an `Optimizer Manager` to support various optimizer algorithms.
+
+Before 0.2.0, the `optimizer` was strongly coupled to the "loss scaler". This results in users cannot use multiple optimizers at the same time when training model in fp16.
+
+**======= Before 0.2.0 =======**
+
+```python
+for iteration in range(1000):
+ # zero grad
+ optimizer.zero_grad()
+
+ # ...
+ # loss scale and backward
+ loss = optimizer.loss_scale(loss)
+ loss.backward()
+
+ # optimizer step
+ bmtrain.optim_step(optimizer, lr_scheduler)
+```
+
+The `bmtrain.optim_step` allows only one `optimizer` and at most one `lr_schduler`, which cannot handle some more complex scenarios.
+
+
+**======= After 0.2.0 =======**
+
+```python
+# create a new instance of optimizer manager
+optim_manager = bmtrain.optim.OptimManager(loss_scale=1024)
+# let optim_manager handle all the optimizer and (optional) their corresponding lr_scheduler
+optim_manager.add_optimizer(optimizer, lr_scheduler)
+# add_optimizer can be called multiple times to add other optimizers.
+
+for iteration in range(1000):
+ # zero grad
+ optim_manager.zero_grad() # calling zero_grad for each optimizer
+
+ # ...
+ # loss scale and backward
+ optim_manager.backward(loss)
+
+ # optimizer step
+ optim_manager.step()
+```
+
+Starting from BMTrain 0.2.0, we provide "OptimManager" to manage optimizers and loss scales.
+`OptimManager` supports managing multiple optimizers and lr_schedulers at the same time, and allows setting the loss scale independently.
+`OptimManager` can also manage pytorch native optimizers, such as SGD, AdamW, etc.
+
+### 2. Pipeline Parallelism
+
+In this version, BMTrain has added a new kind of parallel algorithm: pipeline parallelism.
+To enable pipeline parallelism, one line of code needs to be modified.
+
+**======= ZeRO =======**
+```python
+layers = bmt.TransformerBlockList([
+ # ...
+])
+```
+
+**======= Pipeline =======**
+```python
+layers = bmt.PipelineTransformerBlockList([
+ # ...
+])
+```
+
+Replacing TransformerBlockList with PipelineTransformerBlockList allows the parallel algorithm to switch from ZeRO to pipeline parallelism.
+The number of stages in the pipeline can be set by passing the `pipe_size` parameter to bmtrain.init_distributed.
+
+### 3. Others
+
+* Supports BF16.
+* Tensors recorded in inspector supports backward propagation.
+* Adds new tests.
diff --git a/examples/BMTrain/docs/UPDATE_0.2.3.md b/examples/BMTrain/docs/UPDATE_0.2.3.md
new file mode 100644
index 00000000..e95c6867
--- /dev/null
+++ b/examples/BMTrain/docs/UPDATE_0.2.3.md
@@ -0,0 +1,26 @@
+# Update Log 0.2.3
+
+**Full Changelog**: https://github.com/OpenBMB/BMTrain/compare/0.2.0...0.2.3
+
+
+## What's New
+
+### 1. Get rid of torch cpp extension when compiling
+
+Before 0.2.3, the installation of BMTrain requires the torch cpp extension, which is not friendly to some users (it requires CUDA Runtime fits with torch). Now we get rid of the torch cpp extension when compiling BMTrain, which makes the source-code way installation of BMTrain more convenient.
+Just run `pip install .` to install BMTrain using source code.
+
+### 2. CICD
+
+In 0.2.3, we bring the Github action CICD to BMTrain. Now we can run the CI/CD pipeline on Github to ensure the quality of the code. CICD will run the test cases and compile the source code into wheel packages.
+
+### 3. Loss scale management
+
+In 0.2.3, we add the min and max loss scale to the loss scale manager. The loss scale manager can adjust the loss scale dynamically according to the loss scale's min and max value. This feature can help users to avoid the loss scale being too large or too small.
+
+
+### 3. Others
+
+* Fix `bmt.load(model)` OOM when meets torch >= 1.12
+* `AdamOffloadOptimizer` can choose avx flag automatically in runtime
+* Now BMTrain is fully compatible with torch 2.0
diff --git a/examples/BMTrain/docs/UPDATE_1.0.0.md b/examples/BMTrain/docs/UPDATE_1.0.0.md
new file mode 100644
index 00000000..da9fe86e
--- /dev/null
+++ b/examples/BMTrain/docs/UPDATE_1.0.0.md
@@ -0,0 +1,72 @@
+# Update Log 1.0.0
+
+**Full Changelog**: https://github.com/OpenBMB/BMTrain/compare/0.2.3...1.0.0
+
+## What's New
+
+### 1. Using pytorch's hook mechanism to refactor ZeRO, checkpoint, pipeline, communication implementation
+
+Now user can specify zero level of each `bmt.CheckpointBlock`.
+
+**======= Before 1.0.0 =======**
+
+```python
+import bmtrain as bmt
+bmt.init_distributed(zero_level=3)
+
+```
+
+The zero level setting can only set globally and computation checkpointing can not be disabled.
+For `bmt.TransformerBlockList`, it has to call a blocklist forward instead of a loop way
+
+**======= After 1.0.0 =======**
+
+```python
+import bmtrain as bmt
+bmt.init_distributed()
+# construct block
+class Transformer(bmt.DistributedModule):
+ def __init__(self,
+ num_layers : int) -> None:
+ super().__init__()
+
+ self.transformers = bmt.TransformerBlockList([
+ bmt.Block(
+ TransformerEncoder(
+ dim_model, dim_head, num_heads, dim_ff, bias, dtype
+ ), use_checkpoint=True, zero_level=3
+ )
+ for _ in range(num_layers)
+ ])
+
+ def forward(self):
+ # return self.transformers(x) v0.2.3 can only forward in this way
+ for block in self.transformers:
+ x = block(x)
+ return x
+
+```
+
+You can specify the zero level of each `bmt.CheckpointBlock` (alias of `bmt.Block`) and computation checkpointing can be disabled by setting `use_checkpoint=False` . For `bmt.TransformerBlockList`, it can be called in a loop way.
+
+
+### 2. Add Bf16 support
+
+Now BMTrain supports Bf16 training. You can simply use `dtype=torch.bfloat16' in your model construction method and BMTrain will handle the rest.
+
+### 3. Tensor parallel implementation
+
+For this part, BMTrain only provides a series of parallel ops for Tensor parallel implementation, including `bmt.nn.OpParallelLinear` and `bmt.nn.VPEmbedding` . We also provide a Tensor Parallel training example in our training example. You can simply use `bmt.init_distributed(tp_size=4)` to enable a 4-way tensor parallel training.
+
+### 4. `AdamOffloadOptimizer` can save whole gathered state
+
+Now `AdamOffloadOptimizer` can save whole gathered state. This feature can help users to save the whole gathered state of the optimizer, which can be used to resume training from the saved state. For better performance, we provide async-way save state_dict to overlap I/O and computation.
+```python
+import bmtrain as bmt
+# you can enbale this feature in two ways: Optimmanager's or optimizer's interface
+global_ckpt = bmt.optim.Optimmanager.state_dict(gather_opt=True)
+global_ckpt = optimizer.state_dict(gather=True)
+```
+### Others
+
+* New test for new version BMTrain
\ No newline at end of file
diff --git a/examples/BMTrain/docs/logo.png b/examples/BMTrain/docs/logo.png
new file mode 100644
index 00000000..2dd2f1cb
Binary files /dev/null and b/examples/BMTrain/docs/logo.png differ
diff --git a/examples/BMTrain/docs/make.bat b/examples/BMTrain/docs/make.bat
new file mode 100644
index 00000000..6fcf05b4
--- /dev/null
+++ b/examples/BMTrain/docs/make.bat
@@ -0,0 +1,35 @@
+@ECHO OFF
+
+pushd %~dp0
+
+REM Command file for Sphinx documentation
+
+if "%SPHINXBUILD%" == "" (
+ set SPHINXBUILD=sphinx-build
+)
+set SOURCEDIR=source
+set BUILDDIR=build
+
+if "%1" == "" goto help
+
+%SPHINXBUILD% >NUL 2>NUL
+if errorlevel 9009 (
+ echo.
+ echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
+ echo.installed, then set the SPHINXBUILD environment variable to point
+ echo.to the full path of the 'sphinx-build' executable. Alternatively you
+ echo.may add the Sphinx directory to PATH.
+ echo.
+ echo.If you don't have Sphinx installed, grab it from
+ echo.https://www.sphinx-doc.org/
+ exit /b 1
+)
+
+%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+goto end
+
+:help
+%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
+
+:end
+popd
diff --git a/examples/BMTrain/docs/source-en/_static/css/custom.css b/examples/BMTrain/docs/source-en/_static/css/custom.css
new file mode 100644
index 00000000..1e3a643c
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/_static/css/custom.css
@@ -0,0 +1,124 @@
+a,
+.wy-menu-vertical header,
+.wy-menu-vertical p.caption,
+.wy-nav-top .fa-bars,
+.wy-menu-vertical a:hover,
+
+.rst-content code.literal, .rst-content tt.literal
+
+{
+ color: #315EFE !important;
+}
+
+/* inspired by sphinx press theme */
+.wy-menu.wy-menu-vertical li.toctree-l1.current > a {
+ border-left: solid 8px #315EFE !important;
+ border-top: none;
+ border-bottom: none;
+}
+
+.wy-menu.wy-menu-vertical li.toctree-l1.current > ul {
+ border-left: solid 8px #315EFE !important;
+}
+/* inspired by sphinx press theme */
+
+.wy-nav-side {
+ color: unset !important;
+ background: unset !important;
+ border-right: solid 1px #ccc !important;
+}
+
+.wy-side-nav-search,
+.wy-nav-top,
+.wy-menu-vertical li,
+.wy-menu-vertical li a:hover,
+.wy-menu-vertical li a
+{
+ background: unset !important;
+}
+
+.wy-menu-vertical li.current a {
+ border-right: unset !important;
+}
+
+.wy-side-nav-search div,
+.wy-menu-vertical a {
+ color: #404040 !important;
+}
+
+.wy-menu-vertical button.toctree-expand {
+ color: #333 !important;
+}
+
+.wy-nav-content {
+ max-width: unset;
+}
+
+.rst-content {
+ max-width: 900px;
+}
+
+.wy-nav-content .icon-home:before {
+ content: "Docs";
+}
+
+.wy-side-nav-search .icon-home:before {
+ content: "";
+}
+
+dl.field-list {
+ display: block !important;
+}
+
+dl.field-list > dt:after {
+ content: "" !important;
+}
+
+:root {
+ --dark-blue: #3260F7;
+ --light-blue: rgba(194, 233, 248, 0.1) ;
+}
+
+dl.field-list > dt {
+ display: table;
+ padding-left: 6px !important;
+ padding-right: 6px !important;
+ margin-bottom: 4px !important;
+ padding-bottom: 1px !important;
+ background: var(--light-blue);
+ border-left: solid 2px var(--dark-blue);
+}
+
+
+dl.py.class>dt
+{
+ color: rgba(17, 16, 17, 0.822) !important;
+ background: var(--light-blue) !important;
+ border-top: solid 2px var(--dark-blue) !important;
+}
+
+dl.py.method>dt
+{
+ background: var(--light-blue) !important;
+ border-left: solid 2px var(--dark-blue) !important;
+}
+
+dl.py.attribute>dt,
+dl.py.property>dt
+{
+ background: var(--light-blue) !important;
+ border-left: solid 2px var(--dark-blue) !important;
+}
+
+.fa-plus-square-o::before, .wy-menu-vertical li button.toctree-expand::before,
+.fa-minus-square-o::before, .wy-menu-vertical li.current > a button.toctree-expand::before, .wy-menu-vertical li.on a button.toctree-expand::before
+{
+ content: "";
+}
+
+.rst-content .viewcode-back,
+.rst-content .viewcode-link
+{
+ color:#58b5cc;
+ font-size: 120%;
+}
\ No newline at end of file
diff --git a/examples/BMTrain/docs/source-en/_static/js/custom.js b/examples/BMTrain/docs/source-en/_static/js/custom.js
new file mode 100644
index 00000000..489b7d5c
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/_static/js/custom.js
@@ -0,0 +1,7 @@
+document.addEventListener("DOMContentLoaded", function(event) {
+ document.querySelectorAll(".wy-menu.wy-menu-vertical > ul.current > li > a").forEach(a => a.addEventListener("click", e=>{
+ f = document.querySelector(".wy-menu.wy-menu-vertical > ul.current > li > ul")
+ if (f.style.display=='none') { f.style.display='block'; } else f.style.display = 'none'
+ }));
+ document.querySelectorAll(".headerlink").forEach(a => a.text="\u{1F517}");
+});
\ No newline at end of file
diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.benchmark.rst_bk b/examples/BMTrain/docs/source-en/api/bmtrain.benchmark.rst_bk
new file mode 100644
index 00000000..f8b2902d
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/api/bmtrain.benchmark.rst_bk
@@ -0,0 +1,53 @@
+bmtrain.benchmark package
+=========================
+
+Submodules
+----------
+
+bmtrain.benchmark.all\_gather module
+------------------------------------
+
+.. automodule:: bmtrain.benchmark.all_gather
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.benchmark.reduce\_scatter module
+----------------------------------------
+
+.. automodule:: bmtrain.benchmark.reduce_scatter
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.benchmark.send\_recv module
+-----------------------------------
+
+.. automodule:: bmtrain.benchmark.send_recv
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.benchmark.shape module
+------------------------------
+
+.. automodule:: bmtrain.benchmark.shape
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.benchmark.utils module
+------------------------------
+
+.. automodule:: bmtrain.benchmark.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.benchmark
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.distributed.rst_bk b/examples/BMTrain/docs/source-en/api/bmtrain.distributed.rst_bk
new file mode 100644
index 00000000..ef41db07
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/api/bmtrain.distributed.rst_bk
@@ -0,0 +1,21 @@
+bmtrain.distributed package
+===========================
+
+Submodules
+----------
+
+bmtrain.distributed.ops module
+------------------------------
+
+.. automodule:: bmtrain.distributed.ops
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.distributed
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.inspect.rst b/examples/BMTrain/docs/source-en/api/bmtrain.inspect.rst
new file mode 100644
index 00000000..c57195ad
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/api/bmtrain.inspect.rst
@@ -0,0 +1,37 @@
+bmtrain.inspect package
+=======================
+
+Submodules
+----------
+
+bmtrain.inspect.format module
+-----------------------------
+
+.. automodule:: bmtrain.inspect.format
+ :members: format_summary
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.inspect.model module
+----------------------------
+
+.. automodule:: bmtrain.inspect.model
+ :members: inspect_model
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.inspect.tensor module
+-----------------------------
+
+.. automodule:: bmtrain.inspect.tensor
+ :members: inspect_tensor, InspectTensor
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.inspect
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.loss.rst b/examples/BMTrain/docs/source-en/api/bmtrain.loss.rst
new file mode 100644
index 00000000..03b65646
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/api/bmtrain.loss.rst
@@ -0,0 +1,21 @@
+bmtrain.loss package
+====================
+
+Submodules
+----------
+
+bmtrain.loss.cross\_entropy module
+----------------------------------
+
+.. automodule:: bmtrain.loss.cross_entropy
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.loss
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.lr_scheduler.rst b/examples/BMTrain/docs/source-en/api/bmtrain.lr_scheduler.rst
new file mode 100644
index 00000000..0ba033af
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/api/bmtrain.lr_scheduler.rst
@@ -0,0 +1,61 @@
+bmtrain.lr\_scheduler package
+=============================
+
+Submodules
+----------
+
+bmtrain.lr\_scheduler.cosine module
+-----------------------------------
+
+.. automodule:: bmtrain.lr_scheduler.cosine
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.lr\_scheduler.exponential module
+----------------------------------------
+
+.. automodule:: bmtrain.lr_scheduler.exponential
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.lr\_scheduler.linear module
+-----------------------------------
+
+.. automodule:: bmtrain.lr_scheduler.linear
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.lr\_scheduler.no\_decay module
+--------------------------------------
+
+.. automodule:: bmtrain.lr_scheduler.no_decay
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.lr\_scheduler.noam module
+---------------------------------
+
+.. automodule:: bmtrain.lr_scheduler.noam
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.lr\_scheduler.warmup module
+-----------------------------------
+
+.. automodule:: bmtrain.lr_scheduler.warmup
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.lr_scheduler
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.nccl.rst_bk b/examples/BMTrain/docs/source-en/api/bmtrain.nccl.rst_bk
new file mode 100644
index 00000000..3755d9ef
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/api/bmtrain.nccl.rst_bk
@@ -0,0 +1,21 @@
+bmtrain.nccl package
+====================
+
+Submodules
+----------
+
+bmtrain.nccl.enums module
+-------------------------
+
+.. automodule:: bmtrain.nccl.enums
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.nccl
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.nn.rst b/examples/BMTrain/docs/source-en/api/bmtrain.nn.rst
new file mode 100644
index 00000000..8e2a531f
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/api/bmtrain.nn.rst
@@ -0,0 +1,53 @@
+bmtrain.nn package
+==================
+
+Submodules
+----------
+
+bmtrain.nn.column\_parallel\_linear module
+------------------------------------------
+
+.. automodule:: bmtrain.nn.column_parallel_linear
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.nn.linear module
+------------------------
+
+.. automodule:: bmtrain.nn.linear
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.nn.parallel\_embedding module
+-------------------------------------
+
+.. automodule:: bmtrain.nn.parallel_embedding
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.nn.parallel\_linear\_func module
+----------------------------------------
+
+.. automodule:: bmtrain.nn.parallel_linear_func
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.nn.row\_parallel\_linear module
+---------------------------------------
+
+.. automodule:: bmtrain.nn.row_parallel_linear
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.nn
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.optim.rst b/examples/BMTrain/docs/source-en/api/bmtrain.optim.rst
new file mode 100644
index 00000000..2d47a3dd
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/api/bmtrain.optim.rst
@@ -0,0 +1,37 @@
+bmtrain.optim package
+=====================
+
+Submodules
+----------
+
+bmtrain.optim.adam module
+-------------------------
+
+.. automodule:: bmtrain.optim.adam
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.optim.adam\_offload module
+----------------------------------
+
+.. automodule:: bmtrain.optim.adam_offload
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.optim.optim\_manager module
+-----------------------------------
+
+.. automodule:: bmtrain.optim.optim_manager
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.optim
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source-en/api/bmtrain.rst b/examples/BMTrain/docs/source-en/api/bmtrain.rst
new file mode 100644
index 00000000..8445e5f0
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/api/bmtrain.rst
@@ -0,0 +1,140 @@
+bmtrain package
+===============
+
+Subpackages
+-----------
+
+.. toctree::
+ :maxdepth: 4
+
+ bmtrain.benchmark
+ bmtrain.distributed
+ bmtrain.inspect
+ bmtrain.loss
+ bmtrain.lr_scheduler
+ bmtrain.nccl
+ bmtrain.nn
+ bmtrain.optim
+
+Submodules
+----------
+
+bmtrain.block\_layer module
+---------------------------
+
+.. automodule:: bmtrain.block_layer
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. bmtrain.debug module
+.. --------------------
+
+.. .. automodule:: bmtrain.debug
+.. :members:
+.. :undoc-members:
+.. :show-inheritance:
+
+bmtrain.global\_var module
+--------------------------
+
+.. automodule:: bmtrain.global_var
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. bmtrain.hook\_func module
+.. -------------------------
+
+.. .. automodule:: bmtrain.hook_func
+.. :members:
+.. :undoc-members:
+.. :show-inheritance:
+
+bmtrain.init module
+-------------------
+
+.. automodule:: bmtrain.init
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.layer module
+--------------------
+
+.. automodule:: bmtrain.layer
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.param\_init module
+--------------------------
+
+.. automodule:: bmtrain.param_init
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.parameter module
+------------------------
+
+.. automodule:: bmtrain.parameter
+ :members: DistributedParameter, ParameterInitializer
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.pipe\_layer module
+--------------------------
+
+.. automodule:: bmtrain.pipe_layer
+ :members: PipelineTransformerBlockList
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.store module
+--------------------
+
+.. automodule:: bmtrain.store
+ :members: save, load
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.synchronize module
+--------------------------
+
+.. automodule:: bmtrain.synchronize
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.utils module
+--------------------
+
+.. automodule:: bmtrain.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.wrapper module
+----------------------
+
+.. automodule:: bmtrain.wrapper
+ :members: BMTrainModelWrapper
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.zero\_context module
+----------------------------
+
+.. automodule:: bmtrain.zero_context
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source-en/api/modules.rst b/examples/BMTrain/docs/source-en/api/modules.rst
new file mode 100644
index 00000000..4350b5d7
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/api/modules.rst
@@ -0,0 +1,7 @@
+bmtrain
+=======
+
+.. toctree::
+ :maxdepth: 4
+
+ bmtrain
diff --git a/examples/BMTrain/docs/source-en/conf.py b/examples/BMTrain/docs/source-en/conf.py
new file mode 100644
index 00000000..6351767a
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/conf.py
@@ -0,0 +1,79 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+import os
+import sys
+
+sys.path.insert(0, os.path.abspath("../../.."))
+
+import recommonmark
+from recommonmark.transform import AutoStructify
+
+
+# -- Project information -----------------------------------------------------
+
+project = "BMTrain"
+copyright = "2022, OpenBMB"
+author = "BMTrain Team"
+autodoc_mock_imports = [
+ "numpy",
+ "tensorboard",
+ "bmtrain.nccl._C",
+ "bmtrain.optim._cpu",
+ "bmtrain.optim._cuda",
+ "bmtrain.loss._cuda",
+]
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ "sphinx.ext.autodoc",
+ "sphinx.ext.napoleon",
+ "sphinx.ext.mathjax",
+ "recommonmark",
+ "sphinx_markdown_tables",
+]
+
+source_suffix = [".rst", ".md"]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ["_templates"]
+
+# The language for content autogenerated by Sphinx. Refer to documentation
+# for a list of supported languages.
+#
+# This is also used if you do content translation via gettext catalogs.
+# Usually you set "language" from the command line for these cases.
+language = "en"
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = []
+
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_theme = "sphinx_rtd_theme"
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ["_static"]
+# html_stype="css/custom.css"
+html_css_files = ["css/custom.css"]
+html_js_files = ["js/custom.js"]
diff --git a/examples/BMTrain/docs/source-en/index.rst b/examples/BMTrain/docs/source-en/index.rst
new file mode 100644
index 00000000..a2ec24f9
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/index.rst
@@ -0,0 +1,39 @@
+.. bmtrain-doc documentation master file, created by
+ sphinx-quickstart on Sat Mar 5 17:05:02 2022.
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+BMTrain 文档
+===============================
+
+**BMTrain** 是一个高效的大模型训练工具包,可以用于训练数百亿参数的大模型。BMTrain 可以在分布式训练模型的同时,能够保持代码的简洁性。
+
+=======================================
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Getting Started
+
+ notes/installation.md
+ notes/quickstart.md
+ notes/tech.md
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Package Reference
+
+ api/bmtrain.rst
+ api/bmtrain.benchmark.rst
+ api/bmtrain.distributed.rst
+ api/bmtrain.inspect.rst
+ api/bmtrain.loss.rst
+ api/bmtrain.lr_scheduler.rst
+ api/bmtrain.nccl.rst
+ api/bmtrain.optim.rst
+
+API
+==================
+
+* :ref:`genindex`
+* :ref:`modindex`
+* :ref:`search`
diff --git a/examples/BMTrain/docs/source-en/notes/image/ZeRO3.png b/examples/BMTrain/docs/source-en/notes/image/ZeRO3.png
new file mode 100644
index 00000000..38d52ed5
Binary files /dev/null and b/examples/BMTrain/docs/source-en/notes/image/ZeRO3.png differ
diff --git a/examples/BMTrain/docs/source-en/notes/image/communication_example.png b/examples/BMTrain/docs/source-en/notes/image/communication_example.png
new file mode 100644
index 00000000..0e549b08
Binary files /dev/null and b/examples/BMTrain/docs/source-en/notes/image/communication_example.png differ
diff --git a/examples/BMTrain/docs/source-en/notes/image/communication_fig.png b/examples/BMTrain/docs/source-en/notes/image/communication_fig.png
new file mode 100644
index 00000000..b7636240
Binary files /dev/null and b/examples/BMTrain/docs/source-en/notes/image/communication_fig.png differ
diff --git a/examples/BMTrain/docs/source-en/notes/image/cpu.png b/examples/BMTrain/docs/source-en/notes/image/cpu.png
new file mode 100644
index 00000000..451f397c
Binary files /dev/null and b/examples/BMTrain/docs/source-en/notes/image/cpu.png differ
diff --git a/examples/BMTrain/docs/source-en/notes/image/zero3_example.png b/examples/BMTrain/docs/source-en/notes/image/zero3_example.png
new file mode 100644
index 00000000..40f0e8cc
Binary files /dev/null and b/examples/BMTrain/docs/source-en/notes/image/zero3_example.png differ
diff --git a/examples/BMTrain/docs/source-en/notes/installation.md b/examples/BMTrain/docs/source-en/notes/installation.md
new file mode 100644
index 00000000..330bfb21
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/notes/installation.md
@@ -0,0 +1,45 @@
+# Installation
+
+## Install BMTrain
+
+### 1. From PyPI (Recommend)
+
+```shell
+$ pip install bmtrain
+```
+
+### 2. From Source
+
+```shell
+$ git clone https://github.com/OpenBMB/BMTrain.git
+$ cd BMTrain
+$ python3 setup.py install
+```
+
+## Compilation Options
+
+By setting environment variables, you can configure the compilation options of BMTrain (by default, the compilation environment will be automatically adapted).
+
+### AVX Instructions
+
+* Force the use of AVX instructions: `BMT_AVX256=ON`
+* Force the use of AVX512 instructions: `BMT_AVX512=ON`
+
+### CUDA Compute Capability
+
+`TORCH_CUDA_ARCH_LIST=6.0 6.1 7.0 7.5 8.0+PTX`
+
+## Recommended Configuration
+
+* Network:Infiniband 100Gbps / RoCE 100Gbps
+* GPU:NVIDIA Tesla V100 / NVIDIA Tesla A100 / RTX 3090
+* CPU:CPU that supports AVX512 instructions, 32 cores or above
+* RAM:256GB or above
+
+## FAQ
+
+If the following error message is reported during compilation, try using a newer version of the gcc compiler.
+```
+error: invalid static_cast from type `const torch::OrderdDict<...>`
+```
+
diff --git a/examples/BMTrain/docs/source-en/notes/quickstart.md b/examples/BMTrain/docs/source-en/notes/quickstart.md
new file mode 100644
index 00000000..3ede3d13
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/notes/quickstart.md
@@ -0,0 +1,159 @@
+# Quick Start
+
+## Step 1: Initialize BMTrain
+
+Before you can use BMTrain, you need to initialize it at the beginning of your code. Just like using the distributed module of PyTorch requires the use of **init_process_group** at the beginning of the code, using BMTrain requires the use of **init_distributed** at the beginning of the code.
+
+```python
+import bmtrain as bmt
+bmt.init_distributed(
+ seed=0,
+ # ...
+)
+```
+
+**NOTE:** Do not use PyTorch's distributed module and its associated communication functions when using BMTrain.
+
+## Step 2: Enable ZeRO-3 Optimization
+
+To enable ZeRO-3 optimization, you need to make some simple replacements to the original model's code.
+
+* `torch.nn.Module` -> `bmtrain.DistributedModule`
+* `torch.nn.Parameter` -> `bmtrain.DistributedParameter`
+
+And wrap the transformer blocks with `bmtrain.CheckpointBlock`.
+
+Here is an example.
+
+**Original**
+
+```python
+import torch
+class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.param = torch.nn.Parameter(torch.empty(1024))
+ self.module_list = torch.nn.ModuleList([
+ SomeTransformerBlock(),
+ SomeTransformerBlock(),
+ SomeTransformerBlock()
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+**Replaced**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule): # changed here
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024)) # changed here
+ self.module_list = torch.nn.ModuleList([
+ bmt.CheckpointBlock(SomeTransformerBlock()), # changed here
+ bmt.CheckpointBlock(SomeTransformerBlock()), # changed here
+ bmt.CheckpointBlock(SomeTransformerBlock()) # changed here
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+## Step 3: Enable Communication Optimization
+
+
+To further reduce the extra overhead of communication and overlap communication with computing time, `TransformerBlockList` can be used for optimization.
+
+You can enable them by making the following substitutions to the code:
+
+* `torch.nn.ModuleList` -> `bmtrain.TransformerBlockList`
+* `for module in self.module_list: x = module(x, ...)` -> `x = self.module_list(x, ...)`
+
+**Original**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule):
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024))
+ self.module_list = torch.nn.ModuleList([
+ bmt.CheckpointBlock(SomeTransformerBlock()),
+ bmt.CheckpointBlock(SomeTransformerBlock()),
+ bmt.CheckpointBlock(SomeTransformerBlock())
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+**Replaced**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule):
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024))
+ self.module_list = bmt.TransformerBlockList([ # changed here
+ bmt.CheckpointBlock(SomeTransformerBlock()),
+ bmt.CheckpointBlock(SomeTransformerBlock()),
+ bmt.CheckpointBlock(SomeTransformerBlock())
+ ])
+
+ def forward(self):
+ x = self.param
+ x = self.module_list(x, 1, 2, 3) # changed here
+ return x
+
+```
+
+## Step 4: Launch Distributed Training
+
+BMTrain uses the same launch command as the distributed module of PyTorch.
+
+You can choose one of them depending on your version of PyTorch.
+
+* `${MASTER_ADDR}` means the IP address of the master node.
+* `${MASTER_PORT}` means the port of the master node.
+* `${NNODES}` means the total number of nodes.
+* `${GPU_PER_NODE}` means the number of GPUs per node.
+* `${NODE_RANK}` means the rank of this node.
+
+### torch.distributed.launch
+```shell
+$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.py
+```
+
+### torchrun
+
+```shell
+$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.py
+```
+
+
+For more information, please refer to the [documentation](https://pytorch.org/docs/stable/distributed.html#launch-utility).
+
+## Other Notes
+
+`BMTrain` makes underlying changes to PyTorch, so if your program outputs unexpected results, you can submit information about it in an issue.
+
+For more examples, please refer to the *examples* folder.
+
diff --git a/examples/BMTrain/docs/source-en/notes/tech.md b/examples/BMTrain/docs/source-en/notes/tech.md
new file mode 100644
index 00000000..fc6c704a
--- /dev/null
+++ b/examples/BMTrain/docs/source-en/notes/tech.md
@@ -0,0 +1,11 @@
+# Introduction to Core Technology
+
+## ZeRO-3 Optimization
+
+
+## Overlap Communication and Computation
+
+
+## CPU Offload
+
+
diff --git a/examples/BMTrain/docs/source/_static/css/custom.css b/examples/BMTrain/docs/source/_static/css/custom.css
new file mode 100644
index 00000000..1e3a643c
--- /dev/null
+++ b/examples/BMTrain/docs/source/_static/css/custom.css
@@ -0,0 +1,124 @@
+a,
+.wy-menu-vertical header,
+.wy-menu-vertical p.caption,
+.wy-nav-top .fa-bars,
+.wy-menu-vertical a:hover,
+
+.rst-content code.literal, .rst-content tt.literal
+
+{
+ color: #315EFE !important;
+}
+
+/* inspired by sphinx press theme */
+.wy-menu.wy-menu-vertical li.toctree-l1.current > a {
+ border-left: solid 8px #315EFE !important;
+ border-top: none;
+ border-bottom: none;
+}
+
+.wy-menu.wy-menu-vertical li.toctree-l1.current > ul {
+ border-left: solid 8px #315EFE !important;
+}
+/* inspired by sphinx press theme */
+
+.wy-nav-side {
+ color: unset !important;
+ background: unset !important;
+ border-right: solid 1px #ccc !important;
+}
+
+.wy-side-nav-search,
+.wy-nav-top,
+.wy-menu-vertical li,
+.wy-menu-vertical li a:hover,
+.wy-menu-vertical li a
+{
+ background: unset !important;
+}
+
+.wy-menu-vertical li.current a {
+ border-right: unset !important;
+}
+
+.wy-side-nav-search div,
+.wy-menu-vertical a {
+ color: #404040 !important;
+}
+
+.wy-menu-vertical button.toctree-expand {
+ color: #333 !important;
+}
+
+.wy-nav-content {
+ max-width: unset;
+}
+
+.rst-content {
+ max-width: 900px;
+}
+
+.wy-nav-content .icon-home:before {
+ content: "Docs";
+}
+
+.wy-side-nav-search .icon-home:before {
+ content: "";
+}
+
+dl.field-list {
+ display: block !important;
+}
+
+dl.field-list > dt:after {
+ content: "" !important;
+}
+
+:root {
+ --dark-blue: #3260F7;
+ --light-blue: rgba(194, 233, 248, 0.1) ;
+}
+
+dl.field-list > dt {
+ display: table;
+ padding-left: 6px !important;
+ padding-right: 6px !important;
+ margin-bottom: 4px !important;
+ padding-bottom: 1px !important;
+ background: var(--light-blue);
+ border-left: solid 2px var(--dark-blue);
+}
+
+
+dl.py.class>dt
+{
+ color: rgba(17, 16, 17, 0.822) !important;
+ background: var(--light-blue) !important;
+ border-top: solid 2px var(--dark-blue) !important;
+}
+
+dl.py.method>dt
+{
+ background: var(--light-blue) !important;
+ border-left: solid 2px var(--dark-blue) !important;
+}
+
+dl.py.attribute>dt,
+dl.py.property>dt
+{
+ background: var(--light-blue) !important;
+ border-left: solid 2px var(--dark-blue) !important;
+}
+
+.fa-plus-square-o::before, .wy-menu-vertical li button.toctree-expand::before,
+.fa-minus-square-o::before, .wy-menu-vertical li.current > a button.toctree-expand::before, .wy-menu-vertical li.on a button.toctree-expand::before
+{
+ content: "";
+}
+
+.rst-content .viewcode-back,
+.rst-content .viewcode-link
+{
+ color:#58b5cc;
+ font-size: 120%;
+}
\ No newline at end of file
diff --git a/examples/BMTrain/docs/source/_static/js/custom.js b/examples/BMTrain/docs/source/_static/js/custom.js
new file mode 100644
index 00000000..489b7d5c
--- /dev/null
+++ b/examples/BMTrain/docs/source/_static/js/custom.js
@@ -0,0 +1,7 @@
+document.addEventListener("DOMContentLoaded", function(event) {
+ document.querySelectorAll(".wy-menu.wy-menu-vertical > ul.current > li > a").forEach(a => a.addEventListener("click", e=>{
+ f = document.querySelector(".wy-menu.wy-menu-vertical > ul.current > li > ul")
+ if (f.style.display=='none') { f.style.display='block'; } else f.style.display = 'none'
+ }));
+ document.querySelectorAll(".headerlink").forEach(a => a.text="\u{1F517}");
+});
\ No newline at end of file
diff --git a/examples/BMTrain/docs/source/api/bmtrain.benchmark.rst_bk b/examples/BMTrain/docs/source/api/bmtrain.benchmark.rst_bk
new file mode 100644
index 00000000..f8b2902d
--- /dev/null
+++ b/examples/BMTrain/docs/source/api/bmtrain.benchmark.rst_bk
@@ -0,0 +1,53 @@
+bmtrain.benchmark package
+=========================
+
+Submodules
+----------
+
+bmtrain.benchmark.all\_gather module
+------------------------------------
+
+.. automodule:: bmtrain.benchmark.all_gather
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.benchmark.reduce\_scatter module
+----------------------------------------
+
+.. automodule:: bmtrain.benchmark.reduce_scatter
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.benchmark.send\_recv module
+-----------------------------------
+
+.. automodule:: bmtrain.benchmark.send_recv
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.benchmark.shape module
+------------------------------
+
+.. automodule:: bmtrain.benchmark.shape
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.benchmark.utils module
+------------------------------
+
+.. automodule:: bmtrain.benchmark.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.benchmark
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source/api/bmtrain.distributed.rst_bk b/examples/BMTrain/docs/source/api/bmtrain.distributed.rst_bk
new file mode 100644
index 00000000..ef41db07
--- /dev/null
+++ b/examples/BMTrain/docs/source/api/bmtrain.distributed.rst_bk
@@ -0,0 +1,21 @@
+bmtrain.distributed package
+===========================
+
+Submodules
+----------
+
+bmtrain.distributed.ops module
+------------------------------
+
+.. automodule:: bmtrain.distributed.ops
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.distributed
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source/api/bmtrain.inspect.rst b/examples/BMTrain/docs/source/api/bmtrain.inspect.rst
new file mode 100644
index 00000000..c57195ad
--- /dev/null
+++ b/examples/BMTrain/docs/source/api/bmtrain.inspect.rst
@@ -0,0 +1,37 @@
+bmtrain.inspect package
+=======================
+
+Submodules
+----------
+
+bmtrain.inspect.format module
+-----------------------------
+
+.. automodule:: bmtrain.inspect.format
+ :members: format_summary
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.inspect.model module
+----------------------------
+
+.. automodule:: bmtrain.inspect.model
+ :members: inspect_model
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.inspect.tensor module
+-----------------------------
+
+.. automodule:: bmtrain.inspect.tensor
+ :members: inspect_tensor, InspectTensor
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.inspect
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source/api/bmtrain.loss.rst b/examples/BMTrain/docs/source/api/bmtrain.loss.rst
new file mode 100644
index 00000000..03b65646
--- /dev/null
+++ b/examples/BMTrain/docs/source/api/bmtrain.loss.rst
@@ -0,0 +1,21 @@
+bmtrain.loss package
+====================
+
+Submodules
+----------
+
+bmtrain.loss.cross\_entropy module
+----------------------------------
+
+.. automodule:: bmtrain.loss.cross_entropy
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.loss
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source/api/bmtrain.lr_scheduler.rst b/examples/BMTrain/docs/source/api/bmtrain.lr_scheduler.rst
new file mode 100644
index 00000000..0ba033af
--- /dev/null
+++ b/examples/BMTrain/docs/source/api/bmtrain.lr_scheduler.rst
@@ -0,0 +1,61 @@
+bmtrain.lr\_scheduler package
+=============================
+
+Submodules
+----------
+
+bmtrain.lr\_scheduler.cosine module
+-----------------------------------
+
+.. automodule:: bmtrain.lr_scheduler.cosine
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.lr\_scheduler.exponential module
+----------------------------------------
+
+.. automodule:: bmtrain.lr_scheduler.exponential
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.lr\_scheduler.linear module
+-----------------------------------
+
+.. automodule:: bmtrain.lr_scheduler.linear
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.lr\_scheduler.no\_decay module
+--------------------------------------
+
+.. automodule:: bmtrain.lr_scheduler.no_decay
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.lr\_scheduler.noam module
+---------------------------------
+
+.. automodule:: bmtrain.lr_scheduler.noam
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.lr\_scheduler.warmup module
+-----------------------------------
+
+.. automodule:: bmtrain.lr_scheduler.warmup
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.lr_scheduler
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source/api/bmtrain.nccl.rst_bk b/examples/BMTrain/docs/source/api/bmtrain.nccl.rst_bk
new file mode 100644
index 00000000..3755d9ef
--- /dev/null
+++ b/examples/BMTrain/docs/source/api/bmtrain.nccl.rst_bk
@@ -0,0 +1,21 @@
+bmtrain.nccl package
+====================
+
+Submodules
+----------
+
+bmtrain.nccl.enums module
+-------------------------
+
+.. automodule:: bmtrain.nccl.enums
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.nccl
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source/api/bmtrain.nn.rst b/examples/BMTrain/docs/source/api/bmtrain.nn.rst
new file mode 100644
index 00000000..8e2a531f
--- /dev/null
+++ b/examples/BMTrain/docs/source/api/bmtrain.nn.rst
@@ -0,0 +1,53 @@
+bmtrain.nn package
+==================
+
+Submodules
+----------
+
+bmtrain.nn.column\_parallel\_linear module
+------------------------------------------
+
+.. automodule:: bmtrain.nn.column_parallel_linear
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.nn.linear module
+------------------------
+
+.. automodule:: bmtrain.nn.linear
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.nn.parallel\_embedding module
+-------------------------------------
+
+.. automodule:: bmtrain.nn.parallel_embedding
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.nn.parallel\_linear\_func module
+----------------------------------------
+
+.. automodule:: bmtrain.nn.parallel_linear_func
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.nn.row\_parallel\_linear module
+---------------------------------------
+
+.. automodule:: bmtrain.nn.row_parallel_linear
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.nn
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source/api/bmtrain.optim.rst b/examples/BMTrain/docs/source/api/bmtrain.optim.rst
new file mode 100644
index 00000000..2d47a3dd
--- /dev/null
+++ b/examples/BMTrain/docs/source/api/bmtrain.optim.rst
@@ -0,0 +1,37 @@
+bmtrain.optim package
+=====================
+
+Submodules
+----------
+
+bmtrain.optim.adam module
+-------------------------
+
+.. automodule:: bmtrain.optim.adam
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.optim.adam\_offload module
+----------------------------------
+
+.. automodule:: bmtrain.optim.adam_offload
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.optim.optim\_manager module
+-----------------------------------
+
+.. automodule:: bmtrain.optim.optim_manager
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain.optim
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source/api/bmtrain.rst b/examples/BMTrain/docs/source/api/bmtrain.rst
new file mode 100644
index 00000000..8445e5f0
--- /dev/null
+++ b/examples/BMTrain/docs/source/api/bmtrain.rst
@@ -0,0 +1,140 @@
+bmtrain package
+===============
+
+Subpackages
+-----------
+
+.. toctree::
+ :maxdepth: 4
+
+ bmtrain.benchmark
+ bmtrain.distributed
+ bmtrain.inspect
+ bmtrain.loss
+ bmtrain.lr_scheduler
+ bmtrain.nccl
+ bmtrain.nn
+ bmtrain.optim
+
+Submodules
+----------
+
+bmtrain.block\_layer module
+---------------------------
+
+.. automodule:: bmtrain.block_layer
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. bmtrain.debug module
+.. --------------------
+
+.. .. automodule:: bmtrain.debug
+.. :members:
+.. :undoc-members:
+.. :show-inheritance:
+
+bmtrain.global\_var module
+--------------------------
+
+.. automodule:: bmtrain.global_var
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. bmtrain.hook\_func module
+.. -------------------------
+
+.. .. automodule:: bmtrain.hook_func
+.. :members:
+.. :undoc-members:
+.. :show-inheritance:
+
+bmtrain.init module
+-------------------
+
+.. automodule:: bmtrain.init
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.layer module
+--------------------
+
+.. automodule:: bmtrain.layer
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.param\_init module
+--------------------------
+
+.. automodule:: bmtrain.param_init
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.parameter module
+------------------------
+
+.. automodule:: bmtrain.parameter
+ :members: DistributedParameter, ParameterInitializer
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.pipe\_layer module
+--------------------------
+
+.. automodule:: bmtrain.pipe_layer
+ :members: PipelineTransformerBlockList
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.store module
+--------------------
+
+.. automodule:: bmtrain.store
+ :members: save, load
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.synchronize module
+--------------------------
+
+.. automodule:: bmtrain.synchronize
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.utils module
+--------------------
+
+.. automodule:: bmtrain.utils
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.wrapper module
+----------------------
+
+.. automodule:: bmtrain.wrapper
+ :members: BMTrainModelWrapper
+ :undoc-members:
+ :show-inheritance:
+
+bmtrain.zero\_context module
+----------------------------
+
+.. automodule:: bmtrain.zero_context
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+Module contents
+---------------
+
+.. automodule:: bmtrain
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/examples/BMTrain/docs/source/api/modules.rst b/examples/BMTrain/docs/source/api/modules.rst
new file mode 100644
index 00000000..4350b5d7
--- /dev/null
+++ b/examples/BMTrain/docs/source/api/modules.rst
@@ -0,0 +1,7 @@
+bmtrain
+=======
+
+.. toctree::
+ :maxdepth: 4
+
+ bmtrain
diff --git a/examples/BMTrain/docs/source/conf.py b/examples/BMTrain/docs/source/conf.py
new file mode 100644
index 00000000..066680f7
--- /dev/null
+++ b/examples/BMTrain/docs/source/conf.py
@@ -0,0 +1,72 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+import os
+import sys
+sys.path.insert(0, os.path.abspath('../../..'))
+
+import recommonmark
+from recommonmark.transform import AutoStructify
+
+
+
+# -- Project information -----------------------------------------------------
+
+project = 'BMTrain'
+copyright = '2022, OpenBMB'
+author = 'BMTrain Team'
+autodoc_mock_imports = ["numpy", "tensorboard", "bmtrain.nccl._C", "bmtrain.optim._cpu", "bmtrain.optim._cuda", "bmtrain.loss._cuda"]
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ 'sphinx.ext.autodoc',
+ 'sphinx.ext.napoleon',
+ 'sphinx.ext.mathjax',
+ 'recommonmark',
+ 'sphinx_markdown_tables',
+]
+
+source_suffix = ['.rst', '.md']
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# The language for content autogenerated by Sphinx. Refer to documentation
+# for a list of supported languages.
+#
+# This is also used if you do content translation via gettext catalogs.
+# Usually you set "language" from the command line for these cases.
+language = 'zh_CN'
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = []
+
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_theme = 'sphinx_rtd_theme'
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ['_static']
+#html_stype="css/custom.css"
+html_css_files=['css/custom.css' ]
+html_js_files= ['js/custom.js' ]
diff --git a/examples/BMTrain/docs/source/index.rst b/examples/BMTrain/docs/source/index.rst
new file mode 100644
index 00000000..a2ec24f9
--- /dev/null
+++ b/examples/BMTrain/docs/source/index.rst
@@ -0,0 +1,39 @@
+.. bmtrain-doc documentation master file, created by
+ sphinx-quickstart on Sat Mar 5 17:05:02 2022.
+ You can adapt this file completely to your liking, but it should at least
+ contain the root `toctree` directive.
+
+BMTrain 文档
+===============================
+
+**BMTrain** 是一个高效的大模型训练工具包,可以用于训练数百亿参数的大模型。BMTrain 可以在分布式训练模型的同时,能够保持代码的简洁性。
+
+=======================================
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Getting Started
+
+ notes/installation.md
+ notes/quickstart.md
+ notes/tech.md
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Package Reference
+
+ api/bmtrain.rst
+ api/bmtrain.benchmark.rst
+ api/bmtrain.distributed.rst
+ api/bmtrain.inspect.rst
+ api/bmtrain.loss.rst
+ api/bmtrain.lr_scheduler.rst
+ api/bmtrain.nccl.rst
+ api/bmtrain.optim.rst
+
+API
+==================
+
+* :ref:`genindex`
+* :ref:`modindex`
+* :ref:`search`
diff --git a/examples/BMTrain/docs/source/notes/image/ZeRO3.png b/examples/BMTrain/docs/source/notes/image/ZeRO3.png
new file mode 100644
index 00000000..38d52ed5
Binary files /dev/null and b/examples/BMTrain/docs/source/notes/image/ZeRO3.png differ
diff --git a/examples/BMTrain/docs/source/notes/image/communication_example.png b/examples/BMTrain/docs/source/notes/image/communication_example.png
new file mode 100644
index 00000000..0e549b08
Binary files /dev/null and b/examples/BMTrain/docs/source/notes/image/communication_example.png differ
diff --git a/examples/BMTrain/docs/source/notes/image/communication_fig.png b/examples/BMTrain/docs/source/notes/image/communication_fig.png
new file mode 100644
index 00000000..b7636240
Binary files /dev/null and b/examples/BMTrain/docs/source/notes/image/communication_fig.png differ
diff --git a/examples/BMTrain/docs/source/notes/image/cpu.png b/examples/BMTrain/docs/source/notes/image/cpu.png
new file mode 100644
index 00000000..451f397c
Binary files /dev/null and b/examples/BMTrain/docs/source/notes/image/cpu.png differ
diff --git a/examples/BMTrain/docs/source/notes/image/zero3_example.png b/examples/BMTrain/docs/source/notes/image/zero3_example.png
new file mode 100644
index 00000000..40f0e8cc
Binary files /dev/null and b/examples/BMTrain/docs/source/notes/image/zero3_example.png differ
diff --git a/examples/BMTrain/docs/source/notes/installation.md b/examples/BMTrain/docs/source/notes/installation.md
new file mode 100644
index 00000000..3fff4eb6
--- /dev/null
+++ b/examples/BMTrain/docs/source/notes/installation.md
@@ -0,0 +1,45 @@
+# 安装
+
+## 安装方法
+
+### 1. 用 pip 安装 (推荐)
+
+```shell
+$ pip install bmtrain
+```
+
+### 2. 从源代码安装
+
+```shell
+$ git clone https://github.com/OpenBMB/BMTrain.git
+$ cd BMTrain
+$ python3 setup.py install
+```
+
+## 编译选项
+
+通过设置环境变量,你可以控制BMTrain的编译选项(默认会自动适配编译环境):
+
+### AVX指令集
+
+* 强制使用AVX指令集: `BMT_AVX256=ON`
+* 强制使用AVX512指令集: `BMT_AVX512=ON`
+
+### CUDA计算兼容性
+
+`TORCH_CUDA_ARCH_LIST=6.0 6.1 7.0 7.5 8.0+PTX`
+
+## 推荐配置
+
+* 网络:Infiniband 100Gbps / RoCE 100Gbps
+* GPU:NVIDIA Tesla V100 / NVIDIA Tesla A100 / RTX 3090
+* CPU:支持AVX512指令集的CPU,32核心以上
+* RAM:256GB以上
+
+## 常见问题
+
+如果在编译过程中如下的报错信息,请尝试使用更新版本的gcc编译器。
+```
+error: invalid static_cast from type `const torch::OrderdDict<...>`
+```
+
diff --git a/examples/BMTrain/docs/source/notes/quickstart.md b/examples/BMTrain/docs/source/notes/quickstart.md
new file mode 100644
index 00000000..f139fd33
--- /dev/null
+++ b/examples/BMTrain/docs/source/notes/quickstart.md
@@ -0,0 +1,146 @@
+# 快速入门
+
+## Step 1: 启用 BMTrain
+
+要使用BMTrain需要在代码中引入`bmtrain`工具包,并在代码的开头使用`bmtrain.init_distributed`
+
+```python
+import bmtrain as bmt
+bmt.init_distributed(
+ seed=0,
+ # ...
+)
+```
+
+**注意:** 使用BMTrain时请不要使用PyTorch自带的`distributed`模块,包括`torch.distributed.init_process_group`以及相关通信函数。
+
+## Step 2: 使用 ZeRO-3 优化
+
+使用ZeRO-3优化需要对模型代码进行简单替换:
+
+* `torch.nn.Module` -> `bmtrain.DistributedModule`
+* `torch.nn.Parameter` -> `bmtrain.DistributedParameter`
+
+并在合适的模块上使用`Checkpointing`。
+
+**原始代码:**
+
+```python
+import torch
+class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.param = torch.nn.Parameter(torch.empty(1024))
+ self.module_list = torch.nn.ModuleList([
+ SomeTransformerBlock(),
+ SomeTransformerBlock(),
+ SomeTransformerBlock()
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+**替换后代码:**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule):
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024))
+ self.module_list = torch.nn.ModuleList([
+ bmt.Block(SomeTransformerBlock()),
+ bmt.Block(SomeTransformerBlock()),
+ bmt.Block(SomeTransformerBlock())
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+## Step 3: 通信优化
+
+为了进一步缩短通信额外开销,将通信与运算时间重叠,可以使用`TransformerBlockList`来进一步优化。
+在使用时需要对代码进行简单替换:
+
+* `torch.nn.ModuleList` -> `bmtrain.TransformerBlockList`
+* `for module in self.module_list: x = module(x, ...)` -> `x = self.module_list(x, ...)`
+
+**原始代码:**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule):
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024))
+ self.module_list = torch.nn.ModuleList([
+ bmt.Block(SomeTransformerBlock()),
+ bmt.Block(SomeTransformerBlock()),
+ bmt.Block(SomeTransformerBlock())
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+**替换后代码:**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule):
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024))
+ self.module_list = bmt.TransformerBlockList([
+ bmt.Block(SomeTransformerBlock()),
+ bmt.Block(SomeTransformerBlock()),
+ bmt.Block(SomeTransformerBlock())
+ ])
+
+ def forward(self):
+ x = self.param
+ x = self.module_list(x, 1, 2, 3)
+ return x
+
+```
+
+## Step 4: 运行分布式训练代码
+
+BMTrain支持PyTorch原生的分布式训练启动器,不需要额外的参数:
+
+### torch.distributed.launch
+```shell
+$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.py
+```
+
+### torchrun
+
+```shell
+$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.py
+```
+
+更多信息请参考PyTorch官方文档:[Launch utility](https://pytorch.org/docs/stable/distributed.html#launch-utility)
+
+## 其它说明
+
+`BMTrain`工具包对PyTorch进行了底层修改,如果你的程序输出了意料之外的结果,可以在issue中提交相关信息。
+
+更多例子请参考 *examples* 文件夹。
+
diff --git a/examples/BMTrain/docs/source/notes/tech.md b/examples/BMTrain/docs/source/notes/tech.md
new file mode 100644
index 00000000..9104fde3
--- /dev/null
+++ b/examples/BMTrain/docs/source/notes/tech.md
@@ -0,0 +1,11 @@
+# 核心技术简介
+
+## ZeRO-3 优化
+
+
+## 通信运算重叠
+
+
+## CPU Offload
+
+
diff --git a/examples/BMTrain/example/README.md b/examples/BMTrain/example/README.md
new file mode 100644
index 00000000..395b5e64
--- /dev/null
+++ b/examples/BMTrain/example/README.md
@@ -0,0 +1,5 @@
+# Example
+
+This is an example of BMTrain's implementation of GPT-2.
+
+For more model implementations, please refer to [Model Center](https://github.com/OpenBMB/ModelCenter).
\ No newline at end of file
diff --git a/examples/BMTrain/example/benchmark.py b/examples/BMTrain/example/benchmark.py
new file mode 100644
index 00000000..8a7092d9
--- /dev/null
+++ b/examples/BMTrain/example/benchmark.py
@@ -0,0 +1,12 @@
+import bmtrain as bmt
+
+def main():
+ bmt.init_distributed()
+ bmt.print_rank("======= All Gather =======")
+ bmt.benchmark.all_gather()
+ bmt.print_rank("===== Reduce Scatter =====")
+ bmt.benchmark.reduce_scatter()
+
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/examples/BMTrain/example/layers/__init__.py b/examples/BMTrain/example/layers/__init__.py
new file mode 100644
index 00000000..ef4617c0
--- /dev/null
+++ b/examples/BMTrain/example/layers/__init__.py
@@ -0,0 +1,5 @@
+from .embedding import Embedding
+from .feedforward import Feedforward
+from .layernorm import Layernorm
+from .attention import Attention
+from .transformer import TransformerEncoder
diff --git a/examples/BMTrain/example/layers/attention.py b/examples/BMTrain/example/layers/attention.py
new file mode 100644
index 00000000..0f5155d4
--- /dev/null
+++ b/examples/BMTrain/example/layers/attention.py
@@ -0,0 +1,118 @@
+from typing import Optional
+import torch
+import bmtrain as bmt
+from bmtrain.nn import (
+ Linear,
+ ColumnParallelLinear,
+ RowParallelLinear,
+)
+import math
+from bmtrain.global_var import config
+from bmtrain.distributed import all_gather
+
+class Attention(bmt.DistributedModule):
+ def __init__(self,
+ dim_model : int, dim_head : int,
+ num_heads : int, bias : bool = True,
+ dtype = None
+ ) -> None:
+ super().__init__()
+
+ if config['tp_size'] > 1:
+ self.project_q = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False)
+ self.project_k = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False)
+ self.project_v = ColumnParallelLinear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype, gather_input=False)
+ self.project_out = RowParallelLinear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype)
+ else:
+ self.project_q = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
+ self.project_k = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
+ self.project_v = Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
+ self.project_out = Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype)
+
+
+ self.softmax = torch.nn.Softmax(dim=-1)
+ self.num_heads = num_heads
+ self.dim_head = dim_head
+ self.dim_model = dim_model
+
+ def forward(self,
+ hidden_q : torch.Tensor, # (batch_size, seq_q, dim_model)
+ hidden_kv : torch.Tensor, # (batch_size, seq_kv, dim_model)
+ mask : torch.BoolTensor, # (batch_size, seq_q, seq_kv)
+ position_bias : Optional[torch.Tensor] = None, # (batch, num_heads, seq_q, seq_kv)
+ ) -> torch.Tensor:
+ batch_size = hidden_q.size()[0]
+
+ assert hidden_q.data_ptr() == hidden_kv.data_ptr()
+
+ if config['tp_size'] > 1:
+ hidden_q = bmt.nn.OpParallelLinear.apply(
+ hidden_q,
+ torch.cat([self.project_q.weight, self.project_k.weight, self.project_v.weight], dim=0),
+ torch.cat([self.project_q.bias, self.project_k.bias, self.project_v.bias], dim=0),
+ True, False,
+ False, None
+ )
+ hidden_q = hidden_q.view(batch_size, -1, hidden_q.shape[-1])
+ h_q, h_k, h_v = hidden_q.chunk(3, dim=-1)
+ else:
+ h_q : torch.Tensor = self.project_q(hidden_q)
+ h_k : torch.Tensor = self.project_k(hidden_kv)
+ h_v : torch.Tensor = self.project_v(hidden_kv)
+
+ seq_q = h_q.size()[1]
+ seq_kv = h_k.size(1)
+
+ h_q = h_q.view(batch_size, seq_q, -1, self.dim_head)
+ h_k = h_k.view(batch_size, seq_kv, -1, self.dim_head)
+ h_v = h_v.view(batch_size, seq_kv, -1, self.dim_head)
+
+ h_q = h_q.permute(0, 2, 1, 3).contiguous()
+ h_k = h_k.permute(0, 2, 1, 3).contiguous()
+ h_v = h_v.permute(0, 2, 1, 3).contiguous()
+
+ h_q = h_q.view(-1, seq_q, self.dim_head)
+ h_k = h_k.view(-1, seq_kv, self.dim_head)
+ h_v = h_v.view(-1, seq_kv, self.dim_head)
+
+ score = torch.bmm(
+ h_q, h_k.transpose(1, 2)
+ )
+ score = score / math.sqrt(self.dim_head)
+
+ score = score.view(batch_size, -1, seq_q, seq_kv)
+
+ if position_bias is not None:
+ score = score + position_bias.view(batch_size, -1, seq_q, seq_kv)
+
+ score = torch.where(
+ mask.view(batch_size, 1, seq_q, seq_kv),
+ score,
+ torch.scalar_tensor(float('-inf'), device=score.device, dtype=score.dtype)
+ )
+
+ score = torch.where(
+ mask.view(batch_size, 1, seq_q, seq_kv),
+ self.softmax(score),
+ torch.scalar_tensor(0, device=score.device, dtype=score.dtype)
+ )
+
+ score = score.view(-1, seq_q, seq_kv)
+
+ h_out = torch.bmm(
+ score, h_v
+ )
+ h_out = h_out.view(batch_size, -1, seq_q, self.dim_head)
+ h_out = h_out.permute(0, 2, 1, 3).contiguous()
+ h_out = h_out.view(batch_size, seq_q, -1)
+ if config['tp_size'] > 1:
+ h_out = h_out.view(h_out.shape[0] * bmt.config["tp_size"], -1, h_out.shape[-1])
+
+ attn_out = self.project_out(h_out)
+
+ return attn_out
+
+
+
+
+
diff --git a/examples/BMTrain/example/layers/embedding.py b/examples/BMTrain/example/layers/embedding.py
new file mode 100644
index 00000000..f62151c4
--- /dev/null
+++ b/examples/BMTrain/example/layers/embedding.py
@@ -0,0 +1,102 @@
+import math
+from typing import Optional
+import torch
+import torch.nn.functional as F
+import bmtrain as bmt
+
+
+class Embedding(bmt.DistributedModule):
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
+ max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
+ sparse: bool = False, _weight: Optional[torch.Tensor] = None,
+ dtype=None):
+ super().__init__()
+
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ if padding_idx is not None:
+ if padding_idx > 0:
+ assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
+ elif padding_idx < 0:
+ assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
+ padding_idx = self.num_embeddings + padding_idx
+ self.padding_idx = padding_idx
+ self.max_norm = max_norm
+ self.norm_type = norm_type
+ self.scale_grad_by_freq = scale_grad_by_freq
+ if _weight is None:
+ self.weight = bmt.DistributedParameter(torch.empty(num_embeddings, embedding_dim, dtype=dtype, device="cuda"), init_method=torch.nn.init.normal_)
+ else:
+ self.weight = bmt.DistributedParameter(_weight)
+
+ self.sparse = sparse
+
+ @classmethod
+ def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
+ max_norm=None, norm_type=2., scale_grad_by_freq=False,
+ sparse=False):
+ r"""Creates Embedding instance from given 2-dimensional FloatTensor.
+
+ Args:
+ embeddings (Tensor): FloatTensor containing weights for the Embedding.
+ First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
+ freeze (boolean, optional): If ``True``, the tensor does not get updated in the learning process.
+ Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
+ padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
+ therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
+ i.e. it remains as a fixed "pad".
+ max_norm (float, optional): See module initialization documentation.
+ norm_type (float, optional): See module initialization documentation. Default ``2``.
+ scale_grad_by_freq (boolean, optional): See module initialization documentation. Default ``False``.
+ sparse (bool, optional): See module initialization documentation.
+
+ Examples::
+
+ >>> # FloatTensor containing pretrained weights
+ >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
+ >>> embedding = nn.Embedding.from_pretrained(weight)
+ >>> # Get embeddings for index 1
+ >>> input = torch.LongTensor([1])
+ >>> embedding(input)
+ tensor([[ 4.0000, 5.1000, 6.3000]])
+ """
+ assert embeddings.dim() == 2, \
+ 'Embeddings parameter is expected to be 2-dimensional'
+ rows, cols = embeddings.shape
+ embedding = cls(
+ num_embeddings=rows,
+ embedding_dim=cols,
+ _weight=embeddings,
+ padding_idx=padding_idx,
+ max_norm=max_norm,
+ norm_type=norm_type,
+ scale_grad_by_freq=scale_grad_by_freq,
+ sparse=sparse)
+ embedding.weight.requires_grad = not freeze
+ return embedding
+
+ def forward(self, input: torch.Tensor, projection : bool = False) -> torch.Tensor:
+ if not projection:
+ out = F.embedding(
+ input, self.weight, self.padding_idx, self.max_norm,
+ self.norm_type, self.scale_grad_by_freq, self.sparse)
+ return out
+ else:
+ out = F.linear(input, self.weight)
+ return out
+
+ def extra_repr(self) -> str:
+ s = '{num_embeddings}, {embedding_dim}'
+ if self.padding_idx is not None:
+ s += ', padding_idx={padding_idx}'
+ if self.max_norm is not None:
+ s += ', max_norm={max_norm}'
+ if self.norm_type != 2:
+ s += ', norm_type={norm_type}'
+ if self.scale_grad_by_freq is not False:
+ s += ', scale_grad_by_freq={scale_grad_by_freq}'
+ if self.sparse is not False:
+ s += ', sparse=True'
+ return s.format(**self.__dict__)
+
+
diff --git a/examples/BMTrain/example/layers/feedforward.py b/examples/BMTrain/example/layers/feedforward.py
new file mode 100644
index 00000000..e88d2495
--- /dev/null
+++ b/examples/BMTrain/example/layers/feedforward.py
@@ -0,0 +1,23 @@
+import torch
+import bmtrain as bmt
+from bmtrain.nn import (
+ Linear,
+ ColumnParallelLinear,
+ RowParallelLinear)
+from bmtrain.global_var import config
+
+class Feedforward(bmt.DistributedModule):
+ def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None:
+ super().__init__()
+
+ if config['tp_size'] > 1:
+ self.w_in = ColumnParallelLinear(dim_model, dim_ff, bias = bias, dtype=dtype)
+ self.w_out = RowParallelLinear(dim_ff, dim_model, bias = bias, dtype=dtype)
+ else:
+ self.w_in = Linear(dim_model, dim_ff, bias=bias, dtype=dtype)
+ self.w_out = Linear(dim_ff, dim_model, bias=bias, dtype=dtype)
+
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, input : torch.Tensor) -> torch.Tensor:
+ return self.w_out(self.relu(self.w_in(input)))
diff --git a/examples/BMTrain/example/layers/layernorm.py b/examples/BMTrain/example/layers/layernorm.py
new file mode 100644
index 00000000..9f3e3bc2
--- /dev/null
+++ b/examples/BMTrain/example/layers/layernorm.py
@@ -0,0 +1,34 @@
+from typing import Tuple
+import torch
+import torch.nn.functional as F
+import bmtrain as bmt
+
+class Layernorm(bmt.DistributedModule):
+ __constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
+ normalized_shape: Tuple[int, ...]
+ eps: float
+ elementwise_affine: bool
+
+ def __init__(self, normalized_shape, eps: float = 1e-5, elementwise_affine: bool = True,
+ dtype=None) -> None:
+ super().__init__()
+ if isinstance(normalized_shape, int):
+ # mypy error: incompatible types in assignment
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
+ self.eps = eps
+ self.elementwise_affine = elementwise_affine
+ if self.elementwise_affine:
+ self.weight = bmt.DistributedParameter(torch.empty(self.normalized_shape, dtype=dtype, device="cuda"), init_method=torch.nn.init.ones_)
+ self.bias = bmt.DistributedParameter(torch.empty(self.normalized_shape, dtype=dtype, device="cuda"), init_method=torch.nn.init.zeros_)
+ else:
+ self.register_parameter('weight', None)
+ self.register_parameter('bias', None)
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return F.layer_norm(
+ input, self.normalized_shape, self.weight, self.bias, self.eps)
+
+ def extra_repr(self) -> str:
+ return '{normalized_shape}, eps={eps}, ' \
+ 'elementwise_affine={elementwise_affine}'.format(**self.__dict__)
diff --git a/examples/BMTrain/example/layers/transformer.py b/examples/BMTrain/example/layers/transformer.py
new file mode 100644
index 00000000..4cbff59b
--- /dev/null
+++ b/examples/BMTrain/example/layers/transformer.py
@@ -0,0 +1,34 @@
+from typing import Optional
+import torch
+import bmtrain as bmt
+from layers import Layernorm, Feedforward, Attention
+
+class TransformerEncoder(bmt.DistributedModule):
+ def __init__(self,
+ dim_model : int, dim_head : int, num_heads : int, dim_ff : int,
+ bias : bool = True, dtype = None
+ ) -> None:
+ super().__init__()
+
+ self.ln_attn = Layernorm(dim_model, dtype=dtype)
+ self.attn = Attention(dim_model, dim_head, num_heads, bias=bias, dtype=dtype)
+
+ self.ln_ff = Layernorm(dim_model, dtype=dtype)
+ self.ff = Feedforward(dim_model, dim_ff, bias=bias, dtype=dtype)
+
+ def forward(self,
+ hidden : torch.Tensor, # (batch, seq_len, dim_model)
+ mask : torch.BoolTensor, # (batch, seq_len, dim_model)
+ position_bias : Optional[torch.Tensor] = None, # (batch, num_head, seq_len, seq_len)
+ ):
+ bmt.inspect.record_tensor(hidden, "hidden")
+ x = self.ln_attn(hidden)
+ x = self.attn(x, x, mask, position_bias)
+ hidden = hidden + x
+
+ x = self.ln_ff(hidden)
+ x = self.ff(x)
+ hidden = hidden + x.view_as(hidden)
+
+ return hidden
+
diff --git a/examples/BMTrain/example/models/__init__.py b/examples/BMTrain/example/models/__init__.py
new file mode 100644
index 00000000..e7d1dcc9
--- /dev/null
+++ b/examples/BMTrain/example/models/__init__.py
@@ -0,0 +1 @@
+from .gpt import GPT
\ No newline at end of file
diff --git a/examples/BMTrain/example/models/gpt.py b/examples/BMTrain/example/models/gpt.py
new file mode 100644
index 00000000..ed604382
--- /dev/null
+++ b/examples/BMTrain/example/models/gpt.py
@@ -0,0 +1,64 @@
+import torch
+import bmtrain as bmt
+from layers import TransformerEncoder, Layernorm, Embedding, TransformerEncoder
+from bmtrain.global_var import config
+
+class GPT(bmt.DistributedModule):
+ def __init__(self,
+ num_layers : int, vocab_size : int,
+ dim_model : int, dim_head : int, num_heads : int, dim_ff : int,
+ max_distance : int,
+ bias : bool = True, dtype = None
+ ) -> None:
+ super().__init__()
+
+ self.max_distance = max_distance
+
+ if config["tp_size"] > 1:
+ self.word_emb = bmt.nn.VPEmbedding(vocab_size, dim_model, dtype=dtype)
+ else:
+ self.word_emb = Embedding(vocab_size, dim_model, dtype=dtype)
+ self.pos_emb = Embedding(max_distance, dim_model, dtype=dtype)
+
+ if config['pipe_size'] > 1:
+ self.transformers = bmt.PipelineTransformerBlockList([
+ bmt.Block(
+ TransformerEncoder(
+ dim_model, dim_head, num_heads, dim_ff, bias, dtype
+ )
+ , mode="PIPE"
+ )
+ for _ in range(num_layers)
+ ])
+ else:
+ self.transformers = bmt.TransformerBlockList([
+ bmt.Block(
+ TransformerEncoder(
+ dim_model, dim_head, num_heads, dim_ff, bias, dtype
+ )
+ )
+ for _ in range(num_layers)
+ ])
+
+ self.layernorm = Layernorm(dim_model, dtype=dtype)
+
+ def forward(self,
+ input : torch.LongTensor, # (batch, seq_len)
+ pos : torch.LongTensor, # (batch, seq_len)
+ mask : torch.BoolTensor, # (batch, seq_len)
+ ) -> torch.Tensor:
+
+ mask_2d = mask[:, None, :] & mask[:, :, None] # (batch, seq_len, seq_len)
+ mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None])
+ if config["tp_size"] > 1:
+ input = input.chunk(config["tp_size"], dim=1)[config["tp_rank"]]
+ pos = pos.chunk(config["tp_size"], dim=1)[config["tp_rank"]]
+ out = self.pos_emb(pos) + self.word_emb(input)
+
+ # for layer in self.transformers:
+ out = self.transformers(out, mask_2d, None)
+ out = self.layernorm(out)
+ logits = self.word_emb(out, projection=True)
+ bmt.inspect.record_tensor(logits, "logits")
+
+ return logits
diff --git a/examples/BMTrain/example/run.sh b/examples/BMTrain/example/run.sh
new file mode 100644
index 00000000..542e5252
--- /dev/null
+++ b/examples/BMTrain/example/run.sh
@@ -0,0 +1,3 @@
+export NCCL_P2P_DISABLE=1
+export CUDA_LAUNCH_BLOCKING=1
+torchrun --nnodes=1 --nproc_per_node=4 --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=localhost train.py
diff --git a/examples/BMTrain/example/sbatch.sh b/examples/BMTrain/example/sbatch.sh
new file mode 100644
index 00000000..3b93d99b
--- /dev/null
+++ b/examples/BMTrain/example/sbatch.sh
@@ -0,0 +1,20 @@
+#!/bin/bash
+#SBATCH --job-name=cpm2-test
+#SBATCH --partition=rtx2080
+
+#SBATCH --nodes=1
+#SBATCH --gpus-per-node=8
+
+
+MASTER_PORT=30123
+MASTER_HOST=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
+
+# load python virtualenv if you have
+# source /path/to/python/virtualenv/bin/activate
+
+# uncomment to print nccl debug info
+# export NCCL_DEBUG=info
+
+srun torchrun --nnodes=$SLURM_JOB_NUM_NODES --nproc_per_node=$SLURM_GPUS_PER_NODE --rdzv_id=$SLURM_JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$MASTER_HOST:$MASTER_PORT train.py
+
+
diff --git a/examples/BMTrain/example/train.py b/examples/BMTrain/example/train.py
new file mode 100644
index 00000000..d5906a06
--- /dev/null
+++ b/examples/BMTrain/example/train.py
@@ -0,0 +1,138 @@
+import torch
+import bmtrain as bmt
+from models import GPT
+import time
+from bmtrain import optim
+from bmtrain.global_var import config
+from bmtrain import inspect
+
+def main():
+ bmt.init_distributed(
+ seed=0,
+ tp_size=2,
+ )
+
+ model = GPT(
+ num_layers=8,
+ vocab_size=10240,
+ dim_model=2560,
+ dim_head=80,
+ num_heads=32,
+ dim_ff=8192,
+ max_distance=1024,
+ bias=True,
+ dtype=torch.half
+ )
+
+ bmt.init_parameters(model)
+
+ bmt.print_rank("Model memory")
+ bmt.print_rank(torch.cuda.memory_summary())
+ bmt.synchronize()
+
+ # data
+ # generate dummy data for each rank
+ torch.manual_seed(1234)
+
+ batch_size = 2
+ seq_len = 512
+ world_size = bmt.config["world_size"] if bmt.config["tp_size"] == 1 else bmt.config["tp_zero_size"]
+ r = bmt.config["rank"] if bmt.config["tp_size"] == 1 else bmt.config["tp_zero_rank"]
+
+ for i in range(world_size):
+ sent = torch.randint(0, 10240, (batch_size, seq_len + 1))
+ enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda()
+ enc_input = sent[:, :-1].long().cuda()
+ targets = sent[:, 1:].long().cuda()
+ mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None]
+ targets = torch.where(
+ mask,
+ targets,
+ torch.full_like(targets, -100, dtype=torch.long)
+ )
+
+ if i == r:
+ break
+
+ if config['tp_size'] > 1:
+ loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True)
+ else:
+ loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100)
+
+ optimizer = optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2)
+ lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0)
+
+ optim_manager = optim.OptimManager(loss_scale=2**20)
+ optim_manager.add_optimizer(optimizer, lr_scheduler)
+
+ bmt.synchronize()
+
+ avg_time_recorder = bmt.utils.AverageRecorder()
+ avg_loss_recorder = bmt.utils.AverageRecorder()
+
+ for iteration in range(1000):
+ # load data
+ st = time.time()
+
+ with inspect.inspect_tensor() as inspector:
+ pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1)
+ logits = model(
+ enc_input,
+ pos,
+ pos < enc_length[:, None]
+ )
+ batch, seq_len, vocab_out_size = logits.size()
+
+ if config['tp_size'] > 1:
+ loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len))
+ else:
+ loss = loss_func(logits.float().view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len))
+
+ global_loss = bmt.sum_loss(loss).item()
+
+ optim_manager.zero_grad()
+
+ optim_manager.backward(loss)
+
+ # print inspected tensors in the forward & backward pass
+ # print parameters of the model
+ if iteration % 100 == 0:
+ bmt.print_rank(
+ inspect.format_summary(
+ inspector.get_summary()
+ )
+ )
+ bmt.print_rank(
+ inspect.format_summary(
+ inspect.inspect_model(model, "*")
+ )
+ )
+
+ optim_manager.step()
+
+ # record time and loss
+ iteration_time = time.time() - st
+
+ avg_time_recorder.record(iteration_time)
+ avg_loss_recorder.record(global_loss)
+
+ # print time and loss
+ bmt.print_rank(
+ "| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} scale: {:10.4f} | time: {:.4f}".format(
+ iteration,
+ global_loss,
+ avg_loss_recorder.value,
+ lr_scheduler.current_lr,
+ optim_manager.loss_scale,
+ avg_time_recorder.value
+ )
+ )
+
+ # save model
+ if iteration % 1000 == 0:
+ bmt.save(model, "ckpt-%d.pt" % iteration)
+
+ bmt.save(model, "checkpoint.pt")
+
+if __name__ == '__main__':
+ main()
diff --git a/examples/BMTrain/other_requirements.txt b/examples/BMTrain/other_requirements.txt
new file mode 100644
index 00000000..6654b1ac
--- /dev/null
+++ b/examples/BMTrain/other_requirements.txt
@@ -0,0 +1,6 @@
+tqdm
+cpm_kernels>=1.0.11
+jieba
+tensorboard
+setuptools_rust
+transformers
\ No newline at end of file
diff --git a/examples/BMTrain/pyproject.toml b/examples/BMTrain/pyproject.toml
new file mode 100644
index 00000000..b563eb32
--- /dev/null
+++ b/examples/BMTrain/pyproject.toml
@@ -0,0 +1,8 @@
+[build-system]
+requires = [
+ "setuptools",
+ "pybind11",
+ "nvidia-nccl-cu11 >= 2.14.3",
+ "cmake > 3.27.0"
+]
+build-backend = "setuptools.build_meta"
diff --git a/examples/BMTrain/setup.py b/examples/BMTrain/setup.py
new file mode 100644
index 00000000..70752ff6
--- /dev/null
+++ b/examples/BMTrain/setup.py
@@ -0,0 +1,113 @@
+import os
+import shutil
+from setuptools.command.build_ext import build_ext
+from setuptools import setup, find_packages, Extension
+import setuptools
+import warnings
+import sys
+import subprocess
+class CMakeExtension(Extension):
+ def __init__(self, name, sourcedir=""):
+ Extension.__init__(self, name, sources=[])
+ self.sourcedir = os.path.abspath(sourcedir)
+
+
+def is_ninja_available():
+ r'''
+ Returns ``True`` if the `ninja `_ build system is
+ available on the system, ``False`` otherwise.
+ '''
+ with open(os.devnull, 'wb') as devnull:
+ try:
+ subprocess.check_call('ninja --version'.split(), stdout=devnull)
+ except OSError:
+ return False
+ else:
+ return True
+
+
+class CMakeBuild(build_ext):
+
+ def build_extension(self, ext):
+ extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name)))
+
+ # required for auto-detection & inclusion of auxiliary "native" libs
+ if not extdir.endswith(os.path.sep):
+ extdir += os.path.sep
+
+ debug = int(os.environ.get("DEBUG", 0)) if self.debug is None else self.debug
+ cfg = "Debug" if debug else "Release"
+
+ # CMake lets you override the generator - we need to check this.
+ # Can be set with Conda-Build, for example.
+ cmake_generator = os.environ.get("CMAKE_GENERATOR", "")
+
+ # Set Python_EXECUTABLE instead if you use PYBIND11_FINDPYTHON
+ # EXAMPLE_VERSION_INFO shows you how to pass a value into the C++ code
+ # from Python.
+ cmake_args = [
+ f"-DCMAKE_CXX_STANDARD=14",
+ f"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={extdir}",
+ f"-DPYTHON_EXECUTABLE={sys.executable}",
+ f"-DPYTHON_VERSION={sys.version_info.major}.{sys.version_info.minor}",
+ f"-DCMAKE_BUILD_TYPE={cfg}", # not used on MSVC, but no harm
+ ]
+
+ build_args = []
+ if "CMAKE_ARGS" in os.environ:
+ cmake_args += [item for item in os.environ["CMAKE_ARGS"].split(" ") if item]
+
+ cmake_args += [f"-DEXAMPLE_VERSION_INFO={self.distribution.get_version()}"]
+
+
+ if not cmake_generator or cmake_generator == "Ninja":
+ try:
+ import ninja # noqa: F401
+
+ ninja_executable_path = os.path.join(ninja.BIN_DIR, "ninja")
+ cmake_args += [
+ "-GNinja",
+ f"-DCMAKE_MAKE_PROGRAM:FILEPATH={ninja_executable_path}",
+ ]
+ except ImportError:
+ pass
+
+ if "CMAKE_BUILD_PARALLEL_LEVEL" not in os.environ:
+ # self.parallel is a Python 3 only way to set parallel jobs by hand
+ # using -j in the build_ext call, not supported by pip or PyPA-build.
+ if hasattr(self, "parallel") and self.parallel:
+ # CMake 3.12+ only.
+ build_args += [f"-j{self.parallel}"]
+
+ build_temp = os.path.join(self.build_temp, ext.name)
+ if os.path.exists(build_temp):
+ shutil.rmtree(build_temp)
+ os.makedirs(build_temp)
+
+ cmake_args += ["-DPython_ROOT_DIR=" + os.path.dirname(os.path.dirname(sys.executable))]
+ subprocess.check_call(["cmake", ext.sourcedir] + cmake_args, cwd=build_temp)
+ subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=build_temp)
+
+ext_modules = [
+ CMakeExtension("bmtrain.C"),
+]
+setup(
+ name='bmtrain',
+ version='1.0.1',
+ author="Guoyang Zeng",
+ author_email="qbjooo@qq.com",
+ description="A toolkit for training big models",
+ packages=find_packages(),
+ install_requires=[
+ "numpy",
+ "nvidia-nccl-cu11>=2.14.3"
+ ],
+ setup_requires=[
+ "pybind11",
+ "nvidia-nccl-cu11>=2.14.3"
+ ],
+ ext_modules=ext_modules,
+ cmdclass={
+ 'build_ext': CMakeBuild
+ })
+
diff --git a/examples/BMTrain/tests/test_all.py b/examples/BMTrain/tests/test_all.py
new file mode 100644
index 00000000..db5d2dd4
--- /dev/null
+++ b/examples/BMTrain/tests/test_all.py
@@ -0,0 +1,43 @@
+import subprocess
+from tqdm import tqdm
+
+
+tq = tqdm([
+ ("different_output_shape", 1),
+ ("load_ckpt", 1),
+ ("init_parameters", 1),
+ ("synchronize", 4),
+ ("init_parameters_multi_gpu", 4),
+ ("optim_state", 4),
+
+ ("requires_grad", 1),
+ ("requires_grad_multi_gpu", 2),
+ ("has_inf_nan", 1),
+ ("dropout", 1),
+ ("loss_func", 1),
+
+ ("optim", 1),
+
+ ("multi_return", 2),
+ ("middle_hidden", 4),
+ ("other_hidden", 4),
+
+ ("model_wrapper", 4),
+
+ ("send_recv", 4),
+ ("nccl_backward", 4),
+ ("no_grad", 1),
+ ("column_parallel_linear", 2),
+ ("row_parallel_linear", 2),
+ ("parallel_projection", 4),
+
+ ("training", 4),
+])
+
+for t, num_gpu in tq:
+ PREFIX = f"python3 -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node={num_gpu} --master_addr=localhost --master_port=32123"
+ SUFFIX = f"> test_log.txt 2>&1"
+ command = f"{PREFIX} test_{t}.py {SUFFIX}"
+ completedProc = subprocess.run(command, shell=True)
+ assert completedProc.returncode == 0, f"test {t} failed, see test_log.txt for more details."
+ print(f"finish testing {t}")
diff --git a/examples/BMTrain/tests/test_column_parallel_linear.py b/examples/BMTrain/tests/test_column_parallel_linear.py
new file mode 100644
index 00000000..5f2fdad1
--- /dev/null
+++ b/examples/BMTrain/tests/test_column_parallel_linear.py
@@ -0,0 +1,74 @@
+import torch
+import bmtrain as bmt
+from bmtrain.global_var import config
+import numpy as np
+
+def run_bmt(x, gather_input, gather_output, ckp_path, tp_size=2):
+ linear = bmt.nn.ColumnParallelLinear(8,8, gather_input=gather_input, gather_output=gather_output)
+ linear = bmt.Block(linear)
+ bmt.init_parameters(linear)
+ y = linear(x)
+ y.sum().backward()
+ bmt.save(linear, ckp_path)
+ bmt.synchronize()
+ return y, linear._parameters['weight'].grad, linear._parameters['bias'].grad
+
+def run_torch(x, ckp_path):
+ linear = torch.nn.Linear(8, 8)
+ linear_dict = torch.load(ckp_path)
+ linear.load_state_dict(linear_dict)
+ linear = linear.cuda()
+ linear.weight.requires_grad_()
+ y = linear(x)
+ y.sum().backward()
+ return y, linear.weight.grad, linear.bias.grad
+
+def run(gather_input, gather_output, ckp_path):
+ torch.cuda.manual_seed(100)
+ tp_size = config["tp_size"]
+ tp_rank = config['topology'].tp_id
+ x = torch.randn(8, 8, 8, device='cuda')
+ bmt_x = x.clone()
+ if gather_input:
+ rank_x = bmt_x.chunk(tp_size, dim=0)[tp_rank]
+ else:
+ rank_x = bmt_x
+ rank_x.requires_grad_()
+ x.requires_grad_()
+ y1, weight_grad1, bias_grad1 = run_bmt(rank_x, gather_input, gather_output, ckp_path)
+ y2, weight_grad2, bias_grad2 = run_torch(x, ckp_path)
+ tp_rank = config['topology'].tp_id
+ if gather_output:
+ assert np.allclose(y1.detach().cpu().numpy(), y2.detach().cpu().numpy())
+ else:
+ torch_out_list = torch.split(y2, y2.size()[-1] // tp_size, dim=-1)
+ assert np.allclose(y1.detach().cpu().numpy(), torch_out_list[tp_rank].detach().cpu().numpy())
+
+ weight_grad_list = weight_grad2.chunk(tp_size, dim=0)
+ assert np.allclose(weight_grad1.reshape(weight_grad_list[tp_rank].shape).cpu().numpy(), weight_grad_list[tp_rank].cpu().numpy())
+
+ bias_grad_list = bias_grad2.chunk(tp_size, dim=0)
+ assert np.allclose(bias_grad1.reshape(bias_grad_list[tp_rank].shape).cpu().numpy(), bias_grad_list[tp_rank].cpu().numpy())
+
+ if gather_input:
+ x_grad_list = x.grad.chunk(tp_size, dim=0)
+ np.testing.assert_allclose(rank_x.grad.cpu().numpy(), x_grad_list[tp_rank].cpu().numpy(), atol=1e-4, rtol=1e-4)
+ else:
+ np.testing.assert_allclose(rank_x.grad.cpu().numpy(), x.grad.cpu().numpy(), atol=1e-4, rtol=1e-4)
+
+def test_gather_output():
+ run(True, True, 'linear.ckp')
+
+def test_no_gather_output():
+ run(True, False, 'linear_no_gather.ckp')
+
+def test_no_gather_input():
+ run(False, True, 'linear.ckp')
+
+
+if __name__ == "__main__":
+ bmt.init_distributed(tp_size=2)
+ test_gather_output()
+ test_no_gather_output()
+ test_no_gather_input()
+
diff --git a/examples/BMTrain/tests/test_different_output_shape.py b/examples/BMTrain/tests/test_different_output_shape.py
new file mode 100644
index 00000000..bb8ab7fa
--- /dev/null
+++ b/examples/BMTrain/tests/test_different_output_shape.py
@@ -0,0 +1,50 @@
+import torch
+import bmtrain as bmt
+
+class Block0(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return x
+
+class Block1(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return (x,)
+
+class Block2(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return (x, x)
+
+class Block10(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ return [x, x, x, x, x, x, x, x, x, x]
+
+if __name__ == "__main__":
+ bmt.init_distributed()
+ x = torch.tensor([1,2,3.])
+
+ b = bmt.Block(Block0())
+ y = b(x)
+ assert isinstance(y, torch.Tensor)
+
+ b = bmt.Block(Block1())
+ y = b(x)
+ assert isinstance(y, tuple) and len(y)==1
+
+ b = bmt.Block(Block2())
+ y = b(x)
+ assert isinstance(y, tuple) and len(y)==2
+
+ b = bmt.Block(Block10())
+ y = b(x)
+ assert isinstance(y, tuple) and len(y)==10
diff --git a/examples/BMTrain/tests/test_dropout.py b/examples/BMTrain/tests/test_dropout.py
new file mode 100644
index 00000000..d29240a1
--- /dev/null
+++ b/examples/BMTrain/tests/test_dropout.py
@@ -0,0 +1,45 @@
+from utils import *
+
+import torch
+import bmtrain as bmt
+
+class InnerModule(bmt.DistributedModule):
+ def __init__(self):
+ super().__init__()
+
+ self.drop = torch.nn.Dropout(p=0.5)
+
+ def forward(self, x):
+ return self.drop(x)
+
+class OutterModule(bmt.DistributedModule):
+ def __init__(self) -> None:
+ super().__init__()
+
+ self.blk = bmt.TransformerBlockList([
+ bmt.Block(InnerModule())
+ for _ in range(5)
+ ])
+
+ def forward(self, x):
+ return self.blk(x)
+
+def test_main():
+ model = OutterModule()
+
+ for _ in range(5):
+ model.train()
+ x = torch.ones(32, device="cuda")
+ y = model(x)
+ print(y)
+ assert_neq(x.numel()-y.nonzero().size(0), 0)
+
+ model.eval()
+ x = torch.ones(32, device="cuda")
+ y = model(x)
+ print(y)
+ assert_eq(x.numel()-y.nonzero().size(0), 0)
+
+if __name__ == "__main__":
+ bmt.init_distributed()
+ test_main()
diff --git a/examples/BMTrain/tests/test_grad_accu.py b/examples/BMTrain/tests/test_grad_accu.py
new file mode 100644
index 00000000..48dbe8ac
--- /dev/null
+++ b/examples/BMTrain/tests/test_grad_accu.py
@@ -0,0 +1,80 @@
+import bmtrain as bmt
+import torch
+from bmtrain import config
+from bmtrain.block_layer import CheckpointBlockContext, CheckpointBlock, TransformerBlockList
+from bmtrain.pipe_layer import PipelineTransformerBlockList
+from typing import List
+import torch.nn.functional as F
+def print_rank0(s):
+ if bmt.rank() == 0:
+ print(s)
+class Linear(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.out = {}
+ if init_weight:
+ self.weight = bmt.DistributedParameter(torch.tensor(init_weight, dtype=torch.float, device="cuda").reshape(out_features, in_features))
+ else:
+ self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.xavier_normal_)
+
+ if init_bias:
+ self.bias = bmt.DistributedParameter(torch.tensor(init_bias, dtype=torch.float, device="cuda").reshape(out_features,))
+ else:
+ self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_)
+
+ def forward(self, input):
+ ret = F.linear(input, self.weight, self.bias)
+ return ret
+
+def test_grad_accu():
+ # normal distribute module
+ m = Linear(256, 256)
+ inp = torch.randn((1, 10, 256), device="cuda")
+ logits = m(inp)
+ loss = logits.sum()
+ loss.backward()
+ grad1 = m._parameters["weight"].grad.clone()
+ logits = m(inp)
+ loss = logits.sum()
+ loss.backward()
+ grad2 = m._parameters["weight"].grad
+ assert torch.allclose(grad1*2, grad2)
+ print_rank0("grad accumulation for distribute module passed")
+ # checkpoint block
+ m = CheckpointBlock(Linear(256, 256))
+ inp = torch.randn((1, 10, 256), device="cuda")
+ logits = m(inp)
+ loss = logits.sum()
+ loss.backward()
+ bmt.synchronize()
+ grad1 = m.weight.grad.clone()
+ logits = m(inp)
+ loss = logits.sum()
+ loss.backward()
+ bmt.synchronize()
+ grad2 = m.weight.grad.clone()
+ assert torch.allclose(grad1*2, grad2)
+ print_rank0("grad accumulation for checkpointblock passed")
+ # transformer block list
+ m = TransformerBlockList([CheckpointBlock(Linear(256, 256))])
+ inp = torch.randn((1, 10, 256), device="cuda")
+ logits = m(inp)
+ loss = logits.sum()
+ loss.backward()
+ bmt.synchronize()
+ grad1 = m[0].weight.grad.clone()
+ logits = m(inp)
+ loss = logits.sum()
+ loss.backward()
+ bmt.synchronize()
+ grad2 = m[0].weight.grad
+ assert torch.allclose(grad1*2, grad2)
+ print_rank0("grad accumulation for TransformerBlockList passed")
+
+
+if __name__ == "__main__":
+ bmt.init_distributed()
+ test_grad_accu()
\ No newline at end of file
diff --git a/examples/BMTrain/tests/test_has_inf_nan.py b/examples/BMTrain/tests/test_has_inf_nan.py
new file mode 100644
index 00000000..93ac8118
--- /dev/null
+++ b/examples/BMTrain/tests/test_has_inf_nan.py
@@ -0,0 +1,37 @@
+from utils import *
+import torch
+import bmtrain.loss._function as F
+import random
+
+def check(x, v):
+ out = torch.zeros(1, dtype=torch.uint8, device="cuda")[0]
+ F.has_inf_nan(x, out)
+ assert_eq(out.item(), v)
+
+def test_main(dtype):
+ for i in list(range(1, 100)) + [1000]*10 + [10000]*10 + [100000]*10 + [1000000]*10:
+ x = torch.rand((i,)).to(dtype).cuda()
+ check(x, 0)
+ p = random.randint(0, i-1)
+ x[p] = x[p] / 0
+ check(x, 1)
+ x[p] = 2
+ check(x, 0)
+ p = random.randint(0, i-1)
+ x[p] = 0
+ x[p] = x[p] / 0
+ check(x, 1)
+ p = random.randint(0, i-1)
+ x[p] = x[p] / 0
+ p = random.randint(0, i-1)
+ x[p] = x[p] / 0
+ check(x, 1)
+ print("That's right")
+
+if __name__ == "__main__":
+ test_main(torch.float16)
+ print("==============================================================================")
+ try:
+ test_main(torch.bfloat16)
+ except NotImplementedError:
+ pass
diff --git a/examples/BMTrain/tests/test_init_parameters.py b/examples/BMTrain/tests/test_init_parameters.py
new file mode 100644
index 00000000..b67431f2
--- /dev/null
+++ b/examples/BMTrain/tests/test_init_parameters.py
@@ -0,0 +1,223 @@
+from utils import *
+
+import torch
+import torch.nn.functional as F
+import bmtrain as bmt
+
+def manual_seed(seed=33):
+ torch.manual_seed(seed)
+ import random as random
+ random.seed(seed)
+ try:
+ import numpy as np
+ np.random.seed(seed)
+ except ModuleNotFoundError:
+ pass
+
+class Linear_NormalInitBefore(torch.nn.Module):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ w = torch.empty(out_features, in_features, dtype=dtype, device="cuda") # use cuda to match random algorithm
+ torch.nn.init.xavier_normal_(w)
+ self.weight = torch.nn.Parameter(w)
+ if bias:
+ b = torch.empty(out_features, dtype=dtype, device="cuda") # use cuda to match random algorithm
+ torch.nn.init.zeros_(b)
+ self.bias = torch.nn.Parameter(b)
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, input):
+ return F.linear(input, self.weight, self.bias)
+
+class Linear_NormalInitAfter(torch.nn.Module):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype, device="cuda")) # use cuda to match random algorithm
+ torch.nn.init.xavier_normal_(self.weight)
+ if bias:
+ self.bias = torch.nn.Parameter(torch.empty(out_features, dtype=dtype, device="cuda")) # use cuda to match random algorithm
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, input):
+ return F.linear(input, self.weight, self.bias)
+
+class Linear_BMTInitializer(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=dtype), init_method=torch.nn.init.xavier_normal_)
+ if bias:
+ self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=dtype), init_method=torch.nn.init.zeros_)
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, input):
+ return F.linear(input, self.weight, self.bias)
+
+class Linear_ManualInitBefore(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ w = torch.empty(out_features, in_features, dtype=dtype, device="cuda") # use cuda to match random algorithm
+ torch.nn.init.xavier_normal_(w)
+ self.weight = bmt.DistributedParameter(w)
+ if bias:
+ b = torch.empty(out_features, dtype=dtype, device="cuda") # use cuda to match random algorithm
+ torch.nn.init.zeros_(b)
+ self.bias = bmt.DistributedParameter(b)
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, input):
+ return F.linear(input, self.weight, self.bias)
+
+class Linear_ManualInitAfter(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=dtype, device="cuda")) # use cuda to match random algorithm
+ # torch.nn.init.xavier_normal_(self.weight)
+ if bias:
+ self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=dtype, device="cuda")) # use cuda to match random algorithm
+ # torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, input):
+ return F.linear(input, self.weight, self.bias)
+
+ def extra_repr(self) -> str:
+ return 'in_features={}, out_features={}, bias={}'.format(
+ self.in_features, self.out_features, self.bias is not None
+ )
+
+class Linear_Pipeline(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.l = bmt.PipelineTransformerBlockList([
+ Linear_BMTInitializer(in_features, out_features, bias, dtype),
+ Linear_BMTInitializer(in_features, out_features, bias, dtype),
+ ])
+
+ def forward(self, input):
+ return self.l(input)
+
+class Linear_BlockList(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.l = bmt.TransformerBlockList([
+ Linear_BMTInitializer(in_features, out_features, bias, dtype),
+ Linear_BMTInitializer(in_features, out_features, bias, dtype),
+ ])
+
+ def forward(self, input):
+ return self.l(input)
+
+class Linear_CheckpointList(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.l = torch.nn.ModuleList([
+ bmt.CheckpointBlock(Linear_BMTInitializer(in_features, out_features, bias, dtype)),
+ bmt.CheckpointBlock(Linear_BMTInitializer(in_features, out_features, bias, dtype)),
+ ])
+
+ def forward(self, input):
+ return self.l(input)
+
+class Linear_Checkpoint(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.l = bmt.CheckpointBlock(Linear_BMTInitializer(in_features, out_features, bias, dtype))
+
+ def forward(self, input):
+ return self.l(input)
+
+def test_main():
+ shape = [3, 5]
+ # torch
+ m = [None] * 10
+ ret = [None] * 10
+ manual_seed(33)
+ m[0] = Linear_NormalInitBefore(*shape)
+ ret[0] = (m[0].weight.data, m[0].bias.data)
+
+ manual_seed(33)
+ m[1] = Linear_NormalInitAfter(*shape)
+ ret[1] = (m[1].weight.data, m[1].bias.data)
+
+ # bmtrain
+ manual_seed(33)
+ m[2] = Linear_BMTInitializer(*shape)
+ bmt.init_parameters(m[2])
+ bmt.synchronize()
+ ret[2] = (m[2].weight.data, m[2].bias.data)
+
+ manual_seed(33)
+ m[3] = Linear_ManualInitBefore(*shape)
+ bmt.synchronize()
+ ret[3] = (m[3].weight.data, m[3].bias.data)
+
+ # manual_seed(33)
+ # mw = Linear_ManualInitAfter(*shape) # not supported
+ # print(mw.weight.data, mw.bias.data)
+
+ manual_seed(33)
+ m[4] = bmt.BMTrainModelWrapper(m[0])
+ ret[4] = (m[4].weight.data, m[4].bias.data)
+
+ manual_seed(33)
+ m[5] = bmt.BMTrainModelWrapper(m[1])
+ ret[5] = (m[5].weight.data, m[5].bias.data)
+
+ manual_seed(33)
+ m[6] = Linear_Pipeline(*shape)
+ bmt.init_parameters(m[6])
+ ret[6] = (m[6].l[0].weight.data, m[6].l[0].bias.data)
+
+ manual_seed(33)
+ m[7] = Linear_BlockList(*shape)
+ bmt.init_parameters(m[7])
+ ret[7] = (m[7].l[0].weight.data, m[7].l[0].bias.data)
+
+ manual_seed(33)
+ m[8] = Linear_CheckpointList(*shape)
+ bmt.init_parameters(m[8])
+ ret[8] = (m[8].l[0].weight.data, m[8].l[0].bias.data)
+
+ manual_seed(33)
+ m[9] = Linear_Checkpoint(*shape)
+ bmt.init_parameters(m[9])
+ ret[9] = (m[9].l.weight.data, m[9].l.bias.data)
+
+ for i in range(10):
+ ret[i] = ( ret[i][0].view(-1), ret[i][1].view(-1) )
+ print(ret[i])
+ for i in range(10):
+ for j in range(10):
+ print(i, j)
+ assert_all_eq(ret[i][0], ret[j][0])
+ assert_all_eq(ret[i][1], ret[j][1])
+
+if __name__ == "__main__":
+ bmt.init_distributed(pipe_size=1)
+
+ test_main()
\ No newline at end of file
diff --git a/examples/BMTrain/tests/test_init_parameters_multi_gpu.py b/examples/BMTrain/tests/test_init_parameters_multi_gpu.py
new file mode 100644
index 00000000..1e61568c
--- /dev/null
+++ b/examples/BMTrain/tests/test_init_parameters_multi_gpu.py
@@ -0,0 +1,146 @@
+from utils import *
+
+import os
+import torch
+import torch.nn.functional as F
+import bmtrain as bmt
+
+def manual_seed(seed=33):
+ torch.manual_seed(seed)
+ import random as random
+ random.seed(seed)
+ try:
+ import numpy as np
+ np.random.seed(seed)
+ except ModuleNotFoundError:
+ pass
+
+class Linear_Normal(torch.nn.Module):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype, device="cuda")) # use cuda to match random algorithm
+ torch.nn.init.xavier_normal_(self.weight)
+ if bias:
+ self.bias = torch.nn.Parameter(torch.empty(out_features, dtype=dtype, device="cuda")) # use cuda to match random algorithm
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, input):
+ return F.linear(input, self.weight, self.bias)
+
+class Linear_BMTInitializer(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=dtype), init_method=torch.nn.init.xavier_normal_)
+ if bias:
+ self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=dtype), init_method=torch.nn.init.zeros_)
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, input):
+ return F.linear(input, self.weight, self.bias)
+
+class Linear_NormalList(torch.nn.Module):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.l = torch.nn.ModuleList([
+ Linear_Normal(in_features, out_features, bias, dtype),
+ Linear_Normal(in_features, out_features, bias, dtype),
+ ])
+
+ def forward(self, input):
+ return self.l(input)
+
+class Linear_Pipeline(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.l = bmt.PipelineTransformerBlockList([
+ Linear_BMTInitializer(in_features, out_features, bias, dtype),
+ Linear_BMTInitializer(in_features, out_features, bias, dtype),
+ ])
+
+ def forward(self, input):
+ return self.l(input)
+
+class Linear_BlockList(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.l = bmt.TransformerBlockList([
+ Linear_BMTInitializer(in_features, out_features, bias, dtype),
+ Linear_BMTInitializer(in_features, out_features, bias, dtype),
+ ])
+
+ def forward(self, input):
+ return self.l(input)
+
+class Linear_CheckpointList(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.l = torch.nn.ModuleList([
+ bmt.CheckpointBlock(Linear_BMTInitializer(in_features, out_features, bias, dtype)),
+ bmt.CheckpointBlock(Linear_BMTInitializer(in_features, out_features, bias, dtype)),
+ ])
+
+ def forward(self, input):
+ return self.l(input)
+
+def check(ckpt_path, ckpt_path_ref):
+ if bmt.rank() == 0:
+ ckpt1 = torch.load(ckpt_path)
+ ckpt2 = torch.load(ckpt_path_ref)
+ for (k1, v1), (k2, v2) in zip(ckpt1.items(), ckpt2.items()):
+ assert_eq(k1, k2)
+ print(v1, v2)
+ assert_all_eq(v1.cuda(), v2.cuda())
+
+def test_main():
+ ckpt_path_ref = "test_ckpt_ref.pt"
+ ckpt_path = "test_ckpt.pt"
+ shape = [3, 5]
+ # torch
+ m = [None] * 4
+ ret = [None] * 4
+
+ manual_seed(33)
+ m[0] = Linear_NormalList(*shape)
+ if bmt.rank() == 0:
+ torch.save(m[0].state_dict(), ckpt_path_ref)
+
+ manual_seed(33)
+ m[1] = Linear_Pipeline(*shape)
+ bmt.init_parameters(m[1])
+ bmt.save(m[1], ckpt_path)
+ check(ckpt_path, ckpt_path_ref)
+
+ # bmtrain
+ manual_seed(33)
+ m[2] = Linear_BlockList(*shape)
+ bmt.init_parameters(m[2])
+ bmt.save(m[2], ckpt_path)
+ check(ckpt_path, ckpt_path_ref)
+
+ manual_seed(33)
+ m[3] = Linear_CheckpointList(*shape)
+ bmt.init_parameters(m[3])
+ bmt.save(m[3], ckpt_path)
+ check(ckpt_path, ckpt_path_ref)
+
+ if bmt.rank() == 0:
+ os.remove(ckpt_path)
+ os.remove(ckpt_path_ref)
+
+if __name__ == "__main__":
+ bmt.init_distributed(pipe_size=2)
+
+ test_main()
\ No newline at end of file
diff --git a/examples/BMTrain/tests/test_inspector_hidden.py b/examples/BMTrain/tests/test_inspector_hidden.py
new file mode 100644
index 00000000..62e03656
--- /dev/null
+++ b/examples/BMTrain/tests/test_inspector_hidden.py
@@ -0,0 +1,241 @@
+from utils import *
+
+import bmtrain as bmt
+import random
+import torch
+from bmtrain import config
+from bmtrain.block_layer import Block, TransformerBlockList
+from bmtrain.pipe_layer import PipelineTransformerBlockList
+import torch.nn.functional as F
+from bmtrain import inspect
+
+class Linear(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.out = {}
+ if init_weight:
+ self.weight = bmt.DistributedParameter(torch.tensor(init_weight, dtype=torch.float, device="cuda").reshape(out_features, in_features))
+ else:
+ self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.xavier_normal_)
+
+ if init_bias:
+ self.bias = bmt.DistributedParameter(torch.tensor(init_bias, dtype=torch.float, device="cuda").reshape(out_features,))
+ else:
+ self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_)
+
+ def forward(self, input):
+ ret = F.linear(input, self.weight, self.bias)
+ return ret
+
+class L2(bmt.DistributedModule):
+ def __init__(self, dim : int):
+ super().__init__()
+ self.m1 = Linear(dim, dim)
+ self.m2 = Linear(dim, dim)
+
+ def forward(self, x):
+ x = self.m1(x)
+ x = self.m2(x)
+ return x
+
+class L2_record(bmt.DistributedModule):
+ def __init__(self, dim : int):
+ super().__init__()
+ self.m1 = Linear(dim, dim)
+ self.m2 = Linear(dim, dim)
+
+ def forward(self, x):
+ x = self.m1(x)
+ inspect.record_tensor(x, "hidden")
+ x = self.m2(x)
+ return x
+
+class Model_ZERO(torch.nn.Module):
+ def __init__(self, pre, ms, post) -> None:
+ super().__init__()
+ self.pre = pre
+ self.ms = TransformerBlockList([
+ Block(m)
+ for m in ms
+ ])
+ self.post = post
+
+ def forward(self, x, return_hidden_states=False):
+ x = self.pre(x)
+ if return_hidden_states:
+ x, o = self.ms(x, return_hidden_states=return_hidden_states)
+ return self.post(x), o
+ else:
+ x = self.ms(x, return_hidden_states=return_hidden_states)
+ return self.post(x)
+
+class Model_PIPE(torch.nn.Module):
+ def __init__(self, pre, ms, post) -> None:
+ super().__init__()
+ self.pre = pre
+ self.ms = PipelineTransformerBlockList([
+ Block(m)
+ for m in ms
+ ])
+ self.post = post
+
+ def forward(self, x, return_hidden_states=False):
+ x = self.pre(x)
+ if return_hidden_states:
+ x, o = self.ms(x, return_hidden_states=return_hidden_states)
+ return self.post(x), o
+ else:
+ x = self.ms(x, return_hidden_states=return_hidden_states)
+ return self.post(x)
+
+class Model_BLOCK(torch.nn.Module):
+ def __init__(self, pre, ms, post) -> None:
+ super().__init__()
+ self.pre = pre
+ self.ms = torch.nn.ModuleList([
+ Block(m)
+ for m in ms
+ ])
+ self.post = post
+
+ def forward(self, x, return_hidden_states=False):
+ x = self.pre(x)
+ o = []
+ y = x
+ for m in self.ms:
+ o.append(y)
+ y = m(y)
+ if return_hidden_states:
+ return self.post(y), o
+ else:
+ return self.post(y)
+
+class Model_NORMAL(torch.nn.Module):
+ def __init__(self, pre, ms, post) -> None:
+ super().__init__()
+ self.pre = pre
+ self.ms = torch.nn.ModuleList(ms)
+ self.post = post
+
+ def forward(self, x, return_hidden_states=False):
+ x = self.pre(x)
+ o = []
+ y = x
+ for m in self.ms:
+ o.append(y)
+ y = m(y)
+ if return_hidden_states:
+ return self.post(y), o
+ else:
+ return self.post(y)
+
+def manual_seed(seed=33):
+ torch.manual_seed(seed)
+ random.seed(seed)
+ try:
+ import numpy as np
+ np.random.seed(seed)
+ except ModuleNotFoundError:
+ pass
+
+def sub_run(name, cls, num_layer, dim, batch, seq_len):
+ manual_seed()
+
+ pre = Linear(dim, dim)
+ post = Linear(dim, dim)
+ ms = [L2_record(dim) if i%2==0 else L2(dim) for i in range(num_layer)]
+
+ inp = torch.randn((batch, seq_len, dim)).cuda()
+ last_weight = torch.randn((batch, seq_len, dim)).cuda()
+ middle_weight = [
+ torch.randn((batch, seq_len, dim)).cuda()
+ for i in range(len(ms)//2)
+ ]
+
+ bmt.init_parameters(pre)
+ bmt.init_parameters(post)
+ for m in ms:
+ bmt.init_parameters(m)
+ m = cls(pre, [m for m in ms], post)
+ ret = ""
+ with inspect.inspect_tensor() as inspector:
+ logits = m(inp)
+
+ ret += inspect.format_summary(
+ inspector.get_summary()
+ ) + "\n"
+
+ loss = 0
+ for i in range(len(ms)//2):
+ loss = loss + (inspector.summary[i]['tensor'] * middle_weight[i]).sum()
+
+ with inspect.inspect_tensor():
+ loss.backward()
+
+ ret += inspect.format_summary(
+ inspect.inspect_model(m, '*')
+ ) + "\n"
+ ret += inspect.format_summary(
+ inspector.get_summary()
+ ) + "\n"
+
+ with inspect.inspect_tensor() as inspector:
+ logits = m(inp)
+
+ ret += inspect.format_summary(
+ inspector.get_summary()
+ ) + "\n"
+
+ loss = (logits * last_weight).sum()
+
+ with inspect.inspect_tensor():
+ loss.backward()
+
+ ret += inspect.format_summary(
+ inspect.inspect_model(m, '*')
+ ) + "\n"
+ ret += inspect.format_summary(
+ inspector.get_summary()
+ ) + "\n"
+
+ return ret + "\n" # replace for matching None grad with zero_grad
+
+def run(name, cls, num_layer=4, dim=4096, batch=32, seq_len=256):
+ ret = ""
+ ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len)
+ bmt.synchronize()
+ return ret
+
+def test_main():
+ ret = {}
+ ret["normal"] = run("normal", Model_NORMAL)
+ ret["block"] = run("block", Model_BLOCK)
+ ret["zero"] = run("zero", Model_ZERO)
+ ret["pipe"] = run("pipe", Model_PIPE)
+ for k, r in ret.items():
+ bmt.print_rank(f"============={k}============")
+ bmt.print_rank(r)
+ for r in ret.values():
+ for r2 in ret.values():
+ lines, lines2 = r.split('\n'), r2.split('\n')
+ assert len(lines) == len(lines2)
+ for line, line2 in zip(lines, lines2):
+ words, words2 = line.split(), line2.split()
+ assert len(words) == len(words2)
+ for w, w2 in zip(words, words2):
+ try:
+ is_float = isinstance(eval(w), float)
+ except:
+ is_float = False
+ if is_float:
+ assert_lt(abs(float(w)-float(w2)), 0.00011)
+ else:
+ assert_eq(w, w2)
+
+if __name__ == "__main__":
+ bmt.init_distributed(pipe_size=2)
+
+ test_main()
diff --git a/examples/BMTrain/tests/test_load_ckpt.py b/examples/BMTrain/tests/test_load_ckpt.py
new file mode 100644
index 00000000..0eb4f95f
--- /dev/null
+++ b/examples/BMTrain/tests/test_load_ckpt.py
@@ -0,0 +1,78 @@
+from utils import *
+import torch
+import torch.nn.functional as F
+import bmtrain as bmt
+import os
+
+class Linear_Normal(torch.nn.Module):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=dtype, device="cuda"))
+ torch.nn.init.xavier_normal_(self.weight)
+ if bias:
+ self.bias = torch.nn.Parameter(torch.empty(out_features, dtype=dtype, device="cuda"))
+ torch.nn.init.zeros_(self.bias)
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, input):
+ return F.linear(input, self.weight, self.bias)
+
+class Linear_BMT(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, bias: bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=dtype), init_method=torch.nn.init.xavier_normal_)
+ if bias:
+ self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=dtype), init_method=torch.nn.init.zeros_)
+ else:
+ self.register_parameter('bias', None)
+
+ def forward(self, input):
+ return F.linear(input, self.weight, self.bias)
+
+
+def test_main():
+ ckpt_path = "test_ckpt.pt"
+ # Transformer BlockList
+ m = Linear_Normal(256, 256).cuda()
+ m2 = bmt.TransformerBlockList([bmt.Block(Linear_BMT(256, 256))])
+ if bmt.rank() == 0:
+ torch.save(m.state_dict(), ckpt_path)
+ dic2 = m.state_dict()
+ dic2["0.weight"] = dic2.pop("weight")
+ dic2["0.bias"] = dic2.pop("bias")
+ m2.load_state_dict(dic2)
+ for key in m.state_dict():
+ bmt_key = f"0.{key}"
+ assert bmt_key in m2.state_dict(), "wrong key in bmtrain model"
+ assert (m2.state_dict()[bmt_key].cuda() == m.state_dict()[key]).all() , "wrong param in bmtrain model"
+ if bmt.rank() == 0:
+ os.remove(ckpt_path)
+ print("Transformer Blocklist load_state_dict and state_dict test passed")
+
+ # Block
+ m3 = bmt.Block(Linear_BMT(256, 256))
+ m3.load_state_dict(m.state_dict())
+ for key in m.state_dict():
+ assert key in m3.state_dict(), "wrong key in bmtrain model"
+ assert (m.state_dict()[key] == m3.state_dict()[key].cuda()).all(), "wrong param in bmtrain model"
+ print("Block load_state_dict and state_dict test passed")
+
+ # normal Distributed module
+ m4 = Linear_BMT(256, 256)
+ m4.load_state_dict(m.state_dict())
+ for key in m.state_dict():
+ assert key in m4.state_dict(), "wrong key in bmtrain model"
+ assert (m.state_dict()[key] == m4.state_dict()[key].cuda()).all(), "wrong param in bmtrain model"
+ print("bmt.distributedmodule load_state_dict and state_dict test passed")
+
+if __name__ == "__main__":
+ bmt.init_distributed()
+
+ test_main()
diff --git a/examples/BMTrain/tests/test_loss_func.py b/examples/BMTrain/tests/test_loss_func.py
new file mode 100644
index 00000000..a448b6d1
--- /dev/null
+++ b/examples/BMTrain/tests/test_loss_func.py
@@ -0,0 +1,79 @@
+from utils import *
+
+import torch
+import bmtrain as bmt
+import torch
+import random
+import copy
+
+def run(x, tgt, loss_func, bigmodel=None, scale=32768, use_float=False):
+ x = x.clone().detach()
+ bigmodel = copy.deepcopy(bigmodel)
+ if use_float:
+ x = x.float()
+ if bigmodel is not None:
+ bigmodel = bigmodel.float()
+ x = x.requires_grad_()
+ if bigmodel is None:
+ loss = loss_func(x, tgt)
+ else:
+ t = bigmodel(x)
+ loss = loss_func(t, tgt)
+ (loss * scale).backward()
+ return loss, x.grad
+
+def check(x, tgt, loss_func1, loss_func2, bigmodel=None):
+ loss_1, grad_1 = run(x, tgt, loss_func1, bigmodel=bigmodel)
+ loss_2, grad_2 = run(x, tgt, loss_func2, bigmodel=bigmodel, use_float=True)
+ assert_eq(grad_1.isnan().sum(), 0)
+ assert_eq(grad_2.isnan().sum(), 0)
+ print(f"{(loss_1 - loss_2).abs().item():.6f} {(grad_1 - grad_2).abs().max().item():.6f}")
+ assert_lt((loss_1 - loss_2).abs().item(), 1e-5)
+ assert_lt((grad_1 - grad_2).abs().max().item(), 1e-1)
+
+def test_simple(dtype):
+ loss_func1 = bmt.loss.FusedCrossEntropy()
+ loss_func2 = torch.nn.CrossEntropyLoss()
+
+ N = 32 * 512
+ for i in range(1, 10):
+ C = i * 10
+ x = torch.randn(N, C).cuda().to(dtype)
+ tgt = torch.randint(0, C, (N,)).cuda().long()
+ check(x, tgt, loss_func1, loss_func2)
+ for i in range(1, 10):
+ C = i * 100
+ x = torch.randn(N, C).cuda().to(dtype)
+ tgt = torch.randint(0, C, (N,)).cuda().long()
+ check(x, tgt, loss_func1, loss_func2)
+ for i in range(1, 31):
+ C = i * 1000
+ x = torch.randn(N, C).cuda().to(dtype)
+ tgt = torch.randint(0, C, (N,)).cuda().long()
+ check(x, tgt, loss_func1, loss_func2)
+
+def test_other(dtype):
+ N = 32 * 512
+ for i in range(1, 11):
+ C = i * 10
+ weight = [i+1 for i in range(C)]
+ random.shuffle(weight)
+ weight = torch.tensor(weight, device="cuda")
+ loss_func1 = bmt.loss.FusedCrossEntropy(weight=weight.clone().to(dtype))
+ loss_func2 = torch.nn.CrossEntropyLoss(weight=weight.clone().float())
+
+ x = torch.randn(N, C).cuda().to(dtype)
+ tgt = torch.randint(0, C, (N,)).cuda().long()
+ mask = torch.randint(0, 2, (N,)).cuda().bool()
+ tgt[mask] = -100
+ check(x, tgt, loss_func1, loss_func2)
+
+if __name__ == "__main__":
+ test_other(torch.float16)
+ test_simple(torch.float16)
+ print("==============================================================================")
+ try:
+ test_other(torch.bfloat16)
+ test_simple(torch.bfloat16)
+ except NotImplementedError:
+ pass
\ No newline at end of file
diff --git a/examples/BMTrain/tests/test_middle_hidden.py b/examples/BMTrain/tests/test_middle_hidden.py
new file mode 100644
index 00000000..2a93efe0
--- /dev/null
+++ b/examples/BMTrain/tests/test_middle_hidden.py
@@ -0,0 +1,212 @@
+from utils import *
+
+import bmtrain as bmt
+import random
+import torch
+from bmtrain.block_layer import Block, TransformerBlockList
+from bmtrain.pipe_layer import PipelineTransformerBlockList
+import torch.nn.functional as F
+from bmtrain import inspect
+
+class Linear(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.out = {}
+ if init_weight:
+ self.weight = bmt.DistributedParameter(torch.tensor(init_weight, dtype=torch.float, device="cuda").reshape(out_features, in_features))
+ else:
+ self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.xavier_normal_)
+
+ if init_bias:
+ self.bias = bmt.DistributedParameter(torch.tensor(init_bias, dtype=torch.float, device="cuda").reshape(out_features,))
+ else:
+ self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_)
+
+ def forward(self, input):
+ ret = F.linear(input, self.weight, self.bias)
+ return ret
+
+class Model_ZERO(torch.nn.Module):
+ def __init__(self, pre, ms, post) -> None:
+ super().__init__()
+ self.pre = pre
+ self.ms = TransformerBlockList([
+ Block(m)
+ for m in ms
+ ])
+ self.post = post
+
+ def forward(self, x, return_hidden_states=False):
+ x = self.pre(x)
+ if return_hidden_states:
+ x, o = self.ms(x, return_hidden_states=return_hidden_states)
+ return self.post(x), o
+ else:
+ x = self.ms(x, return_hidden_states=return_hidden_states)
+ return self.post(x)
+
+class Model_PIPE(torch.nn.Module):
+ def __init__(self, pre, ms, post) -> None:
+ super().__init__()
+ self.pre = pre
+ self.ms = PipelineTransformerBlockList([
+ Block(m)
+ for m in ms
+ ])
+ self.post = post
+
+ def forward(self, x, return_hidden_states=False):
+ x = self.pre(x)
+ if return_hidden_states:
+ x, o = self.ms(x, return_hidden_states=return_hidden_states)
+ return self.post(x), o
+ else:
+ x = self.ms(x, return_hidden_states=return_hidden_states)
+ return self.post(x)
+
+class Model_BLOCK(torch.nn.Module):
+ def __init__(self, pre, ms, post) -> None:
+ super().__init__()
+ self.pre = pre
+ self.ms = torch.nn.ModuleList([
+ Block(m)
+ for m in ms
+ ])
+ self.post = post
+
+ def forward(self, x, return_hidden_states=False):
+ x = self.pre(x)
+ o = []
+ y = x
+ for m in self.ms:
+ o.append(y)
+ y = m(y)
+ if return_hidden_states:
+ return self.post(y), o
+ else:
+ return self.post(y)
+
+class Model_NORMAL(torch.nn.Module):
+ def __init__(self, pre, ms, post) -> None:
+ super().__init__()
+ self.pre = pre
+ self.ms = torch.nn.ModuleList(ms)
+ self.post = post
+
+ def forward(self, x, return_hidden_states=False):
+ x = self.pre(x)
+ o = []
+ y = x
+ for m in self.ms:
+ o.append(y)
+ y = m(y)
+ if return_hidden_states:
+ return self.post(y), o
+ else:
+ return self.post(y)
+
+def manual_seed(seed=33):
+ torch.manual_seed(seed)
+ random.seed(seed)
+ try:
+ import numpy as np
+ np.random.seed(seed)
+ except ModuleNotFoundError:
+ pass
+
+def sub_run(name, cls, num_layer, dim, batch, seq_len, only_last=False, only_middle=False, mix_test=False):
+ manual_seed()
+
+ pre = Linear(dim, dim)
+ post = Linear(dim, dim)
+ ms = [Linear(dim, dim) for i in range(num_layer)]
+
+ inp = torch.randn((batch, seq_len, dim)).cuda()
+ last_weight = torch.randn((batch, seq_len, dim)).cuda()
+ middle_weight = [
+ torch.randn((batch, seq_len, dim)).cuda()
+ for i in range(len(ms))
+ ]
+
+ bmt.init_parameters(pre)
+ bmt.init_parameters(post)
+ for m in ms:
+ bmt.init_parameters(m)
+ m = cls(pre, [m for m in ms], post)
+
+ ret = ""
+ if only_last:
+ logits = m(inp)
+ loss = (logits * last_weight).sum()
+ loss.backward()
+ ret += f"========================only last========================\n"
+ ret += inspect.format_summary(
+ inspect.inspect_model(m, '*')
+ )
+ if only_middle:
+ logits, hidden_states = m(inp, return_hidden_states=True)
+ loss = sum([
+ (hidden_state * middle_weight[i]).sum()
+ for i, hidden_state in enumerate(hidden_states)
+ ])
+ loss.backward()
+ ret += f"========================only middle========================\n"
+ ret += inspect.format_summary(
+ inspect.inspect_model(m, '*')
+ )
+ if mix_test:
+ logits, hidden_states = m(inp, return_hidden_states=True)
+ loss = sum([
+ (hidden_state * middle_weight[i]).sum()
+ for i, hidden_state in enumerate(hidden_states)
+ ]) + (logits * last_weight).sum()
+ loss.backward()
+ ret += f"========================mix========================\n"
+ ret += inspect.format_summary(
+ inspect.inspect_model(m, '*')
+ )
+ return ret + "\n" # replace for matching None grad with zero_grad
+
+def run(name, cls, num_layer=4, dim=4096, batch=32, seq_len=256):
+ ret = ""
+ ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len, only_last=True)
+ bmt.synchronize()
+ ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len, only_middle=True)
+ bmt.synchronize()
+ ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len, mix_test=True)
+ bmt.synchronize()
+ return ret
+
+def test_main():
+ ret = {}
+ ret["normal"] = run("normal", Model_NORMAL)
+ ret["block"] = run("block", Model_BLOCK)
+ ret["zero"] = run("zero", Model_ZERO)
+ # ret["pipe"] = run("pipe", Model_PIPE)
+ for k, r in ret.items():
+ bmt.print_rank(f"============={k}============")
+ bmt.print_rank(r)
+ for r in ret.values():
+ for r2 in ret.values():
+ lines, lines2 = r.split('\n'), r2.split('\n')
+ assert len(lines) == len(lines2)
+ for line, line2 in zip(lines, lines2):
+ words, words2 = line.split(), line2.split()
+ assert len(words) == len(words2)
+ for w, w2 in zip(words, words2):
+ try:
+ is_float = isinstance(eval(w), float)
+ except:
+ is_float = False
+ if is_float:
+ assert_lt(abs(float(w)-float(w2)), 2.)
+ else:
+ assert_eq(w, w2)
+
+if __name__ == "__main__":
+ bmt.init_distributed(pipe_size=1)
+
+ test_main()
diff --git a/examples/BMTrain/tests/test_model_wrapper.py b/examples/BMTrain/tests/test_model_wrapper.py
new file mode 100644
index 00000000..6f913d3c
--- /dev/null
+++ b/examples/BMTrain/tests/test_model_wrapper.py
@@ -0,0 +1,221 @@
+from utils import *
+
+from typing import Optional
+import torch
+import math
+import torch.nn.functional as F
+import bmtrain as bmt
+import time
+
+class Attention(torch.nn.Module):
+ def __init__(self,
+ dim_model : int, dim_head : int,
+ num_heads : int, bias : bool = True,
+ dtype = None
+ ) -> None:
+ super().__init__()
+
+ self.project_q = torch.nn.Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
+ self.project_k = torch.nn.Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
+ self.project_v = torch.nn.Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
+
+ self.project_out = torch.nn.Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype)
+
+ self.softmax = torch.nn.Softmax(dim=-1)
+ self.num_heads = num_heads
+ self.dim_head = dim_head
+ self.dim_model = dim_model
+
+ def forward(self,
+ hidden_q : torch.Tensor, # (batch_size, seq_q, dim_model)
+ hidden_kv : torch.Tensor, # (batch_size, seq_kv, dim_model)
+ mask : torch.BoolTensor, # (batch_size, seq_q, seq_kv)
+ position_bias : Optional[torch.Tensor] = None, # (batch, num_heads, seq_q, seq_kv)
+ ) -> torch.Tensor:
+ batch_size, seq_q, dim_model = hidden_q.size()
+ seq_kv = hidden_kv.size(1)
+
+ h_q : torch.Tensor = self.project_q(hidden_q)
+ h_k : torch.Tensor = self.project_k(hidden_kv)
+ h_v : torch.Tensor = self.project_v(hidden_kv)
+
+ h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head)
+ h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head)
+ h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head)
+
+ h_q = h_q.permute(0, 2, 1, 3).contiguous()
+ h_k = h_k.permute(0, 2, 1, 3).contiguous()
+ h_v = h_v.permute(0, 2, 1, 3).contiguous()
+
+ h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head)
+ h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head)
+ h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head)
+
+ score = torch.bmm(
+ h_q, h_k.transpose(1, 2)
+ )
+ score = score / math.sqrt(self.dim_head)
+
+ score = score.view(batch_size, self.num_heads, seq_q, seq_kv)
+
+ if position_bias is not None:
+ score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv)
+
+ score = torch.where(
+ mask.view(batch_size, 1, seq_q, seq_kv),
+ score,
+ torch.scalar_tensor(float('-inf'), device=score.device, dtype=score.dtype)
+ )
+
+ score = torch.where(
+ mask.view(batch_size, 1, seq_q, seq_kv),
+ self.softmax(score),
+ torch.scalar_tensor(0, device=score.device, dtype=score.dtype)
+ )
+
+ score = score.view(batch_size * self.num_heads, seq_q, seq_kv)
+
+ h_out = torch.bmm(
+ score, h_v
+ )
+ h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head)
+ h_out = h_out.permute(0, 2, 1, 3).contiguous()
+ h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head)
+
+ attn_out = self.project_out(h_out)
+ return attn_out
+
+class Feedforward(torch.nn.Module):
+ def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.w_in = torch.nn.Linear(dim_model, dim_ff, bias = bias, dtype=dtype)
+ self.w_out = torch.nn.Linear(dim_ff, dim_model, bias = bias, dtype=dtype)
+
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, input : torch.Tensor) -> torch.Tensor:
+ return self.w_out(self.relu(self.w_in(input)))
+
+
+class TransformerEncoder(torch.nn.Module):
+ def __init__(self,
+ dim_model : int, dim_head : int, num_heads : int, dim_ff : int,
+ bias : bool = True, dtype = None
+ ) -> None:
+ super().__init__()
+
+ self.ln_attn = torch.nn.LayerNorm(dim_model, dtype=dtype)
+ self.attn = Attention(dim_model, dim_head, num_heads, bias=bias, dtype=dtype)
+
+ self.ln_ff = torch.nn.LayerNorm(dim_model, dtype=dtype)
+ self.ff = Feedforward(dim_model, dim_ff, bias=bias, dtype=dtype)
+
+ def forward(self,
+ hidden : torch.Tensor, # (batch, seq_len, dim_model)
+ mask : torch.BoolTensor, # (batch, seq_len, dim_model)
+ position_bias : Optional[torch.Tensor] = None, # (batch, num_head, seq_len, seq_len)
+ ):
+ x = self.ln_attn(hidden)
+ x = self.attn(x, x, mask, position_bias)
+ hidden = hidden + x
+
+ x = self.ln_ff(hidden)
+ x = self.ff(x)
+ hidden = hidden + x
+
+ return hidden
+
+
+class GPT(torch.nn.Module):
+ def __init__(self,
+ num_layers : int, vocab_size : int,
+ dim_model : int, dim_head : int, num_heads : int, dim_ff : int,
+ max_distance : int,
+ bias : bool = True, dtype = None
+ ) -> None:
+ super().__init__()
+
+ self.max_distance = max_distance
+
+ self.word_emb = torch.nn.Embedding(vocab_size, dim_model, dtype=dtype)
+ self.pos_emb = torch.nn.Embedding(max_distance, dim_model, dtype=dtype)
+ self.dim_model = dim_model
+
+ self.transformers = torch.nn.ModuleList([
+ TransformerEncoder(
+ dim_model, dim_head, num_heads, dim_ff, bias, dtype
+ )
+ for _ in range(num_layers)
+ ])
+
+ self.layernorm = torch.nn.LayerNorm(dim_model, dtype=dtype)
+
+ def forward(self,
+ input : torch.LongTensor, # (batch, seq_len)
+ pos : torch.LongTensor, # (batch, seq_len)
+ mask : torch.BoolTensor, # (batch, seq_len)
+ ) -> torch.Tensor:
+
+ mask_2d = mask[:, None, :] & mask[:, :, None] # (batch, seq_len, seq_len)
+ mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None])
+
+ input_emb = self.pos_emb(pos) + self.word_emb(input)
+
+ out = input_emb
+ for layer in self.transformers:
+ out = layer(out, mask_2d)
+ out = self.layernorm(out)
+
+ logits = F.linear(out, self.word_emb.weight) / math.sqrt(self.dim_model)
+
+ return logits
+
+def test_main():
+ model = GPT(
+ num_layers=8,
+ vocab_size=10240,
+ dim_model=2560,
+ dim_head=80,
+ num_heads=32,
+ dim_ff=8192,
+ max_distance=1024,
+ bias=True,
+ dtype=torch.half
+ )
+
+ bmt_model = bmt.BMTrainModelWrapper(model)
+ model = model.cuda()
+
+ # use the break if i == bmt.rank() to generate different data on different rank
+ torch.manual_seed(1234)
+ batch_size = 2
+ seq_len = 512
+
+ for i in range(bmt.world_size()):
+ sent = torch.randint(0, 10240, (batch_size, seq_len + 1))
+ enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda()
+ enc_input = sent[:, :-1].long().cuda()
+ targets = sent[:, 1:].long().cuda()
+ mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None]
+ targets = torch.where(
+ mask,
+ targets,
+ torch.full_like(targets, -100, dtype=torch.long)
+ )
+
+ if i == bmt.rank():
+ break
+
+ pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1)
+ logits = model(enc_input, pos, pos < enc_length[:, None])
+ bmt_logits = bmt_model(enc_input, pos, pos < enc_length[:, None])
+ print(logits)
+ bmt.synchronize()
+ print(bmt_logits)
+ assert_all_eq(logits, bmt_logits)
+
+if __name__ == '__main__':
+ bmt.init_distributed(seed=0)
+
+ test_main()
diff --git a/examples/BMTrain/tests/test_multi_return.py b/examples/BMTrain/tests/test_multi_return.py
new file mode 100644
index 00000000..f4a5d79f
--- /dev/null
+++ b/examples/BMTrain/tests/test_multi_return.py
@@ -0,0 +1,126 @@
+from utils import *
+
+import bmtrain as bmt
+import torch
+import random
+from bmtrain import config
+from bmtrain.block_layer import Block, TransformerBlockList
+from bmtrain.pipe_layer import PipelineTransformerBlockList
+import torch.nn.functional as F
+
+class MultiInputReturn(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, a, b, c, d, e):
+ return a*2, b+d, c*4+e*5
+
+class Model_ZERO(torch.nn.Module):
+ def __init__(self, ms) -> None:
+ super().__init__()
+ self.ms = TransformerBlockList([
+ Block(m)
+ for m in ms
+ ], num_hidden=3)
+
+ def forward(self, x):
+ y = self.ms(*x)
+ return y
+
+class Model_PIPE(torch.nn.Module):
+ def __init__(self, ms) -> None:
+ super().__init__()
+ self.ms = PipelineTransformerBlockList([
+ Block(m)
+ for m in ms
+ ], num_hidden=3)
+
+ def forward(self, x):
+ y = self.ms(*x)
+ return y
+
+class Model_BLOCK(torch.nn.Module):
+ def __init__(self, ms) -> None:
+ super().__init__()
+ self.ms = torch.nn.ModuleList([
+ Block(m)
+ for m in ms
+ ])
+
+ def forward(self, x):
+ y = x[:3]
+ other = x[3:]
+ for m in self.ms:
+ y = m(*y, *other)
+ return y
+
+class Model_NORMAL(torch.nn.Module):
+ def __init__(self, ms) -> None:
+ super().__init__()
+ self.ms = torch.nn.ModuleList(ms)
+
+ def forward(self, x):
+ y = x[:3]
+ other = x[3:]
+ for m in self.ms:
+ y = m(*y, *other)
+ return y
+
+def manual_seed(seed=33):
+ torch.manual_seed(seed)
+ random.seed(seed)
+ try:
+ import numpy as np
+ np.random.seed(seed)
+ except ModuleNotFoundError:
+ pass
+
+def run(name, cls, num_layer=4, dim=4096):
+ manual_seed()
+
+ ms = [MultiInputReturn() for i in range(num_layer)]
+
+ inps = (
+ torch.randn((dim,)).cuda(),
+ torch.randn((dim,)).cuda(),
+ torch.randn((dim,)).cuda(),
+ torch.randn((dim,)).cuda(),
+ torch.randn((dim,)).cuda(),
+ )
+ last_weights = (
+ torch.randn((dim,)).cuda(),
+ torch.randn((dim,)).cuda(),
+ torch.randn((dim,)).cuda(),
+ )
+
+ for inp in inps:
+ inp.requires_grad_(True)
+ m = cls(ms)
+
+ ret = ""
+ logits = m(inps)
+ loss = (logits[0]*last_weights[0] + logits[1]*last_weights[1] + logits[2]*last_weights[2]).sum()
+ loss.backward()
+ return list(logits) + [
+ inp.grad
+ for inp in inps
+ ]
+
+def test_main():
+ ret = {}
+ ret["normal"] = run("normal", Model_NORMAL)
+ ret["block"] = run("block", Model_BLOCK)
+ ret["zero"] = run("zero", Model_ZERO)
+ # ret["pipe"] = run("pipe", Model_PIPE) # TODO pipeline not support multiple input-output yet
+ for k, r in ret.items():
+ bmt.print_rank(f"============={k}============")
+ bmt.print_rank(r)
+ for r in ret.values():
+ for r2 in ret.values():
+ for i in range(len(r)):
+ assert_lt((r[i]-r2[i]).abs().max(), 1e-5)
+
+if __name__ == "__main__":
+ bmt.init_distributed(pipe_size=1)
+
+ test_main()
diff --git a/examples/BMTrain/tests/test_nccl_backward.py b/examples/BMTrain/tests/test_nccl_backward.py
new file mode 100644
index 00000000..5950fb5c
--- /dev/null
+++ b/examples/BMTrain/tests/test_nccl_backward.py
@@ -0,0 +1,43 @@
+from utils import *
+
+import bmtrain as bmt
+import torch
+
+def test_main(dtype):
+ x = torch.full((1,), bmt.rank() + 1, dtype=dtype, device="cuda").requires_grad_(True)
+ y = bmt.distributed.all_reduce(x, "prod").view(-1)
+ loss = (y * y).sum() / 2
+ loss.backward()
+ ref = y
+ for i in range(bmt.world_size()):
+ if i != bmt.rank(): ref *= i+1
+ assert_eq(x.grad, ref)
+
+def test_reducescatter():
+ world_size = bmt.world_size()
+ for shape in [(128,), (128,128)]:
+ tensors = torch.randn(world_size, *shape, dtype=torch.half, device="cuda").requires_grad_(True)
+ local_tensor = tensors[bmt.rank()]
+ x = local_tensor.detach().clone().requires_grad_(True)
+ y = bmt.distributed.reduce_scatter(x, "sum")
+ ref = tensors.sum(0)
+ partition = x.shape[0] // bmt.world_size()
+ ref_p = ref[bmt.rank() * partition:(bmt.rank() + 1) * partition]
+ if bmt.rank() == 0:
+ print(ref_p)
+ print(y)
+ assert torch.allclose(ref_p, y, atol=1e-2, rtol=1e-3)
+ g = torch.randn_like(y)
+ grad = torch.autograd.grad(y, x, g)[0]
+ pgrad = grad[bmt.rank() * y.shape[0]: (bmt.rank() + 1) * y.shape[0]]
+ ref_g = g
+ if bmt.rank() == 0:
+ print(ref_g)
+ print(pgrad)
+ assert torch.allclose(ref_g, pgrad, atol=1e-3, rtol=1e-3)
+
+if __name__ == "__main__":
+ bmt.init_distributed()
+ test_reducescatter()
+ test_main(torch.half)
+ test_main(torch.bfloat16)
diff --git a/examples/BMTrain/tests/test_no_grad.py b/examples/BMTrain/tests/test_no_grad.py
new file mode 100644
index 00000000..7851c670
--- /dev/null
+++ b/examples/BMTrain/tests/test_no_grad.py
@@ -0,0 +1,90 @@
+import torch
+import bmtrain as bmt
+
+class Layer(torch.nn.Module):
+ def __init__(self):
+ super(Layer, self).__init__()
+ self.linear = bmt.nn.Linear(32, 32)
+ self.count = 0
+ def forward(self, x):
+ self.count += 1
+ return self.linear(x)
+
+def test_no_grad():
+ x = torch.randn(32, 32, device='cuda')
+
+ layer1 = bmt.Block(Layer())
+ layer2 = bmt.Block(Layer())
+ layer1.linear.weight.requires_grad_(False)
+ layer1.linear.bias.requires_grad_(False)
+ y = layer1(x)
+ assert y.requires_grad == False
+ y = layer2(y)
+ y.sum().backward()
+ assert layer1.count == 1
+ assert layer2.count == 2
+
+def test_multi_layer_no_grad():
+ x = torch.randn(32, 32, device='cuda')
+
+ layers = []
+ for i in range(10):
+ layer = bmt.Block(Layer())
+ layer.linear.weight.requires_grad_(i > 4)
+ layer.linear.bias.requires_grad_(i > 4)
+ layers.append(layer)
+
+ y = x
+ for layer in layers:
+ y = layer(y)
+ y.sum().backward()
+ for i in range(len(layers)):
+ assert layers[i].count == (1 if i <=4 else 2)
+
+def test_all_input_no_grad():
+ linear1 = bmt.nn.Linear(32, 32)
+ linear2 = bmt.nn.Linear(32, 32)
+
+ x = torch.randn(32,32, device='cuda')
+
+ linear1 = bmt.Block(linear1)
+ linear2 = bmt.Block(linear2)
+ y = linear1(x)
+ y = linear2(y)
+ y.sum().backward()
+ assert linear1.weight.grad is not None
+ assert linear1.bias.grad is not None
+ assert x.grad is None
+
+def test_same_layer():
+ layer = Layer()
+ block_list = bmt.TransformerBlockList([layer, layer])
+ assert id(block_list[0]) != id(block_list[1])
+
+def test_no_grad_error():
+ layer = bmt.Block(Layer())
+
+ try:
+ block_list = bmt.TransformerBlockList([layer, layer])
+ raise ValueError("test failed")
+ except AssertionError as e:
+ return
+
+def test_no_grad_error2():
+ layer = bmt.Block(Layer())
+
+ try:
+ block_list = bmt.PipelineTransformerBlockList([layer])
+ raise ValueError("test failed")
+ except AssertionError as e:
+ return
+
+if __name__ == '__main__':
+ bmt.init_distributed()
+
+ test_no_grad()
+ test_multi_layer_no_grad()
+ test_all_input_no_grad()
+ test_same_layer()
+ test_no_grad_error()
+ test_no_grad_error2()
diff --git a/examples/BMTrain/tests/test_optim.py b/examples/BMTrain/tests/test_optim.py
new file mode 100644
index 00000000..0aca8c31
--- /dev/null
+++ b/examples/BMTrain/tests/test_optim.py
@@ -0,0 +1,94 @@
+from utils import *
+import torch
+import bmtrain as bmt
+from bmtrain import optim
+
+class TestModule(torch.nn.Module):
+ def __init__(self):
+ super(TestModule, self).__init__()
+ self.fc1 = torch.nn.Linear(128, 128, bias=False)
+ self.fc2 = torch.nn.Linear(128, 128)
+ self.fc3 = torch.nn.Linear(128, 128)
+ self.fc4 = torch.nn.Linear(128, 128)
+ self.fc5 = torch.nn.Linear(128, 128)
+ self.param = torch.nn.Parameter(torch.empty(1237))
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.fc2(x)
+ x = self.fc3(x)
+ x = self.fc4(x)
+ x = self.fc5(x)
+ return x
+
+def main(dtype):
+ model1 = TestModule()
+ model2 = TestModule()
+ model3 = TestModule()
+ model4 = TestModule()
+ model5 = TestModule()
+
+ state_dict = model1.state_dict()
+ for kw in state_dict.keys():
+ state_dict[kw] = torch.randn_like(state_dict[kw])
+
+ model1.load_state_dict(state_dict)
+ model2.load_state_dict(state_dict)
+ model3.load_state_dict(state_dict)
+ model4.load_state_dict(state_dict)
+ model5.load_state_dict(state_dict)
+
+ model1 = model1.cuda().to(dtype)
+ model2 = model2.cuda().to(dtype)
+ model3 = model3.cuda()
+ model4 = model4.cuda()
+ model5 = model5.cuda()
+
+ opt1 = bmt.optim.AdamOptimizer(model1.parameters(), lr=1)
+ opt2 = bmt.optim.AdamOffloadOptimizer(model2.parameters(), lr=1)
+ opt3 = torch.optim.Adam(model3.parameters(), lr=1)
+ opt4 = bmt.optim.AdamOptimizer(model4.parameters(), lr=1)
+ opt5 = bmt.optim.AdamOffloadOptimizer(model5.parameters(), lr=1)
+
+ optim_manager = bmt.optim.OptimManager(loss_scale=4)
+ optim_manager.add_optimizer(opt1)
+ optim_manager.add_optimizer(opt2)
+ optim_manager.add_optimizer(opt3)
+ optim_manager.add_optimizer(opt4)
+ optim_manager.add_optimizer(opt5)
+
+ for _ in range(100):
+ optim_manager.zero_grad()
+
+ for p1, p2, p3, p4, p5 in zip(model1.parameters(), model2.parameters(), model3.parameters(), model4.parameters(), model5.parameters()):
+ grad = torch.randn_like(p1)
+ p1.grad = grad.to(dtype)
+ p2.grad = grad.to(dtype)
+ p3.grad = grad.float()
+ p4.grad = grad.float()
+ p5.grad = grad.float()
+
+ optim_manager.step()
+ torch.cuda.synchronize()
+
+ for p1, p2, p3, p4, p5 in zip(model1.parameters(), model2.parameters(), model3.parameters(), model4.parameters(), model5.parameters()):
+ diff1 = torch.abs(p1 - p2).max().item()
+ diff2 = torch.abs(p1 - p3).max().item()
+ diff3 = torch.abs(p2 - p3).max().item()
+ diff4 = torch.abs(p3 - p4).max().item()
+ diff5 = torch.abs(p3 - p5).max().item()
+ print(f"{diff1:.6f}, {diff2:.6f}, {diff3:.6f}, {diff4:.6f}, {diff5:.6f}")
+ assert_lt(diff1, 1)
+ assert_lt(diff2, 1)
+ assert_lt(diff3, 1)
+ assert_eq(diff4, 0)
+ assert_lt(diff5, 0.00001)
+
+if __name__ == "__main__":
+ bmt.init_distributed()
+ main(torch.float16)
+ print("==============================================================================")
+ try:
+ main(torch.bfloat16)
+ except NotImplementedError:
+ pass
diff --git a/examples/BMTrain/tests/test_optim_state.py b/examples/BMTrain/tests/test_optim_state.py
new file mode 100644
index 00000000..57d5d0e3
--- /dev/null
+++ b/examples/BMTrain/tests/test_optim_state.py
@@ -0,0 +1,135 @@
+import torch
+import bmtrain as bmt
+import os
+from copy import deepcopy
+from bmtrain import optim, lr_scheduler
+
+class TestSubModule(bmt.DistributedModule):
+ def __init__(self):
+ super(TestSubModule, self).__init__()
+ self.fc1 = bmt.BMTrainModelWrapper(torch.nn.Linear(768, 3072))
+ self.fc2 = bmt.BMTrainModelWrapper(torch.nn.Linear(3072, 1024))
+ self.fc3 = bmt.BMTrainModelWrapper(torch.nn.Linear(1024, 768))
+ self.param = bmt.DistributedParameter(torch.zeros(1237))
+ self.fc4 = bmt.BMTrainModelWrapper(torch.nn.Linear(768, 300))
+ self.fc5 = bmt.BMTrainModelWrapper(torch.nn.Linear(300, 768))
+ self.dropout = torch.nn.Dropout(0.0)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.fc2(x)
+ x = self.fc3(x)
+ x = self.dropout(x)
+ x = self.fc4(x)
+ x = self.fc5(x)
+ return x
+
+class TestModule(torch.nn.Module):
+ def __init__(self):
+ super(TestModule, self).__init__()
+ self.layer1 = TestSubModule()
+ self.layer2 = bmt.CheckpointBlock(TestSubModule())
+
+ def forward(self, x):
+ x = self.layer1(x)
+ x = self.layer2(x)
+ return x
+
+def train(model1, model2, model3, optim_manager):
+ x = torch.randn((4, 768)).cuda()
+ for _ in range(10):
+ optim_manager.zero_grad()
+
+ y1, y2, y3 = model1(x), model2(x), model3(x)
+ w = torch.randn_like(y1)
+ loss = (y1*w).sum() + (y2*w).sum() + (y3*w).sum()
+ optim_manager.backward(loss)
+
+ optim_manager.step()
+
+def manual_seed(seed=33):
+ torch.manual_seed(seed)
+ import random
+ random.seed(seed)
+ try:
+ import numpy as np
+ np.random.seed(seed)
+ except ModuleNotFoundError:
+ pass
+
+def main():
+ model1 = TestModule()
+ model2 = TestModule()
+ model3 = TestModule()
+
+ bmt.save(model1, "test_optim_state_model1.pt")
+
+ bmt.load(model1, f"test_optim_state_model1.pt")
+ bmt.load(model2, f"test_optim_state_model1.pt")
+ bmt.load(model3, f"test_optim_state_model1.pt")
+
+ opt1 = optim.AdamOptimizer(model1.parameters(), weight_decay=1e-3)
+ opt2 = optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-3)
+ opt3 = torch.optim.Adam(model3.parameters(), weight_decay=1e-3)
+ lrs1 = lr_scheduler.Noam(opt1, start_lr=20, warmup_iter=0, end_iter=300, num_iter=1)
+ lrs2 = lr_scheduler.Noam(opt2, start_lr=20, warmup_iter=0, end_iter=300, num_iter=1)
+ lrs3 = lr_scheduler.Noam(opt3, start_lr=20, warmup_iter=0, end_iter=300, num_iter=1)
+ optim_manager = optim.OptimManager(loss_scale=256)
+ optim_manager.add_optimizer(opt1, lrs1)
+ optim_manager.add_optimizer(opt2, lrs2)
+ optim_manager.add_optimizer(opt3, lrs3)
+
+ train(model1, model2, model3, optim_manager)
+
+ bmt.save(model1, f"test_optim_state_model1.pt")
+ bmt.save(model2, f"test_optim_state_model2.pt")
+ bmt.save(model3, f"test_optim_state_model3.pt")
+
+ torch.save(optim_manager.state_dict(), f"test_optim_manager_{bmt.rank()}.opt")
+
+ manual_seed()
+ train(model1, model2, model3, optim_manager)
+ state_2 = deepcopy([list(model1.parameters()), list(model2.parameters()), list(model3.parameters())])
+
+ model1 = TestModule()
+ model2 = TestModule()
+ model3 = TestModule()
+ bmt.load(model1, f"test_optim_state_model1.pt")
+ bmt.load(model2, f"test_optim_state_model2.pt")
+ bmt.load(model3, f"test_optim_state_model3.pt")
+
+ opt1 = optim.AdamOptimizer(model1.parameters(), weight_decay=1e-8, betas=(0.3, 0.333), eps=1e-3)
+ opt2 = optim.AdamOffloadOptimizer(model2.parameters(), weight_decay=1e-7, betas=(0.4, 0.456), eps=1e-1)
+ opt3 = torch.optim.Adam(model3.parameters(), weight_decay=1e-6, betas=(0.9, 0.777), eps=1e-2)
+ lrs1 = lr_scheduler.Noam(opt1, start_lr=200, warmup_iter=30, end_iter=500, num_iter=3)
+ lrs2 = lr_scheduler.Noam(opt2, start_lr=20, warmup_iter=40, end_iter=600, num_iter=1)
+ lrs3 = lr_scheduler.Noam(opt3, start_lr=10, warmup_iter=50, end_iter=700, num_iter=2)
+ optim_manager = optim.OptimManager(loss_scale=10485760)
+ optim_manager.add_optimizer(opt1, lrs1)
+ optim_manager.add_optimizer(opt2, lrs2)
+ optim_manager.add_optimizer(opt3, lrs3)
+ optim_manager.load_state_dict(torch.load(f"test_optim_manager_{bmt.rank()}.opt"))
+
+ manual_seed()
+ train(model1, model2, model3, optim_manager)
+ state_1_plus_1 = deepcopy([list(model1.parameters()), list(model2.parameters()), list(model3.parameters())])
+
+ for i, kind in [
+ (0, "BMTAdam"),
+ (1, "BMTAdamOffload"),
+ (2, "TorchAdam")
+ ]:
+ ref = state_2[i]
+ chk = state_1_plus_1[i]
+ for rp, p in zip(ref, chk):
+ assert (rp==p).all(), f"{kind} state load error"
+
+ if bmt.rank() == 0:
+ os.remove(f"test_optim_state_model1.pt")
+ os.remove(f"test_optim_state_model2.pt")
+ os.remove(f"test_optim_state_model3.pt")
+ os.remove(f"test_optim_manager_{bmt.rank()}.opt")
+
+if __name__ == "__main__":
+ bmt.init_distributed()
+ main()
diff --git a/examples/BMTrain/tests/test_other_hidden.py b/examples/BMTrain/tests/test_other_hidden.py
new file mode 100644
index 00000000..27736aa7
--- /dev/null
+++ b/examples/BMTrain/tests/test_other_hidden.py
@@ -0,0 +1,189 @@
+from utils import *
+
+import bmtrain as bmt
+import random
+import torch
+from bmtrain import config
+from bmtrain.block_layer import Block, TransformerBlockList
+from bmtrain.pipe_layer import PipelineTransformerBlockList
+import torch.nn.functional as F
+from bmtrain import inspect
+
+class Linear(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.out = {}
+ if init_weight:
+ self.weight = bmt.DistributedParameter(torch.tensor(init_weight, dtype=torch.float, device="cuda").reshape(out_features, in_features))
+ else:
+ self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.xavier_normal_)
+
+ if init_bias:
+ self.bias = bmt.DistributedParameter(torch.tensor(init_bias, dtype=torch.float, device="cuda").reshape(out_features,))
+ else:
+ self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_)
+
+ def forward(self, input):
+ ret = F.linear(input, self.weight, self.bias)
+ return ret
+
+class Model_ZERO(torch.nn.Module):
+ def __init__(self, pre, ms, post) -> None:
+ super().__init__()
+ self.pre = pre
+ self.ms = TransformerBlockList([
+ Block(m)
+ for m in ms
+ ])
+ self.post = post
+
+ def forward(self, x, return_hidden_states=False):
+ x = self.pre(x)
+ if return_hidden_states:
+ x, o = self.ms(x, return_hidden_states=return_hidden_states)
+ return self.post(x), o
+ else:
+ x = self.ms(x, return_hidden_states=return_hidden_states)
+ return self.post(x)
+
+class Model_PIPE(torch.nn.Module):
+ def __init__(self, pre, ms, post) -> None:
+ super().__init__()
+ self.pre = pre
+ self.ms = PipelineTransformerBlockList([
+ Block(m)
+ for m in ms
+ ])
+ self.post = post
+
+ def forward(self, x, return_hidden_states=False):
+ x = self.pre(x)
+ if return_hidden_states:
+ x, o = self.ms(x, return_hidden_states=return_hidden_states)
+ return self.post(x), o
+ else:
+ x = self.ms(x, return_hidden_states=return_hidden_states)
+ return self.post(x)
+
+class Model_BLOCK(torch.nn.Module):
+ def __init__(self, pre, ms, post) -> None:
+ super().__init__()
+ self.pre = pre
+ self.ms = torch.nn.ModuleList([
+ Block(m)
+ for m in ms
+ ])
+ self.post = post
+
+ def forward(self, x, return_hidden_states=False):
+ x = self.pre(x)
+ o = []
+ y = x
+ for m in self.ms:
+ o.append(y)
+ y = m(y)
+ if return_hidden_states:
+ return self.post(y), o
+ else:
+ return self.post(y)
+
+class Model_NORMAL(torch.nn.Module):
+ def __init__(self, pre, ms, post) -> None:
+ super().__init__()
+ self.pre = pre
+ self.ms = torch.nn.ModuleList(ms)
+ self.post = post
+
+ def forward(self, x, return_hidden_states=False):
+ x = self.pre(x)
+ o = []
+ y = x
+ for m in self.ms:
+ o.append(y)
+ y = m(y)
+ if return_hidden_states:
+ return self.post(y), o
+ else:
+ return self.post(y)
+
+def manual_seed(seed=33):
+ torch.manual_seed(seed)
+ random.seed(seed)
+ try:
+ import numpy as np
+ np.random.seed(seed)
+ except ModuleNotFoundError:
+ pass
+
+def sub_run(name, cls, num_layer, dim, batch, seq_len, only_pre=False, only_post=False, mix_test=False):
+ manual_seed()
+
+ pre = Linear(dim, dim)
+ post = Linear(dim, dim)
+ ms = [Linear(dim, dim) for i in range(num_layer)]
+
+ inp = torch.randn((batch, seq_len, dim)).cuda()
+ last_weight = torch.randn(pre.weight.shape).cuda()*10
+ middle_weight = [
+ torch.randn((batch, seq_len, dim)).cuda()
+ for i in range(len(ms))
+ ]
+
+ bmt.init_parameters(pre)
+ bmt.init_parameters(post)
+ for m in ms:
+ bmt.init_parameters(m)
+ m = cls(pre, [m for m in ms], post)
+
+ ret = ""
+ if only_pre:
+ loss = (pre.weight * last_weight).sum()
+ loss.backward()
+ ret += f"========================only last========================\n"
+ ret += inspect.format_summary(
+ inspect.inspect_model(m, '*')
+ )
+ if only_post:
+ loss = (post.weight * last_weight).sum()
+ loss.backward()
+ ret += f"========================only middle========================\n"
+ ret += inspect.format_summary(
+ inspect.inspect_model(m, '*')
+ )
+ if mix_test:
+ loss = (pre.weight * last_weight).sum() + (post.weight * last_weight).sum()
+ loss.backward()
+ ret += f"========================mix========================\n"
+ ret += inspect.format_summary(
+ inspect.inspect_model(m, '*')
+ )
+ return ret + "\n" # replace for matching None grad with zero_grad
+
+def run(name, cls, num_layer=4, dim=4096, batch=32, seq_len=256):
+ ret = ""
+ ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len, only_pre=True)
+ bmt.synchronize()
+ ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len, only_post=True)
+ bmt.synchronize()
+ ret += sub_run(name, cls, num_layer=num_layer, dim=dim, batch=batch, seq_len=seq_len, mix_test=True)
+ bmt.synchronize()
+ return ret
+
+def test_main():
+ ret = []
+ ret.append( run("normal", Model_NORMAL) )
+ ret.append( run("block", Model_BLOCK) )
+ ret.append( run("zero", Model_ZERO) )
+ # ret.append( run("pipe", Model_PIPE) )
+ for r in ret:
+ bmt.print_rank(r)
+ for r in ret:
+ for r2 in ret:
+ assert_eq(r, r2)
+
+if __name__ == "__main__":
+ bmt.init_distributed(pipe_size=1)
+ test_main()
diff --git a/examples/BMTrain/tests/test_parallel_projection.py b/examples/BMTrain/tests/test_parallel_projection.py
new file mode 100644
index 00000000..dc1e874d
--- /dev/null
+++ b/examples/BMTrain/tests/test_parallel_projection.py
@@ -0,0 +1,55 @@
+import torch
+import bmtrain as bmt
+from bmtrain.global_var import config
+import numpy as np
+import os
+
+def run_normal(x, t, ckp_path, dtype):
+ proj = bmt.nn.Projection(100, 64, dtype=dtype)
+ bmt.init_parameters(proj)
+ bmt.save(proj, ckp_path)
+ loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=False)
+ y = proj(x)
+ y = y.detach().requires_grad_()
+ loss = loss_func(y, t)
+ loss.backward()
+ return y, loss, y.grad
+
+def run_vp(x, t, ckp_path, dtype):
+ proj = bmt.nn.VPProjection(100, 64, dtype=dtype)
+ bmt.load(proj, ckp_path)
+ loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True)
+ y = proj(x)
+ y = y.detach().requires_grad_()
+ loss = loss_func(y, t)
+ loss.backward()
+ return y, loss, y.grad
+
+def run(dtype):
+ ckp_path = 'embedding.pt'
+ torch.cuda.manual_seed(100)
+ tp_size = config["tp_size"]
+ tp_rank = config['tp_rank']
+ x = torch.randn(110, 64, device='cuda', dtype=dtype)
+ t = torch.cat([torch.arange(100).view(10, 10), torch.ones((10, 1))*-100], dim=-1).view(110).int().cuda()
+ y1, loss1, grad1 = run_normal(x, t, ckp_path, dtype)
+ y2, loss2, grad2 = run_vp(x, t, ckp_path, dtype)
+ y1 = y1.chunk(tp_size, dim=-1)[tp_rank]
+ grad1 = grad1.chunk(tp_size, dim=-1)[tp_rank]
+ for r in range(tp_size):
+ if bmt.rank() == r:
+ print((y1-y2).abs().max())
+ print((loss1-loss2).abs().max())
+ print((grad1-grad2).abs().max())
+ assert (y1-y2).abs().max() < 1e-4
+ assert (loss1-loss2).abs().max() < 1e-4
+ assert (grad1-grad2).abs().max() < 1e-4
+ bmt.synchronize()
+ if bmt.rank() == 0:
+ os.remove(f"embedding.pt")
+
+if __name__ == "__main__":
+ bmt.init_distributed(tp_size=4)
+ run(torch.half)
+ run(torch.bfloat16)
+
diff --git a/examples/BMTrain/tests/test_requires_grad.py b/examples/BMTrain/tests/test_requires_grad.py
new file mode 100644
index 00000000..9a443bd3
--- /dev/null
+++ b/examples/BMTrain/tests/test_requires_grad.py
@@ -0,0 +1,107 @@
+from utils import *
+
+import bmtrain as bmt
+import torch
+from bmtrain import config
+from bmtrain.block_layer import Block, TransformerBlockList
+from bmtrain.pipe_layer import PipelineTransformerBlockList
+from typing import List
+import torch.nn.functional as F
+from bmtrain import inspect
+
+class Linear(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.out = {}
+ if init_weight:
+ self.weight = bmt.DistributedParameter(torch.tensor(init_weight, dtype=torch.float, device="cuda").reshape(out_features, in_features))
+ else:
+ self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.xavier_normal_)
+
+ if init_bias:
+ self.bias = bmt.DistributedParameter(torch.tensor(init_bias, dtype=torch.float, device="cuda").reshape(out_features,))
+ else:
+ self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_)
+
+ def forward(self, input, other_bias):
+ ret = F.linear(input, self.weight, self.bias)
+ ret += other_bias
+ return ret
+
+def run(m, a, b):
+ inp = torch.rand((1, 10, 256)).cuda()*100
+ bias = torch.rand(256).cuda()*100
+ logits = m(inp, bias)
+ loss = logits.sum()
+ loss.backward()
+ bmt.synchronize()
+ sm = inspect.format_summary(
+ inspect.inspect_model(m, '*')
+ )
+ assert_eq(bias.requires_grad, False)
+ return a.weight.grad is None, a.bias.grad is None, sm
+
+def test_main():
+ a = Linear(256, 256)
+ b = Linear(256, 256)
+ m = TransformerBlockList([Block(a), Block(b)])
+ bmt.init_parameters(m)
+
+ a.bias.requires_grad_(False)
+ awg, abg, sm1 = run(m, a, b)
+ print(awg, abg, sm1)
+ assert_eq((awg, abg), (False, True))
+ assert_eq(sm1.split('\n')[2].split()[-2:], ["0.0000", "0.0000"])
+
+ a.weight.requires_grad_(False)
+ a.bias.requires_grad_(True)
+ awg, abg, sm2 = run(m, a, b)
+ print(awg, abg, sm2)
+ assert_eq((awg, abg), (False, False))
+ assert_eq(sm1.split('\n')[1], sm2.split('\n')[1])
+ assert_neq(sm1.split('\n')[2], sm2.split('\n')[2])
+
+ a.weight.requires_grad_(True)
+ a.bias.requires_grad_(False)
+ awg, abg, sm3 = run(m, a, b)
+ print(awg, abg, sm3)
+ assert_eq((awg, abg), (False, False))
+ assert_neq(sm2.split('\n')[1], sm3.split('\n')[1])
+ assert_eq(sm2.split('\n')[2], sm3.split('\n')[2])
+
+def test_main_pipe():
+ a = Linear(256, 256)
+ b = Linear(256, 256)
+ m = PipelineTransformerBlockList([Block(a), Block(b)])
+ bmt.init_parameters(m)
+
+ a.bias.requires_grad_(False)
+ awg, abg, sm1 = run(m, a, b)
+ print(awg, abg, sm1)
+ assert_eq((awg, abg), (False, True))
+ assert_eq(sm1.split('\n')[2].split()[-2:], ["0.0000", "0.0000"])
+
+ a.weight.requires_grad_(False)
+ a.bias.requires_grad_(True)
+ awg, abg, sm2 = run(m, a, b)
+ print(awg, abg, sm2)
+ assert_eq((awg, abg), (False, False))
+ assert_eq(sm1.split('\n')[1], sm2.split('\n')[1])
+ assert_neq(sm1.split('\n')[2], sm2.split('\n')[2])
+
+ a.weight.requires_grad_(True)
+ a.bias.requires_grad_(False)
+ awg, abg, sm3 = run(m, a, b)
+ print(awg, abg, sm3)
+ assert_eq((awg, abg), (False, False))
+ assert_neq(sm2.split('\n')[1], sm3.split('\n')[1])
+ assert_eq(sm2.split('\n')[2], sm3.split('\n')[2])
+
+if __name__ == "__main__":
+ bmt.init_distributed(pipe_size=1)
+
+ test_main()
+ # test_main_pipe()
diff --git a/examples/BMTrain/tests/test_requires_grad_multi_gpu.py b/examples/BMTrain/tests/test_requires_grad_multi_gpu.py
new file mode 100644
index 00000000..2eedf7b6
--- /dev/null
+++ b/examples/BMTrain/tests/test_requires_grad_multi_gpu.py
@@ -0,0 +1,96 @@
+from utils import *
+
+import bmtrain as bmt
+import torch
+from bmtrain.block_layer import Block, TransformerBlockList
+from bmtrain.pipe_layer import PipelineTransformerBlockList
+from typing import List
+import torch.nn.functional as F
+from bmtrain import inspect
+
+class Linear(bmt.DistributedModule):
+ def __init__(self, in_features : int, out_features: int, init_weight = None, init_bias = None) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.out = {}
+ if init_weight:
+ self.weight = bmt.DistributedParameter(torch.tensor(init_weight, dtype=torch.float, device="cuda").reshape(out_features, in_features))
+ else:
+ self.weight = bmt.DistributedParameter(torch.empty(out_features, in_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.xavier_normal_)
+
+ if init_bias:
+ self.bias = bmt.DistributedParameter(torch.tensor(init_bias, dtype=torch.float, device="cuda").reshape(out_features,))
+ else:
+ self.bias = bmt.DistributedParameter(torch.empty(out_features, dtype=torch.float, device="cuda"), init_method=torch.nn.init.zeros_)
+
+ def forward(self, input):
+ ret = F.linear(input, self.weight, self.bias)
+ return ret
+
+def run(m, a, b):
+ inp = torch.rand((1, 10, 256)).cuda()*100
+ logits = m(inp)
+ loss = logits.sum()
+ loss.backward()
+
+ sm = inspect.format_summary(
+ inspect.inspect_model(m, '*')
+ )
+ return sm
+
+def test_main():
+ a = Linear(256, 256)
+ b = Linear(256, 256)
+ m = TransformerBlockList([Block(a), Block(b)])
+ bmt.init_parameters(m)
+
+ a.bias.requires_grad_(False)
+ sm1 = run(m, a, b)
+ print(sm1)
+ assert_eq(sm1.split('\n')[2].split()[-2:], ["0.0000", "0.0000"])
+
+ a.weight.requires_grad_(False)
+ a.bias.requires_grad_(True)
+ sm2 = run(m, a, b)
+ print(sm2)
+ assert_eq(sm1.split('\n')[1], sm2.split('\n')[1])
+ assert_neq(sm1.split('\n')[2], sm2.split('\n')[2])
+
+ a.weight.requires_grad_(True)
+ a.bias.requires_grad_(False)
+ sm3 = run(m, a, b)
+ assert_neq(sm2.split('\n')[1], sm3.split('\n')[1])
+ assert_eq(sm2.split('\n')[2], sm3.split('\n')[2])
+
+def test_main_pipe():
+ a = Linear(256, 256)
+ b = Linear(256, 256)
+ m = PipelineTransformerBlockList([Block(a), Block(b)])
+ bmt.init_parameters(m)
+
+ a.bias.requires_grad_(False)
+ sm1 = run(m, a, b)
+ print(sm1)
+ assert_eq(sm1.split('\n')[2].split()[-2:], ["0.0000", "0.0000"])
+
+ a.weight.requires_grad_(False)
+ a.bias.requires_grad_(True)
+ sm2 = run(m, a, b)
+ print(sm2)
+ assert_eq(sm1.split('\n')[1], sm2.split('\n')[1])
+ assert_neq(sm1.split('\n')[2], sm2.split('\n')[2])
+
+ a.weight.requires_grad_(True)
+ a.bias.requires_grad_(False)
+ sm3 = run(m, a, b)
+ print(sm3)
+ assert_neq(sm2.split('\n')[1], sm3.split('\n')[1])
+ assert_eq(sm2.split('\n')[2], sm3.split('\n')[2])
+
+if __name__ == "__main__":
+ bmt.init_distributed(pipe_size=1)
+
+ test_main()
+ # test_main_pipe()
diff --git a/examples/BMTrain/tests/test_row_parallel_linear.py b/examples/BMTrain/tests/test_row_parallel_linear.py
new file mode 100644
index 00000000..23dce8b2
--- /dev/null
+++ b/examples/BMTrain/tests/test_row_parallel_linear.py
@@ -0,0 +1,54 @@
+import torch
+import bmtrain as bmt
+from bmtrain.global_var import config
+import numpy as np
+
+def run_bmt(x, ckp_path, split_input=True, use_checkpoint_block=True):
+ linear = bmt.nn.RowParallelLinear(8,8, split_input=split_input, all_reduce_output=True)
+ if use_checkpoint_block:
+ linear = bmt.Block(linear)
+ bmt.init_parameters(linear)
+ y = linear(x)
+ y.sum().backward()
+ bmt.save(linear, ckp_path)
+ bmt.synchronize()
+ return y, linear._parameters['weight'].grad, linear._parameters['bias'].grad
+
+def run_torch(x, ckp_path):
+ linear = torch.nn.Linear(8, 8)
+ linear_dict = torch.load(ckp_path)
+ linear.load_state_dict(linear_dict)
+ linear = linear.cuda()
+ linear.weight.requires_grad_()
+ y = linear(x)
+ y.sum().backward()
+ return y, linear.weight.grad, linear.bias.grad
+
+def run(split_input, use_checkpoint_block, ckp_path):
+ tp_size = bmt.config['tp_size']
+ torch.cuda.manual_seed(100)
+ tp_rank = config['topology'].tp_id
+ x = torch.randn(8,8, device='cuda').requires_grad_()
+ rank_x = x.chunk(tp_size, dim=0 if split_input else 1)[tp_rank]
+ y1, weight_grad1, bias_grad1 = run_bmt(rank_x, ckp_path, split_input, use_checkpoint_block)
+ y2, weight_grad2, bias_grad2 = run_torch(x, ckp_path)
+ np.testing.assert_allclose(y1.detach().cpu().numpy(), y2.detach().cpu().numpy(), atol=1e-5)
+
+ weight_grad_list = weight_grad2.chunk(tp_size, dim=1)
+ assert np.allclose(weight_grad1.reshape(weight_grad_list[tp_rank].shape).cpu().numpy(), weight_grad_list[tp_rank].cpu().numpy())
+
+ assert np.allclose(bias_grad1.cpu().numpy(), bias_grad2.cpu().numpy())
+
+def test_split_input():
+ run(True, False, 'row_parallel_linear.ckp')
+ run(True, True, 'row_parallel_linear.ckp')
+
+def test_no_split_input():
+ run(False, False, 'row_parallel_linear_no_split.ckp')
+ run(False, True, 'row_parallel_linear_no_split.ckp')
+
+if __name__ == "__main__":
+ bmt.init_distributed(tp_size=2)
+ test_no_split_input()
+ test_split_input()
+
diff --git a/examples/BMTrain/tests/test_send_recv.py b/examples/BMTrain/tests/test_send_recv.py
new file mode 100644
index 00000000..f933b0c2
--- /dev/null
+++ b/examples/BMTrain/tests/test_send_recv.py
@@ -0,0 +1,22 @@
+from utils import *
+
+import torch
+import bmtrain as bmt
+from bmtrain.global_var import config
+
+def test_send_recv():
+ if config["topology"].stage_id == 0:
+ a = torch.ones((2,1)) * (config["topology"].pp_zero_id+1)
+ a = a.cuda()
+ print(f"send {a}")
+ bmt.distributed.send_activations(a, 1, config["pipe_comm"])
+ else:
+ ref = torch.ones((2,1)) * (config["topology"].pp_zero_id+1)
+ a = bmt.distributed.recv_activations(0, config["pipe_comm"])
+ print(f"recv {a}")
+ assert_all_eq(a, ref.cuda())
+
+if __name__ == '__main__':
+ bmt.init_distributed(pipe_size=2)
+
+ test_send_recv()
diff --git a/examples/BMTrain/tests/test_store.py b/examples/BMTrain/tests/test_store.py
new file mode 100644
index 00000000..cb427d5b
--- /dev/null
+++ b/examples/BMTrain/tests/test_store.py
@@ -0,0 +1,13 @@
+import bmtrain as bmt
+from bmtrain.store import allgather_object
+
+def test_allgather_object():
+
+ res = allgather_object(bmt.rank(), bmt.config["comm"])
+ ref = [i for i in range(bmt.world_size())]
+ assert res == ref
+
+if __name__ == "__main__":
+ bmt.init_distributed()
+ test_allgather_object()
+
diff --git a/examples/BMTrain/tests/test_synchronize.py b/examples/BMTrain/tests/test_synchronize.py
new file mode 100644
index 00000000..bea48b04
--- /dev/null
+++ b/examples/BMTrain/tests/test_synchronize.py
@@ -0,0 +1,26 @@
+import torch
+import bmtrain as bmt
+
+from bmtrain.global_var import config
+from bmtrain import nccl, distributed
+from bmtrain.synchronize import gather_result
+
+def test_main():
+
+ ref_result = torch.ones(5 * bmt.world_size(), 5)
+ tensor = ref_result.chunk(bmt.world_size(), dim=0)[bmt.rank()]
+ real_result = bmt.gather_result(tensor)
+ assert torch.allclose(ref_result, real_result, atol=1e-6), "Assertion failed for real gather result error"
+
+ for i in range(4):
+ size = i + 1
+ tensor_slice = tensor[:size, :size]
+ result_slice = bmt.gather_result(tensor_slice)
+ test_slice = torch.chunk(result_slice, bmt.world_size(), dim=0)[i]
+ assert torch.allclose(tensor_slice, test_slice), f"Assertion failed for tensor_slice_{i}"
+
+print("All test passed")
+
+if __name__ == '__main__':
+ bmt.init_distributed(pipe_size=1)
+ test_main()
diff --git a/examples/BMTrain/tests/test_training.py b/examples/BMTrain/tests/test_training.py
new file mode 100644
index 00000000..46389802
--- /dev/null
+++ b/examples/BMTrain/tests/test_training.py
@@ -0,0 +1,516 @@
+from bmtrain.optim import optim_manager
+from utils import *
+
+from typing import Optional
+import torch
+import math
+import torch.nn.functional as F
+import bmtrain as bmt
+import os
+from bmtrain import inspect
+
+def clip_grad_norm(loss_scale, param_groups, max_norm, norm_type=2, eps=1e-6, is_torch=False):
+ """Clips gradient norm of an iterable of parameters.
+
+ The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.
+
+ Args:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized.
+ max_norm (float or int): max norm of the gradients.
+ norm_type (float or int): type of the used p-norm. Can be 'inf' for infinity norm.
+ eps (float): epsilon used to avoid zero division.
+
+ Returns:
+ Total norm of the parameters (viewed as a single vector).
+ """
+ scale = loss_scale
+ grads = []
+ parameters = [p for group in param_groups for p in group['params']]
+ for p in parameters:
+ if p.grad is not None:
+ grads.append(p.grad.data)
+ else:
+ grads.append(torch.zeros_like(p.data))
+
+ if norm_type == 'inf':
+ total_norm_cuda = max(g.data.abs().max() for g in grads).detach()
+ if not is_torch:
+ bmt.nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "max", bmt.config["comm"])
+ total_norm = total_norm_cuda
+ else:
+ norm_type = float(norm_type)
+ total_norm_cuda = torch.cuda.FloatTensor([0])
+ for index, g in enumerate(grads):
+ param_norm = g.data.float().norm(norm_type)
+ total_norm_cuda += param_norm ** norm_type
+ if not is_torch:
+ bmt.nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "sum", bmt.config["comm"])
+ total_norm = total_norm_cuda[0] ** (1. / norm_type)
+ clip_coef = float(max_norm * scale) / (total_norm + eps)
+ if clip_coef < 1:
+ for p in parameters:
+ if p.grad is not None:
+ p.grad.data.mul_(clip_coef)
+ return total_norm / scale
+
+class Attention(torch.nn.Module):
+ def __init__(self,
+ dim_model : int, dim_head : int,
+ num_heads : int, bias : bool = True,
+ dtype = None
+ ) -> None:
+ super().__init__()
+
+ self.project_q = torch.nn.Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
+ self.project_k = torch.nn.Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
+ self.project_v = torch.nn.Linear(dim_model, dim_head * num_heads, bias=bias, dtype=dtype)
+
+ self.project_out = torch.nn.Linear(dim_head * num_heads, dim_model, bias=bias, dtype=dtype)
+
+ self.softmax = torch.nn.Softmax(dim=-1)
+ self.num_heads = num_heads
+ self.dim_head = dim_head
+ self.dim_model = dim_model
+
+ def forward(self,
+ hidden_q : torch.Tensor, # (batch_size, seq_q, dim_model)
+ hidden_kv : torch.Tensor, # (batch_size, seq_kv, dim_model)
+ mask : torch.BoolTensor, # (batch_size, seq_q, seq_kv)
+ position_bias : Optional[torch.Tensor] = None, # (batch, num_heads, seq_q, seq_kv)
+ ) -> torch.Tensor:
+ batch_size, seq_q, dim_model = hidden_q.size()
+ seq_kv = hidden_kv.size(1)
+
+ h_q : torch.Tensor = self.project_q(hidden_q)
+ h_k : torch.Tensor = self.project_k(hidden_kv)
+ h_v : torch.Tensor = self.project_v(hidden_kv)
+
+ h_q = h_q.view(batch_size, seq_q, self.num_heads, self.dim_head)
+ h_k = h_k.view(batch_size, seq_kv, self.num_heads, self.dim_head)
+ h_v = h_v.view(batch_size, seq_kv, self.num_heads, self.dim_head)
+
+ h_q = h_q.permute(0, 2, 1, 3).contiguous()
+ h_k = h_k.permute(0, 2, 1, 3).contiguous()
+ h_v = h_v.permute(0, 2, 1, 3).contiguous()
+
+ h_q = h_q.view(batch_size * self.num_heads, seq_q, self.dim_head)
+ h_k = h_k.view(batch_size * self.num_heads, seq_kv, self.dim_head)
+ h_v = h_v.view(batch_size * self.num_heads, seq_kv, self.dim_head)
+
+ score = torch.bmm(
+ h_q, h_k.transpose(1, 2)
+ )
+ score = score / math.sqrt(self.dim_head)
+
+ score = score.view(batch_size, self.num_heads, seq_q, seq_kv)
+
+ if position_bias is not None:
+ score = score + position_bias.view(batch_size, self.num_heads, seq_q, seq_kv)
+
+ score = torch.where(
+ mask.view(batch_size, 1, seq_q, seq_kv),
+ score,
+ torch.scalar_tensor(float('-inf'), device=score.device, dtype=score.dtype)
+ )
+
+ score = torch.where(
+ mask.view(batch_size, 1, seq_q, seq_kv),
+ self.softmax(score),
+ torch.scalar_tensor(0, device=score.device, dtype=score.dtype)
+ )
+
+ score = score.view(batch_size * self.num_heads, seq_q, seq_kv)
+
+ h_out = torch.bmm(
+ score, h_v
+ )
+ h_out = h_out.view(batch_size, self.num_heads, seq_q, self.dim_head)
+ h_out = h_out.permute(0, 2, 1, 3).contiguous()
+ h_out = h_out.view(batch_size, seq_q, self.num_heads * self.dim_head)
+
+ attn_out = self.project_out(h_out)
+ return attn_out
+
+class Feedforward(torch.nn.Module):
+ def __init__(self, dim_model : int, dim_ff : int, bias : bool = True, dtype = None) -> None:
+ super().__init__()
+
+ self.w_in = torch.nn.Linear(dim_model, dim_ff, bias = bias, dtype=dtype)
+ self.w_out = torch.nn.Linear(dim_ff, dim_model, bias = bias, dtype=dtype)
+
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, input : torch.Tensor) -> torch.Tensor:
+ return self.w_out(self.relu(self.w_in(input)))
+
+
+class TransformerEncoder(torch.nn.Module):
+ def __init__(self,
+ dim_model : int, dim_head : int, num_heads : int, dim_ff : int,
+ bias : bool = True, dtype = None
+ ) -> None:
+ super().__init__()
+
+ self.ln_attn = torch.nn.LayerNorm(dim_model, dtype=dtype)
+ self.attn = Attention(dim_model, dim_head, num_heads, bias=bias, dtype=dtype)
+
+ self.ln_ff = torch.nn.LayerNorm(dim_model, dtype=dtype)
+ self.ff = Feedforward(dim_model, dim_ff, bias=bias, dtype=dtype)
+
+ def forward(self,
+ hidden : torch.Tensor, # (batch, seq_len, dim_model)
+ mask : torch.BoolTensor, # (batch, seq_len, dim_model)
+ position_bias : Optional[torch.Tensor] = None, # (batch, num_head, seq_len, seq_len)
+ ):
+ x = self.ln_attn(hidden)
+ x = self.attn(x, x, mask, position_bias)
+ hidden = hidden + x
+
+ x = self.ln_ff(hidden)
+ x = self.ff(x)
+ hidden = hidden + x
+
+ return hidden
+
+
+class GPT(torch.nn.Module):
+ def __init__(self,
+ num_layers : int, vocab_size : int,
+ dim_model : int, dim_head : int, num_heads : int, dim_ff : int,
+ max_distance : int,
+ bias : bool = True, dtype = None
+ ) -> None:
+ super().__init__()
+
+ self.dtype = dtype
+ self.max_distance = max_distance
+
+ self.word_emb = torch.nn.Embedding(vocab_size, dim_model, dtype=dtype)
+ self.pos_emb = torch.nn.Embedding(max_distance, dim_model, dtype=dtype)
+ self.dim_model = dim_model
+
+ self.transformers = torch.nn.ModuleList([
+ TransformerEncoder(
+ dim_model, dim_head, num_heads, dim_ff, bias, dtype
+ )
+ for _ in range(num_layers)
+ ])
+ self.run_unroll = False
+
+ self.layernorm = torch.nn.LayerNorm(dim_model, dtype=dtype)
+
+ def forward(self,
+ input : torch.LongTensor, # (batch, seq_len)
+ pos : torch.LongTensor, # (batch, seq_len)
+ mask : torch.BoolTensor, # (batch, seq_len)
+ ) -> torch.Tensor:
+
+ mask_2d = mask[:, None, :] & mask[:, :, None] # (batch, seq_len, seq_len)
+ mask_2d = mask_2d & (pos[:, None, :] >= pos[:, :, None])
+
+ input_emb = self.pos_emb(pos) + self.word_emb(input)
+
+ out = input_emb
+ if isinstance(self.transformers, torch.nn.ModuleList) or self.run_unroll:
+ for layer in self.transformers:
+ out = layer(out, mask_2d, None)
+ else:
+ out = self.transformers(out, mask_2d, None)
+ out = self.layernorm(out)
+
+ logits = F.linear(out, self.word_emb.weight) / math.sqrt(self.dim_model)
+
+ return logits
+
+def sub_train_torch(model, loss_func_cls, optimizer_cls):
+ loss_func = loss_func_cls(ignore_index=-100)
+ optimizer = optimizer_cls(model.parameters(), weight_decay=1e-2)
+ lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0)
+
+ optim_manager = bmt.optim.OptimManager(loss_scale=2**20 if model.dtype == torch.half else None)
+ optim_manager.add_optimizer(optimizer, lr_scheduler)
+
+ # use the break if i == bmt.rank() to generate different data on different rank
+ torch.manual_seed(2333)
+ batch_size = 2
+ seq_len = 512
+
+ sents = []
+ enc_lengths = []
+ enc_inputs = []
+ targetss = []
+ masks = []
+ inps = []
+ for i in range(bmt.world_size()):
+ sent = torch.randint(0, 10240, (batch_size, seq_len + 1))
+ enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda()
+ enc_input = sent[:, :-1].long().cuda()
+ targets = sent[:, 1:].long().cuda()
+ mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None]
+ targets = torch.where(
+ mask,
+ targets,
+ torch.full_like(targets, -100, dtype=torch.long)
+ )
+
+ sents.append(sent)
+ enc_lengths.append(enc_length)
+ enc_inputs.append(enc_input)
+ targetss.append(targets)
+ masks.append(mask)
+ inps.append((sent,enc_length,enc_input,targets,mask))
+
+ logs = []
+ for iter in range(100):
+
+ optim_manager.zero_grad()
+ global_loss = 0
+ for inp in inps:
+ sent, enc_length, enc_input, targets, mask = inp
+ pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1)
+ logits = model(enc_input, pos, pos < enc_length[:, None])
+
+ batch, seq_len, vocab_out_size = logits.size()
+
+ loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len)) / len(inps)
+
+ global_loss += loss.item()
+
+ loss = optim_manager.loss_scale * loss
+ loss.backward()
+ current_stream = torch.cuda.current_stream()
+ current_stream.wait_stream(bmt.config['load_stream'])
+ grad_norm = clip_grad_norm(optim_manager.loss_scale, optimizer.param_groups, max_norm=10.0, is_torch = True)
+
+ optim_manager.step()
+
+ bmt.print_rank("| Iter: {:6d} | loss: {:.4f} {:.4f} | lr: {:.4e} scale: {:10.4f} | grad_norm: {:.4f} |".format(
+ iter,
+ global_loss,
+ loss,
+ lr_scheduler.current_lr,
+ optim_manager.loss_scale,
+ grad_norm,
+ ))
+ logs.append(global_loss)
+
+ summary = inspect.inspect_model(model, "*")
+ return logs, summary
+
+def sub_train(model, loss_func_cls, optimizer_cls):
+ loss_func = loss_func_cls(ignore_index=-100)
+ optimizer = optimizer_cls(model.parameters(), weight_decay=1e-2)
+ lr_scheduler = bmt.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0)
+
+ optim_manager = bmt.optim.OptimManager(loss_scale=2**20 if model.dtype == torch.half else None)
+ optim_manager.add_optimizer(optimizer, lr_scheduler)
+
+ # use the break if i == bmt.rank() to generate different data on different rank
+ torch.manual_seed(2333)
+ batch_size = 2
+ seq_len = 512
+ sents = []
+ enc_lengths = []
+ enc_inputs = []
+ targetss = []
+ masks = []
+ for i in range(bmt.world_size()):
+ sent = torch.randint(0, 10240, (batch_size, seq_len + 1))
+ enc_length = torch.randint(128, seq_len, (batch_size,)).long().cuda()
+ enc_input = sent[:, :-1].long().cuda()
+ targets = sent[:, 1:].long().cuda()
+ mask = torch.arange(seq_len).long().cuda()[None, :] < enc_length[:, None]
+ targets = torch.where(
+ mask,
+ targets,
+ torch.full_like(targets, -100, dtype=torch.long)
+ )
+
+ sents.append(sent)
+ enc_lengths.append(enc_length)
+ enc_inputs.append(enc_input)
+ targetss.append(targets)
+ masks.append(mask)
+ # sent = torch.cat(sents, dim=0)
+ # enc_length = torch.cat(enc_lengths, dim=0)
+ # enc_input = torch.cat(enc_inputs, dim=0)
+ # targets = torch.cat(targetss, dim=0)
+ # mask = torch.cat(masks, dim=0)
+ sent = sents[bmt.rank()]
+ enc_length = enc_lengths[bmt.rank()]
+ enc_input = enc_inputs[bmt.rank()]
+ targets = targetss[bmt.rank()]
+ mask = masks[bmt.rank()]
+
+
+ logs = []
+ pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1)
+ for iter in range(100):
+ optim_manager.zero_grad()
+ logits = model(enc_input, pos, pos < enc_length[:, None])
+
+ batch, seq_len, vocab_out_size = logits.size()
+
+ loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len))
+
+ global_loss = bmt.sum_loss(loss).item()
+
+ optim_manager.backward(loss)
+ grad_norm = clip_grad_norm(optim_manager.loss_scale, optimizer.param_groups, max_norm=10.0)
+
+ optim_manager.step()
+ bmt.print_rank("| Iter: {:6d} | loss: {:.4f} {:.4f} | lr: {:.4e} scale: {:10.4f} | grad_norm: {:.4f} |".format(
+ iter,
+ global_loss,
+ loss,
+ lr_scheduler.current_lr,
+ optim_manager.loss_scale,
+ grad_norm,
+ ))
+ logs.append(global_loss)
+
+ summary = inspect.inspect_model(model, "*")
+ return logs, summary
+
+def train(model, loss_func, optimizer):
+ key = f"{model[0]}*{loss_func[0]}*{optimizer[0]}"
+ model = model[1]()
+ if key.startswith("torch"):
+ ret = sub_train_torch(model, loss_func[1], optimizer[1])
+ else:
+ ret = sub_train(model, loss_func[1], optimizer[1])
+ del model
+ bmt.print_rank(f"finished {key}")
+ return key, ret
+
+def test_main(test_fp16=True, test_fp32=True):
+ ckpt_path = "test_ckpt.pt"
+
+ kwargs = {
+ "num_layers": 8,
+ "vocab_size": 10240,
+ "dim_model": 2560,
+ "dim_head": 80,
+ "num_heads": 32,
+ "dim_ff": 8192,
+ "max_distance": 1024,
+ "bias": True,
+ "dtype": None,
+ }
+
+ def make_ref_ckpt():
+ model = GPT(**kwargs)
+ if bmt.rank() == 0:
+ torch.save(model.state_dict(), ckpt_path)
+ bmt.synchronize()
+ del model
+
+ ret = {}
+ def torch_model():
+ model = GPT(**kwargs)
+ model.load_state_dict(torch.load(ckpt_path))
+ model = model.cuda()
+ return model
+
+ def wrap_model():
+ model = GPT(**kwargs)
+ wrap_model = bmt.BMTrainModelWrapper(model)
+ bmt.load(wrap_model, ckpt_path)
+ return model
+
+ def list_model():
+ model = GPT(**kwargs)
+ list_model = bmt.BMTrainModelWrapper(model)
+ list_model.transformers = bmt.TransformerBlockList([m for m in list_model.transformers])
+ bmt.load(list_model, ckpt_path)
+ return model
+
+ def pipe_model():
+ model = GPT(**kwargs)
+ pipe_model = bmt.BMTrainModelWrapper(model)
+ for m in pipe_model.transformers:
+ assert isinstance(m, bmt.Block)
+ pipe_model.transformers = bmt.PipelineTransformerBlockList([m for m in pipe_model.transformers])
+ bmt.load(pipe_model, ckpt_path)
+ return model
+
+ def unroll_list_model():
+ model = GPT(**kwargs)
+ list_model = bmt.BMTrainModelWrapper(model)
+ list_model.transformers = bmt.TransformerBlockList([m for m in list_model.transformers])
+ bmt.load(list_model, ckpt_path)
+ model.run_unroll = True
+ return model
+
+ models = {
+ "torch": torch_model,
+ "wrapper": wrap_model,
+ "blocklist": list_model,
+ # "pipelist": pipe_model,
+ "unroll_blocklist": unroll_list_model,
+ }
+ loss_funcs = {
+ "bmt_entropy": bmt.loss.FusedCrossEntropy,
+ "torch_entropy": torch.nn.CrossEntropyLoss,
+ }
+ optimizers = {
+ "bmt_adam": bmt.optim.AdamOptimizer,
+ "bmt_adam_offload": bmt.optim.AdamOffloadOptimizer,
+ "torch_adam": torch.optim.Adam,
+ }
+
+ ret = {}
+ def add_to_check_list(m, l, o):
+ key, value = train((m, models[m]), (l, loss_funcs[l]), (o, optimizers[o]))
+ ret[key] = value
+
+ if test_fp16:
+ kwargs["dtype"] = torch.half
+ make_ref_ckpt()
+ add_to_check_list("torch", "bmt_entropy", "bmt_adam")
+ add_to_check_list("wrapper", "bmt_entropy", "bmt_adam")
+ add_to_check_list("blocklist", "bmt_entropy", "bmt_adam")
+ # add_to_check_list("pipelist", "bmt_entropy", "bmt_adam")
+ add_to_check_list("blocklist", "torch_entropy", "bmt_adam")
+ add_to_check_list("blocklist", "bmt_entropy", "bmt_adam_offload")
+ add_to_check_list("unroll_blocklist", "bmt_entropy", "bmt_adam")
+ if bmt.rank() == 0:
+ os.remove(ckpt_path)
+ check(ret)
+
+ if test_fp32:
+ kwargs["dtype"] = torch.float
+ make_ref_ckpt()
+ add_to_check_list("torch", "torch_entropy", "bmt_adam")
+ add_to_check_list("wrapper", "torch_entropy", "bmt_adam")
+ add_to_check_list("blocklist", "torch_entropy", "bmt_adam")
+ # add_to_check_list("pipelist", "torch_entropy", "bmt_adam")
+ add_to_check_list("blocklist", "torch_entropy", "bmt_adam_offload")
+ add_to_check_list("blocklist", "torch_entropy", "torch_adam")
+ add_to_check_list("unroll_blocklist", "bmt_entropy", "bmt_adam")
+ if bmt.rank() == 0:
+ os.remove(ckpt_path)
+ check(ret)
+
+def check(ret):
+ if bmt.rank() == 0:
+ for k1, v1 in ret.items():
+ for k2, v2 in ret.items():
+ if k1 != k2:
+ print(f"checking {k1} vs. {k2}")
+ check_param(v1[1], v2[1])
+ bmt.synchronize()
+ ret.clear()
+
+def check_param(info1, info2):
+ for l1, l2 in zip(info1, info2):
+ for key in ['std', 'mean', 'max', 'min']:
+ v1 = l1[key]
+ v2 = l2[key]
+ assert_lt(abs(v1-v2), 1e-2)
+
+if __name__ == '__main__':
+ bmt.init_distributed(pipe_size=1)
+
+
+ test_main(test_fp16=True, test_fp32=True)
diff --git a/examples/BMTrain/tests/utils.py b/examples/BMTrain/tests/utils.py
new file mode 100644
index 00000000..62831fea
--- /dev/null
+++ b/examples/BMTrain/tests/utils.py
@@ -0,0 +1,14 @@
+def assert_eq(a, b):
+ assert a == b, f"{a} != {b}"
+
+def assert_neq(a, b):
+ assert a != b, f"{a} == {b}"
+
+def assert_lt(a, b):
+ assert a < b, f"{a} >= {b}"
+
+def assert_gt(a, b):
+ assert a > b, f"{a} <= {b}"
+
+def assert_all_eq(a, b):
+ assert_eq((a==b).all(), True)
\ No newline at end of file
diff --git a/examples/CPM.cu/.arsync b/examples/CPM.cu/.arsync
new file mode 100644
index 00000000..f9635cc6
--- /dev/null
+++ b/examples/CPM.cu/.arsync
@@ -0,0 +1,10 @@
+auto_sync_up 0
+local_options -var
+local_path /Users/tachicoma/projects/CPM.cu
+remote_host sa.km
+remote_options -var
+remote_or_local remote
+remote_path /home/sunao/cpm.cu
+remote_port 0
+rsync_flags ["--max-size=100m"]
+backend rsync
diff --git a/examples/CPM.cu/.gitignore b/examples/CPM.cu/.gitignore
new file mode 100644
index 00000000..acdb2c98
--- /dev/null
+++ b/examples/CPM.cu/.gitignore
@@ -0,0 +1,222 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
+.pdm.toml
+.pdm-python
+.pdm-build/
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+
+build/
+# Cuda
+*.i
+*.ii
+*.gpu
+*.ptx
+*.cubin
+*.fatbin
+
+# Prerequisites
+*.d
+
+# Compiled Object files
+*.slo
+*.lo
+*.o
+*.obj
+
+# Precompiled Headers
+*.gch
+*.pch
+
+# Compiled Dynamic libraries
+*.so
+*.dylib
+*.dll
+
+# Fortran module files
+*.mod
+*.smod
+
+# Compiled Static libraries
+*.lai
+*.la
+*.a
+*.lib
+
+# Executables
+*.exe
+*.out
+*.app
+
+.DS_Store
+.vscode
+
+# Checkpoints
+checkpoints/*
+!checkpoints/.gitkeep
+
+# Cursor Rules
+.mdc
+
+# FR-Spec Index
+fr_index
+
+# Prompt Files
+tests/*.txt
+prompt.txt
\ No newline at end of file
diff --git a/examples/CPM.cu/.gitmodules b/examples/CPM.cu/.gitmodules
new file mode 100644
index 00000000..a00c8278
--- /dev/null
+++ b/examples/CPM.cu/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "src/cutlass"]
+ path = src/cutlass
+ url = https://github.com/NVIDIA/cutlass.git
diff --git a/examples/CPM.cu/LICENSE b/examples/CPM.cu/LICENSE
new file mode 100644
index 00000000..99908f2a
--- /dev/null
+++ b/examples/CPM.cu/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2024 OpenBMB
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/examples/CPM.cu/README.md b/examples/CPM.cu/README.md
new file mode 100644
index 00000000..c4fb84f5
--- /dev/null
+++ b/examples/CPM.cu/README.md
@@ -0,0 +1,167 @@
+# CPM.cu
+
+[中文版本](./README_ZH.md) | English
+
+CPM.cu is a lightweight, high-performance CUDA implementation for LLMs, optimized for end-device inference and featuring cutting-edge techniques in **sparse architecture**, **speculative sampling** and **quantization**.
+
+
+
+## 🔥 Project Updates
+
+- [2025.06.06] Optimized for [MiniCPM4](https://github.com/openbmb/minicpm).
+ - Support InfLLM-v2 attention kernel
+ - Support sliding-window for the MTP layer, optimized for long context
+ - Support quantization for the MTP layer
+- [2025.05.29] Support Quantization at [SpecMQuant](https://github.com/AI9Stars/SpecMQuant).
+ - Support Marlin GPTQ kernel for the LLM
+ - Support Speculative Sampling for quantized LLM
+- [2025.03.01] Release the first version at [FR-Spec](https://github.com/thunlp/FR-Spec).
+ - SOTA Speculative Sampling Implementation
+ - Support FR-Spec: Frequency-Ranked Speculative Sampling
+ - Support Tree-based verification of Speculative Sampling in Flash-Attention
+ - Support Static memory management and memory reuse
+ - Support Fused kernels
+ - Support Chunked prefill
+ - Support CUDA Graph
+
+
+
+## Demo
+
+https://github.com/user-attachments/assets/ab36fd7a-485b-4707-b72f-b80b5c43d024
+
+
+
+## Getting Started
+
+- [Installation](#install)
+- [Model Weights](#modelweights)
+- [Quick Start](#example)
+
+
+
+## Installation
+
+### Install from source
+
+```bash
+git clone https://github.com/OpenBMB/CPM.cu.git --recursive
+cd CPM.cu
+python3 setup.py install
+```
+
+
+
+## Prepare Model
+
+Please follow [MiniCPM4's README](https://github.com/openbmb/minicpm) to download the model weights.
+
+
+
+## Quick Start
+
+We provide a simple example to show how to use CPM.cu to generate text.
+
+```bash
+python3 tests/test_generate.py --prompt-file
+```
+
+If you don't specify the model path, the scripts will load the model from OpenBMB's Hugging Face repository.
+If you want to use local paths, we recommend keeping all model filenames unchanged and placing them in the same directory. This way, you can run the model by specifying the directory with the -p parameter. Otherwise, we suggest modifying the paths in the code accordingly.
+
+If you don't specify the prompt file, a default Haystack task with 15K context length will be provided.
+You can use --help to learn more about the script's features.
+
+We also provide a script, `tests/long_prompt_gen.py`, to generate long code summarization.
+This script automatically collects code from this repository and prompts the model to "Summarize the code."
+
+```bash
+python3 tests/long_prompt_gen.py # generate prompt.txt (for more details, use --help)
+python3 tests/test_generate.py --prompt-file prompt.txt
+```
+
+The output should be of the following format:
+
+```bash
+Generated text (streaming output):
+--------------------------------------------------
+Prefilling: 100.0% (106850/106850 tokens) @ 6565.3 tokens/s - Complete!
+
+
+==================================================
+Stream Generation Summary:
+==================================================
+Prefill length: 106850
+Prefill time: 16.36 s
+Prefill tokens/s: 6530.77
+Mean accept length: 2.50
+Decode length: 118
+Decode time: 0.76 s
+Decode tokens/s: 154.59
+```
+
+Where:
+
+- the `Prefill` and `Decode` speed are output by (length, time and token/s).
+- the `Mean accept length` is the average length of the accepted tokens when using Speculative Sampling.
+
+## Code Structure
+
+```bash
+CPM.cu/
+├── src/
+│ ├── flash_attn/ # attention kernels: sparse, tree-verification, etc.
+│ ├── model/
+│ │ ├── minicpm4/ # minicpm4 model
+│ │ ├── w4a16_gptq_marlin/ # marlin kernel
+│ │ └── ... # common layers
+│ ├── entry.cu # pybind: bind cuda and python
+│ └── ...
+├── cpmcu/ # python interface
+└── ...
+```
+
+## More
+We provide a word frequency generation script for FR-Spec, located at "scripts/fr_spec/gen_fr_index.py". You can run it as follows:
+```
+python scripts/fr_spec/gen_fr_index.py --model_path
+```
+You can modify the code to use your own dataset. If your task is in a specific vertical domain, constructing word frequencies tailored to that domain can significantly improve processing speed.
+
+
+## Acknowledgments
+
+Our `src/flash_attn` folder modified based on [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/v2.6.3/csrc/flash_attn).
+
+We have drawn inspiration from the following repositories:
+
+- [EAGLE](https://github.com/SafeAILab/EAGLE)
+- [Block-Sparse-Attention](https://github.com/mit-han-lab/Block-Sparse-Attention)
+- [vLLM](https://github.com/vllm-project/vllm)
+- [SGLang](https://github.com/sgl-project/sglang)
+
+## Citation
+
+Please cite our paper if you find our work valuable.
+
+```
+@article{zhao2025fr,
+ title={FR-Spec: Accelerating Large-Vocabulary Language Models via Frequency-Ranked Speculative Sampling},
+ author={Zhao, Weilin and Pan, Tengyu and Han, Xu and Zhang, Yudi and Sun, Ao and Huang, Yuxiang and Zhang, Kaihuo and Zhao, Weilun and Li, Yuxuan and Wang, Jianyong and others},
+ journal={arXiv preprint arXiv:2502.14856},
+ year={2025}
+}
+
+@article{zhang2025specmqaunt,
+ title={Speculative Decoding Meets Quantization: Compatibility Evaluation and Hierarchical Framework Design},
+ author={Zhang, Yudi and Zhao, Weilin and Han, Xu and Zhao, Tiejun and Xu, Wang and Cao, Hailong and Zhu, Conghui},
+ journal={arXiv preprint arXiv:2505.22179},
+ year={2025}
+}
+
+@article{minicpm4,
+ title={MiniCPM4: Ultra-Efficient LLMs on End Devices},
+ author={MiniCPM},
+ year={2025}
+}
+```
\ No newline at end of file
diff --git a/examples/CPM.cu/README_ZH.md b/examples/CPM.cu/README_ZH.md
new file mode 100644
index 00000000..7f364ad3
--- /dev/null
+++ b/examples/CPM.cu/README_ZH.md
@@ -0,0 +1,165 @@
+# CPM.cu
+
+中文 | [English Version](./README.md)
+
+CPM.cu 是一个针对端侧大模型推理设计的轻量、高效的 CUDA 推理框架,核心支持 **稀疏架构**、**投机采样** 和 **低位宽量化** 等前沿技术创新。
+
+
+
+## 🔥 项目进展
+
+- [2025.06.06] 为 [MiniCPM4](https://github.com/openbmb/minicpm) 优化。
+ - 支持 InfLLM-v2 注意力内核
+ - 支持 MTP 层的滑动窗口,优化长上下文处理
+ - 支持 MTP 层的量化
+- [2025.05.29] 支持 [SpecMQuant](https://github.com/AI9Stars/SpecMQuant) 的量化。
+ - 支持 LLM 的 Marlin GPTQ 内核
+ - 支持量化 LLM 的投机采样
+- [2025.03.01] 在 [FR-Spec](https://github.com/thunlp/FR-Spec) 发布首个版本。
+ - 速度最快的投机采样实现
+ - 支持 FR-Spec, 基于词频优化的投机采样
+ - 支持 Flash-Attention 中的树状投机采样验证
+ - 支持静态内存管理和内存复用
+ - 支持计算融合内核
+ - 支持分块预填充
+ - 支持 CUDA Graph
+
+
+
+## 效果演示
+
+https://github.com/user-attachments/assets/ab36fd7a-485b-4707-b72f-b80b5c43d024
+
+
+
+## 快速开始
+
+- [框架安装](#install)
+- [模型权重](#modelweights)
+- [运行示例](#example)
+
+
+
+## 框架安装
+
+### 从源码安装
+
+```bash
+git clone https://github.com/OpenBMB/cpm.cu.git --recursive
+cd cpm.cu
+python3 setup.py install
+```
+
+
+
+## 准备模型
+
+请按照 [MiniCPM4 的 README](https://github.com/openbmb/minicpm) 的说明下载模型权重。
+
+
+
+## 运行示例
+
+我们提供了一个简单的示例来展示如何使用 CPM.cu。
+
+```bash
+python3 tests/test_generate.py --prompt-file <输入文件路径>
+```
+
+如果您不指定模型路径,脚本将从 OpenBMB 的 Hugging Face 仓库加载模型。
+如果你想使用本地路径,我们推荐不修改所有模型文件名并放在同一目录下,这样可以通过-p指定该目录运行模型。否则建议修改代码中的路径。
+
+如果您不指定输入文件,将提供一个默认的 Haystack 任务,上下文长度为 15K。
+您可以使用 --help 了解更多关于脚本的功能。
+
+我们还有一个脚本,`tests/long_prompt_gen.py`,用于生成长代码总结。
+这个脚本会自动从本仓库中收集代码,并提示模型“总结代码”。
+
+```bash
+python3 tests/long_prompt_gen.py # 生成 prompt.txt (更多细节请见 --help)
+python3 tests/test_generate.py --prompt-file prompt.txt
+```
+
+输出应为如下格式:
+
+```bash
+Generated text (streaming output):
+--------------------------------------------------
+Prefilling: 100.0% (106850/106850 tokens) @ 6565.3 tokens/s - Complete!
+
+
+==================================================
+Stream Generation Summary:
+==================================================
+Prefill length: 106850
+Prefill time: 16.36 s
+Prefill tokens/s: 6530.77
+Mean accept length: 2.50
+Decode length: 118
+Decode time: 0.76 s
+Decode tokens/s: 154.59
+```
+
+其中:
+
+- `Prefill` (输入) 和 `Decode` (输出) 速度通过(长度、时间和 token/s)输出。
+- `Mean accept length` (平均接受长度) 是使用投机采样时接受 token 的平均长度。
+
+## 代码结构
+
+```bash
+cpm.cu/
+├── src/
+│ ├── flash_attn/ # attention: 稀疏, 投机树状验证等
+│ ├── model/
+│ │ ├── minicpm4/ # minicpm4 模型
+│ │ ├── w4a16_gptq_marlin/ # Marlin GPTQ 计算内核
+│ │ └── ... # 通用层
+│ ├── entry.cu # pybind: 绑定 CUDA 和 Python
+│ └── ...
+├── cpmcu/ # Python 接口
+└── ...
+```
+## 更多
+我们提供了FR-Spec的词频生成脚本,位于"scripts/fr_spec/gen_fr_index.py",运行方式如下:
+```
+python scripts/fr_spec/gen_fr_index.py --model_path
+```
+你可以修改代码使用自己的数据集。如果你的任务是特定垂直领域,根据领域构造词频对速度提升有显著收益。
+
+## 致谢
+
+我们的 `src/flash_attn` 文件夹基于 [FlashAttention](https://github.com/Dao-AILab/flash-attention/tree/v2.6.3/csrc/flash_attn) 并进行了修改。
+
+我们从以下仓库中获取了实现灵感:
+
+- [EAGLE](https://github.com/SafeAILab/EAGLE)
+- [Block-Sparse-Attention](https://github.com/mit-han-lab/Block-Sparse-Attention)
+- [vLLM](https://github.com/vllm-project/vllm)
+- [SGLang](https://github.com/sgl-project/sglang)
+
+## 引用
+
+如果您觉得我们的工作有价值,请引用我们的论文。
+
+```
+@article{zhao2025fr,
+ title={FR-Spec: Accelerating Large-Vocabulary Language Models via Frequency-Ranked Speculative Sampling},
+ author={Zhao, Weilin and Pan, Tengyu and Han, Xu and Zhang, Yudi and Sun, Ao and Huang, Yuxiang and Zhang, Kaihuo and Zhao, Weilun and Li, Yuxuan and Wang, Jianyong and others},
+ journal={arXiv preprint arXiv:2502.14856},
+ year={2025}
+}
+
+@article{zhang2025specmqaunt,
+ title={Speculative Decoding Meets Quantization: Compatibility Evaluation and Hierarchical Framework Design},
+ author={Zhang, Yudi and Zhao, Weilin and Han, Xu and Zhao, Tiejun and Xu, Wang and Cao, Hailong and Zhu, Conghui},
+ journal={arXiv preprint arXiv:2505.22179},
+ year={2025}
+}
+
+@article{minicpm4,
+ title={MiniCPM4: Ultra-Efficient LLMs on End Devices},
+ author={MiniCPM},
+ year={2025}
+}
+```
\ No newline at end of file
diff --git a/examples/CPM.cu/cpmcu/__init__.py b/examples/CPM.cu/cpmcu/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/examples/CPM.cu/cpmcu/llm.py b/examples/CPM.cu/cpmcu/llm.py
new file mode 100644
index 00000000..fb594c6b
--- /dev/null
+++ b/examples/CPM.cu/cpmcu/llm.py
@@ -0,0 +1,422 @@
+from . import C
+
+import os, json, glob
+import torch
+from transformers import AutoTokenizer, AutoConfig
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from safetensors.torch import load_file
+import time, math
+import torch.nn.functional as F
+
+dtype_map = {
+ torch.float16: 0,
+ torch.bfloat16: 1,
+}
+
+def dtype_to_int(dtype):
+ ret = dtype_map.get(dtype, -1)
+ if ret == -1:
+ raise ValueError(f"Unsupported dtype: {dtype}")
+ return ret
+
+class LLM(torch.nn.Module):
+ def __init__(self,
+ path: str, # hf model path
+ memory_limit: float = 0.8,
+ chunk_length: int = 1024,
+ dtype: torch.dtype = None,
+ cuda_graph: bool = False,
+ apply_sparse: bool = False,
+ sink_window_size: int = 1,
+ block_window_size: int = 32,
+ sparse_topk_k: int = 32,
+ sparse_switch: int = 8192,
+ apply_compress_lse: bool = False,
+ use_enter: bool = False,
+ use_decode_enter: bool = False,
+ temperature: float = 0.0,
+ random_seed: int = None,
+ ):
+ super().__init__()
+
+ self.path = path
+ self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
+ self.config = AutoConfig.from_pretrained(path, trust_remote_code=True)
+ self.dtype = dtype if dtype is not None else self.config.torch_dtype
+ self.dtype_int = dtype_to_int(self.dtype)
+ self.cuda_graph = cuda_graph
+ self.use_enter = use_enter
+ self.use_decode_enter = use_decode_enter
+ self.temperature = temperature
+
+ self.chunk_length = chunk_length
+ # Flag for showing prefill progress (used in stream mode)
+ self._show_prefill_progress = False
+
+ # Initialize random generator if random_seed is provided
+ if random_seed is not None:
+ self.generator = torch.Generator(device="cuda")
+ self.generator.manual_seed(random_seed)
+ else:
+ self.generator = None
+
+ if not hasattr(self.config, "head_dim"):
+ self.config.head_dim = self.config.hidden_size // self.config.num_attention_heads
+ scale_embed = self.config.scale_emb if hasattr(self.config, "scale_emb") else 1.0
+ scale_lmhead = (self.config.dim_model_base / self.config.hidden_size) if hasattr(self.config, "dim_model_base") else 1.0
+ scale_residual = self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) if hasattr(self.config, "scale_depth") else 1.0
+
+ if apply_sparse:
+ C.init_minicpm4_model(
+ memory_limit,
+ self.config.vocab_size,
+ self.config.num_hidden_layers,
+ self.config.hidden_size,
+ self.config.intermediate_size,
+ self.config.num_attention_heads,
+ self.config.num_key_value_heads,
+ self.config.head_dim,
+ self.config.rms_norm_eps,
+ self.dtype_int,
+ self.chunk_length,
+ scale_embed,
+ scale_lmhead,
+ scale_residual,
+ sink_window_size,
+ block_window_size,
+ sparse_topk_k,
+ sparse_switch,
+ apply_compress_lse,
+ )
+ else:
+ C.init_base_model(
+ memory_limit,
+ self.config.vocab_size,
+ self.config.num_hidden_layers,
+ self.config.hidden_size,
+ self.config.intermediate_size,
+ self.config.num_attention_heads,
+ self.config.num_key_value_heads,
+ self.config.head_dim,
+ self.config.rms_norm_eps,
+ self.dtype_int,
+ self.chunk_length,
+ scale_embed,
+ scale_lmhead,
+ scale_residual,
+ )
+
+ self.logits = torch.empty((64, self.config.vocab_size), dtype=self.dtype, device="cuda")
+
+ def init_storage(self):
+ self.max_total_length = C.init_storage()
+ print("max supported length under current memory limit: ", self.max_total_length)
+
+ def _load(self, name, param, dtype=None, cls=None):
+ if dtype is None:
+ if 'rotary_emb' in name:
+ dtype = torch.float32
+ else:
+ dtype = self.dtype
+
+ if 'gate_up_proj' in name:
+ self._load(name.replace("gate_up_proj", "gate_proj"), param[:param.shape[0]//2], dtype)
+ self._load(name.replace("gate_up_proj", "up_proj"), param[param.shape[0]//2:])
+ elif 'qkv_proj' in name:
+ self._load(name.replace("qkv_proj", "q_proj"), param[:self.config.num_attention_heads * self.config.head_dim])
+ self._load(name.replace("qkv_proj", "k_proj"), param[self.config.num_attention_heads * self.config.head_dim:(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim])
+ self._load(name.replace("qkv_proj", "v_proj"), param[(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim:])
+ else:
+ param = param.contiguous().to(dtype)
+ C.load_model(name, param.data_ptr())
+
+ if "embed_tokens" in name and hasattr(self.config, "tie_word_embeddings") and self.config.tie_word_embeddings:
+ self._load("lm_head.weight", param)
+
+ def _load_from_ckpt(self, path, cls=None):
+ supported_suffix_1 = ["bin.index.json", "safetensors.index.json"]
+ supported_suffix_2 = ["bin", "safetensors", "pt"]
+ file = None
+ for suffix in supported_suffix_1:
+ files = glob.glob(os.path.join(path, f"*.{suffix}"))
+ if len(files) > 1:
+ raise ValueError(f"Multiple files with suffix {suffix} found in {path}")
+ elif len(files) == 1:
+ file = files[0]
+ break
+ else:
+ for suffix in supported_suffix_2:
+ files = glob.glob(os.path.join(path, f"*.{suffix}"))
+ if len(files) > 1:
+ raise ValueError(f"Multiple files with suffix {suffix} found in {path}")
+ elif len(files) == 1:
+ file = files[0]
+ break
+ else:
+ raise ValueError(f"No supported checkpoint file found in {path}, supported suffixes: {supported_suffix_1 + supported_suffix_2}")
+
+ if file.endswith(".index.json"):
+ with open(file, "r") as f:
+ file_list = set(json.load(f)["weight_map"].values())
+ file_list = [os.path.join(path, file) for file in file_list]
+ else:
+ file_list = [file]
+
+ for file in file_list:
+ print(f"load from {file}")
+ if file.endswith(".bin") or file.endswith(".pt"):
+ ckpt = torch.load(file, map_location="cpu")
+ elif file.endswith(".safetensors"):
+ ckpt = load_file(file)
+ for name, param in ckpt.items():
+ self._load(name, param, cls=cls)
+
+ def load_from_hf(self):
+ with torch.no_grad():
+ self._load_from_ckpt(self.path)
+
+ # rope
+ if hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None:
+ rope_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type"))
+ if rope_type == "longrope" and not hasattr(self.config.rope_scaling, "factor"):
+ self.config.rope_scaling["factor"] = 1.0
+ else:
+ rope_type = "default"
+ # TODO only support "default", "llama3" or "longrope" with long_factor=short_factor
+ inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](self.config, "cpu", seq_len=self.max_total_length)
+ # attention_scaling = torch.tensor([attention_scaling], dtype=torch.float32, device="cpu")
+ self._load("model.rotary_emb.inv_freq", inv_freq, dtype=torch.float32)
+ # self._load("model.rotary_emb.attention_scaling", attention_scaling, dtype=torch.float32)
+
+ def prefill(self, input_ids, position_ids):
+ assert input_ids.dtype == torch.int32
+ # Check if input length exceeds maximum supported length
+ if input_ids.numel() > self.max_total_length:
+ raise ValueError(f"Input token count ({input_ids.numel()}) exceeds maximum supported length ({self.max_total_length}) under current memory limit")
+
+ total_length = input_ids.numel()
+ num_chunks = (total_length + self.chunk_length - 1) // self.chunk_length
+
+ prefill_start_time = None
+ actual_prefill_start = None
+
+ # User interaction logic only when use_enter is True
+ if self._show_prefill_progress and self.use_enter:
+ # Clear screen and move cursor to top, then show prompt and wait for user input
+ print("\033[2J\033[H", end="", flush=True) # Clear screen and move to top
+ print("Please Press Enter to Start Prefilling...", end="", flush=True)
+ input() # Wait for Enter key
+
+ # Replace the prompt with [Prefilling] - clear entire line first
+ print("\r" + " " * 50 + "\r[Prefilling]", flush=True)
+ # Start timing after user presses Enter
+ prefill_start_time = time.time()
+ actual_prefill_start = prefill_start_time
+
+ # Initialize progress display for stream mode (always when _show_prefill_progress is True)
+ if self._show_prefill_progress:
+ if prefill_start_time is None: # Only set start time if not already set above
+ prefill_start_time = time.time()
+ if not self.use_enter:
+ print("Prefilling: 0.0% (0/{} tokens) @ 0.0 tokens/s".format(total_length), end="", flush=True)
+ else:
+ print("Prefilling: 0.0% (0/{} tokens) @ 0.0 tokens/s".format(total_length), end="", flush=True)
+
+ # Record actual computation start time if not set yet
+ if actual_prefill_start is None:
+ actual_prefill_start = time.time()
+
+ for chunk_idx, i in enumerate(range(0, input_ids.numel(), self.chunk_length)):
+ # torch.cuda.nvtx.range_push(f"chunk from {i}")
+ C.prefill(
+ min(input_ids.numel() - i, self.chunk_length), i,
+ input_ids.view(-1)[i:].data_ptr(), position_ids.view(-1)[i:].data_ptr(),
+ self.logits.data_ptr()
+ )
+ # torch.cuda.nvtx.range_pop()
+
+ # Show progress for stream mode - always when _show_prefill_progress is True
+ if self._show_prefill_progress and prefill_start_time is not None:
+ current_tokens = min(i + self.chunk_length, total_length)
+ elapsed_time = time.time() - prefill_start_time
+ progress = (current_tokens * 100.0) / total_length
+ tokens_per_sec = current_tokens / elapsed_time if elapsed_time > 0 else 0.0
+ print(f"\rPrefilling: {progress:.1f}% ({current_tokens}/{total_length} tokens) @ {tokens_per_sec:.1f} tokens/s", end="", flush=True)
+
+ # Calculate actual prefill time
+ actual_prefill_time = time.time() - actual_prefill_start
+
+ # Final completion status for stream mode
+ if self._show_prefill_progress:
+ if prefill_start_time is not None:
+ final_elapsed_time = time.time() - prefill_start_time
+ final_tokens_per_sec = total_length / final_elapsed_time if final_elapsed_time > 0 else 0.0
+ print(f"\rPrefilling: 100.0% ({total_length}/{total_length} tokens) @ {final_tokens_per_sec:.1f} tokens/s - Complete!")
+ if self.use_enter:
+ print("\n[Decoding]") # Show decoding status and move to next line only with use_enter
+ else:
+ print() # Just a newline for normal mode
+
+ # Store the actual prefill time for use in generate method
+ self._last_prefill_time = actual_prefill_time
+
+ return self.logits[:1].clone()
+
+ def decode(self, input_ids, position_ids, cache_length, mask_2d = None):
+ assert input_ids.dtype == torch.int32
+ assert position_ids.dtype == torch.int32
+ assert cache_length.dtype == torch.int32
+ if mask_2d is not None:
+ # assert mask_2d.dtype == torch.uint64
+ assert input_ids.numel() == mask_2d.shape[0]
+
+ # torch.cuda.nvtx.range_push(f"decode")
+ cache_length += input_ids.numel() # temparary add for convinience in flash_attn
+ padded_length = (cache_length[0].item() + 128 - 1) // 128 * 128
+ C.decode(
+ input_ids.numel(), padded_length,
+ input_ids.data_ptr(), position_ids.data_ptr(), cache_length.data_ptr(),
+ mask_2d.data_ptr() if mask_2d is not None else 0,
+ self.logits.data_ptr(),
+ self.cuda_graph
+ )
+ cache_length -= input_ids.numel()
+ # torch.cuda.nvtx.range_pop()
+ return self.logits[:input_ids.numel()].clone()
+
+ def generate(self, input_ids, generation_length=100, teminators=[], use_stream=False):
+ """
+ Generate text with optional streaming output.
+ Returns (tokens, decode_time, prefill_time) if use_stream=False, or generator yielding {'token', 'text', 'is_finished', 'prefill_time', 'decode_time'} if use_stream=True.
+ """
+ assert input_ids.dtype == torch.int32
+
+ prefix_length = input_ids.numel()
+ position_ids = torch.arange(prefix_length, dtype=torch.int32, device="cuda")
+
+ # Set progress flag before prefill for stream mode
+ if use_stream:
+ self._show_prefill_progress = True
+
+ # Measure prefill time
+ if self.use_enter and use_stream:
+ # In use_enter mode, timing will be handled inside prefill method
+ logits = self.prefill(input_ids, position_ids)
+ prefill_time = getattr(self, '_last_prefill_time', 0.0) # Get actual prefill time
+ else:
+ torch.cuda.synchronize()
+ prefill_start = time.time()
+ logits = self.prefill(input_ids, position_ids)
+ torch.cuda.synchronize()
+ prefill_time = time.time() - prefill_start
+
+ if self.temperature > 0.0:
+ token = torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator)[0].item()
+ else:
+ token = logits[0].argmax(dim=-1).item()
+
+ # Wait for user input before decode phase if use_decode_enter is enabled
+ if self.use_decode_enter:
+ if use_stream and self.use_enter:
+ # In stream mode with use_enter, we already showed [Decoding], just wait for input
+ print("Please Press Enter to Start Decoding...", end="", flush=True)
+ input() # Wait for Enter key
+ print("\r" + " " * 50 + "\r", end="", flush=True) # Clear the prompt without showing [Decoding] again
+ else:
+ # In other modes, show prompt and wait
+ print("Please Press Enter to Start Decoding...", end="", flush=True)
+ input() # Wait for Enter key
+ print("\r" + " " * 50 + "\r[Decoding]", flush=True) # Show [Decoding] only when use_enter is not enabled
+
+ if not hasattr(self, "input_ids"):
+ self.input_ids = torch.tensor([0], dtype=torch.int32, device="cuda")
+ self.position_ids = torch.tensor([0], dtype=torch.int32, device="cuda")
+ self.cache_length = torch.tensor([0], dtype=torch.int32, device="cuda")
+
+ if use_stream:
+ # Stream generation (optimized)
+ def _stream_generator():
+ nonlocal token
+ # Keep minimal context for correct spacing
+ prev_token = token
+
+ # yield first token
+ text = self.tokenizer.decode([token], skip_special_tokens=False)
+
+ yield {
+ 'token': token,
+ 'text': text,
+ 'is_finished': token in teminators,
+ 'prefill_time': prefill_time,
+ 'decode_time': 0.0 # First token comes from prefill
+ }
+
+ if token in teminators:
+ return
+
+ decode_start_time = time.time()
+
+ for i in range(generation_length-1):
+ self.input_ids[0] = token
+ self.position_ids[0] = prefix_length + i
+ self.cache_length[0] = prefix_length + i
+
+ logits = self.decode(self.input_ids, self.position_ids, self.cache_length)
+ if self.temperature > 0.0:
+ token = torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator)[0].item()
+ else:
+ token = logits[0].argmax(dim=-1).item()
+
+ # For correct spacing, decode with previous token context
+ if prev_token is not None:
+ context_tokens = [prev_token, token]
+ context_text = self.tokenizer.decode(context_tokens, skip_special_tokens=False)
+ prev_text = self.tokenizer.decode([prev_token], skip_special_tokens=False)
+ text = context_text[len(prev_text):]
+ else:
+ text = self.tokenizer.decode([token], skip_special_tokens=False)
+
+ is_finished = token in teminators or i == generation_length - 2
+
+ # Calculate time only when needed to reduce overhead
+ decode_time = time.time() - decode_start_time
+
+ yield {
+ 'token': token,
+ 'text': text,
+ 'is_finished': is_finished,
+ 'prefill_time': 0.0, # Only report prefill_time for first token
+ 'decode_time': decode_time
+ }
+
+ if token in teminators:
+ break
+
+ # Update prev_token
+ prev_token = token
+
+ return _stream_generator()
+ else:
+ # Original batch generation
+ tokens = [token]
+ torch.cuda.synchronize()
+ decode_start = time.time()
+ for i in range(generation_length-1):
+ self.input_ids[0] = token
+ self.position_ids[0] = prefix_length + i
+ self.cache_length[0] = prefix_length + i
+
+ logits = self.decode(self.input_ids, self.position_ids, self.cache_length)
+ if self.temperature > 0.0:
+ token = torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator)[0].item()
+ else:
+ token = logits[0].argmax(dim=-1).item()
+ tokens.append(token)
+ if token in teminators:
+ break
+ torch.cuda.synchronize()
+ decode_time = time.time() - decode_start
+ return tokens, decode_time, prefill_time
+
+ def print_perf_summary(self):
+ C.print_perf_summary()
\ No newline at end of file
diff --git a/examples/CPM.cu/cpmcu/llm_w4a16_gptq_marlin.py b/examples/CPM.cu/cpmcu/llm_w4a16_gptq_marlin.py
new file mode 100644
index 00000000..d80a869f
--- /dev/null
+++ b/examples/CPM.cu/cpmcu/llm_w4a16_gptq_marlin.py
@@ -0,0 +1,434 @@
+from . import C
+
+import os, json, glob
+import torch
+from transformers import AutoTokenizer, AutoConfig
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from safetensors.torch import load_file
+import time, math
+import torch.nn.functional as F
+
+dtype_map = {
+ torch.float16: 0,
+ torch.bfloat16: 1,
+}
+
+def dtype_to_int(dtype):
+ ret = dtype_map.get(dtype, -1)
+ if ret == -1:
+ raise ValueError(f"Unsupported dtype: {dtype}")
+ return ret
+
+class W4A16GPTQMarlinLLM(torch.nn.Module):
+ def __init__(self,
+ path: str, # hf model path
+ memory_limit: float = 0.8,
+ chunk_length: int = 1024,
+ dtype: torch.dtype = None,
+ cuda_graph: bool = False,
+ apply_sparse: bool = False,
+ sink_window_size: int = 1,
+ block_window_size: int = 32,
+ sparse_topk_k: int = 32,
+ sparse_switch: int = 8192,
+ apply_compress_lse: bool = False,
+ use_enter: bool = False,
+ use_decode_enter: bool = False,
+ temperature: float = 0.0,
+ random_seed: int = None,
+ ):
+ super().__init__()
+
+ self.path = path
+ self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
+ self.config = AutoConfig.from_pretrained(path, trust_remote_code=True)
+ self.dtype = dtype if dtype is not None else self.config.torch_dtype
+ self.dtype_int = dtype_to_int(self.dtype)
+ self.cuda_graph = cuda_graph
+ self.use_enter = use_enter
+ self.use_decode_enter = use_decode_enter
+ self.temperature = temperature
+ self.chunk_length = chunk_length
+ # Flag for showing prefill progress (used in stream mode)
+ self._show_prefill_progress = False
+
+ # Initialize random generator if random_seed is provided
+ if random_seed is not None:
+ self.generator = torch.Generator(device="cuda")
+ self.generator.manual_seed(random_seed)
+ else:
+ self.generator = None
+
+ if not hasattr(self.config, "head_dim"):
+ self.config.head_dim = self.config.hidden_size // self.config.num_attention_heads
+
+ self.group_size = self.config.quantization_config['group_size']
+ scale_embed = self.config.scale_emb if hasattr(self.config, "scale_emb") else 1.0
+ scale_lmhead = (self.config.dim_model_base / self.config.hidden_size) if hasattr(self.config, "dim_model_base") else 1.0
+ scale_residual = self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) if hasattr(self.config, "scale_depth") else 1.0
+
+ if apply_sparse:
+ C.init_w4a16_gptq_marlin_minicpm4_model(
+ memory_limit,
+ self.config.vocab_size,
+ self.config.num_hidden_layers,
+ self.config.hidden_size,
+ self.config.intermediate_size,
+ self.config.num_attention_heads,
+ self.config.num_key_value_heads,
+ self.config.head_dim,
+ self.config.rms_norm_eps,
+ self.group_size,
+ self.dtype_int,
+ self.chunk_length,
+ scale_embed,
+ scale_lmhead,
+ scale_residual,
+ sink_window_size,
+ block_window_size,
+ sparse_topk_k,
+ sparse_switch,
+ apply_compress_lse,
+ )
+ else:
+ C.init_w4a16_gptq_marlin_base_model(
+ memory_limit,
+ self.config.vocab_size,
+ self.config.num_hidden_layers,
+ self.config.hidden_size,
+ self.config.intermediate_size,
+ self.config.num_attention_heads,
+ self.config.num_key_value_heads,
+ self.config.head_dim,
+ self.config.rms_norm_eps,
+ self.group_size,
+ self.dtype_int,
+ self.chunk_length,
+ scale_embed,
+ scale_lmhead,
+ scale_residual,
+ )
+
+ self.logits = torch.empty((64, self.config.vocab_size), dtype=self.dtype, device="cuda")
+
+ def init_storage(self):
+ self.max_total_length = C.init_storage()
+ print("max supported length under current memory limit: ", self.max_total_length)
+
+ def _load(self, name, param, dtype=None, cls=None):
+ # if ".q_proj." in name or ".k_proj." in name or ".v_proj." in name or ".gate_proj." in name or ".up_proj." in name:
+ # return
+ if dtype is None:
+ if 'rotary_emb' in name:
+ dtype = torch.float32
+ else:
+ dtype = self.dtype
+
+ # if 'gate_up_proj' in name:
+ # self._load(name.replace("gate_up_proj", "gate_proj"), param[:param.shape[0]//2], dtype)
+ # self._load(name.replace("gate_up_proj", "up_proj"), param[param.shape[0]//2:])
+ # elif 'qkv_proj' in name:
+ # self._load(name.replace("qkv_proj", "q_proj"), param[:self.config.num_attention_heads * self.config.head_dim])
+ # self._load(name.replace("qkv_proj", "k_proj"), param[self.config.num_attention_heads * self.config.head_dim:(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim])
+ # self._load(name.replace("qkv_proj", "v_proj"), param[(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim:])
+ # else:
+ param = param.contiguous()
+ if param.dtype not in [torch.int8, torch.int16, torch.int32]:
+ param = param.to(dtype)
+ C.load_model(name, param.data_ptr())
+
+ if "embed_tokens" in name and hasattr(self.config, "tie_word_embeddings") and self.config.tie_word_embeddings:
+ self._load("lm_head.weight", param)
+
+ def _load_from_ckpt(self, path, cls=None):
+ supported_suffix_1 = ["bin.index.json", "safetensors.index.json"]
+ supported_suffix_2 = ["bin", "safetensors", "pt"]
+ file = None
+ for suffix in supported_suffix_1:
+ files = glob.glob(os.path.join(path, f"*.{suffix}"))
+ if len(files) > 1:
+ raise ValueError(f"Multiple files with suffix {suffix} found in {path}")
+ elif len(files) == 1:
+ file = files[0]
+ break
+ else:
+ for suffix in supported_suffix_2:
+ files = glob.glob(os.path.join(path, f"*.{suffix}"))
+ if len(files) > 1:
+ print(files)
+ if path + "/model_gptq_marlin.safetensors" in files:
+ file = path + "/model_gptq_marlin.safetensors"
+ else:
+ raise ValueError(f"Autogptq models not found in {path}")
+ break
+ elif len(files) == 1:
+ file = files[0]
+ break
+ else:
+ raise ValueError(f"No supported checkpoint file found in {path}, supported suffixes: {supported_suffix_1 + supported_suffix_2}")
+
+ if file.endswith(".index.json"):
+ with open(file, "r") as f:
+ file_list = set(json.load(f)["weight_map"].values())
+ file_list = [os.path.join(path, file) for file in file_list]
+ else:
+ file_list = [file]
+
+ for file in file_list:
+ print(f"load from {file}")
+ if file.endswith(".bin") or file.endswith(".pt"):
+ ckpt = torch.load(file, map_location="cpu")
+ elif file.endswith(".safetensors"):
+ ckpt = load_file(file)
+ for name, param in ckpt.items():
+ self._load(name, param, cls=cls)
+
+ def load_from_hf(self):
+ with torch.no_grad():
+ self._load_from_ckpt(self.path)
+
+ # rope
+ if hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None:
+ rope_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type"))
+ if rope_type == "longrope" and not hasattr(self.config.rope_scaling, "factor"):
+ self.config.rope_scaling["factor"] = 1.0
+ else:
+ rope_type = "default"
+ # TODO only support "default", "llama3" or "longrope" with long_factor=short_factor
+ inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](self.config, "cpu", seq_len=self.max_total_length)
+ # attention_scaling = torch.tensor([attention_scaling], dtype=torch.float32, device="cpu")
+ self._load("model.rotary_emb.inv_freq", inv_freq, dtype=torch.float32)
+ # self._load("model.rotary_emb.attention_scaling", attention_scaling, dtype=torch.float32)
+
+ def prefill(self, input_ids, position_ids):
+ assert input_ids.dtype == torch.int32
+ # Check if input length exceeds maximum supported length
+ if input_ids.numel() > self.max_total_length:
+ raise ValueError(f"Input token count ({input_ids.numel()}) exceeds maximum supported length ({self.max_total_length}) under current memory limit")
+
+ total_length = input_ids.numel()
+ num_chunks = (total_length + self.chunk_length - 1) // self.chunk_length
+
+ prefill_start_time = None
+ actual_prefill_start = None
+
+ # User interaction logic only when use_enter is True
+ if self._show_prefill_progress and self.use_enter:
+ # Clear screen and move cursor to top, then show prompt and wait for user input
+ print("\033[2J\033[H", end="", flush=True) # Clear screen and move to top
+ print("Please Press Enter to Start Prefilling...", end="", flush=True)
+ input() # Wait for Enter key
+
+ # Replace the prompt with [Prefilling] - clear entire line first
+ print("\r" + " " * 50 + "\r[Prefilling]", flush=True)
+ # Start timing after user presses Enter
+ prefill_start_time = time.time()
+ actual_prefill_start = prefill_start_time
+
+ # Initialize progress display for stream mode (always when _show_prefill_progress is True)
+ if self._show_prefill_progress:
+ if prefill_start_time is None: # Only set start time if not already set above
+ prefill_start_time = time.time()
+ if not self.use_enter:
+ print("Prefilling: 0.0% (0/{} tokens) @ 0.0 tokens/s".format(total_length), end="", flush=True)
+ else:
+ print("Prefilling: 0.0% (0/{} tokens) @ 0.0 tokens/s".format(total_length), end="", flush=True)
+
+ # Record actual computation start time if not set yet
+ if actual_prefill_start is None:
+ actual_prefill_start = time.time()
+
+ for chunk_idx, i in enumerate(range(0, input_ids.numel(), self.chunk_length)):
+ # torch.cuda.nvtx.range_push(f"chunk from {i}")
+ C.prefill(
+ min(input_ids.numel() - i, self.chunk_length), i,
+ input_ids.view(-1)[i:].data_ptr(), position_ids.view(-1)[i:].data_ptr(),
+ self.logits.data_ptr()
+ )
+ # torch.cuda.nvtx.range_pop()
+
+ # Show progress for stream mode - always when _show_prefill_progress is True
+ if self._show_prefill_progress and prefill_start_time is not None:
+ current_tokens = min(i + self.chunk_length, total_length)
+ elapsed_time = time.time() - prefill_start_time
+ progress = (current_tokens * 100.0) / total_length
+ tokens_per_sec = current_tokens / elapsed_time if elapsed_time > 0 else 0.0
+ print(f"\rPrefilling: {progress:.1f}% ({current_tokens}/{total_length} tokens) @ {tokens_per_sec:.1f} tokens/s", end="", flush=True)
+
+ # Calculate actual prefill time
+ actual_prefill_time = time.time() - actual_prefill_start
+
+ # Final completion status for stream mode
+ if self._show_prefill_progress:
+ if prefill_start_time is not None:
+ final_elapsed_time = time.time() - prefill_start_time
+ final_tokens_per_sec = total_length / final_elapsed_time if final_elapsed_time > 0 else 0.0
+ print(f"\rPrefilling: 100.0% ({total_length}/{total_length} tokens) @ {final_tokens_per_sec:.1f} tokens/s - Complete!")
+ if self.use_enter:
+ print("\n[Decoding]") # Show decoding status and move to next line only with use_enter
+ else:
+ print() # Just a newline for normal mode
+
+ # Store the actual prefill time for use in generate method
+ self._last_prefill_time = actual_prefill_time
+
+ return self.logits[:1].clone()
+
+ def decode(self, input_ids, position_ids, cache_length, mask_2d = None):
+ assert input_ids.dtype == torch.int32
+ assert position_ids.dtype == torch.int32
+ assert cache_length.dtype == torch.int32
+ if mask_2d is not None:
+ # assert mask_2d.dtype == torch.uint64
+ assert input_ids.numel() == mask_2d.shape[0]
+
+ # torch.cuda.nvtx.range_push(f"decode")
+ cache_length += input_ids.numel() # temparary add for convinience in flash_attn
+ padded_length = (cache_length[0].item() + 128 - 1) // 128 * 128
+ C.decode(
+ input_ids.numel(), padded_length,
+ input_ids.data_ptr(), position_ids.data_ptr(), cache_length.data_ptr(),
+ mask_2d.data_ptr() if mask_2d is not None else 0,
+ self.logits.data_ptr(),
+ self.cuda_graph
+ )
+ cache_length -= input_ids.numel()
+ # torch.cuda.nvtx.range_pop()
+ return self.logits[:input_ids.numel()].clone()
+
+ def generate(self, input_ids, generation_length=100, teminators=[], use_stream=False):
+ """
+ Generate text with optional streaming output.
+ Returns (tokens, decode_time, prefill_time) if use_stream=False, or generator yielding {'token', 'text', 'is_finished', 'prefill_time', 'decode_time'} if use_stream=True.
+ """
+ assert input_ids.dtype == torch.int32
+
+ prefix_length = input_ids.numel()
+ position_ids = torch.arange(prefix_length, dtype=torch.int32, device="cuda")
+
+ # Set progress flag before prefill for stream mode
+ if use_stream:
+ self._show_prefill_progress = True
+
+ # Measure prefill time
+ if self.use_enter and use_stream:
+ # In use_enter mode, timing will be handled inside prefill method
+ logits = self.prefill(input_ids, position_ids)
+ prefill_time = getattr(self, '_last_prefill_time', 0.0) # Get actual prefill time
+ else:
+ torch.cuda.synchronize()
+ prefill_start = time.time()
+ logits = self.prefill(input_ids, position_ids)
+ torch.cuda.synchronize()
+ prefill_time = time.time() - prefill_start
+
+ if self.temperature > 0:
+ token = torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator)[0]
+ else:
+ token = logits[0].argmax(dim=-1).item()
+
+ # Wait for user input before decode phase if use_decode_enter is enabled
+ if self.use_decode_enter:
+ if use_stream and self.use_enter:
+ # In stream mode with use_enter, we already showed [Decoding], just wait for input
+ print("Please Press Enter to Start Decoding...", end="", flush=True)
+ input() # Wait for Enter key
+ print("\r" + " " * 50 + "\r", end="", flush=True) # Clear the prompt without showing [Decoding] again
+ else:
+ # In other modes, show prompt and wait
+ print("Please Press Enter to Start Decoding...", end="", flush=True)
+ input() # Wait for Enter key
+ print("\r" + " " * 50 + "\r[Decoding]", flush=True) # Show [Decoding] only when use_enter is not enabled
+
+ if not hasattr(self, "input_ids"):
+ self.input_ids = torch.tensor([0], dtype=torch.int32, device="cuda")
+ self.position_ids = torch.tensor([0], dtype=torch.int32, device="cuda")
+ self.cache_length = torch.tensor([0], dtype=torch.int32, device="cuda")
+
+ if use_stream:
+ # Stream generation (optimized)
+ def _stream_generator():
+ nonlocal token
+ # Keep minimal context for correct spacing
+ prev_token = token
+
+ # yield first token
+ text = self.tokenizer.decode([token], skip_special_tokens=False)
+
+ yield {
+ 'token': token,
+ 'text': text,
+ 'is_finished': token in teminators,
+ 'prefill_time': prefill_time,
+ 'decode_time': 0.0 # First token comes from prefill
+ }
+
+ if token in teminators:
+ return
+
+ decode_start_time = time.time()
+
+ for i in range(generation_length-1):
+ self.input_ids[0] = token
+ self.position_ids[0] = prefix_length + i
+ self.cache_length[0] = prefix_length + i
+
+ logits = self.decode(self.input_ids, self.position_ids, self.cache_length)
+ if self.temperature > 0:
+ token = torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator)[0]
+ else:
+ token = logits[0].argmax(dim=-1).item()
+
+ # For correct spacing, decode with previous token context
+ if prev_token is not None:
+ context_tokens = [prev_token, token]
+ context_text = self.tokenizer.decode(context_tokens, skip_special_tokens=False)
+ prev_text = self.tokenizer.decode([prev_token], skip_special_tokens=False)
+ text = context_text[len(prev_text):]
+ else:
+ text = self.tokenizer.decode([token], skip_special_tokens=False)
+
+ is_finished = token in teminators or i == generation_length - 2
+
+ # Calculate time only when needed to reduce overhead
+ decode_time = time.time() - decode_start_time
+
+ yield {
+ 'token': token,
+ 'text': text,
+ 'is_finished': is_finished,
+ 'prefill_time': 0.0, # Only report prefill_time for first token
+ 'decode_time': decode_time
+ }
+
+ if token in teminators:
+ break
+
+ # Update prev_token
+ prev_token = token
+
+ return _stream_generator()
+ else:
+ # Original batch generation
+ tokens = [token]
+ torch.cuda.synchronize()
+ decode_start = time.time()
+ for i in range(generation_length-1):
+ self.input_ids[0] = token
+ self.position_ids[0] = prefix_length + i
+ self.cache_length[0] = prefix_length + i
+
+ logits = self.decode(self.input_ids, self.position_ids, self.cache_length)
+ if self.temperature > 0:
+ token = torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator)[0].item()
+ else:
+ token = logits[0].argmax(dim=-1).item()
+ tokens.append(token)
+ if token in teminators:
+ break
+ torch.cuda.synchronize()
+ decode_time = time.time() - decode_start
+ return tokens, decode_time, prefill_time
+
+ def print_perf_summary(self):
+ C.print_perf_summary()
\ No newline at end of file
diff --git a/examples/CPM.cu/cpmcu/speculative/__init__.py b/examples/CPM.cu/cpmcu/speculative/__init__.py
new file mode 100644
index 00000000..e1dcad20
--- /dev/null
+++ b/examples/CPM.cu/cpmcu/speculative/__init__.py
@@ -0,0 +1 @@
+from .eagle import LLM_with_eagle
\ No newline at end of file
diff --git a/examples/CPM.cu/cpmcu/speculative/eagle.py b/examples/CPM.cu/cpmcu/speculative/eagle.py
new file mode 100644
index 00000000..83ef7f7e
--- /dev/null
+++ b/examples/CPM.cu/cpmcu/speculative/eagle.py
@@ -0,0 +1,99 @@
+from .. import C
+from .tree_drafter import LLM_with_tree_drafter
+import math, torch
+from transformers import PretrainedConfig
+
+class EagleConfig(PretrainedConfig):
+ def __init__(
+ self,
+ num_hidden_layers=1,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.eagle_num_layers = num_hidden_layers
+
+class LLM_with_eagle(LLM_with_tree_drafter):
+ def __init__(self,
+ eagle_path,
+ base_path,
+ num_iter=6,
+ topk_per_iter=10,
+ tree_size=60,
+ eagle_window_size=0,
+ frspec_vocab_size=0,
+ apply_eagle_quant: bool=False,
+ use_rope: bool=False,
+ use_input_norm: bool=False,
+ use_attn_norm: bool=False,
+ **kwargs):
+ super().__init__(
+ "eagle", eagle_path, base_path,
+ tree_size = tree_size,
+ **kwargs
+ )
+
+ self.eagle_path = eagle_path
+ self.eagle_config = EagleConfig.from_pretrained(eagle_path)
+ # Ensure presence consistency and equality for scale_depth, dim_model_base, and scale_emb
+ for attr in ("scale_depth", "dim_model_base", "scale_emb"):
+ base_has = hasattr(self.config, attr)
+ eagle_has = hasattr(self.eagle_config, attr)
+ assert base_has == eagle_has, f"{attr} presence mismatch between base and eagle config"
+ if base_has:
+ assert getattr(self.config, attr) == getattr(self.eagle_config, attr), f"{attr} in base config and eagle config should be the same"
+ scale_residual = self.config.scale_depth / math.sqrt(self.config.num_hidden_layers + 1) if hasattr(self.config, "scale_depth") else 1.0
+ self.apply_eagle_quant = apply_eagle_quant
+ if apply_eagle_quant and hasattr(self.eagle_config, "quantization_config"):
+ self.group_size = self.eagle_config.quantization_config.get('group_size', 0)
+ else:
+ self.group_size = 0
+ assert self.group_size == 128 or self.group_size == 0, "only group_size 128 is supported in quantization mode"
+
+ if not use_rope and not use_input_norm and not use_attn_norm and not apply_eagle_quant:
+ C.init_eagle_model(
+ self.eagle_config.eagle_num_layers,
+ num_iter,
+ topk_per_iter,
+ self.tree_size,
+ self.dtype_int
+ )
+ else:
+ C.init_minicpm4_eagle_model(
+ self.eagle_config.eagle_num_layers,
+ num_iter,
+ topk_per_iter,
+ self.tree_size,
+ self.dtype_int,
+ apply_eagle_quant,
+ self.group_size,
+ eagle_window_size,
+ frspec_vocab_size,
+ scale_residual,
+ use_input_norm,
+ use_attn_norm
+ )
+
+ def _load(self, name, param, dtype=None, cls=None):
+ if cls == self.drafter_type:
+ if name == "token_id_remap":
+ C.load_model(f"{cls}.{name}", param.data_ptr())
+ return
+ if dtype is None:
+ dtype = self.dtype
+ param = param.contiguous()
+ if not self.apply_eagle_quant:
+ param = param.to(dtype)
+ if 'embed_tokens' in name:
+ return
+ if 'fc' in name:
+ if 'weight' in name:
+ param1 = param[..., :param.shape[-1] // 2].contiguous()
+ param2 = param[..., param.shape[-1] // 2:].contiguous()
+ C.load_model(f"{cls}.{name.replace('fc', 'fc1')}", param1.data_ptr())
+ C.load_model(f"{cls}.{name.replace('fc', 'fc2')}", param2.data_ptr())
+ else: # bias
+ C.load_model(f"{cls}.{name.replace('fc', 'fc1')}", param.data_ptr())
+ else:
+ C.load_model(f"{cls}.{name}", param.data_ptr())
+ else:
+ super()._load(name, param, dtype)
diff --git a/examples/CPM.cu/cpmcu/speculative/eagle_base_quant/__init__.py b/examples/CPM.cu/cpmcu/speculative/eagle_base_quant/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/examples/CPM.cu/cpmcu/speculative/eagle_base_quant/eagle_base_w4a16_marlin_gptq.py b/examples/CPM.cu/cpmcu/speculative/eagle_base_quant/eagle_base_w4a16_marlin_gptq.py
new file mode 100644
index 00000000..3df755b8
--- /dev/null
+++ b/examples/CPM.cu/cpmcu/speculative/eagle_base_quant/eagle_base_w4a16_marlin_gptq.py
@@ -0,0 +1,103 @@
+from ... import C
+from ..eagle import EagleConfig
+from ..tree_drafter_base_quant.tree_drafter_w4a16_gptq_marlin import W4A16GPTQMarlinLLM_with_tree_drafter
+
+import math, torch
+
+
+class W4A16GPTQMarlinLLM_with_eagle(W4A16GPTQMarlinLLM_with_tree_drafter):
+ def __init__(self,
+ eagle_path,
+ base_path,
+ num_iter=6,
+ topk_per_iter=10,
+ tree_size=60,
+ eagle_window_size=0,
+ frspec_vocab_size=0,
+ apply_eagle_quant: bool=False,
+ use_rope: bool=False,
+ use_input_norm: bool=False,
+ use_attn_norm: bool=False,
+ use_rotation: bool=False,
+ **kwargs):
+ super().__init__(
+ "eagle", eagle_path, base_path,
+ tree_size = tree_size,
+ **kwargs
+ )
+
+ self.eagle_path = eagle_path
+ self.eagle_config = EagleConfig.from_pretrained(eagle_path)
+ # Ensure presence consistency and equality for scale_depth, dim_model_base, and scale_emb
+ for attr in ("scale_depth", "dim_model_base", "scale_emb"):
+ base_has = hasattr(self.config, attr)
+ eagle_has = hasattr(self.eagle_config, attr)
+ assert base_has == eagle_has, f"{attr} presence mismatch between base and eagle config"
+ if base_has:
+ assert getattr(self.config, attr) == getattr(self.eagle_config, attr), f"{attr} in base config and eagle config should be the same"
+ scale_residual = self.config.scale_depth / math.sqrt(self.config.num_hidden_layers + 1) if hasattr(self.config, "scale_depth") else 1.0
+ self.use_rotation = use_rotation
+ self.apply_eagle_quant = apply_eagle_quant
+ if apply_eagle_quant and hasattr(self.eagle_config, "quantization_config"):
+ self.group_size = self.eagle_config.quantization_config.get('group_size', 0)
+ else:
+ self.group_size = 0
+ assert self.group_size == 128 or self.group_size == 0, "only group_size 128 is supported in quantization mode"
+
+ if not use_rope and not use_input_norm and not use_attn_norm and not apply_eagle_quant:
+ if not use_rotation:
+ C.init_eagle_model(
+ self.eagle_config.eagle_num_layers,
+ num_iter,
+ topk_per_iter,
+ self.tree_size,
+ self.dtype_int
+ )
+ else:
+ C.init_eagle_w4a16_gptq_marlin_rot_model(
+ self.eagle_config.eagle_num_layers,
+ num_iter,
+ topk_per_iter,
+ self.tree_size,
+ self.dtype_int
+ )
+ else:
+ C.init_minicpm4_eagle_model(
+ self.eagle_config.eagle_num_layers,
+ num_iter,
+ topk_per_iter,
+ self.tree_size,
+ self.dtype_int,
+ apply_eagle_quant,
+ self.group_size,
+ eagle_window_size,
+ frspec_vocab_size,
+ scale_residual,
+ use_input_norm,
+ use_attn_norm
+ )
+
+ def _load(self, name, param, dtype=None, cls=None):
+ if cls == self.drafter_type:
+ if name == "token_id_remap":
+ C.load_model(f"{cls}.{name}", param.data_ptr())
+ return
+ if dtype is None:
+ dtype = self.dtype
+ param = param.contiguous()
+ if not self.apply_eagle_quant:
+ param = param.to(dtype)
+ if (not self.use_rotation) and 'embed_tokens' in name:
+ return
+ if 'fc' in name:
+ if 'weight' in name or "scales" in name:
+ param1 = param[..., :param.shape[-1] // 2].contiguous()
+ param2 = param[..., param.shape[-1] // 2:].contiguous()
+ C.load_model(f"{cls}.{name.replace('fc', 'fc1')}", param1.data_ptr())
+ C.load_model(f"{cls}.{name.replace('fc', 'fc2')}", param2.data_ptr())
+ else: # bias
+ C.load_model(f"{cls}.{name.replace('fc', 'fc1')}", param.data_ptr())
+ else:
+ C.load_model(f"{cls}.{name}", param.data_ptr())
+ else:
+ super()._load(name, param, dtype)
diff --git a/examples/CPM.cu/cpmcu/speculative/hier_spec_quant/__init__.py b/examples/CPM.cu/cpmcu/speculative/hier_spec_quant/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/examples/CPM.cu/cpmcu/speculative/hier_spec_quant/hier_eagle_w4a16_gm_spec_w4a16_gm.py b/examples/CPM.cu/cpmcu/speculative/hier_spec_quant/hier_eagle_w4a16_gm_spec_w4a16_gm.py
new file mode 100644
index 00000000..e4f9fc5c
--- /dev/null
+++ b/examples/CPM.cu/cpmcu/speculative/hier_spec_quant/hier_eagle_w4a16_gm_spec_w4a16_gm.py
@@ -0,0 +1,268 @@
+from ... import C
+from ...llm_w4a16_gptq_marlin import W4A16GPTQMarlinLLM
+
+import numpy as np
+import torch
+from ..tree_drafter import *
+import time
+from transformers import PretrainedConfig, AutoTokenizer, AutoConfig
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
+
+def pack_draft_mask(mask_2d):
+ '''
+ for static masks, pack them into a uint64 per row
+ '''
+ mask_2d_packed = torch.zeros((mask_2d.shape[0]), dtype=torch.uint16, device="cuda")
+ for i in range(mask_2d.shape[0]):
+ mask_1 = 0
+ for j in range(i + 1):
+ mask_1 |= (mask_2d[i][j].item() << j )
+ mask_2d_packed[i] = mask_1
+ mask_2d_packed = mask_2d_packed.view(torch.uint16).view(-1)
+ return mask_2d_packed
+
+class EagleConfig(PretrainedConfig):
+ def __init__(
+ self,
+ num_hidden_layers=1,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.eagle_num_layers = num_hidden_layers
+
+class HierEagleW4A16GMSpecW4A16GM(W4A16GPTQMarlinLLM):
+ def __init__(self,
+ drafter_path: str,
+ base_path: str,
+ min_draft_length: int,
+ draft_cuda_graph: bool,
+ tree_path: str,
+ ea_num_iter=6,
+ ea_topk_per_iter=10,
+ tree_size=60,
+ draft_model_start=False,
+ rotation=False,
+ **kwargs):
+
+ super().__init__(base_path, **kwargs)
+
+ # eagle config
+ self.tree_drafter_type = 'eagle'
+ self.eagle_path = tree_path
+ self.ea_num_iter = ea_num_iter
+ self.ea_topk_per_iter = ea_topk_per_iter
+ self.tree_size = tree_size
+
+ self.tree_draft_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda")
+ self.tree_position_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda")
+ self.tree_gt_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda")
+ self.tree_attn_mask = torch.empty((tree_size), dtype=torch.uint64, device="cuda")
+ self.tree_parent = torch.empty((tree_size), dtype=torch.int32, device="cuda")
+
+
+ self.eagle_config = EagleConfig.from_pretrained(self.eagle_path)
+ self.rotation = rotation
+
+
+ # draft config
+ self.drafter_type = 'draft'
+ self.drafter_path = drafter_path
+ self.drafter_tokenizer = AutoTokenizer.from_pretrained(drafter_path)
+ self.drafter_config = AutoConfig.from_pretrained(drafter_path)
+
+ self.min_draft_length = min_draft_length
+ self.max_draft_length = min_draft_length + ea_num_iter
+ self.draft_ids = torch.empty((self.max_draft_length+1), dtype=torch.int32, device="cuda")
+ self.draft_position_ids = torch.empty((self.max_draft_length+1), dtype=torch.int32, device="cuda")
+ self.draft_gt_ids = torch.empty((self.max_draft_length+1), dtype=torch.int32, device="cuda")
+ self.draft_attn_mask = pack_draft_mask(
+ torch.tril(torch.ones(self.max_draft_length+1, self.max_draft_length+1, dtype=torch.bool)).to("cuda")
+ )
+
+ # eagle accept list
+ self.draft_ea_accept_list = torch.empty((1024,), dtype=torch.int32, device="cuda")
+
+ self.draft_logits = torch.empty((64, self.config.vocab_size), dtype=self.dtype, device="cuda")
+ self.draft_cache_length = torch.tensor([0], dtype=torch.int32, device="cuda")
+ self.cache_length = torch.tensor([0], dtype=torch.int32, device="cuda")
+ self.draft_cuda_graph = draft_cuda_graph
+
+ self.draft_model_start = draft_model_start
+
+ self.draft_group_size = self.drafter_config.quantization_config['group_size']
+
+ if self.rotation:
+ C.init_hier_eagle_w4a16_gm_rot_spec_w4a16_gm_model(
+ self.drafter_config.vocab_size,
+ self.drafter_config.num_hidden_layers,
+ self.drafter_config.hidden_size,
+ self.drafter_config.intermediate_size,
+ self.drafter_config.num_attention_heads,
+ self.drafter_config.num_key_value_heads,
+ self.drafter_config.head_dim,
+ self.drafter_config.rms_norm_eps,
+ self.draft_group_size,
+ self.min_draft_length,
+ self.draft_cuda_graph,
+ self.eagle_config.eagle_num_layers,
+ self.ea_num_iter,
+ self.ea_topk_per_iter,
+ self.tree_size,
+ self.draft_model_start,
+ 0,
+ )
+ else:
+ C.init_hier_eagle_w4a16_gm_spec_w4a16_gm_model(
+ self.drafter_config.vocab_size,
+ self.drafter_config.num_hidden_layers,
+ self.drafter_config.hidden_size,
+ self.drafter_config.intermediate_size,
+ self.drafter_config.num_attention_heads,
+ self.drafter_config.num_key_value_heads,
+ self.drafter_config.head_dim,
+ self.drafter_config.rms_norm_eps,
+ self.draft_group_size,
+ self.min_draft_length,
+ self.draft_cuda_graph,
+ self.eagle_config.eagle_num_layers,
+ self.ea_num_iter,
+ self.ea_topk_per_iter,
+ self.tree_size,
+ self.draft_model_start,
+ 0,
+ )
+
+ # def load_from_hf(self):
+ # self._load_from_ckpt(self.eagle_path, cls=self.tree_drafter_type)
+ # self._load_from_ckpt(self.drafter_path, cls=self.drafter_type)
+ # super().load_from_hf()
+
+ def _load(self, name, param, dtype=None, cls=None):
+ if cls == self.tree_drafter_type:
+ if dtype is None:
+ dtype = self.dtype
+ param = param.contiguous().to(dtype)
+ if (not self.rotation) and 'embed_tokens' in name:
+ return
+ if 'fc' in name:
+ if 'weight' in name:
+ param1 = param[..., :param.shape[-1] // 2].contiguous()
+ param2 = param[..., param.shape[-1] // 2:].contiguous()
+ C.load_model(f"{cls}.{name.replace('fc', 'fc1')}", param1.data_ptr())
+ C.load_model(f"{cls}.{name.replace('fc', 'fc2')}", param2.data_ptr())
+ else: # bias
+ C.load_model(f"{cls}.{name.replace('fc', 'fc1')}", param.data_ptr())
+ else:
+ C.load_model(f"{cls}.{name}", param.data_ptr())
+ elif cls == self.drafter_type:
+ if dtype is None:
+ if 'rotary_emb' in name:
+ dtype = torch.float32
+ else:
+ dtype = self.dtype
+
+ # if 'gate_up_proj' in name:
+ # self._load(name.replace("gate_up_proj", "gate_proj"), param[:param.shape[0]//2], dtype, cls=cls)
+ # self._load(name.replace("gate_up_proj", "up_proj"), param[param.shape[0]//2:], cls=cls)
+ # elif 'qkv_proj' in name:
+ # self._load(name.replace("qkv_proj", "q_proj"), param[:self.config.num_attention_heads * self.config.head_dim], cls=cls)
+ # self._load(name.replace("qkv_proj", "k_proj"), param[self.config.num_attention_heads * self.config.head_dim:(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim], cls=cls)
+ # self._load(name.replace("qkv_proj", "v_proj"), param[(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim:], cls=cls)
+ # else:
+ param = param.contiguous()
+ if param.dtype not in [torch.int8, torch.int16, torch.int32]:
+ param = param.to(dtype)
+ C.load_model(f"{cls}.{name}", param.data_ptr())
+
+ if "embed_tokens" in name and hasattr(self.config, "tie_word_embeddings") and self.config.tie_word_embeddings:
+ self._load("lm_head.weight", param, cls)
+ else:
+ super()._load(name, param, dtype)
+
+ def load_from_hf(self):
+ with torch.no_grad():
+ # ealge load
+ self._load_from_ckpt(self.eagle_path, cls=self.tree_drafter_type)
+
+ self._load_from_ckpt(self.drafter_path, cls=self.drafter_type)
+ # rope
+ if hasattr(self.drafter_config, "rope_scaling") and self.drafter_config.rope_scaling is not None:
+ draft_rope_type = self.drafter_config.rope_scaling.get("rope_type", self.drafter_config.rope_scaling.get("type"))
+ else:
+ draft_rope_type = "default"
+ # TODO only support "default", "llama3" or "longrope" with long_factor=short_factor
+ draft_inv_freq, draft_attention_scaling = ROPE_INIT_FUNCTIONS[draft_rope_type](self.drafter_config, "cpu", seq_len=self.max_total_length)
+ # attention_scaling = torch.tensor([attention_scaling], dtype=torch.float32, device="cpu")
+ self._load(f"{self.drafter_type}.model.rotary_emb.inv_freq", draft_inv_freq, dtype=torch.float32, cls=self.drafter_type)
+ # self._load("model.rotary_emb.attention_scaling", attention_scaling, dtype=torch.float32)
+
+ super().load_from_hf()
+
+
+ def generate(self, input_ids, generation_length=100, teminators=[]):
+ assert input_ids.dtype == torch.int32
+
+ prefix_length = input_ids.shape[1]
+
+ position_ids = torch.arange(prefix_length, dtype=torch.int32, device="cuda")
+ logits = self.prefill(input_ids, position_ids)
+ self.draft_ids[:1].copy_(logits[0].argmax(dim=-1))
+
+ tokens = torch.empty((generation_length), dtype=torch.int32, device="cuda")
+ tokens[0].copy_(self.draft_ids[0])
+ accept_lengths = []
+ i = 0
+ model_step = 0
+ terminal = False
+ torch.cuda.synchronize()
+ start_time = time.time()
+
+
+ while i < generation_length-1 and not terminal:
+ self.cache_length[0] = prefix_length + i
+ self.draft_position_ids[0] = prefix_length + i
+
+
+ # step 1: draft model prefill and eagle input prepare
+ C.draft(
+ self.draft_ids.data_ptr(),
+ self.draft_position_ids.data_ptr(),
+ self.cache_length.data_ptr(),
+ self.draft_attn_mask.data_ptr(),
+ self.draft_ea_accept_list.data_ptr(),
+ )
+
+
+ # step 2: target model decode (length fixed for cuda graph)
+ logits = self.decode(self.draft_ids, self.draft_position_ids, self.cache_length, mask_2d=self.draft_attn_mask)
+ self.draft_gt_ids.copy_(logits.argmax(dim=-1))
+
+ # step 6: verify and fix target model and eagle input
+ accept_length = C.verify_and_fix(
+ self.draft_ids.numel(),
+ self.draft_ids.data_ptr(),
+ self.draft_gt_ids.data_ptr(),
+ self.draft_position_ids.data_ptr(),
+ self.cache_length.data_ptr(),
+ self.draft_attn_mask.data_ptr(),
+ self.draft_ea_accept_list.data_ptr(),
+ )
+
+ model_step += 1
+ accept_lengths.append(accept_length)
+ for temin in teminators:
+ if temin in self.draft_gt_ids[:accept_length]:
+ terminal = True
+ append_length = min(accept_length, generation_length - 1 - i)
+ tokens[1+i:1+i+append_length].copy_(self.draft_gt_ids[:append_length])
+ self.draft_ids[0] = self.draft_gt_ids[accept_length - 1]
+ i += accept_length
+
+
+ # print(f"ea accept avg:", np.mean(self.draft_ea_accept_list[1:ea_acc_nums+1].cpu().numpy()))
+
+ torch.cuda.synchronize()
+ decode_time = time.time() - start_time
+ ea_acc_nums = self.draft_ea_accept_list[0].item()
+ tokens = tokens[:i+1].tolist()
+ return tokens, accept_lengths, model_step, decode_time, self.draft_ea_accept_list[1:1+ea_acc_nums].clone()
\ No newline at end of file
diff --git a/examples/CPM.cu/cpmcu/speculative/spec_quant/__init__.py b/examples/CPM.cu/cpmcu/speculative/spec_quant/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/examples/CPM.cu/cpmcu/speculative/spec_quant/spec_w4a16_gm_for_w4a16_gm_model.py b/examples/CPM.cu/cpmcu/speculative/spec_quant/spec_w4a16_gm_for_w4a16_gm_model.py
new file mode 100644
index 00000000..15b22414
--- /dev/null
+++ b/examples/CPM.cu/cpmcu/speculative/spec_quant/spec_w4a16_gm_for_w4a16_gm_model.py
@@ -0,0 +1,160 @@
+from ... import C
+from transformers import AutoTokenizer, AutoConfig
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
+from safetensors.torch import load_file
+
+from ...llm_w4a16_gptq_marlin import W4A16GPTQMarlinLLM
+
+import torch
+import time
+
+def pack_draft_mask(mask_2d):
+ '''
+ for static masks, pack them into a uint64 per row
+ '''
+ mask_2d_packed = torch.zeros((mask_2d.shape[0]), dtype=torch.uint16, device="cuda")
+ for i in range(mask_2d.shape[0]):
+ mask_1 = 0
+ for j in range(i + 1):
+ mask_1 |= (mask_2d[i][j].item() << j )
+ mask_2d_packed[i] = mask_1
+ mask_2d_packed = mask_2d_packed.view(torch.uint16).view(-1)
+ return mask_2d_packed
+
+
+class W4A16GMSpecW4A16GM(W4A16GPTQMarlinLLM):
+ def __init__(self,
+ drafter_path: str,
+ base_path: str,
+ draft_num: int,
+ draft_cuda_graph: bool,
+ **kwargs):
+ super().__init__(base_path, **kwargs)
+
+ self.drafter_type = 'draft'
+ self.drafter_path = drafter_path
+ self.drafter_tokenizer = AutoTokenizer.from_pretrained(drafter_path)
+ self.drafter_config = AutoConfig.from_pretrained(drafter_path)
+
+ self.draft_num = draft_num
+ self.draft_ids = torch.empty((self.draft_num+1), dtype=torch.int32, device="cuda")
+ self.draft_position_ids = torch.empty((self.draft_num+1), dtype=torch.int32, device="cuda")
+ self.draft_gt_ids = torch.empty((self.draft_num+1), dtype=torch.int32, device="cuda")
+ self.draft_attn_mask = pack_draft_mask(
+ torch.tril(torch.ones(draft_num+1, draft_num+1, dtype=torch.bool)).to("cuda")
+ )
+ self.draft_parent = torch.tensor([], dtype=torch.int32, device="cuda")
+ self.cache_length = torch.tensor([0], dtype=torch.int32, device="cuda")
+
+ self.draft_cuda_graph = draft_cuda_graph
+
+ self.draft_group_size = self.drafter_config.quantization_config['group_size']
+
+ C.init_w4a16_gm_spec_w4a16_gm_model(
+ self.drafter_config.vocab_size,
+ self.drafter_config.num_hidden_layers,
+ self.drafter_config.hidden_size,
+ self.drafter_config.intermediate_size,
+ self.drafter_config.num_attention_heads,
+ self.drafter_config.num_key_value_heads,
+ self.drafter_config.head_dim,
+ self.drafter_config.rms_norm_eps,
+ self.draft_group_size,
+ self.draft_num,
+ self.draft_cuda_graph,
+ self.dtype_int,
+ )
+
+ def _load(self, name, param, dtype=None, cls=None):
+ if cls == self.drafter_type:
+ if dtype is None:
+ if 'rotary_emb' in name:
+ dtype = torch.float32
+ else:
+ dtype = self.dtype
+
+ # if 'gate_up_proj' in name:
+ # self._load(name.replace("gate_up_proj", "gate_proj"), param[:param.shape[0]//2], dtype, cls=cls)
+ # self._load(name.replace("gate_up_proj", "up_proj"), param[param.shape[0]//2:], cls=cls)
+ # elif 'qkv_proj' in name:
+ # self._load(name.replace("qkv_proj", "q_proj"), param[:self.config.num_attention_heads * self.config.head_dim], cls=cls)
+ # self._load(name.replace("qkv_proj", "k_proj"), param[self.config.num_attention_heads * self.config.head_dim:(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim], cls=cls)
+ # self._load(name.replace("qkv_proj", "v_proj"), param[(self.config.num_attention_heads + self.config.num_key_value_heads) * self.config.head_dim:], cls=cls)
+ # else:
+ param = param.contiguous()
+ if param.dtype not in [torch.int8, torch.int16, torch.int32]:
+ param = param.to(dtype)
+ C.load_model(f"{cls}.{name}", param.data_ptr())
+
+ if "embed_tokens" in name and hasattr(self.config, "tie_word_embeddings") and self.config.tie_word_embeddings:
+ self._load("lm_head.weight", param, cls)
+ else:
+ super()._load(name, param, dtype)
+
+
+ def load_from_hf(self):
+ with torch.no_grad():
+ self._load_from_ckpt(self.drafter_path, cls=self.drafter_type)
+ # rope
+ if hasattr(self.drafter_config, "rope_scaling") and self.drafter_config.rope_scaling is not None:
+ draft_rope_type = self.drafter_config.rope_scaling.get("rope_type", self.drafter_config.rope_scaling.get("type"))
+ else:
+ draft_rope_type = "default"
+ # TODO only support "default", "llama3" or "longrope" with long_factor=short_factor
+ draft_inv_freq, draft_attention_scaling = ROPE_INIT_FUNCTIONS[draft_rope_type](self.drafter_config, "cpu", seq_len=self.max_total_length)
+ # attention_scaling = torch.tensor([attention_scaling], dtype=torch.float32, device="cpu")
+ self._load(f"{self.drafter_type}.model.rotary_emb.inv_freq", draft_inv_freq, dtype=torch.float32, cls=self.drafter_type)
+ # self._load("model.rotary_emb.attention_scaling", attention_scaling, dtype=torch.float32)
+ super().load_from_hf()
+
+
+ def generate(self, input_ids, generation_length=100, teminators=[]):
+ assert input_ids.dtype == torch.int32
+
+ prefix_length = input_ids.numel()
+ position_ids = torch.arange(prefix_length, dtype=torch.int32, device="cuda")
+ logits = self.prefill(input_ids, position_ids)
+ self.draft_ids[:1].copy_(logits[0].argmax(dim=-1))
+
+ tokens = torch.empty((generation_length), dtype=torch.int32, device="cuda")
+ tokens[0].copy_(self.draft_ids[0])
+ accept_lengths = []
+ i = 0
+ model_step = 0
+ terminal = False
+ torch.cuda.synchronize()
+ start_time = time.time()
+ while i < generation_length-1 and not terminal:
+ self.cache_length[0] = prefix_length + i
+ self.draft_position_ids[0] = prefix_length + i
+
+ # torch.cuda.nvtx.range_push(f"draft")
+ C.draft(self.draft_ids.data_ptr(), self.draft_position_ids.data_ptr(), self.cache_length.data_ptr(), self.draft_attn_mask.data_ptr(), self.draft_parent.data_ptr())
+ # torch.cuda.nvtx.range_pop()
+
+
+ logits = self.decode(self.draft_ids, self.draft_position_ids, self.cache_length, mask_2d=self.draft_attn_mask)
+ self.draft_gt_ids.copy_(logits.argmax(dim=-1))
+
+ # torch.cuda.nvtx.range_push(f"verify")
+ accept_length = C.verify_and_fix(
+ self.draft_ids.numel(), self.draft_ids.data_ptr(), self.draft_gt_ids.data_ptr(),
+ self.draft_position_ids.data_ptr(), self.cache_length.data_ptr(),
+ self.draft_attn_mask.data_ptr(), self.draft_parent.data_ptr()
+ )
+ # torch.cuda.nvtx.range_pop()
+
+ model_step += 1
+ accept_lengths.append(accept_length)
+ for temin in teminators:
+ if temin in self.draft_gt_ids[:accept_length]:
+ terminal = True
+ append_length = min(accept_length, generation_length - 1 - i)
+ tokens[1+i:1+i+append_length].copy_(self.draft_gt_ids[:append_length])
+ self.draft_ids[0] = self.draft_gt_ids[accept_length - 1]
+ i += accept_length
+ torch.cuda.synchronize()
+ decode_time = time.time() - start_time
+ tokens = tokens[:1+i].tolist()
+
+ return tokens, accept_lengths, model_step, decode_time
\ No newline at end of file
diff --git a/examples/CPM.cu/cpmcu/speculative/tree_drafter.py b/examples/CPM.cu/cpmcu/speculative/tree_drafter.py
new file mode 100644
index 00000000..428c9023
--- /dev/null
+++ b/examples/CPM.cu/cpmcu/speculative/tree_drafter.py
@@ -0,0 +1,255 @@
+from .. import C
+from ..llm import LLM
+
+import torch
+import torch.nn.functional as F
+import time
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
+
+def pack_mask(mask_2d):
+ '''
+ for static masks, pack them into a uint64 per row
+ '''
+ mask_2d_packed = torch.zeros((mask_2d.shape[0], 2), dtype=torch.uint32, device="cuda")
+ for i in range(mask_2d.shape[0]):
+ mask_1 = 0
+ mask_2 = 0
+ for j in range(i + 1):
+ if j < 32:
+ mask_1 |= (mask_2d[i][j].item() << j)
+ else:
+ mask_2 |= (mask_2d[i][j].item() << (j - 32))
+ mask_2d_packed[i][0] = mask_1
+ mask_2d_packed[i][1] = mask_2
+ mask_2d_packed = mask_2d_packed.view(torch.uint64).view(-1)
+ return mask_2d_packed
+
+class LLM_with_tree_drafter(LLM):
+ def __init__(self,
+ drafter_type, drafter_path, base_path,
+ tree_size,
+ use_rope: bool=False,
+ **kwargs):
+ super().__init__(base_path, **kwargs)
+
+ self.drafter_type = drafter_type
+ self.drafter_path = drafter_path
+ self.base_path = base_path
+ self.use_rope = use_rope
+
+ self.tree_size = tree_size
+ self.tree_draft_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda")
+ self.tree_position_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda")
+ self.tree_gt_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda")
+ self.tree_attn_mask = torch.empty((tree_size), dtype=torch.uint64, device="cuda")
+ self.tree_parent = torch.empty((tree_size), dtype=torch.int32, device="cuda")
+ self.tree_position_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda")
+
+ self.cache_length = torch.tensor([0], dtype=torch.int32, device="cuda")
+
+ def load_from_hf(self):
+ with torch.no_grad():
+ self._load_from_ckpt(self.drafter_path, cls=self.drafter_type)
+
+ if self.use_rope:
+ if hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None:
+ rope_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type"))
+ else:
+ rope_type = "default"
+ # TODO only support "default", "llama3" or "longrope" with long_factor=short_factor
+ inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](self.config, "cpu", seq_len=self.max_total_length)
+ # attention_scaling = torch.tensor([attention_scaling], dtype=torch.float32, device="cpu")
+ self._load(f"{self.drafter_type}.rotary_emb.inv_freq", inv_freq, dtype=torch.float32)
+ # self._load("model.rotary_emb.attention_scaling", attention_scaling, dtype=torch.float32)
+
+ super().load_from_hf()
+
+ def generate(self, input_ids, generation_length=100, teminators=[], use_stream=False):
+ """
+ Generate text with optional streaming output for tree drafter.
+ Returns (tokens, accept_lengths, decode_time, prefill_time) if use_stream=False, or generator yielding {'token', 'text', 'is_finished', 'accept_length', 'prefill_time', 'decode_time'} if use_stream=True.
+ """
+ assert input_ids.dtype == torch.int32
+
+ prefix_length = input_ids.numel()
+ # Check if input length exceeds maximum supported length
+ if prefix_length > self.max_total_length:
+ raise ValueError(f"Input token count ({prefix_length}) exceeds maximum supported length ({self.max_total_length}) under current memory limit")
+
+ position_ids = torch.arange(prefix_length, dtype=torch.int32, device="cuda")
+
+ # Set progress flag before prefill for stream mode
+ if use_stream:
+ self._show_prefill_progress = True
+ else:
+ self._show_prefill_progress = False
+
+ # Measure prefill time
+ if self.use_enter and use_stream:
+ # In use_enter mode, timing will be handled inside prefill method
+ logits = self.prefill(input_ids, position_ids)
+ prefill_time = getattr(self, '_last_prefill_time', 0.0) # Get actual prefill time
+ else:
+ torch.cuda.synchronize()
+ prefill_start = time.time()
+ logits = self.prefill(input_ids, position_ids)
+ torch.cuda.synchronize()
+ prefill_time = time.time() - prefill_start
+
+ if self.temperature > 0.0:
+ self.tree_draft_ids[:1].copy_(torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator))
+ else:
+ self.tree_draft_ids[:1].copy_(logits[0].argmax(dim=-1))
+
+ # Wait for user input before decode phase if use_decode_enter is enabled
+ if self.use_decode_enter:
+ if use_stream and self.use_enter:
+ # In stream mode with use_enter, we already showed [Decoding], just wait for input
+ print("Please Press Enter to Start Decoding...", end="", flush=True)
+ input() # Wait for Enter key
+ print("\r" + " " * 50 + "\r", end="", flush=True) # Clear the prompt without showing [Decoding] again
+ else:
+ # In other modes, show prompt and wait
+ print("Please Press Enter to Start Decoding...", end="", flush=True)
+ input() # Wait for Enter key
+ print("\r" + " " * 50 + "\r[Decoding]", flush=True) # Show [Decoding] only when use_enter is not enabled
+
+ if use_stream:
+ # Stream generation for tree drafter (optimized)
+ def _stream_generator():
+ # Keep minimal context for correct spacing
+ prev_token = None
+
+ # yield first token
+ token = self.tree_draft_ids[0].item()
+ text = self.tokenizer.decode([token], skip_special_tokens=False)
+ prev_token = token
+
+ yield {
+ 'token': token,
+ 'text': text,
+ 'is_finished': token in teminators,
+ 'accept_length': 1,
+ 'prefill_time': prefill_time,
+ 'decode_time': 0.0 # First token comes from prefill
+ }
+
+ if token in teminators:
+ return
+
+ decode_start_time = time.time()
+
+ i = 0
+ while i < generation_length-1:
+ self.cache_length[0] = prefix_length + i
+
+ # draft step
+ C.draft(self.tree_draft_ids.data_ptr(), self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(), self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr())
+
+ logits = self.decode(self.tree_draft_ids, self.tree_position_ids, self.cache_length, mask_2d=self.tree_attn_mask)
+ if self.temperature > 0.0:
+ self.tree_gt_ids.copy_(torch.multinomial(F.softmax(logits/self.temperature, dim=-1), num_samples=1, generator=self.generator).squeeze(-1))
+ else:
+ self.tree_gt_ids.copy_(logits.argmax(dim=-1))
+
+ # verify step
+ accept_length = C.verify_and_fix(
+ self.tree_draft_ids.numel(), self.tree_draft_ids.data_ptr(), self.tree_gt_ids.data_ptr(),
+ self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(),
+ self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr()
+ )
+
+ # yield accepted tokens (optimized with minimal context)
+ if accept_length > 0:
+ accepted_tokens = self.tree_draft_ids[:accept_length].tolist()
+
+ # For correct spacing, decode with previous token context
+ if prev_token is not None:
+ context_tokens = [prev_token] + accepted_tokens
+ context_text = self.tokenizer.decode(context_tokens, skip_special_tokens=False)
+ prev_text = self.tokenizer.decode([prev_token], skip_special_tokens=False)
+ new_text = context_text[len(prev_text):]
+ else:
+ new_text = self.tokenizer.decode(accepted_tokens, skip_special_tokens=False)
+
+ # Yield tokens with batch text for first token, empty for others
+ for j in range(accept_length):
+ if i + j >= generation_length - 1:
+ break
+
+ token = accepted_tokens[j]
+
+ # Give all new text to first token, empty to others
+ if j == 0:
+ text = new_text
+ else:
+ text = ""
+
+ terminal = token in teminators
+ is_finished = terminal or (i + j == generation_length - 2)
+
+ # Only calculate time for the last token in the batch to reduce overhead
+ decode_time = time.time() - decode_start_time if j == accept_length - 1 else 0.0
+
+ yield {
+ 'token': token,
+ 'text': text,
+ 'is_finished': is_finished,
+ 'accept_length': accept_length if j == 0 else 0, # only report accept_length for first token in batch
+ 'prefill_time': 0.0, # Only report prefill_time for first token
+ 'decode_time': decode_time
+ }
+
+ if terminal:
+ return
+
+ # Update prev_token to the last accepted token
+ prev_token = accepted_tokens[-1]
+
+ self.tree_draft_ids[0] = self.tree_draft_ids[accept_length - 1]
+ i += accept_length
+
+ return _stream_generator()
+ else:
+ # Original batch generation
+ tokens = torch.empty((generation_length), dtype=torch.int32, device="cuda")
+ tokens[0].copy_(self.tree_draft_ids[0])
+ accept_lengths = []
+ i = 0
+ terminal = False
+ torch.cuda.synchronize()
+ decode_start = time.time()
+ while i < generation_length-1 and not terminal:
+ self.cache_length[0] = prefix_length + i
+
+ # torch.cuda.nvtx.range_push(f"draft")
+ C.draft(self.tree_draft_ids.data_ptr(), self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(), self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr())
+ # torch.cuda.nvtx.range_pop()
+
+ logits = self.decode(self.tree_draft_ids, self.tree_position_ids, self.cache_length, mask_2d=self.tree_attn_mask)
+ if self.temperature > 0.0:
+ self.tree_gt_ids.copy_(torch.multinomial(F.softmax(logits/self.temperature, dim=-1), num_samples=1, generator=self.generator).squeeze(-1))
+ else:
+ self.tree_gt_ids.copy_(logits.argmax(dim=-1))
+
+ # torch.cuda.nvtx.range_push(f"verify")
+ accept_length = C.verify_and_fix(
+ self.tree_draft_ids.numel(), self.tree_draft_ids.data_ptr(), self.tree_gt_ids.data_ptr(),
+ self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(),
+ self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr()
+ )
+ # torch.cuda.nvtx.range_pop()
+
+ accept_lengths.append(accept_length)
+ for temin in teminators:
+ if temin in self.tree_draft_ids[:accept_length]:
+ terminal = True
+ append_length = min(accept_length, generation_length - 1 - i)
+ tokens[1+i:1+i+append_length].copy_(self.tree_draft_ids[:append_length])
+ self.tree_draft_ids[0] = self.tree_draft_ids[accept_length - 1]
+ i += accept_length
+ torch.cuda.synchronize()
+ decode_time = time.time() - decode_start
+ tokens = tokens[:1+i].tolist()
+
+ return tokens, accept_lengths, decode_time, prefill_time
\ No newline at end of file
diff --git a/examples/CPM.cu/cpmcu/speculative/tree_drafter_base_quant/__init__.py b/examples/CPM.cu/cpmcu/speculative/tree_drafter_base_quant/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/examples/CPM.cu/cpmcu/speculative/tree_drafter_base_quant/tree_drafter_w4a16_gptq_marlin.py b/examples/CPM.cu/cpmcu/speculative/tree_drafter_base_quant/tree_drafter_w4a16_gptq_marlin.py
new file mode 100644
index 00000000..63ccf306
--- /dev/null
+++ b/examples/CPM.cu/cpmcu/speculative/tree_drafter_base_quant/tree_drafter_w4a16_gptq_marlin.py
@@ -0,0 +1,240 @@
+from ... import C
+from ...llm_w4a16_gptq_marlin import W4A16GPTQMarlinLLM
+
+import torch
+from ..tree_drafter import *
+import time
+from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
+import torch.nn.functional as F
+
+class W4A16GPTQMarlinLLM_with_tree_drafter(W4A16GPTQMarlinLLM):
+ def __init__(self,
+ drafter_type, drafter_path, base_path,
+ tree_size,
+ use_rope: bool=False,
+ temperature: float=0.0,
+ **kwargs):
+ super().__init__(base_path, **kwargs)
+
+ self.drafter_type = drafter_type
+ self.drafter_path = drafter_path
+ self.base_path = base_path
+ self.use_rope = use_rope
+
+ self.tree_size = tree_size
+ self.tree_draft_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda")
+ self.tree_position_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda")
+ self.tree_gt_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda")
+ self.tree_attn_mask = torch.empty((tree_size), dtype=torch.uint64, device="cuda")
+ self.tree_parent = torch.empty((tree_size), dtype=torch.int32, device="cuda")
+ self.tree_position_ids = torch.empty((tree_size), dtype=torch.int32, device="cuda")
+ self.temperature = temperature
+
+ self.cache_length = torch.tensor([0], dtype=torch.int32, device="cuda")
+
+ def load_from_hf(self):
+ with torch.no_grad():
+ self._load_from_ckpt(self.drafter_path, cls=self.drafter_type)
+
+ if self.use_rope:
+ if hasattr(self.config, "rope_scaling") and self.config.rope_scaling is not None:
+ rope_type = self.config.rope_scaling.get("rope_type", self.config.rope_scaling.get("type"))
+ else:
+ rope_type = "default"
+ # TODO only support "default", "llama3" or "longrope" with long_factor=short_factor
+ inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](self.config, "cpu", seq_len=self.max_total_length)
+ # attention_scaling = torch.tensor([attention_scaling], dtype=torch.float32, device="cpu")
+ self._load(f"{self.drafter_type}.rotary_emb.inv_freq", inv_freq, dtype=torch.float32)
+ # self._load("model.rotary_emb.attention_scaling", attention_scaling, dtype=torch.float32)
+
+ super().load_from_hf()
+
+ def generate(self, input_ids, generation_length=100, teminators=[], use_stream=False):
+ """
+ Generate text with optional streaming output for quantized tree drafter.
+ Returns (tokens, accept_lengths, decode_time, prefill_time) if use_stream=False, or generator yielding {'token', 'text', 'is_finished', 'accept_length', 'prefill_time', 'decode_time'} if use_stream=True.
+ """
+ assert input_ids.dtype == torch.int32
+
+ prefix_length = input_ids.numel()
+ # Check if input length exceeds maximum supported length
+ if prefix_length > self.max_total_length:
+ raise ValueError(f"Input token count ({prefix_length}) exceeds maximum supported length ({self.max_total_length}) under current memory limit")
+
+ position_ids = torch.arange(prefix_length, dtype=torch.int32, device="cuda")
+
+ # Set progress flag before prefill for stream mode
+ if use_stream:
+ self._show_prefill_progress = True
+ else:
+ self._show_prefill_progress = False
+
+ # Measure prefill time
+ if self.use_enter and use_stream:
+ # In use_enter mode, timing will be handled inside prefill method
+ logits = self.prefill(input_ids, position_ids)
+ prefill_time = getattr(self, '_last_prefill_time', 0.0) # Get actual prefill time
+ else:
+ torch.cuda.synchronize()
+ prefill_start = time.time()
+ logits = self.prefill(input_ids, position_ids)
+ torch.cuda.synchronize()
+ prefill_time = time.time() - prefill_start
+
+ if self.temperature > 0.0:
+ self.tree_draft_ids[:1].copy_(torch.multinomial(F.softmax(logits[0]/self.temperature, dim=-1), num_samples=1, generator=self.generator))
+ else:
+ self.tree_draft_ids[:1].copy_(logits[0].argmax(dim=-1))
+
+ # Wait for user input before decode phase if use_decode_enter is enabled
+ if self.use_decode_enter:
+ if use_stream and self.use_enter:
+ # In stream mode with use_enter, we already showed [Decoding], just wait for input
+ print("Please Press Enter to Start Decoding...", end="", flush=True)
+ input() # Wait for Enter key
+ print("\r" + " " * 50 + "\r", end="", flush=True) # Clear the prompt without showing [Decoding] again
+ else:
+ # In other modes, show prompt and wait
+ print("Please Press Enter to Start Decoding...", end="", flush=True)
+ input() # Wait for Enter key
+ print("\r" + " " * 50 + "\r[Decoding]", flush=True) # Show [Decoding] only when use_enter is not enabled
+
+ if use_stream:
+ # Stream generation for quantized tree drafter (optimized)
+ def _stream_generator():
+ # Keep minimal context for correct spacing
+ prev_token = None
+
+ # yield first token
+ token = self.tree_draft_ids[0].item()
+ text = self.tokenizer.decode([token], skip_special_tokens=False)
+ prev_token = token
+
+ yield {
+ 'token': token,
+ 'text': text,
+ 'is_finished': token in teminators,
+ 'accept_length': 1,
+ 'prefill_time': prefill_time,
+ 'decode_time': 0.0 # First token comes from prefill
+ }
+
+ if token in teminators:
+ return
+
+ decode_start_time = time.time()
+
+ i = 0
+ while i < generation_length-1:
+ self.cache_length[0] = prefix_length + i
+
+ # draft step
+ C.draft(self.tree_draft_ids.data_ptr(), self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(), self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr())
+
+ logits = self.decode(self.tree_draft_ids, self.tree_position_ids, self.cache_length, mask_2d=self.tree_attn_mask)
+ if self.temperature > 0.0:
+ self.tree_gt_ids.copy_(torch.multinomial(F.softmax(logits/self.temperature, dim=-1), num_samples=1, generator=self.generator).squeeze(-1))
+ else:
+ self.tree_gt_ids.copy_(logits.argmax(dim=-1))
+
+ # verify step
+ accept_length = C.verify_and_fix(
+ self.tree_draft_ids.numel(), self.tree_draft_ids.data_ptr(), self.tree_gt_ids.data_ptr(),
+ self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(),
+ self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr()
+ )
+
+ # yield accepted tokens (optimized with minimal context)
+ if accept_length > 0:
+ accepted_tokens = self.tree_draft_ids[:accept_length].tolist()
+
+ # For correct spacing, decode with previous token context
+ if prev_token is not None:
+ context_tokens = [prev_token] + accepted_tokens
+ context_text = self.tokenizer.decode(context_tokens, skip_special_tokens=False)
+ prev_text = self.tokenizer.decode([prev_token], skip_special_tokens=False)
+ new_text = context_text[len(prev_text):]
+ else:
+ new_text = self.tokenizer.decode(accepted_tokens, skip_special_tokens=False)
+
+ # Yield tokens with batch text for first token, empty for others
+ for j in range(accept_length):
+ if i + j >= generation_length - 1:
+ break
+
+ token = accepted_tokens[j]
+
+ # Give all new text to first token, empty to others
+ if j == 0:
+ text = new_text
+ else:
+ text = ""
+
+ terminal = token in teminators
+ is_finished = terminal or (i + j == generation_length - 2)
+
+ # Only calculate time for the last token in the batch to reduce overhead
+ decode_time = time.time() - decode_start_time if j == accept_length - 1 else 0.0
+
+ yield {
+ 'token': token,
+ 'text': text,
+ 'is_finished': is_finished,
+ 'accept_length': accept_length if j == 0 else 0, # only report accept_length for first token in batch
+ 'prefill_time': 0.0, # Only report prefill_time for first token
+ 'decode_time': decode_time
+ }
+
+ if terminal:
+ return
+
+ # Update prev_token to the last accepted token
+ prev_token = accepted_tokens[-1]
+
+ self.tree_draft_ids[0] = self.tree_draft_ids[accept_length - 1]
+ i += accept_length
+
+ return _stream_generator()
+ else:
+ # Original batch generation
+ tokens = torch.empty((generation_length), dtype=torch.int32, device="cuda")
+ tokens[0].copy_(self.tree_draft_ids[0])
+ accept_lengths = []
+ i = 0
+ terminal = False
+ torch.cuda.synchronize()
+ decode_start = time.time()
+ while i < generation_length-1 and not terminal:
+ self.cache_length[0] = prefix_length + i
+
+ # torch.cuda.nvtx.range_push(f"draft")
+ C.draft(self.tree_draft_ids.data_ptr(), self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(), self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr())
+ # torch.cuda.nvtx.range_pop()
+
+ logits = self.decode(self.tree_draft_ids, self.tree_position_ids, self.cache_length, mask_2d=self.tree_attn_mask)
+ if self.temperature > 0.0:
+ self.tree_gt_ids.copy_(torch.multinomial(F.softmax(logits/self.temperature, dim=-1), num_samples=1, generator=self.generator).squeeze(-1))
+ else:
+ self.tree_gt_ids.copy_(logits.argmax(dim=-1))
+
+ # torch.cuda.nvtx.range_push(f"verify")
+ accept_length = C.verify_and_fix(
+ self.tree_draft_ids.numel(), self.tree_draft_ids.data_ptr(), self.tree_gt_ids.data_ptr(),
+ self.tree_position_ids.data_ptr(), self.cache_length.data_ptr(),
+ self.tree_attn_mask.data_ptr(), self.tree_parent.data_ptr()
+ )
+ # torch.cuda.nvtx.range_pop()
+
+ accept_lengths.append(accept_length)
+ for temin in teminators:
+ if temin in self.tree_draft_ids[:accept_length]:
+ terminal = True
+ append_length = min(accept_length, generation_length - 1 - i)
+ tokens[1+i:1+i+append_length].copy_(self.tree_draft_ids[:append_length])
+ self.tree_draft_ids[0] = self.tree_draft_ids[accept_length - 1]
+ i += accept_length
+ torch.cuda.synchronize()
+ decode_time = time.time() - decode_start
+ tokens = tokens[:1+i].tolist()
+
+ return tokens, accept_lengths, decode_time, prefill_time
\ No newline at end of file
diff --git a/examples/CPM.cu/model_convert/convert_llama_format.py b/examples/CPM.cu/model_convert/convert_llama_format.py
new file mode 100644
index 00000000..ececfcb2
--- /dev/null
+++ b/examples/CPM.cu/model_convert/convert_llama_format.py
@@ -0,0 +1,44 @@
+from transformers import AutoModelForCausalLM, AutoTokenizer
+import torch
+import math
+
+torch.manual_seed(0)
+
+def llm_load(path):
+ tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map='auto', trust_remote_code=True)
+
+ return model, tokenizer
+
+def convert_llm():
+ # model.embed_tokens.weight * scale_emb
+ state_dict["model.embed_tokens.weight"] = state_dict["model.embed_tokens.weight"] * scale_emb
+
+ # lm_head.weight / (hidden_size / dim_model_base)
+ state_dict["lm_head.weight"] = state_dict["lm_head.weight"] / (hidden_size / dim_model_base)
+
+ for i in range(num_layers):
+ attn_out_name = f"model.layers.{i}.self_attn.o_proj.weight"
+ state_dict[attn_out_name] = state_dict[attn_out_name] * (scale_depth / math.sqrt(num_layers))
+
+ ffn_down_proj_name = f"model.layers.{i}.mlp.down_proj.weight"
+ state_dict[ffn_down_proj_name] = state_dict[ffn_down_proj_name] * (scale_depth / math.sqrt(num_layers))
+
+ torch.save(state_dict, "./pytorch_model.bin")
+
+if __name__ == "__main__":
+ model, tokenizer = llm_load("/DATA/disk0/zhaoweilun/minicpm4/models/stable_7T_decay_700B_decay2_300B_longdecay_1sw1fa_sft_50B_release")
+
+ scale_emb = model.config.scale_emb
+ dim_model_base = model.config.dim_model_base
+ scale_depth = model.config.scale_depth
+ num_layers = model.config.num_hidden_layers
+ hidden_size = model.config.hidden_size
+ print(f"scale_emb = {scale_emb}")
+ print(f"dim_model_base = {dim_model_base}")
+ print(f"scale_depth = {scale_depth}")
+ print(f"num_layers = {num_layers}")
+ print(f"hidden_size = {hidden_size}")
+
+ state_dict = model.state_dict()
+ convert_llm()
\ No newline at end of file
diff --git a/examples/CPM.cu/model_convert/convert_w4a16.py b/examples/CPM.cu/model_convert/convert_w4a16.py
new file mode 100644
index 00000000..d2ed26c0
--- /dev/null
+++ b/examples/CPM.cu/model_convert/convert_w4a16.py
@@ -0,0 +1,287 @@
+import torch
+from safetensors.torch import load_file, save_file
+import os, glob, shutil
+import re
+from typing import List
+import argparse
+import numpy as np
+from transformers import AutoConfig
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument("--model-path", type=str, required=True, help="Path to the original model")
+parser.add_argument("--quant-path", type=str, required=True, help="Path to the AutoGPTQ model")
+parser.add_argument("--output-path", type=str, required=True, help="Path to save the converted model")
+
+# Copied from https://github.com/AutoGPTQ/AutoGPTQ/blob/9f7d37072917ab3a7545835f23e808294a542153/auto_gptq/nn_modules/qlinear/qlinear_marlin.py
+def get_perms():
+ perm = []
+ for i in range(32):
+ perm1 = []
+ col = i // 4
+ for block in [0, 1]:
+ for row in [
+ 2 * (i % 4),
+ 2 * (i % 4) + 1,
+ 2 * (i % 4 + 4),
+ 2 * (i % 4 + 4) + 1,
+ ]:
+ perm1.append(16 * row + col + 8 * block)
+ for j in range(4):
+ perm.extend([p + 256 * j for p in perm1])
+
+ perm = np.array(perm)
+ interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
+ perm = perm.reshape((-1, 8))[:, interleave].ravel()
+ perm = torch.from_numpy(perm)
+ scale_perm = []
+ for i in range(8):
+ scale_perm.extend([i + 8 * j for j in range(8)])
+ scale_perm_single = []
+ for i in range(4):
+ scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
+ return perm, scale_perm, scale_perm_single
+
+PERM, SCALE_PERM, SCALE_PERM_SINGLE = get_perms()
+
+def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, group_size: int) -> torch.Tensor:
+
+ if group_size < size_k and group_size != -1:
+ s = s.reshape((-1, len(SCALE_PERM)))[:, SCALE_PERM]
+ else:
+ s = s.reshape((-1, len(SCALE_PERM_SINGLE)))[:, SCALE_PERM_SINGLE]
+ s = s.reshape((-1, size_n)).contiguous()
+
+ return s
+
+def marlin_repack_qweight(qweight: torch.Tensor, bits: int, size_k: int, size_n: int, tile: int = 16) -> torch.Tensor:
+
+ # unpack
+ compress_factor = 32 // bits
+ mask = (1 << bits) - 1
+ qweight = qweight.cpu().numpy().astype(np.uint32)
+ unpacked_qweight = np.zeros((size_k, size_n), dtype=np.uint32)
+ unpacked_offset = np.arange(size_k) // compress_factor
+ unpacked_shift = (np.arange(size_k) % compress_factor) * bits
+ unpacked_qweight = (qweight[unpacked_offset, :] >> unpacked_shift[:, None]) & mask
+
+ # permute
+ unpacked_qweight = torch.from_numpy(unpacked_qweight.astype(np.int32))
+ unpacked_qweight = unpacked_qweight.reshape((size_k // tile, tile, size_n // tile, tile))
+ unpacked_qweight = unpacked_qweight.permute(0, 2, 1, 3)
+ unpacked_qweight = unpacked_qweight.reshape((size_k // tile, size_n * tile)).contiguous()
+ unpacked_qweight = unpacked_qweight.reshape((-1, PERM.numel()))[:, PERM].reshape(unpacked_qweight.shape)
+
+ # repack
+ repacked_qweight = np.zeros((unpacked_qweight.shape[0], unpacked_qweight.shape[1] // compress_factor), dtype=np.uint32)
+ unpacked_qweight = unpacked_qweight.cpu().numpy().astype(np.uint32)
+ for i in range(compress_factor):
+ repacked_qweight |= unpacked_qweight[:, i::compress_factor] << (bits * i)
+ repacked_qweight = torch.from_numpy(repacked_qweight.astype(np.int32))
+
+ return repacked_qweight
+
+def convert_w4a16_checkpoint(orig_model_path, quant_path, output_path):
+
+ config = AutoConfig.from_pretrained(quant_path)
+
+ group_size = config.quantization_config['group_size']
+ assert group_size in [-1, 128], "Only group_size -1 and 128 are supported for marlin"
+
+ bits = config.quantization_config['bits']
+ assert bits == 4, "Only 4-bit quantization is supported for marlin"
+
+ model_path = glob.glob(os.path.join(quant_path, "*.safetensors"))[0]
+
+ autogptq_weigths = load_file(model_path)
+
+ gptq_convert_dict = {
+ "model.layers.{}.self_attn.q_proj.qweight": ["model.layers.{}.self_attn.q_proj.scales", "model.layers.{}.self_attn.q_proj.g_idx", "model.layers.{}.self_attn.q_proj.qzeros"],
+ "model.layers.{}.self_attn.k_proj.qweight":["model.layers.{}.self_attn.k_proj.scales", "model.layers.{}.self_attn.k_proj.g_idx", "model.layers.{}.self_attn.k_proj.qzeros"],
+ "model.layers.{}.self_attn.v_proj.qweight":["model.layers.{}.self_attn.v_proj.scales", "model.layers.{}.self_attn.v_proj.g_idx", "model.layers.{}.self_attn.v_proj.qzeros"],
+ "model.layers.{}.self_attn.o_proj.qweight":["model.layers.{}.self_attn.o_proj.scales", "model.layers.{}.self_attn.o_proj.g_idx", "model.layers.{}.self_attn.o_proj.qzeros"],
+ "model.layers.{}.mlp.gate_proj.qweight":["model.layers.{}.mlp.gate_proj.scales", "model.layers.{}.mlp.gate_proj.g_idx", "model.layers.{}.mlp.gate_proj.qzeros"],
+ "model.layers.{}.mlp.up_proj.qweight": ["model.layers.{}.mlp.up_proj.scales", "model.layers.{}.mlp.up_proj.g_idx", "model.layers.{}.mlp.up_proj.qzeros"],
+ "model.layers.{}.mlp.down_proj.qweight": ["model.layers.{}.mlp.down_proj.scales", "model.layers.{}.mlp.down_proj.g_idx", "model.layers.{}.mlp.down_proj.qzeros"],
+ "fc.qweight": ["fc.scales", "fc.g_idx", "fc.qzeros"],
+ }
+
+ convert_checkpoint = {}
+ processed_keys = set()
+
+ for gptq_key in autogptq_weigths:
+ if gptq_key in processed_keys:
+ continue
+ elif "layers" in gptq_key:
+ abstract_key = re.sub(r'(\d+)', '{}', gptq_key)
+ layer_num = re.search(r'\d+', gptq_key).group(0)
+ if "q_proj" in abstract_key:
+ if abstract_key.endswith('qweight'):
+ k_key = gptq_key.replace('q_proj', 'k_proj')
+ v_key = gptq_key.replace('q_proj', 'v_proj')
+
+ q_weight = autogptq_weigths[gptq_key].clone().cuda()
+ k_weight = autogptq_weigths[k_key].clone().cuda()
+ v_weight = autogptq_weigths[v_key].clone().cuda()
+ x = torch.cat([q_weight, k_weight, v_weight], dim=-1)
+
+ shape_0 = x.shape[0] * 8
+ shape_1 = x.shape[1]
+ x = marlin_repack_qweight(x, bits, shape_0, shape_1)
+
+ convert_checkpoint[gptq_key.replace("q_proj", "qkv_proj")] = x.cpu()
+
+ processed_keys.add(gptq_key)
+ processed_keys.add(k_key)
+ processed_keys.add(v_key)
+
+ for q_keys in gptq_convert_dict[abstract_key]:
+ if q_keys.endswith("scales"):
+ k_q_keys = q_keys.replace("q_proj", "k_proj")
+ v_q_keys = q_keys.replace("q_proj", "v_proj")
+
+ scales_x_q = autogptq_weigths[q_keys.format(layer_num)].clone().cuda()
+ scales_x_k = autogptq_weigths[k_q_keys.format(layer_num)].clone().cuda()
+ scales_x_v = autogptq_weigths[v_q_keys.format(layer_num)].clone().cuda()
+ scales_x = torch.cat([scales_x_q, scales_x_k, scales_x_v], dim=-1)
+ scales_x.data = marlin_permute_scales(scales_x.data.contiguous(),
+ size_k=shape_0,
+ size_n=shape_1,
+ group_size=group_size)
+ convert_checkpoint[q_keys.format(layer_num).replace("q_proj", "qkv_proj")] = scales_x.cpu()
+
+ processed_keys.add(q_keys.format(layer_num))
+ processed_keys.add(q_keys.replace("q_proj", "k_proj").format(layer_num))
+ processed_keys.add(q_keys.replace("q_proj", "v_proj").format(layer_num))
+ elif "gate_proj" in abstract_key:
+ if abstract_key.endswith('qweight'):
+ up_key = gptq_key.replace('gate_proj', 'up_proj')
+
+ gate_weight = autogptq_weigths[gptq_key].clone().cuda()
+ up_weight = autogptq_weigths[up_key].clone().cuda()
+
+ x = torch.cat([gate_weight, up_weight], dim=-1)
+
+ shape_0 = x.shape[0] * 8
+ shape_1 = x.shape[1]
+ x = marlin_repack_qweight(x, bits, shape_0, shape_1)
+
+ convert_checkpoint[gptq_key.replace("gate_proj", "gate_up_proj")] = x.cpu()
+
+ processed_keys.add(gptq_key)
+ processed_keys.add(up_key)
+
+ for q_keys in gptq_convert_dict[abstract_key]:
+ if q_keys.endswith("scales"):
+ up_q_keys = q_keys.replace("gate_proj", "up_proj")
+
+ scales_x_gate = autogptq_weigths[q_keys.format(layer_num)].clone().cuda()
+ scales_x_up = autogptq_weigths[up_q_keys.format(layer_num)].clone().cuda()
+ scales_x = torch.cat([scales_x_gate, scales_x_up], dim=-1)
+ scales_x.data = marlin_permute_scales(scales_x.data.contiguous(),
+ size_k=shape_0,
+ size_n=shape_1,
+ group_size=group_size)
+ convert_checkpoint[q_keys.format(layer_num).replace("gate_proj", "gate_up_proj")] = scales_x.cpu()
+
+ processed_keys.add(q_keys.format(layer_num))
+ processed_keys.add(q_keys.replace("gate_proj", "up_proj").format(layer_num))
+
+ elif "down_proj" in abstract_key or "o_proj" in abstract_key:
+ if abstract_key.endswith('qweight'):
+ x = autogptq_weigths[gptq_key].clone().cuda()
+
+ shape_0 = x.shape[0] * 8
+ shape_1 = x.shape[1]
+ x = marlin_repack_qweight(x, bits, shape_0, shape_1)
+
+ convert_checkpoint[gptq_key] = x.cpu()
+
+ processed_keys.add(gptq_key)
+
+ for q_keys in gptq_convert_dict[abstract_key]:
+ if q_keys.endswith("scales"):
+
+ scales_x = autogptq_weigths[q_keys.format(layer_num)].clone().cuda()
+ scales_x.data = marlin_permute_scales(scales_x.data.contiguous(),
+ size_k=shape_0,
+ size_n=shape_1,
+ group_size=group_size)
+ convert_checkpoint[q_keys.format(layer_num)] = scales_x.cpu()
+
+ processed_keys.add(q_keys.format(layer_num))
+
+ elif "post_attention_layernorm" in gptq_key or "input_layernorm" in gptq_key:
+ convert_checkpoint[gptq_key] = autogptq_weigths[gptq_key].clone()
+ elif "fc" in gptq_key and autogptq_weigths[gptq_key].dtype == torch.int32:
+ if gptq_key.endswith('qweight'):
+ fc_qweight = autogptq_weigths[gptq_key].clone().cuda()
+ packed_in_features_x_2, out_features = fc_qweight.shape
+ packed_in_features = packed_in_features_x_2 // 2
+ in_features = packed_in_features * 32 // bits
+ fc1_weight = fc_qweight[:packed_in_features, :].contiguous()
+ fc2_weight = fc_qweight[packed_in_features:, :].contiguous()
+
+ fc1_weight = marlin_repack_qweight(fc1_weight, bits, in_features, out_features)
+ fc2_weight = marlin_repack_qweight(fc2_weight, bits, in_features, out_features)
+
+ convert_checkpoint[gptq_key] = torch.cat([fc1_weight, fc2_weight], dim=-1).cpu()
+ processed_keys.add(gptq_key)
+
+ for fc_key in gptq_convert_dict[gptq_key]:
+ if fc_key.endswith("scales"):
+ fc_scales = autogptq_weigths[gptq_key.replace("qweight", "scales")].clone().cuda()
+ fc_scales_1 = fc_scales[:in_features // group_size, :].contiguous()
+ fc_scales_2 = fc_scales[in_features // group_size:, :].contiguous()
+
+ fc_scales_1 = marlin_permute_scales(
+ fc_scales_1.data.contiguous(),
+ size_k=in_features,
+ size_n=out_features,
+ group_size=group_size
+ )
+ fc_scales_2 = marlin_permute_scales(
+ fc_scales_2.data.contiguous(),
+ size_k=in_features,
+ size_n=out_features,
+ group_size=group_size
+ )
+ # convert_checkpoint[q_keys.format(layer_num).replace("gate_proj", "gate_up_proj")] = scales_x.cpu()
+ convert_checkpoint[gptq_key.replace("qweight", "scales")] = torch.cat([fc_scales_1, fc_scales_2], dim=-1).cpu()
+ processed_keys.add(gptq_key.replace("qweight", "scales"))
+ else:
+ convert_checkpoint[gptq_key] = autogptq_weigths[gptq_key].clone()
+
+ save_file(convert_checkpoint, os.path.join(output_path, f"model_gptq.safetensors"))
+ # copy quantization config
+ config_list = glob.glob(os.path.join(quant_path, "*config.json"))
+ for config_file in config_list:
+ # copy config to output path
+ config_filename = os.path.basename(config_file)
+ dst_path = os.path.join(output_path, config_filename)
+ shutil.copy2(config_file, dst_path)
+
+ # copy tokenizer
+ tokenizer_list = glob.glob(os.path.join(orig_model_path, "tokenizer*"))
+ for tokenizer_file in tokenizer_list:
+ # copy config to output path
+ tokenizer_filename = os.path.basename(tokenizer_file)
+ dst_path = os.path.join(output_path, tokenizer_filename)
+ shutil.copy2(tokenizer_file, dst_path)
+
+ # copy "special_tokens_map.json"
+ special_tokens_map_file = glob.glob(os.path.join(orig_model_path, "special_tokens_map.json"))[0]
+ special_tokens_map_basename = os.path.basename(special_tokens_map_file)
+ dst_path = os.path.join(output_path, special_tokens_map_basename)
+ shutil.copy2(special_tokens_map_file, dst_path)
+
+if __name__=="__main__":
+
+ args = parser.parse_args()
+ orig_model_path = args.model_path
+ quant_path = args.quant_path
+ output_path = args.output_path
+
+ os.makedirs(output_path, exist_ok=True)
+
+ convert_w4a16_checkpoint(orig_model_path, quant_path, output_path)
\ No newline at end of file
diff --git a/examples/CPM.cu/model_convert/post_process_w4a16_eagle.py b/examples/CPM.cu/model_convert/post_process_w4a16_eagle.py
new file mode 100644
index 00000000..f9a940d3
--- /dev/null
+++ b/examples/CPM.cu/model_convert/post_process_w4a16_eagle.py
@@ -0,0 +1,48 @@
+import os
+import torch
+import argparse
+from safetensors.torch import load_file, save_file
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--fp-model-path", type=str, required=True, help="Path to the fp model")
+parser.add_argument("--quant-model-path", type=str, required=True, help="Path to the quantized model")
+parser.add_argument("--output-path", type=str, required=True, help="Path to save the converted model")
+
+def post_process_eagle_w4a16_ckpt(fp_model_path, quant_model_path, output_path):
+ fp_model = torch.load(os.path.join(fp_model_path, "pytorch_model.bin"))
+ quant_model = load_file(os.path.join(quant_model_path, "model_gptq.safetensors"))
+
+ new_state_dict = {}
+
+ assert (fp_model["embed_tokens.weight"].to(torch.float16) == quant_model["model.embed_tokens.weight"].cuda().to(torch.float16)).all(), "embed_tokens.weight mismatch"
+ new_state_dict["embed_tokens.weight"] = fp_model["embed_tokens.weight"].to(torch.float16)
+
+ if "fc.weight" in quant_model.keys():
+ assert (fp_model["fc.weight"].to(torch.float16) == quant_model["fc.weight"].cuda().to(torch.float16)).all(), "fc.weight mismatch"
+ new_state_dict["fc.weight"] = fp_model["fc.weight"].to(torch.float16)
+ elif "fc.qweight" in quant_model.keys():
+ new_state_dict["fc.qweight"] = quant_model["fc.qweight"]
+ new_state_dict["fc.scales"] = quant_model["fc.scales"]
+
+ new_state_dict["input_norm1.weight"] = fp_model["input_norm1.weight"].to(torch.float16)
+ new_state_dict["input_norm2.weight"] = fp_model["input_norm2.weight"].to(torch.float16)
+
+ for key, value in quant_model.items():
+ if "model.layers." in key:
+ new_key = key.replace("model.", "")
+ new_state_dict[new_key] = value
+
+ save_file(new_state_dict, os.path.join(output_path, f"model_gptq.safetensors"))
+
+ os.system(f"cp {quant_model_path}/*.json {output_path}")
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ fp_model_path = args.fp_model_path
+ quant_model_path = args.quant_model_path
+ output_path = args.output_path
+
+ os.makedirs(output_path, exist_ok=True)
+
+ post_process_eagle_w4a16_ckpt(fp_model_path, quant_model_path, output_path)
+
\ No newline at end of file
diff --git a/examples/CPM.cu/scripts/fr_spec/gen_fr_index.py b/examples/CPM.cu/scripts/fr_spec/gen_fr_index.py
new file mode 100644
index 00000000..b556ca95
--- /dev/null
+++ b/examples/CPM.cu/scripts/fr_spec/gen_fr_index.py
@@ -0,0 +1,89 @@
+from datasets import load_dataset
+from transformers import AutoTokenizer
+from collections import Counter
+from tqdm import tqdm
+import torch
+import argparse
+import os
+
+def main(args):
+ print(f"Generating FR index for {args.model_name} with vocab size {args.vocab_size}")
+ print(f"This may take about 5-10 minutes with 1M lines on good network connection")
+ print("Loading dataset...")
+ ds = load_dataset('Salesforce/wikitext', 'wikitext-103-raw-v1', streaming=True)['train']
+ # Only take the number of samples we need to process
+ ds = ds.take(args.num_lines + 1) # +1 to account for 0-indexing
+ print(f"Dataset limited to {args.num_lines + 1} samples")
+
+ print("Loading tokenizer...")
+ tokenizer = AutoTokenizer.from_pretrained(args.model_path)
+ print("Tokenizer loaded successfully")
+
+ token_counter = Counter()
+ num_lines = args.num_lines
+ num_tokens = 0
+ print("Starting to process data...")
+ for i, d in tqdm(enumerate(ds)):
+ tokens = tokenizer.encode(d['text'])
+ token_counter.update(tokens)
+ num_tokens += len(tokens)
+ if i == num_lines:
+ break
+
+ sort_by_freq = sorted(token_counter.items(), key=lambda x: x[1], reverse=True)
+ ids, frequencies = zip(*sort_by_freq)
+ ids = list(ids)
+
+ print(f"processed {num_lines} items")
+ print(f"processed {num_tokens} tokens")
+ print(f"unique tokens: {len(ids)}")
+
+ os.makedirs(f'fr_index/{args.model_name}', exist_ok=True)
+
+ for r in args.vocab_size:
+ eos_id = tokenizer.encode(tokenizer.special_tokens_map['eos_token'])
+ if eos_id not in ids[:r]:
+ not_in_ids = len(set(eos_id) - set(ids[:r]))
+ freq_ids = ids[:r - not_in_ids] + eos_id
+ else:
+ freq_ids = ids[:r]
+ if (r != len(freq_ids)):
+ print(f"Warning: requested vocab_size {r} but actual size: {len(freq_ids)}, file not saved")
+ else:
+ pt_path = f'fr_index/{args.model_name}/freq_{r}.pt'
+ print(f'save {pt_path}, actual size: {len(freq_ids)}')
+ with open(pt_path, 'wb') as f:
+ torch.save(freq_ids, f)
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument(
+ '--model_name',
+ type=str,
+ default='MiniCPM4-8B',
+ help='The name of the model.'
+ )
+ parser.add_argument(
+ '--model_path',
+ type=str,
+ default='openbmb/MiniCPM4-8B',
+ help='The path to the model.'
+ )
+ parser.add_argument(
+ '--num_lines',
+ type=int,
+ default=1000000,
+ help='The number of lines to process.'
+ )
+ parser.add_argument(
+ '--vocab_size',
+ nargs='+',
+ type=int,
+ default=[8192, 16384, 32768],
+ help='The vocab sizes to process.'
+ )
+
+ args = parser.parse_args()
+ print(args)
+ main(args)
diff --git a/examples/CPM.cu/scripts/model_convert/convert_w4a16.sh b/examples/CPM.cu/scripts/model_convert/convert_w4a16.sh
new file mode 100644
index 00000000..15320d1f
--- /dev/null
+++ b/examples/CPM.cu/scripts/model_convert/convert_w4a16.sh
@@ -0,0 +1,9 @@
+Model_Path=/yourpath/minicpm4_mupformat
+Quant_Path=/yourpath/minicpm4_autogptq
+Output_Path=/yourpath/minicpm4_marlin
+
+
+python model_convert/convert_w4a16.py \
+ --model-path $Model_Path \
+ --quant-path $Quant_Path \
+ --output-path $Output_Path
\ No newline at end of file
diff --git a/examples/CPM.cu/setup.py b/examples/CPM.cu/setup.py
new file mode 100644
index 00000000..e19f9d7b
--- /dev/null
+++ b/examples/CPM.cu/setup.py
@@ -0,0 +1,317 @@
+import os, glob
+from setuptools import setup, find_packages
+
+this_dir = os.path.dirname(os.path.abspath(__file__))
+
+def detect_cuda_arch():
+ """Automatically detect current CUDA architecture"""
+ # 1. First check if environment variable specifies architecture
+ env_arch = os.getenv("CPMCU_CUDA_ARCH")
+ if env_arch:
+ # Only support simple comma-separated format, e.g., "80,86"
+ arch_list = []
+ tokens = env_arch.split(',')
+
+ for token in tokens:
+ token = token.strip()
+ if not token:
+ continue
+
+ # Check format: must be pure digits
+ if not token.isdigit():
+ raise ValueError(
+ f"Invalid CUDA architecture format: '{token}'. "
+ f"CPMCU_CUDA_ARCH should only contain comma-separated numbers like '80,86'"
+ )
+
+ arch_list.append(token)
+
+ if arch_list:
+ print(f"Using CUDA architectures from environment variable: {arch_list}")
+ return arch_list
+
+ # 2. Check if torch library is available, if so, auto-detect
+ try:
+ import torch
+ except ImportError:
+ # 3. If no environment variable and no torch, throw error
+ raise RuntimeError(
+ "CUDA architecture detection failed. Please either:\n"
+ "1. Set environment variable CPMCU_CUDA_ARCH (e.g., export CPMCU_CUDA_ARCH=90), or\n"
+ "2. Install PyTorch (pip install torch) for automatic detection.\n"
+ "Common CUDA architectures: 70 (V100), 75 (T4), 80 (A100), 86 (RTX 30xx), 89 (RTX 40xx), 90 (H100)"
+ )
+
+ # Use torch to auto-detect all GPU architectures
+ try:
+ if torch.cuda.is_available():
+ arch_set = set()
+ device_count = torch.cuda.device_count()
+ for i in range(device_count):
+ major, minor = torch.cuda.get_device_capability(i)
+ arch = f"{major}{minor}"
+ arch_set.add(arch)
+
+ arch_list = sorted(list(arch_set)) # Sort for consistency
+ print(f"Detected CUDA architectures: {arch_list} (from {device_count} GPU devices)")
+ return arch_list
+ else:
+ raise RuntimeError(
+ "No CUDA devices detected. Please either:\n"
+ "1. Set environment variable CPMCU_CUDA_ARCH (e.g., export CPMCU_CUDA_ARCH=90), or\n"
+ "2. Ensure CUDA devices are available and properly configured.\n"
+ "Common CUDA architectures: 70 (V100), 75 (T4), 80 (A100), 86 (RTX 30xx), 89 (RTX 40xx), 90 (H100)"
+ )
+ except Exception as e:
+ raise RuntimeError(
+ f"CUDA architecture detection failed: {e}\n"
+ "Please set environment variable CPMCU_CUDA_ARCH (e.g., export CPMCU_CUDA_ARCH=90).\n"
+ "Common CUDA architectures: 70 (V100), 75 (T4), 80 (A100), 86 (RTX 30xx), 89 (RTX 40xx), 90 (H100)"
+ )
+
+def append_nvcc_threads(nvcc_extra_args):
+ nvcc_threads = os.getenv("NVCC_THREADS") or "16"
+ return nvcc_extra_args + ["--threads", nvcc_threads]
+
+def get_compile_args():
+ """Return different compilation arguments based on debug mode"""
+ debug_mode = os.getenv("CPMCU_DEBUG", "0").lower() in ("1", "true", "yes")
+ perf_mode = os.getenv("CPMCU_PERF", "0").lower() in ("1", "true", "yes")
+
+ # Check precision type environment variable
+ dtype_env = os.getenv("CPMCU_DTYPE", "fp16").lower()
+
+ # Parse precision type list
+ dtype_list = [dtype.strip() for dtype in dtype_env.split(',')]
+
+ # Validate precision types
+ valid_dtypes = {"fp16", "bf16"}
+ invalid_dtypes = [dtype for dtype in dtype_list if dtype not in valid_dtypes]
+ if invalid_dtypes:
+ raise ValueError(
+ f"Invalid CPMCU_DTYPE values: {invalid_dtypes}. "
+ f"Supported values: 'fp16', 'bf16', 'fp16,bf16'"
+ )
+
+ # Deduplicate and generate compilation definitions
+ dtype_set = set(dtype_list)
+ dtype_defines = []
+ if "fp16" in dtype_set:
+ dtype_defines.append("-DENABLE_DTYPE_FP16")
+ if "bf16" in dtype_set:
+ dtype_defines.append("-DENABLE_DTYPE_BF16")
+
+ # Print compilation information
+ if len(dtype_set) == 1:
+ dtype_name = list(dtype_set)[0].upper()
+ print(f"Compiling with {dtype_name} support only")
+ else:
+ dtype_names = [dtype.upper() for dtype in sorted(dtype_set)]
+ print(f"Compiling with {' and '.join(dtype_names)} support")
+
+ # Common compilation arguments
+ common_cxx_args = ["-std=c++17"] + dtype_defines
+ common_nvcc_args = [
+ "-std=c++17",
+ "-U__CUDA_NO_HALF_OPERATORS__",
+ "-U__CUDA_NO_HALF_CONVERSIONS__",
+ "-U__CUDA_NO_HALF2_OPERATORS__",
+ "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
+ "--expt-relaxed-constexpr",
+ "--expt-extended-lambda",
+ ] + dtype_defines
+
+ if debug_mode:
+ print("Debug mode enabled (CPMCU_DEBUG=1)")
+ cxx_args = common_cxx_args + [
+ "-g3", # Most detailed debug information
+ "-O0", # Disable optimization
+ "-DDISABLE_MEMPOOL",
+ "-DDEBUG",
+ "-fno-inline", # Disable inlining
+ "-fno-omit-frame-pointer", # Keep frame pointer
+ ]
+ nvcc_base_args = common_nvcc_args + [
+ "-O0",
+ "-g", # Host-side debug information
+ "-lineinfo", # Generate line number information
+ "-DDISABLE_MEMPOOL",
+ "-DDEBUG",
+ "-DCUDA_DEBUG",
+ "-Xcompiler", "-g3", # Pass to host compiler
+ "-Xcompiler", "-fno-inline", # Disable inlining
+ "-Xcompiler", "-fno-omit-frame-pointer", # Keep frame pointer
+ ]
+ # Debug mode linking arguments
+ link_args = ["-g", "-rdynamic"]
+ else:
+ print("Release mode enabled")
+ cxx_args = common_cxx_args + ["-O3"]
+ nvcc_base_args = common_nvcc_args + [
+ "-O3",
+ "--use_fast_math",
+ ]
+ # Release mode linking arguments
+ link_args = []
+
+ # Add performance testing control
+ if perf_mode:
+ print("Performance monitoring enabled (CPMCU_PERF=1)")
+ cxx_args.append("-DENABLE_PERF")
+ nvcc_base_args.append("-DENABLE_PERF")
+ else:
+ print("Performance monitoring disabled (CPMCU_PERF=0)")
+
+ return cxx_args, nvcc_base_args, link_args, dtype_set
+
+def get_all_headers():
+ """Get all header files for dependency tracking"""
+ header_patterns = [
+ "src/**/*.h",
+ "src/**/*.hpp",
+ "src/**/*.cuh",
+ "src/cutlass/include/**/*.h",
+ "src/cutlass/include/**/*.hpp",
+ "src/flash_attn/**/*.h",
+ "src/flash_attn/**/*.hpp",
+ "src/flash_attn/**/*.cuh",
+ ]
+
+ headers = []
+ for pattern in header_patterns:
+ abs_headers = glob.glob(os.path.join(this_dir, pattern), recursive=True)
+ # Convert to relative paths
+ rel_headers = [os.path.relpath(h, this_dir) for h in abs_headers]
+ headers.extend(rel_headers)
+
+ # Filter out non-existent files (check absolute path but return relative path)
+ headers = [h for h in headers if os.path.exists(os.path.join(this_dir, h))]
+
+ return headers
+
+def get_flash_attn_sources(enabled_dtypes):
+ """Get flash attention source files based on enabled data types"""
+ sources = []
+
+ for dtype in enabled_dtypes:
+ if dtype == "fp16":
+ # sources.extend(glob.glob("src/flash_attn/src/*hdim64_fp16*.cu"))
+ sources.extend(glob.glob("src/flash_attn/src/*hdim128_fp16*.cu"))
+ elif dtype == "bf16":
+ # sources.extend(glob.glob("src/flash_attn/src/*hdim64_bf16*.cu"))
+ sources.extend(glob.glob("src/flash_attn/src/*hdim128_bf16*.cu"))
+
+ return sources
+
+# Try to build extension modules
+ext_modules = []
+cmdclass = {}
+
+try:
+ # Try to import torch-related modules
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+ # Get CUDA architecture
+ arch_list = detect_cuda_arch()
+
+ # Get compilation arguments
+ cxx_args, nvcc_base_args, link_args, dtype_set = get_compile_args()
+
+ # Get header files
+ all_headers = get_all_headers()
+
+ # Generate gencode arguments for each architecture
+ gencode_args = []
+ arch_defines = []
+ for arch in arch_list:
+ gencode_args.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
+ arch_defines.append(f"-D_ARCH{arch}")
+
+ print(f"Using CUDA architecture compile flags: {arch_list}")
+
+ flash_attn_sources = get_flash_attn_sources(dtype_set)
+
+ # Create Ninja build extension class
+ class NinjaBuildExtension(BuildExtension):
+ def __init__(self, *args, **kwargs) -> None:
+ # do not override env MAX_JOBS if already exists
+ if not os.environ.get("MAX_JOBS"):
+ import psutil
+ # calculate the maximum allowed NUM_JOBS based on cores
+ max_num_jobs_cores = max(1, os.cpu_count() // 2)
+ # calculate the maximum allowed NUM_JOBS based on free memory
+ free_memory_gb = psutil.virtual_memory().available / (1024 ** 3) # free memory in GB
+ max_num_jobs_memory = int(free_memory_gb / 9) # each JOB peak memory cost is ~8-9GB when threads = 4
+ # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
+ max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory))
+ os.environ["MAX_JOBS"] = str(max_jobs)
+ super().__init__(*args, **kwargs)
+
+ # Configure extension module
+ ext_modules = [
+ CUDAExtension(
+ name='cpmcu.C',
+ sources = [
+ "src/entry.cu",
+ "src/utils.cu",
+ "src/signal_handler.cu",
+ "src/perf.cu",
+ *glob.glob("src/qgemm/gptq_marlin/*cu"),
+ *flash_attn_sources,
+ ],
+ libraries=["cublas", "dl"],
+ depends=all_headers,
+ extra_compile_args={
+ "cxx": cxx_args,
+ "nvcc": append_nvcc_threads(
+ nvcc_base_args +
+ gencode_args +
+ arch_defines + [
+ # Add dependency file generation options
+ "-MMD", "-MP",
+ ]
+ ),
+ },
+ extra_link_args=link_args,
+ include_dirs=[
+ f"{this_dir}/src/flash_attn",
+ f"{this_dir}/src/flash_attn/src",
+ f"{this_dir}/src/cutlass/include",
+ f"{this_dir}/src/",
+ ],
+ )
+ ]
+
+ cmdclass = {'build_ext': NinjaBuildExtension}
+
+except Exception as e:
+ print(f"Warning: Unable to configure CUDA extension module: {e}")
+ print("Skipping extension module build, installing Python package only...")
+
+setup(
+ name='cpmcu',
+ version='1.0.0',
+ author_email="acha131441373@gmail.com",
+ description="cpm cuda implementation",
+ packages=find_packages(),
+ setup_requires=[
+ "pybind11",
+ "psutil",
+ "ninja",
+ "torch",
+ ],
+ install_requires=[
+ "transformers==4.46.2",
+ "accelerate==0.26.0",
+ "datasets",
+ "fschat",
+ "openai",
+ "anthropic",
+ "human_eval",
+ "zstandard",
+ "tree_sitter",
+ "tree-sitter-python"
+ ],
+ ext_modules=ext_modules,
+ cmdclass=cmdclass,
+)
\ No newline at end of file
diff --git a/examples/CPM.cu/src/entry.cu b/examples/CPM.cu/src/entry.cu
new file mode 100644
index 00000000..41b3acc3
--- /dev/null
+++ b/examples/CPM.cu/src/entry.cu
@@ -0,0 +1,534 @@
+#include
+#include
+#include
+#include
+
+#include "utils.cuh"
+#include "trait.cuh"
+#include "perf.cuh"
+// base model
+#include "model/model.cuh"
+#include "model/w4a16_gptq_marlin/w4a16_gptq_marlin_model.cuh"
+#include "model/minicpm4/minicpm4_model.cuh"
+#include "model/minicpm4/minicpm4_w4a16_gptq_marlin_model.cuh"
+
+// eagle
+#include "model/eagle.cuh"
+#include "model/minicpm4/minicpm4_eagle.cuh"
+#include "model/eagle_base_quant/eagle_base_w4a16_gptq_marlin.cuh"
+
+// spec model
+#include "model/spec_quant/w4a16_gm_spec_w4a16_gm.cuh"
+
+// hier
+#include "model/hier_spec_quant/hier_ea_w4a16_gm_spec_w4a16_gm.cuh"
+#include "model/hier_spec_quant/hier_ea_w4a16_gm_rot_spec_w4a16_gm.cuh"
+
+
+#if defined(ENABLE_DTYPE_FP16) && defined(ENABLE_DTYPE_BF16)
+#define DTYPE_SWITCH(COND, ...) \
+ [&] { \
+ if (COND == 0) { \
+ using elem_type = __half; \
+ return __VA_ARGS__(); \
+ } else { \
+ using elem_type = __nv_bfloat16; \
+ return __VA_ARGS__(); \
+ } \
+ }()
+#elif defined(ENABLE_DTYPE_FP16)
+#define DTYPE_SWITCH(COND, ...) \
+ [&] { \
+ if (COND != 0) { \
+ throw std::runtime_error("BF16 support not compiled. Please recompile with CPMCU_DTYPE=bf16 or CPMCU_DTYPE=fp16,bf16"); \
+ } \
+ using elem_type = __half; \
+ return __VA_ARGS__(); \
+ }()
+#elif defined(ENABLE_DTYPE_BF16)
+#define DTYPE_SWITCH(COND, ...) \
+ [&] { \
+ if (COND == 0) { \
+ throw std::runtime_error("FP16 support not compiled. Please recompile with CPMCU_DTYPE=fp16 or CPMCU_DTYPE=fp16,bf16"); \
+ } \
+ using elem_type = __nv_bfloat16; \
+ return __VA_ARGS__(); \
+ }()
+#else
+#error "At least one of ENABLE_DTYPE_FP16 or ENABLE_DTYPE_BF16 must be defined"
+#endif
+
+#define MODEL_TYPE_SWITCH(MODEL_PTR, T, ...) \
+ [&] { \
+ if (dynamic_cast*>(MODEL_PTR)) { \
+ using ModelType = MiniCPM4Impl; \
+ auto* typed_model = static_cast*>(MODEL_PTR); \
+ return __VA_ARGS__(); \
+ } else if (dynamic_cast*>(MODEL_PTR)) { \
+ using ModelType = ModelImpl; \
+ auto* typed_model = static_cast*>(MODEL_PTR); \
+ return __VA_ARGS__(); \
+ } \
+ else if (dynamic_cast*>(MODEL_PTR)) { \
+ using ModelType = W4A16GPTQMarlinModelImpl; \
+ auto* typed_model = static_cast*>(MODEL_PTR); \
+ return __VA_ARGS__(); \
+ } else if (dynamic_cast*>(MODEL_PTR)) { \
+ using ModelType = MiniCPM4W4A16GPTQMarlinModelImpl; \
+ auto* typed_model = static_cast*>(MODEL_PTR); \
+ return __VA_ARGS__(); \
+ } \
+ }()
+
+#define EAGLE_QUANT_SWITCH(COND, T, ...) \
+ [&] { \
+ if (COND == true) { \
+ using LayerType = W4A16GPTQMarlinLayer; \
+ using Fc1Type = W4A16GPTQMarlinLinear; \
+ using Fc2Type = W4A16GPTQMarlinLinear; \
+ return __VA_ARGS__(); \
+ } else { \
+ using LayerType = Layer; \
+ using Fc1Type = Linear; \
+ using Fc2Type = Linear; \
+ return __VA_ARGS__(); \
+ } \
+ }()
+
+Model* model;
+
+void init_base_model(
+ float memory_limit,
+ int vocab_size,
+ int num_hidden_layers,
+ int hidden_size,
+ int intermediate_size,
+ int num_attention_heads,
+ int num_key_value_heads,
+ int head_dim,
+ float rms_norm_eps,
+ int torch_dtype,
+ int chunk_length,
+ float scale_embed,
+ float scale_lmhead,
+ float scale_residual
+) {
+ init_resources();
+
+ DTYPE_SWITCH(torch_dtype, [&] {
+ model = new ModelImpl(
+ memory_limit,
+ vocab_size,
+ num_hidden_layers,
+ hidden_size,
+ intermediate_size,
+ num_attention_heads,
+ num_key_value_heads,
+ head_dim,
+ rms_norm_eps,
+ chunk_length,
+ scale_embed,
+ scale_lmhead,
+ scale_residual
+ );
+ });
+
+}
+
+void init_minicpm4_model(
+ float memory_limit,
+ int vocab_size,
+ int num_hidden_layers,
+ int hidden_size,
+ int intermediate_size,
+ int num_attention_heads,
+ int num_key_value_heads,
+ int head_dim,
+ float rms_norm_eps,
+ int torch_dtype,
+ int chunk_length,
+ float scale_embed,
+ float scale_lmhead,
+ float scale_residual,
+ int sink_window_size,
+ int block_window_size,
+ int sparse_topk_k,
+ int sparse_switch,
+ bool apply_compress_lse
+) {
+ init_resources();
+
+ DTYPE_SWITCH(torch_dtype, [&] {
+ model = new MiniCPM4Impl(
+ memory_limit,
+ vocab_size,
+ num_hidden_layers,
+ hidden_size,
+ intermediate_size,
+ num_attention_heads,
+ num_key_value_heads,
+ head_dim,
+ rms_norm_eps,
+ chunk_length,
+ scale_embed,
+ scale_lmhead,
+ scale_residual,
+ sink_window_size,
+ block_window_size,
+ sparse_topk_k,
+ sparse_switch,
+ apply_compress_lse
+ );
+ });
+
+}
+
+void init_w4a16_gptq_marlin_base_model(
+ float memory_limit,
+ int vocab_size,
+ int num_hidden_layers,
+ int hidden_size,
+ int intermediate_size,
+ int num_attention_heads,
+ int num_key_value_heads,
+ int head_dim,
+ float rms_norm_eps,
+ int group_size,
+ int torch_dtype,
+ int chunk_length,
+ float scale_embed,
+ float scale_lmhead,
+ float scale_residual
+) {
+ init_resources();
+
+ DTYPE_SWITCH(torch_dtype, [&] {
+ model = new W4A16GPTQMarlinModelImpl(
+ memory_limit,
+ vocab_size,
+ num_hidden_layers,
+ hidden_size,
+ intermediate_size,
+ num_attention_heads,
+ num_key_value_heads,
+ head_dim,
+ rms_norm_eps,
+ group_size,
+ chunk_length,
+ scale_embed,
+ scale_lmhead,
+ scale_residual
+ );
+ });
+
+}
+
+void init_w4a16_gptq_marlin_minicpm4_model(
+ float memory_limit,
+ int vocab_size,
+ int num_hidden_layers,
+ int hidden_size,
+ int intermediate_size,
+ int num_attention_heads,
+ int num_key_value_heads,
+ int head_dim,
+ float rms_norm_eps,
+ int group_size,
+ int torch_dtype,
+ int chunk_length,
+ float scale_embed,
+ float scale_lmhead,
+ float scale_residual,
+ int sink_window_size,
+ int block_window_size,
+ int sparse_topk_k,
+ int sparse_switch,
+ bool apply_compress_lse
+) {
+ init_resources();
+
+ DTYPE_SWITCH(torch_dtype, [&] {
+ model = new MiniCPM4W4A16GPTQMarlinModelImpl(
+ memory_limit,
+ vocab_size,
+ num_hidden_layers,
+ hidden_size,
+ intermediate_size,
+ num_attention_heads,
+ num_key_value_heads,
+ head_dim,
+ rms_norm_eps,
+ group_size,
+ chunk_length,
+ scale_embed,
+ scale_lmhead,
+ scale_residual,
+ sink_window_size,
+ block_window_size,
+ sparse_topk_k,
+ sparse_switch,
+ apply_compress_lse
+ );
+ });
+
+}
+
+// eagle model
+void init_eagle_model(
+ int num_layers,
+ int num_iter,
+ int topk_per_iter,
+ int tree_size,
+ int torch_dtype
+) {
+ bool dispatch_model = false;
+ DTYPE_SWITCH(torch_dtype, [&] {
+ MODEL_TYPE_SWITCH(model, elem_type, [&] {
+ dispatch_model = true;
+ model = new EagleImpl(
+ typed_model,
+ num_layers,
+ num_iter,
+ topk_per_iter,
+ tree_size
+ );
+ });
+ });
+ if (!dispatch_model) {
+ printf("Model type failed to dispatch: %s\n", typeid(*model).name());
+ }
+}
+
+void init_minicpm4_eagle_model(
+ int num_layers,
+ int num_iter,
+ int topk_per_iter,
+ int tree_size,
+ int torch_dtype,
+ bool apply_eagle_quant,
+ int group_size,
+ int eagle_window_size,
+ int frspec_vocab_size,
+ float residual_scale,
+ bool use_input_norm,
+ bool use_attn_norm
+) {
+ bool dispatch_model = false;
+ DTYPE_SWITCH(torch_dtype, [&] {
+ MODEL_TYPE_SWITCH(model, elem_type, [&] {
+ dispatch_model = true;
+ EAGLE_QUANT_SWITCH(apply_eagle_quant, elem_type, [&] {
+ model = new MiniCPM4EagleImpl(
+ typed_model,
+ num_layers,
+ num_iter,
+ topk_per_iter,
+ tree_size,
+ group_size,
+ eagle_window_size,
+ frspec_vocab_size,
+ residual_scale,
+ use_input_norm,
+ use_attn_norm
+ );
+ });
+ });
+ });
+ if (!dispatch_model) {
+ printf("Model type failed to dispatch: %s\n", typeid(*model).name());
+ }
+}
+
+// spec model
+void init_w4a16_gm_spec_w4a16_gm_model(
+ int draft_vocab_size,
+ int draft_num_hidden_layers,
+ int draft_hidden_size,
+ int draft_intermediate_size,
+ int draft_num_attention_heads,
+ int draft_num_key_value_heads,
+ int draft_head_dim,
+ float draft_rms_norm_eps,
+ int draft_group_size,
+ int num_iter,
+ bool draft_cuda_graph,
+ int torch_dtype
+) {
+ DTYPE_SWITCH(torch_dtype, [&] {
+ model = new W4A16GMSpecW4A16GMImpl(
+ (W4A16GPTQMarlinModelImpl*)model,
+ draft_vocab_size,
+ draft_num_hidden_layers,
+ draft_hidden_size,
+ draft_intermediate_size,
+ draft_num_attention_heads,
+ draft_num_key_value_heads,
+ draft_head_dim,
+ draft_rms_norm_eps,
+ draft_group_size,
+ num_iter,
+ draft_cuda_graph
+ );
+ });
+}
+
+// hier spec model
+void init_hier_eagle_w4a16_gm_spec_w4a16_gm_model(
+ int draft_vocab_size,
+ int draft_num_hidden_layers,
+ int draft_hidden_size,
+ int draft_intermediate_size,
+ int draft_num_attention_heads,
+ int draft_num_key_value_heads,
+ int draft_head_dim,
+ float draft_rms_norm_eps,
+ int draft_group_size,
+ int min_draft_length,
+ bool draft_cuda_graph,
+ int ea_num_layers,
+ int ea_num_iter,
+ int ea_topk_per_iter,
+ int ea_tree_size,
+ bool draft_model_start,
+ int torch_dtype
+) {
+ DTYPE_SWITCH(torch_dtype, [&] {
+ model = new HierEagleW4A16GMSpecW4A16GMImpl(
+ (W4A16GPTQMarlinModelImpl*)model,
+ draft_vocab_size,
+ draft_num_hidden_layers,
+ draft_hidden_size,
+ draft_intermediate_size,
+ draft_num_attention_heads,
+ draft_num_key_value_heads,
+ draft_head_dim,
+ draft_rms_norm_eps,
+ draft_group_size,
+ min_draft_length,
+ draft_cuda_graph,
+ ea_num_layers,
+ ea_num_iter,
+ ea_topk_per_iter,
+ ea_tree_size,
+ draft_model_start
+ );
+ });
+}
+
+void init_hier_eagle_w4a16_gm_rot_spec_w4a16_gm_model(
+ int draft_vocab_size,
+ int draft_num_hidden_layers,
+ int draft_hidden_size,
+ int draft_intermediate_size,
+ int draft_num_attention_heads,
+ int draft_num_key_value_heads,
+ int draft_head_dim,
+ float draft_rms_norm_eps,
+ int draft_group_size,
+ int min_draft_length,
+ bool draft_cuda_graph,
+ int ea_num_layers,
+ int ea_num_iter,
+ int ea_topk_per_iter,
+ int ea_tree_size,
+ bool draft_model_start,
+ int torch_dtype
+) {
+ DTYPE_SWITCH(torch_dtype, [&] {
+ model = new HierEagleW4A16GMRotSpecW4A16GMImpl(
+ (W4A16GPTQMarlinModelImpl*)model,
+ draft_vocab_size,
+ draft_num_hidden_layers,
+ draft_hidden_size,
+ draft_intermediate_size,
+ draft_num_attention_heads,
+ draft_num_key_value_heads,
+ draft_head_dim,
+ draft_rms_norm_eps,
+ draft_group_size,
+ min_draft_length,
+ draft_cuda_graph,
+ ea_num_layers,
+ ea_num_iter,
+ ea_topk_per_iter,
+ ea_tree_size,
+ draft_model_start
+ );
+ });
+}
+
+
+int init_storage() {
+ return model->init_storage();
+}
+
+void load_model(std::string name, std::uintptr_t param) {
+ model->load_to_storage(name, reinterpret_cast(param));
+}
+
+void prefill(int input_length, int history_length, std::uintptr_t input, std::uintptr_t position_ids, std::uintptr_t output) {
+ model->prefill(input_length, history_length, reinterpret_cast(input), reinterpret_cast(position_ids), (void*)(output));
+}
+
+void decode(int input_length, int padded_length, std::uintptr_t input, std::uintptr_t position_ids, std::uintptr_t cache_length, std::uintptr_t mask_2d, std::uintptr_t output, bool cuda_graph) {
+ if (cuda_graph) {
+ if (graphCreated_padding_length != padded_length || graphCreated_input_length != input_length) {
+ if (graphExec != nullptr) {
+ cudaGraphExecDestroy(graphExec);
+ graphExec = nullptr;
+ }
+ if (graph != nullptr) {
+ cudaGraphDestroy(graph);
+ graph = nullptr;
+ }
+ cudaStreamBeginCapture(calc_stream.stream, cudaStreamCaptureModeGlobal);
+ model->decode(input_length, padded_length, reinterpret_cast(input), reinterpret_cast(position_ids), reinterpret_cast(cache_length), reinterpret_cast(mask_2d), reinterpret_cast(output));
+ cudaStreamEndCapture(calc_stream.stream, &graph);
+ cudaGraphInstantiate(&graphExec, graph, nullptr, nullptr, 0);
+ graphCreated_padding_length = padded_length;
+ graphCreated_input_length = input_length;
+ }
+ cudaGraphLaunch(graphExec, calc_stream.stream);
+ } else {
+ model->decode(input_length, padded_length, reinterpret_cast(input), reinterpret_cast(position_ids), reinterpret_cast(cache_length), reinterpret_cast(mask_2d), reinterpret_cast(output));
+ }
+}
+
+void draft(std::uintptr_t tree_draft_ids, std::uintptr_t tree_position_ids, std::uintptr_t cache_length, std::uintptr_t attn_mask, std::uintptr_t tree_parent) {
+ model->draft(reinterpret_cast(tree_draft_ids), reinterpret_cast(tree_position_ids), reinterpret_cast(cache_length), reinterpret_cast(attn_mask), reinterpret_cast(tree_parent));
+}
+
+int verify_and_fix(int num_tokens, std::uintptr_t pred, std::uintptr_t gt, std::uintptr_t position_ids, std::uintptr_t cache_length, std::uintptr_t attn_mask, std::uintptr_t tree_parent) {
+ return model->verify(num_tokens, reinterpret_cast(pred), reinterpret_cast(gt), reinterpret_cast(position_ids), reinterpret_cast(cache_length), reinterpret_cast(attn_mask), reinterpret_cast(tree_parent));
+}
+
+void print_perf_summary() {
+ perf_summary();
+}
+
+PYBIND11_MODULE(C, m) {
+ // base bind
+ m.def("init_base_model", &init_base_model, "Init base model");
+ m.def("init_minicpm4_model", &init_minicpm4_model, "Init minicpm4 model");
+ m.def("init_w4a16_gptq_marlin_base_model", &init_w4a16_gptq_marlin_base_model, "Init W4A16 base model");
+ m.def("init_w4a16_gptq_marlin_minicpm4_model", &init_w4a16_gptq_marlin_minicpm4_model, "Init W4A16 base model");
+
+ // eagle bind
+ m.def("init_eagle_model", &init_eagle_model, "Init eagle model");
+ m.def("init_minicpm4_eagle_model", &init_minicpm4_eagle_model, "Init minicpm4 eagle model");
+ // spec bind
+ m.def("init_w4a16_gm_spec_w4a16_gm_model", &init_w4a16_gm_spec_w4a16_gm_model, "Init w4a16 spec v1 model");
+
+ // hier spec bind
+ m.def("init_hier_eagle_w4a16_gm_spec_w4a16_gm_model", &init_hier_eagle_w4a16_gm_spec_w4a16_gm_model, "init hier eagle gm spec gm model");
+ m.def("init_hier_eagle_w4a16_gm_rot_spec_w4a16_gm_model", &init_hier_eagle_w4a16_gm_rot_spec_w4a16_gm_model, "init hier eagle rot gm spec gm model");
+
+ // interface
+ m.def("init_storage", &init_storage, "Init storage");
+ m.def("load_model", &load_model, "Load model");
+ m.def("prefill", &prefill, "Prefill");
+ m.def("decode", &decode, "Decode");
+ m.def("draft", &draft, "Draft");
+ m.def("verify_and_fix", &verify_and_fix, "Verify and fix");
+ m.def("print_perf_summary", &print_perf_summary, "Print perf summary");
+}
\ No newline at end of file
diff --git a/examples/CPM.cu/src/flash_attn/flash_api.hpp b/examples/CPM.cu/src/flash_attn/flash_api.hpp
new file mode 100644
index 00000000..f0ee07e6
--- /dev/null
+++ b/examples/CPM.cu/src/flash_attn/flash_api.hpp
@@ -0,0 +1,392 @@
+/******************************************************************************
+ * Copyright (c) 2024, Tri Dao.
+ ******************************************************************************/
+
+#include
+
+#include "flash.h"
+#include "static_switch.h"
+#include "../model/mask.cuh"
+
+void set_params_fprop(Flash_fwd_params ¶ms,
+ bool is_bf16,
+ // sizes
+ const size_t b,
+ const size_t seqlen_q,
+ const size_t seqlen_k,
+ const size_t seqlen_q_rounded,
+ const size_t seqlen_k_rounded,
+ const size_t h,
+ const size_t h_k,
+ const size_t d,
+ const size_t d_rounded,
+ // device pointers
+ void* q,
+ void* k,
+ void* v,
+ void* out,
+ void *cu_seqlens_q_d,
+ void *cu_seqlens_k_d,
+ void *seqused_k,
+ void *p_d,
+ void *softmax_lse_d,
+ float p_dropout,
+ float softmax_scale,
+ int window_size_left,
+ int window_size_right,
+ const float softcap,
+ bool seqlenq_ngroups_swapped=false,
+ const bool unpadded_lse=false) {
+
+ // Reset the parameters
+ params = {};
+
+ params.is_bf16 = is_bf16;
+
+ // Set the pointers and strides.
+ params.q_ptr = q;
+ params.k_ptr = k;
+ params.v_ptr = v;
+ // All stride are in elements, not bytes.
+ params.q_row_stride = h * d;
+ params.k_row_stride = h_k * d;
+ params.v_row_stride = h_k * d;
+ params.q_head_stride = d;
+ params.k_head_stride = d;
+ params.v_head_stride = d;
+ params.o_ptr = out;
+ params.o_row_stride = h * d;
+ params.o_head_stride = d;
+
+ if (cu_seqlens_q_d == nullptr) {
+ params.q_batch_stride = seqlen_q * h * d;
+ params.k_batch_stride = seqlen_k * h_k * d;
+ params.v_batch_stride = seqlen_k * h_k * d;
+ params.o_batch_stride = seqlen_q * h * d;
+ if (seqlenq_ngroups_swapped) {
+ params.q_batch_stride *= seqlen_q;
+ params.o_batch_stride *= seqlen_q;
+ }
+ }
+
+ params.cu_seqlens_q = static_cast(cu_seqlens_q_d);
+ params.cu_seqlens_k = static_cast(cu_seqlens_k_d);
+ params.seqused_k = static_cast(seqused_k);
+
+ // P = softmax(QK^T)
+ params.p_ptr = p_d;
+
+ // Softmax sum
+ params.softmax_lse_ptr = softmax_lse_d;
+
+ // Set the dimensions.
+ params.b = b;
+ params.h = h;
+ params.h_k = h_k;
+ params.h_h_k_ratio = h / h_k;
+ params.seqlen_q = seqlen_q;
+ params.seqlen_k = seqlen_k;
+ params.seqlen_q_rounded = seqlen_q_rounded;
+ params.seqlen_k_rounded = seqlen_k_rounded;
+ params.d = d;
+ params.d_rounded = d_rounded;
+
+ // Set the different scale values.
+ if (softcap > 0.0) {
+ params.softcap = softmax_scale / softcap;
+ params.scale_softmax = softcap;
+ params.scale_softmax_log2 = softcap * M_LOG2E;
+ } else{
+ // Remove potential NaN
+ params.softcap = 0.0;
+ params.scale_softmax = softmax_scale;
+ params.scale_softmax_log2 = softmax_scale * M_LOG2E;
+ }
+
+ // Set this to probability of keeping an element to simplify things.
+ params.p_dropout = 1.f - p_dropout;
+ // Convert p from float to int so we don't have to convert the random uint to float to compare.
+ // [Minor] We want to round down since when we do the comparison we use <= instead of <
+ // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
+ // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
+ params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
+ params.rp_dropout = 1.f / params.p_dropout;
+ params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
+
+ // Causal is the special case where window_size_right == 0 and window_size_left < 0.
+ // Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
+ params.is_causal = window_size_left < 0 && window_size_right == 0;
+
+ if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; }
+ if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; }
+ params.window_size_left = window_size_left;
+ params.window_size_right = window_size_right;
+
+ params.is_seqlens_k_cumulative = true;
+
+ params.unpadded_lse = unpadded_lse;
+ params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
+}
+
+void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) {
+ FP16_SWITCH(!params.is_bf16, [&] {
+ HEADDIM_SWITCH(params.d, [&] {
+ BOOL_SWITCH(params.is_causal, Is_causal, [&] {
+ if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
+ run_mha_fwd_(params, stream);
+ } else {
+ run_mha_fwd_splitkv_dispatch(params, stream);
+ }
+ });
+ });
+ });
+}
+
+// Find the number of splits that maximizes the occupancy. For example, if we have
+// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
+// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
+// splits as that would incur more HBM reads/writes.
+// So we find the best efficiency, then find the smallest number of splits that gets 85%
+// of the best efficiency.
+inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) {
+ // If we have enough to almost fill the SMs, then just use 1 split
+ if (batch_nheads_mblocks >= 0.8f * num_SMs) { return 1; }
+ max_splits = std::min({max_splits, num_SMs, num_n_blocks});
+ float max_efficiency = 0.f;
+ std::vector efficiency;
+ efficiency.reserve(max_splits);
+ auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
+ // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
+ // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
+ // (i.e. it's 11 splits anyway).
+ // So we check if the number of blocks per split is the same as the previous num_splits.
+ auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
+ return num_splits == 1 || ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
+ };
+ for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
+ if (!is_split_eligible(num_splits)) {
+ efficiency.push_back(0.f);
+ } else {
+ float n_waves = float(batch_nheads_mblocks * num_splits) / num_SMs;
+ float eff = n_waves / ceil(n_waves);
+ // printf("num_splits = %d, eff = %f\n", num_splits, eff);
+ if (eff > max_efficiency) { max_efficiency = eff; }
+ efficiency.push_back(eff);
+ }
+ }
+ for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
+ if (!is_split_eligible(num_splits)) { continue; }
+ if (efficiency[num_splits - 1] >= 0.85 * max_efficiency) {
+ // printf("num_splits chosen = %d\n", num_splits);
+ return num_splits;
+ }
+ }
+ return 1;
+}
+
+int set_params_splitkv(const int batch_size, const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q) {
+ cudaDeviceProp dprops;
+ cudaGetDeviceProperties(&dprops, 0);
+
+ // This needs to match with run_mha_fwd_splitkv_dispatch
+ const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
+ const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
+ // Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
+ // In any case we don't expect seqlen_q to be larger than 64 for inference.
+ const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
+
+ // We multiply number of SMs by 2 to hard-code the fact that we're using 128 threads per block.
+ int num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops.multiProcessorCount * 2, num_n_blocks, 128);
+
+ return num_splits;
+}
+
+void mha_fwd_stage1(
+ bool is_bf16,
+ int batch_size,
+ int seqlen_q,
+ int seqlen_k,
+ int seqlen_c,
+ int num_heads,
+ int num_heads_k,
+ int head_size,
+ void* q, // batch_size x seqlen_q x num_heads x head_size
+ void* kcache, // batch_size x seqlen_k x num_heads_k x head_size
+ void* vcache, // batch_size x seqlen_k x num_heads_k x head_size
+ int* seqlens_k, // batch_size
+ void* p, // batch_size x seqlen_q x seqlen_k
+ const float softmax_scale,
+ bool is_causal,
+ int window_size_left,
+ int window_size_right,
+ const float softcap,
+ cudaStream_t stream,
+ int& q_round,
+ int& k_round
+) {
+ // causal=true is the same as causal=false in this case
+ if (seqlen_q == 1) { is_causal = false; }
+ if (is_causal) { window_size_right = 0; }
+
+ seqlen_q *= 16;
+ num_heads /= 16;
+
+ if (window_size_left >= seqlen_k) { window_size_left = -1; }
+ if (window_size_right >= seqlen_k) { window_size_right = -1; }
+
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+ const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
+ const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
+ const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
+
+ q_round = seqlen_q_rounded / 16;
+ k_round = seqlen_k_rounded;
+
+ Flash_fwd_params params;
+ set_params_fprop(params,
+ is_bf16,
+ batch_size,
+ seqlen_q, seqlen_k,
+ seqlen_q_rounded, seqlen_k_rounded,
+ num_heads, num_heads_k,
+ head_size, head_size_rounded,
+ q, kcache, vcache, nullptr,
+ /*cu_seqlens_q_d=*/nullptr,
+ /*cu_seqlens_k_d=*/nullptr,
+ /*seqused_k=*/nullptr,
+ /*p_ptr=*/p,
+ /*softmax_lse=*/nullptr,
+ /*p_dropout=*/0.f,
+ softmax_scale,
+ window_size_left,
+ window_size_right,
+ softcap
+ );
+
+ params.blockmask = nullptr;
+ params.block_window_size = 0;
+ params.m_block_dim = 16;
+ params.n_block_dim = 1;
+
+ params.seqlen_c = seqlen_c;
+
+ params.mask_2d = nullptr;
+ params.mask_q_range = 0;
+ params.mask_k_range = 0;
+
+ params.rotary_dim = 0;
+
+ params.page_block_size = 1;
+
+ params.alibi_slopes_ptr = nullptr;
+
+ if (seqlens_k != nullptr) {
+ params.cu_seqlens_k = seqlens_k;
+ params.is_seqlens_k_cumulative = false;
+ }
+ params.num_splits = 1;
+
+ run_mha_fwd(params, stream, true);
+}
+
+void mha_fwd_kvcache(
+ bool is_bf16,
+ int batch_size,
+ int seqlen_q,
+ int seqlen_k,
+ int num_heads,
+ int num_heads_k,
+ int head_size,
+ void* q, // batch_size x seqlen_q x num_heads x head_size
+ void* kcache, // batch_size x seqlen_k x num_heads_k x head_size
+ void* vcache, // batch_size x seqlen_k x num_heads_k x head_size
+ int* seqlens_k, // batch_size
+ const Mask& mask, // batch_size x seqlen_q x seqlen_k_range
+ void* out, // batch_size x seqlen_q x num_heads x head_size
+ float* softmax_lse, // batch_size x num_heads x seqlen_q
+ float* softmax_lse_accum, // num_splits x batch_size x num_heads x seqlen_q
+ float* oaccum, // num_splits x batch_size x num_heads x seqlen_q x head_size
+ const float softmax_scale,
+ bool is_causal,
+ int window_size_left,
+ int window_size_right,
+ const float softcap,
+ cudaStream_t stream,
+ uint64_t* blockmask = nullptr,
+ int block_window_size = 0
+) {
+ // causal=true is the same as causal=false in this case
+ if (seqlen_q == 1) { is_causal = false; }
+ if (is_causal) { window_size_right = 0; }
+
+ if (blockmask != nullptr) { // TODO improve this
+ // if (blockmask != nullptr || block_window_size > 0) {
+ seqlen_q *= 16;
+ num_heads /= 16;
+ }
+
+ if (window_size_left >= seqlen_k) { window_size_left = -1; }
+ if (window_size_right >= seqlen_k) { window_size_right = -1; }
+
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+ const int head_size_rounded = head_size <= 192 ? round_multiple(head_size, 32) : 256;
+ const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
+ const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
+
+ Flash_fwd_params params;
+ set_params_fprop(params,
+ is_bf16,
+ batch_size,
+ seqlen_q, seqlen_k,
+ seqlen_q_rounded, seqlen_k_rounded,
+ num_heads, num_heads_k,
+ head_size, head_size_rounded,
+ q, kcache, vcache, out,
+ /*cu_seqlens_q_d=*/nullptr,
+ /*cu_seqlens_k_d=*/nullptr,
+ /*seqused_k=*/nullptr,
+ /*p_ptr=*/nullptr,
+ (void*)softmax_lse,
+ /*p_dropout=*/0.f,
+ softmax_scale,
+ window_size_left,
+ window_size_right,
+ softcap
+ );
+
+ params.blockmask = blockmask;
+ if (blockmask != nullptr) { // TODO improve this
+ // if (blockmask != nullptr || block_window_size > 0) {
+ params.m_block_dim = 16;
+ params.n_block_dim = 64;
+ params.num_blocks_m = (seqlen_q + 16 - 1) / 16;
+ params.num_blocks_n = (seqlen_k + 64 - 1) / 64;
+ } else {
+ params.m_block_dim = 1;
+ params.n_block_dim = 1;
+ }
+ params.block_window_size = block_window_size;
+
+ params.mask_2d = mask.ptr;
+ params.mask_q_range = mask.mask_q_range;
+ params.mask_k_range = mask.mask_k_range;
+
+ params.rotary_dim = 0;
+
+ params.softmax_lseaccum_ptr = (void*)softmax_lse_accum;
+ params.oaccum_ptr = (void*)oaccum;
+
+ params.page_block_size = 1;
+
+ params.alibi_slopes_ptr = nullptr;
+
+ if (seqlens_k != nullptr) {
+ params.cu_seqlens_k = seqlens_k;
+ params.is_seqlens_k_cumulative = false;
+ params.num_splits = 16;
+ } else {
+ params.num_splits = 1;
+ }
+
+ run_mha_fwd(params, stream, true);
+}
\ No newline at end of file
diff --git a/examples/CPM.cu/src/flash_attn/src/alibi.h b/examples/CPM.cu/src/flash_attn/src/alibi.h
new file mode 100644
index 00000000..e714233e
--- /dev/null
+++ b/examples/CPM.cu/src/flash_attn/src/alibi.h
@@ -0,0 +1,74 @@
+#include
+
+#include