From af236f11182bfe10848dc1b3e82461fd2c1478c9 Mon Sep 17 00:00:00 2001 From: geyuhong Date: Tue, 19 Aug 2025 01:42:46 +0800 Subject: [PATCH 01/35] zero-overhead activation offload --- megatron/core/model_parallel_config.py | 3 + .../core/pipeline_parallel/cpu_offload.py | 419 ++++++++++++++++++ megatron/core/pipeline_parallel/schedules.py | 10 +- megatron/core/transformer/attention.py | 162 ++++++- megatron/core/transformer/moe/moe_layer.py | 5 + megatron/core/transformer/moe/moe_utils.py | 47 ++ megatron/core/transformer/moe/router.py | 1 + .../core/transformer/moe/token_dispatcher.py | 47 ++ .../core/transformer/transformer_config.py | 51 +++ .../core/transformer/transformer_layer.py | 57 ++- 10 files changed, 788 insertions(+), 14 deletions(-) create mode 100644 megatron/core/pipeline_parallel/cpu_offload.py diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 17efbc7b843..31465ea534b 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -315,6 +315,9 @@ class ModelParallelConfig: rank 1 | 0 1 2 0 1 2 3 4 3 4 """ + offload_mlp_input: bool = False + """If true, offloads the MLP input to CPU. This is useful for large.""" + ################### # CPU Offloading ################### diff --git a/megatron/core/pipeline_parallel/cpu_offload.py b/megatron/core/pipeline_parallel/cpu_offload.py new file mode 100644 index 00000000000..1ccef70febf --- /dev/null +++ b/megatron/core/pipeline_parallel/cpu_offload.py @@ -0,0 +1,419 @@ +from collections import deque +import torch +from megatron.core import parallel_state +from typing import Any +from transformer_engine.pytorch.float8_tensor import Float8Tensor + +# cpu offload for pipeline + +class PipelineOffloadManager: + OFFLOAD_MGR = None + @classmethod + def get_instance(cls): + if cls.OFFLOAD_MGR is None: + cls.OFFLOAD_MGR = PipelineOffloadManager() + return cls.OFFLOAD_MGR + + def __init__(self): + self._queue = deque() + self._vpp = parallel_state.get_virtual_pipeline_model_parallel_world_size() + + # cache vpp - 1 stages + self._stages = [[] for _ in range(self._vpp)] + # allocate streams and events for synchronization + self._d2h_stream = torch.cuda.Stream() + self._h2d_stream = torch.cuda.Stream() + self._f_event = torch.cuda.Event() + self._b_event = torch.cuda.Event() + self._f_event.record(self._d2h_stream) + self._b_event.record(self._h2d_stream) + self.reset() + + + @property + def d2h_stream(self): + return self._d2h_stream + + @property + def h2d_stream(self): + return self._h2d_stream + + def reset(self): + self._inside_context = False + self._cur_forward_chunk = None + self._cur_backward_chunk = None + self._first_last_vpp_rank = True + + def flush(self): + # put into the queue in the backward order + if len(self._stages[0]) == len(self._stages[-1]): + lens = [len(e) for e in self._stages] + assert min(lens) == max(lens) + self._stages[-1] = [] + for chunks in reversed(self._stages): + for chunk in chunks: + self.push(chunk) + for i in range(self._vpp): + self._stages[i] = [] + + def push(self, handler): + self._queue.append(handler) + + def pop(self): + assert self.size() + self._cur_backward_chunk = self._queue.popleft() + + + def front(self): + if not len(self._queue): + return None + f = self._queue.popleft() + self._queue.appendleft(f) + return f + + def size(self): + return len(self._queue) + + def reset_chunk_handler(self, num_layer, offload_mlp_input=True): + cur_vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() + + first_last_vpp_rank = self._first_last_vpp_rank + # rewind + if cur_vpp_rank == self._vpp - 1: + self.flush() + first_last_vpp_rank = first_last_vpp_rank and (cur_vpp_rank == self._vpp - 1) + cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload_mlp_input) + # save for latter push + self._stages[cur_vpp_rank].append(cur_chunk) + if cur_vpp_rank == self._vpp - 1: + self._first_last_vpp_rank = False + self.push(cur_chunk) + self.flush() + self._cur_forward_chunk = cur_chunk + cur_chunk.vpp_rank = cur_vpp_rank + + + def cur_forward_chunk(self): + return self._cur_forward_chunk + + def cur_backward_chunk(self): + return self._cur_backward_chunk + + def __enter__(self): + self.OFFLOAD_MGR + self.inside_context = True + + torch._C._autograd._push_saved_tensors_default_hooks( + self.on_save_for_backward, self.on_get_saved_tensor + ) + + def __exit__(self, *args: Any): + self.inside_context = False + torch._C._autograd._pop_saved_tensors_default_hooks() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + assert self.inside_context + if self.cur_forward_chunk().is_registered_tensor(tensor.data_ptr()): + tensor.offloading_mlp_input = True + return self.cur_forward_chunk().tensor_push(tensor) + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + return self.cur_backward_chunk().tensor_pop(saved_state) + + + + + +class ChunkOffloadHandler: + + @staticmethod + def offload(src_tensor, pin_memory=True): + """Offload.""" + fp8_offload = isinstance(src_tensor, Float8Tensor) + + cpu_backup = torch.empty( + src_tensor.size(), + dtype=torch.uint8 if fp8_offload else src_tensor.dtype, + layout=src_tensor.layout, + device="cpu", + pin_memory=pin_memory, + ) + + if fp8_offload: + cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup) + + cpu_backup.copy_(src_tensor, non_blocking=pin_memory) + state = (src_tensor.device, cpu_backup) + return state + + @staticmethod + def reload(state, non_blocking=None): + """Reload.""" + dev, cpu_backup = state + if non_blocking is None: + non_blocking = cpu_backup.is_pinned() + return cpu_backup.to(dev, non_blocking=non_blocking) + + def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True): + self._num_layers = num_layer + # Data Structure to maintain reference to activation tensors + self._tensor_tag_to_state = {} + # Tracking the number of layers offloaded + self._offloaded_group_count = 0 + self._is_first_last_vpp_chunk = is_first_last_vpp_chunk + + self._layer_index = 0 + self._tensor_count_current_layer = 0 + self.cur_backward_tensor_count = 0 + + self.tensor_need_offloading_checker = None + self.torch_tensor_count = 0 + self.d2h_stream = PipelineOffloadManager.get_instance().d2h_stream + self.h2d_stream = PipelineOffloadManager.get_instance().h2d_stream + self._f_event = PipelineOffloadManager.get_instance()._f_event + self._b_event = PipelineOffloadManager.get_instance()._b_event + self.do_offload = offload + + self._offload_tensor_ptrs = deque() + + def is_first_last_layer(self): + return self._is_first_last_vpp_chunk and self.is_last_layer() + + def is_last_layer(self): + return (self._layer_index == self._num_layers - 1) + + def tensor_push(self, tensor): + torch_stray_tensor = isinstance( + tensor, + ( + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), + ) + + if not torch_stray_tensor:# True + # obtain a unique tensor tag + tensor_tag = (self._layer_index, self._tensor_count_current_layer) + self._tensor_count_current_layer += 1 + assert tensor_tag not in self._tensor_tag_to_state + self._tensor_tag_to_state[tensor_tag] = tensor + else: + tensor_tag = (-1, self.torch_tensor_count) + self.torch_tensor_count += 1 + self._tensor_tag_to_state[tensor_tag] = tensor + return tensor_tag + + def tensor_pop(self, tensor_tag): + assert tensor_tag in self._tensor_tag_to_state, f"{tensor_tag}, {self._tensor_tag_to_state.keys()}" + tensor = self._tensor_tag_to_state.pop(tensor_tag) + assert not isinstance(tensor, tuple) + return tensor + + def set_offloading_checker(self, check_func): + self.tensor_need_offloading_checker = check_func + + def bulk_offload_group(self, group_to_offload): + """Bulk offload group.""" + if not self.do_offload: + return + assert not self.is_first_last_layer() + with torch.cuda.stream(self.d2h_stream): + for tensor_tag, state in self._tensor_tag_to_state.items(): + group_id, _ = tensor_tag + if group_id == group_to_offload: + assert not isinstance(state, tuple) + tensor_on_device = state + # if offload, return the reference to cpu copy + if self.tensor_need_offloading_checker is not None and self.tensor_need_offloading_checker(tensor_on_device): + #print(f"offload {group_to_offload}") + state = self.offload(tensor_on_device) + tensor_on_device.record_stream(self.d2h_stream) + self._tensor_tag_to_state[tensor_tag] = state + self._offloaded_group_count = group_to_offload + 1 + self._f_event.record(self.d2h_stream) + + + def bulk_reload_group(self, group_to_reload): + """Bulk reload group.""" + if not self.do_offload: + return + no_tensors_cur_layer = True + with torch.cuda.stream(self.h2d_stream): + # move back tensors + # self._tensor_tag_to_state -> {tensor_tag: state} = {(_layer_index, _tensor_count_current_layer): tensor_on_device} + for tensor_label, state in self._tensor_tag_to_state.items(): + group_id, _ = tensor_label + if group_id == group_to_reload: + if isinstance(state, tuple): + recovered_tensor = self.reload(state) + self._tensor_tag_to_state[tensor_label] = recovered_tensor + break + for tensor_label, state in self._tensor_tag_to_state.items(): + group_id, _ = tensor_label + if group_id == group_to_reload: + if isinstance(state, tuple): + no_tensors_cur_layer = False + break + if no_tensors_cur_layer: + self._offloaded_group_count = group_to_reload + self._b_event.record(self.h2d_stream) + + def pre_reload_last_layer(self): + if not self.do_offload: + return + assert not self._is_first_last_vpp_chunk + if self._num_layers == self._offloaded_group_count: + self.bulk_reload_group(self._num_layers - 1) + # assert self._num_layers - 1 == self._offloaded_group_count + + + def should_bulk_offload(self): + if not self.do_offload: + return False + # first backward chunk + if self.is_first_last_layer(): + return False + + # if next backward chunk is this chunk (for last pp stage) + next_backward_chunk = PipelineOffloadManager.get_instance().get_instance().front() + if next_backward_chunk is not None and next_backward_chunk is self: + if self.is_last_layer(): + return False + + return True + + def forward_sync(self): + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + self._f_event.wait(torch.cuda.current_stream()) + #torch.cuda.empty_cache() + + + def bulk_offload(self, offloaded_call_back): + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + #torch.cuda.empty_cache() + if self.should_bulk_offload(): + self.bulk_offload_group(self._layer_index) + if offloaded_call_back is not None: + offloaded_call_back() + + + + + def on_group_commit_forward(self, offloaded_call_back): + # wait each other + self.forward_sync() + self.bulk_offload(offloaded_call_back) + self._layer_index = self._layer_index + 1 + self.cur_backward_tensor_count = self._tensor_count_current_layer + self._tensor_count_current_layer = 0 + + + def bulk_reload(self): + if self.do_offload: + assert self._layer_index == self._offloaded_group_count + if self._layer_index: + # load next layer + self.bulk_reload_group(self._layer_index - 1) + else: + next_backward_chunk = PipelineOffloadManager.get_instance().front() + if next_backward_chunk is not None: + next_backward_chunk.pre_reload_last_layer() + + def backward_sync(self): + self.h2d_stream.wait_stream(torch.cuda.current_stream()) + self._b_event.wait(torch.cuda.current_stream()) + + + def on_group_commit_backward(self): + cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() + if not cur_backward_chunk is self: + PipelineOffloadManager.get_instance().pop() + cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() + assert cur_backward_chunk is self + self._layer_index = self._layer_index - 1 + self.backward_sync() + # layer index already loaded back + # self.bulk_reload() + + def on_group_start_forward(self): + pass + + def on_group_start_backward(self): + self.h2d_stream.wait_stream(torch.cuda.current_stream()) + self.bulk_reload() + + def register_offload_tensor(self, tensor): + self._offload_tensor_ptrs.append(tensor.data_ptr()) + + def is_registered_tensor(self, tensor_ptr: int) -> bool: + if len(self._offload_tensor_ptrs) == 0: + return False + is_registered = tensor_ptr == self._offload_tensor_ptrs[0] + if is_registered: + self._offload_tensor_ptrs.popleft() + return is_registered + + + + +class GroupCommitFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, tensor, cpu_offload_handler, offloaded_call_back): + # pylint: disable=missing-function-docstring + + cpu_offload_handler.on_group_commit_forward(offloaded_call_back) + ctx.cpu_offload_handler = cpu_offload_handler + + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_group_commit_backward() + return grad_output, None, None + + +def group_prefetch_offload_commit(tensor, offloaded_call_back=None): + cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() + return GroupCommitFunction.apply(tensor, cur_forward_chunk, offloaded_call_back) + + +class GroupStartFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, tensor, cpu_offload_handler): + # pylint: disable=missing-function-docstring + # cpu_offload_handler.on_group_start_forward() + ctx.cpu_offload_handler = cpu_offload_handler + + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_group_start_backward() + return grad_output, None + +def group_prefetch_offload_start(tensor): + cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() + return GroupStartFunction.apply(tensor, cur_forward_chunk) + + +def offloading_checker(tensor): + return hasattr(tensor, 'offloading_mlp_input') and tensor.offloading_mlp_input diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index fb1ff086ecb..1262868db04 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -18,7 +18,10 @@ ) from megatron.core.process_groups_config import GradFinalizeProcessGroups from megatron.core.transformer.cuda_graphs import create_cudagraphs -from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler +from megatron.core.transformer.moe.router import ( + MoEAuxLossAutoScaler, + MoEPositiveAuxLossAutoScaler, +) from megatron.core.utils import ( drain_embedding_wgrad_compute, get_attr_wrapped_model, @@ -32,6 +35,7 @@ combined_1f1b_schedule_for_interleaved_pipelining, combined_1f1b_schedule_for_no_pipelining, ) +from .cpu_offload import PipelineOffloadManager, offloading_checker # Types Shape = Union[List[int], torch.Size] @@ -266,6 +270,8 @@ def forward_step_calc_loss( if config.calculate_per_token_loss: MoEAuxLossAutoScaler.set_loss_scale(loss_scale) else: + if config.offload_activation: + MoEPositiveAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches) MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches) # Set the loss scale for Multi-Token Prediction (MTP) loss. @@ -903,6 +909,8 @@ def forward_backward_pipelining_with_interleaving( adjust_tensor_shapes_fn is None ), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism" + if not forward_only: + PipelineOffloadManager.get_instance().reset() if config.overlap_p2p_comm and config.batch_p2p_comm: raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index c749bac4373..8f7e138ab5f 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -37,6 +37,11 @@ from .enums import AttnMaskType from .transformer_config import TransformerConfig +from megatron.core.pipeline_parallel.cpu_offload import ( + PipelineOffloadManager, + group_prefetch_offload_start, + group_prefetch_offload_commit, +) try: from einops import rearrange @@ -179,6 +184,21 @@ def __init__( and "core_attn" in self.config.recompute_modules ) + self.offload_qkv_linear = ( + self.config.offload_activation + and "qkv_linear" in self.config.offload_modules + ) + + self.offload_core_attention = ( + self.config.offload_activation + and "core_attn" in self.config.offload_modules + ) + + self.offload_attn_linear = ( + self.config.offload_activation + and "attn_linear" in self.config.offload_modules + ) + # Output. self.linear_proj = build_module( submodules.linear_proj, @@ -247,6 +267,128 @@ def custom_forward(*inputs): return hidden_states + def _offload_qkv_linear_forward( + self, + hidden_states, + key_value_states, + ): + """====== [todo] weights lose 'main_grad' in backward pass. under debugging. ======""" + """Forward method with qkv linear activation offloading.""" + if not hidden_states.is_contiguous(): + hidden_states = hidden_states.contiguous() + + hidden_states = group_prefetch_offload_start(hidden_states) + + handler = PipelineOffloadManager.get_instance().cur_forward_chunk() + handler.register_offload_tensor(hidden_states) + + hidden_states.offloading_mlp_input = True + + with PipelineOffloadManager.get_instance(): + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + + def call_back(): + cur_stream = torch.cuda.current_stream() + hidden_states.record_stream(cur_stream) + hidden_states.untyped_storage().resize_(0) + + query, key, value = group_prefetch_offload_commit(query, key, value, call_back) + return query, key, value + + def _offload_core_attention_forward( + self, + query, + key, + value, + attention_mask, + rotary_pos_emb=None, + attn_mask_type=None, + attention_bias=None, + packed_seq_params=None, + ): + """Forward method with attention activation offloading.""" + + def custom_forward(*inputs): + query = inputs[0] + key = inputs[1] + value = inputs[2] + attention_mask = inputs[3] + attn_mask_type = inputs[5] + attn_mask_type = AttnMaskType(attn_mask_type.item()) + output_ = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + return output_ + + if attn_mask_type is None: + attn_mask_type = self.attn_mask_type + attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int) + + value = value.contiguous() + + query = group_prefetch_offload_start(query) + key = group_prefetch_offload_start(key) + value = group_prefetch_offload_start(value) + + handler = PipelineOffloadManager.get_instance().cur_forward_chunk() + handler.register_offload_tensor(query) + handler.register_offload_tensor(key) + handler.register_offload_tensor(value) + + query.offloading_mlp_input = True + key.offloading_mlp_input = True + value.offloading_mlp_input = True + + with PipelineOffloadManager.get_instance(): + hidden_states = custom_forward( + query, key, value, attention_mask, rotary_pos_emb, attn_mask_type + ) + + def call_back(): + cur_stream = torch.cuda.current_stream() + query.record_stream(cur_stream) + key.record_stream(cur_stream) + value.record_stream(cur_stream) + query.untyped_storage().resize_(0) + key.untyped_storage().resize_(0) + value.untyped_storage().resize_(0) + + hidden_states = group_prefetch_offload_commit(hidden_states, call_back) + return hidden_states + + def _offload_attn_linear_forward( + self, + hidden_states, + ): + """====== [todo] weights lose 'main_grad' in backward pass. under debugging. ======""" + """Forward method with attention linear projection activation offloading.""" + if not hidden_states.is_contiguous(): + hidden_states = hidden_states.contiguous() + + hidden_states = group_prefetch_offload_start(hidden_states) + + handler = PipelineOffloadManager.get_instance().cur_forward_chunk() + handler.register_offload_tensor(hidden_states) + + hidden_states.offloading_mlp_input = True + + with PipelineOffloadManager.get_instance(): + output, bias = self.linear_proj(hidden_states) + + def call_back(): + cur_stream = torch.cuda.current_stream() + hidden_states.record_stream(cur_stream) + hidden_states.untyped_storage().resize_(0) + + output, bias = group_prefetch_offload_commit(output, bias, call_back) + return output, bias + def _allocate_memory(self, inference_max_sequence_length, batch_size, dim, dtype): """Allocate memory to store kv cache during inference.""" @@ -670,7 +812,10 @@ def forward( # Get the query, key and value tensors based on the type of attention - # self or cross attn. nvtx_range_push(suffix="qkv") - query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + if self.offload_qkv_linear: + query, key, value = self._offload_qkv_linear_forward(hidden_states, key_value_states) + else: + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) nvtx_range_pop(suffix="qkv") # =================================================== @@ -798,6 +943,16 @@ def forward( attention_bias=attention_bias, packed_seq_params=packed_seq_params, ) + elif self.offload_core_attention and self.training: + core_attn_out = self._offload_core_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) else: if inference_context is None or inference_context.is_static_batching(): # Static batching attention kernel. @@ -843,7 +998,10 @@ def forward( # ================= nvtx_range_push(suffix="linear_proj") - output, bias = self.linear_proj(core_attn_out) + if self.offload_attn_linear: + output, bias = self._offload_attn_linear_forward(core_attn_out) + else: + output, bias = self.linear_proj(core_attn_out) nvtx_range_pop(suffix="linear_proj") return output, bias diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index e7dd9d4e56c..07dbda60e76 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -19,6 +19,11 @@ ) from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.pipeline_parallel.cpu_offload import ( + PipelineOffloadManager, + group_prefetch_offload_start, + group_prefetch_offload_commit, +) try: import transformer_engine as te # pylint: disable=unused-import diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 235b6f6af0c..91be5859664 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -215,6 +215,53 @@ def set_loss_scale(scale: torch.Tensor): MoEAuxLossAutoScaler.main_loss_backward_scale.copy_(scale) +class MoEPositiveAuxLossAutoScaler(torch.autograd.Function): + """An AutoScaler that compute and scales the grad for positive auxiliary loss.""" + + main_loss_backward_scale: torch.Tensor = torch.tensor(1.0) + + @staticmethod + def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor): + """Preserve the aux_loss by storing it in the context to avoid garbage collection. + + Args: + output (torch.Tensor): The output tensor. + aux_loss (torch.Tensor): The auxiliary loss tensor. + + Returns: + torch.Tensor: The output tensor. + """ + ctx.save_for_backward(aux_loss) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + """Compute and scale the gradient for positive auxiliary loss.. + + Args: + grad_output (torch.Tensor): The gradient of the output. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled positive + auxiliary loss gradient. + """ + (aux_loss,) = ctx.saved_tensors + aux_loss_backward_scale = MoEPositiveAuxLossAutoScaler.main_loss_backward_scale + aux_loss_backward_scale = aux_loss_backward_scale * (aux_loss > 0.0) + scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale + return grad_output, scaled_aux_loss_grad + + @staticmethod + def set_loss_scale(scale: torch.Tensor): + """set the scale of the aux loss. + + Args: + scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in + matches the scale of the main_loss. + """ + MoEPositiveAuxLossAutoScaler.main_loss_backward_scale = scale + + def permute( tokens, routing_map, diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 6b20b862274..72bc9201748 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -10,6 +10,7 @@ from megatron.core.transformer.moe.moe_utils import ( ModelCommProcessGroups, MoEAuxLossAutoScaler, + MoEPositiveAuxLossAutoScaler, apply_random_logits, apply_router_token_dropping, compute_routing_scores_for_aux_loss, diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 05b73a9ee49..631e69cf1c9 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -932,6 +932,7 @@ def __init__( router_topk: int, num_experts: int, config: TransformerConfig, + offload_activation: bool = False, ): """ Initialize the DeepEP dispatcher. @@ -953,6 +954,7 @@ def __init__( self.router_dtype = config.moe_router_dtype self.capacity_factor = config.moe_expert_capacity_factor self.permute_fusion = config.moe_permute_fusion + self.offload_activation = offload_activation # Metadata self.token_indices: Optional[torch.Tensor] = None @@ -1020,6 +1022,9 @@ def _indices_to_multihot(self, indices, probs): A tuple of (routing_map, probs), where routing_map is the multihot vector and probs is the multihot probabilities. """ + if self.offload_activation: + routing_map_vectorized, probs_map_vectorized = self._indices_to_multihot_vectorized(indices, probs) + return routing_map_vectorized, probs_map_vectorized batch_size = indices.shape[0] multihot_routing_map = torch.zeros( (batch_size, self.num_local_experts), dtype=torch.long, device=indices.device @@ -1038,6 +1043,47 @@ def _indices_to_multihot(self, indices, probs): multihot_probs[row_indices, valid_indices] = probs[mask] return multihot_routing_map.bool(), multihot_probs + def _indices_to_multihot_vectorized(self, indices, probs): + """ + Converts a tensor of indices to a multihot vector efficiently in PyTorch when enabling + offload_activation. + + Args: + indices (torch.Tensor): [num_tokens, topk] token indices, where -1 means masked out. + The max value of indices is local_num_experts - 1. + probs (torch.Tensor): [num_tokens, topk] token probabilities. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - routing_map: Multihot vector. + - probs: Multihot probabilities. + """ + batch_size, topk = indices.shape + + # Create mask for valid indices + mask = indices != -1 + + # Replace -1 with a valid index (will be masked out anyway) + safe_indices = torch.where(mask, indices, 0) + + # Create one-hot encoding for all positions + # Shape: [batch_size, topk, num_local_experts] + one_hot = torch.nn.functional.one_hot(safe_indices, num_classes=self.num_local_experts).float() + + # Apply mask to zero out invalid positions + # Expand mask to match one_hot dimensions + mask_expanded = mask.unsqueeze(-1).float() + one_hot = one_hot * mask_expanded + + # Sum along topk dimension to get multihot representation + multihot_routing_map = (one_hot.sum(dim=1) > 0).bool() + + # For probabilities, multiply by probs and sum + probs_expanded = probs.unsqueeze(-1) + multihot_probs = (one_hot * probs_expanded).sum(dim=1) + + return multihot_routing_map, multihot_probs + def get_dispached_metadata(self) -> torch.Tensor: return self.dispatched_indices, self.dispatched_probs @@ -1173,6 +1219,7 @@ def __init__( router_topk=self.tp_size * self.config.moe_router_topk, num_experts=self.tp_size * self.config.num_moe_experts, config=self.config, + offload_activation=self.config.offload_activation, ) def set_shared_experts(self, shared_experts): diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index de3df5a6e1f..858ae47e0f6 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -316,6 +316,27 @@ class TransformerConfig(ModelParallelConfig): "core_attn", "mlp", "moe", and "shared_experts" use normal checkpointing. """ + #################### + # activation offloading + #################### + offload_activation: bool = False + """If True, offload the activation to the CPU.""" + + offload_modules: Optional[List[str]] = None + """The submodules to offload. + choices: "self_attn", "qkv_linear", "core_attn", "attn_linear", "router_fc1", "router_fc2", + "shared_fc1", "shared_fc2". + default: ["core_attn"]. + "self_attn": offload the self_attn part of the transformer layer. + "qkv_linear": offload the qkv_linear part of the transformer layer. + "core_attn": offload the core attention part of the transformer layer. + "attn_linear": offload the attn linear projection part of the transformer layer. + "router_fc1": offload the moe router_fc1 part of the transformer layer. + "router_fc2": offload the moe router_fc2 part of the transformer layer. + "shared_fc1": offload the shared_fc1 part of the transformer layer. + "shared_fc2": offload the shared_fc2 part of the transformer layer. + """ + #################### # fp8 related #################### @@ -939,6 +960,36 @@ def __post_init__(self): if "moe" not in self.recompute_modules: self.recompute_modules.append("moe") + if self.offload_modules is None: + self.offload_modules = ["core_attn"] + + if len(self.offload_modules) > 0: + allowed_modules = { + "self_attn", "qkv_linear", "core_attn", "attn_linear", "router_fc1", "router_fc2", + "shared_fc1", "shared_fc2" + } + invalid_modules = set(self.offload_modules) - allowed_modules + assert not invalid_modules, ( + f'Invalid choices for offload_modules: {invalid_modules}. ' + f'Allowed modules are: {allowed_modules}' + ) + + if "self_attn" in self.offload_modules: + if "qkv_linear" in self.offload_modules: + self.offload_modules.remove("qkv_linear") + if "core_attn" in self.offload_modules: + self.offload_modules.remove("core_attn") + if "attn_linear" in self.offload_modules: + self.offload_modules.remove("attn_linear") + + if "core_attn" in self.offload_modules: + warnings.warn( + "If you are using transformer_engine as the transformer implementation, " + "the core_attn is from transformer_engine and may be the fused version. " + "For fused attention, you have no need to set 'core_attn' to offload. " + "Please check that the core_attn offload is really needed." + ) + if ( self.num_layers_in_first_pipeline_stage is not None or self.num_layers_in_last_pipeline_stage is not None diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 84f22bdeac1..dc12dfc952d 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -425,6 +425,10 @@ def __init__( if "mlp" in self.config.recompute_modules: if not isinstance(self.mlp, MoELayer): self.recompute_mlp = True + self.offload_self_attn = ( + self.config.offload_activation + and "self_attn" in self.config.offload_modules + ) # @jcasper how should we handle nvfuser? # Set bias+dropout+add fusion grad_enable execution handler. @@ -515,17 +519,48 @@ def _forward_attention( # Self attention. nvtx_range_push(suffix="self_attention") - attention_output_with_bias = self.self_attention( - input_layernorm_output, - attention_mask=attention_mask, - inference_context=inference_context, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - ) + if self.offload_self_attn: + from megatron.core.pipeline_parallel.cpu_offload import ( + PipelineOffloadManager, + group_prefetch_offload_start, + group_prefetch_offload_commit, + ) + if not input_layernorm_output.is_contiguous(): + input_layernorm_output = input_layernorm_output.contiguous() + input_layernorm_output = group_prefetch_offload_start(input_layernorm_output) + handler = PipelineOffloadManager.get_instance().cur_forward_chunk() + handler.register_offload_tensor(input_layernorm_output) + input_layernorm_output.offloading_self_attn = True + with PipelineOffloadManager.get_instance(): + attention_output_with_bias = self.self_attention( + input_layernorm_output, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + def call_back(): + cur_stream = torch.cuda.current_stream() + input_layernorm_output.record_stream(cur_stream) + input_layernorm_output.untyped_storage().resize_(0) + + attention_output_with_bias = group_prefetch_offload_commit(attention_output_with_bias, call_back) + else: + attention_output_with_bias = self.self_attention( + input_layernorm_output, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) nvtx_range_pop(suffix="self_attention") if self.recompute_input_layernorm: From 7168ccdff1cb05789f0d8a72adf4a6359c12586d Mon Sep 17 00:00:00 2001 From: geyuhong Date: Wed, 3 Sep 2025 15:11:38 +0800 Subject: [PATCH 02/35] bugfix main_grad info and bitwise --- megatron/core/models/gpt/gpt_model.py | 8 ++ megatron/core/models/gpt/utils.py | 20 ++++ .../core/pipeline_parallel/cpu_offload.py | 93 +++++++------------ megatron/core/pipeline_parallel/schedules.py | 2 +- megatron/core/transformer/attention.py | 26 ++---- megatron/core/transformer/moe/experts.py | 56 ++++++++++- .../core/transformer/transformer_layer.py | 4 +- 7 files changed, 128 insertions(+), 81 deletions(-) create mode 100644 megatron/core/models/gpt/utils.py diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 6aec66e6dca..f9448c74849 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -33,6 +33,8 @@ from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import WrappedTensor, deprecate_inference_params +from megatron.core.models.gpt.utils import offloading_checker, get_first_layer_index +from megatron.core.pipeline_parallel.cpu_offload import PipelineOffloadManager class GPTModel(LanguageModule): @@ -366,6 +368,12 @@ def forward( runtime_gather_output (bool): Gather output at runtime. Default None means `parallel_output` arg in the constructor will be used. """ + first_layer_index = get_first_layer_index(self.config, self.decoder.num_layers_per_pipeline_rank) + PipelineOffloadManager.get_instance().reset_chunk_handler( + self.decoder.num_layers_per_pipeline_rank, + self.config.offload_activation, + first_layer_index, + ) inference_context = deprecate_inference_params(inference_context, inference_params) diff --git a/megatron/core/models/gpt/utils.py b/megatron/core/models/gpt/utils.py new file mode 100644 index 00000000000..e858b286e13 --- /dev/null +++ b/megatron/core/models/gpt/utils.py @@ -0,0 +1,20 @@ +from megatron.core.parallel_state import ( + get_pipeline_model_parallel_rank, + get_virtual_pipeline_model_parallel_rank, + get_pipeline_model_parallel_world_size +) + +def offloading_checker(tensor): + return hasattr(tensor, "offloading_activation") and tensor.offloading_activation + +def get_first_layer_index(config, num_layers_per_pipeline_rank): + if 'core_attn' in config.offload_modules: + return 0 + pp_rank = get_pipeline_model_parallel_rank() + pp_size = get_pipeline_model_parallel_world_size() + vpp_rank = get_virtual_pipeline_model_parallel_rank() + layer_index_start = num_layers_per_pipeline_rank * (pp_size * vpp_rank + pp_rank) + if config.first_k_dense_replace > layer_index_start: + return config.first_k_dense_replace - layer_index_start + else: + return 0 \ No newline at end of file diff --git a/megatron/core/pipeline_parallel/cpu_offload.py b/megatron/core/pipeline_parallel/cpu_offload.py index 1ccef70febf..936abf211b6 100644 --- a/megatron/core/pipeline_parallel/cpu_offload.py +++ b/megatron/core/pipeline_parallel/cpu_offload.py @@ -3,6 +3,10 @@ from megatron.core import parallel_state from typing import Any from transformer_engine.pytorch.float8_tensor import Float8Tensor +from megatron.core.parallel_state import ( + get_pipeline_model_parallel_rank, + get_virtual_pipeline_model_parallel_rank, +) # cpu offload for pipeline @@ -29,7 +33,6 @@ def __init__(self): self._b_event.record(self._h2d_stream) self.reset() - @property def d2h_stream(self): return self._d2h_stream @@ -62,7 +65,6 @@ def push(self, handler): def pop(self): assert self.size() self._cur_backward_chunk = self._queue.popleft() - def front(self): if not len(self._queue): @@ -74,7 +76,7 @@ def front(self): def size(self): return len(self._queue) - def reset_chunk_handler(self, num_layer, offload_mlp_input=True): + def reset_chunk_handler(self, num_layer, offload_mlp_input=True, first_layer_index=0): cur_vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() first_last_vpp_rank = self._first_last_vpp_rank @@ -82,7 +84,7 @@ def reset_chunk_handler(self, num_layer, offload_mlp_input=True): if cur_vpp_rank == self._vpp - 1: self.flush() first_last_vpp_rank = first_last_vpp_rank and (cur_vpp_rank == self._vpp - 1) - cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload_mlp_input) + cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload_mlp_input, first_layer_index) # save for latter push self._stages[cur_vpp_rank].append(cur_chunk) if cur_vpp_rank == self._vpp - 1: @@ -92,7 +94,6 @@ def reset_chunk_handler(self, num_layer, offload_mlp_input=True): self._cur_forward_chunk = cur_chunk cur_chunk.vpp_rank = cur_vpp_rank - def cur_forward_chunk(self): return self._cur_forward_chunk @@ -114,18 +115,14 @@ def __exit__(self, *args: Any): def on_save_for_backward(self, tensor: torch.Tensor) -> Any: assert self.inside_context if self.cur_forward_chunk().is_registered_tensor(tensor.data_ptr()): - tensor.offloading_mlp_input = True + tensor.offloading_activation = True return self.cur_forward_chunk().tensor_push(tensor) def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: return self.cur_backward_chunk().tensor_pop(saved_state) - - - class ChunkOffloadHandler: - @staticmethod def offload(src_tensor, pin_memory=True): """Offload.""" @@ -154,7 +151,7 @@ def reload(state, non_blocking=None): non_blocking = cpu_backup.is_pinned() return cpu_backup.to(dev, non_blocking=non_blocking) - def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True): + def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, first_layer_index=0): self._num_layers = num_layer # Data Structure to maintain reference to activation tensors self._tensor_tag_to_state = {} @@ -162,9 +159,10 @@ def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True): self._offloaded_group_count = 0 self._is_first_last_vpp_chunk = is_first_last_vpp_chunk - self._layer_index = 0 + self._layer_index = first_layer_index + self.first_layer_index = first_layer_index self._tensor_count_current_layer = 0 - self.cur_backward_tensor_count = 0 + self.multi_input_offload_count = 0 self.tensor_need_offloading_checker = None self.torch_tensor_count = 0 @@ -178,7 +176,7 @@ def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True): def is_first_last_layer(self): return self._is_first_last_vpp_chunk and self.is_last_layer() - + def is_last_layer(self): return (self._layer_index == self._num_layers - 1) @@ -225,36 +223,29 @@ def bulk_offload_group(self, group_to_offload): tensor_on_device = state # if offload, return the reference to cpu copy if self.tensor_need_offloading_checker is not None and self.tensor_need_offloading_checker(tensor_on_device): - #print(f"offload {group_to_offload}") state = self.offload(tensor_on_device) tensor_on_device.record_stream(self.d2h_stream) self._tensor_tag_to_state[tensor_tag] = state self._offloaded_group_count = group_to_offload + 1 self._f_event.record(self.d2h_stream) - def bulk_reload_group(self, group_to_reload): """Bulk reload group.""" if not self.do_offload: return - no_tensors_cur_layer = True with torch.cuda.stream(self.h2d_stream): # move back tensors - # self._tensor_tag_to_state -> {tensor_tag: state} = {(_layer_index, _tensor_count_current_layer): tensor_on_device} for tensor_label, state in self._tensor_tag_to_state.items(): group_id, _ = tensor_label if group_id == group_to_reload: if isinstance(state, tuple): recovered_tensor = self.reload(state) self._tensor_tag_to_state[tensor_label] = recovered_tensor + self.multi_input_offload_count -= 1 + if self.multi_input_offload_count >= 1: break - for tensor_label, state in self._tensor_tag_to_state.items(): - group_id, _ = tensor_label - if group_id == group_to_reload: - if isinstance(state, tuple): - no_tensors_cur_layer = False - break - if no_tensors_cur_layer: + if self.multi_input_offload_count < 1: + self.multi_input_offload_count = 0 self._offloaded_group_count = group_to_reload self._b_event.record(self.h2d_stream) @@ -265,64 +256,54 @@ def pre_reload_last_layer(self): if self._num_layers == self._offloaded_group_count: self.bulk_reload_group(self._num_layers - 1) # assert self._num_layers - 1 == self._offloaded_group_count - - + def should_bulk_offload(self): if not self.do_offload: return False # first backward chunk if self.is_first_last_layer(): return False - + # if next backward chunk is this chunk (for last pp stage) next_backward_chunk = PipelineOffloadManager.get_instance().get_instance().front() if next_backward_chunk is not None and next_backward_chunk is self: if self.is_last_layer(): return False - + return True - + def forward_sync(self): self.d2h_stream.wait_stream(torch.cuda.current_stream()) self._f_event.wait(torch.cuda.current_stream()) - #torch.cuda.empty_cache() - - + def bulk_offload(self, offloaded_call_back): self.d2h_stream.wait_stream(torch.cuda.current_stream()) - #torch.cuda.empty_cache() if self.should_bulk_offload(): self.bulk_offload_group(self._layer_index) if offloaded_call_back is not None: offloaded_call_back() - - - def on_group_commit_forward(self, offloaded_call_back): # wait each other self.forward_sync() self.bulk_offload(offloaded_call_back) self._layer_index = self._layer_index + 1 - self.cur_backward_tensor_count = self._tensor_count_current_layer self._tensor_count_current_layer = 0 - - + def bulk_reload(self): if self.do_offload: assert self._layer_index == self._offloaded_group_count - if self._layer_index: + if self._layer_index > self.first_layer_index: # load next layer self.bulk_reload_group(self._layer_index - 1) else: next_backward_chunk = PipelineOffloadManager.get_instance().front() if next_backward_chunk is not None: next_backward_chunk.pre_reload_last_layer() - + def backward_sync(self): self.h2d_stream.wait_stream(torch.cuda.current_stream()) self._b_event.wait(torch.cuda.current_stream()) - def on_group_commit_backward(self): cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() @@ -337,12 +318,13 @@ def on_group_commit_backward(self): def on_group_start_forward(self): pass - + def on_group_start_backward(self): self.h2d_stream.wait_stream(torch.cuda.current_stream()) self.bulk_reload() - + def register_offload_tensor(self, tensor): + self.multi_input_offload_count += 1 self._offload_tensor_ptrs.append(tensor.data_ptr()) def is_registered_tensor(self, tensor_ptr: int) -> bool: @@ -354,8 +336,6 @@ def is_registered_tensor(self, tensor_ptr: int) -> bool: return is_registered - - class GroupCommitFunction(torch.autograd.Function): """this is a dummy op with output identical to input. However, it is necessary for marking a timepoint for offload handler to @@ -364,9 +344,12 @@ class GroupCommitFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, tensor, cpu_offload_handler, offloaded_call_back): + def forward(ctx, *args): # pylint: disable=missing-function-docstring + offloaded_call_back = args[-1] + cpu_offload_handler = args[-2] + tensor = args[:-2] cpu_offload_handler.on_group_commit_forward(offloaded_call_back) ctx.cpu_offload_handler = cpu_offload_handler @@ -374,17 +357,17 @@ def forward(ctx, tensor, cpu_offload_handler, offloaded_call_back): return tensor @staticmethod - def backward(ctx, grad_output): + def backward(ctx, *grad_output): # pylint: disable=missing-function-docstring cpu_offload_handler = ctx.cpu_offload_handler cpu_offload_handler.on_group_commit_backward() - return grad_output, None, None + return grad_output + (None, None) -def group_prefetch_offload_commit(tensor, offloaded_call_back=None): +def group_prefetch_offload_commit(*tensor, offloaded_call_back=None): cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() - return GroupCommitFunction.apply(tensor, cur_forward_chunk, offloaded_call_back) + return GroupCommitFunction.apply(*tensor, cur_forward_chunk, offloaded_call_back) class GroupStartFunction(torch.autograd.Function): @@ -397,7 +380,6 @@ class GroupStartFunction(torch.autograd.Function): @staticmethod def forward(ctx, tensor, cpu_offload_handler): # pylint: disable=missing-function-docstring - # cpu_offload_handler.on_group_start_forward() ctx.cpu_offload_handler = cpu_offload_handler # return the identical tensor @@ -410,10 +392,7 @@ def backward(ctx, grad_output): cpu_offload_handler.on_group_start_backward() return grad_output, None + def group_prefetch_offload_start(tensor): cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() return GroupStartFunction.apply(tensor, cur_forward_chunk) - - -def offloading_checker(tensor): - return hasattr(tensor, 'offloading_mlp_input') and tensor.offloading_mlp_input diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 1262868db04..0f2dbe99d72 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -35,7 +35,7 @@ combined_1f1b_schedule_for_interleaved_pipelining, combined_1f1b_schedule_for_no_pipelining, ) -from .cpu_offload import PipelineOffloadManager, offloading_checker +from .cpu_offload import PipelineOffloadManager # Types Shape = Union[List[int], torch.Size] diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 8f7e138ab5f..8b614c59ad8 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -272,17 +272,13 @@ def _offload_qkv_linear_forward( hidden_states, key_value_states, ): - """====== [todo] weights lose 'main_grad' in backward pass. under debugging. ======""" """Forward method with qkv linear activation offloading.""" if not hidden_states.is_contiguous(): hidden_states = hidden_states.contiguous() hidden_states = group_prefetch_offload_start(hidden_states) - handler = PipelineOffloadManager.get_instance().cur_forward_chunk() - handler.register_offload_tensor(hidden_states) - - hidden_states.offloading_mlp_input = True + hidden_states.offloading_activation = True with PipelineOffloadManager.get_instance(): query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) @@ -292,7 +288,7 @@ def call_back(): hidden_states.record_stream(cur_stream) hidden_states.untyped_storage().resize_(0) - query, key, value = group_prefetch_offload_commit(query, key, value, call_back) + query, key, value = group_prefetch_offload_commit(query, key, value, offloaded_call_back=call_back) return query, key, value def _offload_core_attention_forward( @@ -341,9 +337,9 @@ def custom_forward(*inputs): handler.register_offload_tensor(key) handler.register_offload_tensor(value) - query.offloading_mlp_input = True - key.offloading_mlp_input = True - value.offloading_mlp_input = True + query.offloading_activation = True + key.offloading_activation = True + value.offloading_activation = True with PipelineOffloadManager.get_instance(): hidden_states = custom_forward( @@ -359,24 +355,20 @@ def call_back(): key.untyped_storage().resize_(0) value.untyped_storage().resize_(0) - hidden_states = group_prefetch_offload_commit(hidden_states, call_back) - return hidden_states + hidden_states = group_prefetch_offload_commit(hidden_states, offloaded_call_back=call_back) + return hidden_states[0] def _offload_attn_linear_forward( self, hidden_states, ): - """====== [todo] weights lose 'main_grad' in backward pass. under debugging. ======""" """Forward method with attention linear projection activation offloading.""" if not hidden_states.is_contiguous(): hidden_states = hidden_states.contiguous() hidden_states = group_prefetch_offload_start(hidden_states) - handler = PipelineOffloadManager.get_instance().cur_forward_chunk() - handler.register_offload_tensor(hidden_states) - - hidden_states.offloading_mlp_input = True + hidden_states.offloading_activation = True with PipelineOffloadManager.get_instance(): output, bias = self.linear_proj(hidden_states) @@ -386,7 +378,7 @@ def call_back(): hidden_states.record_stream(cur_stream) hidden_states.untyped_storage().resize_(0) - output, bias = group_prefetch_offload_commit(output, bias, call_back) + output, bias = group_prefetch_offload_commit(output, bias, offloaded_call_back=call_back) return output, bias def _allocate_memory(self, inference_max_sequence_length, batch_size, dim, dtype): diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 98bc8912292..ac1d1d99b94 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -40,6 +40,11 @@ make_sharded_object_for_checkpoint, sharded_state_dict_default, ) +from megatron.core.pipeline_parallel.cpu_offload import ( + PipelineOffloadManager, + group_prefetch_offload_start, + group_prefetch_offload_commit, +) try: import transformer_engine as te # pylint: disable=unused-import @@ -804,6 +809,16 @@ def __init__( tp_group=parallel_state.get_expert_tensor_parallel_group(), ) + self.offload_router_fc1 = ( + self.config.offload_activation + and "router_fc1" in self.config.offload_modules + ) + + self.offload_router_fc2 = ( + self.config.offload_activation + and "router_fc2" in self.config.offload_modules + ) + self.activation_recompute = ( self.config.recompute_granularity == 'selective' and "moe_act" in self.config.recompute_modules @@ -818,6 +833,36 @@ def __init__( self.fp8_padding = Fp8Padding(self.num_local_experts) self.fp8_unpadding = Fp8Unpadding(self.num_local_experts) + def _offload_router_fc1_forward( + self, + permuted_local_hidden_states, + tokens_per_expert, + ): + """Forward method with router fc1 activation offloading.""" + if not permuted_local_hidden_states.is_contiguous(): + permuted_local_hidden_states = permuted_local_hidden_states.contiguous() + + permuted_local_hidden_states = group_prefetch_offload_start(permuted_local_hidden_states) + + permuted_local_hidden_states.offloading_activation = True + + with PipelineOffloadManager.get_instance(): + intermediate_parallel, bias_parallel = self.linear_fc1( + permuted_local_hidden_states, tokens_per_expert + ) + + def call_back(): + cur_stream = torch.cuda.current_stream() + permuted_local_hidden_states.record_stream(cur_stream) + permuted_local_hidden_states.untyped_storage().resize_(0) + + intermediate_parallel, bias_parallel = group_prefetch_offload_commit( + intermediate_parallel, + bias_parallel, + offloaded_call_back=call_back + ) + return intermediate_parallel, bias_parallel + def forward( self, permuted_local_hidden_states: torch.Tensor, @@ -857,9 +902,14 @@ def forward( # Probs already applied, so reset to 1. permuted_probs = torch.ones_like(permuted_probs) - intermediate_parallel, bias_parallel = self.linear_fc1( - permuted_local_hidden_states, tokens_per_expert - ) + if self.offload_router_fc1: + intermediate_parallel, bias_parallel = self._offload_router_fc1_forward( + permuted_local_hidden_states, tokens_per_expert + ) + else: + intermediate_parallel, bias_parallel = self.linear_fc1( + permuted_local_hidden_states, tokens_per_expert + ) def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): if self.config.use_te_activation_func: diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index dc12dfc952d..c92d64d5004 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -528,9 +528,7 @@ def _forward_attention( if not input_layernorm_output.is_contiguous(): input_layernorm_output = input_layernorm_output.contiguous() input_layernorm_output = group_prefetch_offload_start(input_layernorm_output) - handler = PipelineOffloadManager.get_instance().cur_forward_chunk() - handler.register_offload_tensor(input_layernorm_output) - input_layernorm_output.offloading_self_attn = True + input_layernorm_output.offloading_activation = True with PipelineOffloadManager.get_instance(): attention_output_with_bias = self.self_attention( input_layernorm_output, From 1555e6d750b34a79d8df6d745c33e64a8d5c30d2 Mon Sep 17 00:00:00 2001 From: geyuhong Date: Wed, 3 Sep 2025 16:52:35 +0800 Subject: [PATCH 03/35] remove offload_mlp_input arg --- megatron/core/model_parallel_config.py | 19 +++++++++++++++-- .../core/pipeline_parallel/cpu_offload.py | 21 ++++++++++--------- .../core/transformer/transformer_config.py | 21 ------------------- 3 files changed, 28 insertions(+), 33 deletions(-) diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 31465ea534b..251c1945c3f 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -315,8 +315,23 @@ class ModelParallelConfig: rank 1 | 0 1 2 0 1 2 3 4 3 4 """ - offload_mlp_input: bool = False - """If true, offloads the MLP input to CPU. This is useful for large.""" + offload_activation: bool = False + """If True, offload the activation to the CPU.""" + + offload_modules: Optional[list[str]] = None + """The submodules to offload. + choices: "self_attn", "qkv_linear", "core_attn", "attn_linear", "router_fc1", "router_fc2", + "shared_fc1", "shared_fc2". + default: ["core_attn"]. + "self_attn": offload the self_attn part of the transformer layer. + "qkv_linear": offload the qkv_linear part of the transformer layer. + "core_attn": offload the core attention part of the transformer layer. + "attn_linear": offload the attn linear projection part of the transformer layer. + "router_fc1": offload the moe router_fc1 part of the transformer layer. + "router_fc2": offload the moe router_fc2 part of the transformer layer. + "shared_fc1": offload the shared_fc1 part of the transformer layer. + "shared_fc2": offload the shared_fc2 part of the transformer layer. + """ ################### # CPU Offloading diff --git a/megatron/core/pipeline_parallel/cpu_offload.py b/megatron/core/pipeline_parallel/cpu_offload.py index 936abf211b6..e4465848c31 100644 --- a/megatron/core/pipeline_parallel/cpu_offload.py +++ b/megatron/core/pipeline_parallel/cpu_offload.py @@ -1,4 +1,4 @@ -from collections import deque +from collections import deque, defaultdict import torch from megatron.core import parallel_state from typing import Any @@ -76,7 +76,7 @@ def front(self): def size(self): return len(self._queue) - def reset_chunk_handler(self, num_layer, offload_mlp_input=True, first_layer_index=0): + def reset_chunk_handler(self, num_layer, offload=True, first_layer_index=0): cur_vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() first_last_vpp_rank = self._first_last_vpp_rank @@ -84,7 +84,7 @@ def reset_chunk_handler(self, num_layer, offload_mlp_input=True, first_layer_ind if cur_vpp_rank == self._vpp - 1: self.flush() first_last_vpp_rank = first_last_vpp_rank and (cur_vpp_rank == self._vpp - 1) - cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload_mlp_input, first_layer_index) + cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload, first_layer_index) # save for latter push self._stages[cur_vpp_rank].append(cur_chunk) if cur_vpp_rank == self._vpp - 1: @@ -162,7 +162,8 @@ def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, first_layer self._layer_index = first_layer_index self.first_layer_index = first_layer_index self._tensor_count_current_layer = 0 - self.multi_input_offload_count = 0 + self.multi_input_offload_count = False + self.offload_count_per_layer = defaultdict(int) self.tensor_need_offloading_checker = None self.torch_tensor_count = 0 @@ -225,6 +226,7 @@ def bulk_offload_group(self, group_to_offload): if self.tensor_need_offloading_checker is not None and self.tensor_need_offloading_checker(tensor_on_device): state = self.offload(tensor_on_device) tensor_on_device.record_stream(self.d2h_stream) + self.offload_count_per_layer[group_to_offload] += 1 self._tensor_tag_to_state[tensor_tag] = state self._offloaded_group_count = group_to_offload + 1 self._f_event.record(self.d2h_stream) @@ -241,11 +243,10 @@ def bulk_reload_group(self, group_to_reload): if isinstance(state, tuple): recovered_tensor = self.reload(state) self._tensor_tag_to_state[tensor_label] = recovered_tensor - self.multi_input_offload_count -= 1 - if self.multi_input_offload_count >= 1: - break - if self.multi_input_offload_count < 1: - self.multi_input_offload_count = 0 + self.offload_count_per_layer[group_to_reload] -= 1 + if self.offload_count_per_layer[group_to_reload] > 0 and self.multi_input_offload_count: + break + if self.offload_count_per_layer[group_to_reload] == 0: self._offloaded_group_count = group_to_reload self._b_event.record(self.h2d_stream) @@ -324,7 +325,7 @@ def on_group_start_backward(self): self.bulk_reload() def register_offload_tensor(self, tensor): - self.multi_input_offload_count += 1 + self.multi_input_offload_count = True self._offload_tensor_ptrs.append(tensor.data_ptr()) def is_registered_tensor(self, tensor_ptr: int) -> bool: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 858ae47e0f6..b11594a4f86 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -316,27 +316,6 @@ class TransformerConfig(ModelParallelConfig): "core_attn", "mlp", "moe", and "shared_experts" use normal checkpointing. """ - #################### - # activation offloading - #################### - offload_activation: bool = False - """If True, offload the activation to the CPU.""" - - offload_modules: Optional[List[str]] = None - """The submodules to offload. - choices: "self_attn", "qkv_linear", "core_attn", "attn_linear", "router_fc1", "router_fc2", - "shared_fc1", "shared_fc2". - default: ["core_attn"]. - "self_attn": offload the self_attn part of the transformer layer. - "qkv_linear": offload the qkv_linear part of the transformer layer. - "core_attn": offload the core attention part of the transformer layer. - "attn_linear": offload the attn linear projection part of the transformer layer. - "router_fc1": offload the moe router_fc1 part of the transformer layer. - "router_fc2": offload the moe router_fc2 part of the transformer layer. - "shared_fc1": offload the shared_fc1 part of the transformer layer. - "shared_fc2": offload the shared_fc2 part of the transformer layer. - """ - #################### # fp8 related #################### From c9f00c72fa26c545c320497bc9333d2b926824cb Mon Sep 17 00:00:00 2001 From: geyuhong Date: Mon, 8 Sep 2025 03:23:35 +0800 Subject: [PATCH 04/35] replace get_virtual_pipeline_model_parallel_rank with vp_stage --- megatron/core/models/gpt/gpt_model.py | 8 +++++++- megatron/core/models/gpt/utils.py | 4 +--- megatron/core/pipeline_parallel/cpu_offload.py | 8 ++------ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index f9448c74849..b9ef3e3ea3c 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -368,12 +368,18 @@ def forward( runtime_gather_output (bool): Gather output at runtime. Default None means `parallel_output` arg in the constructor will be used. """ - first_layer_index = get_first_layer_index(self.config, self.decoder.num_layers_per_pipeline_rank) + first_layer_index = get_first_layer_index( + self.config, + self.decoder.num_layers_per_pipeline_rank, + self.vp_stage + ) PipelineOffloadManager.get_instance().reset_chunk_handler( self.decoder.num_layers_per_pipeline_rank, + self.vp_stage, self.config.offload_activation, first_layer_index, ) + PipelineOffloadManager.get_instance().cur_forward_chunk().set_offloading_checker(offloading_checker) inference_context = deprecate_inference_params(inference_context, inference_params) diff --git a/megatron/core/models/gpt/utils.py b/megatron/core/models/gpt/utils.py index e858b286e13..eb02fca0c6b 100644 --- a/megatron/core/models/gpt/utils.py +++ b/megatron/core/models/gpt/utils.py @@ -1,18 +1,16 @@ from megatron.core.parallel_state import ( get_pipeline_model_parallel_rank, - get_virtual_pipeline_model_parallel_rank, get_pipeline_model_parallel_world_size ) def offloading_checker(tensor): return hasattr(tensor, "offloading_activation") and tensor.offloading_activation -def get_first_layer_index(config, num_layers_per_pipeline_rank): +def get_first_layer_index(config, num_layers_per_pipeline_rank, vpp_rank): if 'core_attn' in config.offload_modules: return 0 pp_rank = get_pipeline_model_parallel_rank() pp_size = get_pipeline_model_parallel_world_size() - vpp_rank = get_virtual_pipeline_model_parallel_rank() layer_index_start = num_layers_per_pipeline_rank * (pp_size * vpp_rank + pp_rank) if config.first_k_dense_replace > layer_index_start: return config.first_k_dense_replace - layer_index_start diff --git a/megatron/core/pipeline_parallel/cpu_offload.py b/megatron/core/pipeline_parallel/cpu_offload.py index e4465848c31..4892499ef98 100644 --- a/megatron/core/pipeline_parallel/cpu_offload.py +++ b/megatron/core/pipeline_parallel/cpu_offload.py @@ -3,10 +3,6 @@ from megatron.core import parallel_state from typing import Any from transformer_engine.pytorch.float8_tensor import Float8Tensor -from megatron.core.parallel_state import ( - get_pipeline_model_parallel_rank, - get_virtual_pipeline_model_parallel_rank, -) # cpu offload for pipeline @@ -76,8 +72,8 @@ def front(self): def size(self): return len(self._queue) - def reset_chunk_handler(self, num_layer, offload=True, first_layer_index=0): - cur_vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() + def reset_chunk_handler(self, num_layer, vp_stage, offload=True, first_layer_index=0): + cur_vpp_rank = vp_stage first_last_vpp_rank = self._first_last_vpp_rank # rewind From 4b0d3f1518682b7ee93e17ee80eede3f2a9774fc Mon Sep 17 00:00:00 2001 From: geyuhong Date: Mon, 8 Sep 2025 03:26:27 +0800 Subject: [PATCH 05/35] remove all MoEPositiveAuxLossAutoScaler --- megatron/core/pipeline_parallel/schedules.py | 7 +-- megatron/core/transformer/moe/moe_utils.py | 47 -------------------- megatron/core/transformer/moe/router.py | 1 - 3 files changed, 1 insertion(+), 54 deletions(-) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index 0f2dbe99d72..e0cc1d48b86 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -18,10 +18,7 @@ ) from megatron.core.process_groups_config import GradFinalizeProcessGroups from megatron.core.transformer.cuda_graphs import create_cudagraphs -from megatron.core.transformer.moe.router import ( - MoEAuxLossAutoScaler, - MoEPositiveAuxLossAutoScaler, -) +from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler from megatron.core.utils import ( drain_embedding_wgrad_compute, get_attr_wrapped_model, @@ -270,8 +267,6 @@ def forward_step_calc_loss( if config.calculate_per_token_loss: MoEAuxLossAutoScaler.set_loss_scale(loss_scale) else: - if config.offload_activation: - MoEPositiveAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches) MoEAuxLossAutoScaler.set_loss_scale(loss_scale / num_microbatches) # Set the loss scale for Multi-Token Prediction (MTP) loss. diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 91be5859664..235b6f6af0c 100644 --- a/megatron/core/transformer/moe/moe_utils.py +++ b/megatron/core/transformer/moe/moe_utils.py @@ -215,53 +215,6 @@ def set_loss_scale(scale: torch.Tensor): MoEAuxLossAutoScaler.main_loss_backward_scale.copy_(scale) -class MoEPositiveAuxLossAutoScaler(torch.autograd.Function): - """An AutoScaler that compute and scales the grad for positive auxiliary loss.""" - - main_loss_backward_scale: torch.Tensor = torch.tensor(1.0) - - @staticmethod - def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor): - """Preserve the aux_loss by storing it in the context to avoid garbage collection. - - Args: - output (torch.Tensor): The output tensor. - aux_loss (torch.Tensor): The auxiliary loss tensor. - - Returns: - torch.Tensor: The output tensor. - """ - ctx.save_for_backward(aux_loss) - return output - - @staticmethod - def backward(ctx, grad_output: torch.Tensor): - """Compute and scale the gradient for positive auxiliary loss.. - - Args: - grad_output (torch.Tensor): The gradient of the output. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled positive - auxiliary loss gradient. - """ - (aux_loss,) = ctx.saved_tensors - aux_loss_backward_scale = MoEPositiveAuxLossAutoScaler.main_loss_backward_scale - aux_loss_backward_scale = aux_loss_backward_scale * (aux_loss > 0.0) - scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale - return grad_output, scaled_aux_loss_grad - - @staticmethod - def set_loss_scale(scale: torch.Tensor): - """set the scale of the aux loss. - - Args: - scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in - matches the scale of the main_loss. - """ - MoEPositiveAuxLossAutoScaler.main_loss_backward_scale = scale - - def permute( tokens, routing_map, diff --git a/megatron/core/transformer/moe/router.py b/megatron/core/transformer/moe/router.py index 72bc9201748..6b20b862274 100644 --- a/megatron/core/transformer/moe/router.py +++ b/megatron/core/transformer/moe/router.py @@ -10,7 +10,6 @@ from megatron.core.transformer.moe.moe_utils import ( ModelCommProcessGroups, MoEAuxLossAutoScaler, - MoEPositiveAuxLossAutoScaler, apply_random_logits, apply_router_token_dropping, compute_routing_scores_for_aux_loss, From b00acbcfd949979e95ddb8f42f36f521c8421ee3 Mon Sep 17 00:00:00 2001 From: geyuhong Date: Mon, 8 Sep 2025 03:56:01 +0800 Subject: [PATCH 06/35] reduce modular PipeOffloadManager functions --- .../core/pipeline_parallel/cpu_offload.py | 8 +- megatron/core/transformer/attention.py | 82 ++++++------------- megatron/core/transformer/moe/experts.py | 48 ++++------- megatron/core/transformer/moe/moe_layer.py | 5 -- 4 files changed, 47 insertions(+), 96 deletions(-) diff --git a/megatron/core/pipeline_parallel/cpu_offload.py b/megatron/core/pipeline_parallel/cpu_offload.py index 4892499ef98..adf85417009 100644 --- a/megatron/core/pipeline_parallel/cpu_offload.py +++ b/megatron/core/pipeline_parallel/cpu_offload.py @@ -320,9 +320,13 @@ def on_group_start_backward(self): self.h2d_stream.wait_stream(torch.cuda.current_stream()) self.bulk_reload() - def register_offload_tensor(self, tensor): + def register_offload_tensor(self, tensors): self.multi_input_offload_count = True - self._offload_tensor_ptrs.append(tensor.data_ptr()) + if isinstance(tensors, list): + for tensor in tensors: + self._offload_tensor_ptrs.append(tensor.data_ptr()) + else: + self._offload_tensor_ptrs.append(tensors.data_ptr()) def is_registered_tensor(self, tensor_ptr: int) -> bool: if len(self._offload_tensor_ptrs) == 0: diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 8b614c59ad8..0c527911b45 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -267,30 +267,6 @@ def custom_forward(*inputs): return hidden_states - def _offload_qkv_linear_forward( - self, - hidden_states, - key_value_states, - ): - """Forward method with qkv linear activation offloading.""" - if not hidden_states.is_contiguous(): - hidden_states = hidden_states.contiguous() - - hidden_states = group_prefetch_offload_start(hidden_states) - - hidden_states.offloading_activation = True - - with PipelineOffloadManager.get_instance(): - query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) - - def call_back(): - cur_stream = torch.cuda.current_stream() - hidden_states.record_stream(cur_stream) - hidden_states.untyped_storage().resize_(0) - - query, key, value = group_prefetch_offload_commit(query, key, value, offloaded_call_back=call_back) - return query, key, value - def _offload_core_attention_forward( self, query, @@ -325,7 +301,6 @@ def custom_forward(*inputs): if attn_mask_type is None: attn_mask_type = self.attn_mask_type attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int) - value = value.contiguous() query = group_prefetch_offload_start(query) @@ -333,19 +308,14 @@ def custom_forward(*inputs): value = group_prefetch_offload_start(value) handler = PipelineOffloadManager.get_instance().cur_forward_chunk() - handler.register_offload_tensor(query) - handler.register_offload_tensor(key) - handler.register_offload_tensor(value) - + handler.register_offload_tensor([query, key, value]) query.offloading_activation = True key.offloading_activation = True value.offloading_activation = True - with PipelineOffloadManager.get_instance(): hidden_states = custom_forward( query, key, value, attention_mask, rotary_pos_emb, attn_mask_type ) - def call_back(): cur_stream = torch.cuda.current_stream() query.record_stream(cur_stream) @@ -354,33 +324,9 @@ def call_back(): query.untyped_storage().resize_(0) key.untyped_storage().resize_(0) value.untyped_storage().resize_(0) - hidden_states = group_prefetch_offload_commit(hidden_states, offloaded_call_back=call_back) return hidden_states[0] - def _offload_attn_linear_forward( - self, - hidden_states, - ): - """Forward method with attention linear projection activation offloading.""" - if not hidden_states.is_contiguous(): - hidden_states = hidden_states.contiguous() - - hidden_states = group_prefetch_offload_start(hidden_states) - - hidden_states.offloading_activation = True - - with PipelineOffloadManager.get_instance(): - output, bias = self.linear_proj(hidden_states) - - def call_back(): - cur_stream = torch.cuda.current_stream() - hidden_states.record_stream(cur_stream) - hidden_states.untyped_storage().resize_(0) - - output, bias = group_prefetch_offload_commit(output, bias, offloaded_call_back=call_back) - return output, bias - def _allocate_memory(self, inference_max_sequence_length, batch_size, dim, dtype): """Allocate memory to store kv cache during inference.""" @@ -805,7 +751,17 @@ def forward( # self or cross attn. nvtx_range_push(suffix="qkv") if self.offload_qkv_linear: - query, key, value = self._offload_qkv_linear_forward(hidden_states, key_value_states) + if not hidden_states.is_contiguous(): + hidden_states = hidden_states.contiguous() + hidden_states = group_prefetch_offload_start(hidden_states) + hidden_states.offloading_activation = True + with PipelineOffloadManager.get_instance(): + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + def call_back(): + cur_stream = torch.cuda.current_stream() + hidden_states.record_stream(cur_stream) + hidden_states.untyped_storage().resize_(0) + query, key, value = group_prefetch_offload_commit(query, key, value, offloaded_call_back=call_back) else: query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) nvtx_range_pop(suffix="qkv") @@ -936,6 +892,8 @@ def forward( packed_seq_params=packed_seq_params, ) elif self.offload_core_attention and self.training: + + core_attn_out = self._offload_core_attention_forward( query, key, @@ -991,7 +949,17 @@ def forward( nvtx_range_push(suffix="linear_proj") if self.offload_attn_linear: - output, bias = self._offload_attn_linear_forward(core_attn_out) + if not core_attn_out.is_contiguous(): + core_attn_out = core_attn_out.contiguous() + core_attn_out = group_prefetch_offload_start(core_attn_out) + core_attn_out.offloading_activation = True + with PipelineOffloadManager.get_instance(): + output, bias = self.linear_proj(core_attn_out) + def call_back(): + cur_stream = torch.cuda.current_stream() + core_attn_out.record_stream(cur_stream) + core_attn_out.untyped_storage().resize_(0) + output, bias = group_prefetch_offload_commit(output, bias, offloaded_call_back=call_back) else: output, bias = self.linear_proj(core_attn_out) nvtx_range_pop(suffix="linear_proj") diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index ac1d1d99b94..62613b0d0ad 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -833,36 +833,6 @@ def __init__( self.fp8_padding = Fp8Padding(self.num_local_experts) self.fp8_unpadding = Fp8Unpadding(self.num_local_experts) - def _offload_router_fc1_forward( - self, - permuted_local_hidden_states, - tokens_per_expert, - ): - """Forward method with router fc1 activation offloading.""" - if not permuted_local_hidden_states.is_contiguous(): - permuted_local_hidden_states = permuted_local_hidden_states.contiguous() - - permuted_local_hidden_states = group_prefetch_offload_start(permuted_local_hidden_states) - - permuted_local_hidden_states.offloading_activation = True - - with PipelineOffloadManager.get_instance(): - intermediate_parallel, bias_parallel = self.linear_fc1( - permuted_local_hidden_states, tokens_per_expert - ) - - def call_back(): - cur_stream = torch.cuda.current_stream() - permuted_local_hidden_states.record_stream(cur_stream) - permuted_local_hidden_states.untyped_storage().resize_(0) - - intermediate_parallel, bias_parallel = group_prefetch_offload_commit( - intermediate_parallel, - bias_parallel, - offloaded_call_back=call_back - ) - return intermediate_parallel, bias_parallel - def forward( self, permuted_local_hidden_states: torch.Tensor, @@ -903,8 +873,22 @@ def forward( permuted_probs = torch.ones_like(permuted_probs) if self.offload_router_fc1: - intermediate_parallel, bias_parallel = self._offload_router_fc1_forward( - permuted_local_hidden_states, tokens_per_expert + if not permuted_local_hidden_states.is_contiguous(): + permuted_local_hidden_states = permuted_local_hidden_states.contiguous() + permuted_local_hidden_states = group_prefetch_offload_start(permuted_local_hidden_states) + permuted_local_hidden_states.offloading_activation = True + with PipelineOffloadManager.get_instance(): + intermediate_parallel, bias_parallel = self.linear_fc1( + permuted_local_hidden_states, tokens_per_expert + ) + def call_back(): + cur_stream = torch.cuda.current_stream() + permuted_local_hidden_states.record_stream(cur_stream) + permuted_local_hidden_states.untyped_storage().resize_(0) + intermediate_parallel, bias_parallel = group_prefetch_offload_commit( + intermediate_parallel, + bias_parallel, + offloaded_call_back=call_back ) else: intermediate_parallel, bias_parallel = self.linear_fc1( diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 07dbda60e76..e7dd9d4e56c 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -19,11 +19,6 @@ ) from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.pipeline_parallel.cpu_offload import ( - PipelineOffloadManager, - group_prefetch_offload_start, - group_prefetch_offload_commit, -) try: import transformer_engine as te # pylint: disable=unused-import From b2c99f725ae4ae73394dcc0a1655cfe6d16e5870 Mon Sep 17 00:00:00 2001 From: geyuhong Date: Mon, 8 Sep 2025 04:10:29 +0800 Subject: [PATCH 07/35] remove call_back function --- .../core/pipeline_parallel/cpu_offload.py | 21 ++++++++++-------- megatron/core/transformer/attention.py | 22 +++---------------- megatron/core/transformer/moe/experts.py | 6 +---- .../core/transformer/transformer_layer.py | 7 +----- 4 files changed, 17 insertions(+), 39 deletions(-) diff --git a/megatron/core/pipeline_parallel/cpu_offload.py b/megatron/core/pipeline_parallel/cpu_offload.py index adf85417009..f3aa3a175c3 100644 --- a/megatron/core/pipeline_parallel/cpu_offload.py +++ b/megatron/core/pipeline_parallel/cpu_offload.py @@ -273,17 +273,20 @@ def forward_sync(self): self.d2h_stream.wait_stream(torch.cuda.current_stream()) self._f_event.wait(torch.cuda.current_stream()) - def bulk_offload(self, offloaded_call_back): + def bulk_offload(self, release_tensors): self.d2h_stream.wait_stream(torch.cuda.current_stream()) if self.should_bulk_offload(): self.bulk_offload_group(self._layer_index) - if offloaded_call_back is not None: - offloaded_call_back() + if len(release_tensors) > 0: + cur_stream = torch.cuda.current_stream() + for release_tensor in release_tensors: + release_tensor.record_stream(cur_stream) + release_tensor.untyped_storage().resize_(0) - def on_group_commit_forward(self, offloaded_call_back): + def on_group_commit_forward(self, release_tensors): # wait each other self.forward_sync() - self.bulk_offload(offloaded_call_back) + self.bulk_offload(release_tensors) self._layer_index = self._layer_index + 1 self._tensor_count_current_layer = 0 @@ -348,10 +351,10 @@ class GroupCommitFunction(torch.autograd.Function): def forward(ctx, *args): # pylint: disable=missing-function-docstring - offloaded_call_back = args[-1] + release_tensors = args[-1] cpu_offload_handler = args[-2] tensor = args[:-2] - cpu_offload_handler.on_group_commit_forward(offloaded_call_back) + cpu_offload_handler.on_group_commit_forward(release_tensors) ctx.cpu_offload_handler = cpu_offload_handler # return the identical tensor @@ -366,9 +369,9 @@ def backward(ctx, *grad_output): return grad_output + (None, None) -def group_prefetch_offload_commit(*tensor, offloaded_call_back=None): +def group_prefetch_offload_commit(*tensor, release_tensors=[]): cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() - return GroupCommitFunction.apply(*tensor, cur_forward_chunk, offloaded_call_back) + return GroupCommitFunction.apply(*tensor, cur_forward_chunk, release_tensors) class GroupStartFunction(torch.autograd.Function): diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 0c527911b45..8464e2ffd85 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -316,15 +316,7 @@ def custom_forward(*inputs): hidden_states = custom_forward( query, key, value, attention_mask, rotary_pos_emb, attn_mask_type ) - def call_back(): - cur_stream = torch.cuda.current_stream() - query.record_stream(cur_stream) - key.record_stream(cur_stream) - value.record_stream(cur_stream) - query.untyped_storage().resize_(0) - key.untyped_storage().resize_(0) - value.untyped_storage().resize_(0) - hidden_states = group_prefetch_offload_commit(hidden_states, offloaded_call_back=call_back) + hidden_states = group_prefetch_offload_commit(hidden_states, release_tensors=[query, key, value]) return hidden_states[0] def _allocate_memory(self, inference_max_sequence_length, batch_size, dim, dtype): @@ -757,11 +749,7 @@ def forward( hidden_states.offloading_activation = True with PipelineOffloadManager.get_instance(): query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) - def call_back(): - cur_stream = torch.cuda.current_stream() - hidden_states.record_stream(cur_stream) - hidden_states.untyped_storage().resize_(0) - query, key, value = group_prefetch_offload_commit(query, key, value, offloaded_call_back=call_back) + query, key, value = group_prefetch_offload_commit(query, key, value, release_tensors=[hidden_states]) else: query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) nvtx_range_pop(suffix="qkv") @@ -955,11 +943,7 @@ def call_back(): core_attn_out.offloading_activation = True with PipelineOffloadManager.get_instance(): output, bias = self.linear_proj(core_attn_out) - def call_back(): - cur_stream = torch.cuda.current_stream() - core_attn_out.record_stream(cur_stream) - core_attn_out.untyped_storage().resize_(0) - output, bias = group_prefetch_offload_commit(output, bias, offloaded_call_back=call_back) + output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[core_attn_out]) else: output, bias = self.linear_proj(core_attn_out) nvtx_range_pop(suffix="linear_proj") diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 62613b0d0ad..fde855c3f63 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -881,14 +881,10 @@ def forward( intermediate_parallel, bias_parallel = self.linear_fc1( permuted_local_hidden_states, tokens_per_expert ) - def call_back(): - cur_stream = torch.cuda.current_stream() - permuted_local_hidden_states.record_stream(cur_stream) - permuted_local_hidden_states.untyped_storage().resize_(0) intermediate_parallel, bias_parallel = group_prefetch_offload_commit( intermediate_parallel, bias_parallel, - offloaded_call_back=call_back + release_tensors=[permuted_local_hidden_states] ) else: intermediate_parallel, bias_parallel = self.linear_fc1( diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index c92d64d5004..95b49e97518 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -541,12 +541,7 @@ def _forward_attention( packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, ) - def call_back(): - cur_stream = torch.cuda.current_stream() - input_layernorm_output.record_stream(cur_stream) - input_layernorm_output.untyped_storage().resize_(0) - - attention_output_with_bias = group_prefetch_offload_commit(attention_output_with_bias, call_back) + attention_output_with_bias = group_prefetch_offload_commit(attention_output_with_bias, release_tensors=[input_layernorm_output]) else: attention_output_with_bias = self.self_attention( input_layernorm_output, From e845344a7079d1c3de375a0aac82b306b68e243d Mon Sep 17 00:00:00 2001 From: geyuhong Date: Mon, 8 Sep 2025 05:13:29 +0800 Subject: [PATCH 08/35] polish all event sync --- megatron/core/pipeline_parallel/cpu_offload.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/megatron/core/pipeline_parallel/cpu_offload.py b/megatron/core/pipeline_parallel/cpu_offload.py index f3aa3a175c3..61aaa2f22e6 100644 --- a/megatron/core/pipeline_parallel/cpu_offload.py +++ b/megatron/core/pipeline_parallel/cpu_offload.py @@ -3,6 +3,7 @@ from megatron.core import parallel_state from typing import Any from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.cpu_offload import AsyncDoubleBufferGroupOffloadHandler # cpu offload for pipeline @@ -23,10 +24,6 @@ def __init__(self): # allocate streams and events for synchronization self._d2h_stream = torch.cuda.Stream() self._h2d_stream = torch.cuda.Stream() - self._f_event = torch.cuda.Event() - self._b_event = torch.cuda.Event() - self._f_event.record(self._d2h_stream) - self._b_event.record(self._h2d_stream) self.reset() @property @@ -118,7 +115,7 @@ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: return self.cur_backward_chunk().tensor_pop(saved_state) -class ChunkOffloadHandler: +class ChunkOffloadHandler(AsyncDoubleBufferGroupOffloadHandler): @staticmethod def offload(src_tensor, pin_memory=True): """Offload.""" @@ -165,8 +162,6 @@ def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, first_layer self.torch_tensor_count = 0 self.d2h_stream = PipelineOffloadManager.get_instance().d2h_stream self.h2d_stream = PipelineOffloadManager.get_instance().h2d_stream - self._f_event = PipelineOffloadManager.get_instance()._f_event - self._b_event = PipelineOffloadManager.get_instance()._b_event self.do_offload = offload self._offload_tensor_ptrs = deque() @@ -225,7 +220,6 @@ def bulk_offload_group(self, group_to_offload): self.offload_count_per_layer[group_to_offload] += 1 self._tensor_tag_to_state[tensor_tag] = state self._offloaded_group_count = group_to_offload + 1 - self._f_event.record(self.d2h_stream) def bulk_reload_group(self, group_to_reload): """Bulk reload group.""" @@ -271,7 +265,7 @@ def should_bulk_offload(self): def forward_sync(self): self.d2h_stream.wait_stream(torch.cuda.current_stream()) - self._f_event.wait(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.d2h_stream) def bulk_offload(self, release_tensors): self.d2h_stream.wait_stream(torch.cuda.current_stream()) @@ -303,7 +297,7 @@ def bulk_reload(self): def backward_sync(self): self.h2d_stream.wait_stream(torch.cuda.current_stream()) - self._b_event.wait(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.h2d_stream) def on_group_commit_backward(self): cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() @@ -313,8 +307,6 @@ def on_group_commit_backward(self): assert cur_backward_chunk is self self._layer_index = self._layer_index - 1 self.backward_sync() - # layer index already loaded back - # self.bulk_reload() def on_group_start_forward(self): pass From 81f44c7ee16dc6d729f7d880f1205ac24ec4715b Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 8 Sep 2025 22:18:29 -0700 Subject: [PATCH 09/35] add arguments.py and minor fix, OOTB runable now. Signed-off-by: Hongbin Liu --- megatron/core/models/gpt/gpt_model.py | 9 ++------- megatron/core/models/gpt/utils.py | 11 ----------- megatron/core/pipeline_parallel/cpu_offload.py | 1 - megatron/core/transformer/transformer_config.py | 6 ++++-- megatron/training/arguments.py | 8 +++++++- 5 files changed, 13 insertions(+), 22 deletions(-) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index b9ef3e3ea3c..4048bf93068 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -33,7 +33,7 @@ from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import WrappedTensor, deprecate_inference_params -from megatron.core.models.gpt.utils import offloading_checker, get_first_layer_index +from megatron.core.models.gpt.utils import offloading_checker from megatron.core.pipeline_parallel.cpu_offload import PipelineOffloadManager @@ -368,16 +368,11 @@ def forward( runtime_gather_output (bool): Gather output at runtime. Default None means `parallel_output` arg in the constructor will be used. """ - first_layer_index = get_first_layer_index( - self.config, - self.decoder.num_layers_per_pipeline_rank, - self.vp_stage - ) PipelineOffloadManager.get_instance().reset_chunk_handler( self.decoder.num_layers_per_pipeline_rank, self.vp_stage, self.config.offload_activation, - first_layer_index, + 0, ) PipelineOffloadManager.get_instance().cur_forward_chunk().set_offloading_checker(offloading_checker) diff --git a/megatron/core/models/gpt/utils.py b/megatron/core/models/gpt/utils.py index eb02fca0c6b..ceb572b98fa 100644 --- a/megatron/core/models/gpt/utils.py +++ b/megatron/core/models/gpt/utils.py @@ -5,14 +5,3 @@ def offloading_checker(tensor): return hasattr(tensor, "offloading_activation") and tensor.offloading_activation - -def get_first_layer_index(config, num_layers_per_pipeline_rank, vpp_rank): - if 'core_attn' in config.offload_modules: - return 0 - pp_rank = get_pipeline_model_parallel_rank() - pp_size = get_pipeline_model_parallel_world_size() - layer_index_start = num_layers_per_pipeline_rank * (pp_size * vpp_rank + pp_rank) - if config.first_k_dense_replace > layer_index_start: - return config.first_k_dense_replace - layer_index_start - else: - return 0 \ No newline at end of file diff --git a/megatron/core/pipeline_parallel/cpu_offload.py b/megatron/core/pipeline_parallel/cpu_offload.py index 61aaa2f22e6..07eba0440fc 100644 --- a/megatron/core/pipeline_parallel/cpu_offload.py +++ b/megatron/core/pipeline_parallel/cpu_offload.py @@ -238,7 +238,6 @@ def bulk_reload_group(self, group_to_reload): break if self.offload_count_per_layer[group_to_reload] == 0: self._offloaded_group_count = group_to_reload - self._b_event.record(self.h2d_stream) def pre_reload_last_layer(self): if not self.do_offload: diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index b11594a4f86..e33f51459c8 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -944,9 +944,11 @@ def __post_init__(self): if len(self.offload_modules) > 0: allowed_modules = { - "self_attn", "qkv_linear", "core_attn", "attn_linear", "router_fc1", "router_fc2", - "shared_fc1", "shared_fc2" + "self_attn", "qkv_linear", "core_attn", "attn_linear", "router_fc1" } + if self.multi_latent_attention: + flag = "self_attn" in self.offload_modules or "qkv_linear" in self.offload_modules or "core_attn" in self.offload_modules or "attn_linear" in self.offload_modules + assert not flag, "(Temporary) self_attn, qkv_linear, core_attn, attn_linear must not be in offload_modules when multi_latent_attention is True" invalid_modules = set(self.offload_modules) - allowed_modules assert not invalid_modules, ( f'Invalid choices for offload_modules: {invalid_modules}. ' diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index c510dd6d179..bde739c1d12 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1112,6 +1112,9 @@ def validate_args(args, defaults={}): "disabling gradient_accumulation_fusion is only supported with TE >= 2.7.0 " "when enabling delay_wgrad_compute" ) + + if args.offload_activation: + assert not args.overlap_grad_reduce, "overlap_grad_reduce is not supported with offload_activation" if args.mtp_num_layers: assert not args.use_legacy_models, "The legacy Megatron models does not support Multi-Token Prediction (MTP)." @@ -2113,7 +2116,10 @@ def _add_training_args(parser): help='The communicator group names to use high priority streams.') group.add_argument('--use-te-activation-func', action='store_true', help='Use activation function kernel from Transformer Engine in MLP module.') - + group.add_argument('--offload-activation', action='store_true', + help='Offload the activation to the CPU.') + group.add_argument('--offload-modules', nargs='*', type=str, default=[], + help='The submodules to offload. Choices: "self_attn", "qkv_linear", "core_attn", "attn_linear", "router_fc1", "router_fc2", "shared_fc1", "shared_fc2".') return parser From 7a52582a09b3a33efe3376ff2bc65508cfbda186 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Tue, 9 Sep 2025 03:56:34 -0700 Subject: [PATCH 10/35] support activation offloading at PP=1&PP&VPP Signed-off-by: Hongbin Liu --- megatron/core/pipeline_parallel/cpu_offload.py | 11 ++++++++--- megatron/core/pipeline_parallel/schedules.py | 7 +++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/megatron/core/pipeline_parallel/cpu_offload.py b/megatron/core/pipeline_parallel/cpu_offload.py index 07eba0440fc..bdd32dc1bf4 100644 --- a/megatron/core/pipeline_parallel/cpu_offload.py +++ b/megatron/core/pipeline_parallel/cpu_offload.py @@ -17,7 +17,10 @@ def get_instance(cls): def __init__(self): self._queue = deque() - self._vpp = parallel_state.get_virtual_pipeline_model_parallel_world_size() + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is None: + self._vpp = 1 + else: + self._vpp = parallel_state.get_virtual_pipeline_model_parallel_world_size() # cache vpp - 1 stages self._stages = [[] for _ in range(self._vpp)] @@ -70,7 +73,10 @@ def size(self): return len(self._queue) def reset_chunk_handler(self, num_layer, vp_stage, offload=True, first_layer_index=0): - cur_vpp_rank = vp_stage + if vp_stage is None: + cur_vpp_rank = 0 + else: + cur_vpp_rank = vp_stage first_last_vpp_rank = self._first_last_vpp_rank # rewind @@ -267,7 +273,6 @@ def forward_sync(self): torch.cuda.current_stream().wait_stream(self.d2h_stream) def bulk_offload(self, release_tensors): - self.d2h_stream.wait_stream(torch.cuda.current_stream()) if self.should_bulk_offload(): self.bulk_offload_group(self._layer_index) if len(release_tensors) > 0: diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index e0cc1d48b86..ae1e5a1e36c 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -564,6 +564,9 @@ def forward_backward_no_pipelining( adjust_tensor_shapes_fn is None ), "adjust_tensor_shapes_fn is not supported for non-pipeline-parallel schedule" + if not forward_only: + PipelineOffloadManager.get_instance().reset() + config = get_model_config(model) if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) @@ -906,6 +909,7 @@ def forward_backward_pipelining_with_interleaving( if not forward_only: PipelineOffloadManager.get_instance().reset() + if config.overlap_p2p_comm and config.batch_p2p_comm: raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") @@ -2056,6 +2060,9 @@ def forward_backward_pipelining_without_interleaving( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) + if not forward_only: + PipelineOffloadManager.get_instance().reset() + # Disable async grad reductions no_sync_func = config.no_sync_func if no_sync_func is None: From 29b084d17bf646ca12c317b091021b76346495ca Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 17 Sep 2025 01:13:15 -0700 Subject: [PATCH 11/35] support offloading moe_act/router_fc1/layernorm simultaneously Signed-off-by: Hongbin Liu --- megatron/core/model_parallel_config.py | 6 +- megatron/core/models/gpt/gpt_model.py | 2 - megatron/core/models/gpt/utils.py | 7 - .../core/pipeline_parallel/cpu_offload.py | 172 ++++++++++++++---- megatron/core/tensor_parallel/random.py | 17 +- megatron/core/transformer/moe/experts.py | 68 ++++--- megatron/core/transformer/moe/moe_layer.py | 4 + .../transformer/multi_latent_attention.py | 9 + .../core/transformer/transformer_config.py | 6 +- .../core/transformer/transformer_layer.py | 78 ++++---- 10 files changed, 249 insertions(+), 120 deletions(-) delete mode 100644 megatron/core/models/gpt/utils.py diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 251c1945c3f..115429a5f2e 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -320,17 +320,15 @@ class ModelParallelConfig: offload_modules: Optional[list[str]] = None """The submodules to offload. - choices: "self_attn", "qkv_linear", "core_attn", "attn_linear", "router_fc1", "router_fc2", - "shared_fc1", "shared_fc2". + choices: "self_attn", "qkv_linear", "core_attn", "attn_linear", "router_fc1", "moe_act", "router_fc2". default: ["core_attn"]. "self_attn": offload the self_attn part of the transformer layer. "qkv_linear": offload the qkv_linear part of the transformer layer. "core_attn": offload the core attention part of the transformer layer. "attn_linear": offload the attn linear projection part of the transformer layer. "router_fc1": offload the moe router_fc1 part of the transformer layer. + "moe_act": offload the moe act part of the transformer layer. "router_fc2": offload the moe router_fc2 part of the transformer layer. - "shared_fc1": offload the shared_fc1 part of the transformer layer. - "shared_fc2": offload the shared_fc2 part of the transformer layer. """ ################### diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 4048bf93068..d031fef015f 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -33,7 +33,6 @@ from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import WrappedTensor, deprecate_inference_params -from megatron.core.models.gpt.utils import offloading_checker from megatron.core.pipeline_parallel.cpu_offload import PipelineOffloadManager @@ -374,7 +373,6 @@ def forward( self.config.offload_activation, 0, ) - PipelineOffloadManager.get_instance().cur_forward_chunk().set_offloading_checker(offloading_checker) inference_context = deprecate_inference_params(inference_context, inference_params) diff --git a/megatron/core/models/gpt/utils.py b/megatron/core/models/gpt/utils.py deleted file mode 100644 index ceb572b98fa..00000000000 --- a/megatron/core/models/gpt/utils.py +++ /dev/null @@ -1,7 +0,0 @@ -from megatron.core.parallel_state import ( - get_pipeline_model_parallel_rank, - get_pipeline_model_parallel_world_size -) - -def offloading_checker(tensor): - return hasattr(tensor, "offloading_activation") and tensor.offloading_activation diff --git a/megatron/core/pipeline_parallel/cpu_offload.py b/megatron/core/pipeline_parallel/cpu_offload.py index bdd32dc1bf4..ee8d9a308a1 100644 --- a/megatron/core/pipeline_parallel/cpu_offload.py +++ b/megatron/core/pipeline_parallel/cpu_offload.py @@ -1,11 +1,13 @@ from collections import deque, defaultdict import torch -from megatron.core import parallel_state from typing import Any from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.cpu_offload import AsyncDoubleBufferGroupOffloadHandler # cpu offload for pipeline +DEBUG = False +DEBUG_RANK = 0 +MIN_OFFLOADED_TENSOR_SIZE = 1024 * 1024 class PipelineOffloadManager: OFFLOAD_MGR = None @@ -16,6 +18,7 @@ def get_instance(cls): return cls.OFFLOAD_MGR def __init__(self): + from megatron.core import parallel_state self._queue = deque() if parallel_state.get_virtual_pipeline_model_parallel_world_size() is None: self._vpp = 1 @@ -56,11 +59,15 @@ def flush(self): self._stages[i] = [] def push(self, handler): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("pushing handler") self._queue.append(handler) def pop(self): assert self.size() self._cur_backward_chunk = self._queue.popleft() + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("popping handler", self._cur_backward_chunk) def front(self): if not len(self._queue): @@ -100,6 +107,8 @@ def cur_backward_chunk(self): return self._cur_backward_chunk def __enter__(self): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print(f"__enter__") self.OFFLOAD_MGR self.inside_context = True @@ -108,16 +117,22 @@ def __enter__(self): ) def __exit__(self, *args: Any): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print(f"__exit__") self.inside_context = False torch._C._autograd._pop_saved_tensors_default_hooks() def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("on_save_for_backward", tensor.shape) assert self.inside_context if self.cur_forward_chunk().is_registered_tensor(tensor.data_ptr()): tensor.offloading_activation = True return self.cur_forward_chunk().tensor_push(tensor) def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("on_get_saved_tensor", saved_state) return self.cur_backward_chunk().tensor_pop(saved_state) @@ -125,6 +140,8 @@ class ChunkOffloadHandler(AsyncDoubleBufferGroupOffloadHandler): @staticmethod def offload(src_tensor, pin_memory=True): """Offload.""" + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("offload") fp8_offload = isinstance(src_tensor, Float8Tensor) cpu_backup = torch.empty( @@ -138,6 +155,8 @@ def offload(src_tensor, pin_memory=True): if fp8_offload: cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup) + if not src_tensor.is_contiguous(): + src_tensor = src_tensor.contiguous() cpu_backup.copy_(src_tensor, non_blocking=pin_memory) state = (src_tensor.device, cpu_backup) return state @@ -145,6 +164,8 @@ def offload(src_tensor, pin_memory=True): @staticmethod def reload(state, non_blocking=None): """Reload.""" + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("reload") dev, cpu_backup = state if non_blocking is None: non_blocking = cpu_backup.is_pinned() @@ -155,30 +176,33 @@ def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, first_layer # Data Structure to maintain reference to activation tensors self._tensor_tag_to_state = {} # Tracking the number of layers offloaded - self._offloaded_group_count = 0 + # self._offloaded_group_count = 0 self._is_first_last_vpp_chunk = is_first_last_vpp_chunk - self._layer_index = first_layer_index + self._offloaded_group_index = 0 + self._groups_to_offload = [] + self._groups_to_reload = [] self.first_layer_index = first_layer_index - self._tensor_count_current_layer = 0 + self._tensor_count_current_group = 0 self.multi_input_offload_count = False - self.offload_count_per_layer = defaultdict(int) + # self.offload_count_per_layer = defaultdict(int) - self.tensor_need_offloading_checker = None self.torch_tensor_count = 0 self.d2h_stream = PipelineOffloadManager.get_instance().d2h_stream self.h2d_stream = PipelineOffloadManager.get_instance().h2d_stream self.do_offload = offload + self.is_last_layer = False self._offload_tensor_ptrs = deque() def is_first_last_layer(self): - return self._is_first_last_vpp_chunk and self.is_last_layer() - - def is_last_layer(self): - return (self._layer_index == self._num_layers - 1) + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("is_first_last_layer", self._is_first_last_vpp_chunk, self.is_last_layer) + return self._is_first_last_vpp_chunk and self.is_last_layer def tensor_push(self, tensor): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("tensor_push") torch_stray_tensor = isinstance( tensor, ( @@ -189,27 +213,41 @@ def tensor_push(self, tensor): if not torch_stray_tensor:# True # obtain a unique tensor tag - tensor_tag = (self._layer_index, self._tensor_count_current_layer) - self._tensor_count_current_layer += 1 + tensor_tag = (self._offloaded_group_index, self._tensor_count_current_group) + self._tensor_count_current_group += 1 assert tensor_tag not in self._tensor_tag_to_state self._tensor_tag_to_state[tensor_tag] = tensor else: tensor_tag = (-1, self.torch_tensor_count) self.torch_tensor_count += 1 self._tensor_tag_to_state[tensor_tag] = tensor + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("tensor_push", tensor.shape) + print("tensor_tag", tensor_tag) return tensor_tag def tensor_pop(self, tensor_tag): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("tensor_pop") + print("tensor_tag", tensor_tag) assert tensor_tag in self._tensor_tag_to_state, f"{tensor_tag}, {self._tensor_tag_to_state.keys()}" tensor = self._tensor_tag_to_state.pop(tensor_tag) assert not isinstance(tensor, tuple) + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("tensor_pop", tensor.shape) return tensor - def set_offloading_checker(self, check_func): - self.tensor_need_offloading_checker = check_func + def tensor_need_offloading_checker(self, tensor): + if tensor.numel() < MIN_OFFLOADED_TENSOR_SIZE: + return False + if hasattr(tensor, "offloading_activation") and not tensor.offloading_activation: + return False + return True def bulk_offload_group(self, group_to_offload): """Bulk offload group.""" + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("bulk_offload_group") if not self.do_offload: return assert not self.is_first_last_layer() @@ -217,40 +255,60 @@ def bulk_offload_group(self, group_to_offload): for tensor_tag, state in self._tensor_tag_to_state.items(): group_id, _ = tensor_tag if group_id == group_to_offload: + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("tensor_tag", tensor_tag) + print("group_to_offload", group_to_offload) assert not isinstance(state, tuple) tensor_on_device = state # if offload, return the reference to cpu copy - if self.tensor_need_offloading_checker is not None and self.tensor_need_offloading_checker(tensor_on_device): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("tensor_need_offloading_checker", self.tensor_need_offloading_checker(tensor_on_device)) + print("tensor_on_device", tensor_on_device.shape) + print("hasattr(tensor_on_device, 'offloading_activation')", hasattr(tensor_on_device, 'offloading_activation')) + if self.tensor_need_offloading_checker(tensor_on_device): state = self.offload(tensor_on_device) tensor_on_device.record_stream(self.d2h_stream) - self.offload_count_per_layer[group_to_offload] += 1 + # self.offload_count_per_layer[group_to_offload] += 1 self._tensor_tag_to_state[tensor_tag] = state - self._offloaded_group_count = group_to_offload + 1 + # self._offloaded_group_count = group_to_offload + 1 + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("exit bulk_offload_group") def bulk_reload_group(self, group_to_reload): """Bulk reload group.""" + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("bulk_reload_group") if not self.do_offload: return + found_reload_group = False with torch.cuda.stream(self.h2d_stream): # move back tensors for tensor_label, state in self._tensor_tag_to_state.items(): group_id, _ = tensor_label if group_id == group_to_reload: + found_reload_group = True if isinstance(state, tuple): recovered_tensor = self.reload(state) self._tensor_tag_to_state[tensor_label] = recovered_tensor - self.offload_count_per_layer[group_to_reload] -= 1 - if self.offload_count_per_layer[group_to_reload] > 0 and self.multi_input_offload_count: - break - if self.offload_count_per_layer[group_to_reload] == 0: - self._offloaded_group_count = group_to_reload + # self.offload_count_per_layer[group_to_reload] -= 1 + # if self.offload_count_per_layer[group_to_reload] > 0 and self.multi_input_offload_count: + # break + # if self.offload_count_per_layer[group_to_reload] == 0: + # self._offloaded_group_count = group_to_reload + return found_reload_group def pre_reload_last_layer(self): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("pre_reload_last_layer") if not self.do_offload: return assert not self._is_first_last_vpp_chunk - if self._num_layers == self._offloaded_group_count: - self.bulk_reload_group(self._num_layers - 1) + # TODO: check if this is correct + if len(self._groups_to_reload) > 0: + if self.bulk_reload_group(self._groups_to_reload[-1]): + self._groups_to_reload.pop() + # if self._num_layers == self._offloaded_group_count: + # self.bulk_reload_group(self._num_layers - 1) # assert self._num_layers - 1 == self._offloaded_group_count def should_bulk_offload(self): @@ -263,59 +321,83 @@ def should_bulk_offload(self): # if next backward chunk is this chunk (for last pp stage) next_backward_chunk = PipelineOffloadManager.get_instance().get_instance().front() if next_backward_chunk is not None and next_backward_chunk is self: - if self.is_last_layer(): + if self.is_last_layer: return False return True def forward_sync(self): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("forward_sync") self.d2h_stream.wait_stream(torch.cuda.current_stream()) - torch.cuda.current_stream().wait_stream(self.d2h_stream) + # torch.cuda.current_stream().wait_stream(self.d2h_stream) def bulk_offload(self, release_tensors): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("bulk_offload") if self.should_bulk_offload(): - self.bulk_offload_group(self._layer_index) + group_to_offload = self._groups_to_offload.pop() + self._groups_to_reload.append(group_to_offload) + self.bulk_offload_group(group_to_offload) if len(release_tensors) > 0: cur_stream = torch.cuda.current_stream() for release_tensor in release_tensors: release_tensor.record_stream(cur_stream) release_tensor.untyped_storage().resize_(0) + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("exit bulk_offload") def on_group_commit_forward(self, release_tensors): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("on_group_commit_forward") # wait each other self.forward_sync() self.bulk_offload(release_tensors) - self._layer_index = self._layer_index + 1 - self._tensor_count_current_layer = 0 + self._offloaded_group_per_layer = self._offloaded_group_index // self._num_layers + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("exit on_group_commit_forward") def bulk_reload(self): - if self.do_offload: - assert self._layer_index == self._offloaded_group_count - if self._layer_index > self.first_layer_index: + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("bulk_reload") + # if self.do_offload: + # assert self._layer_index == self._offloaded_group_count, f"{self._layer_index}, {self._offloaded_group_count}" + if len(self._groups_to_reload) > 0: # load next layer - self.bulk_reload_group(self._layer_index - 1) + if self.bulk_reload_group(self._groups_to_reload[-1]): + self._groups_to_reload.pop() else: next_backward_chunk = PipelineOffloadManager.get_instance().front() if next_backward_chunk is not None: next_backward_chunk.pre_reload_last_layer() def backward_sync(self): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("backward_sync") self.h2d_stream.wait_stream(torch.cuda.current_stream()) - torch.cuda.current_stream().wait_stream(self.h2d_stream) + # computation kernels wait until the offloaded groups of one layer are fully reloaded. + if self._offloaded_group_index % self._offloaded_group_per_layer == 0: + torch.cuda.current_stream().wait_stream(self.h2d_stream) def on_group_commit_backward(self): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("on_group_commit_backward") cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() if not cur_backward_chunk is self: PipelineOffloadManager.get_instance().pop() cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() assert cur_backward_chunk is self - self._layer_index = self._layer_index - 1 self.backward_sync() + self._offloaded_group_index = self._offloaded_group_index - 1 def on_group_start_forward(self): - pass + self._offloaded_group_index = self._offloaded_group_index + 1 + self._tensor_count_current_group = 0 + self._groups_to_offload.append(self._offloaded_group_index) def on_group_start_backward(self): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("on_group_start_backward") self.h2d_stream.wait_stream(torch.cuda.current_stream()) self.bulk_reload() @@ -346,6 +428,8 @@ class GroupCommitFunction(torch.autograd.Function): @staticmethod def forward(ctx, *args): # pylint: disable=missing-function-docstring + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("GroupCommitFunction forward") release_tensors = args[-1] cpu_offload_handler = args[-2] @@ -359,6 +443,8 @@ def forward(ctx, *args): @staticmethod def backward(ctx, *grad_output): # pylint: disable=missing-function-docstring + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("GroupCommitFunction backward") cpu_offload_handler = ctx.cpu_offload_handler cpu_offload_handler.on_group_commit_backward() @@ -378,21 +464,27 @@ class GroupStartFunction(torch.autograd.Function): """ @staticmethod - def forward(ctx, tensor, cpu_offload_handler): + def forward(ctx, tensor, cpu_offload_handler, is_last_layer): # pylint: disable=missing-function-docstring ctx.cpu_offload_handler = cpu_offload_handler + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("GroupStartFunction forward", is_last_layer) + ctx.cpu_offload_handler.is_last_layer = is_last_layer + cpu_offload_handler.on_group_start_forward() # return the identical tensor return tensor @staticmethod def backward(ctx, grad_output): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("GroupStartFunction backward") # pylint: disable=missing-function-docstring cpu_offload_handler = ctx.cpu_offload_handler cpu_offload_handler.on_group_start_backward() - return grad_output, None + return grad_output, None, None -def group_prefetch_offload_start(tensor): +def group_prefetch_offload_start(tensor, is_last_layer=False): cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() - return GroupStartFunction.apply(tensor, cur_forward_chunk) + return GroupStartFunction.apply(tensor, cur_forward_chunk, is_last_layer) diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index 54cac0e41e3..9d96dfaf815 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -510,10 +510,11 @@ def forward(ctx, run_function, checkpoint_without_output_obj, *args): @staticmethod def backward(ctx, *args): """Backward pass.""" - inputs = ctx.saved_tensors + inputs = ctx.inputs outputs = ctx.outputs torch.autograd.backward(outputs, args) ctx.outputs = None + ctx.inputs = None grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in inputs) return (None, None) + grads @@ -573,8 +574,19 @@ def _recompute(self, _): recompute_ctx = contextlib.nullcontext() fp8_ctx = contextlib.nullcontext() + inputs = self.ctx.saved_tensors + # do not know why, if saved_tensors is handled by saved_tensor_hook, grad of inputs will be None (not nan) + # detach it to bypass + def detach(t): + if isinstance(t, torch.Tensor): + requires_grad = t.requires_grad + t = t.detach() + t.requires_grad_(requires_grad) + return t + + inputs = tuple(detach(t) for t in inputs) with torch.enable_grad(), fp8_ctx, recompute_ctx: - outputs = self.run_function(*self.ctx.saved_tensors) + outputs = self.run_function(*inputs) self.run_function = None self.rng_states = None @@ -590,6 +602,7 @@ def _recompute(self, _): output.untyped_storage().copy_(recomputation_output.untyped_storage()) self.ctx.outputs = outputs + self.ctx.inputs = inputs self.outputs = None self.ctx = None diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index fde855c3f63..36fbb854c79 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -10,6 +10,7 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter +import contextlib from megatron.core import parallel_state, tensor_parallel from megatron.core.activations import squared_relu @@ -111,6 +112,7 @@ def __init__( self, num_local_experts: int, config: TransformerConfig, + layer_number: Optional[int] = None, model_comm_pgs: Optional[ModelCommProcessGroups] = None, ): super().__init__(config=config) @@ -758,11 +760,13 @@ def __init__( num_local_experts, config: TransformerConfig, submodules: MLPSubmodules, + layer_number: Optional[int] = None, model_comm_pgs: Optional[ModelCommProcessGroups] = None, ): super().__init__(config=config) self.num_local_experts = num_local_experts self.input_size = self.config.hidden_size + self.layer_number = layer_number assert ( config.add_bias_linear == False ), "bias not supported in TEGroupedMLP yet, please set '--disable-bias-linear' instead." @@ -814,6 +818,11 @@ def __init__( and "router_fc1" in self.config.offload_modules ) + self.offload_moe_act = ( + self.config.offload_activation + and "moe_act" in self.config.offload_modules + ) + self.offload_router_fc2 = ( self.config.offload_activation and "router_fc2" in self.config.offload_modules @@ -833,6 +842,10 @@ def __init__( self.fp8_padding = Fp8Padding(self.num_local_experts) self.fp8_unpadding = Fp8Unpadding(self.num_local_experts) + def set_layer_number(self, layer_number: int): + """Set the layer number for the TEGroupedMLP.""" + self.layer_number = layer_number + def forward( self, permuted_local_hidden_states: torch.Tensor, @@ -872,24 +885,17 @@ def forward( # Probs already applied, so reset to 1. permuted_probs = torch.ones_like(permuted_probs) + offload_context = contextlib.nullcontext() if self.offload_router_fc1: - if not permuted_local_hidden_states.is_contiguous(): - permuted_local_hidden_states = permuted_local_hidden_states.contiguous() - permuted_local_hidden_states = group_prefetch_offload_start(permuted_local_hidden_states) - permuted_local_hidden_states.offloading_activation = True - with PipelineOffloadManager.get_instance(): - intermediate_parallel, bias_parallel = self.linear_fc1( - permuted_local_hidden_states, tokens_per_expert - ) - intermediate_parallel, bias_parallel = group_prefetch_offload_commit( - intermediate_parallel, - bias_parallel, - release_tensors=[permuted_local_hidden_states] - ) - else: - intermediate_parallel, bias_parallel = self.linear_fc1( + permuted_local_hidden_states = group_prefetch_offload_start(permuted_local_hidden_states, is_last_layer=(self.layer_number == self.config.num_layers)) + offload_context = PipelineOffloadManager.get_instance() + with offload_context: + fc1_output, bias_parallel = self.linear_fc1( permuted_local_hidden_states, tokens_per_expert - ) + ) + if self.offload_router_fc1: + fc1_output, bias_parallel = group_prefetch_offload_commit(fc1_output, bias_parallel, release_tensors=[permuted_local_hidden_states]) + offload_context = contextlib.nullcontext() def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): if self.config.use_te_activation_func: @@ -946,18 +952,29 @@ def glu(x): intermediate_parallel = intermediate_parallel.to(original_dtype) return intermediate_parallel + if self.offload_moe_act: + fc1_output = group_prefetch_offload_start(fc1_output, is_last_layer=(self.layer_number == self.config.num_layers)) + offload_context = PipelineOffloadManager.get_instance() + if self.activation_recompute: self.activation_checkpoint = tensor_parallel.CheckpointWithoutOutput() - intermediate_parallel = self.activation_checkpoint.checkpoint( - bias_act_func, intermediate_parallel, bias_parallel, permuted_probs - ) - output, output_bias = self.linear_fc2(intermediate_parallel, tokens_per_expert) - self.activation_checkpoint.discard_output_and_register_recompute(output) + with offload_context: + bias_act_output = self.activation_checkpoint.checkpoint( + bias_act_func, fc1_output, bias_parallel, permuted_probs + ) else: - intermediate_parallel = bias_act_func( - intermediate_parallel, bias_parallel, permuted_probs - ) - output, output_bias = self.linear_fc2(intermediate_parallel, tokens_per_expert) + with offload_context: + bias_act_output = bias_act_func( + fc1_output, bias_parallel, permuted_probs + ) + + output, output_bias = self.linear_fc2(bias_act_output, tokens_per_expert) + if self.activation_recompute: + self.activation_checkpoint.discard_output_and_register_recompute(output) + if self.offload_moe_act: + output, = group_prefetch_offload_commit(output, release_tensors=[fc1_output]) + offload_context = contextlib.nullcontext() + # upad and concat the output if self.config.fp8: @@ -1025,6 +1042,7 @@ def __init__( num_local_experts, config: TransformerConfig, submodules: MLPSubmodules, + layer_number: Optional[int] = None, model_comm_pgs: Optional[ModelCommProcessGroups] = None, ): diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index e7dd9d4e56c..772900e17fd 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -87,6 +87,10 @@ def set_layer_number(self, layer_number: int): """Set the layer number for the MoE layer.""" self.layer_number = layer_number self.router.set_layer_number(layer_number) + if hasattr(self.experts, 'set_layer_number'): + self.experts.set_layer_number(layer_number) + if hasattr(self.shared_experts, 'set_layer_number'): + self.shared_experts.set_layer_number(layer_number) class MoELayer(BaseMoELayer): diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 24afd6d63b7..071e9dc67f1 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -263,6 +263,15 @@ def forward( core_attn_out = self._checkpointed_attention_forward( query, key, value, attention_mask, packed_seq_params=packed_seq_params ) + elif self.offload_core_attention and self.training: + core_attn_out = self._offload_core_attention_forward( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + packed_seq_params=packed_seq_params, + ) else: if inference_context is None or inference_context.is_static_batching(): core_attn_out = self.core_attention( diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index e33f51459c8..54fb950c234 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -944,11 +944,11 @@ def __post_init__(self): if len(self.offload_modules) > 0: allowed_modules = { - "self_attn", "qkv_linear", "core_attn", "attn_linear", "router_fc1" + "self_attn", "qkv_linear", "core_attn", "attn_linear", "router_fc1", "moe_act", "layernorm" } if self.multi_latent_attention: - flag = "self_attn" in self.offload_modules or "qkv_linear" in self.offload_modules or "core_attn" in self.offload_modules or "attn_linear" in self.offload_modules - assert not flag, "(Temporary) self_attn, qkv_linear, core_attn, attn_linear must not be in offload_modules when multi_latent_attention is True" + flag = "self_attn" in self.offload_modules or "qkv_linear" in self.offload_modules or "attn_linear" in self.offload_modules + assert not flag, "(Temporary) self_attn, qkv_linear, attn_linear must not be in offload_modules when multi_latent_attention is True" invalid_modules = set(self.offload_modules) - allowed_modules assert not invalid_modules, ( f'Invalid choices for offload_modules: {invalid_modules}. ' diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 95b49e97518..619a02c4100 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -9,6 +9,7 @@ import torch import torch.distributed from torch import Tensor +import contextlib from megatron.core import parallel_state, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict @@ -429,6 +430,10 @@ def __init__( self.config.offload_activation and "self_attn" in self.config.offload_modules ) + self.offload_layernorm = ( + self.config.offload_activation + and "layernorm" in self.config.offload_modules + ) # @jcasper how should we handle nvfuser? # Set bias+dropout+add fusion grad_enable execution handler. @@ -519,41 +524,26 @@ def _forward_attention( # Self attention. nvtx_range_push(suffix="self_attention") + offload_context = contextlib.nullcontext() if self.offload_self_attn: - from megatron.core.pipeline_parallel.cpu_offload import ( - PipelineOffloadManager, - group_prefetch_offload_start, - group_prefetch_offload_commit, - ) - if not input_layernorm_output.is_contiguous(): - input_layernorm_output = input_layernorm_output.contiguous() - input_layernorm_output = group_prefetch_offload_start(input_layernorm_output) - input_layernorm_output.offloading_activation = True - with PipelineOffloadManager.get_instance(): - attention_output_with_bias = self.self_attention( - input_layernorm_output, - attention_mask=attention_mask, - inference_context=inference_context, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - ) - attention_output_with_bias = group_prefetch_offload_commit(attention_output_with_bias, release_tensors=[input_layernorm_output]) - else: + input_layernorm_output = group_prefetch_offload_start(input_layernorm_output, + is_last_layer=(self.layer_number == self.config.num_layers)) + offload_context = PipelineOffloadManager.get_instance() + with offload_context: attention_output_with_bias = self.self_attention( - input_layernorm_output, - attention_mask=attention_mask, - inference_context=inference_context, - rotary_pos_emb=rotary_pos_emb, - rotary_pos_cos=rotary_pos_cos, - rotary_pos_sin=rotary_pos_sin, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - sequence_len_offset=sequence_len_offset, - ) + input_layernorm_output, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + ) + if self.offload_self_attn: + attention_output_with_bias, = group_prefetch_offload_commit(attention_output_with_bias, release_tensors=[input_layernorm_output]) + offload_context = contextlib.nullcontext() nvtx_range_pop(suffix="self_attention") if self.recompute_input_layernorm: @@ -612,14 +602,25 @@ def _forward_mlp(self, hidden_states, inference_context=None): # Residual connection. residual = hidden_states + offload_context = contextlib.nullcontext() + if self.offload_layernorm: + from megatron.core.pipeline_parallel.cpu_offload import ( + PipelineOffloadManager, + group_prefetch_offload_start, + group_prefetch_offload_commit, + ) + hidden_states = group_prefetch_offload_start(hidden_states, is_last_layer=(self.layer_number == self.config.num_layers)) + offload_context = PipelineOffloadManager.get_instance() # Optional Layer norm post the cross-attention. if self.recompute_pre_mlp_layernorm: self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint( - self.pre_mlp_layernorm, hidden_states - ) + with offload_context: + pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint( + self.pre_mlp_layernorm, hidden_states + ) else: - pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) + with offload_context: + pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) nvtx_range_push(suffix="mlp") # Potentially chunk the MLP computation during prefill to minimize the peak activation size @@ -679,6 +680,9 @@ def _forward_mlp(self, hidden_states, inference_context=None): mlp_output_with_bias, residual, self.hidden_dropout ) nvtx_range_pop(suffix="mlp_bda") + if self.offload_layernorm: + hidden_states, = group_prefetch_offload_commit(hidden_states, release_tensors=[residual]) + offload_context = contextlib.nullcontext() # Jit compiled function creates 'view' tensor. This tensor # potentially gets saved in the MPU checkpoint function context, From 83ab849eab03cd1dce016f9a645a56612aeb5eb0 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 17 Sep 2025 23:54:55 -0700 Subject: [PATCH 12/35] support offloading core_attn/attn_proj and code refactoring Signed-off-by: Hongbin Liu --- megatron/core/model_parallel_config.py | 2 + megatron/core/models/gpt/gpt_model.py | 3 +- .../core/pipeline_parallel/cpu_offload.py | 490 ------------------ megatron/core/pipeline_parallel/schedules.py | 2 +- megatron/core/transformer/attention.py | 32 +- megatron/core/transformer/moe/experts.py | 6 +- .../transformer/multi_latent_attention.py | 36 +- .../core/transformer/transformer_config.py | 47 +- .../core/transformer/transformer_layer.py | 45 +- 9 files changed, 99 insertions(+), 564 deletions(-) delete mode 100644 megatron/core/pipeline_parallel/cpu_offload.py diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 115429a5f2e..2ccf42743fd 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -330,6 +330,8 @@ class ModelParallelConfig: "moe_act": offload the moe act part of the transformer layer. "router_fc2": offload the moe router_fc2 part of the transformer layer. """ + offload_module_count_per_layer: Optional[int] = 0 + """The number of modules to offload per layer. default: 0.""" ################### # CPU Offloading diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index d031fef015f..2aab9912a63 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -33,7 +33,7 @@ from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import WrappedTensor, deprecate_inference_params -from megatron.core.pipeline_parallel.cpu_offload import PipelineOffloadManager +from megatron.core.transformer.cpu_offload import PipelineOffloadManager class GPTModel(LanguageModule): @@ -372,6 +372,7 @@ def forward( self.vp_stage, self.config.offload_activation, 0, + self.config.offload_module_count_per_layer, ) inference_context = deprecate_inference_params(inference_context, inference_params) diff --git a/megatron/core/pipeline_parallel/cpu_offload.py b/megatron/core/pipeline_parallel/cpu_offload.py deleted file mode 100644 index ee8d9a308a1..00000000000 --- a/megatron/core/pipeline_parallel/cpu_offload.py +++ /dev/null @@ -1,490 +0,0 @@ -from collections import deque, defaultdict -import torch -from typing import Any -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.cpu_offload import AsyncDoubleBufferGroupOffloadHandler - -# cpu offload for pipeline -DEBUG = False -DEBUG_RANK = 0 -MIN_OFFLOADED_TENSOR_SIZE = 1024 * 1024 - -class PipelineOffloadManager: - OFFLOAD_MGR = None - @classmethod - def get_instance(cls): - if cls.OFFLOAD_MGR is None: - cls.OFFLOAD_MGR = PipelineOffloadManager() - return cls.OFFLOAD_MGR - - def __init__(self): - from megatron.core import parallel_state - self._queue = deque() - if parallel_state.get_virtual_pipeline_model_parallel_world_size() is None: - self._vpp = 1 - else: - self._vpp = parallel_state.get_virtual_pipeline_model_parallel_world_size() - - # cache vpp - 1 stages - self._stages = [[] for _ in range(self._vpp)] - # allocate streams and events for synchronization - self._d2h_stream = torch.cuda.Stream() - self._h2d_stream = torch.cuda.Stream() - self.reset() - - @property - def d2h_stream(self): - return self._d2h_stream - - @property - def h2d_stream(self): - return self._h2d_stream - - def reset(self): - self._inside_context = False - self._cur_forward_chunk = None - self._cur_backward_chunk = None - self._first_last_vpp_rank = True - - def flush(self): - # put into the queue in the backward order - if len(self._stages[0]) == len(self._stages[-1]): - lens = [len(e) for e in self._stages] - assert min(lens) == max(lens) - self._stages[-1] = [] - for chunks in reversed(self._stages): - for chunk in chunks: - self.push(chunk) - for i in range(self._vpp): - self._stages[i] = [] - - def push(self, handler): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("pushing handler") - self._queue.append(handler) - - def pop(self): - assert self.size() - self._cur_backward_chunk = self._queue.popleft() - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("popping handler", self._cur_backward_chunk) - - def front(self): - if not len(self._queue): - return None - f = self._queue.popleft() - self._queue.appendleft(f) - return f - - def size(self): - return len(self._queue) - - def reset_chunk_handler(self, num_layer, vp_stage, offload=True, first_layer_index=0): - if vp_stage is None: - cur_vpp_rank = 0 - else: - cur_vpp_rank = vp_stage - - first_last_vpp_rank = self._first_last_vpp_rank - # rewind - if cur_vpp_rank == self._vpp - 1: - self.flush() - first_last_vpp_rank = first_last_vpp_rank and (cur_vpp_rank == self._vpp - 1) - cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload, first_layer_index) - # save for latter push - self._stages[cur_vpp_rank].append(cur_chunk) - if cur_vpp_rank == self._vpp - 1: - self._first_last_vpp_rank = False - self.push(cur_chunk) - self.flush() - self._cur_forward_chunk = cur_chunk - cur_chunk.vpp_rank = cur_vpp_rank - - def cur_forward_chunk(self): - return self._cur_forward_chunk - - def cur_backward_chunk(self): - return self._cur_backward_chunk - - def __enter__(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print(f"__enter__") - self.OFFLOAD_MGR - self.inside_context = True - - torch._C._autograd._push_saved_tensors_default_hooks( - self.on_save_for_backward, self.on_get_saved_tensor - ) - - def __exit__(self, *args: Any): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print(f"__exit__") - self.inside_context = False - torch._C._autograd._pop_saved_tensors_default_hooks() - - def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("on_save_for_backward", tensor.shape) - assert self.inside_context - if self.cur_forward_chunk().is_registered_tensor(tensor.data_ptr()): - tensor.offloading_activation = True - return self.cur_forward_chunk().tensor_push(tensor) - - def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("on_get_saved_tensor", saved_state) - return self.cur_backward_chunk().tensor_pop(saved_state) - - -class ChunkOffloadHandler(AsyncDoubleBufferGroupOffloadHandler): - @staticmethod - def offload(src_tensor, pin_memory=True): - """Offload.""" - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("offload") - fp8_offload = isinstance(src_tensor, Float8Tensor) - - cpu_backup = torch.empty( - src_tensor.size(), - dtype=torch.uint8 if fp8_offload else src_tensor.dtype, - layout=src_tensor.layout, - device="cpu", - pin_memory=pin_memory, - ) - - if fp8_offload: - cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup) - - if not src_tensor.is_contiguous(): - src_tensor = src_tensor.contiguous() - cpu_backup.copy_(src_tensor, non_blocking=pin_memory) - state = (src_tensor.device, cpu_backup) - return state - - @staticmethod - def reload(state, non_blocking=None): - """Reload.""" - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("reload") - dev, cpu_backup = state - if non_blocking is None: - non_blocking = cpu_backup.is_pinned() - return cpu_backup.to(dev, non_blocking=non_blocking) - - def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, first_layer_index=0): - self._num_layers = num_layer - # Data Structure to maintain reference to activation tensors - self._tensor_tag_to_state = {} - # Tracking the number of layers offloaded - # self._offloaded_group_count = 0 - self._is_first_last_vpp_chunk = is_first_last_vpp_chunk - - self._offloaded_group_index = 0 - self._groups_to_offload = [] - self._groups_to_reload = [] - self.first_layer_index = first_layer_index - self._tensor_count_current_group = 0 - self.multi_input_offload_count = False - # self.offload_count_per_layer = defaultdict(int) - - self.torch_tensor_count = 0 - self.d2h_stream = PipelineOffloadManager.get_instance().d2h_stream - self.h2d_stream = PipelineOffloadManager.get_instance().h2d_stream - self.do_offload = offload - self.is_last_layer = False - - self._offload_tensor_ptrs = deque() - - def is_first_last_layer(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("is_first_last_layer", self._is_first_last_vpp_chunk, self.is_last_layer) - return self._is_first_last_vpp_chunk and self.is_last_layer - - def tensor_push(self, tensor): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("tensor_push") - torch_stray_tensor = isinstance( - tensor, - ( - torch._subclasses.fake_tensor.FakeTensor, - torch._subclasses.functional_tensor.FunctionalTensor, - ), - ) - - if not torch_stray_tensor:# True - # obtain a unique tensor tag - tensor_tag = (self._offloaded_group_index, self._tensor_count_current_group) - self._tensor_count_current_group += 1 - assert tensor_tag not in self._tensor_tag_to_state - self._tensor_tag_to_state[tensor_tag] = tensor - else: - tensor_tag = (-1, self.torch_tensor_count) - self.torch_tensor_count += 1 - self._tensor_tag_to_state[tensor_tag] = tensor - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("tensor_push", tensor.shape) - print("tensor_tag", tensor_tag) - return tensor_tag - - def tensor_pop(self, tensor_tag): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("tensor_pop") - print("tensor_tag", tensor_tag) - assert tensor_tag in self._tensor_tag_to_state, f"{tensor_tag}, {self._tensor_tag_to_state.keys()}" - tensor = self._tensor_tag_to_state.pop(tensor_tag) - assert not isinstance(tensor, tuple) - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("tensor_pop", tensor.shape) - return tensor - - def tensor_need_offloading_checker(self, tensor): - if tensor.numel() < MIN_OFFLOADED_TENSOR_SIZE: - return False - if hasattr(tensor, "offloading_activation") and not tensor.offloading_activation: - return False - return True - - def bulk_offload_group(self, group_to_offload): - """Bulk offload group.""" - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("bulk_offload_group") - if not self.do_offload: - return - assert not self.is_first_last_layer() - with torch.cuda.stream(self.d2h_stream): - for tensor_tag, state in self._tensor_tag_to_state.items(): - group_id, _ = tensor_tag - if group_id == group_to_offload: - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("tensor_tag", tensor_tag) - print("group_to_offload", group_to_offload) - assert not isinstance(state, tuple) - tensor_on_device = state - # if offload, return the reference to cpu copy - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("tensor_need_offloading_checker", self.tensor_need_offloading_checker(tensor_on_device)) - print("tensor_on_device", tensor_on_device.shape) - print("hasattr(tensor_on_device, 'offloading_activation')", hasattr(tensor_on_device, 'offloading_activation')) - if self.tensor_need_offloading_checker(tensor_on_device): - state = self.offload(tensor_on_device) - tensor_on_device.record_stream(self.d2h_stream) - # self.offload_count_per_layer[group_to_offload] += 1 - self._tensor_tag_to_state[tensor_tag] = state - # self._offloaded_group_count = group_to_offload + 1 - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("exit bulk_offload_group") - - def bulk_reload_group(self, group_to_reload): - """Bulk reload group.""" - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("bulk_reload_group") - if not self.do_offload: - return - found_reload_group = False - with torch.cuda.stream(self.h2d_stream): - # move back tensors - for tensor_label, state in self._tensor_tag_to_state.items(): - group_id, _ = tensor_label - if group_id == group_to_reload: - found_reload_group = True - if isinstance(state, tuple): - recovered_tensor = self.reload(state) - self._tensor_tag_to_state[tensor_label] = recovered_tensor - # self.offload_count_per_layer[group_to_reload] -= 1 - # if self.offload_count_per_layer[group_to_reload] > 0 and self.multi_input_offload_count: - # break - # if self.offload_count_per_layer[group_to_reload] == 0: - # self._offloaded_group_count = group_to_reload - return found_reload_group - - def pre_reload_last_layer(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("pre_reload_last_layer") - if not self.do_offload: - return - assert not self._is_first_last_vpp_chunk - # TODO: check if this is correct - if len(self._groups_to_reload) > 0: - if self.bulk_reload_group(self._groups_to_reload[-1]): - self._groups_to_reload.pop() - # if self._num_layers == self._offloaded_group_count: - # self.bulk_reload_group(self._num_layers - 1) - # assert self._num_layers - 1 == self._offloaded_group_count - - def should_bulk_offload(self): - if not self.do_offload: - return False - # first backward chunk - if self.is_first_last_layer(): - return False - - # if next backward chunk is this chunk (for last pp stage) - next_backward_chunk = PipelineOffloadManager.get_instance().get_instance().front() - if next_backward_chunk is not None and next_backward_chunk is self: - if self.is_last_layer: - return False - - return True - - def forward_sync(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("forward_sync") - self.d2h_stream.wait_stream(torch.cuda.current_stream()) - # torch.cuda.current_stream().wait_stream(self.d2h_stream) - - def bulk_offload(self, release_tensors): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("bulk_offload") - if self.should_bulk_offload(): - group_to_offload = self._groups_to_offload.pop() - self._groups_to_reload.append(group_to_offload) - self.bulk_offload_group(group_to_offload) - if len(release_tensors) > 0: - cur_stream = torch.cuda.current_stream() - for release_tensor in release_tensors: - release_tensor.record_stream(cur_stream) - release_tensor.untyped_storage().resize_(0) - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("exit bulk_offload") - - def on_group_commit_forward(self, release_tensors): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("on_group_commit_forward") - # wait each other - self.forward_sync() - self.bulk_offload(release_tensors) - self._offloaded_group_per_layer = self._offloaded_group_index // self._num_layers - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("exit on_group_commit_forward") - - def bulk_reload(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("bulk_reload") - # if self.do_offload: - # assert self._layer_index == self._offloaded_group_count, f"{self._layer_index}, {self._offloaded_group_count}" - if len(self._groups_to_reload) > 0: - # load next layer - if self.bulk_reload_group(self._groups_to_reload[-1]): - self._groups_to_reload.pop() - else: - next_backward_chunk = PipelineOffloadManager.get_instance().front() - if next_backward_chunk is not None: - next_backward_chunk.pre_reload_last_layer() - - def backward_sync(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("backward_sync") - self.h2d_stream.wait_stream(torch.cuda.current_stream()) - # computation kernels wait until the offloaded groups of one layer are fully reloaded. - if self._offloaded_group_index % self._offloaded_group_per_layer == 0: - torch.cuda.current_stream().wait_stream(self.h2d_stream) - - def on_group_commit_backward(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("on_group_commit_backward") - cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() - if not cur_backward_chunk is self: - PipelineOffloadManager.get_instance().pop() - cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() - assert cur_backward_chunk is self - self.backward_sync() - self._offloaded_group_index = self._offloaded_group_index - 1 - - def on_group_start_forward(self): - self._offloaded_group_index = self._offloaded_group_index + 1 - self._tensor_count_current_group = 0 - self._groups_to_offload.append(self._offloaded_group_index) - - def on_group_start_backward(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("on_group_start_backward") - self.h2d_stream.wait_stream(torch.cuda.current_stream()) - self.bulk_reload() - - def register_offload_tensor(self, tensors): - self.multi_input_offload_count = True - if isinstance(tensors, list): - for tensor in tensors: - self._offload_tensor_ptrs.append(tensor.data_ptr()) - else: - self._offload_tensor_ptrs.append(tensors.data_ptr()) - - def is_registered_tensor(self, tensor_ptr: int) -> bool: - if len(self._offload_tensor_ptrs) == 0: - return False - is_registered = tensor_ptr == self._offload_tensor_ptrs[0] - if is_registered: - self._offload_tensor_ptrs.popleft() - return is_registered - - -class GroupCommitFunction(torch.autograd.Function): - """this is a dummy op with output identical to input. - However, it is necessary for marking a timepoint for offload handler to - accomplish all synchronizations. Implementing it as a function is necessary - because we need to actions in both forward and backward. - """ - - @staticmethod - def forward(ctx, *args): - # pylint: disable=missing-function-docstring - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("GroupCommitFunction forward") - - release_tensors = args[-1] - cpu_offload_handler = args[-2] - tensor = args[:-2] - cpu_offload_handler.on_group_commit_forward(release_tensors) - ctx.cpu_offload_handler = cpu_offload_handler - - # return the identical tensor - return tensor - - @staticmethod - def backward(ctx, *grad_output): - # pylint: disable=missing-function-docstring - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("GroupCommitFunction backward") - - cpu_offload_handler = ctx.cpu_offload_handler - cpu_offload_handler.on_group_commit_backward() - return grad_output + (None, None) - - -def group_prefetch_offload_commit(*tensor, release_tensors=[]): - cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() - return GroupCommitFunction.apply(*tensor, cur_forward_chunk, release_tensors) - - -class GroupStartFunction(torch.autograd.Function): - """this is a dummy op with output identical to input. - However, it is necessary for marking a timepoint for offload handler to - accomplish all synchronizations. Implementing it as a function is necessary - because we need to actions in both forward and backward. - """ - - @staticmethod - def forward(ctx, tensor, cpu_offload_handler, is_last_layer): - # pylint: disable=missing-function-docstring - ctx.cpu_offload_handler = cpu_offload_handler - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("GroupStartFunction forward", is_last_layer) - - ctx.cpu_offload_handler.is_last_layer = is_last_layer - cpu_offload_handler.on_group_start_forward() - # return the identical tensor - return tensor - - @staticmethod - def backward(ctx, grad_output): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("GroupStartFunction backward") - # pylint: disable=missing-function-docstring - cpu_offload_handler = ctx.cpu_offload_handler - cpu_offload_handler.on_group_start_backward() - return grad_output, None, None - - -def group_prefetch_offload_start(tensor, is_last_layer=False): - cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() - return GroupStartFunction.apply(tensor, cur_forward_chunk, is_last_layer) diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index ae1e5a1e36c..99cb0f339d6 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -19,6 +19,7 @@ from megatron.core.process_groups_config import GradFinalizeProcessGroups from megatron.core.transformer.cuda_graphs import create_cudagraphs from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler +from megatron.core.transformer.cpu_offload import PipelineOffloadManager from megatron.core.utils import ( drain_embedding_wgrad_compute, get_attr_wrapped_model, @@ -32,7 +33,6 @@ combined_1f1b_schedule_for_interleaved_pipelining, combined_1f1b_schedule_for_no_pipelining, ) -from .cpu_offload import PipelineOffloadManager # Types Shape = Union[List[int], torch.Size] diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 8464e2ffd85..70e1e4dfa1a 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -6,6 +6,7 @@ import torch from torch import Tensor +import contextlib from megatron.core import tensor_parallel from megatron.core.inference.contexts import BaseInferenceContext @@ -37,7 +38,7 @@ from .enums import AttnMaskType from .transformer_config import TransformerConfig -from megatron.core.pipeline_parallel.cpu_offload import ( +from megatron.core.transformer.cpu_offload import ( PipelineOffloadManager, group_prefetch_offload_start, group_prefetch_offload_commit, @@ -194,9 +195,9 @@ def __init__( and "core_attn" in self.config.offload_modules ) - self.offload_attn_linear = ( + self.offload_attn_proj = ( self.config.offload_activation - and "attn_linear" in self.config.offload_modules + and "attn_proj" in self.config.offload_modules ) # Output. @@ -304,14 +305,12 @@ def custom_forward(*inputs): value = value.contiguous() query = group_prefetch_offload_start(query) - key = group_prefetch_offload_start(key) - value = group_prefetch_offload_start(value) - - handler = PipelineOffloadManager.get_instance().cur_forward_chunk() - handler.register_offload_tensor([query, key, value]) - query.offloading_activation = True - key.offloading_activation = True - value.offloading_activation = True + + # handler = PipelineOffloadManager.get_instance().cur_forward_chunk() + # handler.register_offload_tensor([query, key, value]) + # query.offloading_activation = True + # key.offloading_activation = True + # value.offloading_activation = True with PipelineOffloadManager.get_instance(): hidden_states = custom_forward( query, key, value, attention_mask, rotary_pos_emb, attn_mask_type @@ -936,16 +935,7 @@ def forward( # ================= nvtx_range_push(suffix="linear_proj") - if self.offload_attn_linear: - if not core_attn_out.is_contiguous(): - core_attn_out = core_attn_out.contiguous() - core_attn_out = group_prefetch_offload_start(core_attn_out) - core_attn_out.offloading_activation = True - with PipelineOffloadManager.get_instance(): - output, bias = self.linear_proj(core_attn_out) - output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[core_attn_out]) - else: - output, bias = self.linear_proj(core_attn_out) + output, bias = self.linear_proj(core_attn_out) nvtx_range_pop(suffix="linear_proj") return output, bias diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 36fbb854c79..639d46b62e6 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -41,7 +41,7 @@ make_sharded_object_for_checkpoint, sharded_state_dict_default, ) -from megatron.core.pipeline_parallel.cpu_offload import ( +from megatron.core.transformer.cpu_offload import ( PipelineOffloadManager, group_prefetch_offload_start, group_prefetch_offload_commit, @@ -887,7 +887,7 @@ def forward( offload_context = contextlib.nullcontext() if self.offload_router_fc1: - permuted_local_hidden_states = group_prefetch_offload_start(permuted_local_hidden_states, is_last_layer=(self.layer_number == self.config.num_layers)) + permuted_local_hidden_states = group_prefetch_offload_start(permuted_local_hidden_states, name="router_fc1") offload_context = PipelineOffloadManager.get_instance() with offload_context: fc1_output, bias_parallel = self.linear_fc1( @@ -953,7 +953,7 @@ def glu(x): return intermediate_parallel if self.offload_moe_act: - fc1_output = group_prefetch_offload_start(fc1_output, is_last_layer=(self.layer_number == self.config.num_layers)) + fc1_output = group_prefetch_offload_start(fc1_output, name="moe_act") offload_context = PipelineOffloadManager.get_instance() if self.activation_recompute: diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 071e9dc67f1..8f01a0cb8ac 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -37,6 +37,12 @@ from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import MLATransformerConfig from megatron.core.utils import deprecate_inference_params, is_te_min_version +from megatron.core.transformer.cpu_offload import ( + PipelineOffloadManager, + group_prefetch_offload_start, + group_prefetch_offload_commit, +) +import contextlib try: from megatron.core.fusions.fused_mla_yarn_rope_apply import ( @@ -263,18 +269,15 @@ def forward( core_attn_out = self._checkpointed_attention_forward( query, key, value, attention_mask, packed_seq_params=packed_seq_params ) - elif self.offload_core_attention and self.training: - core_attn_out = self._offload_core_attention_forward( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - packed_seq_params=packed_seq_params, - ) else: + offload_context = contextlib.nullcontext() + if self.offload_core_attention and self.training: + query = group_prefetch_offload_start(query, name="core_attn") + offload_context = PipelineOffloadManager.get_instance() + if inference_context is None or inference_context.is_static_batching(): - core_attn_out = self.core_attention( + with offload_context: + core_attn_out = self.core_attention( query, key, value, @@ -302,6 +305,9 @@ def forward( # Only rearrange if not in absorption mode (Flash MLA handles format correctly) if not inference_context.is_decode_only(): core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') + if self.offload_core_attention and self.training: + core_attn_out, = group_prefetch_offload_commit(core_attn_out, release_tensors=[query, key, value]) + offload_context = contextlib.nullcontext() # We are doing absorption with cache mla latents and decode mode. if self.cache_mla_latents and inference_context.is_decode_only(): @@ -327,7 +333,15 @@ def forward( # ================= # Output. [sq, b, h] # ================= - output, bias = self.linear_proj(core_attn_out) + offload_context = contextlib.nullcontext() + if self.offload_attn_proj: + core_attn_out = group_prefetch_offload_start(core_attn_out, name="attn_proj") + offload_context = PipelineOffloadManager.get_instance() + with offload_context: + output, bias = self.linear_proj(core_attn_out) + if self.offload_attn_proj: + output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[core_attn_out]) + offload_context = contextlib.nullcontext() return output, bias diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 54fb950c234..c7c5a175a6d 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -939,37 +939,42 @@ def __post_init__(self): if "moe" not in self.recompute_modules: self.recompute_modules.append("moe") - if self.offload_modules is None: - self.offload_modules = ["core_attn"] + # if self.offload_modules is None: + # self.offload_modules = ["core_attn"] if len(self.offload_modules) > 0: + self.offload_modules = list(set(self.offload_modules)) allowed_modules = { - "self_attn", "qkv_linear", "core_attn", "attn_linear", "router_fc1", "moe_act", "layernorm" + "core_attn", "attn_proj", "router_fc1", "moe_act", "attn_norm", "mlp_norm" } - if self.multi_latent_attention: - flag = "self_attn" in self.offload_modules or "qkv_linear" in self.offload_modules or "attn_linear" in self.offload_modules - assert not flag, "(Temporary) self_attn, qkv_linear, attn_linear must not be in offload_modules when multi_latent_attention is True" invalid_modules = set(self.offload_modules) - allowed_modules assert not invalid_modules, ( f'Invalid choices for offload_modules: {invalid_modules}. ' f'Allowed modules are: {allowed_modules}' ) + self.offload_module_count_per_layer = len(self.offload_modules) + if "attn_proj" in self.offload_modules and "core_attn" not in self.offload_modules: + raise ValueError( + "attn_proj cannot be set to offload_modules alone without core_attn " + "because the input of attn_proj is the output of core_attn, " + "which is needed in core_attn.backward()." + ) - if "self_attn" in self.offload_modules: - if "qkv_linear" in self.offload_modules: - self.offload_modules.remove("qkv_linear") - if "core_attn" in self.offload_modules: - self.offload_modules.remove("core_attn") - if "attn_linear" in self.offload_modules: - self.offload_modules.remove("attn_linear") - - if "core_attn" in self.offload_modules: - warnings.warn( - "If you are using transformer_engine as the transformer implementation, " - "the core_attn is from transformer_engine and may be the fused version. " - "For fused attention, you have no need to set 'core_attn' to offload. " - "Please check that the core_attn offload is really needed." - ) + # if "self_attn" in self.offload_modules: + # if "qkv_linear" in self.offload_modules: + # self.offload_modules.remove("qkv_linear") + # if "core_attn" in self.offload_modules: + # self.offload_modules.remove("core_attn") + # if "attn_linear" in self.offload_modules: + # self.offload_modules.remove("attn_linear") + + # if "core_attn" in self.offload_modules: + # warnings.warn( + # "If you are using transformer_engine as the transformer implementation, " + # "the core_attn is from transformer_engine and may be the fused version. " + # "For fused attention, you have no need to set 'core_attn' to offload. " + # "Please check that the core_attn offload is really needed." + # ) if ( self.num_layers_in_first_pipeline_stage is not None diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 619a02c4100..3384fdd9cf6 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -31,6 +31,11 @@ nvtx_range_pop, nvtx_range_push, ) +from megatron.core.transformer.cpu_offload import ( + PipelineOffloadManager, + group_prefetch_offload_start, + group_prefetch_offload_commit, +) logger = logging.getLogger(__name__) @@ -430,9 +435,13 @@ def __init__( self.config.offload_activation and "self_attn" in self.config.offload_modules ) - self.offload_layernorm = ( + self.offload_attn_norm = ( + self.config.offload_activation + and "attn_norm" in self.config.offload_modules + ) + self.offload_mlp_norm = ( self.config.offload_activation - and "layernorm" in self.config.offload_modules + and "mlp_norm" in self.config.offload_modules ) # @jcasper how should we handle nvfuser? @@ -513,21 +522,26 @@ def _forward_attention( # Residual connection. residual = hidden_states + offload_context = contextlib.nullcontext() + if self.offload_attn_norm: + hidden_states = group_prefetch_offload_start(hidden_states, name="attn_norm") + offload_context = PipelineOffloadManager.get_instance() # Optional Input Layer norm if self.recompute_input_layernorm: self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - input_layernorm_output = self.input_layernorm_checkpoint.checkpoint( - self.input_layernorm, hidden_states - ) + with offload_context: + input_layernorm_output = self.input_layernorm_checkpoint.checkpoint( + self.input_layernorm, hidden_states + ) else: - input_layernorm_output = self.input_layernorm(hidden_states) + with offload_context: + input_layernorm_output = self.input_layernorm(hidden_states) # Self attention. nvtx_range_push(suffix="self_attention") offload_context = contextlib.nullcontext() if self.offload_self_attn: - input_layernorm_output = group_prefetch_offload_start(input_layernorm_output, - is_last_layer=(self.layer_number == self.config.num_layers)) + input_layernorm_output = group_prefetch_offload_start(input_layernorm_output, name="self_attn") offload_context = PipelineOffloadManager.get_instance() with offload_context: attention_output_with_bias = self.self_attention( @@ -562,6 +576,10 @@ def _forward_attention( ) nvtx_range_pop(suffix="self_attn_bda") + if self.offload_attn_norm: + hidden_states, = group_prefetch_offload_commit(hidden_states, release_tensors=[residual]) + offload_context = contextlib.nullcontext() + # Residual connection. residual = hidden_states @@ -603,13 +621,8 @@ def _forward_mlp(self, hidden_states, inference_context=None): residual = hidden_states offload_context = contextlib.nullcontext() - if self.offload_layernorm: - from megatron.core.pipeline_parallel.cpu_offload import ( - PipelineOffloadManager, - group_prefetch_offload_start, - group_prefetch_offload_commit, - ) - hidden_states = group_prefetch_offload_start(hidden_states, is_last_layer=(self.layer_number == self.config.num_layers)) + if self.offload_mlp_norm: + hidden_states = group_prefetch_offload_start(hidden_states, name="mlp_norm") offload_context = PipelineOffloadManager.get_instance() # Optional Layer norm post the cross-attention. if self.recompute_pre_mlp_layernorm: @@ -680,7 +693,7 @@ def _forward_mlp(self, hidden_states, inference_context=None): mlp_output_with_bias, residual, self.hidden_dropout ) nvtx_range_pop(suffix="mlp_bda") - if self.offload_layernorm: + if self.offload_mlp_norm: hidden_states, = group_prefetch_offload_commit(hidden_states, release_tensors=[residual]) offload_context = contextlib.nullcontext() From bee10600ede9d78ad2fec274e42f21dea89d5f6c Mon Sep 17 00:00:00 2001 From: root Date: Thu, 18 Sep 2025 02:32:43 -0700 Subject: [PATCH 13/35] add new cpu_offload.py Signed-off-by: root --- megatron/core/transformer/cpu_offload.py | 519 +++++++++++++++++++++++ 1 file changed, 519 insertions(+) create mode 100644 megatron/core/transformer/cpu_offload.py diff --git a/megatron/core/transformer/cpu_offload.py b/megatron/core/transformer/cpu_offload.py new file mode 100644 index 00000000000..adfe47e09e1 --- /dev/null +++ b/megatron/core/transformer/cpu_offload.py @@ -0,0 +1,519 @@ +from collections import deque, defaultdict +import torch +from typing import Any +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.cpu_offload import AsyncDoubleBufferGroupOffloadHandler + +# cpu offload for pipeline +DEBUG = True +DEBUG_RANK = 5 +MIN_OFFLOADED_TENSOR_SIZE = 1024 * 1024 + +def set_ideal_affinity_for_current_gpu(): + import cuda.cuda + import cuda.cudart + import pynvml + import uuid + err, device_id = cuda.cudart.cudaGetDevice() + assert err == cuda.cudart.cudaError_t.cudaSuccess + err, device_uuid = cuda.cuda.cuDeviceGetUuid(device_id) + assert err == cuda.cuda.CUresult.CUDA_SUCCESS + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByUUID("GPU-" + str(uuid.UUID(bytes=device_uuid.bytes))) + pynvml.nvmlDeviceSetCpuAffinity(handle) + +class PipelineOffloadManager: + OFFLOAD_MGR = None + @classmethod + def get_instance(cls): + if cls.OFFLOAD_MGR is None: + cls.OFFLOAD_MGR = PipelineOffloadManager() + return cls.OFFLOAD_MGR + + def __init__(self): + from megatron.core import parallel_state + self._queue = deque() + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is None: + self._vpp = 1 + else: + self._vpp = parallel_state.get_virtual_pipeline_model_parallel_world_size() + + # cache vpp - 1 stages + self._stages = [[] for _ in range(self._vpp)] + # allocate streams and events for synchronization + self._d2h_stream = torch.cuda.Stream() + self._h2d_stream = torch.cuda.Stream() + self.reset() + + @property + def d2h_stream(self): + return self._d2h_stream + + @property + def h2d_stream(self): + return self._h2d_stream + + def reset(self): + set_ideal_affinity_for_current_gpu() + self._inside_context = False + self._cur_forward_chunk = None + self._cur_backward_chunk = None + self._first_last_vpp_rank = True + + def flush(self): + # put into the queue in the backward order + if len(self._stages[0]) == len(self._stages[-1]): + lens = [len(e) for e in self._stages] + assert min(lens) == max(lens) + self._stages[-1] = [] + for chunks in reversed(self._stages): + for chunk in chunks: + self.push(chunk) + for i in range(self._vpp): + self._stages[i] = [] + + def push(self, handler): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("pushing handler") + self._queue.append(handler) + + def pop(self): + assert self.size() + self._cur_backward_chunk = self._queue.popleft() + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("popping handler", self._cur_backward_chunk) + + def front(self): + if not len(self._queue): + return None + f = self._queue.popleft() + self._queue.appendleft(f) + return f + + def size(self): + return len(self._queue) + + def reset_chunk_handler(self, num_layer, vp_stage, offload=True, first_layer_index=0, offloaded_groups_count_per_layer=0): + if vp_stage is None: + cur_vpp_rank = 0 + else: + cur_vpp_rank = vp_stage + + first_last_vpp_rank = self._first_last_vpp_rank + # rewind + if cur_vpp_rank == self._vpp - 1: + self.flush() + first_last_vpp_rank = first_last_vpp_rank and (cur_vpp_rank == self._vpp - 1) + cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload, first_layer_index, offloaded_groups_count_per_layer) + # save for latter push + self._stages[cur_vpp_rank].append(cur_chunk) + if cur_vpp_rank == self._vpp - 1: + self._first_last_vpp_rank = False + self.push(cur_chunk) + self.flush() + self._cur_forward_chunk = cur_chunk + cur_chunk.vpp_rank = cur_vpp_rank + + def cur_forward_chunk(self): + return self._cur_forward_chunk + + def cur_backward_chunk(self): + return self._cur_backward_chunk + + def __enter__(self): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print(f"__enter__") + self.OFFLOAD_MGR + self.inside_context = True + + torch._C._autograd._push_saved_tensors_default_hooks( + self.on_save_for_backward, self.on_get_saved_tensor + ) + + def __exit__(self, *args: Any): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print(f"__exit__") + self.inside_context = False + torch._C._autograd._pop_saved_tensors_default_hooks() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("on_save_for_backward", tensor.shape) + assert self.inside_context + if self.cur_forward_chunk().is_registered_tensor(tensor.data_ptr()): + tensor.offloading_activation = True + return self.cur_forward_chunk().tensor_push(tensor) + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("on_get_saved_tensor", saved_state) + return self.cur_backward_chunk().tensor_pop(saved_state) + + +class ChunkOffloadHandler(AsyncDoubleBufferGroupOffloadHandler): + @staticmethod + def offload(src_tensor, pin_memory=True): + """Offload.""" + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("offload") + fp8_offload = isinstance(src_tensor, Float8Tensor) + + cpu_backup = torch.empty( + src_tensor.size(), + dtype=torch.uint8 if fp8_offload else src_tensor.dtype, + layout=src_tensor.layout, + device="cpu", + pin_memory=pin_memory, + ) + + if fp8_offload: + cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup) + + if not src_tensor.is_contiguous(): + src_tensor = src_tensor.contiguous() + + cpu_backup.copy_(src_tensor, non_blocking=pin_memory) + state = (src_tensor.device, cpu_backup) + return state + + @staticmethod + def reload(state, non_blocking=None): + """Reload.""" + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("reload") + dev, cpu_backup = state + if non_blocking is None: + non_blocking = cpu_backup.is_pinned() + return cpu_backup.to(dev, non_blocking=non_blocking) + + def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, first_layer_index=0, offloaded_groups_count_per_layer=0): + self._num_layers = num_layer + # Data Structure to maintain reference to activation tensors + self._tensor_tag_to_state = {} + # Tracking the number of layers offloaded + # self._offloaded_group_count = 0 + self._is_first_last_vpp_chunk = is_first_last_vpp_chunk + + self._offloaded_group_index = 0 + self._groups_to_offload = [] + self._groups_to_reload = [] + self.first_layer_index = first_layer_index + self._tensor_count_current_group = 0 + self.multi_input_offload_count = False + self.offloaded_groups_count_per_layer = offloaded_groups_count_per_layer + # self.offload_count_per_layer = defaultdict(int) + + self.torch_tensor_count = 0 + self.d2h_stream = PipelineOffloadManager.get_instance().d2h_stream + self.h2d_stream = PipelineOffloadManager.get_instance().h2d_stream + self.do_offload = offload + self.is_last_layer = False + + self._offload_tensor_ptrs = deque() + + def is_first_last_layer(self): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("is_first_last_layer", self._is_first_last_vpp_chunk, self.is_last_layer) + return self._is_first_last_vpp_chunk and self.is_last_layer + + def tensor_push(self, tensor): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("tensor_push") + torch_stray_tensor = isinstance( + tensor, + ( + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), + ) + + if not torch_stray_tensor:# True + # obtain a unique tensor tag + tensor_tag = (self._offloaded_group_index, self._tensor_count_current_group) + self._tensor_count_current_group += 1 + assert tensor_tag not in self._tensor_tag_to_state + self._tensor_tag_to_state[tensor_tag] = tensor + else: + tensor_tag = (-1, self.torch_tensor_count) + self.torch_tensor_count += 1 + self._tensor_tag_to_state[tensor_tag] = tensor + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("tensor_push", tensor.shape) + print("tensor_tag", tensor_tag) + return tensor_tag + + def tensor_pop(self, tensor_tag): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("tensor_pop") + print("tensor_tag", tensor_tag) + assert tensor_tag in self._tensor_tag_to_state, f"{tensor_tag}, {self._tensor_tag_to_state.keys()}" + tensor = self._tensor_tag_to_state.pop(tensor_tag) + assert not isinstance(tensor, tuple) + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("tensor_pop", tensor.shape) + # print("tensor", tensor) + return tensor + + def tensor_need_offloading_checker(self, tensor): + if tensor.numel() < MIN_OFFLOADED_TENSOR_SIZE: + return False + if hasattr(tensor, "offloading_activation") and not tensor.offloading_activation: + return False + return True + + def bulk_offload_group(self, group_to_offload): + """Bulk offload group.""" + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("bulk_offload_group") + if not self.do_offload: + return + assert not self.is_first_last_layer() + group_id_to_offload, name = group_to_offload + torch.cuda.nvtx.range_push(name) + with torch.cuda.stream(self.d2h_stream): + for tensor_tag, state in self._tensor_tag_to_state.items(): + group_id, _ = tensor_tag + if group_id == group_id_to_offload: + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("tensor_tag", tensor_tag) + print("group_to_offload", group_to_offload) + assert not isinstance(state, tuple) + tensor_on_device = state + # if offload, return the reference to cpu copy + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("tensor_need_offloading_checker", self.tensor_need_offloading_checker(tensor_on_device)) + print("tensor_on_device", tensor_on_device.shape) + if self.tensor_need_offloading_checker(tensor_on_device): + state = self.offload(tensor_on_device) + tensor_on_device.record_stream(self.d2h_stream) + # self.offload_count_per_layer[group_to_offload] += 1 + self._tensor_tag_to_state[tensor_tag] = state + # self._offloaded_group_count = group_to_offload + 1 + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("exit bulk_offload_group") + torch.cuda.nvtx.range_pop() + + def bulk_reload_group(self, group_to_reload): + """Bulk reload group.""" + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("bulk_reload_group") + if not self.do_offload: + return + found_reload_group = False + group_id_to_reload, name = group_to_reload + torch.cuda.nvtx.range_push(name) + with torch.cuda.stream(self.h2d_stream): + # move back tensors + for tensor_label, state in self._tensor_tag_to_state.items(): + group_id, _ = tensor_label + if group_id == group_id_to_reload: + found_reload_group = True + if isinstance(state, tuple): + recovered_tensor = self.reload(state) + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("recovered_tensor", recovered_tensor.shape) + self._tensor_tag_to_state[tensor_label] = recovered_tensor + # self.offload_count_per_layer[group_to_reload] -= 1 + # if self.offload_count_per_layer[group_to_reload] > 0 and self.multi_input_offload_count: + # break + # if self.offload_count_per_layer[group_to_reload] == 0: + # self._offloaded_group_count = group_to_reload + torch.cuda.nvtx.range_pop() + return found_reload_group + + def pre_reload_last_layer(self): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("pre_reload_last_layer") + if not self.do_offload: + return + assert not self._is_first_last_vpp_chunk + # TODO: check if this is correct + if len(self._groups_to_reload) > 0: + if self.bulk_reload_group(self._groups_to_reload[-1]): + self._groups_to_reload.pop() + # if self._num_layers == self._offloaded_group_count: + # self.bulk_reload_group(self._num_layers - 1) + # assert self._num_layers - 1 == self._offloaded_group_count + + def should_bulk_offload(self): + if not self.do_offload: + return False + # first backward chunk + if self.is_first_last_layer(): + return False + + # if next backward chunk is this chunk (for last pp stage) + next_backward_chunk = PipelineOffloadManager.get_instance().get_instance().front() + if next_backward_chunk is not None and next_backward_chunk is self: + if self.is_last_layer: + return False + + return True + + def forward_sync(self): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("forward_sync") + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + + def bulk_offload(self, release_tensors): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("bulk_offload") + if self.should_bulk_offload(): + group_to_offload = self._groups_to_offload.pop() + self._groups_to_reload.append(group_to_offload) + self.bulk_offload_group(group_to_offload) + if len(release_tensors) > 0: + cur_stream = torch.cuda.current_stream() + for release_tensor in release_tensors: + release_tensor.record_stream(cur_stream) + release_tensor.untyped_storage().resize_(0) + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("exit bulk_offload") + + def on_group_commit_forward(self, release_tensors): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("on_group_commit_forward") + # wait each other + self.forward_sync() + self.bulk_offload(release_tensors) + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("exit on_group_commit_forward") + + def bulk_reload(self): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("bulk_reload") + # if self.do_offload: + # assert self._layer_index == self._offloaded_group_count, f"{self._layer_index}, {self._offloaded_group_count}" + if len(self._groups_to_reload) > 0: + # load next layer + if self.bulk_reload_group(self._groups_to_reload[-1]): + self._groups_to_reload.pop() + else: + next_backward_chunk = PipelineOffloadManager.get_instance().front() + if next_backward_chunk is not None: + next_backward_chunk.pre_reload_last_layer() + + def backward_sync(self): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("backward_sync") + self.h2d_stream.wait_stream(torch.cuda.current_stream()) + # computation kernels wait until the offloaded groups of one layer are fully reloaded. + if self._offloaded_group_index % self.offloaded_groups_count_per_layer == 0: + torch.cuda.current_stream().wait_stream(self.h2d_stream) + + def on_group_commit_backward(self): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("on_group_commit_backward") + cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() + if not cur_backward_chunk is self: + PipelineOffloadManager.get_instance().pop() + cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() + assert cur_backward_chunk is self + self.backward_sync() + self._offloaded_group_index = self._offloaded_group_index - 1 + + def on_group_start_forward(self, name): + if self._offloaded_group_index % self.offloaded_groups_count_per_layer == 0: + torch.cuda.current_stream().wait_stream(self.d2h_stream) + if self._offloaded_group_index // self.offloaded_groups_count_per_layer == self._num_layers - 1: + self.is_last_layer = True + else: + self.is_last_layer = False + self._offloaded_group_index = self._offloaded_group_index + 1 + self._tensor_count_current_group = 0 + self._groups_to_offload.append((self._offloaded_group_index, name)) + # wait for the offloaded groups of one layer are fully offloaded. + # This is not necessary but good to have. + + def on_group_start_backward(self): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("on_group_start_backward") + self.h2d_stream.wait_stream(torch.cuda.current_stream()) + self.bulk_reload() + + def register_offload_tensor(self, tensors): + self.multi_input_offload_count = True + if isinstance(tensors, list): + for tensor in tensors: + self._offload_tensor_ptrs.append(tensor.data_ptr()) + else: + self._offload_tensor_ptrs.append(tensors.data_ptr()) + + def is_registered_tensor(self, tensor_ptr: int) -> bool: + if len(self._offload_tensor_ptrs) == 0: + return False + is_registered = tensor_ptr == self._offload_tensor_ptrs[0] + if is_registered: + self._offload_tensor_ptrs.popleft() + return is_registered + + +class GroupCommitFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, *args): + # pylint: disable=missing-function-docstring + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("GroupCommitFunction forward") + + release_tensors = args[-1] + cpu_offload_handler = args[-2] + tensor = args[:-2] + cpu_offload_handler.on_group_commit_forward(release_tensors) + ctx.cpu_offload_handler = cpu_offload_handler + + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, *grad_output): + # pylint: disable=missing-function-docstring + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("GroupCommitFunction backward") + + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_group_commit_backward() + return grad_output + (None, None) + + +def group_prefetch_offload_commit(*tensor, release_tensors=[]): + cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() + return GroupCommitFunction.apply(*tensor, cur_forward_chunk, release_tensors) + + +class GroupStartFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, tensor, cpu_offload_handler, name): + # pylint: disable=missing-function-docstring + ctx.cpu_offload_handler = cpu_offload_handler + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("GroupStartFunction forward") + + cpu_offload_handler.on_group_start_forward("activation offloading " + name) + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, grad_output): + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print("GroupStartFunction backward") + # pylint: disable=missing-function-docstring + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_group_start_backward() + return grad_output, None, None + + +def group_prefetch_offload_start(tensor, name=None): + cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() + return GroupStartFunction.apply(tensor, cur_forward_chunk, name) From 2b574c2abfabb89b1ebe6d424ce6bdb5981867d5 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 18 Sep 2025 06:16:02 -0700 Subject: [PATCH 14/35] minor fix Signed-off-by: Hongbin Liu --- megatron/core/model_parallel_config.py | 1 + megatron/core/transformer/attention.py | 98 +++++-------------- megatron/core/transformer/cpu_offload.py | 2 +- .../core/transformer/transformer_config.py | 19 +--- 4 files changed, 32 insertions(+), 88 deletions(-) diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 2ccf42743fd..7269c6a64df 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -5,6 +5,7 @@ import torch +import warnings @dataclass class ModelParallelConfig: diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 70e1e4dfa1a..6d79d62ac8e 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -268,56 +268,6 @@ def custom_forward(*inputs): return hidden_states - def _offload_core_attention_forward( - self, - query, - key, - value, - attention_mask, - rotary_pos_emb=None, - attn_mask_type=None, - attention_bias=None, - packed_seq_params=None, - ): - """Forward method with attention activation offloading.""" - - def custom_forward(*inputs): - query = inputs[0] - key = inputs[1] - value = inputs[2] - attention_mask = inputs[3] - attn_mask_type = inputs[5] - attn_mask_type = AttnMaskType(attn_mask_type.item()) - output_ = self.core_attention( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - ) - return output_ - - if attn_mask_type is None: - attn_mask_type = self.attn_mask_type - attn_mask_type = torch.tensor([attn_mask_type.value], dtype=torch.int) - value = value.contiguous() - - query = group_prefetch_offload_start(query) - - # handler = PipelineOffloadManager.get_instance().cur_forward_chunk() - # handler.register_offload_tensor([query, key, value]) - # query.offloading_activation = True - # key.offloading_activation = True - # value.offloading_activation = True - with PipelineOffloadManager.get_instance(): - hidden_states = custom_forward( - query, key, value, attention_mask, rotary_pos_emb, attn_mask_type - ) - hidden_states = group_prefetch_offload_commit(hidden_states, release_tensors=[query, key, value]) - return hidden_states[0] - def _allocate_memory(self, inference_max_sequence_length, batch_size, dim, dtype): """Allocate memory to store kv cache during inference.""" @@ -878,30 +828,23 @@ def forward( attention_bias=attention_bias, packed_seq_params=packed_seq_params, ) - elif self.offload_core_attention and self.training: - - - core_attn_out = self._offload_core_attention_forward( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - ) else: + offload_context = contextlib.nullcontext() + if self.offload_core_attention and self.training: + query = group_prefetch_offload_start(query, name="core_attn") + offload_context = PipelineOffloadManager.get_instance() if inference_context is None or inference_context.is_static_batching(): # Static batching attention kernel. - core_attn_out = self.core_attention( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - ) + with offload_context: + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) else: # Dynamic batching attention kernel. @@ -921,6 +864,9 @@ def forward( block_table, ) core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') + if self.offload_core_attention and self.training: + core_attn_out, = group_prefetch_offload_commit(core_attn_out, release_tensors=[query, key, value]) + offload_context = contextlib.nullcontext() if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': # reshape to same output shape as unpacked case @@ -935,7 +881,15 @@ def forward( # ================= nvtx_range_push(suffix="linear_proj") - output, bias = self.linear_proj(core_attn_out) + offload_context = contextlib.nullcontext() + if self.offload_attn_proj: + core_attn_out = group_prefetch_offload_start(core_attn_out, name="attn_proj") + offload_context = PipelineOffloadManager.get_instance() + with offload_context: + output, bias = self.linear_proj(core_attn_out) + if self.offload_attn_proj: + output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[core_attn_out]) + offload_context = contextlib.nullcontext() nvtx_range_pop(suffix="linear_proj") return output, bias diff --git a/megatron/core/transformer/cpu_offload.py b/megatron/core/transformer/cpu_offload.py index adfe47e09e1..2614fddef88 100644 --- a/megatron/core/transformer/cpu_offload.py +++ b/megatron/core/transformer/cpu_offload.py @@ -5,7 +5,7 @@ from transformer_engine.pytorch.cpu_offload import AsyncDoubleBufferGroupOffloadHandler # cpu offload for pipeline -DEBUG = True +DEBUG = False DEBUG_RANK = 5 MIN_OFFLOADED_TENSOR_SIZE = 1024 * 1024 diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index c7c5a175a6d..2f26468a259 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -959,22 +959,11 @@ def __post_init__(self): "because the input of attn_proj is the output of core_attn, " "which is needed in core_attn.backward()." ) + if "router_fc1" in self.offload_modules and self.tensor_model_parallel_size > 1: + raise ValueError( + "(Bug) router_fc1 cannot be set to offload_modules when tensor_model_parallel_size > 1." + ) - # if "self_attn" in self.offload_modules: - # if "qkv_linear" in self.offload_modules: - # self.offload_modules.remove("qkv_linear") - # if "core_attn" in self.offload_modules: - # self.offload_modules.remove("core_attn") - # if "attn_linear" in self.offload_modules: - # self.offload_modules.remove("attn_linear") - - # if "core_attn" in self.offload_modules: - # warnings.warn( - # "If you are using transformer_engine as the transformer implementation, " - # "the core_attn is from transformer_engine and may be the fused version. " - # "For fused attention, you have no need to set 'core_attn' to offload. " - # "Please check that the core_attn offload is really needed." - # ) if ( self.num_layers_in_first_pipeline_stage is not None From 1f03ceb77c3af8cbce3cb778aafd6d10f66e65d1 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 18 Sep 2025 06:24:02 -0700 Subject: [PATCH 15/35] code clean Signed-off-by: Hongbin Liu --- megatron/core/transformer/moe/experts.py | 13 ----- megatron/core/transformer/moe/moe_layer.py | 4 -- .../core/transformer/moe/token_dispatcher.py | 47 ------------------- 3 files changed, 64 deletions(-) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 639d46b62e6..97b3bf528fa 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -112,7 +112,6 @@ def __init__( self, num_local_experts: int, config: TransformerConfig, - layer_number: Optional[int] = None, model_comm_pgs: Optional[ModelCommProcessGroups] = None, ): super().__init__(config=config) @@ -760,13 +759,11 @@ def __init__( num_local_experts, config: TransformerConfig, submodules: MLPSubmodules, - layer_number: Optional[int] = None, model_comm_pgs: Optional[ModelCommProcessGroups] = None, ): super().__init__(config=config) self.num_local_experts = num_local_experts self.input_size = self.config.hidden_size - self.layer_number = layer_number assert ( config.add_bias_linear == False ), "bias not supported in TEGroupedMLP yet, please set '--disable-bias-linear' instead." @@ -823,11 +820,6 @@ def __init__( and "moe_act" in self.config.offload_modules ) - self.offload_router_fc2 = ( - self.config.offload_activation - and "router_fc2" in self.config.offload_modules - ) - self.activation_recompute = ( self.config.recompute_granularity == 'selective' and "moe_act" in self.config.recompute_modules @@ -842,10 +834,6 @@ def __init__( self.fp8_padding = Fp8Padding(self.num_local_experts) self.fp8_unpadding = Fp8Unpadding(self.num_local_experts) - def set_layer_number(self, layer_number: int): - """Set the layer number for the TEGroupedMLP.""" - self.layer_number = layer_number - def forward( self, permuted_local_hidden_states: torch.Tensor, @@ -1042,7 +1030,6 @@ def __init__( num_local_experts, config: TransformerConfig, submodules: MLPSubmodules, - layer_number: Optional[int] = None, model_comm_pgs: Optional[ModelCommProcessGroups] = None, ): diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 772900e17fd..e7dd9d4e56c 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -87,10 +87,6 @@ def set_layer_number(self, layer_number: int): """Set the layer number for the MoE layer.""" self.layer_number = layer_number self.router.set_layer_number(layer_number) - if hasattr(self.experts, 'set_layer_number'): - self.experts.set_layer_number(layer_number) - if hasattr(self.shared_experts, 'set_layer_number'): - self.shared_experts.set_layer_number(layer_number) class MoELayer(BaseMoELayer): diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 631e69cf1c9..05b73a9ee49 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -932,7 +932,6 @@ def __init__( router_topk: int, num_experts: int, config: TransformerConfig, - offload_activation: bool = False, ): """ Initialize the DeepEP dispatcher. @@ -954,7 +953,6 @@ def __init__( self.router_dtype = config.moe_router_dtype self.capacity_factor = config.moe_expert_capacity_factor self.permute_fusion = config.moe_permute_fusion - self.offload_activation = offload_activation # Metadata self.token_indices: Optional[torch.Tensor] = None @@ -1022,9 +1020,6 @@ def _indices_to_multihot(self, indices, probs): A tuple of (routing_map, probs), where routing_map is the multihot vector and probs is the multihot probabilities. """ - if self.offload_activation: - routing_map_vectorized, probs_map_vectorized = self._indices_to_multihot_vectorized(indices, probs) - return routing_map_vectorized, probs_map_vectorized batch_size = indices.shape[0] multihot_routing_map = torch.zeros( (batch_size, self.num_local_experts), dtype=torch.long, device=indices.device @@ -1043,47 +1038,6 @@ def _indices_to_multihot(self, indices, probs): multihot_probs[row_indices, valid_indices] = probs[mask] return multihot_routing_map.bool(), multihot_probs - def _indices_to_multihot_vectorized(self, indices, probs): - """ - Converts a tensor of indices to a multihot vector efficiently in PyTorch when enabling - offload_activation. - - Args: - indices (torch.Tensor): [num_tokens, topk] token indices, where -1 means masked out. - The max value of indices is local_num_experts - 1. - probs (torch.Tensor): [num_tokens, topk] token probabilities. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: - - routing_map: Multihot vector. - - probs: Multihot probabilities. - """ - batch_size, topk = indices.shape - - # Create mask for valid indices - mask = indices != -1 - - # Replace -1 with a valid index (will be masked out anyway) - safe_indices = torch.where(mask, indices, 0) - - # Create one-hot encoding for all positions - # Shape: [batch_size, topk, num_local_experts] - one_hot = torch.nn.functional.one_hot(safe_indices, num_classes=self.num_local_experts).float() - - # Apply mask to zero out invalid positions - # Expand mask to match one_hot dimensions - mask_expanded = mask.unsqueeze(-1).float() - one_hot = one_hot * mask_expanded - - # Sum along topk dimension to get multihot representation - multihot_routing_map = (one_hot.sum(dim=1) > 0).bool() - - # For probabilities, multiply by probs and sum - probs_expanded = probs.unsqueeze(-1) - multihot_probs = (one_hot * probs_expanded).sum(dim=1) - - return multihot_routing_map, multihot_probs - def get_dispached_metadata(self) -> torch.Tensor: return self.dispatched_indices, self.dispatched_probs @@ -1219,7 +1173,6 @@ def __init__( router_topk=self.tp_size * self.config.moe_router_topk, num_experts=self.tp_size * self.config.num_moe_experts, config=self.config, - offload_activation=self.config.offload_activation, ) def set_shared_experts(self, shared_experts): From 2ff9f6edd64d405aedf39239651dc9d9871903dd Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 18 Sep 2025 07:06:51 -0700 Subject: [PATCH 16/35] add interfaces to TE modules Signed-off-by: Hongbin Liu --- .../core/extensions/transformer_engine.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 99c3edc05ab..1cf1287191c 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -295,6 +295,14 @@ def __init__( extra_kwargs["delay_wgrad_compute"] = self.config.delay_wgrad_compute else: raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.") + if self.config.offload_activation: + te_version = get_te_version() + if te_version == PkgVersion("2.8.0.dev0+74a5f77b"): + extra_kwargs["offload_activation"] = self.config.offload_activation + else: + raise ValueError( + f"Transformer Engine v{te_version} does not support offload_activation." + ) if ( self.config.tp_comm_overlap and tp_comm_buffer_name @@ -506,6 +514,15 @@ def __init__( else: raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.") + if self.config.offload_activation: + te_version = get_te_version() + if te_version == PkgVersion("2.8.0.dev0+74a5f77b"): + extra_kwargs["offload_activation"] = self.config.offload_activation + else: + raise ValueError( + f"Transformer Engine v{te_version} does not support offload_activation." + ) + # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` if is_te_min_version("0.11.0"): extra_kwargs["normalization"] = self.config.normalization @@ -1095,6 +1112,14 @@ def __init__( raise RuntimeError( "Only TE with version >=2.3.0 supports delay_wgrad_compute now." ) + if self.config.offload_activation: + te_version = get_te_version() + if te_version == PkgVersion("2.8.0.dev0+74a5f77b"): + extra_kwargs["offload_activation"] = self.config.offload_activation + else: + raise ValueError( + f"Transformer Engine v{te_version} does not support offload_activation." + ) extra_kwargs["ub_name"] = tp_comm_buffer_name From a293701bcc04757295d7335efb12890690a7076f Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 18 Sep 2025 22:40:33 -0700 Subject: [PATCH 17/35] renaming Signed-off-by: Hongbin Liu --- megatron/core/extensions/transformer_engine.py | 18 +++++++++--------- megatron/core/model_parallel_config.py | 17 ++++++++--------- megatron/core/models/gpt/gpt_model.py | 2 +- megatron/core/transformer/attention.py | 6 +++--- megatron/core/transformer/moe/experts.py | 14 +++++++------- .../core/transformer/transformer_config.py | 6 +++--- megatron/core/transformer/transformer_layer.py | 6 +++--- megatron/training/arguments.py | 8 ++++---- 8 files changed, 38 insertions(+), 39 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 1cf1287191c..5fcf94b4b68 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -295,13 +295,13 @@ def __init__( extra_kwargs["delay_wgrad_compute"] = self.config.delay_wgrad_compute else: raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.") - if self.config.offload_activation: + if self.config.fine_grained_activation_offloading: te_version = get_te_version() if te_version == PkgVersion("2.8.0.dev0+74a5f77b"): - extra_kwargs["offload_activation"] = self.config.offload_activation + extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading else: raise ValueError( - f"Transformer Engine v{te_version} does not support offload_activation." + f"Transformer Engine v{te_version} does not support fine_grained_activation_offloading." ) if ( self.config.tp_comm_overlap @@ -514,13 +514,13 @@ def __init__( else: raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.") - if self.config.offload_activation: + if self.config.fine_grained_activation_offloading: te_version = get_te_version() if te_version == PkgVersion("2.8.0.dev0+74a5f77b"): - extra_kwargs["offload_activation"] = self.config.offload_activation + extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading else: raise ValueError( - f"Transformer Engine v{te_version} does not support offload_activation." + f"Transformer Engine v{te_version} does not support fine_grained_activation_offloading." ) # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` @@ -1112,13 +1112,13 @@ def __init__( raise RuntimeError( "Only TE with version >=2.3.0 supports delay_wgrad_compute now." ) - if self.config.offload_activation: + if self.config.fine_grained_activation_offloading: te_version = get_te_version() if te_version == PkgVersion("2.8.0.dev0+74a5f77b"): - extra_kwargs["offload_activation"] = self.config.offload_activation + extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading else: raise ValueError( - f"Transformer Engine v{te_version} does not support offload_activation." + f"Transformer Engine v{te_version} does not support fine_grained_activation_offloading." ) extra_kwargs["ub_name"] = tp_comm_buffer_name diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 7269c6a64df..551f2268486 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -316,20 +316,19 @@ class ModelParallelConfig: rank 1 | 0 1 2 0 1 2 3 4 3 4 """ - offload_activation: bool = False + fine_grained_activation_offloading: bool = False """If True, offload the activation to the CPU.""" offload_modules: Optional[list[str]] = None """The submodules to offload. - choices: "self_attn", "qkv_linear", "core_attn", "attn_linear", "router_fc1", "moe_act", "router_fc2". + choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act". default: ["core_attn"]. - "self_attn": offload the self_attn part of the transformer layer. - "qkv_linear": offload the qkv_linear part of the transformer layer. - "core_attn": offload the core attention part of the transformer layer. - "attn_linear": offload the attn linear projection part of the transformer layer. - "router_fc1": offload the moe router_fc1 part of the transformer layer. - "moe_act": offload the moe act part of the transformer layer. - "router_fc2": offload the moe router_fc2 part of the transformer layer. + "attn_norm": offload the input of the normalization in the attention part. + "core_attn": offload the input of the core attention part. + "mlp_norm": offload the input of the normalization in the mlp part. + "attn_proj": offload the input of the attn linear projection part. + "expert_fc1": offload the input of the expert fc1 part. + "moe_act": offload the input of the moe act part. """ offload_module_count_per_layer: Optional[int] = 0 """The number of modules to offload per layer. default: 0.""" diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 2aab9912a63..3912bccd793 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -370,7 +370,7 @@ def forward( PipelineOffloadManager.get_instance().reset_chunk_handler( self.decoder.num_layers_per_pipeline_rank, self.vp_stage, - self.config.offload_activation, + self.config.fine_grained_activation_offloading, 0, self.config.offload_module_count_per_layer, ) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 6d79d62ac8e..cd849816aaa 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -186,17 +186,17 @@ def __init__( ) self.offload_qkv_linear = ( - self.config.offload_activation + self.config.fine_grained_activation_offloading and "qkv_linear" in self.config.offload_modules ) self.offload_core_attention = ( - self.config.offload_activation + self.config.fine_grained_activation_offloading and "core_attn" in self.config.offload_modules ) self.offload_attn_proj = ( - self.config.offload_activation + self.config.fine_grained_activation_offloading and "attn_proj" in self.config.offload_modules ) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 97b3bf528fa..116f5861ae0 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -810,13 +810,13 @@ def __init__( tp_group=parallel_state.get_expert_tensor_parallel_group(), ) - self.offload_router_fc1 = ( - self.config.offload_activation - and "router_fc1" in self.config.offload_modules + self.offload_expert_fc1 = ( + self.config.fine_grained_activation_offloading + and "expert_fc1" in self.config.offload_modules ) self.offload_moe_act = ( - self.config.offload_activation + self.config.fine_grained_activation_offloading and "moe_act" in self.config.offload_modules ) @@ -874,14 +874,14 @@ def forward( permuted_probs = torch.ones_like(permuted_probs) offload_context = contextlib.nullcontext() - if self.offload_router_fc1: - permuted_local_hidden_states = group_prefetch_offload_start(permuted_local_hidden_states, name="router_fc1") + if self.offload_expert_fc1: + permuted_local_hidden_states = group_prefetch_offload_start(permuted_local_hidden_states, name="expert_fc1") offload_context = PipelineOffloadManager.get_instance() with offload_context: fc1_output, bias_parallel = self.linear_fc1( permuted_local_hidden_states, tokens_per_expert ) - if self.offload_router_fc1: + if self.offload_expert_fc1: fc1_output, bias_parallel = group_prefetch_offload_commit(fc1_output, bias_parallel, release_tensors=[permuted_local_hidden_states]) offload_context = contextlib.nullcontext() diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 2f26468a259..ba15a936f69 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -945,7 +945,7 @@ def __post_init__(self): if len(self.offload_modules) > 0: self.offload_modules = list(set(self.offload_modules)) allowed_modules = { - "core_attn", "attn_proj", "router_fc1", "moe_act", "attn_norm", "mlp_norm" + "core_attn", "attn_proj", "expert_fc1", "moe_act", "attn_norm", "mlp_norm" } invalid_modules = set(self.offload_modules) - allowed_modules assert not invalid_modules, ( @@ -959,9 +959,9 @@ def __post_init__(self): "because the input of attn_proj is the output of core_attn, " "which is needed in core_attn.backward()." ) - if "router_fc1" in self.offload_modules and self.tensor_model_parallel_size > 1: + if "expert_fc1" in self.offload_modules and self.tensor_model_parallel_size > 1: raise ValueError( - "(Bug) router_fc1 cannot be set to offload_modules when tensor_model_parallel_size > 1." + "(Bug) expert_fc1 cannot be set to offload_modules when tensor_model_parallel_size > 1." ) diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 3384fdd9cf6..c23e2425d46 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -432,15 +432,15 @@ def __init__( if not isinstance(self.mlp, MoELayer): self.recompute_mlp = True self.offload_self_attn = ( - self.config.offload_activation + self.config.fine_grained_activation_offloading and "self_attn" in self.config.offload_modules ) self.offload_attn_norm = ( - self.config.offload_activation + self.config.fine_grained_activation_offloading and "attn_norm" in self.config.offload_modules ) self.offload_mlp_norm = ( - self.config.offload_activation + self.config.fine_grained_activation_offloading and "mlp_norm" in self.config.offload_modules ) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index bde739c1d12..e6f462e442d 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1113,8 +1113,8 @@ def validate_args(args, defaults={}): "when enabling delay_wgrad_compute" ) - if args.offload_activation: - assert not args.overlap_grad_reduce, "overlap_grad_reduce is not supported with offload_activation" + if args.fine_grained_activation_offloading: + assert not args.overlap_grad_reduce, "overlap_grad_reduce is not supported with fine_grained_activation_offloading" if args.mtp_num_layers: assert not args.use_legacy_models, "The legacy Megatron models does not support Multi-Token Prediction (MTP)." @@ -2116,10 +2116,10 @@ def _add_training_args(parser): help='The communicator group names to use high priority streams.') group.add_argument('--use-te-activation-func', action='store_true', help='Use activation function kernel from Transformer Engine in MLP module.') - group.add_argument('--offload-activation', action='store_true', + group.add_argument('--fine-grained-activation-offloading', action='store_true', help='Offload the activation to the CPU.') group.add_argument('--offload-modules', nargs='*', type=str, default=[], - help='The submodules to offload. Choices: "self_attn", "qkv_linear", "core_attn", "attn_linear", "router_fc1", "router_fc2", "shared_fc1", "shared_fc2".') + help='The submodules to offload. Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".') return parser From aa628c02e1e3cfdf11b821128d8dd8a46534c601 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 18 Sep 2025 22:47:15 -0700 Subject: [PATCH 18/35] minor fix Signed-off-by: Hongbin Liu --- megatron/core/model_parallel_config.py | 17 ---------------- .../core/transformer/transformer_config.py | 20 +++++++++++++++++++ 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 551f2268486..d0933f863ca 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -315,23 +315,6 @@ class ModelParallelConfig: rank 0 | 0 1 2 0 1 2 3 4 3 4 rank 1 | 0 1 2 0 1 2 3 4 3 4 """ - - fine_grained_activation_offloading: bool = False - """If True, offload the activation to the CPU.""" - - offload_modules: Optional[list[str]] = None - """The submodules to offload. - choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act". - default: ["core_attn"]. - "attn_norm": offload the input of the normalization in the attention part. - "core_attn": offload the input of the core attention part. - "mlp_norm": offload the input of the normalization in the mlp part. - "attn_proj": offload the input of the attn linear projection part. - "expert_fc1": offload the input of the expert fc1 part. - "moe_act": offload the input of the moe act part. - """ - offload_module_count_per_layer: Optional[int] = 0 - """The number of modules to offload per layer. default: 0.""" ################### # CPU Offloading diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index ba15a936f69..29aa52f8468 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -666,6 +666,26 @@ class TransformerConfig(ModelParallelConfig): """Transformer implementation to use. Options are 'transformer_engine' for Transformer Engine and 'local' for MCore.""" + ##################################### + # Fine-grained Activation Offloading + ##################################### + fine_grained_activation_offloading: bool = False + """If True, offload the activation to the CPU.""" + + offload_modules: Optional[list[str]] = None + """The submodules to offload. + choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act". + default: ["core_attn"]. + "attn_norm": offload the input of the normalization in the attention part. + "core_attn": offload the input of the core attention part. + "mlp_norm": offload the input of the normalization in the mlp part. + "attn_proj": offload the input of the attn linear projection part. + "expert_fc1": offload the input of the expert fc1 part. + "moe_act": offload the input of the moe act part. + """ + offload_module_count_per_layer: Optional[int] = 0 + """The number of modules to offload per layer. default: 0.""" + def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more From ecfbc87e01251f9009c8ed59bf5339453c5e8e8b Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 19 Sep 2025 00:16:49 -0700 Subject: [PATCH 19/35] add README Signed-off-by: Hongbin Liu --- megatron/core/transformer/README.md | 68 +++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) create mode 100644 megatron/core/transformer/README.md diff --git a/megatron/core/transformer/README.md b/megatron/core/transformer/README.md new file mode 100644 index 00000000000..44d0cf3388c --- /dev/null +++ b/megatron/core/transformer/README.md @@ -0,0 +1,68 @@ +
+ +Fine-grained Activation Offloading +============= + +
+ +## Quick Start + +```bash +# Enable fine-grained activation offloading +--fine-grained-activation-offloading + +# Specify which modules are going to be offloaded +# Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act". +--offload-modules core_attn +``` + +# Current statue +## Features +* Support PP=1/PP/Interleaved PP +* Compatible with fine-grained recomputation + +## Known issues +* `--offload-modules expert_fc1` doesn't work with TP>1 + +## WIP items +* FP8 support +* Code refactor +* Benchmark + +# Methodology + +## Offload/Reload the input of one module to/from CPU +Let's take the attention projection module as an example: +``` +nvtx_range_push(suffix="linear_proj") +offload_context = contextlib.nullcontext() +if self.offload_attn_proj: + core_attn_out = group_prefetch_offload_start(core_attn_out, name="attn_proj") + offload_context = PipelineOffloadManager.get_instance() +with offload_context: + output, bias = self.linear_proj(core_attn_out) +if self.offload_attn_proj: + output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[core_attn_out]) + offload_context = contextlib.nullcontext() +nvtx_range_pop(suffix="linear_proj") +``` +The above code snippet could be divided into three parts in order: +1. Mark the starting point of offloading a new module; +2. Record the save_for_backward tensors in fprop and push it to a tensor buffer; +3. Offload the recorded tensors after the module's fprop finishes; + +In bprop, the three parts above will: +1. Make sure the offloaded tensors are reloaded back to GPU; +2. Pop the corresponding tensors from the tensor buffer; +3. Reload the corresponding tensors of next module; + +## Compatible with PP&Interleaved PP + +`PipelineOffloadManager` is used to manage the chunks across different model chunks in fprop and bprop. +Before the model.forward() start, the `PipelineOffloadManager.get_instance().reset_chunk_handler` will be executed. In the fprop of this methond, we create a `ChunkOffloadHandler` to handle the offloading context of one model chunk and then push it to a buffer, which will be poped out in a specific order in bprop. + +## Compatible with fine-grained recomputation + +## A special case: attn_norm/mlp_norm + +# Performance (WIP) \ No newline at end of file From 0f99ca606fa4cf6ae9f21536a59a41122df38511 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 19 Sep 2025 15:25:22 +0800 Subject: [PATCH 20/35] Update README.md --- megatron/core/transformer/README.md | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/megatron/core/transformer/README.md b/megatron/core/transformer/README.md index 44d0cf3388c..3b767679754 100644 --- a/megatron/core/transformer/README.md +++ b/megatron/core/transformer/README.md @@ -2,10 +2,10 @@ Fine-grained Activation Offloading ============= - +

NVIDIA, XiaoHongShu

-## Quick Start +# Quick Start ```bash # Enable fine-grained activation offloading @@ -16,7 +16,7 @@ Fine-grained Activation Offloading --offload-modules core_attn ``` -# Current statue +# Current status ## Features * Support PP=1/PP/Interleaved PP * Compatible with fine-grained recomputation @@ -59,10 +59,20 @@ In bprop, the three parts above will: ## Compatible with PP&Interleaved PP `PipelineOffloadManager` is used to manage the chunks across different model chunks in fprop and bprop. -Before the model.forward() start, the `PipelineOffloadManager.get_instance().reset_chunk_handler` will be executed. In the fprop of this methond, we create a `ChunkOffloadHandler` to handle the offloading context of one model chunk and then push it to a buffer, which will be poped out in a specific order in bprop. +Before the model.forward() start, the `PipelineOffloadManager.get_instance().reset_chunk_handler` will be executed. In the fprop of this method, we create a `ChunkOffloadHandler` to handle the offloading context of one model chunk and then push it to a buffer, which will be popped out in a specific order in bprop. + +image + ## Compatible with fine-grained recomputation +offload_and_recompute + + ## A special case: attn_norm/mlp_norm -# Performance (WIP) \ No newline at end of file +# Performance (WIP) + +# Acknowledgement + +This work refers to the previous work from Kuaishou: https://www.usenix.org/conference/atc24/presentation/yuan From e780b94d5ea309271f8e59269178345836bf0f95 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 19 Sep 2025 05:41:48 -0700 Subject: [PATCH 21/35] remove forward sync per layer Signed-off-by: Hongbin Liu --- megatron/core/transformer/README.md | 2 +- megatron/core/transformer/cpu_offload.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/megatron/core/transformer/README.md b/megatron/core/transformer/README.md index 3b767679754..50b0c79395d 100644 --- a/megatron/core/transformer/README.md +++ b/megatron/core/transformer/README.md @@ -2,7 +2,7 @@ Fine-grained Activation Offloading ============= -

NVIDIA, XiaoHongShu

+

NVIDIA, rednote

# Quick Start diff --git a/megatron/core/transformer/cpu_offload.py b/megatron/core/transformer/cpu_offload.py index 2614fddef88..c3e3d079bcd 100644 --- a/megatron/core/transformer/cpu_offload.py +++ b/megatron/core/transformer/cpu_offload.py @@ -413,8 +413,10 @@ def on_group_commit_backward(self): self._offloaded_group_index = self._offloaded_group_index - 1 def on_group_start_forward(self, name): - if self._offloaded_group_index % self.offloaded_groups_count_per_layer == 0: - torch.cuda.current_stream().wait_stream(self.d2h_stream) + # # wait for the offloaded groups of one layer are fully offloaded. + # # This is not necessary but good to have. + # if self._offloaded_group_index % self.offloaded_groups_count_per_layer == 0: + # torch.cuda.current_stream().wait_stream(self.d2h_stream) if self._offloaded_group_index // self.offloaded_groups_count_per_layer == self._num_layers - 1: self.is_last_layer = True else: @@ -422,8 +424,6 @@ def on_group_start_forward(self, name): self._offloaded_group_index = self._offloaded_group_index + 1 self._tensor_count_current_group = 0 self._groups_to_offload.append((self._offloaded_group_index, name)) - # wait for the offloaded groups of one layer are fully offloaded. - # This is not necessary but good to have. def on_group_start_backward(self): if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: From 20c40298725bf12d30d34af0d8738648f35b865b Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 22 Sep 2025 04:26:50 -0700 Subject: [PATCH 22/35] support FP8&MTP Signed-off-by: Hongbin Liu --- megatron/core/models/gpt/gpt_model.py | 5 +- megatron/core/transformer/README.md | 5 +- megatron/core/transformer/cpu_offload.py | 162 ++++++++++-------- .../core/transformer/transformer_layer.py | 2 + 4 files changed, 95 insertions(+), 79 deletions(-) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 3912bccd793..1f6ab958f19 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -367,8 +367,11 @@ def forward( runtime_gather_output (bool): Gather output at runtime. Default None means `parallel_output` arg in the constructor will be used. """ + num_layers = self.decoder.num_layers_per_pipeline_rank + if self.mtp_process: + num_layers = num_layers + self.config.mtp_num_layers PipelineOffloadManager.get_instance().reset_chunk_handler( - self.decoder.num_layers_per_pipeline_rank, + num_layers, self.vp_stage, self.config.fine_grained_activation_offloading, 0, diff --git a/megatron/core/transformer/README.md b/megatron/core/transformer/README.md index 50b0c79395d..3c16dfe4843 100644 --- a/megatron/core/transformer/README.md +++ b/megatron/core/transformer/README.md @@ -20,13 +20,14 @@ Fine-grained Activation Offloading ## Features * Support PP=1/PP/Interleaved PP * Compatible with fine-grained recomputation - +* Support FP8 +* Support MTP ## Known issues * `--offload-modules expert_fc1` doesn't work with TP>1 ## WIP items -* FP8 support * Code refactor +* Support MTP * Benchmark # Methodology diff --git a/megatron/core/transformer/cpu_offload.py b/megatron/core/transformer/cpu_offload.py index c3e3d079bcd..0ba6344001c 100644 --- a/megatron/core/transformer/cpu_offload.py +++ b/megatron/core/transformer/cpu_offload.py @@ -6,9 +6,14 @@ # cpu offload for pipeline DEBUG = False -DEBUG_RANK = 5 +DEBUG_RANK = 0 MIN_OFFLOADED_TENSOR_SIZE = 1024 * 1024 +def print_rank(message): + assert torch.distributed.is_initialized() + if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + print(message, flush=True) + def set_ideal_affinity_for_current_gpu(): import cuda.cuda import cuda.cudart @@ -73,15 +78,13 @@ def flush(self): self._stages[i] = [] def push(self, handler): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("pushing handler") + print_rank("pushing handler") self._queue.append(handler) def pop(self): assert self.size() self._cur_backward_chunk = self._queue.popleft() - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("popping handler", self._cur_backward_chunk) + print_rank(f"popping handler {self._cur_backward_chunk}") def front(self): if not len(self._queue): @@ -121,8 +124,7 @@ def cur_backward_chunk(self): return self._cur_backward_chunk def __enter__(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print(f"__enter__") + print_rank("__enter__") self.OFFLOAD_MGR self.inside_context = True @@ -131,22 +133,19 @@ def __enter__(self): ) def __exit__(self, *args: Any): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print(f"__exit__") + print_rank("__exit__") self.inside_context = False torch._C._autograd._pop_saved_tensors_default_hooks() def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("on_save_for_backward", tensor.shape) + print_rank(f"on_save_for_backward {tensor.shape}") assert self.inside_context if self.cur_forward_chunk().is_registered_tensor(tensor.data_ptr()): tensor.offloading_activation = True return self.cur_forward_chunk().tensor_push(tensor) def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("on_get_saved_tensor", saved_state) + print_rank(f"on_get_saved_tensor {saved_state}") return self.cur_backward_chunk().tensor_pop(saved_state) @@ -154,8 +153,7 @@ class ChunkOffloadHandler(AsyncDoubleBufferGroupOffloadHandler): @staticmethod def offload(src_tensor, pin_memory=True): """Offload.""" - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("offload") + print_rank("offload") fp8_offload = isinstance(src_tensor, Float8Tensor) cpu_backup = torch.empty( @@ -179,8 +177,7 @@ def offload(src_tensor, pin_memory=True): @staticmethod def reload(state, non_blocking=None): """Reload.""" - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("reload") + print_rank("reload") dev, cpu_backup = state if non_blocking is None: non_blocking = cpu_backup.is_pinned() @@ -198,6 +195,7 @@ def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, first_layer self._groups_to_offload = [] self._groups_to_reload = [] self.first_layer_index = first_layer_index + self._layer_index = 0 self._tensor_count_current_group = 0 self.multi_input_offload_count = False self.offloaded_groups_count_per_layer = offloaded_groups_count_per_layer @@ -206,19 +204,18 @@ def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, first_layer self.torch_tensor_count = 0 self.d2h_stream = PipelineOffloadManager.get_instance().d2h_stream self.h2d_stream = PipelineOffloadManager.get_instance().h2d_stream + self._offload_events = [] self.do_offload = offload self.is_last_layer = False self._offload_tensor_ptrs = deque() def is_first_last_layer(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("is_first_last_layer", self._is_first_last_vpp_chunk, self.is_last_layer) + print_rank(f"is_first_last_layer {self._is_first_last_vpp_chunk} {self.is_last_layer}") return self._is_first_last_vpp_chunk and self.is_last_layer def tensor_push(self, tensor): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("tensor_push") + print_rank("tensor_push") torch_stray_tensor = isinstance( tensor, ( @@ -237,21 +234,15 @@ def tensor_push(self, tensor): tensor_tag = (-1, self.torch_tensor_count) self.torch_tensor_count += 1 self._tensor_tag_to_state[tensor_tag] = tensor - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("tensor_push", tensor.shape) - print("tensor_tag", tensor_tag) + print_rank(f"tensor_push {tensor.shape}") return tensor_tag def tensor_pop(self, tensor_tag): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("tensor_pop") - print("tensor_tag", tensor_tag) + print_rank(f"tensor_pop {tensor_tag}") assert tensor_tag in self._tensor_tag_to_state, f"{tensor_tag}, {self._tensor_tag_to_state.keys()}" tensor = self._tensor_tag_to_state.pop(tensor_tag) assert not isinstance(tensor, tuple) - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("tensor_pop", tensor.shape) - # print("tensor", tensor) + print_rank(f"tensor_pop {tensor.shape}") return tensor def tensor_need_offloading_checker(self, tensor): @@ -263,8 +254,7 @@ def tensor_need_offloading_checker(self, tensor): def bulk_offload_group(self, group_to_offload): """Bulk offload group.""" - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("bulk_offload_group") + print_rank("bulk_offload_group") if not self.do_offload: return assert not self.is_first_last_layer() @@ -274,29 +264,25 @@ def bulk_offload_group(self, group_to_offload): for tensor_tag, state in self._tensor_tag_to_state.items(): group_id, _ = tensor_tag if group_id == group_id_to_offload: - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("tensor_tag", tensor_tag) - print("group_to_offload", group_to_offload) + print_rank(f"tensor_tag {tensor_tag}") + print_rank(f"group_to_offload {group_to_offload}") assert not isinstance(state, tuple) tensor_on_device = state # if offload, return the reference to cpu copy - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("tensor_need_offloading_checker", self.tensor_need_offloading_checker(tensor_on_device)) - print("tensor_on_device", tensor_on_device.shape) + print_rank(f"tensor_need_offloading_checker {self.tensor_need_offloading_checker(tensor_on_device)}") + print_rank(f"tensor_on_device {tensor_on_device.shape}") if self.tensor_need_offloading_checker(tensor_on_device): state = self.offload(tensor_on_device) tensor_on_device.record_stream(self.d2h_stream) # self.offload_count_per_layer[group_to_offload] += 1 self._tensor_tag_to_state[tensor_tag] = state # self._offloaded_group_count = group_to_offload + 1 - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("exit bulk_offload_group") + print_rank("exit bulk_offload_group") torch.cuda.nvtx.range_pop() def bulk_reload_group(self, group_to_reload): """Bulk reload group.""" - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("bulk_reload_group") + print_rank("bulk_reload_group") if not self.do_offload: return found_reload_group = False @@ -308,10 +294,11 @@ def bulk_reload_group(self, group_to_reload): group_id, _ = tensor_label if group_id == group_id_to_reload: found_reload_group = True + event = self._offload_events[-1] if isinstance(state, tuple): + torch.cuda.current_stream().wait_event(event) recovered_tensor = self.reload(state) - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("recovered_tensor", recovered_tensor.shape) + print_rank(f"recovered_tensor {recovered_tensor.shape}") self._tensor_tag_to_state[tensor_label] = recovered_tensor # self.offload_count_per_layer[group_to_reload] -= 1 # if self.offload_count_per_layer[group_to_reload] > 0 and self.multi_input_offload_count: @@ -319,15 +306,18 @@ def bulk_reload_group(self, group_to_reload): # if self.offload_count_per_layer[group_to_reload] == 0: # self._offloaded_group_count = group_to_reload torch.cuda.nvtx.range_pop() + if found_reload_group: + self._offload_events.pop() return found_reload_group def pre_reload_last_layer(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("pre_reload_last_layer") + print_rank("pre_reload_last_layer") if not self.do_offload: return assert not self._is_first_last_vpp_chunk # TODO: check if this is correct + print_rank(f"len(self._groups_to_reload) {len(self._groups_to_reload)}") + print_rank(f"len(self._offload_events) {len(self._offload_events)}") if len(self._groups_to_reload) > 0: if self.bulk_reload_group(self._groups_to_reload[-1]): self._groups_to_reload.pop() @@ -351,39 +341,38 @@ def should_bulk_offload(self): return True def forward_sync(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("forward_sync") + print_rank("forward_sync") self.d2h_stream.wait_stream(torch.cuda.current_stream()) def bulk_offload(self, release_tensors): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("bulk_offload") + print_rank("bulk_offload") if self.should_bulk_offload(): group_to_offload = self._groups_to_offload.pop() self._groups_to_reload.append(group_to_offload) self.bulk_offload_group(group_to_offload) + event = torch.cuda.Event() + event.record(self.d2h_stream) + self._offload_events.append(event) if len(release_tensors) > 0: cur_stream = torch.cuda.current_stream() for release_tensor in release_tensors: release_tensor.record_stream(cur_stream) - release_tensor.untyped_storage().resize_(0) - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("exit bulk_offload") + # release_tensor.untyped_storage().resize_(0) + print_rank("exit bulk_offload") def on_group_commit_forward(self, release_tensors): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("on_group_commit_forward") + print_rank("on_group_commit_forward") # wait each other self.forward_sync() self.bulk_offload(release_tensors) - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("exit on_group_commit_forward") + print_rank("exit on_group_commit_forward") def bulk_reload(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("bulk_reload") + print_rank("bulk_reload") # if self.do_offload: # assert self._layer_index == self._offloaded_group_count, f"{self._layer_index}, {self._offloaded_group_count}" + print_rank(f"len(self._groups_to_reload) {len(self._groups_to_reload)}") + print_rank(f"len(self._offload_events) {len(self._offload_events)}") if len(self._groups_to_reload) > 0: # load next layer if self.bulk_reload_group(self._groups_to_reload[-1]): @@ -394,16 +383,14 @@ def bulk_reload(self): next_backward_chunk.pre_reload_last_layer() def backward_sync(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("backward_sync") + print_rank("backward_sync") self.h2d_stream.wait_stream(torch.cuda.current_stream()) # computation kernels wait until the offloaded groups of one layer are fully reloaded. - if self._offloaded_group_index % self.offloaded_groups_count_per_layer == 0: - torch.cuda.current_stream().wait_stream(self.h2d_stream) + # if self._offloaded_group_index % self.offloaded_groups_count_per_layer == 0: + # torch.cuda.current_stream().wait_stream(self.h2d_stream) def on_group_commit_backward(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("on_group_commit_backward") + print_rank("on_group_commit_backward") cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() if not cur_backward_chunk is self: PipelineOffloadManager.get_instance().pop() @@ -417,7 +404,8 @@ def on_group_start_forward(self, name): # # This is not necessary but good to have. # if self._offloaded_group_index % self.offloaded_groups_count_per_layer == 0: # torch.cuda.current_stream().wait_stream(self.d2h_stream) - if self._offloaded_group_index // self.offloaded_groups_count_per_layer == self._num_layers - 1: + print_rank(f"on_group_start_forward {self._layer_index} {self._num_layers}") + if self._layer_index == self._num_layers: self.is_last_layer = True else: self.is_last_layer = False @@ -426,8 +414,7 @@ def on_group_start_forward(self, name): self._groups_to_offload.append((self._offloaded_group_index, name)) def on_group_start_backward(self): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("on_group_start_backward") + print_rank("on_group_start_backward") self.h2d_stream.wait_stream(torch.cuda.current_stream()) self.bulk_reload() @@ -446,6 +433,14 @@ def is_registered_tensor(self, tensor_ptr: int) -> bool: if is_registered: self._offload_tensor_ptrs.popleft() return is_registered + + def on_layer_start_forward(self): + print_rank("on_layer_start_forward") + self._layer_index = self._layer_index + 1 + + def on_layer_start_backward(self): + print_rank("on_layer_start_backward") + torch.cuda.current_stream().wait_stream(self.h2d_stream) class GroupCommitFunction(torch.autograd.Function): @@ -458,8 +453,7 @@ class GroupCommitFunction(torch.autograd.Function): @staticmethod def forward(ctx, *args): # pylint: disable=missing-function-docstring - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("GroupCommitFunction forward") + print_rank("GroupCommitFunction forward") release_tensors = args[-1] cpu_offload_handler = args[-2] @@ -473,8 +467,7 @@ def forward(ctx, *args): @staticmethod def backward(ctx, *grad_output): # pylint: disable=missing-function-docstring - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("GroupCommitFunction backward") + print_rank("GroupCommitFunction backward") cpu_offload_handler = ctx.cpu_offload_handler cpu_offload_handler.on_group_commit_backward() @@ -497,8 +490,7 @@ class GroupStartFunction(torch.autograd.Function): def forward(ctx, tensor, cpu_offload_handler, name): # pylint: disable=missing-function-docstring ctx.cpu_offload_handler = cpu_offload_handler - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("GroupStartFunction forward") + print_rank("GroupStartFunction forward") cpu_offload_handler.on_group_start_forward("activation offloading " + name) # return the identical tensor @@ -506,8 +498,7 @@ def forward(ctx, tensor, cpu_offload_handler, name): @staticmethod def backward(ctx, grad_output): - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: - print("GroupStartFunction backward") + print_rank("GroupStartFunction backward") # pylint: disable=missing-function-docstring cpu_offload_handler = ctx.cpu_offload_handler cpu_offload_handler.on_group_start_backward() @@ -517,3 +508,22 @@ def backward(ctx, grad_output): def group_prefetch_offload_start(tensor, name=None): cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() return GroupStartFunction.apply(tensor, cur_forward_chunk, name) + +class MarkLayerStartFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, tensor, cpu_offload_handler): + ctx.cpu_offload_handler = cpu_offload_handler + print_rank("MarkLayerStartFunction forward") + cpu_offload_handler.on_layer_start_forward() + return tensor + + @staticmethod + def backward(ctx, grad_output): + print_rank("MarkLayerStartFunction backward") + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_layer_start_backward() + return grad_output, None, None + +def mark_layer_start(tensor): + cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() + return MarkLayerStartFunction.apply(tensor, cur_forward_chunk) \ No newline at end of file diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index c23e2425d46..d3717cce8b0 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -35,6 +35,7 @@ PipelineOffloadManager, group_prefetch_offload_start, group_prefetch_offload_commit, + mark_layer_start, ) logger = logging.getLogger(__name__) @@ -473,6 +474,7 @@ def forward(self, *args, **kwargs): This method calls the core computation of a transformer layer, including self-attention, cross-attention (if applicable), and feed-forward operations. """ + kwargs["hidden_states"] = mark_layer_start(kwargs["hidden_states"]) hidden_states, context = self._forward_attention(*args, **kwargs) output = self._forward_mlp(hidden_states, kwargs.get("inference_context", None)) return output, context From ae494b55ca0bc25010f587d9e160dbe1bd25fb77 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 22 Sep 2025 04:27:22 -0700 Subject: [PATCH 23/35] minor fix Signed-off-by: Hongbin Liu --- megatron/core/transformer/cpu_offload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/core/transformer/cpu_offload.py b/megatron/core/transformer/cpu_offload.py index 0ba6344001c..94208b05e7f 100644 --- a/megatron/core/transformer/cpu_offload.py +++ b/megatron/core/transformer/cpu_offload.py @@ -357,7 +357,7 @@ def bulk_offload(self, release_tensors): cur_stream = torch.cuda.current_stream() for release_tensor in release_tensors: release_tensor.record_stream(cur_stream) - # release_tensor.untyped_storage().resize_(0) + release_tensor.untyped_storage().resize_(0) print_rank("exit bulk_offload") def on_group_commit_forward(self, release_tensors): From ba17d780d9f9fd30842483dcc1852127712afeac Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 22 Sep 2025 08:00:07 -0700 Subject: [PATCH 24/35] code refactor and bug fix Signed-off-by: Hongbin Liu --- megatron/core/models/gpt/gpt_model.py | 1 - megatron/core/transformer/attention.py | 4 +- megatron/core/transformer/cpu_offload.py | 82 ++++++------------- megatron/core/transformer/moe/experts.py | 2 +- .../transformer/multi_latent_attention.py | 4 +- .../core/transformer/transformer_config.py | 4 - 6 files changed, 32 insertions(+), 65 deletions(-) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 1f6ab958f19..ea3a138abdd 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -375,7 +375,6 @@ def forward( self.vp_stage, self.config.fine_grained_activation_offloading, 0, - self.config.offload_module_count_per_layer, ) inference_context = deprecate_inference_params(inference_context, inference_params) diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index cd849816aaa..9e3226fc2f8 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -865,7 +865,7 @@ def forward( ) core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') if self.offload_core_attention and self.training: - core_attn_out, = group_prefetch_offload_commit(core_attn_out, release_tensors=[query, key, value]) + core_attn_out, = group_prefetch_offload_commit(core_attn_out, release_tensors=[]) offload_context = contextlib.nullcontext() if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': @@ -888,7 +888,7 @@ def forward( with offload_context: output, bias = self.linear_proj(core_attn_out) if self.offload_attn_proj: - output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[core_attn_out]) + output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[]) offload_context = contextlib.nullcontext() nvtx_range_pop(suffix="linear_proj") diff --git a/megatron/core/transformer/cpu_offload.py b/megatron/core/transformer/cpu_offload.py index 94208b05e7f..08a42512a12 100644 --- a/megatron/core/transformer/cpu_offload.py +++ b/megatron/core/transformer/cpu_offload.py @@ -96,7 +96,7 @@ def front(self): def size(self): return len(self._queue) - def reset_chunk_handler(self, num_layer, vp_stage, offload=True, first_layer_index=0, offloaded_groups_count_per_layer=0): + def reset_chunk_handler(self, num_layer, vp_stage, offload=True, first_layer_index=0): if vp_stage is None: cur_vpp_rank = 0 else: @@ -107,7 +107,7 @@ def reset_chunk_handler(self, num_layer, vp_stage, offload=True, first_layer_ind if cur_vpp_rank == self._vpp - 1: self.flush() first_last_vpp_rank = first_last_vpp_rank and (cur_vpp_rank == self._vpp - 1) - cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload, first_layer_index, offloaded_groups_count_per_layer) + cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload, first_layer_index) # save for latter push self._stages[cur_vpp_rank].append(cur_chunk) if cur_vpp_rank == self._vpp - 1: @@ -140,8 +140,6 @@ def __exit__(self, *args: Any): def on_save_for_backward(self, tensor: torch.Tensor) -> Any: print_rank(f"on_save_for_backward {tensor.shape}") assert self.inside_context - if self.cur_forward_chunk().is_registered_tensor(tensor.data_ptr()): - tensor.offloading_activation = True return self.cur_forward_chunk().tensor_push(tensor) def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: @@ -183,7 +181,7 @@ def reload(state, non_blocking=None): non_blocking = cpu_backup.is_pinned() return cpu_backup.to(dev, non_blocking=non_blocking) - def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, first_layer_index=0, offloaded_groups_count_per_layer=0): + def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, first_layer_index=0): self._num_layers = num_layer # Data Structure to maintain reference to activation tensors self._tensor_tag_to_state = {} @@ -198,7 +196,6 @@ def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, first_layer self._layer_index = 0 self._tensor_count_current_group = 0 self.multi_input_offload_count = False - self.offloaded_groups_count_per_layer = offloaded_groups_count_per_layer # self.offload_count_per_layer = defaultdict(int) self.torch_tensor_count = 0 @@ -208,9 +205,9 @@ def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, first_layer self.do_offload = offload self.is_last_layer = False - self._offload_tensor_ptrs = deque() def is_first_last_layer(self): + """Do not offload the last layer of the last pp stage.""" print_rank(f"is_first_last_layer {self._is_first_last_vpp_chunk} {self.is_last_layer}") return self._is_first_last_vpp_chunk and self.is_last_layer @@ -246,6 +243,7 @@ def tensor_pop(self, tensor_tag): return tensor def tensor_need_offloading_checker(self, tensor): + """Check if the tensor needs to be offloaded.""" if tensor.numel() < MIN_OFFLOADED_TENSOR_SIZE: return False if hasattr(tensor, "offloading_activation") and not tensor.offloading_activation: @@ -253,7 +251,8 @@ def tensor_need_offloading_checker(self, tensor): return True def bulk_offload_group(self, group_to_offload): - """Bulk offload group.""" + """offload a group of tensors recorded in tensor_push(). + """ print_rank("bulk_offload_group") if not self.do_offload: return @@ -296,36 +295,30 @@ def bulk_reload_group(self, group_to_reload): found_reload_group = True event = self._offload_events[-1] if isinstance(state, tuple): + # make sure the tensor is already offloaded to cpu before reloading it. torch.cuda.current_stream().wait_event(event) recovered_tensor = self.reload(state) print_rank(f"recovered_tensor {recovered_tensor.shape}") self._tensor_tag_to_state[tensor_label] = recovered_tensor - # self.offload_count_per_layer[group_to_reload] -= 1 - # if self.offload_count_per_layer[group_to_reload] > 0 and self.multi_input_offload_count: - # break - # if self.offload_count_per_layer[group_to_reload] == 0: - # self._offloaded_group_count = group_to_reload torch.cuda.nvtx.range_pop() if found_reload_group: self._offload_events.pop() return found_reload_group def pre_reload_last_layer(self): + """Pre-reload the last layer of the next model chunk.""" print_rank("pre_reload_last_layer") if not self.do_offload: return assert not self._is_first_last_vpp_chunk - # TODO: check if this is correct print_rank(f"len(self._groups_to_reload) {len(self._groups_to_reload)}") print_rank(f"len(self._offload_events) {len(self._offload_events)}") if len(self._groups_to_reload) > 0: if self.bulk_reload_group(self._groups_to_reload[-1]): self._groups_to_reload.pop() - # if self._num_layers == self._offloaded_group_count: - # self.bulk_reload_group(self._num_layers - 1) - # assert self._num_layers - 1 == self._offloaded_group_count def should_bulk_offload(self): + """Check if the chunk should be offloaded.""" if not self.do_offload: return False # first backward chunk @@ -340,10 +333,6 @@ def should_bulk_offload(self): return True - def forward_sync(self): - print_rank("forward_sync") - self.d2h_stream.wait_stream(torch.cuda.current_stream()) - def bulk_offload(self, release_tensors): print_rank("bulk_offload") if self.should_bulk_offload(): @@ -352,6 +341,7 @@ def bulk_offload(self, release_tensors): self.bulk_offload_group(group_to_offload) event = torch.cuda.Event() event.record(self.d2h_stream) + # TODO: check if we really need it. self._offload_events.append(event) if len(release_tensors) > 0: cur_stream = torch.cuda.current_stream() @@ -361,49 +351,38 @@ def bulk_offload(self, release_tensors): print_rank("exit bulk_offload") def on_group_commit_forward(self, release_tensors): + """Offload a group of tensors.""" print_rank("on_group_commit_forward") - # wait each other - self.forward_sync() + # wait for the compute stream for offloading + self.d2h_stream.wait_stream(torch.cuda.current_stream()) self.bulk_offload(release_tensors) print_rank("exit on_group_commit_forward") def bulk_reload(self): print_rank("bulk_reload") - # if self.do_offload: - # assert self._layer_index == self._offloaded_group_count, f"{self._layer_index}, {self._offloaded_group_count}" - print_rank(f"len(self._groups_to_reload) {len(self._groups_to_reload)}") - print_rank(f"len(self._offload_events) {len(self._offload_events)}") if len(self._groups_to_reload) > 0: # load next layer if self.bulk_reload_group(self._groups_to_reload[-1]): self._groups_to_reload.pop() else: + # load the last layer of one backward chunk in advance next_backward_chunk = PipelineOffloadManager.get_instance().front() if next_backward_chunk is not None: next_backward_chunk.pre_reload_last_layer() - def backward_sync(self): - print_rank("backward_sync") - self.h2d_stream.wait_stream(torch.cuda.current_stream()) - # computation kernels wait until the offloaded groups of one layer are fully reloaded. - # if self._offloaded_group_index % self.offloaded_groups_count_per_layer == 0: - # torch.cuda.current_stream().wait_stream(self.h2d_stream) - def on_group_commit_backward(self): + """Prepare for reloadingthe next group of tensors.""" print_rank("on_group_commit_backward") cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() if not cur_backward_chunk is self: PipelineOffloadManager.get_instance().pop() cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() assert cur_backward_chunk is self - self.backward_sync() + # self.h2d_stream.wait_stream(torch.cuda.current_stream()) self._offloaded_group_index = self._offloaded_group_index - 1 def on_group_start_forward(self, name): - # # wait for the offloaded groups of one layer are fully offloaded. - # # This is not necessary but good to have. - # if self._offloaded_group_index % self.offloaded_groups_count_per_layer == 0: - # torch.cuda.current_stream().wait_stream(self.d2h_stream) + """Prepare for offloading the next group of tensors.""" print_rank(f"on_group_start_forward {self._layer_index} {self._num_layers}") if self._layer_index == self._num_layers: self.is_last_layer = True @@ -414,31 +393,19 @@ def on_group_start_forward(self, name): self._groups_to_offload.append((self._offloaded_group_index, name)) def on_group_start_backward(self): + """Reload the next group of tensors.""" print_rank("on_group_start_backward") self.h2d_stream.wait_stream(torch.cuda.current_stream()) self.bulk_reload() - - def register_offload_tensor(self, tensors): - self.multi_input_offload_count = True - if isinstance(tensors, list): - for tensor in tensors: - self._offload_tensor_ptrs.append(tensor.data_ptr()) - else: - self._offload_tensor_ptrs.append(tensors.data_ptr()) - - def is_registered_tensor(self, tensor_ptr: int) -> bool: - if len(self._offload_tensor_ptrs) == 0: - return False - is_registered = tensor_ptr == self._offload_tensor_ptrs[0] - if is_registered: - self._offload_tensor_ptrs.popleft() - return is_registered def on_layer_start_forward(self): + """Increment the layer index.""" print_rank("on_layer_start_forward") self._layer_index = self._layer_index + 1 def on_layer_start_backward(self): + """When the bprop of one layer finishes, make sure the reloading jobs on h2d stream are done. + """ print_rank("on_layer_start_backward") torch.cuda.current_stream().wait_stream(self.h2d_stream) @@ -475,6 +442,11 @@ def backward(ctx, *grad_output): def group_prefetch_offload_commit(*tensor, release_tensors=[]): + """Specify the tensors to be released after offloading. + release_tensors is a list of tensors to be released after offloading. + The tensors will be untyped_storage().resize_(0) after offloading. + Note: specify the tensors only when they are not automatically released by torch gc. + """ cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() return GroupCommitFunction.apply(*tensor, cur_forward_chunk, release_tensors) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 116f5861ae0..6bd5ffffc71 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -882,7 +882,7 @@ def forward( permuted_local_hidden_states, tokens_per_expert ) if self.offload_expert_fc1: - fc1_output, bias_parallel = group_prefetch_offload_commit(fc1_output, bias_parallel, release_tensors=[permuted_local_hidden_states]) + fc1_output, bias_parallel = group_prefetch_offload_commit(fc1_output, bias_parallel) offload_context = contextlib.nullcontext() def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 8f01a0cb8ac..a00a23b8ccb 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -306,7 +306,7 @@ def forward( if not inference_context.is_decode_only(): core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') if self.offload_core_attention and self.training: - core_attn_out, = group_prefetch_offload_commit(core_attn_out, release_tensors=[query, key, value]) + core_attn_out, = group_prefetch_offload_commit(core_attn_out, release_tensors=[]) offload_context = contextlib.nullcontext() # We are doing absorption with cache mla latents and decode mode. @@ -340,7 +340,7 @@ def forward( with offload_context: output, bias = self.linear_proj(core_attn_out) if self.offload_attn_proj: - output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[core_attn_out]) + output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[]) offload_context = contextlib.nullcontext() return output, bias diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 29aa52f8468..c32147783ae 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -979,10 +979,6 @@ def __post_init__(self): "because the input of attn_proj is the output of core_attn, " "which is needed in core_attn.backward()." ) - if "expert_fc1" in self.offload_modules and self.tensor_model_parallel_size > 1: - raise ValueError( - "(Bug) expert_fc1 cannot be set to offload_modules when tensor_model_parallel_size > 1." - ) if ( From dfaa62067db718c1dfc29c985356cd3c314d0bfd Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 22 Sep 2025 08:01:21 -0700 Subject: [PATCH 25/35] update README Signed-off-by: Hongbin Liu --- megatron/core/transformer/README.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/megatron/core/transformer/README.md b/megatron/core/transformer/README.md index 3c16dfe4843..99429a55288 100644 --- a/megatron/core/transformer/README.md +++ b/megatron/core/transformer/README.md @@ -23,11 +23,10 @@ Fine-grained Activation Offloading * Support FP8 * Support MTP ## Known issues -* `--offload-modules expert_fc1` doesn't work with TP>1 ## WIP items * Code refactor -* Support MTP +* Support mixed dense & moe layer * Benchmark # Methodology From 7d867ab9632a808c785a733d76af16176835be10 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Tue, 23 Sep 2025 07:22:41 -0700 Subject: [PATCH 26/35] avoid multiple d2h copies for expert_fc1 and update README Signed-off-by: Hongbin Liu --- megatron/core/transformer/README.md | 40 ++++++++++++++++++- megatron/core/transformer/cpu_offload.py | 8 ++-- megatron/core/transformer/moe/experts.py | 9 ++++- .../core/transformer/transformer_layer.py | 8 ++-- 4 files changed, 54 insertions(+), 11 deletions(-) diff --git a/megatron/core/transformer/README.md b/megatron/core/transformer/README.md index 99429a55288..79b033bf783 100644 --- a/megatron/core/transformer/README.md +++ b/megatron/core/transformer/README.md @@ -71,7 +71,45 @@ Before the model.forward() start, the `PipelineOffloadManager.get_instance().res ## A special case: attn_norm/mlp_norm -# Performance (WIP) +# Performance + +## H100 +### DeepSeek-V3-Proxy +#### Model structure +* Layer parameters are same as DeepSeek-V3 model +* Layer number is cut off to 14 layers +* Replace the fisrt 3 dense layers with 3 moe layers + +#### Key Hyper-parameters +* TP1PP4EP16VPP1CP1-MBS1GBS512 +* bf16 training +* DeepEP dispatcher +* `--cross-entropy-loss-fusion` and `--cross-entropy-fusion-impl te` +* `--moe-permute-fusion` +* `--moe-router-fusion` +* `--enable-experimental` + +#### Throughput and correctness + +#### Memory consumption + +Baseline (no offloading) +``` +[Rank 0] (after 10 iterations) memory (MB) | allocated: 24761.02978515625 | max allocated: 65203.93359375 | reserved: 64438.0 | max reserved: 74306.0 +[Rank 16] (after 10 iterations) memory (MB) | allocated: 18907.728515625 | max allocated: 52228.1533203125 | reserved: 58770.0 | max reserved: 58770.0 +[Rank 32] (after 10 iterations) memory (MB) | allocated: 18907.7529296875 | max allocated: 45200.8349609375 | reserved: 51772.0 | max reserved: 51772.0 +[Rank 48] (after 10 iterations) memory (MB) | allocated: 29006.82275390625 | max allocated: 48166.263671875 | reserved: 56328.0 | max reserved: 56328.0 +``` +With offloading expert_fc1, moe_act, act_norm and mlp_norm +``` +[Rank 0] (after 10 iterations) memory (MB) | allocated: 24705.02978515625 | max allocated: 48544.70849609375 | reserved: 61046.0 | max reserved: 61046.0 +[Rank 16] (after 10 iterations) memory (MB) | allocated: 18795.728515625 | max allocated: 38760.3876953125 | reserved: 46330.0 | max reserved: 46330.0 +[Rank 32] (after 10 iterations) memory (MB) | allocated: 18795.7529296875 | max allocated: 34950.2509765625 | reserved: 42452.0 | max reserved: 42452.0 +[Rank 48] (after 10 iterations) memory (MB) | allocated: 28950.82275390625 | max allocated: 41310.798828125 | reserved: 50408.0 | max reserved: 50408.0 +``` + + +## GB200 # Acknowledgement diff --git a/megatron/core/transformer/cpu_offload.py b/megatron/core/transformer/cpu_offload.py index 08a42512a12..7fef5cff7e6 100644 --- a/megatron/core/transformer/cpu_offload.py +++ b/megatron/core/transformer/cpu_offload.py @@ -263,13 +263,13 @@ def bulk_offload_group(self, group_to_offload): for tensor_tag, state in self._tensor_tag_to_state.items(): group_id, _ = tensor_tag if group_id == group_id_to_offload: - print_rank(f"tensor_tag {tensor_tag}") - print_rank(f"group_to_offload {group_to_offload}") + # print_rank(f"tensor_tag {tensor_tag}") + # print_rank(f"group_to_offload {group_to_offload}") assert not isinstance(state, tuple) tensor_on_device = state # if offload, return the reference to cpu copy - print_rank(f"tensor_need_offloading_checker {self.tensor_need_offloading_checker(tensor_on_device)}") - print_rank(f"tensor_on_device {tensor_on_device.shape}") + # print_rank(f"tensor_need_offloading_checker {self.tensor_need_offloading_checker(tensor_on_device)}") + # print_rank(f"tensor_on_device {tensor_on_device.shape}") if self.tensor_need_offloading_checker(tensor_on_device): state = self.offload(tensor_on_device) tensor_on_device.record_stream(self.d2h_stream) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 6bd5ffffc71..f87754974a3 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -828,6 +828,11 @@ def __init__( from megatron.core.extensions.transformer_engine import set_save_original_input set_save_original_input(self.linear_fc2) + + # This is to avoid the CPU overhead of multiple d2h copies + if self.offload_expert_fc1: + from megatron.core.extensions.transformer_engine import set_save_original_input + set_save_original_input(self.linear_fc1) if self.config.fp8: assert HAVE_TE, "FP8 requires TE." @@ -875,8 +880,8 @@ def forward( offload_context = contextlib.nullcontext() if self.offload_expert_fc1: - permuted_local_hidden_states = group_prefetch_offload_start(permuted_local_hidden_states, name="expert_fc1") - offload_context = PipelineOffloadManager.get_instance() + permuted_local_hidden_states = group_prefetch_offload_start(permuted_local_hidden_states, name="expert_fc1") + offload_context = PipelineOffloadManager.get_instance() with offload_context: fc1_output, bias_parallel = self.linear_fc1( permuted_local_hidden_states, tokens_per_expert diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index d3717cce8b0..8173e1ebe84 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -525,7 +525,7 @@ def _forward_attention( residual = hidden_states offload_context = contextlib.nullcontext() - if self.offload_attn_norm: + if self.offload_attn_norm and not isinstance(self.input_layernorm, IdentityOp): hidden_states = group_prefetch_offload_start(hidden_states, name="attn_norm") offload_context = PipelineOffloadManager.get_instance() # Optional Input Layer norm @@ -578,7 +578,7 @@ def _forward_attention( ) nvtx_range_pop(suffix="self_attn_bda") - if self.offload_attn_norm: + if self.offload_attn_norm and not isinstance(self.input_layernorm, IdentityOp): hidden_states, = group_prefetch_offload_commit(hidden_states, release_tensors=[residual]) offload_context = contextlib.nullcontext() @@ -623,7 +623,7 @@ def _forward_mlp(self, hidden_states, inference_context=None): residual = hidden_states offload_context = contextlib.nullcontext() - if self.offload_mlp_norm: + if self.offload_mlp_norm and not isinstance(self.pre_mlp_layernorm, IdentityOp): hidden_states = group_prefetch_offload_start(hidden_states, name="mlp_norm") offload_context = PipelineOffloadManager.get_instance() # Optional Layer norm post the cross-attention. @@ -695,7 +695,7 @@ def _forward_mlp(self, hidden_states, inference_context=None): mlp_output_with_bias, residual, self.hidden_dropout ) nvtx_range_pop(suffix="mlp_bda") - if self.offload_mlp_norm: + if self.offload_mlp_norm and not isinstance(self.pre_mlp_layernorm, IdentityOp): hidden_states, = group_prefetch_offload_commit(hidden_states, release_tensors=[residual]) offload_context = contextlib.nullcontext() From 7a7af1c49c2be0390710e53a42703c5edfaf7426 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Tue, 23 Sep 2025 22:24:54 +0800 Subject: [PATCH 27/35] Update README.md --- megatron/core/transformer/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/megatron/core/transformer/README.md b/megatron/core/transformer/README.md index 79b033bf783..18a07bd648f 100644 --- a/megatron/core/transformer/README.md +++ b/megatron/core/transformer/README.md @@ -91,6 +91,10 @@ Before the model.forward() start, the `PipelineOffloadManager.get_instance().res #### Throughput and correctness +image +image + + #### Memory consumption Baseline (no offloading) From 5b28cb29d4759cefb4953ea8ae595c31ff8cdfb3 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Wed, 24 Sep 2025 22:02:05 -0700 Subject: [PATCH 28/35] support mixed dense&moe layer and a2a overlap Signed-off-by: Hongbin Liu --- .../core/models/gpt/fine_grained_callables.py | 24 +++++-- megatron/core/models/gpt/gpt_model.py | 25 ++++--- megatron/core/transformer/README.md | 18 ++++- megatron/core/transformer/attention.py | 4 +- megatron/core/transformer/cpu_offload.py | 67 ++++++++++++------- megatron/core/transformer/moe/experts.py | 2 +- .../transformer/multi_latent_attention.py | 4 +- .../core/transformer/transformer_block.py | 6 ++ .../core/transformer/transformer_config.py | 14 ++-- .../core/transformer/transformer_layer.py | 13 ++-- 10 files changed, 126 insertions(+), 51 deletions(-) diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index fbecc047682..f8d27932d97 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -16,6 +16,12 @@ get_mtp_layer_offset, ) from megatron.core.transformer.transformer_layer import TransformerLayer, make_viewless_tensor +from megatron.core.transformer.cpu_offload import ( + PipelineOffloadManager, + group_prefetch_offload_start, + group_prefetch_offload_commit, + mark_layer_start, +) def weak_method(method): @@ -331,6 +337,7 @@ def submodule_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor): """ Performs same attnention forward logic as GPT Model. """ + hidden_states = mark_layer_start(hidden_states) hidden_states, _ = layer._forward_attention( hidden_states=hidden_states, attention_mask=node.chunk_state.attention_mask, @@ -347,13 +354,20 @@ def submodule_post_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor) Run forward pass for computations between attention and dispatch: pre mlp layernorm->router->dispatch preprocess """ + offload_context = nullcontext() + if layer.offload_mlp_norm: + hidden_states = group_prefetch_offload_start(hidden_states, name="mlp_norm") + offload_context = PipelineOffloadManager.get_instance() if layer.recompute_pre_mlp_layernorm: layer.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint( - layer.pre_mlp_layernorm, hidden_states - ) + with offload_context: + pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint( + layer.pre_mlp_layernorm, hidden_states + ) else: - pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states) + with offload_context: + pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states) + offload_context = nullcontext() local_tokens, probs, _ = layer.mlp.router_and_preprocess(pre_mlp_layernorm_output) @@ -433,6 +447,8 @@ def submodule_combine_forward( hidden_states = layer.mlp_bda(layer.training, layer.config.bias_dropout_fusion)( mlp_output_with_bias, residual, layer.hidden_dropout ) + if layer.offload_mlp_norm: + hidden_states, = group_prefetch_offload_commit(hidden_states, release_tensors=[residual]) output = make_viewless_tensor( inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index ea3a138abdd..fc0fc63ae59 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -342,6 +342,18 @@ def _preprocess( return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset + def initialize_model_chunk_offload_handler(self): + num_layers = self.decoder.num_layers_per_pipeline_rank + if self.mtp_process: + num_layers = num_layers + self.config.mtp_num_layers + # TODO: will be an issue when dense layer is placed across different pipeline stages + PipelineOffloadManager.get_instance().reset_chunk_handler( + num_layers, + self.vp_stage, + self.config.fine_grained_activation_offloading, + self.decoder.num_dense_layer, + ) + def forward( self, input_ids: Tensor, @@ -367,15 +379,8 @@ def forward( runtime_gather_output (bool): Gather output at runtime. Default None means `parallel_output` arg in the constructor will be used. """ - num_layers = self.decoder.num_layers_per_pipeline_rank - if self.mtp_process: - num_layers = num_layers + self.config.mtp_num_layers - PipelineOffloadManager.get_instance().reset_chunk_handler( - num_layers, - self.vp_stage, - self.config.fine_grained_activation_offloading, - 0, - ) + if self.config.fine_grained_activation_offloading: + self.initialize_model_chunk_offload_handler() inference_context = deprecate_inference_params(inference_context, inference_params) @@ -637,6 +642,8 @@ def build_schedule_plan( TransformerModelChunkSchedulePlan: The model chunk schedule plan. """ + if self.config.fine_grained_activation_offloading: + self.initialize_model_chunk_offload_handler() from ..common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan return TransformerModelChunkSchedulePlan( diff --git a/megatron/core/transformer/README.md b/megatron/core/transformer/README.md index 18a07bd648f..1306c783009 100644 --- a/megatron/core/transformer/README.md +++ b/megatron/core/transformer/README.md @@ -22,11 +22,13 @@ Fine-grained Activation Offloading * Compatible with fine-grained recomputation * Support FP8 * Support MTP +* Support mixed dense & moe layer +* Support A2A Overlap + ## Known issues ## WIP items * Code refactor -* Support mixed dense & moe layer * Benchmark # Methodology @@ -74,6 +76,7 @@ Before the model.forward() start, the `PipelineOffloadManager.get_instance().res # Performance ## H100 + ### DeepSeek-V3-Proxy #### Model structure * Layer parameters are same as DeepSeek-V3 model @@ -112,6 +115,19 @@ With offloading expert_fc1, moe_act, act_norm and mlp_norm [Rank 48] (after 10 iterations) memory (MB) | allocated: 28950.82275390625 | max allocated: 41310.798828125 | reserved: 50408.0 | max reserved: 50408.0 ``` +### Qwen3-30B-A3B +#### Model structure +* Same as Qwen-30B model structure + +#### Results + +| Model | Mapping | Sequence length | Recompute | Offload | Throughput (tflops) | Memory (MB) | +|---------------|--------------------------|-----------------|-----------|------------|---------------------|-------------| +| Qwen3-30B-A3B | TP1PP1EP8VPP1_MBS1GBS256 | 4096 | / | / | 194 | 65308 | +| | TP1PP1EP8VPP1_MBS1GBS256 | 8192 | full | / | 230 | 59566 | +| | TP1PP2EP8VPP4_MBS1GBS256 | 8192 | layernorm | expert_fc1 | 255 | 64962 | + + ## GB200 diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 9e3226fc2f8..cd849816aaa 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -865,7 +865,7 @@ def forward( ) core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') if self.offload_core_attention and self.training: - core_attn_out, = group_prefetch_offload_commit(core_attn_out, release_tensors=[]) + core_attn_out, = group_prefetch_offload_commit(core_attn_out, release_tensors=[query, key, value]) offload_context = contextlib.nullcontext() if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': @@ -888,7 +888,7 @@ def forward( with offload_context: output, bias = self.linear_proj(core_attn_out) if self.offload_attn_proj: - output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[]) + output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[core_attn_out]) offload_context = contextlib.nullcontext() nvtx_range_pop(suffix="linear_proj") diff --git a/megatron/core/transformer/cpu_offload.py b/megatron/core/transformer/cpu_offload.py index 7fef5cff7e6..30d6a26e9ec 100644 --- a/megatron/core/transformer/cpu_offload.py +++ b/megatron/core/transformer/cpu_offload.py @@ -11,7 +11,7 @@ def print_rank(message): assert torch.distributed.is_initialized() - if torch.distributed.get_rank() == DEBUG_RANK and DEBUG: + if DEBUG and torch.distributed.get_rank() == DEBUG_RANK: print(message, flush=True) def set_ideal_affinity_for_current_gpu(): @@ -83,20 +83,24 @@ def push(self, handler): def pop(self): assert self.size() - self._cur_backward_chunk = self._queue.popleft() + while self._queue: + self._cur_backward_chunk = self._queue.popleft() + if not isinstance(self._cur_backward_chunk, NullChunkOffloadHandler): + break print_rank(f"popping handler {self._cur_backward_chunk}") def front(self): if not len(self._queue): return None - f = self._queue.popleft() - self._queue.appendleft(f) - return f + for chunk_handler in self._queue: + if not isinstance(chunk_handler, NullChunkOffloadHandler): + return chunk_handler + return None def size(self): return len(self._queue) - def reset_chunk_handler(self, num_layer, vp_stage, offload=True, first_layer_index=0): + def reset_chunk_handler(self, num_layer, vp_stage, offload=True, num_dense_layer=0): if vp_stage is None: cur_vpp_rank = 0 else: @@ -107,7 +111,11 @@ def reset_chunk_handler(self, num_layer, vp_stage, offload=True, first_layer_ind if cur_vpp_rank == self._vpp - 1: self.flush() first_last_vpp_rank = first_last_vpp_rank and (cur_vpp_rank == self._vpp - 1) - cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload, first_layer_index) + # If the model chunk contains only the dense layers, initialize a null chunk handler. + if num_layer <= num_dense_layer: + cur_chunk = NullChunkOffloadHandler(num_layer, first_last_vpp_rank, offload, num_dense_layer) + else: + cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload, num_dense_layer) # save for latter push self._stages[cur_vpp_rank].append(cur_chunk) if cur_vpp_rank == self._vpp - 1: @@ -128,14 +136,16 @@ def __enter__(self): self.OFFLOAD_MGR self.inside_context = True - torch._C._autograd._push_saved_tensors_default_hooks( - self.on_save_for_backward, self.on_get_saved_tensor - ) + if not isinstance(self.cur_forward_chunk(), NullChunkOffloadHandler): + torch._C._autograd._push_saved_tensors_default_hooks( + self.on_save_for_backward, self.on_get_saved_tensor + ) def __exit__(self, *args: Any): print_rank("__exit__") self.inside_context = False - torch._C._autograd._pop_saved_tensors_default_hooks() + if not isinstance(self.cur_forward_chunk(), NullChunkOffloadHandler): + torch._C._autograd._pop_saved_tensors_default_hooks() def on_save_for_backward(self, tensor: torch.Tensor) -> Any: print_rank(f"on_save_for_backward {tensor.shape}") @@ -181,7 +191,7 @@ def reload(state, non_blocking=None): non_blocking = cpu_backup.is_pinned() return cpu_backup.to(dev, non_blocking=non_blocking) - def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, first_layer_index=0): + def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, num_dense_layer=0): self._num_layers = num_layer # Data Structure to maintain reference to activation tensors self._tensor_tag_to_state = {} @@ -192,7 +202,7 @@ def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, first_layer self._offloaded_group_index = 0 self._groups_to_offload = [] self._groups_to_reload = [] - self.first_layer_index = first_layer_index + self.num_dense_layer = num_dense_layer self._layer_index = 0 self._tensor_count_current_group = 0 self.multi_input_offload_count = False @@ -231,7 +241,7 @@ def tensor_push(self, tensor): tensor_tag = (-1, self.torch_tensor_count) self.torch_tensor_count += 1 self._tensor_tag_to_state[tensor_tag] = tensor - print_rank(f"tensor_push {tensor.shape}") + print_rank(f"tensor_push {tensor_tag}") return tensor_tag def tensor_pop(self, tensor_tag): @@ -263,8 +273,8 @@ def bulk_offload_group(self, group_to_offload): for tensor_tag, state in self._tensor_tag_to_state.items(): group_id, _ = tensor_tag if group_id == group_id_to_offload: - # print_rank(f"tensor_tag {tensor_tag}") - # print_rank(f"group_to_offload {group_to_offload}") + print_rank(f"tensor_tag {tensor_tag}") + print_rank(f"group_to_offload {group_to_offload}") assert not isinstance(state, tuple) tensor_on_device = state # if offload, return the reference to cpu copy @@ -321,6 +331,8 @@ def should_bulk_offload(self): """Check if the chunk should be offloaded.""" if not self.do_offload: return False + if self._layer_index < self.num_dense_layer: + return False # first backward chunk if self.is_first_last_layer(): return False @@ -360,7 +372,7 @@ def on_group_commit_forward(self, release_tensors): def bulk_reload(self): print_rank("bulk_reload") - if len(self._groups_to_reload) > 0: + if len(self._groups_to_reload) > 0 and self._layer_index > self.num_dense_layer: # load next layer if self.bulk_reload_group(self._groups_to_reload[-1]): self._groups_to_reload.pop() @@ -407,8 +419,11 @@ def on_layer_start_backward(self): """When the bprop of one layer finishes, make sure the reloading jobs on h2d stream are done. """ print_rank("on_layer_start_backward") + self._layer_index = self._layer_index - 1 torch.cuda.current_stream().wait_stream(self.h2d_stream) +class NullChunkOffloadHandler(ChunkOffloadHandler): + pass class GroupCommitFunction(torch.autograd.Function): """this is a dummy op with output identical to input. @@ -425,7 +440,8 @@ def forward(ctx, *args): release_tensors = args[-1] cpu_offload_handler = args[-2] tensor = args[:-2] - cpu_offload_handler.on_group_commit_forward(release_tensors) + if not isinstance(cpu_offload_handler, NullChunkOffloadHandler): + cpu_offload_handler.on_group_commit_forward(release_tensors) ctx.cpu_offload_handler = cpu_offload_handler # return the identical tensor @@ -437,7 +453,8 @@ def backward(ctx, *grad_output): print_rank("GroupCommitFunction backward") cpu_offload_handler = ctx.cpu_offload_handler - cpu_offload_handler.on_group_commit_backward() + if not isinstance(cpu_offload_handler, NullChunkOffloadHandler): + cpu_offload_handler.on_group_commit_backward() return grad_output + (None, None) @@ -464,7 +481,8 @@ def forward(ctx, tensor, cpu_offload_handler, name): ctx.cpu_offload_handler = cpu_offload_handler print_rank("GroupStartFunction forward") - cpu_offload_handler.on_group_start_forward("activation offloading " + name) + if not isinstance(cpu_offload_handler, NullChunkOffloadHandler): + cpu_offload_handler.on_group_start_forward("activation offloading " + name) # return the identical tensor return tensor @@ -473,7 +491,8 @@ def backward(ctx, grad_output): print_rank("GroupStartFunction backward") # pylint: disable=missing-function-docstring cpu_offload_handler = ctx.cpu_offload_handler - cpu_offload_handler.on_group_start_backward() + if not isinstance(cpu_offload_handler, NullChunkOffloadHandler): + cpu_offload_handler.on_group_start_backward() return grad_output, None, None @@ -486,14 +505,16 @@ class MarkLayerStartFunction(torch.autograd.Function): def forward(ctx, tensor, cpu_offload_handler): ctx.cpu_offload_handler = cpu_offload_handler print_rank("MarkLayerStartFunction forward") - cpu_offload_handler.on_layer_start_forward() + if not isinstance(cpu_offload_handler, NullChunkOffloadHandler): + cpu_offload_handler.on_layer_start_forward() return tensor @staticmethod def backward(ctx, grad_output): print_rank("MarkLayerStartFunction backward") cpu_offload_handler = ctx.cpu_offload_handler - cpu_offload_handler.on_layer_start_backward() + if not isinstance(cpu_offload_handler, NullChunkOffloadHandler): + cpu_offload_handler.on_layer_start_backward() return grad_output, None, None def mark_layer_start(tensor): diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index f87754974a3..1e4bf8a88e9 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -887,7 +887,7 @@ def forward( permuted_local_hidden_states, tokens_per_expert ) if self.offload_expert_fc1: - fc1_output, bias_parallel = group_prefetch_offload_commit(fc1_output, bias_parallel) + fc1_output, bias_parallel = group_prefetch_offload_commit(fc1_output, bias_parallel, release_tensors=[]) offload_context = contextlib.nullcontext() def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index a00a23b8ccb..8f01a0cb8ac 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -306,7 +306,7 @@ def forward( if not inference_context.is_decode_only(): core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') if self.offload_core_attention and self.training: - core_attn_out, = group_prefetch_offload_commit(core_attn_out, release_tensors=[]) + core_attn_out, = group_prefetch_offload_commit(core_attn_out, release_tensors=[query, key, value]) offload_context = contextlib.nullcontext() # We are doing absorption with cache mla latents and decode mode. @@ -340,7 +340,7 @@ def forward( with offload_context: output, bias = self.linear_proj(core_attn_out) if self.offload_attn_proj: - output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[]) + output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[core_attn_out]) offload_context = contextlib.nullcontext() return output, bias diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 534351b737f..bb04674f250 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -312,6 +312,12 @@ def __init__( self._build_layers() self.num_layers_per_pipeline_rank = len(self.layers) + self.num_dense_layer = 0 + from megatron.core.transformer.moe.moe_layer import MoELayer + for layer in self.layers: + if not isinstance(layer.mlp, MoELayer): + self.num_dense_layer += 1 + def _build_layers(self): # Transformer layers. # @jcasper can we improve how we deal with layer_number? diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index c32147783ae..ec2dca28ca4 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -683,8 +683,6 @@ class TransformerConfig(ModelParallelConfig): "expert_fc1": offload the input of the expert fc1 part. "moe_act": offload the input of the moe act part. """ - offload_module_count_per_layer: Optional[int] = 0 - """The number of modules to offload per layer. default: 0.""" def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. @@ -963,7 +961,6 @@ def __post_init__(self): # self.offload_modules = ["core_attn"] if len(self.offload_modules) > 0: - self.offload_modules = list(set(self.offload_modules)) allowed_modules = { "core_attn", "attn_proj", "expert_fc1", "moe_act", "attn_norm", "mlp_norm" } @@ -972,13 +969,22 @@ def __post_init__(self): f'Invalid choices for offload_modules: {invalid_modules}. ' f'Allowed modules are: {allowed_modules}' ) - self.offload_module_count_per_layer = len(self.offload_modules) if "attn_proj" in self.offload_modules and "core_attn" not in self.offload_modules: raise ValueError( "attn_proj cannot be set to offload_modules alone without core_attn " "because the input of attn_proj is the output of core_attn, " "which is needed in core_attn.backward()." ) + + if isinstance(self.moe_layer_freq, int): + raise ValueError( + "moe_layer_freq cannot be an integer when offload_modules is set." + ) + elif isinstance(self.moe_layer_freq, list): + if 0 in self.moe_layer_freq: + warnings.warn( + "Activation of dense layer won't be offloaded at all for mixed dense and moe layer." + ) if ( diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 8173e1ebe84..0694d896c5e 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -439,10 +439,12 @@ def __init__( self.offload_attn_norm = ( self.config.fine_grained_activation_offloading and "attn_norm" in self.config.offload_modules + and not isinstance(self.input_layernorm, IdentityOp) ) self.offload_mlp_norm = ( self.config.fine_grained_activation_offloading and "mlp_norm" in self.config.offload_modules + and not isinstance(self.pre_mlp_layernorm, IdentityOp) ) # @jcasper how should we handle nvfuser? @@ -474,7 +476,8 @@ def forward(self, *args, **kwargs): This method calls the core computation of a transformer layer, including self-attention, cross-attention (if applicable), and feed-forward operations. """ - kwargs["hidden_states"] = mark_layer_start(kwargs["hidden_states"]) + if self.config.fine_grained_activation_offloading: + kwargs["hidden_states"] = mark_layer_start(kwargs["hidden_states"]) hidden_states, context = self._forward_attention(*args, **kwargs) output = self._forward_mlp(hidden_states, kwargs.get("inference_context", None)) return output, context @@ -525,7 +528,7 @@ def _forward_attention( residual = hidden_states offload_context = contextlib.nullcontext() - if self.offload_attn_norm and not isinstance(self.input_layernorm, IdentityOp): + if self.offload_attn_norm: hidden_states = group_prefetch_offload_start(hidden_states, name="attn_norm") offload_context = PipelineOffloadManager.get_instance() # Optional Input Layer norm @@ -578,7 +581,7 @@ def _forward_attention( ) nvtx_range_pop(suffix="self_attn_bda") - if self.offload_attn_norm and not isinstance(self.input_layernorm, IdentityOp): + if self.offload_attn_norm: hidden_states, = group_prefetch_offload_commit(hidden_states, release_tensors=[residual]) offload_context = contextlib.nullcontext() @@ -623,7 +626,7 @@ def _forward_mlp(self, hidden_states, inference_context=None): residual = hidden_states offload_context = contextlib.nullcontext() - if self.offload_mlp_norm and not isinstance(self.pre_mlp_layernorm, IdentityOp): + if self.offload_mlp_norm: hidden_states = group_prefetch_offload_start(hidden_states, name="mlp_norm") offload_context = PipelineOffloadManager.get_instance() # Optional Layer norm post the cross-attention. @@ -695,7 +698,7 @@ def _forward_mlp(self, hidden_states, inference_context=None): mlp_output_with_bias, residual, self.hidden_dropout ) nvtx_range_pop(suffix="mlp_bda") - if self.offload_mlp_norm and not isinstance(self.pre_mlp_layernorm, IdentityOp): + if self.offload_mlp_norm: hidden_states, = group_prefetch_offload_commit(hidden_states, release_tensors=[residual]) offload_context = contextlib.nullcontext() From 4c3b2c5a0a79f73421677cb425f53633c2e9f808 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 25 Sep 2025 01:27:56 -0700 Subject: [PATCH 29/35] minor fix Signed-off-by: Hongbin Liu --- megatron/core/models/gpt/fine_grained_callables.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index f8d27932d97..1552168a53a 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -337,7 +337,8 @@ def submodule_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor): """ Performs same attnention forward logic as GPT Model. """ - hidden_states = mark_layer_start(hidden_states) + if layer.config.fine_grained_activation_offloading: + hidden_states = mark_layer_start(hidden_states) hidden_states, _ = layer._forward_attention( hidden_states=hidden_states, attention_mask=node.chunk_state.attention_mask, From a5d194cb2955eadb24c1aa990179b73abfb64cb0 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 25 Sep 2025 22:28:28 -0700 Subject: [PATCH 30/35] bug fix Signed-off-by: Hongbin Liu --- megatron/core/transformer/README.md | 1 + megatron/core/transformer/moe/experts.py | 2 +- megatron/core/transformer/transformer_config.py | 4 +--- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/megatron/core/transformer/README.md b/megatron/core/transformer/README.md index 1306c783009..22221b92de7 100644 --- a/megatron/core/transformer/README.md +++ b/megatron/core/transformer/README.md @@ -26,6 +26,7 @@ Fine-grained Activation Offloading * Support A2A Overlap ## Known issues +* We explicitly resize some tensors to 0 to release the memory space immediately, which sometimes leads to illegal memory access. Please remove the released tensors in `group_prefetch_offload_commit` if you run into the issue. ## WIP items * Code refactor diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index e15411f27f2..627ace53475 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -980,7 +980,7 @@ def glu(x): if self.activation_recompute: self.activation_checkpoint.discard_output_and_register_recompute(output) if self.offload_moe_act: - output, = group_prefetch_offload_commit(output, release_tensors=[fc1_output]) + output, = group_prefetch_offload_commit(output, release_tensors=[]) offload_context = contextlib.nullcontext() diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 23f3e3e9085..046166df2da 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -1028,9 +1028,7 @@ def __post_init__(self): ) if isinstance(self.moe_layer_freq, int): - raise ValueError( - "moe_layer_freq cannot be an integer when offload_modules is set." - ) + assert self.moe_layer_freq == 1, "moe_layer_freq cannot be an integer other than 1 when offload_modules is set." elif isinstance(self.moe_layer_freq, list): if 0 in self.moe_layer_freq: warnings.warn( From c26eb8aadc20a9f953daa7cde0dabdf277a5ba5d Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 26 Sep 2025 01:07:53 -0700 Subject: [PATCH 31/35] temp fix to enable --overlap-grad-reduce Signed-off-by: Hongbin Liu --- megatron/training/arguments.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 6a7670e4786..76aac0ea63e 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1139,8 +1139,8 @@ def validate_args(args, defaults={}): "when enabling delay_wgrad_compute" ) - if args.fine_grained_activation_offloading: - assert not args.overlap_grad_reduce, "overlap_grad_reduce is not supported with fine_grained_activation_offloading" + # if args.fine_grained_activation_offloading: + # assert not args.overlap_grad_reduce, "overlap_grad_reduce is not supported with fine_grained_activation_offloading" if args.mtp_num_layers: assert not args.use_legacy_models, "The legacy Megatron models does not support Multi-Token Prediction (MTP)." From 0d845de67f17524c518edff2ffd1ffdf56e256c4 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Sun, 28 Sep 2025 23:49:47 -0700 Subject: [PATCH 32/35] fix to enable --overlap-grad-reduce and allow placing loss layer only into the last stage Signed-off-by: Hongbin Liu --- megatron/core/extensions/transformer_engine.py | 6 +++--- megatron/core/models/gpt/gpt_model.py | 4 ++++ megatron/core/transformer/cpu_offload.py | 17 +++++++++++++++-- megatron/core/transformer/transformer_config.py | 9 +++++++++ megatron/core/utils.py | 4 ++-- megatron/training/arguments.py | 3 --- megatron/training/training.py | 16 +++++++++++++++- 7 files changed, 48 insertions(+), 11 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 4864af8e5f8..940b6a8dd39 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -296,7 +296,7 @@ def __init__( raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.") if self.config.fine_grained_activation_offloading: te_version = get_te_version() - if te_version == PkgVersion("2.8.0.dev0+74a5f77b"): + if te_version == PkgVersion("2.8.0.dev0+93a67af"): extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading else: raise ValueError( @@ -515,7 +515,7 @@ def __init__( if self.config.fine_grained_activation_offloading: te_version = get_te_version() - if te_version == PkgVersion("2.8.0.dev0+74a5f77b"): + if te_version == PkgVersion("2.8.0.dev0+93a67af"): extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading else: raise ValueError( @@ -1118,7 +1118,7 @@ def __init__( ) if self.config.fine_grained_activation_offloading: te_version = get_te_version() - if te_version == PkgVersion("2.8.0.dev0+74a5f77b"): + if te_version == PkgVersion("2.8.0.dev0+93a67af"): extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading else: raise ValueError( diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 74c6e1cd101..001463ca488 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -346,12 +346,16 @@ def initialize_model_chunk_offload_handler(self): num_layers = self.decoder.num_layers_per_pipeline_rank if self.mtp_process: num_layers = num_layers + self.config.mtp_num_layers + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + pp_size = parallel_state.get_pipeline_model_parallel_world_size() + last_stage_is_loss = (pp_rank == pp_size - 1) and self.config.last_vp_stage_is_loss # TODO: will be an issue when dense layer is placed across different pipeline stages PipelineOffloadManager.get_instance().reset_chunk_handler( num_layers, self.vp_stage, self.config.fine_grained_activation_offloading, self.decoder.num_dense_layer, + last_stage_is_loss, ) def forward( diff --git a/megatron/core/transformer/cpu_offload.py b/megatron/core/transformer/cpu_offload.py index 30d6a26e9ec..275260fccb5 100644 --- a/megatron/core/transformer/cpu_offload.py +++ b/megatron/core/transformer/cpu_offload.py @@ -78,7 +78,7 @@ def flush(self): self._stages[i] = [] def push(self, handler): - print_rank("pushing handler") + print_rank(f"pushing handler {handler}") self._queue.append(handler) def pop(self): @@ -100,12 +100,23 @@ def front(self): def size(self): return len(self._queue) - def reset_chunk_handler(self, num_layer, vp_stage, offload=True, num_dense_layer=0): + def reset_chunk_handler(self, num_layer, vp_stage, offload=True, num_dense_layer=0, last_stage_is_loss=False): if vp_stage is None: cur_vpp_rank = 0 else: cur_vpp_rank = vp_stage + if last_stage_is_loss: + from megatron.core import parallel_state + vpp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() + # skip the last stage + if cur_vpp_rank == vpp_size - 1: + return + # reduce the vpp size + if self._vpp == vpp_size: + self._vpp -= 1 + self._stages = self._stages[:-1] + first_last_vpp_rank = self._first_last_vpp_rank # rewind if cur_vpp_rank == self._vpp - 1: @@ -302,6 +313,7 @@ def bulk_reload_group(self, group_to_reload): for tensor_label, state in self._tensor_tag_to_state.items(): group_id, _ = tensor_label if group_id == group_id_to_reload: + print_rank(f"tensor_label {tensor_label}") found_reload_group = True event = self._offload_events[-1] if isinstance(state, tuple): @@ -375,6 +387,7 @@ def bulk_reload(self): if len(self._groups_to_reload) > 0 and self._layer_index > self.num_dense_layer: # load next layer if self.bulk_reload_group(self._groups_to_reload[-1]): + print_rank(f"bulk_reload_group {self._groups_to_reload}") self._groups_to_reload.pop() else: # load the last layer of one backward chunk in advance diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 046166df2da..561226d1ec2 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -717,6 +717,8 @@ class TransformerConfig(ModelParallelConfig): "expert_fc1": offload the input of the expert fc1 part. "moe_act": offload the input of the moe act part. """ + last_vp_stage_is_loss: bool = False + """If True, the last virtual pipeline stage is the loss stage.""" def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. @@ -1217,6 +1219,13 @@ def __post_init__(self): f"{self.virtual_pipeline_model_parallel_size}" ) + if len(self.offload_modules) > 0: + if self.pipeline_model_parallel_layout is not None: + from megatron.core.transformer.pipeline_parallel_layer_layout import LayerType + if (len(self.pipeline_model_parallel_layout.layout[-1][-1]) == 1 and + self.pipeline_model_parallel_layout.layout[-1][-1][0] is LayerType.loss): + self.last_vp_stage_is_loss = True + if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True diff --git a/megatron/core/utils.py b/megatron/core/utils.py index e66dab8abd8..7e1b9d9251a 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -136,7 +136,7 @@ def wrapped_func(*args, **kwargs): if config.is_experimental_enabled() is not True: raise ExperimentalNotEnabledError(f"Flag config.ENABLE_EXPERIMENTAL not enabled.") - logger.info("Setting ENABLE_EXPERIMENTAL=True will run experimental code.") + # logger.info("Setting ENABLE_EXPERIMENTAL=True will run experimental code.") return func(*args, **kwargs) @@ -213,7 +213,7 @@ def guard(super: super, attr: str): f"Flag config.ENABLE_EXPERIMENTAL not enabled." ) - logger.info("Setting ENABLE_EXPERIMENTAL=True will run experimental code.") + # logger.info("Setting ENABLE_EXPERIMENTAL=True will run experimental code.") return super.__getattribute__(attr) class ClassInterceptor(type): diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 76aac0ea63e..5b4ab369840 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1138,9 +1138,6 @@ def validate_args(args, defaults={}): "disabling gradient_accumulation_fusion is only supported with TE >= 2.7.0 " "when enabling delay_wgrad_compute" ) - - # if args.fine_grained_activation_offloading: - # assert not args.overlap_grad_reduce, "overlap_grad_reduce is not supported with fine_grained_activation_offloading" if args.mtp_num_layers: assert not args.use_legacy_models, "The legacy Megatron models does not support Multi-Token Prediction (MTP)." diff --git a/megatron/training/training.py b/megatron/training/training.py index f4a3eb9bef6..bc141e0c4dd 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -589,6 +589,11 @@ def pretrain( args = get_args() timers = get_timers() + if args.profile and torch.distributed.get_rank() in args.profile_ranks: + torch.cuda.memory._record_memory_history( + max_entries=1000000, + ) + if args.log_progress: append_to_progress_log("Starting job") @@ -1625,7 +1630,7 @@ def training_log( total_loss_dict[skipped_iters_key] = 0 total_loss_dict[nan_iters_key] = 0 print_rank_last(log_string) - if report_memory_flag: + if report_memory_flag and iteration == 10: # Report memory after optimizer state has been initialized. if torch.distributed.get_rank() == 0: num_microbatches = get_num_microbatches() @@ -2173,6 +2178,15 @@ def get_e2e_base_metrics(): # Run training iterations till done. buffered_rollouts = None while iteration < args.train_iters: + if args.profile and torch.distributed.get_rank() in args.profile_ranks and iteration == 3: + try: + comment = os.getenv("COMMENT") + model_name = os.getenv("MODEL") + memory_snapshot_path = f"/lustre/fsw/coreai_devtech_all/hongbinl/cpu_offloading/megatron-moe-scripts/pyt_profile/{model_name}_{comment}" + torch.cuda.memory._dump_snapshot(f"{memory_snapshot_path}.pickle") + print_rank_0(f"Captured memory snapshot at {memory_snapshot_path}.pickle") + except Exception as e: + print_rank_0(f"Failed to capture memory snapshot {e}") if args.profile and torch.distributed.get_rank() in args.profile_ranks: if args.use_pytorch_profiler: prof.step() From 17d14d88bf7dcf14f4879f7e31c8b8a98ee5ea14 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 9 Oct 2025 01:55:58 -0700 Subject: [PATCH 33/35] 1. remove TE version checking in developing stage 2. refine the sync mechanism; 3. remove mark_layer_start; 4. support activation offload for dense layer; 5. support cuda graph but the cuda graph scope cannot contain the offloading module Signed-off-by: Hongbin Liu --- .../core/extensions/transformer_engine.py | 48 ++++---- megatron/core/transformer/attention.py | 8 +- megatron/core/transformer/cpu_offload.py | 112 +++++++----------- megatron/core/transformer/moe/experts.py | 4 +- .../transformer/multi_latent_attention.py | 8 +- .../core/transformer/transformer_block.py | 6 + .../core/transformer/transformer_layer.py | 9 +- 7 files changed, 86 insertions(+), 109 deletions(-) diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 940b6a8dd39..86c717359e3 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -294,14 +294,14 @@ def __init__( extra_kwargs["delay_wgrad_compute"] = self.config.delay_wgrad_compute else: raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.") - if self.config.fine_grained_activation_offloading: - te_version = get_te_version() - if te_version == PkgVersion("2.8.0.dev0+93a67af"): - extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading - else: - raise ValueError( - f"Transformer Engine v{te_version} does not support fine_grained_activation_offloading." - ) + # if self.config.fine_grained_activation_offloading: + # te_version = get_te_version() + # if te_version == PkgVersion("2.8.0.dev0+93a67af"): + extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading + # else: + # raise ValueError( + # f"Transformer Engine v{te_version} does not support fine_grained_activation_offloading." + # ) if ( self.config.tp_comm_overlap and tp_comm_buffer_name @@ -513,14 +513,14 @@ def __init__( else: raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.") - if self.config.fine_grained_activation_offloading: - te_version = get_te_version() - if te_version == PkgVersion("2.8.0.dev0+93a67af"): - extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading - else: - raise ValueError( - f"Transformer Engine v{te_version} does not support fine_grained_activation_offloading." - ) + # if self.config.fine_grained_activation_offloading: + # te_version = get_te_version() + # if te_version == PkgVersion("2.8.0.dev0+93a67af"): + extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading + # else: + # raise ValueError( + # f"Transformer Engine v{te_version} does not support fine_grained_activation_offloading." + # ) # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` if is_te_min_version("0.11.0"): @@ -1116,14 +1116,14 @@ def __init__( raise RuntimeError( "Only TE with version >=2.3.0 supports delay_wgrad_compute now." ) - if self.config.fine_grained_activation_offloading: - te_version = get_te_version() - if te_version == PkgVersion("2.8.0.dev0+93a67af"): - extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading - else: - raise ValueError( - f"Transformer Engine v{te_version} does not support fine_grained_activation_offloading." - ) + # if self.config.fine_grained_activation_offloading: + # te_version = get_te_version() + # if te_version == PkgVersion("2.8.0.dev0+93a67af"): + extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading + # else: + # raise ValueError( + # f"Transformer Engine v{te_version} does not support fine_grained_activation_offloading." + # ) extra_kwargs["ub_name"] = tp_comm_buffer_name diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index f565f1abee7..7a283ba6b46 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -692,11 +692,11 @@ def forward( if self.offload_qkv_linear: if not hidden_states.is_contiguous(): hidden_states = hidden_states.contiguous() - hidden_states = group_prefetch_offload_start(hidden_states) + hidden_states = group_prefetch_offload_start(hidden_states, name="qkv_linear") hidden_states.offloading_activation = True with PipelineOffloadManager.get_instance(): query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) - query, key, value = group_prefetch_offload_commit(query, key, value, release_tensors=[hidden_states]) + query, key, value = group_prefetch_offload_commit(query, key, value, name="qkv_linear", release_tensors=[hidden_states]) else: query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) nvtx_range_pop(suffix="qkv") @@ -863,7 +863,7 @@ def forward( ) core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') if self.offload_core_attention and self.training: - core_attn_out, = group_prefetch_offload_commit(core_attn_out, release_tensors=[query, key, value]) + core_attn_out, = group_prefetch_offload_commit(core_attn_out, name="core_attn", release_tensors=[query, key, value]) offload_context = contextlib.nullcontext() if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': @@ -886,7 +886,7 @@ def forward( with offload_context: output, bias = self.linear_proj(core_attn_out) if self.offload_attn_proj: - output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[core_attn_out]) + output, bias = group_prefetch_offload_commit(output, bias, name="attn_proj", release_tensors=[core_attn_out]) offload_context = contextlib.nullcontext() nvtx_range_pop(suffix="linear_proj") diff --git a/megatron/core/transformer/cpu_offload.py b/megatron/core/transformer/cpu_offload.py index 275260fccb5..01bc9bb65cf 100644 --- a/megatron/core/transformer/cpu_offload.py +++ b/megatron/core/transformer/cpu_offload.py @@ -122,11 +122,7 @@ def reset_chunk_handler(self, num_layer, vp_stage, offload=True, num_dense_layer if cur_vpp_rank == self._vpp - 1: self.flush() first_last_vpp_rank = first_last_vpp_rank and (cur_vpp_rank == self._vpp - 1) - # If the model chunk contains only the dense layers, initialize a null chunk handler. - if num_layer <= num_dense_layer: - cur_chunk = NullChunkOffloadHandler(num_layer, first_last_vpp_rank, offload, num_dense_layer) - else: - cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload, num_dense_layer) + cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload) # save for latter push self._stages[cur_vpp_rank].append(cur_chunk) if cur_vpp_rank == self._vpp - 1: @@ -136,6 +132,9 @@ def reset_chunk_handler(self, num_layer, vp_stage, offload=True, num_dense_layer self._cur_forward_chunk = cur_chunk cur_chunk.vpp_rank = cur_vpp_rank + def set_last_layer(self, is_last_layer): + self._cur_forward_chunk.is_last_layer = is_last_layer + def cur_forward_chunk(self): return self._cur_forward_chunk @@ -202,7 +201,7 @@ def reload(state, non_blocking=None): non_blocking = cpu_backup.is_pinned() return cpu_backup.to(dev, non_blocking=non_blocking) - def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, num_dense_layer=0): + def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True): self._num_layers = num_layer # Data Structure to maintain reference to activation tensors self._tensor_tag_to_state = {} @@ -213,7 +212,6 @@ def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, num_dense_l self._offloaded_group_index = 0 self._groups_to_offload = [] self._groups_to_reload = [] - self.num_dense_layer = num_dense_layer self._layer_index = 0 self._tensor_count_current_group = 0 self.multi_input_offload_count = False @@ -222,7 +220,8 @@ def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True, num_dense_l self.torch_tensor_count = 0 self.d2h_stream = PipelineOffloadManager.get_instance().d2h_stream self.h2d_stream = PipelineOffloadManager.get_instance().h2d_stream - self._offload_events = [] + self._offload_events = {} + self._reload_events = {} self.do_offload = offload self.is_last_layer = False @@ -288,18 +287,31 @@ def bulk_offload_group(self, group_to_offload): print_rank(f"group_to_offload {group_to_offload}") assert not isinstance(state, tuple) tensor_on_device = state - # if offload, return the reference to cpu copy - # print_rank(f"tensor_need_offloading_checker {self.tensor_need_offloading_checker(tensor_on_device)}") - # print_rank(f"tensor_on_device {tensor_on_device.shape}") if self.tensor_need_offloading_checker(tensor_on_device): state = self.offload(tensor_on_device) + # TODO: check if we really need it. + # Record the last offloading event for this group, + # which is used to avoid reloading before offloading. + event = torch.cuda.Event() + event.record(self.d2h_stream) + self._offload_events[name] = event tensor_on_device.record_stream(self.d2h_stream) - # self.offload_count_per_layer[group_to_offload] += 1 self._tensor_tag_to_state[tensor_tag] = state - # self._offloaded_group_count = group_to_offload + 1 print_rank("exit bulk_offload_group") torch.cuda.nvtx.range_pop() + def get_offload_event(self, name): + if name in self._offload_events: + return self._offload_events[name] + else: + return None + + def get_reload_event(self, name): + if name in self._reload_events: + return self._reload_events[name] + else: + return None + def bulk_reload_group(self, group_to_reload): """Bulk reload group.""" print_rank("bulk_reload_group") @@ -315,16 +327,16 @@ def bulk_reload_group(self, group_to_reload): if group_id == group_id_to_reload: print_rank(f"tensor_label {tensor_label}") found_reload_group = True - event = self._offload_events[-1] + event = self.get_offload_event(name) if isinstance(state, tuple): # make sure the tensor is already offloaded to cpu before reloading it. torch.cuda.current_stream().wait_event(event) recovered_tensor = self.reload(state) + event.record(self.h2d_stream) + self._reload_events[name] = event print_rank(f"recovered_tensor {recovered_tensor.shape}") self._tensor_tag_to_state[tensor_label] = recovered_tensor torch.cuda.nvtx.range_pop() - if found_reload_group: - self._offload_events.pop() return found_reload_group def pre_reload_last_layer(self): @@ -334,7 +346,6 @@ def pre_reload_last_layer(self): return assert not self._is_first_last_vpp_chunk print_rank(f"len(self._groups_to_reload) {len(self._groups_to_reload)}") - print_rank(f"len(self._offload_events) {len(self._offload_events)}") if len(self._groups_to_reload) > 0: if self.bulk_reload_group(self._groups_to_reload[-1]): self._groups_to_reload.pop() @@ -343,8 +354,6 @@ def should_bulk_offload(self): """Check if the chunk should be offloaded.""" if not self.do_offload: return False - if self._layer_index < self.num_dense_layer: - return False # first backward chunk if self.is_first_last_layer(): return False @@ -361,12 +370,9 @@ def bulk_offload(self, release_tensors): print_rank("bulk_offload") if self.should_bulk_offload(): group_to_offload = self._groups_to_offload.pop() + name = group_to_offload[1] self._groups_to_reload.append(group_to_offload) self.bulk_offload_group(group_to_offload) - event = torch.cuda.Event() - event.record(self.d2h_stream) - # TODO: check if we really need it. - self._offload_events.append(event) if len(release_tensors) > 0: cur_stream = torch.cuda.current_stream() for release_tensor in release_tensors: @@ -384,7 +390,7 @@ def on_group_commit_forward(self, release_tensors): def bulk_reload(self): print_rank("bulk_reload") - if len(self._groups_to_reload) > 0 and self._layer_index > self.num_dense_layer: + if len(self._groups_to_reload) > 0: # load next layer if self.bulk_reload_group(self._groups_to_reload[-1]): print_rank(f"bulk_reload_group {self._groups_to_reload}") @@ -395,7 +401,7 @@ def bulk_reload(self): if next_backward_chunk is not None: next_backward_chunk.pre_reload_last_layer() - def on_group_commit_backward(self): + def on_group_commit_backward(self, name): """Prepare for reloadingthe next group of tensors.""" print_rank("on_group_commit_backward") cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() @@ -403,16 +409,15 @@ def on_group_commit_backward(self): PipelineOffloadManager.get_instance().pop() cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() assert cur_backward_chunk is self - # self.h2d_stream.wait_stream(torch.cuda.current_stream()) + # make sure the reloading jobs for current computation are done. + event = self.get_reload_event(name) + if event is not None: + torch.cuda.current_stream().wait_event(event) self._offloaded_group_index = self._offloaded_group_index - 1 def on_group_start_forward(self, name): """Prepare for offloading the next group of tensors.""" print_rank(f"on_group_start_forward {self._layer_index} {self._num_layers}") - if self._layer_index == self._num_layers: - self.is_last_layer = True - else: - self.is_last_layer = False self._offloaded_group_index = self._offloaded_group_index + 1 self._tensor_count_current_group = 0 self._groups_to_offload.append((self._offloaded_group_index, name)) @@ -422,18 +427,6 @@ def on_group_start_backward(self): print_rank("on_group_start_backward") self.h2d_stream.wait_stream(torch.cuda.current_stream()) self.bulk_reload() - - def on_layer_start_forward(self): - """Increment the layer index.""" - print_rank("on_layer_start_forward") - self._layer_index = self._layer_index + 1 - - def on_layer_start_backward(self): - """When the bprop of one layer finishes, make sure the reloading jobs on h2d stream are done. - """ - print_rank("on_layer_start_backward") - self._layer_index = self._layer_index - 1 - torch.cuda.current_stream().wait_stream(self.h2d_stream) class NullChunkOffloadHandler(ChunkOffloadHandler): pass @@ -451,11 +444,13 @@ def forward(ctx, *args): print_rank("GroupCommitFunction forward") release_tensors = args[-1] - cpu_offload_handler = args[-2] - tensor = args[:-2] + name = args[-2] + cpu_offload_handler = args[-3] + tensor = args[:-3] if not isinstance(cpu_offload_handler, NullChunkOffloadHandler): cpu_offload_handler.on_group_commit_forward(release_tensors) ctx.cpu_offload_handler = cpu_offload_handler + ctx.name = name # return the identical tensor return tensor @@ -467,18 +462,18 @@ def backward(ctx, *grad_output): cpu_offload_handler = ctx.cpu_offload_handler if not isinstance(cpu_offload_handler, NullChunkOffloadHandler): - cpu_offload_handler.on_group_commit_backward() - return grad_output + (None, None) + cpu_offload_handler.on_group_commit_backward(ctx.name) + return grad_output + (None, None, None) -def group_prefetch_offload_commit(*tensor, release_tensors=[]): +def group_prefetch_offload_commit(*tensor, name, release_tensors=[]): """Specify the tensors to be released after offloading. release_tensors is a list of tensors to be released after offloading. The tensors will be untyped_storage().resize_(0) after offloading. Note: specify the tensors only when they are not automatically released by torch gc. """ cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() - return GroupCommitFunction.apply(*tensor, cur_forward_chunk, release_tensors) + return GroupCommitFunction.apply(*tensor, cur_forward_chunk, name, release_tensors) class GroupStartFunction(torch.autograd.Function): @@ -512,24 +507,3 @@ def backward(ctx, grad_output): def group_prefetch_offload_start(tensor, name=None): cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() return GroupStartFunction.apply(tensor, cur_forward_chunk, name) - -class MarkLayerStartFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, tensor, cpu_offload_handler): - ctx.cpu_offload_handler = cpu_offload_handler - print_rank("MarkLayerStartFunction forward") - if not isinstance(cpu_offload_handler, NullChunkOffloadHandler): - cpu_offload_handler.on_layer_start_forward() - return tensor - - @staticmethod - def backward(ctx, grad_output): - print_rank("MarkLayerStartFunction backward") - cpu_offload_handler = ctx.cpu_offload_handler - if not isinstance(cpu_offload_handler, NullChunkOffloadHandler): - cpu_offload_handler.on_layer_start_backward() - return grad_output, None, None - -def mark_layer_start(tensor): - cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() - return MarkLayerStartFunction.apply(tensor, cur_forward_chunk) \ No newline at end of file diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 627ace53475..eddf61e6f8a 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -888,7 +888,7 @@ def forward( permuted_local_hidden_states, tokens_per_expert ) if self.offload_expert_fc1: - fc1_output, bias_parallel = group_prefetch_offload_commit(fc1_output, bias_parallel, release_tensors=[]) + fc1_output, bias_parallel = group_prefetch_offload_commit(fc1_output, bias_parallel, name="expert_fc1", release_tensors=[]) offload_context = contextlib.nullcontext() def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): @@ -980,7 +980,7 @@ def glu(x): if self.activation_recompute: self.activation_checkpoint.discard_output_and_register_recompute(output) if self.offload_moe_act: - output, = group_prefetch_offload_commit(output, release_tensors=[]) + output, = group_prefetch_offload_commit(output, name="moe_act", release_tensors=[]) offload_context = contextlib.nullcontext() diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 1985cbee4cc..3a16f0b5377 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -305,9 +305,9 @@ def forward( # Only rearrange if not in absorption mode (Flash MLA handles format correctly) if not inference_context.is_decode_only(): core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') - if self.offload_core_attention and self.training: - core_attn_out, = group_prefetch_offload_commit(core_attn_out, release_tensors=[query, key, value]) - offload_context = contextlib.nullcontext() + if self.offload_core_attention and self.training: + core_attn_out, = group_prefetch_offload_commit(core_attn_out, name="core_attn", release_tensors=[query, key, value]) + offload_context = contextlib.nullcontext() # We are doing absorption with cache mla latents and decode mode. if self.cache_mla_latents and inference_context.is_decode_only(): @@ -340,7 +340,7 @@ def forward( with offload_context: output, bias = self.linear_proj(core_attn_out) if self.offload_attn_proj: - output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[core_attn_out]) + output, bias = group_prefetch_offload_commit(output, bias, name="attn_proj", release_tensors=[core_attn_out]) offload_context = contextlib.nullcontext() return output, bias diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index 3ca054787e7..867b3689a1d 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -33,6 +33,7 @@ get_pg_rank, make_viewless_tensor, ) +from megatron.core.transformer.cpu_offload import PipelineOffloadManager try: import transformer_engine.pytorch as te # pylint: disable=unused-import @@ -645,6 +646,11 @@ def forward( inner_quantization_context = nullcontext() else: inner_quantization_context = nullcontext() + + if l_no == self.num_layers_per_pipeline_rank - 1: + PipelineOffloadManager.get_instance().set_last_layer(True) + else: + PipelineOffloadManager.get_instance().set_last_layer(False) with self.offload_context, inner_quantization_context: hidden_states, context = layer( diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 855a6a4f75e..ed5ace3fa6c 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -36,7 +36,6 @@ PipelineOffloadManager, group_prefetch_offload_start, group_prefetch_offload_commit, - mark_layer_start, ) logger = logging.getLogger(__name__) @@ -482,8 +481,6 @@ def forward(self, *args, **kwargs): # this is only used to uniquely identify decode and non-decode cuda graph # runners in the cuda graph manager kwargs.pop("dynamic_inference_decode_only", None) - if self.config.fine_grained_activation_offloading: - kwargs["hidden_states"] = mark_layer_start(kwargs["hidden_states"]) hidden_states, context = self._forward_attention(*args, **kwargs) output = self._forward_mlp(hidden_states, kwargs.get("inference_context", None)) return output, context @@ -567,7 +564,7 @@ def _forward_attention( sequence_len_offset=sequence_len_offset, ) if self.offload_self_attn: - attention_output_with_bias, = group_prefetch_offload_commit(attention_output_with_bias, release_tensors=[input_layernorm_output]) + attention_output_with_bias, = group_prefetch_offload_commit(attention_output_with_bias, name="self_attn", release_tensors=[input_layernorm_output]) offload_context = contextlib.nullcontext() nvtx_range_pop(suffix="self_attention") @@ -588,7 +585,7 @@ def _forward_attention( nvtx_range_pop(suffix="self_attn_bda") if self.offload_attn_norm: - hidden_states, = group_prefetch_offload_commit(hidden_states, release_tensors=[residual]) + hidden_states, = group_prefetch_offload_commit(hidden_states, name="attn_norm", release_tensors=[residual]) offload_context = contextlib.nullcontext() # Residual connection. @@ -705,7 +702,7 @@ def _forward_mlp(self, hidden_states, inference_context=None): ) nvtx_range_pop(suffix="mlp_bda") if self.offload_mlp_norm: - hidden_states, = group_prefetch_offload_commit(hidden_states, release_tensors=[residual]) + hidden_states, = group_prefetch_offload_commit(hidden_states, name="mlp_norm", release_tensors=[residual]) offload_context = contextlib.nullcontext() # Jit compiled function creates 'view' tensor. This tensor From 22dabcf8e454b6365efcb1b84c8d90a720425946 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 9 Oct 2025 03:42:19 -0700 Subject: [PATCH 34/35] minor fix Signed-off-by: Hongbin Liu --- megatron/core/transformer/cpu_offload.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/megatron/core/transformer/cpu_offload.py b/megatron/core/transformer/cpu_offload.py index 01bc9bb65cf..8dd5f139884 100644 --- a/megatron/core/transformer/cpu_offload.py +++ b/megatron/core/transformer/cpu_offload.py @@ -122,7 +122,11 @@ def reset_chunk_handler(self, num_layer, vp_stage, offload=True, num_dense_layer if cur_vpp_rank == self._vpp - 1: self.flush() first_last_vpp_rank = first_last_vpp_rank and (cur_vpp_rank == self._vpp - 1) - cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload) + # If the model chunk contains only the dense layers, initialize a null chunk handler. + if num_layer <= num_dense_layer: + cur_chunk = NullChunkOffloadHandler(num_layer, first_last_vpp_rank, offload) + else: + cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload) # save for latter push self._stages[cur_vpp_rank].append(cur_chunk) if cur_vpp_rank == self._vpp - 1: From 5c020240dd830bb32ceec033d765022bd3f5df31 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Thu, 9 Oct 2025 06:02:45 -0700 Subject: [PATCH 35/35] update README Signed-off-by: Hongbin Liu --- megatron/core/transformer/README.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/megatron/core/transformer/README.md b/megatron/core/transformer/README.md index 22221b92de7..5c16c5f85be 100644 --- a/megatron/core/transformer/README.md +++ b/megatron/core/transformer/README.md @@ -5,6 +5,10 @@ Fine-grained Activation Offloading

NVIDIA, rednote

+# What is Fine-grained Activation Offloading? + +Memory capacity are more and more important with the rising of extreme sparse MoE models like DeepSeek-V3 and Qwen3-235B. Fine-grained Activation Offloading targets at offloading the activation at the granularity of specific modules, so that we can calibrate the amount of offloading activation to maximize the training throughput. + # Quick Start ```bash @@ -24,6 +28,8 @@ Fine-grained Activation Offloading * Support MTP * Support mixed dense & moe layer * Support A2A Overlap +* Support CUDA Graph + * (Temporary) cuda graph scope cannot contains the offloading modules ## Known issues * We explicitly resize some tensors to 0 to release the memory space immediately, which sometimes leads to illegal memory access. Please remove the released tensors in `group_prefetch_offload_commit` if you run into the issue.