+
+# What is Fine-grained Activation Offloading?
+
+Memory capacity are more and more important with the rising of extreme sparse MoE models like DeepSeek-V3 and Qwen3-235B. Fine-grained Activation Offloading targets at offloading the activation at the granularity of specific modules, so that we can calibrate the amount of offloading activation to maximize the training throughput.
+
+# Quick Start
+
+```bash
+# Enable fine-grained activation offloading
+--fine-grained-activation-offloading
+
+# Specify which modules are going to be offloaded
+# Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".
+--offload-modules core_attn
+```
+
+# Current status
+## Features
+* Support PP=1/PP/Interleaved PP
+* Compatible with fine-grained recomputation
+* Support FP8
+* Support MTP
+* Support mixed dense & moe layer
+* Support A2A Overlap
+* Support CUDA Graph
+ * (Temporary) cuda graph scope cannot contains the offloading modules
+
+## Known issues
+* We explicitly resize some tensors to 0 to release the memory space immediately, which sometimes leads to illegal memory access. Please remove the released tensors in `group_prefetch_offload_commit` if you run into the issue.
+
+## WIP items
+* Code refactor
+* Benchmark
+
+# Methodology
+
+## Offload/Reload the input of one module to/from CPU
+Let's take the attention projection module as an example:
+```
+nvtx_range_push(suffix="linear_proj")
+offload_context = contextlib.nullcontext()
+if self.offload_attn_proj:
+ core_attn_out = group_prefetch_offload_start(core_attn_out, name="attn_proj")
+ offload_context = PipelineOffloadManager.get_instance()
+with offload_context:
+ output, bias = self.linear_proj(core_attn_out)
+if self.offload_attn_proj:
+ output, bias = group_prefetch_offload_commit(output, bias, release_tensors=[core_attn_out])
+ offload_context = contextlib.nullcontext()
+nvtx_range_pop(suffix="linear_proj")
+```
+The above code snippet could be divided into three parts in order:
+1. Mark the starting point of offloading a new module;
+2. Record the save_for_backward tensors in fprop and push it to a tensor buffer;
+3. Offload the recorded tensors after the module's fprop finishes;
+
+In bprop, the three parts above will:
+1. Make sure the offloaded tensors are reloaded back to GPU;
+2. Pop the corresponding tensors from the tensor buffer;
+3. Reload the corresponding tensors of next module;
+
+## Compatible with PP&Interleaved PP
+
+`PipelineOffloadManager` is used to manage the chunks across different model chunks in fprop and bprop.
+Before the model.forward() start, the `PipelineOffloadManager.get_instance().reset_chunk_handler` will be executed. In the fprop of this method, we create a `ChunkOffloadHandler` to handle the offloading context of one model chunk and then push it to a buffer, which will be popped out in a specific order in bprop.
+
+

+
+
+## Compatible with fine-grained recomputation
+
+

+
+
+## A special case: attn_norm/mlp_norm
+
+# Performance
+
+## H100
+
+### DeepSeek-V3-Proxy
+#### Model structure
+* Layer parameters are same as DeepSeek-V3 model
+* Layer number is cut off to 14 layers
+* Replace the fisrt 3 dense layers with 3 moe layers
+
+#### Key Hyper-parameters
+* TP1PP4EP16VPP1CP1-MBS1GBS512
+* bf16 training
+* DeepEP dispatcher
+* `--cross-entropy-loss-fusion` and `--cross-entropy-fusion-impl te`
+* `--moe-permute-fusion`
+* `--moe-router-fusion`
+* `--enable-experimental`
+
+#### Throughput and correctness
+
+

+

+
+
+#### Memory consumption
+
+Baseline (no offloading)
+```
+[Rank 0] (after 10 iterations) memory (MB) | allocated: 24761.02978515625 | max allocated: 65203.93359375 | reserved: 64438.0 | max reserved: 74306.0
+[Rank 16] (after 10 iterations) memory (MB) | allocated: 18907.728515625 | max allocated: 52228.1533203125 | reserved: 58770.0 | max reserved: 58770.0
+[Rank 32] (after 10 iterations) memory (MB) | allocated: 18907.7529296875 | max allocated: 45200.8349609375 | reserved: 51772.0 | max reserved: 51772.0
+[Rank 48] (after 10 iterations) memory (MB) | allocated: 29006.82275390625 | max allocated: 48166.263671875 | reserved: 56328.0 | max reserved: 56328.0
+```
+With offloading expert_fc1, moe_act, act_norm and mlp_norm
+```
+[Rank 0] (after 10 iterations) memory (MB) | allocated: 24705.02978515625 | max allocated: 48544.70849609375 | reserved: 61046.0 | max reserved: 61046.0
+[Rank 16] (after 10 iterations) memory (MB) | allocated: 18795.728515625 | max allocated: 38760.3876953125 | reserved: 46330.0 | max reserved: 46330.0
+[Rank 32] (after 10 iterations) memory (MB) | allocated: 18795.7529296875 | max allocated: 34950.2509765625 | reserved: 42452.0 | max reserved: 42452.0
+[Rank 48] (after 10 iterations) memory (MB) | allocated: 28950.82275390625 | max allocated: 41310.798828125 | reserved: 50408.0 | max reserved: 50408.0
+```
+
+### Qwen3-30B-A3B
+#### Model structure
+* Same as Qwen-30B model structure
+
+#### Results
+
+| Model | Mapping | Sequence length | Recompute | Offload | Throughput (tflops) | Memory (MB) |
+|---------------|--------------------------|-----------------|-----------|------------|---------------------|-------------|
+| Qwen3-30B-A3B | TP1PP1EP8VPP1_MBS1GBS256 | 4096 | / | / | 194 | 65308 |
+| | TP1PP1EP8VPP1_MBS1GBS256 | 8192 | full | / | 230 | 59566 |
+| | TP1PP2EP8VPP4_MBS1GBS256 | 8192 | layernorm | expert_fc1 | 255 | 64962 |
+
+
+
+## GB200
+
+# Acknowledgement
+
+This work refers to the previous work from Kuaishou: https://www.usenix.org/conference/atc24/presentation/yuan
diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py
index 6b2e898be6c..7a283ba6b46 100644
--- a/megatron/core/transformer/attention.py
+++ b/megatron/core/transformer/attention.py
@@ -6,6 +6,7 @@
import torch
from torch import Tensor
+import contextlib
from megatron.core import tensor_parallel
from megatron.core.inference.contexts import BaseInferenceContext
@@ -37,6 +38,11 @@
from .enums import AttnMaskType
from .transformer_config import TransformerConfig
+from megatron.core.transformer.cpu_offload import (
+ PipelineOffloadManager,
+ group_prefetch_offload_start,
+ group_prefetch_offload_commit,
+)
try:
from einops import rearrange
@@ -177,6 +183,21 @@ def __init__(
and "core_attn" in self.config.recompute_modules
)
+ self.offload_qkv_linear = (
+ self.config.fine_grained_activation_offloading
+ and "qkv_linear" in self.config.offload_modules
+ )
+
+ self.offload_core_attention = (
+ self.config.fine_grained_activation_offloading
+ and "core_attn" in self.config.offload_modules
+ )
+
+ self.offload_attn_proj = (
+ self.config.fine_grained_activation_offloading
+ and "attn_proj" in self.config.offload_modules
+ )
+
# Output.
self.linear_proj = build_module(
submodules.linear_proj,
@@ -668,7 +689,16 @@ def forward(
# Get the query, key and value tensors based on the type of attention -
# self or cross attn.
nvtx_range_push(suffix="qkv")
- query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
+ if self.offload_qkv_linear:
+ if not hidden_states.is_contiguous():
+ hidden_states = hidden_states.contiguous()
+ hidden_states = group_prefetch_offload_start(hidden_states, name="qkv_linear")
+ hidden_states.offloading_activation = True
+ with PipelineOffloadManager.get_instance():
+ query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
+ query, key, value = group_prefetch_offload_commit(query, key, value, name="qkv_linear", release_tensors=[hidden_states])
+ else:
+ query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states)
nvtx_range_pop(suffix="qkv")
# ===================================================
@@ -797,17 +827,22 @@ def forward(
packed_seq_params=packed_seq_params,
)
else:
+ offload_context = contextlib.nullcontext()
+ if self.offload_core_attention and self.training:
+ query = group_prefetch_offload_start(query, name="core_attn")
+ offload_context = PipelineOffloadManager.get_instance()
if inference_context is None or inference_context.is_static_batching():
# Static batching attention kernel.
- core_attn_out = self.core_attention(
- query,
- key,
- value,
- attention_mask,
- attn_mask_type=attn_mask_type,
- attention_bias=attention_bias,
- packed_seq_params=packed_seq_params,
- )
+ with offload_context:
+ core_attn_out = self.core_attention(
+ query,
+ key,
+ value,
+ attention_mask,
+ attn_mask_type=attn_mask_type,
+ attention_bias=attention_bias,
+ packed_seq_params=packed_seq_params,
+ )
else:
# Dynamic batching attention kernel.
@@ -827,6 +862,9 @@ def forward(
block_table,
)
core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)')
+ if self.offload_core_attention and self.training:
+ core_attn_out, = group_prefetch_offload_commit(core_attn_out, name="core_attn", release_tensors=[query, key, value])
+ offload_context = contextlib.nullcontext()
if packed_seq_params is not None and packed_seq_params.qkv_format == 'thd':
# reshape to same output shape as unpacked case
@@ -841,7 +879,15 @@ def forward(
# =================
nvtx_range_push(suffix="linear_proj")
- output, bias = self.linear_proj(core_attn_out)
+ offload_context = contextlib.nullcontext()
+ if self.offload_attn_proj:
+ core_attn_out = group_prefetch_offload_start(core_attn_out, name="attn_proj")
+ offload_context = PipelineOffloadManager.get_instance()
+ with offload_context:
+ output, bias = self.linear_proj(core_attn_out)
+ if self.offload_attn_proj:
+ output, bias = group_prefetch_offload_commit(output, bias, name="attn_proj", release_tensors=[core_attn_out])
+ offload_context = contextlib.nullcontext()
nvtx_range_pop(suffix="linear_proj")
return output, bias
diff --git a/megatron/core/transformer/cpu_offload.py b/megatron/core/transformer/cpu_offload.py
new file mode 100644
index 00000000000..8dd5f139884
--- /dev/null
+++ b/megatron/core/transformer/cpu_offload.py
@@ -0,0 +1,513 @@
+from collections import deque, defaultdict
+import torch
+from typing import Any
+from transformer_engine.pytorch.float8_tensor import Float8Tensor
+from transformer_engine.pytorch.cpu_offload import AsyncDoubleBufferGroupOffloadHandler
+
+# cpu offload for pipeline
+DEBUG = False
+DEBUG_RANK = 0
+MIN_OFFLOADED_TENSOR_SIZE = 1024 * 1024
+
+def print_rank(message):
+ assert torch.distributed.is_initialized()
+ if DEBUG and torch.distributed.get_rank() == DEBUG_RANK:
+ print(message, flush=True)
+
+def set_ideal_affinity_for_current_gpu():
+ import cuda.cuda
+ import cuda.cudart
+ import pynvml
+ import uuid
+ err, device_id = cuda.cudart.cudaGetDevice()
+ assert err == cuda.cudart.cudaError_t.cudaSuccess
+ err, device_uuid = cuda.cuda.cuDeviceGetUuid(device_id)
+ assert err == cuda.cuda.CUresult.CUDA_SUCCESS
+ pynvml.nvmlInit()
+ handle = pynvml.nvmlDeviceGetHandleByUUID("GPU-" + str(uuid.UUID(bytes=device_uuid.bytes)))
+ pynvml.nvmlDeviceSetCpuAffinity(handle)
+
+class PipelineOffloadManager:
+ OFFLOAD_MGR = None
+ @classmethod
+ def get_instance(cls):
+ if cls.OFFLOAD_MGR is None:
+ cls.OFFLOAD_MGR = PipelineOffloadManager()
+ return cls.OFFLOAD_MGR
+
+ def __init__(self):
+ from megatron.core import parallel_state
+ self._queue = deque()
+ if parallel_state.get_virtual_pipeline_model_parallel_world_size() is None:
+ self._vpp = 1
+ else:
+ self._vpp = parallel_state.get_virtual_pipeline_model_parallel_world_size()
+
+ # cache vpp - 1 stages
+ self._stages = [[] for _ in range(self._vpp)]
+ # allocate streams and events for synchronization
+ self._d2h_stream = torch.cuda.Stream()
+ self._h2d_stream = torch.cuda.Stream()
+ self.reset()
+
+ @property
+ def d2h_stream(self):
+ return self._d2h_stream
+
+ @property
+ def h2d_stream(self):
+ return self._h2d_stream
+
+ def reset(self):
+ set_ideal_affinity_for_current_gpu()
+ self._inside_context = False
+ self._cur_forward_chunk = None
+ self._cur_backward_chunk = None
+ self._first_last_vpp_rank = True
+
+ def flush(self):
+ # put into the queue in the backward order
+ if len(self._stages[0]) == len(self._stages[-1]):
+ lens = [len(e) for e in self._stages]
+ assert min(lens) == max(lens)
+ self._stages[-1] = []
+ for chunks in reversed(self._stages):
+ for chunk in chunks:
+ self.push(chunk)
+ for i in range(self._vpp):
+ self._stages[i] = []
+
+ def push(self, handler):
+ print_rank(f"pushing handler {handler}")
+ self._queue.append(handler)
+
+ def pop(self):
+ assert self.size()
+ while self._queue:
+ self._cur_backward_chunk = self._queue.popleft()
+ if not isinstance(self._cur_backward_chunk, NullChunkOffloadHandler):
+ break
+ print_rank(f"popping handler {self._cur_backward_chunk}")
+
+ def front(self):
+ if not len(self._queue):
+ return None
+ for chunk_handler in self._queue:
+ if not isinstance(chunk_handler, NullChunkOffloadHandler):
+ return chunk_handler
+ return None
+
+ def size(self):
+ return len(self._queue)
+
+ def reset_chunk_handler(self, num_layer, vp_stage, offload=True, num_dense_layer=0, last_stage_is_loss=False):
+ if vp_stage is None:
+ cur_vpp_rank = 0
+ else:
+ cur_vpp_rank = vp_stage
+
+ if last_stage_is_loss:
+ from megatron.core import parallel_state
+ vpp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size()
+ # skip the last stage
+ if cur_vpp_rank == vpp_size - 1:
+ return
+ # reduce the vpp size
+ if self._vpp == vpp_size:
+ self._vpp -= 1
+ self._stages = self._stages[:-1]
+
+ first_last_vpp_rank = self._first_last_vpp_rank
+ # rewind
+ if cur_vpp_rank == self._vpp - 1:
+ self.flush()
+ first_last_vpp_rank = first_last_vpp_rank and (cur_vpp_rank == self._vpp - 1)
+ # If the model chunk contains only the dense layers, initialize a null chunk handler.
+ if num_layer <= num_dense_layer:
+ cur_chunk = NullChunkOffloadHandler(num_layer, first_last_vpp_rank, offload)
+ else:
+ cur_chunk = ChunkOffloadHandler(num_layer, first_last_vpp_rank, offload)
+ # save for latter push
+ self._stages[cur_vpp_rank].append(cur_chunk)
+ if cur_vpp_rank == self._vpp - 1:
+ self._first_last_vpp_rank = False
+ self.push(cur_chunk)
+ self.flush()
+ self._cur_forward_chunk = cur_chunk
+ cur_chunk.vpp_rank = cur_vpp_rank
+
+ def set_last_layer(self, is_last_layer):
+ self._cur_forward_chunk.is_last_layer = is_last_layer
+
+ def cur_forward_chunk(self):
+ return self._cur_forward_chunk
+
+ def cur_backward_chunk(self):
+ return self._cur_backward_chunk
+
+ def __enter__(self):
+ print_rank("__enter__")
+ self.OFFLOAD_MGR
+ self.inside_context = True
+
+ if not isinstance(self.cur_forward_chunk(), NullChunkOffloadHandler):
+ torch._C._autograd._push_saved_tensors_default_hooks(
+ self.on_save_for_backward, self.on_get_saved_tensor
+ )
+
+ def __exit__(self, *args: Any):
+ print_rank("__exit__")
+ self.inside_context = False
+ if not isinstance(self.cur_forward_chunk(), NullChunkOffloadHandler):
+ torch._C._autograd._pop_saved_tensors_default_hooks()
+
+ def on_save_for_backward(self, tensor: torch.Tensor) -> Any:
+ print_rank(f"on_save_for_backward {tensor.shape}")
+ assert self.inside_context
+ return self.cur_forward_chunk().tensor_push(tensor)
+
+ def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor:
+ print_rank(f"on_get_saved_tensor {saved_state}")
+ return self.cur_backward_chunk().tensor_pop(saved_state)
+
+
+class ChunkOffloadHandler(AsyncDoubleBufferGroupOffloadHandler):
+ @staticmethod
+ def offload(src_tensor, pin_memory=True):
+ """Offload."""
+ print_rank("offload")
+ fp8_offload = isinstance(src_tensor, Float8Tensor)
+
+ cpu_backup = torch.empty(
+ src_tensor.size(),
+ dtype=torch.uint8 if fp8_offload else src_tensor.dtype,
+ layout=src_tensor.layout,
+ device="cpu",
+ pin_memory=pin_memory,
+ )
+
+ if fp8_offload:
+ cpu_backup = Float8Tensor.make_like(src_tensor, data=cpu_backup)
+
+ if not src_tensor.is_contiguous():
+ src_tensor = src_tensor.contiguous()
+
+ cpu_backup.copy_(src_tensor, non_blocking=pin_memory)
+ state = (src_tensor.device, cpu_backup)
+ return state
+
+ @staticmethod
+ def reload(state, non_blocking=None):
+ """Reload."""
+ print_rank("reload")
+ dev, cpu_backup = state
+ if non_blocking is None:
+ non_blocking = cpu_backup.is_pinned()
+ return cpu_backup.to(dev, non_blocking=non_blocking)
+
+ def __init__(self, num_layer, is_first_last_vpp_chunk, offload=True):
+ self._num_layers = num_layer
+ # Data Structure to maintain reference to activation tensors
+ self._tensor_tag_to_state = {}
+ # Tracking the number of layers offloaded
+ # self._offloaded_group_count = 0
+ self._is_first_last_vpp_chunk = is_first_last_vpp_chunk
+
+ self._offloaded_group_index = 0
+ self._groups_to_offload = []
+ self._groups_to_reload = []
+ self._layer_index = 0
+ self._tensor_count_current_group = 0
+ self.multi_input_offload_count = False
+ # self.offload_count_per_layer = defaultdict(int)
+
+ self.torch_tensor_count = 0
+ self.d2h_stream = PipelineOffloadManager.get_instance().d2h_stream
+ self.h2d_stream = PipelineOffloadManager.get_instance().h2d_stream
+ self._offload_events = {}
+ self._reload_events = {}
+ self.do_offload = offload
+ self.is_last_layer = False
+
+
+ def is_first_last_layer(self):
+ """Do not offload the last layer of the last pp stage."""
+ print_rank(f"is_first_last_layer {self._is_first_last_vpp_chunk} {self.is_last_layer}")
+ return self._is_first_last_vpp_chunk and self.is_last_layer
+
+ def tensor_push(self, tensor):
+ print_rank("tensor_push")
+ torch_stray_tensor = isinstance(
+ tensor,
+ (
+ torch._subclasses.fake_tensor.FakeTensor,
+ torch._subclasses.functional_tensor.FunctionalTensor,
+ ),
+ )
+
+ if not torch_stray_tensor:# True
+ # obtain a unique tensor tag
+ tensor_tag = (self._offloaded_group_index, self._tensor_count_current_group)
+ self._tensor_count_current_group += 1
+ assert tensor_tag not in self._tensor_tag_to_state
+ self._tensor_tag_to_state[tensor_tag] = tensor
+ else:
+ tensor_tag = (-1, self.torch_tensor_count)
+ self.torch_tensor_count += 1
+ self._tensor_tag_to_state[tensor_tag] = tensor
+ print_rank(f"tensor_push {tensor_tag}")
+ return tensor_tag
+
+ def tensor_pop(self, tensor_tag):
+ print_rank(f"tensor_pop {tensor_tag}")
+ assert tensor_tag in self._tensor_tag_to_state, f"{tensor_tag}, {self._tensor_tag_to_state.keys()}"
+ tensor = self._tensor_tag_to_state.pop(tensor_tag)
+ assert not isinstance(tensor, tuple)
+ print_rank(f"tensor_pop {tensor.shape}")
+ return tensor
+
+ def tensor_need_offloading_checker(self, tensor):
+ """Check if the tensor needs to be offloaded."""
+ if tensor.numel() < MIN_OFFLOADED_TENSOR_SIZE:
+ return False
+ if hasattr(tensor, "offloading_activation") and not tensor.offloading_activation:
+ return False
+ return True
+
+ def bulk_offload_group(self, group_to_offload):
+ """offload a group of tensors recorded in tensor_push().
+ """
+ print_rank("bulk_offload_group")
+ if not self.do_offload:
+ return
+ assert not self.is_first_last_layer()
+ group_id_to_offload, name = group_to_offload
+ torch.cuda.nvtx.range_push(name)
+ with torch.cuda.stream(self.d2h_stream):
+ for tensor_tag, state in self._tensor_tag_to_state.items():
+ group_id, _ = tensor_tag
+ if group_id == group_id_to_offload:
+ print_rank(f"tensor_tag {tensor_tag}")
+ print_rank(f"group_to_offload {group_to_offload}")
+ assert not isinstance(state, tuple)
+ tensor_on_device = state
+ if self.tensor_need_offloading_checker(tensor_on_device):
+ state = self.offload(tensor_on_device)
+ # TODO: check if we really need it.
+ # Record the last offloading event for this group,
+ # which is used to avoid reloading before offloading.
+ event = torch.cuda.Event()
+ event.record(self.d2h_stream)
+ self._offload_events[name] = event
+ tensor_on_device.record_stream(self.d2h_stream)
+ self._tensor_tag_to_state[tensor_tag] = state
+ print_rank("exit bulk_offload_group")
+ torch.cuda.nvtx.range_pop()
+
+ def get_offload_event(self, name):
+ if name in self._offload_events:
+ return self._offload_events[name]
+ else:
+ return None
+
+ def get_reload_event(self, name):
+ if name in self._reload_events:
+ return self._reload_events[name]
+ else:
+ return None
+
+ def bulk_reload_group(self, group_to_reload):
+ """Bulk reload group."""
+ print_rank("bulk_reload_group")
+ if not self.do_offload:
+ return
+ found_reload_group = False
+ group_id_to_reload, name = group_to_reload
+ torch.cuda.nvtx.range_push(name)
+ with torch.cuda.stream(self.h2d_stream):
+ # move back tensors
+ for tensor_label, state in self._tensor_tag_to_state.items():
+ group_id, _ = tensor_label
+ if group_id == group_id_to_reload:
+ print_rank(f"tensor_label {tensor_label}")
+ found_reload_group = True
+ event = self.get_offload_event(name)
+ if isinstance(state, tuple):
+ # make sure the tensor is already offloaded to cpu before reloading it.
+ torch.cuda.current_stream().wait_event(event)
+ recovered_tensor = self.reload(state)
+ event.record(self.h2d_stream)
+ self._reload_events[name] = event
+ print_rank(f"recovered_tensor {recovered_tensor.shape}")
+ self._tensor_tag_to_state[tensor_label] = recovered_tensor
+ torch.cuda.nvtx.range_pop()
+ return found_reload_group
+
+ def pre_reload_last_layer(self):
+ """Pre-reload the last layer of the next model chunk."""
+ print_rank("pre_reload_last_layer")
+ if not self.do_offload:
+ return
+ assert not self._is_first_last_vpp_chunk
+ print_rank(f"len(self._groups_to_reload) {len(self._groups_to_reload)}")
+ if len(self._groups_to_reload) > 0:
+ if self.bulk_reload_group(self._groups_to_reload[-1]):
+ self._groups_to_reload.pop()
+
+ def should_bulk_offload(self):
+ """Check if the chunk should be offloaded."""
+ if not self.do_offload:
+ return False
+ # first backward chunk
+ if self.is_first_last_layer():
+ return False
+
+ # if next backward chunk is this chunk (for last pp stage)
+ next_backward_chunk = PipelineOffloadManager.get_instance().get_instance().front()
+ if next_backward_chunk is not None and next_backward_chunk is self:
+ if self.is_last_layer:
+ return False
+
+ return True
+
+ def bulk_offload(self, release_tensors):
+ print_rank("bulk_offload")
+ if self.should_bulk_offload():
+ group_to_offload = self._groups_to_offload.pop()
+ name = group_to_offload[1]
+ self._groups_to_reload.append(group_to_offload)
+ self.bulk_offload_group(group_to_offload)
+ if len(release_tensors) > 0:
+ cur_stream = torch.cuda.current_stream()
+ for release_tensor in release_tensors:
+ release_tensor.record_stream(cur_stream)
+ release_tensor.untyped_storage().resize_(0)
+ print_rank("exit bulk_offload")
+
+ def on_group_commit_forward(self, release_tensors):
+ """Offload a group of tensors."""
+ print_rank("on_group_commit_forward")
+ # wait for the compute stream for offloading
+ self.d2h_stream.wait_stream(torch.cuda.current_stream())
+ self.bulk_offload(release_tensors)
+ print_rank("exit on_group_commit_forward")
+
+ def bulk_reload(self):
+ print_rank("bulk_reload")
+ if len(self._groups_to_reload) > 0:
+ # load next layer
+ if self.bulk_reload_group(self._groups_to_reload[-1]):
+ print_rank(f"bulk_reload_group {self._groups_to_reload}")
+ self._groups_to_reload.pop()
+ else:
+ # load the last layer of one backward chunk in advance
+ next_backward_chunk = PipelineOffloadManager.get_instance().front()
+ if next_backward_chunk is not None:
+ next_backward_chunk.pre_reload_last_layer()
+
+ def on_group_commit_backward(self, name):
+ """Prepare for reloadingthe next group of tensors."""
+ print_rank("on_group_commit_backward")
+ cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk()
+ if not cur_backward_chunk is self:
+ PipelineOffloadManager.get_instance().pop()
+ cur_backward_chunk = PipelineOffloadManager.get_instance().cur_backward_chunk()
+ assert cur_backward_chunk is self
+ # make sure the reloading jobs for current computation are done.
+ event = self.get_reload_event(name)
+ if event is not None:
+ torch.cuda.current_stream().wait_event(event)
+ self._offloaded_group_index = self._offloaded_group_index - 1
+
+ def on_group_start_forward(self, name):
+ """Prepare for offloading the next group of tensors."""
+ print_rank(f"on_group_start_forward {self._layer_index} {self._num_layers}")
+ self._offloaded_group_index = self._offloaded_group_index + 1
+ self._tensor_count_current_group = 0
+ self._groups_to_offload.append((self._offloaded_group_index, name))
+
+ def on_group_start_backward(self):
+ """Reload the next group of tensors."""
+ print_rank("on_group_start_backward")
+ self.h2d_stream.wait_stream(torch.cuda.current_stream())
+ self.bulk_reload()
+
+class NullChunkOffloadHandler(ChunkOffloadHandler):
+ pass
+
+class GroupCommitFunction(torch.autograd.Function):
+ """this is a dummy op with output identical to input.
+ However, it is necessary for marking a timepoint for offload handler to
+ accomplish all synchronizations. Implementing it as a function is necessary
+ because we need to actions in both forward and backward.
+ """
+
+ @staticmethod
+ def forward(ctx, *args):
+ # pylint: disable=missing-function-docstring
+ print_rank("GroupCommitFunction forward")
+
+ release_tensors = args[-1]
+ name = args[-2]
+ cpu_offload_handler = args[-3]
+ tensor = args[:-3]
+ if not isinstance(cpu_offload_handler, NullChunkOffloadHandler):
+ cpu_offload_handler.on_group_commit_forward(release_tensors)
+ ctx.cpu_offload_handler = cpu_offload_handler
+ ctx.name = name
+
+ # return the identical tensor
+ return tensor
+
+ @staticmethod
+ def backward(ctx, *grad_output):
+ # pylint: disable=missing-function-docstring
+ print_rank("GroupCommitFunction backward")
+
+ cpu_offload_handler = ctx.cpu_offload_handler
+ if not isinstance(cpu_offload_handler, NullChunkOffloadHandler):
+ cpu_offload_handler.on_group_commit_backward(ctx.name)
+ return grad_output + (None, None, None)
+
+
+def group_prefetch_offload_commit(*tensor, name, release_tensors=[]):
+ """Specify the tensors to be released after offloading.
+ release_tensors is a list of tensors to be released after offloading.
+ The tensors will be untyped_storage().resize_(0) after offloading.
+ Note: specify the tensors only when they are not automatically released by torch gc.
+ """
+ cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk()
+ return GroupCommitFunction.apply(*tensor, cur_forward_chunk, name, release_tensors)
+
+
+class GroupStartFunction(torch.autograd.Function):
+ """this is a dummy op with output identical to input.
+ However, it is necessary for marking a timepoint for offload handler to
+ accomplish all synchronizations. Implementing it as a function is necessary
+ because we need to actions in both forward and backward.
+ """
+
+ @staticmethod
+ def forward(ctx, tensor, cpu_offload_handler, name):
+ # pylint: disable=missing-function-docstring
+ ctx.cpu_offload_handler = cpu_offload_handler
+ print_rank("GroupStartFunction forward")
+
+ if not isinstance(cpu_offload_handler, NullChunkOffloadHandler):
+ cpu_offload_handler.on_group_start_forward("activation offloading " + name)
+ # return the identical tensor
+ return tensor
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ print_rank("GroupStartFunction backward")
+ # pylint: disable=missing-function-docstring
+ cpu_offload_handler = ctx.cpu_offload_handler
+ if not isinstance(cpu_offload_handler, NullChunkOffloadHandler):
+ cpu_offload_handler.on_group_start_backward()
+ return grad_output, None, None
+
+
+def group_prefetch_offload_start(tensor, name=None):
+ cur_forward_chunk = PipelineOffloadManager.get_instance().cur_forward_chunk()
+ return GroupStartFunction.apply(tensor, cur_forward_chunk, name)
diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py
index fc741aa46f3..eddf61e6f8a 100644
--- a/megatron/core/transformer/moe/experts.py
+++ b/megatron/core/transformer/moe/experts.py
@@ -10,6 +10,7 @@
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
+import contextlib
from megatron.core import parallel_state, tensor_parallel
from megatron.core.activations import squared_relu
@@ -41,6 +42,11 @@
make_sharded_object_for_checkpoint,
sharded_state_dict_default,
)
+from megatron.core.transformer.cpu_offload import (
+ PipelineOffloadManager,
+ group_prefetch_offload_start,
+ group_prefetch_offload_commit,
+)
try:
import transformer_engine as te # pylint: disable=unused-import
@@ -805,6 +811,16 @@ def __init__(
tp_group=pg_collection.expt_tp,
)
+ self.offload_expert_fc1 = (
+ self.config.fine_grained_activation_offloading
+ and "expert_fc1" in self.config.offload_modules
+ )
+
+ self.offload_moe_act = (
+ self.config.fine_grained_activation_offloading
+ and "moe_act" in self.config.offload_modules
+ )
+
self.activation_recompute = (
self.config.recompute_granularity == 'selective'
and "moe_act" in self.config.recompute_modules
@@ -813,6 +829,11 @@ def __init__(
from megatron.core.extensions.transformer_engine import set_save_original_input
set_save_original_input(self.linear_fc2)
+
+ # This is to avoid the CPU overhead of multiple d2h copies
+ if self.offload_expert_fc1:
+ from megatron.core.extensions.transformer_engine import set_save_original_input
+ set_save_original_input(self.linear_fc1)
if self.config.fp8:
assert HAVE_TE, "FP8 requires TE."
@@ -858,9 +879,17 @@ def forward(
# Probs already applied, so reset to 1.
permuted_probs = torch.ones_like(permuted_probs)
- intermediate_parallel, bias_parallel = self.linear_fc1(
- permuted_local_hidden_states, tokens_per_expert
+ offload_context = contextlib.nullcontext()
+ if self.offload_expert_fc1:
+ permuted_local_hidden_states = group_prefetch_offload_start(permuted_local_hidden_states, name="expert_fc1")
+ offload_context = PipelineOffloadManager.get_instance()
+ with offload_context:
+ fc1_output, bias_parallel = self.linear_fc1(
+ permuted_local_hidden_states, tokens_per_expert
)
+ if self.offload_expert_fc1:
+ fc1_output, bias_parallel = group_prefetch_offload_commit(fc1_output, bias_parallel, name="expert_fc1", release_tensors=[])
+ offload_context = contextlib.nullcontext()
def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs):
if self.config.use_te_activation_func:
@@ -931,18 +960,29 @@ def glu(x):
intermediate_parallel = intermediate_parallel.to(original_dtype)
return intermediate_parallel
+ if self.offload_moe_act:
+ fc1_output = group_prefetch_offload_start(fc1_output, name="moe_act")
+ offload_context = PipelineOffloadManager.get_instance()
+
if self.activation_recompute:
self.activation_checkpoint = tensor_parallel.CheckpointWithoutOutput()
- intermediate_parallel = self.activation_checkpoint.checkpoint(
- bias_act_func, intermediate_parallel, bias_parallel, permuted_probs
- )
- output, output_bias = self.linear_fc2(intermediate_parallel, tokens_per_expert)
- self.activation_checkpoint.discard_output_and_register_recompute(output)
+ with offload_context:
+ bias_act_output = self.activation_checkpoint.checkpoint(
+ bias_act_func, fc1_output, bias_parallel, permuted_probs
+ )
else:
- intermediate_parallel = bias_act_func(
- intermediate_parallel, bias_parallel, permuted_probs
- )
- output, output_bias = self.linear_fc2(intermediate_parallel, tokens_per_expert)
+ with offload_context:
+ bias_act_output = bias_act_func(
+ fc1_output, bias_parallel, permuted_probs
+ )
+
+ output, output_bias = self.linear_fc2(bias_act_output, tokens_per_expert)
+ if self.activation_recompute:
+ self.activation_checkpoint.discard_output_and_register_recompute(output)
+ if self.offload_moe_act:
+ output, = group_prefetch_offload_commit(output, name="moe_act", release_tensors=[])
+ offload_context = contextlib.nullcontext()
+
# upad and concat the output
if self.config.fp8:
diff --git a/megatron/core/transformer/multi_latent_attention.py b/megatron/core/transformer/multi_latent_attention.py
index 9e6b46fd4e5..3a16f0b5377 100644
--- a/megatron/core/transformer/multi_latent_attention.py
+++ b/megatron/core/transformer/multi_latent_attention.py
@@ -37,6 +37,12 @@
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import MLATransformerConfig
from megatron.core.utils import deprecate_inference_params, is_te_min_version
+from megatron.core.transformer.cpu_offload import (
+ PipelineOffloadManager,
+ group_prefetch_offload_start,
+ group_prefetch_offload_commit,
+)
+import contextlib
try:
from megatron.core.fusions.fused_mla_yarn_rope_apply import (
@@ -264,8 +270,14 @@ def forward(
query, key, value, attention_mask, packed_seq_params=packed_seq_params
)
else:
+ offload_context = contextlib.nullcontext()
+ if self.offload_core_attention and self.training:
+ query = group_prefetch_offload_start(query, name="core_attn")
+ offload_context = PipelineOffloadManager.get_instance()
+
if inference_context is None or inference_context.is_static_batching():
- core_attn_out = self.core_attention(
+ with offload_context:
+ core_attn_out = self.core_attention(
query,
key,
value,
@@ -293,6 +305,9 @@ def forward(
# Only rearrange if not in absorption mode (Flash MLA handles format correctly)
if not inference_context.is_decode_only():
core_attn_out = rearrange(core_attn_out, 's b h d -> s b (h d)')
+ if self.offload_core_attention and self.training:
+ core_attn_out, = group_prefetch_offload_commit(core_attn_out, name="core_attn", release_tensors=[query, key, value])
+ offload_context = contextlib.nullcontext()
# We are doing absorption with cache mla latents and decode mode.
if self.cache_mla_latents and inference_context.is_decode_only():
@@ -318,7 +333,15 @@ def forward(
# =================
# Output. [sq, b, h]
# =================
- output, bias = self.linear_proj(core_attn_out)
+ offload_context = contextlib.nullcontext()
+ if self.offload_attn_proj:
+ core_attn_out = group_prefetch_offload_start(core_attn_out, name="attn_proj")
+ offload_context = PipelineOffloadManager.get_instance()
+ with offload_context:
+ output, bias = self.linear_proj(core_attn_out)
+ if self.offload_attn_proj:
+ output, bias = group_prefetch_offload_commit(output, bias, name="attn_proj", release_tensors=[core_attn_out])
+ offload_context = contextlib.nullcontext()
return output, bias
diff --git a/megatron/core/transformer/transformer_block.py b/megatron/core/transformer/transformer_block.py
index c14f8ea1f74..867b3689a1d 100755
--- a/megatron/core/transformer/transformer_block.py
+++ b/megatron/core/transformer/transformer_block.py
@@ -33,6 +33,7 @@
get_pg_rank,
make_viewless_tensor,
)
+from megatron.core.transformer.cpu_offload import PipelineOffloadManager
try:
import transformer_engine.pytorch as te # pylint: disable=unused-import
@@ -324,6 +325,12 @@ def __init__(
self._build_layers()
self.num_layers_per_pipeline_rank = len(self.layers)
+ self.num_dense_layer = 0
+ from megatron.core.transformer.moe.moe_layer import MoELayer
+ for layer in self.layers:
+ if not isinstance(layer.mlp, MoELayer):
+ self.num_dense_layer += 1
+
def _build_layers(self):
# Transformer layers.
# @jcasper can we improve how we deal with layer_number?
@@ -639,6 +646,11 @@ def forward(
inner_quantization_context = nullcontext()
else:
inner_quantization_context = nullcontext()
+
+ if l_no == self.num_layers_per_pipeline_rank - 1:
+ PipelineOffloadManager.get_instance().set_last_layer(True)
+ else:
+ PipelineOffloadManager.get_instance().set_last_layer(False)
with self.offload_context, inner_quantization_context:
hidden_states, context = layer(
diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py
index 7272c7a136a..561226d1ec2 100644
--- a/megatron/core/transformer/transformer_config.py
+++ b/megatron/core/transformer/transformer_config.py
@@ -700,6 +700,26 @@ class TransformerConfig(ModelParallelConfig):
"""Transformer implementation to use.
Options are 'transformer_engine' for Transformer Engine and 'local' for MCore."""
+ #####################################
+ # Fine-grained Activation Offloading
+ #####################################
+ fine_grained_activation_offloading: bool = False
+ """If True, offload the activation to the CPU."""
+
+ offload_modules: Optional[list[str]] = None
+ """The submodules to offload.
+ choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".
+ default: ["core_attn"].
+ "attn_norm": offload the input of the normalization in the attention part.
+ "core_attn": offload the input of the core attention part.
+ "mlp_norm": offload the input of the normalization in the mlp part.
+ "attn_proj": offload the input of the attn linear projection part.
+ "expert_fc1": offload the input of the expert fc1 part.
+ "moe_act": offload the input of the moe act part.
+ """
+ last_vp_stage_is_loss: bool = False
+ """If True, the last virtual pipeline stage is the loss stage."""
+
def __post_init__(self):
"""Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
@@ -990,6 +1010,34 @@ def __post_init__(self):
if "moe" not in self.recompute_modules:
self.recompute_modules.append("moe")
+ # if self.offload_modules is None:
+ # self.offload_modules = ["core_attn"]
+
+ if len(self.offload_modules) > 0:
+ allowed_modules = {
+ "core_attn", "attn_proj", "expert_fc1", "moe_act", "attn_norm", "mlp_norm"
+ }
+ invalid_modules = set(self.offload_modules) - allowed_modules
+ assert not invalid_modules, (
+ f'Invalid choices for offload_modules: {invalid_modules}. '
+ f'Allowed modules are: {allowed_modules}'
+ )
+ if "attn_proj" in self.offload_modules and "core_attn" not in self.offload_modules:
+ raise ValueError(
+ "attn_proj cannot be set to offload_modules alone without core_attn "
+ "because the input of attn_proj is the output of core_attn, "
+ "which is needed in core_attn.backward()."
+ )
+
+ if isinstance(self.moe_layer_freq, int):
+ assert self.moe_layer_freq == 1, "moe_layer_freq cannot be an integer other than 1 when offload_modules is set."
+ elif isinstance(self.moe_layer_freq, list):
+ if 0 in self.moe_layer_freq:
+ warnings.warn(
+ "Activation of dense layer won't be offloaded at all for mixed dense and moe layer."
+ )
+
+
if (
self.num_layers_in_first_pipeline_stage is not None
or self.num_layers_in_last_pipeline_stage is not None
@@ -1171,6 +1219,13 @@ def __post_init__(self):
f"{self.virtual_pipeline_model_parallel_size}"
)
+ if len(self.offload_modules) > 0:
+ if self.pipeline_model_parallel_layout is not None:
+ from megatron.core.transformer.pipeline_parallel_layer_layout import LayerType
+ if (len(self.pipeline_model_parallel_layout.layout[-1][-1]) == 1 and
+ self.pipeline_model_parallel_layout.layout[-1][-1][0] is LayerType.loss):
+ self.last_vp_stage_is_loss = True
+
if self.apply_query_key_layer_scaling:
self.attention_softmax_in_fp32 = True
diff --git a/megatron/core/transformer/transformer_layer.py b/megatron/core/transformer/transformer_layer.py
index 42d077e2e9f..ed5ace3fa6c 100644
--- a/megatron/core/transformer/transformer_layer.py
+++ b/megatron/core/transformer/transformer_layer.py
@@ -9,6 +9,7 @@
import torch
import torch.distributed
from torch import Tensor
+import contextlib
from megatron.core import parallel_state, tensor_parallel
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
@@ -31,6 +32,11 @@
nvtx_range_pop,
nvtx_range_push,
)
+from megatron.core.transformer.cpu_offload import (
+ PipelineOffloadManager,
+ group_prefetch_offload_start,
+ group_prefetch_offload_commit,
+)
logger = logging.getLogger(__name__)
@@ -427,6 +433,20 @@ def __init__(
if "mlp" in self.config.recompute_modules:
if not isinstance(self.mlp, MoELayer):
self.recompute_mlp = True
+ self.offload_self_attn = (
+ self.config.fine_grained_activation_offloading
+ and "self_attn" in self.config.offload_modules
+ )
+ self.offload_attn_norm = (
+ self.config.fine_grained_activation_offloading
+ and "attn_norm" in self.config.offload_modules
+ and not isinstance(self.input_layernorm, IdentityOp)
+ )
+ self.offload_mlp_norm = (
+ self.config.fine_grained_activation_offloading
+ and "mlp_norm" in self.config.offload_modules
+ and not isinstance(self.pre_mlp_layernorm, IdentityOp)
+ )
# @jcasper how should we handle nvfuser?
# Set bias+dropout+add fusion grad_enable execution handler.
@@ -510,18 +530,29 @@ def _forward_attention(
# Residual connection.
residual = hidden_states
+ offload_context = contextlib.nullcontext()
+ if self.offload_attn_norm:
+ hidden_states = group_prefetch_offload_start(hidden_states, name="attn_norm")
+ offload_context = PipelineOffloadManager.get_instance()
# Optional Input Layer norm
if self.recompute_input_layernorm:
self.input_layernorm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
- input_layernorm_output = self.input_layernorm_checkpoint.checkpoint(
- self.input_layernorm, hidden_states
- )
+ with offload_context:
+ input_layernorm_output = self.input_layernorm_checkpoint.checkpoint(
+ self.input_layernorm, hidden_states
+ )
else:
- input_layernorm_output = self.input_layernorm(hidden_states)
+ with offload_context:
+ input_layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
nvtx_range_push(suffix="self_attention")
- attention_output_with_bias = self.self_attention(
+ offload_context = contextlib.nullcontext()
+ if self.offload_self_attn:
+ input_layernorm_output = group_prefetch_offload_start(input_layernorm_output, name="self_attn")
+ offload_context = PipelineOffloadManager.get_instance()
+ with offload_context:
+ attention_output_with_bias = self.self_attention(
input_layernorm_output,
attention_mask=attention_mask,
inference_context=inference_context,
@@ -532,6 +563,9 @@ def _forward_attention(
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
)
+ if self.offload_self_attn:
+ attention_output_with_bias, = group_prefetch_offload_commit(attention_output_with_bias, name="self_attn", release_tensors=[input_layernorm_output])
+ offload_context = contextlib.nullcontext()
nvtx_range_pop(suffix="self_attention")
if self.recompute_input_layernorm:
@@ -550,6 +584,10 @@ def _forward_attention(
)
nvtx_range_pop(suffix="self_attn_bda")
+ if self.offload_attn_norm:
+ hidden_states, = group_prefetch_offload_commit(hidden_states, name="attn_norm", release_tensors=[residual])
+ offload_context = contextlib.nullcontext()
+
# Residual connection.
residual = hidden_states
@@ -590,14 +628,20 @@ def _forward_mlp(self, hidden_states, inference_context=None):
# Residual connection.
residual = hidden_states
+ offload_context = contextlib.nullcontext()
+ if self.offload_mlp_norm:
+ hidden_states = group_prefetch_offload_start(hidden_states, name="mlp_norm")
+ offload_context = PipelineOffloadManager.get_instance()
# Optional Layer norm post the cross-attention.
if self.recompute_pre_mlp_layernorm:
self.pre_mlp_norm_checkpoint = tensor_parallel.CheckpointWithoutOutput()
- pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint(
- self.pre_mlp_layernorm, hidden_states
- )
+ with offload_context:
+ pre_mlp_layernorm_output = self.pre_mlp_norm_checkpoint.checkpoint(
+ self.pre_mlp_layernorm, hidden_states
+ )
else:
- pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
+ with offload_context:
+ pre_mlp_layernorm_output = self.pre_mlp_layernorm(hidden_states)
nvtx_range_push(suffix="mlp")
# Potentially chunk the MLP computation during prefill to minimize the peak activation size
@@ -657,6 +701,9 @@ def _forward_mlp(self, hidden_states, inference_context=None):
mlp_output_with_bias, residual, self.hidden_dropout
)
nvtx_range_pop(suffix="mlp_bda")
+ if self.offload_mlp_norm:
+ hidden_states, = group_prefetch_offload_commit(hidden_states, name="mlp_norm", release_tensors=[residual])
+ offload_context = contextlib.nullcontext()
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
diff --git a/megatron/core/utils.py b/megatron/core/utils.py
index e66dab8abd8..7e1b9d9251a 100644
--- a/megatron/core/utils.py
+++ b/megatron/core/utils.py
@@ -136,7 +136,7 @@ def wrapped_func(*args, **kwargs):
if config.is_experimental_enabled() is not True:
raise ExperimentalNotEnabledError(f"Flag config.ENABLE_EXPERIMENTAL not enabled.")
- logger.info("Setting ENABLE_EXPERIMENTAL=True will run experimental code.")
+ # logger.info("Setting ENABLE_EXPERIMENTAL=True will run experimental code.")
return func(*args, **kwargs)
@@ -213,7 +213,7 @@ def guard(super: super, attr: str):
f"Flag config.ENABLE_EXPERIMENTAL not enabled."
)
- logger.info("Setting ENABLE_EXPERIMENTAL=True will run experimental code.")
+ # logger.info("Setting ENABLE_EXPERIMENTAL=True will run experimental code.")
return super.__getattribute__(attr)
class ClassInterceptor(type):
diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py
index c8102c48cf4..5b4ab369840 100644
--- a/megatron/training/arguments.py
+++ b/megatron/training/arguments.py
@@ -2189,7 +2189,10 @@ def _add_training_args(parser):
help='The communicator group names to use high priority streams.')
group.add_argument('--use-te-activation-func', action='store_true',
help='Use activation function kernel from Transformer Engine in MLP module.')
-
+ group.add_argument('--fine-grained-activation-offloading', action='store_true',
+ help='Offload the activation to the CPU.')
+ group.add_argument('--offload-modules', nargs='*', type=str, default=[],
+ help='The submodules to offload. Choices: "attn_norm", "core_attn", "attn_proj", "mlp_norm", "expert_fc1", "moe_act".')
return parser
diff --git a/megatron/training/training.py b/megatron/training/training.py
index f4a3eb9bef6..bc141e0c4dd 100644
--- a/megatron/training/training.py
+++ b/megatron/training/training.py
@@ -589,6 +589,11 @@ def pretrain(
args = get_args()
timers = get_timers()
+ if args.profile and torch.distributed.get_rank() in args.profile_ranks:
+ torch.cuda.memory._record_memory_history(
+ max_entries=1000000,
+ )
+
if args.log_progress:
append_to_progress_log("Starting job")
@@ -1625,7 +1630,7 @@ def training_log(
total_loss_dict[skipped_iters_key] = 0
total_loss_dict[nan_iters_key] = 0
print_rank_last(log_string)
- if report_memory_flag:
+ if report_memory_flag and iteration == 10:
# Report memory after optimizer state has been initialized.
if torch.distributed.get_rank() == 0:
num_microbatches = get_num_microbatches()
@@ -2173,6 +2178,15 @@ def get_e2e_base_metrics():
# Run training iterations till done.
buffered_rollouts = None
while iteration < args.train_iters:
+ if args.profile and torch.distributed.get_rank() in args.profile_ranks and iteration == 3:
+ try:
+ comment = os.getenv("COMMENT")
+ model_name = os.getenv("MODEL")
+ memory_snapshot_path = f"/lustre/fsw/coreai_devtech_all/hongbinl/cpu_offloading/megatron-moe-scripts/pyt_profile/{model_name}_{comment}"
+ torch.cuda.memory._dump_snapshot(f"{memory_snapshot_path}.pickle")
+ print_rank_0(f"Captured memory snapshot at {memory_snapshot_path}.pickle")
+ except Exception as e:
+ print_rank_0(f"Failed to capture memory snapshot {e}")
if args.profile and torch.distributed.get_rank() in args.profile_ranks:
if args.use_pytorch_profiler:
prof.step()