-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[Dev][feat] Support CUDA Graph capture offloading modules #3219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
1219a26
ce6e661
d04d741
2fe4aeb
fb3f7c3
5001e2b
9937890
6c83118
33a38f5
6c76b07
4d83f69
8e72b44
1646f04
43973a7
a177cf5
f7cfbba
6d475ad
089da6c
35b0f97
df09b85
06ef4e2
12cb8de
3cf19b7
d0fc888
bc47650
b7c0fba
b18e69b
ae4e2b5
b797438
6cec22f
5a150c7
e6cf8b5
60e3082
256d79d
9475e3d
6daa2a4
93c0827
f22a194
a9d6633
9d766e9
9d1fe34
08b46aa
d33b3c4
884c335
df2e839
569f347
7f8109a
19b35c3
e867835
dd874a9
b8c0b79
994fc5a
2688c7e
aab8455
0a92566
0df5134
62d36f2
17c0eb9
d979f1e
53a21e1
600cfe7
09956df
18a2d9e
885164b
688c2ab
2b0276c
0c901a8
e606424
0cda334
0f0d1ed
254a2d2
4ba07bd
e585f60
1b8050f
410f879
162e388
504b3a4
01e2ad9
526fc35
6264494
7640bf3
15641c4
dfc8161
e362f79
2da15b2
221cc16
726c526
f19c4f6
998d1b0
4bf0085
716e12a
cd84623
61b589a
0200121
c8bd90d
ddd67d2
e989b95
b481fa9
19fe6b3
cf04b4a
ce84682
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,31 +1,141 @@ | ||
| # Fine-grained Activation Offloading (collaborated with rednote) | ||
| # Fine-Grained Activation Offloading | ||
|
|
||
| Memory capacity is more and more important with the rising of extreme sparse MoE models like DeepSeek-V3 and Qwen3-235B. Fine-grained recomputing reduces the memory footprint at the cost of extra recomputation, while offloading could utilize the host-device bandwidth to achieve nearly zero-overhead. Fine-grained Activation Offloading targets at offloading the activation at the granularity of specific modules, so that we can calibrate the amount of offloading activation to maximize the training throughput. | ||
| Fine-grained activation offloading reduces GPU memory by asynchronously transferring activations to CPU at the granularity of individual submodules within a transformer layer. Unlike layer-level offloading, it allows precise control over which activations to offload, enabling a tradeoff between memory savings and PCIe bandwidth overhead. | ||
|
|
||
| Currently, the supported offloading modules are `"attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act"`, which could work with fine-grained recomputation to release almost all activations of a transformer layer. | ||
| ## User Guide | ||
|
|
||
| **Features** | ||
| * Support PP=1/PP/Interleaved PP | ||
| * Compatible with fine-grained recomputation | ||
| * Support FP8 | ||
| * Support MTP | ||
| * Support mixed dense & moe layer | ||
| * Support A2A Overlap | ||
| * Support CUDA Graph | ||
| * (Temporary) cuda graph scope cannot contains the offloading modules | ||
| ### Basic Usage | ||
|
|
||
| **Usage** | ||
| ```bash | ||
| # Enable fine-grained activation offloading | ||
| --fine-grained-activation-offloading | ||
|
|
||
| # Specify which modules are going to offload its input | ||
| # Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act". | ||
| --offload-modules expert_fc1 | ||
| # Specify which modules to offload (can combine multiple) | ||
| # Choices: attn_norm, qkv_linear, core_attn, attn_proj, mlp_norm, expert_fc1, moe_act | ||
| --offload-modules core_attn attn_proj expert_fc1 | ||
| ``` | ||
|
|
||
| ### Offloadable Modules | ||
|
|
||
| Each module offloads its **input** activation to CPU during forward and reloads it before backward: | ||
|
|
||
| | Module | Description | Notes | | ||
| |---|---|---| | ||
| | `attn_norm` | Input layernorm of attention | Skipped if using `IdentityOp` | | ||
| | `qkv_linear` | QKV linear projection | | | ||
| | `core_attn` | Core attention (softmax + matmul) | | | ||
| | `attn_proj` | Output projection of attention | Must be used together with `core_attn` | | ||
| | `mlp_norm` | Pre-MLP layernorm | Skipped if using `IdentityOp` | | ||
| | `expert_fc1` | First FC layer in MoE experts | MoE models only | | ||
| | `moe_act` | Activation function in MoE experts | MoE models only | | ||
|
|
||
| ### Tuning Parameters | ||
|
|
||
| ```bash | ||
| # Minimum tensor size (in elements) to offload. Smaller tensors are skipped. | ||
| # Default: 1048576 (1M elements) | ||
| --min-offloaded-tensor-size 1048576 | ||
|
|
||
| # Fraction of activations to offload, range [0, 1]. Default: 1.0 | ||
| # Useful for partial offloading when PCIe bandwidth is a bottleneck. | ||
| --activation-offload-fraction 0.8 | ||
|
|
||
| # Reduce offload amount on higher PP ranks (in bytes). Default: 0 | ||
| # Higher PP ranks have fewer microbatches in flight, so offloading less | ||
| # reduces overhead without increasing peak memory. | ||
| --delta-offload-bytes-across-pp-ranks 1073741824 | ||
| ``` | ||
|
|
||
| ### CUDA Graph Integration | ||
|
|
||
| Fine-grained offloading is compatible with CUDA graphs. When CUDA graph is enabled, the following constraints apply: | ||
|
|
||
| - `attn_norm` and `mlp_norm` **cannot** be offloaded (they cross CUDA graph boundaries). | ||
| - `cuda_graph_scope` must include `attn` and `moe_router`. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can I use "moe" scope if I'm in a drop-pad MoE? Can I offload attention part modules if my cuda graph scope is only "moe_router"? This may be needed since some cases have dynamic-shaped attention so only the router part can be captured.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed this hard limitation, now the scope could be |
||
| - `cuda_graph_impl` must be `transformer_engine`. | ||
| - Requires `torch >= 2.9.0` and `transformer_engine >= 2.13.0`. | ||
|
|
||
| ```bash | ||
| # Delay offloading until CUDA graph launch to hide CPU overhead | ||
| --delay-offload-until-cuda-graph | ||
| ``` | ||
|
|
||
| ### Combining with Fine-Grained Recomputation | ||
|
|
||
| Offloading and recomputation are complementary: | ||
| - Use **recomputation** for lightweight modules (e.g., layernorm, activation functions) with negligible compute overhead. | ||
| - Use **offloading** for heavy modules (e.g., core_attn, expert_fc1) where recomputation would be too costly. | ||
|
|
||
| ```bash | ||
| --recompute-granularity selective | ||
| --recompute-modules layernorm moe_act | ||
| --fine-grained-activation-offloading | ||
| --offload-modules core_attn attn_proj expert_fc1 | ||
| ``` | ||
| **Compatible with Fine-grained Recomputation** | ||
| - For modules with minor perf overhead like layernorm or moe_act, use recomputing to reduce memory footprint; | ||
| - For other modules, use offloading to reduce memory footprint; | ||
| - Make sure the offloading/reloading could be overlapped with computing; | ||
|
|
||
|  | ||
|
|
||
|
|
||
| ### Compatibility | ||
|
|
||
| | Feature | Supported | | ||
| |---|---| | ||
| | PP / Interleaved PP / PP=1 | Yes | | ||
| | Fine-grained recomputation | Yes | | ||
| | FP8 training | Yes | | ||
| | MTP (Multi-Token Prediction) | Yes | | ||
| | Mixed dense & MoE layers | Yes | | ||
| | A2A overlap (EP) | Yes | | ||
| | CUDA Graph (TE impl) | Yes | | ||
|
|
||
| --- | ||
|
|
||
| ## How It Works | ||
|
|
||
| ### Architecture Overview | ||
|
|
||
| The implementation consists of three layers: | ||
|
|
||
| 1. **`PipelineOffloadManager`** (singleton): Global coordinator that manages CUDA streams, CPU tensor pools, and chunk lifecycle across pipeline stages. | ||
| 2. **`ChunkOffloadHandler`**: Per-microbatch handler that tracks tensor groups, executes D2H/H2D transfers, and decides which groups to actually offload. | ||
| 3. **`FineGrainedActivationOffloadingInterface`**: Lightweight interface used by transformer modules (attention, MoE, etc.) to mark offload boundaries. | ||
|
|
||
| ### Offload/Reload Flow | ||
|
|
||
| ``` | ||
| Forward pass (Layer N): Backward pass (Layer N): | ||
| ┌─────────────────────┐ ┌───────────────────────┐ | ||
| │ group_start(input) │─── register ──► │ │ | ||
| │ │ tensor group │ group_commit_backward │ | ||
| │ module.forward() │ │ wait H2D complete │ | ||
| │ │ │ pop tensors from │ | ||
| │ group_offload(out) │─── D2H async ──► │ CPU → GPU │ | ||
| │ on d2h_stream │ to pinned CPU │ on h2d_stream │ | ||
| └─────────────────────┘ └───────────────────────┘ | ||
| ``` | ||
|
|
||
| 1. **`group_start`**: Registers a new tensor group and hooks into `saved_tensors_hooks` to intercept `save_for_backward`. | ||
| 2. **Forward execution**: All tensors saved by autograd within the group are captured. | ||
| 3. **`group_offload`**: Triggers asynchronous D2H copy on a dedicated CUDA stream (`d2h_stream`), optionally releases GPU storage of input tensors. | ||
| 4. **Backward**: Before the group's backward, tensors are reloaded from CPU to GPU on `h2d_stream`, and the compute stream waits for the transfer to complete. | ||
|
|
||
| ### Warmup and Adaptive Offloading | ||
|
|
||
| The first training iteration serves as a **warmup phase** where the manager records tensor groups, their sizes, and the execution order. After warmup, a `post_warmup_callback` runs to: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So we cannot capture cudagraphs on the first training iteration? If so, we should assert cuda_graph_warmup_steps>0 when offloading is enabled. |
||
|
|
||
| 1. **Reserve margin**: The last N groups (by deduplication count) are kept on GPU to avoid reload blocking the compute stream. | ||
| 2. **Apply PP rank delta**: Higher PP ranks offload fewer bytes (controlled by `delta_offload_bytes_across_pp_ranks`). | ||
| 3. **Apply fraction**: Only a fraction of eligible groups are actually offloaded (controlled by `activation_offload_fraction`). | ||
| 4. **Print summary table**: An ASCII table of per-rank offload bytes is printed for debugging. | ||
|
|
||
| ### CPU Tensor Pool | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GPU Tensor Pool? |
||
|
|
||
| A `GPUTensorPool` (on CPU with pinned memory) caches allocated tensors by `(shape, dtype)`. This avoids repeated `cudaMallocHost` / `cudaFreeHost` calls and reduces D2H latency after the first iteration. | ||
|
|
||
| ### CUDA Graph Support | ||
|
|
||
| When offloading modules captured inside a CUDA graph: | ||
|
|
||
| - A dedicated `cuda_graph_stream` runs the captured computation, while `d2h_stream` overlaps D2H transfers. | ||
| - During CUDA graph **warmup**, offloading is disabled (`pre_warmup_hook` / `post_warmup_hook`). | ||
| - The `delay_offload_until_cuda_graph` option defers D2H launches until graph replay, utilizing the CPU idle time during `cudaGraphLaunch` to issue offload commands with near-zero CPU overhead. | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unless using "moe" cudagrpah scope in a drop-pad or sync-free MoE.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if we only capture
moe_routerormoe_preprocess? Is it still true?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so. If we only capture
moe_router,mlp_normworks as the input buffer of the graph, so not offloadable. The only exception is that we use attn+moe scope for drop-pad MoE, then themlp_normis totally inside the graph, so offloadable.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw you cannot only capture
moe_preprocess.moe_preprocessmust go together withmoe_router.