From 5dab8980ef2740c57a354f155c42e089b463f108 Mon Sep 17 00:00:00 2001 From: mrveiss Date: Thu, 26 Mar 2026 20:26:22 +0200 Subject: [PATCH] fix(backend): add failover monitor for dead worker task migration (#1769) --- .../services/npu_worker_manager.py | 163 +++++++++++++++++- 1 file changed, 159 insertions(+), 4 deletions(-) diff --git a/autobot-backend/services/npu_worker_manager.py b/autobot-backend/services/npu_worker_manager.py index 356a30fb4..e53b82eb9 100644 --- a/autobot-backend/services/npu_worker_manager.py +++ b/autobot-backend/services/npu_worker_manager.py @@ -8,6 +8,7 @@ """ import asyncio +import json import logging import time from datetime import datetime @@ -66,6 +67,7 @@ def __init__(self, config_file: Path = None, redis_client=None): self._health_check_task: Optional[asyncio.Task] = None self._running = False self._load_balancing_config = LoadBalancingConfig() + self._failover_monitor_task: Optional[asyncio.Task] = None # Issue #699: Track consecutive failures for exponential backoff self._worker_failure_counts: Dict[str, int] = {} @@ -151,17 +153,18 @@ async def _save_workers_to_config(self): raise async def start_health_monitoring(self): - """Start background health monitoring task""" + """Start background health monitoring and failover monitor tasks.""" if self._running: logger.warning("Health monitoring already running") return self._running = True self._health_check_task = asyncio.create_task(self._health_check_loop()) - logger.info("Started NPU worker health monitoring") + self._failover_monitor_task = asyncio.create_task(self.failover_monitor()) + logger.info("Started NPU worker health monitoring and failover monitor") async def stop_health_monitoring(self): - """Stop background health monitoring task""" + """Stop background health monitoring and failover monitor tasks.""" self._running = False if self._health_check_task: @@ -171,12 +174,19 @@ async def stop_health_monitoring(self): except asyncio.CancelledError: logger.debug("Health check task cancelled") + if self._failover_monitor_task: + self._failover_monitor_task.cancel() + try: + await self._failover_monitor_task + except asyncio.CancelledError: + logger.debug("Failover monitor task cancelled") + # Close all worker clients for client in self._worker_clients.values(): await client.close() self._worker_clients.clear() - logger.info("Stopped NPU worker health monitoring") + logger.info("Stopped NPU worker health monitoring and failover monitor") async def _check_single_worker_health(self, worker_id: str) -> None: """Check health of a single worker with error handling (Issue #315: extracted helper). @@ -668,6 +678,151 @@ async def _get_worker_metrics(self, worker_id: str) -> Optional[NPUWorkerMetrics return None + async def _is_worker_heartbeat_expired(self, worker_id: str) -> bool: + """Return True if the worker's heartbeat key has expired or is missing (#1769). + + Args: + worker_id: ID of the worker to check + """ + if not self.redis_client: + return False + try: + key = f"npu:worker:{worker_id}:status" + ttl = await self.redis_client.ttl(key) + # ttl == -2 means key does not exist; ttl == -1 means no expiry (treat as alive) + return ttl == -2 + except Exception as e: + logger.error( + "Failed to check heartbeat TTL for worker %s: %s", worker_id, e + ) + return False + + async def _migrate_running_task( + self, + task_id: str, + running_key: str, + pending_key: str, + failed_key: str, + tasks_hash: str, + max_retries: int, + ) -> None: + """Move one task from running back to pending or to failed (#1769). + + Args: + task_id: Task identifier + running_key: Redis key for the running ZSET + pending_key: Redis key for the pending ZSET + failed_key: Redis key for the failed ZSET + tasks_hash: Redis hash storing task data + max_retries: Maximum allowed retries before moving to failed + """ + try: + raw = await self.redis_client.hget(tasks_hash, task_id) + task_data = json.loads(raw) if raw else {} + metadata = task_data.get("metadata") or {} + retry_count = metadata.get("retry_count", 0) + 1 + + await self.redis_client.zrem(running_key, task_id) + + if retry_count <= max_retries: + metadata["retry_count"] = retry_count + task_data["metadata"] = metadata + await self.redis_client.hset(tasks_hash, task_id, json.dumps(task_data)) + score = int(time.time()) + await self.redis_client.zadd(pending_key, {task_id: score}) + logger.info( + "Failover: re-queued task %s (retry %d/%d)", + task_id, + retry_count, + max_retries, + ) + else: + await self.redis_client.zadd(failed_key, {task_id: int(time.time())}) + logger.warning( + "Failover: task %s exceeded max_retries (%d), moved to failed", + task_id, + max_retries, + ) + except Exception as e: + logger.error("Failed to migrate task %s during failover: %s", task_id, e) + + async def _failover_dead_worker( + self, + worker_id: str, + queue_name: str, + max_retries: int, + ) -> None: + """Re-queue all running tasks for a dead worker and remove it (#1769). + + Args: + worker_id: ID of the dead worker + queue_name: Base name for the task queue Redis keys + max_retries: Maximum retries before a task is moved to failed + """ + running_key = f"{queue_name}:running" + pending_key = f"{queue_name}:pending" + failed_key = f"{queue_name}:failed" + tasks_hash = f"{queue_name}:tasks" + + task_ids = await self.redis_client.zrange(running_key, 0, -1) + if not task_ids: + logger.info("Failover: dead worker %s had no running tasks", worker_id) + else: + for task_id in task_ids: + await self._migrate_running_task( + task_id, + running_key, + pending_key, + failed_key, + tasks_hash, + max_retries, + ) + + # Remove worker status key and registry entry + await self.redis_client.delete(f"npu:worker:{worker_id}:status") + await self.redis_client.delete(f"npu:worker:{worker_id}:metrics") + self._workers.pop(worker_id, None) + logger.info("Failover: removed dead worker %s from registry", worker_id) + + async def failover_monitor( + self, + queue_name: str = "autobot_tasks", + check_interval: int = 30, + max_retries: int = 3, + ) -> None: + """Periodically detect dead workers and re-queue their tasks (#1769). + + Scans registered workers for expired heartbeat keys. When a heartbeat + has expired the worker is considered dead: its running tasks are moved + back to pending (up to max_retries) or to failed when retries are + exhausted. + + Args: + queue_name: Base name for the task queue Redis keys + check_interval: Seconds between scans + max_retries: Maximum re-queue attempts before marking a task failed + """ + logger.info( + "Failover monitor started (queue=%s, interval=%ds)", + queue_name, + check_interval, + ) + while True: + try: + worker_ids = list(self._workers.keys()) + for worker_id in worker_ids: + if await self._is_worker_heartbeat_expired(worker_id): + logger.warning( + "Failover: worker %s heartbeat expired, starting failover", + worker_id, + ) + await self._failover_dead_worker( + worker_id, queue_name, max_retries + ) + except Exception: + logger.exception("Failover monitor encountered an unexpected error") + await asyncio.sleep(check_interval) + def get_load_balancing_config(self) -> LoadBalancingConfig: """Get current load balancing configuration""" return self._load_balancing_config