diff --git a/megatron/core/optimizer/distrib_optimizer.py b/megatron/core/optimizer/distrib_optimizer.py index 48a3ad47aa7..c6283db9ded 100644 --- a/megatron/core/optimizer/distrib_optimizer.py +++ b/megatron/core/optimizer/distrib_optimizer.py @@ -771,7 +771,7 @@ def make_needed_groups(param_group): # Allocate dummy tensors. numel = len(param_range_map["gbuf_world"]) init_shard = lambda dtype=torch.float32: torch.empty( - (numel,), dtype=dtype, device=torch.cuda.current_device() + (numel,), dtype=dtype, device="cpu" ) # For precision_aware_optimizer, the empty tensors should also be diff --git a/megatron/core/transformer/mlp.py b/megatron/core/transformer/mlp.py index 23cbc16b788..cb4526ea4cf 100644 --- a/megatron/core/transformer/mlp.py +++ b/megatron/core/transformer/mlp.py @@ -340,17 +340,8 @@ def sh_ten_build_fn( def sh_ten_merge_fn(sub_state_dict): with torch.no_grad(): - try: - return torch.cat(sub_state_dict) - except (RuntimeError, torch.cuda.OutOfMemoryError) as e: - logger.warning( - f"CUDA OutOfMemoryError encountered during tensors merging." - f" Switching to CPU merge. (Error: {e})" - ) - merged_sub_state_dict = torch.cat([t.cpu() for t in sub_state_dict]) - gc.collect() - torch.cuda.empty_cache() - return merged_sub_state_dict + merged_sub_state_dict = torch.cat([t.cpu() for t in sub_state_dict]) + return merged_sub_state_dict return ShardedTensorFactory( original_sh_ten.key,