Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions azchess/training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
from azchess.training.npz_dataset import build_training_dataloader
from azchess.utils import (add_memory_alert_callback, clear_memory_cache,
emergency_memory_cleanup, get_memory_usage,
log_tensor_stats, safe_config_get,
start_memory_monitoring)
log_tensor_stats, remove_memory_alert_callback,
safe_config_get, start_memory_monitoring,
stop_memory_monitoring)

# Setup logging
logger = setup_logging(level=logging.INFO)
Expand Down Expand Up @@ -1076,15 +1077,25 @@ def train_comprehensive(
memory_warning_threshold = safe_config_get(cfg, 'memory_warning_threshold', 0.85, section='training') # 85% default
memory_critical_threshold = safe_config_get(cfg, 'memory_critical_threshold', 0.95, section='training') # 95% default

# Track monitor lifecycle for cleanup and fallback warnings
memory_monitor_started = False
training_callback_registered = False
training_memory_alert_callback = None
last_memory_warning = 0
memory_warning_cooldown = 300 # 5 minutes between warnings (fallback path)

# Start comprehensive memory monitoring
try:
start_memory_monitoring(
memory_monitor_started = start_memory_monitoring(
device=device,
warning_threshold=memory_warning_threshold,
critical_threshold=memory_critical_threshold,
check_interval=30.0 # Check every 30 seconds
)
logger.info("Advanced memory monitoring system started")
if memory_monitor_started:
logger.info("Advanced memory monitoring system started")
else:
logger.info("Advanced memory monitoring system already running; reusing existing monitor")

# Add custom alert callback for training-specific actions
def training_memory_alert_callback(alert):
Expand All @@ -1094,12 +1105,14 @@ def training_memory_alert_callback(alert):
logger.warning(f"HIGH MEMORY: Monitor training stability. Memory: {alert.memory_usage_gb:.2f}GB")

add_memory_alert_callback(training_memory_alert_callback)
training_callback_registered = True

except Exception as e:
logger.warning(f"Could not start advanced memory monitoring: {e}")
# Fallback to basic monitoring
last_memory_warning = 0
memory_warning_cooldown = 300 # 5 minutes between warnings
memory_monitor_started = False
training_memory_alert_callback = None
training_callback_registered = False

def get_system_memory_usage():
"""Get current system memory usage in GB for heartbeat monitoring."""
Expand Down Expand Up @@ -1545,6 +1558,23 @@ def _finite(x):
finally:
pbar.close()

# Clean up memory monitoring callbacks and threads
if training_callback_registered and training_memory_alert_callback is not None:
try:
removed = remove_memory_alert_callback(training_memory_alert_callback)
if removed:
logger.debug("Removed training memory alert callback")
except Exception as callback_error:
logger.warning(f"Failed to remove training memory alert callback: {callback_error}")

if memory_monitor_started:
try:
stopped = stop_memory_monitoring()
if stopped:
logger.debug("Advanced memory monitoring system stopped")
except Exception as monitor_error:
logger.warning(f"Failed to stop memory monitoring: {monitor_error}")

# Save final checkpoint with enhanced prefix, but handle KeyboardInterrupt gracefully
if not locals().get('interrupted', False):
try:
Expand Down
2 changes: 2 additions & 0 deletions azchess/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
emergency_memory_cleanup, get_memory_usage)
from .memory_monitor import (MemoryAlert, MemoryMonitor,
add_memory_alert_callback, get_memory_stats,
remove_memory_alert_callback,
start_memory_monitoring, stop_memory_monitoring)
from .model_loader import load_model_and_mcts
from .performance_utils import (PerformanceMetric, PerformanceMonitor,
Expand Down Expand Up @@ -63,6 +64,7 @@
"stop_memory_monitoring",
"get_memory_stats",
"add_memory_alert_callback",
"remove_memory_alert_callback",

# Device management
"DeviceManager",
Expand Down
34 changes: 26 additions & 8 deletions azchess/utils/memory_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def start_monitoring(self, device: str = 'auto') -> None:
"""Start the memory monitoring thread."""
if self.is_monitoring:
logger.warning("Memory monitoring is already running")
return
return False

self.is_monitoring = True
self.stop_event.clear()
Expand All @@ -81,19 +81,22 @@ def start_monitoring(self, device: str = 'auto') -> None:
)
self.monitor_thread.start()
logger.info(f"Memory monitoring started for device: {device}")
return True

def stop_monitoring(self) -> None:
def stop_monitoring(self) -> bool:
"""Stop the memory monitoring thread."""
if not self.is_monitoring:
return
if not self.is_monitoring and not (self.monitor_thread and self.monitor_thread.is_alive()):
Copy link

Copilot AI Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This condition check is confusing. Consider simplifying to check if there's nothing to stop: if not self.is_monitoring and (not self.monitor_thread or not self.monitor_thread.is_alive()):

Suggested change
if not self.is_monitoring and not (self.monitor_thread and self.monitor_thread.is_alive()):
if not self.is_monitoring and (not self.monitor_thread or not self.monitor_thread.is_alive()):

Copilot uses AI. Check for mistakes.
return False

self.is_monitoring = False
self.stop_event.set()

if self.monitor_thread and self.monitor_thread.is_alive():
self.monitor_thread.join(timeout=5.0)

self.monitor_thread = None
logger.info("Memory monitoring stopped")
return True

def _monitor_loop(self, device: str) -> None:
"""Main monitoring loop."""
Expand Down Expand Up @@ -202,6 +205,16 @@ def add_alert_callback(self, callback: Callable[[MemoryAlert], None]) -> None:
"""Add a callback function to be called when alerts are generated."""
self.alert_callbacks.append(callback)

def remove_alert_callback(self, callback: Callable[[MemoryAlert], None]) -> bool:
"""Remove a previously registered alert callback."""
try:
self.alert_callbacks.remove(callback)
logger.debug("Removed memory alert callback: %s", getattr(callback, "__name__", repr(callback)))
return True
except ValueError:
logger.debug("Attempted to remove unregistered memory alert callback: %s", getattr(callback, "__name__", repr(callback)))
return False

def get_memory_stats(self) -> Dict[str, Any]:
"""Get comprehensive memory statistics."""
stats = {
Expand Down Expand Up @@ -235,19 +248,19 @@ def set_device_limit(self, device: str, limit_gb: float) -> None:
memory_monitor = MemoryMonitor()


def start_memory_monitoring(device: str = 'auto', **kwargs) -> None:
def start_memory_monitoring(device: str = 'auto', **kwargs) -> bool:
"""Convenience function to start memory monitoring."""
# Update monitor settings if provided
for key, value in kwargs.items():
if hasattr(memory_monitor, key):
setattr(memory_monitor, key, value)

memory_monitor.start_monitoring(device)
return memory_monitor.start_monitoring(device)


def stop_memory_monitoring() -> None:
def stop_memory_monitoring() -> bool:
"""Convenience function to stop memory monitoring."""
memory_monitor.stop_monitoring()
return memory_monitor.stop_monitoring()


def get_memory_stats() -> Dict[str, Any]:
Expand All @@ -258,3 +271,8 @@ def get_memory_stats() -> Dict[str, Any]:
def add_memory_alert_callback(callback: Callable[[MemoryAlert], None]) -> None:
"""Convenience function to add alert callback."""
memory_monitor.add_alert_callback(callback)


def remove_memory_alert_callback(callback: Callable[[MemoryAlert], None]) -> bool:
"""Convenience function to remove alert callback."""
return memory_monitor.remove_alert_callback(callback)
212 changes: 212 additions & 0 deletions tests/test_train_memory_monitor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
from __future__ import annotations

import time
from types import SimpleNamespace

import torch

from azchess.utils.memory_monitor import MemoryAlert, memory_monitor


class DummyModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(1, 1)

def forward(self, x): # pragma: no cover - not used in this test
return self.linear(x)

def enable_gradient_checkpointing(self, strategy: str = "adaptive") -> None:
self._checkpoint_strategy = strategy

def get_memory_usage(self):
return {"parameters_gb": 0.0}

def enable_memory_optimization(self) -> None:
self._memory_optimization_enabled = True


class DummyPolicyValueNet:
@classmethod
def from_config(cls, _cfg):
return DummyModel()


class DummyDataManager:
def __init__(self, *_, **__):
self._stats = SimpleNamespace(total_samples=1, total_shards=1)

def get_stats(self):
return self._stats

def get_external_data_stats(self):
return {"external_total": 0, "tactical_samples": 0, "openings_samples": 0}

def get_curriculum_batch(self, *_args, **_kwargs):
return None

def get_training_batch(self, *_args, **_kwargs):
while False:
yield None
Comment on lines +49 to +50
Copy link

Copilot AI Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generator will never yield any values since the while condition is False. This should be while True: or use a different approach to create an empty generator.

Suggested change
while False:
yield None
if False:
yield

Copilot uses AI. Check for mistakes.


class DummySummaryWriter:
def __init__(self, *_args, **_kwargs) -> None:
self.scalars = []

def add_scalar(self, *_args, **_kwargs) -> None:
pass

def close(self) -> None:
pass


class DummyConfig:
def __init__(self, log_dir: str) -> None:
self._root = {}
self._training = {
"gradient_accumulation_steps": None,
"compile": False,
"compile_mode": "default",
"steps_per_epoch": 1,
"use_curriculum": False,
"curriculum_phases": [],
"dataloader_workers": 0,
"prefetch_factor": 0,
"memory_limit_gb": 16,
"memory_warning_threshold": 0.85,
"memory_critical_threshold": 0.95,
"ssl_weight": 0.0,
"policy_label_smoothing": 0.0,
"value_loss": "mse",
"huber_delta": 1.0,
"policy_masking": False,
"ssl_warmup_steps": 0,
"ssl_target_weight": 1.0,
"ssl_targets_provider": "auto",
"wdl_weight": 0.0,
"wdl_margin": 0.25,
"precision": "fp32",
"ssl_every_n": 1,
"ssl_chunk_size": 0,
"log_dir": log_dir,
}
self._model = {"self_supervised": False, "wdl": False, "ssl_curriculum": False}

def training(self):
return self._training

def model(self):
return self._model

def get(self, key, default=None):
return self._root.get(key, default)


def test_train_comprehensive_memory_monitor_cleanup(monkeypatch, tmp_path):
from azchess.training import train

# Ensure monitor is not running before the test
memory_monitor.stop_monitoring()
memory_monitor.alert_callbacks.clear()

dummy_log_dir = tmp_path / "logs"
dummy_ckpt_dir = tmp_path / "ckpts"

# Patch heavy dependencies with lightweight stand-ins
monkeypatch.setattr(train, "PolicyValueNet", DummyPolicyValueNet)
monkeypatch.setattr(train, "DataManager", DummyDataManager)
monkeypatch.setattr(train, "build_training_dataloader", lambda *a, **k: None)
monkeypatch.setattr(train, "SummaryWriter", DummySummaryWriter)
monkeypatch.setattr(train, "clear_memory_cache", lambda *a, **k: None)
monkeypatch.setattr(train, "get_memory_usage", lambda *a, **k: {"memory_gb": 1.0})
monkeypatch.setattr(train, "emergency_memory_cleanup", lambda *a, **k: None)
monkeypatch.setattr(train, "save_checkpoint", lambda *a, **k: None)

def fake_config_load(_path):
return DummyConfig(str(dummy_log_dir))

monkeypatch.setattr(train.Config, "load", staticmethod(fake_config_load))

warning_messages: list[str] = []
critical_messages: list[str] = []

def capture_warning(msg, *args, **kwargs):
text = msg % args if args else msg
warning_messages.append(text)

def capture_critical(msg, *args, **kwargs):
text = msg % args if args else msg
critical_messages.append(text)

monkeypatch.setattr(train.logger, "warning", capture_warning)
monkeypatch.setattr(train.logger, "critical", capture_critical)

callback_lengths: list[int] = []
alert_fire_counts: list[int] = []

def tracked_add_callback(callback):
pre_len = len(memory_monitor.alert_callbacks)
callback_lengths.append(pre_len)
memory_monitor.add_alert_callback(callback)
post_len = len(memory_monitor.alert_callbacks)
callback_lengths.append(post_len)

before_warning = len(warning_messages)
alert = MemoryAlert(
alert_type="warning",
message="unit-test alert",
memory_usage_gb=1.0,
memory_limit_gb=2.0,
timestamp=time.time(),
device="cpu",
)
memory_monitor._send_alert(alert)
high_memory_logs = [
msg for msg in warning_messages[before_warning:] if msg.startswith("HIGH MEMORY:")
]
alert_fire_counts.append(len(high_memory_logs))

monkeypatch.setattr(train, "add_memory_alert_callback", tracked_add_callback)

run_kwargs = dict(
config_path="dummy",
total_steps=0,
batch_size=1,
learning_rate=0.001,
weight_decay=0.0,
ema_decay=0.0,
grad_clip_norm=1.0,
accum_steps=1,
warmup_steps=0,
checkpoint_dir=str(dummy_ckpt_dir),
log_dir=str(dummy_log_dir),
device="cpu",
use_amp=False,
augment=False,
precision="fp32",
epochs=0,
steps_per_epoch=0,
init_checkpoint=None,
resume=False,
data_mode=None,
dataloader_workers=0,
prefetch_factor=0,
)

# Run twice to ensure cleanup between runs
train.train_comprehensive(**run_kwargs)
assert not memory_monitor.is_monitoring
assert not (memory_monitor.monitor_thread and memory_monitor.monitor_thread.is_alive())
assert len(memory_monitor.alert_callbacks) == 0

train.train_comprehensive(**run_kwargs)
assert not memory_monitor.is_monitoring
assert not (memory_monitor.monitor_thread and memory_monitor.monitor_thread.is_alive())
assert len(memory_monitor.alert_callbacks) == 0

# Verify callbacks were not accumulated and only fired once per registration
assert callback_lengths == [0, 1, 0, 1]
assert alert_fire_counts == [1, 1]
assert all(msg.startswith("HIGH MEMORY:") for msg in warning_messages if "HIGH MEMORY:" in msg)
assert critical_messages == []
Loading