Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 159 additions & 4 deletions autobot-backend/services/npu_worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import asyncio
import json
import logging
import time
from datetime import datetime
Expand Down Expand Up @@ -66,6 +67,7 @@ def __init__(self, config_file: Path = None, redis_client=None):
self._health_check_task: Optional[asyncio.Task] = None
self._running = False
self._load_balancing_config = LoadBalancingConfig()
self._failover_monitor_task: Optional[asyncio.Task] = None

# Issue #699: Track consecutive failures for exponential backoff
self._worker_failure_counts: Dict[str, int] = {}
Expand Down Expand Up @@ -151,17 +153,18 @@ async def _save_workers_to_config(self):
raise

async def start_health_monitoring(self):
"""Start background health monitoring task"""
"""Start background health monitoring and failover monitor tasks."""
if self._running:
logger.warning("Health monitoring already running")
return

self._running = True
self._health_check_task = asyncio.create_task(self._health_check_loop())
logger.info("Started NPU worker health monitoring")
self._failover_monitor_task = asyncio.create_task(self.failover_monitor())
logger.info("Started NPU worker health monitoring and failover monitor")

async def stop_health_monitoring(self):
"""Stop background health monitoring task"""
"""Stop background health monitoring and failover monitor tasks."""
self._running = False

if self._health_check_task:
Expand All @@ -171,12 +174,19 @@ async def stop_health_monitoring(self):
except asyncio.CancelledError:
logger.debug("Health check task cancelled")

if self._failover_monitor_task:
self._failover_monitor_task.cancel()
try:
await self._failover_monitor_task
except asyncio.CancelledError:
logger.debug("Failover monitor task cancelled")

# Close all worker clients
for client in self._worker_clients.values():
await client.close()
self._worker_clients.clear()

logger.info("Stopped NPU worker health monitoring")
logger.info("Stopped NPU worker health monitoring and failover monitor")

async def _check_single_worker_health(self, worker_id: str) -> None:
"""Check health of a single worker with error handling (Issue #315: extracted helper).
Expand Down Expand Up @@ -668,6 +678,151 @@ async def _get_worker_metrics(self, worker_id: str) -> Optional[NPUWorkerMetrics

return None

async def _is_worker_heartbeat_expired(self, worker_id: str) -> bool:
"""Return True if the worker's heartbeat key has expired or is missing (#1769).

Args:
worker_id: ID of the worker to check
"""
if not self.redis_client:
return False
try:
key = f"npu:worker:{worker_id}:status"
ttl = await self.redis_client.ttl(key)
# ttl == -2 means key does not exist; ttl == -1 means no expiry (treat as alive)
return ttl == -2
except Exception as e:
logger.error(
"Failed to check heartbeat TTL for worker %s: %s", worker_id, e
)
return False

async def _migrate_running_task(
self,
task_id: str,
running_key: str,
pending_key: str,
failed_key: str,
tasks_hash: str,
max_retries: int,
) -> None:
"""Move one task from running back to pending or to failed (#1769).

Args:
task_id: Task identifier
running_key: Redis key for the running ZSET
pending_key: Redis key for the pending ZSET
failed_key: Redis key for the failed ZSET
tasks_hash: Redis hash storing task data
max_retries: Maximum allowed retries before moving to failed
"""
try:
raw = await self.redis_client.hget(tasks_hash, task_id)
task_data = json.loads(raw) if raw else {}
metadata = task_data.get("metadata") or {}
retry_count = metadata.get("retry_count", 0) + 1

await self.redis_client.zrem(running_key, task_id)

if retry_count <= max_retries:
metadata["retry_count"] = retry_count
task_data["metadata"] = metadata
await self.redis_client.hset(tasks_hash, task_id, json.dumps(task_data))
score = int(time.time())
await self.redis_client.zadd(pending_key, {task_id: score})
logger.info(
"Failover: re-queued task %s (retry %d/%d)",
task_id,
retry_count,
max_retries,
)
else:
await self.redis_client.zadd(failed_key, {task_id: int(time.time())})
logger.warning(
"Failover: task %s exceeded max_retries (%d), moved to failed",
task_id,
max_retries,
)
except Exception as e:
logger.error("Failed to migrate task %s during failover: %s", task_id, e)

async def _failover_dead_worker(
self,
worker_id: str,
queue_name: str,
max_retries: int,
) -> None:
"""Re-queue all running tasks for a dead worker and remove it (#1769).

Args:
worker_id: ID of the dead worker
queue_name: Base name for the task queue Redis keys
max_retries: Maximum retries before a task is moved to failed
"""
running_key = f"{queue_name}:running"
pending_key = f"{queue_name}:pending"
failed_key = f"{queue_name}:failed"
tasks_hash = f"{queue_name}:tasks"

task_ids = await self.redis_client.zrange(running_key, 0, -1)
if not task_ids:
logger.info("Failover: dead worker %s had no running tasks", worker_id)
else:
for task_id in task_ids:
await self._migrate_running_task(
task_id,
running_key,
pending_key,
failed_key,
tasks_hash,
max_retries,
)

# Remove worker status key and registry entry
await self.redis_client.delete(f"npu:worker:{worker_id}:status")
await self.redis_client.delete(f"npu:worker:{worker_id}:metrics")
self._workers.pop(worker_id, None)
logger.info("Failover: removed dead worker %s from registry", worker_id)

async def failover_monitor(
self,
queue_name: str = "autobot_tasks",
check_interval: int = 30,
max_retries: int = 3,
) -> None:
"""Periodically detect dead workers and re-queue their tasks (#1769).

Scans registered workers for expired heartbeat keys. When a heartbeat
has expired the worker is considered dead: its running tasks are moved
back to pending (up to max_retries) or to failed when retries are
exhausted.

Args:
queue_name: Base name for the task queue Redis keys
check_interval: Seconds between scans
max_retries: Maximum re-queue attempts before marking a task failed
"""
logger.info(
"Failover monitor started (queue=%s, interval=%ds)",
queue_name,
check_interval,
)
while True:
try:
worker_ids = list(self._workers.keys())
for worker_id in worker_ids:
if await self._is_worker_heartbeat_expired(worker_id):
logger.warning(
"Failover: worker %s heartbeat expired, starting failover",
worker_id,
)
await self._failover_dead_worker(
worker_id, queue_name, max_retries
)
except Exception:
logger.exception("Failover monitor encountered an unexpected error")
await asyncio.sleep(check_interval)

def get_load_balancing_config(self) -> LoadBalancingConfig:
"""Get current load balancing configuration"""
return self._load_balancing_config
Expand Down
Loading