From c03d390c71a9cdf7b378beea655032a90e7d75a1 Mon Sep 17 00:00:00 2001 From: Abhijit Paithankar Date: Tue, 23 Dec 2025 15:18:25 -0800 Subject: [PATCH 1/2] Add retry logic with configurable delays to checkpoint write operations Add tests for testing ckpt save retry --- .../strategies/filesystem_async.py | 98 ++++--- .../dist_checkpointing/test_async_save.py | 246 ++++++++++++++++++ 2 files changed, 310 insertions(+), 34 deletions(-) diff --git a/megatron/core/dist_checkpointing/strategies/filesystem_async.py b/megatron/core/dist_checkpointing/strategies/filesystem_async.py index b23c4e9893d..991ded06748 100644 --- a/megatron/core/dist_checkpointing/strategies/filesystem_async.py +++ b/megatron/core/dist_checkpointing/strategies/filesystem_async.py @@ -8,6 +8,7 @@ import os import pickle import queue +import time as time_module from functools import partial from heapq import heappop, heappush from itertools import chain @@ -366,45 +367,74 @@ def write_preloaded_data( mem_before = _process_memory() use_msc = kwargs.get("use_msc", False) + # Retry configuration + max_retries = kwargs.get("max_ckpt_save_retries", 3) + retry_delay = kwargs.get("ckpt_save_retry_delay", 10.0) # seconds + local_results = [] - try: - file_name, storage_key, (bytes_data, tensor_data) = write_bucket - extra_kwargs = {} - if "serialization_format" in inspect.signature(_write_item).parameters: - from torch.distributed.checkpoint.filesystem import SerializationFormat - - extra_kwargs["serialization_format"] = SerializationFormat.TORCH_SAVE - if use_msc: - import multistorageclient as msc - - open_file = msc.open - else: - open_file = open - with open_file(file_name, "wb") as stream: - for write_item, data in bytes_data: - local_results.append( - _write_item( - *transform_list, stream, data, write_item, storage_key, **extra_kwargs + local_output = None + + for attempt in range(max_retries): + try: + file_name, storage_key, (bytes_data, tensor_data) = write_bucket + extra_kwargs = {} + if "serialization_format" in inspect.signature(_write_item).parameters: + from torch.distributed.checkpoint.filesystem import SerializationFormat + + extra_kwargs["serialization_format"] = SerializationFormat.TORCH_SAVE + if use_msc: + import multistorageclient as msc + + open_file = msc.open + else: + open_file = open + + # Reset results for each retry attempt + local_results = [] + + with open_file(file_name, "wb") as stream: + for write_item, data in bytes_data: + local_results.append( + _write_item( + *transform_list, stream, data, write_item, storage_key, **extra_kwargs + ) ) - ) - for write_item, tensor in tensor_data: - assert tensor.is_cpu - local_results.append( - _write_item( - *transform_list, stream, tensor, write_item, storage_key, **extra_kwargs + for write_item, tensor in tensor_data: + assert tensor.is_cpu + local_results.append( + _write_item( + *transform_list, stream, tensor, write_item, storage_key, **extra_kwargs + ) ) - ) - if use_fsync: - if use_msc: - stream.fsync() - else: - os.fsync(stream.fileno()) - local_output = (local_proc_idx, local_results) - except Exception as e: - logger.debug(f"{local_proc_idx} failed") - local_output = (local_proc_idx, e) # type: ignore[assignment] + if use_fsync: + if use_msc: + stream.fsync() + else: + os.fsync(stream.fileno()) + + local_output = (local_proc_idx, local_results) + logger.debug(f"{local_proc_idx} completed successfully on attempt {attempt + 1}") + break # Success, exit retry loop + + except Exception as e: + is_last_attempt = (attempt == max_retries - 1) + + if is_last_attempt: + logger.error( + f"{local_proc_idx} failed after {max_retries} attempts. " + f"Last error: {type(e).__name__}: {str(e)}" + ) + local_output = (local_proc_idx, e) # type: ignore[assignment] + else: + logger.warning( + f"{local_proc_idx} failed on attempt {attempt + 1}/{max_retries} " + f"with error: {type(e).__name__}: {str(e)}. " + f"Retrying in {retry_delay:.2f} seconds..." + ) + # TODO: Use exponential backoff for retry delay + time_module.sleep(retry_delay) results_queue.put(local_output) # Signal this process is done. diff --git a/tests/unit_tests/dist_checkpointing/test_async_save.py b/tests/unit_tests/dist_checkpointing/test_async_save.py index 523342883b3..5ef9cf45328 100644 --- a/tests/unit_tests/dist_checkpointing/test_async_save.py +++ b/tests/unit_tests/dist_checkpointing/test_async_save.py @@ -1,8 +1,11 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import logging +import time from unittest import mock import pytest import torch +from torch import multiprocessing as mp from torch.distributed.checkpoint import CheckpointException from megatron.core.dist_checkpointing import ShardedTensor, load, save @@ -107,3 +110,246 @@ def test_errors_are_reported(self, tmp_path_dist_ckpt, async_save, worker_fn): FileSystemWriterAsync.write_preloaded_data = orig_fn Utils.destroy_model_parallel() + + +class TestRetryLogic: + + def test_retry_on_transient_error_with_eventual_success(self, tmp_path_dist_ckpt, caplog): + # Track call attempts + call_count = {'count': 0} + + def mock_write_item_fail_twice(*args, **kwargs): + # Raise exception two times, then succeed + call_count['count'] += 1 + if call_count['count'] <= 2: + raise OSError(f"Transient error on attempt {call_count['count']}") + # Return None since the result just needs to be picklable for multiprocessing + return None + + # Create mock queues + results_queue = mp.SimpleQueue() + count_queue = mp.JoinableQueue() + count_queue.put(None) # Add item for worker to consume + + write_bucket = ( + tmp_path_dist_ckpt / 'test_file.pt', + 'storage_key', + ([(mock.MagicMock(), b'test')], []), # (bytes_data, tensor_data) + ) + + with caplog.at_level(logging.WARNING): + with mock.patch( + 'megatron.core.dist_checkpointing.strategies.filesystem_async._write_item', + side_effect=mock_write_item_fail_twice, + ): + FileSystemWriterAsync.write_preloaded_data( + transform_list=[], + local_proc_idx=0, + write_bucket=write_bucket, + results_queue=results_queue, + count_queue=count_queue, + use_fsync=False, + max_ckpt_save_retries=3, + ckpt_save_retry_delay=0.1, + ) + + assert call_count['count'] == 3, "Should have attempted 3 times (2 failures + 1 success)" + + result = results_queue.get() + proc_idx, results = result + assert proc_idx == 0 + assert not isinstance(results, Exception), "Should have succeeded after retries" + + warning_logs = [record for record in caplog.records if record.levelname == 'WARNING'] + assert len(warning_logs) == 2, "Should have logged 2 warnings for failed attempts" + assert 'failed on attempt 1/3' in warning_logs[0].message + assert 'Retrying in 0.10 seconds' in warning_logs[0].message + assert 'OSError: Transient error on attempt 1' in warning_logs[0].message + + def test_retry_exhaustion_logs_error(self, tmp_path_dist_ckpt, caplog): + + def mock_write_item_always_fail(*args, **kwargs): + # Always fail + raise ConnectionError("Persistent connection error") + + results_queue = mp.SimpleQueue() + count_queue = mp.JoinableQueue() + count_queue.put(None) + + write_bucket = ( + tmp_path_dist_ckpt / 'test_file.pt', + 'storage_key', + ([(mock.MagicMock(), b'test')], []), + ) + + with caplog.at_level(logging.WARNING): + with mock.patch( + 'megatron.core.dist_checkpointing.strategies.filesystem_async._write_item', + side_effect=mock_write_item_always_fail, + ): + FileSystemWriterAsync.write_preloaded_data( + transform_list=[], + local_proc_idx=1, + write_bucket=write_bucket, + results_queue=results_queue, + count_queue=count_queue, + use_fsync=False, + max_ckpt_save_retries=3, + ckpt_save_retry_delay=0.05, + ) + + result = results_queue.get() + proc_idx, exception = result + assert proc_idx == 1 + assert isinstance(exception, ConnectionError) + + error_logs = [record for record in caplog.records if record.levelname == 'ERROR'] + assert len(error_logs) == 1, "Should have logged 1 error for final failure" + assert 'failed after 3 attempts' in error_logs[0].message + assert 'ConnectionError: Persistent connection error' in error_logs[0].message + + warning_logs = [record for record in caplog.records if record.levelname == 'WARNING'] + assert len(warning_logs) == 2, "Should have logged 2 warnings before final failure" + + def test_verify_retry_delay(self, tmp_path_dist_ckpt): + call_times = [] + + def mock_write_item_track_time(*args, **kwargs): + call_times.append(time.time()) + if len(call_times) <= 2: + raise TimeoutError(f"Timeout on attempt {len(call_times)}") + return None + + results_queue = mp.SimpleQueue() + count_queue = mp.JoinableQueue() + count_queue.put(None) + + write_bucket = ( + tmp_path_dist_ckpt / 'test_file.pt', + 'storage_key', + ([(mock.MagicMock(), b'test')], []), + ) + + retry_delay = 0.2 + + with mock.patch( + 'megatron.core.dist_checkpointing.strategies.filesystem_async._write_item', + side_effect=mock_write_item_track_time, + ): + FileSystemWriterAsync.write_preloaded_data( + transform_list=[], + local_proc_idx=2, + write_bucket=write_bucket, + results_queue=results_queue, + count_queue=count_queue, + use_fsync=False, + max_ckpt_save_retries=3, + ckpt_save_retry_delay=retry_delay, + ) + + assert len(call_times) == 3 + time_diff_1 = call_times[1] - call_times[0] + time_diff_2 = call_times[2] - call_times[1] + + assert ( + time_diff_1 >= retry_delay * 0.9 + ), f"First retry delay too short: {time_diff_1}s < {retry_delay}s" + assert ( + time_diff_2 >= retry_delay * 0.9 + ), f"Second retry delay too short: {time_diff_2}s < {retry_delay}s" + + def test_success_no_retry(self, tmp_path_dist_ckpt, caplog): + call_count = {'count': 0} + + def mock_write_item_succeed(*args, **kwargs): + """Mock that succeeds immediately""" + call_count['count'] += 1 + return None + + results_queue = mp.SimpleQueue() + count_queue = mp.JoinableQueue() + count_queue.put(None) + + write_bucket = ( + tmp_path_dist_ckpt / 'test_file.pt', + 'storage_key', + ([(mock.MagicMock(), b'test')], []), + ) + + with caplog.at_level(logging.WARNING): + with mock.patch( + 'megatron.core.dist_checkpointing.strategies.filesystem_async._write_item', + side_effect=mock_write_item_succeed, + ): + FileSystemWriterAsync.write_preloaded_data( + transform_list=[], + local_proc_idx=3, + write_bucket=write_bucket, + results_queue=results_queue, + count_queue=count_queue, + use_fsync=False, + max_ckpt_save_retries=3, + ckpt_save_retry_delay=0.1, + ) + + # Verify only one attempt + assert call_count['count'] == 1, "Should have attempted only once on success" + + # Verify no warning or error logs + warning_logs = [record for record in caplog.records if record.levelname == 'WARNING'] + error_logs = [record for record in caplog.records if record.levelname == 'ERROR'] + assert len(warning_logs) == 0, "Should have no warnings on immediate success" + assert len(error_logs) == 0, "Should have no errors on immediate success" + + # Verify successful result + result = results_queue.get() + proc_idx, results = result + assert proc_idx == 3 + assert not isinstance(results, Exception) + + def test_local_results_reset_between_retries(self, tmp_path_dist_ckpt): + results_per_attempt = [] + + def mock_write_item_collect_results(*args, **kwargs): + # Use a simple dict instead of MagicMock since it needs to be picklable + result = {'value': f"attempt_{len(results_per_attempt) + 1}"} + + if len(results_per_attempt) == 0: + results_per_attempt.append([result]) + raise RuntimeError("First attempt failure") + else: + results_per_attempt.append([result]) + return result + + results_queue = mp.SimpleQueue() + count_queue = mp.JoinableQueue() + count_queue.put(None) + + write_bucket = ( + tmp_path_dist_ckpt / 'test_file.pt', + 'storage_key', + ([(mock.MagicMock(), b'test')], []), + ) + + with mock.patch( + 'megatron.core.dist_checkpointing.strategies.filesystem_async._write_item', + side_effect=mock_write_item_collect_results, + ): + FileSystemWriterAsync.write_preloaded_data( + transform_list=[], + local_proc_idx=4, + write_bucket=write_bucket, + results_queue=results_queue, + count_queue=count_queue, + use_fsync=False, + max_ckpt_save_retries=2, + ckpt_save_retry_delay=0.05, + ) + + result = results_queue.get() + proc_idx, final_results = result + assert proc_idx == 4 + assert len(final_results) == 1, "Should only have results from successful attempt" + assert ( + final_results[0]['value'] == "attempt_2" + ), "Should have results from second attempt only" From 4e2a97ba3113398c72adeda47a91f6632ea2556f Mon Sep 17 00:00:00 2001 From: Abhijit Paithankar Date: Wed, 24 Dec 2025 15:31:13 -0800 Subject: [PATCH 2/2] Refactor retry logic from file-level to item-level to avoid having to write the entire file when writing a single item fails --- .../strategies/filesystem_async.py | 172 ++++++--- .../dist_checkpointing/test_async_save.py | 347 +++++++++++++----- 2 files changed, 366 insertions(+), 153 deletions(-) diff --git a/megatron/core/dist_checkpointing/strategies/filesystem_async.py b/megatron/core/dist_checkpointing/strategies/filesystem_async.py index 991ded06748..be00ec4f44c 100644 --- a/megatron/core/dist_checkpointing/strategies/filesystem_async.py +++ b/megatron/core/dist_checkpointing/strategies/filesystem_async.py @@ -347,6 +347,8 @@ def write_preloaded_data( results_queue: mp.SimpleQueue, count_queue: mp.JoinableQueue, use_fsync: bool, + max_item_retries: int = 3, + item_retry_delay: float = 10.0, **kwargs, ) -> None: """ @@ -367,74 +369,142 @@ def write_preloaded_data( mem_before = _process_memory() use_msc = kwargs.get("use_msc", False) - # Retry configuration - max_retries = kwargs.get("max_ckpt_save_retries", 3) - retry_delay = kwargs.get("ckpt_save_retry_delay", 10.0) # seconds - local_results = [] local_output = None - for attempt in range(max_retries): - try: - file_name, storage_key, (bytes_data, tensor_data) = write_bucket - extra_kwargs = {} - if "serialization_format" in inspect.signature(_write_item).parameters: - from torch.distributed.checkpoint.filesystem import SerializationFormat + def write_item_with_retry( + transform_list, + stream, + data, + write_item, + storage_key, + use_fsync, + use_msc, + max_item_retries, + item_retry_delay, + **extra_kwargs + ): + """ + Wraps _write_item with retry logic + + Args: + transform_list: List of storage writer transforms + stream: File stream to write to + data: Data to write (bytes or tensor) + write_item: WriteItem containing metadata + storage_key: Storage key for the item + use_fsync: Whether to call fsync after writing + use_msc: Whether using multistorageclient + max_item_retries: Maximum number of retry attempts for this item + item_retry_delay: Delay in seconds between retries + **extra_kwargs: Additional arguments for _write_item + + Returns: + WriteResult from _write_item + + Raises: + Exception: Re-raises the last exception if all retries fail + """ + last_exception = None + for attempt in range(max_item_retries): + try: + result = _write_item( + *transform_list, stream, data, write_item, storage_key, **extra_kwargs + ) - extra_kwargs["serialization_format"] = SerializationFormat.TORCH_SAVE - if use_msc: - import multistorageclient as msc + # Perform fsync if requested and write was successful + if use_fsync: + try: + if use_msc: + stream.fsync() + else: + os.fsync(stream.fileno()) + except Exception as fsync_err: + logger.warning( + f"fsync failed for item {write_item.index}: {type(fsync_err).__name__}: {str(fsync_err)}" + ) + # Continue despite fsync failure, but log it - open_file = msc.open - else: - open_file = open + return result - # Reset results for each retry attempt - local_results = [] + except Exception as e: + last_exception = e + is_last_attempt = (attempt == max_item_retries - 1) - with open_file(file_name, "wb") as stream: - for write_item, data in bytes_data: - local_results.append( - _write_item( - *transform_list, stream, data, write_item, storage_key, **extra_kwargs - ) + if is_last_attempt: + logger.error( + f"Failed to write item {write_item.index} after {max_item_retries} attempts. " + f"Last error: {type(e).__name__}: {str(e)}" + ) + raise + else: + logger.warning( + f"Write item {write_item.index} failed on attempt {attempt + 1}/{max_item_retries}. " + f"Error: {type(e).__name__}: {str(e)}. Retrying in {item_retry_delay}s..." ) + time_module.sleep(item_retry_delay) + + # Should not reach here, but just in case + if last_exception: + raise last_exception + + try: + file_name, storage_key, (bytes_data, tensor_data) = write_bucket + extra_kwargs = {} + if "serialization_format" in inspect.signature(_write_item).parameters: + from torch.distributed.checkpoint.filesystem import SerializationFormat + + extra_kwargs["serialization_format"] = SerializationFormat.TORCH_SAVE + if use_msc: + import multistorageclient as msc + + open_file = msc.open + else: + open_file = open + + # Reset results for each retry attempt + local_results = [] + + with open_file(file_name, "wb") as stream: + for write_item, data in bytes_data: + local_results.append( + write_item_with_retry( + transform_list, stream, data, write_item, storage_key, + use_fsync, use_msc, max_item_retries, item_retry_delay, **extra_kwargs + ) + ) - for write_item, tensor in tensor_data: - assert tensor.is_cpu - local_results.append( - _write_item( - *transform_list, stream, tensor, write_item, storage_key, **extra_kwargs - ) + for write_item, tensor in tensor_data: + assert tensor.is_cpu + local_results.append( + write_item_with_retry( + transform_list, stream, tensor, write_item, storage_key, + use_fsync, use_msc, max_item_retries, item_retry_delay, **extra_kwargs ) + ) - if use_fsync: + # Note: fsync is now handled inside write_item_with_retry for each item + # but we can still do a final fsync here if needed + if use_fsync: + try: if use_msc: stream.fsync() else: os.fsync(stream.fileno()) + except Exception as fsync_err: + logger.warning( + f"fsync failed for file {file_name}: {type(fsync_err).__name__}: {str(fsync_err)}" + ) + # Continue despite fsync failure, but log it - local_output = (local_proc_idx, local_results) - logger.debug(f"{local_proc_idx} completed successfully on attempt {attempt + 1}") - break # Success, exit retry loop - - except Exception as e: - is_last_attempt = (attempt == max_retries - 1) + local_output = (local_proc_idx, local_results) + logger.debug(f"{local_proc_idx} completed successfully") - if is_last_attempt: - logger.error( - f"{local_proc_idx} failed after {max_retries} attempts. " - f"Last error: {type(e).__name__}: {str(e)}" - ) - local_output = (local_proc_idx, e) # type: ignore[assignment] - else: - logger.warning( - f"{local_proc_idx} failed on attempt {attempt + 1}/{max_retries} " - f"with error: {type(e).__name__}: {str(e)}. " - f"Retrying in {retry_delay:.2f} seconds..." - ) - # TODO: Use exponential backoff for retry delay - time_module.sleep(retry_delay) + except Exception as e: + logger.error( + f"{local_proc_idx} failed with {type(e).__name__}: {str(e)}" + ) + local_output = (local_proc_idx, e) # type: ignore[assignment] results_queue.put(local_output) # Signal this process is done. diff --git a/tests/unit_tests/dist_checkpointing/test_async_save.py b/tests/unit_tests/dist_checkpointing/test_async_save.py index 5ef9cf45328..09a7adb0481 100644 --- a/tests/unit_tests/dist_checkpointing/test_async_save.py +++ b/tests/unit_tests/dist_checkpointing/test_async_save.py @@ -112,29 +112,82 @@ def test_errors_are_reported(self, tmp_path_dist_ckpt, async_save, worker_fn): Utils.destroy_model_parallel() -class TestRetryLogic: +class TestWriteItemWithRetry: + """Tests for the write_item_with_retry local function""" - def test_retry_on_transient_error_with_eventual_success(self, tmp_path_dist_ckpt, caplog): - # Track call attempts + def test_write_item_retry_success_first_attempt(self, tmp_path_dist_ckpt, caplog): + """Test that write_item_with_retry succeeds on first attempt""" + call_count = {'count': 0} + + def mock_write_item_succeed(*args, **kwargs): + call_count['count'] += 1 + return None + + results_queue = mp.SimpleQueue() + count_queue = mp.JoinableQueue() + count_queue.put(None) + + # Create mock WriteItem with index attribute + mock_write_item_obj = mock.MagicMock() + mock_write_item_obj.index = 'test_key_0' + + write_bucket = ( + tmp_path_dist_ckpt / 'test_file.pt', + 'storage_key', + ([(mock_write_item_obj, b'test_data')], []), # (bytes_data, tensor_data) + ) + + with caplog.at_level(logging.WARNING): + with mock.patch( + 'megatron.core.dist_checkpointing.strategies.filesystem_async._write_item', + side_effect=mock_write_item_succeed, + ): + FileSystemWriterAsync.write_preloaded_data( + transform_list=[], + local_proc_idx=0, + write_bucket=write_bucket, + results_queue=results_queue, + count_queue=count_queue, + use_fsync=False, + max_item_retries=3, + item_retry_delay=1.0, + ) + + # Verify only one attempt per item + assert call_count['count'] == 1, "Should call _write_item once for success" + + # Verify no warnings + warning_logs = [record for record in caplog.records if record.levelname == 'WARNING'] + assert len(warning_logs) == 0, "Should have no warnings on immediate success" + + # Verify successful result + result = results_queue.get() + proc_idx, results = result + assert proc_idx == 0 + assert not isinstance(results, Exception) + assert len(results) == 1 + + def test_write_item_retry_transient_failure(self, tmp_path_dist_ckpt, caplog): + """Test that write_item_with_retry retries on transient failures""" call_count = {'count': 0} def mock_write_item_fail_twice(*args, **kwargs): - # Raise exception two times, then succeed call_count['count'] += 1 if call_count['count'] <= 2: - raise OSError(f"Transient error on attempt {call_count['count']}") - # Return None since the result just needs to be picklable for multiprocessing + raise IOError(f"Transient I/O error on attempt {call_count['count']}") return None - # Create mock queues results_queue = mp.SimpleQueue() count_queue = mp.JoinableQueue() - count_queue.put(None) # Add item for worker to consume + count_queue.put(None) + + mock_write_item_obj = mock.MagicMock() + mock_write_item_obj.index = 'test_key_1' write_bucket = ( tmp_path_dist_ckpt / 'test_file.pt', 'storage_key', - ([(mock.MagicMock(), b'test')], []), # (bytes_data, tensor_data) + ([(mock_write_item_obj, b'test_data')], []), ) with caplog.at_level(logging.WARNING): @@ -149,37 +202,42 @@ def mock_write_item_fail_twice(*args, **kwargs): results_queue=results_queue, count_queue=count_queue, use_fsync=False, - max_ckpt_save_retries=3, - ckpt_save_retry_delay=0.1, + max_item_retries=3, + item_retry_delay=1.0, ) - assert call_count['count'] == 3, "Should have attempted 3 times (2 failures + 1 success)" + # Verify 3 attempts (2 failures + 1 success) + assert call_count['count'] == 3, "Should retry twice and succeed on third attempt" + # Verify warning logs + warning_logs = [record for record in caplog.records if record.levelname == 'WARNING'] + assert len(warning_logs) == 2, "Should have 2 warnings for the 2 failed attempts" + assert 'Write item test_key_1 failed on attempt 1/3' in warning_logs[0].message + assert 'Retrying in 1.0s' in warning_logs[0].message + + # Verify successful result result = results_queue.get() proc_idx, results = result assert proc_idx == 0 - assert not isinstance(results, Exception), "Should have succeeded after retries" - - warning_logs = [record for record in caplog.records if record.levelname == 'WARNING'] - assert len(warning_logs) == 2, "Should have logged 2 warnings for failed attempts" - assert 'failed on attempt 1/3' in warning_logs[0].message - assert 'Retrying in 0.10 seconds' in warning_logs[0].message - assert 'OSError: Transient error on attempt 1' in warning_logs[0].message + assert not isinstance(results, Exception) - def test_retry_exhaustion_logs_error(self, tmp_path_dist_ckpt, caplog): + def test_write_item_retry_exhaustion(self, tmp_path_dist_ckpt, caplog): + """Test that write_item_with_retry fails after exhausting retries""" def mock_write_item_always_fail(*args, **kwargs): - # Always fail - raise ConnectionError("Persistent connection error") + raise PermissionError("Permission denied - persistent error") results_queue = mp.SimpleQueue() count_queue = mp.JoinableQueue() count_queue.put(None) + mock_write_item_obj = mock.MagicMock() + mock_write_item_obj.index = 'test_key_2' + write_bucket = ( tmp_path_dist_ckpt / 'test_file.pt', 'storage_key', - ([(mock.MagicMock(), b'test')], []), + ([(mock_write_item_obj, b'test_data')], []), ) with caplog.at_level(logging.WARNING): @@ -189,167 +247,252 @@ def mock_write_item_always_fail(*args, **kwargs): ): FileSystemWriterAsync.write_preloaded_data( transform_list=[], - local_proc_idx=1, + local_proc_idx=0, write_bucket=write_bucket, results_queue=results_queue, count_queue=count_queue, use_fsync=False, - max_ckpt_save_retries=3, - ckpt_save_retry_delay=0.05, + max_item_retries=3, + item_retry_delay=1.0, ) + # Verify error result result = results_queue.get() proc_idx, exception = result - assert proc_idx == 1 - assert isinstance(exception, ConnectionError) + assert proc_idx == 0 + assert isinstance(exception, PermissionError) + # Verify error log error_logs = [record for record in caplog.records if record.levelname == 'ERROR'] - assert len(error_logs) == 1, "Should have logged 1 error for final failure" - assert 'failed after 3 attempts' in error_logs[0].message - assert 'ConnectionError: Persistent connection error' in error_logs[0].message - + assert len(error_logs) >= 1, "Should have at least one error log" + # Find the specific error about write item failure + item_error_logs = [log for log in error_logs if 'Failed to write item test_key_2' in log.message] + assert len(item_error_logs) == 1, "Should have error log for write_item failure" + assert 'after 3 attempts' in item_error_logs[0].message + assert 'PermissionError' in item_error_logs[0].message + + # Verify warning logs (should have 2 warnings for non-final attempts) warning_logs = [record for record in caplog.records if record.levelname == 'WARNING'] - assert len(warning_logs) == 2, "Should have logged 2 warnings before final failure" + item_warning_logs = [log for log in warning_logs if 'Write item test_key_2' in log.message] + assert len(item_warning_logs) == 2, "Should have 2 warnings before final failure" - def test_verify_retry_delay(self, tmp_path_dist_ckpt): - call_times = [] + def test_write_item_retry_with_fsync(self, tmp_path_dist_ckpt, caplog): + """Test that write_item_with_retry calls fsync when use_fsync=True""" + fsync_call_count = {'count': 0} - def mock_write_item_track_time(*args, **kwargs): - call_times.append(time.time()) - if len(call_times) <= 2: - raise TimeoutError(f"Timeout on attempt {len(call_times)}") - return None + def mock_write_item_succeed(*args, **kwargs): + return + + def mock_fsync(fileno): + fsync_call_count['count'] += 1 results_queue = mp.SimpleQueue() count_queue = mp.JoinableQueue() count_queue.put(None) + mock_write_item_obj = mock.MagicMock() + mock_write_item_obj.index = 'test_key_3' + write_bucket = ( tmp_path_dist_ckpt / 'test_file.pt', 'storage_key', - ([(mock.MagicMock(), b'test')], []), + ([(mock_write_item_obj, b'test_data')], []), ) - retry_delay = 0.2 - with mock.patch( 'megatron.core.dist_checkpointing.strategies.filesystem_async._write_item', - side_effect=mock_write_item_track_time, + side_effect=mock_write_item_succeed, ): - FileSystemWriterAsync.write_preloaded_data( - transform_list=[], - local_proc_idx=2, - write_bucket=write_bucket, - results_queue=results_queue, - count_queue=count_queue, - use_fsync=False, - max_ckpt_save_retries=3, - ckpt_save_retry_delay=retry_delay, - ) + with mock.patch('os.fsync', side_effect=mock_fsync): + FileSystemWriterAsync.write_preloaded_data( + transform_list=[], + local_proc_idx=0, + write_bucket=write_bucket, + results_queue=results_queue, + count_queue=count_queue, + use_fsync=True, # Enable fsync + max_item_retries=3, + item_retry_delay=1.0, + ) - assert len(call_times) == 3 - time_diff_1 = call_times[1] - call_times[0] - time_diff_2 = call_times[2] - call_times[1] + # Verify fsync was called (once per item + once at end) + assert fsync_call_count['count'] == 2, "Should call fsync twice (per item + final)" - assert ( - time_diff_1 >= retry_delay * 0.9 - ), f"First retry delay too short: {time_diff_1}s < {retry_delay}s" - assert ( - time_diff_2 >= retry_delay * 0.9 - ), f"Second retry delay too short: {time_diff_2}s < {retry_delay}s" + # Verify successful result + result = results_queue.get() + proc_idx, results = result + assert proc_idx == 0 + assert not isinstance(results, Exception) - def test_success_no_retry(self, tmp_path_dist_ckpt, caplog): - call_count = {'count': 0} + def test_write_item_retry_fsync_failure_logged(self, tmp_path_dist_ckpt, caplog): + """Test that fsync failures are logged but don't fail the write""" def mock_write_item_succeed(*args, **kwargs): - """Mock that succeeds immediately""" - call_count['count'] += 1 return None + def mock_fsync_fail(fileno): + raise OSError("fsync failed - disk full") + results_queue = mp.SimpleQueue() count_queue = mp.JoinableQueue() count_queue.put(None) + mock_write_item_obj = mock.MagicMock() + mock_write_item_obj.index = 'test_key_4' + write_bucket = ( tmp_path_dist_ckpt / 'test_file.pt', 'storage_key', - ([(mock.MagicMock(), b'test')], []), + ([(mock_write_item_obj, b'test_data')], []), ) with caplog.at_level(logging.WARNING): with mock.patch( 'megatron.core.dist_checkpointing.strategies.filesystem_async._write_item', side_effect=mock_write_item_succeed, + ): + with mock.patch('os.fsync', side_effect=mock_fsync_fail): + FileSystemWriterAsync.write_preloaded_data( + transform_list=[], + local_proc_idx=0, + write_bucket=write_bucket, + results_queue=results_queue, + count_queue=count_queue, + use_fsync=True, + use_msc=False, + max_item_retries=3, + item_retry_delay=1.0, + ) + + # Verify write succeeded despite fsync failure + result = results_queue.get() + proc_idx, results = result + assert proc_idx == 0 + assert not isinstance(results, Exception), "Write should succeed despite fsync failure" + + # Verify fsync failure was logged + warning_logs = [record for record in caplog.records if record.levelname == 'WARNING'] + fsync_warnings = [log for log in warning_logs if 'fsync failed' in log.message] + assert len(fsync_warnings) >= 1, "Should have warning about fsync failure" + assert 'test_key_4' in fsync_warnings[0].message + assert 'OSError' in fsync_warnings[0].message + + def test_write_item_retry_multiple_items(self, tmp_path_dist_ckpt, caplog): + """Test that write_item_with_retry handles multiple items correctly""" + call_counts = {'item_0': 0, 'item_1': 0, 'item_2': 0} + + def mock_write_item_mixed_failures(*args, **kwargs): + # args[2] is write_item in the signature + write_item = args[2] + item_key = write_item.index + call_counts[item_key] += 1 + + # First item fails once, second succeeds, third fails twice + if item_key == 'item_0' and call_counts[item_key] <= 1: + raise ConnectionError(f"{item_key} transient error") + elif item_key == 'item_2' and call_counts[item_key] <= 2: + raise TimeoutError(f"{item_key} timeout") + + return None + + results_queue = mp.SimpleQueue() + count_queue = mp.JoinableQueue() + count_queue.put(None) + + # Create multiple write items + mock_items = [] + for i in range(3): + mock_item = mock.MagicMock() + mock_item.index = f'item_{i}' + mock_items.append((mock_item, f'data_{i}'.encode())) + + write_bucket = ( + tmp_path_dist_ckpt / 'test_file.pt', + 'storage_key', + (mock_items, []), # bytes_data with 3 items + ) + + with caplog.at_level(logging.WARNING): + with mock.patch( + 'megatron.core.dist_checkpointing.strategies.filesystem_async._write_item', + side_effect=mock_write_item_mixed_failures, ): FileSystemWriterAsync.write_preloaded_data( transform_list=[], - local_proc_idx=3, + local_proc_idx=0, write_bucket=write_bucket, results_queue=results_queue, count_queue=count_queue, use_fsync=False, - max_ckpt_save_retries=3, - ckpt_save_retry_delay=0.1, + max_item_retries=3, + item_retry_delay=1.0, ) - # Verify only one attempt - assert call_count['count'] == 1, "Should have attempted only once on success" - - # Verify no warning or error logs - warning_logs = [record for record in caplog.records if record.levelname == 'WARNING'] - error_logs = [record for record in caplog.records if record.levelname == 'ERROR'] - assert len(warning_logs) == 0, "Should have no warnings on immediate success" - assert len(error_logs) == 0, "Should have no errors on immediate success" + # Verify call counts: item_0 fails once (2 attempts), item_1 succeeds (1 attempt), item_2 fails twice (3 attempts) + assert call_counts['item_0'] == 2, "Item 0 should have 2 attempts" + assert call_counts['item_1'] == 1, "Item 1 should have 1 attempt" + assert call_counts['item_2'] == 3, "Item 2 should have 3 attempts" # Verify successful result result = results_queue.get() proc_idx, results = result - assert proc_idx == 3 + assert proc_idx == 0 assert not isinstance(results, Exception) + assert len(results) == 3, "Should have 3 results for 3 items" - def test_local_results_reset_between_retries(self, tmp_path_dist_ckpt): - results_per_attempt = [] + # Verify warning logs for failed attempts + warning_logs = [record for record in caplog.records if record.levelname == 'WARNING'] + item_0_warnings = [log for log in warning_logs if 'item_0' in log.message] + item_2_warnings = [log for log in warning_logs if 'item_2' in log.message] + assert len(item_0_warnings) == 1, "Item 0 should have 1 warning" + assert len(item_2_warnings) == 2, "Item 2 should have 2 warnings" - def mock_write_item_collect_results(*args, **kwargs): - # Use a simple dict instead of MagicMock since it needs to be picklable - result = {'value': f"attempt_{len(results_per_attempt) + 1}"} + def test_write_item_with_tensor_data(self, tmp_path_dist_ckpt): + """Test that write_item_with_retry works with tensor data""" + call_count = {'count': 0} - if len(results_per_attempt) == 0: - results_per_attempt.append([result]) - raise RuntimeError("First attempt failure") - else: - results_per_attempt.append([result]) - return result + def mock_write_item_succeed(*args, **kwargs): + call_count['count'] += 1 + return None results_queue = mp.SimpleQueue() count_queue = mp.JoinableQueue() count_queue.put(None) + mock_write_item_obj = mock.MagicMock() + mock_write_item_obj.index = 'test_tensor_0' + + # Create a CPU tensor + test_tensor = torch.ones(3, 4) + assert test_tensor.is_cpu, "Test tensor must be on CPU" + write_bucket = ( tmp_path_dist_ckpt / 'test_file.pt', 'storage_key', - ([(mock.MagicMock(), b'test')], []), + ([], [(mock_write_item_obj, test_tensor)]), # Empty bytes_data, one tensor ) with mock.patch( 'megatron.core.dist_checkpointing.strategies.filesystem_async._write_item', - side_effect=mock_write_item_collect_results, + side_effect=mock_write_item_succeed, ): FileSystemWriterAsync.write_preloaded_data( transform_list=[], - local_proc_idx=4, + local_proc_idx=0, write_bucket=write_bucket, results_queue=results_queue, count_queue=count_queue, use_fsync=False, - max_ckpt_save_retries=2, - ckpt_save_retry_delay=0.05, + max_item_retries=3, + item_retry_delay=1.0, ) + # Verify write was called + assert call_count['count'] == 1, "Should write tensor once" + + # Verify successful result result = results_queue.get() - proc_idx, final_results = result - assert proc_idx == 4 - assert len(final_results) == 1, "Should only have results from successful attempt" - assert ( - final_results[0]['value'] == "attempt_2" - ), "Should have results from second attempt only" + proc_idx, results = result + assert proc_idx == 0 + assert not isinstance(results, Exception) + assert len(results) == 1