diff --git a/Cargo.lock b/Cargo.lock index 215952a..9d62e10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -282,12 +282,19 @@ name = "arcan-provider" version = "0.2.0" dependencies = [ "arcan-core", + "base64", + "dirs", + "open", + "rand 0.9.2", "reqwest", "rig-core", "serde", "serde_json", + "sha2", "thiserror 2.0.18", "tokio", + "tracing", + "url", "uuid", ] @@ -1422,6 +1429,25 @@ dependencies = [ "serde", ] +[[package]] +name = "is-docker" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "928bae27f42bc99b60d9ac7334e3a21d10ad8f1835a4e12ec3ec0464765ed1b3" +dependencies = [ + "once_cell", +] + +[[package]] +name = "is-wsl" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "173609498df190136aa7dea1a91db051746d339e18476eed5ca40521f02d7aa5" +dependencies = [ + "is-docker", + "once_cell", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.2" @@ -1840,6 +1866,17 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "open" +version = "5.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43bb73a7fa3799b198970490a51174027ba0d4ec504b03cd08caf513d40024bc" +dependencies = [ + "is-wsl", + "libc", + "pathdiff", +] + [[package]] name = "openssl" version = "0.10.75" @@ -1934,6 +1971,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b867cad97c0791bbd3aaa6472142568c6c9e8f71937e98379f584cfb0cf35bec" +[[package]] +name = "pathdiff" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" + [[package]] name = "percent-encoding" version = "2.3.2" diff --git a/crates/arcan-core/src/error.rs b/crates/arcan-core/src/error.rs index 732bcb6..c15215b 100644 --- a/crates/arcan-core/src/error.rs +++ b/crates/arcan-core/src/error.rs @@ -12,4 +12,6 @@ pub enum CoreError { Middleware(String), #[error("state patch failed: {0}")] State(String), + #[error("auth error: {0}")] + Auth(String), } diff --git a/crates/arcan-provider/Cargo.toml b/crates/arcan-provider/Cargo.toml index d565e74..65311c5 100644 --- a/crates/arcan-provider/Cargo.toml +++ b/crates/arcan-provider/Cargo.toml @@ -16,10 +16,17 @@ workspace = true [dependencies] arcan-core = { path = "../arcan-core", version = "0.2.0" } +base64 = "0.22" +dirs = "6" +open = "5" +rand = "0.9" reqwest.workspace = true rig-core.workspace = true serde.workspace = true serde_json.workspace = true +sha2 = "0.10" thiserror.workspace = true tokio.workspace = true +tracing.workspace = true +url = "2" uuid.workspace = true diff --git a/crates/arcan-provider/src/anthropic.rs b/crates/arcan-provider/src/anthropic.rs index 4fa0ca7..fbeff39 100644 --- a/crates/arcan-provider/src/anthropic.rs +++ b/crates/arcan-provider/src/anthropic.rs @@ -6,16 +6,29 @@ use arcan_core::protocol::{ use arcan_core::runtime::{Provider, ProviderRequest}; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; +use std::sync::Arc; + +use crate::credential::{AnthropicApiKeyCredential, Credential}; /// Configuration for the Anthropic provider. -#[derive(Debug, Clone)] pub struct AnthropicConfig { - pub api_key: String, + pub credential: Arc, pub model: String, pub max_tokens: u32, pub base_url: String, } +impl std::fmt::Debug for AnthropicConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AnthropicConfig") + .field("credential", &self.credential.kind()) + .field("model", &self.model) + .field("max_tokens", &self.max_tokens) + .field("base_url", &self.base_url) + .finish() + } +} + impl AnthropicConfig { pub fn from_env() -> Result { let api_key = std::env::var("ANTHROPIC_API_KEY").map_err(|_| { @@ -34,7 +47,7 @@ impl AnthropicConfig { .unwrap_or_else(|_| "https://api.anthropic.com".to_string()); Ok(Self { - api_key, + credential: Arc::new(AnthropicApiKeyCredential::new(api_key)), model, max_tokens, base_url, @@ -73,7 +86,7 @@ impl AnthropicConfig { .unwrap_or_else(|| "https://api.anthropic.com".to_string()); Ok(Self { - api_key, + credential: Arc::new(crate::credential::AnthropicApiKeyCredential::new(api_key)), model, max_tokens, base_url, @@ -214,10 +227,16 @@ impl Provider for AnthropicProvider { let url = format!("{}/v1/messages", self.config.base_url); + let api_key = self + .config + .credential + .auth_header() + .map_err(|e| CoreError::Provider(format!("credential error: {e}")))?; + let response = self .client .post(&url) - .header("x-api-key", &self.config.api_key) + .header("x-api-key", &api_key) .header("anthropic-version", "2023-06-01") .header("content-type", "application/json") .json(&body) @@ -310,15 +329,18 @@ enum ResponseBlock { mod tests { use super::*; - #[test] - fn builds_messages_with_system_prompt() { - let config = AnthropicConfig { - api_key: "test".to_string(), + fn test_config() -> AnthropicConfig { + AnthropicConfig { + credential: Arc::new(AnthropicApiKeyCredential::new("test".to_string())), model: "test-model".to_string(), max_tokens: 1024, base_url: "http://localhost".to_string(), - }; - let provider = AnthropicProvider::new(config); + } + } + + #[test] + fn builds_messages_with_system_prompt() { + let provider = AnthropicProvider::new(test_config()); let messages = vec![ ChatMessage::system("You are helpful."), @@ -333,13 +355,7 @@ mod tests { #[test] fn parses_text_response() { - let config = AnthropicConfig { - api_key: "test".to_string(), - model: "test-model".to_string(), - max_tokens: 1024, - base_url: "http://localhost".to_string(), - }; - let provider = AnthropicProvider::new(config); + let provider = AnthropicProvider::new(test_config()); let response = ApiResponse { content: vec![ResponseBlock::Text { @@ -357,13 +373,7 @@ mod tests { #[test] fn parses_tool_use_response() { - let config = AnthropicConfig { - api_key: "test".to_string(), - model: "test-model".to_string(), - max_tokens: 1024, - base_url: "http://localhost".to_string(), - }; - let provider = AnthropicProvider::new(config); + let provider = AnthropicProvider::new(test_config()); let response = ApiResponse { content: vec![ diff --git a/crates/arcan-provider/src/credential.rs b/crates/arcan-provider/src/credential.rs new file mode 100644 index 0000000..c43df59 --- /dev/null +++ b/crates/arcan-provider/src/credential.rs @@ -0,0 +1,159 @@ +use arcan_core::error::CoreError; +use std::fmt; + +/// A credential that can produce HTTP authorization headers. +/// +/// Implementations handle API keys, OAuth tokens with refresh, etc. +pub trait Credential: Send + Sync + fmt::Debug { + /// Returns the authorization header value (e.g. `"Bearer "`). + fn auth_header(&self) -> Result; + + /// Returns the credential kind for display/logging. + fn kind(&self) -> &str; + + /// Whether this credential needs periodic refresh (OAuth tokens do, API keys don't). + fn needs_refresh(&self) -> bool { + false + } + + /// Refresh the credential if needed. No-op for static credentials. + fn refresh(&self) -> Result<(), CoreError> { + Ok(()) + } +} + +/// A static API key credential that produces `Bearer ` headers. +/// +/// Used for OpenAI, Ollama, and other Bearer-token APIs. +#[derive(Clone)] +pub struct ApiKeyCredential { + api_key: String, +} + +impl fmt::Debug for ApiKeyCredential { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ApiKeyCredential") + .field("api_key", &"[REDACTED]") + .finish() + } +} + +impl ApiKeyCredential { + pub fn new(api_key: String) -> Self { + Self { api_key } + } + + /// Returns the raw API key (for providers that need non-Bearer auth). + pub fn raw_key(&self) -> &str { + &self.api_key + } + + /// Whether the underlying key is empty (e.g. Ollama local servers). + pub fn is_empty(&self) -> bool { + self.api_key.is_empty() + } +} + +impl Credential for ApiKeyCredential { + fn auth_header(&self) -> Result { + if self.api_key.is_empty() { + return Err(CoreError::Auth("API key is empty".to_string())); + } + Ok(format!("Bearer {}", self.api_key)) + } + + fn kind(&self) -> &str { + "api_key" + } +} + +/// A static API key credential that produces `x-api-key` style headers. +/// +/// Used specifically for Anthropic which uses a custom header instead of Bearer. +#[derive(Clone)] +pub struct AnthropicApiKeyCredential { + api_key: String, +} + +impl fmt::Debug for AnthropicApiKeyCredential { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AnthropicApiKeyCredential") + .field("api_key", &"[REDACTED]") + .finish() + } +} + +impl AnthropicApiKeyCredential { + pub fn new(api_key: String) -> Self { + Self { api_key } + } + + /// Returns the raw API key for direct use in `x-api-key` header. + pub fn raw_key(&self) -> &str { + &self.api_key + } +} + +impl Credential for AnthropicApiKeyCredential { + fn auth_header(&self) -> Result { + if self.api_key.is_empty() { + return Err(CoreError::Auth("Anthropic API key is empty".to_string())); + } + // Anthropic uses `x-api-key` header directly, but we return the raw value + // so the provider can set it on the appropriate header. + Ok(self.api_key.clone()) + } + + fn kind(&self) -> &str { + "anthropic_api_key" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn api_key_credential_bearer_header() { + let cred = ApiKeyCredential::new("sk-test-123".to_string()); + assert_eq!(cred.auth_header().unwrap(), "Bearer sk-test-123"); + assert_eq!(cred.kind(), "api_key"); + assert!(!cred.needs_refresh()); + assert!(!cred.is_empty()); + } + + #[test] + fn api_key_credential_empty_returns_error() { + let cred = ApiKeyCredential::new(String::new()); + assert!(cred.auth_header().is_err()); + assert!(cred.is_empty()); + } + + #[test] + fn anthropic_credential_raw_key() { + let cred = AnthropicApiKeyCredential::new("sk-ant-test".to_string()); + assert_eq!(cred.auth_header().unwrap(), "sk-ant-test"); + assert_eq!(cred.kind(), "anthropic_api_key"); + assert_eq!(cred.raw_key(), "sk-ant-test"); + } + + #[test] + fn anthropic_credential_empty_returns_error() { + let cred = AnthropicApiKeyCredential::new(String::new()); + assert!(cred.auth_header().is_err()); + } + + #[test] + fn api_key_debug_redacts_key() { + let cred = ApiKeyCredential::new("secret-key".to_string()); + let debug_output = format!("{cred:?}"); + assert!(!debug_output.contains("secret-key")); + assert!(debug_output.contains("REDACTED")); + } + + #[test] + fn refresh_is_noop_for_static_credentials() { + let cred = ApiKeyCredential::new("test".to_string()); + assert!(cred.refresh().is_ok()); + } +} diff --git a/crates/arcan-provider/src/lib.rs b/crates/arcan-provider/src/lib.rs index 07b15d3..812953e 100644 --- a/crates/arcan-provider/src/lib.rs +++ b/crates/arcan-provider/src/lib.rs @@ -1,3 +1,5 @@ pub mod anthropic; +pub mod credential; +pub mod oauth; pub mod openai; pub mod rig_bridge; diff --git a/crates/arcan-provider/src/oauth.rs b/crates/arcan-provider/src/oauth.rs new file mode 100644 index 0000000..eba4e78 --- /dev/null +++ b/crates/arcan-provider/src/oauth.rs @@ -0,0 +1,722 @@ +use arcan_core::error::CoreError; +use base64::Engine; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::fmt; +use std::path::PathBuf; +use std::sync::RwLock; +use std::time::{SystemTime, UNIX_EPOCH}; + +use crate::credential::Credential; + +// ─── OpenAI Codex OAuth constants ───────────────────────────────── + +const OPENAI_AUTH_URL: &str = "https://auth.openai.com/authorize"; +const OPENAI_TOKEN_URL: &str = "https://auth.openai.com/oauth/token"; +const OPENAI_DEVICE_AUTH_URL: &str = "https://auth.openai.com/oauth/device/code"; +const OPENAI_CLIENT_ID: &str = "app_scp_BIqDzYAUWMiRFEih7bh0N"; +const OPENAI_REDIRECT_URI: &str = "http://127.0.0.1:8769/callback"; +const OPENAI_SCOPE: &str = "openai.public"; + +// ─── Token types ────────────────────────────────────────────────── + +/// Persisted OAuth token set. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OAuthTokenSet { + pub access_token: String, + pub refresh_token: Option, + /// Absolute expiry time (seconds since UNIX epoch). + pub expires_at: u64, + pub provider: String, +} + +impl OAuthTokenSet { + pub fn is_expired(&self) -> bool { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + // Treat as expired 60s before actual expiry to avoid edge-case failures. + now >= self.expires_at.saturating_sub(60) + } +} + +/// A credential backed by an OAuth token with automatic refresh. +pub struct OAuthCredential { + tokens: RwLock, + client_id: String, + token_url: String, +} + +impl fmt::Debug for OAuthCredential { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OAuthCredential") + .field("provider", &self.provider_name()) + .field("token_url", &self.token_url) + .finish() + } +} + +impl OAuthCredential { + pub fn new(tokens: OAuthTokenSet, client_id: String, token_url: String) -> Self { + Self { + tokens: RwLock::new(tokens), + client_id, + token_url, + } + } + + /// Create from a stored token set using OpenAI defaults. + pub fn openai(tokens: OAuthTokenSet) -> Self { + Self::new( + tokens, + OPENAI_CLIENT_ID.to_string(), + OPENAI_TOKEN_URL.to_string(), + ) + } + + fn provider_name(&self) -> String { + self.tokens + .read() + .map(|t| t.provider.clone()) + .unwrap_or_else(|_| "unknown".to_string()) + } +} + +impl Credential for OAuthCredential { + fn auth_header(&self) -> Result { + // Auto-refresh if expired. + if self.needs_refresh() { + self.refresh()?; + } + let tokens = self + .tokens + .read() + .map_err(|e| CoreError::Auth(format!("token lock poisoned: {e}")))?; + Ok(format!("Bearer {}", tokens.access_token)) + } + + fn kind(&self) -> &str { + "oauth" + } + + fn needs_refresh(&self) -> bool { + self.tokens.read().map(|t| t.is_expired()).unwrap_or(true) + } + + fn refresh(&self) -> Result<(), CoreError> { + let refresh_token = { + let tokens = self + .tokens + .read() + .map_err(|e| CoreError::Auth(format!("token lock poisoned: {e}")))?; + match &tokens.refresh_token { + Some(rt) => rt.clone(), + None => { + return Err(CoreError::Auth( + "no refresh token available, re-login required".to_string(), + )); + } + } + }; + + let new_tokens = refresh_token_grant(&self.token_url, &self.client_id, &refresh_token)?; + + let mut tokens = self + .tokens + .write() + .map_err(|e| CoreError::Auth(format!("token lock poisoned: {e}")))?; + + tokens.access_token = new_tokens.access_token; + if let Some(rt) = new_tokens.refresh_token { + tokens.refresh_token = Some(rt); + } + tokens.expires_at = new_tokens.expires_at; + + // Persist refreshed tokens. + if let Err(e) = store_tokens(&tokens) { + tracing::warn!(%e, "failed to persist refreshed tokens"); + } + + Ok(()) + } +} + +// ─── Token storage ──────────────────────────────────────────────── + +/// Returns the credentials directory: `~/.arcan/credentials/`. +pub fn credentials_dir() -> Result { + let home = dirs::home_dir() + .ok_or_else(|| CoreError::Auth("could not determine home directory".to_string()))?; + Ok(home.join(".arcan").join("credentials")) +} + +/// Path to the credential file for a given provider. +fn credential_path(provider: &str) -> Result { + Ok(credentials_dir()?.join(format!("{provider}.json"))) +} + +/// Store tokens to `~/.arcan/credentials/.json`. +pub fn store_tokens(tokens: &OAuthTokenSet) -> Result<(), CoreError> { + let dir = credentials_dir()?; + std::fs::create_dir_all(&dir) + .map_err(|e| CoreError::Auth(format!("failed to create credentials dir: {e}")))?; + + let path = dir.join(format!("{}.json", tokens.provider)); + let json = serde_json::to_string_pretty(tokens) + .map_err(|e| CoreError::Auth(format!("failed to serialize tokens: {e}")))?; + + std::fs::write(&path, &json) + .map_err(|e| CoreError::Auth(format!("failed to write credentials: {e}")))?; + + // Set file permissions to 0600 (owner read/write only) on Unix. + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let perms = std::fs::Permissions::from_mode(0o600); + std::fs::set_permissions(&path, perms) + .map_err(|e| CoreError::Auth(format!("failed to set file permissions: {e}")))?; + } + + Ok(()) +} + +/// Load stored tokens for a given provider. +pub fn load_tokens(provider: &str) -> Result { + let path = credential_path(provider)?; + let json = std::fs::read_to_string(&path) + .map_err(|e| CoreError::Auth(format!("no stored credentials for {provider}: {e}")))?; + serde_json::from_str(&json) + .map_err(|e| CoreError::Auth(format!("invalid stored credentials for {provider}: {e}"))) +} + +/// Remove stored tokens for a given provider. +pub fn remove_tokens(provider: &str) -> Result<(), CoreError> { + let path = credential_path(provider)?; + if path.exists() { + std::fs::remove_file(&path) + .map_err(|e| CoreError::Auth(format!("failed to remove credentials: {e}")))?; + } + Ok(()) +} + +/// List providers that have stored credentials. +pub fn list_stored_providers() -> Vec { + let Ok(dir) = credentials_dir() else { + return Vec::new(); + }; + let Ok(entries) = std::fs::read_dir(dir) else { + return Vec::new(); + }; + entries + .filter_map(std::result::Result::ok) + .filter_map(|e| { + let name = e.file_name().to_string_lossy().to_string(); + name.strip_suffix(".json").map(String::from) + }) + .collect() +} + +// ─── Token refresh ──────────────────────────────────────────────── + +/// Standard OAuth 2.0 refresh_token grant. +fn refresh_token_grant( + token_url: &str, + client_id: &str, + refresh_token: &str, +) -> Result { + let client = reqwest::blocking::Client::new(); + let resp = client + .post(token_url) + .form(&[ + ("grant_type", "refresh_token"), + ("client_id", client_id), + ("refresh_token", refresh_token), + ]) + .send() + .map_err(|e| CoreError::Auth(format!("refresh request failed: {e}")))?; + + let status = resp.status(); + let body = resp + .text() + .map_err(|e| CoreError::Auth(format!("failed to read refresh response: {e}")))?; + + if !status.is_success() { + return Err(CoreError::Auth(format!( + "token refresh failed ({status}): {body}" + ))); + } + + parse_token_response(&body, "openai") +} + +// ─── PKCE helpers ───────────────────────────────────────────────── + +/// Generate a random PKCE code verifier (43-128 URL-safe characters). +fn generate_code_verifier() -> String { + let mut rng = rand::rng(); + let bytes: Vec = (0..32).map(|_| rng.random::()).collect(); + URL_SAFE_NO_PAD.encode(&bytes) +} + +/// Compute the S256 code challenge from a verifier. +fn compute_code_challenge(verifier: &str) -> String { + let hash = Sha256::digest(verifier.as_bytes()); + URL_SAFE_NO_PAD.encode(hash) +} + +// ─── OAuth flows ────────────────────────────────────────────────── + +/// Run the PKCE Authorization Code flow for OpenAI. +/// +/// 1. Generate PKCE verifier/challenge +/// 2. Open browser to authorization URL +/// 3. Start local server on port 8769 to receive callback +/// 4. Exchange authorization code for tokens +/// 5. Store tokens to disk +#[allow(clippy::print_stderr)] +pub fn pkce_login_openai() -> Result { + let code_verifier = generate_code_verifier(); + let code_challenge = compute_code_challenge(&code_verifier); + + // Build authorization URL. + let auth_url = format!( + "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256", + OPENAI_AUTH_URL, + urlencoding::encode(OPENAI_CLIENT_ID), + urlencoding::encode(OPENAI_REDIRECT_URI), + urlencoding::encode(OPENAI_SCOPE), + urlencoding::encode(&code_challenge), + ); + + eprintln!("Opening browser for OpenAI authentication..."); + eprintln!("If the browser doesn't open, visit:\n{auth_url}\n"); + + // Try to open browser (best-effort). + let _ = open::that(&auth_url); + + // Start local callback server. + let listener = std::net::TcpListener::bind("127.0.0.1:8769") + .map_err(|e| CoreError::Auth(format!("failed to bind callback server: {e}")))?; + + eprintln!("Waiting for authorization callback on http://127.0.0.1:8769/callback ..."); + + let code = wait_for_callback(&listener)?; + + // Exchange code for tokens. + let client = reqwest::blocking::Client::new(); + let resp = client + .post(OPENAI_TOKEN_URL) + .form(&[ + ("grant_type", "authorization_code"), + ("client_id", OPENAI_CLIENT_ID), + ("code", code.as_str()), + ("redirect_uri", OPENAI_REDIRECT_URI), + ("code_verifier", code_verifier.as_str()), + ]) + .send() + .map_err(|e| CoreError::Auth(format!("token exchange failed: {e}")))?; + + let status = resp.status(); + let body = resp + .text() + .map_err(|e| CoreError::Auth(format!("failed to read token response: {e}")))?; + + if !status.is_success() { + return Err(CoreError::Auth(format!( + "token exchange failed ({status}): {body}" + ))); + } + + let tokens = parse_token_response(&body, "openai")?; + store_tokens(&tokens)?; + + eprintln!("Successfully authenticated with OpenAI!"); + Ok(tokens) +} + +/// Run the Device Code flow for OpenAI (headless/SSH environments). +/// +/// 1. Request device code +/// 2. Display verification URL and user code +/// 3. Poll token endpoint until authorized +/// 4. Store tokens to disk +#[allow(clippy::print_stderr)] +pub fn device_login_openai() -> Result { + let client = reqwest::blocking::Client::new(); + + // Step 1: Request device code. + let resp = client + .post(OPENAI_DEVICE_AUTH_URL) + .form(&[("client_id", OPENAI_CLIENT_ID), ("scope", OPENAI_SCOPE)]) + .send() + .map_err(|e| CoreError::Auth(format!("device auth request failed: {e}")))?; + + let status = resp.status(); + let body = resp + .text() + .map_err(|e| CoreError::Auth(format!("failed to read device auth response: {e}")))?; + + if !status.is_success() { + return Err(CoreError::Auth(format!( + "device authorization failed ({status}): {body}" + ))); + } + + let device_resp: DeviceCodeResponse = serde_json::from_str(&body) + .map_err(|e| CoreError::Auth(format!("invalid device auth response: {e}")))?; + + // Step 2: Display instructions. + eprintln!("\nTo authenticate, visit: {}", device_resp.verification_uri); + if let Some(ref complete_uri) = device_resp.verification_uri_complete { + eprintln!("Or open: {complete_uri}"); + } + eprintln!("Enter code: {}\n", device_resp.user_code); + + // Step 3: Poll for token. + let interval = std::time::Duration::from_secs(device_resp.interval.max(5)); + let deadline = + std::time::Instant::now() + std::time::Duration::from_secs(device_resp.expires_in.min(900)); + + loop { + std::thread::sleep(interval); + + if std::time::Instant::now() > deadline { + return Err(CoreError::Auth( + "device authorization timed out".to_string(), + )); + } + + let resp = client + .post(OPENAI_TOKEN_URL) + .form(&[ + ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), + ("client_id", OPENAI_CLIENT_ID), + ("device_code", &device_resp.device_code), + ]) + .send() + .map_err(|e| CoreError::Auth(format!("device token poll failed: {e}")))?; + + let status = resp.status(); + let body = resp + .text() + .map_err(|e| CoreError::Auth(format!("failed to read poll response: {e}")))?; + + if status.is_success() { + let tokens = parse_token_response(&body, "openai")?; + store_tokens(&tokens)?; + eprintln!("Successfully authenticated with OpenAI!"); + return Ok(tokens); + } + + // Check for specific OAuth error codes. + if let Ok(err_resp) = serde_json::from_str::(&body) { + match err_resp.error.as_str() { + "authorization_pending" => continue, + "slow_down" => { + // Back off a bit more. + std::thread::sleep(std::time::Duration::from_secs(5)); + continue; + } + "expired_token" => { + return Err(CoreError::Auth( + "device code expired, please try again".to_string(), + )); + } + "access_denied" => { + return Err(CoreError::Auth("authorization denied by user".to_string())); + } + _ => { + return Err(CoreError::Auth(format!( + "device auth error: {}", + err_resp.error_description.unwrap_or(err_resp.error) + ))); + } + } + } + + return Err(CoreError::Auth(format!( + "unexpected device token response ({status}): {body}" + ))); + } +} + +// ─── Internal helpers ───────────────────────────────────────────── + +/// Wait for the OAuth callback on a local server, extract the `code` parameter. +fn wait_for_callback(listener: &std::net::TcpListener) -> Result { + use std::io::{Read, Write}; + + let (mut stream, _) = listener + .accept() + .map_err(|e| CoreError::Auth(format!("failed to accept callback: {e}")))?; + + let mut buf = [0u8; 4096]; + let n = stream + .read(&mut buf) + .map_err(|e| CoreError::Auth(format!("failed to read callback request: {e}")))?; + let request = String::from_utf8_lossy(&buf[..n]); + + // Extract the request path from "GET /callback?code=...&state=... HTTP/1.1" + let path = request + .lines() + .next() + .and_then(|line| line.split_whitespace().nth(1)) + .ok_or_else(|| CoreError::Auth("malformed callback request".to_string()))?; + + // Parse the URL to extract the code parameter. + let full_url = format!("http://127.0.0.1:8769{path}"); + let parsed = url::Url::parse(&full_url) + .map_err(|e| CoreError::Auth(format!("failed to parse callback URL: {e}")))?; + + // Check for error in callback. + if let Some(error) = parsed.query_pairs().find(|(k, _)| k == "error") { + let desc = parsed + .query_pairs() + .find(|(k, _)| k == "error_description") + .map(|(_, v)| v.to_string()) + .unwrap_or_default(); + // Send error response to browser. + let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n

Authentication Failed

You can close this window.

"; + let _ = stream.write_all(response.as_bytes()); + return Err(CoreError::Auth(format!( + "OAuth error: {} — {desc}", + error.1 + ))); + } + + let code = parsed + .query_pairs() + .find(|(k, _)| k == "code") + .map(|(_, v)| v.to_string()) + .ok_or_else(|| CoreError::Auth("no authorization code in callback".to_string()))?; + + // Send success response to browser. + let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n

Authenticated!

You can close this window and return to the terminal.

"; + let _ = stream.write_all(response.as_bytes()); + + Ok(code) +} + +/// Parse a standard OAuth token response into our `OAuthTokenSet`. +fn parse_token_response(body: &str, provider: &str) -> Result { + let resp: TokenResponse = serde_json::from_str(body) + .map_err(|e| CoreError::Auth(format!("invalid token response: {e}")))?; + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + + Ok(OAuthTokenSet { + access_token: resp.access_token, + refresh_token: resp.refresh_token, + expires_at: now + resp.expires_in.unwrap_or(3600), + provider: provider.to_string(), + }) +} + +// ─── Response types ─────────────────────────────────────────────── + +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + refresh_token: Option, + expires_in: Option, + #[allow(dead_code)] + token_type: Option, +} + +#[derive(Debug, Deserialize)] +struct DeviceCodeResponse { + device_code: String, + user_code: String, + verification_uri: String, + verification_uri_complete: Option, + expires_in: u64, + interval: u64, +} + +#[derive(Debug, Deserialize)] +struct OAuthErrorResponse { + error: String, + error_description: Option, +} + +// ─── URL encoding helper ────────────────────────────────────────── + +mod urlencoding { + pub fn encode(s: &str) -> String { + let mut encoded = String::new(); + for byte in s.bytes() { + match byte { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { + encoded.push(byte as char); + } + _ => { + encoded.push('%'); + encoded.push_str(&format!("{byte:02X}")); + } + } + } + encoded + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn token_set_not_expired() { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let tokens = OAuthTokenSet { + access_token: "test".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: now + 3600, + provider: "test".to_string(), + }; + assert!(!tokens.is_expired()); + } + + #[test] + fn token_set_expired() { + let tokens = OAuthTokenSet { + access_token: "test".to_string(), + refresh_token: None, + expires_at: 1000, // Long past. + provider: "test".to_string(), + }; + assert!(tokens.is_expired()); + } + + #[test] + fn token_set_expires_within_buffer() { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let tokens = OAuthTokenSet { + access_token: "test".to_string(), + refresh_token: None, + expires_at: now + 30, // 30s from now, within 60s buffer. + provider: "test".to_string(), + }; + assert!(tokens.is_expired()); + } + + #[test] + fn pkce_code_verifier_length() { + let verifier = generate_code_verifier(); + assert!(verifier.len() >= 43); + } + + #[test] + fn pkce_code_challenge_is_base64url() { + let verifier = generate_code_verifier(); + let challenge = compute_code_challenge(&verifier); + // Base64url should not contain +, /, or =. + assert!(!challenge.contains('+')); + assert!(!challenge.contains('/')); + assert!(!challenge.contains('=')); + assert_eq!(challenge.len(), 43); // SHA-256 = 32 bytes → 43 base64url chars. + } + + #[test] + fn parse_token_response_full() { + let body = r#"{ + "access_token": "at-123", + "refresh_token": "rt-456", + "expires_in": 7200, + "token_type": "Bearer" + }"#; + let tokens = parse_token_response(body, "openai").unwrap(); + assert_eq!(tokens.access_token, "at-123"); + assert_eq!(tokens.refresh_token.as_deref(), Some("rt-456")); + assert_eq!(tokens.provider, "openai"); + assert!(!tokens.is_expired()); + } + + #[test] + fn parse_token_response_minimal() { + let body = r#"{"access_token": "at-123"}"#; + let tokens = parse_token_response(body, "test").unwrap(); + assert_eq!(tokens.access_token, "at-123"); + assert!(tokens.refresh_token.is_none()); + } + + #[test] + fn token_storage_roundtrip() { + let dir = std::env::temp_dir().join("arcan-oauth-test"); + let _ = std::fs::remove_dir_all(&dir); + + // Temporarily override home dir by using direct path functions. + let tokens = OAuthTokenSet { + access_token: "at-roundtrip".to_string(), + refresh_token: Some("rt-roundtrip".to_string()), + expires_at: 9999999999, + provider: "test-roundtrip".to_string(), + }; + + // Write directly to a known path. + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("test-roundtrip.json"); + let json = serde_json::to_string_pretty(&tokens).unwrap(); + std::fs::write(&path, &json).unwrap(); + + // Read back. + let loaded: OAuthTokenSet = + serde_json::from_str(&std::fs::read_to_string(&path).unwrap()).unwrap(); + assert_eq!(loaded.access_token, "at-roundtrip"); + assert_eq!(loaded.refresh_token.as_deref(), Some("rt-roundtrip")); + + let _ = std::fs::remove_dir_all(dir); + } + + #[test] + fn oauth_credential_kind() { + let tokens = OAuthTokenSet { + access_token: "at-test".to_string(), + refresh_token: None, + expires_at: 9999999999, + provider: "openai".to_string(), + }; + let cred = OAuthCredential::openai(tokens); + assert_eq!(cred.kind(), "oauth"); + } + + #[test] + fn oauth_credential_auth_header_with_valid_token() { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let tokens = OAuthTokenSet { + access_token: "at-valid".to_string(), + refresh_token: None, + expires_at: now + 3600, + provider: "openai".to_string(), + }; + let cred = OAuthCredential::openai(tokens); + assert_eq!(cred.auth_header().unwrap(), "Bearer at-valid"); + } + + #[test] + fn urlencoding_basic() { + assert_eq!(urlencoding::encode("hello"), "hello"); + assert_eq!(urlencoding::encode("hello world"), "hello%20world"); + assert_eq!(urlencoding::encode("a+b"), "a%2Bb"); + } + + #[test] + fn list_stored_providers_empty() { + // This just tests that the function doesn't panic on empty/missing dirs. + // In CI or fresh machines, there may be no stored providers. + let _ = list_stored_providers(); + } +} diff --git a/crates/arcan-provider/src/openai.rs b/crates/arcan-provider/src/openai.rs index 323ba75..c1ce48e 100644 --- a/crates/arcan-provider/src/openai.rs +++ b/crates/arcan-provider/src/openai.rs @@ -6,15 +6,17 @@ use arcan_core::protocol::{ use arcan_core::runtime::{Provider, ProviderRequest}; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; +use std::sync::Arc; + +use crate::credential::{ApiKeyCredential, Credential}; /// Configuration for an OpenAI-compatible provider. /// /// Works with: OpenAI, Ollama, Together, Groq, LM Studio, vLLM, and any /// server that implements the OpenAI chat completions API. -#[derive(Debug, Clone)] pub struct OpenAiConfig { - /// API key (empty string for local servers like Ollama). - pub api_key: String, + /// Credential for API authentication (API key or OAuth token). + pub credential: Arc, /// Model name (e.g., "gpt-4o", "llama3.1", "qwen2.5"). pub model: String, /// Maximum tokens for model response. @@ -27,6 +29,18 @@ pub struct OpenAiConfig { pub enable_streaming: bool, } +impl std::fmt::Debug for OpenAiConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OpenAiConfig") + .field("credential", &self.credential.kind()) + .field("model", &self.model) + .field("max_tokens", &self.max_tokens) + .field("base_url", &self.base_url) + .field("provider_name", &self.provider_name) + .finish() + } +} + impl OpenAiConfig { /// Create config for OpenAI from environment variables. pub fn openai_from_env() -> Result { @@ -43,7 +57,7 @@ impl OpenAiConfig { .unwrap_or_else(|_| "https://api.openai.com".to_string()); Ok(Self { - api_key, + credential: Arc::new(ApiKeyCredential::new(api_key)), model, max_tokens, base_url, @@ -64,7 +78,7 @@ impl OpenAiConfig { .unwrap_or_else(|_| "http://localhost:11434".to_string()); Ok(Self { - api_key: String::new(), + credential: Arc::new(ApiKeyCredential::new(String::new())), model, max_tokens, base_url, @@ -102,7 +116,7 @@ impl OpenAiConfig { .unwrap_or_else(|| "https://api.openai.com".to_string()); Ok(Self { - api_key, + credential: Arc::new(ApiKeyCredential::new(api_key)), model, max_tokens, base_url, @@ -137,7 +151,7 @@ impl OpenAiConfig { .unwrap_or_else(|| "http://localhost:11434".to_string()); Ok(Self { - api_key: String::new(), + credential: Arc::new(ApiKeyCredential::new(String::new())), model, max_tokens, base_url, @@ -145,6 +159,18 @@ impl OpenAiConfig { enable_streaming: enable_streaming_override.unwrap_or(true), }) } + + /// Create config with an OAuth credential. + pub fn from_oauth(credential: Arc, model: String) -> Self { + Self { + credential, + model, + max_tokens: 4096, + base_url: "https://api.openai.com".to_string(), + provider_name: "openai".to_string(), + enable_streaming: false, + } + } } /// Provider implementation for any OpenAI-compatible chat completions API. @@ -256,6 +282,7 @@ impl OpenAiCompatibleProvider { } /// Execute with retry logic for transient errors (429, 5xx). + /// On 401 Unauthorized, attempts to refresh the credential once before failing. fn execute_with_retry( &self, body: &Value, @@ -264,6 +291,7 @@ impl OpenAiCompatibleProvider { ) -> Result { let mut last_error = None; let base_delay = std::time::Duration::from_millis(200); + let mut refreshed_on_401 = false; for attempt in 0..=max_retries { if attempt > 0 { @@ -276,10 +304,9 @@ impl OpenAiCompatibleProvider { .post(url) .header("content-type", "application/json"); - // Only add Authorization header if API key is non-empty - if !self.config.api_key.is_empty() { - request = - request.header("authorization", format!("Bearer {}", self.config.api_key)); + // Use credential for auth header (supports API keys, OAuth tokens, etc.) + if let Ok(header_value) = self.config.credential.auth_header() { + request = request.header("authorization", header_value); } let response = match request.json(body).send() { @@ -296,6 +323,16 @@ impl OpenAiCompatibleProvider { .text() .map_err(|e| CoreError::Provider(format!("failed to read response: {e}")))?; + // On 401, try refreshing the credential once. + if status.as_u16() == 401 + && !refreshed_on_401 + && self.config.credential.needs_refresh() + && self.config.credential.refresh().is_ok() + { + refreshed_on_401 = true; + continue; + } + // Retry on transient errors if (status.as_u16() == 429 || status.is_server_error()) && attempt < max_retries { last_error = Some(format!("{status}: {response_text}")); @@ -351,9 +388,9 @@ impl OpenAiCompatibleProvider { .post(&url) .header("content-type", "application/json"); - if !self.config.api_key.is_empty() { - http_request = - http_request.header("authorization", format!("Bearer {}", self.config.api_key)); + // Add authorization header from credential (API key or OAuth token) + if let Ok(auth_header) = self.config.credential.auth_header() { + http_request = http_request.header("authorization", auth_header); } let response = http_request @@ -644,7 +681,7 @@ mod tests { fn test_config() -> OpenAiConfig { OpenAiConfig { - api_key: "test-key".to_string(), + credential: Arc::new(ApiKeyCredential::new("test-key".to_string())), model: "gpt-4o".to_string(), max_tokens: 4096, base_url: "http://localhost:8080".to_string(), @@ -809,15 +846,17 @@ mod tests { #[test] fn ollama_config_defaults() { // This just tests the config structure, not actual env vars + let cred = ApiKeyCredential::new(String::new()); let config = OpenAiConfig { - api_key: String::new(), + credential: Arc::new(cred), model: "llama3.2".to_string(), max_tokens: 4096, base_url: "http://localhost:11434".to_string(), provider_name: "ollama".to_string(), enable_streaming: true, }; - assert!(config.api_key.is_empty()); + // Empty credential returns error on auth_header (expected for local servers) + assert!(config.credential.auth_header().is_err()); assert_eq!(config.base_url, "http://localhost:11434"); assert!(config.enable_streaming); } diff --git a/crates/arcan/src/main.rs b/crates/arcan/src/main.rs index 283247f..001ec2e 100644 --- a/crates/arcan/src/main.rs +++ b/crates/arcan/src/main.rs @@ -198,6 +198,20 @@ enum Command { Status, /// Stop the running daemon Stop, + /// Authenticate with an LLM provider via OAuth + Login { + /// Provider to authenticate with (e.g. "openai") + provider: String, + + /// Use device code flow instead of browser-based PKCE (for headless environments) + #[arg(long)] + device: bool, + }, + /// Remove stored OAuth credentials for a provider + Logout { + /// Provider to log out of (e.g. "openai") + provider: String, + }, } #[derive(Subcommand)] @@ -292,6 +306,18 @@ async fn resolve_session( "default".to_owned() } +/// Try to create an OpenAI provider from stored OAuth credentials. +fn try_openai_oauth_provider() -> Option> { + let tokens = arcan_provider::oauth::load_tokens("openai").ok()?; + let credential = Arc::new(arcan_provider::oauth::OAuthCredential::openai(tokens)); + let model = std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4o".to_string()); + let config = arcan_provider::openai::OpenAiConfig::from_oauth(credential, model.clone()); + tracing::info!(%model, "Provider: OpenAI (OAuth)"); + Some(Arc::new( + arcan_provider::openai::OpenAiCompatibleProvider::new(config), + )) +} + /// Build provider from resolved configuration. fn build_provider(resolved: &ResolvedConfig) -> anyhow::Result> { let pc = resolved.provider_config.as_ref(); @@ -301,7 +327,11 @@ fn build_provider(resolved: &ResolvedConfig) -> anyhow::Result tracing::warn!("Provider: MockProvider (forced via config)"); Ok(Arc::new(MockProvider)) } - "openai" => { + "openai" | "codex" | "openai-codex" => { + // Try OAuth credential first, then fall back to env var / resolved config. + if let Some(p) = try_openai_oauth_provider() { + return Ok(p); + } let config = arcan_provider::openai::OpenAiConfig::openai_from_resolved( resolved.model.as_deref(), pc.and_then(|p| p.base_url.as_deref()), @@ -338,6 +368,8 @@ fn build_provider(resolved: &ResolvedConfig) -> anyhow::Result if let Ok(config) = AnthropicConfig::from_env() { tracing::info!(model = %config.model, "Provider: Anthropic (auto-detected)"); Ok(Arc::new(AnthropicProvider::new(config))) + } else if let Some(p) = try_openai_oauth_provider() { + Ok(p) } else if let Ok(config) = arcan_provider::openai::OpenAiConfig::openai_from_env() { tracing::info!(model = %config.model, "Provider: OpenAI (auto-detected)"); Ok(Arc::new( @@ -670,6 +702,34 @@ async fn run_status(data_dir: &Path, resolved: &ResolvedConfig) -> anyhow::Resul Ok(()) } +fn run_login(provider: &str, device: bool) -> anyhow::Result<()> { + match provider { + "openai" | "codex" | "openai-codex" => { + if device { + arcan_provider::oauth::device_login_openai().map_err(|e| anyhow::anyhow!("{e}"))?; + } else { + arcan_provider::oauth::pkce_login_openai().map_err(|e| anyhow::anyhow!("{e}"))?; + } + Ok(()) + } + _ => Err(anyhow::anyhow!( + "Unknown provider '{provider}'. Supported: openai" + )), + } +} + +#[allow(clippy::print_stderr)] +fn run_logout(provider: &str) -> anyhow::Result<()> { + // Normalize provider name for credential lookup. + let normalized = match provider { + "codex" | "openai-codex" => "openai", + other => other, + }; + arcan_provider::oauth::remove_tokens(normalized).map_err(|e| anyhow::anyhow!("{e}"))?; + eprintln!("Logged out of {provider}"); + Ok(()) +} + fn main() -> anyhow::Result<()> { let cli = Cli::parse(); let data_dir = resolve_data_dir(&cli.data_dir)?; @@ -722,6 +782,8 @@ fn main() -> anyhow::Result<()> { .build()?; runtime.block_on(run_chat(data_dir, &resolved, session, url)) } + Some(Command::Login { provider, device }) => run_login(&provider, device), + Some(Command::Logout { provider }) => run_logout(&provider), Some(Command::Run { message, session,