diff --git a/crates/openfang-api/src/routes.rs b/crates/openfang-api/src/routes.rs index 5b873175c..7b74c7b77 100644 --- a/crates/openfang-api/src/routes.rs +++ b/crates/openfang-api/src/routes.rs @@ -639,6 +639,70 @@ pub async fn kill_agent( } } +/// POST /api/agents/{id}/restart — Restart a crashed/stuck agent. +/// +/// Cancels any active task, resets agent state to Running, and updates last_active. +/// Returns the agent's new state. +pub async fn restart_agent( + State(state): State>, + Path(id): Path, +) -> impl IntoResponse { + let agent_id: AgentId = match id.parse() { + Ok(id) => id, + Err(_) => { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({"error": "Invalid agent ID"})), + ); + } + }; + + // Check agent exists + let entry = match state.kernel.registry.get(agent_id) { + Some(e) => e, + None => { + return ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({"error": "Agent not found"})), + ); + } + }; + + let agent_name = entry.name.clone(); + let previous_state = format!("{:?}", entry.state); + drop(entry); + + // Cancel any running task + let was_running = state + .kernel + .stop_agent_run(agent_id) + .unwrap_or(false); + + // Reset state to Running (also updates last_active) + let _ = state + .kernel + .registry + .set_state(agent_id, openfang_types::agent::AgentState::Running); + + tracing::info!( + agent = %agent_name, + previous_state = %previous_state, + task_cancelled = was_running, + "Agent restarted via API" + ); + + ( + StatusCode::OK, + Json(serde_json::json!({ + "status": "restarted", + "agent": agent_name, + "agent_id": id, + "previous_state": previous_state, + "task_cancelled": was_running, + })), + ) +} + /// GET /api/status — Kernel status. pub async fn status(State(state): State>) -> impl IntoResponse { let agents: Vec = state diff --git a/crates/openfang-api/src/server.rs b/crates/openfang-api/src/server.rs index 47c9b91fe..d5495e967 100644 --- a/crates/openfang-api/src/server.rs +++ b/crates/openfang-api/src/server.rs @@ -146,6 +146,14 @@ pub async fn build_router( axum::routing::put(routes::set_agent_mode), ) .route("/api/profiles", axum::routing::get(routes::list_profiles)) + .route( + "/api/agents/{id}/restart", + axum::routing::post(routes::restart_agent), + ) + .route( + "/api/agents/{id}/start", + axum::routing::post(routes::restart_agent), + ) .route( "/api/agents/{id}/message", axum::routing::post(routes::send_message), diff --git a/crates/openfang-kernel/src/heartbeat.rs b/crates/openfang-kernel/src/heartbeat.rs index be3d10fc2..f682157f9 100644 --- a/crates/openfang-kernel/src/heartbeat.rs +++ b/crates/openfang-kernel/src/heartbeat.rs @@ -4,9 +4,14 @@ //! each running agent's `last_active` timestamp. If an agent hasn't been active //! for longer than 2x its heartbeat interval, a `HealthCheckFailed` event is //! published to the event bus. +//! +//! Crashed agents are tracked for auto-recovery: the heartbeat will attempt to +//! reset crashed agents back to Running up to `max_recovery_attempts` times. +//! After exhausting attempts, agents are marked as Terminated (dead). use crate::registry::AgentRegistry; use chrono::Utc; +use dashmap::DashMap; use openfang_types::agent::{AgentId, AgentState}; use tracing::{debug, warn}; @@ -17,6 +22,12 @@ const DEFAULT_CHECK_INTERVAL_SECS: u64 = 30; /// multiples of its heartbeat interval. const UNRESPONSIVE_MULTIPLIER: u64 = 2; +/// Default maximum recovery attempts before giving up. +const DEFAULT_MAX_RECOVERY_ATTEMPTS: u32 = 3; + +/// Default cooldown between recovery attempts (seconds). +const DEFAULT_RECOVERY_COOLDOWN_SECS: u64 = 60; + /// Result of a heartbeat check. #[derive(Debug, Clone)] pub struct HeartbeatStatus { @@ -28,6 +39,8 @@ pub struct HeartbeatStatus { pub inactive_secs: i64, /// Whether the agent is considered unresponsive. pub unresponsive: bool, + /// Current agent state. + pub state: AgentState, } /// Heartbeat monitor configuration. @@ -38,18 +51,82 @@ pub struct HeartbeatConfig { /// Default threshold for unresponsiveness (seconds). /// Overridden per-agent by AutonomousConfig.heartbeat_interval_secs. pub default_timeout_secs: u64, + /// Maximum recovery attempts before marking agent as Terminated. + pub max_recovery_attempts: u32, + /// Minimum seconds between recovery attempts for the same agent. + pub recovery_cooldown_secs: u64, } impl Default for HeartbeatConfig { fn default() -> Self { Self { check_interval_secs: DEFAULT_CHECK_INTERVAL_SECS, - default_timeout_secs: DEFAULT_CHECK_INTERVAL_SECS * UNRESPONSIVE_MULTIPLIER, + // 180s default: browser tasks and complex LLM calls can take 1-3 minutes + default_timeout_secs: 180, + max_recovery_attempts: DEFAULT_MAX_RECOVERY_ATTEMPTS, + recovery_cooldown_secs: DEFAULT_RECOVERY_COOLDOWN_SECS, } } } -/// Check all running agents and return their heartbeat status. +/// Tracks per-agent recovery state across heartbeat cycles. +#[derive(Debug)] +pub struct RecoveryTracker { + /// Per-agent recovery state: (consecutive_failures, last_attempt_epoch_secs). + state: DashMap, +} + +impl RecoveryTracker { + /// Create a new recovery tracker. + pub fn new() -> Self { + Self { + state: DashMap::new(), + } + } + + /// Record a recovery attempt for an agent. + /// Returns the current attempt number (1-indexed). + pub fn record_attempt(&self, agent_id: AgentId) -> u32 { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + let mut entry = self.state.entry(agent_id).or_insert((0, 0)); + entry.0 += 1; + entry.1 = now; + entry.0 + } + + /// Check if enough time has passed since the last recovery attempt. + pub fn can_attempt(&self, agent_id: AgentId, cooldown_secs: u64) -> bool { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + match self.state.get(&agent_id) { + Some(entry) => now.saturating_sub(entry.1) >= cooldown_secs, + None => true, // No prior attempts + } + } + + /// Get the current failure count for an agent. + pub fn failure_count(&self, agent_id: AgentId) -> u32 { + self.state.get(&agent_id).map(|e| e.0).unwrap_or(0) + } + + /// Reset recovery state for an agent (e.g. after successful recovery). + pub fn reset(&self, agent_id: AgentId) { + self.state.remove(&agent_id); + } +} + +impl Default for RecoveryTracker { + fn default() -> Self { + Self::new() + } +} + +/// Check all running and crashed agents and return their heartbeat status. /// /// This is a pure function — it doesn't start a background task. /// The caller (kernel) can run this periodically or in a background task. @@ -58,9 +135,10 @@ pub fn check_agents(registry: &AgentRegistry, config: &HeartbeatConfig) -> Vec {} + _ => continue, } let inactive_secs = (now - entry_ref.last_active).num_seconds(); @@ -73,15 +151,22 @@ pub fn check_agents(registry: &AgentRegistry, config: &HeartbeatConfig) -> Vec timeout_secs; + // Crashed agents are always considered unresponsive + let unresponsive = entry_ref.state == AgentState::Crashed || inactive_secs > timeout_secs; - if unresponsive { + if unresponsive && entry_ref.state == AgentState::Running { warn!( agent = %entry_ref.name, inactive_secs, timeout_secs, "Agent is unresponsive" ); + } else if entry_ref.state == AgentState::Crashed { + warn!( + agent = %entry_ref.name, + inactive_secs, + "Agent is crashed — eligible for recovery" + ); } else { debug!( agent = %entry_ref.name, @@ -95,6 +180,7 @@ pub fn check_agents(registry: &AgentRegistry, config: &HeartbeatConfig) -> Vec) { - use crate::heartbeat::{check_agents, is_quiet_hours, HeartbeatConfig}; + use crate::heartbeat::{check_agents, is_quiet_hours, HeartbeatConfig, RecoveryTracker}; let kernel = Arc::clone(self); let config = HeartbeatConfig::default(); let interval_secs = config.check_interval_secs; + let recovery_tracker = RecoveryTracker::new(); tokio::spawn(async move { let mut interval = @@ -4206,7 +4207,102 @@ impl OpenFangKernel { } } - if status.unresponsive { + // --- Auto-recovery for crashed agents --- + if status.state == AgentState::Crashed { + let failures = recovery_tracker.failure_count(status.agent_id); + + if failures >= config.max_recovery_attempts { + // Already exhausted recovery attempts — mark Terminated + // (only do this once, check current state) + if let Some(entry) = kernel.registry.get(status.agent_id) { + if entry.state == AgentState::Crashed { + let _ = kernel + .registry + .set_state(status.agent_id, AgentState::Terminated); + warn!( + agent = %status.name, + attempts = failures, + "Agent exhausted all recovery attempts — marked Terminated. Manual restart required." + ); + // Publish event for notification channels + let event = Event::new( + status.agent_id, + EventTarget::System, + EventPayload::System(SystemEvent::HealthCheckFailed { + agent_id: status.agent_id, + unresponsive_secs: status.inactive_secs as u64, + }), + ); + kernel.event_bus.publish(event).await; + } + } + continue; + } + + // Check cooldown + if !recovery_tracker.can_attempt( + status.agent_id, + config.recovery_cooldown_secs, + ) { + debug!( + agent = %status.name, + "Recovery cooldown active, skipping" + ); + continue; + } + + // Attempt recovery: reset state to Running + let attempt = recovery_tracker.record_attempt(status.agent_id); + info!( + agent = %status.name, + attempt = attempt, + max = config.max_recovery_attempts, + "Auto-recovering crashed agent (attempt {}/{})", + attempt, + config.max_recovery_attempts + ); + let _ = kernel + .registry + .set_state(status.agent_id, AgentState::Running); + + // Publish recovery event + let event = Event::new( + status.agent_id, + EventTarget::System, + EventPayload::System(SystemEvent::HealthCheckFailed { + agent_id: status.agent_id, + unresponsive_secs: 0, // 0 signals recovery attempt + }), + ); + kernel.event_bus.publish(event).await; + continue; + } + + // --- Running agent that recovered successfully --- + // If agent is Running and was previously in recovery, clear the tracker + if status.state == AgentState::Running + && !status.unresponsive + && recovery_tracker.failure_count(status.agent_id) > 0 + { + info!( + agent = %status.name, + "Agent recovered successfully — resetting recovery tracker" + ); + recovery_tracker.reset(status.agent_id); + } + + // --- Unresponsive Running agent --- + if status.unresponsive && status.state == AgentState::Running { + // Mark as Crashed so next cycle triggers recovery + let _ = kernel + .registry + .set_state(status.agent_id, AgentState::Crashed); + warn!( + agent = %status.name, + inactive_secs = status.inactive_secs, + "Unresponsive Running agent marked as Crashed for recovery" + ); + let event = Event::new( status.agent_id, EventTarget::System, diff --git a/crates/openfang-runtime/src/drivers/claude_code.rs b/crates/openfang-runtime/src/drivers/claude_code.rs index 1cdfe3b44..986224eed 100644 --- a/crates/openfang-runtime/src/drivers/claude_code.rs +++ b/crates/openfang-runtime/src/drivers/claude_code.rs @@ -4,13 +4,18 @@ //! which is non-interactive and handles its own authentication. //! This allows users with Claude Code installed to use it as an LLM provider //! without needing a separate API key. +//! +//! Tracks active subprocess PIDs and enforces message timeouts to prevent +//! hung CLI processes from blocking agents indefinitely. use crate::llm_driver::{CompletionRequest, CompletionResponse, LlmDriver, LlmError, StreamEvent}; use async_trait::async_trait; +use dashmap::DashMap; use openfang_types::message::{ContentBlock, Role, StopReason, TokenUsage}; use serde::Deserialize; -use tokio::io::AsyncBufReadExt; -use tracing::{debug, warn}; +use std::sync::Arc; +use tokio::io::{AsyncBufReadExt, AsyncReadExt}; +use tracing::{debug, info, warn}; /// Environment variable names (and suffixes) to strip from the subprocess /// to prevent leaking API keys from other providers. We keep the full env @@ -44,10 +49,18 @@ const SENSITIVE_ENV_EXACT: &[&str] = &[ /// unless it starts with `CLAUDE_`. const SENSITIVE_SUFFIXES: &[&str] = &["_SECRET", "_TOKEN", "_PASSWORD"]; +/// Default subprocess timeout in seconds (5 minutes). +const DEFAULT_MESSAGE_TIMEOUT_SECS: u64 = 300; + /// LLM driver that delegates to the Claude Code CLI. pub struct ClaudeCodeDriver { cli_path: String, skip_permissions: bool, + /// Active subprocess PIDs keyed by a caller-provided label (e.g. agent name). + /// Allows external code to check if a subprocess is running and kill it. + active_pids: Arc>, + /// Message timeout in seconds. CLI subprocesses that exceed this are killed. + message_timeout_secs: u64, } impl ClaudeCodeDriver { @@ -70,9 +83,32 @@ impl ClaudeCodeDriver { .filter(|s| !s.is_empty()) .unwrap_or_else(|| "claude".to_string()), skip_permissions, + active_pids: Arc::new(DashMap::new()), + message_timeout_secs: DEFAULT_MESSAGE_TIMEOUT_SECS, } } + /// Create a new Claude Code driver with a custom timeout. + pub fn with_timeout(cli_path: Option, skip_permissions: bool, timeout_secs: u64) -> Self { + let mut driver = Self::new(cli_path, skip_permissions); + driver.message_timeout_secs = timeout_secs; + driver + } + + /// Get a snapshot of active subprocess PIDs. + /// Returns a vec of (label, pid) pairs. + pub fn active_pids(&self) -> Vec<(String, u32)> { + self.active_pids + .iter() + .map(|entry| (entry.key().clone(), *entry.value())) + .collect() + } + + /// Get the shared PID map for external monitoring. + pub fn pid_map(&self) -> Arc> { + Arc::clone(&self.active_pids) + } + /// Detect if the Claude Code CLI is available on PATH. pub fn detect() -> Option { let output = std::process::Command::new("claude") @@ -220,20 +256,78 @@ impl LlmDriver for ClaudeCodeDriver { debug!(cli = %self.cli_path, skip_permissions = self.skip_permissions, "Spawning Claude Code CLI"); - let output = cmd - .output() - .await + // Spawn child process instead of cmd.output() so we can track PID and timeout + let mut child = cmd + .spawn() .map_err(|e| LlmError::Http(format!( "Claude Code CLI not found or failed to start ({}). \ Install: npm install -g @anthropic-ai/claude-code && claude auth", e )))?; - if !output.status.success() { - let stderr = String::from_utf8_lossy(&output.stderr).trim().to_string(); - let stdout = String::from_utf8_lossy(&output.stdout).trim().to_string(); - let detail = if !stderr.is_empty() { &stderr } else { &stdout }; - let code = output.status.code().unwrap_or(1); + // Track the PID using the model name as label (best identifier available) + let pid_label = request.model.clone(); + if let Some(pid) = child.id() { + self.active_pids.insert(pid_label.clone(), pid); + debug!(pid = pid, model = %pid_label, "Claude Code CLI subprocess started"); + } + + // Read stdout/stderr before waiting (take ownership of pipes) + let child_stdout = child.stdout.take(); + let child_stderr = child.stderr.take(); + + // Wait with timeout + let timeout_duration = std::time::Duration::from_secs(self.message_timeout_secs); + let wait_result = tokio::time::timeout(timeout_duration, child.wait()).await; + + // Clear PID tracking regardless of outcome + self.active_pids.remove(&pid_label); + + let status = match wait_result { + Ok(Ok(status)) => status, + Ok(Err(e)) => { + warn!(error = %e, model = %pid_label, "Claude Code CLI subprocess failed"); + return Err(LlmError::Http(format!( + "Claude Code CLI subprocess failed: {e}" + ))); + } + Err(_elapsed) => { + // Timeout — kill the process + warn!( + timeout_secs = self.message_timeout_secs, + model = %pid_label, + "Claude Code CLI subprocess timed out, killing process" + ); + let _ = child.kill().await; + return Err(LlmError::Http(format!( + "Claude Code CLI subprocess timed out after {}s — process killed", + self.message_timeout_secs + ))); + } + }; + + // Read captured output from pipes + let mut stdout_bytes = Vec::new(); + let mut stderr_bytes = Vec::new(); + if let Some(mut out) = child_stdout { + let _ = out.read_to_end(&mut stdout_bytes).await; + } + if let Some(mut err) = child_stderr { + let _ = err.read_to_end(&mut stderr_bytes).await; + }; + + if !status.success() { + let stderr = String::from_utf8_lossy(&stderr_bytes).trim().to_string(); + let stdout_str = String::from_utf8_lossy(&stdout_bytes).trim().to_string(); + let detail = if !stderr.is_empty() { &stderr } else { &stdout_str }; + let code = status.code().unwrap_or(1); + + warn!( + exit_code = code, + model = %pid_label, + stderr = %detail, + "Claude Code CLI exited with error" + ); // Provide actionable error messages let message = if detail.contains("not authenticated") @@ -261,7 +355,9 @@ impl LlmDriver for ClaudeCodeDriver { }); } - let stdout = String::from_utf8_lossy(&output.stdout); + info!(model = %pid_label, "Claude Code CLI subprocess completed successfully"); + + let stdout = String::from_utf8_lossy(&stdout_bytes); // Try JSON parse first if let Ok(parsed) = serde_json::from_str::(&stdout) { @@ -322,7 +418,7 @@ impl LlmDriver for ClaudeCodeDriver { cmd.stdout(std::process::Stdio::piped()); cmd.stderr(std::process::Stdio::piped()); - debug!(cli = %self.cli_path, skip_permissions = self.skip_permissions, "Spawning Claude Code CLI (streaming)"); + debug!(cli = %self.cli_path, "Spawning Claude Code CLI (streaming)"); let mut child = cmd .spawn() @@ -332,10 +428,20 @@ impl LlmDriver for ClaudeCodeDriver { e )))?; + // Track PID + let pid_label = format!("{}-stream", request.model); + if let Some(pid) = child.id() { + self.active_pids.insert(pid_label.clone(), pid); + debug!(pid = pid, model = %pid_label, "Claude Code CLI streaming subprocess started"); + } + let stdout = child .stdout .take() - .ok_or_else(|| LlmError::Http("No stdout from claude CLI".to_string()))?; + .ok_or_else(|| { + self.active_pids.remove(&pid_label); + LlmError::Http("No stdout from claude CLI".to_string()) + })?; let reader = tokio::io::BufReader::new(stdout); let mut lines = reader.lines(); @@ -346,64 +452,84 @@ impl LlmDriver for ClaudeCodeDriver { output_tokens: 0, }; - while let Ok(Some(line)) = lines.next_line().await { - if line.trim().is_empty() { - continue; - } + let timeout_duration = std::time::Duration::from_secs(self.message_timeout_secs); + let stream_result = tokio::time::timeout(timeout_duration, async { + while let Ok(Some(line)) = lines.next_line().await { + if line.trim().is_empty() { + continue; + } - match serde_json::from_str::(&line) { - Ok(event) => { - match event.r#type.as_str() { - "content" | "text" | "assistant" | "content_block_delta" => { - if let Some(ref content) = event.content { - full_text.push_str(content); - let _ = tx - .send(StreamEvent::TextDelta { - text: content.clone(), - }) - .await; - } - } - "result" | "done" | "complete" => { - if let Some(ref result) = event.result { - if full_text.is_empty() { - full_text = result.clone(); + match serde_json::from_str::(&line) { + Ok(event) => { + match event.r#type.as_str() { + "content" | "text" | "assistant" | "content_block_delta" => { + if let Some(ref content) = event.content { + full_text.push_str(content); let _ = tx .send(StreamEvent::TextDelta { - text: result.clone(), + text: content.clone(), }) .await; } } - if let Some(usage) = event.usage { - final_usage = TokenUsage { - input_tokens: usage.input_tokens, - output_tokens: usage.output_tokens, - }; + "result" | "done" | "complete" => { + if let Some(ref result) = event.result { + if full_text.is_empty() { + full_text = result.clone(); + let _ = tx + .send(StreamEvent::TextDelta { + text: result.clone(), + }) + .await; + } + } + if let Some(usage) = event.usage { + final_usage = TokenUsage { + input_tokens: usage.input_tokens, + output_tokens: usage.output_tokens, + }; + } } - } - _ => { - // Unknown event type — try content field as fallback - if let Some(ref content) = event.content { - full_text.push_str(content); - let _ = tx - .send(StreamEvent::TextDelta { - text: content.clone(), - }) - .await; + _ => { + // Unknown event type — try content field as fallback + if let Some(ref content) = event.content { + full_text.push_str(content); + let _ = tx + .send(StreamEvent::TextDelta { + text: content.clone(), + }) + .await; + } } } } - } - Err(e) => { - // Not valid JSON — treat as raw text - warn!(line = %line, error = %e, "Non-JSON line from Claude CLI"); - full_text.push_str(&line); - let _ = tx - .send(StreamEvent::TextDelta { text: line }) - .await; + Err(e) => { + // Not valid JSON — treat as raw text + warn!(line = %line, error = %e, "Non-JSON line from Claude CLI"); + full_text.push_str(&line); + let _ = tx + .send(StreamEvent::TextDelta { text: line }) + .await; + } } } + }) + .await; + + // Clear PID tracking + self.active_pids.remove(&pid_label); + + if stream_result.is_err() { + warn!( + timeout_secs = self.message_timeout_secs, + model = %pid_label, + "Claude Code CLI streaming subprocess timed out, killing process" + ); + let _ = child.kill().await; + return Err(LlmError::Http(format!( + "Claude Code CLI streaming subprocess timed out after {}s — process killed", + self.message_timeout_secs + ))); } // Wait for process to finish @@ -413,7 +539,28 @@ impl LlmDriver for ClaudeCodeDriver { .map_err(|e| LlmError::Http(format!("Claude CLI wait failed: {e}")))?; if !status.success() { - warn!(code = ?status.code(), "Claude CLI exited with error"); + let code = status.code().unwrap_or(1); + // Read stderr for diagnostic info + let stderr_text = if let Some(mut err) = child.stderr.take() { + let mut buf = Vec::new(); + let _ = err.read_to_end(&mut buf).await; + String::from_utf8_lossy(&buf).trim().to_string() + } else { + String::new() + }; + warn!( + exit_code = code, + model = %pid_label, + stderr = %stderr_text, + "Claude Code CLI streaming subprocess exited with error" + ); + return Err(LlmError::Api { + status: code as u16, + message: format!( + "Claude Code CLI streaming exited with code {code}: {}", + if stderr_text.is_empty() { "no stderr" } else { &stderr_text } + ), + }); } let _ = tx @@ -517,7 +664,8 @@ mod tests { fn test_new_defaults_to_claude() { let driver = ClaudeCodeDriver::new(None, true); assert_eq!(driver.cli_path, "claude"); - assert!(driver.skip_permissions); + assert_eq!(driver.message_timeout_secs, DEFAULT_MESSAGE_TIMEOUT_SECS); + assert!(driver.active_pids().is_empty()); } #[test] @@ -533,9 +681,19 @@ mod tests { } #[test] - fn test_skip_permissions_disabled() { - let driver = ClaudeCodeDriver::new(None, false); - assert!(!driver.skip_permissions); + fn test_with_timeout() { + let driver = ClaudeCodeDriver::with_timeout(None, true, 600); + assert_eq!(driver.message_timeout_secs, 600); + assert_eq!(driver.cli_path, "claude"); + } + + #[test] + fn test_pid_map_shared() { + let driver = ClaudeCodeDriver::new(None, true); + let map = driver.pid_map(); + map.insert("test-agent".to_string(), 12345); + assert_eq!(driver.active_pids().len(), 1); + assert_eq!(driver.active_pids()[0], ("test-agent".to_string(), 12345)); } #[test]