Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion megatron/core/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ def make_needed_groups(param_group):
# Allocate dummy tensors.
numel = len(param_range_map["gbuf_world"])
init_shard = lambda dtype=torch.float32: torch.empty(
(numel,), dtype=dtype, device=torch.cuda.current_device()
(numel,), dtype=dtype, device="cpu"
)

# For precision_aware_optimizer, the empty tensors should also be
Expand Down
13 changes: 2 additions & 11 deletions megatron/core/transformer/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,17 +340,8 @@ def sh_ten_build_fn(

def sh_ten_merge_fn(sub_state_dict):
with torch.no_grad():
try:
return torch.cat(sub_state_dict)
except (RuntimeError, torch.cuda.OutOfMemoryError) as e:
logger.warning(
f"CUDA OutOfMemoryError encountered during tensors merging."
f" Switching to CPU merge. (Error: {e})"
)
merged_sub_state_dict = torch.cat([t.cpu() for t in sub_state_dict])
gc.collect()
torch.cuda.empty_cache()
return merged_sub_state_dict
merged_sub_state_dict = torch.cat([t.cpu() for t in sub_state_dict])
return merged_sub_state_dict

return ShardedTensorFactory(
original_sh_ten.key,
Expand Down