Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
af236f1
zero-overhead activation offload
GeYuhong Aug 18, 2025
7168ccd
bugfix main_grad info and bitwise
GeYuhong Sep 3, 2025
1555e6d
remove offload_mlp_input arg
GeYuhong Sep 3, 2025
c9f00c7
replace get_virtual_pipeline_model_parallel_rank with vp_stage
GeYuhong Sep 7, 2025
4b0d3f1
remove all MoEPositiveAuxLossAutoScaler
GeYuhong Sep 7, 2025
b00acbc
reduce modular PipeOffloadManager functions
GeYuhong Sep 7, 2025
b2c99f7
remove call_back function
GeYuhong Sep 7, 2025
e845344
polish all event sync
GeYuhong Sep 7, 2025
81f44c7
add arguments.py and minor fix, OOTB runable now.
lhb8125 Sep 9, 2025
e1a3ad6
Merge pull request #1 from lhb8125/hongbinl/activation_offloading
GeYuhong Sep 9, 2025
7a52582
support activation offloading at PP=1&PP&VPP
lhb8125 Sep 9, 2025
31ab477
Merge pull request #2 from lhb8125/hongbinl/activation_offloading
lhb8125 Sep 9, 2025
29b084d
support offloading moe_act/router_fc1/layernorm simultaneously
lhb8125 Sep 17, 2025
83ab849
support offloading core_attn/attn_proj and code refactoring
lhb8125 Sep 18, 2025
bee1060
add new cpu_offload.py
Sep 18, 2025
2b574c2
minor fix
lhb8125 Sep 18, 2025
1f03ceb
code clean
lhb8125 Sep 18, 2025
2ff9f6e
add interfaces to TE modules
lhb8125 Sep 18, 2025
a293701
renaming
lhb8125 Sep 19, 2025
aa628c0
minor fix
lhb8125 Sep 19, 2025
ecfbc87
add README
lhb8125 Sep 19, 2025
0f99ca6
Update README.md
lhb8125 Sep 19, 2025
e780b94
remove forward sync per layer
lhb8125 Sep 19, 2025
20c4029
support FP8&MTP
lhb8125 Sep 22, 2025
ae494b5
minor fix
lhb8125 Sep 22, 2025
ba17d78
code refactor and bug fix
lhb8125 Sep 22, 2025
dfaa620
update README
lhb8125 Sep 22, 2025
7d867ab
avoid multiple d2h copies for expert_fc1 and update README
lhb8125 Sep 23, 2025
7a7af1c
Update README.md
lhb8125 Sep 23, 2025
b9f0a3f
Merge pull request #3 from lhb8125/hongbinl/activation_offloading
lhb8125 Sep 24, 2025
5b28cb2
support mixed dense&moe layer and a2a overlap
lhb8125 Sep 25, 2025
4c3b2c5
minor fix
lhb8125 Sep 25, 2025
5cc4b69
Merge pull request #4 from lhb8125/hongbinl/activation_offloading
lhb8125 Sep 25, 2025
2bfc5cc
Merge branch 'main' into hongbinl/activation_offloading
lhb8125 Sep 26, 2025
a5d194c
bug fix
lhb8125 Sep 26, 2025
1853279
Merge pull request #5 from lhb8125/hongbinl/activation_offloading
lhb8125 Sep 26, 2025
c26eb8a
temp fix to enable --overlap-grad-reduce
Sep 26, 2025
0d845de
fix to enable --overlap-grad-reduce and allow placing loss layer only…
lhb8125 Sep 29, 2025
ff3a852
Merge pull request #6 from lhb8125/hongbinl/activation_offloading
lhb8125 Sep 29, 2025
17d14d8
1. remove TE version checking in developing stage
lhb8125 Oct 9, 2025
22dabcf
minor fix
lhb8125 Oct 9, 2025
5c02024
update README
lhb8125 Oct 9, 2025
086620e
Merge pull request #7 from lhb8125/hongbinl/activation_offloading
lhb8125 Oct 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,14 @@ def __init__(
extra_kwargs["delay_wgrad_compute"] = self.config.delay_wgrad_compute
else:
raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.")
# if self.config.fine_grained_activation_offloading:
# te_version = get_te_version()
# if te_version == PkgVersion("2.8.0.dev0+93a67af"):
extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading
# else:
# raise ValueError(
# f"Transformer Engine v{te_version} does not support fine_grained_activation_offloading."
# )
if (
self.config.tp_comm_overlap
and tp_comm_buffer_name
Expand Down Expand Up @@ -505,6 +513,15 @@ def __init__(
else:
raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.")

# if self.config.fine_grained_activation_offloading:
# te_version = get_te_version()
# if te_version == PkgVersion("2.8.0.dev0+93a67af"):
extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading
# else:
# raise ValueError(
# f"Transformer Engine v{te_version} does not support fine_grained_activation_offloading."
# )

# Only Transformer-Engine version >= 0.11.0 supports `RMSNorm`
if is_te_min_version("0.11.0"):
extra_kwargs["normalization"] = self.config.normalization
Expand Down Expand Up @@ -1099,6 +1116,14 @@ def __init__(
raise RuntimeError(
"Only TE with version >=2.3.0 supports delay_wgrad_compute now."
)
# if self.config.fine_grained_activation_offloading:
# te_version = get_te_version()
# if te_version == PkgVersion("2.8.0.dev0+93a67af"):
extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading
# else:
# raise ValueError(
# f"Transformer Engine v{te_version} does not support fine_grained_activation_offloading."
# )

extra_kwargs["ub_name"] = tp_comm_buffer_name

Expand Down
3 changes: 2 additions & 1 deletion megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch

import warnings

@dataclass
class ModelParallelConfig:
Expand Down Expand Up @@ -314,7 +315,7 @@ class ModelParallelConfig:
rank 0 | 0 1 2 0 1 2 3 4 3 4
rank 1 | 0 1 2 0 1 2 3 4 3 4
"""

###################
# CPU Offloading
###################
Expand Down
25 changes: 21 additions & 4 deletions megatron/core/models/gpt/fine_grained_callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
get_mtp_layer_offset,
)
from megatron.core.transformer.transformer_layer import TransformerLayer, make_viewless_tensor
from megatron.core.transformer.cpu_offload import (
PipelineOffloadManager,
group_prefetch_offload_start,
group_prefetch_offload_commit,
mark_layer_start,
)


def weak_method(method):
Expand Down Expand Up @@ -331,6 +337,8 @@ def submodule_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor):
"""
Performs same attnention forward logic as GPT Model.
"""
if layer.config.fine_grained_activation_offloading:
hidden_states = mark_layer_start(hidden_states)
hidden_states, _ = layer._forward_attention(
hidden_states=hidden_states,
attention_mask=node.chunk_state.attention_mask,
Expand All @@ -347,13 +355,20 @@ def submodule_post_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor)
Run forward pass for computations between attention and dispatch:
pre mlp layernorm->router->dispatch preprocess
"""
offload_context = nullcontext()
if layer.offload_mlp_norm:
hidden_states = group_prefetch_offload_start(hidden_states, name="mlp_norm")
offload_context = PipelineOffloadManager.get_instance()
if layer.recompute_pre_mlp_layernorm:
layer.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint(
layer.pre_mlp_layernorm, hidden_states
)
with offload_context:
pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint(
layer.pre_mlp_layernorm, hidden_states
)
else:
pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states)
with offload_context:
pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states)
offload_context = nullcontext()

local_tokens, probs, _ = layer.mlp.router_and_preprocess(pre_mlp_layernorm_output)

Expand Down Expand Up @@ -433,6 +448,8 @@ def submodule_combine_forward(
hidden_states = layer.mlp_bda(layer.training, layer.config.bias_dropout_fusion)(
mlp_output_with_bias, residual, layer.hidden_dropout
)
if layer.offload_mlp_norm:
hidden_states, = group_prefetch_offload_commit(hidden_states, release_tensors=[residual])
output = make_viewless_tensor(
inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True
)
Expand Down
21 changes: 21 additions & 0 deletions megatron/core/models/gpt/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from megatron.core.transformer.transformer_block import TransformerBlock
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import WrappedTensor, deprecate_inference_params
from megatron.core.transformer.cpu_offload import PipelineOffloadManager


class GPTModel(LanguageModule):
Expand Down Expand Up @@ -341,6 +342,22 @@ def _preprocess(

return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset

def initialize_model_chunk_offload_handler(self):
num_layers = self.decoder.num_layers_per_pipeline_rank
if self.mtp_process:
num_layers = num_layers + self.config.mtp_num_layers
pp_rank = parallel_state.get_pipeline_model_parallel_rank()
pp_size = parallel_state.get_pipeline_model_parallel_world_size()
last_stage_is_loss = (pp_rank == pp_size - 1) and self.config.last_vp_stage_is_loss
# TODO: will be an issue when dense layer is placed across different pipeline stages
PipelineOffloadManager.get_instance().reset_chunk_handler(
num_layers,
self.vp_stage,
self.config.fine_grained_activation_offloading,
self.decoder.num_dense_layer,
last_stage_is_loss,
)

def forward(
self,
input_ids: Tensor,
Expand All @@ -366,6 +383,8 @@ def forward(
runtime_gather_output (bool): Gather output at runtime. Default None means
`parallel_output` arg in the constructor will be used.
"""
if self.config.fine_grained_activation_offloading:
self.initialize_model_chunk_offload_handler()

inference_context = deprecate_inference_params(inference_context, inference_params)

Expand Down Expand Up @@ -627,6 +646,8 @@ def build_schedule_plan(
TransformerModelChunkSchedulePlan: The model chunk schedule plan.
"""

if self.config.fine_grained_activation_offloading:
self.initialize_model_chunk_offload_handler()
from ..common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan

return TransformerModelChunkSchedulePlan(
Expand Down
10 changes: 10 additions & 0 deletions megatron/core/pipeline_parallel/schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.cuda_graphs import create_cudagraphs
from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler
from megatron.core.transformer.cpu_offload import PipelineOffloadManager
from megatron.core.utils import (
drain_embedding_wgrad_compute,
get_attr_wrapped_model,
Expand Down Expand Up @@ -558,6 +559,9 @@ def forward_backward_no_pipelining(
adjust_tensor_shapes_fn is None
), "adjust_tensor_shapes_fn is not supported for non-pipeline-parallel schedule"

if not forward_only:
PipelineOffloadManager.get_instance().reset()

config = get_model_config(model)
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)
Expand Down Expand Up @@ -898,6 +902,9 @@ def forward_backward_pipelining_with_interleaving(
adjust_tensor_shapes_fn is None
), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism"

if not forward_only:
PipelineOffloadManager.get_instance().reset()

if config.overlap_p2p_comm and config.batch_p2p_comm:
raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm")

Expand Down Expand Up @@ -2043,6 +2050,9 @@ def forward_backward_pipelining_without_interleaving(
if config.timers is not None:
config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time)

if not forward_only:
PipelineOffloadManager.get_instance().reset()

# Disable async grad reductions
no_sync_func = config.no_sync_func
if no_sync_func is None:
Expand Down
17 changes: 15 additions & 2 deletions megatron/core/tensor_parallel/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,10 +510,11 @@ def forward(ctx, run_function, checkpoint_without_output_obj, *args):
@staticmethod
def backward(ctx, *args):
"""Backward pass."""
inputs = ctx.saved_tensors
inputs = ctx.inputs
outputs = ctx.outputs
torch.autograd.backward(outputs, args)
ctx.outputs = None
ctx.inputs = None
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in inputs)
return (None, None) + grads

Expand Down Expand Up @@ -573,8 +574,19 @@ def _recompute(self, _):
recompute_ctx = contextlib.nullcontext()
fp8_ctx = contextlib.nullcontext()

inputs = self.ctx.saved_tensors
# do not know why, if saved_tensors is handled by saved_tensor_hook, grad of inputs will be None (not nan)
# detach it to bypass
def detach(t):
if isinstance(t, torch.Tensor):
requires_grad = t.requires_grad
t = t.detach()
t.requires_grad_(requires_grad)
return t

inputs = tuple(detach(t) for t in inputs)
with torch.enable_grad(), fp8_ctx, recompute_ctx:
outputs = self.run_function(*self.ctx.saved_tensors)
outputs = self.run_function(*inputs)

self.run_function = None
self.rng_states = None
Expand All @@ -590,6 +602,7 @@ def _recompute(self, _):
output.untyped_storage().copy_(recomputation_output.untyped_storage())

self.ctx.outputs = outputs
self.ctx.inputs = inputs
self.outputs = None
self.ctx = None

Expand Down
143 changes: 143 additions & 0 deletions megatron/core/transformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
<div align="center">

Fine-grained Activation Offloading
=============
<h4>NVIDIA, rednote</h4>
<div align="left">

# What is Fine-grained Activation Offloading?

Memory capacity are more and more important with the rising of extreme sparse MoE models like DeepSeek-V3 and Qwen3-235B. Fine-grained Activation Offloading targets at offloading the activation at the granularity of specific modules, so that we can calibrate the amount of offloading activation to maximize the training throughput.

# Quick Start

```bash
# Enable fine-grained activation offloading
--fine-grained-activation-offloading

# Specify which modules are going to be offloaded
# Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".
--offload-modules core_attn
```

# Current status
## Features
* Support PP=1/PP/Interleaved PP
* Compatible with fine-grained recomputation
* Support FP8
* Support MTP
* Support mixed dense & moe layer
* Support A2A Overlap
* Support CUDA Graph
* (Temporary) cuda graph scope cannot contains the offloading modules

## Known issues
* We explicitly resize some tensors to 0 to release the memory space immediately, which sometimes leads to illegal memory access. Please remove the released tensors in `group_prefetch_offload_commit` if you run into the issue.

## WIP items
* Code refactor
* Benchmark

# Methodology

## Offload/Reload the input of one module to/from CPU
Let's take the attention projection module as an example:
```
nvtx_range_push(suffix="linear_proj")
offload_context = contextlib.nullcontext()
if self.offload_attn_proj:
core_attn_out = group_prefetch_offload_start(core_attn_out, name="attn_proj")
offload_context = PipelineOffloadManager.get_instance()
with offload_context:
output, bias = self.linear_proj(core_attn_out)
if self.offload_attn_proj:
output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[core_attn_out])
offload_context = contextlib.nullcontext()
nvtx_range_pop(suffix="linear_proj")
```
The above code snippet could be divided into three parts in order:
1. Mark the starting point of offloading a new module;
2. Record the save_for_backward tensors in fprop and push it to a tensor buffer;
3. Offload the recorded tensors after the module's fprop finishes;

In bprop, the three parts above will:
1. Make sure the offloaded tensors are reloaded back to GPU;
2. Pop the corresponding tensors from the tensor buffer;
3. Reload the corresponding tensors of next module;

## Compatible with PP&Interleaved PP

`PipelineOffloadManager` is used to manage the chunks across different model chunks in fprop and bprop.
Before the model.forward() start, the `PipelineOffloadManager.get_instance().reset_chunk_handler` will be executed. In the fprop of this method, we create a `ChunkOffloadHandler` to handle the offloading context of one model chunk and then push it to a buffer, which will be popped out in a specific order in bprop.

<img width="1182" height="537" alt="image" src="https://github.com/user-attachments/assets/9d1655cc-d6d4-44de-acaf-35099cb902c2" />


## Compatible with fine-grained recomputation

<img width="2873" height="1494" alt="offload_and_recompute" src="https://github.com/user-attachments/assets/b857112f-4cf6-480f-aaf8-496bfe821faa" />


## A special case: attn_norm/mlp_norm

# Performance

## H100

### DeepSeek-V3-Proxy
#### Model structure
* Layer parameters are same as DeepSeek-V3 model
* Layer number is cut off to 14 layers
* Replace the fisrt 3 dense layers with 3 moe layers

#### Key Hyper-parameters
* TP1PP4EP16VPP1CP1-MBS1GBS512
* bf16 training
* DeepEP dispatcher
* `--cross-entropy-loss-fusion` and `--cross-entropy-fusion-impl te`
* `--moe-permute-fusion`
* `--moe-router-fusion`
* `--enable-experimental`

#### Throughput and correctness

<img width="1245" height="845" alt="image" src="https://github.com/user-attachments/assets/51e8e0d1-b03a-4723-a90e-4cbd5c661550" />
<img width="1291" height="832" alt="image" src="https://github.com/user-attachments/assets/73eb5c86-bd69-4dcd-a477-b0194225aa1e" />


#### Memory consumption

Baseline (no offloading)
```
[Rank 0] (after 10 iterations) memory (MB) | allocated: 24761.02978515625 | max allocated: 65203.93359375 | reserved: 64438.0 | max reserved: 74306.0
[Rank 16] (after 10 iterations) memory (MB) | allocated: 18907.728515625 | max allocated: 52228.1533203125 | reserved: 58770.0 | max reserved: 58770.0
[Rank 32] (after 10 iterations) memory (MB) | allocated: 18907.7529296875 | max allocated: 45200.8349609375 | reserved: 51772.0 | max reserved: 51772.0
[Rank 48] (after 10 iterations) memory (MB) | allocated: 29006.82275390625 | max allocated: 48166.263671875 | reserved: 56328.0 | max reserved: 56328.0
```
With offloading expert_fc1, moe_act, act_norm and mlp_norm
```
[Rank 0] (after 10 iterations) memory (MB) | allocated: 24705.02978515625 | max allocated: 48544.70849609375 | reserved: 61046.0 | max reserved: 61046.0
[Rank 16] (after 10 iterations) memory (MB) | allocated: 18795.728515625 | max allocated: 38760.3876953125 | reserved: 46330.0 | max reserved: 46330.0
[Rank 32] (after 10 iterations) memory (MB) | allocated: 18795.7529296875 | max allocated: 34950.2509765625 | reserved: 42452.0 | max reserved: 42452.0
[Rank 48] (after 10 iterations) memory (MB) | allocated: 28950.82275390625 | max allocated: 41310.798828125 | reserved: 50408.0 | max reserved: 50408.0
```

### Qwen3-30B-A3B
#### Model structure
* Same as Qwen-30B model structure

#### Results

| Model | Mapping | Sequence length | Recompute | Offload | Throughput (tflops) | Memory (MB) |
|---------------|--------------------------|-----------------|-----------|------------|---------------------|-------------|
| Qwen3-30B-A3B | TP1PP1EP8VPP1_MBS1GBS256 | 4096 | / | / | 194 | 65308 |
| | TP1PP1EP8VPP1_MBS1GBS256 | 8192 | full | / | 230 | 59566 |
| | TP1PP2EP8VPP4_MBS1GBS256 | 8192 | layernorm | expert_fc1 | 255 | 64962 |



## GB200

# Acknowledgement

This work refers to the previous work from Kuaishou: https://www.usenix.org/conference/atc24/presentation/yuan
Loading