diff --git a/src-tauri/src/cli/commands/check.rs b/src-tauri/src/cli/commands/check.rs new file mode 100644 index 0000000..a7a6fee --- /dev/null +++ b/src-tauri/src/cli/commands/check.rs @@ -0,0 +1,567 @@ +use crate::cli::ui::{create_table, error, highlight, info, success, warning}; +use crate::error::AppError; +use crate::t; +use clap::Subcommand; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::process::Command; + +/// CLI tools configuration for version checking +#[derive(Debug, Clone, Serialize, Deserialize)] +struct CliTool { + id: String, + label: String, + npm_package: String, +} + +/// Version check result for a CLI tool +#[derive(Debug, Clone, Serialize)] +struct VersionCheckResult { + id: String, + label: String, + current: Option, + latest: Option, + status: VersionStatus, + upgrade_cmd: Option, + #[serde(skip_serializing_if = "Option::is_none")] + error: Option, +} + +#[derive(Debug, Clone, Serialize, PartialEq)] +enum VersionStatus { + #[serde(rename = "latest")] + Latest, + #[serde(rename = "upgradable")] + Upgradable, + #[serde(rename = "not_installed")] + NotInstalled, + #[serde(rename = "unknown")] + Unknown, + #[serde(rename = "fetch_failed")] + FetchFailed, +} + +impl VersionStatus { + fn display(&self) -> &'static str { + match self { + VersionStatus::Latest => t!("Up to date", "最新"), + VersionStatus::Upgradable => t!("Upgradable", "可升级"), + VersionStatus::NotInstalled => t!("Not installed", "未安装"), + VersionStatus::Unknown => t!("Unknown", "未知"), + VersionStatus::FetchFailed => t!("Fetch failed", "获取失败"), + } + } +} + +#[derive(Subcommand)] +pub enum CheckCommand { + /// Check for CLI tool updates (Claude Code, Codex, Gemini, etc.) + #[command(alias = "update")] + Updates { + /// Tool ID to check (e.g., claude, codex, gemini). If not specified, checks all. + tool: Option, + + /// Skip fetching latest versions (offline mode) + #[arg(long)] + offline: bool, + + /// Output in JSON format + #[arg(long)] + json: bool, + }, + + /// Upgrade CLI tools to latest version + Upgrade { + /// Tool ID to upgrade (e.g., claude, codex, gemini). If not specified, upgrades all. + tool: Option, + + /// Actually execute the upgrade (without this flag, only shows what would be done) + #[arg(long, short)] + yes: bool, + }, +} + +pub fn execute(cmd: CheckCommand, _app: Option) -> Result<(), AppError> { + match cmd { + CheckCommand::Updates { tool, offline, json } => check_updates(tool, offline, json), + CheckCommand::Upgrade { tool, yes } => upgrade_tools(tool, yes), + } +} + +/// Get the list of CLI tools to check +fn get_cli_tools() -> Vec { + vec![ + CliTool { + id: "claude".to_string(), + label: "Claude Code".to_string(), + npm_package: "@anthropic-ai/claude-code".to_string(), + }, + CliTool { + id: "codex".to_string(), + label: "Codex".to_string(), + npm_package: "@openai/codex".to_string(), + }, + CliTool { + id: "gemini".to_string(), + label: "Gemini".to_string(), + npm_package: "@google/gemini-cli".to_string(), + }, + CliTool { + id: "opencode".to_string(), + label: "OpenCode".to_string(), + npm_package: "opencode-ai".to_string(), + }, + CliTool { + id: "qwen".to_string(), + label: "Qwen Code".to_string(), + npm_package: "@qwen-code/qwen-code".to_string(), + }, + ] +} + +/// Get globally installed npm packages and their versions +fn get_npm_globals() -> HashMap { + let mut map = HashMap::new(); + + let output = Command::new("npm") + .args(["ls", "-g", "--depth=0", "--json"]) + .output(); + + let Ok(output) = output else { + return map; + }; + + if !output.status.success() { + return map; + } + + let Ok(stdout) = String::from_utf8(output.stdout) else { + return map; + }; + + #[derive(Deserialize)] + struct NpmLsOutput { + dependencies: Option>, + } + + #[derive(Deserialize)] + struct NpmPackageInfo { + version: Option, + } + + if let Ok(parsed) = serde_json::from_str::(&stdout) { + if let Some(deps) = parsed.dependencies { + for (name, info) in deps { + if let Some(version) = info.version { + map.insert(name, version); + } + } + } + } + + map +} + +/// Get the latest version of an npm package +fn get_npm_latest_version(package: &str) -> Result { + // Try npmmirror first (faster in China), then fallback to official + let registries = [ + "https://registry.npmmirror.com", + "https://registry.npmjs.org", + ]; + + for registry in registries { + let output = Command::new("npm") + .args(["view", package, "version", "--registry", registry]) + .output(); + + if let Ok(output) = output { + if output.status.success() { + let version = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if !version.is_empty() { + return Ok(version); + } + } + } + } + + Err("Failed to fetch latest version".to_string()) +} + +/// Compare two semver versions +/// Returns: -1 if a < b, 0 if a == b, 1 if a > b +/// Handles pre-release suffixes: 1.0.0-beta.1 < 1.0.0 +fn compare_versions(a: &str, b: &str) -> i32 { + let a = a.trim_start_matches('v'); + let b = b.trim_start_matches('v'); + + // Split into version and pre-release parts + let (a_ver, a_pre) = a.split_once('-').map(|(v, p)| (v, Some(p))).unwrap_or((a, None)); + let (b_ver, b_pre) = b.split_once('-').map(|(v, p)| (v, Some(p))).unwrap_or((b, None)); + + let parse = |v: &str| -> Vec { + v.split('.') + .filter_map(|s| s.parse::().ok()) + .collect() + }; + + let a_parts = parse(a_ver); + let b_parts = parse(b_ver); + + for i in 0..3 { + let a_val = a_parts.get(i).copied().unwrap_or(0); + let b_val = b_parts.get(i).copied().unwrap_or(0); + if a_val < b_val { + return -1; + } + if a_val > b_val { + return 1; + } + } + + // Same version numbers: pre-release < release (1.0.0-beta < 1.0.0) + match (a_pre, b_pre) { + (Some(_), None) => -1, + (None, Some(_)) => 1, + (Some(a_p), Some(b_p)) => a_p.cmp(b_p) as i32, + (None, None) => 0, + } +} + +fn check_updates(tool_id: Option, offline: bool, json_output: bool) -> Result<(), AppError> { + let all_tools = get_cli_tools(); + let tools: Vec = if let Some(ref id) = tool_id { + all_tools + .into_iter() + .filter(|t| t.id == *id || t.label.to_lowercase() == id.to_lowercase()) + .collect() + } else { + all_tools + }; + + if tools.is_empty() { + println!( + "{}", + error(&format!( + "Tool '{}' not found. Available: claude, codex, gemini, opencode, qwen", + tool_id.unwrap_or_default() + )) + ); + return Ok(()); + } + + let npm_globals = get_npm_globals(); + + let mut results: Vec = Vec::new(); + + if !json_output { + println!("\n{}", highlight("AI CLI Tools Version Check")); + println!("{}", "═".repeat(60)); + println!(); + } + + // Use indicatif for progress if not JSON output + let pb = if !json_output && !offline { + let pb = indicatif::ProgressBar::new(tools.len() as u64); + if let Ok(style) = indicatif::ProgressStyle::default_bar() + .template("{spinner:.cyan} [{pos}/{len}] Checking {msg}...") + { + pb.set_style(style); + } + Some(pb) + } else { + None + }; + + for tool in &tools { + if let Some(ref pb) = pb { + pb.set_message(tool.label.clone()); + } + + let current = npm_globals.get(&tool.npm_package).cloned(); + + let (latest, fetch_error) = if offline { + (None, None) + } else { + match get_npm_latest_version(&tool.npm_package) { + Ok(v) => (Some(v), None), + Err(e) => (None, Some(e)), + } + }; + + let status = if current.is_none() { + VersionStatus::NotInstalled + } else if offline { + VersionStatus::Unknown + } else if fetch_error.is_some() { + VersionStatus::FetchFailed + } else if let (Some(ref curr), Some(ref lat)) = (¤t, &latest) { + if compare_versions(curr, lat) < 0 { + VersionStatus::Upgradable + } else { + VersionStatus::Latest + } + } else { + VersionStatus::Unknown + }; + + let upgrade_cmd = if current.is_some() || status == VersionStatus::NotInstalled { + Some(format!("npm i -g {}@latest", tool.npm_package)) + } else { + None + }; + + results.push(VersionCheckResult { + id: tool.id.clone(), + label: tool.label.clone(), + current, + latest, + status, + upgrade_cmd, + error: fetch_error, + }); + + if let Some(ref pb) = pb { + pb.inc(1); + } + } + + if let Some(pb) = pb { + pb.finish_and_clear(); + } + + // Output results + if json_output { + let output = serde_json::json!({ + "title": "AI CLI Tools", + "results": results + }); + println!("{}", serde_json::to_string_pretty(&output).unwrap()); + return Ok(()); + } + + // Display table + let mut table = create_table(); + table.set_header(vec!["CLI Tool", "Current", "Latest", "Status"]); + + for result in &results { + let current_display = result.current.clone().unwrap_or("-".to_string()); + let latest_display = result.latest.clone().unwrap_or("-".to_string()); + let status_display = result.status.display(); + + table.add_row(vec![ + result.label.as_str(), + ¤t_display, + &latest_display, + status_display, + ]); + } + + println!("{}", table); + + // Summary + let upgradable_count = results + .iter() + .filter(|r| r.status == VersionStatus::Upgradable) + .count(); + let not_installed_count = results + .iter() + .filter(|r| r.status == VersionStatus::NotInstalled) + .count(); + let fetch_failed_count = results + .iter() + .filter(|r| r.status == VersionStatus::FetchFailed) + .count(); + + println!(); + + if upgradable_count > 0 { + println!( + "{}", + warning(&format!("⬆ {} tool(s) can be upgraded", upgradable_count)) + ); + println!( + "{}", + info(" Run `cc-switch check upgrade --yes` to upgrade all") + ); + } + + if not_installed_count > 0 { + println!( + "{}", + info(&format!("📦 {} tool(s) not installed", not_installed_count)) + ); + } + + if fetch_failed_count > 0 { + println!( + "{}", + error(&format!( + "⚠ {} tool(s) failed to fetch latest version", + fetch_failed_count + )) + ); + println!("{}", info(" Check your network or npm registry settings")); + } + + if upgradable_count == 0 && fetch_failed_count == 0 && !offline { + let installed_count = results + .iter() + .filter(|r| r.status == VersionStatus::Latest) + .count(); + if installed_count > 0 { + println!("{}", success("✓ All installed tools are up to date")); + } + } + + if offline { + println!(); + println!( + "{}", + info("ℹ Offline mode: latest versions not checked. Remove --offline to check for updates.") + ); + } + + Ok(()) +} + +fn upgrade_tools(tool_id: Option, yes: bool) -> Result<(), AppError> { + let tools = get_cli_tools(); + let npm_globals = get_npm_globals(); + + // Filter tools to upgrade + let tools_to_check: Vec<&CliTool> = if let Some(ref id) = tool_id { + tools + .iter() + .filter(|t| t.id == *id || t.label.to_lowercase() == id.to_lowercase()) + .collect() + } else { + tools.iter().collect() + }; + + if tools_to_check.is_empty() { + println!( + "{}", + error(&format!( + "Tool '{}' not found. Available: claude, codex, gemini, opencode, qwen", + tool_id.unwrap_or_default() + )) + ); + return Ok(()); + } + + // Check which tools need upgrade + let mut upgradable: Vec<(&CliTool, String, String)> = Vec::new(); + + println!("\n{}", highlight("Checking for upgrades...")); + println!(); + + for tool in tools_to_check { + let current = npm_globals.get(&tool.npm_package); + + if current.is_none() { + println!( + "{}", + info(&format!(" {} - not installed, will install", tool.label)) + ); + upgradable.push((tool, "-".to_string(), "latest".to_string())); + continue; + } + + match get_npm_latest_version(&tool.npm_package) { + Ok(latest) => { + let curr = current.unwrap(); + if compare_versions(curr, &latest) < 0 { + println!( + "{}", + warning(&format!( + " {} - {} → {} (upgradable)", + tool.label, curr, latest + )) + ); + upgradable.push((tool, curr.clone(), latest)); + } else { + println!( + "{}", + success(&format!(" {} - {} (up to date)", tool.label, curr)) + ); + } + } + Err(_) => { + println!( + "{}", + error(&format!( + " {} - failed to check latest version", + tool.label + )) + ); + } + } + } + + println!(); + + if upgradable.is_empty() { + println!("{}", success("✓ Nothing to upgrade")); + return Ok(()); + } + + println!( + "{}", + highlight(&format!("{} tool(s) to upgrade:", upgradable.len())) + ); + for (tool, _, _) in &upgradable { + println!(" - {} (npm i -g {}@latest)", tool.label, tool.npm_package); + } + println!(); + + if !yes { + println!( + "{}", + warning("Add --yes flag to actually execute the upgrades") + ); + return Ok(()); + } + + // Execute upgrades with registry fallback (npmmirror first, then official) + let registries = [ + ("npmmirror", "https://registry.npmmirror.com"), + ("npmjs.org", "https://registry.npmjs.org"), + ]; + + for (tool, _, _) in &upgradable { + println!("{}", info(&format!("Upgrading {}...", tool.label))); + + let pkg = format!("{}@latest", tool.npm_package); + let mut upgraded = false; + + for (reg_name, reg_url) in ®istries { + let status = Command::new("npm") + .args(["i", "-g", &pkg, "--registry", reg_url]) + .status(); + + match status { + Ok(s) if s.success() => { + println!("{}", success(&format!(" ✓ {} upgraded successfully", tool.label))); + upgraded = true; + break; + } + _ => { + println!( + "{}", + warning(&format!(" ⚠ {} registry failed, trying next...", reg_name)) + ); + } + } + } + + if !upgraded { + println!("{}", error(&format!(" ✗ Failed to upgrade {} (all registries failed)", tool.label))); + } + } + + println!(); + println!("{}", success("Done!")); + + Ok(()) +} diff --git a/src-tauri/src/cli/commands/memory.rs b/src-tauri/src/cli/commands/memory.rs new file mode 100644 index 0000000..ab8ecaf --- /dev/null +++ b/src-tauri/src/cli/commands/memory.rs @@ -0,0 +1,473 @@ +use clap::{Subcommand, ValueEnum}; +use std::io::{self, BufRead}; + +use crate::app_config::AppType; +use crate::cli::ui::{create_table, highlight, info, success, warning}; +use crate::error::AppError; +use crate::services::memory::{MemoryService, NewObservation, ObservationType}; + +#[derive(Subcommand)] +pub enum MemoryCommand { + /// Add a new observation + Add { + /// Title of the observation + title: String, + /// Content of the observation + #[arg(short, long)] + content: Option, + /// Type of observation (decision, error, pattern, preference, general) + #[arg(short = 't', long, value_enum, default_value = "general")] + r#type: ObservationTypeArg, + /// Comma-separated tags + #[arg(long)] + tags: Option, + /// Project directory this observation relates to + #[arg(short, long)] + project: Option, + }, + /// List observations + List { + /// Maximum number of observations to show + #[arg(short, long, default_value = "20")] + limit: i64, + /// Filter by type + #[arg(short = 't', long, value_enum)] + r#type: Option, + /// Filter by project directory + #[arg(short, long)] + project: Option, + }, + /// Show a specific observation + Show { + /// Observation ID + id: i64, + }, + /// Search observations using full-text search + Search { + /// Search query + query: String, + /// Maximum results + #[arg(short, long, default_value = "10")] + limit: i64, + }, + /// Delete an observation + Delete { + /// Observation ID + id: i64, + }, + /// Show memory statistics + Stats, + /// Get context with progressive disclosure + Context { + /// Optional search query + query: Option, + /// Maximum tokens in context + #[arg(long, default_value = "4000")] + max_tokens: i32, + /// Project directory for context + #[arg(short, long)] + project: Option, + }, + /// Manage Claude Code hooks integration + #[command(subcommand)] + Hooks(HooksCommand), + /// List recent sessions + Sessions { + /// Maximum number of sessions to show + #[arg(short, long, default_value = "10")] + limit: i64, + }, +} + +#[derive(Clone, Copy, ValueEnum)] +pub enum HookType { + SessionStart, + PostToolUse, +} + +#[derive(Subcommand)] +pub enum HooksCommand { + /// Register hooks in Claude Code settings + Register, + /// Unregister hooks from Claude Code settings + Unregister, + /// Check hook registration status + Status, + /// Process hook event (internal use, called by Claude Code hooks) + Ingest { + /// Which hook triggered this + #[arg(long)] + hook: HookType, + }, +} + +#[derive(Clone, Copy, ValueEnum)] +pub enum ObservationTypeArg { + Decision, + Error, + Pattern, + Preference, + General, +} + +impl From for ObservationType { + fn from(arg: ObservationTypeArg) -> Self { + match arg { + ObservationTypeArg::Decision => ObservationType::Decision, + ObservationTypeArg::Error => ObservationType::Error, + ObservationTypeArg::Pattern => ObservationType::Pattern, + ObservationTypeArg::Preference => ObservationType::Preference, + ObservationTypeArg::General => ObservationType::General, + } + } +} + +pub fn execute(cmd: MemoryCommand, _app: Option) -> Result<(), AppError> { + match cmd { + MemoryCommand::Add { + title, + content, + r#type, + tags, + project, + } => add_observation(title, content, r#type.into(), tags, project), + MemoryCommand::List { + limit, + r#type, + project, + } => list_observations(limit, r#type.map(Into::into), project), + MemoryCommand::Show { id } => show_observation(id), + MemoryCommand::Search { query, limit } => search_observations(&query, limit), + MemoryCommand::Delete { id } => delete_observation(id), + MemoryCommand::Stats => show_stats(), + MemoryCommand::Context { + query, + max_tokens, + project, + } => show_context(query.as_deref(), max_tokens, project.as_deref()), + MemoryCommand::Hooks(hooks_cmd) => execute_hooks(hooks_cmd), + MemoryCommand::Sessions { limit } => list_sessions(limit), + } +} + +fn add_observation( + title: String, + content: Option, + observation_type: ObservationType, + tags: Option, + project: Option, +) -> Result<(), AppError> { + let content = content.unwrap_or_default(); + let tags: Vec = tags + .map(|t| t.split(',').map(|s| s.trim().to_string()).collect()) + .unwrap_or_default(); + + let obs = MemoryService::add_observation(NewObservation { + session_id: None, + title: title.clone(), + content, + observation_type, + tags, + project_dir: project, + })?; + + println!( + "{}", + success(&format!("Added observation #{} '{}'", obs.id, title)) + ); + Ok(()) +} + +fn list_observations( + limit: i64, + observation_type: Option, + project: Option, +) -> Result<(), AppError> { + let observations = + MemoryService::list_observations(Some(limit), observation_type, project.as_deref())?; + + if observations.is_empty() { + println!("{}", info("No observations found.")); + return Ok(()); + } + + let mut table = create_table(); + table.set_header(vec!["ID", "Type", "Title", "Tokens", "Created"]); + + for obs in observations { + table.add_row(vec![ + obs.id.to_string(), + obs.observation_type.to_string(), + truncate(&obs.title, 40), + obs.tokens.to_string(), + obs.created_at.format("%Y-%m-%d %H:%M").to_string(), + ]); + } + + println!("{}", table); + Ok(()) +} + +fn show_observation(id: i64) -> Result<(), AppError> { + let obs = MemoryService::get_observation(id)?; + + match obs { + Some(obs) => { + println!("{}", highlight(&format!("Observation #{}", obs.id))); + println!("Title: {}", obs.title); + println!("Type: {}", obs.observation_type); + println!("Tokens: {}", obs.tokens); + println!("Created: {}", obs.created_at.format("%Y-%m-%d %H:%M:%S")); + if !obs.tags.is_empty() { + println!("Tags: {}", obs.tags.join(", ")); + } + if let Some(ref proj) = obs.project_dir { + println!("Project: {}", proj); + } + println!("\n{}", highlight("Content:")); + println!("{}", obs.content); + Ok(()) + } + None => { + println!("{}", warning(&format!("Observation #{} not found", id))); + Ok(()) + } + } +} + +fn search_observations(query: &str, limit: i64) -> Result<(), AppError> { + let results = MemoryService::search(query, Some(limit))?; + + if results.is_empty() { + println!("{}", info(&format!("No results for '{}'", query))); + return Ok(()); + } + + println!( + "{}", + highlight(&format!("Found {} result(s) for '{}':", results.len(), query)) + ); + println!(); + + let mut table = create_table(); + table.set_header(vec!["ID", "Type", "Title", "Tokens"]); + + for obs in results { + table.add_row(vec![ + obs.id.to_string(), + obs.observation_type.to_string(), + truncate(&obs.title, 50), + obs.tokens.to_string(), + ]); + } + + println!("{}", table); + Ok(()) +} + +fn delete_observation(id: i64) -> Result<(), AppError> { + let deleted = MemoryService::delete_observation(id)?; + + if deleted { + println!("{}", success(&format!("Deleted observation #{}", id))); + } else { + println!("{}", warning(&format!("Observation #{} not found", id))); + } + Ok(()) +} + +fn show_stats() -> Result<(), AppError> { + let stats = MemoryService::stats()?; + + println!("{}", highlight("Memory Statistics")); + println!(); + println!("Total observations: {}", stats.total_observations); + println!("Total sessions: {}", stats.total_sessions); + println!("Total tokens: {}", stats.total_tokens); + + if !stats.observations_by_type.is_empty() { + println!(); + println!("{}", highlight("By Type:")); + for (obs_type, count) in &stats.observations_by_type { + println!(" {}: {}", obs_type, count); + } + } + + if let Some(oldest) = stats.oldest_observation { + println!(); + println!( + "Oldest: {}", + oldest.format("%Y-%m-%d %H:%M:%S") + ); + } + if let Some(newest) = stats.newest_observation { + println!( + "Newest: {}", + newest.format("%Y-%m-%d %H:%M:%S") + ); + } + + Ok(()) +} + +fn show_context(query: Option<&str>, max_tokens: i32, project: Option<&str>) -> Result<(), AppError> { + let context = MemoryService::get_context(query, max_tokens, project)?; + + if context.is_empty() { + println!("{}", info("No relevant context found.")); + return Ok(()); + } + + let total_tokens: i32 = context.iter().map(|c| c.observation.tokens).sum(); + + println!( + "{}", + highlight(&format!( + "Context ({} items, {} tokens):", + context.len(), + total_tokens + )) + ); + println!(); + + for (i, item) in context.iter().enumerate() { + let priority_label = match item.priority { + 1 => "[FTS]", + 2 => "[Project]", + _ => "[Recent]", + }; + + println!( + "{}. {} [{}] {}", + i + 1, + priority_label, + item.observation.observation_type, + item.observation.title + ); + println!( + " {}", + truncate(&item.observation.content.replace('\n', " "), 80) + ); + println!(); + } + + Ok(()) +} + +fn execute_hooks(cmd: HooksCommand) -> Result<(), AppError> { + match cmd { + HooksCommand::Register => { + MemoryService::register_hooks()?; + println!("{}", success("Hooks registered in Claude Code settings")); + Ok(()) + } + HooksCommand::Unregister => { + MemoryService::unregister_hooks()?; + println!("{}", success("Hooks unregistered from Claude Code settings")); + Ok(()) + } + HooksCommand::Status => { + let status = MemoryService::hooks_status()?; + + println!("{}", highlight("Hook Status")); + println!( + "Registered: {}", + if status.registered { "Yes" } else { "No" } + ); + println!( + "SessionStart: {}", + if status.session_start { "Yes" } else { "No" } + ); + println!( + "PostToolUse: {}", + if status.post_tool_use { "Yes" } else { "No" } + ); + Ok(()) + } + HooksCommand::Ingest { hook } => { + // Read event JSON from stdin + let stdin = io::stdin(); + let mut input = String::new(); + for line in stdin.lock().lines() { + let line = line.map_err(|e| AppError::Message(format!("Failed to read stdin: {e}")))?; + input.push_str(&line); + } + + if input.trim().is_empty() { + return Ok(()); + } + + match MemoryService::ingest_hook_event(&input, hook) { + Ok(Some(output)) => { + // Output context to stdout for Claude to see + print!("{}", output); + } + Ok(None) => {} + Err(e) => { + // Write to stderr so the user can see hook failures + // without interfering with stdout used by Claude + eprintln!("[cc-switch memory] hook ingest error: {}", e); + log::warn!("Hook ingest error: {}", e); + } + } + Ok(()) + } + } +} + +fn list_sessions(limit: i64) -> Result<(), AppError> { + let sessions = MemoryService::list_sessions(Some(limit))?; + + if sessions.is_empty() { + println!("{}", info("No sessions found.")); + return Ok(()); + } + + let mut table = create_table(); + table.set_header(vec!["ID", "App", "Project", "Started", "Ended", "Summary"]); + + for session in sessions { + table.add_row(vec![ + session.id.to_string(), + session.app, + session + .project_dir + .map(|p| truncate(&p, 30)) + .unwrap_or_else(|| "-".to_string()), + session.started_at.format("%Y-%m-%d %H:%M").to_string(), + session + .ended_at + .map(|e| e.format("%H:%M").to_string()) + .unwrap_or_else(|| "ongoing".to_string()), + session + .summary + .map(|s| truncate(&s, 30)) + .unwrap_or_else(|| "-".to_string()), + ]); + } + + println!("{}", table); + Ok(()) +} + +fn truncate(s: &str, max_len: usize) -> String { + if s.len() <= max_len { + return s.to_string(); + } + + let ellipsis = "..."; + if max_len <= ellipsis.len() { + return ellipsis.chars().take(max_len).collect(); + } + + let char_limit = max_len.saturating_sub(ellipsis.len()); + let mut end = 0; + for (idx, _) in s.char_indices() { + if idx > char_limit { + break; + } + end = idx; + } + + format!("{}{}", &s[..end], ellipsis) +} diff --git a/src-tauri/src/cli/commands/mod.rs b/src-tauri/src/cli/commands/mod.rs index 0c13533..7cbb906 100644 --- a/src-tauri/src/cli/commands/mod.rs +++ b/src-tauri/src/cli/commands/mod.rs @@ -1,6 +1,8 @@ +pub mod check; pub mod config; pub mod env; pub mod mcp; +pub mod memory; pub mod prompts; pub mod provider; pub mod provider_input; diff --git a/src-tauri/src/cli/commands/skills.rs b/src-tauri/src/cli/commands/skills.rs index d64add5..503a5e2 100644 --- a/src-tauri/src/cli/commands/skills.rs +++ b/src-tauri/src/cli/commands/skills.rs @@ -37,6 +37,10 @@ pub enum SkillsCommand { /// Skill directory or id spec: String, }, + /// Enable all installed skills for the selected app + EnableAll, + /// Disable all skills for the selected app + DisableAll, /// Sync enabled skills to app skills dirs Sync, /// Scan unmanaged skills in app skills dirs @@ -88,6 +92,8 @@ pub fn execute(cmd: SkillsCommand, app: Option) -> Result<(), AppError> SkillsCommand::Uninstall { spec } => uninstall_skill(&spec), SkillsCommand::Enable { spec } => toggle_skill(&app_type, &spec, true), SkillsCommand::Disable { spec } => toggle_skill(&app_type, &spec, false), + SkillsCommand::EnableAll => enable_all(&app_type), + SkillsCommand::DisableAll => disable_all(&app_type), SkillsCommand::Sync => sync_skills(app.as_ref()), SkillsCommand::ScanUnmanaged => scan_unmanaged(), SkillsCommand::ImportFromApps { directories } => import_from_apps(directories), @@ -192,6 +198,58 @@ fn toggle_skill(app_type: &AppType, spec: &str, enabled: bool) -> Result<(), App Ok(()) } +fn enable_all(app_type: &AppType) -> Result<(), AppError> { + let skills = SkillService::list_installed()?; + if skills.is_empty() { + println!("{}", info("No installed skills found.")); + return Ok(()); + } + + let mut count = 0; + for skill in &skills { + if !skill.apps.is_enabled_for(app_type) { + SkillService::toggle_app(&skill.directory, app_type, true)?; + count += 1; + } + } + + println!( + "{}", + success(&format!( + "✓ Enabled {} skill(s) for {}", + count, + app_type.as_str() + )) + ); + Ok(()) +} + +fn disable_all(app_type: &AppType) -> Result<(), AppError> { + let skills = SkillService::list_installed()?; + if skills.is_empty() { + println!("{}", info("No installed skills found.")); + return Ok(()); + } + + let mut count = 0; + for skill in &skills { + if skill.apps.is_enabled_for(app_type) { + SkillService::toggle_app(&skill.directory, app_type, false)?; + count += 1; + } + } + + println!( + "{}", + success(&format!( + "✓ Disabled {} skill(s) for {}", + count, + app_type.as_str() + )) + ); + Ok(()) +} + fn sync_skills(app: Option<&AppType>) -> Result<(), AppError> { SkillService::sync_all_enabled(app)?; println!("{}", success("✓ Skills synced successfully")); diff --git a/src-tauri/src/cli/mod.rs b/src-tauri/src/cli/mod.rs index 7a4ca04..9121073 100644 --- a/src-tauri/src/cli/mod.rs +++ b/src-tauri/src/cli/mod.rs @@ -59,6 +59,14 @@ pub enum Commands { /// Update cc-switch binary to latest release Update(commands::update::UpdateCommand), + /// Manage memory (observations, context, hooks) + #[command(subcommand)] + Memory(commands::memory::MemoryCommand), + + /// Check for CLI tool updates (Claude Code, Codex, Gemini) + #[command(subcommand)] + Check(commands::check::CheckCommand), + /// Enter interactive mode #[command(alias = "ui")] Interactive, diff --git a/src-tauri/src/lib.rs b/src-tauri/src/lib.rs index 40d613b..1412ff5 100644 --- a/src-tauri/src/lib.rs +++ b/src-tauri/src/lib.rs @@ -42,8 +42,8 @@ pub use mcp::{ }; pub use provider::{Provider, ProviderMeta}; pub use services::{ - ConfigService, EndpointLatency, McpService, PromptService, ProviderService, SkillService, - SpeedtestService, SyncDecision, WebDavSyncService, WebDavSyncSummary, + ConfigService, EndpointLatency, McpService, MemoryService, PromptService, ProviderService, + SkillService, SpeedtestService, SyncDecision, WebDavSyncService, WebDavSyncSummary, }; pub use settings::{ get_skip_claude_onboarding, get_webdav_sync_settings, set_skip_claude_onboarding, diff --git a/src-tauri/src/main.rs b/src-tauri/src/main.rs index 72af700..d79c121 100644 --- a/src-tauri/src/main.rs +++ b/src-tauri/src/main.rs @@ -37,6 +37,8 @@ fn run(cli: Cli) -> Result<(), AppError> { Some(Commands::Config(cmd)) => cc_switch_lib::cli::commands::config::execute(cmd, cli.app), Some(Commands::Env(cmd)) => cc_switch_lib::cli::commands::env::execute(cmd, cli.app), Some(Commands::Update(cmd)) => cc_switch_lib::cli::commands::update::execute(cmd), + Some(Commands::Memory(cmd)) => cc_switch_lib::cli::commands::memory::execute(cmd, cli.app), + Some(Commands::Check(cmd)) => cc_switch_lib::cli::commands::check::execute(cmd, cli.app), Some(Commands::Completions { shell }) => { cc_switch_lib::cli::generate_completions(shell); Ok(()) diff --git a/src-tauri/src/services/memory.rs b/src-tauri/src/services/memory.rs new file mode 100644 index 0000000..50dfb7d --- /dev/null +++ b/src-tauri/src/services/memory.rs @@ -0,0 +1,1266 @@ +//! Memory service for session context capture and semantic search. +//! +//! Uses SQLite + FTS5 for full-text search. Database is stored in the app config directory +//! (e.g. `~/.cc-switch/memory.db`) as determined by `get_app_config_dir()`. + +use chrono::{DateTime, TimeZone, Utc}; +use rusqlite::{params, Connection, OptionalExtension}; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; +use std::sync::{Mutex, Once, OnceLock}; + +use crate::app_config::AppType; +use crate::config::get_app_config_dir; +use crate::error::AppError; + +// ============================================================================ +// Data Structures +// ============================================================================ + +/// Observation types for categorizing memories +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum ObservationType { + /// Important decisions made during development + Decision, + /// Errors encountered and their solutions + Error, + /// Code patterns and conventions discovered + Pattern, + /// User preferences learned from interactions + Preference, + /// General observations + #[default] + General, +} + +impl ObservationType { + pub fn as_str(&self) -> &'static str { + match self { + Self::Decision => "decision", + Self::Error => "error", + Self::Pattern => "pattern", + Self::Preference => "preference", + Self::General => "general", + } + } + + pub fn from_str(s: &str) -> Self { + match s.to_lowercase().as_str() { + "decision" => Self::Decision, + "error" => Self::Error, + "pattern" => Self::Pattern, + "preference" => Self::Preference, + _ => Self::General, + } + } +} + +impl std::fmt::Display for ObservationType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_str()) + } +} + +impl std::str::FromStr for ObservationType { + type Err = (); + + fn from_str(s: &str) -> Result { + Ok(Self::from_str(s)) + } +} + +/// A session record +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Session { + pub id: i64, + pub app: String, + pub project_dir: Option, + pub started_at: DateTime, + pub ended_at: Option>, + pub summary: Option, +} + +/// An observation record +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Observation { + pub id: i64, + pub session_id: Option, + pub title: String, + pub content: String, + pub observation_type: ObservationType, + pub tags: Vec, + pub tokens: i32, + /// Reserved for future relevance ranking. Currently always 1.0. + pub relevance_score: f64, + pub created_at: DateTime, + pub project_dir: Option, +} + +/// Input for creating a new observation +#[derive(Debug, Clone)] +pub struct NewObservation { + pub session_id: Option, + pub title: String, + pub content: String, + pub observation_type: ObservationType, + pub tags: Vec, + pub project_dir: Option, +} + +/// Context item for progressive disclosure +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ContextItem { + pub observation: Observation, + pub priority: u8, // 1 = highest (FTS match), 2 = project match, 3 = recent + pub match_reason: String, +} + +/// Statistics about the memory database +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct MemoryStats { + pub total_observations: i64, + pub total_sessions: i64, + pub total_tokens: i64, + pub observations_by_type: Vec<(String, i64)>, + pub oldest_observation: Option>, + pub newest_observation: Option>, +} + +// ============================================================================ +// Database Connection Management +// ============================================================================ + +fn get_db_path() -> PathBuf { + get_app_config_dir().join("memory.db") +} + +static DB_CONNECTION: OnceLock> = OnceLock::new(); +static DB_INIT: Once = Once::new(); +static DB_INIT_ERROR: OnceLock = OnceLock::new(); + +fn get_connection() -> Result<&'static Mutex, AppError> { + DB_INIT.call_once(|| { + let result = (|| -> Result { + let path = get_db_path(); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent).map_err(|e| AppError::io(parent, e))?; + } + let conn = Connection::open(&path) + .map_err(|e| AppError::Message(format!("Failed to open memory database: {e}")))?; + init_schema(&conn)?; + Ok(conn) + })(); + + match result { + Ok(conn) => { let _ = DB_CONNECTION.set(Mutex::new(conn)); } + Err(e) => { let _ = DB_INIT_ERROR.set(e.to_string()); } + } + }); + + DB_CONNECTION.get().ok_or_else(|| { + let msg = DB_INIT_ERROR.get().map(|s| s.as_str()).unwrap_or("Unknown error"); + AppError::Message(format!("Memory database initialization failed: {msg}")) + }) +} + +fn init_schema(conn: &Connection) -> Result<(), AppError> { + conn.execute_batch( + r#" + -- Sessions table + CREATE TABLE IF NOT EXISTS sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + app TEXT NOT NULL, + project_dir TEXT, + started_at INTEGER NOT NULL, + ended_at INTEGER, + summary TEXT + ); + + -- Observations table + CREATE TABLE IF NOT EXISTS observations ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + session_id INTEGER REFERENCES sessions(id), + title TEXT NOT NULL, + content TEXT NOT NULL, + observation_type TEXT NOT NULL DEFAULT 'general', + tags TEXT NOT NULL DEFAULT '', + tokens INTEGER NOT NULL DEFAULT 0, + relevance_score REAL NOT NULL DEFAULT 1.0, + created_at INTEGER NOT NULL, + project_dir TEXT + ); + + -- Create indexes + CREATE INDEX IF NOT EXISTS idx_observations_session ON observations(session_id); + CREATE INDEX IF NOT EXISTS idx_observations_type ON observations(observation_type); + CREATE INDEX IF NOT EXISTS idx_observations_created ON observations(created_at DESC); + CREATE INDEX IF NOT EXISTS idx_observations_project ON observations(project_dir); + CREATE INDEX IF NOT EXISTS idx_sessions_app ON sessions(app); + CREATE INDEX IF NOT EXISTS idx_sessions_started ON sessions(started_at DESC); + + -- FTS5 virtual table for full-text search + CREATE VIRTUAL TABLE IF NOT EXISTS observations_fts USING fts5( + title, + content, + tags, + content='observations', + content_rowid='id' + ); + + -- Triggers to keep FTS5 in sync + CREATE TRIGGER IF NOT EXISTS observations_ai AFTER INSERT ON observations BEGIN + INSERT INTO observations_fts(rowid, title, content, tags) + VALUES (new.id, new.title, new.content, new.tags); + END; + + CREATE TRIGGER IF NOT EXISTS observations_ad AFTER DELETE ON observations BEGIN + INSERT INTO observations_fts(observations_fts, rowid, title, content, tags) + VALUES ('delete', old.id, old.title, old.content, old.tags); + END; + + CREATE TRIGGER IF NOT EXISTS observations_au AFTER UPDATE ON observations BEGIN + INSERT INTO observations_fts(observations_fts, rowid, title, content, tags) + VALUES ('delete', old.id, old.title, old.content, old.tags); + INSERT INTO observations_fts(rowid, title, content, tags) + VALUES (new.id, new.title, new.content, new.tags); + END; + "#, + ) + .map_err(|e| AppError::Message(format!("Failed to initialize memory schema: {e}")))?; + + Ok(()) +} + +// ============================================================================ +// Token Estimation +// ============================================================================ + +/// Safely convert a Unix timestamp to DateTime, returning an error for invalid values. +fn timestamp_to_datetime(ts: i64) -> Result, rusqlite::Error> { + Utc.timestamp_opt(ts, 0) + .single() + .ok_or_else(|| rusqlite::Error::IntegralValueOutOfRange(0, ts)) +} + +/// Estimate token count for text. +/// +/// Uses a rough heuristic of ~4 bytes per token. This is an approximation and may +/// be inaccurate for CJK text or code. Callers should not rely on exact counts. +fn estimate_tokens(text: &str) -> i32 { + (text.len() as f64 / 4.0).ceil() as i32 +} + +// ============================================================================ +// MemoryService +// ============================================================================ + +pub struct MemoryService; + +impl MemoryService { + // ------------------------------------------------------------------------- + // Observation CRUD + // ------------------------------------------------------------------------- + + /// Add a new observation + pub fn add_observation(obs: NewObservation) -> Result { + let conn = get_connection()?.lock()?; + let now = Utc::now(); + let tags_str = obs.tags.join(","); + let full_text = format!("{} {}", obs.title, obs.content); + let tokens = estimate_tokens(&full_text); + + conn.execute( + r#" + INSERT INTO observations (session_id, title, content, observation_type, tags, tokens, relevance_score, created_at, project_dir) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9) + "#, + params![ + obs.session_id, + obs.title, + obs.content, + obs.observation_type.as_str(), + tags_str, + tokens, + 1.0, + now.timestamp(), + obs.project_dir, + ], + ) + .map_err(|e| AppError::Message(format!("Failed to add observation: {e}")))?; + + let id = conn.last_insert_rowid(); + + Ok(Observation { + id, + session_id: obs.session_id, + title: obs.title, + content: obs.content, + observation_type: obs.observation_type, + tags: obs.tags, + tokens, + relevance_score: 1.0, + created_at: now, + project_dir: obs.project_dir, + }) + } + + /// Get an observation by ID + pub fn get_observation(id: i64) -> Result, AppError> { + let conn = get_connection()?.lock()?; + + conn.query_row( + r#" + SELECT id, session_id, title, content, observation_type, tags, tokens, relevance_score, created_at, project_dir + FROM observations WHERE id = ?1 + "#, + params![id], + |row| { + Ok(Observation { + id: row.get(0)?, + session_id: row.get(1)?, + title: row.get(2)?, + content: row.get(3)?, + observation_type: ObservationType::from_str(&row.get::<_, String>(4)?), + tags: row + .get::<_, String>(5)? + .split(',') + .filter(|s| !s.is_empty()) + .map(String::from) + .collect(), + tokens: row.get(6)?, + relevance_score: row.get(7)?, + created_at: timestamp_to_datetime(row.get(8)?)?, + project_dir: row.get(9)?, + }) + }, + ) + .optional() + .map_err(|e| AppError::Message(format!("Failed to get observation: {e}"))) + } + + /// List observations with optional filters + pub fn list_observations( + limit: Option, + observation_type: Option, + project_dir: Option<&str>, + ) -> Result, AppError> { + let conn = get_connection()?.lock()?; + let limit = limit.unwrap_or(50); + + let mut sql = String::from( + r#" + SELECT id, session_id, title, content, observation_type, tags, tokens, relevance_score, created_at, project_dir + FROM observations + WHERE 1=1 + "#, + ); + + let mut params_vec: Vec> = vec![]; + + if let Some(ref obs_type) = observation_type { + sql.push_str(" AND observation_type = ?"); + params_vec.push(Box::new(obs_type.as_str().to_string())); + } + + if let Some(ref proj) = project_dir { + sql.push_str(" AND project_dir = ?"); + params_vec.push(Box::new(proj.to_string())); + } + + sql.push_str(" ORDER BY created_at DESC LIMIT ?"); + params_vec.push(Box::new(limit)); + + let mut stmt = conn + .prepare(&sql) + .map_err(|e| AppError::Message(format!("Failed to prepare query: {e}")))?; + + let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec.iter().map(|p| p.as_ref()).collect(); + + let rows = stmt + .query_map(params_refs.as_slice(), |row| { + Ok(Observation { + id: row.get(0)?, + session_id: row.get(1)?, + title: row.get(2)?, + content: row.get(3)?, + observation_type: ObservationType::from_str(&row.get::<_, String>(4)?), + tags: row + .get::<_, String>(5)? + .split(',') + .filter(|s| !s.is_empty()) + .map(String::from) + .collect(), + tokens: row.get(6)?, + relevance_score: row.get(7)?, + created_at: timestamp_to_datetime(row.get(8)?)?, + project_dir: row.get(9)?, + }) + }) + .map_err(|e| AppError::Message(format!("Failed to list observations: {e}")))?; + + let mut results = Vec::new(); + for row in rows { + results.push(row.map_err(|e| AppError::Message(format!("Row error: {e}")))?); + } + + Ok(results) + } + + /// Delete an observation + pub fn delete_observation(id: i64) -> Result { + let conn = get_connection()?.lock()?; + + let rows = conn + .execute("DELETE FROM observations WHERE id = ?1", params![id]) + .map_err(|e| AppError::Message(format!("Failed to delete observation: {e}")))?; + + Ok(rows > 0) + } + + // ------------------------------------------------------------------------- + // FTS5 Search + // ------------------------------------------------------------------------- + + /// Search observations using FTS5 + pub fn search(query: &str, limit: Option) -> Result, AppError> { + let conn = get_connection()?.lock()?; + let limit = limit.unwrap_or(20); + + // Sanitize query for FTS5: strip special characters and wrap each token in quotes + let fts_query: String = query + .split_whitespace() + .map(|token| { + let clean: String = token.chars().filter(|c| c.is_alphanumeric() || *c == '_' || *c == '-').collect(); + format!("\"{}\"", clean) + }) + .filter(|t| t != "\"\"") + .collect::>() + .join(" "); + + if fts_query.is_empty() { + return Ok(Vec::new()); + } + + let mut stmt = conn + .prepare( + r#" + SELECT o.id, o.session_id, o.title, o.content, o.observation_type, o.tags, + o.tokens, o.relevance_score, o.created_at, o.project_dir + FROM observations o + JOIN observations_fts fts ON o.id = fts.rowid + WHERE observations_fts MATCH ?1 + ORDER BY rank + LIMIT ?2 + "#, + ) + .map_err(|e| AppError::Message(format!("Failed to prepare search: {e}")))?; + + let rows = stmt + .query_map(params![fts_query, limit], |row| { + Ok(Observation { + id: row.get(0)?, + session_id: row.get(1)?, + title: row.get(2)?, + content: row.get(3)?, + observation_type: ObservationType::from_str(&row.get::<_, String>(4)?), + tags: row + .get::<_, String>(5)? + .split(',') + .filter(|s| !s.is_empty()) + .map(String::from) + .collect(), + tokens: row.get(6)?, + relevance_score: row.get(7)?, + created_at: timestamp_to_datetime(row.get(8)?)?, + project_dir: row.get(9)?, + }) + }) + .map_err(|e| AppError::Message(format!("Search failed: {e}")))?; + + let mut results = Vec::new(); + for row in rows { + results.push(row.map_err(|e| AppError::Message(format!("Row error: {e}")))?); + } + + Ok(results) + } + + // ------------------------------------------------------------------------- + // Progressive Disclosure + // ------------------------------------------------------------------------- + + /// Get context with progressive disclosure and token budget + pub fn get_context( + query: Option<&str>, + max_tokens: i32, + project_dir: Option<&str>, + ) -> Result, AppError> { + let mut items: Vec = Vec::new(); + let mut used_tokens = 0; + let mut seen_ids = std::collections::HashSet::new(); + + // Layer 1: FTS5 matches (highest priority) + if let Some(q) = query { + if !q.trim().is_empty() { + let search_results = Self::search(q, Some(10))?; + for obs in search_results { + if used_tokens + obs.tokens > max_tokens { + continue; + } + if seen_ids.contains(&obs.id) { + continue; + } + seen_ids.insert(obs.id); + used_tokens += obs.tokens; + items.push(ContextItem { + observation: obs, + priority: 1, + match_reason: "FTS match".to_string(), + }); + } + } + } + + // Layer 2: Project-specific observations + if let Some(proj) = project_dir { + let project_obs = Self::list_observations(Some(20), None, Some(proj))?; + for obs in project_obs { + if used_tokens + obs.tokens > max_tokens { + continue; + } + if seen_ids.contains(&obs.id) { + continue; + } + seen_ids.insert(obs.id); + used_tokens += obs.tokens; + items.push(ContextItem { + observation: obs, + priority: 2, + match_reason: "Project match".to_string(), + }); + } + } + + // Layer 3: Recent observations (lowest priority) + let recent = Self::list_observations(Some(50), None, None)?; + for obs in recent { + if used_tokens + obs.tokens > max_tokens { + continue; + } + if seen_ids.contains(&obs.id) { + continue; + } + seen_ids.insert(obs.id); + used_tokens += obs.tokens; + items.push(ContextItem { + observation: obs, + priority: 3, + match_reason: "Recent".to_string(), + }); + } + + // Sort by priority (lower number = higher priority) + items.sort_by(|a, b| a.priority.cmp(&b.priority)); + + Ok(items) + } + + // ------------------------------------------------------------------------- + // Sessions + // ------------------------------------------------------------------------- + + /// Start a new session + pub fn start_session(app: &AppType, project_dir: Option<&str>) -> Result { + let conn = get_connection()?.lock()?; + let now = Utc::now(); + + conn.execute( + r#" + INSERT INTO sessions (app, project_dir, started_at) + VALUES (?1, ?2, ?3) + "#, + params![app.as_str(), project_dir, now.timestamp()], + ) + .map_err(|e| AppError::Message(format!("Failed to start session: {e}")))?; + + let id = conn.last_insert_rowid(); + + Ok(Session { + id, + app: app.as_str().to_string(), + project_dir: project_dir.map(String::from), + started_at: now, + ended_at: None, + summary: None, + }) + } + + /// End a session + pub fn end_session(id: i64, summary: Option<&str>) -> Result { + let conn = get_connection()?.lock()?; + let now = Utc::now(); + + let rows = conn + .execute( + r#" + UPDATE sessions SET ended_at = ?1, summary = ?2 WHERE id = ?3 + "#, + params![now.timestamp(), summary, id], + ) + .map_err(|e| AppError::Message(format!("Failed to end session: {e}")))?; + + Ok(rows > 0) + } + + /// List recent sessions + pub fn list_sessions(limit: Option) -> Result, AppError> { + let conn = get_connection()?.lock()?; + let limit = limit.unwrap_or(10); + + let mut stmt = conn + .prepare( + r#" + SELECT id, app, project_dir, started_at, ended_at, summary + FROM sessions + ORDER BY started_at DESC + LIMIT ?1 + "#, + ) + .map_err(|e| AppError::Message(format!("Failed to prepare query: {e}")))?; + + let rows = stmt + .query_map(params![limit], |row| { + Ok(Session { + id: row.get(0)?, + app: row.get(1)?, + project_dir: row.get(2)?, + started_at: timestamp_to_datetime(row.get(3)?)?, + ended_at: row + .get::<_, Option>(4)? + .map(|ts| timestamp_to_datetime(ts)) + .transpose()?, + summary: row.get(5)?, + }) + }) + .map_err(|e| AppError::Message(format!("Failed to list sessions: {e}")))?; + + let mut results = Vec::new(); + for row in rows { + results.push(row.map_err(|e| AppError::Message(format!("Row error: {e}")))?); + } + + Ok(results) + } + + // ------------------------------------------------------------------------- + // Statistics + // ------------------------------------------------------------------------- + + /// Get memory statistics + pub fn stats() -> Result { + let conn = get_connection()?.lock()?; + + let total_observations: i64 = conn + .query_row("SELECT COUNT(*) FROM observations", [], |row| row.get(0)) + .unwrap_or(0); + + let total_sessions: i64 = conn + .query_row("SELECT COUNT(*) FROM sessions", [], |row| row.get(0)) + .unwrap_or(0); + + let total_tokens: i64 = conn + .query_row( + "SELECT COALESCE(SUM(tokens), 0) FROM observations", + [], + |row| row.get(0), + ) + .unwrap_or(0); + + let mut stmt = conn + .prepare("SELECT observation_type, COUNT(*) FROM observations GROUP BY observation_type") + .map_err(|e| AppError::Message(format!("Failed to prepare stats query: {e}")))?; + + let by_type: Vec<(String, i64)> = stmt + .query_map([], |row| Ok((row.get(0)?, row.get(1)?))) + .map_err(|e| AppError::Message(format!("Stats query failed: {e}")))? + .filter_map(|r| r.ok()) + .collect(); + + let oldest: Option> = conn + .query_row( + "SELECT MIN(created_at) FROM observations", + [], + |row| row.get::<_, Option>(0), + ) + .ok() + .flatten() + .and_then(|ts| timestamp_to_datetime(ts).ok()); + + let newest: Option> = conn + .query_row( + "SELECT MAX(created_at) FROM observations", + [], + |row| row.get::<_, Option>(0), + ) + .ok() + .flatten() + .and_then(|ts| timestamp_to_datetime(ts).ok()); + + Ok(MemoryStats { + total_observations, + total_sessions, + total_tokens, + observations_by_type: by_type, + oldest_observation: oldest, + newest_observation: newest, + }) + } + + // ------------------------------------------------------------------------- + // Hook Integration + // ------------------------------------------------------------------------- + + /// Check if hooks are registered in Claude Code settings + pub fn hooks_status() -> Result { + let settings_path = crate::config::get_claude_settings_path(); + if !settings_path.exists() { + return Ok(HooksStatus { + registered: false, + session_start: false, + post_tool_use: false, + }); + } + + let content = + std::fs::read_to_string(&settings_path).map_err(|e| AppError::io(&settings_path, e))?; + + let value: serde_json::Value = + serde_json::from_str(&content).map_err(|e| AppError::json(&settings_path, e))?; + + let hooks = value.get("hooks"); + let session_start = hooks + .and_then(|h| h.get("SessionStart")) + .map(|v| has_cc_switch_hook(v)) + .unwrap_or(false); + let post_tool_use = hooks + .and_then(|h| h.get("PostToolUse")) + .map(|v| has_cc_switch_hook(v)) + .unwrap_or(false); + + Ok(HooksStatus { + registered: session_start || post_tool_use, + session_start, + post_tool_use, + }) + } + + /// Register hooks in Claude Code settings + pub fn register_hooks() -> Result<(), AppError> { + let settings_path = crate::config::get_claude_settings_path(); + + let mut value: serde_json::Value = if settings_path.exists() { + let content = std::fs::read_to_string(&settings_path) + .map_err(|e| AppError::io(&settings_path, e))?; + serde_json::from_str(&content).map_err(|e| AppError::json(&settings_path, e))? + } else { + serde_json::json!({}) + }; + + let hooks = value + .as_object_mut() + .ok_or_else(|| AppError::Message("Settings is not an object".to_string()))? + .entry("hooks") + .or_insert(serde_json::json!({})); + + let hooks_obj = hooks + .as_object_mut() + .ok_or_else(|| AppError::Message("Hooks is not an object".to_string()))?; + + // SessionStart hook - outputs context to Claude + let session_start_hook = serde_json::json!([{ + "matcher": "", + "hooks": [{ + "type": "command", + "command": "cc-switch memory hooks ingest --hook session-start" + }] + }]); + + // PostToolUse hook - captures observations + let post_tool_use_hook = serde_json::json!([{ + "matcher": "", + "hooks": [{ + "type": "command", + "command": "cc-switch memory hooks ingest --hook post-tool-use" + }] + }]); + + // Check and add hooks if not present + if !has_cc_switch_hook(hooks_obj.get("SessionStart").unwrap_or(&serde_json::Value::Null)) { + hooks_obj.insert("SessionStart".to_string(), session_start_hook); + } + + if !has_cc_switch_hook(hooks_obj.get("PostToolUse").unwrap_or(&serde_json::Value::Null)) { + hooks_obj.insert("PostToolUse".to_string(), post_tool_use_hook); + } + + crate::config::write_json_file(&settings_path, &value)?; + Ok(()) + } + + /// Unregister hooks from Claude Code settings + pub fn unregister_hooks() -> Result<(), AppError> { + let settings_path = crate::config::get_claude_settings_path(); + + if !settings_path.exists() { + return Ok(()); + } + + let content = + std::fs::read_to_string(&settings_path).map_err(|e| AppError::io(&settings_path, e))?; + + let mut value: serde_json::Value = + serde_json::from_str(&content).map_err(|e| AppError::json(&settings_path, e))?; + + if let Some(hooks) = value.get_mut("hooks").and_then(|h| h.as_object_mut()) { + // Remove our hooks from SessionStart + if let Some(session_start) = hooks.get_mut("SessionStart") { + remove_cc_switch_hook(session_start); + } + + // Remove our hooks from PostToolUse + if let Some(post_tool_use) = hooks.get_mut("PostToolUse") { + remove_cc_switch_hook(post_tool_use); + } + } + + crate::config::write_json_file(&settings_path, &value)?; + Ok(()) + } + + /// Process incoming hook event (called from hooks ingest) + pub fn ingest_hook_event( + event_json: &str, + hook_type: crate::cli::commands::memory::HookType, + ) -> Result, AppError> { + let event: serde_json::Value = serde_json::from_str(event_json) + .map_err(|e| AppError::Message(format!("Invalid hook event JSON: {e}")))?; + + use crate::cli::commands::memory::HookType; + match hook_type { + HookType::SessionStart => Self::handle_session_start(&event), + HookType::PostToolUse => Self::handle_post_tool_use(&event), + } + } + + fn handle_session_start(event: &serde_json::Value) -> Result, AppError> { + // Claude Code sends: { "session_id": "...", "cwd": "..." } + let project_dir = event + .get("cwd") + .and_then(|v| v.as_str()) + .map(String::from); + + // Start a new session + let _ = Self::start_session(&AppType::Claude, project_dir.as_deref())?; + + // Get context for this session + let context = Self::get_context(None, 4000, project_dir.as_deref())?; + + if context.is_empty() { + return Ok(None); + } + + // Format context for output + let mut output = String::from("## Memory Context\n\n"); + for item in context.iter().take(5) { + output.push_str(&format!( + "### {} ({})\n{}\n\n", + item.observation.title, + item.observation.observation_type, + item.observation.content + )); + } + + Ok(Some(output)) + } + + fn handle_post_tool_use(event: &serde_json::Value) -> Result, AppError> { + // Claude Code sends: { "session_id", "cwd", "tool_name", "tool_input", "tool_response" } + let tool_name = event + .get("tool_name") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + let tool_input = event.get("tool_input"); + let tool_output = event.get("tool_response"); + + // Filter for interesting events + let (observation_type, title) = match tool_name { + "Write" | "Edit" => (Some(ObservationType::Pattern), format!("{} operation", tool_name)), + "Bash" => { + // Check exit code first if available, then fall back to output heuristics + let has_error_exit = event + .get("exit_code") + .and_then(|v| v.as_i64()) + .map(|code| code != 0) + .unwrap_or(false); + + if has_error_exit { + (Some(ObservationType::Error), "Bash error".to_string()) + } else if let Some(output) = tool_output.and_then(|v| v.as_str()) { + // Use line-start patterns to reduce false positives + let has_error = output.lines().any(|line| { + let trimmed = line.trim(); + trimmed.starts_with("error:") + || trimmed.starts_with("Error:") + || trimmed.starts_with("ERROR:") + || trimmed.starts_with("FAILED") + || trimmed.starts_with("fatal:") + || trimmed.starts_with("panic:") + }); + if has_error { + (Some(ObservationType::Error), "Bash error".to_string()) + } else { + (None, String::new()) + } + } else { + (None, String::new()) + } + } + // GitHub MCP tools + "mcp__github__create_pull_request" => { + let pr_title = tool_input + .and_then(|v| v.get("title")) + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + (Some(ObservationType::Decision), format!("PR created: {}", pr_title)) + } + "mcp__github__merge_pull_request" => { + let pr_num = tool_input + .and_then(|v| v.get("pull_number")) + .and_then(|v| v.as_i64()) + .map(|n| n.to_string()) + .unwrap_or_else(|| "unknown".to_string()); + (Some(ObservationType::Decision), format!("PR #{} merged", pr_num)) + } + "mcp__github__create_issue" => { + let issue_title = tool_input + .and_then(|v| v.get("title")) + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + (Some(ObservationType::General), format!("Issue created: {}", issue_title)) + } + "mcp__github__create_branch" => { + let branch = tool_input + .and_then(|v| v.get("branch")) + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + (Some(ObservationType::Pattern), format!("Branch created: {}", branch)) + } + "mcp__github__push_files" | "mcp__github__create_or_update_file" => { + let msg = tool_input + .and_then(|v| v.get("message")) + .and_then(|v| v.as_str()) + .unwrap_or("unknown"); + (Some(ObservationType::Pattern), format!("Pushed: {}", msg)) + } + _ => (None, String::new()), + }; + + let Some(obs_type) = observation_type else { + return Ok(None); + }; + let content = format!( + "Tool: {}\nInput: {}\nOutput: {}", + tool_name, + tool_input + .map(|v| v.to_string()) + .unwrap_or_else(|| "N/A".to_string()), + tool_output + .map(|v| v.to_string()) + .unwrap_or_else(|| "N/A".to_string()) + ); + + let project_dir = event + .get("cwd") + .and_then(|v| v.as_str()) + .map(String::from); + + Self::add_observation(NewObservation { + session_id: None, + title, + content, + observation_type: obs_type, + tags: vec![tool_name.to_string()], + project_dir, + })?; + + Ok(None) + } + + // ------------------------------------------------------------------------- + // SQL Export / Import (for WebDAV sync) + // ------------------------------------------------------------------------- + + /// Export memory database as SQL bytes. Returns `None` if the database is empty. + pub fn export_sql_bytes() -> Result>, AppError> { + let conn = get_connection()?.lock()?; + + let total: i64 = conn + .query_row( + "SELECT (SELECT COUNT(*) FROM sessions) + (SELECT COUNT(*) FROM observations)", + [], + |row| row.get(0), + ) + .unwrap_or(0); + + if total == 0 { + return Ok(None); + } + + let mut output = String::new(); + output.push_str("-- CC Switch Memory 导出\n"); + output.push_str("PRAGMA foreign_keys=OFF;\nBEGIN TRANSACTION;\n"); + + // Schema for sessions and observations (regular tables only) + let mut stmt = conn + .prepare( + "SELECT type, name, sql FROM sqlite_master \ + WHERE sql NOT NULL AND name IN ('sessions','observations') \ + ORDER BY type='table' DESC, name", + ) + .map_err(|e| AppError::Message(format!("Memory export schema query failed: {e}")))?; + + let mut tables = Vec::new(); + let mut rows = stmt.query([]).map_err(|e| AppError::Message(e.to_string()))?; + while let Some(row) = rows.next().map_err(|e| AppError::Message(e.to_string()))? { + let obj_type: String = row.get(0).map_err(|e| AppError::Message(e.to_string()))?; + let name: String = row.get(1).map_err(|e| AppError::Message(e.to_string()))?; + let sql: String = row.get(2).map_err(|e| AppError::Message(e.to_string()))?; + output.push_str(&sql); + output.push_str(";\n"); + if obj_type == "table" { + tables.push(name); + } + } + + // Export indexes for these tables + let mut idx_stmt = conn + .prepare( + "SELECT sql FROM sqlite_master \ + WHERE type='index' AND sql NOT NULL AND tbl_name IN ('sessions','observations')", + ) + .map_err(|e| AppError::Message(e.to_string()))?; + let mut idx_rows = idx_stmt.query([]).map_err(|e| AppError::Message(e.to_string()))?; + while let Some(row) = idx_rows.next().map_err(|e| AppError::Message(e.to_string()))? { + let sql: String = row.get(0).map_err(|e| AppError::Message(e.to_string()))?; + output.push_str(&sql); + output.push_str(";\n"); + } + + // Export triggers + let mut trig_stmt = conn + .prepare( + "SELECT sql FROM sqlite_master \ + WHERE type='trigger' AND sql NOT NULL AND tbl_name='observations'", + ) + .map_err(|e| AppError::Message(e.to_string()))?; + let mut trig_rows = trig_stmt.query([]).map_err(|e| AppError::Message(e.to_string()))?; + while let Some(row) = trig_rows.next().map_err(|e| AppError::Message(e.to_string()))? { + let sql: String = row.get(0).map_err(|e| AppError::Message(e.to_string()))?; + output.push_str(&sql); + output.push_str(";\n"); + } + + // FTS5 virtual table + output.push_str( + "CREATE VIRTUAL TABLE IF NOT EXISTS observations_fts USING fts5(\ + title,content,tags,content='observations',content_rowid='id');\n", + ); + + // Data + for table in &tables { + let col_names = memory_table_columns(&conn, table)?; + if col_names.is_empty() { + continue; + } + let mut data_stmt = conn + .prepare(&format!("SELECT * FROM \"{table}\"")) + .map_err(|e| AppError::Message(e.to_string()))?; + let mut data_rows = data_stmt.query([]).map_err(|e| AppError::Message(e.to_string()))?; + while let Some(row) = data_rows.next().map_err(|e| AppError::Message(e.to_string()))? { + let mut values = Vec::with_capacity(col_names.len()); + for idx in 0..col_names.len() { + let val = row + .get_ref(idx) + .map_err(|e| AppError::Message(e.to_string()))?; + values.push(format_memory_sql_value(val)); + } + let cols = col_names + .iter() + .map(|c| format!("\"{c}\"")) + .collect::>() + .join(", "); + output.push_str(&format!( + "INSERT INTO \"{table}\" ({cols}) VALUES ({});\n", + values.join(", ") + )); + } + } + + // Rebuild FTS5 index + output.push_str( + "INSERT INTO observations_fts(observations_fts) VALUES('rebuild');\n", + ); + output.push_str("COMMIT;\nPRAGMA foreign_keys=ON;\n"); + + Ok(Some(output.into_bytes())) + } + + /// Import memory database from SQL bytes (full replace). + pub fn import_sql_bytes(bytes: &[u8]) -> Result<(), AppError> { + let sql = std::str::from_utf8(bytes) + .map_err(|e| AppError::Message(format!("Memory SQL is not valid UTF-8: {e}")))?; + + if !sql.trim_start().starts_with("-- CC Switch Memory") { + return Err(AppError::Message( + "Invalid memory SQL export format".to_string(), + )); + } + + let conn = get_connection()?.lock()?; + + // Drop existing data + conn.execute_batch( + "DELETE FROM observations;\n\ + DELETE FROM sessions;\n\ + INSERT INTO observations_fts(observations_fts) VALUES('rebuild');", + ) + .map_err(|e| AppError::Message(format!("Failed to clear memory tables: {e}")))?; + + // Execute import — the SQL already contains CREATE TABLE IF NOT EXISTS + INSERT + conn.execute_batch(sql) + .map_err(|e| AppError::Message(format!("Memory SQL import failed: {e}")))?; + + Ok(()) + } + + /// Format context for display + pub fn format_context(items: &[ContextItem]) -> String { + if items.is_empty() { + return String::from("No relevant context found."); + } + + let mut output = String::new(); + for (i, item) in items.iter().enumerate() { + output.push_str(&format!( + "{}. [{}] {} ({})\n {}\n\n", + i + 1, + item.observation.observation_type, + item.observation.title, + item.match_reason, + item.observation + .content + .lines() + .next() + .unwrap_or("") + .chars() + .take(100) + .collect::() + )); + } + output + } +} + +// ============================================================================ +// Hook Helpers +// ============================================================================ + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct HooksStatus { + pub registered: bool, + pub session_start: bool, + pub post_tool_use: bool, +} + +fn has_cc_switch_hook(value: &serde_json::Value) -> bool { + if let Some(arr) = value.as_array() { + for item in arr { + if let Some(hooks) = item.get("hooks").and_then(|h| h.as_array()) { + for hook in hooks { + if let Some(cmd) = hook.get("command").and_then(|c| c.as_str()) { + if cmd.contains("cc-switch memory hooks ingest") { + return true; + } + } + } + } + } + } + false +} + +fn remove_cc_switch_hook(value: &mut serde_json::Value) { + if let Some(arr) = value.as_array_mut() { + // Remove cc-switch hooks from each item, then drop items with empty hooks + for item in arr.iter_mut() { + if let Some(hooks) = item.get_mut("hooks").and_then(|h| h.as_array_mut()) { + hooks.retain(|hook| { + hook.get("command") + .and_then(|c| c.as_str()) + .map(|cmd| !cmd.contains("cc-switch memory hooks ingest")) + .unwrap_or(true) + }); + } + } + // Remove items whose hooks array is now empty + arr.retain(|item| { + item.get("hooks") + .and_then(|h| h.as_array()) + .map(|hooks| !hooks.is_empty()) + .unwrap_or(true) + }); + } +} + +// ============================================================================ +// SQL Export Helpers +// ============================================================================ + +fn memory_table_columns(conn: &Connection, table: &str) -> Result, AppError> { + let mut stmt = conn + .prepare(&format!("PRAGMA table_info(\"{table}\")")) + .map_err(|e| AppError::Message(e.to_string()))?; + let iter = stmt + .query_map([], |row| row.get::<_, String>(1)) + .map_err(|e| AppError::Message(e.to_string()))?; + let mut cols = Vec::new(); + for c in iter { + cols.push(c.map_err(|e| AppError::Message(e.to_string()))?); + } + Ok(cols) +} + +fn format_memory_sql_value(value: rusqlite::types::ValueRef<'_>) -> String { + use rusqlite::types::ValueRef; + match value { + ValueRef::Null => "NULL".to_string(), + ValueRef::Integer(i) => i.to_string(), + ValueRef::Real(f) => f.to_string(), + ValueRef::Text(t) => { + let text = std::str::from_utf8(t).unwrap_or(""); + let escaped = text.replace('\'', "''"); + format!("'{escaped}'") + } + ValueRef::Blob(bytes) => { + let mut s = String::from("X'"); + for b in bytes { + use std::fmt::Write; + let _ = write!(&mut s, "{b:02X}"); + } + s.push('\''); + s + } + } +} diff --git a/src-tauri/src/services/mod.rs b/src-tauri/src/services/mod.rs index 9a09ba6..80e378a 100644 --- a/src-tauri/src/services/mod.rs +++ b/src-tauri/src/services/mod.rs @@ -3,6 +3,7 @@ pub mod env_checker; pub mod env_manager; pub mod local_env_check; pub mod mcp; +pub mod memory; pub mod prompt; pub mod provider; pub mod skill; @@ -11,6 +12,7 @@ pub mod webdav_sync; pub use config::ConfigService; pub use mcp::McpService; +pub use memory::MemoryService; pub use prompt::PromptService; pub use provider::ProviderService; pub use skill::SkillService; diff --git a/src-tauri/src/services/webdav_sync.rs b/src-tauri/src/services/webdav_sync.rs index 45d8a15..3ed1449 100644 --- a/src-tauri/src/services/webdav_sync.rs +++ b/src-tauri/src/services/webdav_sync.rs @@ -15,6 +15,7 @@ use zip::{write::SimpleFileOptions, DateTime}; use crate::config::atomic_write; use crate::database::Database; use crate::error::AppError; +use crate::services::memory::MemoryService; use crate::services::skill::SkillService; use crate::settings::{ get_settings, get_webdav_sync_settings, set_webdav_sync_settings, update_settings, @@ -27,6 +28,21 @@ const REMOTE_DB_SQL: &str = "db.sql"; const REMOTE_SKILLS_ZIP: &str = "skills.zip"; const REMOTE_SETTINGS_SYNC: &str = "settings.sync.json"; const REMOTE_MANIFEST: &str = "manifest.json"; +const REMOTE_CLAUDE_ZIP: &str = "claude.zip"; +const REMOTE_CODEX_ZIP: &str = "codex.zip"; +const REMOTE_GEMINI_ZIP: &str = "gemini.zip"; +const REMOTE_MEMORY_SQL: &str = "memory.sql"; + +const CLAUDE_EXCLUDES: &[&str] = &[ + "debug", "cache", "paste-cache", "telemetry", "statsig", + "session-env", "shell-snapshots", "tasks", "plugins", + "usage-data", "stats-cache.json", "statusline-command.sh", +]; +const CODEX_EXCLUDES: &[&str] = &["log", "tmp"]; +const GEMINI_EXCLUDES: &[&str] = &[ + "antigravity", "antigravity-browser-profile", + "oauth_creds.json", "google_accounts.json", "tmp", +]; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum SyncDecision { @@ -64,6 +80,14 @@ struct ManifestArtifacts { db_sql: ManifestArtifact, skills_zip: ManifestArtifact, settings_sync: ManifestArtifact, + #[serde(default, skip_serializing_if = "Option::is_none")] + claude_zip: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + codex_zip: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + gemini_zip: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + memory_sql: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -84,6 +108,10 @@ struct LocalSnapshot { db_sql: Vec, skills_zip: Vec, settings_sync: Vec, + claude_zip: Option>, + codex_zip: Option>, + gemini_zip: Option>, + memory_sql: Option>, manifest: WebDavManifest, manifest_bytes: Vec, manifest_hash: String, @@ -137,7 +165,44 @@ impl WebDavSyncService { let settings_sync = download_and_verify_artifact(&settings, &remote.manifest.artifacts.settings_sync)?; - apply_downloaded_snapshot(&db_sql, &skills_zip, &settings_sync)?; + let claude_zip = remote + .manifest + .artifacts + .claude_zip + .as_ref() + .map(|a| download_and_verify_artifact(&settings, a)) + .transpose()?; + let codex_zip = remote + .manifest + .artifacts + .codex_zip + .as_ref() + .map(|a| download_and_verify_artifact(&settings, a)) + .transpose()?; + let gemini_zip = remote + .manifest + .artifacts + .gemini_zip + .as_ref() + .map(|a| download_and_verify_artifact(&settings, a)) + .transpose()?; + let memory_sql = remote + .manifest + .artifacts + .memory_sql + .as_ref() + .map(|a| download_and_verify_artifact(&settings, a)) + .transpose()?; + + apply_downloaded_snapshot( + &db_sql, + &skills_zip, + &settings_sync, + claude_zip.as_deref(), + codex_zip.as_deref(), + gemini_zip.as_deref(), + memory_sql.as_deref(), + )?; settings.status.last_sync_at = Some(Utc::now().timestamp()); settings.status.last_error = None; @@ -183,9 +248,21 @@ fn build_local_snapshot(settings: &WebDavSyncSettings) -> Result Result, + codex_zip: Option<&[u8]>, + gemini_zip: Option<&[u8]>, + memory_sql: Option<&[u8]>, ) -> Result<(), AppError> { let tmp = tempdir().map_err(|e| AppError::IoContext { context: "创建 WebDAV 下载临时目录失败".to_string(), @@ -252,6 +364,23 @@ fn apply_downloaded_snapshot( apply_syncable_settings(settings_sync)?; restore_skills_zip(skills_zip)?; + + // Restore CLI directories (merge mode) + if let Some(bytes) = claude_zip { + restore_cli_zip(bytes, &crate::config::get_claude_config_dir())?; + } + if let Some(bytes) = codex_zip { + restore_cli_zip(bytes, &crate::codex_config::get_codex_config_dir())?; + } + if let Some(bytes) = gemini_zip { + restore_cli_zip(bytes, &crate::gemini_config::get_gemini_dir())?; + } + + // Restore memory + if let Some(bytes) = memory_sql { + MemoryService::import_sql_bytes(bytes)?; + } + Ok(()) } @@ -331,6 +460,20 @@ fn upload_snapshot( &snapshot.skills_zip, "application/zip", )?; + + if let Some(ref bytes) = snapshot.claude_zip { + put_remote_bytes(settings, REMOTE_CLAUDE_ZIP, bytes, "application/zip")?; + } + if let Some(ref bytes) = snapshot.codex_zip { + put_remote_bytes(settings, REMOTE_CODEX_ZIP, bytes, "application/zip")?; + } + if let Some(ref bytes) = snapshot.gemini_zip { + put_remote_bytes(settings, REMOTE_GEMINI_ZIP, bytes, "application/zip")?; + } + if let Some(ref bytes) = snapshot.memory_sql { + put_remote_bytes(settings, REMOTE_MEMORY_SQL, bytes, "application/sql")?; + } + let _ = &snapshot.manifest; put_remote_bytes( settings, @@ -385,11 +528,39 @@ fn snapshot_identity_from_manifest(manifest: &WebDavManifest) -> String { &manifest.artifacts.db_sql.sha256, &manifest.artifacts.skills_zip.sha256, &manifest.artifacts.settings_sync.sha256, + manifest.artifacts.claude_zip.as_ref().map(|a| a.sha256.as_str()), + manifest.artifacts.codex_zip.as_ref().map(|a| a.sha256.as_str()), + manifest.artifacts.gemini_zip.as_ref().map(|a| a.sha256.as_str()), + manifest.artifacts.memory_sql.as_ref().map(|a| a.sha256.as_str()), ) } -fn snapshot_identity_from_hashes(db_hash: &str, skills_hash: &str, settings_hash: &str) -> String { - let combined = format!("{db_hash}:{skills_hash}:{settings_hash}"); +fn snapshot_identity_from_hashes( + db_hash: &str, + skills_hash: &str, + settings_hash: &str, + claude_hash: Option<&str>, + codex_hash: Option<&str>, + gemini_hash: Option<&str>, + memory_hash: Option<&str>, +) -> String { + let mut combined = format!("{db_hash}:{skills_hash}:{settings_hash}"); + if let Some(h) = claude_hash { + combined.push(':'); + combined.push_str(h); + } + if let Some(h) = codex_hash { + combined.push(':'); + combined.push_str(h); + } + if let Some(h) = gemini_hash { + combined.push(':'); + combined.push_str(h); + } + if let Some(h) = memory_hash { + combined.push(':'); + combined.push_str(h); + } sha256_hex(combined.as_bytes()) } @@ -753,6 +924,128 @@ fn zip_file_options() -> SimpleFileOptions { .last_modified_time(DateTime::default()) } +/// Zip a CLI config directory, skipping entries whose name matches any exclude pattern. +/// Returns `Ok(None)` if the directory does not exist or is empty. +fn zip_cli_dir(dir: &Path, excludes: &[&str]) -> Result>, AppError> { + if !dir.exists() { + return Ok(None); + } + + let tmp = tempdir().map_err(|e| AppError::IoContext { + context: format!("创建 CLI zip 临时目录失败: {}", dir.display()), + source: e, + })?; + let zip_path = tmp.path().join("cli.zip"); + + let file = fs::File::create(&zip_path).map_err(|e| AppError::io(&zip_path, e))?; + let mut writer = zip::ZipWriter::new(file); + let options = zip_file_options(); + + zip_dir_recursive_filtered(dir, dir, &mut writer, options, excludes)?; + + writer + .finish() + .map_err(|e| AppError::Message(format!("写入 CLI zip 失败: {e}")))?; + + let bytes = fs::read(&zip_path).map_err(|e| AppError::io(&zip_path, e))?; + // An empty zip with no entries is ~22 bytes; treat as None + if bytes.len() <= 22 { + return Ok(None); + } + Ok(Some(bytes)) +} + +fn zip_dir_recursive_filtered( + root: &Path, + current: &Path, + writer: &mut zip::ZipWriter, + options: SimpleFileOptions, + excludes: &[&str], +) -> Result<(), AppError> { + let mut entries = fs::read_dir(current) + .map_err(|e| AppError::io(current, e))? + .collect::, _>>() + .map_err(|e| AppError::io(current, e))?; + entries.sort_by_key(|entry| entry.file_name()); + + for entry in entries { + let path = entry.path(); + let name = entry.file_name(); + let name_str = name.to_string_lossy(); + + // Check excludes against the file/dir name + if excludes.iter().any(|ex| name_str == *ex) { + continue; + } + + let rel = path + .strip_prefix(root) + .map_err(|e| AppError::Message(format!("生成 ZIP 相对路径失败: {e}")))?; + let rel_str = rel.to_string_lossy().replace('\\', "/"); + + if path.is_dir() { + writer + .add_directory(format!("{rel_str}/"), options) + .map_err(|e| AppError::Message(format!("写入 ZIP 目录失败: {e}")))?; + zip_dir_recursive_filtered(root, &path, writer, options, excludes)?; + } else { + writer + .start_file(&*rel_str, options) + .map_err(|e| AppError::Message(format!("写入 ZIP 文件头失败: {e}")))?; + let mut file = fs::File::open(&path).map_err(|e| AppError::io(&path, e))?; + let mut buf = Vec::new(); + file.read_to_end(&mut buf) + .map_err(|e| AppError::io(&path, e))?; + writer + .write_all(&buf) + .map_err(|e| AppError::Message(format!("写入 ZIP 文件内容失败: {e}")))?; + } + } + Ok(()) +} + +/// Restore a CLI zip in merge mode: extract into temp dir, then copy over target dir +/// without deleting files that are not in the zip (preserves excluded / local-only files). +fn restore_cli_zip(raw: &[u8], target_dir: &Path) -> Result<(), AppError> { + let tmp = tempdir().map_err(|e| AppError::IoContext { + context: "创建 CLI zip 解压临时目录失败".to_string(), + source: e, + })?; + let zip_path = tmp.path().join("cli.zip"); + atomic_write(&zip_path, raw)?; + + let file = fs::File::open(&zip_path).map_err(|e| AppError::io(&zip_path, e))?; + let mut archive = zip::ZipArchive::new(file) + .map_err(|e| AppError::Message(format!("解析 CLI zip 失败: {e}")))?; + + let extracted = tmp.path().join("extracted"); + fs::create_dir_all(&extracted).map_err(|e| AppError::io(&extracted, e))?; + + for idx in 0..archive.len() { + let mut entry = archive + .by_index(idx) + .map_err(|e| AppError::Message(format!("读取 ZIP 项失败: {e}")))?; + let Some(safe_name) = entry.enclosed_name() else { + continue; + }; + let out_path = extracted.join(safe_name); + if entry.is_dir() { + fs::create_dir_all(&out_path).map_err(|e| AppError::io(&out_path, e))?; + continue; + } + if let Some(parent) = out_path.parent() { + fs::create_dir_all(parent).map_err(|e| AppError::io(parent, e))?; + } + let mut out = fs::File::create(&out_path).map_err(|e| AppError::io(&out_path, e))?; + std::io::copy(&mut entry, &mut out).map_err(|e| AppError::io(&out_path, e))?; + } + + // Merge: copy extracted files over target, creating dirs as needed + fs::create_dir_all(target_dir).map_err(|e| AppError::io(target_dir, e))?; + copy_dir_recursive(&extracted, target_dir)?; + Ok(()) +} + fn zip_dir_recursive( root: &Path, current: &Path, @@ -906,6 +1199,10 @@ mod tests { sha256: "settings-hash".to_string(), size: 3, }, + claude_zip: None, + codex_zip: None, + gemini_zip: None, + memory_sql: None, }, }; let manifest_b = WebDavManifest {