diff --git a/docs/api-guide/fine_grained_activation_offloading.md b/docs/api-guide/fine_grained_activation_offloading.md index 53211d1d06c..91edec48d68 100644 --- a/docs/api-guide/fine_grained_activation_offloading.md +++ b/docs/api-guide/fine_grained_activation_offloading.md @@ -1,31 +1,141 @@ -# Fine-grained Activation Offloading (collaborated with rednote) +# Fine-Grained Activation Offloading -Memory capacity is more and more important with the rising of extreme sparse MoE models like DeepSeek-V3 and Qwen3-235B. Fine-grained recomputing reduces the memory footprint at the cost of extra recomputation, while offloading could utilize the host-device bandwidth to achieve nearly zero-overhead. Fine-grained Activation Offloading targets at offloading the activation at the granularity of specific modules, so that we can calibrate the amount of offloading activation to maximize the training throughput. +Fine-grained activation offloading reduces GPU memory by asynchronously transferring activations to CPU at the granularity of individual submodules within a transformer layer. Unlike layer-level offloading, it allows precise control over which activations to offload, enabling a tradeoff between memory savings and PCIe bandwidth overhead. -Currently, the supported offloading modules are `"attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act"`, which could work with fine-grained recomputation to release almost all activations of a transformer layer. +## User Guide -**Features** -* Support PP=1/PP/Interleaved PP -* Compatible with fine-grained recomputation -* Support FP8 -* Support MTP -* Support mixed dense & moe layer -* Support A2A Overlap -* Support CUDA Graph - * (Temporary) cuda graph scope cannot contains the offloading modules +### Basic Usage -**Usage** ```bash # Enable fine-grained activation offloading --fine-grained-activation-offloading -# Specify which modules are going to offload its input -# Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act". ---offload-modules expert_fc1 +# Specify which modules to offload (can combine multiple) +# Choices: attn_norm, qkv_linear, core_attn, attn_proj, mlp_norm, expert_fc1, moe_act +--offload-modules core_attn attn_proj expert_fc1 +``` + +### Offloadable Modules + +Each module offloads its **input** activation to CPU during forward and reloads it before backward: + +| Module | Description | Notes | +|---|---|---| +| `attn_norm` | Input layernorm of attention | Skipped if using `IdentityOp` | +| `qkv_linear` | QKV linear projection | | +| `core_attn` | Core attention (softmax + matmul) | | +| `attn_proj` | Output projection of attention | Must be used together with `core_attn` | +| `mlp_norm` | Pre-MLP layernorm | Skipped if using `IdentityOp` | +| `expert_fc1` | First FC layer in MoE experts | MoE models only | +| `moe_act` | Activation function in MoE experts | MoE models only | + +### Tuning Parameters + +```bash +# Minimum tensor size (in elements) to offload. Smaller tensors are skipped. +# Default: 1048576 (1M elements) +--min-offloaded-tensor-size 1048576 + +# Fraction of activations to offload, range [0, 1]. Default: 1.0 +# Useful for partial offloading when PCIe bandwidth is a bottleneck. +--activation-offload-fraction 0.8 + +# Reduce offload amount on higher PP ranks (in bytes). Default: 0 +# Higher PP ranks have fewer microbatches in flight, so offloading less +# reduces overhead without increasing peak memory. +--delta-offload-bytes-across-pp-ranks 1073741824 +``` + +### CUDA Graph Integration + +Fine-grained offloading is compatible with CUDA graphs. When CUDA graph is enabled, the following constraints apply: + +- `attn_norm` and `mlp_norm` **cannot** be offloaded (they cross CUDA graph boundaries). +- `cuda_graph_scope` must include `attn` and `moe_router`. +- `cuda_graph_impl` must be `transformer_engine`. +- Requires `torch >= 2.9.0` and `transformer_engine >= 2.13.0`. + +```bash +# Delay offloading until CUDA graph launch to hide CPU overhead +--delay-offload-until-cuda-graph +``` + +### Combining with Fine-Grained Recomputation + +Offloading and recomputation are complementary: +- Use **recomputation** for lightweight modules (e.g., layernorm, activation functions) with negligible compute overhead. +- Use **offloading** for heavy modules (e.g., core_attn, expert_fc1) where recomputation would be too costly. + +```bash +--recompute-granularity selective +--recompute-modules layernorm moe_act +--fine-grained-activation-offloading +--offload-modules core_attn attn_proj expert_fc1 ``` -**Compatible with Fine-grained Recomputation** -- For modules with minor perf overhead like layernorm or moe_act, use recomputing to reduce memory footprint; -- For other modules, use offloading to reduce memory footprint; -- Make sure the offloading/reloading could be overlapped with computing; ![Fine-grained Activation Offloading and Fine-grained Recomputation](../../images/fine_grained_activation_offloading/offloading_and_recomputing.png) + + +### Compatibility + +| Feature | Supported | +|---|---| +| PP / Interleaved PP / PP=1 | Yes | +| Fine-grained recomputation | Yes | +| FP8 training | Yes | +| MTP (Multi-Token Prediction) | Yes | +| Mixed dense & MoE layers | Yes | +| A2A overlap (EP) | Yes | +| CUDA Graph (TE impl) | Yes | + +--- + +## How It Works + +### Architecture Overview + +The implementation consists of three layers: + +1. **`PipelineOffloadManager`** (singleton): Global coordinator that manages CUDA streams, CPU tensor pools, and chunk lifecycle across pipeline stages. +2. **`ChunkOffloadHandler`**: Per-microbatch handler that tracks tensor groups, executes D2H/H2D transfers, and decides which groups to actually offload. +3. **`FineGrainedActivationOffloadingInterface`**: Lightweight interface used by transformer modules (attention, MoE, etc.) to mark offload boundaries. + +### Offload/Reload Flow + +``` +Forward pass (Layer N): Backward pass (Layer N): +┌─────────────────────┐ ┌───────────────────────┐ +│ group_start(input) │─── register ──► │ │ +│ │ tensor group │ group_commit_backward │ +│ module.forward() │ │ wait H2D complete │ +│ │ │ pop tensors from │ +│ group_offload(out) │─── D2H async ──► │ CPU → GPU │ +│ on d2h_stream │ to pinned CPU │ on h2d_stream │ +└─────────────────────┘ └───────────────────────┘ +``` + +1. **`group_start`**: Registers a new tensor group and hooks into `saved_tensors_hooks` to intercept `save_for_backward`. +2. **Forward execution**: All tensors saved by autograd within the group are captured. +3. **`group_offload`**: Triggers asynchronous D2H copy on a dedicated CUDA stream (`d2h_stream`), optionally releases GPU storage of input tensors. +4. **Backward**: Before the group's backward, tensors are reloaded from CPU to GPU on `h2d_stream`, and the compute stream waits for the transfer to complete. + +### Warmup and Adaptive Offloading + +The first training iteration serves as a **warmup phase** where the manager records tensor groups, their sizes, and the execution order. After warmup, a `post_warmup_callback` runs to: + +1. **Reserve margin**: The last N groups (by deduplication count) are kept on GPU to avoid reload blocking the compute stream. +2. **Apply PP rank delta**: Higher PP ranks offload fewer bytes (controlled by `delta_offload_bytes_across_pp_ranks`). +3. **Apply fraction**: Only a fraction of eligible groups are actually offloaded (controlled by `activation_offload_fraction`). +4. **Print summary table**: An ASCII table of per-rank offload bytes is printed for debugging. + +### CPU Tensor Pool + +A `GPUTensorPool` (on CPU with pinned memory) caches allocated tensors by `(shape, dtype)`. This avoids repeated `cudaMallocHost` / `cudaFreeHost` calls and reduces D2H latency after the first iteration. + +### CUDA Graph Support + +When offloading modules captured inside a CUDA graph: + +- A dedicated `cuda_graph_stream` runs the captured computation, while `d2h_stream` overlaps D2H transfers. +- During CUDA graph **warmup**, offloading is disabled (`pre_warmup_hook` / `post_warmup_hook`). +- The `delay_offload_until_cuda_graph` option defers D2H launches until graph replay, utilizing the CPU idle time during `cudaGraphLaunch` to issue offload commands with near-zero CPU overhead. diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index 6658b6363ea..4615b62d456 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -476,18 +476,16 @@ def forward_func( ) if not isinstance(layer.mlp, MoELayer): return hidden_states, None, None, None + mlp_norm_manager = off_interface(layer.offload_mlp_norm, hidden_states, "mlp_norm") + node.layer_state.mlp_norm_manager = mlp_norm_manager if layer.recompute_pre_mlp_layernorm: layer.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface( - layer.offload_mlp_norm, hidden_states, "mlp_norm" - ) as hidden_states: + with mlp_norm_manager as hidden_states: pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint( apply_module(layer.pre_mlp_layernorm), hidden_states ) else: - with off_interface( - layer.offload_mlp_norm, hidden_states, "mlp_norm" - ) as hidden_states: + with mlp_norm_manager as hidden_states: pre_mlp_layernorm_output = apply_module(layer.pre_mlp_layernorm)( hidden_states ) @@ -589,10 +587,12 @@ def submodule_combine_forward(node: ScheduleNode, output: torch.Tensor): ) # Delay the offload of the mlp norm until after the mlp_bda has been computed # because the residual is needed in the mlp_bda. - if layer.offload_mlp_norm: - hidden_states = off_interface.group_commit( - hidden_states, name="mlp_norm", forced_released_tensors=[residual] + mlp_norm_manager = getattr(node.layer_state, 'mlp_norm_manager', None) + if mlp_norm_manager is not None: + hidden_states = mlp_norm_manager.group_offload( + hidden_states, forced_released_tensors=[residual] ) + node.layer_state.mlp_norm_manager = None output = make_viewless_tensor( inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 27b62f91c34..5cc5a64e1d0 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -458,19 +458,22 @@ def _preprocess( def preprocess_for_fine_grained_offloading(self): """Preprocess for fine-grained activation offloading.""" off_interface.init_chunk_handler( + pp_rank=self.pg_collection.pp.rank(), vp_size=self.config.virtual_pipeline_model_parallel_size, vp_stage=self.vp_stage, min_offloaded_tensor_size=self.config.min_offloaded_tensor_size, + delta_offload_bytes_across_pp_ranks=self.config.delta_offload_bytes_across_pp_ranks, + activation_offload_fraction=self.config.activation_offload_fraction, ) if self.disable_param_offloading: for param in self.decoder.parameters(): - off_interface.mark_not_offloadable(param) + off_interface.mark_not_offload(param) if self.mtp_process: for param in self.mtp.parameters(): - off_interface.mark_not_offloadable(param) + off_interface.mark_not_offload(param) if self.post_process: for param in self.output_layer.parameters(): - off_interface.mark_not_offloadable(param) + off_interface.mark_not_offload(param) self.disable_param_offloading = False def forward( diff --git a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py index 1d2545b682d..6e66d5ddffb 100644 --- a/megatron/core/pipeline_parallel/fine_grained_activation_offload.py +++ b/megatron/core/pipeline_parallel/fine_grained_activation_offload.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Tuple import torch +from torch.autograd.graph import saved_tensors_hooks # CPU offload implementation for pipeline parallelism DEBUG = False @@ -410,11 +411,16 @@ def __init__(self): # allocate streams and events for synchronization self._d2h_stream = torch.cuda.Stream() self._h2d_stream = torch.cuda.Stream() + # CUDA graph stream and event for offloading modules in cuda graph + self._cuda_graph_stream = torch.cuda.Stream() + self._cuda_graph_event = torch.cuda.Event(external=True) # Shared CPU tensor pool for all chunks to improve reuse efficiency self._cpu_tensor_pool = GPUTensorPool(device="cpu", pin_memory=True) # Whether the manager is in warmup phase. self._is_warmup = True + # Whether the manager is in CUDA graph replay phase. + self._in_replay = False # Cache OffloadChunkHandler objects for each virtual pipeline stage and each forward pass. self._cached_chunks_forward = [] # Cache OffloadChunkHandler objects for each virtual pipeline stage and each backward pass. @@ -433,6 +439,10 @@ def __init__(self): self._delayed_offload_groups = [] self.reset() + self._saved_tensors_hooks = saved_tensors_hooks( + self.on_save_for_backward, self.on_get_saved_tensor + ) + @property def d2h_stream(self): """Get the device-to-host (GPU to CPU) transfer stream.""" @@ -443,22 +453,32 @@ def h2d_stream(self): """Get the host-to-device (CPU to GPU) transfer stream.""" return self._h2d_stream + @property + def cuda_graph_stream(self): + """Get the CUDA graph stream.""" + return self._cuda_graph_stream + + @property + def cuda_graph_event(self): + """Get the CUDA graph event.""" + return self._cuda_graph_event + @property def cpu_tensor_pool(self): """Get the shared CPU tensor pool.""" return self._cpu_tensor_pool - def push_offload_groups(self, group_hook, forced_released_tensors): + def push_offload_groups(self, group_hook, name, forced_released_tensors): """Push the offload groups to the delayed queue.""" debug_rank(f"pushing offload groups to the delayed queue") - self._delayed_offload_groups.append((group_hook, forced_released_tensors)) + self._delayed_offload_groups.append((group_hook, name, forced_released_tensors)) def flush_delayed_groups(self): """Flush the delayed groups.""" debug_rank("flushing delayed groups") # Flush the delayed groups in reverse order to maintain the order of the groups. - for group_hook, forced_released_tensors in reversed(self._delayed_offload_groups): - group_hook(forced_released_tensors) + for group_hook, name, forced_released_tensors in self._delayed_offload_groups: + group_hook(name, forced_released_tensors) self._delayed_offload_groups = [] def reset(self): @@ -549,13 +569,39 @@ def post_warmup_callback(self): debug_rank(f"setting offload to false for group {name} at chunk index {chunk_idx}") else: break - debug_rank(f"offload margin {self._offload_margin}") assert self._offload_margin == 0, "Offload margin is not 0" + # Disable the groups to meet the delta offload bytes across PP ranks. + keep_on_gpu_bytes = self._pp_rank * self._delta_offload_bytes_across_pp_ranks + for chunk in self._cached_chunks_backward: + for group in chunk.offload_groups: + if group.offload and keep_on_gpu_bytes > 0: + debug_rank( + f"group {group._name} offload {group.offload} \ + keep_on_gpu_bytes {keep_on_gpu_bytes}" + ) + keep_on_gpu_bytes -= group.total_offload_bytes + group.offload = False + # Disable the groups to meet the activation offload fraction. + for chunk in self._cached_chunks_backward: + offloaded_groups_count = 0 + for group in chunk.offload_groups: + if group.offload: + offloaded_groups_count += 1 + disabled_groups_count = offloaded_groups_count * (1 - self._activation_offload_fraction) + debug_rank(f"Disabled {disabled_groups_count}/{offloaded_groups_count} groups") + for group in reversed(chunk.offload_groups): + if group.offload: + if disabled_groups_count > 0: + disabled_groups_count -= 1 + group.offload = False + else: + break # Dump the offload information total_tensor_count = {} total_offload_bytes = {} for chunk in self._cached_chunks_forward: for group in chunk.offload_groups: + debug_rank(f"chunk {chunk} group {group} offload {group.offload}") if group.offload: if group._name not in total_tensor_count: total_tensor_count[group._name] = 0 @@ -567,6 +613,8 @@ def post_warmup_callback(self): # where the memory cost will not increase anymore. if chunk is self._cached_chunks_backward[0]: break + debug_rank(f"total_tensor_count {total_tensor_count}") + debug_rank(f"total_offload_bytes {total_offload_bytes}") # Cache summary for downstream consumers (e.g., unit tests). self._offload_summary_bytes = dict(total_offload_bytes) self._offload_summary_total_bytes = int(sum(total_offload_bytes.values())) @@ -607,15 +655,24 @@ def front_backward_chunk(self, name=None): return None def init_model_chunk_offload_handler( - self, vp_size, vp_stage, min_offloaded_tensor_size=1024 * 1024 + self, + pp_rank, + vp_size, + vp_stage, + min_offloaded_tensor_size=1024 * 1024, + delta_offload_bytes_across_pp_ranks=0, + activation_offload_fraction: float = 1.0, ): """ Initialize a chunk offload handler for a model chunk (microbatch). Args: + pp_rank: Pipeline parallel rank vp_size: Virtual pipeline size vp_stage: Virtual pipeline stage index (None means stage 0) min_offloaded_tensor_size: Minimum tensor size (in elements) to offload + delta_offload_bytes_across_pp_ranks: + Difference of offload bytes across PP ranks to balance the offload load. """ if not self._is_warmup: return @@ -625,6 +682,10 @@ def init_model_chunk_offload_handler( self._vpp = vp_size self._stages = [[] for _ in range(vp_size)] + self._delta_offload_bytes_across_pp_ranks = delta_offload_bytes_across_pp_ranks + self._pp_rank = pp_rank + self._activation_offload_fraction = activation_offload_fraction + if vp_stage is None: cur_vpp_rank = 0 else: @@ -670,10 +731,10 @@ def cur_backward_chunk(self): """Get the current backward pass chunk handler.""" return self._cur_backward_chunk - def mark_not_offloadable(self, tensor: torch.Tensor): + def mark_not_offload(self, tensor: torch.Tensor): """Mark the current forward chunk as not offloadable.""" if tensor is not None: - tensor.offloading_activation = False + tensor._do_not_offload = True def __enter__(self): """Enter context manager to enable activation offloading hooks.""" @@ -687,10 +748,7 @@ def __enter__(self): else: raise RuntimeError("TE CPU offload is not available") self.inside_context = True - - torch._C._autograd._push_saved_tensors_default_hooks( - self.on_save_for_backward, self.on_get_saved_tensor - ) + self._saved_tensors_hooks.__enter__() def __exit__(self, *args: Any): """Exit context manager and restore original tensor saving behavior.""" @@ -704,7 +762,7 @@ def __exit__(self, *args: Any): else: raise RuntimeError("TE CPU offload is not available") self.inside_context = False - torch._C._autograd._pop_saved_tensors_default_hooks() + self._saved_tensors_hooks.__exit__() def on_save_for_backward(self, tensor: torch.Tensor) -> Any: """ @@ -794,17 +852,17 @@ def reset(self): self._tensor_count_current_group = 0 self._reloading_group = [] - def find_group_with_name(self, name: str, start_index: int = 0): + def find_group_with_name( + self, groups: list[OffloadTensorGroup], name: str, start_index: int = 0 + ): """Find the group with the given name starting from the given index.""" - return next( - (group for group in self.offload_groups[start_index:] if group._name == name), None - ) + return next((group for group in groups[start_index:] if group._name == name), None) def is_empty_chunk(self, name=None): """Check if this chunk has no tensors to manage.""" debug_rank(f"------is_empty_chunk {self._max_group_size}") if name is not None: - return self.find_group_with_name(name) is None + return self.find_group_with_name(self.offload_groups, name) is None return self._max_group_size == 0 def finish_all_groups(self, name=None) -> bool: @@ -821,12 +879,15 @@ def finish_all_groups(self, name=None) -> bool: ): return True assert name is not None, "Name is required" - return self.find_group_with_name(name, self._offloaded_group_index) is None + return ( + self.find_group_with_name(self.offload_groups, name, self._offloaded_group_index) + is None + ) def find_next_group(self, name=None): """Find the next group with the given name.""" assert name is not None, "Name is required" - return self.find_group_with_name(name, self._offloaded_group_index) + return self.find_group_with_name(self.offload_groups, name, self._offloaded_group_index) def tensor_push(self, tensor): """Push tensor to the offload handler.""" @@ -859,20 +920,19 @@ def tensor_pop(self, tensor_tag): def tensor_need_offloading_checker(self, tensor): """Check if the tensor needs to be offloaded.""" - debug_rank( - f"tensor_need_offloading_checker {getattr(tensor, 'offloading_activation', None)}" - ) + debug_rank("tensor_need_offloading_checker") if tensor.numel() < self.min_offloaded_tensor_size: return False # Respect tensor's offload preference if specified - if hasattr(tensor, "offloading_activation") and not tensor.offloading_activation: + if getattr(tensor, "_TE_do_not_offload", False) or getattr( + tensor, "_do_not_offload", False + ): return False return True - def bulk_offload_group(self): + def bulk_offload_group(self, group_to_offload): """offload a group of tensors recorded in tensor_push().""" debug_rank("------bulk_offload_group") - group_to_offload = self._groups_to_offload[-1] torch.cuda.nvtx.range_push("activation offloading " + group_to_offload._name) with torch.cuda.stream(self.d2h_stream): for tensor_tag, tensor_on_device in group_to_offload._tensors.items(): @@ -885,7 +945,6 @@ def bulk_offload_group(self): tensor_on_device.record_stream(self.d2h_stream) group_to_offload.push_tensor(tensor_tag, state) group_to_offload.record_offload_event(self.d2h_stream) - self._groups_to_offload.pop() torch.cuda.nvtx.range_pop() def get_max_deduplicated_groups(self): @@ -925,10 +984,11 @@ def pre_reload_last_layer(self): # Reload the last group (last layer) early self.bulk_reload_group() - def should_bulk_offload(self): + def should_bulk_offload(self, name): """Determine if the current group should be offloaded.""" assert len(self._groups_to_offload) > 0, "No groups to offload" - group = self._groups_to_offload[-1] + group = self.find_group_with_name(self._groups_to_offload, name) + assert group is not None, f"Group {name} not found in {self._groups_to_offload}" debug_rank(f"should_bulk_offload {self.is_warmup} {group.offload}") # Don't offload if the chunk is not in warmup stage if self.is_warmup: @@ -949,12 +1009,17 @@ def should_bulk_offload(self): return True - def bulk_offload(self, forced_released_tensors): + def bulk_offload(self, name, forced_released_tensors): """Offload a group of tensors and optionally release their GPU memory.""" debug_rank("----bulk_offload") - if self.should_bulk_offload(): - self._groups_to_reload.append(self._groups_to_offload[-1]) - self.bulk_offload_group() + if self.should_bulk_offload(name): + group_to_offload = self.find_group_with_name(self._groups_to_offload, name) + assert ( + group_to_offload is not None + ), f"Group {name} not found in {self._groups_to_offload}" + self._groups_to_reload.append(group_to_offload) + self.bulk_offload_group(group_to_offload) + self._groups_to_offload.remove(group_to_offload) # Manually release tensors not auto-freed by torch GC if len(forced_released_tensors) > 0: cur_stream = torch.cuda.current_stream() @@ -964,14 +1029,14 @@ def bulk_offload(self, forced_released_tensors): release_tensor.record_stream(cur_stream) release_tensor.untyped_storage().resize_(0) - def on_group_commit_forward(self, forced_released_tensors): + def on_group_commit_forward(self, name, forced_released_tensors): """Called at the end of a layer group's forward pass to trigger offloading.""" if not self.do_offload: return - debug_rank("--on_group_commit_forward") + debug_rank(f"--on_group_commit_forward {name}") # Wait for compute to finish before starting offload self.d2h_stream.wait_stream(torch.cuda.current_stream()) - self.bulk_offload(forced_released_tensors) + self.bulk_offload(name, forced_released_tensors) def bulk_reload(self): """Reload the next group of tensors from CPU to GPU.""" @@ -1070,12 +1135,12 @@ def forward(ctx, tensor, cur_forward_chunk, name, forced_released_tensors, delay # pylint: disable=missing-function-docstring debug_rank("FineGrainedOffloadingGroupCommitFunction forward") - if delay_offload: + if delay_offload and PipelineOffloadManager.get_instance()._in_replay: PipelineOffloadManager.get_instance().push_offload_groups( - cur_forward_chunk.on_group_commit_forward, forced_released_tensors + cur_forward_chunk.on_group_commit_forward, name, forced_released_tensors ) else: - cur_forward_chunk.on_group_commit_forward(forced_released_tensors) + cur_forward_chunk.on_group_commit_forward(name, forced_released_tensors) ctx.cpu_offload_handler = cur_forward_chunk ctx.name = name return tensor @@ -1172,13 +1237,6 @@ def fine_grained_offloading_group_start(tensor, name=None): return FineGrainedOffloadingGroupStartFunction.apply(tensor, cur_forward_chunk, name) -def fine_grained_offloading_forward_record(event: torch.cuda.Event) -> None: - """Record the forward event for cuda graph capture.""" - d2h_stream = PipelineOffloadManager.get_instance().d2h_stream - torch.cuda.current_stream().record_event(event) - torch.cuda.current_stream().wait_stream(d2h_stream) - - class FineGrainedOffloadingBackwardRecordFunction(torch.autograd.Function): """ Identity operation that marks the end of a layer group for offload synchronization. @@ -1186,23 +1244,19 @@ class FineGrainedOffloadingBackwardRecordFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, tensor, event: torch.cuda.Event) -> torch.Tensor: + def forward(ctx, tensor) -> torch.Tensor: """Forward pass for cuda graph capture.""" - ctx.event = event + debug_rank("FineGrainedOffloadingBackwardRecordFunction forward") return tensor @staticmethod def backward(ctx, grad_output): """Record the backward event and wait for the h2d stream on cuda graph stream.""" - h2d_stream = PipelineOffloadManager.get_instance().h2d_stream - torch.cuda.current_stream().record_event(ctx.event) - torch.cuda.current_stream().wait_stream(h2d_stream) - return grad_output, None - - -def fine_grained_offloading_backward_record(tensor, event: torch.cuda.Event) -> torch.Tensor: - """Record the backward event for cuda graph capture.""" - return FineGrainedOffloadingBackwardRecordFunction.apply(tensor, event) + debug_rank("FineGrainedOffloadingBackwardRecordFunction backward") + mgr = PipelineOffloadManager.get_instance() + torch.cuda.current_stream().record_event(mgr.cuda_graph_event) + torch.cuda.current_stream().wait_stream(mgr.h2d_stream) + return (grad_output,) class FineGrainedActivationOffloadingInterface: @@ -1226,10 +1280,32 @@ def __exit__(self, *args: Any): PipelineOffloadManager.get_instance().__exit__() @staticmethod - def init_chunk_handler(vp_size, vp_stage, min_offloaded_tensor_size): + def cuda_graph_stream(): + """Get the CUDA graph stream.""" + return PipelineOffloadManager.get_instance().cuda_graph_stream + + @staticmethod + def cuda_graph_event(): + """Get the CUDA graph event.""" + return PipelineOffloadManager.get_instance().cuda_graph_event + + @staticmethod + def init_chunk_handler( + pp_rank, + vp_size, + vp_stage, + min_offloaded_tensor_size, + delta_offload_bytes_across_pp_ranks, + activation_offload_fraction, + ): """Initialize the chunk handler, called at the start of a microbatch forward pass.""" PipelineOffloadManager.get_instance().init_model_chunk_offload_handler( - vp_size, vp_stage, min_offloaded_tensor_size + pp_rank, + vp_size, + vp_stage, + min_offloaded_tensor_size, + delta_offload_bytes_across_pp_ranks, + activation_offload_fraction, ) @staticmethod @@ -1237,25 +1313,32 @@ def get_context(flag): """Get the fine-grained offload context""" return PipelineOffloadManager.get_instance() if flag else nullcontext() - @staticmethod - def group_commit(tensor, name, forced_released_tensors=None, delay_offload=False): - """Group commit the tensors.""" - return fine_grained_offloading_group_commit( - tensor, name, forced_released_tensors, delay_offload - ) + def group_offload(self, tensor, forced_released_tensors=None, delay_offload=False): + """Group offload the tensors.""" + if self.offload: + return fine_grained_offloading_group_commit( + tensor, self.name, forced_released_tensors, delay_offload + ) + return tensor @staticmethod - def mark_not_offloadable(tensor: torch.Tensor): + def mark_not_offload(tensor: torch.Tensor): """Mark the tensor as not offloadable.""" - PipelineOffloadManager.get_instance().mark_not_offloadable(tensor) + PipelineOffloadManager.get_instance().mark_not_offload(tensor) @staticmethod - def forward_record(event: torch.cuda.Event) -> None: + def forward_record() -> None: """Record the forward event for cuda graph capture.""" - d2h_stream = PipelineOffloadManager.get_instance().d2h_stream - torch.cuda.current_stream().record_event(event) - torch.cuda.current_stream().wait_stream(d2h_stream) + mgr = PipelineOffloadManager.get_instance() + torch.cuda.current_stream().record_event(mgr.cuda_graph_event) + torch.cuda.current_stream().wait_stream(mgr.d2h_stream) + @staticmethod + def backward_record(tensor) -> torch.Tensor: + """Record the backward event for cuda graph capture.""" + return FineGrainedOffloadingBackwardRecordFunction.apply(tensor) + + @staticmethod def reset(): """Reset the chunk handler.""" PipelineOffloadManager.get_instance().reset() @@ -1264,3 +1347,28 @@ def reset(): def reset_instance(): """Reset the singleton instance.""" PipelineOffloadManager.reset_instance() + + @staticmethod + def flush_delayed_groups(): + """Flush the delayed groups.""" + PipelineOffloadManager.get_instance().flush_delayed_groups() + + @staticmethod + def disable_offload(): + """Disable the offload.""" + PipelineOffloadManager.get_instance().disable_offload() + + @staticmethod + def enable_offload(): + """Enable the offload.""" + PipelineOffloadManager.get_instance().enable_offload() + + @staticmethod + def enter_replay(): + """Enter CUDA graph replay mode to enable delayed offloading.""" + PipelineOffloadManager.get_instance()._in_replay = True + + @staticmethod + def exit_replay(): + """Exit CUDA graph replay mode.""" + PipelineOffloadManager.get_instance()._in_replay = False diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index e903f392bf0..a142956068d 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -689,7 +689,7 @@ def forward_backward_no_pipelining( force_all_reduce=force_all_reduce, ) - if not forward_only and config.fine_grained_activation_offloading: + if getattr(config, 'fine_grained_activation_offloading', False): off_interface.reset() if config.timers is not None: @@ -2054,7 +2054,7 @@ def pp_post_backward(input_tensor_grad, vp_stage=None): force_all_reduce=force_all_reduce, ) - if not forward_only and config.fine_grained_activation_offloading: + if getattr(config, 'fine_grained_activation_offloading', False): off_interface.reset() # Restore config.grad_sync_func and config.param_sync_func. if forward_only: @@ -2442,7 +2442,7 @@ def enable_grad_sync(): force_all_reduce=force_all_reduce, ) - if not forward_only and config.fine_grained_activation_offloading: + if getattr(config, 'fine_grained_activation_offloading', False): off_interface.reset() if config.timers is not None: diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index b8d9ef69443..72b4854a727 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -982,18 +982,16 @@ def forward( if output_gate: assert split_qkv, "output_gate is not supported for unsplit mixed_qkv tensor." - with off_interface(self.offload_qkv_linear, hidden_states, "qkv_linear") as hidden_states: + qkv_linear_manager = off_interface(self.offload_qkv_linear, hidden_states, "qkv_linear") + with qkv_linear_manager as hidden_states: qkv_output = self.get_query_key_value_tensors( hidden_states, key_value_states, split_qkv=split_qkv, output_gate=self.config.attention_output_gate, ) - if self.offload_qkv_linear: - # `qkv_output` may be a tuple; commit supports tuple/list and will keep structure. - qkv_output = off_interface.group_commit( - qkv_output, name="qkv_linear", forced_released_tensors=[] - ) + # `qkv_output` may be a tuple; commit supports tuple/list and will keep structure. + qkv_output = qkv_linear_manager.group_offload(qkv_output, forced_released_tensors=[]) attn_mask_type = self.attn_mask_type block_table = None gate = None @@ -1136,6 +1134,9 @@ def forward( # ================================== nvtx_range_push(suffix="core_attention") + core_attn_manager = off_interface( + self.offload_core_attention and self.training, query, "core_attn" + ) if self.checkpoint_core_attention and self.training: core_attn_out = self._checkpointed_attention_forward( query, @@ -1149,9 +1150,7 @@ def forward( else: if inference_context is None or inference_context.is_static_batching(): # Static batching attention kernel. - with off_interface( - self.offload_core_attention and self.training, query, "core_attn" - ) as query: + with core_attn_manager as query: core_attn_out = apply_module(self.core_attention)( query, key, @@ -1187,10 +1186,9 @@ def forward( if is_using_quantization_scales(self.config): core_attn_out[inference_context.padding_slice] = 0.0 - if self.offload_core_attention and self.training: - core_attn_out = off_interface.group_commit( - core_attn_out, name="core_attn", forced_released_tensors=[query, key, value] - ) + core_attn_out = core_attn_manager.group_offload( + core_attn_out, forced_released_tensors=[query, key, value] + ) if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': # reshape to same output shape as unpacked case @@ -1210,12 +1208,10 @@ def forward( # Output. [sq, b, h] # ================= nvtx_range_push(suffix="linear_proj") - with off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") as core_attn_out: + attn_proj_manager = off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") + with attn_proj_manager as core_attn_out: output, bias = self.linear_proj(core_attn_out) - if self.offload_attn_proj: - output = off_interface.group_commit( - output, name="attn_proj", forced_released_tensors=[core_attn_out] - ) + output = attn_proj_manager.group_offload(output, forced_released_tensors=[core_attn_out]) nvtx_range_pop(suffix="linear_proj") return output, bias diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index f7b2bc79cab..b05ea82696a 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -2176,6 +2176,15 @@ def _get_fp8_enabled(): ) else: kwargs['fp8_enabled'] = False + + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + # Disable and enable offloading before and after the warmup stage of cuda graph. + if self.config.fine_grained_activation_offloading: + kwargs['pre_warmup_hook'] = off_interface.disable_offload + kwargs['post_warmup_hook'] = off_interface.enable_offload return kwargs kwargs = get_make_graphed_callables_kwargs() @@ -2210,6 +2219,12 @@ def _finish_capturing(self, start_time): _set_capture_end() from megatron.core.distributed.finalize_model_grads import reset_model_temporary_tensors + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + if self.config.fine_grained_activation_offloading: + off_interface.reset() torch.distributed.barrier() for model_chunk in self.model: diff --git a/megatron/core/transformer/module.py b/megatron/core/transformer/module.py index 6539ee36105..2d588262676 100644 --- a/megatron/core/transformer/module.py +++ b/megatron/core/transformer/module.py @@ -322,6 +322,15 @@ def _get_te_cuda_graph_replay_args(self, *args, **kwargs): cudagraph_kwargs = kwargs.copy() cudagraph_kwargs['is_first_microbatch'] = getattr(self, 'current_microbatch', 0) == 0 + if self.config.fine_grained_activation_offloading and getattr( + self, 'offload_module_in_cuda_graph', False + ): + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + cudagraph_kwargs['cuda_graph_stream'] = off_interface.cuda_graph_stream() + cudagraph_kwargs['cuda_graph_event'] = off_interface.cuda_graph_event() return cudagraph_args, cudagraph_kwargs def _should_call_local_cudagraph(self, *args, **kwargs): diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 8a271ab4fb9..e0e4f2b0f12 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -707,7 +707,7 @@ def __init__( set_save_original_input(self.linear_fc2) # This is to avoid the CPU overhead of multiple d2h copies - if self.offload_expert_fc1: + if self.offload_expert_fc1 and not self.config.fp8: from megatron.core.extensions.transformer_engine import set_save_original_input set_save_original_input(self.linear_fc1) @@ -776,18 +776,18 @@ def forward( # Probs already applied, so reset to 1. permuted_probs = torch.ones_like(permuted_probs) - with off_interface( + expert_fc1_manager = off_interface( self.offload_expert_fc1, permuted_local_hidden_states, "expert_fc1" - ) as permuted_local_hidden_states: + ) + with expert_fc1_manager as permuted_local_hidden_states: fc1_output, bias_parallel = apply_module(self.linear_fc1)( permuted_local_hidden_states, tokens_per_expert ) - if self.offload_expert_fc1: - fc1_output = off_interface.group_commit( - fc1_output, - name="expert_fc1", - forced_released_tensors=[permuted_local_hidden_states], - ) + fc1_output = expert_fc1_manager.group_offload( + fc1_output, + forced_released_tensors=[permuted_local_hidden_states], + delay_offload=self.config.delay_offload_until_cuda_graph, + ) def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): if self.config.use_te_activation_func: @@ -847,14 +847,15 @@ def glu(x): intermediate_parallel = intermediate_parallel.to(original_dtype) return intermediate_parallel + moe_act_manager = off_interface(self.offload_moe_act, fc1_output, "moe_act") if self.activation_recompute: self.activation_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: + with moe_act_manager as fc1_output: bias_act_output = self.activation_checkpoint.checkpoint( bias_act_func, fc1_output, bias_parallel, permuted_probs ) else: - with off_interface(self.offload_moe_act, fc1_output, "moe_act") as fc1_output: + with moe_act_manager as fc1_output: bias_act_output = bias_act_func(fc1_output, bias_parallel, permuted_probs) output, output_bias = apply_module(self.linear_fc2)(bias_act_output, tokens_per_expert) @@ -863,10 +864,11 @@ def glu(x): # Delay the offload of the moe act until after the linear_fc2 has been computed # to make sure the fc1_output is reloaded to GPU before recomputing moe_act. - if self.offload_moe_act: - output = off_interface.group_commit( - output, name="moe_act", forced_released_tensors=[fc1_output] - ) + output = moe_act_manager.group_offload( + output, + forced_released_tensors=[fc1_output], + delay_offload=self.config.delay_offload_until_cuda_graph, + ) output = self._apply_bias(output, output_bias, tokens_per_expert, permuted_probs) # upad and concat the output diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index a9cdc697cc8..a8bc121ca97 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -245,7 +245,8 @@ def forward( # Get the query, key and value tensors based on the type of attention - # self or cross attn. # query: [96, 1, 16, 128], key:[96, 1, 16, 128], value:[96, 1, 16, 128] - with off_interface(self.offload_qkv_linear, hidden_states, "qkv_linear") as hidden_states: + qkv_linear_manager = off_interface(self.offload_qkv_linear, hidden_states, "qkv_linear") + with qkv_linear_manager as hidden_states: query, key, value, q_compressed, kv_compressed = self.get_query_key_value_tensors( hidden_states, key_value_states, @@ -253,10 +254,7 @@ def forward( packed_seq_params, inference_context=inference_context, ) - if self.offload_qkv_linear: - query = off_interface.group_commit( - query, name="qkv_linear", forced_released_tensors=[hidden_states] - ) + query = qkv_linear_manager.group_offload(query, forced_released_tensors=[]) # =================================================== # Adjust key, value for inference @@ -278,6 +276,9 @@ def forward( # core attention computation # ================================== # Need corresponding TE change + core_attn_manager = off_interface( + self.offload_core_attention and self.training, query, "core_attn" + ) if self.checkpoint_core_attention and self.training: core_attn_out = self._checkpointed_attention_forward( query, key, value, attention_mask, packed_seq_params=packed_seq_params @@ -290,9 +291,7 @@ def forward( # query representation. extra_kwargs["x"] = hidden_states extra_kwargs["qr"] = q_compressed - with off_interface( - self.offload_core_attention and self.training, query, "core_attn" - ) as query: + with core_attn_manager as query: core_attn_out = self.core_attention( query, key, @@ -322,10 +321,9 @@ def forward( # Only rearrange if not in absorption mode (Flash MLA handles format correctly) if not inference_context.is_decode_only(): core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') - if self.offload_core_attention and self.training: - core_attn_out = off_interface.group_commit( - core_attn_out, name="core_attn", forced_released_tensors=[query, key, value] - ) + core_attn_out = core_attn_manager.group_offload( + core_attn_out, forced_released_tensors=[query, key, value] + ) # We are doing absorption with cache mla latents and decode mode. if self.cache_mla_latents and inference_context.is_decode_only(): @@ -351,12 +349,10 @@ def forward( # ================= # Output. [sq, b, h] # ================= - with off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") as core_attn_out: + attn_proj_manager = off_interface(self.offload_attn_proj, core_attn_out, "attn_proj") + with attn_proj_manager as core_attn_out: output, bias = self.linear_proj(core_attn_out) - if self.offload_attn_proj: - output = off_interface.group_commit( - output, name="attn_proj", forced_released_tensors=[core_attn_out] - ) + output = attn_proj_manager.group_offload(output, forced_released_tensors=[core_attn_out]) return output, bias diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index d48e29c1e71..3f6056e9ae6 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -943,6 +943,15 @@ class TransformerConfig(ModelParallelConfig): min_offloaded_tensor_size: int = 1024 * 1024 """The minimum size of the tensor to be offloaded.""" + delay_offload_until_cuda_graph: bool = False + """If True, delay the offload until the CUDA graph is executed for minimal CPU overhead.""" + + delta_offload_bytes_across_pp_ranks: int = 0 + """Difference of offload bytes across PP ranks to balance the offload load.""" + + activation_offload_fraction: float = 1.0 + """The fraction of the activation to be offloaded, which should be in range [0, 1].""" + def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more @@ -1352,6 +1361,41 @@ def __post_init__(self): "because the input of attn_proj is the output of core_attn, " "which is needed in core_attn.backward()." ) + if self.delay_offload_until_cuda_graph: + assert ( + self.transformer_impl == "transformer_engine" + ), "delay_offload_until_cuda_graph must be used with cuda graph." + assert ( + self.min_offloaded_tensor_size >= 0 + ), "min_offloaded_tensor_size must be non-negative." + assert ( + self.activation_offload_fraction >= 0 and self.activation_offload_fraction <= 1 + ), "activation_offload_fraction must be in range [0, 1]." + assert ( + self.delta_offload_bytes_across_pp_ranks >= 0 + ), "delta_offload_bytes_across_pp_ranks must be non-negative." + if self.external_cuda_graph or self.enable_cuda_graph: + assert ( + self.cuda_graph_impl == "transformer_engine" + ), "cuda_graph_impl must be transformer_engine when enabling offloading." + if self.cuda_graph_impl == "transformer_engine": + assert ( + self.cuda_graph_scope is not None + ), "cuda_graph_scope must be set when enabling offloading." + if ( + "attn" in self.cuda_graph_scope + or "moe_router" in self.cuda_graph_scope + or "moe_preprocess" in self.cuda_graph_scope + or CudaGraphScope.attn in self.cuda_graph_scope + or CudaGraphScope.moe_router in self.cuda_graph_scope + or CudaGraphScope.moe_preprocess in self.cuda_graph_scope + ): + assert ( + "attn_norm" not in self.offload_modules + ), "attn_norm is the start point of cuda graph, so can't be offloaded." + assert ( + "mlp_norm" not in self.offload_modules + ), "mlp_norm goes through the boundary of cuda graph, so can't be offloaded." if ( self.num_layers_in_first_pipeline_stage is not None diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 58fe690c553..f9609165f24 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -30,6 +30,7 @@ deprecate_inference_params, get_pg_rank, is_te_min_version, + is_torch_min_version, log_single_rank, make_viewless_tensor, nvtx_range_pop, @@ -461,17 +462,9 @@ def can_recompute_pre_mlp_layernorm_for_cudagraph(): if "mlp" in self.config.recompute_modules: if not self.is_moe_layer: self.recompute_mlp = True - self.offload_attn_norm = ( - self.config.fine_grained_activation_offloading - and "attn_norm" in self.config.offload_modules - and not isinstance(self.input_layernorm, IdentityOp) - ) - self.offload_mlp_norm = ( - self.config.fine_grained_activation_offloading - and "mlp_norm" in self.config.offload_modules - and not isinstance(self.pre_mlp_layernorm, IdentityOp) - ) + self._set_offload_modules() + self.mlp_norm_manager = None # @jcasper how should we handle nvfuser? # Set bias+dropout+add fusion grad_enable execution handler. # TORCH_MAJOR = int(torch.__version__.split('.')[0]) @@ -572,14 +565,15 @@ def _forward_attention( residual = residual.float() # Optional Input Layer norm + attn_norm_manager = off_interface(self.offload_attn_norm, hidden_states, "attn_norm") if self.recompute_input_layernorm: self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: + with attn_norm_manager as hidden_states: input_layernorm_output = self.input_layernorm_checkpoint.checkpoint( apply_module(self.input_layernorm), hidden_states ) else: - with off_interface(self.offload_attn_norm, hidden_states, "attn_norm") as hidden_states: + with attn_norm_manager as hidden_states: input_layernorm_output = apply_module(self.input_layernorm)(hidden_states) using_fused_tp_inference_kernel = (not self.training) and ( @@ -631,10 +625,9 @@ def _forward_attention( # Delay the offload of the attention norm until after the self_attn_bda has been computed # because the residual is needed in the self_attn_bda. - if self.offload_attn_norm: - hidden_states = off_interface.group_commit( - hidden_states, name="attn_norm", forced_released_tensors=[residual] - ) + hidden_states = attn_norm_manager.group_offload( + hidden_states, forced_released_tensors=[residual] + ) # Residual connection. residual = hidden_states @@ -687,14 +680,15 @@ def _forward_pre_mlp_layernorm(self, hidden_states: Tensor): FineGrainedActivationOffloadingInterface as off_interface, ) + self.mlp_norm_manager = off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") if self.recompute_pre_mlp_layernorm: self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: + with self.mlp_norm_manager as hidden_states: pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint( apply_module(self.pre_mlp_layernorm), hidden_states ) else: - with off_interface(self.offload_mlp_norm, hidden_states, "mlp_norm") as hidden_states: + with self.mlp_norm_manager as hidden_states: pre_mlp_layernorm_output = apply_module(self.pre_mlp_layernorm)(hidden_states) return pre_mlp_layernorm_output @@ -814,9 +808,6 @@ def _forward_post_mlp( Returns: output (Tensor): Transformed hidden states of shape [s, b, h]. """ - from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( - FineGrainedActivationOffloadingInterface as off_interface, - ) using_fused_tp_inference_kernel = (not self.training) and ( self.config.inference_fuse_tp_communication @@ -845,10 +836,11 @@ def _forward_post_mlp( nvtx_range_pop(suffix="mlp_bda") # Delay the offload of the mlp norm until after the mlp_bda has been computed # because the residual is needed in the mlp_bda. - if self.offload_mlp_norm: - hidden_states = off_interface.group_commit( - hidden_states, name="mlp_norm", forced_released_tensors=[residual] + if self.mlp_norm_manager is not None: + hidden_states = self.mlp_norm_manager.group_offload( + hidden_states, forced_released_tensors=[residual] ) + self.mlp_norm_manager = None # Jit compiled function creates 'view' tensor. This tensor # potentially gets saved in the MPU checkpoint function context, @@ -1003,6 +995,22 @@ def _te_cuda_graph_capture(self, *args, **kwargs): attribute can be set to control the scope of the CUDA graph. 2. If context is None, it cannot be returned as output. """ + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + # Record the backward event on cuda graph stream in backward pass. + # This is to ensure the main stream waits for computing on cuda graph stream to complete, + # and overlaps with the H2D transfer on reload stream. + if self.offload_module_in_cuda_graph: + if len(args) > 0: + hidden_states = args[0] + hidden_states = off_interface.backward_record(hidden_states) + args = (hidden_states,) + args[1:] + else: + hidden_states = kwargs.pop("hidden_states") + hidden_states = off_interface.backward_record(hidden_states) + kwargs["hidden_states"] = hidden_states context = None if not self.config.cuda_graph_scope or CudaGraphScope.attn in self.config.cuda_graph_scope: hidden_states, context = self._forward_attention(*args, **kwargs) @@ -1030,6 +1038,15 @@ def _te_cuda_graph_capture(self, *args, **kwargs): cuda_graph_outputs = list(hidden_states) if context is not None: cuda_graph_outputs.append(context) + # Record the forward event on cuda graph stream for cuda graph capture. + # This is to ensure the main stream waits for computing on cuda graph stream to complete, + # and overlaps with the D2H transfer on offloading stream. + if self.offload_module_in_cuda_graph: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + off_interface.forward_record() return tuple(cuda_graph_outputs) def _te_cuda_graph_replay(self, *args, **kwargs): @@ -1053,8 +1070,33 @@ def _te_cuda_graph_replay(self, *args, **kwargs): "For inference cuda graph, please use cuda_graph_impl=local instead." ) + if self.config.delay_offload_until_cuda_graph: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + off_interface.enter_replay() + + try: + return self._te_cuda_graph_replay_impl(args, kwargs, context) + finally: + if self.config.delay_offload_until_cuda_graph: + off_interface.exit_replay() + + def _te_cuda_graph_replay_impl(self, args, kwargs, context): + """Implementation of _te_cuda_graph_replay, separated for replay mode cleanup.""" cuda_graph_output = list(super()._te_cuda_graph_replay(*args, **kwargs)) + # Flush delayed offload groups from previous layers after graph replay. + # The CPU is idle during the sync between graph replay and a2a comm, + # so we use that time to execute the delayed offload operations. + if self.config.delay_offload_until_cuda_graph: + from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( + FineGrainedActivationOffloadingInterface as off_interface, + ) + + off_interface.flush_delayed_groups() + if kwargs.get('context') is not None: context = cuda_graph_output.pop() @@ -1253,6 +1295,47 @@ def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs) + def _set_offload_modules(self): + """Set the offload modules for the transformer layer.""" + if self.config.fine_grained_activation_offloading: + self.offload_attn_norm = "attn_norm" in self.config.offload_modules and not isinstance( + self.input_layernorm, IdentityOp + ) + self.offload_qkv_linear = "qkv_linear" in self.config.offload_modules + self.offload_core_attn = "core_attn" in self.config.offload_modules + self.offload_attn_proj = "attn_proj" in self.config.offload_modules + self.offload_mlp_norm = "mlp_norm" in self.config.offload_modules and not isinstance( + self.pre_mlp_layernorm, IdentityOp + ) + self.offload_expert_fc1 = "expert_fc1" in self.config.offload_modules + self.offload_moe_act = "moe_act" in self.config.offload_modules + else: + self.offload_attn_norm = False + self.offload_qkv_linear = False + self.offload_core_attn = False + self.offload_attn_proj = False + self.offload_mlp_norm = False + self.offload_expert_fc1 = False + self.offload_moe_act = False + # Set the offload module in cuda graph flag. + self.offload_module_in_cuda_graph = False + if CudaGraphScope.attn in self.config.cuda_graph_scope: + if self.offload_core_attn or self.offload_attn_proj or self.offload_qkv_linear: + self.offload_module_in_cuda_graph = True + if not self.is_moe_layer and CudaGraphScope.mlp in self.config.cuda_graph_scope: + if self.offload_mlp_norm: + self.offload_module_in_cuda_graph = True + if self.offload_module_in_cuda_graph: + assert is_torch_min_version( + "2.9.0a0" + ), "Offloading modules captured in cuda graph requires torch>=2.9.0." + assert is_te_min_version( + "2.13.0" + ), "Offloading modules captured in cuda graph requires TE>=2.13.0." + assert ( + self.config.cuda_graph_warmup_steps > 0 + ), "Fine-grained activation offloading needs cuda_graph_warmup_steps > 0." + def get_layer_norm_weights(self): """ Get the weights of all layernorms (attention and MLP) in the transformer layer. diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json index d5ced620365..8fbe219530d 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_fine_grained_offloading/golden_values_dev_dgx_h100.json @@ -341,4 +341,4 @@ "50": 1.89832 } } -} \ No newline at end of file +} diff --git a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json index 57848f8130e..03c8cb800c9 100644 --- a/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json +++ b/tests/functional_tests/test_cases/moe/gpt3_moe_mcore_te_tp2_pp2_ep4_etp1_no_mtp_no_a2a_ovlp_fine_grained_offloading/golden_values_dev_dgx_h100.json @@ -284,4 +284,4 @@ "50": 1.93018 } } -} \ No newline at end of file +} diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index 5ecd4e92d80..00687b8cdac 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -274,6 +274,9 @@ "fine_grained_activation_offloading": False, "min_offloaded_tensor_size": 1024 * 1024, "offload_modules": [], + "delay_offload_until_cuda_graph": False, + "delta_offload_bytes_across_pp_ranks": 0, + "activation_offload_fraction": 1.0, "hybrid_context_parallel": False, "max_seqlen_per_dp_cp_rank": None, "sequence_packing_scheduler": None, diff --git a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py index 558c6934a0c..41b9391e171 100644 --- a/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py +++ b/tests/unit_tests/pipeline_parallel/test_fine_grained_activation_offloading.py @@ -318,7 +318,6 @@ def test_gpt_fine_grained_activation_offloading_correctness_and_memory( ("alltoall", True, ["mlp_norm"]), ("alltoall", False, ["expert_fc1"]), ("alltoall", False, ["moe_act"]), - ("alltoall", False, ["mlp_norm", "expert_fc1", "moe_act"]), ( "alltoall", True, @@ -571,3 +570,338 @@ def _run_schedule_1f1b_two_microbatches( ) finally: Utils.destroy_model_parallel() + + +# ============================================================================= +# CUDA Graph + Fine-grained Activation Offloading Tests +# ============================================================================= + + +def _build_gpt_model_with_cuda_graph( + *, + seed: int, + num_layers: int, + hidden_size: int, + num_attention_heads: int, + vocab_size: int, + seq_length: int, + num_experts: Optional[int], + fine_grained_activation_offloading: bool, + offload_modules: Optional[List[str]], + min_offloaded_tensor_size: int, + is_mla: bool, + cuda_graph_impl: str, + cuda_graph_scope: Optional[List[str]], + cuda_graph_warmup_steps: int, + delay_offload_until_cuda_graph: bool = False, + activation_offload_fraction: float = 1.0, +) -> GPTModel: + """Build a GPTModel with CUDA Graph support and fine-grained activation offloading.""" + model_parallel_cuda_manual_seed(seed) + torch.manual_seed(seed) + ConfigClass = MLATransformerConfig if is_mla else TransformerConfig + transformer_config = ConfigClass( + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + use_cpu_initialization=True, + attention_backend=AttnBackend.unfused, + bf16=True, + # Recompute + recompute_modules=["layernorm", "moe_act"] if num_experts is not None else ["layernorm"], + recompute_granularity="selective", + # MoE + num_moe_experts=num_experts, + moe_grouped_gemm=(num_experts is not None), + # Fine-grained activation offloading + fine_grained_activation_offloading=fine_grained_activation_offloading, + offload_modules=offload_modules, + min_offloaded_tensor_size=min_offloaded_tensor_size, + delay_offload_until_cuda_graph=delay_offload_until_cuda_graph, + activation_offload_fraction=activation_offload_fraction, + # CUDA Graph settings + cuda_graph_impl=cuda_graph_impl, + cuda_graph_scope=cuda_graph_scope, + cuda_graph_warmup_steps=cuda_graph_warmup_steps, + use_te_rng_tracker=True, + ) + gpt_model = GPTModel( + config=transformer_config, + transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec( + num_experts=num_experts, + moe_grouped_gemm=num_experts is not None, + moe_use_legacy_grouped_gemm=False, + multi_latent_attention=is_mla, + ), + vocab_size=vocab_size, + max_sequence_length=seq_length, + ).bfloat16() + return gpt_model + + +def _run_iters_with_cuda_graph( + model: GPTModel, + *, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + num_warmup_iters: int, + num_measure_iters: int, + enable_offload_reset: bool, +) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], int]: + """ + Run multiple forward+backward iterations with CUDA graph capture. + + Returns: + - logits from last iteration (CPU float32) + - selected grads from last iteration (CPU float32) + - peak_memory_allocated (bytes) during measurement iterations + """ + from megatron.core.transformer.cuda_graphs import _CudagraphGlobalRecord, delete_cuda_graphs + + if enable_offload_reset: + off_interface.reset() + + # Warmup iterations (before CUDA graph capture) + for _ in range(num_warmup_iters): + if enable_offload_reset: + off_interface.reset() + logits = model( + input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask + ) + loss = logits.float().sum() + loss.backward() + # Zero grads for next iteration + for p in model.parameters(): + if p.grad is not None: + p.grad.zero_() + + # Trigger post-warmup offload decisions + if enable_offload_reset: + off_interface.reset() + + # Create CUDA graphs after warmup + _CudagraphGlobalRecord.create_cudagraphs() + + # Measurement iterations (with CUDA graph replay) + torch.cuda.reset_peak_memory_stats() + for i in range(num_measure_iters): + if enable_offload_reset: + off_interface.reset() + logits = model( + input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask + ) + loss = logits.float().sum() + loss.backward() + if i < num_measure_iters - 1: + for p in model.parameters(): + if p.grad is not None: + p.grad.zero_() + + torch.cuda.synchronize() + peak_bytes = int(torch.cuda.max_memory_allocated()) + + # Capture grads from last iteration + grads: Dict[str, torch.Tensor] = {} + for name, p in model.named_parameters(): + grads[name] = p.grad.detach().float().cpu() if p.grad is not None else None + + # Cleanup CUDA graphs + delete_cuda_graphs() + + return logits.detach().float().cpu(), grads, peak_bytes + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for offloading tests.") +@pytest.mark.skipif( + not is_te_min_version("2.13.0"), reason="CUDA Graph with TE RNG tracker requires TE >= 2.13.0" +) +@pytest.mark.parametrize( + "is_mla, offload_modules, cuda_graph_scope, activation_offload_fraction, delay_offload", + [ + # MoE model with attention CUDA graph + attn offloading + (False, ["core_attn", "attn_proj"], ["attn", "moe_router"], 1.0, True), + (False, ["expert_fc1", "moe_act"], ["attn", "moe_router", "moe_preprocess"], 1.0, True), + (False, ["core_attn", "attn_proj", "expert_fc1"], ["attn", "moe_router"], 1.0, True), + ( + False, + ["core_attn", "attn_proj", "expert_fc1", "moe_act"], + ["attn", "moe_router"], + 1.0, + True, + ), + ( + False, + ["core_attn", "expert_fc1", "moe_act"], + ["attn", "moe_router", "moe_preprocess"], + 1.0, + True, + ), + ( + True, + ["core_attn", "attn_proj", "expert_fc1", "moe_act"], + ["attn", "moe_router", "moe_preprocess"], + 1.0, + True, + ), + # Test activation_offload_fraction parameter + (False, ["core_attn", "attn_proj", "expert_fc1"], ["attn", "moe_router"], 0.0, True), + (False, ["core_attn", "attn_proj", "expert_fc1"], ["attn", "moe_router"], 0.5, True), + # Test delay_offload_until_cuda_graph parameter + (False, ["core_attn", "attn_proj", "expert_fc1"], ["attn", "moe_router"], 1.0, False), + ], +) +def test_fine_grained_activation_offloading_with_cuda_graph( + is_mla: bool, + offload_modules: List[str], + cuda_graph_scope: List[str], + activation_offload_fraction: float, + delay_offload: bool, +): + """ + Test fine-grained activation offloading combined with CUDA graph capture. + + Verifies: + - Forward output correctness with CUDA graph + offloading + - Backward gradient correctness + - Memory savings from offloading are preserved with CUDA graphs + - Different activation_offload_fraction values work correctly + - Both delay_offload_until_cuda_graph=True/False produce correct results + """ + from megatron.core.tensor_parallel.random import initialize_rng_tracker + + os.environ.pop("NVTE_FUSED_ATTN", None) + os.environ.pop("NVTE_FLASH_ATTN", None) + os.environ.pop("NVTE_UNFUSED_ATTN", None) + + initialize_rng_tracker(use_te_rng_tracker=True, force_reset=True) + Utils.initialize_model_parallel(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + + seed = 123 + num_experts = 4 # Always MoE model + num_layers = 4 # Smaller for faster test with CUDA graphs + hidden_size = 1024 + num_attention_heads = 8 + vocab_size = 512 + seq_length = 512 + micro_batch_size = 2 + device = torch.device("cuda") + cuda_graph_warmup_steps = 3 + + input_ids, position_ids, attention_mask = _make_gpt_inputs( + seq_length=seq_length, micro_batch_size=micro_batch_size, device=device + ) + + off_interface.reset_instance() + + try: + # 1) Baseline: CUDA graph enabled, offloading disabled + _reset_cuda_memory() + base_model = _build_gpt_model_with_cuda_graph( + seed=seed, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + vocab_size=vocab_size, + seq_length=seq_length, + num_experts=num_experts, + fine_grained_activation_offloading=False, + offload_modules=None, + min_offloaded_tensor_size=1024 * 1024, + is_mla=is_mla, + cuda_graph_impl="transformer_engine", + cuda_graph_scope=cuda_graph_scope, + cuda_graph_warmup_steps=cuda_graph_warmup_steps, + ).cuda() + base_model.train() + + base_logits, base_grads, base_peak = _run_iters_with_cuda_graph( + base_model, + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + num_warmup_iters=cuda_graph_warmup_steps, + num_measure_iters=2, + enable_offload_reset=False, + ) + del base_model + _reset_cuda_memory() + + # 2) Test: CUDA graph enabled + offloading enabled + off_interface.reset_instance() + + off_model = _build_gpt_model_with_cuda_graph( + seed=seed, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + vocab_size=vocab_size, + seq_length=seq_length, + num_experts=num_experts, + fine_grained_activation_offloading=True, + offload_modules=offload_modules, + min_offloaded_tensor_size=1024, # Force offloading for determinism + is_mla=is_mla, + cuda_graph_impl="transformer_engine", + cuda_graph_scope=cuda_graph_scope, + cuda_graph_warmup_steps=cuda_graph_warmup_steps, + delay_offload_until_cuda_graph=delay_offload, + activation_offload_fraction=activation_offload_fraction, + ).cuda() + off_model.train() + + off_logits, off_grads, off_peak = _run_iters_with_cuda_graph( + off_model, + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + num_warmup_iters=cuda_graph_warmup_steps, + num_measure_iters=2, + enable_offload_reset=True, + ) + del off_model + _reset_cuda_memory() + + # 3) Correctness checks + assert torch.allclose( + off_logits, base_logits, rtol=1e-2, atol=1e-2 + ), f"Logits mismatch: max_diff={torch.max(torch.abs(off_logits - base_logits))}" + assert set(off_grads.keys()) == set(base_grads.keys()) + for name, gb in base_grads.items(): + go = off_grads[name] + if gb is None or go is None: + assert gb is None and go is None, f"Grad None mismatch for {name}" + continue + assert torch.allclose( + go, gb, rtol=1e-2, atol=1e-2 + ), f"Grad mismatch for {name}: max_diff={torch.max(torch.abs(go - gb))}" + + # 4) Memory checks - offloading should still reduce memory with CUDA graphs + saved_mib = (base_peak - off_peak) / (1024**2) + print( + f"CUDA Graph + Offload test (fraction={activation_offload_fraction}, delay={delay_offload}): " + f"base_peak={base_peak/(1024**2):.2f}MiB, " + f"off_peak={off_peak/(1024**2):.2f}MiB, " + f"saved={saved_mib:.2f}MiB" + ) + + # Basic sanity checks + assert not torch.isnan(off_logits).any(), "NaN detected in logits" + assert not torch.isinf(off_logits).any(), "Inf detected in logits" + + # Check gradients are valid + for name, g in off_grads.items(): + if g is not None: + assert not torch.isnan(g).any(), f"NaN detected in grad for {name}" + assert not torch.isinf(g).any(), f"Inf detected in grad for {name}" + + # Note: With CUDA graphs, memory behavior may differ from eager mode. + # We check that offloading doesn't significantly increase memory. + # In some cases, graph capture overhead may offset offload savings. + assert saved_mib >= -DELTA, ( + f"Offloading with CUDA graph significantly increased memory: " + f"saved={saved_mib:.2f}MiB (negative means increase)" + ) + + finally: + Utils.destroy_model_parallel()