diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index 5bf25adc0dd..86c717359e3 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -294,6 +294,14 @@ def __init__( extra_kwargs["delay_wgrad_compute"] = self.config.delay_wgrad_compute else: raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.") + # if self.config.fine_grained_activation_offloading: + # te_version = get_te_version() + # if te_version == PkgVersion("2.8.0.dev0+93a67af"): + extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading + # else: + # raise ValueError( + # f"Transformer Engine v{te_version} does not support fine_grained_activation_offloading." + # ) if ( self.config.tp_comm_overlap and tp_comm_buffer_name @@ -505,6 +513,15 @@ def __init__( else: raise RuntimeError("Only TE with version >=2.3.0 supports delay_wgrad_compute now.") + # if self.config.fine_grained_activation_offloading: + # te_version = get_te_version() + # if te_version == PkgVersion("2.8.0.dev0+93a67af"): + extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading + # else: + # raise ValueError( + # f"Transformer Engine v{te_version} does not support fine_grained_activation_offloading." + # ) + # Only Transformer-Engine version >= 0.11.0 supports `RMSNorm` if is_te_min_version("0.11.0"): extra_kwargs["normalization"] = self.config.normalization @@ -1099,6 +1116,14 @@ def __init__( raise RuntimeError( "Only TE with version >=2.3.0 supports delay_wgrad_compute now." ) + # if self.config.fine_grained_activation_offloading: + # te_version = get_te_version() + # if te_version == PkgVersion("2.8.0.dev0+93a67af"): + extra_kwargs["fine_grained_activation_offloading"] = self.config.fine_grained_activation_offloading + # else: + # raise ValueError( + # f"Transformer Engine v{te_version} does not support fine_grained_activation_offloading." + # ) extra_kwargs["ub_name"] = tp_comm_buffer_name diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index 9f5018c26ea..d0677dd9d6c 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -5,6 +5,7 @@ import torch +import warnings @dataclass class ModelParallelConfig: @@ -314,7 +315,7 @@ class ModelParallelConfig: rank 0 | 0 1 2 0 1 2 3 4 3 4 rank 1 | 0 1 2 0 1 2 3 4 3 4 """ - + ################### # CPU Offloading ################### diff --git a/megatron/core/models/gpt/fine_grained_callables.py b/megatron/core/models/gpt/fine_grained_callables.py index fbecc047682..1552168a53a 100644 --- a/megatron/core/models/gpt/fine_grained_callables.py +++ b/megatron/core/models/gpt/fine_grained_callables.py @@ -16,6 +16,12 @@ get_mtp_layer_offset, ) from megatron.core.transformer.transformer_layer import TransformerLayer, make_viewless_tensor +from megatron.core.transformer.cpu_offload import ( + PipelineOffloadManager, + group_prefetch_offload_start, + group_prefetch_offload_commit, + mark_layer_start, +) def weak_method(method): @@ -331,6 +337,8 @@ def submodule_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor): """ Performs same attnention forward logic as GPT Model. """ + if layer.config.fine_grained_activation_offloading: + hidden_states = mark_layer_start(hidden_states) hidden_states, _ = layer._forward_attention( hidden_states=hidden_states, attention_mask=node.chunk_state.attention_mask, @@ -347,13 +355,20 @@ def submodule_post_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor) Run forward pass for computations between attention and dispatch: pre mlp layernorm->router->dispatch preprocess """ + offload_context = nullcontext() + if layer.offload_mlp_norm: + hidden_states = group_prefetch_offload_start(hidden_states, name="mlp_norm") + offload_context = PipelineOffloadManager.get_instance() if layer.recompute_pre_mlp_layernorm: layer.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint( - layer.pre_mlp_layernorm, hidden_states - ) + with offload_context: + pre_mlp_layernorm_output = layer.pre_mlp_norm_checkpoint.checkpoint( + layer.pre_mlp_layernorm, hidden_states + ) else: - pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states) + with offload_context: + pre_mlp_layernorm_output = layer.pre_mlp_layernorm(hidden_states) + offload_context = nullcontext() local_tokens, probs, _ = layer.mlp.router_and_preprocess(pre_mlp_layernorm_output) @@ -433,6 +448,8 @@ def submodule_combine_forward( hidden_states = layer.mlp_bda(layer.training, layer.config.bias_dropout_fusion)( mlp_output_with_bias, residual, layer.hidden_dropout ) + if layer.offload_mlp_norm: + hidden_states, = group_prefetch_offload_commit(hidden_states, release_tensors=[residual]) output = make_viewless_tensor( inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True ) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index e74f93bd1aa..001463ca488 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -33,6 +33,7 @@ from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import WrappedTensor, deprecate_inference_params +from megatron.core.transformer.cpu_offload import PipelineOffloadManager class GPTModel(LanguageModule): @@ -341,6 +342,22 @@ def _preprocess( return decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset + def initialize_model_chunk_offload_handler(self): + num_layers = self.decoder.num_layers_per_pipeline_rank + if self.mtp_process: + num_layers = num_layers + self.config.mtp_num_layers + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + pp_size = parallel_state.get_pipeline_model_parallel_world_size() + last_stage_is_loss = (pp_rank == pp_size - 1) and self.config.last_vp_stage_is_loss + # TODO: will be an issue when dense layer is placed across different pipeline stages + PipelineOffloadManager.get_instance().reset_chunk_handler( + num_layers, + self.vp_stage, + self.config.fine_grained_activation_offloading, + self.decoder.num_dense_layer, + last_stage_is_loss, + ) + def forward( self, input_ids: Tensor, @@ -366,6 +383,8 @@ def forward( runtime_gather_output (bool): Gather output at runtime. Default None means `parallel_output` arg in the constructor will be used. """ + if self.config.fine_grained_activation_offloading: + self.initialize_model_chunk_offload_handler() inference_context = deprecate_inference_params(inference_context, inference_params) @@ -627,6 +646,8 @@ def build_schedule_plan( TransformerModelChunkSchedulePlan: The model chunk schedule plan. """ + if self.config.fine_grained_activation_offloading: + self.initialize_model_chunk_offload_handler() from ..common.model_chunk_schedule_plan import TransformerModelChunkSchedulePlan return TransformerModelChunkSchedulePlan( diff --git a/megatron/core/pipeline_parallel/schedules.py b/megatron/core/pipeline_parallel/schedules.py index c0d8af394d2..179b7a4f014 100644 --- a/megatron/core/pipeline_parallel/schedules.py +++ b/megatron/core/pipeline_parallel/schedules.py @@ -19,6 +19,7 @@ from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.cuda_graphs import create_cudagraphs from megatron.core.transformer.moe.router import MoEAuxLossAutoScaler +from megatron.core.transformer.cpu_offload import PipelineOffloadManager from megatron.core.utils import ( drain_embedding_wgrad_compute, get_attr_wrapped_model, @@ -558,6 +559,9 @@ def forward_backward_no_pipelining( adjust_tensor_shapes_fn is None ), "adjust_tensor_shapes_fn is not supported for non-pipeline-parallel schedule" + if not forward_only: + PipelineOffloadManager.get_instance().reset() + config = get_model_config(model) if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) @@ -898,6 +902,9 @@ def forward_backward_pipelining_with_interleaving( adjust_tensor_shapes_fn is None ), "adjust_tensor_shapes_fn is not supported for interleaved pipeline parallelism" + if not forward_only: + PipelineOffloadManager.get_instance().reset() + if config.overlap_p2p_comm and config.batch_p2p_comm: raise ValueError("Can not use both overlap_p2p_comm and batch_p2p_comm") @@ -2043,6 +2050,9 @@ def forward_backward_pipelining_without_interleaving( if config.timers is not None: config.timers('forward-backward', log_level=1).start(barrier=config.barrier_with_L1_time) + if not forward_only: + PipelineOffloadManager.get_instance().reset() + # Disable async grad reductions no_sync_func = config.no_sync_func if no_sync_func is None: diff --git a/megatron/core/tensor_parallel/random.py b/megatron/core/tensor_parallel/random.py index 54cac0e41e3..9d96dfaf815 100644 --- a/megatron/core/tensor_parallel/random.py +++ b/megatron/core/tensor_parallel/random.py @@ -510,10 +510,11 @@ def forward(ctx, run_function, checkpoint_without_output_obj, *args): @staticmethod def backward(ctx, *args): """Backward pass.""" - inputs = ctx.saved_tensors + inputs = ctx.inputs outputs = ctx.outputs torch.autograd.backward(outputs, args) ctx.outputs = None + ctx.inputs = None grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp for inp in inputs) return (None, None) + grads @@ -573,8 +574,19 @@ def _recompute(self, _): recompute_ctx = contextlib.nullcontext() fp8_ctx = contextlib.nullcontext() + inputs = self.ctx.saved_tensors + # do not know why, if saved_tensors is handled by saved_tensor_hook, grad of inputs will be None (not nan) + # detach it to bypass + def detach(t): + if isinstance(t, torch.Tensor): + requires_grad = t.requires_grad + t = t.detach() + t.requires_grad_(requires_grad) + return t + + inputs = tuple(detach(t) for t in inputs) with torch.enable_grad(), fp8_ctx, recompute_ctx: - outputs = self.run_function(*self.ctx.saved_tensors) + outputs = self.run_function(*inputs) self.run_function = None self.rng_states = None @@ -590,6 +602,7 @@ def _recompute(self, _): output.untyped_storage().copy_(recomputation_output.untyped_storage()) self.ctx.outputs = outputs + self.ctx.inputs = inputs self.outputs = None self.ctx = None diff --git a/megatron/core/transformer/README.md b/megatron/core/transformer/README.md new file mode 100644 index 00000000000..5c16c5f85be --- /dev/null +++ b/megatron/core/transformer/README.md @@ -0,0 +1,143 @@ +
+ +Fine-grained Activation Offloading +============= +

NVIDIA, rednote

+
+ +# What is Fine-grained Activation Offloading? + +Memory capacity are more and more important with the rising of extreme sparse MoE models like DeepSeek-V3 and Qwen3-235B. Fine-grained Activation Offloading targets at offloading the activation at the granularity of specific modules, so that we can calibrate the amount of offloading activation to maximize the training throughput. + +# Quick Start + +```bash +# Enable fine-grained activation offloading +--fine-grained-activation-offloading + +# Specify which modules are going to be offloaded +# Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act". +--offload-modules core_attn +``` + +# Current status +## Features +* Support PP=1/PP/Interleaved PP +* Compatible with fine-grained recomputation +* Support FP8 +* Support MTP +* Support mixed dense & moe layer +* Support A2A Overlap +* Support CUDA Graph + * (Temporary) cuda graph scope cannot contains the offloading modules + +## Known issues +* We explicitly resize some tensors to 0 to release the memory space immediately, which sometimes leads to illegal memory access. Please remove the released tensors in `group_prefetch_offload_commit` if you run into the issue. + +## WIP items +* Code refactor +* Benchmark + +# Methodology + +## Offload/Reload the input of one module to/from CPU +Let's take the attention projection module as an example: +``` +nvtx_range_push(suffix="linear_proj") +offload_context = contextlib.nullcontext() +if self.offload_attn_proj: + core_attn_out = group_prefetch_offload_start(core_attn_out, name="attn_proj") + offload_context = PipelineOffloadManager.get_instance() +with offload_context: + output, bias = self.linear_proj(core_attn_out) +if self.offload_attn_proj: + output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[core_attn_out]) + offload_context = contextlib.nullcontext() +nvtx_range_pop(suffix="linear_proj") +``` +The above code snippet could be divided into three parts in order: +1. Mark the starting point of offloading a new module; +2. Record the save_for_backward tensors in fprop and push it to a tensor buffer; +3. Offload the recorded tensors after the module's fprop finishes; + +In bprop, the three parts above will: +1. Make sure the offloaded tensors are reloaded back to GPU; +2. Pop the corresponding tensors from the tensor buffer; +3. Reload the corresponding tensors of next module; + +## Compatible with PP&Interleaved PP + +`PipelineOffloadManager` is used to manage the chunks across different model chunks in fprop and bprop. +Before the model.forward() start, the `PipelineOffloadManager.get_instance().reset_chunk_handler` will be executed. In the fprop of this method, we create a `ChunkOffloadHandler` to handle the offloading context of one model chunk and then push it to a buffer, which will be popped out in a specific order in bprop. + +image + + +## Compatible with fine-grained recomputation + +offload_and_recompute + + +## A special case: attn_norm/mlp_norm + +# Performance + +## H100 + +### DeepSeek-V3-Proxy +#### Model structure +* Layer parameters are same as DeepSeek-V3 model +* Layer number is cut off to 14 layers +* Replace the fisrt 3 dense layers with 3 moe layers + +#### Key Hyper-parameters +* TP1PP4EP16VPP1CP1-MBS1GBS512 +* bf16 training +* DeepEP dispatcher +* `--cross-entropy-loss-fusion` and `--cross-entropy-fusion-impl te` +* `--moe-permute-fusion` +* `--moe-router-fusion` +* `--enable-experimental` + +#### Throughput and correctness + +image +image + + +#### Memory consumption + +Baseline (no offloading) +``` +[Rank 0] (after 10 iterations) memory (MB) | allocated: 24761.02978515625 | max allocated: 65203.93359375 | reserved: 64438.0 | max reserved: 74306.0 +[Rank 16] (after 10 iterations) memory (MB) | allocated: 18907.728515625 | max allocated: 52228.1533203125 | reserved: 58770.0 | max reserved: 58770.0 +[Rank 32] (after 10 iterations) memory (MB) | allocated: 18907.7529296875 | max allocated: 45200.8349609375 | reserved: 51772.0 | max reserved: 51772.0 +[Rank 48] (after 10 iterations) memory (MB) | allocated: 29006.82275390625 | max allocated: 48166.263671875 | reserved: 56328.0 | max reserved: 56328.0 +``` +With offloading expert_fc1, moe_act, act_norm and mlp_norm +``` +[Rank 0] (after 10 iterations) memory (MB) | allocated: 24705.02978515625 | max allocated: 48544.70849609375 | reserved: 61046.0 | max reserved: 61046.0 +[Rank 16] (after 10 iterations) memory (MB) | allocated: 18795.728515625 | max allocated: 38760.3876953125 | reserved: 46330.0 | max reserved: 46330.0 +[Rank 32] (after 10 iterations) memory (MB) | allocated: 18795.7529296875 | max allocated: 34950.2509765625 | reserved: 42452.0 | max reserved: 42452.0 +[Rank 48] (after 10 iterations) memory (MB) | allocated: 28950.82275390625 | max allocated: 41310.798828125 | reserved: 50408.0 | max reserved: 50408.0 +``` + +### Qwen3-30B-A3B +#### Model structure +* Same as Qwen-30B model structure + +#### Results + +| Model | Mapping | Sequence length | Recompute | Offload | Throughput (tflops) | Memory (MB) | +|---------------|--------------------------|-----------------|-----------|------------|---------------------|-------------| +| Qwen3-30B-A3B | TP1PP1EP8VPP1_MBS1GBS256 | 4096 | / | / | 194 | 65308 | +| | TP1PP1EP8VPP1_MBS1GBS256 | 8192 | full | / | 230 | 59566 | +| | TP1PP2EP8VPP4_MBS1GBS256 | 8192 | layernorm | expert_fc1 | 255 | 64962 | + + + +## GB200 + +# Acknowledgement + +This work refers to the previous work from Kuaishou: https://www.usenix.org/conference/atc24/presentation/yuan diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py index 6b2e898be6c..7a283ba6b46 100644 --- a/megatron/core/transformer/attention.py +++ b/megatron/core/transformer/attention.py @@ -6,6 +6,7 @@ import torch from torch import Tensor +import contextlib from megatron.core import tensor_parallel from megatron.core.inference.contexts import BaseInferenceContext @@ -37,6 +38,11 @@ from .enums import AttnMaskType from .transformer_config import TransformerConfig +from megatron.core.transformer.cpu_offload import ( + PipelineOffloadManager, + group_prefetch_offload_start, + group_prefetch_offload_commit, +) try: from einops import rearrange @@ -177,6 +183,21 @@ def __init__( and "core_attn" in self.config.recompute_modules ) + self.offload_qkv_linear = ( + self.config.fine_grained_activation_offloading + and "qkv_linear" in self.config.offload_modules + ) + + self.offload_core_attention = ( + self.config.fine_grained_activation_offloading + and "core_attn" in self.config.offload_modules + ) + + self.offload_attn_proj = ( + self.config.fine_grained_activation_offloading + and "attn_proj" in self.config.offload_modules + ) + # Output. self.linear_proj = build_module( submodules.linear_proj, @@ -668,7 +689,16 @@ def forward( # Get the query, key and value tensors based on the type of attention - # self or cross attn. nvtx_range_push(suffix="qkv") - query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + if self.offload_qkv_linear: + if not hidden_states.is_contiguous(): + hidden_states = hidden_states.contiguous() + hidden_states = group_prefetch_offload_start(hidden_states, name="qkv_linear") + hidden_states.offloading_activation = True + with PipelineOffloadManager.get_instance(): + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + query, key, value = group_prefetch_offload_commit(query, key, value, name="qkv_linear", release_tensors=[hidden_states]) + else: + query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) nvtx_range_pop(suffix="qkv") # =================================================== @@ -797,17 +827,22 @@ def forward( packed_seq_params=packed_seq_params, ) else: + offload_context = contextlib.nullcontext() + if self.offload_core_attention and self.training: + query = group_prefetch_offload_start(query, name="core_attn") + offload_context = PipelineOffloadManager.get_instance() if inference_context is None or inference_context.is_static_batching(): # Static batching attention kernel. - core_attn_out = self.core_attention( - query, - key, - value, - attention_mask, - attn_mask_type=attn_mask_type, - attention_bias=attention_bias, - packed_seq_params=packed_seq_params, - ) + with offload_context: + core_attn_out = self.core_attention( + query, + key, + value, + attention_mask, + attn_mask_type=attn_mask_type, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) else: # Dynamic batching attention kernel. @@ -827,6 +862,9 @@ def forward( block_table, ) core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') + if self.offload_core_attention and self.training: + core_attn_out, = group_prefetch_offload_commit(core_attn_out, name="core_attn", release_tensors=[query, key, value]) + offload_context = contextlib.nullcontext() if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd': # reshape to same output shape as unpacked case @@ -841,7 +879,15 @@ def forward( # ================= nvtx_range_push(suffix="linear_proj") - output, bias = self.linear_proj(core_attn_out) + offload_context = contextlib.nullcontext() + if self.offload_attn_proj: + core_attn_out = group_prefetch_offload_start(core_attn_out, name="attn_proj") + offload_context = PipelineOffloadManager.get_instance() + with offload_context: + output, bias = self.linear_proj(core_attn_out) + if self.offload_attn_proj: + output, bias = group_prefetch_offload_commit(output, bias, name="attn_proj", release_tensors=[core_attn_out]) + offload_context = contextlib.nullcontext() nvtx_range_pop(suffix="linear_proj") return output, bias diff --git a/megatron/core/transformer/cpu_offload.py b/megatron/core/transformer/cpu_offload.py new file mode 100644 index 00000000000..8dd5f139884 --- /dev/null +++ b/megatron/core/transformer/cpu_offload.py @@ -0,0 +1,513 @@ +from collections import deque, defaultdict +import torch +from typing import Any +from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.cpu_offload import AsyncDoubleBufferGroupOffloadHandler + +# cpu offload for pipeline +DEBUG = False +DEBUG_RANK = 0 +MIN_OFFLOADED_TENSOR_SIZE = 1024 * 1024 + +def print_rank(message): + assert torch.distributed.is_initialized() + if DEBUG and torch.distributed.get_rank() == DEBUG_RANK: + print(message, flush=True) + +def set_ideal_affinity_for_current_gpu(): + import cuda.cuda + import cuda.cudart + import pynvml + import uuid + err, device_id = cuda.cudart.cudaGetDevice() + assert err == cuda.cudart.cudaError_t.cudaSuccess + err, device_uuid = cuda.cuda.cuDeviceGetUuid(device_id) + assert err == cuda.cuda.CUresult.CUDA_SUCCESS + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByUUID("GPU-" + str(uuid.UUID(bytes=device_uuid.bytes))) + pynvml.nvmlDeviceSetCpuAffinity(handle) + +class PipelineOffloadManager: + OFFLOAD_MGR = None + @classmethod + def get_instance(cls): + if cls.OFFLOAD_MGR is None: + cls.OFFLOAD_MGR = PipelineOffloadManager() + return cls.OFFLOAD_MGR + + def __init__(self): + from megatron.core import parallel_state + self._queue = deque() + if parallel_state.get_virtual_pipeline_model_parallel_world_size() is None: + self._vpp = 1 + else: + self._vpp = parallel_state.get_virtual_pipeline_model_parallel_world_size() + + # cache vpp - 1 stages + self._stages = [[] for _ in range(self._vpp)] + # allocate streams and events for synchronization + self._d2h_stream = torch.cuda.Stream() + self._h2d_stream = torch.cuda.Stream() + self.reset() + + @property + def d2h_stream(self): + return self._d2h_stream + + @property + def h2d_stream(self): + return self._h2d_stream + + def reset(self): + set_ideal_affinity_for_current_gpu() + self._inside_context = False + self._cur_forward_chunk = None + self._cur_backward_chunk = None + self._first_last_vpp_rank = True + + def flush(self): + # put into the queue in the backward order + if len(self._stages[0]) == len(self._stages[-1]): + lens = [len(e) for e in self._stages] + assert min(lens) == max(lens) + self._stages[-1] = [] + for chunks in reversed(self._stages): + for chunk in chunks: + self.push(chunk) + for i in range(self._vpp): + self._stages[i] = [] + + def push(self, handler): + print_rank(f"pushing handler {handler}") + self._queue.append(handler) + + def pop(self): + assert self.size() + while self._queue: + self._cur_backward_chunk = self._queue.popleft() + if not isinstance(self._cur_backward_chunk, NullChunkOffloadHandler): + break + print_rank(f"popping handler {self._cur_backward_chunk}") + + def front(self): + if not len(self._queue): + return None + for chunk_handler in self._queue: + if not isinstance(chunk_handler, NullChunkOffloadHandler): + return chunk_handler + return None + + def size(self): + return len(self._queue) + + def reset_chunk_handler(self, num_layer, vp_stage, offload=True, num_dense_layer=0, last_stage_is_loss=False): + if vp_stage is None: + cur_vpp_rank = 0 + else: + cur_vpp_rank = vp_stage + + if last_stage_is_loss: + from megatron.core import parallel_state + vpp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() + # skip the last stage + if cur_vpp_rank == vpp_size - 1: + return + # reduce the vpp size + if self._vpp == vpp_size: + self._vpp -= 1 + self._stages = self._stages[:-1] + + first_last_vpp_rank = self._first_last_vpp_rank + # rewind + if cur_vpp_rank == self._vpp - 1: + self.flush() + first_last_vpp_rank = first_last_vpp_rank and (cur_vpp_rank == self._vpp - 1) + # If the model chunk contains only the dense layers, initialize a null chunk handler. + if num_layer <= num_dense_layer: + cur_chunk = NullChunkOffloadHandler(num_layer, first_last_vpp_rank, offload) + else: + cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload) + # save for latter push + self._stages[cur_vpp_rank].append(cur_chunk) + if cur_vpp_rank == self._vpp - 1: + self._first_last_vpp_rank = False + self.push(cur_chunk) + self.flush() + self._cur_forward_chunk = cur_chunk + cur_chunk.vpp_rank = cur_vpp_rank + + def set_last_layer(self, is_last_layer): + self._cur_forward_chunk.is_last_layer = is_last_layer + + def cur_forward_chunk(self): + return self._cur_forward_chunk + + def cur_backward_chunk(self): + return self._cur_backward_chunk + + def __enter__(self): + print_rank("__enter__") + self.OFFLOAD_MGR + self.inside_context = True + + if not isinstance(self.cur_forward_chunk(), NullChunkOffloadHandler): + torch._C._autograd._push_saved_tensors_default_hooks( + self.on_save_for_backward, self.on_get_saved_tensor + ) + + def __exit__(self, *args: Any): + print_rank("__exit__") + self.inside_context = False + if not isinstance(self.cur_forward_chunk(), NullChunkOffloadHandler): + torch._C._autograd._pop_saved_tensors_default_hooks() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + print_rank(f"on_save_for_backward {tensor.shape}") + assert self.inside_context + return self.cur_forward_chunk().tensor_push(tensor) + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + print_rank(f"on_get_saved_tensor {saved_state}") + return self.cur_backward_chunk().tensor_pop(saved_state) + + +class ChunkOffloadHandler(AsyncDoubleBufferGroupOffloadHandler): + @staticmethod + def offload(src_tensor, pin_memory=True): + """Offload.""" + print_rank("offload") + fp8_offload = isinstance(src_tensor, Float8Tensor) + + cpu_backup = torch.empty( + src_tensor.size(), + dtype=torch.uint8 if fp8_offload else src_tensor.dtype, + layout=src_tensor.layout, + device="cpu", + pin_memory=pin_memory, + ) + + if fp8_offload: + cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup) + + if not src_tensor.is_contiguous(): + src_tensor = src_tensor.contiguous() + + cpu_backup.copy_(src_tensor, non_blocking=pin_memory) + state = (src_tensor.device, cpu_backup) + return state + + @staticmethod + def reload(state, non_blocking=None): + """Reload.""" + print_rank("reload") + dev, cpu_backup = state + if non_blocking is None: + non_blocking = cpu_backup.is_pinned() + return cpu_backup.to(dev, non_blocking=non_blocking) + + def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True): + self._num_layers = num_layer + # Data Structure to maintain reference to activation tensors + self._tensor_tag_to_state = {} + # Tracking the number of layers offloaded + # self._offloaded_group_count = 0 + self._is_first_last_vpp_chunk = is_first_last_vpp_chunk + + self._offloaded_group_index = 0 + self._groups_to_offload = [] + self._groups_to_reload = [] + self._layer_index = 0 + self._tensor_count_current_group = 0 + self.multi_input_offload_count = False + # self.offload_count_per_layer = defaultdict(int) + + self.torch_tensor_count = 0 + self.d2h_stream = PipelineOffloadManager.get_instance().d2h_stream + self.h2d_stream = PipelineOffloadManager.get_instance().h2d_stream + self._offload_events = {} + self._reload_events = {} + self.do_offload = offload + self.is_last_layer = False + + + def is_first_last_layer(self): + """Do not offload the last layer of the last pp stage.""" + print_rank(f"is_first_last_layer {self._is_first_last_vpp_chunk} {self.is_last_layer}") + return self._is_first_last_vpp_chunk and self.is_last_layer + + def tensor_push(self, tensor): + print_rank("tensor_push") + torch_stray_tensor = isinstance( + tensor, + ( + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), + ) + + if not torch_stray_tensor:# True + # obtain a unique tensor tag + tensor_tag = (self._offloaded_group_index, self._tensor_count_current_group) + self._tensor_count_current_group += 1 + assert tensor_tag not in self._tensor_tag_to_state + self._tensor_tag_to_state[tensor_tag] = tensor + else: + tensor_tag = (-1, self.torch_tensor_count) + self.torch_tensor_count += 1 + self._tensor_tag_to_state[tensor_tag] = tensor + print_rank(f"tensor_push {tensor_tag}") + return tensor_tag + + def tensor_pop(self, tensor_tag): + print_rank(f"tensor_pop {tensor_tag}") + assert tensor_tag in self._tensor_tag_to_state, f"{tensor_tag}, {self._tensor_tag_to_state.keys()}" + tensor = self._tensor_tag_to_state.pop(tensor_tag) + assert not isinstance(tensor, tuple) + print_rank(f"tensor_pop {tensor.shape}") + return tensor + + def tensor_need_offloading_checker(self, tensor): + """Check if the tensor needs to be offloaded.""" + if tensor.numel() < MIN_OFFLOADED_TENSOR_SIZE: + return False + if hasattr(tensor, "offloading_activation") and not tensor.offloading_activation: + return False + return True + + def bulk_offload_group(self, group_to_offload): + """offload a group of tensors recorded in tensor_push(). + """ + print_rank("bulk_offload_group") + if not self.do_offload: + return + assert not self.is_first_last_layer() + group_id_to_offload, name = group_to_offload + torch.cuda.nvtx.range_push(name) + with torch.cuda.stream(self.d2h_stream): + for tensor_tag, state in self._tensor_tag_to_state.items(): + group_id, _ = tensor_tag + if group_id == group_id_to_offload: + print_rank(f"tensor_tag {tensor_tag}") + print_rank(f"group_to_offload {group_to_offload}") + assert not isinstance(state, tuple) + tensor_on_device = state + if self.tensor_need_offloading_checker(tensor_on_device): + state = self.offload(tensor_on_device) + # TODO: check if we really need it. + # Record the last offloading event for this group, + # which is used to avoid reloading before offloading. + event = torch.cuda.Event() + event.record(self.d2h_stream) + self._offload_events[name] = event + tensor_on_device.record_stream(self.d2h_stream) + self._tensor_tag_to_state[tensor_tag] = state + print_rank("exit bulk_offload_group") + torch.cuda.nvtx.range_pop() + + def get_offload_event(self, name): + if name in self._offload_events: + return self._offload_events[name] + else: + return None + + def get_reload_event(self, name): + if name in self._reload_events: + return self._reload_events[name] + else: + return None + + def bulk_reload_group(self, group_to_reload): + """Bulk reload group.""" + print_rank("bulk_reload_group") + if not self.do_offload: + return + found_reload_group = False + group_id_to_reload, name = group_to_reload + torch.cuda.nvtx.range_push(name) + with torch.cuda.stream(self.h2d_stream): + # move back tensors + for tensor_label, state in self._tensor_tag_to_state.items(): + group_id, _ = tensor_label + if group_id == group_id_to_reload: + print_rank(f"tensor_label {tensor_label}") + found_reload_group = True + event = self.get_offload_event(name) + if isinstance(state, tuple): + # make sure the tensor is already offloaded to cpu before reloading it. + torch.cuda.current_stream().wait_event(event) + recovered_tensor = self.reload(state) + event.record(self.h2d_stream) + self._reload_events[name] = event + print_rank(f"recovered_tensor {recovered_tensor.shape}") + self._tensor_tag_to_state[tensor_label] = recovered_tensor + torch.cuda.nvtx.range_pop() + return found_reload_group + + def pre_reload_last_layer(self): + """Pre-reload the last layer of the next model chunk.""" + print_rank("pre_reload_last_layer") + if not self.do_offload: + return + assert not self._is_first_last_vpp_chunk + print_rank(f"len(self._groups_to_reload) {len(self._groups_to_reload)}") + if len(self._groups_to_reload) > 0: + if self.bulk_reload_group(self._groups_to_reload[-1]): + self._groups_to_reload.pop() + + def should_bulk_offload(self): + """Check if the chunk should be offloaded.""" + if not self.do_offload: + return False + # first backward chunk + if self.is_first_last_layer(): + return False + + # if next backward chunk is this chunk (for last pp stage) + next_backward_chunk = PipelineOffloadManager.get_instance().get_instance().front() + if next_backward_chunk is not None and next_backward_chunk is self: + if self.is_last_layer: + return False + + return True + + def bulk_offload(self, release_tensors): + print_rank("bulk_offload") + if self.should_bulk_offload(): + group_to_offload = self._groups_to_offload.pop() + name = group_to_offload[1] + self._groups_to_reload.append(group_to_offload) + self.bulk_offload_group(group_to_offload) + if len(release_tensors) > 0: + cur_stream = torch.cuda.current_stream() + for release_tensor in release_tensors: + release_tensor.record_stream(cur_stream) + release_tensor.untyped_storage().resize_(0) + print_rank("exit bulk_offload") + + def on_group_commit_forward(self, release_tensors): + """Offload a group of tensors.""" + print_rank("on_group_commit_forward") + # wait for the compute stream for offloading + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + self.bulk_offload(release_tensors) + print_rank("exit on_group_commit_forward") + + def bulk_reload(self): + print_rank("bulk_reload") + if len(self._groups_to_reload) > 0: + # load next layer + if self.bulk_reload_group(self._groups_to_reload[-1]): + print_rank(f"bulk_reload_group {self._groups_to_reload}") + self._groups_to_reload.pop() + else: + # load the last layer of one backward chunk in advance + next_backward_chunk = PipelineOffloadManager.get_instance().front() + if next_backward_chunk is not None: + next_backward_chunk.pre_reload_last_layer() + + def on_group_commit_backward(self, name): + """Prepare for reloadingthe next group of tensors.""" + print_rank("on_group_commit_backward") + cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() + if not cur_backward_chunk is self: + PipelineOffloadManager.get_instance().pop() + cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk() + assert cur_backward_chunk is self + # make sure the reloading jobs for current computation are done. + event = self.get_reload_event(name) + if event is not None: + torch.cuda.current_stream().wait_event(event) + self._offloaded_group_index = self._offloaded_group_index - 1 + + def on_group_start_forward(self, name): + """Prepare for offloading the next group of tensors.""" + print_rank(f"on_group_start_forward {self._layer_index} {self._num_layers}") + self._offloaded_group_index = self._offloaded_group_index + 1 + self._tensor_count_current_group = 0 + self._groups_to_offload.append((self._offloaded_group_index, name)) + + def on_group_start_backward(self): + """Reload the next group of tensors.""" + print_rank("on_group_start_backward") + self.h2d_stream.wait_stream(torch.cuda.current_stream()) + self.bulk_reload() + +class NullChunkOffloadHandler(ChunkOffloadHandler): + pass + +class GroupCommitFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, *args): + # pylint: disable=missing-function-docstring + print_rank("GroupCommitFunction forward") + + release_tensors = args[-1] + name = args[-2] + cpu_offload_handler = args[-3] + tensor = args[:-3] + if not isinstance(cpu_offload_handler, NullChunkOffloadHandler): + cpu_offload_handler.on_group_commit_forward(release_tensors) + ctx.cpu_offload_handler = cpu_offload_handler + ctx.name = name + + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, *grad_output): + # pylint: disable=missing-function-docstring + print_rank("GroupCommitFunction backward") + + cpu_offload_handler = ctx.cpu_offload_handler + if not isinstance(cpu_offload_handler, NullChunkOffloadHandler): + cpu_offload_handler.on_group_commit_backward(ctx.name) + return grad_output + (None, None, None) + + +def group_prefetch_offload_commit(*tensor, name, release_tensors=[]): + """Specify the tensors to be released after offloading. + release_tensors is a list of tensors to be released after offloading. + The tensors will be untyped_storage().resize_(0) after offloading. + Note: specify the tensors only when they are not automatically released by torch gc. + """ + cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() + return GroupCommitFunction.apply(*tensor, cur_forward_chunk, name, release_tensors) + + +class GroupStartFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, tensor, cpu_offload_handler, name): + # pylint: disable=missing-function-docstring + ctx.cpu_offload_handler = cpu_offload_handler + print_rank("GroupStartFunction forward") + + if not isinstance(cpu_offload_handler, NullChunkOffloadHandler): + cpu_offload_handler.on_group_start_forward("activation offloading " + name) + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, grad_output): + print_rank("GroupStartFunction backward") + # pylint: disable=missing-function-docstring + cpu_offload_handler = ctx.cpu_offload_handler + if not isinstance(cpu_offload_handler, NullChunkOffloadHandler): + cpu_offload_handler.on_group_start_backward() + return grad_output, None, None + + +def group_prefetch_offload_start(tensor, name=None): + cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk() + return GroupStartFunction.apply(tensor, cur_forward_chunk, name) diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index fc741aa46f3..eddf61e6f8a 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -10,6 +10,7 @@ import torch import torch.nn.functional as F from torch.nn.parameter import Parameter +import contextlib from megatron.core import parallel_state, tensor_parallel from megatron.core.activations import squared_relu @@ -41,6 +42,11 @@ make_sharded_object_for_checkpoint, sharded_state_dict_default, ) +from megatron.core.transformer.cpu_offload import ( + PipelineOffloadManager, + group_prefetch_offload_start, + group_prefetch_offload_commit, +) try: import transformer_engine as te # pylint: disable=unused-import @@ -805,6 +811,16 @@ def __init__( tp_group=pg_collection.expt_tp, ) + self.offload_expert_fc1 = ( + self.config.fine_grained_activation_offloading + and "expert_fc1" in self.config.offload_modules + ) + + self.offload_moe_act = ( + self.config.fine_grained_activation_offloading + and "moe_act" in self.config.offload_modules + ) + self.activation_recompute = ( self.config.recompute_granularity == 'selective' and "moe_act" in self.config.recompute_modules @@ -813,6 +829,11 @@ def __init__( from megatron.core.extensions.transformer_engine import set_save_original_input set_save_original_input(self.linear_fc2) + + # This is to avoid the CPU overhead of multiple d2h copies + if self.offload_expert_fc1: + from megatron.core.extensions.transformer_engine import set_save_original_input + set_save_original_input(self.linear_fc1) if self.config.fp8: assert HAVE_TE, "FP8 requires TE." @@ -858,9 +879,17 @@ def forward( # Probs already applied, so reset to 1. permuted_probs = torch.ones_like(permuted_probs) - intermediate_parallel, bias_parallel = self.linear_fc1( - permuted_local_hidden_states, tokens_per_expert + offload_context = contextlib.nullcontext() + if self.offload_expert_fc1: + permuted_local_hidden_states = group_prefetch_offload_start(permuted_local_hidden_states, name="expert_fc1") + offload_context = PipelineOffloadManager.get_instance() + with offload_context: + fc1_output, bias_parallel = self.linear_fc1( + permuted_local_hidden_states, tokens_per_expert ) + if self.offload_expert_fc1: + fc1_output, bias_parallel = group_prefetch_offload_commit(fc1_output, bias_parallel, name="expert_fc1", release_tensors=[]) + offload_context = contextlib.nullcontext() def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs): if self.config.use_te_activation_func: @@ -931,18 +960,29 @@ def glu(x): intermediate_parallel = intermediate_parallel.to(original_dtype) return intermediate_parallel + if self.offload_moe_act: + fc1_output = group_prefetch_offload_start(fc1_output, name="moe_act") + offload_context = PipelineOffloadManager.get_instance() + if self.activation_recompute: self.activation_checkpoint = tensor_parallel.CheckpointWithoutOutput() - intermediate_parallel = self.activation_checkpoint.checkpoint( - bias_act_func, intermediate_parallel, bias_parallel, permuted_probs - ) - output, output_bias = self.linear_fc2(intermediate_parallel, tokens_per_expert) - self.activation_checkpoint.discard_output_and_register_recompute(output) + with offload_context: + bias_act_output = self.activation_checkpoint.checkpoint( + bias_act_func, fc1_output, bias_parallel, permuted_probs + ) else: - intermediate_parallel = bias_act_func( - intermediate_parallel, bias_parallel, permuted_probs - ) - output, output_bias = self.linear_fc2(intermediate_parallel, tokens_per_expert) + with offload_context: + bias_act_output = bias_act_func( + fc1_output, bias_parallel, permuted_probs + ) + + output, output_bias = self.linear_fc2(bias_act_output, tokens_per_expert) + if self.activation_recompute: + self.activation_checkpoint.discard_output_and_register_recompute(output) + if self.offload_moe_act: + output, = group_prefetch_offload_commit(output, name="moe_act", release_tensors=[]) + offload_context = contextlib.nullcontext() + # upad and concat the output if self.config.fp8: diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py index 9e6b46fd4e5..3a16f0b5377 100644 --- a/megatron/core/transformer/multi_latent_attention.py +++ b/megatron/core/transformer/multi_latent_attention.py @@ -37,6 +37,12 @@ from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import MLATransformerConfig from megatron.core.utils import deprecate_inference_params, is_te_min_version +from megatron.core.transformer.cpu_offload import ( + PipelineOffloadManager, + group_prefetch_offload_start, + group_prefetch_offload_commit, +) +import contextlib try: from megatron.core.fusions.fused_mla_yarn_rope_apply import ( @@ -264,8 +270,14 @@ def forward( query, key, value, attention_mask, packed_seq_params=packed_seq_params ) else: + offload_context = contextlib.nullcontext() + if self.offload_core_attention and self.training: + query = group_prefetch_offload_start(query, name="core_attn") + offload_context = PipelineOffloadManager.get_instance() + if inference_context is None or inference_context.is_static_batching(): - core_attn_out = self.core_attention( + with offload_context: + core_attn_out = self.core_attention( query, key, value, @@ -293,6 +305,9 @@ def forward( # Only rearrange if not in absorption mode (Flash MLA handles format correctly) if not inference_context.is_decode_only(): core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)') + if self.offload_core_attention and self.training: + core_attn_out, = group_prefetch_offload_commit(core_attn_out, name="core_attn", release_tensors=[query, key, value]) + offload_context = contextlib.nullcontext() # We are doing absorption with cache mla latents and decode mode. if self.cache_mla_latents and inference_context.is_decode_only(): @@ -318,7 +333,15 @@ def forward( # ================= # Output. [sq, b, h] # ================= - output, bias = self.linear_proj(core_attn_out) + offload_context = contextlib.nullcontext() + if self.offload_attn_proj: + core_attn_out = group_prefetch_offload_start(core_attn_out, name="attn_proj") + offload_context = PipelineOffloadManager.get_instance() + with offload_context: + output, bias = self.linear_proj(core_attn_out) + if self.offload_attn_proj: + output, bias = group_prefetch_offload_commit(output, bias, name="attn_proj", release_tensors=[core_attn_out]) + offload_context = contextlib.nullcontext() return output, bias diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py index c14f8ea1f74..867b3689a1d 100755 --- a/megatron/core/transformer/transformer_block.py +++ b/megatron/core/transformer/transformer_block.py @@ -33,6 +33,7 @@ get_pg_rank, make_viewless_tensor, ) +from megatron.core.transformer.cpu_offload import PipelineOffloadManager try: import transformer_engine.pytorch as te # pylint: disable=unused-import @@ -324,6 +325,12 @@ def __init__( self._build_layers() self.num_layers_per_pipeline_rank = len(self.layers) + self.num_dense_layer = 0 + from megatron.core.transformer.moe.moe_layer import MoELayer + for layer in self.layers: + if not isinstance(layer.mlp, MoELayer): + self.num_dense_layer += 1 + def _build_layers(self): # Transformer layers. # @jcasper can we improve how we deal with layer_number? @@ -639,6 +646,11 @@ def forward( inner_quantization_context = nullcontext() else: inner_quantization_context = nullcontext() + + if l_no == self.num_layers_per_pipeline_rank - 1: + PipelineOffloadManager.get_instance().set_last_layer(True) + else: + PipelineOffloadManager.get_instance().set_last_layer(False) with self.offload_context, inner_quantization_context: hidden_states, context = layer( diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 7272c7a136a..561226d1ec2 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -700,6 +700,26 @@ class TransformerConfig(ModelParallelConfig): """Transformer implementation to use. Options are 'transformer_engine' for Transformer Engine and 'local' for MCore.""" + ##################################### + # Fine-grained Activation Offloading + ##################################### + fine_grained_activation_offloading: bool = False + """If True, offload the activation to the CPU.""" + + offload_modules: Optional[list[str]] = None + """The submodules to offload. + choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act". + default: ["core_attn"]. + "attn_norm": offload the input of the normalization in the attention part. + "core_attn": offload the input of the core attention part. + "mlp_norm": offload the input of the normalization in the mlp part. + "attn_proj": offload the input of the attn linear projection part. + "expert_fc1": offload the input of the expert fc1 part. + "moe_act": offload the input of the moe act part. + """ + last_vp_stage_is_loss: bool = False + """If True, the last virtual pipeline stage is the loss stage.""" + def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more @@ -990,6 +1010,34 @@ def __post_init__(self): if "moe" not in self.recompute_modules: self.recompute_modules.append("moe") + # if self.offload_modules is None: + # self.offload_modules = ["core_attn"] + + if len(self.offload_modules) > 0: + allowed_modules = { + "core_attn", "attn_proj", "expert_fc1", "moe_act", "attn_norm", "mlp_norm" + } + invalid_modules = set(self.offload_modules) - allowed_modules + assert not invalid_modules, ( + f'Invalid choices for offload_modules: {invalid_modules}. ' + f'Allowed modules are: {allowed_modules}' + ) + if "attn_proj" in self.offload_modules and "core_attn" not in self.offload_modules: + raise ValueError( + "attn_proj cannot be set to offload_modules alone without core_attn " + "because the input of attn_proj is the output of core_attn, " + "which is needed in core_attn.backward()." + ) + + if isinstance(self.moe_layer_freq, int): + assert self.moe_layer_freq == 1, "moe_layer_freq cannot be an integer other than 1 when offload_modules is set." + elif isinstance(self.moe_layer_freq, list): + if 0 in self.moe_layer_freq: + warnings.warn( + "Activation of dense layer won't be offloaded at all for mixed dense and moe layer." + ) + + if ( self.num_layers_in_first_pipeline_stage is not None or self.num_layers_in_last_pipeline_stage is not None @@ -1171,6 +1219,13 @@ def __post_init__(self): f"{self.virtual_pipeline_model_parallel_size}" ) + if len(self.offload_modules) > 0: + if self.pipeline_model_parallel_layout is not None: + from megatron.core.transformer.pipeline_parallel_layer_layout import LayerType + if (len(self.pipeline_model_parallel_layout.layout[-1][-1]) == 1 and + self.pipeline_model_parallel_layout.layout[-1][-1][0] is LayerType.loss): + self.last_vp_stage_is_loss = True + if self.apply_query_key_layer_scaling: self.attention_softmax_in_fp32 = True diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py index 42d077e2e9f..ed5ace3fa6c 100644 --- a/megatron/core/transformer/transformer_layer.py +++ b/megatron/core/transformer/transformer_layer.py @@ -9,6 +9,7 @@ import torch import torch.distributed from torch import Tensor +import contextlib from megatron.core import parallel_state, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict @@ -31,6 +32,11 @@ nvtx_range_pop, nvtx_range_push, ) +from megatron.core.transformer.cpu_offload import ( + PipelineOffloadManager, + group_prefetch_offload_start, + group_prefetch_offload_commit, +) logger = logging.getLogger(__name__) @@ -427,6 +433,20 @@ def __init__( if "mlp" in self.config.recompute_modules: if not isinstance(self.mlp, MoELayer): self.recompute_mlp = True + self.offload_self_attn = ( + self.config.fine_grained_activation_offloading + and "self_attn" in self.config.offload_modules + ) + self.offload_attn_norm = ( + self.config.fine_grained_activation_offloading + and "attn_norm" in self.config.offload_modules + and not isinstance(self.input_layernorm, IdentityOp) + ) + self.offload_mlp_norm = ( + self.config.fine_grained_activation_offloading + and "mlp_norm" in self.config.offload_modules + and not isinstance(self.pre_mlp_layernorm, IdentityOp) + ) # @jcasper how should we handle nvfuser? # Set bias+dropout+add fusion grad_enable execution handler. @@ -510,18 +530,29 @@ def _forward_attention( # Residual connection. residual = hidden_states + offload_context = contextlib.nullcontext() + if self.offload_attn_norm: + hidden_states = group_prefetch_offload_start(hidden_states, name="attn_norm") + offload_context = PipelineOffloadManager.get_instance() # Optional Input Layer norm if self.recompute_input_layernorm: self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - input_layernorm_output = self.input_layernorm_checkpoint.checkpoint( - self.input_layernorm, hidden_states - ) + with offload_context: + input_layernorm_output = self.input_layernorm_checkpoint.checkpoint( + self.input_layernorm, hidden_states + ) else: - input_layernorm_output = self.input_layernorm(hidden_states) + with offload_context: + input_layernorm_output = self.input_layernorm(hidden_states) # Self attention. nvtx_range_push(suffix="self_attention") - attention_output_with_bias = self.self_attention( + offload_context = contextlib.nullcontext() + if self.offload_self_attn: + input_layernorm_output = group_prefetch_offload_start(input_layernorm_output, name="self_attn") + offload_context = PipelineOffloadManager.get_instance() + with offload_context: + attention_output_with_bias = self.self_attention( input_layernorm_output, attention_mask=attention_mask, inference_context=inference_context, @@ -532,6 +563,9 @@ def _forward_attention( packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, ) + if self.offload_self_attn: + attention_output_with_bias, = group_prefetch_offload_commit(attention_output_with_bias, name="self_attn", release_tensors=[input_layernorm_output]) + offload_context = contextlib.nullcontext() nvtx_range_pop(suffix="self_attention") if self.recompute_input_layernorm: @@ -550,6 +584,10 @@ def _forward_attention( ) nvtx_range_pop(suffix="self_attn_bda") + if self.offload_attn_norm: + hidden_states, = group_prefetch_offload_commit(hidden_states, name="attn_norm", release_tensors=[residual]) + offload_context = contextlib.nullcontext() + # Residual connection. residual = hidden_states @@ -590,14 +628,20 @@ def _forward_mlp(self, hidden_states, inference_context=None): # Residual connection. residual = hidden_states + offload_context = contextlib.nullcontext() + if self.offload_mlp_norm: + hidden_states = group_prefetch_offload_start(hidden_states, name="mlp_norm") + offload_context = PipelineOffloadManager.get_instance() # Optional Layer norm post the cross-attention. if self.recompute_pre_mlp_layernorm: self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput() - pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint( - self.pre_mlp_layernorm, hidden_states - ) + with offload_context: + pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint( + self.pre_mlp_layernorm, hidden_states + ) else: - pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) + with offload_context: + pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states) nvtx_range_push(suffix="mlp") # Potentially chunk the MLP computation during prefill to minimize the peak activation size @@ -657,6 +701,9 @@ def _forward_mlp(self, hidden_states, inference_context=None): mlp_output_with_bias, residual, self.hidden_dropout ) nvtx_range_pop(suffix="mlp_bda") + if self.offload_mlp_norm: + hidden_states, = group_prefetch_offload_commit(hidden_states, name="mlp_norm", release_tensors=[residual]) + offload_context = contextlib.nullcontext() # Jit compiled function creates 'view' tensor. This tensor # potentially gets saved in the MPU checkpoint function context, diff --git a/megatron/core/utils.py b/megatron/core/utils.py index e66dab8abd8..7e1b9d9251a 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -136,7 +136,7 @@ def wrapped_func(*args, **kwargs): if config.is_experimental_enabled() is not True: raise ExperimentalNotEnabledError(f"Flag config.ENABLE_EXPERIMENTAL not enabled.") - logger.info("Setting ENABLE_EXPERIMENTAL=True will run experimental code.") + # logger.info("Setting ENABLE_EXPERIMENTAL=True will run experimental code.") return func(*args, **kwargs) @@ -213,7 +213,7 @@ def guard(super: super, attr: str): f"Flag config.ENABLE_EXPERIMENTAL not enabled." ) - logger.info("Setting ENABLE_EXPERIMENTAL=True will run experimental code.") + # logger.info("Setting ENABLE_EXPERIMENTAL=True will run experimental code.") return super.__getattribute__(attr) class ClassInterceptor(type): diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index c8102c48cf4..5b4ab369840 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2189,7 +2189,10 @@ def _add_training_args(parser): help='The communicator group names to use high priority streams.') group.add_argument('--use-te-activation-func', action='store_true', help='Use activation function kernel from Transformer Engine in MLP module.') - + group.add_argument('--fine-grained-activation-offloading', action='store_true', + help='Offload the activation to the CPU.') + group.add_argument('--offload-modules', nargs='*', type=str, default=[], + help='The submodules to offload. Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".') return parser diff --git a/megatron/training/training.py b/megatron/training/training.py index f4a3eb9bef6..bc141e0c4dd 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -589,6 +589,11 @@ def pretrain( args = get_args() timers = get_timers() + if args.profile and torch.distributed.get_rank() in args.profile_ranks: + torch.cuda.memory._record_memory_history( + max_entries=1000000, + ) + if args.log_progress: append_to_progress_log("Starting job") @@ -1625,7 +1630,7 @@ def training_log( total_loss_dict[skipped_iters_key] = 0 total_loss_dict[nan_iters_key] = 0 print_rank_last(log_string) - if report_memory_flag: + if report_memory_flag and iteration == 10: # Report memory after optimizer state has been initialized. if torch.distributed.get_rank() == 0: num_microbatches = get_num_microbatches() @@ -2173,6 +2178,15 @@ def get_e2e_base_metrics(): # Run training iterations till done. buffered_rollouts = None while iteration < args.train_iters: + if args.profile and torch.distributed.get_rank() in args.profile_ranks and iteration == 3: + try: + comment = os.getenv("COMMENT") + model_name = os.getenv("MODEL") + memory_snapshot_path = f"/lustre/fsw/coreai_devtech_all/hongbinl/cpu_offloading/megatron-moe-scripts/pyt_profile/{model_name}_{comment}" + torch.cuda.memory._dump_snapshot(f"{memory_snapshot_path}.pickle") + print_rank_0(f"Captured memory snapshot at {memory_snapshot_path}.pickle") + except Exception as e: + print_rank_0(f"Failed to capture memory snapshot {e}") if args.profile and torch.distributed.get_rank() in args.profile_ranks: if args.use_pytorch_profiler: prof.step()