diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py index d0f2945c684..e323e46ec15 100644 --- a/megatron/training/checkpointing.py +++ b/megatron/training/checkpointing.py @@ -788,9 +788,11 @@ def onelogger_finalize_fn(): else: onelogger_finalize_fn() - # Additional callback for wandb (last rank) + # Additional callback for wandb + # The wandb artifact requires the tracker file to be present, so we need to ensure + # that rank 0 has already saved it before proceeding with wandb operations if not torch.distributed.is_initialized() \ - or is_last_rank(): + or torch.distributed.get_rank() == 0: def wandb_finalize_fn(): wandb_utils.on_save_checkpoint_success(checkpoint_name, get_checkpoint_tracker_filename(save_dir), save_dir, iteration) if args.async_save: