diff --git a/Cargo.lock b/Cargo.lock index 814f82e..fad43b0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -243,6 +243,12 @@ dependencies = [ "syn", ] +[[package]] +name = "dotenvy" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" + [[package]] name = "encode_unicode" version = "1.0.0" @@ -616,6 +622,7 @@ dependencies = [ "clap_complete", "dialoguer", "directories", + "dotenvy", "keyring", "reqwest", "semver", diff --git a/Cargo.toml b/Cargo.toml index 172b32c..687ac0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ async-trait = "0.1" keyring = "3" semver = "1" tempfile = "3" +dotenvy = "0.15" [dev-dependencies] tokio = { version = "1", features = ["full"] } diff --git a/src/cli/app.rs b/src/cli/app.rs index e8d969e..a9d96be 100644 --- a/src/cli/app.rs +++ b/src/cli/app.rs @@ -58,22 +58,22 @@ pub enum AppSubcommands { }, } -pub async fn handle(cmd: AppCommands, json: bool) -> Result<()> { +pub async fn handle(cmd: AppCommands, json: bool, load_env: bool) -> Result<()> { match cmd.command { AppSubcommands::List { category, provider, page, per_page, - } => list_apps(category, provider, page, per_page, json).await, + } => list_apps(category, provider, page, per_page, json, load_env).await, AppSubcommands::Run { app, input, input_file, stream, output, - } => run_app(app, input, input_file, stream, output, json).await, - AppSubcommands::Show { app } => show_app(app, json).await, + } => run_app(app, input, input_file, stream, output, json, load_env).await, + AppSubcommands::Show { app } => show_app(app, json, load_env).await, } } @@ -83,9 +83,10 @@ async fn list_apps( page: usize, per_page: usize, json: bool, + load_env: bool, ) -> Result<()> { let registry = build_registry(); - let app_config = config::load_config()?; + let app_config = config::load_config_with_env(load_env)?; let catalog = Catalog::new(®istry, &app_config); let apps = if let Some(provider_id) = &provider_filter { @@ -177,6 +178,7 @@ async fn run_app( stream: bool, output: Option, json: bool, + load_env: bool, ) -> Result<()> { let app_id = AppId::parse(&app_str)?; @@ -201,7 +203,7 @@ async fn run_app( let registry = build_registry(); let provider = registry.find_provider(&app_id.provider)?; - let app_config = config::load_config()?; + let app_config = config::load_config_with_env(load_env)?; let prov_config = app_config .providers .get(&app_id.provider) @@ -299,11 +301,11 @@ async fn save_images(urls: &[String], base_path: &std::path::Path) -> Result<()> Ok(()) } -async fn show_app(app_str: String, json: bool) -> Result<()> { +async fn show_app(app_str: String, json: bool, load_env: bool) -> Result<()> { let app_id = AppId::parse(&app_str)?; let registry = build_registry(); - let app_config = config::load_config()?; + let app_config = config::load_config_with_env(load_env)?; let catalog = Catalog::new(®istry, &app_config); let app = catalog diff --git a/src/cli/doctor.rs b/src/cli/doctor.rs index 1fd51dd..f084d89 100644 --- a/src/cli/doctor.rs +++ b/src/cli/doctor.rs @@ -3,7 +3,7 @@ use crate::providers::registry::build_registry; use crate::types::ProviderConnectionStatus; use anyhow::Result; -pub async fn handle() -> Result<()> { +pub async fn handle(load_env: bool) -> Result<()> { println!("infs Doctor"); println!("==========="); println!(); @@ -23,7 +23,7 @@ pub async fn handle() -> Result<()> { // Check providers let registry = build_registry(); - let app_config = config::load_config()?; + let app_config = config::load_config_with_env(load_env)?; println!("Providers:"); for provider in registry.list_providers() { diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 3734ce9..f0610c7 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -24,6 +24,10 @@ pub struct Cli { #[arg(long, global = true)] pub json: bool, + /// Skip loading .env files from current directory and parent directories + #[arg(long, global = true)] + pub no_env: bool, + #[command(subcommand)] pub command: Commands, } diff --git a/src/cli/provider.rs b/src/cli/provider.rs index 5e622e2..99f0931 100644 --- a/src/cli/provider.rs +++ b/src/cli/provider.rs @@ -32,18 +32,18 @@ pub enum ProviderSubcommands { }, } -pub async fn handle(cmd: ProviderCommands, json: bool) -> Result<()> { +pub async fn handle(cmd: ProviderCommands, json: bool, load_env: bool) -> Result<()> { match cmd.command { - ProviderSubcommands::List => list_providers(json).await, + ProviderSubcommands::List => list_providers(json, load_env).await, ProviderSubcommands::Connect { provider } => connect_provider(&provider).await, ProviderSubcommands::Disconnect { provider } => disconnect_provider(&provider).await, - ProviderSubcommands::Show { provider } => show_provider(&provider, json).await, + ProviderSubcommands::Show { provider } => show_provider(&provider, json, load_env).await, } } -async fn list_providers(json: bool) -> Result<()> { +async fn list_providers(json: bool, load_env: bool) -> Result<()> { let registry = build_registry(); - let app_config = config::load_config()?; + let app_config = config::load_config_with_env(load_env)?; let mut rows = Vec::new(); for provider in registry.list_providers() { @@ -137,11 +137,11 @@ async fn disconnect_provider(provider_id: &str) -> Result<()> { Ok(()) } -async fn show_provider(provider_id: &str, json: bool) -> Result<()> { +async fn show_provider(provider_id: &str, json: bool, load_env: bool) -> Result<()> { let registry = build_registry(); let provider = registry.find_provider(provider_id)?; let d = provider.descriptor(); - let app_config = config::load_config()?; + let app_config = config::load_config_with_env(load_env)?; let prov_config = app_config.providers.get(&d.id); let status = if prov_config.is_some_and(|c| c.connected && c.get_api_key().is_some()) { diff --git a/src/config/mod.rs b/src/config/mod.rs index ea68520..5a7be57 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -7,6 +7,18 @@ use std::path::PathBuf; /// Service name used for all keyring entries. const KEYRING_SERVICE: &str = "infs"; +/// Maximum number of parent directories to search for .env files. +const MAX_ENV_PARENT_DEPTH: usize = 3; + +/// Environment variable patterns for provider credentials. +/// Format: (provider_id, env_var_prefix, credential_key) +const PROVIDER_ENV_PATTERNS: &[(&str, &str, &str)] = &[ + ("openrouter", "OPENROUTER", "api_key"), + ("falai", "FALAI", "api_key"), + ("replicate", "REPLICATE", "api_token"), + ("wavespeed", "WAVESPEED", "api_key"), +]; + #[derive(Debug, Clone, Serialize, Deserialize, Default)] pub struct AppConfig { #[serde(default)] @@ -46,6 +58,77 @@ pub fn get_credentials_path() -> Result { Ok(get_config_dir()?.join("credentials.toml")) } +// --------------------------------------------------------------------------- +// .env file loading +// --------------------------------------------------------------------------- + +/// Load .env files from the current directory and up to MAX_ENV_PARENT_DEPTH parent directories. +/// Returns the path of the .env file that was loaded, if any. +pub fn load_dotenv() -> Option { + let cwd = std::env::current_dir().ok()?; + let mut current = cwd.as_path(); + + for _ in 0..=MAX_ENV_PARENT_DEPTH { + let env_path = current.join(".env"); + if env_path.exists() && env_path.is_file() { + match dotenvy::from_path(&env_path) { + Ok(()) => { + tracing::debug!("Loaded .env from: {:?}", env_path); + return Some(env_path); + } + Err(e) => { + tracing::warn!("Failed to load .env from {:?}: {}", env_path, e); + } + } + } + + current = current.parent()?; + } + + None +} + +/// Extract provider credentials from environment variables. +/// Returns a map of provider_id -> (credential_key -> value). +pub fn credentials_from_env() -> HashMap> { + let mut result: HashMap> = HashMap::new(); + + for (provider_id, prefix, cred_key) in PROVIDER_ENV_PATTERNS { + let env_var = format!("{}_{}", prefix, cred_key.to_uppercase()); + if let Ok(value) = std::env::var(&env_var) { + if !value.is_empty() { + result + .entry(provider_id.to_string()) + .or_default() + .insert(cred_key.to_string(), value); + } + } + } + + result +} + +fn merge_env_credentials( + config: &mut AppConfig, + env_creds: HashMap>, +) { + for (provider_id, creds) in env_creds { + let provider_config = config.providers.entry(provider_id).or_default(); + for (key, value) in creds { + provider_config.credentials.entry(key).or_insert(value); + } + } +} + +fn merge_file_credentials(config: &mut AppConfig, file_creds: HashMap) { + for (provider_id, cred_config) in file_creds { + let provider_config = config.providers.entry(provider_id).or_default(); + for (key, value) in cred_config.credentials { + provider_config.credentials.insert(key, value); + } + } +} + // --------------------------------------------------------------------------- // Keyring helpers // --------------------------------------------------------------------------- @@ -135,6 +218,10 @@ pub fn keyring_delete(provider_id: &str, cred_key: &str) -> Result<(), InfsError // --------------------------------------------------------------------------- pub fn load_config() -> Result { + load_config_with_env(false) +} + +pub fn load_config_with_env(load_env: bool) -> Result { let config_path = get_config_path()?; let mut config = if config_path.exists() { @@ -146,7 +233,12 @@ pub fn load_config() -> Result { AppConfig::default() }; - // Load credentials: keychain first (for keys recorded in keychain_credentials), + // Load credentials from environment variables first (lowest priority). + if load_env { + merge_env_credentials(&mut config, credentials_from_env()); + } + + // Load credentials: keychain next (for keys recorded in keychain_credentials), // then fall back to credentials.toml for anything not yet migrated. for (provider_id, provider_config) in config.providers.iter_mut() { for cred_key in &provider_config.keychain_credentials { @@ -165,14 +257,7 @@ pub fn load_config() -> Result { let creds: HashMap = toml::from_str(&creds_content) .map_err(|e| InfsError::ConfigError(format!("Failed to parse credentials: {}", e)))?; - - for (provider_id, cred_config) in creds { - let provider_config = config.providers.entry(provider_id).or_default(); - for (key, value) in cred_config.credentials { - // Don't overwrite a value already loaded from keychain. - provider_config.credentials.entry(key).or_insert(value); - } - } + merge_file_credentials(&mut config, creds); } Ok(config) @@ -324,6 +409,46 @@ pub fn remove_provider_credentials(provider_id: &str) -> Result<(), InfsError> { #[cfg(test)] mod tests { use super::*; + use std::path::PathBuf; + use std::sync::{Mutex, MutexGuard, OnceLock}; + + fn test_env_lock() -> MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())).lock().unwrap() + } + + struct TestEnvGuard { + original_cwd: PathBuf, + original_vars: HashMap>, + } + + impl TestEnvGuard { + fn new(vars: &[&str]) -> Self { + let original_cwd = std::env::current_dir().unwrap(); + let original_vars = vars + .iter() + .map(|key| ((*key).to_string(), std::env::var(key).ok())) + .collect(); + + Self { + original_cwd, + original_vars, + } + } + } + + impl Drop for TestEnvGuard { + fn drop(&mut self) { + let _ = std::env::set_current_dir(&self.original_cwd); + + for (key, value) in &self.original_vars { + match value { + Some(value) => std::env::set_var(key, value), + None => std::env::remove_var(key), + } + } + } + } #[test] fn test_app_config_default() { @@ -452,4 +577,158 @@ mod tests { ); } } + + #[test] + fn test_credentials_from_env_empty() { + let _lock = test_env_lock(); + let _guard = TestEnvGuard::new(&[ + "OPENROUTER_API_KEY", + "FALAI_API_KEY", + "REPLICATE_API_TOKEN", + "WAVESPEED_API_KEY", + ]); + + std::env::remove_var("OPENROUTER_API_KEY"); + std::env::remove_var("FALAI_API_KEY"); + std::env::remove_var("REPLICATE_API_TOKEN"); + std::env::remove_var("WAVESPEED_API_KEY"); + + let creds = credentials_from_env(); + assert!(creds.is_empty()); + } + + #[test] + fn test_credentials_from_env_with_values() { + let _lock = test_env_lock(); + let _guard = TestEnvGuard::new(&[ + "OPENROUTER_API_KEY", + "FALAI_API_KEY", + "REPLICATE_API_TOKEN", + "WAVESPEED_API_KEY", + ]); + + std::env::set_var("OPENROUTER_API_KEY", "test-openrouter-key"); + std::env::set_var("FALAI_API_KEY", "test-falai-key"); + + let creds = credentials_from_env(); + assert_eq!( + creds.get("openrouter").and_then(|c| c.get("api_key")), + Some(&"test-openrouter-key".to_string()) + ); + assert_eq!( + creds.get("falai").and_then(|c| c.get("api_key")), + Some(&"test-falai-key".to_string()) + ); + } + + #[test] + fn test_credentials_from_env_ignores_empty() { + let _lock = test_env_lock(); + let _guard = TestEnvGuard::new(&[ + "OPENROUTER_API_KEY", + "FALAI_API_KEY", + "REPLICATE_API_TOKEN", + "WAVESPEED_API_KEY", + ]); + + std::env::set_var("OPENROUTER_API_KEY", ""); + + let creds = credentials_from_env(); + assert!(!creds.contains_key("openrouter")); + } + + #[test] + fn test_credentials_file_overrides_env() { + let _lock = test_env_lock(); + let _guard = TestEnvGuard::new(&[ + "OPENROUTER_API_KEY", + "FALAI_API_KEY", + "REPLICATE_API_TOKEN", + "WAVESPEED_API_KEY", + ]); + + let mut config = AppConfig::default(); + merge_env_credentials( + &mut config, + HashMap::from([( + "openrouter".to_string(), + HashMap::from([("api_key".to_string(), "from-env".to_string())]), + )]), + ); + + merge_file_credentials( + &mut config, + HashMap::from([( + "openrouter".to_string(), + ProviderConfig { + credentials: HashMap::from([("api_key".to_string(), "from-file".to_string())]), + ..Default::default() + }, + )]), + ); + + assert_eq!( + config + .providers + .get("openrouter") + .and_then(|provider| provider.credentials.get("api_key")), + Some(&"from-file".to_string()) + ); + } + + #[test] + fn test_load_dotenv_finds_file_in_current_dir() { + use std::io::Write; + let _lock = test_env_lock(); + let _guard = TestEnvGuard::new(&["TEST_VAR"]); + let temp_dir = tempfile::tempdir().unwrap(); + let env_path = temp_dir.path().join(".env"); + let mut file = std::fs::File::create(&env_path).unwrap(); + writeln!(file, "TEST_VAR=test_value").unwrap(); + std::env::set_current_dir(temp_dir.path()).unwrap(); + + std::env::remove_var("TEST_VAR"); + let loaded_path = load_dotenv(); + + assert!(loaded_path.is_some()); + assert_eq!( + std::env::var("TEST_VAR").ok(), + Some("test_value".to_string()) + ); + } + + #[test] + fn test_load_dotenv_searches_parent_dirs() { + use std::io::Write; + let _lock = test_env_lock(); + let _guard = TestEnvGuard::new(&["PARENT_TEST_VAR"]); + let temp_dir = tempfile::tempdir().unwrap(); + let env_path = temp_dir.path().join(".env"); + let mut file = std::fs::File::create(&env_path).unwrap(); + writeln!(file, "PARENT_TEST_VAR=parent_value").unwrap(); + + let child_dir = temp_dir.path().join("child"); + std::fs::create_dir(&child_dir).unwrap(); + std::env::set_current_dir(&child_dir).unwrap(); + + std::env::remove_var("PARENT_TEST_VAR"); + let loaded_path = load_dotenv(); + + assert!(loaded_path.is_some()); + assert_eq!( + std::env::var("PARENT_TEST_VAR").ok(), + Some("parent_value".to_string()) + ); + } + + #[test] + fn test_load_dotenv_returns_none_when_no_file() { + let _lock = test_env_lock(); + let _guard = TestEnvGuard::new(&[]); + let temp_dir = tempfile::tempdir().unwrap(); + std::env::set_current_dir(temp_dir.path()).unwrap(); + + let loaded_path = load_dotenv(); + assert!(loaded_path.is_none()); + } } diff --git a/src/main.rs b/src/main.rs index bf6ac82..f9ae8cd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,8 +22,15 @@ async fn main() -> Result<()> { .init(); let cli = Cli::parse(); + let load_env = !cli.no_env; - if let Err(e) = run(cli).await { + if load_env { + if let Some(env_path) = config::load_dotenv() { + tracing::info!("Loaded .env from: {:?}", env_path); + } + } + + if let Err(e) = run(cli, load_env).await { eprintln!("Error: {}", e); std::process::exit(1); } @@ -31,13 +38,13 @@ async fn main() -> Result<()> { Ok(()) } -async fn run(cli: Cli) -> Result<()> { +async fn run(cli: Cli, load_env: bool) -> Result<()> { let json = cli.json; match cli.command { - Commands::Provider(cmd) => cli::provider::handle(cmd, json).await, - Commands::App(cmd) => cli::app::handle(cmd, json).await, + Commands::Provider(cmd) => cli::provider::handle(cmd, json, load_env).await, + Commands::App(cmd) => cli::app::handle(cmd, json, load_env).await, Commands::Config(cmd) => cli::config::handle(cmd).await, - Commands::Doctor => cli::doctor::handle().await, + Commands::Doctor => cli::doctor::handle(load_env).await, Commands::Completions { shell } => cli::completions::handle(shell), Commands::SelfCmd(cmd) => cli::update::handle_update_command(cmd, json).await, }