From 543dbd0019502e970b3a37da8011bae5c9226f9a Mon Sep 17 00:00:00 2001 From: lukifer23 <56565060+lukifer23@users.noreply.github.com> Date: Tue, 23 Sep 2025 08:38:26 -0500 Subject: [PATCH] Ensure training memory monitor cleans up callbacks --- azchess/training/train.py | 42 +++++- azchess/utils/__init__.py | 2 + azchess/utils/memory_monitor.py | 34 +++-- tests/test_train_memory_monitor.py | 212 +++++++++++++++++++++++++++++ 4 files changed, 276 insertions(+), 14 deletions(-) create mode 100644 tests/test_train_memory_monitor.py diff --git a/azchess/training/train.py b/azchess/training/train.py index f78358e..8e44502 100644 --- a/azchess/training/train.py +++ b/azchess/training/train.py @@ -35,8 +35,9 @@ from azchess.training.npz_dataset import build_training_dataloader from azchess.utils import (add_memory_alert_callback, clear_memory_cache, emergency_memory_cleanup, get_memory_usage, - log_tensor_stats, safe_config_get, - start_memory_monitoring) + log_tensor_stats, remove_memory_alert_callback, + safe_config_get, start_memory_monitoring, + stop_memory_monitoring) # Setup logging logger = setup_logging(level=logging.INFO) @@ -1076,15 +1077,25 @@ def train_comprehensive( memory_warning_threshold = safe_config_get(cfg, 'memory_warning_threshold', 0.85, section='training') # 85% default memory_critical_threshold = safe_config_get(cfg, 'memory_critical_threshold', 0.95, section='training') # 95% default + # Track monitor lifecycle for cleanup and fallback warnings + memory_monitor_started = False + training_callback_registered = False + training_memory_alert_callback = None + last_memory_warning = 0 + memory_warning_cooldown = 300 # 5 minutes between warnings (fallback path) + # Start comprehensive memory monitoring try: - start_memory_monitoring( + memory_monitor_started = start_memory_monitoring( device=device, warning_threshold=memory_warning_threshold, critical_threshold=memory_critical_threshold, check_interval=30.0 # Check every 30 seconds ) - logger.info("Advanced memory monitoring system started") + if memory_monitor_started: + logger.info("Advanced memory monitoring system started") + else: + logger.info("Advanced memory monitoring system already running; reusing existing monitor") # Add custom alert callback for training-specific actions def training_memory_alert_callback(alert): @@ -1094,12 +1105,14 @@ def training_memory_alert_callback(alert): logger.warning(f"HIGH MEMORY: Monitor training stability. Memory: {alert.memory_usage_gb:.2f}GB") add_memory_alert_callback(training_memory_alert_callback) + training_callback_registered = True except Exception as e: logger.warning(f"Could not start advanced memory monitoring: {e}") # Fallback to basic monitoring - last_memory_warning = 0 - memory_warning_cooldown = 300 # 5 minutes between warnings + memory_monitor_started = False + training_memory_alert_callback = None + training_callback_registered = False def get_system_memory_usage(): """Get current system memory usage in GB for heartbeat monitoring.""" @@ -1545,6 +1558,23 @@ def _finite(x): finally: pbar.close() + # Clean up memory monitoring callbacks and threads + if training_callback_registered and training_memory_alert_callback is not None: + try: + removed = remove_memory_alert_callback(training_memory_alert_callback) + if removed: + logger.debug("Removed training memory alert callback") + except Exception as callback_error: + logger.warning(f"Failed to remove training memory alert callback: {callback_error}") + + if memory_monitor_started: + try: + stopped = stop_memory_monitoring() + if stopped: + logger.debug("Advanced memory monitoring system stopped") + except Exception as monitor_error: + logger.warning(f"Failed to stop memory monitoring: {monitor_error}") + # Save final checkpoint with enhanced prefix, but handle KeyboardInterrupt gracefully if not locals().get('interrupted', False): try: diff --git a/azchess/utils/__init__.py b/azchess/utils/__init__.py index d6a16f5..0c669d1 100644 --- a/azchess/utils/__init__.py +++ b/azchess/utils/__init__.py @@ -31,6 +31,7 @@ emergency_memory_cleanup, get_memory_usage) from .memory_monitor import (MemoryAlert, MemoryMonitor, add_memory_alert_callback, get_memory_stats, + remove_memory_alert_callback, start_memory_monitoring, stop_memory_monitoring) from .model_loader import load_model_and_mcts from .performance_utils import (PerformanceMetric, PerformanceMonitor, @@ -63,6 +64,7 @@ "stop_memory_monitoring", "get_memory_stats", "add_memory_alert_callback", + "remove_memory_alert_callback", # Device management "DeviceManager", diff --git a/azchess/utils/memory_monitor.py b/azchess/utils/memory_monitor.py index bc462c2..9fbc2da 100644 --- a/azchess/utils/memory_monitor.py +++ b/azchess/utils/memory_monitor.py @@ -69,7 +69,7 @@ def start_monitoring(self, device: str = 'auto') -> None: """Start the memory monitoring thread.""" if self.is_monitoring: logger.warning("Memory monitoring is already running") - return + return False self.is_monitoring = True self.stop_event.clear() @@ -81,11 +81,12 @@ def start_monitoring(self, device: str = 'auto') -> None: ) self.monitor_thread.start() logger.info(f"Memory monitoring started for device: {device}") + return True - def stop_monitoring(self) -> None: + def stop_monitoring(self) -> bool: """Stop the memory monitoring thread.""" - if not self.is_monitoring: - return + if not self.is_monitoring and not (self.monitor_thread and self.monitor_thread.is_alive()): + return False self.is_monitoring = False self.stop_event.set() @@ -93,7 +94,9 @@ def stop_monitoring(self) -> None: if self.monitor_thread and self.monitor_thread.is_alive(): self.monitor_thread.join(timeout=5.0) + self.monitor_thread = None logger.info("Memory monitoring stopped") + return True def _monitor_loop(self, device: str) -> None: """Main monitoring loop.""" @@ -202,6 +205,16 @@ def add_alert_callback(self, callback: Callable[[MemoryAlert], None]) -> None: """Add a callback function to be called when alerts are generated.""" self.alert_callbacks.append(callback) + def remove_alert_callback(self, callback: Callable[[MemoryAlert], None]) -> bool: + """Remove a previously registered alert callback.""" + try: + self.alert_callbacks.remove(callback) + logger.debug("Removed memory alert callback: %s", getattr(callback, "__name__", repr(callback))) + return True + except ValueError: + logger.debug("Attempted to remove unregistered memory alert callback: %s", getattr(callback, "__name__", repr(callback))) + return False + def get_memory_stats(self) -> Dict[str, Any]: """Get comprehensive memory statistics.""" stats = { @@ -235,19 +248,19 @@ def set_device_limit(self, device: str, limit_gb: float) -> None: memory_monitor = MemoryMonitor() -def start_memory_monitoring(device: str = 'auto', **kwargs) -> None: +def start_memory_monitoring(device: str = 'auto', **kwargs) -> bool: """Convenience function to start memory monitoring.""" # Update monitor settings if provided for key, value in kwargs.items(): if hasattr(memory_monitor, key): setattr(memory_monitor, key, value) - memory_monitor.start_monitoring(device) + return memory_monitor.start_monitoring(device) -def stop_memory_monitoring() -> None: +def stop_memory_monitoring() -> bool: """Convenience function to stop memory monitoring.""" - memory_monitor.stop_monitoring() + return memory_monitor.stop_monitoring() def get_memory_stats() -> Dict[str, Any]: @@ -258,3 +271,8 @@ def get_memory_stats() -> Dict[str, Any]: def add_memory_alert_callback(callback: Callable[[MemoryAlert], None]) -> None: """Convenience function to add alert callback.""" memory_monitor.add_alert_callback(callback) + + +def remove_memory_alert_callback(callback: Callable[[MemoryAlert], None]) -> bool: + """Convenience function to remove alert callback.""" + return memory_monitor.remove_alert_callback(callback) diff --git a/tests/test_train_memory_monitor.py b/tests/test_train_memory_monitor.py new file mode 100644 index 0000000..e98b525 --- /dev/null +++ b/tests/test_train_memory_monitor.py @@ -0,0 +1,212 @@ +from __future__ import annotations + +import time +from types import SimpleNamespace + +import torch + +from azchess.utils.memory_monitor import MemoryAlert, memory_monitor + + +class DummyModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear = torch.nn.Linear(1, 1) + + def forward(self, x): # pragma: no cover - not used in this test + return self.linear(x) + + def enable_gradient_checkpointing(self, strategy: str = "adaptive") -> None: + self._checkpoint_strategy = strategy + + def get_memory_usage(self): + return {"parameters_gb": 0.0} + + def enable_memory_optimization(self) -> None: + self._memory_optimization_enabled = True + + +class DummyPolicyValueNet: + @classmethod + def from_config(cls, _cfg): + return DummyModel() + + +class DummyDataManager: + def __init__(self, *_, **__): + self._stats = SimpleNamespace(total_samples=1, total_shards=1) + + def get_stats(self): + return self._stats + + def get_external_data_stats(self): + return {"external_total": 0, "tactical_samples": 0, "openings_samples": 0} + + def get_curriculum_batch(self, *_args, **_kwargs): + return None + + def get_training_batch(self, *_args, **_kwargs): + while False: + yield None + + +class DummySummaryWriter: + def __init__(self, *_args, **_kwargs) -> None: + self.scalars = [] + + def add_scalar(self, *_args, **_kwargs) -> None: + pass + + def close(self) -> None: + pass + + +class DummyConfig: + def __init__(self, log_dir: str) -> None: + self._root = {} + self._training = { + "gradient_accumulation_steps": None, + "compile": False, + "compile_mode": "default", + "steps_per_epoch": 1, + "use_curriculum": False, + "curriculum_phases": [], + "dataloader_workers": 0, + "prefetch_factor": 0, + "memory_limit_gb": 16, + "memory_warning_threshold": 0.85, + "memory_critical_threshold": 0.95, + "ssl_weight": 0.0, + "policy_label_smoothing": 0.0, + "value_loss": "mse", + "huber_delta": 1.0, + "policy_masking": False, + "ssl_warmup_steps": 0, + "ssl_target_weight": 1.0, + "ssl_targets_provider": "auto", + "wdl_weight": 0.0, + "wdl_margin": 0.25, + "precision": "fp32", + "ssl_every_n": 1, + "ssl_chunk_size": 0, + "log_dir": log_dir, + } + self._model = {"self_supervised": False, "wdl": False, "ssl_curriculum": False} + + def training(self): + return self._training + + def model(self): + return self._model + + def get(self, key, default=None): + return self._root.get(key, default) + + +def test_train_comprehensive_memory_monitor_cleanup(monkeypatch, tmp_path): + from azchess.training import train + + # Ensure monitor is not running before the test + memory_monitor.stop_monitoring() + memory_monitor.alert_callbacks.clear() + + dummy_log_dir = tmp_path / "logs" + dummy_ckpt_dir = tmp_path / "ckpts" + + # Patch heavy dependencies with lightweight stand-ins + monkeypatch.setattr(train, "PolicyValueNet", DummyPolicyValueNet) + monkeypatch.setattr(train, "DataManager", DummyDataManager) + monkeypatch.setattr(train, "build_training_dataloader", lambda *a, **k: None) + monkeypatch.setattr(train, "SummaryWriter", DummySummaryWriter) + monkeypatch.setattr(train, "clear_memory_cache", lambda *a, **k: None) + monkeypatch.setattr(train, "get_memory_usage", lambda *a, **k: {"memory_gb": 1.0}) + monkeypatch.setattr(train, "emergency_memory_cleanup", lambda *a, **k: None) + monkeypatch.setattr(train, "save_checkpoint", lambda *a, **k: None) + + def fake_config_load(_path): + return DummyConfig(str(dummy_log_dir)) + + monkeypatch.setattr(train.Config, "load", staticmethod(fake_config_load)) + + warning_messages: list[str] = [] + critical_messages: list[str] = [] + + def capture_warning(msg, *args, **kwargs): + text = msg % args if args else msg + warning_messages.append(text) + + def capture_critical(msg, *args, **kwargs): + text = msg % args if args else msg + critical_messages.append(text) + + monkeypatch.setattr(train.logger, "warning", capture_warning) + monkeypatch.setattr(train.logger, "critical", capture_critical) + + callback_lengths: list[int] = [] + alert_fire_counts: list[int] = [] + + def tracked_add_callback(callback): + pre_len = len(memory_monitor.alert_callbacks) + callback_lengths.append(pre_len) + memory_monitor.add_alert_callback(callback) + post_len = len(memory_monitor.alert_callbacks) + callback_lengths.append(post_len) + + before_warning = len(warning_messages) + alert = MemoryAlert( + alert_type="warning", + message="unit-test alert", + memory_usage_gb=1.0, + memory_limit_gb=2.0, + timestamp=time.time(), + device="cpu", + ) + memory_monitor._send_alert(alert) + high_memory_logs = [ + msg for msg in warning_messages[before_warning:] if msg.startswith("HIGH MEMORY:") + ] + alert_fire_counts.append(len(high_memory_logs)) + + monkeypatch.setattr(train, "add_memory_alert_callback", tracked_add_callback) + + run_kwargs = dict( + config_path="dummy", + total_steps=0, + batch_size=1, + learning_rate=0.001, + weight_decay=0.0, + ema_decay=0.0, + grad_clip_norm=1.0, + accum_steps=1, + warmup_steps=0, + checkpoint_dir=str(dummy_ckpt_dir), + log_dir=str(dummy_log_dir), + device="cpu", + use_amp=False, + augment=False, + precision="fp32", + epochs=0, + steps_per_epoch=0, + init_checkpoint=None, + resume=False, + data_mode=None, + dataloader_workers=0, + prefetch_factor=0, + ) + + # Run twice to ensure cleanup between runs + train.train_comprehensive(**run_kwargs) + assert not memory_monitor.is_monitoring + assert not (memory_monitor.monitor_thread and memory_monitor.monitor_thread.is_alive()) + assert len(memory_monitor.alert_callbacks) == 0 + + train.train_comprehensive(**run_kwargs) + assert not memory_monitor.is_monitoring + assert not (memory_monitor.monitor_thread and memory_monitor.monitor_thread.is_alive()) + assert len(memory_monitor.alert_callbacks) == 0 + + # Verify callbacks were not accumulated and only fired once per registration + assert callback_lengths == [0, 1, 0, 1] + assert alert_fire_counts == [1, 1] + assert all(msg.startswith("HIGH MEMORY:") for msg in warning_messages if "HIGH MEMORY:" in msg) + assert critical_messages == []