From bc4161ce49348e75fd3886b9521c0cb3d9e71e43 Mon Sep 17 00:00:00 2001 From: Carlos Escobar Date: Sat, 28 Feb 2026 23:51:44 -0500 Subject: [PATCH] feat: Add OAuth credential system with OpenAI Codex support Introduce a generic Credential trait abstraction for LLM provider authentication, replacing raw API key strings with pluggable credential types. Implement OpenAI Codex OAuth 2.0 flows (PKCE and Device Code) as the first concrete OAuth provider. - Add Credential trait with auth_header(), refresh(), and needs_refresh() - Add ApiKeyCredential and AnthropicApiKeyCredential implementations - Add OAuthCredential with automatic token refresh via RwLock - Add PKCE Authorization Code flow with local callback server - Add Device Code flow for headless/SSH environments - Add token storage to ~/.arcan/credentials/ with 0600 permissions - Refactor OpenAiConfig and AnthropicConfig to use Arc - Add 401 retry with credential refresh in execute_with_retry - Add `arcan login openai` and `arcan logout openai` CLI commands - Update provider auto-detection to check stored OAuth credentials Co-Authored-By: Claude Opus 4.6 --- Cargo.lock | 43 ++ crates/arcan-core/src/error.rs | 2 + crates/arcan-provider/Cargo.toml | 7 + crates/arcan-provider/src/anthropic.rs | 60 +- crates/arcan-provider/src/credential.rs | 159 ++++++ crates/arcan-provider/src/lib.rs | 2 + crates/arcan-provider/src/oauth.rs | 722 ++++++++++++++++++++++++ crates/arcan-provider/src/openai.rs | 73 ++- crates/arcan/src/main.rs | 64 ++- 9 files changed, 1089 insertions(+), 43 deletions(-) create mode 100644 crates/arcan-provider/src/credential.rs create mode 100644 crates/arcan-provider/src/oauth.rs diff --git a/Cargo.lock b/Cargo.lock index 215952a..9d62e10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -282,12 +282,19 @@ name = "arcan-provider" version = "0.2.0" dependencies = [ "arcan-core", + "base64", + "dirs", + "open", + "rand 0.9.2", "reqwest", "rig-core", "serde", "serde_json", + "sha2", "thiserror 2.0.18", "tokio", + "tracing", + "url", "uuid", ] @@ -1422,6 +1429,25 @@ dependencies = [ "serde", ] +[[package]] +name = "is-docker" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "928bae27f42bc99b60d9ac7334e3a21d10ad8f1835a4e12ec3ec0464765ed1b3" +dependencies = [ + "once_cell", +] + +[[package]] +name = "is-wsl" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "173609498df190136aa7dea1a91db051746d339e18476eed5ca40521f02d7aa5" +dependencies = [ + "is-docker", + "once_cell", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.2" @@ -1840,6 +1866,17 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "open" +version = "5.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43bb73a7fa3799b198970490a51174027ba0d4ec504b03cd08caf513d40024bc" +dependencies = [ + "is-wsl", + "libc", + "pathdiff", +] + [[package]] name = "openssl" version = "0.10.75" @@ -1934,6 +1971,12 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b867cad97c0791bbd3aaa6472142568c6c9e8f71937e98379f584cfb0cf35bec" +[[package]] +name = "pathdiff" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" + [[package]] name = "percent-encoding" version = "2.3.2" diff --git a/crates/arcan-core/src/error.rs b/crates/arcan-core/src/error.rs index 732bcb6..c15215b 100644 --- a/crates/arcan-core/src/error.rs +++ b/crates/arcan-core/src/error.rs @@ -12,4 +12,6 @@ pub enum CoreError { Middleware(String), #[error("state patch failed: {0}")] State(String), + #[error("auth error: {0}")] + Auth(String), } diff --git a/crates/arcan-provider/Cargo.toml b/crates/arcan-provider/Cargo.toml index d565e74..65311c5 100644 --- a/crates/arcan-provider/Cargo.toml +++ b/crates/arcan-provider/Cargo.toml @@ -16,10 +16,17 @@ workspace = true [dependencies] arcan-core = { path = "../arcan-core", version = "0.2.0" } +base64 = "0.22" +dirs = "6" +open = "5" +rand = "0.9" reqwest.workspace = true rig-core.workspace = true serde.workspace = true serde_json.workspace = true +sha2 = "0.10" thiserror.workspace = true tokio.workspace = true +tracing.workspace = true +url = "2" uuid.workspace = true diff --git a/crates/arcan-provider/src/anthropic.rs b/crates/arcan-provider/src/anthropic.rs index 4fa0ca7..fbeff39 100644 --- a/crates/arcan-provider/src/anthropic.rs +++ b/crates/arcan-provider/src/anthropic.rs @@ -6,16 +6,29 @@ use arcan_core::protocol::{ use arcan_core::runtime::{Provider, ProviderRequest}; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; +use std::sync::Arc; + +use crate::credential::{AnthropicApiKeyCredential, Credential}; /// Configuration for the Anthropic provider. -#[derive(Debug, Clone)] pub struct AnthropicConfig { - pub api_key: String, + pub credential: Arc, pub model: String, pub max_tokens: u32, pub base_url: String, } +impl std::fmt::Debug for AnthropicConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("AnthropicConfig") + .field("credential", &self.credential.kind()) + .field("model", &self.model) + .field("max_tokens", &self.max_tokens) + .field("base_url", &self.base_url) + .finish() + } +} + impl AnthropicConfig { pub fn from_env() -> Result { let api_key = std::env::var("ANTHROPIC_API_KEY").map_err(|_| { @@ -34,7 +47,7 @@ impl AnthropicConfig { .unwrap_or_else(|_| "https://api.anthropic.com".to_string()); Ok(Self { - api_key, + credential: Arc::new(AnthropicApiKeyCredential::new(api_key)), model, max_tokens, base_url, @@ -73,7 +86,7 @@ impl AnthropicConfig { .unwrap_or_else(|| "https://api.anthropic.com".to_string()); Ok(Self { - api_key, + credential: Arc::new(crate::credential::AnthropicApiKeyCredential::new(api_key)), model, max_tokens, base_url, @@ -214,10 +227,16 @@ impl Provider for AnthropicProvider { let url = format!("{}/v1/messages", self.config.base_url); + let api_key = self + .config + .credential + .auth_header() + .map_err(|e| CoreError::Provider(format!("credential error: {e}")))?; + let response = self .client .post(&url) - .header("x-api-key", &self.config.api_key) + .header("x-api-key", &api_key) .header("anthropic-version", "2023-06-01") .header("content-type", "application/json") .json(&body) @@ -310,15 +329,18 @@ enum ResponseBlock { mod tests { use super::*; - #[test] - fn builds_messages_with_system_prompt() { - let config = AnthropicConfig { - api_key: "test".to_string(), + fn test_config() -> AnthropicConfig { + AnthropicConfig { + credential: Arc::new(AnthropicApiKeyCredential::new("test".to_string())), model: "test-model".to_string(), max_tokens: 1024, base_url: "http://localhost".to_string(), - }; - let provider = AnthropicProvider::new(config); + } + } + + #[test] + fn builds_messages_with_system_prompt() { + let provider = AnthropicProvider::new(test_config()); let messages = vec![ ChatMessage::system("You are helpful."), @@ -333,13 +355,7 @@ mod tests { #[test] fn parses_text_response() { - let config = AnthropicConfig { - api_key: "test".to_string(), - model: "test-model".to_string(), - max_tokens: 1024, - base_url: "http://localhost".to_string(), - }; - let provider = AnthropicProvider::new(config); + let provider = AnthropicProvider::new(test_config()); let response = ApiResponse { content: vec![ResponseBlock::Text { @@ -357,13 +373,7 @@ mod tests { #[test] fn parses_tool_use_response() { - let config = AnthropicConfig { - api_key: "test".to_string(), - model: "test-model".to_string(), - max_tokens: 1024, - base_url: "http://localhost".to_string(), - }; - let provider = AnthropicProvider::new(config); + let provider = AnthropicProvider::new(test_config()); let response = ApiResponse { content: vec![ diff --git a/crates/arcan-provider/src/credential.rs b/crates/arcan-provider/src/credential.rs new file mode 100644 index 0000000..c43df59 --- /dev/null +++ b/crates/arcan-provider/src/credential.rs @@ -0,0 +1,159 @@ +use arcan_core::error::CoreError; +use std::fmt; + +/// A credential that can produce HTTP authorization headers. +/// +/// Implementations handle API keys, OAuth tokens with refresh, etc. +pub trait Credential: Send + Sync + fmt::Debug { + /// Returns the authorization header value (e.g. `"Bearer "`). + fn auth_header(&self) -> Result; + + /// Returns the credential kind for display/logging. + fn kind(&self) -> &str; + + /// Whether this credential needs periodic refresh (OAuth tokens do, API keys don't). + fn needs_refresh(&self) -> bool { + false + } + + /// Refresh the credential if needed. No-op for static credentials. + fn refresh(&self) -> Result<(), CoreError> { + Ok(()) + } +} + +/// A static API key credential that produces `Bearer ` headers. +/// +/// Used for OpenAI, Ollama, and other Bearer-token APIs. +#[derive(Clone)] +pub struct ApiKeyCredential { + api_key: String, +} + +impl fmt::Debug for ApiKeyCredential { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("ApiKeyCredential") + .field("api_key", &"[REDACTED]") + .finish() + } +} + +impl ApiKeyCredential { + pub fn new(api_key: String) -> Self { + Self { api_key } + } + + /// Returns the raw API key (for providers that need non-Bearer auth). + pub fn raw_key(&self) -> &str { + &self.api_key + } + + /// Whether the underlying key is empty (e.g. Ollama local servers). + pub fn is_empty(&self) -> bool { + self.api_key.is_empty() + } +} + +impl Credential for ApiKeyCredential { + fn auth_header(&self) -> Result { + if self.api_key.is_empty() { + return Err(CoreError::Auth("API key is empty".to_string())); + } + Ok(format!("Bearer {}", self.api_key)) + } + + fn kind(&self) -> &str { + "api_key" + } +} + +/// A static API key credential that produces `x-api-key` style headers. +/// +/// Used specifically for Anthropic which uses a custom header instead of Bearer. +#[derive(Clone)] +pub struct AnthropicApiKeyCredential { + api_key: String, +} + +impl fmt::Debug for AnthropicApiKeyCredential { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("AnthropicApiKeyCredential") + .field("api_key", &"[REDACTED]") + .finish() + } +} + +impl AnthropicApiKeyCredential { + pub fn new(api_key: String) -> Self { + Self { api_key } + } + + /// Returns the raw API key for direct use in `x-api-key` header. + pub fn raw_key(&self) -> &str { + &self.api_key + } +} + +impl Credential for AnthropicApiKeyCredential { + fn auth_header(&self) -> Result { + if self.api_key.is_empty() { + return Err(CoreError::Auth("Anthropic API key is empty".to_string())); + } + // Anthropic uses `x-api-key` header directly, but we return the raw value + // so the provider can set it on the appropriate header. + Ok(self.api_key.clone()) + } + + fn kind(&self) -> &str { + "anthropic_api_key" + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn api_key_credential_bearer_header() { + let cred = ApiKeyCredential::new("sk-test-123".to_string()); + assert_eq!(cred.auth_header().unwrap(), "Bearer sk-test-123"); + assert_eq!(cred.kind(), "api_key"); + assert!(!cred.needs_refresh()); + assert!(!cred.is_empty()); + } + + #[test] + fn api_key_credential_empty_returns_error() { + let cred = ApiKeyCredential::new(String::new()); + assert!(cred.auth_header().is_err()); + assert!(cred.is_empty()); + } + + #[test] + fn anthropic_credential_raw_key() { + let cred = AnthropicApiKeyCredential::new("sk-ant-test".to_string()); + assert_eq!(cred.auth_header().unwrap(), "sk-ant-test"); + assert_eq!(cred.kind(), "anthropic_api_key"); + assert_eq!(cred.raw_key(), "sk-ant-test"); + } + + #[test] + fn anthropic_credential_empty_returns_error() { + let cred = AnthropicApiKeyCredential::new(String::new()); + assert!(cred.auth_header().is_err()); + } + + #[test] + fn api_key_debug_redacts_key() { + let cred = ApiKeyCredential::new("secret-key".to_string()); + let debug_output = format!("{cred:?}"); + assert!(!debug_output.contains("secret-key")); + assert!(debug_output.contains("REDACTED")); + } + + #[test] + fn refresh_is_noop_for_static_credentials() { + let cred = ApiKeyCredential::new("test".to_string()); + assert!(cred.refresh().is_ok()); + } +} diff --git a/crates/arcan-provider/src/lib.rs b/crates/arcan-provider/src/lib.rs index 07b15d3..812953e 100644 --- a/crates/arcan-provider/src/lib.rs +++ b/crates/arcan-provider/src/lib.rs @@ -1,3 +1,5 @@ pub mod anthropic; +pub mod credential; +pub mod oauth; pub mod openai; pub mod rig_bridge; diff --git a/crates/arcan-provider/src/oauth.rs b/crates/arcan-provider/src/oauth.rs new file mode 100644 index 0000000..eba4e78 --- /dev/null +++ b/crates/arcan-provider/src/oauth.rs @@ -0,0 +1,722 @@ +use arcan_core::error::CoreError; +use base64::Engine; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; +use rand::Rng; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::fmt; +use std::path::PathBuf; +use std::sync::RwLock; +use std::time::{SystemTime, UNIX_EPOCH}; + +use crate::credential::Credential; + +// ─── OpenAI Codex OAuth constants ───────────────────────────────── + +const OPENAI_AUTH_URL: &str = "https://auth.openai.com/authorize"; +const OPENAI_TOKEN_URL: &str = "https://auth.openai.com/oauth/token"; +const OPENAI_DEVICE_AUTH_URL: &str = "https://auth.openai.com/oauth/device/code"; +const OPENAI_CLIENT_ID: &str = "app_scp_BIqDzYAUWMiRFEih7bh0N"; +const OPENAI_REDIRECT_URI: &str = "http://127.0.0.1:8769/callback"; +const OPENAI_SCOPE: &str = "openai.public"; + +// ─── Token types ────────────────────────────────────────────────── + +/// Persisted OAuth token set. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OAuthTokenSet { + pub access_token: String, + pub refresh_token: Option, + /// Absolute expiry time (seconds since UNIX epoch). + pub expires_at: u64, + pub provider: String, +} + +impl OAuthTokenSet { + pub fn is_expired(&self) -> bool { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + // Treat as expired 60s before actual expiry to avoid edge-case failures. + now >= self.expires_at.saturating_sub(60) + } +} + +/// A credential backed by an OAuth token with automatic refresh. +pub struct OAuthCredential { + tokens: RwLock, + client_id: String, + token_url: String, +} + +impl fmt::Debug for OAuthCredential { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("OAuthCredential") + .field("provider", &self.provider_name()) + .field("token_url", &self.token_url) + .finish() + } +} + +impl OAuthCredential { + pub fn new(tokens: OAuthTokenSet, client_id: String, token_url: String) -> Self { + Self { + tokens: RwLock::new(tokens), + client_id, + token_url, + } + } + + /// Create from a stored token set using OpenAI defaults. + pub fn openai(tokens: OAuthTokenSet) -> Self { + Self::new( + tokens, + OPENAI_CLIENT_ID.to_string(), + OPENAI_TOKEN_URL.to_string(), + ) + } + + fn provider_name(&self) -> String { + self.tokens + .read() + .map(|t| t.provider.clone()) + .unwrap_or_else(|_| "unknown".to_string()) + } +} + +impl Credential for OAuthCredential { + fn auth_header(&self) -> Result { + // Auto-refresh if expired. + if self.needs_refresh() { + self.refresh()?; + } + let tokens = self + .tokens + .read() + .map_err(|e| CoreError::Auth(format!("token lock poisoned: {e}")))?; + Ok(format!("Bearer {}", tokens.access_token)) + } + + fn kind(&self) -> &str { + "oauth" + } + + fn needs_refresh(&self) -> bool { + self.tokens.read().map(|t| t.is_expired()).unwrap_or(true) + } + + fn refresh(&self) -> Result<(), CoreError> { + let refresh_token = { + let tokens = self + .tokens + .read() + .map_err(|e| CoreError::Auth(format!("token lock poisoned: {e}")))?; + match &tokens.refresh_token { + Some(rt) => rt.clone(), + None => { + return Err(CoreError::Auth( + "no refresh token available, re-login required".to_string(), + )); + } + } + }; + + let new_tokens = refresh_token_grant(&self.token_url, &self.client_id, &refresh_token)?; + + let mut tokens = self + .tokens + .write() + .map_err(|e| CoreError::Auth(format!("token lock poisoned: {e}")))?; + + tokens.access_token = new_tokens.access_token; + if let Some(rt) = new_tokens.refresh_token { + tokens.refresh_token = Some(rt); + } + tokens.expires_at = new_tokens.expires_at; + + // Persist refreshed tokens. + if let Err(e) = store_tokens(&tokens) { + tracing::warn!(%e, "failed to persist refreshed tokens"); + } + + Ok(()) + } +} + +// ─── Token storage ──────────────────────────────────────────────── + +/// Returns the credentials directory: `~/.arcan/credentials/`. +pub fn credentials_dir() -> Result { + let home = dirs::home_dir() + .ok_or_else(|| CoreError::Auth("could not determine home directory".to_string()))?; + Ok(home.join(".arcan").join("credentials")) +} + +/// Path to the credential file for a given provider. +fn credential_path(provider: &str) -> Result { + Ok(credentials_dir()?.join(format!("{provider}.json"))) +} + +/// Store tokens to `~/.arcan/credentials/.json`. +pub fn store_tokens(tokens: &OAuthTokenSet) -> Result<(), CoreError> { + let dir = credentials_dir()?; + std::fs::create_dir_all(&dir) + .map_err(|e| CoreError::Auth(format!("failed to create credentials dir: {e}")))?; + + let path = dir.join(format!("{}.json", tokens.provider)); + let json = serde_json::to_string_pretty(tokens) + .map_err(|e| CoreError::Auth(format!("failed to serialize tokens: {e}")))?; + + std::fs::write(&path, &json) + .map_err(|e| CoreError::Auth(format!("failed to write credentials: {e}")))?; + + // Set file permissions to 0600 (owner read/write only) on Unix. + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let perms = std::fs::Permissions::from_mode(0o600); + std::fs::set_permissions(&path, perms) + .map_err(|e| CoreError::Auth(format!("failed to set file permissions: {e}")))?; + } + + Ok(()) +} + +/// Load stored tokens for a given provider. +pub fn load_tokens(provider: &str) -> Result { + let path = credential_path(provider)?; + let json = std::fs::read_to_string(&path) + .map_err(|e| CoreError::Auth(format!("no stored credentials for {provider}: {e}")))?; + serde_json::from_str(&json) + .map_err(|e| CoreError::Auth(format!("invalid stored credentials for {provider}: {e}"))) +} + +/// Remove stored tokens for a given provider. +pub fn remove_tokens(provider: &str) -> Result<(), CoreError> { + let path = credential_path(provider)?; + if path.exists() { + std::fs::remove_file(&path) + .map_err(|e| CoreError::Auth(format!("failed to remove credentials: {e}")))?; + } + Ok(()) +} + +/// List providers that have stored credentials. +pub fn list_stored_providers() -> Vec { + let Ok(dir) = credentials_dir() else { + return Vec::new(); + }; + let Ok(entries) = std::fs::read_dir(dir) else { + return Vec::new(); + }; + entries + .filter_map(std::result::Result::ok) + .filter_map(|e| { + let name = e.file_name().to_string_lossy().to_string(); + name.strip_suffix(".json").map(String::from) + }) + .collect() +} + +// ─── Token refresh ──────────────────────────────────────────────── + +/// Standard OAuth 2.0 refresh_token grant. +fn refresh_token_grant( + token_url: &str, + client_id: &str, + refresh_token: &str, +) -> Result { + let client = reqwest::blocking::Client::new(); + let resp = client + .post(token_url) + .form(&[ + ("grant_type", "refresh_token"), + ("client_id", client_id), + ("refresh_token", refresh_token), + ]) + .send() + .map_err(|e| CoreError::Auth(format!("refresh request failed: {e}")))?; + + let status = resp.status(); + let body = resp + .text() + .map_err(|e| CoreError::Auth(format!("failed to read refresh response: {e}")))?; + + if !status.is_success() { + return Err(CoreError::Auth(format!( + "token refresh failed ({status}): {body}" + ))); + } + + parse_token_response(&body, "openai") +} + +// ─── PKCE helpers ───────────────────────────────────────────────── + +/// Generate a random PKCE code verifier (43-128 URL-safe characters). +fn generate_code_verifier() -> String { + let mut rng = rand::rng(); + let bytes: Vec = (0..32).map(|_| rng.random::()).collect(); + URL_SAFE_NO_PAD.encode(&bytes) +} + +/// Compute the S256 code challenge from a verifier. +fn compute_code_challenge(verifier: &str) -> String { + let hash = Sha256::digest(verifier.as_bytes()); + URL_SAFE_NO_PAD.encode(hash) +} + +// ─── OAuth flows ────────────────────────────────────────────────── + +/// Run the PKCE Authorization Code flow for OpenAI. +/// +/// 1. Generate PKCE verifier/challenge +/// 2. Open browser to authorization URL +/// 3. Start local server on port 8769 to receive callback +/// 4. Exchange authorization code for tokens +/// 5. Store tokens to disk +#[allow(clippy::print_stderr)] +pub fn pkce_login_openai() -> Result { + let code_verifier = generate_code_verifier(); + let code_challenge = compute_code_challenge(&code_verifier); + + // Build authorization URL. + let auth_url = format!( + "{}?response_type=code&client_id={}&redirect_uri={}&scope={}&code_challenge={}&code_challenge_method=S256", + OPENAI_AUTH_URL, + urlencoding::encode(OPENAI_CLIENT_ID), + urlencoding::encode(OPENAI_REDIRECT_URI), + urlencoding::encode(OPENAI_SCOPE), + urlencoding::encode(&code_challenge), + ); + + eprintln!("Opening browser for OpenAI authentication..."); + eprintln!("If the browser doesn't open, visit:\n{auth_url}\n"); + + // Try to open browser (best-effort). + let _ = open::that(&auth_url); + + // Start local callback server. + let listener = std::net::TcpListener::bind("127.0.0.1:8769") + .map_err(|e| CoreError::Auth(format!("failed to bind callback server: {e}")))?; + + eprintln!("Waiting for authorization callback on http://127.0.0.1:8769/callback ..."); + + let code = wait_for_callback(&listener)?; + + // Exchange code for tokens. + let client = reqwest::blocking::Client::new(); + let resp = client + .post(OPENAI_TOKEN_URL) + .form(&[ + ("grant_type", "authorization_code"), + ("client_id", OPENAI_CLIENT_ID), + ("code", code.as_str()), + ("redirect_uri", OPENAI_REDIRECT_URI), + ("code_verifier", code_verifier.as_str()), + ]) + .send() + .map_err(|e| CoreError::Auth(format!("token exchange failed: {e}")))?; + + let status = resp.status(); + let body = resp + .text() + .map_err(|e| CoreError::Auth(format!("failed to read token response: {e}")))?; + + if !status.is_success() { + return Err(CoreError::Auth(format!( + "token exchange failed ({status}): {body}" + ))); + } + + let tokens = parse_token_response(&body, "openai")?; + store_tokens(&tokens)?; + + eprintln!("Successfully authenticated with OpenAI!"); + Ok(tokens) +} + +/// Run the Device Code flow for OpenAI (headless/SSH environments). +/// +/// 1. Request device code +/// 2. Display verification URL and user code +/// 3. Poll token endpoint until authorized +/// 4. Store tokens to disk +#[allow(clippy::print_stderr)] +pub fn device_login_openai() -> Result { + let client = reqwest::blocking::Client::new(); + + // Step 1: Request device code. + let resp = client + .post(OPENAI_DEVICE_AUTH_URL) + .form(&[("client_id", OPENAI_CLIENT_ID), ("scope", OPENAI_SCOPE)]) + .send() + .map_err(|e| CoreError::Auth(format!("device auth request failed: {e}")))?; + + let status = resp.status(); + let body = resp + .text() + .map_err(|e| CoreError::Auth(format!("failed to read device auth response: {e}")))?; + + if !status.is_success() { + return Err(CoreError::Auth(format!( + "device authorization failed ({status}): {body}" + ))); + } + + let device_resp: DeviceCodeResponse = serde_json::from_str(&body) + .map_err(|e| CoreError::Auth(format!("invalid device auth response: {e}")))?; + + // Step 2: Display instructions. + eprintln!("\nTo authenticate, visit: {}", device_resp.verification_uri); + if let Some(ref complete_uri) = device_resp.verification_uri_complete { + eprintln!("Or open: {complete_uri}"); + } + eprintln!("Enter code: {}\n", device_resp.user_code); + + // Step 3: Poll for token. + let interval = std::time::Duration::from_secs(device_resp.interval.max(5)); + let deadline = + std::time::Instant::now() + std::time::Duration::from_secs(device_resp.expires_in.min(900)); + + loop { + std::thread::sleep(interval); + + if std::time::Instant::now() > deadline { + return Err(CoreError::Auth( + "device authorization timed out".to_string(), + )); + } + + let resp = client + .post(OPENAI_TOKEN_URL) + .form(&[ + ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), + ("client_id", OPENAI_CLIENT_ID), + ("device_code", &device_resp.device_code), + ]) + .send() + .map_err(|e| CoreError::Auth(format!("device token poll failed: {e}")))?; + + let status = resp.status(); + let body = resp + .text() + .map_err(|e| CoreError::Auth(format!("failed to read poll response: {e}")))?; + + if status.is_success() { + let tokens = parse_token_response(&body, "openai")?; + store_tokens(&tokens)?; + eprintln!("Successfully authenticated with OpenAI!"); + return Ok(tokens); + } + + // Check for specific OAuth error codes. + if let Ok(err_resp) = serde_json::from_str::(&body) { + match err_resp.error.as_str() { + "authorization_pending" => continue, + "slow_down" => { + // Back off a bit more. + std::thread::sleep(std::time::Duration::from_secs(5)); + continue; + } + "expired_token" => { + return Err(CoreError::Auth( + "device code expired, please try again".to_string(), + )); + } + "access_denied" => { + return Err(CoreError::Auth("authorization denied by user".to_string())); + } + _ => { + return Err(CoreError::Auth(format!( + "device auth error: {}", + err_resp.error_description.unwrap_or(err_resp.error) + ))); + } + } + } + + return Err(CoreError::Auth(format!( + "unexpected device token response ({status}): {body}" + ))); + } +} + +// ─── Internal helpers ───────────────────────────────────────────── + +/// Wait for the OAuth callback on a local server, extract the `code` parameter. +fn wait_for_callback(listener: &std::net::TcpListener) -> Result { + use std::io::{Read, Write}; + + let (mut stream, _) = listener + .accept() + .map_err(|e| CoreError::Auth(format!("failed to accept callback: {e}")))?; + + let mut buf = [0u8; 4096]; + let n = stream + .read(&mut buf) + .map_err(|e| CoreError::Auth(format!("failed to read callback request: {e}")))?; + let request = String::from_utf8_lossy(&buf[..n]); + + // Extract the request path from "GET /callback?code=...&state=... HTTP/1.1" + let path = request + .lines() + .next() + .and_then(|line| line.split_whitespace().nth(1)) + .ok_or_else(|| CoreError::Auth("malformed callback request".to_string()))?; + + // Parse the URL to extract the code parameter. + let full_url = format!("http://127.0.0.1:8769{path}"); + let parsed = url::Url::parse(&full_url) + .map_err(|e| CoreError::Auth(format!("failed to parse callback URL: {e}")))?; + + // Check for error in callback. + if let Some(error) = parsed.query_pairs().find(|(k, _)| k == "error") { + let desc = parsed + .query_pairs() + .find(|(k, _)| k == "error_description") + .map(|(_, v)| v.to_string()) + .unwrap_or_default(); + // Send error response to browser. + let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n

Authentication Failed

You can close this window.

"; + let _ = stream.write_all(response.as_bytes()); + return Err(CoreError::Auth(format!( + "OAuth error: {} — {desc}", + error.1 + ))); + } + + let code = parsed + .query_pairs() + .find(|(k, _)| k == "code") + .map(|(_, v)| v.to_string()) + .ok_or_else(|| CoreError::Auth("no authorization code in callback".to_string()))?; + + // Send success response to browser. + let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n

Authenticated!

You can close this window and return to the terminal.

"; + let _ = stream.write_all(response.as_bytes()); + + Ok(code) +} + +/// Parse a standard OAuth token response into our `OAuthTokenSet`. +fn parse_token_response(body: &str, provider: &str) -> Result { + let resp: TokenResponse = serde_json::from_str(body) + .map_err(|e| CoreError::Auth(format!("invalid token response: {e}")))?; + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .map(|d| d.as_secs()) + .unwrap_or(0); + + Ok(OAuthTokenSet { + access_token: resp.access_token, + refresh_token: resp.refresh_token, + expires_at: now + resp.expires_in.unwrap_or(3600), + provider: provider.to_string(), + }) +} + +// ─── Response types ─────────────────────────────────────────────── + +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + refresh_token: Option, + expires_in: Option, + #[allow(dead_code)] + token_type: Option, +} + +#[derive(Debug, Deserialize)] +struct DeviceCodeResponse { + device_code: String, + user_code: String, + verification_uri: String, + verification_uri_complete: Option, + expires_in: u64, + interval: u64, +} + +#[derive(Debug, Deserialize)] +struct OAuthErrorResponse { + error: String, + error_description: Option, +} + +// ─── URL encoding helper ────────────────────────────────────────── + +mod urlencoding { + pub fn encode(s: &str) -> String { + let mut encoded = String::new(); + for byte in s.bytes() { + match byte { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { + encoded.push(byte as char); + } + _ => { + encoded.push('%'); + encoded.push_str(&format!("{byte:02X}")); + } + } + } + encoded + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn token_set_not_expired() { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let tokens = OAuthTokenSet { + access_token: "test".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: now + 3600, + provider: "test".to_string(), + }; + assert!(!tokens.is_expired()); + } + + #[test] + fn token_set_expired() { + let tokens = OAuthTokenSet { + access_token: "test".to_string(), + refresh_token: None, + expires_at: 1000, // Long past. + provider: "test".to_string(), + }; + assert!(tokens.is_expired()); + } + + #[test] + fn token_set_expires_within_buffer() { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let tokens = OAuthTokenSet { + access_token: "test".to_string(), + refresh_token: None, + expires_at: now + 30, // 30s from now, within 60s buffer. + provider: "test".to_string(), + }; + assert!(tokens.is_expired()); + } + + #[test] + fn pkce_code_verifier_length() { + let verifier = generate_code_verifier(); + assert!(verifier.len() >= 43); + } + + #[test] + fn pkce_code_challenge_is_base64url() { + let verifier = generate_code_verifier(); + let challenge = compute_code_challenge(&verifier); + // Base64url should not contain +, /, or =. + assert!(!challenge.contains('+')); + assert!(!challenge.contains('/')); + assert!(!challenge.contains('=')); + assert_eq!(challenge.len(), 43); // SHA-256 = 32 bytes → 43 base64url chars. + } + + #[test] + fn parse_token_response_full() { + let body = r#"{ + "access_token": "at-123", + "refresh_token": "rt-456", + "expires_in": 7200, + "token_type": "Bearer" + }"#; + let tokens = parse_token_response(body, "openai").unwrap(); + assert_eq!(tokens.access_token, "at-123"); + assert_eq!(tokens.refresh_token.as_deref(), Some("rt-456")); + assert_eq!(tokens.provider, "openai"); + assert!(!tokens.is_expired()); + } + + #[test] + fn parse_token_response_minimal() { + let body = r#"{"access_token": "at-123"}"#; + let tokens = parse_token_response(body, "test").unwrap(); + assert_eq!(tokens.access_token, "at-123"); + assert!(tokens.refresh_token.is_none()); + } + + #[test] + fn token_storage_roundtrip() { + let dir = std::env::temp_dir().join("arcan-oauth-test"); + let _ = std::fs::remove_dir_all(&dir); + + // Temporarily override home dir by using direct path functions. + let tokens = OAuthTokenSet { + access_token: "at-roundtrip".to_string(), + refresh_token: Some("rt-roundtrip".to_string()), + expires_at: 9999999999, + provider: "test-roundtrip".to_string(), + }; + + // Write directly to a known path. + std::fs::create_dir_all(&dir).unwrap(); + let path = dir.join("test-roundtrip.json"); + let json = serde_json::to_string_pretty(&tokens).unwrap(); + std::fs::write(&path, &json).unwrap(); + + // Read back. + let loaded: OAuthTokenSet = + serde_json::from_str(&std::fs::read_to_string(&path).unwrap()).unwrap(); + assert_eq!(loaded.access_token, "at-roundtrip"); + assert_eq!(loaded.refresh_token.as_deref(), Some("rt-roundtrip")); + + let _ = std::fs::remove_dir_all(dir); + } + + #[test] + fn oauth_credential_kind() { + let tokens = OAuthTokenSet { + access_token: "at-test".to_string(), + refresh_token: None, + expires_at: 9999999999, + provider: "openai".to_string(), + }; + let cred = OAuthCredential::openai(tokens); + assert_eq!(cred.kind(), "oauth"); + } + + #[test] + fn oauth_credential_auth_header_with_valid_token() { + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + let tokens = OAuthTokenSet { + access_token: "at-valid".to_string(), + refresh_token: None, + expires_at: now + 3600, + provider: "openai".to_string(), + }; + let cred = OAuthCredential::openai(tokens); + assert_eq!(cred.auth_header().unwrap(), "Bearer at-valid"); + } + + #[test] + fn urlencoding_basic() { + assert_eq!(urlencoding::encode("hello"), "hello"); + assert_eq!(urlencoding::encode("hello world"), "hello%20world"); + assert_eq!(urlencoding::encode("a+b"), "a%2Bb"); + } + + #[test] + fn list_stored_providers_empty() { + // This just tests that the function doesn't panic on empty/missing dirs. + // In CI or fresh machines, there may be no stored providers. + let _ = list_stored_providers(); + } +} diff --git a/crates/arcan-provider/src/openai.rs b/crates/arcan-provider/src/openai.rs index 323ba75..c1ce48e 100644 --- a/crates/arcan-provider/src/openai.rs +++ b/crates/arcan-provider/src/openai.rs @@ -6,15 +6,17 @@ use arcan_core::protocol::{ use arcan_core::runtime::{Provider, ProviderRequest}; use serde::{Deserialize, Serialize}; use serde_json::{Value, json}; +use std::sync::Arc; + +use crate::credential::{ApiKeyCredential, Credential}; /// Configuration for an OpenAI-compatible provider. /// /// Works with: OpenAI, Ollama, Together, Groq, LM Studio, vLLM, and any /// server that implements the OpenAI chat completions API. -#[derive(Debug, Clone)] pub struct OpenAiConfig { - /// API key (empty string for local servers like Ollama). - pub api_key: String, + /// Credential for API authentication (API key or OAuth token). + pub credential: Arc, /// Model name (e.g., "gpt-4o", "llama3.1", "qwen2.5"). pub model: String, /// Maximum tokens for model response. @@ -27,6 +29,18 @@ pub struct OpenAiConfig { pub enable_streaming: bool, } +impl std::fmt::Debug for OpenAiConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OpenAiConfig") + .field("credential", &self.credential.kind()) + .field("model", &self.model) + .field("max_tokens", &self.max_tokens) + .field("base_url", &self.base_url) + .field("provider_name", &self.provider_name) + .finish() + } +} + impl OpenAiConfig { /// Create config for OpenAI from environment variables. pub fn openai_from_env() -> Result { @@ -43,7 +57,7 @@ impl OpenAiConfig { .unwrap_or_else(|_| "https://api.openai.com".to_string()); Ok(Self { - api_key, + credential: Arc::new(ApiKeyCredential::new(api_key)), model, max_tokens, base_url, @@ -64,7 +78,7 @@ impl OpenAiConfig { .unwrap_or_else(|_| "http://localhost:11434".to_string()); Ok(Self { - api_key: String::new(), + credential: Arc::new(ApiKeyCredential::new(String::new())), model, max_tokens, base_url, @@ -102,7 +116,7 @@ impl OpenAiConfig { .unwrap_or_else(|| "https://api.openai.com".to_string()); Ok(Self { - api_key, + credential: Arc::new(ApiKeyCredential::new(api_key)), model, max_tokens, base_url, @@ -137,7 +151,7 @@ impl OpenAiConfig { .unwrap_or_else(|| "http://localhost:11434".to_string()); Ok(Self { - api_key: String::new(), + credential: Arc::new(ApiKeyCredential::new(String::new())), model, max_tokens, base_url, @@ -145,6 +159,18 @@ impl OpenAiConfig { enable_streaming: enable_streaming_override.unwrap_or(true), }) } + + /// Create config with an OAuth credential. + pub fn from_oauth(credential: Arc, model: String) -> Self { + Self { + credential, + model, + max_tokens: 4096, + base_url: "https://api.openai.com".to_string(), + provider_name: "openai".to_string(), + enable_streaming: false, + } + } } /// Provider implementation for any OpenAI-compatible chat completions API. @@ -256,6 +282,7 @@ impl OpenAiCompatibleProvider { } /// Execute with retry logic for transient errors (429, 5xx). + /// On 401 Unauthorized, attempts to refresh the credential once before failing. fn execute_with_retry( &self, body: &Value, @@ -264,6 +291,7 @@ impl OpenAiCompatibleProvider { ) -> Result { let mut last_error = None; let base_delay = std::time::Duration::from_millis(200); + let mut refreshed_on_401 = false; for attempt in 0..=max_retries { if attempt > 0 { @@ -276,10 +304,9 @@ impl OpenAiCompatibleProvider { .post(url) .header("content-type", "application/json"); - // Only add Authorization header if API key is non-empty - if !self.config.api_key.is_empty() { - request = - request.header("authorization", format!("Bearer {}", self.config.api_key)); + // Use credential for auth header (supports API keys, OAuth tokens, etc.) + if let Ok(header_value) = self.config.credential.auth_header() { + request = request.header("authorization", header_value); } let response = match request.json(body).send() { @@ -296,6 +323,16 @@ impl OpenAiCompatibleProvider { .text() .map_err(|e| CoreError::Provider(format!("failed to read response: {e}")))?; + // On 401, try refreshing the credential once. + if status.as_u16() == 401 + && !refreshed_on_401 + && self.config.credential.needs_refresh() + && self.config.credential.refresh().is_ok() + { + refreshed_on_401 = true; + continue; + } + // Retry on transient errors if (status.as_u16() == 429 || status.is_server_error()) && attempt < max_retries { last_error = Some(format!("{status}: {response_text}")); @@ -351,9 +388,9 @@ impl OpenAiCompatibleProvider { .post(&url) .header("content-type", "application/json"); - if !self.config.api_key.is_empty() { - http_request = - http_request.header("authorization", format!("Bearer {}", self.config.api_key)); + // Add authorization header from credential (API key or OAuth token) + if let Ok(auth_header) = self.config.credential.auth_header() { + http_request = http_request.header("authorization", auth_header); } let response = http_request @@ -644,7 +681,7 @@ mod tests { fn test_config() -> OpenAiConfig { OpenAiConfig { - api_key: "test-key".to_string(), + credential: Arc::new(ApiKeyCredential::new("test-key".to_string())), model: "gpt-4o".to_string(), max_tokens: 4096, base_url: "http://localhost:8080".to_string(), @@ -809,15 +846,17 @@ mod tests { #[test] fn ollama_config_defaults() { // This just tests the config structure, not actual env vars + let cred = ApiKeyCredential::new(String::new()); let config = OpenAiConfig { - api_key: String::new(), + credential: Arc::new(cred), model: "llama3.2".to_string(), max_tokens: 4096, base_url: "http://localhost:11434".to_string(), provider_name: "ollama".to_string(), enable_streaming: true, }; - assert!(config.api_key.is_empty()); + // Empty credential returns error on auth_header (expected for local servers) + assert!(config.credential.auth_header().is_err()); assert_eq!(config.base_url, "http://localhost:11434"); assert!(config.enable_streaming); } diff --git a/crates/arcan/src/main.rs b/crates/arcan/src/main.rs index 283247f..001ec2e 100644 --- a/crates/arcan/src/main.rs +++ b/crates/arcan/src/main.rs @@ -198,6 +198,20 @@ enum Command { Status, /// Stop the running daemon Stop, + /// Authenticate with an LLM provider via OAuth + Login { + /// Provider to authenticate with (e.g. "openai") + provider: String, + + /// Use device code flow instead of browser-based PKCE (for headless environments) + #[arg(long)] + device: bool, + }, + /// Remove stored OAuth credentials for a provider + Logout { + /// Provider to log out of (e.g. "openai") + provider: String, + }, } #[derive(Subcommand)] @@ -292,6 +306,18 @@ async fn resolve_session( "default".to_owned() } +/// Try to create an OpenAI provider from stored OAuth credentials. +fn try_openai_oauth_provider() -> Option> { + let tokens = arcan_provider::oauth::load_tokens("openai").ok()?; + let credential = Arc::new(arcan_provider::oauth::OAuthCredential::openai(tokens)); + let model = std::env::var("OPENAI_MODEL").unwrap_or_else(|_| "gpt-4o".to_string()); + let config = arcan_provider::openai::OpenAiConfig::from_oauth(credential, model.clone()); + tracing::info!(%model, "Provider: OpenAI (OAuth)"); + Some(Arc::new( + arcan_provider::openai::OpenAiCompatibleProvider::new(config), + )) +} + /// Build provider from resolved configuration. fn build_provider(resolved: &ResolvedConfig) -> anyhow::Result> { let pc = resolved.provider_config.as_ref(); @@ -301,7 +327,11 @@ fn build_provider(resolved: &ResolvedConfig) -> anyhow::Result tracing::warn!("Provider: MockProvider (forced via config)"); Ok(Arc::new(MockProvider)) } - "openai" => { + "openai" | "codex" | "openai-codex" => { + // Try OAuth credential first, then fall back to env var / resolved config. + if let Some(p) = try_openai_oauth_provider() { + return Ok(p); + } let config = arcan_provider::openai::OpenAiConfig::openai_from_resolved( resolved.model.as_deref(), pc.and_then(|p| p.base_url.as_deref()), @@ -338,6 +368,8 @@ fn build_provider(resolved: &ResolvedConfig) -> anyhow::Result if let Ok(config) = AnthropicConfig::from_env() { tracing::info!(model = %config.model, "Provider: Anthropic (auto-detected)"); Ok(Arc::new(AnthropicProvider::new(config))) + } else if let Some(p) = try_openai_oauth_provider() { + Ok(p) } else if let Ok(config) = arcan_provider::openai::OpenAiConfig::openai_from_env() { tracing::info!(model = %config.model, "Provider: OpenAI (auto-detected)"); Ok(Arc::new( @@ -670,6 +702,34 @@ async fn run_status(data_dir: &Path, resolved: &ResolvedConfig) -> anyhow::Resul Ok(()) } +fn run_login(provider: &str, device: bool) -> anyhow::Result<()> { + match provider { + "openai" | "codex" | "openai-codex" => { + if device { + arcan_provider::oauth::device_login_openai().map_err(|e| anyhow::anyhow!("{e}"))?; + } else { + arcan_provider::oauth::pkce_login_openai().map_err(|e| anyhow::anyhow!("{e}"))?; + } + Ok(()) + } + _ => Err(anyhow::anyhow!( + "Unknown provider '{provider}'. Supported: openai" + )), + } +} + +#[allow(clippy::print_stderr)] +fn run_logout(provider: &str) -> anyhow::Result<()> { + // Normalize provider name for credential lookup. + let normalized = match provider { + "codex" | "openai-codex" => "openai", + other => other, + }; + arcan_provider::oauth::remove_tokens(normalized).map_err(|e| anyhow::anyhow!("{e}"))?; + eprintln!("Logged out of {provider}"); + Ok(()) +} + fn main() -> anyhow::Result<()> { let cli = Cli::parse(); let data_dir = resolve_data_dir(&cli.data_dir)?; @@ -722,6 +782,8 @@ fn main() -> anyhow::Result<()> { .build()?; runtime.block_on(run_chat(data_dir, &resolved, session, url)) } + Some(Command::Login { provider, device }) => run_login(&provider, device), + Some(Command::Logout { provider }) => run_logout(&provider), Some(Command::Run { message, session,