Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
1219a26
renaming golden values
lhb8125 Oct 29, 2025
ce6e661
fix bug: accuracy issu because of recomputing and offloading same module
lhb8125 Nov 4, 2025
d04d741
Merge branch 'dev' into hongbinl/activation_offloading_fix
lhb8125 Nov 4, 2025
2fe4aeb
format
lhb8125 Nov 4, 2025
fb3f7c3
update golden values
lhb8125 Nov 5, 2025
5001e2b
Merge branch 'dev' into hongbinl/activation_offloading_fix
lhb8125 Nov 5, 2025
9937890
update golden values
lhb8125 Nov 5, 2025
6c83118
update model_config and golden values
lhb8125 Nov 6, 2025
33a38f5
format
lhb8125 Nov 6, 2025
6c76b07
update golden values
lhb8125 Nov 6, 2025
4d83f69
Merge branch 'dev' into hongbinl/activation_offloading_fix
lhb8125 Nov 10, 2025
8e72b44
temp save
lhb8125 Nov 18, 2025
1646f04
support offloading+cuda graph
lhb8125 Nov 25, 2025
43973a7
Merge branch 'dev' into hongbinl/activation_offloading_cuda_graph
lhb8125 Nov 25, 2025
a177cf5
support PP=1
lhb8125 Nov 27, 2025
f7cfbba
support VPP
lhb8125 Dec 1, 2025
6d475ad
bug fix
lhb8125 Dec 2, 2025
089da6c
support VPP
lhb8125 Dec 8, 2025
35b0f97
code refactor
lhb8125 Dec 8, 2025
df09b85
big code refactor and format
lhb8125 Dec 8, 2025
06ef4e2
Merge branch 'dev' into hongbinl/activation_offloading_cuda_graph
lhb8125 Dec 8, 2025
12cb8de
minor fix
lhb8125 Dec 8, 2025
3cf19b7
minor fix
lhb8125 Dec 8, 2025
d0fc888
dump offloading information
lhb8125 Dec 8, 2025
bc47650
Merge branch 'dev' into hongbinl/activation_offloading_cuda_graph
lhb8125 Dec 8, 2025
b7c0fba
fix ut
lhb8125 Dec 8, 2025
b18e69b
Merge branch 'hongbinl/activation_offloading_cuda_graph' of https://g…
lhb8125 Dec 8, 2025
ae4e2b5
format
lhb8125 Dec 8, 2025
b797438
fit ut
lhb8125 Dec 8, 2025
6cec22f
delay d2h copies until finishing cuda graph
lhb8125 Dec 17, 2025
5a150c7
minor fix
lhb8125 Dec 17, 2025
e6cf8b5
Merge branch 'dev' into hongbinl/activation_offloading_cuda_graph
lhb8125 Dec 17, 2025
60e3082
format
lhb8125 Dec 17, 2025
256d79d
fix ut
lhb8125 Dec 17, 2025
9475e3d
format
lhb8125 Dec 17, 2025
6daa2a4
Merge branch 'dev' into hongbinl/activation_offloading_cuda_graph
lhb8125 Dec 17, 2025
93c0827
minor fix
lhb8125 Jan 6, 2026
f22a194
remove changes for cuda graph
lhb8125 Jan 6, 2026
a9d6633
bug fix when cuda graph is disabled and fix for dumping offloading info
Jan 7, 2026
9d766e9
refactor and update ut
lhb8125 Jan 8, 2026
9d1fe34
format
lhb8125 Jan 8, 2026
08b46aa
fix ut
lhb8125 Jan 8, 2026
d33b3c4
update ut
lhb8125 Jan 12, 2026
884c335
update ut
lhb8125 Jan 9, 2026
df2e839
fix ut
lhb8125 Jan 9, 2026
569f347
add version check
lhb8125 Jan 12, 2026
7f8109a
minor refactor for fine_grained_activation_offload.py
lhb8125 Jan 12, 2026
19b35c3
format
lhb8125 Jan 12, 2026
e867835
Merge branch 'dev' into hongbinl/activation_offloading_refactor
lhb8125 Jan 12, 2026
dd874a9
support partial cuda graph
lhb8125 Jan 13, 2026
b8c0b79
fix bug when working with a2a overlap and cuda graph
lhb8125 Jan 13, 2026
994fc5a
support offloading less for large pp rank
lhb8125 Jan 13, 2026
2688c7e
fix doc
lhb8125 Jan 14, 2026
aab8455
code refactor
lhb8125 Jan 14, 2026
0a92566
remove group_start() calls
lhb8125 Jan 14, 2026
0df5134
add comments
lhb8125 Jan 14, 2026
62d36f2
fix min_offload_size and update golden values
lhb8125 Jan 14, 2026
17c0eb9
minor fix and format
lhb8125 Jan 15, 2026
d979f1e
Merge branch 'hongbinl/activation_offloading_refactor' into hongbinl/…
lhb8125 Jan 15, 2026
53a21e1
rename group_commit
lhb8125 Jan 15, 2026
600cfe7
fix for graph support
lhb8125 Jan 19, 2026
09956df
refine offloading strategy
lhb8125 Jan 20, 2026
18a2d9e
temp fix for mxfp8
lhb8125 Jan 22, 2026
885164b
minor fix
lhb8125 Jan 24, 2026
688c2ab
support offloading fraction
lhb8125 Jan 25, 2026
2b0276c
free input of mlp when fp8
lhb8125 Feb 3, 2026
0c901a8
Merge branch 'dev' into hongbinl/activation_offloading_refactor_cuda_…
lhb8125 Feb 3, 2026
e606424
minor fix and format
lhb8125 Feb 3, 2026
0cda334
Merge branch 'dev' into hongbinl/activation_offloading_refactor_cuda_…
lhb8125 Feb 3, 2026
0f0d1ed
fix ut
lhb8125 Feb 3, 2026
254a2d2
Merge branch 'hongbinl/activation_offloading_refactor_cuda_graph' of …
lhb8125 Feb 3, 2026
4ba07bd
Merge branch 'dev' into hongbinl/activation_offloading_refactor_cuda_…
lhb8125 Feb 4, 2026
e585f60
Update arguments.py
lhb8125 Feb 4, 2026
1b8050f
fix ut
lhb8125 Feb 4, 2026
410f879
Merge branch 'dev' into hongbinl/activation_offloading_refactor_cuda_…
lhb8125 Feb 4, 2026
162e388
Merge branch 'dev' into hongbinl/activation_offloading_refactor_cuda_…
lhb8125 Feb 5, 2026
504b3a4
update ut and minor refactor
lhb8125 Feb 5, 2026
01e2ad9
Merge branch 'hongbinl/activation_offloading_refactor_cuda_graph' of …
lhb8125 Feb 5, 2026
526fc35
Merge branch 'dev' into hongbinl/activation_offloading_refactor_cuda_…
lhb8125 Feb 6, 2026
6264494
Merge branch 'dev' into hongbinl/activation_offloading_refactor_cuda_…
lhb8125 Feb 24, 2026
7640bf3
minor refactor
lhb8125 Feb 25, 2026
15641c4
Merge branch 'dev' into hongbinl/activation_offloading_refactor_cuda_…
lhb8125 Feb 25, 2026
dfc8161
format
lhb8125 Feb 25, 2026
e362f79
fix ut
lhb8125 Feb 26, 2026
2da15b2
format
lhb8125 Feb 26, 2026
221cc16
minor fix
lhb8125 Feb 26, 2026
726c526
format and minor fix
lhb8125 Feb 26, 2026
f19c4f6
Merge branch 'dev' into hongbinl/activation_offloading_refactor_cuda_…
lhb8125 Feb 26, 2026
998d1b0
1. replace hasattr+delattr with None;
lhb8125 Feb 27, 2026
4bf0085
Merge branch 'dev' into hongbinl/activation_offloading_refactor_cuda_…
lhb8125 Mar 2, 2026
716e12a
format
lhb8125 Mar 2, 2026
cd84623
Merge branch 'hongbinl/activation_offloading_refactor_cuda_graph' of …
lhb8125 Mar 2, 2026
61b589a
bug fix
lhb8125 Mar 2, 2026
0200121
add flag to control flush_delayed_groups in fine_grained_callables.py
lhb8125 Mar 2, 2026
c8bd90d
1. move backward_record() to te_cuda_graph_capture()
lhb8125 Mar 5, 2026
ddd67d2
format
lhb8125 Mar 5, 2026
e989b95
remove the knob forward_only when executing reset()
lhb8125 Mar 5, 2026
b481fa9
Merge branch 'dev' into hongbinl/activation_offloading_refactor_cuda_…
lhb8125 Mar 5, 2026
19fe6b3
fix ut and reviewer's comments
lhb8125 Mar 5, 2026
cf04b4a
Merge branch 'hongbinl/activation_offloading_refactor_cuda_graph' of …
lhb8125 Mar 5, 2026
ce84682
Merge branch 'dev' into hongbinl/activation_offloading_refactor_cuda_…
lhb8125 Mar 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 130 additions & 20 deletions docs/api-guide/fine_grained_activation_offloading.md
Original file line number Diff line number Diff line change
@@ -1,31 +1,141 @@
# Fine-grained Activation Offloading (collaborated with rednote)
# Fine-Grained Activation Offloading

Memory capacity is more and more important with the rising of extreme sparse MoE models like DeepSeek-V3 and Qwen3-235B. Fine-grained recomputing reduces the memory footprint at the cost of extra recomputation, while offloading could utilize the host-device bandwidth to achieve nearly zero-overhead. Fine-grained Activation Offloading targets at offloading the activation at the granularity of specific modules, so that we can calibrate the amount of offloading activation to maximize the training throughput.
Fine-grained activation offloading reduces GPU memory by asynchronously transferring activations to CPU at the granularity of individual submodules within a transformer layer. Unlike layer-level offloading, it allows precise control over which activations to offload, enabling a tradeoff between memory savings and PCIe bandwidth overhead.

Currently, the supported offloading modules are `"attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act"`, which could work with fine-grained recomputation to release almost all activations of a transformer layer.
## User Guide

**Features**
* Support PP=1/PP/Interleaved PP
* Compatible with fine-grained recomputation
* Support FP8
* Support MTP
* Support mixed dense & moe layer
* Support A2A Overlap
* Support CUDA Graph
* (Temporary) cuda graph scope cannot contains the offloading modules
### Basic Usage

**Usage**
```bash
# Enable fine-grained activation offloading
--fine-grained-activation-offloading

# Specify which modules are going to offload its input
# Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".
--offload-modules expert_fc1
# Specify which modules to offload (can combine multiple)
# Choices: attn_norm, qkv_linear, core_attn, attn_proj, mlp_norm, expert_fc1, moe_act
--offload-modules core_attn attn_proj expert_fc1
```

### Offloadable Modules

Each module offloads its **input** activation to CPU during forward and reloads it before backward:

| Module | Description | Notes |
|---|---|---|
| `attn_norm` | Input layernorm of attention | Skipped if using `IdentityOp` |
| `qkv_linear` | QKV linear projection | |
| `core_attn` | Core attention (softmax + matmul) | |
| `attn_proj` | Output projection of attention | Must be used together with `core_attn` |
| `mlp_norm` | Pre-MLP layernorm | Skipped if using `IdentityOp` |
| `expert_fc1` | First FC layer in MoE experts | MoE models only |
| `moe_act` | Activation function in MoE experts | MoE models only |

### Tuning Parameters

```bash
# Minimum tensor size (in elements) to offload. Smaller tensors are skipped.
# Default: 1048576 (1M elements)
--min-offloaded-tensor-size 1048576

# Fraction of activations to offload, range [0, 1]. Default: 1.0
# Useful for partial offloading when PCIe bandwidth is a bottleneck.
--activation-offload-fraction 0.8

# Reduce offload amount on higher PP ranks (in bytes). Default: 0
# Higher PP ranks have fewer microbatches in flight, so offloading less
# reduces overhead without increasing peak memory.
--delta-offload-bytes-across-pp-ranks 1073741824
```

### CUDA Graph Integration

Fine-grained offloading is compatible with CUDA graphs. When CUDA graph is enabled, the following constraints apply:

- `attn_norm` and `mlp_norm` **cannot** be offloaded (they cross CUDA graph boundaries).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unless using "moe" cudagrpah scope in a drop-pad or sync-free MoE.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if we only capture moe_router or moe_preprocess? Is it still true?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. If we only capture moe_router, mlp_norm works as the input buffer of the graph, so not offloadable. The only exception is that we use attn+moe scope for drop-pad MoE, then the mlp_norm is totally inside the graph, so offloadable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw you cannot only capture moe_preprocess . moe_preprocess must go together with moe_router .

- `cuda_graph_scope` must include `attn` and `moe_router`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can I use "moe" scope if I'm in a drop-pad MoE?

Can I offload attention part modules if my cuda graph scope is only "moe_router"? This may be needed since some cases have dynamic-shaped attention so only the router part can be captured.

Copy link
Contributor Author

@lhb8125 lhb8125 Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this hard limitation, now the scope could be moe_router alone or moe.

- `cuda_graph_impl` must be `transformer_engine`.
- Requires `torch >= 2.9.0` and `transformer_engine >= 2.13.0`.

```bash
# Delay offloading until CUDA graph launch to hide CPU overhead
--delay-offload-until-cuda-graph
```

### Combining with Fine-Grained Recomputation

Offloading and recomputation are complementary:
- Use **recomputation** for lightweight modules (e.g., layernorm, activation functions) with negligible compute overhead.
- Use **offloading** for heavy modules (e.g., core_attn, expert_fc1) where recomputation would be too costly.

```bash
--recompute-granularity selective
--recompute-modules layernorm moe_act
--fine-grained-activation-offloading
--offload-modules core_attn attn_proj expert_fc1
```
**Compatible with Fine-grained Recomputation**
- For modules with minor perf overhead like layernorm or moe_act, use recomputing to reduce memory footprint;
- For other modules, use offloading to reduce memory footprint;
- Make sure the offloading/reloading could be overlapped with computing;

![Fine-grained Activation Offloading and Fine-grained Recomputation](../../images/fine_grained_activation_offloading/offloading_and_recomputing.png)


### Compatibility

| Feature | Supported |
|---|---|
| PP / Interleaved PP / PP=1 | Yes |
| Fine-grained recomputation | Yes |
| FP8 training | Yes |
| MTP (Multi-Token Prediction) | Yes |
| Mixed dense & MoE layers | Yes |
| A2A overlap (EP) | Yes |
| CUDA Graph (TE impl) | Yes |

---

## How It Works

### Architecture Overview

The implementation consists of three layers:

1. **`PipelineOffloadManager`** (singleton): Global coordinator that manages CUDA streams, CPU tensor pools, and chunk lifecycle across pipeline stages.
2. **`ChunkOffloadHandler`**: Per-microbatch handler that tracks tensor groups, executes D2H/H2D transfers, and decides which groups to actually offload.
3. **`FineGrainedActivationOffloadingInterface`**: Lightweight interface used by transformer modules (attention, MoE, etc.) to mark offload boundaries.

### Offload/Reload Flow

```
Forward pass (Layer N): Backward pass (Layer N):
┌─────────────────────┐ ┌───────────────────────┐
│ group_start(input) │─── register ──► │ │
│ │ tensor group │ group_commit_backward │
│ module.forward() │ │ wait H2D complete │
│ │ │ pop tensors from │
│ group_offload(out) │─── D2H async ──► │ CPU → GPU │
│ on d2h_stream │ to pinned CPU │ on h2d_stream │
└─────────────────────┘ └───────────────────────┘
```

1. **`group_start`**: Registers a new tensor group and hooks into `saved_tensors_hooks` to intercept `save_for_backward`.
2. **Forward execution**: All tensors saved by autograd within the group are captured.
3. **`group_offload`**: Triggers asynchronous D2H copy on a dedicated CUDA stream (`d2h_stream`), optionally releases GPU storage of input tensors.
4. **Backward**: Before the group's backward, tensors are reloaded from CPU to GPU on `h2d_stream`, and the compute stream waits for the transfer to complete.

### Warmup and Adaptive Offloading

The first training iteration serves as a **warmup phase** where the manager records tensor groups, their sizes, and the execution order. After warmup, a `post_warmup_callback` runs to:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we cannot capture cudagraphs on the first training iteration? If so, we should assert cuda_graph_warmup_steps>0 when offloading is enabled.


1. **Reserve margin**: The last N groups (by deduplication count) are kept on GPU to avoid reload blocking the compute stream.
2. **Apply PP rank delta**: Higher PP ranks offload fewer bytes (controlled by `delta_offload_bytes_across_pp_ranks`).
3. **Apply fraction**: Only a fraction of eligible groups are actually offloaded (controlled by `activation_offload_fraction`).
4. **Print summary table**: An ASCII table of per-rank offload bytes is printed for debugging.

### CPU Tensor Pool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GPU Tensor Pool?


A `GPUTensorPool` (on CPU with pinned memory) caches allocated tensors by `(shape, dtype)`. This avoids repeated `cudaMallocHost` / `cudaFreeHost` calls and reduces D2H latency after the first iteration.

### CUDA Graph Support

When offloading modules captured inside a CUDA graph:

- A dedicated `cuda_graph_stream` runs the captured computation, while `d2h_stream` overlaps D2H transfers.
- During CUDA graph **warmup**, offloading is disabled (`pre_warmup_hook` / `post_warmup_hook`).
- The `delay_offload_until_cuda_graph` option defers D2H launches until graph replay, utilizing the CPU idle time during `cudaGraphLaunch` to issue offload commands with near-zero CPU overhead.
18 changes: 9 additions & 9 deletions megatron/core/models/gpt/fine_grained_callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,18 +476,16 @@ def forward_func(
)
if not isinstance(layer.mlp, MoELayer):
return hidden_states, None, None, None
mlp_norm_manager = off_interface(layer.offload_mlp_norm, hidden_states, "mlp_norm")
node.layer_state.mlp_norm_manager = mlp_norm_manager
if layer.recompute_pre_mlp_layernorm:
layer.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
with off_interface(
layer.offload_mlp_norm, hidden_states, "mlp_norm"
) as hidden_states:
with mlp_norm_manager as hidden_states:
pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint(
apply_module(layer.pre_mlp_layernorm), hidden_states
)
else:
with off_interface(
layer.offload_mlp_norm, hidden_states, "mlp_norm"
) as hidden_states:
with mlp_norm_manager as hidden_states:
pre_mlp_layernorm_output = apply_module(layer.pre_mlp_layernorm)(
hidden_states
)
Expand Down Expand Up @@ -589,10 +587,12 @@ def submodule_combine_forward(node: ScheduleNode, output: torch.Tensor):
)
# Delay the offload of the mlp norm until after the mlp_bda has been computed
# because the residual is needed in the mlp_bda.
if layer.offload_mlp_norm:
hidden_states = off_interface.group_commit(
hidden_states, name="mlp_norm", forced_released_tensors=[residual]
mlp_norm_manager = getattr(node.layer_state, 'mlp_norm_manager', None)
if mlp_norm_manager is not None:
hidden_states = mlp_norm_manager.group_offload(
hidden_states, forced_released_tensors=[residual]
)
node.layer_state.mlp_norm_manager = None
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
Expand Down
9 changes: 6 additions & 3 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,19 +458,22 @@ def _preprocess(
def preprocess_for_fine_grained_offloading(self):
"""Preprocess for fine-grained activation offloading."""
off_interface.init_chunk_handler(
pp_rank=self.pg_collection.pp.rank(),
vp_size=self.config.virtual_pipeline_model_parallel_size,
vp_stage=self.vp_stage,
min_offloaded_tensor_size=self.config.min_offloaded_tensor_size,
delta_offload_bytes_across_pp_ranks=self.config.delta_offload_bytes_across_pp_ranks,
activation_offload_fraction=self.config.activation_offload_fraction,
)
if self.disable_param_offloading:
for param in self.decoder.parameters():
off_interface.mark_not_offloadable(param)
off_interface.mark_not_offload(param)
if self.mtp_process:
for param in self.mtp.parameters():
off_interface.mark_not_offloadable(param)
off_interface.mark_not_offload(param)
if self.post_process:
for param in self.output_layer.parameters():
off_interface.mark_not_offloadable(param)
off_interface.mark_not_offload(param)
self.disable_param_offloading = False

def forward(
Expand Down
Loading
Loading