From c0a3a64a2a3ffa90da96c465f0c4eb156e415d01 Mon Sep 17 00:00:00 2001
From: MayDomine <1583143678@qq.com>
Date: Thu, 31 Jul 2025 18:49:26 +0800
Subject: [PATCH] Add two toolkit for training/inference
---
examples/BMTrain/.dockerignore | 147 +
.../.github/ISSUE_TEMPLATE/bug_report.yml | 79 +
.../.github/ISSUE_TEMPLATE/build_err.yml | 94 +
.../ISSUE_TEMPLATE/features_request.yml | 30 +
.../BMTrain/.github/pull_request_template.md | 29 +
examples/BMTrain/.github/workflows/build.yml | 35 +
.../BMTrain/.github/workflows/build_whl.yml | 89 +
.../BMTrain/.github/workflows/publish.yaml | 41 +
.../BMTrain/.github/workflows/release.yml | 44 +
examples/BMTrain/.gitignore | 155 +
examples/BMTrain/CMakeLists.txt | 65 +
examples/BMTrain/CONTRIBUTING.md | 55 +
examples/BMTrain/Dockerfile | 20 +
examples/BMTrain/LICENSE | 201 ++
examples/BMTrain/MANIFEST.in | 4 +
examples/BMTrain/README-ZH.md | 369 +++
examples/BMTrain/README.md | 375 +++
examples/BMTrain/Release.txt | 9 +
examples/BMTrain/bmtrain/__init__.py | 26 +
.../BMTrain/bmtrain/benchmark/__init__.py | 3 +
.../BMTrain/bmtrain/benchmark/all_gather.py | 28 +
.../bmtrain/benchmark/reduce_scatter.py | 28 +
.../BMTrain/bmtrain/benchmark/send_recv.py | 31 +
examples/BMTrain/bmtrain/benchmark/shape.py | 3 +
examples/BMTrain/bmtrain/benchmark/utils.py | 11 +
examples/BMTrain/bmtrain/block_layer.py | 726 +++++
examples/BMTrain/bmtrain/debug.py | 34 +
.../BMTrain/bmtrain/distributed/__init__.py | 1 +
examples/BMTrain/bmtrain/distributed/ops.py | 223 ++
examples/BMTrain/bmtrain/global_var.py | 35 +
examples/BMTrain/bmtrain/hook_func.py | 121 +
examples/BMTrain/bmtrain/init.py | 258 ++
examples/BMTrain/bmtrain/inspect/__init__.py | 3 +
examples/BMTrain/bmtrain/inspect/format.py | 64 +
examples/BMTrain/bmtrain/inspect/model.py | 246 ++
examples/BMTrain/bmtrain/inspect/tensor.py | 383 +++
examples/BMTrain/bmtrain/layer.py | 143 +
examples/BMTrain/bmtrain/loss/__init__.py | 1 +
examples/BMTrain/bmtrain/loss/_function.py | 182 ++
.../BMTrain/bmtrain/loss/cross_entropy.py | 260 ++
.../BMTrain/bmtrain/lr_scheduler/__init__.py | 6 +
.../BMTrain/bmtrain/lr_scheduler/cosine.py | 18 +
.../bmtrain/lr_scheduler/exponential.py | 20 +
.../BMTrain/bmtrain/lr_scheduler/linear.py | 19 +
.../BMTrain/bmtrain/lr_scheduler/no_decay.py | 14 +
examples/BMTrain/bmtrain/lr_scheduler/noam.py | 15 +
.../BMTrain/bmtrain/lr_scheduler/warmup.py | 72 +
examples/BMTrain/bmtrain/nccl/__init__.py | 336 +++
examples/BMTrain/bmtrain/nccl/enums.py | 27 +
examples/BMTrain/bmtrain/nn/__init__.py | 5 +
.../bmtrain/nn/column_parallel_linear.py | 80 +
examples/BMTrain/bmtrain/nn/linear.py | 56 +
.../BMTrain/bmtrain/nn/parallel_embedding.py | 59 +
.../bmtrain/nn/parallel_linear_func.py | 352 +++
.../BMTrain/bmtrain/nn/row_parallel_linear.py | 88 +
examples/BMTrain/bmtrain/optim/__init__.py | 3 +
.../BMTrain/bmtrain/optim/_distributed.py | 40 +
examples/BMTrain/bmtrain/optim/_function.py | 218 ++
examples/BMTrain/bmtrain/optim/adam.py | 252 ++
.../BMTrain/bmtrain/optim/adam_offload.py | 386 +++
.../BMTrain/bmtrain/optim/optim_manager.py | 226 ++
examples/BMTrain/bmtrain/param_init.py | 105 +
examples/BMTrain/bmtrain/parameter.py | 206 ++
examples/BMTrain/bmtrain/pipe_layer.py | 314 +++
examples/BMTrain/bmtrain/store.py | 325 +++
examples/BMTrain/bmtrain/synchronize.py | 73 +
examples/BMTrain/bmtrain/utils.py | 184 ++
examples/BMTrain/bmtrain/wrapper.py | 54 +
examples/BMTrain/bmtrain/zero_context.py | 203 ++
examples/BMTrain/cmake/FindNCCL.cmake | 100 +
examples/BMTrain/csrc/bind.cpp | 35 +
examples/BMTrain/csrc/cuda/adam_cuda.cu | 126 +
examples/BMTrain/csrc/cuda/bfloat16.cuh | 5 +
examples/BMTrain/csrc/cuda/cross_entropy.cu | 315 +++
examples/BMTrain/csrc/cuda/has_inf_nan.cu | 145 +
examples/BMTrain/csrc/cuda/reduce.cuh | 114 +
examples/BMTrain/csrc/include/adam_cpu.hpp | 557 ++++
examples/BMTrain/csrc/include/bind.hpp | 111 +
examples/BMTrain/csrc/include/cpu_info.h | 38 +
examples/BMTrain/csrc/include/nccl.hpp | 188 ++
examples/BMTrain/doc_requirements.txt | 5 +
examples/BMTrain/docs/Makefile | 20 +
examples/BMTrain/docs/UPDATE_0.2.0.md | 79 +
examples/BMTrain/docs/UPDATE_0.2.3.md | 26 +
examples/BMTrain/docs/UPDATE_1.0.0.md | 72 +
examples/BMTrain/docs/logo.png | Bin 0 -> 47918 bytes
examples/BMTrain/docs/make.bat | 35 +
.../docs/source-en/_static/css/custom.css | 124 +
.../docs/source-en/_static/js/custom.js | 7 +
.../source-en/api/bmtrain.benchmark.rst_bk | 53 +
.../source-en/api/bmtrain.distributed.rst_bk | 21 +
.../docs/source-en/api/bmtrain.inspect.rst | 37 +
.../docs/source-en/api/bmtrain.loss.rst | 21 +
.../source-en/api/bmtrain.lr_scheduler.rst | 61 +
.../docs/source-en/api/bmtrain.nccl.rst_bk | 21 +
.../BMTrain/docs/source-en/api/bmtrain.nn.rst | 53 +
.../docs/source-en/api/bmtrain.optim.rst | 37 +
.../BMTrain/docs/source-en/api/bmtrain.rst | 140 +
.../BMTrain/docs/source-en/api/modules.rst | 7 +
examples/BMTrain/docs/source-en/conf.py | 79 +
examples/BMTrain/docs/source-en/index.rst | 39 +
.../docs/source-en/notes/image/ZeRO3.png | Bin 0 -> 79501 bytes
.../notes/image/communication_example.png | Bin 0 -> 805075 bytes
.../notes/image/communication_fig.png | Bin 0 -> 46837 bytes
.../docs/source-en/notes/image/cpu.png | Bin 0 -> 89060 bytes
.../source-en/notes/image/zero3_example.png | Bin 0 -> 795194 bytes
.../docs/source-en/notes/installation.md | 45 +
.../docs/source-en/notes/quickstart.md | 159 ++
examples/BMTrain/docs/source-en/notes/tech.md | 11 +
.../docs/source/_static/css/custom.css | 124 +
.../BMTrain/docs/source/_static/js/custom.js | 7 +
.../docs/source/api/bmtrain.benchmark.rst_bk | 53 +
.../source/api/bmtrain.distributed.rst_bk | 21 +
.../docs/source/api/bmtrain.inspect.rst | 37 +
.../BMTrain/docs/source/api/bmtrain.loss.rst | 21 +
.../docs/source/api/bmtrain.lr_scheduler.rst | 61 +
.../docs/source/api/bmtrain.nccl.rst_bk | 21 +
.../BMTrain/docs/source/api/bmtrain.nn.rst | 53 +
.../BMTrain/docs/source/api/bmtrain.optim.rst | 37 +
examples/BMTrain/docs/source/api/bmtrain.rst | 140 +
examples/BMTrain/docs/source/api/modules.rst | 7 +
examples/BMTrain/docs/source/conf.py | 72 +
examples/BMTrain/docs/source/index.rst | 39 +
.../BMTrain/docs/source/notes/image/ZeRO3.png | Bin 0 -> 79501 bytes
.../notes/image/communication_example.png | Bin 0 -> 805075 bytes
.../source/notes/image/communication_fig.png | Bin 0 -> 46837 bytes
.../BMTrain/docs/source/notes/image/cpu.png | Bin 0 -> 89060 bytes
.../docs/source/notes/image/zero3_example.png | Bin 0 -> 795194 bytes
.../BMTrain/docs/source/notes/installation.md | 45 +
.../BMTrain/docs/source/notes/quickstart.md | 146 +
examples/BMTrain/docs/source/notes/tech.md | 11 +
examples/BMTrain/example/README.md | 5 +
examples/BMTrain/example/benchmark.py | 12 +
examples/BMTrain/example/layers/__init__.py | 5 +
examples/BMTrain/example/layers/attention.py | 118 +
examples/BMTrain/example/layers/embedding.py | 102 +
.../BMTrain/example/layers/feedforward.py | 23 +
examples/BMTrain/example/layers/layernorm.py | 34 +
.../BMTrain/example/layers/transformer.py | 34 +
examples/BMTrain/example/models/__init__.py | 1 +
examples/BMTrain/example/models/gpt.py | 64 +
examples/BMTrain/example/run.sh | 3 +
examples/BMTrain/example/sbatch.sh | 20 +
examples/BMTrain/example/train.py | 138 +
examples/BMTrain/other_requirements.txt | 6 +
examples/BMTrain/pyproject.toml | 8 +
examples/BMTrain/setup.py | 113 +
examples/BMTrain/tests/test_all.py | 43 +
.../tests/test_column_parallel_linear.py | 74 +
.../tests/test_different_output_shape.py | 50 +
examples/BMTrain/tests/test_dropout.py | 45 +
examples/BMTrain/tests/test_grad_accu.py | 80 +
examples/BMTrain/tests/test_has_inf_nan.py | 37 +
.../BMTrain/tests/test_init_parameters.py | 223 ++
.../tests/test_init_parameters_multi_gpu.py | 146 +
.../BMTrain/tests/test_inspector_hidden.py | 241 ++
examples/BMTrain/tests/test_load_ckpt.py | 78 +
examples/BMTrain/tests/test_loss_func.py | 79 +
examples/BMTrain/tests/test_middle_hidden.py | 212 ++
examples/BMTrain/tests/test_model_wrapper.py | 221 ++
examples/BMTrain/tests/test_multi_return.py | 126 +
examples/BMTrain/tests/test_nccl_backward.py | 43 +
examples/BMTrain/tests/test_no_grad.py | 90 +
examples/BMTrain/tests/test_optim.py | 94 +
examples/BMTrain/tests/test_optim_state.py | 135 +
examples/BMTrain/tests/test_other_hidden.py | 189 ++
.../BMTrain/tests/test_parallel_projection.py | 55 +
examples/BMTrain/tests/test_requires_grad.py | 107 +
.../tests/test_requires_grad_multi_gpu.py | 96 +
.../BMTrain/tests/test_row_parallel_linear.py | 54 +
examples/BMTrain/tests/test_send_recv.py | 22 +
examples/BMTrain/tests/test_store.py | 13 +
examples/BMTrain/tests/test_synchronize.py | 26 +
examples/BMTrain/tests/test_training.py | 516 ++++
examples/BMTrain/tests/utils.py | 14 +
examples/CPM.cu/.arsync | 10 +
examples/CPM.cu/.gitignore | 222 ++
examples/CPM.cu/.gitmodules | 3 +
examples/CPM.cu/LICENSE | 201 ++
examples/CPM.cu/README.md | 167 ++
examples/CPM.cu/README_ZH.md | 165 ++
examples/CPM.cu/cpmcu/__init__.py | 0
examples/CPM.cu/cpmcu/llm.py | 422 +++
.../CPM.cu/cpmcu/llm_w4a16_gptq_marlin.py | 434 +++
examples/CPM.cu/cpmcu/speculative/__init__.py | 1 +
examples/CPM.cu/cpmcu/speculative/eagle.py | 99 +
.../speculative/eagle_base_quant/__init__.py | 0
.../eagle_base_w4a16_marlin_gptq.py | 103 +
.../speculative/hier_spec_quant/__init__.py | 0
.../hier_eagle_w4a16_gm_spec_w4a16_gm.py | 268 ++
.../cpmcu/speculative/spec_quant/__init__.py | 0
.../spec_w4a16_gm_for_w4a16_gm_model.py | 160 ++
.../CPM.cu/cpmcu/speculative/tree_drafter.py | 255 ++
.../tree_drafter_base_quant/__init__.py | 0
.../tree_drafter_w4a16_gptq_marlin.py | 240 ++
.../model_convert/convert_llama_format.py | 44 +
.../CPM.cu/model_convert/convert_w4a16.py | 287 ++
.../model_convert/post_process_w4a16_eagle.py | 48 +
.../CPM.cu/scripts/fr_spec/gen_fr_index.py | 89 +
.../scripts/model_convert/convert_w4a16.sh | 9 +
examples/CPM.cu/setup.py | 317 +++
examples/CPM.cu/src/entry.cu | 534 ++++
examples/CPM.cu/src/flash_attn/flash_api.hpp | 392 +++
examples/CPM.cu/src/flash_attn/src/alibi.h | 74 +
.../CPM.cu/src/flash_attn/src/block_info.h | 54 +
examples/CPM.cu/src/flash_attn/src/dropout.h | 94 +
examples/CPM.cu/src/flash_attn/src/flash.h | 193 ++
.../src/flash_attn/src/flash_blockmask.h | 108 +
.../src/flash_fwd_hdim128_bf16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim128_bf16_sm80.cu | 10 +
.../src/flash_fwd_hdim128_fp16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim128_fp16_sm80.cu | 10 +
.../src/flash_fwd_hdim160_bf16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim160_bf16_sm80.cu | 10 +
.../src/flash_fwd_hdim160_fp16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim160_fp16_sm80.cu | 10 +
.../src/flash_fwd_hdim192_bf16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim192_bf16_sm80.cu | 10 +
.../src/flash_fwd_hdim192_fp16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim192_fp16_sm80.cu | 10 +
.../src/flash_fwd_hdim224_bf16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim224_bf16_sm80.cu | 10 +
.../src/flash_fwd_hdim224_fp16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim224_fp16_sm80.cu | 10 +
.../src/flash_fwd_hdim256_bf16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim256_bf16_sm80.cu | 10 +
.../src/flash_fwd_hdim256_fp16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim256_fp16_sm80.cu | 10 +
.../src/flash_fwd_hdim32_bf16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim32_bf16_sm80.cu | 10 +
.../src/flash_fwd_hdim32_fp16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim32_fp16_sm80.cu | 10 +
.../src/flash_fwd_hdim64_bf16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim64_bf16_sm80.cu | 10 +
.../src/flash_fwd_hdim64_fp16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim64_fp16_sm80.cu | 10 +
.../src/flash_fwd_hdim96_bf16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim96_bf16_sm80.cu | 10 +
.../src/flash_fwd_hdim96_fp16_causal_sm80.cu | 10 +
.../src/flash_fwd_hdim96_fp16_sm80.cu | 10 +
.../src/flash_attn/src/flash_fwd_kernel.h | 2503 +++++++++++++++++
.../src/flash_fwd_launch_template.h | 382 +++
...lash_fwd_split_hdim128_bf16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim128_bf16_sm80.cu | 7 +
...lash_fwd_split_hdim128_fp16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim128_fp16_sm80.cu | 7 +
...lash_fwd_split_hdim160_bf16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim160_bf16_sm80.cu | 7 +
...lash_fwd_split_hdim160_fp16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim160_fp16_sm80.cu | 7 +
...lash_fwd_split_hdim192_bf16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim192_bf16_sm80.cu | 7 +
...lash_fwd_split_hdim192_fp16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim192_fp16_sm80.cu | 7 +
...lash_fwd_split_hdim224_bf16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim224_bf16_sm80.cu | 7 +
...lash_fwd_split_hdim224_fp16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim224_fp16_sm80.cu | 7 +
...lash_fwd_split_hdim256_bf16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim256_bf16_sm80.cu | 7 +
...lash_fwd_split_hdim256_fp16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim256_fp16_sm80.cu | 7 +
...flash_fwd_split_hdim32_bf16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim32_bf16_sm80.cu | 7 +
...flash_fwd_split_hdim32_fp16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim32_fp16_sm80.cu | 7 +
...flash_fwd_split_hdim64_bf16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim64_bf16_sm80.cu | 7 +
...flash_fwd_split_hdim64_fp16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim64_fp16_sm80.cu | 7 +
...flash_fwd_split_hdim96_bf16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim96_bf16_sm80.cu | 7 +
...flash_fwd_split_hdim96_fp16_causal_sm80.cu | 7 +
.../src/flash_fwd_split_hdim96_fp16_sm80.cu | 7 +
.../src/flash_attn/src/generate_kernels.py | 108 +
.../CPM.cu/src/flash_attn/src/kernel_traits.h | 349 +++
examples/CPM.cu/src/flash_attn/src/mask.h | 328 +++
examples/CPM.cu/src/flash_attn/src/philox.cuh | 51 +
examples/CPM.cu/src/flash_attn/src/rotary.h | 152 +
examples/CPM.cu/src/flash_attn/src/softmax.h | 259 ++
.../CPM.cu/src/flash_attn/src/static_switch.h | 142 +
examples/CPM.cu/src/flash_attn/src/utils.h | 411 +++
examples/CPM.cu/src/model/activation.cuh | 83 +
examples/CPM.cu/src/model/attn.cuh | 235 ++
examples/CPM.cu/src/model/drafter.cuh | 48 +
examples/CPM.cu/src/model/eagle.cuh | 517 ++++
.../eagle_base_w4a16_gptq_marlin.cuh | 260 ++
examples/CPM.cu/src/model/elementwise.cuh | 87 +
examples/CPM.cu/src/model/embedding.cuh | 53 +
examples/CPM.cu/src/model/ffn.cuh | 91 +
.../hier_ea_w4a16_gm_rot_spec_w4a16_gm.cuh | 706 +++++
.../hier_ea_w4a16_gm_spec_w4a16_gm.cuh | 662 +++++
examples/CPM.cu/src/model/kvcache.cuh | 65 +
examples/CPM.cu/src/model/layer.cuh | 90 +
examples/CPM.cu/src/model/linear.cuh | 101 +
examples/CPM.cu/src/model/mask.cuh | 18 +
examples/CPM.cu/src/model/memory.cuh | 183 ++
.../src/model/minicpm4/minicpm4_attn.cuh | 324 +++
.../src/model/minicpm4/minicpm4_eagle.cuh | 419 +++
.../src/model/minicpm4/minicpm4_kvcache.cuh | 316 +++
.../src/model/minicpm4/minicpm4_layer.cuh | 85 +
.../src/model/minicpm4/minicpm4_model.cuh | 160 ++
.../minicpm4_w4a16_gptq_marlin_attn.cuh | 319 +++
.../minicpm4_w4a16_gptq_marlin_layer.cuh | 98 +
.../minicpm4_w4a16_gptq_marlin_model.cuh | 161 ++
examples/CPM.cu/src/model/model.cuh | 174 ++
examples/CPM.cu/src/model/norm.cuh | 154 +
examples/CPM.cu/src/model/rotary.cuh | 68 +
.../spec_quant/w4a16_gm_spec_w4a16_gm.cuh | 247 ++
examples/CPM.cu/src/model/topk.cuh | 292 ++
examples/CPM.cu/src/model/tree_drafter.cuh | 111 +
.../w4a16_gptq_marlin_attn.cuh | 178 ++
.../w4a16_gptq_marlin_ffn.cuh | 70 +
.../w4a16_gptq_marlin_layer.cuh | 98 +
.../w4a16_gptq_marlin_linear.cuh | 143 +
.../w4a16_gptq_marlin_model.cuh | 157 ++
examples/CPM.cu/src/perf.cu | 9 +
examples/CPM.cu/src/perf.cuh | 291 ++
.../qgemm/gptq_marlin/core/scalar_type.hpp | 348 +++
.../src/qgemm/gptq_marlin/gptq_marlin.cu | 2175 ++++++++++++++
.../src/qgemm/gptq_marlin/gptq_marlin.cuh | 27 +
.../CPM.cu/src/qgemm/gptq_marlin/marlin.cuh | 87 +
.../src/qgemm/gptq_marlin/marlin_dtypes.cuh | 75 +
examples/CPM.cu/src/signal_handler.cu | 126 +
examples/CPM.cu/src/signal_handler.cuh | 26 +
examples/CPM.cu/src/trait.cuh | 37 +
examples/CPM.cu/src/utils.cu | 25 +
examples/CPM.cu/src/utils.cuh | 60 +
examples/CPM.cu/src/utilsq.cuh | 421 +++
examples/CPM.cu/tests/long_prompt_gen.py | 110 +
examples/CPM.cu/tests/test_generate.py | 544 ++++
331 files changed, 37339 insertions(+)
create mode 100644 examples/BMTrain/.dockerignore
create mode 100644 examples/BMTrain/.github/ISSUE_TEMPLATE/bug_report.yml
create mode 100644 examples/BMTrain/.github/ISSUE_TEMPLATE/build_err.yml
create mode 100644 examples/BMTrain/.github/ISSUE_TEMPLATE/features_request.yml
create mode 100644 examples/BMTrain/.github/pull_request_template.md
create mode 100644 examples/BMTrain/.github/workflows/build.yml
create mode 100644 examples/BMTrain/.github/workflows/build_whl.yml
create mode 100644 examples/BMTrain/.github/workflows/publish.yaml
create mode 100644 examples/BMTrain/.github/workflows/release.yml
create mode 100644 examples/BMTrain/.gitignore
create mode 100644 examples/BMTrain/CMakeLists.txt
create mode 100644 examples/BMTrain/CONTRIBUTING.md
create mode 100644 examples/BMTrain/Dockerfile
create mode 100644 examples/BMTrain/LICENSE
create mode 100644 examples/BMTrain/MANIFEST.in
create mode 100644 examples/BMTrain/README-ZH.md
create mode 100644 examples/BMTrain/README.md
create mode 100644 examples/BMTrain/Release.txt
create mode 100644 examples/BMTrain/bmtrain/__init__.py
create mode 100644 examples/BMTrain/bmtrain/benchmark/__init__.py
create mode 100644 examples/BMTrain/bmtrain/benchmark/all_gather.py
create mode 100644 examples/BMTrain/bmtrain/benchmark/reduce_scatter.py
create mode 100644 examples/BMTrain/bmtrain/benchmark/send_recv.py
create mode 100644 examples/BMTrain/bmtrain/benchmark/shape.py
create mode 100644 examples/BMTrain/bmtrain/benchmark/utils.py
create mode 100644 examples/BMTrain/bmtrain/block_layer.py
create mode 100644 examples/BMTrain/bmtrain/debug.py
create mode 100644 examples/BMTrain/bmtrain/distributed/__init__.py
create mode 100644 examples/BMTrain/bmtrain/distributed/ops.py
create mode 100644 examples/BMTrain/bmtrain/global_var.py
create mode 100644 examples/BMTrain/bmtrain/hook_func.py
create mode 100644 examples/BMTrain/bmtrain/init.py
create mode 100644 examples/BMTrain/bmtrain/inspect/__init__.py
create mode 100644 examples/BMTrain/bmtrain/inspect/format.py
create mode 100644 examples/BMTrain/bmtrain/inspect/model.py
create mode 100644 examples/BMTrain/bmtrain/inspect/tensor.py
create mode 100644 examples/BMTrain/bmtrain/layer.py
create mode 100644 examples/BMTrain/bmtrain/loss/__init__.py
create mode 100644 examples/BMTrain/bmtrain/loss/_function.py
create mode 100644 examples/BMTrain/bmtrain/loss/cross_entropy.py
create mode 100644 examples/BMTrain/bmtrain/lr_scheduler/__init__.py
create mode 100644 examples/BMTrain/bmtrain/lr_scheduler/cosine.py
create mode 100644 examples/BMTrain/bmtrain/lr_scheduler/exponential.py
create mode 100644 examples/BMTrain/bmtrain/lr_scheduler/linear.py
create mode 100644 examples/BMTrain/bmtrain/lr_scheduler/no_decay.py
create mode 100644 examples/BMTrain/bmtrain/lr_scheduler/noam.py
create mode 100644 examples/BMTrain/bmtrain/lr_scheduler/warmup.py
create mode 100644 examples/BMTrain/bmtrain/nccl/__init__.py
create mode 100644 examples/BMTrain/bmtrain/nccl/enums.py
create mode 100644 examples/BMTrain/bmtrain/nn/__init__.py
create mode 100644 examples/BMTrain/bmtrain/nn/column_parallel_linear.py
create mode 100644 examples/BMTrain/bmtrain/nn/linear.py
create mode 100644 examples/BMTrain/bmtrain/nn/parallel_embedding.py
create mode 100644 examples/BMTrain/bmtrain/nn/parallel_linear_func.py
create mode 100644 examples/BMTrain/bmtrain/nn/row_parallel_linear.py
create mode 100644 examples/BMTrain/bmtrain/optim/__init__.py
create mode 100644 examples/BMTrain/bmtrain/optim/_distributed.py
create mode 100644 examples/BMTrain/bmtrain/optim/_function.py
create mode 100644 examples/BMTrain/bmtrain/optim/adam.py
create mode 100644 examples/BMTrain/bmtrain/optim/adam_offload.py
create mode 100644 examples/BMTrain/bmtrain/optim/optim_manager.py
create mode 100644 examples/BMTrain/bmtrain/param_init.py
create mode 100644 examples/BMTrain/bmtrain/parameter.py
create mode 100644 examples/BMTrain/bmtrain/pipe_layer.py
create mode 100644 examples/BMTrain/bmtrain/store.py
create mode 100644 examples/BMTrain/bmtrain/synchronize.py
create mode 100644 examples/BMTrain/bmtrain/utils.py
create mode 100644 examples/BMTrain/bmtrain/wrapper.py
create mode 100644 examples/BMTrain/bmtrain/zero_context.py
create mode 100644 examples/BMTrain/cmake/FindNCCL.cmake
create mode 100644 examples/BMTrain/csrc/bind.cpp
create mode 100644 examples/BMTrain/csrc/cuda/adam_cuda.cu
create mode 100644 examples/BMTrain/csrc/cuda/bfloat16.cuh
create mode 100644 examples/BMTrain/csrc/cuda/cross_entropy.cu
create mode 100644 examples/BMTrain/csrc/cuda/has_inf_nan.cu
create mode 100644 examples/BMTrain/csrc/cuda/reduce.cuh
create mode 100644 examples/BMTrain/csrc/include/adam_cpu.hpp
create mode 100644 examples/BMTrain/csrc/include/bind.hpp
create mode 100644 examples/BMTrain/csrc/include/cpu_info.h
create mode 100644 examples/BMTrain/csrc/include/nccl.hpp
create mode 100644 examples/BMTrain/doc_requirements.txt
create mode 100644 examples/BMTrain/docs/Makefile
create mode 100644 examples/BMTrain/docs/UPDATE_0.2.0.md
create mode 100644 examples/BMTrain/docs/UPDATE_0.2.3.md
create mode 100644 examples/BMTrain/docs/UPDATE_1.0.0.md
create mode 100644 examples/BMTrain/docs/logo.png
create mode 100644 examples/BMTrain/docs/make.bat
create mode 100644 examples/BMTrain/docs/source-en/_static/css/custom.css
create mode 100644 examples/BMTrain/docs/source-en/_static/js/custom.js
create mode 100644 examples/BMTrain/docs/source-en/api/bmtrain.benchmark.rst_bk
create mode 100644 examples/BMTrain/docs/source-en/api/bmtrain.distributed.rst_bk
create mode 100644 examples/BMTrain/docs/source-en/api/bmtrain.inspect.rst
create mode 100644 examples/BMTrain/docs/source-en/api/bmtrain.loss.rst
create mode 100644 examples/BMTrain/docs/source-en/api/bmtrain.lr_scheduler.rst
create mode 100644 examples/BMTrain/docs/source-en/api/bmtrain.nccl.rst_bk
create mode 100644 examples/BMTrain/docs/source-en/api/bmtrain.nn.rst
create mode 100644 examples/BMTrain/docs/source-en/api/bmtrain.optim.rst
create mode 100644 examples/BMTrain/docs/source-en/api/bmtrain.rst
create mode 100644 examples/BMTrain/docs/source-en/api/modules.rst
create mode 100644 examples/BMTrain/docs/source-en/conf.py
create mode 100644 examples/BMTrain/docs/source-en/index.rst
create mode 100644 examples/BMTrain/docs/source-en/notes/image/ZeRO3.png
create mode 100644 examples/BMTrain/docs/source-en/notes/image/communication_example.png
create mode 100644 examples/BMTrain/docs/source-en/notes/image/communication_fig.png
create mode 100644 examples/BMTrain/docs/source-en/notes/image/cpu.png
create mode 100644 examples/BMTrain/docs/source-en/notes/image/zero3_example.png
create mode 100644 examples/BMTrain/docs/source-en/notes/installation.md
create mode 100644 examples/BMTrain/docs/source-en/notes/quickstart.md
create mode 100644 examples/BMTrain/docs/source-en/notes/tech.md
create mode 100644 examples/BMTrain/docs/source/_static/css/custom.css
create mode 100644 examples/BMTrain/docs/source/_static/js/custom.js
create mode 100644 examples/BMTrain/docs/source/api/bmtrain.benchmark.rst_bk
create mode 100644 examples/BMTrain/docs/source/api/bmtrain.distributed.rst_bk
create mode 100644 examples/BMTrain/docs/source/api/bmtrain.inspect.rst
create mode 100644 examples/BMTrain/docs/source/api/bmtrain.loss.rst
create mode 100644 examples/BMTrain/docs/source/api/bmtrain.lr_scheduler.rst
create mode 100644 examples/BMTrain/docs/source/api/bmtrain.nccl.rst_bk
create mode 100644 examples/BMTrain/docs/source/api/bmtrain.nn.rst
create mode 100644 examples/BMTrain/docs/source/api/bmtrain.optim.rst
create mode 100644 examples/BMTrain/docs/source/api/bmtrain.rst
create mode 100644 examples/BMTrain/docs/source/api/modules.rst
create mode 100644 examples/BMTrain/docs/source/conf.py
create mode 100644 examples/BMTrain/docs/source/index.rst
create mode 100644 examples/BMTrain/docs/source/notes/image/ZeRO3.png
create mode 100644 examples/BMTrain/docs/source/notes/image/communication_example.png
create mode 100644 examples/BMTrain/docs/source/notes/image/communication_fig.png
create mode 100644 examples/BMTrain/docs/source/notes/image/cpu.png
create mode 100644 examples/BMTrain/docs/source/notes/image/zero3_example.png
create mode 100644 examples/BMTrain/docs/source/notes/installation.md
create mode 100644 examples/BMTrain/docs/source/notes/quickstart.md
create mode 100644 examples/BMTrain/docs/source/notes/tech.md
create mode 100644 examples/BMTrain/example/README.md
create mode 100644 examples/BMTrain/example/benchmark.py
create mode 100644 examples/BMTrain/example/layers/__init__.py
create mode 100644 examples/BMTrain/example/layers/attention.py
create mode 100644 examples/BMTrain/example/layers/embedding.py
create mode 100644 examples/BMTrain/example/layers/feedforward.py
create mode 100644 examples/BMTrain/example/layers/layernorm.py
create mode 100644 examples/BMTrain/example/layers/transformer.py
create mode 100644 examples/BMTrain/example/models/__init__.py
create mode 100644 examples/BMTrain/example/models/gpt.py
create mode 100644 examples/BMTrain/example/run.sh
create mode 100644 examples/BMTrain/example/sbatch.sh
create mode 100644 examples/BMTrain/example/train.py
create mode 100644 examples/BMTrain/other_requirements.txt
create mode 100644 examples/BMTrain/pyproject.toml
create mode 100644 examples/BMTrain/setup.py
create mode 100644 examples/BMTrain/tests/test_all.py
create mode 100644 examples/BMTrain/tests/test_column_parallel_linear.py
create mode 100644 examples/BMTrain/tests/test_different_output_shape.py
create mode 100644 examples/BMTrain/tests/test_dropout.py
create mode 100644 examples/BMTrain/tests/test_grad_accu.py
create mode 100644 examples/BMTrain/tests/test_has_inf_nan.py
create mode 100644 examples/BMTrain/tests/test_init_parameters.py
create mode 100644 examples/BMTrain/tests/test_init_parameters_multi_gpu.py
create mode 100644 examples/BMTrain/tests/test_inspector_hidden.py
create mode 100644 examples/BMTrain/tests/test_load_ckpt.py
create mode 100644 examples/BMTrain/tests/test_loss_func.py
create mode 100644 examples/BMTrain/tests/test_middle_hidden.py
create mode 100644 examples/BMTrain/tests/test_model_wrapper.py
create mode 100644 examples/BMTrain/tests/test_multi_return.py
create mode 100644 examples/BMTrain/tests/test_nccl_backward.py
create mode 100644 examples/BMTrain/tests/test_no_grad.py
create mode 100644 examples/BMTrain/tests/test_optim.py
create mode 100644 examples/BMTrain/tests/test_optim_state.py
create mode 100644 examples/BMTrain/tests/test_other_hidden.py
create mode 100644 examples/BMTrain/tests/test_parallel_projection.py
create mode 100644 examples/BMTrain/tests/test_requires_grad.py
create mode 100644 examples/BMTrain/tests/test_requires_grad_multi_gpu.py
create mode 100644 examples/BMTrain/tests/test_row_parallel_linear.py
create mode 100644 examples/BMTrain/tests/test_send_recv.py
create mode 100644 examples/BMTrain/tests/test_store.py
create mode 100644 examples/BMTrain/tests/test_synchronize.py
create mode 100644 examples/BMTrain/tests/test_training.py
create mode 100644 examples/BMTrain/tests/utils.py
create mode 100644 examples/CPM.cu/.arsync
create mode 100644 examples/CPM.cu/.gitignore
create mode 100644 examples/CPM.cu/.gitmodules
create mode 100644 examples/CPM.cu/LICENSE
create mode 100644 examples/CPM.cu/README.md
create mode 100644 examples/CPM.cu/README_ZH.md
create mode 100644 examples/CPM.cu/cpmcu/__init__.py
create mode 100644 examples/CPM.cu/cpmcu/llm.py
create mode 100644 examples/CPM.cu/cpmcu/llm_w4a16_gptq_marlin.py
create mode 100644 examples/CPM.cu/cpmcu/speculative/__init__.py
create mode 100644 examples/CPM.cu/cpmcu/speculative/eagle.py
create mode 100644 examples/CPM.cu/cpmcu/speculative/eagle_base_quant/__init__.py
create mode 100644 examples/CPM.cu/cpmcu/speculative/eagle_base_quant/eagle_base_w4a16_marlin_gptq.py
create mode 100644 examples/CPM.cu/cpmcu/speculative/hier_spec_quant/__init__.py
create mode 100644 examples/CPM.cu/cpmcu/speculative/hier_spec_quant/hier_eagle_w4a16_gm_spec_w4a16_gm.py
create mode 100644 examples/CPM.cu/cpmcu/speculative/spec_quant/__init__.py
create mode 100644 examples/CPM.cu/cpmcu/speculative/spec_quant/spec_w4a16_gm_for_w4a16_gm_model.py
create mode 100644 examples/CPM.cu/cpmcu/speculative/tree_drafter.py
create mode 100644 examples/CPM.cu/cpmcu/speculative/tree_drafter_base_quant/__init__.py
create mode 100644 examples/CPM.cu/cpmcu/speculative/tree_drafter_base_quant/tree_drafter_w4a16_gptq_marlin.py
create mode 100644 examples/CPM.cu/model_convert/convert_llama_format.py
create mode 100644 examples/CPM.cu/model_convert/convert_w4a16.py
create mode 100644 examples/CPM.cu/model_convert/post_process_w4a16_eagle.py
create mode 100644 examples/CPM.cu/scripts/fr_spec/gen_fr_index.py
create mode 100644 examples/CPM.cu/scripts/model_convert/convert_w4a16.sh
create mode 100644 examples/CPM.cu/setup.py
create mode 100644 examples/CPM.cu/src/entry.cu
create mode 100644 examples/CPM.cu/src/flash_attn/flash_api.hpp
create mode 100644 examples/CPM.cu/src/flash_attn/src/alibi.h
create mode 100644 examples/CPM.cu/src/flash_attn/src/block_info.h
create mode 100644 examples/CPM.cu/src/flash_attn/src/dropout.h
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash.h
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_blockmask.h
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_kernel.h
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_launch_template.h
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu
create mode 100644 examples/CPM.cu/src/flash_attn/src/generate_kernels.py
create mode 100644 examples/CPM.cu/src/flash_attn/src/kernel_traits.h
create mode 100644 examples/CPM.cu/src/flash_attn/src/mask.h
create mode 100644 examples/CPM.cu/src/flash_attn/src/philox.cuh
create mode 100644 examples/CPM.cu/src/flash_attn/src/rotary.h
create mode 100644 examples/CPM.cu/src/flash_attn/src/softmax.h
create mode 100644 examples/CPM.cu/src/flash_attn/src/static_switch.h
create mode 100644 examples/CPM.cu/src/flash_attn/src/utils.h
create mode 100644 examples/CPM.cu/src/model/activation.cuh
create mode 100644 examples/CPM.cu/src/model/attn.cuh
create mode 100644 examples/CPM.cu/src/model/drafter.cuh
create mode 100644 examples/CPM.cu/src/model/eagle.cuh
create mode 100644 examples/CPM.cu/src/model/eagle_base_quant/eagle_base_w4a16_gptq_marlin.cuh
create mode 100644 examples/CPM.cu/src/model/elementwise.cuh
create mode 100644 examples/CPM.cu/src/model/embedding.cuh
create mode 100644 examples/CPM.cu/src/model/ffn.cuh
create mode 100644 examples/CPM.cu/src/model/hier_spec_quant/hier_ea_w4a16_gm_rot_spec_w4a16_gm.cuh
create mode 100644 examples/CPM.cu/src/model/hier_spec_quant/hier_ea_w4a16_gm_spec_w4a16_gm.cuh
create mode 100644 examples/CPM.cu/src/model/kvcache.cuh
create mode 100644 examples/CPM.cu/src/model/layer.cuh
create mode 100644 examples/CPM.cu/src/model/linear.cuh
create mode 100644 examples/CPM.cu/src/model/mask.cuh
create mode 100644 examples/CPM.cu/src/model/memory.cuh
create mode 100644 examples/CPM.cu/src/model/minicpm4/minicpm4_attn.cuh
create mode 100644 examples/CPM.cu/src/model/minicpm4/minicpm4_eagle.cuh
create mode 100644 examples/CPM.cu/src/model/minicpm4/minicpm4_kvcache.cuh
create mode 100644 examples/CPM.cu/src/model/minicpm4/minicpm4_layer.cuh
create mode 100644 examples/CPM.cu/src/model/minicpm4/minicpm4_model.cuh
create mode 100644 examples/CPM.cu/src/model/minicpm4/minicpm4_w4a16_gptq_marlin_attn.cuh
create mode 100644 examples/CPM.cu/src/model/minicpm4/minicpm4_w4a16_gptq_marlin_layer.cuh
create mode 100644 examples/CPM.cu/src/model/minicpm4/minicpm4_w4a16_gptq_marlin_model.cuh
create mode 100644 examples/CPM.cu/src/model/model.cuh
create mode 100644 examples/CPM.cu/src/model/norm.cuh
create mode 100644 examples/CPM.cu/src/model/rotary.cuh
create mode 100644 examples/CPM.cu/src/model/spec_quant/w4a16_gm_spec_w4a16_gm.cuh
create mode 100644 examples/CPM.cu/src/model/topk.cuh
create mode 100644 examples/CPM.cu/src/model/tree_drafter.cuh
create mode 100644 examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_attn.cuh
create mode 100644 examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_ffn.cuh
create mode 100644 examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_layer.cuh
create mode 100644 examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_linear.cuh
create mode 100644 examples/CPM.cu/src/model/w4a16_gptq_marlin/w4a16_gptq_marlin_model.cuh
create mode 100644 examples/CPM.cu/src/perf.cu
create mode 100644 examples/CPM.cu/src/perf.cuh
create mode 100644 examples/CPM.cu/src/qgemm/gptq_marlin/core/scalar_type.hpp
create mode 100644 examples/CPM.cu/src/qgemm/gptq_marlin/gptq_marlin.cu
create mode 100644 examples/CPM.cu/src/qgemm/gptq_marlin/gptq_marlin.cuh
create mode 100644 examples/CPM.cu/src/qgemm/gptq_marlin/marlin.cuh
create mode 100644 examples/CPM.cu/src/qgemm/gptq_marlin/marlin_dtypes.cuh
create mode 100644 examples/CPM.cu/src/signal_handler.cu
create mode 100644 examples/CPM.cu/src/signal_handler.cuh
create mode 100644 examples/CPM.cu/src/trait.cuh
create mode 100644 examples/CPM.cu/src/utils.cu
create mode 100644 examples/CPM.cu/src/utils.cuh
create mode 100644 examples/CPM.cu/src/utilsq.cuh
create mode 100644 examples/CPM.cu/tests/long_prompt_gen.py
create mode 100644 examples/CPM.cu/tests/test_generate.py
diff --git a/examples/BMTrain/.dockerignore b/examples/BMTrain/.dockerignore
new file mode 100644
index 00000000..c1591543
--- /dev/null
+++ b/examples/BMTrain/.dockerignore
@@ -0,0 +1,147 @@
+**/__pycache__/
+**/*.py[cod]
+**/*$py.class
+
+# C extensions
+**/*.so
+
+# Distribution / packaging
+**/.Python
+**/build/
+**/develop-eggs/
+**/dist/
+**/downloads/
+**/eggs/
+**/.eggs/
+**/lib/
+**/lib64/
+**/parts/
+**/sdist/
+**/var/
+**/wheels/
+**/share/python-wheels/
+**/*.egg-info/
+**/.installed.cfg
+**/*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+**/*.manifest
+**/*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+**/htmlcov/
+**/.tox/
+**/.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+**/*.cover
+**/*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+**/.pyre/
+
+# pytype static type analyzer
+**/.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+**/*.pt
+
+**/*.npy
+
+**/.DS_Store
+
+**/log
+**/*.qdrep
+!bmtrain/dist
\ No newline at end of file
diff --git a/examples/BMTrain/.github/ISSUE_TEMPLATE/bug_report.yml b/examples/BMTrain/.github/ISSUE_TEMPLATE/bug_report.yml
new file mode 100644
index 00000000..47c89fe6
--- /dev/null
+++ b/examples/BMTrain/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -0,0 +1,79 @@
+
+name: 🐞 Bug Report
+description: Report a bug/issue related to the PyTorch-based parallel model training toolkit
+title: "[BUG]
"
+labels: ["bug"]
+body:
+- type: checkboxes
+ attributes:
+ label: Is there an existing issue for this?
+ description: Please search to see if an issue already exists for the bug you encountered.
+ options:
+ - label: I have searched the existing issues
+ required: true
+- type: textarea
+ attributes:
+ label: Description of the Bug
+ description: Provide a clear and concise description of what the bug is.
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Environment Information
+ description: |
+ Provide details about your environment.
+ Example:
+ - GCC version: 9.3.0
+ - Torch version: 1.9.0
+ - Linux system version: Ubuntu 20.04
+ - CUDA version: 11.4
+ - Torch's CUDA version (as per `torch.cuda.version()`): 11.3
+ value: |
+ - GCC version:
+ - Torch version:
+ - Linux system version:
+ - CUDA version:
+ - Torch's CUDA version (as per `torch.cuda.version()`):
+ render: markdown
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: To Reproduce
+ description: Provide the steps and details to reproduce the behavior.
+ placeholder: |
+ 1. Describe your environment setup, including any specific version requirements.
+ 2. Clearly state the steps you took to trigger the error, including the specific code you executed.
+ 3. Identify the file and line number where the error occurred, along with the full traceback of the error. Make sure to have `NCCL_DEBUG=INFO` and `CUDA_LAUNCH_BLOCKING=True` set to get accurate debug information.
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Expected Behavior
+ description: Describe what you expected to happen when you executed the code.
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Screenshots
+ description: If applicable, please add screenshots to help explain your problem.
+ validations:
+ required: false
+- type: textarea
+ attributes:
+ label: Additional Information
+ description: |
+ Provide any other relevant context or information about the problem here.
+ Links? References? Anything that will give us more context about the issue you are encountering!
+ Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in.
+ validations:
+ required: false
+- type: checkboxes
+ attributes:
+ label: Confirmation
+ description: Please confirm that you have reviewed all of the above requirements and verified the information provided before submitting this issue.
+ options:
+ - label: I have reviewed and verified all the information provided in this report.
+ validations:
+ required: true
+
diff --git a/examples/BMTrain/.github/ISSUE_TEMPLATE/build_err.yml b/examples/BMTrain/.github/ISSUE_TEMPLATE/build_err.yml
new file mode 100644
index 00000000..940fd7ed
--- /dev/null
+++ b/examples/BMTrain/.github/ISSUE_TEMPLATE/build_err.yml
@@ -0,0 +1,94 @@
+name: 🛠️ Build Error
+description: Report a build error for this project
+title: "[BUILD ERROR] "
+labels: ["Build ERR"]
+body:
+- type: checkboxes
+ id: prev_issue
+ attributes:
+ label: Is there an existing issue for this?
+ description: Please search to see if an issue already exists for the build error you encountered.
+ options:
+ - label: I have searched the existing issues
+ required: true
+- type: textarea
+ attributes:
+ label: Description of the Build Error
+ description: Provide a clear and concise description of what the build error is.
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Expected Behavior
+ description: Provide a clear and concise description of what you expected to happen.
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: To Reproduce
+ description: Describe the steps you took to trigger the build error. Include any commands you executed or files you modified.
+ placeholder: |
+ 1. Go to '...'
+ 2. Click on '....'
+ 3. Scroll down to '....'
+ 4. See error
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Environment Information
+ description: |
+ Provide details about your environment.
+ Example:
+ - Operating System version: Ubuntu 20.04
+ - GCC version: 9.3.0
+ - Pybind version: 2.8.1
+ - CUDA version: 11.4
+ - NVIDIA NCCL CU11 version: 2.14.3
+ - CMake version: 3.21.2
+ - Pip version: 22.0.0
+ value: |
+ - Operating System version:
+ - GCC version:
+ - Pybind version:
+ - CUDA version:
+ - NVIDIA NCCL CU11 version:
+ - CMake version:
+ - Pip version:
+ render: markdown
+ validations:
+ required: true
+- type: dropdown
+ attributes:
+ label: Installation Method
+ description: Please indicate if the error occurred during source code installation or when using the pip install .whl method.
+ options:
+ - Source Code Installation
+ - Pip Install .whl Method
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Full Error Traceback
+ description: Provide the complete error traceback.
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Additional Information
+ description: |
+ Provide any other relevant context or information about the problem here.
+ Links? References? Anything that will give us more context about the issue you are encountering!
+ Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in.
+ validations:
+ required: false
+- type: checkboxes
+ id: confirm
+ attributes:
+ label: Confirmation
+ description: Please confirm that you have reviewed all of the above requirements and verified the information provided before submitting this report.
+ options:
+ - label: I have reviewed and verified all the information provided in this report.
+ validations:
+ required: true
+
diff --git a/examples/BMTrain/.github/ISSUE_TEMPLATE/features_request.yml b/examples/BMTrain/.github/ISSUE_TEMPLATE/features_request.yml
new file mode 100644
index 00000000..769948c2
--- /dev/null
+++ b/examples/BMTrain/.github/ISSUE_TEMPLATE/features_request.yml
@@ -0,0 +1,30 @@
+name: 🚀Feature Request
+description: Suggest an idea for this project
+title: "[Feature] "
+labels: ["enhancement"]
+assignees: []
+body:
+- type: textarea
+ attributes:
+ label: Is your feature request related to a problem? Please describe.
+ description: "A clear and concise description of what the problem is. Example: I'm always frustrated when..."
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Describe the solution you'd like
+ description: "A clear and concise description of what you want to happen."
+ validations:
+ required: true
+- type: textarea
+ attributes:
+ label: Describe alternatives you've considered
+ description: "A clear and concise description of any alternative solutions or features you've considered."
+ validations:
+ required: false
+- type: textarea
+ attributes:
+ label: Additional context
+ description: "Add any other context or screenshots about the feature request here."
+ validations:
+ required: false
diff --git a/examples/BMTrain/.github/pull_request_template.md b/examples/BMTrain/.github/pull_request_template.md
new file mode 100644
index 00000000..87a5f9b0
--- /dev/null
+++ b/examples/BMTrain/.github/pull_request_template.md
@@ -0,0 +1,29 @@
+## Pull Request Template
+
+### Issue Reference
+Please mention the issue number if applicable, or write "N/A" if it's a new feature.
+
+Issue #...
+
+### Description
+Please describe your changes in detail. If it resolves an issue, please state how it resolves it.
+
+### Type of Change
+- [ ] Bug fix (non-breaking change which fixes an issue)
+- [ ] New feature (non-breaking change which adds functionality)
+- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
+- [ ] This change requires a documentation update
+
+### How Has This Been Tested?
+Please describe the tests that you ran to verify your changes. Provide instructions so we can reproduce.
+
+### Checklist
+- [ ] I have read the [CONTRIBUTING](../../CONTRIBUTING.md) document.
+- [ ] My code follows the code style of this project.
+- [ ] My change requires a change to the documentation.
+- [ ] I have updated the documentation accordingly.
+- [ ] I have added tests to cover my changes.
+- [ ] All new and existing tests passed.
+
+### Additional Information
+Any additional information, configuration, or data that might be necessary for the review.
diff --git a/examples/BMTrain/.github/workflows/build.yml b/examples/BMTrain/.github/workflows/build.yml
new file mode 100644
index 00000000..11aa61f6
--- /dev/null
+++ b/examples/BMTrain/.github/workflows/build.yml
@@ -0,0 +1,35 @@
+name: Build
+
+on:
+ pull_request_target:
+ types: [opened, reopened, synchronize]
+ branches:
+ - 'dev'
+ - 'main'
+ push:
+ branches:
+ - 'dev'
+
+jobs:
+ build-archive-wheel:
+
+ uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@main
+ secrets: inherit
+
+ fake-publish:
+ needs: build-archive-wheel
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Set Up the Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.9
+
+ - name: Download distribution files
+ uses: actions/download-artifact@v4
+ with:
+ name: dist
+ path: dist
diff --git a/examples/BMTrain/.github/workflows/build_whl.yml b/examples/BMTrain/.github/workflows/build_whl.yml
new file mode 100644
index 00000000..9116b598
--- /dev/null
+++ b/examples/BMTrain/.github/workflows/build_whl.yml
@@ -0,0 +1,89 @@
+name: Build wheels in docker and archive
+
+on:
+ workflow_call:
+ secrets:
+ DOCKERHUB_TOKEN:
+ required: true
+ DOCKERHUB_USERNAME:
+ required: true
+
+jobs:
+ build:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ['37', '38', '39', '310', '311']
+
+
+ steps:
+
+ - name: Check the disk space and clear unnecessary library
+ run: |
+ rm -rf /home/runner/work/BMTrain/BMTrain/dist
+ sudo rm -rf /usr/share/dotnet
+ sudo rm -rf /opt/ghc
+ sudo rm -rf "/usr/local/share/boost"
+ sudo rm -rf "$AGENT_TOOLSDIRECTORY"
+ df -hl
+
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Login to DockerHub
+ uses: docker/login-action@v2
+ with:
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
+
+ - name: Pull Docker image
+ run: docker pull pytorch/manylinux-cuda113:latest
+
+ - name: Run Docker image and execute script
+ run: |
+ version=${{ matrix.python-version }}
+ docker run -e BUILD_DOCKER_ENV=1 -e CUDACXX=/usr/local/cuda-11.3/bin/nvcc -e PATH="/opt/rh/devtoolset-9/root/usr/bin:$PATH" -e LD_LIBRARY_PATH="/opt/rh/devtoolset-9/root/usr/lib64:/opt/rh/devtoolset-9/root/usr/lib:/usr/local/nvidia/lib:/usr/local/nvidia/lib64:$LD_LIBRARY_PATH" -v ${{ github.workspace }}:/workspace/BMTrain -i pytorch/manylinux-cuda113:latest /bin/bash -c "cd /workspace/BMTrain;/opt/python/cp${version}*/bin/pip install build; /opt/python/cp${version}*/bin/python -m build .;for file in dist/*-linux_x86_64.whl; do mv \"\$file\" \"\${file//-linux_x86_64/-manylinux2014_x86_64}\"; done"
+
+ - name: Upload wheels as artifacts
+ uses: actions/upload-artifact@v4
+ with:
+ name: wheels_py${{ matrix.python-version }}
+ path: dist/*.whl
+
+ - name: Upload source distribution (only once)
+ if: matrix.python-version == '37' # Only upload source distribution once
+ uses: actions/upload-artifact@v4
+ with:
+ name: source_dist
+ path: dist/*.tar.gz
+
+ archive:
+ runs-on: ubuntu-latest
+ needs: build
+ steps:
+ - name: Download all wheels
+ uses: actions/download-artifact@v4
+ with:
+ path: wheels
+ pattern: wheels_py*
+
+ - name: Download source distribution
+ uses: actions/download-artifact@v4
+ with:
+ path: source_dist
+ name: source_dist
+
+ - name: Combine all wheels into a single directory
+ run: |
+ mkdir -p dist
+ find wheels -name '*.whl' -exec mv {} dist/ \;
+ find source_dist -name '*.tar.gz' -exec mv {} dist/ \;
+
+ - name: Archive distribution files
+ uses: actions/upload-artifact@v4
+ with:
+ name: dist
+ path: |
+ dist/*.tar.gz
+ dist/*.whl
+ overwrite: true
\ No newline at end of file
diff --git a/examples/BMTrain/.github/workflows/publish.yaml b/examples/BMTrain/.github/workflows/publish.yaml
new file mode 100644
index 00000000..fd9b8c50
--- /dev/null
+++ b/examples/BMTrain/.github/workflows/publish.yaml
@@ -0,0 +1,41 @@
+name: Build and Publish to PyPI
+
+on:
+ push:
+ tags:
+
+ - "v*.*.*"
+
+jobs:
+
+ build-archive-wheel:
+ uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@main
+ secrets:
+ DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
+ DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
+
+ publish:
+ needs: build-archive-wheel
+ runs-on: ubuntu-latest
+ steps:
+ - name: Set Up the Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.9
+
+ - name: Install twine
+ run: python -m pip install twine
+
+ - name: Download distribution files
+ uses: actions/download-artifact@v4
+ with:
+ name: dist
+ path: dist
+
+ - name: Publish to PyPI
+ env:
+ TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
+ TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
+ run: |
+ cd dist
+ python -m twine upload *.tar.gz *.whl
diff --git a/examples/BMTrain/.github/workflows/release.yml b/examples/BMTrain/.github/workflows/release.yml
new file mode 100644
index 00000000..bafc2173
--- /dev/null
+++ b/examples/BMTrain/.github/workflows/release.yml
@@ -0,0 +1,44 @@
+name: Publish release in Github
+
+on:
+ push:
+ tags:
+ - "v*.*.*"
+
+jobs:
+
+ build-archive-wheel:
+
+ uses: OpenBMB/BMTrain/.github/workflows/build_whl.yml@main
+ secrets:
+ DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
+ DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
+
+ publish:
+ needs: build-archive-wheel
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout code
+ uses: actions/checkout@v3
+
+ - name: Set Up the Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: 3.9
+
+ - name: Download distribution files
+ uses: actions/download-artifact@v4
+ with:
+ name: dist
+ path: dist
+
+ - name: Upload Distribution Files to Existing Release
+ uses: softprops/action-gh-release@v1
+ with:
+ files: |
+ dist/*.tar.gz
+ dist/*.whl
+ tag_name: ${{ github.ref_name }} # 使用当前触发工作流的 tag
+ token: ${{ secrets.RELEASE_TOKEN }}
+ env:
+ GITHUB_REPOSITORY: OpenBMB/BMTrain
diff --git a/examples/BMTrain/.gitignore b/examples/BMTrain/.gitignore
new file mode 100644
index 00000000..75138102
--- /dev/null
+++ b/examples/BMTrain/.gitignore
@@ -0,0 +1,155 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+*.pt
+
+*.npy
+
+bminference/version.py
+
+.DS_Store
+
+log
+*.qdrep
+.vscode
+
+!bmtrain/dist
+tests/test_log.txt
+tests/*.opt
+tests/*.ckp
\ No newline at end of file
diff --git a/examples/BMTrain/CMakeLists.txt b/examples/BMTrain/CMakeLists.txt
new file mode 100644
index 00000000..e027e7da
--- /dev/null
+++ b/examples/BMTrain/CMakeLists.txt
@@ -0,0 +1,65 @@
+cmake_minimum_required(VERSION 3.18)
+project(bmtrain)
+enable_language(C)
+enable_language(CXX)
+set(CMAKE_CUDA_ARCHITECTURES "61;62;70;72;75;80")
+enable_language(CUDA)
+set(CMAKE_CXX_STANDARD 14)
+set(CMAKE_CXX_STANDARD_REQUIRED True)
+set(CMAKE_CUDA_STANDARD 14)
+set(CMAKE_CUDA_STANDARD_REQUIRED True)
+
+set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_61,code=sm_61 -gencode=arch=compute_62,code=sm_62 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_72,code=sm_72 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80")
+
+if(NOT DEFINED ENV{BUILD_DOCKER_ENV} OR "$ENV{BUILD_DOCKER_ENV}" STREQUAL "0")
+ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_86,code=sm_86")
+ set(AVX_FLAGS "${AVX_FLAGS} -march=native")
+else()
+ message("Building in docker environment, skipping compute_86 and enable all avx flag")
+ set(AVX_FLAGS "${AVX_FLAGS} -mavx -mfma -mf16c -mavx512f")
+endif()
+
+set(CMAKE_BUILD_RPATH $ORIGIN)
+set(CMAKE_INSTALL_RPATH $ORIGIN)
+set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/)
+
+find_package(NCCL REQUIRED)
+find_package(Python ${PYTHON_VERSION} EXACT COMPONENTS Interpreter Development.Module REQUIRED)
+message (STATUS "Python_EXECUTABLE: ${Python_EXECUTABLE}")
+execute_process(COMMAND ${Python_EXECUTABLE} "-c"
+ "import pybind11; print(pybind11.get_cmake_dir())"
+ OUTPUT_VARIABLE PYBIND11_CMAKE_DIR
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
+message (STATUS "PYBIND11_CMAKE_DIR: ${PYB
+IND11_CMAKE_DIR}")
+list(APPEND CMAKE_PREFIX_PATH ${PYBIND11_CMAKE_DIR})
+find_package(pybind11 REQUIRED)
+
+message (STATUS "CMAKE_INSTALL_RPATH: ${CMAKE_INSTALL_RPATH}")
+
+file(GLOB_RECURSE SOURCES "csrc/*.cpp")
+file(GLOB_RECURSE CUDA_SOURCES "csrc/cuda/*.cu")
+
+
+pybind11_add_module(C ${SOURCES} ${CUDA_SOURCES})
+
+set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${AVX_FLAGS}")
+
+target_link_libraries(C PRIVATE
+ "-Wl,-Bsymbolic"
+ "-Wl,-Bsymbolic-functions"
+ ${NCCL_LIBRARIES}
+)
+target_include_directories(C PRIVATE ${NCCL_INCLUDE_DIRS})
+target_compile_definitions(C
+ PRIVATE VERSION_INFO=${EXAMPLE_VERSION_INFO})
+
+set_target_properties(C PROPERTIES CUDA_ARCHITECTURES "61;62;70;72;75;80")
+
+target_include_directories(C
+ PRIVATE "csrc/include"
+ PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
+)
+
+
+
diff --git a/examples/BMTrain/CONTRIBUTING.md b/examples/BMTrain/CONTRIBUTING.md
new file mode 100644
index 00000000..e1c4458f
--- /dev/null
+++ b/examples/BMTrain/CONTRIBUTING.md
@@ -0,0 +1,55 @@
+# Contributing to BMTrain
+
+We welcome everyone's effort to make the community and the package better. You are welcomed to propose an issue, make a pull request or help others in the community. All of the efforts are appreciated!
+
+There are many ways that you can contribute to BMTrain:
+
+- ✉️ Submitting an issue.
+- ⌨️ Making a pull request.
+- 🤝 Serving the community.
+
+## Submitting an issue
+You can submit an issue if you find bugs or require new features and enhancements. Here are some principles:
+
+1. **Language.** It is better to write your issue in English so that more people can understand and help you more conveniently.
+2. **Search.** It is a good habit to search existing issues using the search bar of GitHub. Make sure there are no duplicated or similar issues with yours and if yes, check their solutions first.
+3. **Format.** It is also very helpful to write the issue with a good writing style. We provide templates of common types of issues and everyone is encouraged to use these templates. If the templates do not fit in your issue, feel free to open a blank one.
+4. **Writing style.** Write your issues in clear and concise words. It is also important to provide enough details for others to help. For example in a bug report, it is better to provide your running environment and minimal lines of code to reproduce it.
+
+## Making a pull request (PR)
+You can also write codes to contribute. The codes may include a bug fix, a new enhancement, or a new running example. Here we provide the steps to make a pull request:
+
+1. **Combine the PR with an issue.** Make us and others know what you are going to work on. If your codes try to solve an existing issue, you should comment on the issue and make sure there are no others working on it. If you are proposing a new enhancement, submit an issue first and we can discuss it with you before you work on it.
+
+2. **Fork the repository.** Fork the repository to your own GitHub space by clicking the "Fork" button. Then clone it on your disk and set the remote repo:
+```git
+$ git clone https://github.com//BMTrain.git
+$ cd BMTrain
+$ git remote add upstream https://github.com/OpenBMB/BMTrain.git
+```
+
+3. **Write your code.** Change to a new branch to work on your modifications.
+```git
+$ git checkout -b your-branch-name
+```
+You are encouraged to think up a meaningful and descriptive name for your branch.
+
+4. **Make a pull request.** After you finish coding, you should first rebase your code and solve the conflicts with the remote codes:
+```git
+$ git fetch upstream
+$ git rebase upstream/main
+```
+Then you can push your codes to your own repo:
+```git
+$ git push -u origin your-branch-name
+```
+Finally, you can make the pull request from your GitHub repo and merge it with ours. Your codes will be merged into the main repo after our code review.
+
+
+## Serving the community
+
+Besides submitting issues and PRs, you can also join our community and help others. Efforts like writing the documents, answering questions as well as discussing new features are appreciated and welcomed. It will also be helpful if you can post your opinions and feelings about using our package on social media.
+
+We are now developing a reward system and all your contributions will be recorded and rewarded in the future.
+
+
diff --git a/examples/BMTrain/Dockerfile b/examples/BMTrain/Dockerfile
new file mode 100644
index 00000000..8e6cbddf
--- /dev/null
+++ b/examples/BMTrain/Dockerfile
@@ -0,0 +1,20 @@
+FROM nvidia/cuda:10.2-devel
+WORKDIR /build
+RUN apt update && apt install -y --no-install-recommends \
+ build-essential \
+ python3-dev \
+ python3-pip \
+ python3-setuptools \
+ python3-wheel
+RUN pip3 install torch==1.10.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
+RUN pip3 install numpy -i https://pypi.tuna.tsinghua.edu.cn/simple
+RUN apt install iputils-ping opensm libopensm-dev libibverbs1 libibverbs-dev -y --no-install-recommends
+ENV TORCH_CUDA_ARCH_LIST=6.1;7.0;7.5
+ENV BMT_AVX512=1
+ADD other_requirements.txt other_requirements.txt
+RUN pip3 install --upgrade pip && pip3 install -r other_requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
+ADD . .
+RUN python3 setup.py install
+
+WORKDIR /root
+ADD example example
\ No newline at end of file
diff --git a/examples/BMTrain/LICENSE b/examples/BMTrain/LICENSE
new file mode 100644
index 00000000..7ad7f39e
--- /dev/null
+++ b/examples/BMTrain/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2022 OpenBMB
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/examples/BMTrain/MANIFEST.in b/examples/BMTrain/MANIFEST.in
new file mode 100644
index 00000000..a6f97fa4
--- /dev/null
+++ b/examples/BMTrain/MANIFEST.in
@@ -0,0 +1,4 @@
+graft csrc
+include CMakeLists.txt
+graft cmake
+
diff --git a/examples/BMTrain/README-ZH.md b/examples/BMTrain/README-ZH.md
new file mode 100644
index 00000000..d36953a2
--- /dev/null
+++ b/examples/BMTrain/README-ZH.md
@@ -0,0 +1,369 @@
+
+
+## 最新动态
+- 2022/06/14 **BMTrain** [0.1.7](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.7) 发布。支持了ZeRO-2优化!
+- 2022/03/30 **BMTrain** [0.1.2](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.2) 发布。适配了[OpenPrompt](https://github.com/thunlp/OpenPrompt)和 [OpenDelta](https://github.com/thunlp/OpenDelta)工具包。
+- 2022/03/16 **BMTrain** [0.1.1](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.1) 公开发布了第一个稳定版本,修复了 beta 版本中的一些问题。
+- 2022/02/11 **BMTrain** [0.0.15](https://github.com/OpenBMB/BMTrain/releases/tag/0.0.15) 公开发布了第一个 beta 版本。
+
+
+
+## 总览
+
+BMTrain 是一个高效的大模型训练工具包,可以用于训练数百亿参数的大模型。BMTrain 可以在分布式训练模型的同时,能够保持代码的简洁性。
+
+
+
+## 文档
+我们的[文档](https://bmtrain.readthedocs.io/en/latest/index.html)提供了关于工具包的更多信息。
+
+
+
+## 安装
+
+- 用 pip 安装(推荐): ``pip install bmtrain``
+
+- 从源代码安装: 下载工具包,然后运行 ``pip install .`` (setup.py的安装方式将会在未来被setuptools弃用)
+
+安装 BMTrain 可能需要花费数分钟的时间,因为在安装时需要编译 c/cuda 源代码。
+我们推荐直接在训练环境中编译 BMTrain,以避免不同环境带来的潜在问题。
+
+
+
+
+## 使用说明
+
+### 步骤 1: 启用 BMTrain
+
+首先,你需要在代码开头初始化 BMTrain。正如在使用 PyTorch 的分布式训练模块需要在代码开头使用 **init_process_group** 一样,使用 BMTrain 需要在代码开头使用 **init_distributed**。
+
+```python
+import bmtrain as bmt
+bmt.init_distributed(
+ seed=0,
+ zero_level=3, # 目前支持2和3
+ # ...
+)
+```
+
+**注意:** 使用 BMTrain 时请不要使用 PyTorch 自带的 `distributed` 模块,包括 `torch.distributed.init_process_group` 以及相关通信函数。
+
+### 步骤 2: 使用 ZeRO 优化
+
+使用ZeRO优化需要对模型代码进行简单替换:
+
+* `torch.nn.Module` -> `bmtrain.DistributedModule`
+* `torch.nn.Parameter` -> `bmtrain.DistributedParameter`
+
+并在 transformer 模块上使用 `bmtrain.CheckpointBlock`。
+
+下面是一个例子:
+
+**原始代码**
+
+```python
+import torch
+class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.param = torch.nn.Parameter(torch.empty(1024))
+ self.module_list = torch.nn.ModuleList([
+ SomeTransformerBlock(),
+ SomeTransformerBlock(),
+ SomeTransformerBlock()
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+**替换后代码**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule): # 修改这里
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024)) # 修改这里
+ self.module_list = torch.nn.ModuleList([
+ bmt.CheckpointBlock(SomeTransformerBlock()), # 修改这里
+ bmt.CheckpointBlock(SomeTransformerBlock()), # 修改这里
+ bmt.CheckpointBlock(SomeTransformerBlock()) # 修改这里
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+### 步骤 3: 通信优化
+
+为了进一步缩短通信额外开销,将通信与运算时间重叠,可以使用 `TransformerBlockList` 来进一步优化。
+
+在使用时需要对代码进行简单替换:
+
+* `torch.nn.ModuleList` -> `bmtrain.TransformerBlockList`
+* `for module in self.module_list: x = module(x, ...)` -> `x = self.module_list(x, ...)`
+
+**原始代码**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule):
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024))
+ self.module_list = torch.nn.ModuleList([
+ bmt.CheckpointBlock(SomeTransformerBlock()),
+ bmt.CheckpointBlock(SomeTransformerBlock()),
+ bmt.CheckpointBlock(SomeTransformerBlock())
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+**替换后代码**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule):
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024))
+ self.module_list = bmt.TransformerBlockList([ # 修改这里
+ bmt.CheckpointBlock(SomeTransformerBlock()),
+ bmt.CheckpointBlock(SomeTransformerBlock()),
+ bmt.CheckpointBlock(SomeTransformerBlock())
+ ])
+
+ def forward(self):
+ x = self.param
+ x = self.module_list(x, 1, 2, 3) # 修改这里
+ return x
+
+```
+
+### 步骤 4: 运行分布式训练代码
+
+BMTrain 使用 PyTorch 原生的分布式训练启动器,你可以根据 PyTorch 版本选择下列命令中的一个。
+
+* `${MASTER_ADDR}` 为主节点的 IP 地址
+* `${MASTER_PORT}` 为主节点的端口
+* `${NNODES}` 为节点数量
+* `${GPU_PER_NODE}` 为每个节点的 GPU 数量
+* `${NODE_RANK}` 为本节点的 rank
+
+#### torch.distributed.launch
+```shell
+$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.py
+```
+
+#### torchrun
+
+```shell
+$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.py
+```
+
+更多信息请参考 PyTorch [官方文档](https://pytorch.org/docs/stable/distributed.html#launch-utility)。
+
+## 样例
+
+我们提供了一个使用 BMTrain 训练 GPT-2 的[样例](https://github.com/OpenBMB/BMTrain/tree/main/example)。
+代码主要包含以下几个部分。
+
+### 第 1 部分: 模型定义
+
+```
+├── layers
+│ ├── attention.py
+│ ├── embedding.py
+│ ├── feedforward.py
+│ ├── __init__.py
+│ ├── layernorm.py
+│ └── linear.py
+└── models
+ ├── gpt.py
+ └── __init__.py
+```
+
+上面是代码的目录结构。
+
+我们定义了 GPT-2 需要的所有模型层,并使用 BMTrain 的 `DistributedModule` 和 `DistributedParameter` 来启用 ZeRO 优化。
+
+### 第 2 部分: 初始化 BMTrain
+
+```python
+bmtrain.init_distributed(seed=0)
+
+model = GPT(
+ num_layers=8,
+ vocab_size=10240,
+ dim_model=2560,
+ dim_head=80,
+ num_heads=32,
+ dim_ff=8192,
+ max_distance=1024,
+ bias=True,
+ dtype=torch.half
+)
+
+bmtrain.init_parameters(model) # 或者使用`bmtrain.load`加载checkpoint
+
+# ... 其他初始化(例如数据集) ...
+```
+
+`bmtrain.init_distributed(seed=0)` 用于初始化分布式训练环境,并设置随机数种子便于复现。
+
+`bmtrain.init_parameters(model)` 用于初始化模型的分布式参数。
+
+### 第 3 部分: 初始化优化器和学习率调整策略
+
+```python
+loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100)
+optimizer = bmtrain.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2)
+lr_scheduler = bmtrain.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0)
+```
+
+BMTrain 支持**所有** PyTorch 原生的优化器和损失函数,同时你也可以使用 BMTrain 提供的融合(fused)优化器用于混合精度训练。
+
+此外,在 `bmtrain.lr_scheduler` 中 BMTrain 也提供了常见的学习率调整策略。
+
+### 第 4 部分: 训练
+
+```python
+# 新建优化器管理器实例
+optim_manager = bmtrain.optim.OptimManager(loss_scale=1024)
+# 将所有的 optimzer 及(可选)其对应的 lr_scheduler 收入优化器管理器管理。
+optim_manager.add_optimizer(optimizer, lr_scheduler)
+# 可以再次调用 add_optimizer 加入其他优化器
+
+for iteration in range(1000):
+ # ... 为每个rank加载数据 ...
+
+ # 前向传播并计算梯度
+ pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1)
+ logits = model(
+ enc_input,
+ pos,
+ pos < enc_length[:, None]
+ )
+ batch, seq_len, vocab_out_size = logits.size()
+
+ loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len))
+
+ global_loss = bmtrain.sum_loss(loss).item() # 聚合所有rank上的损失, 仅用于输出训练日志
+
+ # 梯度清零
+ optim_manager.zero_grad() # 为每个 optimizer 调用 zero_grad
+
+ # 损失缩放和反向传播
+ optim_manager.backward(loss)
+
+ # 梯度裁剪
+ grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=1.0)
+
+ # 更新参数
+ optim_manager.step()
+
+ # ... 保存checkpoint、打印日志 ...
+```
+
+这部分代码略有些长,但写起来就像常见的训练代码一样,你不需要为分布式训练调整太多的代码。
+
+你可以根据代码中的注释来了解各部分代码的作用。
+
+唯一需要说明的是 `optim_manager`。在使用 BMTrain 后,优化器的部分相关操作需要有一些细节上的调整。我们在 `optim_manager` 帮你实现了这些细节, 你只需要通过 `add_optimizer` 将优化器和学习率调整策略收入 `optim_manager` 管理,并由 `optim_manger` 代为执行 `zero_grad()`, `backward()`, `clip_grad_norm()` 和 `step()` 等操作。
+
+如果你没有使用混合精度训练,你可以不用损失缩放,只需要将 `OptimManger(loss_scale=None)` 构造函数中 `loss_scale` 置为 None 即可, 这也是 `OptimManager` 的默认构造参数。
+
+如果你使用了混合精度训练,**损失缩放**是混合精度训练中的一项常用技术,我们在 `optim_manager.backward(loss)` 帮你对 `loss` 进行了放缩,用于避免梯度下溢。只需要将 `OptimManger` 构造函数中 `loss_scale` 置为一个浮点数即可。 `loss_scale` 会在训练过程中根据梯度进行自适应的调整。
+
+
+
+## 性能
+
+我们训练了一个有130亿参数的 GPT-2 模型,使用了4台服务器,每台服务器有8张V100显卡。我们测试了训练过程中每个GPU的吞吐量(每个GPU每秒处理的样本数),结果见下表。
+
+模型结构:
+* 40层
+* 128个注意力头
+* 5120的隐藏层维数
+* 512的序列长度
+
+
+| batch size | 8 | 16 | 24 | 32 |
+|-------------|-------|-------|:------|:------|
+| BMTrain | 24.15 | 26.94 | 29.42 | 28.28 |
+| ZeRO3(mp=1) | 14.88 | 21.69 | 24.38 | - |
+| ZeRO3(mp=4) | 15.51 | - | - | - |
+| ZeRO3(mp=8) | 15.51 | - | - | - |
+| ZeRO2(mp=1) | - | - | - | - |
+| ZeRO2(mp=4) | 22.85 | - | - | - |
+| ZeRO2(mp=8) | 21.33 | - | - | - |
+
+**ZeROa(mp=b)** 表示 DeepSpeed + Megatron ZeRO stage a 和 model parallelism = b。
+
+表格中的 **-** 表示超出显存。
+
+## 模型支持
+
+我们已经将大多数常见的 NLP 模型移植到了 BMTrain 中。你可以在 [ModelCenter](https://github.com/OpenBMB/ModelCenter) 项目中找到支持模型的列表。
+
+## 开源社区
+欢迎贡献者参照我们的[贡献指南](https://github.com/OpenBMB/BMTrain/blob/master/CONTRIBUTING.md)贡献相关代码。
+
+您也可以在其他平台与我们沟通交流:
+- QQ群: 735930538
+- 官方网站: https://www.openbmb.org
+- 微博: http://weibo.cn/OpenBMB
+- Twitter: https://twitter.com/OpenBMB
+
+## 开源许可
+
+该工具包使用[Apache 2.0](https://github.com/OpenBMB/BMTrain/blob/main/LICENSE)开源许可证。
+
+## 其他说明
+
+`BMTrain` 工具包对 PyTorch 进行了底层修改,如果你的程序输出了意料之外的结果,可以在 issue 中提交相关信息。
diff --git a/examples/BMTrain/README.md b/examples/BMTrain/README.md
new file mode 100644
index 00000000..134929f5
--- /dev/null
+++ b/examples/BMTrain/README.md
@@ -0,0 +1,375 @@
+
+
+## What's New
+- 2024/02/26 **BMTrain** [1.0.0](https://github.com/OpenBMB/BMTrain/releases/tag/v1.0.0) released. Code refactoring and Tensor parallel support. See the detail in [update log](docs/UPDATE_1.0.0.md)
+- 2023/08/17 **BMTrain** [0.2.3](https://github.com/OpenBMB/BMTrain/releases/tag/v0.2.3) released. See the [update log](docs/UPDATE_0.2.3.md).
+- 2022/12/15 **BMTrain** [0.2.0](https://github.com/OpenBMB/BMTrain/releases/tag/0.2.0) released. See the [update log](docs/UPDATE_0.2.0.md).
+- 2022/06/14 **BMTrain** [0.1.7](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.7) released. ZeRO-2 optimization is supported!
+- 2022/03/30 **BMTrain** [0.1.2](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.2) released. Adapted to [OpenPrompt](https://github.com/thunlp/OpenPrompt)and [OpenDelta](https://github.com/thunlp/OpenDelta).
+- 2022/03/16 **BMTrain** [0.1.1](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.1) has publicly released the first stable version, which fixes many bugs that were in the beta version.
+- 2022/02/11 **BMTrain** [0.0.15](https://github.com/OpenBMB/BMTrain/releases/tag/0.0.15) has publicly released the first beta version.
+
+
+
+## Overview
+
+BMTrain is an efficient large model training toolkit that can be used to train large models with tens of billions of parameters. It can train models in a distributed manner while keeping the code as simple as stand-alone training.
+
+
+
+## Documentation
+Our [documentation](https://bmtrain.readthedocs.io/en/latest/index.html) provides more information about the package.
+
+
+
+## Installation
+
+- From pip (recommend) : ``pip install bmtrain``
+
+- From source code: download the package and run ``pip install .``
+
+Installing BMTrain may take a few to ten minutes, as it requires compiling the c/cuda source code at the time of installation.
+We recommend compiling BMTrain directly in the training environment to avoid potential problems caused by the different environments.
+
+
+
+## Usage
+
+### Step 1: Initialize BMTrain
+
+Before you can use BMTrain, you need to initialize it at the beginning of your code. Just like using the distributed module of PyTorch requires the use of **init_process_group** at the beginning of the code, using BMTrain requires the use of **init_distributed** at the beginning of the code.
+
+```python
+import bmtrain as bmt
+bmt.init_distributed(
+ seed=0,
+ # ...
+)
+```
+
+**NOTE:** Do not use PyTorch's distributed module and its associated communication functions when using BMTrain.
+
+### Step 2: Enable ZeRO Optimization
+
+To enable ZeRO optimization, you need to make some simple replacements to the original model's code.
+
+* `torch.nn.Module` -> `bmtrain.DistributedModule`
+* `torch.nn.Parameter` -> `bmtrain.DistributedParameter`
+
+And wrap the transformer blocks with `bmtrain.Block`.
+
+Here is an example.
+
+**Original**
+
+```python
+import torch
+class MyModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.param = torch.nn.Parameter(torch.empty(1024))
+ self.module_list = torch.nn.ModuleList([
+ SomeTransformerBlock(),
+ SomeTransformerBlock(),
+ SomeTransformerBlock()
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+**Replaced**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule): # changed here
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024)) # changed here
+ self.module_list = torch.nn.ModuleList([
+ bmt.Block(SomeTransformerBlock(), zero_level=3), # changed here, support 2 and 3 now
+ bmt.Block(SomeTransformerBlock(), zero_level=3), # changed here, support 2 and 3 now
+ bmt.Block(SomeTransformerBlock(), zero_level=3) # changed here, support 2 and 3 now
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+### Step 3: Enable Communication Optimization
+
+
+To further reduce the extra overhead of communication and overlap communication with computing time, `TransformerBlockList` can be used for optimization.
+
+You can enable them by making the following substitutions to the code:
+
+* `torch.nn.ModuleList` -> `bmtrain.TransformerBlockList`
+* `for module in self.module_list: x = module(x, ...)` -> `x = self.module_list(x, ...)`
+
+**Original**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule):
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024))
+ self.module_list = torch.nn.ModuleList([
+ bmt.Block(SomeTransformerBlock()),
+ bmt.Block(SomeTransformerBlock()),
+ bmt.Block(SomeTransformerBlock())
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+**Replaced**
+
+```python
+import torch
+import bmtrain as bmt
+class MyModule(bmt.DistributedModule):
+ def __init__(self):
+ super().__init__()
+ self.param = bmt.DistributedParameter(torch.empty(1024))
+ self.module_list = bmt.TransformerBlockList([ # changed here
+ bmt.Block(SomeTransformerBlock()),
+ bmt.Block(SomeTransformerBlock()),
+ bmt.Block(SomeTransformerBlock())
+ ])
+
+ def forward(self):
+ x = self.param
+ for module in self.module_list:
+ x = module(x, 1, 2, 3)
+ return x
+
+```
+
+### Step 4: Launch Distributed Training
+
+BMTrain uses the same launch command as the distributed module of PyTorch.
+
+You can choose one of them depending on your version of PyTorch.
+
+* `${MASTER_ADDR}` means the IP address of the master node.
+* `${MASTER_PORT}` means the port of the master node.
+* `${NNODES}` means the total number of nodes.
+* `${GPU_PER_NODE}` means the number of GPUs per node.
+* `${NODE_RANK}` means the rank of this node.
+
+#### torch.distributed.launch
+```shell
+$ python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node ${GPU_PER_NODE} --nnodes ${NNODES} --node_rank ${NODE_RANK} train.py
+```
+
+#### torchrun
+
+```shell
+$ torchrun --nnodes=${NNODES} --nproc_per_node=${GPU_PER_NODE} --rdzv_id=1 --rdzv_backend=c10d --rdzv_endpoint=${MASTER_ADDR}:${MASTER_PORT} train.py
+```
+
+
+For more information, please refer to the [documentation](https://pytorch.org/docs/stable/distributed.html#launch-utility).
+
+## Example
+
+We provide an [example](https://github.com/OpenBMB/BMTrain/tree/main/example) of training GPT-2 based on BMTrain.
+The code mainly consists of the following parts.
+
+### Part 1: Model Definition
+
+```
+├── layers
+│ ├── attention.py
+│ ├── embedding.py
+│ ├── feedforward.py
+│ ├── __init__.py
+│ ├── layernorm.py
+│ └── linear.py
+└── models
+ ├── gpt.py
+ └── __init__.py
+```
+
+Above is the directory structure of the code in the part of Model Definition.
+
+We defined all the layers needed in GPT-2 and used BMTrain's `DistributedModule` and `DistributedParameter` to enable ZeRO optimization.
+
+### Part 2: BMTrain Initialization
+
+```python
+bmtrain.init_distributed(seed=0)
+
+model = GPT(
+ num_layers=8,
+ vocab_size=10240,
+ dim_model=2560,
+ dim_head=80,
+ num_heads=32,
+ dim_ff=8192,
+ max_distance=1024,
+ bias=True,
+ dtype=torch.half
+)
+
+bmtrain.init_parameters(model) # or loading checkpoint use `bmtrain.load`
+
+# ... other initialization (dataset) ...
+```
+
+`bmtrain.init_distributed(seed=0)` is used to initialize the distributed training environment and set the random seed for reproducibility.
+
+`bmtrain.init_parameters(model)` is used to initialize the distributed parameters of the model.
+
+### Part 3: Intialization of the Optimizer and LR Scheduler
+
+```python
+loss_func = torch.nn.CrossEntropyLoss(ignore_index=-100)
+optimizer = bmtrain.optim.AdamOffloadOptimizer(model.parameters(), weight_decay=1e-2)
+lr_scheduler = bmtrain.lr_scheduler.Noam(optimizer, start_lr=1e-3, warmup_iter=40, end_iter=1000, num_iter=0)
+```
+
+BMTrain supports *all* the PyTorch native optimizers and loss functions, and you can also use the fused optimizer provided by BMTrain for mixed-precision training.
+
+In addition, BMTrain also provides the common LRScheduler in the `bmtrain.lr_scheduler` module.
+
+### Part 4: Training Loop
+
+```python
+# create a new instance of optimizer manager
+optim_manager = bmtrain.optim.OptimManager(loss_scale=1024)
+# let optim_manager handle all the optimizer and (optional) their corresponding lr_scheduler
+optim_manager.add_optimizer(optimizer, lr_scheduler)
+# add_optimizer can be called multiple times to add other optimizers.
+
+for iteration in range(1000):
+ # ... load data for each rank ...
+
+ # forward pass and calculate loss
+ pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1)
+ logits = model(
+ enc_input,
+ pos,
+ pos < enc_length[:, None]
+ )
+ batch, seq_len, vocab_out_size = logits.size()
+
+ loss = loss_func(logits.view(batch * seq_len, vocab_out_size), targets.view(batch * seq_len))
+
+ global_loss = bmtrain.sum_loss(loss).item() # sum the loss across all ranks. This is only used for the training log
+
+ # zero grad
+ optim_manager.zero_grad() # calling zero_grad for each optimizer
+
+ # loss scale and backward
+ optim_manager.backward(loss)
+
+ # clip grad norm
+ grad_norm = optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=1.0)
+
+ # optimizer step
+ optim_manager.step()
+
+ # ... save checkpoint or print logs ...
+```
+
+The training loop part will be slightly longer, but just like a normal training loop, you don't need to adapt much to distributed training.
+
+You can follow the comments in the code to get an idea of what each section of code is doing.
+
+The only additional note is `optimizer`. After using BMTrain, some details in optimizers should be adjusted. We have implemented all those details needed in `optim_manager`. What you need is just letting `optim_manager` to handle all the optimizers by `add_optimizer`, and letting `optim_manager` do `zero_grad()`, `backward()`, `clip_grad_norm()` and `step()` instead.
+
+If you are not using the mixed-precision training, you can train without `loss_scale`. Just set `loss_scale` to None in the `__init__` function of `OptimManager(loss_scale=None)`, which is also the default.
+
+If you are using mixed-precision training, *loss scale* is the technique widely used in mixed precision training to prevent gradient underflow. By using `optim_manager.backward(loss)` to scale the `loss` before backward and set `loss_scale` to some floating number in the `__init__` function of `OptimManager`。The `loss_scale` would be adjusted adaptively based on the gradient during training.
+
+
+
+## Performance
+
+We trained a GPT-2 model with 13B parameters using 4 servers with 8 V100s on each server, and measured the throughput of each GPU during the training process (samples per GPU per second).
+
+Model structure:
+* 40 layers
+* 128 attention heads
+* 5120 hidden dimension
+* 512 sequence length
+
+
+| batch size | 8 | 16 | 24 | 32 |
+|-------------|-------|-------|:------|:------|
+| BMTrain | 24.15 | 26.94 | 29.42 | 28.28 |
+| ZeRO3(mp=1) | 14.88 | 21.69 | 24.38 | - |
+| ZeRO3(mp=4) | 15.51 | - | - | - |
+| ZeRO3(mp=8) | 15.51 | - | - | - |
+| ZeRO2(mp=1) | - | - | - | - |
+| ZeRO2(mp=4) | 22.85 | - | - | - |
+| ZeRO2(mp=8) | 21.33 | - | - | - |
+
+**ZeROa(mp=b)** means DeepSpeed + Megatron ZeRO stage a and model parallelism = b.
+
+**-** in the table means out of memory.
+
+## Supported Models
+
+We have migrated most of the common models in NLP to the BMTrain. You can find the list of supported models in the repo [ModelCenter](https://github.com/OpenBMB/ModelCenter).
+
+## Community
+We welcome everyone to contribute codes following our [contributing guidelines](https://github.com/OpenBMB/BMTrain/blob/master/CONTRIBUTING.md).
+
+You can also find us on other platforms:
+- QQ Group: 735930538
+- Website: https://www.openbmb.org
+- Weibo: http://weibo.cn/OpenBMB
+- Twitter: https://twitter.com/OpenBMB
+
+## License
+The package is released under the [Apache 2.0](https://github.com/OpenBMB/BMTrain/blob/master/LICENSE) License.
+
+## Other Notes
+
+`BMTrain` makes underlying changes to PyTorch, so if your program outputs unexpected results, you can submit information about it in an issue.
+
diff --git a/examples/BMTrain/Release.txt b/examples/BMTrain/Release.txt
new file mode 100644
index 00000000..7c8a41be
--- /dev/null
+++ b/examples/BMTrain/Release.txt
@@ -0,0 +1,9 @@
+## What's Changed
+* Using pytorch's hook mechanism to refactor ZeRO, checkpoint, pipeline, communication implementation by @zkh2016 in #128 #159
+* Add Bf16 support by @Achazwl in #136
+* Tensor parallel implementation by @Achazwl @zkh2016 @MayDomine in #153
+* Async save state_dict by @zkh2016 in #171
+* `AdamOffloadOptimizer` can save whole gathered state by @MayDomine in #184
+* New test for new version's bmtrain by @Achazwl @JerryYin777 @MayDomine
+**Full Changelog**: https://github.com/OpenBMB/BMTrain/compare/0.2.3...1.0.0
+
diff --git a/examples/BMTrain/bmtrain/__init__.py b/examples/BMTrain/bmtrain/__init__.py
new file mode 100644
index 00000000..f4ac3642
--- /dev/null
+++ b/examples/BMTrain/bmtrain/__init__.py
@@ -0,0 +1,26 @@
+from .utils import print_block, print_dict, print_rank, see_memory, load_nccl_pypi
+try:
+ from . import nccl
+except:
+ load_nccl_pypi()
+from .global_var import config, world_size, rank
+from .init import init_distributed
+
+from .parameter import DistributedParameter, ParameterInitializer
+from .layer import DistributedModule
+from .param_init import init_parameters, grouped_parameters
+from .synchronize import synchronize, sum_loss, wait_loader, gather_result
+from .block_layer import Block, TransformerBlockList
+from .wrapper import BMTrainModelWrapper
+from .pipe_layer import PipelineTransformerBlockList
+from . import debug
+from .store import save, load
+
+from . import loss
+from . import distributed
+from . import nn
+from . import optim
+from . import inspect
+from . import lr_scheduler
+
+CheckpointBlock = Block
diff --git a/examples/BMTrain/bmtrain/benchmark/__init__.py b/examples/BMTrain/bmtrain/benchmark/__init__.py
new file mode 100644
index 00000000..571d621f
--- /dev/null
+++ b/examples/BMTrain/bmtrain/benchmark/__init__.py
@@ -0,0 +1,3 @@
+from .all_gather import all_gather
+from .reduce_scatter import reduce_scatter
+from .send_recv import send_recv
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/benchmark/all_gather.py b/examples/BMTrain/bmtrain/benchmark/all_gather.py
new file mode 100644
index 00000000..b2f2ee7c
--- /dev/null
+++ b/examples/BMTrain/bmtrain/benchmark/all_gather.py
@@ -0,0 +1,28 @@
+from .. import nccl
+from .shape import SHAPES
+from ..global_var import config
+from ..utils import round_up, print_rank
+from .utils import format_size
+import torch
+
+def all_gather():
+ current_stream = torch.cuda.current_stream()
+ for shape in SHAPES:
+ global_size = round_up(shape, config['world_size'] * 2)
+ partition_size = global_size // config['world_size']
+
+ partition_tensor = torch.empty( partition_size // 2, dtype=torch.half, device="cuda" )
+ global_tensor = torch.empty( global_size // 2, dtype=torch.half, device="cuda" )
+
+ start_evt = torch.cuda.Event(enable_timing=True)
+ end_evt = torch.cuda.Event(enable_timing=True)
+
+ current_stream.record_event(start_evt)
+ nccl.allGather(partition_tensor.storage(), global_tensor.storage(), config['comm'])
+ current_stream.record_event(end_evt)
+ current_stream.synchronize()
+ time_usage = start_evt.elapsed_time(end_evt)
+
+ bw = global_size / 1024 / 1024 / 1024 * 1000 / time_usage
+ print_rank("All gather:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(global_size), time_usage, bw))
+
diff --git a/examples/BMTrain/bmtrain/benchmark/reduce_scatter.py b/examples/BMTrain/bmtrain/benchmark/reduce_scatter.py
new file mode 100644
index 00000000..75733556
--- /dev/null
+++ b/examples/BMTrain/bmtrain/benchmark/reduce_scatter.py
@@ -0,0 +1,28 @@
+from .. import nccl
+from .shape import SHAPES
+from ..global_var import config
+from ..utils import round_up, print_rank
+from .utils import format_size
+import torch
+
+def reduce_scatter():
+ current_stream = torch.cuda.current_stream()
+ for shape in SHAPES:
+ global_size = round_up(shape, config['world_size'])
+ partition_size = global_size // config['world_size']
+
+ partition_tensor = torch.empty( partition_size // 2, dtype=torch.half, device="cuda" )
+ global_tensor = torch.empty( global_size // 2, dtype=torch.half, device="cuda" )
+
+ start_evt = torch.cuda.Event(enable_timing=True)
+ end_evt = torch.cuda.Event(enable_timing=True)
+
+ current_stream.record_event(start_evt)
+ nccl.reduceScatter(global_tensor.storage(), partition_tensor.storage(), 'avg', config['comm'])
+ current_stream.record_event(end_evt)
+ current_stream.synchronize()
+ time_usage = start_evt.elapsed_time(end_evt)
+
+ bw = global_size / 1024 / 1024 / 1024 * 1000 / time_usage
+ print_rank("Reduce Scatter:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(global_size), time_usage, bw))
+
diff --git a/examples/BMTrain/bmtrain/benchmark/send_recv.py b/examples/BMTrain/bmtrain/benchmark/send_recv.py
new file mode 100644
index 00000000..e3c971e4
--- /dev/null
+++ b/examples/BMTrain/bmtrain/benchmark/send_recv.py
@@ -0,0 +1,31 @@
+from .. import nccl
+from .shape import SHAPES
+from ..global_var import config
+from ..utils import print_rank
+from .utils import format_size
+import torch
+def send_recv():
+ current_stream = torch.cuda.current_stream()
+ for shape in SHAPES:
+ send_size = shape
+
+ send_buffer = torch.empty( send_size // 2, dtype=torch.half, device="cuda" )
+ recv_buffer = torch.empty( send_size // 2, dtype=torch.half, device="cuda" )
+
+ start_evt = torch.cuda.Event(enable_timing=True)
+ end_evt = torch.cuda.Event(enable_timing=True)
+
+ current_stream.record_event(start_evt)
+ nccl.groupStart()
+ if config['rank'] in [0,2,4,6]:
+ nccl.send(send_buffer.storage(), config['rank']+1, config['comm'])
+ else:
+ nccl.recv(recv_buffer.storage(), config['rank']-1, config['comm'])
+ nccl.groupEnd()
+ current_stream.record_event(end_evt)
+ current_stream.synchronize()
+ time_usage = start_evt.elapsed_time(end_evt)
+
+ bw = shape / 1024 / 1024 / 1024 * 1000 / time_usage
+ print_rank("Send Recv:\tsize {}\ttime: {:4.3f}\tbw: {:2.6f} GB/s".format(format_size(send_size), time_usage, bw))
+
diff --git a/examples/BMTrain/bmtrain/benchmark/shape.py b/examples/BMTrain/bmtrain/benchmark/shape.py
new file mode 100644
index 00000000..0699e8cd
--- /dev/null
+++ b/examples/BMTrain/bmtrain/benchmark/shape.py
@@ -0,0 +1,3 @@
+SHAPES = [
+ (2**i) for i in range(10, 33)
+]
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/benchmark/utils.py b/examples/BMTrain/bmtrain/benchmark/utils.py
new file mode 100644
index 00000000..dbc4a70c
--- /dev/null
+++ b/examples/BMTrain/bmtrain/benchmark/utils.py
@@ -0,0 +1,11 @@
+def format_size_(x):
+ if x < 1024:
+ return "{:d}B".format(x)
+ if x < 1024 * 1024:
+ return "{:4.2f}KB".format(x / 1024)
+ if x < 1024 * 1024 * 1024:
+ return "{:4.2f}MB".format(x / 1024 / 1024)
+ return "{:4.2f}GB".format(x / 1024 / 1024 / 1024)
+
+def format_size(x):
+ return "{:.6s}".format(format_size_(x))
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/block_layer.py b/examples/BMTrain/bmtrain/block_layer.py
new file mode 100644
index 00000000..216d77b2
--- /dev/null
+++ b/examples/BMTrain/bmtrain/block_layer.py
@@ -0,0 +1,726 @@
+from typing import Dict, Iterable, Iterator, Union, List
+
+from .utils import round_up, tp_split_tensor
+from .global_var import config
+import torch
+from . import nccl
+from .parameter import DistributedParameter, OpAllGather
+from .zero_context import ZeroContext
+from . import hook_func
+import inspect
+from torch.utils.checkpoint import checkpoint
+
+
+def storage_type_cuda(storage_type):
+ """Convert storage_type to cuda storage_type."""
+ STORAGE_MAP = {
+ torch.FloatStorage: torch.cuda.FloatStorage,
+ torch.DoubleStorage: torch.cuda.DoubleStorage,
+ torch.HalfStorage: torch.cuda.HalfStorage,
+ torch.BFloat16Storage: torch.cuda.BFloat16Storage,
+ torch.CharStorage: torch.cuda.CharStorage,
+ torch.ByteStorage: torch.cuda.ByteStorage,
+ torch.ShortStorage: torch.cuda.ShortStorage,
+ torch.IntStorage: torch.cuda.IntStorage,
+ torch.cuda.FloatStorage: torch.cuda.FloatStorage,
+ torch.cuda.DoubleStorage: torch.cuda.DoubleStorage,
+ torch.cuda.HalfStorage: torch.cuda.HalfStorage,
+ torch.cuda.BFloat16Storage: torch.cuda.BFloat16Storage,
+ torch.cuda.CharStorage: torch.cuda.CharStorage,
+ torch.cuda.ByteStorage: torch.cuda.ByteStorage,
+ torch.cuda.ShortStorage: torch.cuda.ShortStorage,
+ torch.cuda.IntStorage: torch.cuda.IntStorage,
+ }
+ if storage_type not in STORAGE_MAP:
+ raise ValueError("Unknown storage type: {}".format(storage_type))
+ return STORAGE_MAP[storage_type]
+
+
+def _get_param_kw(param: DistributedParameter):
+ """Get DistributedParameter kw name."""
+ type_name = str(param.dtype).split(".")[-1]
+ grad_name = "_grad" if param.requires_grad else "_nograd"
+ group_name = ""
+ if param.group is not None:
+ group_name = "_g_" + param.group
+ return type_name + grad_name + group_name
+
+
+class Block(torch.nn.Module):
+ """A block containing two memory-saving methods of ZeRO and checkpoint.
+ For details please refer to `ZeRO `_ and
+ `Checkpointing `_ .
+
+ Args:
+ inner_module (torch.nn.Module): The module to reduce memory usage. All kinds of modules are supported.
+ use_checkpoint (boolean): use checkpoint or not. Default True.
+ zero_level (int): 2 (ZeRO-2) indicates that optimizer states and gradients are partitioned across the process,
+ 3 (ZeRO-3) means that the parameters are partitioned one the basis of ZeRO-2. Default 3.
+ initialized (bool): initialized parameter storage. Default False.
+ mode (str): the mode shouled be "PIPE" when runing in pipeline mode, otherwise mode="BLOCK". Default "BLOCK"
+
+ Examples:
+ >>> transformer_block = TransformerBlock(...)
+ >>> block = Block(transformer_block)
+ >>> y1, ... = block(x)
+ >>> y2, ... = transformer_block(x)
+ >>> assert torch.allclose(y1, y2)
+ """
+
+ def __init__(
+ self,
+ inner_module: torch.nn.Module,
+ use_checkpoint=True,
+ zero_level=3,
+ initialized=False,
+ mode="BLOCK",
+ ):
+ super().__init__()
+ self._module = inner_module
+ self._inputs = None
+ self._layer_dict = {}
+ self._forward_block_ctx = None
+ self._backward_block_ctx = None
+
+ self._param_info = []
+ self._storage_params: Dict[str, torch.nn.Parameter] = {}
+ self._storage_info = {}
+ self._ready = False
+
+ self._use_checkpoint = use_checkpoint
+ self._is_first_layer = True
+ self._is_last_layer = True
+ self._need_release = True
+ self._next_module = None # save the next module of self
+ self._pre_module = None # save the pre module of self
+ self._mode = mode # BLOCK or PIPE
+ self.all_input_no_grad = False
+ self.all_param_no_grad = False
+ self._zero_level = zero_level
+ if not initialized:
+ self.init_param_storage()
+
+ def reference(self, block):
+ """Make this block be a reference of the input Block."""
+ self._param_info = block._param_info
+ self._storage_params = block._storage_params
+ self._storage_info = block._storage_info
+ self._layer_dict = block._layer_dict
+ self._initialized = True
+ self._need_release = False
+
+ def init_param_storage(self):
+ """Init param storage."""
+ # sort parameters by name
+ ordered_parameters = list(self._module.named_parameters())
+
+ # calc total number of parameters
+ for name, param in ordered_parameters:
+ if not isinstance(param, DistributedParameter):
+ raise ValueError(
+ "All parameters in checkpoint block must be DistributedParameter."
+ )
+
+ storage_type = storage_type_cuda(param.storage_type())
+ kw_name = _get_param_kw(param)
+
+ if kw_name not in self._storage_info:
+ if self._mode == "PIPE" and param._tp_mode:
+ zero_comm = config["pp_tp_zero_comm"]
+ elif self._mode != "PIPE" and param._tp_mode:
+ zero_comm = config["tp_zero_comm"]
+ elif self._mode == "PIPE" and not param._tp_mode:
+ zero_comm = config["pp_zero_comm"]
+ else:
+ zero_comm = config["zero_comm"]
+
+ self._storage_info[kw_name] = {
+ "total": 0,
+ "storage_type": storage_type,
+ "requires_grad": param.requires_grad,
+ "group": param.group,
+ "zero_comm": zero_comm,
+ }
+
+ param_shape = param._original_shape
+
+ self._storage_info[kw_name]["total"] = round_up(
+ self._storage_info[kw_name]["total"] + param_shape.numel(),
+ 512 // param.element_size(),
+ # 512 bytes aligned
+ )
+
+ offsets = {}
+ # intialize storage buffers
+ for kw, val in self._storage_info.items():
+ comm = val["zero_comm"]
+ world_size = nccl.commCount(comm)
+ rank = nccl.commRank(comm)
+ val["world_size"] = world_size
+ partition_size = (
+ round_up(val["total"], val["world_size"]) // val["world_size"]
+ )
+ val["partition_size"] = partition_size
+ val["begin"] = rank * partition_size
+ val["end"] = (rank + 1) * partition_size
+ offsets[kw] = 0
+
+ storage_type = val["storage_type"]
+
+ storage_param_buffer = storage_type(partition_size)
+
+ dtype = storage_param_buffer.dtype
+ device = storage_param_buffer.device
+
+ # bind storage to buffer tensor
+ storage_param = torch.nn.Parameter(
+ torch.tensor([], dtype=dtype, device=device).set_(storage_param_buffer)
+ )
+ if val["requires_grad"]:
+ storage_param.requires_grad_(True)
+ else:
+ storage_param.requires_grad_(False)
+
+ self._storage_params[kw] = storage_param
+
+ # initialize parameters in module
+ for name, param in ordered_parameters:
+ param_shape = param._original_shape
+ kw_name = _get_param_kw(param)
+
+ param_st = offsets[kw_name]
+ offsets[kw_name] += param_shape.numel()
+ param_end = offsets[kw_name]
+ offsets[kw_name] = round_up(offsets[kw_name], 512 // param.element_size())
+
+ self._param_info.append(
+ {
+ "parameter": param,
+ "name": name,
+ "offset": param_st,
+ "size": param_shape.numel(),
+ "shape": param_shape,
+ "kw_name": kw_name,
+ }
+ )
+
+ # copy values to buffer for normal parameter
+ storage_st = self._storage_info[kw_name]["begin"]
+ storage_end = self._storage_info[kw_name]["end"]
+
+ # make parameter contiguous in storage
+ with torch.no_grad():
+ contiguous_param = OpAllGather.apply(param)
+
+ if not (param_st >= storage_end or param_end <= storage_st):
+ # copy offset in parameter storage
+ offset_st = max(storage_st - param_st, 0)
+ offset_end = min(storage_end - param_st, contiguous_param.numel())
+ assert offset_st < offset_end
+
+ # copy to offset in buffer storage
+ to_offset_st = offset_st + param_st - storage_st
+ to_offset_end = offset_end + param_st - storage_st
+
+ # copy to buffer
+ # PyTorch 1.11 changed the API of storage.__getitem__
+ d_dtype = self._storage_params[kw_name].dtype
+ d_device = self._storage_params[kw_name].device
+ param.data = torch.tensor(
+ [], dtype=param.dtype, device=param.device
+ ).set_(
+ self._storage_params[kw_name].storage(),
+ to_offset_st,
+ (to_offset_end - to_offset_st,),
+ )
+ self._param_info[-1]["begin"] = to_offset_st
+ self._param_info[-1]["end"] = (to_offset_end - to_offset_st,)
+ setattr(param, "_start_partition", offset_st)
+ setattr(param, "_end_partition", offset_end)
+ param.data[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_(
+ contiguous_param.storage(), offset_st, (offset_end - offset_st,)
+ )[:]
+ del contiguous_param
+ else:
+ param.data = torch.tensor([], dtype=param.dtype, device=param.device)
+ setattr(param, "_start_partition", None)
+ setattr(param, "_end_partition", 0)
+ # clear parameter data, but keep the dtype and device
+ setattr(param, "_in_block", True)
+
+ for kw in offsets.keys():
+ assert offsets[kw] == self._storage_info[kw]["total"]
+
+ def set_pre_module(self, pre_module):
+ """Set pre module for current Block."""
+ if pre_module is not None:
+ self._pre_module = pre_module
+ pre_module._next_module = self
+
+ def pre_module(self):
+ """Return pre module of current Block."""
+ return self._pre_module if not self._is_first_layer else None
+
+ def next_module(self):
+ """Return next module of current Block."""
+ return self._next_module if not self._is_last_layer else None
+
+ def release_next_module(self, flag):
+ """Release next module of current Block."""
+ if self.next_module() is not None:
+ self.next_module().release(flag)
+
+ def release(self, flag):
+ """Release cuurent block ctx."""
+ if self._need_release and self._backward_block_ctx is not None:
+ self._backward_block_ctx.exit(flag, True)
+ config["load_stream"].record_event(config["load_event"])
+
+ def pre_hook(self, *args):
+ """Hook function before forward."""
+ grad_tensors = []
+ grad_index = []
+ arg_list = list(args)
+ for i, arg in enumerate(args):
+ if arg is not None and isinstance(arg, torch.Tensor) and arg.requires_grad:
+ grad_tensors.append(arg)
+ grad_index.append(i)
+ grad_tensors = tuple(grad_tensors)
+
+ pre_out = hook_func.PreHookFunc.apply(self, *grad_tensors)
+ for i in range(len(grad_index)):
+ arg_list[grad_index[i]] = pre_out[i]
+
+ if self._mode != "PIPE" and len(grad_tensors) == 0:
+ self.all_param_no_grad = True
+ for param in self._param_info:
+ if param["parameter"].requires_grad:
+ self.all_param_no_grad = False
+ break
+ self.all_input_no_grad = True
+ else:
+ self.all_input_no_grad = False
+ return arg_list
+
+ def post_hook(self, out):
+ """Hook function after forward."""
+ tuple_out = (out,) if isinstance(out, torch.Tensor) else out
+ post_out = hook_func.PostHookFunc.apply(self, *tuple_out)
+ if isinstance(out, torch.Tensor) and isinstance(post_out, tuple):
+ return post_out[0]
+ post_out = tuple(post_out)
+ return post_out
+
+ def forward(self, *args, **kwargs):
+ signature = inspect.signature(self._module.forward)
+ bound_args = signature.bind(*args, **kwargs)
+ args = bound_args.args
+ arg_list = self.pre_hook(*args)
+
+
+ if self.all_input_no_grad and not self.all_param_no_grad:
+ placeholder = torch.tensor([], requires_grad=torch.is_grad_enabled())
+ return hook_func.OneStepNoGradFunc.apply(self, placeholder, *arg_list)
+
+ if self._use_checkpoint:
+ out = checkpoint(
+ self._module, *arg_list, use_reentrant=not self.all_input_no_grad
+ )
+ else:
+ out = self._module(*arg_list)
+
+ return self.post_hook(out)
+
+ def __getattr__(self, name: str):
+ if name == "_module":
+ return self._module
+ return getattr(self._module, name)
+
+ def __setattr__(self, name, value):
+ object.__setattr__(self, name, value)
+
+ def __getattribute__(self, name: str):
+ if name == "_parameters":
+ return self._module._parameters
+ return super().__getattribute__(name)
+
+ def __delattr__(self, name):
+ object.__delattr__(self, name)
+
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
+ raise RuntimeError("._save_to_state_dict() of Block should not be called")
+
+ def state_dict(self, destination=None, prefix="", keep_vars=False):
+ # gather here
+ with torch.no_grad():
+ with ZeroContext(self):
+ return self._module.state_dict(
+ destination=destination, prefix=prefix, keep_vars=keep_vars
+ )
+
+ def _load_from_state_dict(
+ self,
+ state_dict,
+ prefix,
+ local_metadata,
+ strict,
+ missing_keys,
+ unexpected_keys,
+ error_msgs,
+ ):
+ all_keys = []
+ for it in self._param_info:
+ key = prefix + it["name"]
+ all_keys.append(key)
+ if key in state_dict:
+ # load here
+ input_param = state_dict[key]
+ param = it["parameter"]
+ tp_mode = param._tp_mode
+ if input_param.__class__.__name__ == "DistributedTensorWrapper":
+ input_param = input_param.broadcast()
+
+ verify_shape = torch.Size(
+ it["shape"] if not tp_mode else param._tp_original_shape
+ )
+ if input_param.shape != verify_shape:
+ error_msgs.append(
+ "size mismatch for {}: copying a param with shape {} from checkpoint, "
+ "the shape in current model is {}.".format(
+ key, input_param.shape, verify_shape
+ )
+ )
+ continue
+
+ param_st = it["offset"]
+ param_end = it["offset"] + it["size"]
+ kw_name = it["kw_name"]
+
+ # not in this partition
+ storage_st = self._storage_info[kw_name]["begin"]
+ storage_end = self._storage_info[kw_name]["end"]
+ if param_st >= storage_end:
+ continue
+ if param_end <= storage_st:
+ continue
+
+ # copy to buffer
+ verify_size = verify_shape.numel()
+ assert input_param.numel() == verify_size
+
+ contiguous_param = (
+ input_param.to(it["parameter"].dtype).cuda().contiguous()
+ )
+
+ tp_split_dim = param._tp_split_dim
+ if tp_mode and tp_split_dim >= 0:
+ contiguous_param = tp_split_tensor(contiguous_param, tp_split_dim)
+
+ offset_st = max(storage_st - param_st, 0)
+ offset_end = min(storage_end - param_st, contiguous_param.numel())
+ assert offset_st < offset_end
+
+ to_offset_st = offset_st + param_st - storage_st
+ to_offset_end = offset_end + param_st - storage_st
+
+ # copy to buffer
+ # PyTorch 1.11 changed the API of storage.__getitem__
+ d_dtype = self._storage_params[kw_name].dtype
+ d_device = self._storage_params[kw_name].device
+ torch.tensor([], dtype=d_dtype, device=d_device).set_(
+ self._storage_params[kw_name].storage(),
+ to_offset_st,
+ (to_offset_end - to_offset_st,),
+ )[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_(
+ contiguous_param.storage(), offset_st, (offset_end - offset_st,)
+ )[
+ :
+ ]
+ del contiguous_param
+ elif strict:
+ missing_keys.append(key)
+
+ for name, param in self.named_parameters():
+ if isinstance(param, DistributedParameter) and not param._in_block:
+ key = prefix + name
+ all_keys.append(key)
+ if key in state_dict:
+ input_param = state_dict[key]
+ is_param_lazy = torch.nn.parameter.is_lazy(param)
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
+ if (
+ not is_param_lazy
+ and len(param.shape) == 0
+ and len(input_param.shape) == 1
+ ):
+ input_param = input_param[0]
+
+ if (
+ not is_param_lazy
+ and not isinstance(param, DistributedParameter)
+ and input_param.shape != param.shape
+ ):
+ # local shape should match the one in checkpoint
+ error_msgs.append(
+ "size mismatch for {}: copying a param with shape {} from checkpoint, "
+ "the shape in current model is {}.".format(
+ key, input_param.shape, param.shape
+ )
+ )
+ continue
+ if (
+ not is_param_lazy
+ and isinstance(param, DistributedParameter)
+ and input_param.shape != param._original_shape
+ ):
+ error_msgs.append(
+ "size mismatch for {}: copying a param with shape {} from checkpoint, "
+ "the shape in current model is {}.".format(
+ key, input_param.shape, param.shape
+ )
+ )
+ try:
+ with torch.no_grad():
+ param._copy_data(input_param)
+ except Exception as ex:
+ error_msgs.append(
+ 'While copying the parameter named "{}", '
+ "whose dimensions in the model are {} and "
+ "whose dimensions in the checkpoint are {}, "
+ "an exception occurred : {}.".format(
+ key, param.size(), input_param.size(), ex.args
+ )
+ )
+ elif strict:
+ missing_keys.append(key)
+
+ if strict:
+ all_keys = set(all_keys)
+ for key in state_dict.keys():
+ if key.startswith(prefix) and key not in all_keys:
+ unexpected_keys.append(key)
+
+ def grouped_parameters(self):
+ """
+ Yield group params in storage params.
+ """
+ ret = {}
+ for kw, val in self._storage_info.items():
+ if val["group"] not in ret:
+ ret[val["group"]] = []
+ ret[val["group"]].append(self._storage_params[kw])
+ for kw, val in ret.items():
+ yield kw, val
+
+ def init_parameters(self):
+ """
+ Initialize distributed parameters in this block.
+ """
+ for it in self._param_info:
+ param = it["parameter"]
+ if (
+ isinstance(param, DistributedParameter)
+ and param._init_method is not None
+ ):
+ # initialzie here
+ tmp_tensor = torch.empty(
+ param._tp_original_shape, device=param.device, dtype=param.dtype
+ )
+ param._init_method(tmp_tensor)
+ param_st = it["offset"]
+ param_end = it["offset"] + it["size"]
+ kw_name = it["kw_name"]
+
+ # not in this partition
+ storage_st = self._storage_info[kw_name]["begin"]
+ storage_end = self._storage_info[kw_name]["end"]
+ if param_st >= storage_end:
+ continue
+ if param_end <= storage_st:
+ continue
+
+ if param._tp_mode and param._tp_split_dim >= 0:
+ tmp_tensor = tp_split_tensor(tmp_tensor, param._tp_split_dim)
+ # copy to buffer
+ assert tmp_tensor.is_contiguous() and it["size"] == tmp_tensor.numel()
+
+ offset_st = max(storage_st - param_st, 0)
+ offset_end = min(storage_end - param_st, tmp_tensor.numel())
+ assert offset_st < offset_end
+
+ # copy to buffer
+ # PyTorch 1.11 changed the API of storage.__getitem__
+ d_dtype = self._storage_params[kw_name].dtype
+ d_device = self._storage_params[kw_name].device
+ param.data[:] = torch.tensor([], dtype=d_dtype, device=d_device).set_(
+ tmp_tensor.storage(), offset_st, (offset_end - offset_st,)
+ )[:]
+ del tmp_tensor
+
+ def _named_members(self, get_members_fn, prefix="", recurse=True, **kwargs):
+ r"""Helper method for yielding various names + members of modules."""
+
+ # compitibity with torch 2.0
+ if (
+ "remove_duplicate"
+ in inspect.signature(torch.nn.Module._named_members).parameters
+ and "remove_duplicate" not in kwargs
+ ):
+ kwargs["remove_duplicate"] = True
+ return self._module._named_members(get_members_fn, prefix, recurse, **kwargs)
+
+ def named_modules(self, memo=None, prefix: str = "", remove_duplicate: bool = True):
+ r"""Returns an iterator over all modules in the network, yielding
+ both the name of the module as well as the module itself.
+
+ Args:
+ memo: a memo to store the set of modules already added to the result
+ prefix: a prefix that will be added to the name of the module
+ remove_duplicate: whether to remove the duplicated module instances in the result
+ or not
+
+ Yields:
+ (string, Module): Tuple of name and module
+
+ Note:
+ Duplicate modules are returned only once. In the following
+ example, ``l`` will be returned only once.
+
+ Example::
+
+ >>> l = nn.Linear(2, 2)
+ >>> net = nn.Sequential(l, l)
+ >>> for idx, m in enumerate(net.named_modules()):
+ print(idx, '->', m)
+
+ 0 -> ('', Sequential(
+ (0): Linear(in_features=2, out_features=2, bias=True)
+ (1): Linear(in_features=2, out_features=2, bias=True)
+ ))
+ 1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
+
+ """
+
+ if memo is None:
+ memo = set()
+ if self not in memo:
+ if remove_duplicate:
+ memo.add(self)
+ yield prefix, self
+ for name, module in self._module._modules.items():
+ if module is None:
+ continue
+ submodule_prefix = prefix + ("." if prefix else "") + name
+ for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
+ yield m
+
+ def named_children(self):
+ return self._module.named_children()
+
+ def train(self, mode: bool = True):
+ self._module.train(mode)
+
+ def eval(self):
+ self._module.eval()
+
+ def __repr__(self):
+ return self._module.__repr__()
+
+
+def _block_wrapper(module, module_dict: dict, mode="BLOCK"):
+ if not isinstance(module, Block):
+ in_block = id(module) in module_dict
+ new_module = Block(module, initialized=in_block, mode=mode)
+ if in_block:
+ new_module.reference(module_dict[id(module)])
+ else:
+ module_dict[id(module)] = new_module
+ else:
+ if mode == "PIPE" and module._mode != "PIPE":
+ assert (
+ False
+ ), 'You must be set mode="PIPE" in bmt.Block when use PipelineTransformerBlockList!'
+ if id(module._module) in module_dict:
+ assert False, "Duplicate bmt.Block not supported in same block list!"
+ else:
+ new_module = module
+ module_dict[id(module._module)] = new_module
+ return new_module
+
+
+class TransformerBlockList(torch.nn.Module):
+ r"""
+ TransformerBlockList is a list of bmt.Block.
+
+ This is designed to reduce the communication overhead by overlapping the computation and reduce_scatter operation during backward pass.
+
+ It is similar to `torch.nn.ModuleList` but with the difference when calling .forward() and .backward().
+
+ Example:
+ >>> module_list = [ ... ]
+ >>> normal_module_list = torch.nn.ModuleList(module_list)
+ >>> transformer_module_list = TransformerBlockList(module_list)
+ >>> # Calling normal module list
+ >>> for layer in normal_module_list:
+ >>> hidden_state = layer.forward(hidden_state, ...)
+ >>> # Calling transformer module list
+ >>> hidden_state = transformer_module_list(hidden_state, ...)
+
+ """
+
+ _modules: Dict[str, Block]
+
+ def __init__(self, modules: Iterable[Block], num_hidden=1) -> None:
+ super().__init__()
+
+ self._modules = {}
+ pre_module = None
+ module_dict = {}
+ module_dict = {}
+ for i, module in enumerate(modules):
+ module = _block_wrapper(module, module_dict)
+ module.set_pre_module(pre_module)
+ pre_module = module
+ module._is_first_layer = False
+ module._is_last_layer = False
+ self._modules[str(i)] = module
+ self.add_module(str(i), module)
+
+ self._modules[str(0)]._is_first_layer = True
+ self._modules[str(len(modules) - 1)]._is_last_layer = True
+
+ self.num_hidden = num_hidden
+
+ def __len__(self) -> int:
+ return len(self._modules)
+
+ def __iter__(self) -> Iterator[Block]:
+ return iter(self._modules.values())
+
+ def __getitem__(self, index: Union[int, str]) -> Block:
+ return self._modules[str(index)]
+
+ def forward(self, *args, return_hidden_states=False):
+ self.return_hidden_states = return_hidden_states
+ hidden_states = []
+ for i in range(len(self)):
+ if return_hidden_states:
+ for hidden_state in args[: self.num_hidden]:
+ hidden_states.append(hidden_state)
+ outputs = self._modules[str(i)]._call_impl(*args)
+ if not isinstance(outputs, tuple):
+ outputs = (outputs,)
+ args = outputs + args[self.num_hidden :]
+
+ if return_hidden_states:
+ hidden_states = [
+ torch.stack(hidden_states[i :: self.num_hidden], dim=0)
+ for i in range(self.num_hidden)
+ ]
+
+ if return_hidden_states:
+ return outputs + tuple(hidden_states)
+ else:
+ return (
+ tuple(outputs[: self.num_hidden]) if self.num_hidden > 1 else outputs[0]
+ )
diff --git a/examples/BMTrain/bmtrain/debug.py b/examples/BMTrain/bmtrain/debug.py
new file mode 100644
index 00000000..de392623
--- /dev/null
+++ b/examples/BMTrain/bmtrain/debug.py
@@ -0,0 +1,34 @@
+import torch
+
+DEBUG_VARS = {}
+
+def clear(key=None):
+ global DEBUG_VARS
+ if key is None:
+ DEBUG_VARS = {}
+ else:
+ DEBUG_VARS.pop(key, None)
+
+def set(key, value):
+ global DEBUG_VARS
+ if torch.is_tensor(value):
+ value = value.detach().cpu()
+ DEBUG_VARS[key] = value
+
+def get(key, default=None):
+ global DEBUG_VARS
+ if key in DEBUG_VARS:
+ return DEBUG_VARS[key]
+ return default
+
+def append(key, value):
+ global DEBUG_VARS
+ if key not in DEBUG_VARS:
+ DEBUG_VARS[key] = []
+ DEBUG_VARS[key].append(value)
+
+def extend(key, value):
+ global DEBUG_VARS
+ if key not in DEBUG_VARS:
+ DEBUG_VARS[key] = []
+ DEBUG_VARS[key].extend(value)
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/distributed/__init__.py b/examples/BMTrain/bmtrain/distributed/__init__.py
new file mode 100644
index 00000000..84a4adf8
--- /dev/null
+++ b/examples/BMTrain/bmtrain/distributed/__init__.py
@@ -0,0 +1 @@
+from .ops import all_gather, all_reduce, broadcast, recv_activations, send_activations, reduce_scatter
diff --git a/examples/BMTrain/bmtrain/distributed/ops.py b/examples/BMTrain/bmtrain/distributed/ops.py
new file mode 100644
index 00000000..d1b489e2
--- /dev/null
+++ b/examples/BMTrain/bmtrain/distributed/ops.py
@@ -0,0 +1,223 @@
+import torch
+from ..global_var import config
+from ..nccl import allGather as ncclAllGather, recv
+from ..nccl import allReduce as ncclAllReduce
+from ..nccl import broadcast as ncclBroadcast
+from ..nccl import reduceScatter as ncclReduceScatter
+from ..nccl import send as ncclSend
+from ..nccl import recv as ncclRecv
+from ..nccl import commCount,commRank,NCCLCommunicator
+DTYPE_LIST = [
+ torch.float64,
+ torch.float32,
+ torch.float16,
+ torch.int64,
+ torch.int32,
+ torch.int16,
+ torch.int8,
+ torch.bfloat16,
+ torch.bool
+]
+def send_activations(hidden_state, next_rank, comm):
+ send_meta(hidden_state, next_rank, comm)
+ ncclSend(hidden_state.storage(), next_rank, comm)
+
+def recv_activations(prev_rank, comm):
+ dtype, shape = recv_meta(prev_rank, comm)
+ hidden_state = torch.empty(shape, dtype=dtype, device="cuda")
+ ncclRecv(hidden_state.storage(), prev_rank, comm)
+ return hidden_state
+
+def send_meta(x, next_rank, comm):
+ meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int)
+ meta_data[0] = len(x.size())
+ meta_data[1] = DTYPE_LIST.index(x.dtype)
+ meta_data[2:len(x.size())+2] = torch.tensor(x.size(), device="cuda", dtype=torch.int)
+ meta_data = meta_data.contiguous()
+ ncclSend(meta_data.storage(), next_rank, comm)
+
+def recv_meta(prev_rank, comm):
+ meta_data = torch.tensor(data=[0]*50, device="cuda", dtype=torch.int)
+ ncclRecv(meta_data.storage(), prev_rank, comm)
+ n_dims = meta_data[0].item()
+ dtype = DTYPE_LIST[meta_data[1].item()]
+ shape = meta_data[2:n_dims+2].tolist()
+ return dtype,shape
+
+class OpBroadcast(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, src, root, comm = None):
+ if comm is None:
+ comm = config["comm"]
+ ctx.comm = comm
+ outputs = torch.empty_like(src, dtype = src.dtype, device = src.device)
+ ncclBroadcast(src.storage(), outputs.storage(), root, comm)
+ return outputs
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ res = all_reduce(grad_output, "sum", ctx.comm)
+ return res, None, None
+
+def broadcast(src, root, comm=None):
+ if not config["initialized"]:
+ raise RuntimeError("BMTrain is not initialized")
+ return OpBroadcast.apply(src, root, comm)
+
+class OpAllGather(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input : torch.Tensor, comm = None):
+ if comm is None:
+ comm = config["comm"]
+ world_size = commCount(comm)
+ if not input.is_contiguous():
+ input = input.contiguous()
+ if input.storage_offset() != 0 or input.storage().size() != input.numel():
+ input = input.clone()
+ output = torch.empty( (world_size,) + input.size(), dtype=input.dtype, device=input.device)
+ ctx.comm = comm
+ ncclAllGather(
+ input.storage(),
+ output.storage(),
+ comm
+ )
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ return grad_output[commRank(ctx.comm)], None
+
+def all_gather(x : torch.Tensor, comm = None):
+ """Gathers the input tensor from all processes.
+
+ Args:
+ x (torch.Tensor): The input tensor of shape (...).
+
+ Returns:
+ torch.Tensor: The gathered tensor of shape (world_size, ...).
+ """
+ if not config["initialized"]:
+ raise RuntimeError("BMTrain is not initialized")
+
+ assert x.is_cuda
+ return OpAllGather.apply(x, comm)
+
+class OpReduceScatter(torch.autograd.Function):
+
+ @staticmethod
+ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None):
+ if comm is None:
+ comm = config["comm"]
+ ctx.comm = comm
+ rank = commRank(comm)
+ assert input.shape[0] % commCount(comm) == 0, "The dimension 0 must be divisible by the number of communication processes"
+ if not input.is_contiguous():
+ input = input.contiguous()
+ if input.storage_offset() != 0 or input.storage().size() != input.numel():
+ input = input.clone()
+ output_shape = (input.shape[0] // commCount(comm), *input.shape[1:])
+ output = torch.empty( output_shape, dtype=input.dtype, device=input.device )
+ ncclReduceScatter(
+ input.storage(),
+ output.storage(),
+ op,
+ comm
+ )
+ ctx.op = op
+ if op in ["sum", "avg"]:
+ pass
+ elif op in ["max", "min"]:
+ ctx.save_for_backward( output != input[rank * input.shape[0]:(rank + 1) * input.shape[0]] )
+ else:
+ ctx.save_for_backward( output / input[rank * input.shape[0]:(rank + 1) * input.shape[0]] )
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ with torch.no_grad():
+ grad_output = OpAllGather.apply(grad_output, ctx.comm).flatten(0,1)
+ if ctx.op in ["max", "min", "prod"]:
+ raise NotImplementedError("max min operation now do not support backward")
+ else:
+ if ctx.op == "avg":
+ grad_output /= commCount(ctx.comm)
+ return grad_output, None, None
+
+
+def reduce_scatter(x : torch.Tensor, op : str = "sum", comm = None):
+ """Reduces the input tensor from all processes.
+
+ Args:
+ x (torch.Tensor): The input tensor of shape (world_size, ...).
+ op (str): The reduction operation, one of "sum", "avg", "max", "min", "prod". Default: "sum".
+
+ Returns:
+ torch.Tensor: The reduced tensor of shape (...).
+
+ """
+ if not config["initialized"]:
+ raise RuntimeError("BMTrain is not initialized")
+
+ assert x.is_cuda
+ return OpReduceScatter.apply(x, op, comm)
+
+class OpAllReduce(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input : torch.Tensor, op : str, comm : NCCLCommunicator = None):
+ if comm is None:
+ comm = config["comm"]
+ ctx.comm = comm
+ if not input.is_contiguous():
+ input = input.contiguous()
+ if input.storage_offset() != 0 or input.storage().size() != input.numel():
+ input = input.clone()
+ output = torch.empty( input.size(), dtype=input.dtype, device=input.device)
+
+ ncclAllReduce(
+ input.storage(),
+ output.storage(),
+ op,
+ comm
+ )
+ ctx.op = op
+
+ if op in ["sum", "avg"]:
+ pass
+ elif op in ["max", "min"]:
+ ctx.save_for_backward( input != output )
+ else:
+ ctx.save_for_backward( output / input )
+ return output
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ if ctx.op == "sum":
+ return grad_output, None, None
+ elif ctx.op == "avg":
+ return grad_output / commCount(ctx.comm), None, None
+ elif ctx.op in ["max", "min"]:
+ return torch.masked_fill(grad_output, ctx.saved_tensors[0], 0), None, None
+ else:
+ return grad_output * ctx.saved_tensors[0], None, None
+
+def all_reduce(x : torch.Tensor, op : str = "sum", comm = None):
+ """Reduces the input tensor from all processes.
+
+ Args:
+ x (torch.Tensor): The input tensor of shape (...).
+ op (str): The reduction operation, one of "sum", "avg", "max", "min", "prod". Default: "sum".
+
+ Returns:
+ torch.Tensor: The reduced tensor of shape (...).
+
+ """
+ if not config["initialized"]:
+ raise RuntimeError("BMTrain is not initialized")
+
+ assert x.is_cuda
+ return OpAllReduce.apply(x, op, comm)
+
+
+
diff --git a/examples/BMTrain/bmtrain/global_var.py b/examples/BMTrain/bmtrain/global_var.py
new file mode 100644
index 00000000..137fa9cd
--- /dev/null
+++ b/examples/BMTrain/bmtrain/global_var.py
@@ -0,0 +1,35 @@
+import torch
+from typing_extensions import TypedDict
+class ConfigMap(TypedDict):
+ rank : int
+ local_rank : int
+ world_size : int
+ local_size : int
+ zero_level : int
+ pipe_size : int
+ num_micro_batches : int
+ calc_stream : torch.cuda.Stream
+ load_stream : torch.cuda.Stream
+ load_event : torch.cuda.Event
+ barrier_stream : torch.cuda.Stream
+ loss_scale_factor : float
+ loss_scale_steps : int
+ topology : 'topology'
+ gradient_inspect : bool
+ initialized : bool
+
+ comm : 'NCCLCommunicator'
+
+config = ConfigMap(rank=0, local_rank=0, world_size=1, initialized=False)
+
+def rank():
+ """
+ Returns the global rank of the current process. (0 ~ world_size-1)
+ """
+ return config['rank']
+
+def world_size():
+ """
+ Returns the total number of workers across all nodes.
+ """
+ return config['world_size']
diff --git a/examples/BMTrain/bmtrain/hook_func.py b/examples/BMTrain/bmtrain/hook_func.py
new file mode 100644
index 00000000..577331a2
--- /dev/null
+++ b/examples/BMTrain/bmtrain/hook_func.py
@@ -0,0 +1,121 @@
+import torch
+from .global_var import config
+from .zero_context import ZeroContext
+
+
+def zero_pre_forward(module, inputs):
+ """Helper function for using ZeroContext to gather parmas before forward."""
+ enter = True
+ pipe = False
+ if module._mode == "PIPE":
+ enter = module._micro_idx == 0
+ pipe = True
+ if enter:
+ zero_level = module._zero_level
+ forward_flag = 1 if zero_level == 2 else 0
+ if zero_level == 2 and not module._need_release:
+ forward_flag = 2 # repeating forward in same layer
+ if module.all_param_no_grad: # only forward
+ forward_flag = 0
+ module._forward_block_ctx = ZeroContext(module, module._layer_dict, pipe=pipe)
+ module._forward_block_ctx.enter(forward_flag)
+
+
+def zero_post_forward(module, inputs, outputs):
+ """Helper function for module _forwar_block_ctx weather exits after forward."""
+ forward_flag = 1 if module._zero_level == 2 else 0
+ if module.all_param_no_grad:
+ forward_flag = 0
+ exit = True
+ if module._mode == "PIPE":
+ exit = module._micro_idx == config["micros"] - 1
+
+ if exit:
+ module._forward_block_ctx.exit(forward_flag)
+
+
+def zero_pre_backward(module, grad_outputs):
+ """Helper function for using ZeroContext to init grad buffer before backward."""
+ backward_flag = 2 if module._zero_level == 2 else 0
+ if module._mode != "PIPE":
+ module._backward_block_ctx = ZeroContext(module, module._layer_dict)
+ module._backward_block_ctx.enter(backward_flag, True)
+ module.release_next_module(backward_flag)
+ else:
+ if module._micro_idx == config["micros"] - 1:
+ module._backward_block_ctx = ZeroContext(
+ module, module._layer_dict, pipe=True
+ )
+ module._backward_block_ctx.enter(backward_flag, True)
+
+
+def zero_post_backward(module, grad_inputs, grad_outputs):
+ """Helper function for module weather release after backward."""
+ backward_flag = 2 if module._zero_level == 2 else 0
+ if module._mode != "PIPE":
+ if module._is_first_layer:
+ module.release(backward_flag)
+ else:
+ if module._micro_idx == 0:
+ module.release(backward_flag)
+ module._micro_idx -= 1
+
+
+class OneStepNoGradFunc(torch.autograd.Function):
+ """
+ Requires_grad = False for all inputs.
+ """
+
+ @staticmethod
+ def forward(ctx, module, placeholder, *x):
+ ctx.x = x
+ ctx.module = module
+ ctx.rng_state = torch.cuda.get_rng_state()
+
+ with torch.no_grad():
+ out = module._module(*x)
+ zero_post_forward(module, None, out)
+ if not isinstance(out, torch.Tensor):
+ return tuple(out)
+ return out
+
+ @staticmethod
+ def backward(ctx, grads):
+ zero_pre_backward(ctx.module, grads)
+ with torch.random.fork_rng(devices=[torch.cuda.current_device()], enabled=True):
+ torch.cuda.set_rng_state(ctx.rng_state)
+ x = ctx.x
+ with torch.enable_grad():
+ out = ctx.module._module(*x)
+ torch.autograd.backward(out, grads)
+ zero_post_backward(ctx.module, grads, None)
+ grads = []
+ for _ in x:
+ grads.append(None)
+ return None, None, *grads
+
+
+class PreHookFunc(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, module, *x):
+ ctx.module = module
+ zero_pre_forward(module, x)
+ return x
+
+ @staticmethod
+ def backward(ctx, *grads):
+ zero_post_backward(ctx.module, grads, None)
+ return None, *grads
+
+
+class PostHookFunc(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, module, *out):
+ ctx.module = module
+ zero_post_forward(module, None, out)
+ return out
+
+ @staticmethod
+ def backward(ctx, *grads):
+ zero_pre_backward(ctx.module, grads)
+ return None, *grads
diff --git a/examples/BMTrain/bmtrain/init.py b/examples/BMTrain/bmtrain/init.py
new file mode 100644
index 00000000..601d617e
--- /dev/null
+++ b/examples/BMTrain/bmtrain/init.py
@@ -0,0 +1,258 @@
+import datetime
+import torch
+import random
+import torch.distributed as dist
+import os
+from .utils import print_dict
+import ctypes
+from .global_var import config
+
+from . import nccl
+from .synchronize import synchronize
+
+
+def init_distributed(
+ init_method: str = "env://",
+ seed: int = 0,
+ pipe_size: int = -1,
+ num_micro_batches: int = None,
+ tp_size: int = 1,
+):
+ """Initialize distributed training.
+ This function will initialize the distributed training, set the random seed and global configurations.
+ It must be called before any other distributed functions.
+
+ Args:
+ seed (int): The random seed.
+ pipe_size (int) : pipe_size means that all processes will be divided into pipe_size groups
+ num_micro_batches (int) : means that the input batchs will be divided into num_micro_batches small batches. used in pipeline mode.
+ tp_size (int) : tp_size means the size of each of tensor parallel group
+
+ **init_distributed** reads the following environment variables:
+
+ * `WORLD_SIZE`: The total number gpus in the distributed training.
+ * `RANK`: The global rank of the current gpu. From 0 to `WORLD_SIZE - 1`.
+ * `MASTER_ADDR`: The address of the master node.
+ * `MASTER_PORT`: The port of the master node.
+ * `LOCAL_RANK`: The local rank of the current gpu.
+
+ Normally, all the environments variables above are setted by the pytorch distributed launcher.
+
+ **Note**: Do not use any functions in torch.distributed package including `torch.distributed.init_process_group` .
+
+ **Note**: If your training script is stuck here , it means some of your distributed workers are not connected to the master node.
+
+ """
+ torch.backends.cudnn.enabled = False
+
+ local_rank = int(os.environ.get("LOCAL_RANK", "0"))
+ rank = int(os.environ.get("RANK", "0"))
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
+ local_size = int(os.environ.get("LOCAL_WORLD_SIZE", "1"))
+ if "MASTER_ADDR" not in os.environ:
+ os.environ["MASTER_ADDR"] = "localhost"
+ if "MASTER_PORT" not in os.environ:
+ os.environ["MASTER_PORT"] = "10010"
+ addr = os.environ["MASTER_ADDR"]
+ port = os.environ["MASTER_PORT"]
+ master = addr + ":" + port
+ timeout = datetime.timedelta(seconds=1800)
+ rendezvous_iterator = dist.rendezvous(
+ init_method, rank, world_size, timeout=timeout
+ )
+
+ store, rank, world_size = next(rendezvous_iterator)
+ store.set_timeout(timeout)
+ store = dist.PrefixStore("bmtrain", store)
+ torch.cuda.set_device(local_rank)
+ config["initialized"] = True
+ config["pipe_size"] = pipe_size if pipe_size > 0 else 1
+ config["pipe_enabled"] = pipe_size > 0
+ config["local_rank"] = local_rank
+ config["local_size"] = local_size
+ config["rank"] = rank
+ config["world_size"] = world_size
+ config["calc_stream"] = torch.cuda.current_stream()
+ config["load_stream"] = torch.cuda.Stream(priority=-1)
+ config["tp_comm_stream"] = torch.cuda.Stream(priority=-1)
+ config["pp_comm_stream"] = torch.cuda.Stream(priority=-1)
+ config["barrier_stream"] = torch.cuda.Stream()
+ config["load_event"] = torch.cuda.Event()
+ config["tp_size"] = tp_size if tp_size > 0 else 1
+ config["topology"] = topology(config)
+ config["zero_rank"] = config["topology"].get_group_rank("zero")
+ config["tp_rank"] = config["topology"].get_group_rank("tp")
+ config["tp_zero_rank"] = config["topology"].get_group_rank("tp_zero")
+ config["save_param_to_cpu"] = True
+ cpus_this_worker = None
+
+ all_available_cpus = sorted(list(os.sched_getaffinity(0)))
+
+ cpus_per_worker = len(all_available_cpus) // local_size
+
+ if cpus_per_worker < 1:
+ cpus_this_worker = all_available_cpus
+ torch.set_num_threads(1)
+ else:
+ cpus_this_worker = all_available_cpus[
+ local_rank * cpus_per_worker : (local_rank + 1) * cpus_per_worker
+ ]
+ os.sched_setaffinity(0, cpus_this_worker)
+ torch.set_num_threads(len(cpus_this_worker))
+
+ torch.manual_seed(seed)
+ random.seed(seed)
+ try:
+ import numpy as np
+
+ np.random.seed(seed)
+ except ModuleNotFoundError:
+ pass
+
+ if rank == 0:
+ unique_id: bytes = nccl.getUniqueId()
+ store.set("BMTRAIN_UNIQUE_ID", unique_id.hex())
+
+ unique_id = bytes.fromhex(store.get("BMTRAIN_UNIQUE_ID").decode())
+ config["comm"] = nccl.commInitRank(unique_id, world_size, rank)
+ topo = config["topology"]
+
+ if config["pipe_enabled"]:
+ config["micros"] = (
+ num_micro_batches if num_micro_batches else config["pipe_size"]
+ )
+ if topo.stage_id == 0:
+ unique_id = nccl.getUniqueId()
+ store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex())
+ unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode())
+ config["pipe_comm"] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id)
+
+ if topo.pp_zero_id == 0:
+ unique_id = nccl.getUniqueId()
+ store.set(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}", unique_id.hex())
+ unique_id = bytes.fromhex(
+ store.get(f"PP_ZERO_UNIQUE_ID{topo.pp_zero_idx}").decode()
+ )
+ config["pp_zero_comm"] = nccl.commInitRank(
+ unique_id, world_size // config["pipe_size"], topo.pp_zero_id
+ )
+
+ if config["tp_size"] > 1:
+ if topo.tp_id == 0:
+ unique_id = nccl.getUniqueId()
+ store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex())
+ unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode())
+ config["tp_comm"] = nccl.commInitRank(unique_id, tp_size, topo.tp_id)
+
+ if topo.tp_zero_id == 0:
+ unique_id = nccl.getUniqueId()
+ store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex())
+ unique_id = bytes.fromhex(
+ store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode()
+ )
+ config["tp_zero_comm"] = nccl.commInitRank(
+ unique_id, world_size // config["tp_size"], topo.tp_zero_id
+ )
+
+ if config["pipe_size"] > 1 and config["tp_size"] > 1:
+ if topo.pp_tp_zero_id == 0:
+ unique_id = nccl.getUniqueId()
+ store.set(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}", unique_id.hex())
+ unique_id = bytes.fromhex(
+ store.get(f"PP_TP_ZERO_UNIQUE_ID{topo.pp_tp_zero_idx}").decode()
+ )
+ config["pp_tp_zero_comm"] = nccl.commInitRank(
+ unique_id,
+ world_size // (config["pipe_size"] * config["tp_size"]),
+ topo.pp_tp_zero_id,
+ )
+
+ config["zero_comm"] = config["comm"]
+
+ for i in range(world_size):
+ if i == rank:
+ print_dict(
+ "Initialization",
+ {
+ "rank": rank,
+ "local_rank": local_rank,
+ "world_size": world_size,
+ "local_size": local_size,
+ "master": master,
+ "device": torch.cuda.current_device(),
+ "cpus": cpus_this_worker,
+ },
+ )
+ synchronize()
+
+
+class topology:
+ """A helper class to keep parallel information when using different parallel methods together."""
+
+ def __init__(self, config):
+ # pipe_idx is the idx of the pipeline in the group
+ self.rank = config["rank"]
+ pp_size = config["pipe_size"]
+ tp_size = config["tp_size"]
+ world_size = config["world_size"]
+ assert (
+ world_size % (pp_size * tp_size) == 0
+ ), "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size"
+
+ dp_size = world_size // (pp_size * tp_size)
+ config["tp_zero_size"] = dp_size
+ config["zero_size"] = world_size // pp_size
+ self.stages = config["pipe_size"]
+
+ stage_size = world_size // pp_size
+ for i in range(world_size):
+ self.pipe_idx = self.rank % stage_size
+ self.stage_id = self.rank // stage_size
+ self.tp_id = self.rank % tp_size
+ self.tp_idx = self.rank // tp_size
+ # pp->zero
+ self.pp_zero_idx = self.stage_id
+ self.pp_zero_id = self.pipe_idx
+ # tp->zero
+ self.tp_zero_idx = self.tp_id
+ self.tp_zero_id = self.tp_idx
+ # pp->tp->zero
+ self.pp_tp_zero_idx = self.stage_id * tp_size + self.tp_id
+ self.pp_tp_zero_id = self.pipe_idx // tp_size
+ # only zero
+ self.zero_idx = 0
+ self.zero_id = self.rank
+
+ def get_group_id(self, group_name):
+ """Get group id of different parallel group.
+
+ Args:
+ group_name (str): must be one of "pipe", "zero", "tp_zero" or "tp".
+ """
+ if group_name == "pipe":
+ return self.pipe_idx
+ elif group_name == "zero":
+ return self.zero_idx
+ elif group_name == "tp_zero":
+ return self.tp_zero_idx
+ elif group_name == "tp":
+ return self.tp_idx
+
+ def get_group_rank(self, group_name):
+ """Get group rank of different parallel group.
+
+ Args:
+ group_name (str): must be one of "pipe", "zero", "tp_zero" or "tp".
+ """
+ if group_name == "pipe":
+ return self.stage_id
+ elif group_name == "zero":
+ return self.zero_id
+ elif group_name == "tp_zero":
+ return self.tp_zero_id
+ elif group_name == "tp":
+ return self.tp_id
+
+
+def is_initialized() -> bool:
+ return config["initialized"]
diff --git a/examples/BMTrain/bmtrain/inspect/__init__.py b/examples/BMTrain/bmtrain/inspect/__init__.py
new file mode 100644
index 00000000..2b6d2d26
--- /dev/null
+++ b/examples/BMTrain/bmtrain/inspect/__init__.py
@@ -0,0 +1,3 @@
+from .format import format_summary
+from .model import inspect_model
+from .tensor import inspect_tensor, record_tensor
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/inspect/format.py b/examples/BMTrain/bmtrain/inspect/format.py
new file mode 100644
index 00000000..79b3a1e5
--- /dev/null
+++ b/examples/BMTrain/bmtrain/inspect/format.py
@@ -0,0 +1,64 @@
+from typing import Any, Dict, List
+
+def align_str(s : str, align : int, left : bool) -> str:
+ if left:
+ return s + " " * (align - len(s))
+ else:
+ return " " * (align - len(s)) + s
+
+def format_line(strs : List[str], length : List[int]):
+ ret = ""
+ for v, l in zip(strs, length):
+ if len(v) + 1 > l:
+ v = " " + v[:l - 1]
+ else:
+ v = " " + v
+ ret += align_str(v, l, True)
+ return ret
+
+def item_formater(x) -> str:
+ if isinstance(x, float):
+ return "{:.4f}".format(x)
+ else:
+ return str(x)
+
+def format_summary(summary : List[Dict[str, Any]]) -> str:
+ """Format summary to string.
+
+ Args:
+ summary (List[Dict[str, Any]]): The summary to format.
+
+ Returns:
+ str: The formatted summary.
+
+ """
+ ret = []
+
+ max_name_len = max([len("name")] + [len(item["name"]) for item in summary]) + 4
+ headers = [
+ "name",
+ "shape",
+ "max",
+ "min",
+ "std",
+ "mean",
+ "grad_std",
+ "grad_mean",
+ ]
+ headers_length = [
+ max_name_len,
+ 20,
+ 10,
+ 10,
+ 10,
+ 10,
+ 10,
+ 10
+ ]
+ ret.append( format_line(headers, headers_length) )
+ for item in summary:
+ values = [ item_formater(item[name]) for name in headers ]
+ ret.append( format_line(values, headers_length) )
+ return "\n".join(ret)
+
+
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/inspect/model.py b/examples/BMTrain/bmtrain/inspect/model.py
new file mode 100644
index 00000000..fc54f0d6
--- /dev/null
+++ b/examples/BMTrain/bmtrain/inspect/model.py
@@ -0,0 +1,246 @@
+import torch
+from ..store import broadcast_object
+from ..pipe_layer import PipelineTransformerBlockList
+from ..block_layer import Block
+from ..parameter import DistributedParameter
+from .. import nccl
+from ..global_var import config
+import fnmatch
+
+def _gather_value(value : torch.Tensor, partition_size, origin_size):
+ global_size = partition_size * config['world_size']
+
+ storage = value.storage_type()(global_size)
+
+ if value.storage().size() != partition_size:
+ tmp_buf = torch.zeros(partition_size, dtype=value.dtype, device=value.device)
+ tmp_buf[:value.numel()] = value[:]
+ nccl.allGather(
+ tmp_buf.storage(),
+ storage,
+ config['comm']
+ )
+ else:
+ nccl.allGather(
+ value.storage(),
+ storage,
+ config['comm']
+ )
+
+ output_tensor = torch.tensor([], dtype=value.dtype, device="cuda")
+ output_tensor.set_(storage, 0, origin_size)
+
+ return output_tensor
+
+def inspect_pipeline_transformer_block_list(pipe_model: PipelineTransformerBlockList, param_name : str, _prefix : str = ''):
+ ret = []
+ for name, model in pipe_model._modules.items():
+ idx = int(name)
+ prefix = _prefix + name + '.'
+
+ # fast check
+ pass_fast_check = False
+ for param in model._param_info:
+ abs_name = prefix + param["name"]
+ if fnmatch.fnmatch(abs_name, param_name):
+ pass_fast_check = True
+ break
+ if not pass_fast_check:
+ continue
+
+ if idx in pipe_model.layer_ids:
+ _param_buffer = {}
+ _grad_buffer = {}
+ for kw, val in model._storage_info.items():
+ storage_type = model._storage_params[kw].storage_type()
+
+ _param_buffer[kw] = storage_type(val["partition_size"] * val['world_size'])
+ if model._storage_params[kw].grad is not None:
+ _grad_buffer[kw] = storage_type(val["partition_size"] * val['world_size'])
+
+ nccl.groupStart()
+ for kw, val in model._storage_info.items():
+ nccl.allGather(
+ model._storage_params[kw].storage(),
+ _param_buffer[kw],
+ val["zero_comm"]
+ )
+ if model._storage_params[kw].grad is not None:
+ nccl.allGather(
+ model._storage_params[kw].grad.storage(),
+ _grad_buffer[kw],
+ val["zero_comm"]
+ )
+
+ nccl.groupEnd()
+ for param in model._param_info:
+ abs_name = prefix + param["name"]
+ if fnmatch.fnmatch(abs_name, param_name):
+ kw_name = param["kw_name"]
+ dtype = _param_buffer[kw_name].dtype
+ device = _param_buffer[kw_name].device
+ offset = param["offset"]
+ shape = param["shape"]
+ p = torch.tensor([], dtype=dtype, device=device).set_(_param_buffer[kw_name], offset, shape)
+ if kw_name in _grad_buffer:
+ g = torch.tensor([], dtype=dtype, device=device).set_(_grad_buffer[kw_name], offset, shape)
+ info = {
+ "name": abs_name,
+ "shape": tuple(shape),
+ "std": p.std().cpu().item(),
+ "mean": p.mean().cpu().item(),
+ "grad_std": g.std().cpu().item(),
+ "grad_mean": g.mean().cpu().item(),
+ "max": p.max().cpu().item(),
+ "min": p.min().cpu().item(),
+ }
+ else:
+ info = {
+ "name": abs_name,
+ "shape": tuple(shape),
+ "std": p.std().cpu().item(),
+ "mean": p.mean().cpu().item(),
+ "grad_std": 0.,
+ "grad_mean": 0.,
+ "max": p.max().cpu().item(),
+ "min": p.min().cpu().item(),
+ }
+ broadcast_object(info, config["pipe_comm"], pipe_model.get_stage_by_layer_id(idx))
+ ret.append(info)
+ else:
+ for param in model._param_info:
+ abs_name = prefix + param["name"]
+ if fnmatch.fnmatch(abs_name, param_name):
+ info = broadcast_object({}, config["pipe_comm"], pipe_model.get_stage_by_layer_id(idx))
+ ret.append(info)
+
+ return ret
+
+
+def inspect_block(model : Block, param_name : str, prefix : str = ''):
+ # fast check
+ pass_fast_check = False
+ for param in model._param_info:
+ abs_name = prefix + param["name"]
+ if fnmatch.fnmatch(abs_name, param_name):
+ pass_fast_check = True
+ break
+ if not pass_fast_check:
+ return []
+
+ _param_buffer = {}
+ _grad_buffer = {}
+ for kw, val in model._storage_info.items():
+ storage_type = model._storage_params[kw].storage_type()
+
+ _param_buffer[kw] = storage_type(val["partition_size"] * config['world_size'])
+ if model._storage_params[kw].grad is not None:
+ _grad_buffer[kw] = storage_type(val["partition_size"] * config['world_size'])
+
+ nccl.groupStart()
+ for kw, val in model._storage_info.items():
+ nccl.allGather(
+ model._storage_params[kw].storage(),
+ _param_buffer[kw],
+ config["comm"]
+ )
+ if model._storage_params[kw].grad is not None:
+ nccl.allGather(
+ model._storage_params[kw].grad.storage(),
+ _grad_buffer[kw],
+ config["comm"]
+ )
+
+ nccl.groupEnd()
+ ret = []
+ for param in model._param_info:
+ abs_name = prefix + param["name"]
+ if fnmatch.fnmatch(abs_name, param_name):
+ kw_name = param["kw_name"]
+ dtype = _param_buffer[kw_name].dtype
+ device = _param_buffer[kw_name].device
+ offset = param["offset"]
+ shape = param["shape"]
+ p = torch.tensor([], dtype=dtype, device=device).set_(_param_buffer[kw_name], offset, shape)
+ if kw_name in _grad_buffer:
+ g = torch.tensor([], dtype=dtype, device=device).set_(_grad_buffer[kw_name], offset, shape)
+ ret.append({
+ "name": abs_name,
+ "shape": tuple(shape),
+ "std": p.std().cpu().item(),
+ "mean": p.mean().cpu().item(),
+ "grad_std": g.std().cpu().item(),
+ "grad_mean": g.mean().cpu().item(),
+ "max": p.max().cpu().item(),
+ "min": p.min().cpu().item(),
+ })
+ else:
+ ret.append({
+ "name": abs_name,
+ "shape": tuple(shape),
+ "std": p.std().cpu().item(),
+ "mean": p.mean().cpu().item(),
+ "grad_std": 0.,
+ "grad_mean": 0.,
+ "max": p.max().cpu().item(),
+ "min": p.min().cpu().item(),
+ })
+ return ret
+
+@torch.no_grad()
+def inspect_model(model : torch.nn.Module, param_name : str, prefix : str = ''):
+ """Inspect the model and return the summary of the parameters.
+
+ Args:
+ model (torch.nn.Module): The model to be inspected.
+ param_name (str): The name of the parameter to be inspected. The wildcard '*' can be used to match multiple parameters.
+ prefix (str): The prefix of the parameter name.
+
+ Returns:
+ list: The summary of the parameters.
+
+ Example:
+ >>> result_linear = bmt.inspect.inspect_model(model, "*.linear*")
+ >>> result_layernorm = bmt.inspect.inspect_model(model, "*.layernorm*")
+ >>> text_summray = bmt.inspect.format_summary(result_linear + result_layernorm)
+ >>> bmt.print_rank(text_summary)
+ name shape max min std mean grad_std grad_mean
+ ...
+
+ """
+ if isinstance(model, PipelineTransformerBlockList):
+ return inspect_pipeline_transformer_block_list(model, param_name, prefix)
+ elif isinstance(model, Block):
+ return inspect_block(model, param_name, prefix)
+ else:
+ ret = []
+ for name, param in model._parameters.items():
+ if fnmatch.fnmatch(prefix + name, param_name):
+ if isinstance(param, DistributedParameter):
+ p = _gather_value(param.data, param.storage().size(), param._original_shape)
+ else:
+ p = param
+ if p is None:
+ continue
+ stats = {
+ 'name': prefix + name,
+ 'shape': tuple(p.size()),
+ "std": p.std().cpu().item(),
+ "mean": p.mean().cpu().item(),
+ "max": p.max().cpu().item(),
+ "min": p.min().cpu().item(),
+ }
+ if param.grad is not None:
+ if isinstance(param, DistributedParameter):
+ g = _gather_value(param.grad.data, param.storage().size(), param._original_shape)
+ else:
+ g = param.grad
+ stats["grad_std"] = g.std().cpu().item()
+ stats["grad_mean"] = g.mean().cpu().item()
+ else:
+ stats["grad_std"] = 0.
+ stats["grad_mean"] = 0.
+ ret.append(stats)
+ for name, module in model._modules.items():
+ ret.extend(inspect_model(module, param_name, prefix + name + '.'))
+ return ret
diff --git a/examples/BMTrain/bmtrain/inspect/tensor.py b/examples/BMTrain/bmtrain/inspect/tensor.py
new file mode 100644
index 00000000..9d003f82
--- /dev/null
+++ b/examples/BMTrain/bmtrain/inspect/tensor.py
@@ -0,0 +1,383 @@
+from typing import Optional
+import torch
+from .. import debug
+from .. import nccl
+from ..global_var import config
+from ..store import broadcast_object
+from ..distributed import broadcast
+import math
+
+
+class InspectTensor:
+ """This object is returned by `InspectTensorManager`.
+
+ You can get the tensors recorded by `record_tensor`.
+
+ """
+
+ def __init__(self):
+ self.summary = []
+
+ def _set_summary(self, summary):
+ self._summary = summary
+ for item in summary:
+ item["prefix"] = "" if item["group"] is None else f'{item["group"]}.'
+
+ self.summary = []
+
+ kw_cnt = {}
+ i = 0
+ while i < len(summary):
+ item = summary[i]
+ if item["inside_pipe"] is not None:
+ before_len = len(self.summary)
+
+ assert item["inside_pipe"]["st"]
+ pipe_cnt = {}
+ j = i
+ while j < len(summary):
+ item = summary[j]
+ kw = f'{item["prefix"]}{item["name"]}'
+
+ assert item["inside_pipe"] is not None
+ stage_id = item["inside_pipe"]["stage_id"]
+ stages = item["inside_pipe"]["stages"]
+ st = item["inside_pipe"]["st"]
+ ed = item["inside_pipe"]["ed"]
+
+ if kw not in pipe_cnt:
+ pipe_cnt[kw] = 0
+ pipe_cnt[kw] += 1
+
+ j += 1
+ if ed:
+ break
+
+ for stage in range(stages):
+ if stage_id == stage:
+ broadcast_object(pipe_cnt, config["pipe_comm"], src=stage)
+ for k in range(i, j):
+ item = summary[k]
+ kw = f'{item["prefix"]}{item["name"]}'
+ if kw not in kw_cnt:
+ kw_cnt[kw] = 0
+ tensor = torch.cat(
+ [
+ summary[k + m * (j - i)]["tensor"]
+ for m in range(config["micros"])
+ ],
+ dim=0,
+ )
+ grad = (
+ torch.cat(
+ [
+ summary[k + m * (j - i)]["tensor"].grad
+ for m in range(config["micros"])
+ ],
+ dim=0,
+ )
+ if item["requires_grad"]
+ and item["tensor"].grad is not None
+ else None
+ )
+ self.summary.append(
+ {
+ "name": item["name"],
+ "summary_name": f'{item["prefix"]}{kw_cnt[kw]}.{item["name"]}',
+ "group": item["group"],
+ "min": None,
+ "max": None,
+ "mean": None,
+ "std": None,
+ "shape": (item["shape"][0] * config["micros"],)
+ + item["shape"][1:],
+ "grad_mean": None,
+ "grad_std": None,
+ "tensor": tensor,
+ "grad": grad,
+ "requires_grad": item["requires_grad"],
+ "inside_pipe": {"stage_id": stage},
+ }
+ )
+ kw_cnt[kw] += 1
+ else:
+ cnt = broadcast_object({}, config["pipe_comm"], src=stage)
+ for kw, val in cnt.items():
+ if kw not in kw_cnt:
+ kw_cnt[kw] = 0
+ for _ in range(val):
+ self.summary.append(
+ {
+ "name": item["name"],
+ "summary_name": f'{item["prefix"]}{kw_cnt[kw]}.{item["name"]}',
+ "group": None,
+ "min": None,
+ "max": None,
+ "mean": None,
+ "std": None,
+ "shape": None,
+ "grad_mean": None,
+ "grad_std": None,
+ "tensor": None,
+ "grad": None,
+ "requires_grad": None,
+ "inside_pipe": {"stage_id": stage},
+ }
+ )
+ kw_cnt[kw] += 1
+
+ after_len = len(self.summary)
+ with torch.enable_grad():
+ for it in self.summary[before_len:after_len]:
+ if it["tensor"] is not None:
+ has_grad = it["grad"] is not None
+ info = {
+ "group": it["group"],
+ "shape": it["shape"],
+ "requires_grad": it["requires_grad"],
+ "has_grad": has_grad,
+ }
+ broadcast_object(
+ info,
+ config["pipe_comm"],
+ src=it["inside_pipe"]["stage_id"],
+ )
+ tensor = it["tensor"]
+ tensor = broadcast(
+ tensor,
+ it["inside_pipe"]["stage_id"],
+ config["pipe_comm"],
+ )
+ grad = it["grad"]
+ else:
+ info = broadcast_object(
+ {},
+ config["pipe_comm"],
+ src=it["inside_pipe"]["stage_id"],
+ )
+ has_grad = info.pop("has_grad")
+ it.update(info)
+ tensor = torch.empty(it["shape"]).cuda().requires_grad_()
+ tensor = broadcast(
+ tensor,
+ it["inside_pipe"]["stage_id"],
+ config["pipe_comm"],
+ )
+ if has_grad:
+ grad = torch.empty(it["shape"]).cuda()
+ tensor = tensor.chunk(stages, dim=0)[stage_id].clone()
+ it["tensor"] = tensor
+ if has_grad:
+ grad = broadcast(
+ grad, it["inside_pipe"]["stage_id"], config["pipe_comm"]
+ )
+ grad = grad.chunk(stages, dim=0)[stage_id].clone()
+ tensor.grad = grad
+ it["shape"] = (it["shape"][0] // config["pipe_size"],) + it[
+ "shape"
+ ][1:]
+
+ i = i + config["micros"] * (j - i)
+ else:
+ kw = f'{item["prefix"]}{item["name"]}'
+ if kw not in kw_cnt:
+ kw_cnt[kw] = 0
+ self.summary.append(
+ {
+ "name": item["name"],
+ "summary_name": f'{item["prefix"]}{kw_cnt[kw]}.{item["name"]}',
+ "group": item["group"],
+ "min": None,
+ "max": None,
+ "mean": None,
+ "std": None,
+ "shape": item["shape"],
+ "grad_mean": None,
+ "grad_std": None,
+ "tensor": item["tensor"],
+ "requires_grad": item["requires_grad"],
+ "inside_pipe": None,
+ }
+ )
+ kw_cnt[kw] += 1
+ i = i + 1
+
+ def get_summary(self):
+ r"""Get the summary of the tensors recorded by `record_tensor`.
+
+ Returns:
+ A list of dicts. Each dict contains the following keys:
+ - name: The name of the tensor.
+ - min: The minimum value of the tensor.
+ - max: The maximum value of the tensor.
+ - mean: The mean value of the tensor.
+ - std: The standard deviation of the tensor.
+ - shape: The shape of the tensor.
+ - grad_mean: The mean value of the gradient of the tensor.
+ - grad_std: The standard deviation of the gradient of the tensor.
+
+ **Note:** This method must be called outside of the `with` block.
+
+ """
+ self._set_summary(self._summary)
+ ret = []
+ for item in self.summary:
+ comm = config["comm"]
+
+ if not item["requires_grad"] or item["tensor"].grad is None:
+ x = item["tensor"]
+ info = torch.empty(2, dtype=x.dtype, device=x.device)
+ info[0] = x.mean()
+ info[1] = x.var()
+ nccl.allReduce(info.storage(), info.storage(), "sum", comm)
+ info = info / nccl.commCount(comm)
+ x_mean = info[0].cpu().item()
+ x_std = math.sqrt(info[1].cpu().item())
+ grad_mean = None
+ grad_std = None
+ else:
+ x = item["tensor"]
+ info = torch.empty(4, dtype=x.dtype, device=x.device)
+ info[0] = x.mean()
+ info[1] = x.var()
+ info[2] = x.grad.mean()
+ info[3] = x.grad.var()
+ nccl.allReduce(info.storage(), info.storage(), "sum", comm)
+ info = info / nccl.commCount(comm)
+ x_mean = info[0].cpu().item()
+ x_std = math.sqrt(info[1].cpu().item())
+ grad_mean = info[2].cpu().item()
+ grad_std = math.sqrt(info[3].cpu().item())
+
+ info[0] = x.max()
+ info[1] = -x.min()
+ nccl.allReduce(info.storage(), info.storage(), "max", comm)
+ x_max = info[0].cpu().item()
+ x_min = -info[1].cpu().item()
+
+ summary = {
+ "name": item["summary_name"],
+ "min": x_min,
+ "max": x_max,
+ "mean": x_mean,
+ "std": x_std,
+ "shape": tuple(
+ (item["shape"][0] * config["world_size"],) + item["shape"][1:]
+ ),
+ "grad_mean": grad_mean,
+ "grad_std": grad_std,
+ }
+
+ ret.append(summary)
+ return ret
+
+ def get_tensor(
+ self, name: str, group: Optional[str] = None, index: Optional[int] = None
+ ) -> torch.Tensor:
+ """Get the tensor recorded by `record_tensor` by name, group and index.
+
+ Args:
+ name (str): The name of the tensor.
+ group (Optional[str]): The group of the tensor.
+ index (Optional[int]): The index of the tensor.
+
+ Returns:
+ The tensor if found, otherwise None.
+
+ """
+ group_name_prefix = f"{group}." if group is not None else ""
+
+ all_names = []
+ if index is None:
+ all_names.append(f"{group_name_prefix}{name}")
+ all_names.append(f"{group_name_prefix}0.{name}")
+ else:
+ all_names.append(f"{group_name_prefix}{index}.{name}")
+
+ for item in self.summary:
+ if item["name"] in all_names:
+ return item["tensor"]
+ return None
+
+
+class InspectTensorManager:
+ def __init__(self) -> None:
+ self._inspector = None
+
+ def __enter__(self) -> InspectTensor:
+ self.prev_val = debug.get("_inspect_tensor", False)
+ if not self.prev_val:
+ debug.set("_inspect_tensor", True)
+ self._inspector = InspectTensor()
+ return self._inspector
+ else:
+ raise RuntimeError("InspectTensorManager is already in use")
+
+ def __exit__(self, *args):
+ if not self.prev_val:
+ debug.set("_inspect_tensor", self.prev_val)
+ summary = debug.get("_inspect_hidden_states", [])
+ self._inspector._set_summary(summary)
+ self._inspector = None
+ debug.set("_inspect_hidden_states", [])
+
+
+def inspect_tensor() -> InspectTensorManager:
+ """**inspect_tensor** returns a context manager that can be used to get the intermediate results of the model computations and their gradients.
+
+ Example:
+ >>> with bmt.inspect.inspect_tensor() as inspector:
+ >>> loss = model(inputs)
+ >>> loss.backward()
+ >>> summary = inspector.get_summary()
+ >>> text_summary = bmt.inspect.format_summary(summary)
+ >>> bmt.print_rank(text_summary)
+ name shape max min std mean grad_std grad_mean
+ ...
+
+ **Note:** loss.backward() must be called inside the context manager, otherwise the gradients will not be recorded.
+ **Note:** Calling get_summary() has significant overhead.
+
+ """
+
+ return InspectTensorManager()
+
+
+def record_tensor(x: torch.Tensor, name: str, group=None):
+ """Record the tensor for inspection.
+
+ Args:
+ x (torch.Tensor): The tensor to be recorded.
+ name (str): The name of the tensor.
+ group (str): The group name of the tensor.
+
+ **Note:** This function is only available in inspect_tensor context.
+ **Note:** Recording too many tensors may cause memory issues.
+
+ """
+ if isinstance(x, torch.nn.Parameter):
+ raise RuntimeError("Cannot inspect Parameter")
+
+ if not debug.get("_inspect_tensor", False):
+ # do nothing
+ return
+
+ if x.requires_grad:
+ x.retain_grad()
+ debug.append(
+ "_inspect_hidden_states",
+ {
+ "name": name,
+ "group": group,
+ "min": None,
+ "max": None,
+ "mean": None,
+ "std": None,
+ "shape": x.shape,
+ "grad_mean": None,
+ "grad_std": None,
+ "tensor": x,
+ "requires_grad": x.requires_grad,
+ "inside_pipe": None,
+ },
+ )
diff --git a/examples/BMTrain/bmtrain/layer.py b/examples/BMTrain/bmtrain/layer.py
new file mode 100644
index 00000000..e071e01b
--- /dev/null
+++ b/examples/BMTrain/bmtrain/layer.py
@@ -0,0 +1,143 @@
+import torch
+from .parameter import DistributedParameter
+from .global_var import config
+import itertools
+from .utils import tp_split_tensor
+
+class DistributedModule(torch.nn.Module):
+ """
+ DistributedModule is a subclass of torch.nn.Module that overrides the `__getattr__` method to gather distributed parameters automatically.
+
+ """
+
+ def __getattr__(self, name: str):
+ ret = super().__getattr__(name)
+ # gather distributed parameters if not in bmt.Block
+ if isinstance(ret, DistributedParameter) and not ret._in_block:
+ return ret.gather()
+ return ret
+
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
+ r"""Saves module state to `destination` dictionary, containing a state
+ of the module, but not its descendants. This is called on every
+ submodule in :meth:`~torch.nn.Module.state_dict`.
+
+ In rare cases, subclasses can achieve class-specific behavior by
+ overriding this method with custom logic.
+
+ Args:
+ destination (dict): a dict where state will be stored
+ prefix (str): the prefix for parameters and buffers used in this
+ module
+ """
+ for name, param in self._parameters.items():
+ if param is not None:
+ if isinstance(param, DistributedParameter):#and not param._in_block:
+ if param._in_block:
+ destination[prefix + name] = param.tp_gather().detach() # sync operation
+ else:
+ destination[prefix + name] = param.gather_all().detach() # sync operation
+ if config['save_param_to_cpu']:
+ destination[prefix + name] = destination[prefix + name].cpu()
+ else:
+ if config['save_param_to_cpu']:
+ destination[prefix + name] = param if keep_vars else param.detach().cpu()
+ else:
+ destination[prefix + name] = param if keep_vars else param.detach()
+
+ for name, buf in self._buffers.items():
+ if buf is not None and name not in self._non_persistent_buffers_set:
+ destination[prefix + name] = buf if keep_vars else buf.detach()
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ r"""Copies parameters and buffers from :attr:`state_dict` into only
+ this module, but not its descendants. This is called on every submodule
+ in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
+ module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
+ For state dicts without metadata, :attr:`local_metadata` is empty.
+ Subclasses can achieve class-specific backward compatible loading using
+ the version number at `local_metadata.get("version", None)`.
+
+ .. note::
+ :attr:`state_dict` is not the same object as the input
+ :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
+ it can be modified.
+
+ Args:
+ state_dict (dict): a dict containing parameters and
+ persistent buffers.
+ prefix (str): the prefix for parameters and buffers used in this
+ module
+ local_metadata (dict): a dict containing the metadata for this module.
+ See
+ strict (bool): whether to strictly enforce that the keys in
+ :attr:`state_dict` with :attr:`prefix` match the names of
+ parameters and buffers in this module
+ missing_keys (list of str): if ``strict=True``, add missing keys to
+ this list
+ unexpected_keys (list of str): if ``strict=True``, add unexpected
+ keys to this list
+ error_msgs (list of str): error messages should be added to this
+ list, and will be reported together in
+ :meth:`~torch.nn.Module.load_state_dict`
+ """
+ for hook in self._load_state_dict_pre_hooks.values():
+ hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
+
+ persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
+ local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
+ local_state = {k: v for k, v in local_name_params if v is not None}
+
+ for name, param in local_state.items():
+ key = prefix + name
+ if key in state_dict:
+ tp_mode = param._tp_mode
+ input_param = state_dict[key]
+ if input_param.__class__.__name__ == "DistributedTensorWrapper":
+ input_param = input_param.broadcast()
+ # This is used to avoid copying uninitialized parameters into
+ # non-lazy modules, since they dont have the hook to do the checks
+ # in such case, it will error when accessing the .shape attribute.
+ is_param_lazy = torch.nn.parameter.is_lazy(param)
+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
+ if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
+ input_param = input_param[0]
+
+ if not is_param_lazy and not isinstance(param, DistributedParameter) and input_param.shape != param.shape:
+ # local shape should match the one in checkpoint
+ error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
+ 'the shape in current model is {}.'
+ .format(key, input_param.shape, param.shape))
+ continue
+ verify_shape = torch.Size(param._original_shape if not tp_mode else param._tp_original_shape)
+ if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != verify_shape:
+ error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
+ 'the shape in current model is {}.'
+ .format(key, input_param.shape, verify_shape))
+ try:
+ with torch.no_grad():
+ if isinstance(param, DistributedParameter):
+ tp_split_dim = param._tp_split_dim
+ if tp_mode and tp_split_dim >= 0:
+ input_param = tp_split_tensor(input_param, tp_split_dim)
+ param._copy_data(input_param)
+ else:
+ param.copy_(input_param)
+ except Exception as ex:
+ error_msgs.append('While copying the parameter named "{}", '
+ 'whose dimensions in the model are {} and '
+ 'whose dimensions in the checkpoint are {}, '
+ 'an exception occurred : {}.'
+ .format(key, param.size(), input_param.size(), ex.args))
+ elif strict:
+ missing_keys.append(key)
+
+ if strict:
+ for key in state_dict.keys():
+ if key.startswith(prefix):
+ input_name = key[len(prefix):]
+ input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
+ if input_name not in self._modules and input_name not in local_state:
+ unexpected_keys.append(key)
+
diff --git a/examples/BMTrain/bmtrain/loss/__init__.py b/examples/BMTrain/bmtrain/loss/__init__.py
new file mode 100644
index 00000000..daa731fd
--- /dev/null
+++ b/examples/BMTrain/bmtrain/loss/__init__.py
@@ -0,0 +1 @@
+from .cross_entropy import FusedCrossEntropy
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/loss/_function.py b/examples/BMTrain/bmtrain/loss/_function.py
new file mode 100644
index 00000000..6ff3c471
--- /dev/null
+++ b/examples/BMTrain/bmtrain/loss/_function.py
@@ -0,0 +1,182 @@
+from .. import C
+import torch
+
+CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda
+
+
+def has_inf_nan(g_half: torch.Tensor, out: torch.Tensor) -> None:
+ assert out.dtype == torch.uint8, "out must be a uint8 tensor"
+ assert CHECK_INPUT(g_half), "g_fp16 must be contiguous and on cuda"
+ assert CHECK_INPUT(out), "out must be contiguous and on cuda"
+ mid = torch.zeros(1024, device=out.device, dtype=out.dtype)
+ stream = torch.cuda.current_stream().cuda_stream
+ if g_half.dtype == torch.float16:
+ C.has_nan_inf_fp16_launcher(
+ g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream
+ )
+ elif g_half.dtype == torch.bfloat16:
+ if not C.is_bf16_supported():
+ raise NotImplementedError(f"bfloat16 is not supported on current GPU")
+ C.has_nan_inf_bf16_launcher(
+ g_half.numel(), g_half.data_ptr(), mid.data_ptr(), out.data_ptr(), stream
+ )
+ else:
+ raise ValueError(f"has_inf_nan not supported for dtype {g_half.dtype}")
+
+
+def cross_entropy_forward(
+ m: int,
+ n: int,
+ input: torch.Tensor,
+ target: torch.Tensor,
+ softmax: torch.Tensor,
+ output: torch.Tensor,
+ ignore_index: int,
+) -> None:
+ CHECK_INPUT(input)
+ CHECK_INPUT(target)
+ CHECK_INPUT(softmax)
+ CHECK_INPUT(output)
+ assert target.dtype == torch.int32, "target must be an int tensor"
+ assert output.dtype == torch.float32, "output must be a float tensor"
+ assert (
+ input.numel() == softmax.numel()
+ ), "input and softmax must have the same number of elements"
+ assert (
+ target.numel() == output.numel()
+ ), "target and output must have the same number of elements"
+ input_ptr = input.data_ptr()
+ target_ptr = target.data_ptr()
+ softmax_ptr = softmax.data_ptr()
+ output_ptr = output.data_ptr()
+ cuda_stream = torch.cuda.current_stream().cuda_stream
+ if input.dtype == torch.float16:
+ C.cross_entropy_forward_fp16_launcher(
+ m,
+ n,
+ input_ptr,
+ target_ptr,
+ softmax_ptr,
+ output_ptr,
+ ignore_index,
+ cuda_stream,
+ )
+ elif input.dtype == torch.bfloat16:
+ if not C.is_bf16_supported():
+ raise NotImplementedError(f"bfloat16 is not supported on current GPU")
+ C.cross_entropy_forward_bf16_launcher(
+ m,
+ n,
+ input_ptr,
+ target_ptr,
+ softmax_ptr,
+ output_ptr,
+ ignore_index,
+ cuda_stream,
+ )
+ else:
+ raise ValueError(f"cross_entropy_forward not supported for dtype {input.dtype}")
+
+
+def cross_entropy_backward_inplace(
+ m: int,
+ n: int,
+ grad_output: torch.Tensor,
+ target: torch.Tensor,
+ x: torch.Tensor,
+ ignore_index: int,
+) -> None:
+ CHECK_INPUT(grad_output)
+ CHECK_INPUT(target)
+ CHECK_INPUT(x)
+ assert grad_output.dtype == torch.float32, "grad_output must be a float tensor"
+ assert target.dtype == torch.int32, "target must be an int tensor"
+ assert (
+ target.numel() == grad_output.numel()
+ ), "target and grad_output must have the same number of elements"
+ cuda_stream = torch.cuda.current_stream().cuda_stream
+ grad_output_ptr = grad_output.data_ptr()
+ target_ptr = target.data_ptr()
+ x_ptr = x.data_ptr()
+
+ if x.dtype == torch.float16:
+ C.cross_entropy_backward_inplace_fp16_launcher(
+ m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream
+ )
+ elif x.dtype == torch.bfloat16:
+ if not C.is_bf16_supported():
+ raise NotImplementedError(f"bfloat16 is not supported on current GPU")
+ C.cross_entropy_backward_inplace_bf16_launcher(
+ m, n, grad_output_ptr, target_ptr, x_ptr, ignore_index, cuda_stream
+ )
+ else:
+ raise ValueError(
+ f"cross_entropy_backward not supported for dtype {input.dtype}"
+ )
+
+
+def fused_sumexp(logits: torch.Tensor, max_logits: torch.Tensor) -> torch.Tensor:
+ CHECK_INPUT(logits)
+ CHECK_INPUT(max_logits)
+ assert max_logits.dtype == torch.float32, "max_logits must be float tensor"
+ assert max_logits.size(0) == logits.size(
+ 0
+ ), "max_logits must have same size(0) as logits"
+ sum_exp_logits = torch.empty(
+ logits.size(0), dtype=torch.float32, device=logits.device
+ )
+ m = logits.size(0)
+ n = logits.size(1)
+ cuda_stream = torch.cuda.current_stream().cuda_stream
+ logits_ptr = logits.data_ptr()
+ max_logits_ptr = max_logits.data_ptr()
+ sum_exp_logits_ptr = sum_exp_logits.data_ptr()
+ if logits.dtype == torch.float16:
+ C.fused_sumexp_fp16_launcher(
+ m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream
+ )
+ elif logits.dtype == torch.bfloat16:
+ if not C.is_bf16_supported():
+ raise NotImplementedError(f"bfloat16 is not supported on current GPU")
+ C.fused_sumexp_bf16_launcher(
+ m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream
+ )
+ else:
+ raise ValueError(f"fused_sumexp not supported for dtype {logits.dtype}")
+ return sum_exp_logits
+
+
+def fused_softmax_inplace(
+ logits: torch.Tensor, max_logits: torch.Tensor, sum_exp_logits: torch.Tensor
+) -> None:
+ CHECK_INPUT(logits)
+ CHECK_INPUT(max_logits)
+ CHECK_INPUT(sum_exp_logits)
+ assert max_logits.dtype == torch.float32, "max_logits must be float tensor"
+ assert sum_exp_logits.dtype == torch.float32, "sum_exp_logits must be float tensor"
+ assert max_logits.size(0) == logits.size(
+ 0
+ ), "max_logits must have same size(0) as logits"
+ assert sum_exp_logits.size(0) == logits.size(
+ 0
+ ), "sum_exp_logits must have same size(0) as logits"
+ m = logits.size(0)
+ n = logits.size(1)
+ cuda_stream = torch.cuda.current_stream().cuda_stream
+ logits_ptr = logits.data_ptr()
+ max_logits_ptr = max_logits.data_ptr()
+ sum_exp_logits_ptr = sum_exp_logits.data_ptr()
+ if logits.dtype == torch.float16:
+ C.fused_softmax_inplace_fp16_launcher(
+ m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream
+ )
+ elif logits.dtype == torch.bfloat16:
+ if not C.is_bf16_supported():
+ raise NotImplementedError(f"bfloat16 is not supported on current GPU")
+ C.fused_softmax_inplace_bf16_launcher(
+ m, n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr, cuda_stream
+ )
+ else:
+ raise ValueError(
+ f"fused_softmax_inplace not supported for dtype {logits.dtype}"
+ )
diff --git a/examples/BMTrain/bmtrain/loss/cross_entropy.py b/examples/BMTrain/bmtrain/loss/cross_entropy.py
new file mode 100644
index 00000000..5be07665
--- /dev/null
+++ b/examples/BMTrain/bmtrain/loss/cross_entropy.py
@@ -0,0 +1,260 @@
+from typing import Optional
+import torch
+from . import _function as F
+from bmtrain.global_var import config
+from bmtrain.distributed import all_gather, all_reduce
+
+class OpFusedCrossEntropy(torch.autograd.Function):
+ """
+ CrossEntropy dim = 1
+ """
+ @staticmethod
+ def forward(ctx, x : torch.Tensor, target : torch.Tensor, ignore_index: int):
+ assert x.ndim == 2
+ softmax = torch.empty(x.size(), device=x.device, dtype=x.dtype)
+ out = torch.empty(x.size(0), device=x.device, dtype=torch.float)
+ F.cross_entropy_forward(
+ x.size(0), x.size(1),
+ x, target,
+ softmax, out,
+ ignore_index,
+ )
+ ctx.ignore_index = ignore_index
+ ctx.save_for_backward(softmax, target)
+ return out # float tensor
+
+ @staticmethod
+ def backward(ctx, grad_output : torch.Tensor):
+ grad_output = grad_output.contiguous()
+ softmax, target = ctx.saved_tensors
+ F.cross_entropy_backward_inplace(
+ softmax.size(0), softmax.size(1),
+ grad_output, target,
+ softmax,
+ ctx.ignore_index,
+ )
+ return (softmax, None, None)
+
+class VPFusedCrossEntropy(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, logits : torch.Tensor, target : torch.Tensor):
+ comm = config['tp_comm']
+ rank = config['tp_rank']
+ world_size = config['tp_size']
+
+ max_logits = torch.max(logits, dim=-1)[0].float()
+ max_logits = all_reduce(max_logits, op="max", comm=comm)
+
+ partition_vocab_size = logits.size()[-1]
+ vocab_start_index = rank * partition_vocab_size
+ vocab_end_index = (rank + 1) * partition_vocab_size
+
+ # Create a mask of valid vocab ids (1 means it needs to be masked).
+ target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
+ masked_target = target.clone() - vocab_start_index
+ masked_target[target_mask] = 0
+
+ logits_2d = logits.view(-1, partition_vocab_size)
+ masked_target_1d = masked_target.view(-1)
+ arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
+ predicted_logits_1d = logits_2d[arange_1d, masked_target_1d].contiguous() # (-1,)
+ predicted_logits = predicted_logits_1d.view_as(target)
+ predicted_logits[target_mask] = 0.0 # if target=-100, it will also be 0
+
+ # All reduce is needed to get the chunks from other GPUs.
+ predicted_logits = all_reduce(predicted_logits.float(), op="sum", comm=comm)
+ predicted_logits = predicted_logits - max_logits
+ # Sum of exponential of logits along vocab dimension across all GPUs.
+
+ sum_exp_logits = torch.empty(logits.size(0), device=logits.device, dtype=torch.float)
+ sum_exp_logits = F.fused_sumexp(logits, max_logits) # float
+ sum_exp_logits = all_reduce(sum_exp_logits, op="sum", comm=comm) + 1e-10 # avoid nan
+
+ softmax = logits.clone()
+ F.fused_softmax_inplace(softmax, max_logits, sum_exp_logits) # logits -> softmax
+ # logits = logits.float() - max_logits.unsqueeze(dim=-1).float()
+ # exp_logits = logits
+ # torch.exp(logits, out=exp_logits)
+ # sum_exp_logits = exp_logits.sum(dim=-1)
+ # exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
+
+ loss = torch.log(sum_exp_logits.view(predicted_logits.shape)) - predicted_logits
+
+ # Normalize
+ ctx.save_for_backward(softmax, target_mask, masked_target_1d)
+
+ return loss
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ # Retreive tensors from the forward path.
+ softmax, target_mask, masked_target_1d = ctx.saved_tensors
+ # All the inputs have softmax as thier gradient.
+ grad_input = softmax
+ # For simplicity, work with the 2D gradient.
+ partition_vocab_size = softmax.size()[-1]
+ grad_2d = grad_input.view(-1, partition_vocab_size)
+
+ # Add the gradient from matching classes.
+ arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
+
+ softmax_update = 1.0 - target_mask.view(-1).float()
+
+ grad_2d[arange_1d, masked_target_1d] -= softmax_update
+ grad_input.mul_(grad_output.view(*grad_input.shape[:-1]).unsqueeze(dim=-1))
+
+ return grad_input, None
+
+class FusedCrossEntropy(torch.nn.Module):
+ r"""This criterion computes the cross entropy loss between input and target.
+
+ It is useful when training a classification problem with `C` classes.
+ If provided, the optional argument :attr:`weight` should be a 1D `Tensor`
+ assigning weight to each of the classes.
+ This is particularly useful when you have an unbalanced training set.
+
+ The `input` is expected to contain raw, unnormalized scores for each class.
+ `input` has to be a Tensor of size :math:`(minibatch, C)`.
+
+ The `target` that this criterion expects should contain either:
+
+ - Class indices in the range :math:`[0, C-1]` where :math:`C` is the number of classes; if
+ `ignore_index` is specified, this loss also accepts this class index (this index
+ may not necessarily be in the class range). The unreduced (i.e. with :attr:`reduction`
+ set to ``'none'``) loss for this case can be described as:
+
+ .. math::
+ \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
+ l_n = - w_{y_n} \log \frac{\exp(x_{n,y_n})}{\sum_{c=1}^C \exp(x_{n,c})}
+ \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}
+
+ where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight,
+ :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension. If
+ :attr:`reduction` is not ``'none'`` (default ``'mean'``), then
+
+ .. math::
+ \ell(x, y) = \begin{cases}
+ \sum_{n=1}^N \frac{1}{\sum_{n=1}^N w_{y_n} \cdot \mathbb{1}\{y_n \not= \text{ignore\_index}\}} l_n, &
+ \text{if reduction} = \text{`mean';}\\
+ \sum_{n=1}^N l_n, &
+ \text{if reduction} = \text{`sum'.}
+ \end{cases}
+
+ Note that this case is equivalent to the combination of :class:`~torch.nn.LogSoftmax` and
+ :class:`~torch.nn.NLLLoss`.
+
+ - Probabilities for each class; useful when labels beyond a single class per minibatch item
+ are required, such as for blended labels, label smoothing, etc. The unreduced (i.e. with
+ :attr:`reduction` set to ``'none'``) loss for this case can be described as:
+
+ .. math::
+ \ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad
+ l_n = - \sum_{c=1}^C w_c \log \frac{\exp(x_{n,c})}{\exp(\sum_{i=1}^C x_{n,i})} y_{n,c}
+
+ where :math:`x` is the input, :math:`y` is the target, :math:`w` is the weight,
+ :math:`C` is the number of classes, and :math:`N` spans the minibatch dimension. If
+ :attr:`reduction` is not ``'none'`` (default ``'mean'``), then
+
+ .. math::
+ \ell(x, y) = \begin{cases}
+ \frac{\sum_{n=1}^N l_n}{N}, &
+ \text{if reduction} = \text{`mean';}\\
+ \sum_{n=1}^N l_n, &
+ \text{if reduction} = \text{`sum'.}
+ \end{cases}
+
+ .. note::
+ The performance of this criterion is generally better when `target` contains class
+ indices, as this allows for optimized computation. Consider providing `target` as
+ class probabilities only when a single class label per minibatch item is too restrictive.
+
+ Args:
+ weight (Tensor, optional): a manual rescaling weight given to each class.
+ If given, has to be a Tensor of size `C`
+ size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
+ the losses are averaged over each loss element in the batch. Note that for
+ some losses, there are multiple elements per sample. If the field :attr:`size_average`
+ is set to ``False``, the losses are instead summed for each minibatch. Ignored
+ when :attr:`reduce` is ``False``. Default: ``True``
+ ignore_index (int, optional): Specifies a target value that is ignored
+ and does not contribute to the input gradient. When :attr:`size_average` is
+ ``True``, the loss is averaged over non-ignored targets. Note that
+ :attr:`ignore_index` is only applicable when the target contains class indices.
+ reduce (bool, optional): Deprecated (see :attr:`reduction`). By default, the
+ losses are averaged or summed over observations for each minibatch depending
+ on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
+ batch element instead and ignores :attr:`size_average`. Default: ``True``
+ reduction (string, optional): Specifies the reduction to apply to the output:
+ ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will
+ be applied, ``'mean'``: the weighted mean of the output is taken,
+ ``'sum'``: the output will be summed. Note: :attr:`size_average`
+ and :attr:`reduce` are in the process of being deprecated, and in
+ the meantime, specifying either of those two args will override
+ :attr:`reduction`. Default: ``'mean'``
+ label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount
+ of smoothing when computing the loss, where 0.0 means no smoothing. The targets
+ become a mixture of the original ground truth and a uniform distribution as described in
+ `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`.
+
+ Shape:
+ - Input: :math:`(N, C)` where `C = number of classes`.
+ - Target: If containing class indices, shape :math:`(N)` where each value is
+ :math:`0 \leq \text{targets}[i] \leq C-1`. If containing class probabilities,
+ same shape as the input.
+ - Output: If :attr:`reduction` is ``'none'``, shape :math:`(N)`.
+ Otherwise, scalar.
+
+ Examples::
+
+ >>> # Example of target with class indices
+ >>> loss_func = bmt.loss.FusedCrossEntropy()
+ >>> input = torch.randn(32, 100).half()
+ >>> target = torch.randint(0, 100, (32,)).long()
+ >>> loss = loss_func(input, target)
+ >>> loss.backward()
+ """
+ def __init__(self,
+ weight: Optional[torch.Tensor] = None,
+ ignore_index: int = -100,
+ reduction: str = 'mean',
+ label_smoothing: float = 0.0, # TODO not supported yet
+ parallel: bool = False,
+ ) -> None:
+ super().__init__()
+ self.weight = weight
+ self.ignore_index = ignore_index
+ self.reduction = reduction
+ self.label_smoothing = label_smoothing
+ self.parallel = parallel
+
+ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ if self.parallel:
+ ret = VPFusedCrossEntropy.apply(input, target.long())
+ else:
+ if input.dtype == torch.float32:
+ return torch.nn.functional.cross_entropy(
+ input,
+ target.long(),
+ weight=self.weight,
+ ignore_index=self.ignore_index,
+ reduction=self.reduction,
+ label_smoothing=self.label_smoothing)
+
+ ret = OpFusedCrossEntropy.apply(input, target.int(), self.ignore_index) # return float tensor
+
+ if self.weight is not None:
+ if self.weight.dim() != 1 or self.weight.size(0) != input.size(1):
+ raise ValueError("weight should be a 1D tensor of size C");
+ w = self.weight[torch.where(target==self.ignore_index, 0, target)].float()
+ w[target==self.ignore_index] = 0
+ else:
+ w = (target != self.ignore_index).int()
+
+ ret = w * ret
+
+ if self.reduction == "none":
+ return ret
+ elif self.reduction == "sum":
+ return ret.sum()
+ elif self.reduction == "mean":
+ return ret.sum() / w.sum().float()
diff --git a/examples/BMTrain/bmtrain/lr_scheduler/__init__.py b/examples/BMTrain/bmtrain/lr_scheduler/__init__.py
new file mode 100644
index 00000000..0d9a0596
--- /dev/null
+++ b/examples/BMTrain/bmtrain/lr_scheduler/__init__.py
@@ -0,0 +1,6 @@
+from .warmup import WarmupLRScheduler
+from .no_decay import NoDecay
+from .noam import Noam
+from .linear import Linear
+from .cosine import Cosine
+from .exponential import Exponential
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/lr_scheduler/cosine.py b/examples/BMTrain/bmtrain/lr_scheduler/cosine.py
new file mode 100644
index 00000000..3aed034d
--- /dev/null
+++ b/examples/BMTrain/bmtrain/lr_scheduler/cosine.py
@@ -0,0 +1,18 @@
+import math
+from .warmup import WarmupLRScheduler
+
+
+class Cosine(WarmupLRScheduler):
+ r"""
+ After a warmup period during which learning rate increases linearly between 0 and the start_lr,
+ The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{1+\cos \left( \pi \cdot \dfrac{\text{num_iter}-\text{warmup_iter}}{\text{end_iter}-\text{warmup_iter}}\right)}{2}`
+ """
+
+ def get_lr_warmup(self, num_iter) -> float:
+ return self.start_lr * num_iter / self.warmup_iter
+
+ def get_lr_decay(self, num_iter) -> float:
+ progress = (num_iter - self.warmup_iter) / max(
+ 1, (self.end_iter - self.warmup_iter)
+ )
+ return max(0.0, self.start_lr * 0.5 * (1.0 + math.cos(progress * math.pi)))
diff --git a/examples/BMTrain/bmtrain/lr_scheduler/exponential.py b/examples/BMTrain/bmtrain/lr_scheduler/exponential.py
new file mode 100644
index 00000000..6cf3240e
--- /dev/null
+++ b/examples/BMTrain/bmtrain/lr_scheduler/exponential.py
@@ -0,0 +1,20 @@
+from .warmup import WarmupLRScheduler
+
+
+class Exponential(WarmupLRScheduler):
+ r"""
+ After a warmup period during which learning rate increases linearly between 0 and the start_lr,
+ The decay period performs :math:`\text{lr}=\text{start_lr}\times \gamma ^ {\left(\text{num_iter}-\text{warmup_iter}\right)}`
+ """
+
+ def __init__(
+ self, optimizer, start_lr, warmup_iter, end_iter, num_iter, gamma=0.95
+ ) -> None:
+ super().__init__(optimizer, start_lr, warmup_iter, end_iter, num_iter)
+ self.gamma = gamma
+
+ def get_lr_warmup(self, num_iter) -> float:
+ return self.start_lr * num_iter / self.warmup_iter
+
+ def get_lr_decay(self, num_iter) -> float:
+ return max(0.0, self.start_lr * self.gamma ** (num_iter - self.warmup_iter))
diff --git a/examples/BMTrain/bmtrain/lr_scheduler/linear.py b/examples/BMTrain/bmtrain/lr_scheduler/linear.py
new file mode 100644
index 00000000..af193dd8
--- /dev/null
+++ b/examples/BMTrain/bmtrain/lr_scheduler/linear.py
@@ -0,0 +1,19 @@
+from .warmup import WarmupLRScheduler
+
+
+class Linear(WarmupLRScheduler):
+ r"""
+ After a warmup period during which learning rate increases linearly between 0 and the start_lr,
+ The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{end_iter}-\text{num_iter}}{\text{end_iter}-\text{warmup_iter}}`
+ """
+
+ def get_lr_warmup(self, num_iter) -> float:
+ return self.start_lr * num_iter / self.warmup_iter
+
+ def get_lr_decay(self, num_iter) -> float:
+ return max(
+ 0.0,
+ self.start_lr
+ * (self.end_iter - num_iter)
+ / (self.end_iter - self.warmup_iter),
+ )
diff --git a/examples/BMTrain/bmtrain/lr_scheduler/no_decay.py b/examples/BMTrain/bmtrain/lr_scheduler/no_decay.py
new file mode 100644
index 00000000..6f85bf0a
--- /dev/null
+++ b/examples/BMTrain/bmtrain/lr_scheduler/no_decay.py
@@ -0,0 +1,14 @@
+from .warmup import WarmupLRScheduler
+
+
+class NoDecay(WarmupLRScheduler):
+ r"""
+ After a warmup period during which learning rate increases linearly between 0 and the start_lr,
+ The decay period performs :math:`\text{lr}=\text{start_lr}`
+ """
+
+ def get_lr_warmup(self, num_iter) -> float:
+ return self.start_lr * num_iter / self.warmup_iter
+
+ def get_lr_decay(self, num_iter) -> float:
+ return self.start_lr
diff --git a/examples/BMTrain/bmtrain/lr_scheduler/noam.py b/examples/BMTrain/bmtrain/lr_scheduler/noam.py
new file mode 100644
index 00000000..8954a64d
--- /dev/null
+++ b/examples/BMTrain/bmtrain/lr_scheduler/noam.py
@@ -0,0 +1,15 @@
+import math
+from .warmup import WarmupLRScheduler
+
+
+class Noam(WarmupLRScheduler):
+ r"""
+ After a warmup period during which performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{num_iter}}{\text{warmup_iter}^{3/2}}`,
+ The decay period performs :math:`\text{lr}=\text{start_lr}\times \dfrac{\text{1}}{\sqrt{\text{num_iter}}}`
+ """
+
+ def get_lr_warmup(self, num_iter) -> float:
+ return self.start_lr / math.sqrt(self.warmup_iter) * num_iter / self.warmup_iter
+
+ def get_lr_decay(self, num_iter) -> float:
+ return self.start_lr / math.sqrt(num_iter)
diff --git a/examples/BMTrain/bmtrain/lr_scheduler/warmup.py b/examples/BMTrain/bmtrain/lr_scheduler/warmup.py
new file mode 100644
index 00000000..1f9ccc8e
--- /dev/null
+++ b/examples/BMTrain/bmtrain/lr_scheduler/warmup.py
@@ -0,0 +1,72 @@
+import torch
+
+
+class WarmupLRScheduler:
+ r"""Base class for learning rate schedulers with warmup.
+
+ Args:
+ optimizer (torch.optim.Optimizer): optimizer used for training
+ start_lr (float): starting learning rate
+ warmup_iter (int): number of iterations to linearly increase learning rate
+ end_iter (int): number of iterations to stop training
+ num_iter (int): current iteration number
+ """
+
+ def __init__(
+ self,
+ optimizer: torch.optim.Optimizer,
+ start_lr,
+ warmup_iter,
+ end_iter,
+ num_iter=0,
+ ) -> None:
+ self.start_lr = start_lr
+ self.warmup_iter = warmup_iter
+ self.end_iter = end_iter
+ self.optimizer = optimizer
+ self.num_iter = num_iter
+ self._current_lr = None
+
+ self.step(self.num_iter)
+
+ def get_lr_warmup(self, num_iter) -> float: ...
+
+ def get_lr_decay(self, num_iter) -> float: ...
+
+ def get_lr(self):
+ assert self.num_iter >= 0
+
+ if self.num_iter < self.warmup_iter:
+ return self.get_lr_warmup(self.num_iter)
+ else:
+ return self.get_lr_decay(self.num_iter)
+
+ @property
+ def current_lr(self):
+ return self._current_lr
+
+ def step(self, num_iter=None) -> None:
+ if num_iter is None:
+ num_iter = self.num_iter + 1
+ self.num_iter = num_iter
+
+ lr = self.get_lr()
+ self._current_lr = lr
+ for group in self.optimizer.param_groups:
+ group["lr"] = lr
+
+ def state_dict(self):
+ return {
+ "start_lr": self.start_lr,
+ "warmup_iter": self.warmup_iter,
+ "end_iter": self.end_iter,
+ "num_iter": self.num_iter,
+ }
+
+ def load_state_dict(self, state_dict):
+ self.start_lr = state_dict["start_lr"]
+ self.warmup_iter = state_dict["warmup_iter"]
+ self.end_iter = state_dict["end_iter"]
+ self.num_iter = state_dict["num_iter"]
+
+ self.step(self.num_iter)
diff --git a/examples/BMTrain/bmtrain/nccl/__init__.py b/examples/BMTrain/bmtrain/nccl/__init__.py
new file mode 100644
index 00000000..0f4129d5
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nccl/__init__.py
@@ -0,0 +1,336 @@
+
+from typing_extensions import Literal
+import torch
+from .. import C
+from .enums import *
+
+class NCCLCommunicator:
+ """
+ NCCL communicator stores the communicator handle.
+ """
+
+ def __init__(self, ptr) -> None:
+ self.__ptr = ptr
+
+ @property
+ def ptr(self):
+ """
+ Returns the communicator handle.
+ """
+ if self.__ptr == -1:
+ raise RuntimeError("NCCL Communicator is already destroyed")
+ return self.__ptr
+
+ def _destroy_ptr(self):
+ self.__ptr = -1
+
+# utils
+
+def dtype2nccl(dtype : torch.dtype) -> int:
+ MAP = {
+ torch.int8: ncclInt8,
+ torch.uint8 : ncclUint8,
+ torch.int32 : ncclInt32,
+ torch.int : ncclInt32,
+ torch.int64 : ncclInt64,
+ torch.float16 : ncclFloat16,
+ torch.half : ncclHalf,
+ torch.bfloat16 : ncclBFloat16,
+ torch.float32 : ncclFloat32,
+ torch.float : ncclFloat,
+ torch.float64 : ncclFloat64,
+ torch.double : ncclDouble,
+ torch.bool : ncclBool
+ }
+ if dtype not in MAP:
+ raise TypeError("Unsupport dtype %s" % dtype)
+ return MAP[dtype]
+
+def op2nccl(
+ op : Literal["sum", "prod", "max", "min", "avg"]
+):
+ if op == "sum":
+ return ncclSum
+ if op == "prod":
+ return ncclProd
+ if op == "max":
+ return ncclMax
+ if op == "min":
+ return ncclMin
+ if op == "avg":
+ return ncclAvg
+ raise ValueError("Unknown gather op %s")
+
+# wrappers
+
+def getUniqueId() -> bytes:
+ """
+ NCCL API: `ncclGetUniqueId `_
+
+ """
+ return C.ncclGetUniqueId()
+
+def commInitRank(unique_id : bytes, world_size : int, rank : int) -> NCCLCommunicator:
+ """
+ NCCL API: `ncclCommInitRank `_
+
+ """
+ assert rank >= 0 and rank < world_size, "rank must be between 0 and world_size-1"
+ return NCCLCommunicator(C.ncclCommInitRank(unique_id, world_size, rank))
+
+def commDestroy(comm : NCCLCommunicator):
+ """
+ NCCL API: `ncclCommDestroy `_
+
+ """
+ C.ncclCommDestroy(comm.ptr)
+ comm._destroy_ptr()
+def commCount(comm : NCCLCommunicator):
+ """NCCL API: `ncclCommCount `_
+
+ Args:
+ comm (NCCLCommunicator): NCCL communicator.
+ """
+ return C.ncclCommCount(comm.ptr)
+### collective
+def commRank(comm : NCCLCommunicator):
+ """NCCL API: `ncclCommUserRank `_
+
+ Args:
+ comm (NCCLCommunicator): NCCL communicator.
+ """
+ return C.ncclCommUserRank(comm.ptr)
+def allReduce(
+ src : torch.storage._StorageBase,
+ dst : torch.storage._StorageBase,
+ op : Literal["sum", "prod", "max", "min", "avg"],
+ comm : NCCLCommunicator
+ ):
+ """NCCL API: `ncclAllReduce `_
+
+ Args:
+ src (torch.storage._StorageBase): Source buffer.
+ dst (torch.storage._StorageBase): Destination buffer.
+ op (Literal["sum", "prod", "max", "min", "avg"]): Reduction operation.
+ comm (NCCLCommunicator): NCCL communicator.
+
+ The src and dst buffers must be the same size, type and on the same device.
+
+ If src == dst, the operation is performed in-place.
+
+ """
+ assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
+ assert src.is_cuda and dst.is_cuda
+
+ sendbuff = src.data_ptr()
+ recvbuff = dst.data_ptr()
+ count = src.size()
+ datatype = dtype2nccl(src.dtype)
+ operator = op2nccl(op)
+
+ assert src.size() == dst.size(), "Buffer size not aligned"
+ C.ncclAllReduce(
+ sendbuff,
+ recvbuff,
+ count,
+ datatype,
+ operator,
+ comm.ptr,
+ torch.cuda.current_stream().cuda_stream
+ )
+def send(src : torch.storage._StorageBase,
+ peer : int,
+ comm : NCCLCommunicator
+ ):
+ """NCCL API: `ncclsend `_
+
+ Args:
+ src (torch.storage._StorageBase): Source buffer.
+ peer (int): rank peer needs to call ncclRecv
+ comm (NCCLCommunicator): NCCL communicator.
+ """
+
+ sendbuff = src.data_ptr()
+ count = src.size()
+ datatype = dtype2nccl(src.dtype)
+ C.ncclSend(
+ sendbuff,
+ count,
+ datatype,
+ peer,
+ comm.ptr,
+ torch.cuda.current_stream().cuda_stream
+ )
+def recv(dst : torch.storage._StorageBase,
+ peer : int,
+ comm : NCCLCommunicator
+ ):
+ recvbuff = dst.data_ptr()
+ count = dst.size()
+ datatype = dtype2nccl(dst.dtype)
+ C.ncclRecv(
+ recvbuff,
+ count,
+ datatype,
+ peer,
+ comm.ptr,
+ torch.cuda.current_stream().cuda_stream
+ )
+
+def broadcast(
+ src : torch.storage._StorageBase,
+ dst : torch.storage._StorageBase,
+ root : int,
+ comm : NCCLCommunicator
+ ):
+ """NCCL API: `ncclBroadcast `_
+
+ Args:
+ src (torch.storage._StorageBase): Source buffer.
+ dst (torch.storage._StorageBase): Destination buffer.
+ root (int): Rank of the root.
+ comm (NCCLCommunicator): NCCL communicator.
+
+ The src and dst buffers must be the same size, type and on the same device.
+
+ If src == dst, the operation is performed in-place.
+
+ """
+
+ assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
+ assert src.is_cuda and dst.is_cuda
+
+ sendbuff = src.data_ptr()
+ recvbuff = dst.data_ptr()
+ count = src.size()
+ datatype = dtype2nccl(src.dtype)
+
+ assert dst.size() == src.size(), "Buffer size not aligned"
+ C.ncclBroadcast(
+ sendbuff,
+ recvbuff,
+ count,
+ datatype,
+ root,
+ comm.ptr,
+ torch.cuda.current_stream().cuda_stream
+ )
+
+def reduce(
+ src : torch.storage._StorageBase,
+ dst : torch.storage._StorageBase,
+ op : Literal["sum", "prod", "max", "min", "avg"],
+ root : int,
+ comm : NCCLCommunicator
+ ):
+ """NCCL API: `ncclReduce `_
+
+ Args:
+ src (torch.storage._StorageBase): Source buffer.
+ dst (torch.storage._StorageBase): Destination buffer.
+ op (Literal["sum", "prod", "max", "min", "avg"]): Reduction operation.
+ root (int): Rank of the root.
+ comm (NCCLCommunicator): NCCL communicator.
+
+ The src and dst buffers must be the same size, type and on the same device.
+
+ If src == dst, the operation is performed in-place.
+
+ """
+ assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
+ assert src.is_cuda and dst.is_cuda
+
+ sendbuff = src.data_ptr()
+ recvbuff = dst.data_ptr()
+ count = src.size()
+ datatype = dtype2nccl(src.dtype)
+ operator = op2nccl(op)
+
+ assert dst.size() == src.size(), "Buffer size not aligned"
+ C.ncclReduce(sendbuff, recvbuff, count, datatype, operator, root, comm.ptr, torch.cuda.current_stream().cuda_stream)
+
+def allGather(
+ src : torch.storage._StorageBase,
+ dst : torch.storage._StorageBase,
+ comm : NCCLCommunicator
+ ):
+ """NCCL API: `ncclAllGather `_
+
+ Args:
+ src (torch.storage._StorageBase): Source buffer.
+ dst (torch.storage._StorageBase): Destination buffer.
+ comm (NCCLCommunicator): NCCL communicator.
+
+ The size of the dst buffer must be equal to the size of src buffer * world_size.
+
+ The dst buffer is only used on rank root.
+
+ """
+ assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
+ assert src.is_cuda and dst.is_cuda
+
+ sendbuff = src.data_ptr()
+ recvbuff = dst.data_ptr()
+ sendcount = src.size()
+ datatype = dtype2nccl(src.dtype)
+ assert dst.size() % sendcount == 0, "Buffer size not aligned"
+ C.ncclAllGather(
+ sendbuff,
+ recvbuff,
+ sendcount,
+ datatype,
+ comm.ptr,
+ torch.cuda.current_stream().cuda_stream
+ )
+
+
+def reduceScatter(
+ src : torch.storage._StorageBase,
+ dst : torch.storage._StorageBase,
+ op : Literal["sum", "prod", "max", "min", "avg"],
+ comm : NCCLCommunicator
+ ):
+ """NCCL API: `ncclReduceScatter `_
+
+ Args:
+ src (torch.storage._StorageBase): Source buffer.
+ dst (torch.storage._StorageBase): Destination buffer.
+ op (Literal["sum", "prod", "max", "min", "avg"]): Reduction operation.
+ comm (NCCLCommunicator): NCCL communicator.
+
+ The size of the dst buffer must be equal to the size of src buffer / world_size.
+
+ The dst buffer on rank `i` will contail the i-th block of the reduced result.
+
+ """
+ assert src.dtype == dst.dtype, "send and recv buffers must be the same time"
+ assert src.is_cuda and dst.is_cuda
+
+ sendbuff = src.data_ptr()
+ recvbuff = dst.data_ptr()
+ recvcount = dst.size()
+ datatype = dtype2nccl(src.dtype)
+ operator = op2nccl(op)
+
+ assert src.size() % recvcount == 0, "Buffer size not aligned"
+ C.ncclReduceScatter(
+ sendbuff,
+ recvbuff,
+ recvcount,
+ datatype,
+ operator,
+ comm.ptr,
+ torch.cuda.current_stream().cuda_stream
+ )
+
+def groupStart():
+ """
+ NCCL API: `ncclGroupStart `_
+ """
+ C.ncclGroupStart()
+
+def groupEnd():
+ """
+ NCCL API: `ncclGroupEnd `_
+ """
+ C.ncclGroupEnd()
diff --git a/examples/BMTrain/bmtrain/nccl/enums.py b/examples/BMTrain/bmtrain/nccl/enums.py
new file mode 100644
index 00000000..67411f0e
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nccl/enums.py
@@ -0,0 +1,27 @@
+
+### ncclDataType_t
+
+ncclInt8 = 0
+ncclChar = 0
+ncclBool = 0
+ncclUint8 = 1
+ncclInt32 = 2
+ncclInt = 2
+ncclUint32 = 3
+ncclInt64 = 4
+ncclUint64 = 5
+ncclFloat16 = 6
+ncclHalf = 6
+ncclFloat32 = 7
+ncclFloat = 7
+ncclFloat64 = 8
+ncclDouble = 8
+ncclBFloat16 = 9
+
+### ncclRedOp_t
+
+ncclSum = 0
+ncclProd = 1
+ncclMax = 2
+ncclMin = 3
+ncclAvg = 4
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/nn/__init__.py b/examples/BMTrain/bmtrain/nn/__init__.py
new file mode 100644
index 00000000..60fed663
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nn/__init__.py
@@ -0,0 +1,5 @@
+from .linear import Linear, OpLinear
+from .column_parallel_linear import ColumnParallelLinear
+from .row_parallel_linear import RowParallelLinear
+from .parallel_embedding import VPEmbedding
+from .parallel_linear_func import OpParallelLinear
diff --git a/examples/BMTrain/bmtrain/nn/column_parallel_linear.py b/examples/BMTrain/bmtrain/nn/column_parallel_linear.py
new file mode 100644
index 00000000..e1ede115
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nn/column_parallel_linear.py
@@ -0,0 +1,80 @@
+import torch
+from torch.nn.parameter import Parameter
+
+import bmtrain as bmt
+from bmtrain.global_var import config
+from .parallel_linear_func import OpParallelLinear, ReduceType
+
+
+class ColumnParallelLinear(bmt.DistributedModule):
+ """Tensor Parallel use cloumn partition for Linear.
+
+ Args:
+ in_features (int): in_features size.
+ out_features (int): out_features size.
+ bias (bool): whether use bias.
+ dtype : data type.
+ gather_ouput (bool): whether gather output after compute.
+ gather_input (bool): whether gather input before compute.
+ async_gather_chunks (int): chunk size for async gathering data.
+
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ dtype=None,
+ gather_output=False,
+ gather_input=True,
+ async_gather_chunks=2,
+ ) -> None:
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.gather_output = gather_output
+ self.gather_input = gather_input
+ self.async_gather_chunks = async_gather_chunks
+ tp_size = config["tp_size"]
+ assert out_features % tp_size == 0
+ self.out_features_per_partition = out_features // tp_size
+ self.weight = bmt.DistributedParameter(
+ torch.empty(
+ self.out_features_per_partition, in_features, dtype=dtype, device="cuda"
+ ),
+ init_method=torch.nn.init.xavier_normal_,
+ tp_split_dim=0,
+ tp_mode=True,
+ )
+ if bias:
+ self.bias = bmt.DistributedParameter(
+ torch.empty(
+ self.out_features_per_partition, dtype=dtype, device="cuda"
+ ),
+ init_method=torch.nn.init.zeros_,
+ tp_split_dim=0,
+ tp_mode=True,
+ )
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, input):
+ gather_input = self.gather_input
+ split_input = False
+ reduce_output_type = None
+ return OpParallelLinear.apply(
+ input,
+ self.weight,
+ self.bias,
+ gather_input,
+ self.gather_output,
+ split_input,
+ reduce_output_type,
+ self.async_gather_chunks,
+ )
+
+ def extra_repr(self) -> str:
+ return "in_features={}, out_features={}, bias={}".format(
+ self.in_features, self.out_features_per_partitions, self.bias is not None
+ )
diff --git a/examples/BMTrain/bmtrain/nn/linear.py b/examples/BMTrain/bmtrain/nn/linear.py
new file mode 100644
index 00000000..8afb1d89
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nn/linear.py
@@ -0,0 +1,56 @@
+import torch
+import torch.nn.functional as F
+import bmtrain as bmt
+
+
+class OpLinear(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, x, weight, bias=None):
+ ctx.save_for_backward(x, weight, bias)
+ return F.linear(x, weight, bias)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ x, weight, bias = ctx.saved_tensors
+ grad_x = grad_weight = grad_bias = None
+ if x.requires_grad:
+ grad_x = grad_output.matmul(weight)
+ if weight.requires_grad:
+ dim = grad_output.dim()
+ grad_weight = (
+ grad_output.reshape(-1, grad_output.shape[-1])
+ .t()
+ .matmul(x.reshape(-1, x.shape[-1]))
+ )
+ if bias is not None and bias.requires_grad:
+ grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
+ return grad_x, grad_weight, grad_bias
+
+
+class Linear(bmt.DistributedModule):
+ def __init__(
+ self, in_features: int, out_features: int, bias: bool = True, dtype=None
+ ) -> None:
+ super().__init__()
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.weight = bmt.DistributedParameter(
+ torch.empty(out_features, in_features, dtype=dtype, device="cuda"),
+ init_method=torch.nn.init.xavier_normal_,
+ )
+ if bias:
+ self.bias = bmt.DistributedParameter(
+ torch.empty(out_features, dtype=dtype, device="cuda"),
+ init_method=torch.nn.init.zeros_,
+ )
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, input):
+ return OpLinear.apply(input, self.weight, self.bias)
+
+ def extra_repr(self) -> str:
+ return "in_features={}, out_features={}, bias={}".format(
+ self.in_features, self.out_features, self.bias is not None
+ )
diff --git a/examples/BMTrain/bmtrain/nn/parallel_embedding.py b/examples/BMTrain/bmtrain/nn/parallel_embedding.py
new file mode 100644
index 00000000..3bdc4e56
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nn/parallel_embedding.py
@@ -0,0 +1,59 @@
+import torch
+from torch.nn.parameter import Parameter
+import torch.nn.functional as F
+import math
+
+import bmtrain as bmt
+from bmtrain.global_var import config
+from bmtrain.distributed import all_reduce, all_gather
+from .parallel_linear_func import OpParallelLinear
+
+
+class VPEmbedding(bmt.DistributedModule):
+ """Vocab Parallel Embedding.
+
+ Args:
+ vocab_size (int required): vocab size.
+ embedding_size (int required): embedding size.
+ dtype (torch.dtype): data type.
+ init_mean (float optional): mean for weight init.
+ init_std (float optional): std for weight init.
+
+ """
+
+ def __init__(
+ self,
+ vocab_size: int,
+ embedding_size: int,
+ dtype: torch.dtype = torch.half,
+ init_mean: float = 0.0,
+ init_std: float = 1,
+ ):
+ super().__init__()
+
+ self.dim_model = embedding_size
+ assert vocab_size % bmt.config["tp_size"] == 0
+ self.vocab_size_per_partition = vocab_size // bmt.config["tp_size"]
+ self.start_index = bmt.config["tp_rank"] * self.vocab_size_per_partition
+ self.end_index = (bmt.config["tp_rank"] + 1) * self.vocab_size_per_partition
+ self.weight = bmt.DistributedParameter(
+ torch.empty(self.vocab_size_per_partition, embedding_size, dtype=dtype),
+ init_method=bmt.ParameterInitializer(
+ torch.nn.init.normal_, mean=init_mean, std=init_std
+ ),
+ tp_split_dim=0,
+ tp_mode=True,
+ )
+
+ def forward(self, x: torch.Tensor, projection=False):
+ if not projection:
+ weight = all_gather(self.weight, comm=config["tp_comm"]).flatten(0, 1)
+ out = F.embedding(x, weight)
+ return out
+ else:
+ x = bmt.distributed.all_gather(x, comm=bmt.config["tp_comm"]).view(
+ x.shape[0], -1, x.shape[-1]
+ )
+ return bmt.nn.OpParallelLinear.apply(
+ x, self.weight, None, False, False, False, None, 1
+ )
diff --git a/examples/BMTrain/bmtrain/nn/parallel_linear_func.py b/examples/BMTrain/bmtrain/nn/parallel_linear_func.py
new file mode 100644
index 00000000..e389cde6
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nn/parallel_linear_func.py
@@ -0,0 +1,352 @@
+import torch
+import torch.nn.functional as F
+from bmtrain.global_var import config
+from ..distributed import all_gather, all_reduce
+from .. import nccl
+import bmtrain as bmt
+from enum import Enum
+
+
+class ReduceType(Enum):
+ ALL_REDUCE = 1
+ REDUCE_SCATTER = 2
+
+
+def preprocess_input(input, gather_input, split_input):
+ if gather_input:
+ input = all_gather(input, config["tp_comm"])
+ input = input.flatten(0, 1)
+
+ if split_input:
+ all_input_list = input.chunk(config["tp_size"], dim=-1)
+ input = all_input_list[config["topology"].tp_id]
+ return input
+
+
+def async_all_gather_linear_func(input, weight, bias, async_chunks=2):
+ dim = input.dim()
+ shape = list(input.shape)
+ if dim > 2:
+ input = input.view(-1, input.shape[-1])
+ tp_size = config["tp_size"]
+ current_stream = torch.cuda.current_stream()
+ comm_stream = config["tp_comm_stream"]
+
+ rounds = async_chunks
+ inputs = input.chunk(rounds, dim=0)
+ comm_stream.wait_stream(current_stream)
+ outputs = [None] * tp_size * rounds
+
+ input = all_gather(inputs[0], config["tp_comm"])
+ input = input.flatten(0, 1)
+ out = F.linear(input, weight, bias)
+ outs = out.chunk(tp_size, dim=0)
+ for i in range(tp_size):
+ outputs[i * rounds] = outs[i]
+
+ # async all_gather and overalap with linear
+ for i in range(rounds - 1):
+ with torch.cuda.stream(comm_stream):
+ inputs[i + 1].record_stream(comm_stream)
+ input = all_gather(inputs[i + 1], config["tp_comm"])
+ input = input.flatten(0, 1)
+
+ current_stream.wait_stream(comm_stream)
+ out = F.linear(input, weight, bias)
+ outs = out.chunk(tp_size, dim=0)
+ for j in range(tp_size):
+ outputs[(i + 1) + j * rounds] = outs[j]
+
+ out = torch.cat(outputs, dim=0)
+ if dim > 2:
+ out_shape = list(out.shape)
+ shape[-1] = out_shape[-1]
+ shape[0] = shape[0] * tp_size
+ out = out.view(shape)
+ return out
+
+
+def async_reduce_scatter_linear_func(input, weight, bias, async_chunks=2):
+ tp_size = config["tp_size"]
+ comm_stream = config["tp_comm_stream"]
+ rounds = async_chunks
+ input_shape = list(input.shape)
+ dim = input.dim()
+ if dim > 2:
+ input = input.view(-1, input.shape[-1])
+ inputs = input.chunk(rounds * tp_size, dim=0)
+ current_stream = torch.cuda.current_stream()
+
+ outputs = [None] * rounds
+ for i in range(rounds):
+ input = [None] * tp_size
+ for j in range(tp_size):
+ input[j] = inputs[j * rounds + i]
+ input = torch.cat(input, dim=0)
+ out = F.linear(input, weight, bias)
+ with torch.cuda.stream(comm_stream):
+ comm_stream.wait_stream(current_stream)
+ out.record_stream(comm_stream)
+ shape = list(out.shape)
+ shape[0] = shape[0] // config["tp_size"]
+ outputs[i] = torch.empty(shape, dtype=out.dtype, device=out.device)
+ nccl.reduceScatter(
+ out.storage(), outputs[i].storage(), "sum", config["tp_comm"]
+ )
+
+ current_stream.wait_stream(comm_stream)
+ out = torch.cat(outputs, dim=0)
+ if dim > 2:
+ out_shape = list(out.shape)
+ input_shape[-1] = out_shape[-1]
+ input_shape[0] = input_shape[0] // tp_size
+ out = out.view(input_shape)
+
+ return out
+
+
+def async_all_gather_linear_backward_func(
+ grad_out, input, weight, bias, async_chunks=2
+):
+ tp_size = config["tp_size"]
+ current_stream = torch.cuda.current_stream()
+ comm_stream = config["tp_comm_stream"]
+ input_require_grad = input.requires_grad
+ dim = input.dim()
+ input_shape = input.shape
+ if dim > 2:
+ input = input.view(-1, input_shape[-1])
+ grad_out = grad_out.view(-1, grad_out.shape[-1])
+
+ rounds = async_chunks
+ grad_inputs = [None] * tp_size * rounds
+ grad_weights = [None] * tp_size * rounds
+ grad_outs = [None] * tp_size * rounds
+ local_grad_outs = grad_out.chunk(rounds, dim=0)
+
+ inputs = [None] * rounds
+ comm_stream.wait_stream(current_stream)
+ if weight.requires_grad:
+ with torch.cuda.stream(comm_stream):
+ input.record_stream(comm_stream)
+ input_list = [None] * tp_size * rounds
+ tp_inputs = input.chunk(tp_size, dim=0)
+ for i in range(tp_size):
+ chunk_inputs = tp_inputs[i].chunk(rounds, dim=0)
+ for j in range(rounds):
+ input_list[j * tp_size + i] = chunk_inputs[j]
+ start = 0
+ end = tp_size
+ for i in range(rounds):
+ inputs[i] = torch.cat(input_list[start:end], dim=0)
+ start = end
+ end += tp_size
+
+ grad_input = grad_weight = grad_bias = None
+
+ grad_out = all_gather(local_grad_outs[0], config["tp_comm"])
+ for j in range(tp_size):
+ grad_outs[j * rounds] = grad_out[j]
+ grad_out = grad_out.flatten(0, 1) # (tp_size * (m/rounds), n)
+ if input_require_grad:
+ grad_input = grad_out.matmul(
+ weight
+ ) # (tp_size * (m/rounds), n) * (n, k/tp_size)
+ tmp_grad_inputs = grad_input.chunk(tp_size, dim=0)
+ for j in range(tp_size):
+ grad_inputs[j * rounds] = tmp_grad_inputs[j]
+
+ if weight.requires_grad:
+ grad_weight = (
+ grad_out.reshape(-1, grad_out.shape[-1])
+ .t()
+ .matmul(inputs[0].reshape(-1, inputs[0].shape[-1]))
+ )
+
+ # async all_gather and overalap with matmul
+ for i in range(rounds - 1):
+ with torch.cuda.stream(comm_stream):
+ local_grad_outs[i + 1].record_stream(comm_stream)
+ grad_out = all_gather(local_grad_outs[i + 1], config["tp_comm"])
+ for j in range(tp_size):
+ grad_outs[j * rounds + i + 1] = grad_out[j]
+ grad_out = grad_out.flatten(0, 1) # (tp_size * (m/rounds), n)
+
+ current_stream.wait_stream(comm_stream)
+ if input_require_grad:
+ grad_input = grad_out.matmul(
+ weight
+ ) # (tp_size * (m/rounds), n) * (n, k/tp_size)
+ tmp_grad_inputs = grad_input.chunk(tp_size, dim=0)
+ for j in range(tp_size):
+ grad_inputs[j * rounds + i + 1] = tmp_grad_inputs[j]
+
+ if weight.requires_grad:
+ dim = grad_out.dim()
+ grad_weight += (
+ grad_out.reshape(-1, grad_out.shape[-1])
+ .t()
+ .matmul(inputs[i + 1].reshape(-1, inputs[i + 1].shape[-1]))
+ )
+
+ if input_require_grad:
+ grad_input = torch.cat(grad_inputs, dim=0)
+ grad_input = grad_input.view(input_shape)
+
+ if bias is not None and bias.requires_grad:
+ grad_out = torch.cat(grad_outs, dim=0)
+ grad_bias = grad_out.reshape(-1, grad_out.shape[-1]).sum(0)
+
+ return grad_input, grad_weight, grad_bias
+
+
+class OpParallelLinear(torch.autograd.Function):
+ """OpParallelLinear is a subclass of torch.autograd.Function.
+ It gathers the input tensor when needed, and all reduce or reduece scatter the output when needed.
+
+ """
+
+ @staticmethod
+ def forward(
+ ctx,
+ input,
+ weight,
+ bias=None,
+ gather_input=False,
+ gather_output=False,
+ split_input=False,
+ reduce_output_type=None,
+ async_gather_chunks=2,
+ ):
+ if reduce_output_type is not None:
+ reduce_output_type = ReduceType(reduce_output_type)
+
+ ctx.save_for_backward(input, weight, bias)
+ ctx.gather_output = gather_output
+ ctx.split_input = split_input
+ ctx.gather_input = gather_input
+ ctx.reduce_output_type = reduce_output_type
+ ctx.async_gather_chunks = async_gather_chunks
+
+ if (
+ gather_input
+ and config["tp_size"] > 1
+ and async_gather_chunks > 1
+ and split_input == False
+ ):
+ out = async_all_gather_linear_func(input, weight, bias, async_gather_chunks)
+ elif reduce_output_type == ReduceType.REDUCE_SCATTER:
+ return async_reduce_scatter_linear_func(
+ input, weight, bias, async_gather_chunks
+ )
+ else:
+ all_input = preprocess_input(input, ctx.gather_input, ctx.split_input)
+ out = F.linear(all_input, weight, bias)
+
+ if gather_output:
+ all_output_list = all_gather(out, config["tp_comm"])
+ all_output_list = all_output_list.chunk(config["tp_size"], dim=0)
+ out = torch.cat(all_output_list, dim=all_output_list[0].dim() - 1).flatten(
+ 0, 1
+ )
+
+ if reduce_output_type is None:
+ return out
+
+ if reduce_output_type == ReduceType.ALL_REDUCE:
+ nccl.allReduce(out.storage(), out.storage(), "sum", config["tp_comm"])
+ return out
+ else:
+ assert False, "no support reduce type{}".format(reduce_output_type)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ input, weight, bias = ctx.saved_tensors
+ gather_output = ctx.gather_output
+
+ if ctx.reduce_output_type == ReduceType.REDUCE_SCATTER:
+ if input.requires_grad or weight.requires_grad:
+ grad_input, grad_weight, grad_bias = (
+ async_all_gather_linear_backward_func(
+ grad_output, input, weight, bias, ctx.async_gather_chunks
+ )
+ )
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None
+ else:
+ grad_output = all_gather(grad_output, config["tp_comm"])
+ grad_output = grad_output.flatten(0, 1)
+
+ if gather_output:
+ tp_size = config["tp_size"]
+ tp_id = config["topology"].tp_id
+ grad_output_list = grad_output.chunk(tp_size, dim=-1)
+ grad_output = grad_output_list[tp_id]
+
+ grad_input = grad_weight = grad_bias = None
+
+ current_stream = torch.cuda.current_stream()
+ if input.requires_grad or weight.requires_grad:
+ if ctx.gather_input:
+ # async the all_gather
+ with torch.cuda.stream(config["tp_comm_stream"]):
+ input.record_stream(config["tp_comm_stream"])
+ config["tp_comm_stream"].wait_stream(current_stream)
+ all_input = preprocess_input(
+ input, ctx.gather_input, ctx.split_input
+ )
+ # use event to solve two streams waiting for each other
+ gather_event = config["tp_comm_stream"].record_event()
+ else:
+ all_input = preprocess_input(input, ctx.gather_input, ctx.split_input)
+
+ if input.requires_grad:
+ grad_all_input = grad_output.matmul(weight)
+ grad_input = torch.zeros_like(input)
+ if ctx.gather_input:
+ # async the reduce_scatter
+ with torch.cuda.stream(config["tp_comm_stream"]):
+ config["tp_comm_stream"].wait_stream(current_stream)
+ grad_input.record_stream(config["tp_comm_stream"])
+ grad_all_input.record_stream(config["tp_comm_stream"])
+ nccl.reduceScatter(
+ grad_all_input.storage(),
+ grad_input.storage(),
+ "sum",
+ config["tp_comm"],
+ )
+ elif ctx.reduce_output_type is None:
+ with torch.cuda.stream(config["tp_comm_stream"]):
+ config["tp_comm_stream"].wait_stream(current_stream)
+ grad_input.record_stream(config["tp_comm_stream"])
+ nccl.allReduce(
+ grad_all_input.storage(),
+ grad_all_input.storage(),
+ "sum",
+ config["tp_comm"],
+ )
+ grad_input = grad_all_input
+ else:
+ grad_input = grad_all_input
+
+ if ctx.split_input:
+ with torch.cuda.stream(config["tp_comm_stream"]):
+ config["tp_comm_stream"].wait_stream(current_stream)
+ grad_input.record_stream(config["tp_comm_stream"])
+ grad_input = all_gather(grad_input, config["tp_comm"])
+
+ # wait all_gather
+ if ctx.gather_input:
+ current_stream.wait_event(gather_event)
+ if weight.requires_grad:
+ grad_weight = (
+ grad_output.reshape(-1, grad_output.shape[-1])
+ .t()
+ .matmul(all_input.reshape(-1, all_input.shape[-1]))
+ )
+
+ if bias is not None and bias.requires_grad:
+ grad_bias = grad_output.reshape(-1, grad_output.shape[-1]).sum(0)
+
+ current_stream = torch.cuda.current_stream()
+ current_stream.wait_stream(config["tp_comm_stream"])
+ return grad_input, grad_weight, grad_bias, None, None, None, None, None
diff --git a/examples/BMTrain/bmtrain/nn/row_parallel_linear.py b/examples/BMTrain/bmtrain/nn/row_parallel_linear.py
new file mode 100644
index 00000000..ee4610cc
--- /dev/null
+++ b/examples/BMTrain/bmtrain/nn/row_parallel_linear.py
@@ -0,0 +1,88 @@
+import torch
+from torch.nn.parameter import Parameter
+
+import bmtrain as bmt
+from bmtrain.global_var import config
+from .parallel_linear_func import OpParallelLinear, ReduceType
+
+
+class RowParallelLinear(bmt.DistributedModule):
+ """Tensor Parallel use row partition for Linear.
+
+ Args:
+ in_features (int): in_features size.
+ out_features (int): out_features size.
+ bias (bool): whether use bias.
+ dtype : data type.
+ split_input (bool): whether split input before compute.
+ all_reduce_output (bool): if true use all_reduce data after compute, or use reduce_scatter.
+ async_chunks (int): chunk size for async.
+
+ """
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ dtype=None,
+ split_input=False,
+ all_reduce_output=False,
+ async_chunks=2,
+ ) -> None:
+ super().__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.split_input = split_input
+ self.all_reduce_output = all_reduce_output
+ self.async_chunks = async_chunks
+ tp_size = config["tp_size"]
+ assert in_features % tp_size == 0
+ self.in_features_per_partition = in_features // tp_size
+ self.weight = bmt.DistributedParameter(
+ torch.empty(
+ self.out_features,
+ self.in_features_per_partition,
+ dtype=dtype,
+ device="cuda",
+ ),
+ init_method=torch.nn.init.xavier_normal_,
+ tp_split_dim=1,
+ tp_mode=True,
+ )
+ if bias:
+ self.bias = bmt.DistributedParameter(
+ torch.empty(self.out_features, dtype=dtype, device="cuda"),
+ init_method=torch.nn.init.zeros_,
+ tp_split_dim=-1,
+ tp_mode=True,
+ )
+ else:
+ self.register_parameter("bias", None)
+
+ def forward(self, input):
+ gather_input = self.split_input
+ gather_output = False
+ reduce_output_type = (
+ ReduceType.ALL_REDUCE
+ if self.all_reduce_output
+ else ReduceType.REDUCE_SCATTER
+ )
+ out = OpParallelLinear.apply(
+ input,
+ self.weight,
+ None,
+ gather_input,
+ gather_output,
+ self.split_input,
+ reduce_output_type,
+ self.async_chunks,
+ )
+ if self.bias is not None:
+ out = out + self.bias
+ return out
+
+ def extra_repr(self) -> str:
+ return "in_features={}, out_features={}, bias={}".format(
+ self.in_features_per_partition, self.out_features, self.bias is not None
+ )
diff --git a/examples/BMTrain/bmtrain/optim/__init__.py b/examples/BMTrain/bmtrain/optim/__init__.py
new file mode 100644
index 00000000..15206328
--- /dev/null
+++ b/examples/BMTrain/bmtrain/optim/__init__.py
@@ -0,0 +1,3 @@
+from .adam import AdamOptimizer
+from .adam_offload import AdamOffloadOptimizer
+from .optim_manager import OptimManager
\ No newline at end of file
diff --git a/examples/BMTrain/bmtrain/optim/_distributed.py b/examples/BMTrain/bmtrain/optim/_distributed.py
new file mode 100644
index 00000000..df8f2f3e
--- /dev/null
+++ b/examples/BMTrain/bmtrain/optim/_distributed.py
@@ -0,0 +1,40 @@
+import torch
+from ..distributed import all_reduce, all_gather
+
+
+def state_dict_gather(state_dict):
+ param_key = [
+ p for param_group in state_dict["param_groups"] for p in param_group["params"]
+ ]
+ for k, v in state_dict["state"].items():
+ if "step" in v:
+ step = v["step"]
+
+ for k in param_key:
+ if k not in state_dict["state"]:
+ state_dict["state"][k] = {
+ "exp_avg": torch.tensor([], device="cuda", dtype=torch.float32),
+ "exp_avg_sq": torch.tensor([], device="cuda", dtype=torch.float32),
+ "_param_fp32": torch.tensor([], device="cuda", dtype=torch.float32),
+ "step": step,
+ }
+ v = state_dict["state"][k]
+ for name, dtype in [
+ ("exp_avg", torch.float32),
+ ("exp_avg_sq", torch.float32),
+ ("_param_fp32", torch.float32),
+ ]:
+ if name in v:
+ with torch.no_grad():
+ numel = torch.tensor(
+ v[name].numel(), device="cuda", dtype=torch.long
+ )
+ max_numel = all_reduce(numel, op="max")
+ v_p = torch.nn.functional.pad(
+ v[name], (0, max_numel - numel), value=-1e15
+ )
+ if max_numel > 0:
+ whole_state = all_gather(v_p.cuda()).flatten()
+ whole_state = whole_state[whole_state != -1e15]
+ v[name] = whole_state.contiguous().cpu()
+ return state_dict
diff --git a/examples/BMTrain/bmtrain/optim/_function.py b/examples/BMTrain/bmtrain/optim/_function.py
new file mode 100644
index 00000000..f9e0ce9d
--- /dev/null
+++ b/examples/BMTrain/bmtrain/optim/_function.py
@@ -0,0 +1,218 @@
+from .. import C
+import torch
+
+CHECK_INPUT = lambda x: x.is_contiguous() and x.is_cuda
+
+
+def bf16_from_fp32(param_fp32):
+ param_bf16 = torch.empty_like(param_fp32, dtype=torch.bfloat16)
+ C.to_bf16_from_fp32(
+ param_fp32.numel(), param_fp32.data_ptr(), param_bf16.data_ptr()
+ )
+ return param_bf16
+
+
+def fp16_from_fp32(param_fp32):
+ param_fp16 = torch.empty_like(param_fp32, dtype=torch.float16)
+ C.to_fp16_from_fp32(
+ param_fp32.numel(), param_fp32.data_ptr(), param_fp16.data_ptr()
+ )
+ return param_fp16
+
+
+def adam_cpu(
+ param_fp32: torch.Tensor,
+ param_fp16: torch.Tensor,
+ delta_info: torch.Tensor,
+ g_fp16: torch.Tensor,
+ m_fp32: torch.Tensor,
+ v_fp32: torch.Tensor,
+ beta1: float,
+ beta2: float,
+ eps: float,
+ lr: float,
+ scale: float,
+ weight_decay: float,
+ step: int,
+) -> None:
+ assert param_fp32.is_contiguous(), "param_fp32 must be contiguous"
+ assert param_fp16.is_contiguous(), "param_fp16 must be contiguous"
+ assert g_fp16.is_contiguous(), "g_fp16 must be contiguous"
+ assert m_fp32.is_contiguous(), "m_fp32 must be contiguous"
+ assert v_fp32.is_contiguous(), "v_fp32 must be contiguous"
+ assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor"
+ assert (
+ param_fp16.dtype == torch.float16 or param_fp16.dtype == torch.bfloat16
+ ), "param_fp16 must be float16/bfloat16 tensor"
+ assert (
+ g_fp16.dtype == torch.float16 or g_fp16.dtype == torch.bfloat16
+ ), "g_fp16 must be float16/bfloat16 tensor"
+ assert m_fp32.dtype == torch.float32, "m_fp32 must be float32 tensor"
+ assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor"
+ assert param_fp32.device == torch.device("cpu"), "param_fp32 must be a cpu tensor"
+ assert param_fp16.device == torch.device("cpu"), "param_fp16 must be a cpu tensor"
+ assert g_fp16.device == torch.device("cpu"), "g_fp16 must be a cpu tensor"
+ assert m_fp32.device == torch.device("cpu"), "m_fp32 must be a cpu tensor"
+ assert v_fp32.device == torch.device("cpu"), "v_fp32 must be a cpu tensor"
+ assert (
+ param_fp32.numel() == param_fp16.numel()
+ ), "param_fp32 and param_fp16 must have the same number of elements"
+ assert (
+ param_fp32.numel() == g_fp16.numel()
+ ), "param_fp32 and g_fp16 must have the same number of elements"
+ assert (
+ param_fp32.numel() == m_fp32.numel()
+ ), "param_fp32 and m_fp32 must have the same number of elements"
+ assert (
+ param_fp32.numel() == v_fp32.numel()
+ ), "param_fp32 and v_fp32 must have the same number of elements"
+ if delta_info is not None:
+ assert delta_info.is_contiguous(), "delta_info must be contiguous"
+ assert delta_info.dtype == torch.float32, "delta_info must be float32 tensor"
+ assert delta_info.device == torch.device(
+ "cpu"
+ ), "delta_info must be a cpu tensor"
+ assert delta_info.numel() == 4, "delta_info have a length of 4"
+ bias_correction1 = 1 - beta1**step
+ bias_correction2 = 1 - beta2**step
+ if g_fp16.dtype == torch.float16:
+ launcher = C.adam_cpu_fp16_launcher
+ elif g_fp16.dtype == torch.bfloat16:
+ if not C.is_bf16_supported():
+ raise NotImplementedError(f"bfloat16 is not supported on current GPU")
+ launcher = C.adam_cpu_bf16_launcher
+ launcher(
+ param_fp32.numel(),
+ param_fp32.data_ptr(),
+ param_fp16.data_ptr(),
+ delta_info.data_ptr() if delta_info is not None else 0,
+ g_fp16.data_ptr(),
+ m_fp32.data_ptr(),
+ v_fp32.data_ptr(),
+ beta1,
+ beta2,
+ eps,
+ lr,
+ scale,
+ weight_decay,
+ bias_correction1,
+ bias_correction2,
+ )
+
+
+def adam_fp16(
+ param_fp32: torch.Tensor,
+ param_fp16: torch.Tensor,
+ g_fp16: torch.Tensor,
+ m_fp16: torch.Tensor,
+ v_fp32: torch.Tensor,
+ beta1: float,
+ beta2: float,
+ eps: float,
+ lr: float,
+ scale: float,
+ weight_decay: float,
+ step: int,
+) -> None:
+ assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda"
+ assert CHECK_INPUT(param_fp16), "param_fp16 must be contiguous and on cuda"
+ assert CHECK_INPUT(g_fp16), "g_fp16 must be contiguous and on cuda"
+ assert CHECK_INPUT(m_fp16), "m_fp32 must be contiguous and on cuda"
+ assert CHECK_INPUT(v_fp32), "v_fp32 must be contiguous and on cuda"
+ assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor"
+ assert param_fp16.dtype == torch.float16, "param_fp16 must be float16 tensor"
+ assert g_fp16.dtype == torch.float16, "g_fp16 must be float16 tensor"
+ assert m_fp16.dtype == torch.float16, "m_fp16 must be float16 tensor"
+ assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor"
+ assert (
+ param_fp32.numel() == param_fp16.numel()
+ ), "param_fp32 and param_fp16 must have the same number of elements"
+ assert (
+ param_fp32.numel() == g_fp16.numel()
+ ), "param_fp32 and g_fp16 must have the same number of elements"
+ assert (
+ param_fp32.numel() == m_fp16.numel()
+ ), "param_fp32 and m_fp32 must have the same number of elements"
+ assert (
+ param_fp32.numel() == v_fp32.numel()
+ ), "param_fp32 and v_fp32 must have the same number of elements"
+ bias_correction1 = 1 - beta1**step
+ bias_correction2 = 1 - beta2**step
+ stream = torch.cuda.current_stream().cuda_stream
+ C.adam_fp16_launcher(
+ param_fp32.numel(),
+ param_fp32.data_ptr(),
+ param_fp16.data_ptr(),
+ g_fp16.data_ptr(),
+ m_fp16.data_ptr(),
+ v_fp32.data_ptr(),
+ beta1,
+ beta2,
+ eps,
+ lr,
+ scale,
+ weight_decay,
+ bias_correction1,
+ bias_correction2,
+ stream,
+ )
+
+
+def adam_bf16(
+ param_fp32: torch.Tensor,
+ param_bf16: torch.Tensor,
+ g_bf16: torch.Tensor,
+ m_fp32: torch.Tensor,
+ v_fp32: torch.Tensor,
+ beta1: float,
+ beta2: float,
+ eps: float,
+ lr: float,
+ scale: float,
+ weight_decay: float,
+ step: int,
+) -> None:
+ assert CHECK_INPUT(param_fp32), "param_fp32 must be contiguous and on cuda"
+ assert CHECK_INPUT(param_bf16), "param_bf16 must be contiguous and on cuda"
+ assert CHECK_INPUT(g_bf16), "g_bf16 must be contiguous and on cuda"
+ assert CHECK_INPUT(m_fp32), "m_fp32 must be contiguous and on cuda"
+ assert CHECK_INPUT(v_fp32), "v_fp32 must be contiguous and on cuda"
+ assert param_fp32.dtype == torch.float32, "param_fp32 must be float32 tensor"
+ assert param_bf16.dtype == torch.bfloat16, "param_fp16 must be float16 tensor"
+ assert g_bf16.dtype == torch.bfloat16, "g_bf16 must be bfloat16 tensor"
+ assert m_fp32.dtype == torch.float32, "m_fp32 must be bfloat16 tensor"
+ assert v_fp32.dtype == torch.float32, "v_fp32 must be float32 tensor"
+ assert (
+ param_fp32.numel() == param_bf16.numel()
+ ), "param_fp32 and param_bf16 must have the same number of elements"
+ assert (
+ param_fp32.numel() == g_bf16.numel()
+ ), "param_fp32 and g_fp16 must have the same number of elements"
+ assert (
+ param_fp32.numel() == m_fp32.numel()
+ ), "param_fp32 and m_m_fp32 must have the same number of elements"
+ assert (
+ param_fp32.numel() == v_fp32.numel()
+ ), "param_fp32 and v_fp32 must have the same number of elements"
+ bias_correction1 = 1 - beta1**step
+ bias_correction2 = 1 - beta2**step
+ stream = torch.cuda.current_stream().cuda_stream
+ if not C.is_bf16_supported():
+ raise NotImplementedError(f"bfloat16 is not supported on current GPU")
+ C.adam_bf16_launcher(
+ param_fp32.numel(),
+ param_fp32.data_ptr(),
+ param_bf16.data_ptr(),
+ g_bf16.data_ptr(),
+ m_fp32.data_ptr(),
+ v_fp32.data_ptr(),
+ beta1,
+ beta2,
+ eps,
+ lr,
+ scale,
+ weight_decay,
+ bias_correction1,
+ bias_correction2,
+ stream,
+ )
diff --git a/examples/BMTrain/bmtrain/optim/adam.py b/examples/BMTrain/bmtrain/optim/adam.py
new file mode 100644
index 00000000..f99c483c
--- /dev/null
+++ b/examples/BMTrain/bmtrain/optim/adam.py
@@ -0,0 +1,252 @@
+import torch
+from ..global_var import config
+from . import _function as F
+import torch.optim._functional
+from .. import C
+from .. import nccl
+import inspect
+from ..utils import check_torch_version
+from copy import deepcopy
+from itertools import chain
+from collections import defaultdict
+
+
+class AdamOptimizer(torch.optim.Optimizer):
+ """
+ Adam optimizer support fp16 and bf16.
+ """
+
+ _bmtrain_optimizer = True
+
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ hold_steps=0,
+ ):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
+ super().__init__(params, defaults)
+
+ self._hold_steps = hold_steps
+
+ def _on_justify_scale(self, old_scale, new_scale):
+ delta = new_scale / old_scale
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p in self.state:
+ state = self.state[p]
+ if len(state) > 0:
+ if p.dtype == torch.float16:
+ state["exp_avg"] *= delta
+ state["exp_avg_sq"] *= delta
+
+ @torch.no_grad()
+ def step(self, closure=None, scale=1):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
+ """
+
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ # update parameters
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is not None and p.requires_grad:
+ if p.grad.is_sparse:
+ raise RuntimeError(
+ "Adam does not support sparse gradients, please consider SparseAdam instead"
+ )
+ if p.dtype not in [torch.float32, torch.half, torch.bfloat16]:
+ raise RuntimeError(
+ "Adam only supports fp32, fp16 and bf16 gradients"
+ )
+
+ state = self.state[p]
+ # Lazy state initialization
+ if len(state) == 0:
+ state["step"] = 0
+ # Exponential moving average of gradient values
+ if p.dtype == torch.float16:
+ state["exp_avg"] = torch.zeros(
+ p.size(), dtype=torch.float16, device=p.device
+ ) # on device
+ else:
+ state["exp_avg"] = torch.zeros(
+ p.size(), dtype=torch.float32, device=p.device
+ ) # on device
+ # Exponential moving average of squared gradient values
+ state["exp_avg_sq"] = torch.zeros(
+ p.size(), dtype=torch.float32, device=p.device
+ ) # on device
+
+ if p.dtype != torch.float32:
+ state["_param_fp32"] = torch.empty(
+ p.size(), dtype=torch.float32, device=p.device
+ ) # on device
+ state["_param_fp32"].copy_(p)
+
+ # update the steps for each param group update
+ if ("maximize" in group) and (group["maximize"] is True):
+ grad = -p.grad
+ else:
+ grad = p.grad
+
+ if p.dtype == torch.float32:
+ other_kwargs = {}
+ if (
+ "maximize"
+ in inspect.signature(
+ torch.optim._functional.adam
+ ).parameters
+ ):
+ other_kwargs["maximize"] = False
+ torch.optim._functional.adam(
+ [p],
+ [grad / scale],
+ [state["exp_avg"]],
+ [state["exp_avg_sq"]],
+ [],
+ (
+ [state["step"]]
+ if check_torch_version("1.12.0") < 0
+ else [torch.tensor(state["step"])]
+ ),
+ amsgrad=False,
+ beta1=group["betas"][0],
+ beta2=group["betas"][1],
+ lr=0.0 if state["step"] < self._hold_steps else group["lr"],
+ weight_decay=group["weight_decay"],
+ eps=group["eps"],
+ **other_kwargs
+ )
+ state["step"] += 1
+ else:
+ f = F.adam_fp16 if p.dtype == torch.float16 else F.adam_bf16
+ state["step"] += 1
+ f(
+ state["_param_fp32"], # fp32
+ p, # fp16
+ grad, # fp16
+ state["exp_avg"], # fp16: m
+ state["exp_avg_sq"], # fp32: v
+ group["betas"][0],
+ group["betas"][1],
+ group["eps"],
+ 0.0 if state["step"] < self._hold_steps else group["lr"],
+ scale,
+ group["weight_decay"],
+ state["step"],
+ )
+
+ return loss
+
+ def get_avg_delta():
+
+ raise NotImplementedError(
+ "get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer"
+ )
+
+ def get_var_delta():
+
+ raise NotImplementedError(
+ "get delta info is not supported in Adam optimizer , try bmt.optim.AdamOffloadOptimizer"
+ )
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ r"""Loads the optimizer state.
+
+ Args:
+ state_dict (dict): optimizer state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # deepcopy, to be consistent with module API
+ state_dict = deepcopy(state_dict)
+ # Validate the state_dict
+ groups = self.param_groups
+ saved_groups = state_dict["param_groups"]
+
+ if len(groups) != len(saved_groups):
+ raise ValueError(
+ "loaded state dict has a different number of " "parameter groups"
+ )
+ param_lens = (len(g["params"]) for g in groups)
+ saved_lens = (len(g["params"]) for g in saved_groups)
+ if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
+ raise ValueError(
+ "loaded state dict contains a parameter group "
+ "that doesn't match the size of optimizer's group"
+ )
+
+ # Update the state
+ id_map = {
+ old_id: p
+ for old_id, p in zip(
+ chain.from_iterable((g["params"] for g in saved_groups)),
+ chain.from_iterable((g["params"] for g in groups)),
+ )
+ }
+
+ # Copy state assigned to params (and cast tensors to appropriate types).
+ # State that is not assigned to params is copied as is (needed for
+ # backward compatibility).
+ state = defaultdict(dict)
+ for k, v in state_dict["state"].items():
+ if k in id_map:
+ param = id_map[k]
+
+ if param.dtype != torch.float32 and "_param_fp32" not in v:
+ v["_param_fp32"] = torch.empty(
+ param.size(), dtype=torch.float32, device=param.device
+ )
+ v["_param_fp32"].copy_(param)
+
+ for name, dtype in [
+ (
+ "exp_avg",
+ (
+ torch.float16
+ if param.dtype == torch.float16
+ else torch.float32
+ ),
+ ),
+ ("exp_avg_sq", torch.float32),
+ ("_param_fp32", torch.float32),
+ ]:
+ if name in v:
+ v[name] = v[name].to(param.device).to(dtype)
+
+ state[param] = v
+ else:
+ state[k] = v
+
+ # Update parameter groups, setting their 'params' value
+ def update_group(group, new_group):
+ new_group["params"] = group["params"]
+ return new_group
+
+ param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
+ self.__setstate__({"state": state, "param_groups": param_groups})
+
+ # TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu
+ def zero_grad(self, set_to_none: bool = False):
+ super().zero_grad(set_to_none=set_to_none)
diff --git a/examples/BMTrain/bmtrain/optim/adam_offload.py b/examples/BMTrain/bmtrain/optim/adam_offload.py
new file mode 100644
index 00000000..f6ea97ba
--- /dev/null
+++ b/examples/BMTrain/bmtrain/optim/adam_offload.py
@@ -0,0 +1,386 @@
+import torch
+from ..global_var import config
+from . import _function as F
+from .. import nccl
+import inspect
+from ..utils import check_torch_version
+from copy import deepcopy
+from itertools import chain
+from collections import defaultdict
+from ._distributed import state_dict_gather
+
+
+class AdamOffloadOptimizer(torch.optim.Optimizer):
+ """
+ Adam optimizer using optimizer offload.
+ """
+
+ _bmtrain_optimizer = True
+
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ hold_steps=0,
+ record_delta=False,
+ ):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0.0 <= weight_decay:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ self.avg_delta = 0
+ self.var_delta = 0
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
+ super().__init__(params, defaults)
+ self._hold_steps = hold_steps
+ self._events = {}
+ self.record_delta = record_delta
+ if self.record_delta:
+ for group in self.param_groups:
+ for p in group["params"]:
+ setattr(
+ p,
+ "_delta_info",
+ (
+ torch.tensor(
+ [0 for i in range(4)], dtype=torch.float32, device="cpu"
+ )
+ ),
+ )
+
+ @torch.no_grad()
+ def step(self, closure=None, scale=1):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
+ """
+
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ # parameters to be updated
+ update_params = []
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is not None and p.requires_grad:
+ if p.grad.is_sparse:
+ raise RuntimeError(
+ "Adam does not support sparse gradients, please consider SparseAdam instead"
+ )
+ if p.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
+ raise RuntimeError(
+ "Adam only supports fp32, fp16 and bf16 gradients"
+ )
+
+ state = self.state[p]
+ # Lazy state initialization
+ if len(state) == 0:
+ state["step"] = 0
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros(
+ p.size(), dtype=torch.float32, device="cpu"
+ ) # on host
+ # Exponential moving average of squared gradient values
+ state["exp_avg_sq"] = torch.zeros(
+ p.size(), dtype=torch.float32, device="cpu"
+ ) # on host
+
+ if p.dtype == torch.float32:
+ state["_param_fp32"] = torch.empty(
+ p.size(), dtype=torch.float32, pin_memory=True
+ ) # on host
+ state["_param_fp32"].copy_(p)
+
+ # placeholder
+ state["_grad_fp32"] = torch.empty(
+ p.size(), dtype=torch.float32, pin_memory=True
+ ) # on host
+ else:
+ state["_param_fp32"] = torch.empty(
+ p.size(), dtype=torch.float32, device="cpu"
+ ) # on host
+ state["_param_fp32"].copy_(p)
+
+ # placeholder
+ state["_param_fp16"] = torch.empty(
+ p.size(), dtype=p.dtype, pin_memory=True
+ ) # on host
+ state["_grad_fp16"] = torch.empty(
+ p.size(), dtype=p.dtype, pin_memory=True
+ ) # on host
+
+ if p not in self._events:
+ self._events[p] = torch.cuda.Event()
+
+ update_params.append(
+ (
+ p,
+ state,
+ self._events[p],
+ group["betas"][0],
+ group["betas"][1],
+ group["eps"],
+ group["lr"],
+ group["weight_decay"],
+ )
+ )
+
+ # transfer parameters to host asynchronously
+ for param, state, event, _, _, _, _, _ in update_params:
+ if param.dtype == torch.float32:
+ state["_grad_fp32"].copy_(param.grad, non_blocking=True)
+ else:
+ state["_grad_fp16"].copy_(param.grad, non_blocking=True)
+ torch.cuda.current_stream().record_event(event)
+ sum_delta = 0
+ sum_sq_delta = 0
+ total_numel = 0
+ for param, state, event, beta1, beta2, eps, lr, weight_decay in update_params:
+ # wait for transfer to host
+ event.synchronize()
+
+ # update parameters
+ if param.dtype == torch.float32:
+ state["_grad_fp32"].mul_(1.0 / scale)
+ if ("maximize" in group) and (group["maximize"] is True):
+ grad = -state["_grad_fp32"]
+ else:
+ grad = state["_grad_fp32"]
+ other_kwargs = {}
+ if (
+ "maximize"
+ in inspect.signature(torch.optim._functional.adam).parameters
+ ):
+ other_kwargs["maximize"] = False
+ torch.optim._functional.adam(
+ [state["_param_fp32"]],
+ [grad],
+ [state["exp_avg"]],
+ [state["exp_avg_sq"]],
+ [],
+ (
+ [state["step"]]
+ if check_torch_version("1.12.0") < 0
+ else [torch.tensor(state["step"])]
+ ),
+ amsgrad=False,
+ beta1=beta1,
+ beta2=beta2,
+ lr=0.0 if state["step"] < self._hold_steps else lr,
+ weight_decay=weight_decay,
+ eps=eps,
+ **other_kwargs
+ )
+ # transfer parameters back to device asynchronously
+ param.copy_(state["_param_fp32"], non_blocking=True)
+ state["step"] += 1
+ else:
+ state["step"] += 1
+ if ("maximize" in group) and (group["maximize"] is True):
+ grad = -state["_grad_fp16"]
+ else:
+ grad = state["_grad_fp16"]
+ F.adam_cpu(
+ state["_param_fp32"].view(-1),
+ state["_param_fp16"].view(-1),
+ param._delta_info if self.record_delta else None,
+ grad.view(-1),
+ state["exp_avg"].view(-1),
+ state["exp_avg_sq"].view(-1),
+ beta1,
+ beta2,
+ eps,
+ 0.0 if state["step"] < self._hold_steps else lr,
+ scale,
+ weight_decay,
+ state["step"],
+ )
+ total_numel += state["_param_fp16"].numel()
+ if self.record_delta:
+ sum_delta += param._delta_info[2].item()
+ sum_sq_delta += param._delta_info[3].item()
+ # transfer parameters back to device asynchronously
+ param.copy_(state["_param_fp16"], non_blocking=True)
+ if self.record_delta:
+ self.avg_delta = sum_delta / total_numel
+ self.var_delta = sum_sq_delta / total_numel - self.avg_delta**2
+
+ return loss
+
+ def get_avg_delta(self) -> None:
+ return self.avg_delta if self.record_delta else 0
+
+ def get_var_delta(self) -> None:
+ return self.var_delta if self.record_delta else 0
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ r"""Loads the optimizer state.
+
+ Args:
+ state_dict (dict): optimizer state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # deepcopy, to be consistent with module API
+
+ state_dict = deepcopy(state_dict)
+ # Validate the state_dict
+ groups = self.param_groups
+ saved_groups = state_dict["param_groups"]
+
+ if len(groups) != len(saved_groups):
+ raise ValueError(
+ "loaded state dict has a different number of " "parameter groups"
+ )
+ param_lens = (len(g["params"]) for g in groups)
+ saved_lens = (len(g["params"]) for g in saved_groups)
+ if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
+ raise ValueError(
+ "loaded state dict contains a parameter group "
+ "that doesn't match the size of optimizer's group"
+ )
+
+ # Update the state
+ id_map = {
+ old_id: p
+ for old_id, p in zip(
+ chain.from_iterable((g["params"] for g in saved_groups)),
+ chain.from_iterable((g["params"] for g in groups)),
+ )
+ }
+
+ # _param_start_end = chain.from_iterable((g["params_start_end"] for g in saved_groups))
+ # Copy state assigned to params (and cast tensors to appropriate types).
+ # State that is not assigned to params is copied as is (needed for
+ # backward compatibility).
+ state = defaultdict(dict)
+ is_whole = False if "is_whole" not in state_dict else state_dict["is_whole"]
+ pop_key = []
+ for k, v in state_dict["state"].items():
+ if k in id_map:
+ param = id_map[k]
+ if is_whole and param._start_partition is not None:
+ for key in ["_param_fp32", "exp_avg_sq", "exp_avg"]:
+ if key in v:
+ v[key] = v[key][
+ param._start_partition : param._end_partition
+ ]
+ elif is_whole and param._start_partition is None:
+ pop_key.append(param)
+
+ if "_param_fp32" not in v:
+ with torch.no_grad():
+ v["_param_fp32"] = torch.empty(
+ param.size(), dtype=torch.float32, device="cpu"
+ )
+ v["_param_fp32"].copy_(param)
+
+ for name, dtype in [
+ ("exp_avg", torch.float32),
+ ("exp_avg_sq", torch.float32),
+ ("_param_fp32", torch.float32),
+ ]:
+ if name in v:
+ v[name] = v[name].to("cpu").to(dtype)
+
+ state[param] = v
+ if param.dtype == torch.float32:
+ state[param]["_param_fp32"] = state[param][
+ "_param_fp32"
+ ].pin_memory() # on host
+ # initialize placeholders
+ state[param]["_grad_fp32"] = torch.empty(
+ param.size(), dtype=torch.float32, pin_memory=True
+ ) # on host
+ else:
+ # initialize placeholders
+ state[param]["_param_fp16"] = torch.empty(
+ param.size(), dtype=param.dtype, pin_memory=True
+ ) # on host
+ state[param]["_grad_fp16"] = torch.empty(
+ param.size(), dtype=param.dtype, pin_memory=True
+ ) # on host
+ else:
+ state[k] = v
+ for k in pop_key:
+ state.pop(k)
+
+ # Update parameter groups, setting their 'params' value
+ def update_group(group, new_group):
+ new_group["params"] = group["params"]
+ return new_group
+
+ param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
+ self.__setstate__({"state": state, "param_groups": param_groups})
+
+ def state_dict(self, gather=False) -> dict:
+ r"""Returns the state of the optimizer as a :class:`dict`.
+
+ It contains two entries:
+
+ * state - a dict holding current optimization state. Its content
+ differs between optimizer classes.
+ * param_groups - a list containing all parameter groups where each
+ parameter group is a dict
+ """
+
+ # Save order indices instead of Tensors
+ param_mappings = {}
+ start_index = 0
+
+ def pack_group(group):
+ nonlocal start_index
+ packed = {k: v for k, v in group.items() if k != "params"}
+ param_mappings.update(
+ {
+ id(p): i
+ for i, p in enumerate(group["params"], start_index)
+ if id(p) not in param_mappings
+ }
+ )
+ packed["params"] = [param_mappings[id(p)] for p in group["params"]]
+ start_index += len(packed["params"])
+ return packed
+
+ def cut_states(state):
+ return {
+ "step": state["step"],
+ "exp_avg": state["exp_avg"],
+ "exp_avg_sq": state["exp_avg_sq"],
+ "_param_fp32": state["_param_fp32"],
+ }
+
+ param_groups = [pack_group(g) for g in self.param_groups]
+ # Remap state to use order indices as keys
+ packed_state = {
+ (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): cut_states(v)
+ for k, v in self.state.items()
+ }
+ states = {
+ "state": packed_state,
+ "param_groups": param_groups,
+ }
+ if gather:
+ states = state_dict_gather(states)
+ states["is_whole"] = True
+ else:
+ states["is_whole"] = False
+
+ return states
+
+ # TODO zero_grad(set_to_none=True) makes optimizer crashed, maybe the reason of grad accu
+ def zero_grad(self, set_to_none: bool = False):
+ super().zero_grad(set_to_none=set_to_none)
diff --git a/examples/BMTrain/bmtrain/optim/optim_manager.py b/examples/BMTrain/bmtrain/optim/optim_manager.py
new file mode 100644
index 00000000..1a98ed92
--- /dev/null
+++ b/examples/BMTrain/bmtrain/optim/optim_manager.py
@@ -0,0 +1,226 @@
+from typing import Optional, Union, List, Dict, Tuple
+import torch
+from ..loss._function import has_inf_nan
+from ..utils import print_rank
+from ..lr_scheduler.warmup import WarmupLRScheduler
+from .. import nccl
+from ..global_var import config
+
+def check_overflow(param_groups):
+ # check overflow
+ has_inf_or_nan = torch.zeros(1, dtype=torch.uint8, device="cuda")[0]
+ for group in param_groups:
+ for p in group['params']:
+ if p.grad is not None:
+ if p.dtype != torch.float:
+ has_inf_nan(p.grad, has_inf_or_nan)
+ if "comm" in config:
+ nccl.allReduce(has_inf_or_nan.storage(), has_inf_or_nan.storage(), "max", config["comm"])
+
+ if has_inf_or_nan > 0:
+ raise OverflowError("Gradient overflow")
+
+def grad_rescale(param_groups, scale):
+ for group in param_groups:
+ for p in group['params']:
+ if p.grad is not None and p.requires_grad:
+ p.grad /= scale
+
+class OptimManager:
+ """wait cuda stream. Optional: add loss scaler for mix-precision training
+
+ Args:
+ loss_scale (float): The initial loss scale. Default to None for not using loss scaling.
+ loss_scale_factor (float): The loss scale factor.
+ loss_scale_steps (int): The loss scale steps.
+
+ Examples:
+ >>> optim_manager = bmt.optim.OptimManager(loss_scale=1024)
+ >>> optim_manager.add_optimizer(optimizer1)
+ >>> optim_manager.add_optimizer(optimizer2, lr_scheduler2)
+ >>> for data in dataset:
+ >>> # forward pass and calculate loss
+ >>> optim_manager.zero_grad()
+ >>> optim_manager.backward(loss)
+ >>> optim_manager.clip_grad_norm(optimizer1.param_groups, max_norm=1.0, norm_type=2)
+ >>> optim_manager.clip_grad_norm(optimizer2.param_groups, max_norm=2.0, norm_type=2)
+ >>> optim_manager.step()
+ """
+ def __init__(self,
+ loss_scale : Optional[float] = None,
+ loss_scale_factor : float = 2,
+ loss_scale_steps : int = 1024,
+ min_loss_scale = 1,
+ max_loss_scale = float("inf"),
+ grad_scale : Optional[int] = None,
+ ):
+ if loss_scale is not None:
+ self.loss_scale = loss_scale
+ self.loss_scale_enabled = True
+ else:
+ self.loss_scale = 1
+ self.loss_scale_enabled = False
+ self.steps_since_last_scale = 0
+ self.loss_scale_factor = loss_scale_factor if loss_scale_factor > 1 else 1 / loss_scale_factor
+ self.loss_scale_steps = loss_scale_steps
+ self.min_loss_scale = min_loss_scale
+ self.max_loss_scale = max_loss_scale
+ if grad_scale is None:
+ grad_scale = config['zero_size']
+ self.grad_scale = grad_scale
+
+ self.optimizers = []
+ self.lr_schedulers = []
+
+ def add_optimizer(
+ self,
+ optimizer: torch.optim.Optimizer,
+ lr_scheduler: Optional[WarmupLRScheduler] = None,
+ ):
+ """Add optimizer and (optional) its corresponding lr_scheduler into optim_manager.
+ All optimizers in the same optim_manager share the same loss scale.
+
+ Args:
+ optim (torch.optim.Optimizer): A pytorch optimizer, e.g. torch.optim.Adam, torch.optim.SGD or bmtrain.optim.AdamOffloadOptimizer
+ lr_scheduler (Optional[WarmupLRScheduler]): A warmup lr scheduler, e.g. bmt.lr_scheduler.Noam
+ """
+ self.optimizers.append(optimizer)
+ self.lr_schedulers.append(lr_scheduler)
+
+ def scale_loss(self, loss : torch.Tensor) -> torch.Tensor:
+
+ return loss * ( self.loss_scale / self.grad_scale ) # loss scale
+
+ def backward(self, loss : torch.Tensor):
+ """
+ Backward with loss scale.
+
+ Args:
+ loss (torch.Tensor): loss
+ """
+ loss = self.scale_loss(loss)
+ loss.backward()
+ # some reduce ops of distributed parameter were launched on load stream
+ current_stream = torch.cuda.current_stream()
+ current_stream.wait_stream(config['load_stream'])
+
+ def zero_grad(self):
+ """
+ This is a helper function to call optimizer.zero_grad()
+ """
+ for optimizer in self.optimizers:
+ optimizer.zero_grad(set_to_none=False)
+
+ def step(self):
+ """
+ Backward with loss scale.
+ Synchronize streams before optimizer steps.
+
+ This is a helper function to call optimizer.step() and lr_scheduler.step() and synchronize streams.
+
+ This function can also handle gradient overflow by reducing the loss scale when it occurs.
+ """
+ if self.loss_scale_enabled:
+ has_overflow = False
+ for optimizer in self.optimizers:
+ try:
+ check_overflow(optimizer.param_groups)
+ except OverflowError:
+ has_overflow = True
+ break
+ if has_overflow:
+ print_rank("Gradient overflow, change scale from %lf to %lf" % (self.loss_scale, self.loss_scale / self.loss_scale_factor))
+ with torch.no_grad():
+ if self.loss_scale > self.min_loss_scale:
+ self._justify_scale(self.loss_scale / self.loss_scale_factor)
+ self.zero_grad()
+ return
+ for optimizer, lr_scheduler in zip(self.optimizers, self.lr_schedulers):
+ if hasattr(optimizer, "_bmtrain_optimizer") and optimizer._bmtrain_optimizer:
+ optimizer.step(scale=self.loss_scale)
+ else:
+ if self.loss_scale_enabled:
+ grad_rescale(optimizer.param_groups, self.loss_scale)
+ optimizer.step()
+
+ if lr_scheduler is not None:
+ lr_scheduler.step()
+
+ if self.loss_scale_enabled:
+ self.steps_since_last_scale += 1
+
+ if self.steps_since_last_scale >= self.loss_scale_steps and self.loss_scale < self.max_loss_scale:
+ self._justify_scale(self.loss_scale * self.loss_scale_factor)
+
+ current_stream = torch.cuda.current_stream()
+ config['load_stream'].wait_stream(current_stream)
+
+ def clip_grad_norm(self, param_groups, max_norm, norm_type=2, eps=1e-6):
+ """Clips gradient norm of an iterable of parameters.
+
+ The norm is computed over all gradients together, as if they were concatenated into a single vector. Gradients are modified in-place.
+
+ Args:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized.
+ max_norm (float or int): max norm of the gradients.
+ norm_type (float or int): type of the used p-norm. Can be 'inf' for infinity norm.
+ eps (float): epsilon used to avoid zero division.
+
+ Returns:
+ Total norm of the parameters (viewed as a single vector).
+ """
+ scale = self.loss_scale
+ grads = []
+ parameters = [p for group in param_groups for p in group['params']]
+ for p in parameters:
+ if p.grad is not None:
+ grads.append(p.grad.data)
+ else:
+ grads.append(torch.zeros_like(p.data))
+
+ if norm_type == 'inf':
+ total_norm_cuda = max(g.data.abs().max() for g in grads).detach()
+ nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "max", config["comm"])
+ total_norm = total_norm_cuda
+ else:
+ norm_type = float(norm_type)
+ total_norm_cuda = torch.cuda.FloatTensor([0])
+ for index, g in enumerate(grads):
+ param_norm = g.data.float().norm(norm_type)
+ total_norm_cuda += param_norm ** norm_type
+ nccl.allReduce(total_norm_cuda.storage(), total_norm_cuda.storage(), "sum", config["comm"])
+ total_norm = total_norm_cuda[0] ** (1. / norm_type)
+ # total_norm = total_norm / scale
+ # clip_coef = float(max_norm) / (total_norm + eps)
+ clip_coef = float(max_norm * scale) / (total_norm + eps)
+ if clip_coef < 1:
+ for p in parameters:
+ if p.grad is not None:
+ p.grad.data.mul_(clip_coef)
+ return total_norm / scale
+
+ @torch.no_grad()
+ def _justify_scale(self, scale):
+ for optimizer in self.optimizers:
+ if hasattr(optimizer, "_on_justify_scale"):
+ optimizer._on_justify_scale(self.loss_scale, scale)
+ self.loss_scale = scale
+ self.steps_since_last_scale = 0
+
+ def state_dict(self, gather_opt=False) -> dict:
+ return {
+ "optimizers": [opt.state_dict(gather_opt) for opt in self.optimizers],
+ "lr_schedulers": [lrs.state_dict() if lrs else None for lrs in self.lr_schedulers],
+ "loss_scale": self.loss_scale,
+ "loss_scale_enabled": self.loss_scale_enabled,
+ }
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ assert len(self.optimizers) == len(state_dict["optimizers"])
+ assert len(self.lr_schedulers) == len(state_dict["lr_schedulers"])
+ for opt, opt_st in zip(self.optimizers, state_dict["optimizers"]):
+ opt.load_state_dict(opt_st)
+ for lrs, lrs_st in zip(self.lr_schedulers, state_dict["lr_schedulers"]):
+ lrs.load_state_dict(lrs_st)
+ self.loss_scale = state_dict["loss_scale"]
+ self.loss_scale_enabled = state_dict["loss_scale_enabled"]
diff --git a/examples/BMTrain/bmtrain/param_init.py b/examples/BMTrain/bmtrain/param_init.py
new file mode 100644
index 00000000..21f95f25
--- /dev/null
+++ b/examples/BMTrain/bmtrain/param_init.py
@@ -0,0 +1,105 @@
+from typing import Generator, Iterable, List, Tuple
+import torch
+from .block_layer import Block
+from .parameter import DistributedParameter
+from .global_var import config
+
+
+def init_distributed_parameter(params: Iterable[torch.nn.Parameter]):
+ """Init param of params which is instance of DistributedParameter using param._init_method.
+
+ Args:
+ params (Iterable[torch.nn.Parameter]): parameter tensors.
+
+ """
+ for param in params:
+ if not isinstance(param, DistributedParameter):
+ continue
+ if param._init_method is None:
+ continue
+ with torch.no_grad():
+ partition_size = param.storage().size()
+ global_size = partition_size * config["tp_zero_size"] * config["tp_size"]
+ tmp_storage = param.storage_type()(global_size)
+ tmp_tensor = torch.tensor([], dtype=param.dtype, device="cuda")
+ tmp_tensor.set_(tmp_storage, 0, param._tp_original_shape)
+
+ param._init_method(tmp_tensor)
+ if param._tp_mode and param._tp_split_dim >= 0:
+ tensor_list = tmp_tensor.chunk(
+ config["tp_size"], dim=param._tp_split_dim
+ )
+ sub_tensor = tensor_list[config["topology"].tp_id].contiguous()
+ tmp_tensor = torch.empty(
+ sub_tensor.shape, device=param.device, dtype=sub_tensor.dtype
+ )
+ tmp_tensor.copy_(sub_tensor)
+
+ if param._tp_mode:
+ begin = config["tp_zero_rank"]
+ else:
+ begin = config["zero_rank"]
+ end = begin + 1
+
+ # Pytorch 1.11 changed the API of storage.__getitem__
+ torch.tensor([], dtype=param.dtype, device=param.device).set_(
+ param.storage()
+ )[:] = torch.tensor([], dtype=param.dtype, device=param.device).set_(
+ tmp_tensor.storage()
+ )[
+ partition_size * begin : partition_size * end
+ ]
+ # param.storage().copy_(tmp_storage[partition_size * config['rank'] : partition_size * (config['rank'] + 1)])
+
+
+def iterate_parameters(model: torch.nn.Module):
+ """
+ Itterate over the parameters of the model.
+ """
+ for kw, val in model._parameters.items():
+ if hasattr(val, "_in_block") and val._in_block:
+ return []
+ yield val
+
+
+def init_parameters(model: torch.nn.Module):
+ """
+ Initialize the parameters of the model by calling the init_method of the distributed parameters.
+ """
+
+ modules = model.named_modules()
+ for module_prefix, module in modules:
+ if isinstance(module, Block):
+ module.init_parameters()
+ else:
+ init_distributed_parameter(iterate_parameters(module))
+
+ current_stream = torch.cuda.current_stream()
+ config["load_stream"].wait_stream(current_stream)
+
+
+def grouped_parameters(
+ model: torch.nn.Module,
+) -> Generator[Tuple[str, List[torch.nn.Parameter]], None, None]:
+ """
+ Iterate over the parameters of the model grouped by the group name.
+ This is similar to `torch.nn.Module.named_parameters()` .
+ """
+
+ ret: List[torch.nn.Parameter] = {}
+ for module in model.modules():
+ if isinstance(module, Block):
+ for kw, params in module.grouped_parameters():
+ if kw not in ret:
+ ret[kw] = []
+ ret[kw].extend(params)
+ else:
+ for param in module._parameters.values():
+ group = None
+ if isinstance(param, DistributedParameter):
+ group = param.group
+ if group not in ret:
+ ret[group] = []
+ ret[group].append(param)
+ for kw, val in ret.items():
+ yield kw, val
diff --git a/examples/BMTrain/bmtrain/parameter.py b/examples/BMTrain/bmtrain/parameter.py
new file mode 100644
index 00000000..2dad4a3d
--- /dev/null
+++ b/examples/BMTrain/bmtrain/parameter.py
@@ -0,0 +1,206 @@
+from typing import Callable, Iterable, Optional
+import torch
+from .utils import round_up
+from .global_var import config
+from . import nccl
+from .distributed import all_gather
+
+
+class DistributedParameter(torch.nn.Parameter):
+ r"""
+ DistributedParameter is a subclass of torch.nn.Parameter.
+
+ It scatters the tensor to all the nodes and gathers them when needed.
+
+ Args:
+ data (Tensor): parameter tensor.
+ requires_grad (bool, optional): if the parameter requires gradient.
+ init_method (Callable[['DistributedParameter'], None], optional): the method to initialize the parameter.
+ group (str, optional): the group name of the parameter.
+
+ **Note**: DistributedParameter must be on the CUDA device. It will transfer the data to device automatically when `__init__` called.
+
+ """
+
+ _original_shape: torch.Size
+ _start_partition: int
+ _end_partition: int
+ _init_method: Optional[Callable[["DistributedParameter"], None]]
+ _in_block: bool
+ _group: Optional[str]
+
+ def __new__(
+ cls,
+ data: torch.Tensor,
+ requires_grad: bool = True,
+ init_method: Optional[Callable[["DistributedParameter"], None]] = None,
+ group: Optional[str] = None,
+ tp_mode: bool = False,
+ tp_split_dim: int = -1,
+ ):
+ if not config["initialized"]:
+ raise RuntimeError("BMTrain is not initialized")
+
+ num_of_elements = data.numel()
+
+ cuda_tensor = torch.tensor([], dtype=data.dtype, device="cuda")
+ if tp_mode:
+ comm = config["tp_zero_comm"]
+ else:
+ comm = config["zero_comm"]
+ world_size = nccl.commCount(comm)
+ rank = nccl.commRank(comm)
+ cuda_storage_size = round_up(num_of_elements, world_size) // world_size
+
+ original_shape = data.size()
+ tp_original_shape = original_shape
+ if tp_mode and tp_split_dim >= 0:
+ tp_original_shape = list(original_shape)
+ tp_original_shape[tp_split_dim] *= config["tp_size"]
+
+ cuda_storage = cuda_tensor.storage_type()(cuda_storage_size)
+
+ start_of_partition = cuda_storage_size * rank
+ end_of_partition = min(num_of_elements, cuda_storage_size * (rank + 1))
+
+ # FX: cuda_tensor_size < 0 if num_of_elements is too small
+ cuda_tensor_size = max(end_of_partition - start_of_partition, 0)
+
+ cuda_tensor.set_(cuda_storage, 0, (cuda_tensor_size,))
+ cuda_tensor.copy_(data.view(-1)[start_of_partition:end_of_partition])
+ ret = torch.Tensor._make_subclass(cls, cuda_tensor, requires_grad)
+
+ setattr(ret, "_original_shape", original_shape)
+ setattr(ret, "_start_partition", start_of_partition)
+ setattr(ret, "_end_partition", end_of_partition)
+ setattr(ret, "_init_method", init_method)
+ setattr(ret, "_in_block", False)
+ setattr(ret, "_group", group if not tp_mode else "tp")
+
+ setattr(ret, "_tp_mode", tp_mode)
+ setattr(ret, "_zero_comm", comm)
+ setattr(ret, "_tp_split_dim", tp_split_dim)
+ setattr(ret, "_tp_original_shape", tp_original_shape)
+ return ret
+
+ @property
+ def group(self):
+ """The group name of the distributed parameter."""
+
+ return self._group
+
+ def gather(self) -> torch.Tensor:
+ """Gather the data from ZeRO distributed nodes.
+
+ Return:
+ torch.Tensor: The gathered data.
+
+ """
+ with torch.cuda.stream(config["load_stream"]):
+ output_tensor = OpAllGather.apply(self)
+ current_stream = torch.cuda.current_stream()
+ output_tensor.record_stream(current_stream)
+ current_stream.wait_stream(config["load_stream"])
+ return output_tensor
+
+ def gather_all(self) -> torch.tensor:
+ """Gather the data from ZeRO and Tensor Parallel distributed nodes.
+
+ Return:
+ torch.Tensor: The gathered data.
+
+ """
+ zero_param = self.gather()
+ if config["tp_size"] > 1 and self._tp_split_dim >= 0:
+ output_tensor = all_gather(zero_param, config["tp_comm"])
+ if self._tp_split_dim == 1:
+ output_list = output_tensor.chunk(config["tp_size"], dim=0)
+ output = torch.cat(output_list, dim=output_list[0].dim() - 1).flatten(
+ 0, 1
+ )
+ return output
+ else:
+ return output_tensor.flatten(0, 1)
+ else:
+ return zero_param
+
+ def tp_gather(self) -> torch.tensor:
+ """Gather the data from Tensor Parallel distributed nodes.
+
+ Return:
+ torch.Tensor: The gathered data.
+
+ """
+ if config["tp_size"] > 1 and self._tp_split_dim >= 0:
+ output_tensor = all_gather(self, config["tp_comm"])
+ if self._tp_split_dim == 1:
+ output_list = output_tensor.chunk(config["tp_size"], dim=0)
+ output = torch.cat(output_list, dim=output_list[0].dim() - 1).flatten(
+ 0, 1
+ )
+ return output
+ else:
+ return output_tensor.flatten(0, 1)
+ else:
+ return self
+
+ def _copy_data(self, data: torch.Tensor):
+ """Copy data to self.data."""
+ self.data.copy_(data.view(-1)[self._start_partition : self._end_partition])
+
+
+class OpAllGather(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, value: DistributedParameter):
+ assert isinstance(value, DistributedParameter)
+ comm = value._zero_comm # config['zero_comm']
+ world_size = nccl.commCount(comm)
+ ctx.comm = comm
+ ctx.world_size = world_size
+
+ partition_size = value.storage().size()
+ global_size = partition_size * world_size
+
+ storage = value.storage_type()(global_size)
+
+ nccl.allGather(value.storage(), storage, comm)
+
+ output_tensor = torch.tensor([], dtype=value.dtype, device="cuda")
+ output_tensor.set_(storage, 0, value._original_shape)
+
+ ctx.partition_size = partition_size
+ ctx.tensor_size = value.size(0)
+ return output_tensor
+
+ @staticmethod
+ def backward(ctx, grad_output: torch.Tensor):
+ if not grad_output.is_contiguous():
+ grad_output = grad_output.contiguous()
+
+ grad_storage = grad_output.storage_type()(ctx.partition_size)
+ grad_output_storage = grad_output.storage()
+ if grad_output_storage.size() == ctx.partition_size * ctx.world_size:
+ pass
+ else:
+ grad_output_storage.resize_(ctx.partition_size * ctx.world_size)
+ nccl.reduceScatter(grad_output_storage, grad_storage, "sum", ctx.comm)
+ grad_tensor = torch.tensor([], dtype=grad_output.dtype, device="cuda")
+ grad_tensor.set_(grad_storage, 0, (ctx.tensor_size,))
+ return grad_tensor
+
+
+class ParameterInitializer:
+ """
+ ParameterInitializer is a helper class that is used to initialize the distributed parameters.
+
+ Similar to functools.partial .
+
+ """
+
+ def __init__(self, func: Callable, *args, **kwargs) -> None:
+ self.func = func
+ self._args = args
+ self._kwargs = kwargs
+
+ def __call__(self, param: DistributedParameter):
+ self.func(param, *self._args, **self._kwargs)
diff --git a/examples/BMTrain/bmtrain/pipe_layer.py b/examples/BMTrain/bmtrain/pipe_layer.py
new file mode 100644
index 00000000..4d3b17ad
--- /dev/null
+++ b/examples/BMTrain/bmtrain/pipe_layer.py
@@ -0,0 +1,314 @@
+from collections import OrderedDict
+import copy
+import torch
+import copy
+from typing import Dict, Iterable, Iterator, Tuple, Union, List
+import torch
+
+from .distributed import all_gather, broadcast, all_reduce, send_activations, recv_activations
+from .global_var import config
+from . import nccl
+from .zero_context import (
+ ZeroContext
+)
+from . import debug
+from .block_layer import Block, round_up, _get_param_kw, _block_wrapper
+
+class PipePreFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, hidden_state, *args):
+ hidden_state_list = all_gather(hidden_state.clone(), config["pipe_comm"])
+ hidden_state_list.requires_grad_()
+
+ batch_related = args[-1]
+ batch_related_origin = [True if i in args[-1] else False for i in range(len(args[:-1]))]
+ batch_related_rule = []
+ args = args[:-1]
+
+ batch_size = hidden_state.shape[0]
+ num_micros = config["micros"]
+ args_list = [[] for _ in range(num_micros)]
+ input_requires_grad = []
+ for arg in args:
+ if torch.is_tensor(arg):
+ arg_all = all_gather(arg, config['pipe_comm'])
+ if arg.dim() == hidden_state.dim() and arg.shape[0] == batch_size:
+ batch_related_rule.append(True)
+ arg_all = arg_all.flatten(0, 1).chunk(num_micros, dim=0)
+ arg_all = [tensor.requires_grad_(arg.requires_grad) for tensor in arg_all]
+ else:
+ batch_related_rule.append(False)
+ arg_all = [arg_all[0].requires_grad_(arg.requires_grad) for i in range(num_micros)]
+ input_requires_grad.append(arg.requires_grad)
+ else:
+ batch_related_rule.append(False)
+ arg_all = [arg for _ in range(num_micros)]
+ input_requires_grad.append(False)
+ for i in range(num_micros):
+ args_list[i].append(arg_all[i])
+ ctx.input_requires_grad = input_requires_grad
+ ctx.args_list = args_list
+ if len(batch_related) == 0:
+ ctx.batch_related = batch_related_rule
+ else:
+ ctx.batch_related = batch_related_origin
+ return hidden_state_list, args_list
+
+ @staticmethod
+ def backward(ctx, grads, arg_grads):
+ grads = broadcast(grads, 0, config['pipe_comm'])
+ topo = config['topology']
+ arg_grads = []
+ num_micros = config['micros']
+ for idx,requires_grad in enumerate(ctx.input_requires_grad):
+ if requires_grad:
+ grad = torch.cat([ctx.args_list[m][idx].grad for m in range(num_micros)], dim=0)
+ grad = all_reduce(grad, "sum", config["pipe_comm"])
+ split_size = topo.stages if ctx.batch_related[idx] else num_micros
+ grad = grad.chunk(split_size)
+ if ctx.batch_related[idx]:
+ arg_grads.append(grad[topo.stage_id])
+ else:
+ arg_grads.append(grad[0])
+ else:
+ arg_grads.append(None)
+ arg_grads.append(None) #for append(batch_related)
+ return grads.chunk(topo.stages, dim=0)[topo.stage_id], *arg_grads
+
+class PipePostFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, last_hidden, hidden_states=None, forward_stage_ranges=None, backward_stage_ranges=None, last_hidden_shape=None, return_hidden_states=False):
+ topo = config['topology']
+ ctx.return_hidden_states = return_hidden_states
+ last_hidden = broadcast(last_hidden, config["pipe_size"] - 1, config["pipe_comm"])
+ last_hidden = last_hidden.chunk(topo.stages, dim=0)
+ output = last_hidden[topo.stage_id]
+ output.requires_grad_()
+
+ if return_hidden_states:
+ ctx.stage_id = topo.stage_id
+ ctx.stages = topo.stages
+ ctx.backward_stage_ranges = backward_stage_ranges
+ middle_hiddens = []
+ for stage_id in range(ctx.stages):
+ if ctx.stage_id == stage_id:
+ middle_hidden = hidden_states
+ else:
+ middle_shape = (forward_stage_ranges[stage_id],) + last_hidden_shape
+ middle_hidden = torch.zeros(middle_shape, device=hidden_states.device, dtype=hidden_states.dtype)
+ middle_hidden = broadcast(middle_hidden, stage_id, config["pipe_comm"])
+ middle_hidden = middle_hidden.chunk(ctx.stages, dim=1)
+ middle_hidden = middle_hidden[ctx.stage_id].clone()
+ middle_hiddens.append(middle_hidden)
+ middle_hiddens = torch.cat(middle_hiddens, dim=0)
+ middle_hiddens.requires_grad_()
+ return output, middle_hiddens
+ else:
+ return output
+
+ @staticmethod
+ def backward(ctx, grads, grad_middle=None):
+ grad_list = all_gather(grads, config["pipe_comm"])
+ grad_list = grad_list.flatten(start_dim=0, end_dim=1)
+
+ if ctx.return_hidden_states:
+ for stage_id in range(ctx.stages):
+ layer_range = ctx.backward_stage_ranges[stage_id]
+ grad_middle_state = grad_middle[layer_range]
+ grad_middle_state = all_gather(grad_middle_state.transpose(0,1), config["pipe_comm"])
+ grad_middle_state = grad_middle_state.flatten(start_dim=0, end_dim=1).transpose(0, 1)
+ if ctx.stage_id == stage_id:
+ grad_hidden_state_list = grad_middle_state
+ return grad_list, grad_hidden_state_list, None, None, None, None
+ else:
+ return grad_list
+
+class StagePreFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, input, stage_id):
+ ctx.stage_id = stage_id
+ ctx.is_first_stage = stage_id == 0
+ ctx.is_last_stage = stage_id == config['pipe_size'] - 1
+ if not ctx.is_first_stage:
+ input = recv_activations(stage_id - 1, config['pipe_comm'])
+ input.requires_grad_()
+ return input
+ return input
+
+ @staticmethod
+ def backward(ctx, grad_outputs):
+ if not ctx.is_first_stage:
+ send_data = grad_outputs[0] if isinstance(grad_outputs, tuple) else grad_outputs
+ current_stream = torch.cuda.current_stream()
+ with torch.cuda.stream(config['pp_comm_stream']):
+ config['pp_comm_stream'].wait_stream(current_stream)
+ send_data.record_stream(config['pp_comm_stream'])
+ send_activations(send_data, ctx.stage_id - 1, config['pipe_comm'])
+ return grad_outputs, None
+
+class StagePostFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, outputs, stage_id):
+ ctx.stage_id = stage_id
+ ctx.is_first_stage = stage_id == 0
+ ctx.is_last_stage = stage_id == config['pipe_size'] - 1
+ if not ctx.is_last_stage:
+ send_data = outputs[0] if isinstance(outputs, tuple) else outputs
+ current_stream = torch.cuda.current_stream()
+ with torch.cuda.stream(config['pp_comm_stream']):
+ config['pp_comm_stream'].wait_stream(current_stream)
+ send_data.record_stream(config['pp_comm_stream'])
+ send_activations(send_data.detach(), stage_id + 1, config['pipe_comm'])
+ return outputs
+
+ @staticmethod
+ def backward(ctx, grad_outputs):
+ if not ctx.is_last_stage:
+ pre_grad_inputs = recv_activations(ctx.stage_id + 1, config['pipe_comm'])
+ return pre_grad_inputs, None
+ return grad_outputs, None
+
+
+class PipelineTransformerBlockList(torch.nn.Module):
+ r"""
+ TransformerBlockList is a list of Blocks.
+
+ This is designed to reduce the communication overhead by overlapping the computation and reduce_scatter operation during backward pass.
+
+ It is similar to `torch.nn.ModuleList` but with the difference when calling .forward() and .backward().
+
+ Example:
+ >>> module_list = [ ... ]
+ >>> normal_module_list = torch.nn.ModuleList(module_list)
+ >>> transformer_module_list = PipelineTransformerBlockList(module_list)
+ >>> # Calling normal module list
+ >>> for layer in normal_module_list:
+ >>> hidden_state = layer.forward(hidden_state, ...)
+ >>> # Calling transformer module list
+ >>> hidden_state = transformer_module_list(hidden_state, ...)
+
+ """
+ _modules: Dict[str, Block]
+
+ def __init__(self, modules: Iterable[torch.nn.Module], num_hidden=1) -> None:
+ super().__init__()
+ self.num_hidden = num_hidden
+ self._modules = {}
+ self.layer_ids = []
+ topo = config["topology"]
+ self.stages = topo.stages
+ self.stage_id = topo.stage_id
+ self.pipe_idx = topo.pipe_idx
+ module_dict = {}
+ for idx, module in enumerate(modules):
+ module = _block_wrapper(module, module_dict, "PIPE")
+ module._zero_level = 2 #currently, only support ZeRO-2 in pipeline mode
+ self._modules[str(idx)] = module
+
+ self.layer_ids = self.get_range_by_stage_id(self.stage_id)
+
+ pre_module = None
+ for i,layer_id in enumerate(self.layer_ids):
+ module = self._modules[str(layer_id)]
+ module.set_pre_module(pre_module)
+ pre_module = module
+ module._is_first_layer = False
+ module._is_last_layer = False
+
+ self._modules[str(self.layer_ids[0])]._is_first_layer = True
+ self._modules[str(self.layer_ids[-1])]._is_last_layer = True
+
+ def __len__(self) -> int:
+ return len(self._modules)
+
+ def __iter__(self) -> Iterator[Block]:
+ return iter(self._modules.values())
+
+ def __getitem__(self, index: Union[int, str]) -> Block:
+ return self._modules[str(index)]
+
+ def forward(self, hidden_state, *args, batch_related=[], return_hidden_states=False):
+ self.return_hidden_states = return_hidden_states
+ batch_size = hidden_state.shape[0]
+ num_micros = config["micros"]
+ args = args + (batch_related, )
+ hidden_state.requires_grad_()
+ hidden_state_list, args_list = PipePreFunction.apply(hidden_state, *args)
+
+ hidden_state_list = hidden_state_list.flatten(0, 1).chunk(num_micros, dim=0)
+ outputs = []
+ hidden_states = []
+
+ for micro_idx, (hidden_state, arg) in enumerate(zip(hidden_state_list, args_list)):
+ micro_hidden_states = []
+
+ hidden_state = StagePreFunction.apply(hidden_state, self.stage_id)
+
+ for idx,layer_id in enumerate(self.layer_ids):
+ self._modules[str(layer_id)]._micro_idx = micro_idx
+ if return_hidden_states:
+ micro_hidden_states.append(hidden_state)
+ hidden_state = self._modules[str(layer_id)](hidden_state, *arg)
+ hidden_state = StagePostFunction.apply(hidden_state, self.stage_id)
+
+ outputs.append(hidden_state)
+ if return_hidden_states:
+ hidden_states.append(torch.stack(micro_hidden_states, dim=0))
+
+ last_hidden = torch.cat(outputs, dim=0)
+ last_hidden_shape = last_hidden.shape
+
+ if return_hidden_states:
+ hidden_states = torch.cat(hidden_states, dim=1)
+ forward_stage_ranges = []
+ backward_stage_ranges = []
+ for stage_id in range(self.stages):
+ forward_stage_ranges.append(self.get_part_len_by_stage_id(stage_id))
+ backward_stage_ranges.append(self.get_range_by_stage_id(stage_id))
+ outputs, hidden_states = PipePostFunction.apply(last_hidden, hidden_states, forward_stage_ranges, backward_stage_ranges, last_hidden_shape, return_hidden_states)
+ return outputs, hidden_states
+ else:
+ outputs = PipePostFunction.apply(last_hidden)
+ return outputs
+
+ def get_range_by_stage_id(self, stage_id : int) -> List[int]:
+ part_lens = [0]+[self.get_part_len_by_stage_id(i) for i in range(stage_id+1)]
+ start = sum(part_lens[:stage_id+1])
+ end = start + part_lens[stage_id+1]
+ return range(start, end)
+
+ def get_part_len_by_stage_id(self, stage_id : int) -> int:
+ return len(self) // self.stages + (stage_id < (len(self) % self.stages))
+
+ def get_stage_by_layer_id(self, layer_id : int) -> int:
+ part_len = len(self) // self.stages
+ rest = len(self) % self.stages
+ if layer_id // (part_len + 1) < rest:
+ return layer_id // (part_len + 1)
+ else:
+ return rest + (layer_id - rest * (part_len+1)) // part_len
+
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
+ for name, module in self._modules.items():
+ idx = int(name)
+ name = prefix + name + '.'
+
+ dst = OrderedDict() # creates an temporary ordered dict
+ dst._metadata = OrderedDict()
+
+ if idx in self.layer_ids:
+ with torch.no_grad():
+ with ZeroContext(module, pipe=True):
+ module._module.state_dict(destination=dst, prefix=name, keep_vars=False)
+
+ if config["topology"].pp_zero_id == 0:
+ if config["rank"] == 0:
+ destination.update(dst)
+ else:
+ assert list(dst.keys()) == [name+n for n, parameter in module._module.named_parameters()]
+ for key, tensor in dst.items():
+ send_activations(tensor.cuda(), 0, config['pipe_comm'])
+ if config['rank'] == 0 and idx not in self.layer_ids:
+ for n, parameter in module._module.named_parameters():
+ destination[name+n] = recv_activations(self.get_stage_by_layer_id(idx), config['pipe_comm']).cpu()
+
diff --git a/examples/BMTrain/bmtrain/store.py b/examples/BMTrain/bmtrain/store.py
new file mode 100644
index 00000000..2a3ee02c
--- /dev/null
+++ b/examples/BMTrain/bmtrain/store.py
@@ -0,0 +1,325 @@
+from collections import OrderedDict
+from typing import Dict
+import torch
+
+from .pipe_layer import PipelineTransformerBlockList
+from .block_layer import TransformerBlockList
+from .global_var import config
+from .block_layer import Block
+from . import nccl
+import io, pickle
+from typing import Mapping
+import threading
+import bmtrain as bmt
+
+def _save_to_state_dict(model : torch.nn.Module, rank, destination, prefix):
+ if isinstance(model, Block):
+ if rank != 0:
+ destination = OrderedDict() # creates an temporary ordered dict
+ destination._metadata = OrderedDict()
+ model.state_dict(destination=destination, prefix=prefix, keep_vars=False)
+ else:
+ if rank != 0:
+ destination = OrderedDict() # creates an temporary ordered dict
+ destination._metadata = OrderedDict()
+ model._save_to_state_dict(destination, prefix, False)
+
+def _save_to_local_rank0(model : torch.nn.Module, destination=None, prefix=''):
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version)
+ _save_to_state_dict(model, config['local_rank'], destination, prefix)
+ for name, module in model._modules.items():
+ if module is not None:
+ _save_to_local_rank0(module, destination, prefix + name + '.')
+ for hook in model._state_dict_hooks.values():
+ hook_result = hook(model, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+ return destination
+
+
+def _save_to_rank0(model : torch.nn.Module, destination=None, prefix=''):
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version)
+ if not isinstance(model, PipelineTransformerBlockList):
+ _save_to_state_dict(model, config['rank'], destination, prefix)
+ for name, module in model._modules.items():
+ if module is not None:
+ _save_to_rank0(module, destination, prefix + name + '.')
+ for hook in model._state_dict_hooks.values():
+ hook_result = hook(model, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+ else:
+ model._save_to_state_dict(destination, prefix, False)
+ return destination
+
+def _save_to_infer_model(model : torch.nn.Module, infer_model, destination=None, prefix=''):
+ config['save_param_to_cpu'] = False
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(version=model._version)
+ _save_to_state_dict(model, config['local_rank'], destination, prefix)
+ for name, module in model._modules.items():
+ if module is not None:
+ if isinstance(module, TransformerBlockList):
+ for local_name, local_module in module._modules.items():
+ local_state_dict = _save_to_local_rank0(local_module, None, prefix + name + "." + local_name + '.')
+ if config['local_rank'] == 0:
+ infer_model.load_layer_state_dict(local_state_dict)
+ else:
+ _save_to_infer_model(module, infer_model, destination, prefix + name + '.')
+ for hook in model._state_dict_hooks.values():
+ hook_result = hook(model, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+
+ if config['local_rank'] == 0:
+ infer_model.load_layer_state_dict(destination)
+
+
+def async_save_to_file(state_dict, file_path):
+ torch.save(state_dict, file_path)
+ config['finish_save'] = True
+ print("finish save state_dict to ", file_path)
+
+def save(model : torch.nn.Module, file_name : str, non_blocking : bool=False):
+ """Saves the model to the file.
+
+ Similar to torch.save, but it used for distributed modules.
+
+ Args:
+ model (torch.nn.Module): The model to be saved.
+ file_name (str): The file name of the checkpoint.
+ non_blocking (bool): Whether to asynchronously save state_dict to file
+
+
+ Examples:
+ >>> bmtrain.save(model, "model.pt")
+ """
+ torch.cuda.synchronize()
+ state_dict = _save_to_rank0(model)
+ if config["rank"] == 0:
+ if non_blocking is False:
+ torch.save(state_dict, file_name)
+ else:
+ if 'finish_save' not in config:
+ config['finish_save'] = True
+
+ if config['finish_save'] is False:
+ config['save_thread'].join()
+
+ config['finish_save'] = False
+ config['save_thread'] = threading.Thread(target=async_save_to_file, args=(state_dict, file_name))
+ config['save_thread'].start()
+ bmt.synchronize()
+
+DTYPE_LIST = [
+ torch.float64,
+ torch.float32,
+ torch.float16,
+ torch.int64,
+ torch.int32,
+ torch.int16,
+ torch.int8,
+ torch.bfloat16,
+ torch.bool
+]
+
+_pickler = pickle.Pickler
+_unpickler = pickle.Unpickler
+
+def allgather_objects(obj):
+ if bmt.world_size() == 1:
+ return [obj]
+
+ with torch.no_grad():
+ data_bytes: bytes = pickle.dumps(obj)
+ data_length: int = len(data_bytes)
+
+ gpu_data_length = torch.tensor([data_length], device="cuda", dtype=torch.long)
+ gathered_length = bmt.distributed.all_gather(gpu_data_length).view(-1).cpu()
+ max_data_length = gathered_length.max().item()
+
+ gpu_data_bytes = torch.zeros(max_data_length, dtype=torch.uint8, device="cuda")
+ byte_storage = torch.ByteStorage.from_buffer(data_bytes)
+ gpu_data_bytes[:data_length] = torch.ByteTensor(byte_storage)
+
+ gathered_data = bmt.distributed.all_gather(gpu_data_bytes).cpu()
+
+ ret = []
+ for i in range(gathered_data.size(0)):
+ data_bytes = gathered_data[i, : gathered_length[i].item()].numpy().tobytes()
+ ret.append(pickle.loads(data_bytes))
+ return ret
+
+def broadcast_object(obj, comm, src = 0):
+ if nccl.commRank(comm) == src:
+ f = io.BytesIO()
+ _pickler(f).dump(obj)
+ byte_storage = torch.ByteStorage.from_buffer(f.getvalue())
+ # Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
+ # Otherwise, it will casue 100X slowdown.
+ # See: https://github.com/pytorch/pytorch/issues/65696
+ byte_tensor = torch.ByteTensor(byte_storage).cuda()
+ local_size = torch.LongTensor([byte_tensor.numel()]).cuda()
+
+ nccl.broadcast(
+ local_size.storage(),
+ local_size.storage(),
+ src,
+ comm
+ )
+ nccl.broadcast(
+ byte_tensor.storage(),
+ byte_tensor.storage(),
+ src,
+ comm
+ )
+ else:
+ local_size = torch.LongTensor([0]).cuda()
+ nccl.broadcast(
+ local_size.storage(),
+ local_size.storage(),
+ src,
+ comm
+ )
+ byte_tensor_size = local_size[0].item()
+ byte_tensor = torch.empty(int(byte_tensor_size), dtype=torch.uint8, device="cuda")
+ nccl.broadcast(
+ byte_tensor.storage(),
+ byte_tensor.storage(),
+ src,
+ comm
+ )
+ buf = byte_tensor.cpu().numpy().tobytes()
+ obj = _unpickler(io.BytesIO(buf)).load()
+ return obj
+
+# Must be a Mapping after pytorch 1.12.0
+class DistributedTensorWrapper:
+ def __init__(self, tensor, shape=None):
+ self._dtype = tensor.dtype
+ self._device = tensor.device
+ self.shape = shape
+ self.tensor = tensor
+
+ def broadcast(self):
+ output_param = torch.empty(self.shape, dtype=self._dtype, device="cuda")
+ if config['rank'] == 0:
+ input_param = self.tensor
+ if input_param.is_cuda:
+ input_param = input_param.clone().contiguous()
+ else:
+ input_param = input_param.cuda().contiguous()
+
+ nccl.broadcast(
+ input_param.storage(),
+ output_param.storage(),
+ 0,
+ config['comm']
+ )
+ else:
+ nccl.broadcast(
+ output_param.storage(),
+ output_param.storage(),
+ 0,
+ config['comm']
+ )
+ return output_param
+
+ def copy(self):
+ return self.tensor
+
+ def __getattribute__(self, name):
+ if name == "tensor" or name == "shape":
+ return object.__getattribute__(self, name)
+ else:
+ try:
+ return object.__getattribute__(self, name)
+ except AttributeError:
+ pass
+
+ return getattr(self.tensor, name)
+
+class DistributedStateDictWrapper(Mapping):
+ def __init__(self, state_dict : Dict) -> None:
+ self._state_dict = state_dict
+ self._metadata = broadcast_object(getattr(state_dict, "_metadata", None), config["comm"])
+
+ def __getitem__(self, key : str):
+ tmp_shape = torch.zeros(32, device="cuda", dtype=torch.int32)
+ if config['rank'] == 0:
+ input_param : torch.Tensor = self._state_dict[key]
+ shape_list = torch.tensor(list(input_param.size()), device="cuda", dtype=torch.int32)
+ dtype_idx = DTYPE_LIST.index(input_param.dtype)
+
+ assert dtype_idx != -1, "Unknown data type %s" % input_param.dtype
+
+ tmp_shape[0] = shape_list.size(0)
+ tmp_shape[1] = dtype_idx
+ tmp_shape[2:2 + shape_list.size(0)] = shape_list
+
+ nccl.broadcast(
+ tmp_shape.storage(),
+ tmp_shape.storage(),
+ 0,
+ config['comm']
+ )
+
+ shape_list_size = tmp_shape[0].item()
+ dtype_idx = tmp_shape[1].item()
+ shape_list = torch.Size(tmp_shape[2: 2 + shape_list_size].tolist())
+
+ if config['rank'] != 0:
+ return DistributedTensorWrapper(torch.tensor([], dtype=DTYPE_LIST[dtype_idx], device="cuda"), shape=shape_list)
+ else:
+ return DistributedTensorWrapper(self._state_dict[key], shape=shape_list)
+
+
+
+ def copy(self):
+ return self
+
+ def __len__(self):
+ return broadcast_object(len(self._state_dict), config["comm"])
+
+ def __contains__(self, key : str):
+ return broadcast_object(key in self._state_dict, config["comm"])
+
+ def keys(self):
+ return broadcast_object(list(self._state_dict.keys()),config["comm"])
+
+ def __iter__(self):
+ # pytorch 1.12.0 updated the load_state_dict method, which needs the state_dict to be a `Mapping`.
+ return iter(self.keys())
+
+def load(model : torch.nn.Module, file_name : str, strict : bool = True):
+ """Loads the model from the file.
+
+ Similar to torch.load, but it uses less memory when loading large models.
+
+ Args:
+ model (torch.nn.Module): The model to be loaded.
+ file_name (str): The file name of the checkpoint.
+ strict (bool): Strict option of `load_state_dict`.
+
+ Example:
+ >>> bmtrain.load(model, "model.pt", strict=True)
+ """
+ if config['rank'] == 0:
+ state_dict = DistributedStateDictWrapper(torch.load(file_name))
+ else:
+ state_dict = DistributedStateDictWrapper({})
+
+ ret = model.load_state_dict(
+ state_dict,
+ strict = strict
+ )
+ torch.cuda.synchronize()
+ return ret
diff --git a/examples/BMTrain/bmtrain/synchronize.py b/examples/BMTrain/bmtrain/synchronize.py
new file mode 100644
index 00000000..87619159
--- /dev/null
+++ b/examples/BMTrain/bmtrain/synchronize.py
@@ -0,0 +1,73 @@
+import torch
+from . import distributed, nccl
+from .global_var import config
+import warnings
+from typing import Optional
+
+
+def synchronize():
+ """
+ Synchronize all the workers across all nodes. (both CPU and GPU are synchronized)
+ """
+ if not config["initialized"]:
+ raise RuntimeError("BMTrain is not initialized")
+
+ with torch.cuda.stream(config["barrier_stream"]):
+ barrier = torch.cuda.FloatTensor([1])
+ nccl.allReduce(barrier.storage(), barrier.storage(), "sum", config["comm"])
+ config["barrier_stream"].synchronize()
+
+
+def wait_loader():
+ """
+ Clac_stream (normally current stream) wait latest loader event, and set a new one.
+ """
+ if not config["initialized"]:
+ raise RuntimeError("BMTrain is not initialized")
+
+ config["load_event"].synchronize()
+ config["calc_stream"].record_event(config["load_event"])
+
+
+def sum_loss(loss: torch.Tensor, comm: Optional[nccl.NCCLCommunicator] = None):
+ """
+ Sum the loss across all workers.
+
+ This is a helper function to reduce the loss across all workers.
+ """
+ if comm is None:
+ comm = config["comm"]
+ warnings.warn(
+ "bmtrain.sum_loss is deprecated and will be removed in later version. Use bmtrain.distributed.all_reduce instead.",
+ DeprecationWarning,
+ )
+
+ return distributed.all_reduce(loss, "avg", comm)
+
+
+def gather_result(result: torch.Tensor):
+ """
+ Gather result across all workers.
+ """
+ warnings.warn(
+ "bmtrain.gather_result is deprecated and will be removed in later version. Use bmtrain.distributed.all_gather instead.",
+ DeprecationWarning,
+ )
+ if result.storage_offset() != 0 or result.storage().size() != result.numel():
+ # Create a clone of the original tensor if it's a slice
+ result = result.clone()
+
+ output_cuda = True
+ if not result.is_cuda:
+ result = result.cuda()
+ output_cuda = False
+ ret = torch.empty(
+ (result.shape[0] * config["world_size"], *list(result.shape[1:])),
+ device=result.device,
+ dtype=result.dtype,
+ )
+ nccl.allGather(result.storage(), ret.storage(), config["comm"])
+ if output_cuda:
+ return ret
+ else:
+ return ret.cpu()
diff --git a/examples/BMTrain/bmtrain/utils.py b/examples/BMTrain/bmtrain/utils.py
new file mode 100644
index 00000000..daa4c595
--- /dev/null
+++ b/examples/BMTrain/bmtrain/utils.py
@@ -0,0 +1,184 @@
+import torch
+import sys
+from typing import Any, Dict, Iterable, Optional
+from .global_var import config
+import os
+import ctypes
+
+ALIGN = 4
+ROW_WIDTH = 60
+
+
+def check_torch_version(version_str):
+ """
+ Checks if the current torch version is greater than or equal to the given version.
+ version_str (str): The version to compare with, in the format of "x.y.z" ,and the func will convert it into a int value of x*100+y*10+z.
+ """
+ version_int_arr = [int(v) for v in version_str.split(".")]
+
+ version_int = (
+ version_int_arr[0] * 10000 + version_int_arr[1] * 100 + version_int_arr[2]
+ )
+ torch_version = torch.__version__.split("+")[0]
+ current_version_int_arr = [int(v) for v in torch_version.split(".")]
+ current_version_int = (
+ current_version_int_arr[0] * 10000
+ + current_version_int_arr[1] * 100
+ + current_version_int_arr[2]
+ )
+ return current_version_int - version_int
+
+
+def load_nccl_pypi():
+ """
+ Check if current nccl is avaliable.
+ """
+ try:
+ import nvidia.nccl
+ except:
+ raise ImportError("Run pip install nvidia-nccl-cu11 >=2.14.3 first")
+
+ path = os.path.join(os.path.dirname(nvidia.nccl.__file__), "lib")
+ for file_so in os.listdir(path):
+ file_split = file_so.split(".")
+ if file_split[-1] == "so" or (len(file_split) > 1 and file_split[-2] == "so"):
+ ctypes.CDLL(os.path.join(path, file_so))
+
+
+def round_up(x, d):
+ """
+ Return (x + d - 1) // d * d
+ """
+ return (x + d - 1) // d * d
+
+
+def print_dict(title: str, content: Dict[str, Any], file=sys.stdout):
+ """
+ Print Dict to file.
+ """
+ max_kw_len = max([len(kw) for kw in content.keys()])
+ max_kw_len = round_up(max_kw_len + 3, 4)
+
+ raw_content = ""
+
+ for kw, val in content.items():
+ raw_content += kw + " :" + " " * (max_kw_len - len(kw) - 2)
+ raw_val = "%s" % val
+
+ len_val_row = ROW_WIDTH - max_kw_len
+ st = 0
+ if len(raw_val) == 0:
+ raw_val = " "
+ while st < len(raw_val):
+ if st > 0:
+ raw_content += " " * max_kw_len
+ raw_content += raw_val[st : st + len_val_row] + "\n"
+ st += len_val_row
+
+ print_block(title, raw_content, file)
+
+
+def print_block(title: str, content: Optional[str] = None, file=sys.stdout):
+ """
+ Print content to file.
+ """
+ left_title = (ROW_WIDTH - len(title) - 2) // 2
+ right_title = ROW_WIDTH - len(title) - 2 - left_title
+
+ print("=" * left_title + " " + title + " " + "=" * right_title, file=file)
+ if content is not None:
+ print(content, file=file)
+
+
+def print_rank(*args, rank=0, **kwargs):
+ """
+ Prints the message only on the `rank` of the process.
+
+ Args:
+ *args: The arguments to be printed.
+ rank (int): The rank id of the process to print.
+ **kwargs: The keyword arguments to be printed.
+
+ """
+ if config["rank"] == rank:
+ print(*args, **kwargs)
+
+
+def see_memory(message, detail=False):
+ """
+ Outputs a message followed by GPU memory status summary on rank 0.
+ At the end of the function, the starting point in tracking maximum GPU memory will be reset.
+
+ Args:
+ message (str): The message to be printed. It can be used to distinguish between other outputs.
+ detail (bool): Whether to print memory status in a detailed way or in a concise way. Default to false.
+
+ Example:
+ >>> bmt.see_memory("before forward")
+ >>> # forward_step()
+ >>> bmt.see_memory("after forward")
+
+ """
+ print_rank(message)
+ if detail:
+ print_rank(torch.cuda.memory_summary())
+ else:
+ print_rank(
+ f"""
+ =======================================================================================
+ memory_allocated {round(torch.cuda.memory_allocated() / (1024 * 1024 * 1024),2 )} GB
+ max_memory_allocated {round(torch.cuda.max_memory_allocated() / (1024 * 1024 * 1024),2)} GB
+ =======================================================================================
+ """
+ )
+ torch.cuda.reset_peak_memory_stats()
+
+
+def tp_split_tensor(tensor, split_dim):
+ """
+ Outpus the tensor with config["toplogy"].tp_id split at split dim.
+
+ Args:
+ tensor (torch.tensor): The tensor to be splited.
+ split_dim (int): The dim to split the input tensor.
+
+ """
+ tensor_list = tensor.chunk(config["tp_size"], dim=split_dim)
+ sub_tensor = tensor_list[config["topology"].tp_id].contiguous()
+ tmp_tensor = torch.empty(
+ sub_tensor.shape, device=sub_tensor.device, dtype=sub_tensor.dtype
+ )
+ tmp_tensor.copy_(sub_tensor)
+ return tmp_tensor
+
+
+class AverageRecorder:
+ """A utility class to record the average value of a quantity over time.
+
+ Args:
+ alpha (float): The decay factor of the average.
+ start_value (float): The initial value of the average.
+
+ Use `.value` to get the current average value.
+ It is calculated as `alpha * old_value + (1 - alpha) * new_value`.
+
+ """
+
+ def __init__(self, alpha=0.9, start_value=0):
+ self._value = start_value
+ self.alpha = alpha
+ self._steps = 0
+
+ def record(self, v):
+ """Records a new value.
+ Args:
+ v (float): The new value.
+ """
+ self._value = self._value * self.alpha + v * (1 - self.alpha)
+ self._steps += 1
+
+ @property
+ def value(self):
+ if self._steps <= 0:
+ return self._value
+ return self._value / (1 - pow(self.alpha, self._steps))
diff --git a/examples/BMTrain/bmtrain/wrapper.py b/examples/BMTrain/bmtrain/wrapper.py
new file mode 100644
index 00000000..e64fd5ba
--- /dev/null
+++ b/examples/BMTrain/bmtrain/wrapper.py
@@ -0,0 +1,54 @@
+import torch
+from .block_layer import Block, TransformerBlockList
+from .layer import DistributedModule, DistributedParameter
+
+
+def make_distributed(model: torch.nn.Module):
+ for kw in list(model._parameters.keys()):
+ if model._parameters[kw] is not None:
+ if not isinstance(model._parameters[kw], DistributedParameter):
+ model._parameters[kw] = DistributedParameter(
+ model._parameters[kw],
+ requires_grad=model._parameters[kw].requires_grad,
+ )
+
+ for kw in list(model._buffers.keys()):
+ if model._buffers[kw] is not None:
+ model._buffers[kw] = model._buffers[kw].cuda()
+ is_module_list = isinstance(model, torch.nn.ModuleList)
+ pre_module = None
+ for kw in list(model._modules.keys()):
+ if is_module_list:
+ if not isinstance(model._modules[kw], Block):
+ model._modules[kw] = Block(model_wrapper_dispatch(model._modules[kw]))
+ if pre_module is not None:
+ model._modules[kw].set_pre_module(pre_module)
+ pre_module = model._modules[kw]
+ else:
+ model._modules[kw] = model_wrapper_dispatch(model._modules[kw])
+
+ model.__class__ = type(
+ "bmtrain.Distributed" + model.__class__.__name__,
+ (model.__class__, DistributedModule),
+ {},
+ )
+ return model
+
+
+def model_wrapper_dispatch(model: torch.nn.Module):
+ if isinstance(model, TransformerBlockList):
+ return model
+ elif isinstance(model, DistributedModule):
+ return model
+ elif isinstance(model, Block):
+ return model
+ else:
+ return make_distributed(model)
+
+
+def BMTrainModelWrapper(model: torch.nn.Module) -> torch.nn.Module:
+ """
+ Automatically wrap a model in a BMTrain model.
+ Replaces all parameters with DistributedParameter, all modules with DistributedModule, and modules in ModuleList with Block.
+ """
+ return model_wrapper_dispatch(model)
diff --git a/examples/BMTrain/bmtrain/zero_context.py b/examples/BMTrain/bmtrain/zero_context.py
new file mode 100644
index 00000000..8a74b3f8
--- /dev/null
+++ b/examples/BMTrain/bmtrain/zero_context.py
@@ -0,0 +1,203 @@
+import torch
+from . import nccl
+from .global_var import config
+from .synchronize import wait_loader
+
+
+class ZeroContext:
+ """ZeroContext is a helper class to Gather parameters before module forward and reduce scatter
+ gradients after module backward.
+
+ Args:
+ block (BLock): Input Block.
+ ctx_dict (dict): block._layer_dict.
+ pipe (bool): True if use pipe parallel.
+
+ """
+
+ def __init__(self, block: "Block", ctx_dict: dict = None, pipe=False) -> None:
+ self.block = block
+ self.ctx_dict = ctx_dict
+ self._param_buffer = {}
+ self._grad_buffer = {}
+ self._param_tensor = {}
+ self._grad_tensor = {}
+ self._need_release = False
+
+ def enter(self, flag=0, requires_grad=False):
+ """
+ Gather parameters before module forward and init grad buffer before backward.
+ """
+ if self.block._ready:
+ return
+ self.block._ready = True
+ self._need_release = True
+
+ wait_loader()
+ with torch.cuda.stream(config["load_stream"]):
+ for kw, val in self.block._storage_info.items():
+ assert self.block._storage_params[kw].is_cuda
+ assert kw not in self._grad_buffer
+ assert kw not in self._param_buffer
+ local_param = self.block._storage_params[kw]
+
+ storage_type = local_param.storage_type()
+ if flag != 2:
+ self._param_buffer[kw] = storage_type(
+ val["partition_size"] * val["world_size"]
+ )
+ self._param_tensor[kw] = torch.tensor(
+ [],
+ dtype=self._param_buffer[kw].dtype,
+ device=self._param_buffer[kw].device,
+ ).set_(self._param_buffer[kw])
+
+ if requires_grad and local_param.requires_grad:
+ self._grad_buffer[kw] = storage_type(
+ val["partition_size"] * val["world_size"]
+ )
+ self._grad_tensor[kw] = (
+ torch.tensor(
+ [],
+ dtype=self._grad_buffer[kw].dtype,
+ device=self._grad_buffer[kw].device,
+ )
+ .set_(self._grad_buffer[kw])
+ .zero_()
+ )
+ if flag != 2:
+ nccl.groupStart()
+ for kw, val in self.block._storage_info.items():
+ nccl.allGather(
+ self.block._storage_params[kw].storage(),
+ self._param_buffer[kw],
+ val["zero_comm"],
+ )
+ nccl.groupEnd()
+
+ current_stream = torch.cuda.current_stream()
+ current_stream.wait_stream(config["load_stream"])
+
+ # set wait stream for each storage
+ for kw in self.block._storage_info.keys():
+ if flag != 2:
+ self._param_tensor[kw].record_stream(current_stream)
+ if requires_grad and kw in self._grad_tensor:
+ self._grad_tensor[kw].record_stream(current_stream)
+
+ # update parameters in block
+ for param in self.block._param_info:
+ kw_name = param["kw_name"]
+ offset = param["offset"]
+ shape = param["shape"]
+
+ if flag != 2:
+ dtype = self._param_buffer[kw_name].dtype
+ device = self._param_buffer[kw_name].device
+ param["parameter"].data = torch.tensor(
+ [], dtype=dtype, device=device
+ ).set_(self._param_buffer[kw_name], offset, shape)
+ else:
+ dtype = param["parameter"].data.dtype
+ device = param["parameter"].data.device
+ param["parameter"].data = torch.tensor(
+ [], dtype=dtype, device=device
+ ).set_(self.ctx_dict[kw_name], offset, shape)
+
+ if (
+ requires_grad
+ and kw_name in self._grad_buffer
+ and param["parameter"].requires_grad
+ ):
+ param["parameter"].grad = torch.tensor(
+ [], dtype=dtype, device=device
+ ).set_(self._grad_buffer[kw_name], offset, shape)
+
+ def __enter__(self):
+ self.enter()
+
+ def exit(self, flag=0, backward=False):
+ """
+ Reduce scatter gradients when backward and release all parameters from buffer to block_storge when forward is done.
+ """
+ if not self._need_release:
+ return
+ self._need_release = False
+ self.block._ready = False
+ if backward:
+ for kw, val in self.block._storage_info.items():
+ local_param = self.block._storage_params[kw]
+
+ # accumulate previous gradient
+ if local_param.requires_grad:
+ if local_param.grad is None:
+ grad_storage = val["storage_type"](
+ val["partition_size"]
+ ) # initialize gradient if not exist
+ local_param.grad = (
+ torch.tensor(
+ [], dtype=grad_storage.dtype, device=grad_storage.device
+ )
+ .set_(grad_storage)
+ .zero_()
+ )
+ else:
+ self._grad_tensor[kw][
+ val["begin"] : val["end"]
+ ] += local_param.grad
+
+ current_stream = torch.cuda.current_stream()
+ config["load_stream"].wait_stream(current_stream) # wait for backward
+
+ with torch.cuda.stream(config["load_stream"]):
+ nccl.groupStart()
+ for kw, val in self.block._storage_info.items():
+ local_param = self.block._storage_params[kw]
+
+ # scatter gradient
+ if local_param.requires_grad:
+ nccl.reduceScatter(
+ self._grad_buffer[kw],
+ local_param.grad.storage(),
+ "sum",
+ val["zero_comm"],
+ )
+ nccl.groupEnd()
+
+ # set wait stream for each storage
+ for kw in self._grad_tensor.keys():
+ # grads can not be freed until reduce ops finish
+ self._grad_tensor[kw].record_stream(config["load_stream"])
+
+ # Release all parameters from buffer to block_storge
+ for param in self.block._param_info:
+ kw_name = param["kw_name"]
+ dtype = self.block._storage_params[kw_name].dtype
+ device = self.block._storage_params[kw_name].device
+ if "begin" not in param:
+ param["parameter"].data = torch.tensor([], dtype=dtype, device=device)
+ param["parameter"].grad = None
+ continue
+ begin = param["begin"]
+ end = param["end"]
+ param["parameter"].data = torch.tensor([], dtype=dtype, device=device).set_(
+ self.block._storage_params[kw_name].storage(), begin, end
+ )
+ if (
+ param["parameter"].requires_grad
+ and self.block._storage_params[kw_name].grad is not None
+ ):
+ param["parameter"].grad = torch.tensor(
+ [], dtype=dtype, device=device
+ ).set_(self.block._storage_params[kw_name].grad.storage(), begin, end)
+ if flag == 1:
+ for i in self._param_buffer:
+ self.ctx_dict[i] = self._param_buffer[i]
+ self._grad_tensor = {}
+ self._param_tensor = {}
+ self._grad_buffer = {}
+ self._param_buffer = {}
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ # reduce scatter gradients
+ self.exit()
diff --git a/examples/BMTrain/cmake/FindNCCL.cmake b/examples/BMTrain/cmake/FindNCCL.cmake
new file mode 100644
index 00000000..2af8e3b9
--- /dev/null
+++ b/examples/BMTrain/cmake/FindNCCL.cmake
@@ -0,0 +1,100 @@
+list(APPEND NCCL_ROOT $ENV{NCCL_ROOT_DIR} ${CUDA_TOOLKIT_ROOT_DIR})
+if(DEFINED ENV{NCCL_ROOT_DIR})
+ set(NCCL_ROOT_DIR $ENV{NCCL_ROOT_DIR})
+ set(NCCL_INCLUDE_DIR "${NCCL_ROOT_DIR}/include" CACHE PATH "Folder contains NVIDIA NCCL headers")
+ set(NCCL_LIB_DIR "${NCCL_ROOT_DIR}/lib" CACHE PATH "Folder contains NVIDIA NCCL libraries")
+else()
+ set(NCCL_INCLUDE_DIR $ENV{NCCL_INCLUDE_DIR} CACHE PATH "Folder contains NVIDIA NCCL headers")
+ set(NCCL_LIB_DIR $ENV{NCCL_LIB_DIR} CACHE PATH "Folder contains NVIDIA NCCL libraries")
+endif()
+
+# Compatible layer for CMake <3.12. NCCL_ROOT will be accounted in for searching paths and libraries for CMake >=3.12.
+if(NOT NCCL_INCLUDE_DIR OR NOT NCCL_LIB_DIR)
+ execute_process(
+ COMMAND python -c "import nvidia.nccl;import os; print(os.path.dirname(nvidia.nccl.__file__))"
+ OUTPUT_VARIABLE NCCL_PIP_DIR
+ OUTPUT_STRIP_TRAILING_WHITESPACE
+ )
+ list(APPEND NCCL_ROOT $ENV{NCCL_PIP_DIR})
+ if(NOT NCCL_INCLUDE_DIR)
+ set(NCCL_INCLUDE_DIR "${NCCL_PIP_DIR}/include")
+ endif()
+ if(NOT NCCL_LIB_DIR)
+ set(NCCL_LIB_DIR "${NCCL_PIP_DIR}/lib")
+ endif()
+ find_library(NCCL_LIBRARIES
+ NAMES ${NCCL_LIBNAME}
+ HINTS ${NCCL_LIB_DIR})
+endif()
+
+list(APPEND CMAKE_PREFIX_PATH ${NCCL_ROOT})
+find_path(NCCL_INCLUDE_DIRS
+ NAMES nccl.h
+ HINTS ${NCCL_INCLUDE_DIR})
+
+
+
+if (USE_STATIC_NCCL)
+ MESSAGE(STATUS "USE_STATIC_NCCL is set. Linking with static NCCL library.")
+ SET(NCCL_LIBNAME "nccl_static")
+ if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified
+ set(CMAKE_FIND_LIBRARY_SUFFIXES ".a.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES})
+ endif()
+else()
+ SET(NCCL_LIBNAME "nccl")
+
+ if (NCCL_VERSION) # Prefer the versioned library if a specific NCCL version is specified
+ message(STATUS "NCCL version: ${NCCL_VERSION}")
+ set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.${NCCL_VERSION}" ${CMAKE_FIND_LIBRARY_SUFFIXES})
+ else()
+ set(CMAKE_FIND_LIBRARY_SUFFIXES ".so.2" ${CMAKE_FIND_LIBRARY_SUFFIXES})
+ endif()
+
+endif()
+
+find_library(NCCL_LIBRARIES
+ NAMES ${NCCL_LIBNAME}
+ HINTS ${NCCL_LIB_DIR})
+
+
+include(FindPackageHandleStandardArgs)
+find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
+
+if(NCCL_FOUND) # obtaining NCCL version and some sanity checks
+ set (NCCL_HEADER_FILE "${NCCL_INCLUDE_DIRS}/nccl.h")
+ message (STATUS "Determining NCCL version from ${NCCL_HEADER_FILE}...")
+ set (OLD_CMAKE_REQUIRED_INCLUDES ${CMAKE_REQUIRED_INCLUDES})
+ list (APPEND CMAKE_REQUIRED_INCLUDES ${NCCL_INCLUDE_DIRS})
+ include(CheckCXXSymbolExists)
+ check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED)
+
+ if (NCCL_VERSION_DEFINED)
+ set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc")
+ file(WRITE ${file} "
+ #include
+ #include
+ int main()
+ {
+ std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl;
+ int x;
+ ncclGetVersion(&x);
+ return x == NCCL_VERSION_CODE;
+ }
+")
+ try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file}
+ RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER
+ CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${NCCL_INCLUDE_DIRS}"
+ LINK_LIBRARIES ${NCCL_LIBRARIES})
+ if (NOT NCCL_VERSION_MATCHED)
+ message(FATAL_ERROR "Found NCCL header version and library version do not match! \
+(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.")
+ endif()
+ message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}")
+ else()
+ # message(STATUS "NCCL version < 2.3.5-5")
+ endif ()
+ set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES})
+
+ message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
+ mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)
+endif()
diff --git a/examples/BMTrain/csrc/bind.cpp b/examples/BMTrain/csrc/bind.cpp
new file mode 100644
index 00000000..b8f6fa85
--- /dev/null
+++ b/examples/BMTrain/csrc/bind.cpp
@@ -0,0 +1,35 @@
+#include "include/bind.hpp"
+
+PYBIND11_MODULE(C, m) {
+ m.def("to_fp16_from_fp32", &fp16_from_fp32_value_launcher, "convert");
+ m.def("to_bf16_from_fp32", &bf16_from_fp32_value_launcher, "convert");
+ m.def("is_bf16_supported", &is_bf16_supported, "whether bf16 supported");
+ m.def("has_nan_inf_fp16_launcher", &has_nan_inf_fp16_launcher, "has nan inf");
+ m.def("has_nan_inf_bf16_launcher", &has_nan_inf_bf16_launcher, "has nan inf bf16");
+ m.def("adam_fp16_launcher", &adam_fp16_launcher, "adam function cpu");
+ m.def("adam_bf16_launcher", &adam_bf16_launcher, "adam function cpu");
+ m.def("adam_cpu_fp16_launcher", &adam_cpu_fp16_launcher, "adam function cpu");
+ m.def("adam_cpu_bf16_launcher", &adam_cpu_bf16_launcher, "adam function cpu");
+ m.def("cross_entropy_forward_fp16_launcher", &cross_entropy_forward_fp16_launcher, "cross entropy forward");
+ m.def("cross_entropy_forward_bf16_launcher", &cross_entropy_forward_bf16_launcher, "cross entropy forward");
+ m.def("cross_entropy_backward_inplace_fp16_launcher", &cross_entropy_backward_inplace_fp16_launcher, "cross entropy backward inplace");
+ m.def("cross_entropy_backward_inplace_bf16_launcher", &cross_entropy_backward_inplace_bf16_launcher, "cross entropy backward inplace");
+ m.def("fused_sumexp_fp16_launcher", &fused_sumexp_fp16_launcher, "sum exp");
+ m.def("fused_sumexp_bf16_launcher", &fused_sumexp_bf16_launcher, "sum exp");
+ m.def("fused_softmax_inplace_fp16_launcher", &fused_softmax_inplace_fp16_launcher, "softmax inplace");
+ m.def("fused_softmax_inplace_bf16_launcher", &fused_softmax_inplace_bf16_launcher, "softmax inplace");
+ m.def("ncclGetUniqueId", &pyNCCLGetUniqueID, "nccl get unique ID");
+ m.def("ncclCommInitRank", &pyNCCLCommInitRank, "nccl init rank");
+ m.def("ncclCommDestroy", &pyNCCLCommDestroy, "nccl delete rank");
+ m.def("ncclAllGather", &pyNCCLAllGather, "nccl all gather");
+ m.def("ncclAllReduce", &pyNCCLAllReduce, "nccl all reduce");
+ m.def("ncclBroadcast", &pyNCCLBroadcast, "nccl broadcast");
+ m.def("ncclReduce", &pyNCCLReduce, "nccl reduce");
+ m.def("ncclReduceScatter", &pyNCCLReduceScatter, "nccl reduce scatter");
+ m.def("ncclGroupStart", &pyNCCLGroupStart, "nccl group start");
+ m.def("ncclGroupEnd", &pyNCCLGroupEnd, "nccl group end");
+ m.def("ncclSend", &pyNCCLSend, "nccl send");
+ m.def("ncclRecv", &pyNCCLRecv, "nccl recv");
+ m.def("ncclCommCount", &pyNCCLCommCount, "nccl comm count");
+ m.def("ncclCommUserRank", &pyNCCLCommUserRank, "nccl comm user rank");
+}
diff --git a/examples/BMTrain/csrc/cuda/adam_cuda.cu b/examples/BMTrain/csrc/cuda/adam_cuda.cu
new file mode 100644
index 00000000..0510ac12
--- /dev/null
+++ b/examples/BMTrain/csrc/cuda/adam_cuda.cu
@@ -0,0 +1,126 @@
+#include
+#include
+#include
+#include "bfloat16.cuh"
+
+namespace {
+// blocks , threads
+__global__ void adam_fp32_accum(
+ int32_t n,
+ const half *g, // (n)
+ half *m, // (n)
+ float *v, // (n)
+ float *param, // (n)
+ half *param_h, // (n)
+ float beta1,
+ float beta2,
+ float eps,
+ float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+) {
+ int32_t col = blockIdx.x * blockDim.x + threadIdx.x;
+ if (col < n) {
+ float local_g = __half2float(g[col]); // real_g * scale
+ float local_m = beta1 * __half2float(m[col]) + (1 - beta1) * local_g; // real_m * scale
+ float local_v = beta2 * v[col] + (1 - beta2) * local_g * local_g / scale; // real_v * scale
+ float local_p = param[col];
+ local_p = local_p - lr * local_m / bias_correction1 / (sqrtf(local_v * scale / bias_correction2) + eps * scale) - lr * weight_decay * local_p;
+
+ param_h[col] = __float2half(local_p);
+ param[col] = local_p;
+ v[col] = local_v;
+ m[col] = __float2half(local_m);
+ }
+}
+
+__global__ void adam_fp32_accum_bf16(
+ int32_t n,
+ const std::uintptr_t g_ptr, // (n)
+ float *m, // (n)
+ float *v, // (n)
+ float *param, // (n)
+ std::uintptr_t param_h_ptr, // (n)
+ float beta1,
+ float beta2,
+ float eps,
+ float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+) {
+#ifdef BF16_SUPPORT
+ const __nv_bfloat16* g = reinterpret_cast(g_ptr);
+ __nv_bfloat16* param_h = reinterpret_cast<__nv_bfloat16*>(param_h_ptr);
+ int32_t col = blockIdx.x * blockDim.x + threadIdx.x;
+ if (col < n) {
+ float local_g = __bfloat162float(g[col]) / scale; // real_g
+ float local_m = beta1 * m[col] + (1 - beta1) * local_g; // real_m
+ float local_v = beta2 * v[col] + (1 - beta2) * local_g * local_g; // real_v
+ float local_p = param[col];
+ local_p = local_p - lr * local_m / bias_correction1 / (sqrtf(local_v / bias_correction2) + eps) - lr * weight_decay * local_p;
+
+ param_h[col] = __float2bfloat16(local_p);
+ param[col] = local_p;
+ v[col] = local_v;
+ m[col] = local_m;
+ }
+#endif
+}
+
+}
+
+void adam_fp16_launcher(
+ int n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_fp16,
+ std::uintptr_t g_fp16,
+ std::uintptr_t m_fp16,
+ std::uintptr_t v_fp32,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2,
+ uintptr_t stream
+) {
+ if (n <= 0) return;
+ auto g_ptr = reinterpret_cast(g_fp16);
+ auto m_ptr = reinterpret_cast(m_fp16);
+ auto param_h_ptr = reinterpret_cast(param_fp16);
+ auto param_fp32_ptr = reinterpret_cast(param_fp32);
+ auto v_fp32_ptr = reinterpret_cast(v_fp32);
+ int32_t threads = 1024;
+ dim3 block_size = dim3(threads, 1, 1);
+ dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1);
+ adam_fp32_accum<<(stream)>>>(n, g_ptr, m_ptr, v_fp32_ptr, param_fp32_ptr, param_h_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
+}
+
+void adam_bf16_launcher(
+ int n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_bf16,
+ std::uintptr_t g_bf16,
+ std::uintptr_t m_fp32,
+ std::uintptr_t v_fp32,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2,
+ uintptr_t stream
+) {
+ if (n <= 0) return;
+ auto m_ptr = reinterpret_cast(m_fp32);
+ auto param_fp32_ptr = reinterpret_cast(param_fp32);
+ auto v_fp32_ptr = reinterpret_cast(v_fp32);
+ int32_t threads = 1024;
+ dim3 block_size = dim3(threads, 1, 1);
+ dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1);
+ adam_fp32_accum_bf16<<(stream)>>>(n, g_bf16, m_ptr, v_fp32_ptr, param_fp32_ptr, param_bf16, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
+}
diff --git a/examples/BMTrain/csrc/cuda/bfloat16.cuh b/examples/BMTrain/csrc/cuda/bfloat16.cuh
new file mode 100644
index 00000000..564d8bec
--- /dev/null
+++ b/examples/BMTrain/csrc/cuda/bfloat16.cuh
@@ -0,0 +1,5 @@
+#include
+#if defined(__CUDACC__) && CUDA_VERSION >= 11000
+#include
+#define BF16_SUPPORT
+#endif
\ No newline at end of file
diff --git a/examples/BMTrain/csrc/cuda/cross_entropy.cu b/examples/BMTrain/csrc/cuda/cross_entropy.cu
new file mode 100644
index 00000000..177c3b77
--- /dev/null
+++ b/examples/BMTrain/csrc/cuda/cross_entropy.cu
@@ -0,0 +1,315 @@
+#include "reduce.cuh"
+#include
+#include
+#include
+#include "bfloat16.cuh"
+
+namespace {
+// blocks , threads<1024>
+__global__ void cross_entropy_forward_fp16(
+ int64_t n,
+ const half *input, // (m, n)
+ const int32_t *target, // (m)
+ half *softmax, // (m, n)
+ float *output, // (m)
+ int32_t ignore_index
+) {
+ int64_t base_idx = blockIdx.x * n;
+
+ float local_max = -INFINITY;
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ local_max = fmaxf(__half2float(input[base_idx + i]), local_max);
+ }
+
+ local_max = fmaxf(block_allreduce_max(local_max), -1e6);
+
+ float local_sum = 0;
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ local_sum += expf(__half2float(input[base_idx + i]) - local_max);
+ }
+ local_sum = block_allreduce_sum(local_sum) + 1e-10; // avoid nan
+
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ softmax[base_idx + i] = __float2half( expf(__half2float(input[base_idx + i]) - local_max) / local_sum );
+ }
+
+ if (threadIdx.x == 0) {
+ if (target[blockIdx.x] != ignore_index) {
+ output[blockIdx.x] = -__half2float(input[base_idx + target[blockIdx.x]]) + local_max + logf(local_sum);
+ } else {
+ output[blockIdx.x] = 0;
+ }
+ }
+}
+
+// blocks , threads<1024>
+__global__ void cross_entropy_backward_inplace_fp16(
+ int64_t n,
+ const float *grad_output, // (m)
+ const int32_t *target, // (m)
+ half *x, // (m, n)
+ int32_t ignore_index
+) {
+ int64_t base_idx = blockIdx.x * n;
+
+ int32_t t = target[blockIdx.x];
+ float v = grad_output[blockIdx.x];
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ x[base_idx + i] = __float2half(i==t ? (__half2float(x[base_idx + i])-1)*v : __half2float(x[base_idx + i])*v);
+ }
+}
+
+// blocks , threads<1024>
+__global__ void cross_entropy_forward_bf16(
+ int64_t n,
+ const std::uintptr_t input_ptr, // (m, n)
+ const int32_t *target, // (m)
+ std::uintptr_t softmax_ptr, // (m, n)
+ float *output, // (m)
+ int32_t ignore_index
+) {
+#ifdef BF16_SUPPORT
+ const __nv_bfloat16* input = reinterpret_cast(input_ptr);
+ __nv_bfloat16* softmax = reinterpret_cast<__nv_bfloat16*>(softmax_ptr);
+ int64_t base_idx = blockIdx.x * n;
+
+ float local_max = -INFINITY;
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ local_max = fmaxf(__bfloat162float(input[base_idx + i]), local_max);
+ }
+
+ local_max = fmaxf(block_allreduce_max(local_max), -1e6);
+
+ float local_sum = 0;
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ local_sum += expf(__bfloat162float(input[base_idx + i]) - local_max);
+ }
+ local_sum = block_allreduce_sum(local_sum) + 1e-10; // avoid nan
+
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ softmax[base_idx + i] = __float2bfloat16( expf(__bfloat162float(input[base_idx + i]) - local_max) / local_sum );
+ }
+
+ if (threadIdx.x == 0) {
+ if (target[blockIdx.x] != ignore_index) {
+ output[blockIdx.x] = -__bfloat162float(input[base_idx + target[blockIdx.x]]) + local_max + logf(local_sum);
+ } else {
+ output[blockIdx.x] = 0;
+ }
+ }
+#endif
+}
+
+// blocks , threads<1024>
+__global__ void cross_entropy_backward_inplace_bf16(
+ int64_t n,
+ const float *grad_output, // (m)
+ const int32_t *target, // (m)
+ std::uintptr_t x_ptr, // (m, n)
+ int32_t ignore_index
+) {
+#ifdef BF16_SUPPORT
+ __nv_bfloat16* x = reinterpret_cast<__nv_bfloat16*>(x_ptr);
+ int64_t base_idx = blockIdx.x * n;
+
+ int32_t t = target[blockIdx.x];
+ float v = grad_output[blockIdx.x];
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ x[base_idx + i] = __float2bfloat16(i==t ? (__bfloat162float(x[base_idx + i])-1)*v : __bfloat162float(x[base_idx + i])*v);
+ }
+#endif
+}
+
+// blocks , threads<1024>
+__global__ void fused_sumexp_fp16(
+ int64_t n,
+ const half *input, // (m, n)
+ const float *global_max, // (m)
+ float *global_sum // (m)
+) {
+ int64_t base_idx = blockIdx.x * n;
+ float local_max = global_max[blockIdx.x];
+
+ float local_sum = 0;
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ local_sum += expf(__half2float(input[base_idx + i]) - local_max);
+ }
+ local_sum = block_allreduce_sum(local_sum);
+ if (threadIdx.x == 0) {
+ global_sum[blockIdx.x] = local_sum;
+ }
+}
+
+// blocks , threads<1024>
+__global__ void fused_sumexp_bf16(
+ int64_t n,
+ const std::uintptr_t input_ptr, // (m, n)
+ const float *global_max, // (m)
+ float *global_sum // (m)
+) {
+#ifdef BF16_SUPPORT
+ const __nv_bfloat16* input = reinterpret_cast(input_ptr);
+ int64_t base_idx = blockIdx.x * n;
+ float local_max = global_max[blockIdx.x];
+
+ float local_sum = 0;
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ local_sum += expf(__bfloat162float(input[base_idx + i]) - local_max);
+ }
+ local_sum = block_allreduce_sum(local_sum);
+ if (threadIdx.x == 0) {
+ global_sum[blockIdx.x] = local_sum;
+ }
+#endif
+}
+
+// blocks , threads<1024>
+__global__ void fused_softmax_inplace_fp16(
+ int64_t n,
+ half *softmax, // (m, n)
+ const float *global_max, // (m)
+ const float *global_sum // (m)
+) {
+ int64_t base_idx = blockIdx.x * n;
+ float local_max = global_max[blockIdx.x];
+ float local_sum = global_sum[blockIdx.x];
+
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ softmax[base_idx + i] = __float2half( expf(__half2float(softmax[base_idx + i]) - local_max) / local_sum );
+ }
+}
+
+// blocks , threads<1024>
+__global__ void fused_softmax_inplace_bf16(
+ int64_t n,
+ std::uintptr_t softmax_ptr, // (m, n)
+ const float *global_max, // (m)
+ const float *global_sum // (m)
+) {
+#ifdef BF16_SUPPORT
+ __nv_bfloat16* softmax = reinterpret_cast<__nv_bfloat16*>(softmax_ptr);
+ int64_t base_idx = blockIdx.x * n;
+ float local_max = global_max[blockIdx.x];
+ float local_sum = global_sum[blockIdx.x];
+
+ for (int64_t i = threadIdx.x; i < n; i += blockDim.x) {
+ softmax[base_idx + i] = __float2bfloat16( expf(__bfloat162float(softmax[base_idx + i]) - local_max) / local_sum );
+ }
+#endif
+}
+}
+
+void cross_entropy_forward_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t input,
+ std::uintptr_t target,
+ std::uintptr_t softmax,
+ std::uintptr_t output,
+ int32_t ignore_index,
+ std::uintptr_t stream
+) {
+ auto input_ptr = reinterpret_cast(input);
+ auto target_ptr = reinterpret_cast(target);
+ auto softmax_ptr = reinterpret_cast(softmax);
+ auto output_ptr = reinterpret_cast(output);
+ int32_t threads = 1024;
+ cross_entropy_forward_fp16<<(stream)>>>(n, input_ptr, target_ptr, softmax_ptr, output_ptr, ignore_index);
+}
+
+void cross_entropy_backward_inplace_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t grad_output,
+ std::uintptr_t target,
+ std::uintptr_t x,
+ int32_t ignore_index,
+ std::uintptr_t stream
+) {
+ auto output_ptr = reinterpret_cast(grad_output);
+ auto target_ptr = reinterpret_cast(target);
+ auto x_ptr = reinterpret_cast(x);
+ int32_t threads = 1024;
+ cross_entropy_backward_inplace_fp16<<(stream)>>>(n, output_ptr, target_ptr, x_ptr, ignore_index);
+}
+
+void cross_entropy_forward_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t input,
+ std::uintptr_t target,
+ std::uintptr_t softmax,
+ std::uintptr_t output,
+ int32_t ignore_index,
+ std::uintptr_t stream
+) {
+ auto target_ptr = reinterpret_cast(target);
+ auto output_ptr = reinterpret_cast(output);
+ int32_t threads = 1024;
+ cross_entropy_forward_bf16<<(stream)>>>(n, input, target_ptr, softmax, output_ptr, ignore_index);
+}
+
+void cross_entropy_backward_inplace_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t grad_output,
+ std::uintptr_t target,
+ std::uintptr_t x,
+ int32_t ignore_index,
+ std::uintptr_t stream
+) {
+ auto output_ptr = reinterpret_cast(grad_output);
+ auto target_ptr = reinterpret_cast(target);
+ int32_t threads = 1024;
+ cross_entropy_backward_inplace_bf16<<(stream)>>>(n, output_ptr, target_ptr, x, ignore_index);
+}
+
+void fused_sumexp_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+) {
+ auto logits_ptr = reinterpret_cast(logits);
+ auto max_logits_ptr = reinterpret_cast(max_logits);
+ auto sum_exp_logits_ptr = reinterpret_cast(sum_exp_logits);
+ int32_t threads = 1024;
+ fused_sumexp_fp16<<(stream)>>>(n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr);
+}
+
+void fused_sumexp_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+) {
+ auto max_logits_ptr = reinterpret_cast(max_logits);
+ auto sum_exp_logits_ptr = reinterpret_cast(sum_exp_logits);
+ int32_t threads = 1024;
+ fused_sumexp_bf16<<(stream)>>>(n, logits, max_logits_ptr, sum_exp_logits_ptr);
+}
+
+void fused_softmax_inplace_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+) {
+ auto logits_ptr = reinterpret_cast(logits);
+ auto max_logits_ptr = reinterpret_cast(max_logits);
+ auto sum_exp_logits_ptr = reinterpret_cast(sum_exp_logits);
+ int32_t threads = 1024;
+ fused_softmax_inplace_fp16<<(stream)>>>(n, logits_ptr, max_logits_ptr, sum_exp_logits_ptr);
+}
+
+void fused_softmax_inplace_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+) {
+ auto max_logits_ptr = reinterpret_cast(max_logits);
+ auto sum_exp_logits_ptr = reinterpret_cast(sum_exp_logits);
+ int32_t threads = 1024;
+ fused_softmax_inplace_bf16<<(stream)>>>(n, logits, max_logits_ptr, sum_exp_logits_ptr);
+}
\ No newline at end of file
diff --git a/examples/BMTrain/csrc/cuda/has_inf_nan.cu b/examples/BMTrain/csrc/cuda/has_inf_nan.cu
new file mode 100644
index 00000000..32bc5a5f
--- /dev/null
+++ b/examples/BMTrain/csrc/cuda/has_inf_nan.cu
@@ -0,0 +1,145 @@
+#include
+#include
+#include
+#include
+#include "bfloat16.cuh"
+
+namespace{
+__inline__ __device__ bool isnan_(half v) {
+ #if __CUDA_ARCH__ >= 700 || __CUDA_ARCH__ == 600
+ return __hisnan(v);
+ #else
+ return !__heq(v, v);
+ #endif
+}
+
+__inline__ __device__ int8_t warpReduceAny(int8_t x) {
+ for (int offset = warpSize/2; offset > 0; offset /= 2)
+ x |= __shfl_down_sync(0xFFFFFFFF, x, offset);
+ return x;
+}
+
+__inline__ __device__ float blockReduceAny(int8_t x) {
+ static __shared__ float shared[32];
+ int lane = threadIdx.x % warpSize;
+ int wid = threadIdx.x / warpSize;
+ x = warpReduceAny(x);
+ if (lane == 0) shared[wid] = x;
+ __syncthreads();
+ x = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0;
+ if (wid == 0) x = warpReduceAny(x);
+ return x;
+}
+
+// grid , thread<1024>
+__global__ void bmt_has_nan_inf_fp16(
+ int32_t n,
+ const half* inp, // (n,)
+ uint8_t* mid // (1024,)
+) {
+ int32_t gid = blockIdx.x * blockDim.x + threadIdx.x;
+ int32_t span = blockDim.x * gridDim.x;
+
+ int8_t r = 0;
+ for (int i = gid; i < n; i += span) {
+ half v = inp[i];
+ if (__hisinf(v) || isnan_(v)) {
+ r = 1;
+ break;
+ }
+ }
+ r = blockReduceAny(r);
+ if (threadIdx.x == 0) {
+ mid[blockIdx.x] = r;
+ }
+}
+
+// grid <1>, thread<1024>
+__global__ void bmt_has_nan_inf_reduce(
+ const uint8_t* mid, // (1024,)
+ uint8_t* out
+) {
+ int tid = threadIdx.x;
+ int8_t r = blockReduceAny(mid[tid]);
+ if (tid == 0 && r > 0) {
+ out[0] = 1;
+ }
+}
+
+// grid , thread<1024>
+__global__ void bmt_has_nan_inf_bf16(
+ int32_t n,
+ const uintptr_t inp, // (n,)
+ uint8_t* mid // (1024,)
+) {
+#ifdef BF16_SUPPORT
+ const __nv_bfloat16* bf_inp = reinterpret_cast(inp);
+ int32_t gid = blockIdx.x * blockDim.x + threadIdx.x;
+ int32_t span = blockDim.x * gridDim.x;
+
+ int8_t r = 0;
+ for (int i = gid; i < n; i += span) {
+ __nv_bfloat16 v = bf_inp[i];
+ #if __CUDA_ARCH__ >= 800
+ if (__hisinf(v) || __hisnan(v)) {
+ #else
+ if (isinf(__bfloat162float(v)) || isnan(__bfloat162float(v))) {
+ #endif
+ r = 1;
+ break;
+ }
+ }
+ r = blockReduceAny(r);
+ if (threadIdx.x == 0) {
+ mid[blockIdx.x] = r;
+ }
+#endif
+}
+
+}
+
+void has_nan_inf_fp16_launcher(
+ int32_t n,
+ std::uintptr_t g_fp16,
+ std::uintptr_t mid,
+ std::uintptr_t out,
+ std::uintptr_t stream
+) {
+ if (n <= 0) return;
+ auto g_ptr = reinterpret_cast(g_fp16);
+ auto mid_ptr = reinterpret_cast(mid);
+ auto out_ptr = reinterpret_cast(out);
+ int32_t threads = 1024;
+ dim3 block_size = dim3(threads, 1, 1);
+ dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1);
+ dim3 clamp_grid_size = dim3(min((n + threads - 1) / threads, 1024), 1, 1);
+
+ bmt_has_nan_inf_fp16<<(stream)>>>(n, g_ptr, mid_ptr);
+ bmt_has_nan_inf_reduce<<<1, block_size, 0, reinterpret_cast(stream)>>>(mid_ptr, out_ptr);
+}
+
+void has_nan_inf_bf16_launcher(
+ int32_t n,
+ std::uintptr_t g_bf16,
+ std::uintptr_t mid,
+ std::uintptr_t out,
+ std::uintptr_t stream
+) {
+ if (n <= 0) return;
+ auto mid_ptr = reinterpret_cast(mid);
+ auto out_ptr = reinterpret_cast(out);
+ int32_t threads = 1024;
+ dim3 block_size = dim3(threads, 1, 1);
+ dim3 grid_size = dim3((n + threads - 1) / threads, 1, 1);
+ dim3 clamp_grid_size = dim3(min((n + threads - 1) / threads, 1024), 1, 1);
+
+ bmt_has_nan_inf_bf16<<(stream)>>>(n, g_bf16, mid_ptr);
+ bmt_has_nan_inf_reduce<<<1, block_size, 0, reinterpret_cast(stream)>>>(mid_ptr, out_ptr);
+}
+
+int is_bf16_supported() {
+#ifdef BF16_SUPPORT
+ return 1;
+#endif
+ return 0;
+}
\ No newline at end of file
diff --git a/examples/BMTrain/csrc/cuda/reduce.cuh b/examples/BMTrain/csrc/cuda/reduce.cuh
new file mode 100644
index 00000000..a9c4c15b
--- /dev/null
+++ b/examples/BMTrain/csrc/cuda/reduce.cuh
@@ -0,0 +1,114 @@
+namespace {
+const int WARP_SZ = 32;
+
+// blocks , threads<1024>
+__device__ float block_reduce_sum(float val) {
+ static __shared__ float s_x[WARP_SZ];
+ // int gid = threadIdx.x + blockIdx.x * blockDim.x;
+ int tid = threadIdx.x;
+ int lid = threadIdx.x % WARP_SZ;
+ int wid = threadIdx.x / WARP_SZ;
+
+ // reduce intra warp
+
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val += __shfl_down_sync(0xFFFFFFFF, val, offset);
+
+ if (lid == 0) s_x[wid] = val;
+ __syncthreads();
+
+ // reduce inter warp
+ val = (tid < WARP_SZ) ? s_x[lid] : 0;
+ if (wid == 0) {
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val += __shfl_down_sync(0xFFFFFFFF, val, offset);
+ }
+ return val;
+}
+
+// blocks , threads<1024>
+__device__ float block_reduce_max(float val) {
+ static __shared__ float s_x[WARP_SZ];
+ // int gid = threadIdx.x + blockIdx.x * blockDim.x;
+ int tid = threadIdx.x;
+ int lid = threadIdx.x % WARP_SZ;
+ int wid = threadIdx.x / WARP_SZ;
+
+ // reduce intra warp
+
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
+
+ if (lid == 0) s_x[wid] = val;
+ __syncthreads();
+
+ // reduce inter warp
+ val = (tid < WARP_SZ) ? s_x[lid] : -INFINITY;
+ if (wid == 0) {
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
+ }
+ return val;
+}
+
+// blocks , threads<1024>
+__device__ float block_allreduce_sum(float val) {
+ static __shared__ float s_x[WARP_SZ];
+ // int gid = threadIdx.x + blockIdx.x * blockDim.x;
+ int tid = threadIdx.x;
+ int lid = threadIdx.x % WARP_SZ;
+ int wid = threadIdx.x / WARP_SZ;
+
+ // reduce intra warp
+
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val += __shfl_down_sync(0xFFFFFFFF, val, offset);
+
+ if (lid == 0) s_x[wid] = val;
+ __syncthreads();
+
+ // reduce inter warp
+ val = (tid < WARP_SZ) ? s_x[lid] : 0;
+ if (wid == 0) {
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val += __shfl_down_sync(0xFFFFFFFF, val, offset);
+ }
+
+ if (tid == 0) {
+ s_x[0] = val;
+ }
+ __syncthreads();
+ return s_x[0];
+}
+
+// blocks , threads<1024>
+__device__ float block_allreduce_max(float val) {
+ static __shared__ float s_x[WARP_SZ];
+ // int gid = threadIdx.x + blockIdx.x * blockDim.x;
+ int tid = threadIdx.x;
+ int lid = threadIdx.x % WARP_SZ;
+ int wid = threadIdx.x / WARP_SZ;
+
+ // reduce intra warp
+
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
+
+ if (lid == 0) s_x[wid] = val;
+ __syncthreads();
+
+ // reduce inter warp
+ val = (tid < WARP_SZ) ? s_x[lid] : -INFINITY;
+ if (wid == 0) {
+ for (int offset = WARP_SZ/2; offset > 0; offset >>= 1)
+ val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
+ }
+
+ if (tid == 0) {
+ s_x[0] = val;
+ }
+ __syncthreads();
+ return s_x[0];
+}
+
+}
\ No newline at end of file
diff --git a/examples/BMTrain/csrc/include/adam_cpu.hpp b/examples/BMTrain/csrc/include/adam_cpu.hpp
new file mode 100644
index 00000000..52575d69
--- /dev/null
+++ b/examples/BMTrain/csrc/include/adam_cpu.hpp
@@ -0,0 +1,557 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "cpu_info.h"
+#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
+
+static inline float _mm256_reduce_add_ps(__m256 x) {
+ /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */
+ const __m128 x128 = _mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x));
+ /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */
+ const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128));
+ /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */
+ const __m128 x32 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55));
+ /* Conversion to float is a no-op on x86-64 */
+ return _mm_cvtss_f32(x32);
+}
+
+inline float fp32_from_bits(uint32_t w) {
+ union {
+ uint32_t as_bits;
+ float as_value;
+ } fp32 = {w};
+ return fp32.as_value;
+}
+
+inline uint32_t fp32_to_bits(float f) {
+ union {
+ float as_value;
+ uint32_t as_bits;
+ } fp32 = {f};
+ return fp32.as_bits;
+}
+
+template
+inline void parallel_for(int64_t begin, int64_t end, int64_t grain_size, const F& f) {
+ // Number of iterations
+ int64_t numiter = end - begin;
+
+ // Number of threads to use
+ int64_t num_threads = 1; // Default to serial execution
+
+ if (grain_size > 0) {
+ num_threads = std::max((numiter+grain_size-1) / grain_size, static_cast(1));
+ }
+ else{
+ cpu_set_t cpu_set;
+ CPU_ZERO(&cpu_set);
+ sched_getaffinity(0, sizeof(cpu_set), &cpu_set);
+ num_threads = CPU_COUNT(&cpu_set);
+ grain_size = std::max((numiter+num_threads-1) / num_threads, static_cast(1));
+
+ }
+
+ // Check if parallel execution is feasible
+ if (num_threads > 1) {
+ py::gil_scoped_release release; // Release the GIL
+ std::vector threads(num_threads);
+ for (int64_t t = 0; t < num_threads; ++t) {
+ threads[t] = std::thread([&, t]() {
+ int64_t left = std::min(begin + t * grain_size, end);
+ int64_t right = std::min(begin + (t + 1) * grain_size, end);
+ f(left, right);
+ });
+ }
+ for (auto& thread : threads) {
+ thread.join();
+ }
+ } else {
+ // If not feasible or grain_size is 0, perform the operation serially
+ f(begin, end);
+ }
+}
+
+// fp32 -> fp16
+inline uint16_t fp16_ieee_from_fp32_value(float f) {
+ // const float scale_to_inf = 0x1.0p+112f;
+ // const float scale_to_zero = 0x1.0p-110f;
+ uint32_t scale_to_inf_bits = (uint32_t) 239 << 23;
+ uint32_t scale_to_zero_bits = (uint32_t) 17 << 23;
+ float scale_to_inf_val, scale_to_zero_val;
+ std::memcpy(&scale_to_inf_val, &scale_to_inf_bits, sizeof(scale_to_inf_val));
+ std::memcpy(&scale_to_zero_val, &scale_to_zero_bits, sizeof(scale_to_zero_val));
+ const float scale_to_inf = scale_to_inf_val;
+ const float scale_to_zero = scale_to_zero_val;
+
+ float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
+
+ const uint32_t w = (uint32_t)fp32_to_bits(f);
+ const uint32_t shl1_w = w + w;
+ const uint32_t sign = w & UINT32_C(0x80000000);
+ uint32_t bias = shl1_w & UINT32_C(0xFF000000);
+ if (bias < UINT32_C(0x71000000)) {
+ bias = UINT32_C(0x71000000);
+ }
+
+ base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
+ const uint32_t bits = (uint32_t)fp32_to_bits(base);
+ const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
+ const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
+ const uint32_t nonsign = exp_bits + mantissa_bits;
+ return static_cast(
+ (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign)
+ );
+}
+
+// fp16 -> fp32
+inline float fp16_ieee_to_fp32_value(uint16_t h) {
+ const uint32_t w = (uint32_t)h << 16;
+ const uint32_t sign = w & UINT32_C(0x80000000);
+ const uint32_t two_w = w + w;
+
+ const uint32_t exp_offset = UINT32_C(0xE0) << 23;
+ const float exp_scale = 0x1.0p-112f;
+ const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
+
+ const uint32_t magic_mask = UINT32_C(126) << 23;
+ const float magic_bias = 0.5f;
+ const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
+
+ const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
+ const uint32_t result =
+ sign | (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value)
+ : fp32_to_bits(normalized_value));
+ return fp32_from_bits(result);
+}
+
+inline uint16_t bf16_from_fp32_value(float f){
+ return *reinterpret_cast(&f) >> 16;
+}
+// fp32 -> bf16
+void bf16_from_fp32_value_launcher(
+ int64_t n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_bf16
+){
+ int span = 1;
+ auto param_fp32_ptr = reinterpret_cast(param_fp32);
+ auto param_bf16_ptr = reinterpret_cast(param_bf16);
+ parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
+ for (int64_t j = start; j < end; j += span) {
+ for (int64_t i = j; i < end; i++) {
+ float p = param_fp32_ptr[i];
+ param_bf16_ptr[i] = bf16_from_fp32_value(p);
+ }
+ break; // must break here
+ }
+ });
+}
+
+void fp16_from_fp32_value_launcher(
+ int64_t n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_fp16
+){
+ int span = 1;
+ auto param_fp32_ptr = reinterpret_cast(param_fp32);
+ auto param_fp16_ptr = reinterpret_cast(param_fp16);
+ parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
+ for (int64_t j = start; j < end; j += span) {
+ for (int64_t i = j; i < end; i++) {
+ float p = param_fp32_ptr[i];
+ param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p);
+ }
+ break; // must break here
+ }
+ });
+}
+// bf16 -> fp32
+inline float bf16_to_fp32_value(uint16_t h){
+ uint32_t src = h;
+ src <<= 16;
+ return *reinterpret_cast(&src);
+}
+
+void adam_cpu_0(
+ int64_t n,
+ float* param_fp32_ptr,
+ uint16_t* param_fp16_ptr,
+ float* delta_info_ptr,
+ uint16_t* g_fp16_ptr,
+ float* m_fp32_ptr,
+ float* v_fp32_ptr,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+){
+ int64_t span = 1;
+ float sum_sq_delta = 0;
+ float sum_delta = 0;
+ std::mutex delta_mutex;
+ parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
+ float sum_sq_delta_i = 0;
+ float sum_delta_i = 0;
+ for (int64_t j = start; j < end; j += span) {
+ for (int64_t i = j; i < end; i++) {
+ float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale;
+ float m = m_fp32_ptr[i];
+ float v = v_fp32_ptr[i];
+ float p = param_fp32_ptr[i];
+ m = beta1 * m + (1 - beta1) * g;
+ v = beta2 * v + (1 - beta2) * g * g;
+ if (delta_info_ptr != NULL){
+ float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p;
+ sum_delta_i += delta;
+ sum_sq_delta_i += delta * delta;
+ }
+ p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p;
+ param_fp32_ptr[i] = p;
+ param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p);
+ m_fp32_ptr[i] = m;
+ v_fp32_ptr[i] = v;
+ }
+ break; // must break here
+ }
+ if (delta_info_ptr != NULL){
+ delta_mutex.lock();
+ sum_delta += sum_delta_i;
+ sum_sq_delta += sum_sq_delta_i;
+ delta_mutex.unlock();
+ }
+ });
+ if (delta_info_ptr != NULL){
+ delta_info_ptr[0] = sum_delta / n;
+ delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2
+ delta_info_ptr[2] = sum_delta;
+ delta_info_ptr[3] = sum_sq_delta;
+ }
+}
+
+void adam_cpu_bf16_0(
+ int64_t n,
+ float* param_fp32_ptr,
+ uint16_t* param_bf16_ptr,
+ float* delta_info_ptr,
+ uint16_t* g_bf16_ptr,
+ float* m_fp32_ptr,
+ float* v_fp32_ptr,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+){
+ int64_t span = 1;
+ float sum_sq_delta = 0;
+ float sum_delta = 0;
+ std::mutex delta_mutex;
+ parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
+ float sum_sq_delta_i = 0;
+ float sum_delta_i = 0;
+ for (int64_t j = start; j < end; j += span) {
+ for (int64_t i = j; i < end; i++) {
+ float g = bf16_to_fp32_value(g_bf16_ptr[i]) / scale;
+ float m = m_fp32_ptr[i];
+ float v = v_fp32_ptr[i];
+ float p = param_fp32_ptr[i];
+ m = beta1 * m + (1 - beta1) * g;
+ v = beta2 * v + (1 - beta2) * g * g;
+ if (delta_info_ptr != NULL){
+ float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p;
+ sum_delta_i += delta;
+ sum_sq_delta_i += delta * delta;
+ }
+ p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p;
+ param_fp32_ptr[i] = p;
+ param_bf16_ptr[i] = bf16_from_fp32_value(p);
+ m_fp32_ptr[i] = m;
+ v_fp32_ptr[i] = v;
+ }
+ break; // must break here
+ }
+ if (delta_info_ptr != NULL){
+ delta_mutex.lock();
+ sum_delta += sum_delta_i;
+ sum_sq_delta += sum_sq_delta_i;
+ delta_mutex.unlock();
+ }
+ });
+ if (delta_info_ptr != NULL){
+ delta_info_ptr[0] = sum_delta / n;
+ delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2
+ delta_info_ptr[2] = sum_delta;
+ delta_info_ptr[3] = sum_sq_delta;
+ }
+}
+
+static void __attribute__ ((__target__ ("avx,fma,f16c"))) adam_cpu_1(
+ int64_t n,
+ float* param_fp32_ptr,
+ uint16_t* param_fp16_ptr,
+ float* delta_info_ptr,
+ uint16_t* g_fp16_ptr,
+ float* m_fp32_ptr,
+ float* v_fp32_ptr,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+){
+ float sum_sq_delta = 0;
+ float sum_delta = 0;
+ std::mutex delta_mutex;
+ auto avx_beta1 = _mm256_set1_ps(beta1);
+ auto avx_beta2 = _mm256_set1_ps(beta2);
+ auto avx_beta1_1 = _mm256_set1_ps(1 - beta1);
+ auto avx_beta2_1 = _mm256_set1_ps(1 - beta2);
+ auto avx_eps = _mm256_set1_ps(eps);
+ auto avx_neg_lr = _mm256_set1_ps(-lr);
+ auto avx_scale = _mm256_set1_ps(scale);
+ auto avx_weight_decay = _mm256_set1_ps(weight_decay);
+ auto avx_bias_correction1 = _mm256_set1_ps(bias_correction1);
+ auto avx_bias_correction2 = _mm256_set1_ps(bias_correction2);
+ int64_t span = 8;
+ parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
+ float sum_sq_delta_i = 0;
+ float sum_delta_i = 0;
+ for (int64_t j = start; j < end; j += span) {
+ if (j + span > end) {
+ for (int64_t i = j; i < end; i++) {
+ float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale;
+ float m = m_fp32_ptr[i];
+ float v = v_fp32_ptr[i];
+ float p = param_fp32_ptr[i];
+ m = beta1 * m + (1 - beta1) * g;
+ v = beta2 * v + (1 - beta2) * g * g;
+ if (delta_info_ptr != NULL){
+ float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p;
+ sum_delta_i += delta;
+ sum_sq_delta_i += delta * delta;
+ }
+ p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p;
+ param_fp32_ptr[i] = p;
+ param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p);
+ m_fp32_ptr[i] = m;
+ v_fp32_ptr[i] = v;
+ }
+ break; // must break here
+ } else {
+ auto g = _mm256_div_ps(_mm256_cvtph_ps(_mm_loadu_si128((const __m128i*)&g_fp16_ptr[j])), avx_scale);
+ auto m = _mm256_loadu_ps(&m_fp32_ptr[j]);
+ auto v = _mm256_loadu_ps(&v_fp32_ptr[j]);
+ auto p = _mm256_loadu_ps(¶m_fp32_ptr[j]);
+ m = _mm256_fmadd_ps(avx_beta1, m, _mm256_mul_ps(avx_beta1_1, g));
+ v = _mm256_fmadd_ps(avx_beta2, v, _mm256_mul_ps(avx_beta2_1, _mm256_mul_ps(g, g)));
+ if (delta_info_ptr != NULL){
+ auto delta_256 = _mm256_add_ps(
+ _mm256_div_ps(
+ _mm256_div_ps(m, avx_bias_correction1), // m / bias_correction1
+ _mm256_add_ps(_mm256_sqrt_ps(_mm256_div_ps(v, avx_bias_correction2)), avx_eps) // sqrt(v / bias_correction2) + eps
+ ), // m / bias_correction1 / (sqrt(v / bias_correction2) + eps)
+ _mm256_mul_ps(avx_weight_decay, p) // weight_decay * p
+ ); // delta = m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + weight_decay * p
+ sum_delta_i += _mm256_reduce_add_ps(delta_256);
+ sum_sq_delta_i += _mm256_reduce_add_ps(_mm256_mul_ps(delta_256, delta_256));
+ }
+ p = _mm256_fmadd_ps(avx_neg_lr, _mm256_mul_ps(avx_weight_decay, p), p); // p = p - lr * weight_decay * p
+ p = _mm256_fmadd_ps(
+ avx_neg_lr,
+ _mm256_div_ps(
+ _mm256_div_ps(m, avx_bias_correction1), // m / bias_correction1
+ _mm256_add_ps(_mm256_sqrt_ps(_mm256_div_ps(v, avx_bias_correction2)), avx_eps) // sqrt(v / bias_correction2) + eps
+ ), // m / bias_correction1 / (sqrt(v / bias_correction2) + eps)
+ p
+ ); // p = p - lr * m / bias_correction1 / (sqrt(v / bias_correction2) + eps)
+ _mm256_storeu_ps(¶m_fp32_ptr[j], p);
+ _mm_storeu_si128((__m128i*)¶m_fp16_ptr[j], _mm256_cvtps_ph(p, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+ _mm256_storeu_ps(&m_fp32_ptr[j], m);
+ _mm256_storeu_ps(&v_fp32_ptr[j], v);
+ }
+ }
+ if (delta_info_ptr != NULL){
+ delta_mutex.lock();
+ sum_delta += sum_delta_i;
+ sum_sq_delta += sum_sq_delta_i;
+ delta_mutex.unlock();
+ }
+ });
+ if (delta_info_ptr != NULL){
+ delta_info_ptr[0] = sum_delta / n;
+ delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2
+ delta_info_ptr[2] = sum_delta;
+ delta_info_ptr[3] = sum_sq_delta;
+ }
+}
+
+static void __attribute__ ((__target__ ("avx512f"))) adam_cpu_2(
+ int64_t n,
+ float* param_fp32_ptr,
+ uint16_t* param_fp16_ptr,
+ float* delta_info_ptr,
+ uint16_t* g_fp16_ptr,
+ float* m_fp32_ptr,
+ float* v_fp32_ptr,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+){
+ float sum_sq_delta = 0;
+ float sum_delta = 0;
+ std::mutex delta_mutex;
+ auto avx_beta1 = _mm512_set1_ps(beta1);
+ auto avx_beta2 = _mm512_set1_ps(beta2);
+ auto avx_beta1_1 = _mm512_set1_ps(1 - beta1);
+ auto avx_beta2_1 = _mm512_set1_ps(1 - beta2);
+ auto avx_eps = _mm512_set1_ps(eps);
+ auto avx_neg_lr = _mm512_set1_ps(-lr);
+ auto avx_scale = _mm512_set1_ps(scale);
+ auto avx_weight_decay = _mm512_set1_ps(weight_decay);
+ auto avx_bias_correction1 = _mm512_set1_ps(bias_correction1);
+ auto avx_bias_correction2 = _mm512_set1_ps(bias_correction2);
+ int64_t span = 16;
+ parallel_for(0, n, 0, [&](int64_t start, int64_t end) {
+ float sum_sq_delta_i = 0;
+ float sum_delta_i = 0;
+ for (int64_t j = start; j < end; j += span) {
+ if (j + span > end) {
+ for (int64_t i = j; i < end; i++) {
+ float g = fp16_ieee_to_fp32_value(g_fp16_ptr[i]) / scale;
+ float m = m_fp32_ptr[i];
+ float v = v_fp32_ptr[i];
+ float p = param_fp32_ptr[i];
+ m = beta1 * m + (1 - beta1) * g;
+ v = beta2 * v + (1 - beta2) * g * g;
+ if (delta_info_ptr != NULL){
+ float delta = m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) + weight_decay * p;
+ sum_delta_i += delta;
+ sum_sq_delta_i += delta * delta;
+ }
+ p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps) - lr * weight_decay * p;
+ param_fp32_ptr[i] = p;
+ param_fp16_ptr[i] = fp16_ieee_from_fp32_value(p);
+ m_fp32_ptr[i] = m;
+ v_fp32_ptr[i] = v;
+ }
+ break; // must break here
+ }else{
+ auto g = _mm512_div_ps(_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i*)&g_fp16_ptr[j])), avx_scale);
+ auto m = _mm512_loadu_ps(&m_fp32_ptr[j]);
+ auto v = _mm512_loadu_ps(&v_fp32_ptr[j]);
+ auto p = _mm512_loadu_ps(¶m_fp32_ptr[j]);
+ m = _mm512_fmadd_ps(avx_beta1, m, _mm512_mul_ps(avx_beta1_1, g));
+ v = _mm512_fmadd_ps(avx_beta2, v, _mm512_mul_ps(avx_beta2_1, _mm512_mul_ps(g, g)));
+ if (delta_info_ptr != NULL){
+ auto delta_512 = _mm512_add_ps(
+ _mm512_div_ps(
+ _mm512_div_ps(m, avx_bias_correction1), // m / bias_correction1
+ _mm512_add_ps(_mm512_sqrt_ps(_mm512_div_ps(v, avx_bias_correction2)), avx_eps) // sqrt(v / bias_correction2) + eps
+ ), // m / bias_correction1 / (sqrt(v / bias_correction2) + eps)
+ _mm512_mul_ps(avx_weight_decay, p) // weight_decay * p
+ ); // delta = m / bias_correction1 / (sqrt(v / bias_correction2) + eps) + weight_decay * p
+ sum_delta_i += _mm512_reduce_add_ps(delta_512);
+ sum_sq_delta_i += _mm512_reduce_add_ps(_mm512_mul_ps(delta_512, delta_512));
+ }
+ p = _mm512_fmadd_ps(avx_neg_lr, _mm512_mul_ps(avx_weight_decay, p), p); // p = p - lr * weight_decay * p
+ p = _mm512_fmadd_ps(
+ avx_neg_lr,
+ _mm512_div_ps(
+ _mm512_div_ps(m, avx_bias_correction1), // m / bias_correction1
+ _mm512_add_ps(
+ _mm512_sqrt_ps(_mm512_div_ps(v, avx_bias_correction2)),
+ avx_eps
+ ) // sqrt(v / bias_correction2) + eps
+ ),
+ p
+ ); // p = p - lr * m / bias_correction1 / (sqrtf(v / bias_correction2) + eps)
+ _mm512_storeu_ps(¶m_fp32_ptr[j], p);
+ _mm256_storeu_si256((__m256i*)¶m_fp16_ptr[j], _mm512_cvtps_ph(p, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+ _mm512_storeu_ps(&m_fp32_ptr[j], m);
+ _mm512_storeu_ps(&v_fp32_ptr[j], v);
+ }
+ }
+ if (delta_info_ptr != NULL){
+ delta_mutex.lock();
+ sum_delta += sum_delta_i;
+ sum_sq_delta += sum_sq_delta_i;
+ delta_mutex.unlock();
+ }
+ });
+ if (delta_info_ptr != NULL){
+ delta_info_ptr[0] = sum_delta / n;
+ delta_info_ptr[1] = sum_sq_delta / n - sum_delta * sum_delta / (n * n);// var = E(x^2) - E(x)^2
+ delta_info_ptr[2] = sum_delta;
+ delta_info_ptr[3] = sum_sq_delta;
+ }
+}
+
+void adam_cpu_fp16_launcher(
+ int64_t n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_fp16,
+ std::uintptr_t delta_info,
+ std::uintptr_t g_fp16,
+ std::uintptr_t m_fp32,
+ std::uintptr_t v_fp32,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+) {
+ auto delta_info_ptr = reinterpret_cast(delta_info);
+ auto param_fp32_ptr = reinterpret_cast(param_fp32);
+ auto m_fp32_ptr = reinterpret_cast(m_fp32);
+ auto v_fp32_ptr = reinterpret_cast(v_fp32);
+ auto param_fp16_ptr = reinterpret_cast(param_fp16);
+ auto g_fp16_ptr = reinterpret_cast(g_fp16);
+ int cpu_level = get_cpu_level();
+ if (cpu_level == 0 ){
+ adam_cpu_0(n, param_fp32_ptr, param_fp16_ptr, delta_info_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
+ }else if(cpu_level == 1){
+ adam_cpu_1(n, param_fp32_ptr, param_fp16_ptr, delta_info_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
+ }else{
+ adam_cpu_2(n, param_fp32_ptr, param_fp16_ptr, delta_info_ptr, g_fp16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
+ }
+}
+
+void adam_cpu_bf16_launcher(
+ int64_t n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_bf16,
+ std::uintptr_t delta_info,
+ std::uintptr_t g_bf16,
+ std::uintptr_t m_fp32,
+ std::uintptr_t v_fp32,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2
+) {
+ auto delta_info_ptr = reinterpret_cast(delta_info);
+ auto m_fp32_ptr = reinterpret_cast(m_fp32);
+ auto v_fp32_ptr = reinterpret_cast(v_fp32);
+ auto param_fp32_ptr = reinterpret_cast(param_fp32);
+ auto param_bf16_ptr = reinterpret_cast(param_bf16);
+ auto g_bf16_ptr = reinterpret_cast(g_bf16);
+ adam_cpu_bf16_0(n, param_fp32_ptr, param_bf16_ptr, delta_info_ptr, g_bf16_ptr, m_fp32_ptr, v_fp32_ptr, beta1, beta2, eps, lr, scale, weight_decay, bias_correction1, bias_correction2);
+}
diff --git a/examples/BMTrain/csrc/include/bind.hpp b/examples/BMTrain/csrc/include/bind.hpp
new file mode 100644
index 00000000..3ff967fd
--- /dev/null
+++ b/examples/BMTrain/csrc/include/bind.hpp
@@ -0,0 +1,111 @@
+#include
+#include "nccl.hpp"
+#include "adam_cpu.hpp"
+
+int is_bf16_supported();
+
+void has_nan_inf_fp16_launcher(int32_t n, std::uintptr_t g_fp16, std::uintptr_t mid, std::uintptr_t out, std::uintptr_t stream);
+void has_nan_inf_bf16_launcher(int32_t n, std::uintptr_t g_bf16, std::uintptr_t mid, std::uintptr_t out, std::uintptr_t stream);
+
+void fp16_from_fp32_value_launcher(
+ int64_t n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_fp16
+);
+void bf16_from_fp32_value_launcher(
+ int64_t n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_bf16
+);
+void cross_entropy_forward_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t input,
+ std::uintptr_t target,
+ std::uintptr_t softmax,
+ std::uintptr_t output,
+ int32_t ignore_index,
+ std::uintptr_t stream
+);
+void cross_entropy_backward_inplace_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t grad_output,
+ std::uintptr_t target,
+ std::uintptr_t x,
+ int32_t ignore_index,
+ std::uintptr_t stream
+);
+void cross_entropy_forward_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t input,
+ std::uintptr_t target,
+ std::uintptr_t softmax,
+ std::uintptr_t output,
+ int32_t ignore_index,
+ std::uintptr_t stream
+);
+void cross_entropy_backward_inplace_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t grad_output,
+ std::uintptr_t target,
+ std::uintptr_t x,
+ int32_t ignore_index,
+ std::uintptr_t stream
+);
+void fused_sumexp_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+);
+void fused_sumexp_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+);
+void fused_softmax_inplace_fp16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+);
+void fused_softmax_inplace_bf16_launcher(
+ int32_t m, int32_t n,
+ std::uintptr_t logits,
+ std::uintptr_t max_logits,
+ std::uintptr_t sum_exp_logits,
+ std::uintptr_t stream
+);
+void adam_fp16_launcher(
+ int n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_fp16,
+ std::uintptr_t g_fp16,
+ std::uintptr_t m_fp16,
+ std::uintptr_t v_fp32,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2,
+ uintptr_t stream
+);
+void adam_bf16_launcher(
+ int n,
+ std::uintptr_t param_fp32,
+ std::uintptr_t param_bf16,
+ std::uintptr_t g_bf16,
+ std::uintptr_t m_fp32,
+ std::uintptr_t v_fp32,
+ float beta1, float beta2,
+ float eps, float lr,
+ float scale,
+ float weight_decay,
+ float bias_correction1,
+ float bias_correction2,
+ uintptr_t stream
+);
diff --git a/examples/BMTrain/csrc/include/cpu_info.h b/examples/BMTrain/csrc/include/cpu_info.h
new file mode 100644
index 00000000..53ed48f8
--- /dev/null
+++ b/examples/BMTrain/csrc/include/cpu_info.h
@@ -0,0 +1,38 @@
+#include
+
+static void cpuid(int info[4], int InfoType){
+ __cpuid_count(InfoType, 0, info[0], info[1], info[2], info[3]);
+}
+
+int get_cpu_level() {
+ // SIMD: 128-bit
+ bool HW_F16C;
+
+ // SIMD: 256-bit
+ bool HW_AVX;
+ bool HW_FMA;
+
+ // SIMD: 512-bit
+ bool HW_AVX512F; // AVX512 Foundation
+
+ int info[4];
+ cpuid(info, 0);
+ int nIds = info[0];
+
+ // Detect Features
+ if (nIds >= 0x00000001){
+ cpuid(info,0x00000001);
+ HW_AVX = (info[2] & ((int)1 << 28)) != 0;
+ HW_FMA = (info[2] & ((int)1 << 12)) != 0;
+ HW_F16C = (info[2] & ((int)1 << 29)) != 0;
+ }
+ if (nIds >= 0x00000007){
+ cpuid(info,0x00000007);
+ HW_AVX512F = (info[1] & ((int)1 << 16)) != 0;
+ }
+
+ int ret = 0;
+ if (HW_AVX && HW_FMA && HW_F16C) ret = 1;
+ if (HW_AVX512F) ret = 2;
+ return ret;
+}
diff --git a/examples/BMTrain/csrc/include/nccl.hpp b/examples/BMTrain/csrc/include/nccl.hpp
new file mode 100644
index 00000000..bba0278b
--- /dev/null
+++ b/examples/BMTrain/csrc/include/nccl.hpp
@@ -0,0 +1,188 @@
+#include
+#include
+#include
+
+namespace py = pybind11;
+#include
+
+void checkNCCLStatus(ncclResult_t result) {
+ if (result == ncclSuccess) return;
+ throw std::logic_error(
+ std::string("NCCL Error: ") +
+ ncclGetErrorString(result)
+ );
+}
+
+py::bytes pyNCCLGetUniqueID() {
+ ncclUniqueId uniqueID;
+ checkNCCLStatus(ncclGetUniqueId(&uniqueID));
+ return py::bytes(uniqueID.internal, NCCL_UNIQUE_ID_BYTES);
+}
+
+std::uintptr_t pyNCCLCommInitRank(py::bytes byteUniqueID, int world_size, int rank) {
+ ncclUniqueId uniqueID;
+ std::memcpy(uniqueID.internal, std::string(byteUniqueID).c_str(), NCCL_UNIQUE_ID_BYTES);
+ ncclComm_t comm;
+ checkNCCLStatus(ncclCommInitRank(&comm, world_size, uniqueID, rank));
+ return reinterpret_cast(comm);
+}
+
+void pyNCCLCommDestroy(std::uintptr_t ptrcomm) {
+ ncclComm_t comm = reinterpret_cast(ptrcomm);
+ checkNCCLStatus(ncclCommDestroy(comm));
+}
+
+void pyNCCLAllGather(
+ std::uintptr_t sendbuff,
+ std::uintptr_t recvbuff,
+ size_t sendcount,
+ int datatype,
+ std::uintptr_t comm,
+ std::uintptr_t stream
+) {
+ checkNCCLStatus(ncclAllGather(
+ reinterpret_cast(sendbuff),
+ reinterpret_cast(recvbuff),
+ sendcount,
+ static_cast(datatype),
+ reinterpret_cast(comm),
+ reinterpret_cast(stream)
+ ));
+}
+
+void pyNCCLAllReduce(
+ std::uintptr_t sendbuff,
+ std::uintptr_t recvbuff,
+ size_t count,
+ int data_type,
+ int op,
+ std::uintptr_t comm,
+ std::uintptr_t stream
+) {
+ checkNCCLStatus(ncclAllReduce(
+ reinterpret_cast(sendbuff),
+ reinterpret_cast(recvbuff),
+ count,
+ static_cast(data_type),
+ static_cast(op),
+ reinterpret_cast(comm),
+ reinterpret_cast(stream)
+ ));
+}
+
+void pyNCCLBroadcast(
+ std::uintptr_t sendbuff,
+ std::uintptr_t recvbuff,
+ size_t count,
+ int datatype,
+ int root,
+ std::uintptr_t comm,
+ std::uintptr_t stream
+) {
+ checkNCCLStatus(ncclBroadcast(
+ reinterpret_cast(sendbuff),
+ reinterpret_cast(recvbuff),
+ count,
+ static_cast(datatype),
+ root,
+ reinterpret_cast(comm),
+ reinterpret_cast(stream)
+ ));
+}
+
+void pyNCCLReduce(
+ std::uintptr_t sendbuff,
+ std::uintptr_t recvbuff,
+ size_t count,
+ int datatype,
+ int op,
+ int root,
+ std::uintptr_t comm,
+ std::uintptr_t stream
+) {
+ checkNCCLStatus(ncclReduce(
+ reinterpret_cast(sendbuff),
+ reinterpret_cast(recvbuff),
+ count,
+ static_cast(datatype),
+ static_cast(op),
+ root,
+ reinterpret_cast(comm),
+ reinterpret_cast(stream)
+ ));
+}
+
+void pyNCCLReduceScatter(
+ std::uintptr_t sendbuff,
+ std::uintptr_t recvbuff,
+ size_t recvcount,
+ int datatype,
+ int op,
+ std::uintptr_t comm,
+ std::uintptr_t stream
+) {
+ checkNCCLStatus(ncclReduceScatter(
+ reinterpret_cast(sendbuff),
+ reinterpret_cast(recvbuff),
+ recvcount,
+ static_cast(datatype),
+ static_cast(op),
+ reinterpret_cast(comm),
+ reinterpret_cast(stream)
+ ));
+}
+void pyNCCLSend(
+ std::uintptr_t sendbuff,
+ size_t sendcount,
+ int data_type,
+ int peer,
+ std::uintptr_t comm,
+ std::uintptr_t stream
+) {
+ checkNCCLStatus(ncclSend(
+ reinterpret_cast(sendbuff),
+ sendcount,
+ static_cast(data_type),
+ peer,
+ reinterpret_cast(comm),
+ reinterpret_cast(stream)
+ ));
+}
+void pyNCCLRecv(
+ std::uintptr_t recvbuff,
+ size_t recvcount,
+ int data_type,
+ int peer,
+ std::uintptr_t comm,
+ std::uintptr_t stream
+) {
+ checkNCCLStatus(ncclRecv(
+ reinterpret_cast(recvbuff),
+ recvcount,
+ static_cast(data_type),
+ peer,
+ reinterpret_cast(comm),
+ reinterpret_cast(stream)
+ ));
+}
+void pyNCCLGroupStart() {
+ checkNCCLStatus(ncclGroupStart());
+}
+
+void pyNCCLGroupEnd() {
+ checkNCCLStatus(ncclGroupEnd());
+}
+int pyNCCLCommCount(
+ std::uintptr_t comm
+){
+ int res;
+ checkNCCLStatus(ncclCommCount(reinterpret_cast(comm),&res));
+ return res;
+}
+int pyNCCLCommUserRank(
+ std::uintptr_t comm
+){
+ int rank;
+ checkNCCLStatus(ncclCommUserRank(reinterpret_cast(comm),&rank));
+ return rank;
+}
diff --git a/examples/BMTrain/doc_requirements.txt b/examples/BMTrain/doc_requirements.txt
new file mode 100644
index 00000000..79d22ca0
--- /dev/null
+++ b/examples/BMTrain/doc_requirements.txt
@@ -0,0 +1,5 @@
+sphinx>=4.0.0
+recommonmark
+sphinx_markdown_tables
+sphinx_rtd_theme>=0.3.0
+torch
\ No newline at end of file
diff --git a/examples/BMTrain/docs/Makefile b/examples/BMTrain/docs/Makefile
new file mode 100644
index 00000000..4f2fbe66
--- /dev/null
+++ b/examples/BMTrain/docs/Makefile
@@ -0,0 +1,20 @@
+# Minimal makefile for Sphinx documentation
+#
+
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= sphinx-build
+SOURCEDIR = source-en
+BUILDDIR = build
+
+# Put it first so that "make" without argument is like "make help".
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+.PHONY: help Makefile
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+%: Makefile
+ @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
diff --git a/examples/BMTrain/docs/UPDATE_0.2.0.md b/examples/BMTrain/docs/UPDATE_0.2.0.md
new file mode 100644
index 00000000..92819afd
--- /dev/null
+++ b/examples/BMTrain/docs/UPDATE_0.2.0.md
@@ -0,0 +1,79 @@
+# Update Log 0.2.0
+
+## What's New
+
+### 1. Added an `Optimizer Manager` to support various optimizer algorithms.
+
+Before 0.2.0, the `optimizer` was strongly coupled to the "loss scaler". This results in users cannot use multiple optimizers at the same time when training model in fp16.
+
+**======= Before 0.2.0 =======**
+
+```python
+for iteration in range(1000):
+ # zero grad
+ optimizer.zero_grad()
+
+ # ...
+ # loss scale and backward
+ loss = optimizer.loss_scale(loss)
+ loss.backward()
+
+ # optimizer step
+ bmtrain.optim_step(optimizer, lr_scheduler)
+```
+
+The `bmtrain.optim_step` allows only one `optimizer` and at most one `lr_schduler`, which cannot handle some more complex scenarios.
+
+
+**======= After 0.2.0 =======**
+
+```python
+# create a new instance of optimizer manager
+optim_manager = bmtrain.optim.OptimManager(loss_scale=1024)
+# let optim_manager handle all the optimizer and (optional) their corresponding lr_scheduler
+optim_manager.add_optimizer(optimizer, lr_scheduler)
+# add_optimizer can be called multiple times to add other optimizers.
+
+for iteration in range(1000):
+ # zero grad
+ optim_manager.zero_grad() # calling zero_grad for each optimizer
+
+ # ...
+ # loss scale and backward
+ optim_manager.backward(loss)
+
+ # optimizer step
+ optim_manager.step()
+```
+
+Starting from BMTrain 0.2.0, we provide "OptimManager" to manage optimizers and loss scales.
+`OptimManager` supports managing multiple optimizers and lr_schedulers at the same time, and allows setting the loss scale independently.
+`OptimManager` can also manage pytorch native optimizers, such as SGD, AdamW, etc.
+
+### 2. Pipeline Parallelism
+
+In this version, BMTrain has added a new kind of parallel algorithm: pipeline parallelism.
+To enable pipeline parallelism, one line of code needs to be modified.
+
+**======= ZeRO =======**
+```python
+layers = bmt.TransformerBlockList([
+ # ...
+])
+```
+
+**======= Pipeline =======**
+```python
+layers = bmt.PipelineTransformerBlockList([
+ # ...
+])
+```
+
+Replacing TransformerBlockList with PipelineTransformerBlockList allows the parallel algorithm to switch from ZeRO to pipeline parallelism.
+The number of stages in the pipeline can be set by passing the `pipe_size` parameter to bmtrain.init_distributed.
+
+### 3. Others
+
+* Supports BF16.
+* Tensors recorded in inspector supports backward propagation.
+* Adds new tests.
diff --git a/examples/BMTrain/docs/UPDATE_0.2.3.md b/examples/BMTrain/docs/UPDATE_0.2.3.md
new file mode 100644
index 00000000..e95c6867
--- /dev/null
+++ b/examples/BMTrain/docs/UPDATE_0.2.3.md
@@ -0,0 +1,26 @@
+# Update Log 0.2.3
+
+**Full Changelog**: https://github.com/OpenBMB/BMTrain/compare/0.2.0...0.2.3
+
+
+## What's New
+
+### 1. Get rid of torch cpp extension when compiling
+
+Before 0.2.3, the installation of BMTrain requires the torch cpp extension, which is not friendly to some users (it requires CUDA Runtime fits with torch). Now we get rid of the torch cpp extension when compiling BMTrain, which makes the source-code way installation of BMTrain more convenient.
+Just run `pip install .` to install BMTrain using source code.
+
+### 2. CICD
+
+In 0.2.3, we bring the Github action CICD to BMTrain. Now we can run the CI/CD pipeline on Github to ensure the quality of the code. CICD will run the test cases and compile the source code into wheel packages.
+
+### 3. Loss scale management
+
+In 0.2.3, we add the min and max loss scale to the loss scale manager. The loss scale manager can adjust the loss scale dynamically according to the loss scale's min and max value. This feature can help users to avoid the loss scale being too large or too small.
+
+
+### 3. Others
+
+* Fix `bmt.load(model)` OOM when meets torch >= 1.12
+* `AdamOffloadOptimizer` can choose avx flag automatically in runtime
+* Now BMTrain is fully compatible with torch 2.0
diff --git a/examples/BMTrain/docs/UPDATE_1.0.0.md b/examples/BMTrain/docs/UPDATE_1.0.0.md
new file mode 100644
index 00000000..da9fe86e
--- /dev/null
+++ b/examples/BMTrain/docs/UPDATE_1.0.0.md
@@ -0,0 +1,72 @@
+# Update Log 1.0.0
+
+**Full Changelog**: https://github.com/OpenBMB/BMTrain/compare/0.2.3...1.0.0
+
+## What's New
+
+### 1. Using pytorch's hook mechanism to refactor ZeRO, checkpoint, pipeline, communication implementation
+
+Now user can specify zero level of each `bmt.CheckpointBlock`.
+
+**======= Before 1.0.0 =======**
+
+```python
+import bmtrain as bmt
+bmt.init_distributed(zero_level=3)
+
+```
+
+The zero level setting can only set globally and computation checkpointing can not be disabled.
+For `bmt.TransformerBlockList`, it has to call a blocklist forward instead of a loop way
+
+**======= After 1.0.0 =======**
+
+```python
+import bmtrain as bmt
+bmt.init_distributed()
+# construct block
+class Transformer(bmt.DistributedModule):
+ def __init__(self,
+ num_layers : int) -> None:
+ super().__init__()
+
+ self.transformers = bmt.TransformerBlockList([
+ bmt.Block(
+ TransformerEncoder(
+ dim_model, dim_head, num_heads, dim_ff, bias, dtype
+ ), use_checkpoint=True, zero_level=3
+ )
+ for _ in range(num_layers)
+ ])
+
+ def forward(self):
+ # return self.transformers(x) v0.2.3 can only forward in this way
+ for block in self.transformers:
+ x = block(x)
+ return x
+
+```
+
+You can specify the zero level of each `bmt.CheckpointBlock` (alias of `bmt.Block`) and computation checkpointing can be disabled by setting `use_checkpoint=False` . For `bmt.TransformerBlockList`, it can be called in a loop way.
+
+
+### 2. Add Bf16 support
+
+Now BMTrain supports Bf16 training. You can simply use `dtype=torch.bfloat16' in your model construction method and BMTrain will handle the rest.
+
+### 3. Tensor parallel implementation
+
+For this part, BMTrain only provides a series of parallel ops for Tensor parallel implementation, including `bmt.nn.OpParallelLinear` and `bmt.nn.VPEmbedding` . We also provide a Tensor Parallel training example in our training example. You can simply use `bmt.init_distributed(tp_size=4)` to enable a 4-way tensor parallel training.
+
+### 4. `AdamOffloadOptimizer` can save whole gathered state
+
+Now `AdamOffloadOptimizer` can save whole gathered state. This feature can help users to save the whole gathered state of the optimizer, which can be used to resume training from the saved state. For better performance, we provide async-way save state_dict to overlap I/O and computation.
+```python
+import bmtrain as bmt
+# you can enbale this feature in two ways: Optimmanager's or optimizer's interface
+global_ckpt = bmt.optim.Optimmanager.state_dict(gather_opt=True)
+global_ckpt = optimizer.state_dict(gather=True)
+```
+### Others
+
+* New test for new version BMTrain
\ No newline at end of file
diff --git a/examples/BMTrain/docs/logo.png b/examples/BMTrain/docs/logo.png
new file mode 100644
index 0000000000000000000000000000000000000000..2dd2f1cbf0e44358b5bad8c84e2a3bd6e0489662
GIT binary patch
literal 47918
zcmZ5|Wmr^S)bB7bFbp-o(49jK-Q_TJcY}1Qv@}R}H_|PFbm!11B?yjW6t}fpxHt?)(HKcr8tA0HDe)!wh;gv<>^J#L$XAZ+8?O8cl
zoSRw52r(oHV;Ba%koduAasI%`Wc}~I6qhaK<#oM1J$0!VLOnuzf#+k1Z~XbJXsKK@
zzF(0~nlrkG_ksTZ4w-prlkFUlic;^FxwY*-59g^j=UY}?hs8m&Vgr}SsCPy)S4PiL
z%-#D%9Aq!n&;2K)*1%UJ#vHJgrt1CBO`ude=+p9!AJ_HB_DKc9Z|D2(z2^h9pPzq!
z*7<%b3SdR2xTwA-PA$Z~My6PotyjwIT;uZK(_1trT9E8cmN-WG(e>wF^^C8>10fG{
zUm-T_H#671W_XWiE-W)_xK^P!FuW-0Vt^}UZ>r1W57Xpru{hULJVhOD*sd!Eq-?uA
zgF;w`(Exk41iOFCs4f8VDFuRS)l^?bgYoVjz#a~yQ
znOth7?u9kyMb^)>tIlpNDd?o{&N0#=nns%Z?`7P>i#`F@I{m1v3WkZWIxj9ph*CAQ
zv<6xho%`7nb~|B-iy~Ee$FAUK0ttH%h0R)wx7ZUg*T=f$A`<#wU<|E>G61KC(RB2M
zfa;Qv0Y0VHqw&Ju&Bw>bPWPVU>T9srB;JBrlU@pwY(&|(+~~6=UarY!2%fejGemL3
zPd%Jh^I>z+JjpKTuWE{GQiCs)@L*z*M}*Z>sl)AN$93Gnxc6K%MB~M;aZ4+1!4?E0?
z(MB$ierHpmiA1etS637MdA)rN%bKH&5jx)8-34iwJRKf&0CsKcBbqo7+qXY8Ij!N)
z{SS9DR)4%+o^SSd=FWWl%gZ~?7n$}&CrLm#EItBdUob=nEp*s!Pde~jsd|t@d;(&)
zaDW*xyCBi5E^IHsGBLzwq&oOk^5YIB>|*7ommIvCgcpLt6=5m1SSp~DY9+j8@{8&;
zVG!sN5jAL3mp$mUjSfEqnB|to1j~QE^q!aTgBs0zSE}&DyrabUOJ>4F;n_T!88=TW
z>2oCPKS{Qnw*E?L>UXILSOCizpYFFSM?^t6DvAju%rzgXQV4%I`HMh~9!E9+qVdQl
zRhHfC$XMhq9m!i{F;8~bP2d&L;*zv>^+1GP#DZTok;BlX`kJe#@6g36!5g;8Ek*EN
z?F7TC-~;cH2Me$B{@ae8G*QZc-FP@xo>gsWQXSJHj`h$6pjz3}(57EQPJ)MtuP8r1
zKL(Yj{5wa76+wSAU?K6m>AsULgrlYDto>*NU6CpqsBcmGNDCW!@lkIdLCgvrrX1;T
zg6Gu9Gq*uMmGd`k0D$GtzOW+pCe$2`>;cIdlpV}JN~0F0XonK=%^PxG)Fm{T1n?#}
zk9<;D&qdl%3*`>33m@$0Wdxs0rn68;JvKi1nq&P6f(k{by~4R$!{K%nyW=en=0$gm
z?y8!#g6BnOrdaa$B;sEO14qid^9t_geyPt3_@VInwh>sem)1Am8(@$X)E&__m;M#-
z*!;%)#|#m1uzyT5B~_S54931~oUQsqVdOgL$7h$}cPN|;a7ioa-Ci6sT}6I+#}dn!
zaFOc7L`ei-KC(f&DiXCq3rfDK)vr;q!SLWDpJ^ET?P)(Q2M7{K0DblL0tJG<;o_s~
z$P`WZIKmwe&Sdqc>SZj6S?|qB_#la
z)T2V5hDgAML$cL=hVTa{UntOrzYH6A|3}j*YPl|1^JKD>OoOOYxoI5*m562Z{umyi
z;l$K*lgG%K8g%M{+NrW#yb1@BjWfl;yHF9`5!^?ruQehUW_MA0Rui5NiA^K}y~I@)
zkJs~qv|fLh+hzu_=}bTyu{PEU&Cxw{l%;mI-*}3tI5P0w89ithy#JKz*0s~4?Q6HX
zib*7`)BEjh9pT7p{yR;`zSn)gDZZrw@x8P{Bldbsn})f&{O5wFtfHW7M<=J$vQ~N}
zn>=I^V~ANje)OB9H{NhDB%I+%##pi{>tr3CelFs$S|`L!!jl=1fZ=>%^L>7gZzbqY^bW
zFPc^V*-qn%1fPLdp18!!#6T7&eRh49peRr#nF`;wQnHQgP^u^;?BI#b$QUSlfjjhJ1;U&5w^DgwS(u{e!
z{rS1Q59-;y3#@05XGwCXh1tOj3~!?3u+TB|pp#_wZE&udwov1Aq7HuoEPF@3w@XvT
zlVXc|{fvJ_Vn}RZ8WXD{uAeHqu5sWOK~t%loN2`Zr-Lg!=hN@h>97fQ!RImy)v#Z0
zjJPN^m!Ef}!Za_rj>3s-w03b#fGajrKdftE{S0e%B$D?k%|#imWgkVI62_Khhld~G
z>3-Ci`3&;ElDOG!JIX*$n;<(wI1C%+S{cti=-n%UCB=Sw(7H)=Z!HXb4LIhCGN2%L
zk$8(k$YE8OzgKgYMxHkBiK;wL>-dW>=VTyJmF}XdUNHrrX%o4ziFGIDXb;V#*}IY!{N8-@$oXTWtFASSy@;-?x^S?epGe@f4IS;xoDuE{0P>p
zN42PA7JKxA?-|!@|KmR2E8lNd`|C}^g1DcaLI5T%juflagrf!uG&DnyWUqIkYi~Pw
zM*=y9d&K{owu(Le?7tAz#2H~X5J-f!m`@ucB
zTZX~EpZ?Q!c)a3pzizNq5^EP&CGZUQy~<4rMWA;j1C#o(p!H#~8BsB=|9e=hmN=PNc+?H66E%XRQjbATbhssioR74QsZA@I@P3-8_)D)k1S6pFFl(%C0XKzf1JyNN$aY
z8YvHXRALTDj{nJU)hzbY)?dM|$p=Iv&F%p=7PGki9zjQCu}TvY6Su~$?N->|nr)c^
z>4JICFf|T}Yiv4hE|8TE56yj-xQW~|RXh83;-^uDE9BO(kCZ&hJRaHYdWc*BSlLe<
z8@q%ECMFm={bS|(F}&v{|LM>-h8_x+D&VQ
z-Pd%;z{f*=gGAqyO64vnH!mO3N|m1^iH`Yx*UG@O=JU&?Xj5e!daj%x%c9YW@yjQ}
zI`WTYRDu{TxOkR(`;uqa9ha#D<>JCYB)VCKHx{;tlPg>i687Zov~~Xb?@47@L9ME|
zcXOuwRO9$htd3qQFm;kw90hVrgz)coTlOSG*m|(MUGj2Jj|jul9?soU>H^}eVXNHG
z`QRA1(U^DQ5$7W6-Y7-)j@0_87Hpib#Vt!vDg#~s!+$RjJT#3IO7FWRnj>F5Ab5T6
z#x%9HwXLfI#m2|qed+Zx=CynsLe%p`-Dc#E*qHcR>&RpiU{mwrL;abWRpIs7q*1`0
z{^C#Tz@wq078WTUP+O6?P)qTqDl~iKjW=+&61T`o?^vFBR4vWclsqb3OZQzdK$@h!
zboV%+rWr}TcD+FGq006L;G(%0qZjzf;>yDb7m^ow;1v#bP{
zB7+N;?k2|P?Q>WRKg6Fc+Qhx-I?B?oCD$>>?a&-bh)-l|qhTlFT8&c`2eim6!OMI9DHGlXR+uFKI&_EPfizdyRorhm}jz{S~V`|
z*e<{`wh_*$0mL?n+$wH3>7kM$2ItPRo^7v8tSsP&;bT)|?y0AGry$X|V7+mzavi`G~D-oyK%tD{i=~ze7)h|y8x{~yp7bGu=
zqB}VwSccvd*|Ny5X{`SY^rlKlEr}It1pV?!c+c%$k=1>Cw5Y7TVan5Ch0k_j1RhRY
zz`3mmVy9%WvoKqYwiF`s4F=;x@!Gx~V$DUENpY5!!_px9>$O#ft=HxzzSjB1=r1pW
zREbfuuqGbsitoJq4Js!)b7AU{VLfqGh?%sE;cgk(DmL2?3Rc(YCjy7MVp<1_%oF3>
z(#h@^>te$|e~E9uDSbslK~m5scr$?B&br%e&L{D{a%U2L!1>FvpWkNb(qtP)gfx9LK-gV^DmuM`b>l~=P046TslbrB-~^|)!u*l|Ei;H2&6
z8`js{`0Wvh&L}#F_eb#oTI0G6{q$ahTR68ClLw!)=;B5^95-h=;T6%t>Fsa{nlwEl
z%rDu6U^QhHPW#XsvQ)9B3)pVfp5CP7y~P-wDn6-?+G`BsEpUJu_72+n{?(|nH$@X_
z!Iw2rm3Tk}jL~bDL!v*PjxurDD|4KH#|Yh}k=9L^dLGMA8)G{L5*$MvFnRM6x>?5i%^k*GuukTrp
ztG)foydCCjsC3m$S)&^cbyz2$bBWFJBBElUW|OFkb#qnUl=-@h}oJ%
zhYOJ^98Cu&uBlSabSh34*}b98h_mrdo7J=7e4~+oY)QgK@PIj<>Sc(InBx2QbSiMp
z&hzF4q`fiVUNwLfKw^lV`!z79!DltYQVg*z(zgxC^CpFBHPB$fSPyX{fi$9}8iy5U
zkqV_cV$T!cQFxoMsOkI{@^&;?6|@weUn8z@F=lpET7CNlUn%vw;n7ZvW=z8p`UYv6;?immCNCh!QIKYHq|sqI4-xCi6=He
znoY&&%^NPem-7lwgO}69hmsJ@TBBq3looeN|Axb;1H7ra!bJh;mxKsf0%yX|mGTMV
zlEd5Y6of_QJ(}+<2w&K0F^%zu(e%HebjH=W8v=vdn_sr%p=|c{xeIXgf}cbTpHn)x
zNr{PV1;*DV%Oo?SU$E0ShvTI$_jeTWYouujrKf7n
zrO2vz2;S%wt6S$iCZtjnM}ZxpqT{PDQ&UY62?9^K~P-3Gy|gK>Xn?5Du|w)DW~RYzzCVxN-Q9N(PS&`r=5X`)>04NUQ+I)ED12
zG+7z6tM*{2Ns2xL4s>Cp?^!B2Mjs%g%1w`~?I(8qAj-)L(FuO=>zQ}X2Xo)-~U^$VI->WLr{|A+b9b4RvXQ541Km&j@lI8uk&NVvwn{Tg^A>kQ=C7u5Ku%qAw$09M$
zI%>Lxu{b#E)m<9iNQCd}!}?lT-?7&X
z^CK|LqzYBqQrec&g=XqBKT7}nOxqxHgcnlh--2rrrWVwU^gstd=)da0uD7%XJwBF#
z5;4x{^Ash>cq2lG9`K@X;^X6+=Yf|%(Ky7|z85K^EP-hX7=ssE3Vy39wWjY}vzNza
z0uT~+l>K8=YapI%nV{WwmAyi+FE0Kfd4r)I!AC>EXEZrL&Oa<9dHN=HlPtP;H#6^*_Rw`38C6e}qW2Ea*U3Ol!N+z)mxqLql%lXTo&mfoz;XbZnib}zJ$
zykFr9IaV(y`L$ewzgdRFWgm7Ag!HO%ix=#UBwn{>t{O_(aP4=a>I|RUZX6$$UeMfr
z@!P6OT9Ix-Ae!u8K&$a^5j0Wm>BQ%~%&Is3EJ>#I70N5tE$$?L#6_tM6r}54@X?u=
zCl{F(Uf*Mlt=})DcN1Y4wcroB*tU)+!)RlXp5-j~RV!Rrs?E5B8JzhnE%4ZxrFEns1YwPEfER6D(rdM4P~R
zG)Hj}pjc|^C0ix7id^ci
zKjBpBz^fRU;k6WD;E*w9+_(vt&2Rf>^PN>mw~O
zLnXqE$Hp1RSK1=r!9sHg^ekfxex%BTjlR1}9tN&_gp@2+eg*;j0h%@uj{-MyW?%X;
z)pKPu8Z{%R2;`rd&;;*pasf-1mh`I7IlMN60LJ-bqnSaxzb=K3kA`>My%FZt>o6b!
zvv4+u1?$KAqi?sni~b_BJJgbWnOy5{Lo20f%mF-YXBYU19E@sX=yo^Xi>MsntR-sd
zMrn8`W7`ZlzzeXNe02#dnaWn_$z&hm2K^;{vMtq1{O0i^m*E;#gTc1}WuyDuzh}Uqm@xQxyIz>0~Y^HOs
zV&)0V7}7YNR+d2MTB6R(%H94pE-bg!R&H?5VDk;Qi2_^u2}g|R)BQ0EAX^b>iXgJo
z+T#ttWn?)I>OAtirdxwMtjd8cgO6A~y++Fhm?k&%VV$Ifv63q*TQO_uN#c+V@RrOZ
zn(>-u(uN+6e1|FQ5zjjU&OhtKEu_%NBdjSB9@-(ueS`Unu!KeMJF4Dj=1i@-FRsJY
zrb{qkxWcH_ko`sQ$kS;DDoRb^A=3(zWUcV+R`g%+UC6iXp-3XR+IWORz|^_1E8g>N
zls|ys1WoL;VqJgls&f1URnvn_8vp1B?5UY_Sjf>@81=TnDHFKyll=LDQXm%%_nXFSziuc*yDVAD;jmYZR>!0*6pT=Wlynipa249+
zNr$bE+f~KQz>m!4WrlS|j?I4CtUnf6D7AuxTtxUHb5Fu!yvf%1xRk_02}vdWO=}4i
z3e)rGQ}?cf39AC}^>HYk{wR;s-t!w(Ciij(6AhBR%o!5F1K$?mlJRVW0>XQWBf=fkFaWYx^2he*fMM@b%?WV(=Wy`kDYvXiWvi
zS+yT2LqWL1NT0XS+owm;isL5vKjqG#GOP=q)okcTdh{b2&(Px|6lM%vli)-`XNqWbWpfw(v19=(DoJoe0oKo3DLNlcr+
z`~%{D4Nr0JhMVms?qtTX7+!RTF(EN#(AgA&cVoj4*_`!(%W)d>pqta3GB(zt!DnGg
z{!BdxL-5ga7RuJ$OslH81%aJC{o`P1Jw~IK9hrW3giBv9H5J~ULv0{0!>gW`?3zKY
z@4Sd6PL^?O$->B=&XJnkL@{j7SVUCS=3pedFX{For`f43##54gzaSjAsDg;C?wJS{
z%{4sZwf?S3G9qTqb&WY`;>sUsD+&pj2oa#(w~QD$qs%5}Y{3*e2DB@4+jZ}AWgK7X
z>$+k1CtpR9YzY3lQOv@a$F*psJ$mpQH^j4+IS5pUC7y5FnbYCa+1@fF&IXjwXo`;`
zos|)w8Xu$l`(iZaUD|FfcxU3o4>+>0|>vW$(1hA2-li@|T-MLzxcud;i0P?7kjmxl9IoFdW3ciS|O
zuwiZKPYjLZt4n3Lr_bhAb7uF|0Xxrd!?;*0vcG{$u;FE3S>3vipW56@^&RO0nL5!q
zoLKo%dFEeNq8k&Q5$`Y1J#Rr>>p>P~-W|p(7(Y3{Yx(;%dx)4zLo`dHMRB#*32>(%PR!Zc7BSWYtsJLBx
z^!5)A$z@9#1v?Ax{5-sZP!7g;sN7VJVng2pPsDXeXX3!076S?`e6VY@-~Th_yu0tu
z%{Xou8O%R|EgYpiE*4Ij@V*;O;%v6>UpNE(Yg!G{
z$V1#bmk}6IBA?f3rS|68#PMLTe&f&%5Gyt}n1oslZN}38u3|g6Xd{0kj3kX*8!d$i_nVE5I(4#El+H8G%SE}$N#JCYz6ZycN(1;mM
zLm-OOj%la)xSlabu1a~J>228k*v*nk`=gV|K?uG#=)NGjiV;uu#Qzp^M^ta
zKn5J-^xqvlmalTX2lt$^(i^idS?w_A=mYe9WR>Q05$)
zXA5EeK7_e)DsNKX!#NyR9Gt#`=RoQi<)*vG$lxr!CTC9cVUE-DylId7EQ%gz-~hlN
zhGhFY>a!QmG-$s|Ql&Y{Q4$7<&QecpBbOXL>9WnYUP(`5p6nN=u$C3VT8S=m%F7^7
zf9=nNsfN2H&uLqc)>kZ$ShmQM#vi`=2irvav-9#`*q1U2tB19fgROs_OyO1}c~j(P
zM$bImLT{uT)r(Gu-4WU$>9Y`9NwWXR)HkF2`isZcZ_mi6C)m|>6s~oQ
zD4H%$B>IGXX|F}CMxI9Ec=Rl?tg3qejdYB9s`9q3$(s;qxQFbbV8Sue!Z7AMp6cF<
zr+miWmm&%yV&-_YbyiQT{LwtU4RDJ?GjlRV>I4+38U!&Jjw3de{oYXg>4s6qC
zr2D$np%<$&eB`JRIw@rB(J!_@ZVQPs
z565FD0BAm9;W$2%zruoBksu2nJ;jG}+fT>gbB1~M!Y-lYYF{O2?H^4!x=l|Xy1KgZ
z3P<7qi%$0WfQu}X#OU&^zg0xW%>G@sGR07IZf|;Wz-@Nmv$qJ|kqb+G%fMPUKRuy`}
z;9{oSy`|irvd3{fzf0LF>Ay|0)jeZwt%+zlhJ;&eyBIm|G=)QcvHQ)F2hd~u;=jG~
zkS&qI6F$zmCZt6eM&R+0l*kuXy8rClsWx@^GvU@CSb>6DuS
zK9hJXVB~4>FmVo>X*Ju73|5wH11lJ}jB6U`dElTZK^x6HAgPWg(1PzU#H<27w}`J0
zB)!v+B1tiY`RAm*S~lL^zCKa0rgBpBs=|6(PkF64_gTI5=XqA$;sZ#5`Q?*JBX>wB7M;lx!$k|NC
z{@X~lVS>z(nPS^Nt%QB63W2KBCliJVDT0i&aESmfT{Q+f>^I<9x;97dvFUs>Pl|cO
ztmP&UvlCry4leooc=-PA?oOkE*h=f-7p;aH2l)B&(C4X}T6Lkq2qi@T!?3EdBGT!c
zB+A<4iRZJ^LU|~6xVBaSXy)_U^{-{4AwDHGVn@gz5D)(%ix-X?O~090opx33RlYlE
z-A?<9XW1#u#gsPRHv+xn#H0rWmBQi-SkPs{zS-
zRpW7yOJ;j5oF59Gt&!P#iv{(qyc<1aWrAZJC(dN0eo}YA6^_BVkZ=ls#(SOc_3{WCAxE^1z%E8x43A$2|hdQ4$8G_5o3hyxKGe
zZc6*{0hTHLSIWUwt1sRzRkaTpCx9_8DTWR|6dwmq>%X~=I{X~(!()CUg4RnXgVv-m
zezzrHy(727yC4WQCf-5{T;CZ;4K(M8eiJ8^&>To1i^jQZSpu#Nk(}B}rhY{%41z>M
zCXwU=e}A9zWzY2Tp6Am_V3h4i;rBAl7J^6eUEJOlM@m_UMIC>T5wA3QT?1MduceKWd)H7;+iXEB}4BGZY$G
zg!V=;>OHIC&TJNG;`;^gs2nrSKKoTw!Pp@g&r{>9&-j*=F~aR**alhAvA^Jnl5OZu
zO`k-7gnwWu6DyYXhhy__r(zClE-VFcgWbpQ;;^K%jctZ=BQ$dvX$~uvv@=6-Lm_i2
zB#Jcc(Z{TRGfM&s7M{wGji-Y(=K3y>^oSOC^paAxaLk8bUOgDg;pzeBeyQTwhYj=o
z|CmmTHaqLNS^Z9)Ld?-tw{J!>;TNx}6;O2-C%wb|=bJQ6x591)Q}H_!s?vowP_9xk
z?UD{ew1*GO;BpFgILRw`PMw_z6hoS)1G%y>slnp-Q62|oWIcVqh-~=nJ&iOnhO*Nc
zWs$~T5gTf8B@w;&%l#|XN~@P%;1K+VOCU)SvHrEf-Cv{+sV2yEMSxrV%q1BC0ZEPR|
zS9z7CJ$G}t@F1NMOG-(!>q4;J%s-9sPZ{eYuzZH?djDanJX)6T0(i*iBOUl-{>*^J
z_X~Zh>fh~|acopoM6yrmA)N!jG6-cRfe9p&BX5RybtlR;MC9!L3y}9*M8L>B*Jg6&
z8)(jO_jp=?;?VgSvBnJ6Tuv(wQfD1(+P^dF*-a2}zis!77Rns2Ptqy$*Q{|g
z*($gNJ%R&O7JitF@92R-42>yCsGk;*C!!Y~)c-Q!1ZjCzlmD3Q|HmxZaMkC}6-Rj!
zO0Fz>N52ihOtJR^;GYL-NDl0f4MAu@8k-Y_vT`NliaFJ#MGhc6QV$~!j
zE)iGGIU#=5=rqh6oL2lY${buCkp2)=LV+u>+U5OOf?K$zA^1*ZJ9Vg`pX|#oMtj#N
zB_7a_>=)=k{hDOq*&Uq%H?>OD6oMR}4@;1Z9vd&3RkMh~#TDLHeA@kqhz&hovC@5C
zv&rv2x*I3lU}GZ0P$C1Dc!IVq^;768b8v!R(Eb7g}mt<`VT0tGH=J|HLVXU7
zVjK5bb|Hv*8DC1ShXLEkBG8H>Tf;q&G*SN)J*JOq18fb`OcLmjO693KrkqolB#RJr
z`+oa8TvHd!@aG~{xpDB%uf@gY<`aQe{&YQ_RO?P($Cw`M+kWc}YXO%hSCng_y``n5
zbZjQ03Q}#w8%`1n)5wd_QLXm(x+D$fkv(u1BM&C&D_i1{BMyMJ*#lSeka4nl@K|%a
z!>lCvqcY8zuSSOOik+W*&~z1w{JTV%eAy^1+syVjlPp@op5vN(hV~(nLCOzp)?}eP
z#&YoXM7Kk(mL+11P=l-r+oEua6&cn1K-<+%75i?FB{{n7rGzUP??0W~OHR7*nA3@F
zA*VE9()jy9OiSbaY`;_WXX=uj6bb1Rj?Lg->6G}-No6j2+wV@=7o&|gZ>RODdJ1f#
z_*&HIc{~zoz@)jG95_8#hu)p#I#6(b@%?B!FbsezW%G7^!gJR08e-dpyAzfE!J
z!m!iA<&Tv6Det?pD(usfQG<+XRqDjEs$IR4C6QZJ+#9_rsIwMU&iz|a7-q?-f#BD0
ztvfa=@FW!Db=V=-D%=+$+NBW}kGx<%kipogFVp5yQ_bjgBX8m6OpKF^)N4k#+#_o_
z%;#RJlo8$CeAu6wYOZw;IWfP_IjG;i`Uiafx6`DeBLDfN$D_g5?^Zm?ZAND0IWJDb
zN%*)xhx(IzPkwjv%mu%<%QMdjp*?uFPTXlERGsiFY6Z+o`Cy1&MAaamwY5*$5Z~P%
zP(O^=tzVR&WgY#O2N*KTmBl_xLR%lA|oX7U~J7D&Ye|$pn&m_+*
zJ<+et%+C-699W-e$G*hz3j@p>(2{6nOYV!Qw;GN;(=A0D;~h7WMAkz0l{Ky3v_D4t
ze3xDN3B1DtD!Psko6)9&XJos)Ih{=sb%OuSUCxY|q
zd5%)@LBK@vyK?ZjN&W$X5}TAFiKo@9DtX)hf4!@pGy(e*o)5B-XA9f{J+6|`kQ%yt
z;Jg$!CX8YvZ$g+j3q?F&p@SvRAeaTZWu{<~*$AMt8^jn~gDoCzVoqu%WM;|LHF@j+
z#f!%iUz-dhLL62fa6K4L^xS7em_7Oup^-2IX5R}yUHp2h4H^1vsrF$KCJi=*dSt|0
zq5%8{FkO)iX)$)_p^y!V<@^*b#vg0^zKkeClAXV1sV$h;iF2;vH^K(={&ycJ;=gNi
zt|7;O3vdDb*FJom%6=JKDOe)2%8B*sbkP8Nmv$f-e}~sWLXugcL)vCRjb0G(ggH1P
z$Mc+P*JInTi`bg-9w{`1m4{>Qlevcf7fl21Dn(sOQr;Y=;X>q$#HGRiVg
z&3V)lT&Ae8k2!$-v$jB$3~j<&%7xiYXMh@7o_7qeTYNP
ze$K&*E2Ye6ZjrkA`l_Vm8l6B6-Tg>n++J$&`Evg@*0LWoO`p4H_y-(*w})T*IF3d^!tii&mn*Xx7Oqjl~c%^@z<*nfA9nf
zVejpW)9;wmY+y{SLLNy}w{u-e&e6;-hkFiCM*KI5+A
zbp>-y)GcbLK;X?8m*QDV$I1*LJ~wQ90jh3PtL*dVm~H4^WWcDjkgVUvznUOb>`!3SMkSj0C}1?XFgBo3@A
z4P+=!w(-gkOUc8=VZzE*;E7#L$ukS9_zP2!;8Y;aj4dYctA6o{e-ZT&ajt~K(