Skip to content

Race Condition in Checkpoint Saving When Using Multi-Node Training with Shared NFS #894

@JiaoShuai

Description

@JiaoShuai

Environment:

  • Training setup: 2-node, 8-GPU per node, GPU as A100, with shared NFS for checkpointing
  • File system: Shared NFS across nodes

Problem Description

NH-DC-NM129-B01-20U-GPU-32:1134869:1185464 [4] NCCL INFO NCCL_SHM_DISABLE set by environment to 0.
NH-DC-NM129-B01-20U-GPU-32:1134868:1185467 [3] NCCL INFO NCCL_SHM_DISABLE set by environment to 0.
NH-DC-NM129-B01-20U-GPU-32:1134871:1185466 [5] NCCL INFO NCCL_SHM_DISABLE set by environment to 0.
NH-DC-NM129-B01-20U-GPU-32:1134866:1185468 [2] NCCL INFO NCCL_SHM_DISABLE set by environment to 0.
NH-DC-NM129-B01-20U-GPU-32:1134874:1185469 [7] NCCL INFO NCCL_SHM_DISABLE set by environment to 0.
NH-DC-NM129-B01-20U-GPU-32:1134869:1185464 [4] NCCL INFO Channel 01/1 : 12[4] -> 0[0] [send] via NET/Socket/0/Shared
NH-DC-NM129-B01-20U-GPU-32:1134865:1185465 [1] NCCL INFO Channel 01/1 : 9[1] -> 0[0] [send] via NET/Socket/0/Shared
NH-DC-NM129-B01-20U-GPU-32:1134868:1185467 [3] NCCL INFO Channel 01/1 : 11[3] -> 0[0] [send] via NET/Socket/0/Shared
NH-DC-NM129-B01-20U-GPU-32:1134871:1185466 [5] NCCL INFO Channel 01/1 : 13[5] -> 0[0] [send] via NET/Socket/0/Shared
NH-DC-NM129-B01-20U-GPU-32:1134872:1185471 [6] NCCL INFO NCCL_SHM_DISABLE set by environment to 0.
NH-DC-NM129-B01-20U-GPU-32:1134866:1185468 [2] NCCL INFO Channel 01/1 : 10[2] -> 0[0] [send] via NET/Socket/0/Shared
NH-DC-NM129-B01-20U-GPU-32:1134864:1185470 [0] NCCL INFO Channel 01/1 : 8[0] -> 0[0] [send] via NET/Socket/0/Shared
NH-DC-NM129-B01-20U-GPU-32:1134874:1185469 [7] NCCL INFO Channel 01/1 : 15[7] -> 0[0] [send] via NET/Socket/0/Shared
NH-DC-NM129-B01-20U-GPU-32:1134872:1185471 [6] NCCL INFO Channel 01/1 : 14[6] -> 0[0] [send] via NET/Socket/0/Shared
NH-DC-NM129-B01-20U-GPU-32:1134872:1185634 [6] NCCL INFO Channel 01/1 : 0[0] -> 14[6] [receive] via NET/Socket/0/Shared
NH-DC-NM129-B01-20U-GPU-32:1134866:1185632 [2] NCCL INFO Channel 01/1 : 0[0] -> 10[2] [receive] via NET/Socket/0/Shared
NH-DC-NM129-B01-20U-GPU-32:1134868:1185633 [3] NCCL INFO Channel 01/1 : 0[0] -> 11[3] [receive] via NET/Socket/0/Shared
NH-DC-NM129-B01-20U-GPU-32:1134869:1185631 [4] NCCL INFO Channel 01/1 : 0[0] -> 12[4] [receive] via NET/Socket/0/Shared
NH-DC-NM129-B01-20U-GPU-32:1134871:1185635 [5] NCCL INFO Channel 01/1 : 0[0] -> 13[5] [receive] via NET/Socket/0/Shared
NH-DC-NM129-B01-20U-GPU-32:1134874:1185638 [7] NCCL INFO Channel 01/1 : 0[0] -> 15[7] [receive] via NET/Socket/0/Shared
NH-DC-NM129-B01-20U-GPU-32:1134865:1185636 [1] NCCL INFO Channel 01/1 : 0[0] -> 9[1] [receive] via NET/Socket/0/Shared
NH-DC-NM129-B01-20U-GPU-32:1134864:1185637 [0] NCCL INFO Channel 01/1 : 0[0] -> 8[0] [receive] via NET/Socket/0/Shared
2025-11-01 07:52:20.638 NH-DC-NM129-B01-20U-GPU-32:0    olmo.util:168   CRITICAL        Uncaught FileNotFoundError: [Errno 2] No such file or directory: 'save/olmo_output/step0-tmp' -> 'save/olmo_output/step0'
Traceback (most recent call last):
  File "/mnt/si001960aoia/default/Workspace/jiaoshuai/olmorun_1b_qwen3d_fineweb_edu_mix_pure_2nodes_a100/OLMo/scripts/train.py", line 436, in <module>
    main(cfg)
  File "/mnt/si001960aoia/default/Workspace/jiaoshuai/olmorun_1b_qwen3d_fineweb_edu_mix_pure_2nodes_a100/OLMo/scripts/train.py", line 336, in main
    checkpoint_path, local_checkpoint_cache = trainer.save_checkpoint(checkpoint_type=checkpoint_type)
                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/OLMo/olmo/train.py", line 615, in save_checkpoint
    result = self.save_sharded_checkpoint()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/OLMo/olmo/train.py", line 523, in save_sharded_checkpoint
    result = self._save_checkpoint(checkpointer, CheckpointType.sharded)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/OLMo/olmo/train.py", line 483, in _save_checkpoint
    checkpointer.save_checkpoint(
  File "/workspace/OLMo/olmo/checkpoint.py", line 1923, in save_checkpoint
    with self._temporary_wd(dir) as checkpoint_dir:
  File "/opt/conda/lib/python3.11/contextlib.py", line 144, in __exit__
    next(self.gen)
  File "/workspace/OLMo/olmo/checkpoint.py", line 589, in _temporary_wd
    checkpoint_dir_tmp.replace(checkpoint_dir)
  File "/opt/conda/lib/python3.11/pathlib.py", line 1188, in replace
    os.replace(self, target)
FileNotFoundError: [Errno 2] No such file or directory: 'save/olmo_output/step0-tmp' -> 'save/olmo_output/step0'
[rank8]:[W1101 07:52:22.399796673 ProcessGroupNCCL.cpp:1496] Warning: WARNING: destroy_process_group() was not called before program exit, which can leak resources. For more info, please see https://pytorch.org/docs/stable/distributed.html#shutdown (function operator())
NH-DC-NM129-B01-20U-GPU-32:1134864:1142022 [0] NCCL INFO [Service thread] Connection closed by localRank 0
NH-DC-NM129-B01-20U-GPU-32:1134864:1192435 [0] NCCL INFO comm 0x561f604c13d0 rank 8 nranks 16 cudaDev 0 busId 10000 - Abort COMPLETE
W1101 07:52:23.154000 1134339 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1134865 closing signal SIGTERM
W1101 07:52:23.155000 1134339 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1134866 closing signal SIGTERM
W1101 07:52:23.155000 1134339 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1134868 closing signal SIGTERM
W1101 07:52:23.156000 1134339 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1134869 closing signal SIGTERM
W1101 07:52:23.156000 1134339 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1134871 closing signal SIGTERM
W1101 07:52:23.156000 1134339 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1134872 closing signal SIGTERM
W1101 07:52:23.157000 1134339 site-packages/torch/distributed/elastic/multiprocessing/api.py:897] Sending process 1134874 closing signal SIGTERM
/opt/conda/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
/opt/conda/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
/opt/conda/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
/opt/conda/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
/opt/conda/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
/opt/conda/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
/opt/conda/lib/python3.11/multiprocessing/resource_tracker.py:254: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '
E1101 07:52:24.451000 1134339 site-packages/torch/distributed/elastic/multiprocessing/api.py:869] failed (exitcode: 1) local_rank: 0 (pid: 1134864) of binary: /opt/conda/bin/python3.11
Traceback (most recent call last):
  File "/opt/conda/bin/torchrun", line 8, in <module>
    sys.exit(main())
             ^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/distributed/run.py", line 918, in main
    run(args)
  File "/opt/conda/lib/python3.11/site-packages/torch/distributed/run.py", line 909, in run
    elastic_launch(
  File "/opt/conda/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 138, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/distributed/launcher/api.py", line 269, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
scripts/train.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2025-11-01_07:52:23
  host      : NH-DC-NM129-B01-20U-GPU-32.inspur.com
  rank      : 8 (local_rank: 0)
  exitcode  : 1 (pid: 1134864)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

When training OLMo on multiple nodes with a shared NFS directory for saving checkpoints, I encountered a FileExistsError or FileNotFoundError during the checkpoint saving phase, even when save_overwrite=True.

The error typically appears on non-zero nodes and looks like:

FileExistsError: [Errno 17] File exists: '/path/to/checkpoints/step-0

or

FileNotFoundError: [Errno 2] No such file or directory: '/path/to/checkpoints/step-0-tmp' -> '/path/to/checkpoints/step-0'

This happens because multiple local_rank=0 processes (one per node) are attempting to manage the checkpoint directory, leading to race conditions.


The issue lies in the _temporary_wd context manager in the checkpointing logic, specifically in how it determines which process should perform directory operations.

Relevant file:
olmo/checkpointing.py (or similar path depending on structure)

Relevant lines:

if get_fs_local_rank() == 0:
    shutil.rmtree(checkpoint_dir, ignore_errors=True)

and

if get_fs_local_rank() == 0:
    checkpoint_dir_tmp.mkdir(...)

and

if get_fs_local_rank() == 0:
    checkpoint_dir_tmp.replace(checkpoint_dir)

Problem: get_fs_local_rank() returns 0 for the first process on each node. So in a 2-node setup, both node0:rank0 and node1:rank0 will try to:

  • Delete the checkpoint dir
  • Create the temp dir
  • Replace the temp dir

This leads to race conditions, especially over NFS where directory visibility is eventually consistent.


Why Is This Intermittent?
The issue occurs non-deterministically because it depends on the timing of file system visibility over NFS.

When one node (e.g., node 0) creates or deletes a directory, it may take a short but variable amount of time for that change to propagate to other nodes via NFS. During this window:

Another node might see the directory as still existing → FileExistsError
Or not yet created → FileNotFoundError during replace()
This explains why the problem:

Only happens sometimes
Is more likely under high I/O load or network latency
Cannot be easily reproduced in single-node setups
The existing wait_for(lambda: ..., timeout=10.0) calls already acknowledge this eventual consistency, but they cannot fully prevent the race when multiple nodes are making conflicting changes.

Will the following modification handle the problem???

Replace all instances of get_fs_local_rank() with get_global_rank() in the _temporary_wd context manager, so that only the global rank 0 process performs these critical file operations.

Patch:

- if get_fs_local_rank() == 0:
+ if get_global_rank() == 0:

Applied to three locations in _temporary_wd:

  1. Initial cleanup of checkpoint directory

  2. Creation of temporary directory

  3. Final replace() operation

  4. This ensures that only one process in the entire training job performs directory management, eliminating race conditions.

  5. The barrier() calls already ensure all processes wait appropriately before and after, so correctness is preserved.

  6. The wait_for(...) logic for NFS visibility remains necessary and unchanged.


Additional Notes

  • This issue only manifests in multi-node setups with shared storage.
  • Single-node training is unaffected.
  • The current code comment about "another (file-system) local rank 0" already hints at this problem.

Metadata

Metadata

Assignees

No one assigned

    Labels

    type/questionAn issue that's a question

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions