diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 357ef42..b967643 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -11,6 +11,7 @@ Prereqs: Clone and install tools: ```bash +mise trust mise install ``` diff --git a/Cargo.lock b/Cargo.lock index 8cd98a6..7fdb126 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -108,6 +108,15 @@ version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + [[package]] name = "braintrust-sdk-rust" version = "0.1.0-alpha.2" @@ -133,17 +142,22 @@ name = "bt" version = "0.1.2" dependencies = [ "anyhow", + "base64", "braintrust-sdk-rust", + "chrono", "clap", "crossterm", "dialoguer", "dotenvy", "indicatif", "open", + "rand 0.8.5", "ratatui", "reqwest", "serde", "serde_json", + "serial_test", + "sha2", "strip-ansi-escapes", "tokio", "unicode-width 0.1.14", @@ -292,6 +306,15 @@ version = "0.8.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + [[package]] name = "crossterm" version = "0.28.1" @@ -317,6 +340,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "darling" version = "0.23.0" @@ -365,6 +398,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "displaydoc" version = "0.2.5" @@ -535,6 +578,16 @@ dependencies = [ "thread_local", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.17" @@ -1137,7 +1190,7 @@ dependencies = [ "bytes", "getrandom 0.3.4", "lru-slab", - "rand", + "rand 0.9.2", "ring", "rustc-hash", "rustls", @@ -1178,14 +1231,35 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + [[package]] name = "rand" version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ - "rand_chacha", - "rand_core", + "rand_chacha 0.9.0", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", ] [[package]] @@ -1195,7 +1269,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.5", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.17", ] [[package]] @@ -1371,12 +1454,27 @@ version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" +[[package]] +name = "scc" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46e6f046b7fef48e2660c57ed794263155d713de679057f2d0c169bfc6e756cc" +dependencies = [ + "sdd", +] + [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sdd" +version = "3.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca" + [[package]] name = "serde" version = "1.0.228" @@ -1443,6 +1541,43 @@ dependencies = [ "serde", ] +[[package]] +name = "serial_test" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d0b343e184fc3b7bb44dff0705fffcf4b3756ba6aff420dddd8b24ca145e555" +dependencies = [ + "futures-executor", + "futures-util", + "log", + "once_cell", + "parking_lot", + "scc", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "3.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f50427f258fb77356e4cd4aa0e87e2bd2c66dbcee41dc405282cae2bfc26c83" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shell-words" version = "1.1.1" @@ -1813,6 +1948,12 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicode-ident" version = "1.0.23" @@ -1896,6 +2037,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + [[package]] name = "vte" version = "0.14.1" diff --git a/Cargo.toml b/Cargo.toml index 9cd362c..4469134 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ repository = "https://github.com/braintrustdata/bt" [dependencies] anyhow = "1.0.89" braintrust-sdk-rust = { git = "https://github.com/braintrustdata/braintrust-sdk-rust", rev = "33ee4c8b8c1e4cd11961f7572100298caa3a39d0" } +chrono = { version = "0.4", features = ["serde"] } clap = { version = "4.5.20", features = ["derive", "env"] } crossterm = "0.28.1" indicatif = "0.17.8" @@ -25,6 +26,12 @@ dialoguer = { version = "0.11", features = ["fuzzy-select"] } dotenvy = "0.15" open = "5" urlencoding = "2" +rand = "0.8" +sha2 = "0.10" +base64 = "0.22" + +[dev-dependencies] +serial_test = "3.2" [profile.dist] inherits = "release" diff --git a/src/args.rs b/src/args.rs index d5e928e..f29cdc9 100644 --- a/src/args.rs +++ b/src/args.rs @@ -11,6 +11,10 @@ pub struct BaseArgs { #[arg(short = 'p', long, env = "BRAINTRUST_DEFAULT_PROJECT")] pub project: Option, + /// Auth profile to use + #[arg(long, env = "BRAINTRUST_PROFILE", default_value = "DEFAULT")] + pub profile: String, + /// Override stored API key (or via BRAINTRUST_API_KEY) #[arg(long, env = "BRAINTRUST_API_KEY")] pub api_key: Option, diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..1b25ed3 --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,924 @@ +use anyhow::{anyhow, Context, Result}; +use chrono::{Duration, Utc}; +use clap::{Args, Subcommand}; +use dialoguer::Input; +use serde::{Deserialize, Serialize}; +use std::net::TcpListener; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpListener as TokioTcpListener; + +use crate::config::{self, Profile}; + +#[derive(Debug, Clone, Args)] +pub struct AuthArgs { + #[command(subcommand)] + pub command: AuthSubcommand, +} + +#[derive(Debug, Clone, Subcommand)] +pub enum AuthSubcommand { + /// Log in to Braintrust using OAuth2 or API key + Login(LoginArgs), + /// Display current access token and TTL + Token(TokenArgs), + /// Log out (remove profile) + Logout(LogoutArgs), +} + +#[derive(Debug, Clone, Args)] +pub struct LoginArgs { + /// Profile name to use + #[arg(long, default_value = "DEFAULT")] + pub profile: String, + + /// API URL (defaults to https://api.braintrust.dev) + #[arg(long)] + pub api_url: Option, + + /// Use API key instead of OAuth2 (for headless/CI) + #[arg(long)] + pub api_key: bool, + + /// Optional default project for this profile + #[arg(long)] + pub project: Option, +} + +#[derive(Debug, Clone, Args)] +pub struct TokenArgs { + /// Profile name to use + #[arg(long, default_value = "DEFAULT")] + pub profile: String, + + /// Output as JSON + #[arg(short = 'j', long)] + pub json: bool, + + /// Show full token (default: masked) + #[arg(long)] + pub show: bool, +} + +#[derive(Debug, Clone, Args)] +pub struct LogoutArgs { + /// Profile name to remove + #[arg(long, default_value = "DEFAULT")] + pub profile: String, +} + +#[derive(Debug, Deserialize)] +struct OAuthDiscovery { + authorization_endpoint: String, + token_endpoint: String, +} + +#[derive(Debug, Serialize)] +struct TokenRequest { + grant_type: String, + code: String, + redirect_uri: String, + client_id: String, + code_verifier: String, +} + +#[derive(Debug, Serialize)] +struct RefreshTokenRequest { + grant_type: String, + refresh_token: String, + client_id: String, +} + +#[derive(Debug, Deserialize)] +pub struct TokenResponse { + pub access_token: String, + #[serde(default)] + pub refresh_token: Option, + #[serde(default)] + pub expires_in: Option, +} + +pub async fn run(args: AuthArgs) -> Result<()> { + match args.command { + AuthSubcommand::Login(args) => run_login(args).await, + AuthSubcommand::Token(args) => run_token(args), + AuthSubcommand::Logout(args) => run_logout(args), + } +} + +async fn run_login(args: LoginArgs) -> Result<()> { + // Prompt for API URL if not provided + let api_url = if let Some(url) = args.api_url { + url + } else { + Input::::new() + .with_prompt("API URL") + .default("https://api.braintrust.dev".to_string()) + .interact_text()? + }; + + let api_url = api_url.trim_end_matches('/').to_string(); + + // Choose authentication method + if args.api_key { + login_with_api_key(&args.profile, &api_url, args.project.as_deref()).await?; + } else { + login_with_oauth2(&args.profile, &api_url, args.project.as_deref()).await?; + } + + println!("✓ Successfully logged in to profile '{}'", args.profile); + Ok(()) +} + +async fn login_with_api_key( + profile_name: &str, + api_url: &str, + project: Option<&str>, +) -> Result<()> { + let api_key: String = dialoguer::Password::new() + .with_prompt("Enter your API key") + .interact()?; + + let api_key = api_key.trim().to_string(); + + if api_key.is_empty() { + anyhow::bail!("API key cannot be empty"); + } + + // Try to fetch org info + let org_name = fetch_org_name(api_url, &api_key).await.ok(); + + let profile = Profile { + api_url: api_url.to_string(), + access_token: api_key, + refresh_token: None, + expires_at: None, + org_name, + project: project.map(|s| s.to_string()), + }; + + config::save_profile(profile_name, profile)?; + Ok(()) +} + +async fn login_with_oauth2(profile_name: &str, api_url: &str, project: Option<&str>) -> Result<()> { + // Discover OAuth endpoints + let discovery = discover_oauth_endpoints(api_url).await?; + + // Generate PKCE challenge + let code_verifier = generate_code_verifier(); + let code_challenge = generate_code_challenge(&code_verifier); + + // Start local server for callback + let listener = TcpListener::bind("127.0.0.1:0") + .context("failed to bind local server for OAuth callback")?; + let port = listener.local_addr()?.port(); + let redirect_uri = format!("http://127.0.0.1:{}/callback", port); + + // Build authorization URL + let client_id = "bt-cli"; // TODO: Use proper client ID from Braintrust + let auth_url = format!( + "{}?response_type=code&client_id={}&redirect_uri={}&code_challenge={}&code_challenge_method=S256&scope=openid%20profile%20email", + discovery.authorization_endpoint, + urlencoding::encode(client_id), + urlencoding::encode(&redirect_uri), + urlencoding::encode(&code_challenge) + ); + + println!("Opening browser for authentication..."); + println!("If browser doesn't open, visit: {}", auth_url); + + // Open browser + if let Err(e) = open::that(&auth_url) { + eprintln!("Warning: failed to open browser: {}", e); + } + + // Wait for callback + println!("Waiting for authentication callback..."); + let code = receive_callback(listener).await?; + + // Exchange code for tokens + let token_response = exchange_code_for_token( + &discovery.token_endpoint, + &code, + &redirect_uri, + client_id, + &code_verifier, + ) + .await?; + + // Calculate expiry + let expires_at = token_response + .expires_in + .map(|secs| Utc::now() + Duration::seconds(secs)); + + // OAuth2 JWTs work without x-bt-org-name header, so we leave org_name as None + let profile = Profile { + api_url: api_url.to_string(), + access_token: token_response.access_token, + refresh_token: token_response.refresh_token, + expires_at, + org_name: None, // OAuth2 tokens don't need this + project: project.map(|s| s.to_string()), + }; + + config::save_profile(profile_name, profile)?; + Ok(()) +} + +async fn discover_oauth_endpoints(api_url: &str) -> Result { + let discovery_url = format!("{}/.well-known/oauth-authorization-server", api_url); + + let client = reqwest::Client::new(); + let response = client + .get(&discovery_url) + .send() + .await + .context("failed to discover OAuth endpoints")?; + + if !response.status().is_success() { + anyhow::bail!( + "OAuth discovery failed ({}): {}", + response.status(), + response.text().await.unwrap_or_default() + ); + } + + response + .json::() + .await + .context("failed to parse OAuth discovery response") +} + +fn generate_code_verifier() -> String { + use rand::Rng; + const CHARSET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~"; + let mut rng = rand::thread_rng(); + (0..128) + .map(|_| { + let idx = rng.gen_range(0..CHARSET.len()); + CHARSET[idx] as char + }) + .collect() +} + +fn generate_code_challenge(verifier: &str) -> String { + use sha2::{Digest, Sha256}; + let mut hasher = Sha256::new(); + hasher.update(verifier.as_bytes()); + let hash = hasher.finalize(); + base64_url_encode(&hash) +} + +fn base64_url_encode(input: &[u8]) -> String { + use base64::Engine; + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(input) +} + +async fn receive_callback(listener: TcpListener) -> Result { + listener.set_nonblocking(true)?; + let listener = TokioTcpListener::from_std(listener)?; + + let (mut stream, _) = listener.accept().await?; + + let mut buffer = vec![0u8; 4096]; + let n = stream.read(&mut buffer).await?; + let request = String::from_utf8_lossy(&buffer[..n]); + + // Parse the code from the request + let code = parse_code_from_request(&request) + .ok_or_else(|| anyhow!("failed to parse authorization code from callback"))?; + + // Send success response + let response = "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\n\r\n\ +

Authentication successful!

\ +

You can close this window and return to the terminal.

"; + stream.write_all(response.as_bytes()).await?; + stream.flush().await?; + + Ok(code) +} + +fn parse_code_from_request(request: &str) -> Option { + // Parse GET /callback?code=... HTTP/1.1 + let first_line = request.lines().next()?; + let parts: Vec<&str> = first_line.split_whitespace().collect(); + if parts.len() < 2 { + return None; + } + + let path = parts[1]; + if let Some(query_start) = path.find('?') { + let query = &path[query_start + 1..]; + for param in query.split('&') { + if let Some((key, value)) = param.split_once('=') { + if key == "code" { + return Some(urlencoding::decode(value).ok()?.into_owned()); + } + } + } + } + None +} + +async fn exchange_code_for_token( + token_endpoint: &str, + code: &str, + redirect_uri: &str, + client_id: &str, + code_verifier: &str, +) -> Result { + let client = reqwest::Client::new(); + + let params = TokenRequest { + grant_type: "authorization_code".to_string(), + code: code.to_string(), + redirect_uri: redirect_uri.to_string(), + client_id: client_id.to_string(), + code_verifier: code_verifier.to_string(), + }; + + let response = client + .post(token_endpoint) + .form(¶ms) + .send() + .await + .context("failed to exchange code for token")?; + + if !response.status().is_success() { + anyhow::bail!( + "token exchange failed ({}): {}", + response.status(), + response.text().await.unwrap_or_default() + ); + } + + response + .json::() + .await + .context("failed to parse token response") +} + +pub async fn refresh_token_if_needed(profile_name: &str, mut profile: Profile) -> Result { + // Check if token is expired + if let Some(expires_at) = profile.expires_at { + if Utc::now() >= expires_at { + // Token is expired, try to refresh + if let Some(refresh_token) = &profile.refresh_token { + println!("Access token expired, refreshing..."); + let token_response = refresh_access_token(&profile.api_url, refresh_token).await?; + + profile.access_token = token_response.access_token; + if let Some(new_refresh) = token_response.refresh_token { + profile.refresh_token = Some(new_refresh); + } + if let Some(expires_in) = token_response.expires_in { + profile.expires_at = Some(Utc::now() + Duration::seconds(expires_in)); + } + + // Save updated profile + config::save_profile(profile_name, profile.clone())?; + println!("✓ Token refreshed successfully"); + } else { + anyhow::bail!("Access token expired and no refresh token available. Please run `bt auth login --profile {}`", profile_name); + } + } + } + + Ok(profile) +} + +async fn refresh_access_token(api_url: &str, refresh_token: &str) -> Result { + // Discover token endpoint + let discovery = discover_oauth_endpoints(api_url).await?; + + let client = reqwest::Client::new(); + let client_id = "bt-cli"; // TODO: Use proper client ID + + let params = RefreshTokenRequest { + grant_type: "refresh_token".to_string(), + refresh_token: refresh_token.to_string(), + client_id: client_id.to_string(), + }; + + let response = client + .post(&discovery.token_endpoint) + .form(¶ms) + .send() + .await + .context("failed to refresh token")?; + + if !response.status().is_success() { + anyhow::bail!( + "token refresh failed ({}): {}", + response.status(), + response.text().await.unwrap_or_default() + ); + } + + response + .json::() + .await + .context("failed to parse refresh token response") +} + +async fn fetch_org_name(api_url: &str, token: &str) -> Result { + // Try to discover the user's org name automatically + if let Ok(org) = discover_user_org(api_url, token).await { + println!("✓ Discovered organization: {}", org); + return Ok(org); + } + + // Try to get a list of available orgs + let orgs = fetch_available_orgs(api_url, token).await; + + if !orgs.is_empty() { + if orgs.len() == 1 { + // Only one org found, use it automatically + println!("✓ Found organization: {}", orgs[0]); + return Ok(orgs[0].clone()); + } + + // Multiple orgs found, let user select + println!("Found {} organizations:", orgs.len()); + let selection = dialoguer::Select::new() + .with_prompt("Select your organization") + .items(&orgs) + .default(0) + .interact()?; + + return Ok(orgs[selection].clone()); + } + + // Fallback: prompt user for org name with helpful guidance + eprintln!("\nCould not automatically discover your organization."); + eprintln!("You can find your org name in the Braintrust web app URL:"); + eprintln!(" https://www.braintrust.dev/app/YOUR-ORG-NAME/..."); + eprintln!(); + + let org_name: String = dialoguer::Input::new() + .with_prompt("Organization name") + .allow_empty(false) + .interact_text()?; + + Ok(org_name.trim().to_string()) +} + +async fn discover_user_org(api_url: &str, token: &str) -> Result { + let client = reqwest::Client::new(); + + // Strategy 1: Fetch API keys to get org_id (works as org_name in headers!) + let response = client + .get(format!("{}/v1/api_key", api_url)) + .bearer_auth(token) + .send() + .await?; + + if response.status().is_success() { + #[derive(Deserialize)] + struct ApiKeyResponse { + objects: Vec, + } + + #[derive(Deserialize)] + struct ApiKeyObject { + org_id: String, + } + + if let Ok(keys) = response.json::().await { + if let Some(key) = keys.objects.first() { + // The org_id can be used in the x-bt-org-name header! + return Ok(key.org_id.clone()); + } + } + } + + // Strategy 2: Try to fetch a project - error messages sometimes include [user_org=...] + let response = client + .get(format!( + "{}/v1/organization/00000000-0000-0000-0000-000000000000", + api_url + )) + .bearer_auth(token) + .send() + .await?; + + let error_text = response.text().await.unwrap_or_default(); + + // Extract org name from error message like [user_org=Braintrust Demos] + if let Some(org_start) = error_text.find("user_org=") { + let org_part = &error_text[org_start + 9..]; + if let Some(org_end) = org_part.find(']') { + return Ok(org_part[..org_end].to_string()); + } + } + + anyhow::bail!("Could not extract org from API response") +} + +async fn fetch_available_orgs(api_url: &str, token: &str) -> Vec { + let client = reqwest::Client::new(); + let mut orgs = std::collections::HashSet::new(); + + // Try to fetch projects and extract org names + let response = client + .get(format!("{}/v1/project?limit=100", api_url)) + .bearer_auth(token) + .send() + .await; + + if let Ok(resp) = response { + if resp.status().is_success() { + #[derive(Deserialize)] + struct ProjectResponse { + objects: Vec, + } + + #[derive(Deserialize)] + struct ProjectObject { + org_id: String, + } + + if let Ok(projects) = resp.json::().await { + // For each unique org_id, try to get the org name + let unique_org_ids: std::collections::HashSet<_> = + projects.objects.iter().map(|p| p.org_id.clone()).collect(); + + for org_id in unique_org_ids { + // Try to fetch org - even if it fails, we might get org name from error + if let Ok(org_resp) = client + .get(format!("{}/v1/organization/{}", api_url, org_id)) + .bearer_auth(token) + .send() + .await + { + if let Ok(error_text) = org_resp.text().await { + // Try to extract from error message + if let Some(org_start) = error_text.find("user_org=") { + let org_part = &error_text[org_start + 9..]; + if let Some(org_end) = org_part.find(']') { + orgs.insert(org_part[..org_end].to_string()); + } + } + } + } + } + } + } + } + + let mut org_list: Vec = orgs.into_iter().collect(); + org_list.sort(); + org_list +} + +fn run_token(args: TokenArgs) -> Result<()> { + let profile = config::get_profile(&args.profile)?.ok_or_else(|| { + anyhow!( + "Profile '{}' not found. Run `bt auth login --profile {}`", + args.profile, + args.profile + ) + })?; + + if args.json { + let ttl_seconds = profile + .expires_at + .map(|exp| (exp - Utc::now()).num_seconds()); + + let token_value = if args.show { + profile.access_token.clone() + } else { + mask_token(&profile.access_token) + }; + + let output = serde_json::json!({ + "token": token_value, + "expires_at": profile.expires_at, + "ttl_seconds": ttl_seconds, + }); + + println!("{}", serde_json::to_string_pretty(&output)?); + } else { + let token_display = if args.show { + &profile.access_token + } else { + &mask_token(&profile.access_token) + }; + + println!("Token: {}", token_display); + + if let Some(expires_at) = profile.expires_at { + let ttl = expires_at - Utc::now(); + if ttl.num_seconds() > 0 { + println!("Expires: {} (in {} seconds)", expires_at, ttl.num_seconds()); + } else { + println!("Expires: {} (EXPIRED)", expires_at); + } + } else { + println!("Expires: Never"); + } + } + + Ok(()) +} + +fn run_logout(args: LogoutArgs) -> Result<()> { + let removed = config::delete_profile(&args.profile)?; + + if removed { + println!("✓ Logged out from profile '{}'", args.profile); + } else { + println!("Profile '{}' not found (already logged out)", args.profile); + } + + Ok(()) +} + +fn mask_token(token: &str) -> String { + if token.len() <= 8 { + return "***".to_string(); + } + format!("{}...{}", &token[..4], &token[token.len() - 4..]) +} + +#[cfg(test)] +mod tests { + use super::*; + use serial_test::serial; + use std::time::{SystemTime, UNIX_EPOCH}; + + #[test] + fn test_generate_code_verifier() { + let verifier = generate_code_verifier(); + // PKCE verifier should be 43-128 characters + assert!(verifier.len() >= 43 && verifier.len() <= 128); + // Should be URL-safe characters only (RFC 7636 allows: A-Z a-z 0-9 - . _ ~) + assert!(verifier.chars().all(|c| c.is_alphanumeric() + || c == '-' + || c == '.' + || c == '_' + || c == '~')); + } + + #[test] + fn test_generate_code_verifier_randomness() { + let v1 = generate_code_verifier(); + let v2 = generate_code_verifier(); + // Two calls should produce different verifiers + assert_ne!(v1, v2); + } + + #[test] + fn test_generate_code_challenge() { + let verifier = "test_verifier_string"; + let challenge = generate_code_challenge(verifier); + + // Should be base64url encoded (no padding) + assert!(!challenge.contains('=')); + assert!(!challenge.contains('+')); + assert!(!challenge.contains('/')); + + // SHA256 hash should produce 32 bytes -> 43 base64url chars (without padding) + assert_eq!(challenge.len(), 43); + } + + #[test] + fn test_generate_code_challenge_deterministic() { + let verifier = "same_verifier"; + let c1 = generate_code_challenge(verifier); + let c2 = generate_code_challenge(verifier); + // Same verifier should produce same challenge + assert_eq!(c1, c2); + } + + #[test] + fn test_base64_url_encode() { + let input = b"hello world"; + let encoded = base64_url_encode(input); + + // Should not contain standard base64 chars + assert!(!encoded.contains('+')); + assert!(!encoded.contains('/')); + assert!(!encoded.contains('=')); + } + + #[test] + fn test_base64_url_encode_known_value() { + // SHA256 of "test" in base64url + let input = b"\x9f\x86\xd0\x81\x88\x4c\x7d\x65\x9a\x2f\xea\xa0\xc5\x5a\xd0\x15\xa3\xbf\x4f\x1b\x2b\x0b\x82\x2c\xd1\x5d\x6c\x15\xb0\xf0\x0a\x08"; + let encoded = base64_url_encode(input); + assert_eq!(encoded, "n4bQgYhMfWWaL-qgxVrQFaO_TxsrC4Is0V1sFbDwCgg"); + } + + #[test] + fn test_parse_code_from_request_valid() { + let request = "GET /?code=test_code_12345&state=xyz HTTP/1.1\r\nHost: localhost\r\n\r\n"; + let code = parse_code_from_request(request); + assert_eq!(code, Some("test_code_12345".to_string())); + } + + #[test] + fn test_parse_code_from_request_no_code() { + let request = "GET /?state=xyz HTTP/1.1\r\nHost: localhost\r\n\r\n"; + let code = parse_code_from_request(request); + assert_eq!(code, None); + } + + #[test] + fn test_parse_code_from_request_with_multiple_params() { + let request = "GET /?foo=bar&code=my_auth_code&state=abc HTTP/1.1\r\n"; + let code = parse_code_from_request(request); + assert_eq!(code, Some("my_auth_code".to_string())); + } + + #[test] + fn test_parse_code_from_request_invalid() { + let request = "POST / HTTP/1.1\r\nHost: localhost\r\n\r\n"; + let code = parse_code_from_request(request); + assert_eq!(code, None); + } + + #[test] + fn test_mask_token_long() { + let token = "brt_1234567890abcdef"; + let masked = mask_token(token); + assert_eq!(masked, "brt_...cdef"); + } + + #[test] + fn test_mask_token_short() { + let token = "short"; + let masked = mask_token(token); + assert_eq!(masked, "***"); + } + + #[test] + fn test_mask_token_exact_8() { + let token = "12345678"; + let masked = mask_token(token); + assert_eq!(masked, "***"); + } + + #[test] + fn test_mask_token_jwt() { + let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIn0.dozjgNryP4J3jVmNHl0w5N_XgL0n3I9PlFUP0THsR8U"; + let masked = mask_token(token); + assert_eq!(masked, "eyJh...sR8U"); + } + + #[tokio::test] + #[serial] + async fn test_refresh_token_if_needed_not_expired() { + // Create a temp config + let config_path = std::env::temp_dir().join(format!( + "bt-auth-test-{}-{}.json", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos() + )); + std::env::set_var("BT_CONFIG", &config_path); + + let future_time = Utc::now() + Duration::hours(1); + let profile = Profile { + api_url: "https://api.test.com".to_string(), + access_token: "valid_token".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: Some(future_time), + org_name: None, + project: None, + }; + + config::save_profile("test", profile.clone()).unwrap(); + + // Should not refresh since token is still valid + let result = refresh_token_if_needed("test", profile.clone()).await; + assert!(result.is_ok()); + let refreshed = result.unwrap(); + + // Token should be unchanged + assert_eq!(refreshed.access_token, "valid_token"); + + // Cleanup + std::fs::remove_file(&config_path).ok(); + std::env::remove_var("BT_CONFIG"); + } + + #[tokio::test] + #[serial] + async fn test_refresh_token_if_needed_expired_no_refresh_token() { + let config_path = std::env::temp_dir().join(format!( + "bt-auth-test-{}-{}.json", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos() + )); + std::env::set_var("BT_CONFIG", &config_path); + + let past_time = Utc::now() - Duration::hours(1); + let profile = Profile { + api_url: "https://api.test.com".to_string(), + access_token: "expired_token".to_string(), + refresh_token: None, // No refresh token (API key) + expires_at: Some(past_time), + org_name: Some("test-org".to_string()), + project: None, + }; + + config::save_profile("test", profile.clone()).unwrap(); + + // Should fail since token is expired and no refresh token + let result = refresh_token_if_needed("test", profile).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("expired")); + + // Cleanup + std::fs::remove_file(&config_path).ok(); + std::env::remove_var("BT_CONFIG"); + } + + #[tokio::test] + #[serial] + async fn test_refresh_token_if_needed_no_expiry() { + let config_path = std::env::temp_dir().join(format!( + "bt-auth-test-{}-{}.json", + std::process::id(), + SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos() + )); + std::env::set_var("BT_CONFIG", &config_path); + + let profile = Profile { + api_url: "https://api.test.com".to_string(), + access_token: "api_key_token".to_string(), + refresh_token: None, + expires_at: None, // API keys don't expire + org_name: Some("test-org".to_string()), + project: None, + }; + + config::save_profile("test", profile.clone()).unwrap(); + + // Should pass through without refresh (API keys don't expire) + let result = refresh_token_if_needed("test", profile.clone()).await; + assert!(result.is_ok()); + let refreshed = result.unwrap(); + assert_eq!(refreshed.access_token, "api_key_token"); + + // Cleanup + std::fs::remove_file(&config_path).ok(); + std::env::remove_var("BT_CONFIG"); + } + + #[test] + fn test_oauth_discovery_deserialization() { + let json = r#"{ + "authorization_endpoint": "https://auth.example.com/authorize", + "token_endpoint": "https://auth.example.com/token" + }"#; + + let discovery: Result = serde_json::from_str(json); + assert!(discovery.is_ok()); + let discovery = discovery.unwrap(); + assert_eq!( + discovery.authorization_endpoint, + "https://auth.example.com/authorize" + ); + assert_eq!(discovery.token_endpoint, "https://auth.example.com/token"); + } + + #[test] + fn test_token_response_deserialization() { + let json = r#"{ + "access_token": "at_12345", + "refresh_token": "rt_67890", + "expires_in": 3600 + }"#; + + let response: Result = serde_json::from_str(json); + assert!(response.is_ok()); + let response = response.unwrap(); + assert_eq!(response.access_token, "at_12345"); + assert_eq!(response.refresh_token, Some("rt_67890".to_string())); + assert_eq!(response.expires_in, Some(3600)); + } + + #[test] + fn test_token_response_without_refresh() { + let json = r#"{ + "access_token": "at_only" + }"#; + + let response: Result = serde_json::from_str(json); + assert!(response.is_ok()); + let response = response.unwrap(); + assert_eq!(response.access_token, "at_only"); + assert_eq!(response.refresh_token, None); + assert_eq!(response.expires_in, None); + } +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..b3087a2 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,425 @@ +use std::collections::HashMap; +use std::env; +use std::fs; +use std::path::PathBuf; + +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; + +/// Auth profile containing credentials and optional project +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Profile { + pub api_url: String, + pub access_token: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub refresh_token: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_at: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub org_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub project: Option, +} + +/// Top-level config structure +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct BtConfig { + #[serde(default)] + pub profiles: HashMap, +} + +/// Get the bt config directory path (same as receipt directory) +pub fn bt_config_dir() -> Option { + #[cfg(windows)] + { + env::var_os("APPDATA") + .map(PathBuf::from) + .map(|path| path.join("bt")) + } + #[cfg(not(windows))] + { + if let Some(xdg) = env::var_os("XDG_CONFIG_HOME") { + return Some(PathBuf::from(xdg).join("bt")); + } + env::var_os("HOME") + .map(PathBuf::from) + .map(|path| path.join(".config").join("bt")) + } +} + +/// Get the auth config file path +pub fn auth_config_path() -> Option { + // Allow override via BT_CONFIG env var + if let Ok(path) = env::var("BT_CONFIG") { + return Some(PathBuf::from(path)); + } + + bt_config_dir().map(|dir| dir.join("config.json")) +} + +/// Load the config from disk +pub fn load_config() -> Result { + let path = auth_config_path().context("failed to resolve config directory")?; + + if !path.exists() { + return Ok(BtConfig::default()); + } + + let contents = fs::read_to_string(&path) + .with_context(|| format!("failed to read config from {}", path.display()))?; + + let config: BtConfig = serde_json::from_str(&contents) + .with_context(|| format!("failed to parse config from {}", path.display()))?; + + Ok(config) +} + +/// Get a specific profile from config +pub fn get_profile(name: &str) -> Result> { + let config = load_config()?; + Ok(config.profiles.get(name).cloned()) +} + +/// Save a profile to config (updates existing or creates new) +pub fn save_profile(name: &str, profile: Profile) -> Result<()> { + let path = auth_config_path().context("failed to resolve config directory")?; + + // Create parent directory if it doesn't exist + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("failed to create config directory {}", parent.display()))?; + } + + // Load existing config or create new + let mut config = load_config().unwrap_or_default(); + + // Update profile + config.profiles.insert(name.to_string(), profile); + + // Write to disk + let contents = serde_json::to_string_pretty(&config).context("failed to serialize config")?; + + fs::write(&path, contents) + .with_context(|| format!("failed to write config to {}", path.display()))?; + + Ok(()) +} + +/// Delete a profile from config +pub fn delete_profile(name: &str) -> Result { + let path = auth_config_path().context("failed to resolve config directory")?; + + if !path.exists() { + return Ok(false); + } + + let mut config = load_config()?; + let removed = config.profiles.remove(name).is_some(); + + if removed { + let contents = + serde_json::to_string_pretty(&config).context("failed to serialize config")?; + + fs::write(&path, contents) + .with_context(|| format!("failed to write config to {}", path.display()))?; + } + + Ok(removed) +} + +#[cfg(test)] +mod tests { + use super::*; + use serial_test::serial; + use std::sync::atomic::{AtomicU64, Ordering}; + use std::time::{SystemTime, UNIX_EPOCH}; + + static TEST_COUNTER: AtomicU64 = AtomicU64::new(0); + + fn make_temp_config_path() -> PathBuf { + let counter = TEST_COUNTER.fetch_add(1, Ordering::SeqCst); + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system clock before unix epoch") + .as_nanos(); + let thread_id = std::thread::current().id(); + std::env::temp_dir().join(format!( + "bt-config-test-{}-{:?}-{}-{}.json", + std::process::id(), + thread_id, + now, + counter + )) + } + + #[test] + fn config_serialization_roundtrip() { + let mut config = BtConfig::default(); + config.profiles.insert( + "test".to_string(), + Profile { + api_url: "https://api.braintrust.dev".to_string(), + access_token: "brt_test123".to_string(), + refresh_token: Some("refresh_abc".to_string()), + expires_at: Some(Utc::now()), + org_name: Some("my-org".to_string()), + project: Some("my-project".to_string()), + }, + ); + + let json = serde_json::to_string_pretty(&config).unwrap(); + let parsed: BtConfig = serde_json::from_str(&json).unwrap(); + + assert_eq!(parsed.profiles.len(), 1); + assert!(parsed.profiles.contains_key("test")); + } + + #[test] + fn profile_optional_fields() { + let profile = Profile { + api_url: "https://api.braintrust.dev".to_string(), + access_token: "brt_test".to_string(), + refresh_token: None, + expires_at: None, + org_name: None, + project: None, + }; + + let json = serde_json::to_string(&profile).unwrap(); + assert!(!json.contains("refresh_token")); + assert!(!json.contains("expires_at")); + assert!(!json.contains("org_name")); + assert!(!json.contains("project")); + } + + #[test] + #[serial] + fn save_and_load_profile() { + let config_path = make_temp_config_path(); + + // Use a scoped block to ensure env var is set before operations + { + env::set_var("BT_CONFIG", &config_path); + + let profile = Profile { + api_url: "https://api.test.com".to_string(), + access_token: "test_token".to_string(), + refresh_token: Some("refresh_token".to_string()), + expires_at: None, + org_name: Some("test-org".to_string()), + project: Some("test-project".to_string()), + }; + + // Save profile + save_profile("test_profile", profile.clone()).unwrap(); + + // Verify file was created + assert!(config_path.exists(), "Config file should exist after save"); + + // Load it back + let loaded = get_profile("test_profile").unwrap(); + assert!(loaded.is_some(), "Profile should be loaded"); + let loaded = loaded.unwrap(); + assert_eq!(loaded.api_url, profile.api_url); + assert_eq!(loaded.access_token, profile.access_token); + assert_eq!(loaded.org_name, profile.org_name); + assert_eq!(loaded.project, profile.project); + } + + // Cleanup + fs::remove_file(&config_path).ok(); + env::remove_var("BT_CONFIG"); + } + + #[test] + #[serial] + fn get_nonexistent_profile() { + let config_path = make_temp_config_path(); + env::set_var("BT_CONFIG", &config_path); + + let result = get_profile("nonexistent").unwrap(); + assert!(result.is_none()); + + // Cleanup + env::remove_var("BT_CONFIG"); + } + + #[test] + #[serial] + fn save_multiple_profiles() { + let config_path = make_temp_config_path(); + + { + env::set_var("BT_CONFIG", &config_path); + + let profile1 = Profile { + api_url: "https://api1.com".to_string(), + access_token: "token1".to_string(), + refresh_token: None, + expires_at: None, + org_name: None, + project: None, + }; + + let profile2 = Profile { + api_url: "https://api2.com".to_string(), + access_token: "token2".to_string(), + refresh_token: None, + expires_at: None, + org_name: Some("org2".to_string()), + project: Some("proj2".to_string()), + }; + + save_profile("profile1", profile1).unwrap(); + save_profile("profile2", profile2).unwrap(); + + let config = load_config().unwrap(); + assert_eq!(config.profiles.len(), 2); + assert!(config.profiles.contains_key("profile1")); + assert!(config.profiles.contains_key("profile2")); + } + + // Cleanup + fs::remove_file(&config_path).ok(); + env::remove_var("BT_CONFIG"); + } + + #[test] + #[serial] + fn update_existing_profile() { + let config_path = make_temp_config_path(); + + { + env::set_var("BT_CONFIG", &config_path); + + let profile_v1 = Profile { + api_url: "https://api1.com".to_string(), + access_token: "token1".to_string(), + refresh_token: None, + expires_at: None, + org_name: None, + project: None, + }; + + save_profile("test", profile_v1).unwrap(); + + let profile_v2 = Profile { + api_url: "https://api2.com".to_string(), + access_token: "token2".to_string(), + refresh_token: Some("refresh".to_string()), + expires_at: None, + org_name: Some("new-org".to_string()), + project: Some("new-project".to_string()), + }; + + save_profile("test", profile_v2).unwrap(); + + let loaded = get_profile("test").unwrap(); + assert!(loaded.is_some(), "Updated profile should exist"); + let loaded = loaded.unwrap(); + assert_eq!(loaded.api_url, "https://api2.com"); + assert_eq!(loaded.access_token, "token2"); + assert_eq!(loaded.org_name, Some("new-org".to_string())); + } + + // Cleanup + fs::remove_file(&config_path).ok(); + env::remove_var("BT_CONFIG"); + } + + #[test] + #[serial] + fn delete_existing_profile() { + let config_path = make_temp_config_path(); + + { + env::set_var("BT_CONFIG", &config_path); + + let profile = Profile { + api_url: "https://api.com".to_string(), + access_token: "token".to_string(), + refresh_token: None, + expires_at: None, + org_name: None, + project: None, + }; + + save_profile("to_delete", profile).unwrap(); + assert!( + get_profile("to_delete").unwrap().is_some(), + "Profile should exist after save" + ); + + let deleted = delete_profile("to_delete").unwrap(); + assert!(deleted, "Delete should return true for existing profile"); + assert!( + get_profile("to_delete").unwrap().is_none(), + "Profile should be gone after delete" + ); + } + + // Cleanup + fs::remove_file(&config_path).ok(); + env::remove_var("BT_CONFIG"); + } + + #[test] + #[serial] + fn delete_nonexistent_profile() { + let config_path = make_temp_config_path(); + env::set_var("BT_CONFIG", &config_path); + + let deleted = delete_profile("nonexistent").unwrap(); + assert!(!deleted); + + // Cleanup + env::remove_var("BT_CONFIG"); + } + + #[test] + #[serial] + fn load_config_when_file_missing() { + let config_path = make_temp_config_path(); + env::set_var("BT_CONFIG", &config_path); + + let config = load_config().unwrap(); + assert_eq!(config.profiles.len(), 0); + + // Cleanup + env::remove_var("BT_CONFIG"); + } + + #[test] + fn oauth2_profile_without_org_name() { + let profile = Profile { + api_url: "https://api.braintrust.dev".to_string(), + access_token: "jwt_token_here".to_string(), + refresh_token: Some("refresh_here".to_string()), + expires_at: Some(Utc::now()), + org_name: None, // OAuth2 tokens don't need org_name + project: Some("my-project".to_string()), + }; + + let json = serde_json::to_string(&profile).unwrap(); + assert!(!json.contains("org_name")); + assert!(json.contains("refresh_token")); + assert!(json.contains("project")); + } + + #[test] + fn api_key_profile_with_org_name() { + let profile = Profile { + api_url: "https://api.braintrust.dev".to_string(), + access_token: "brt_apikey".to_string(), + refresh_token: None, // API keys don't have refresh tokens + expires_at: None, + org_name: Some("my-org".to_string()), // API keys need org_name + project: Some("my-project".to_string()), + }; + + let json = serde_json::to_string(&profile).unwrap(); + assert!(json.contains("org_name")); + assert!(!json.contains("refresh_token")); + } +} diff --git a/src/eval.rs b/src/eval.rs index 06ba1ce..12feb50 100644 --- a/src/eval.rs +++ b/src/eval.rs @@ -30,6 +30,7 @@ use ratatui::widgets::{Cell, Row, Table}; use ratatui::Terminal; use crate::args::BaseArgs; +use crate::login::{login, LoginContext}; const MAX_NAME_LENGTH: usize = 40; const WATCH_POLL_INTERVAL: Duration = Duration::from_millis(500); @@ -145,10 +146,12 @@ pub async fn run(base: BaseArgs, args: EvalArgs) -> Result<()> { list: args.list, filter: args.filter, }; + // Resolve credentials and project from profile or explicit flags + let ctx = login(&base).await?; if args.watch { run_eval_files_watch( - &base, + &ctx, args.language, args.runner.clone(), args.files.clone(), @@ -158,7 +161,7 @@ pub async fn run(base: BaseArgs, args: EvalArgs) -> Result<()> { .await } else { let output = run_eval_files_once( - &base, + &ctx, args.language, args.runner.clone(), args.files.clone(), @@ -174,7 +177,7 @@ pub async fn run(base: BaseArgs, args: EvalArgs) -> Result<()> { } async fn run_eval_files_watch( - base: &BaseArgs, + ctx: &LoginContext, language_override: Option, runner_override: Option, files: Vec, @@ -192,7 +195,7 @@ async fn run_eval_files_watch( loop { match run_eval_files_once( - base, + ctx, language_override, runner_override.clone(), files.clone(), @@ -235,7 +238,7 @@ async fn run_eval_files_watch( } async fn run_eval_files_once( - base: &BaseArgs, + ctx: &LoginContext, language_override: Option, runner_override: Option, files: Vec, @@ -284,7 +287,7 @@ async fn run_eval_files_once( EvalLanguage::JavaScript => build_js_command(runner_override, &js_runner, &files)?, }; - cmd.envs(build_env(base)); + cmd.envs(build_env(ctx)); if no_send_logs { cmd.env("BT_EVAL_NO_SEND_LOGS", "1"); cmd.env("BT_EVAL_LOCAL", "1"); @@ -630,17 +633,20 @@ fn format_watch_paths(paths: &[PathBuf]) -> String { } } -fn build_env(base: &BaseArgs) -> Vec<(String, String)> { +fn build_env(ctx: &LoginContext) -> Vec<(String, String)> { let mut envs = Vec::new(); - if let Some(api_key) = base.api_key.as_ref() { - envs.push(("BRAINTRUST_API_KEY".to_string(), api_key.clone())); - } - if let Some(api_url) = base.api_url.as_ref() { - envs.push(("BRAINTRUST_API_URL".to_string(), api_url.clone())); - } - if let Some(project) = base.project.as_ref() { + + // Use resolved API key from LoginContext (from profile or explicit) + envs.push(("BRAINTRUST_API_KEY".to_string(), ctx.api_key().to_string())); + + // Use resolved API URL from LoginContext + envs.push(("BRAINTRUST_API_URL".to_string(), ctx.api_url.clone())); + + // Use resolved project from LoginContext (--project > BRAINTRUST_DEFAULT_PROJECT > profile.project) + if let Some(project) = &ctx.project { envs.push(("BRAINTRUST_DEFAULT_PROJECT".to_string(), project.clone())); } + envs } diff --git a/src/http.rs b/src/http.rs index dff9654..7915380 100644 --- a/src/http.rs +++ b/src/http.rs @@ -21,8 +21,8 @@ impl ApiClient { Ok(Self { http, base_url: ctx.api_url.trim_end_matches('/').to_string(), - api_key: ctx.login.api_key.clone(), - org_name: ctx.login.org_name.clone(), + api_key: ctx.api_key().to_string(), + org_name: ctx.org_name().unwrap_or("").to_string(), }) } diff --git a/src/login.rs b/src/login.rs index c4daa58..c1d8289 100644 --- a/src/login.rs +++ b/src/login.rs @@ -1,19 +1,90 @@ -use anyhow::Result; +use anyhow::{anyhow, Result}; use braintrust_sdk_rust::{BraintrustClient, LoginState}; use crate::args::BaseArgs; +use crate::auth; +use crate::config; +/// Resolved credentials and context for API calls pub struct LoginContext { - pub login: LoginState, + /// Either from SDK LoginState or from profile + pub source: LoginSource, pub api_url: String, pub app_url: String, + /// Resolved project (from --project, env, or profile) + pub project: Option, +} + +pub enum LoginSource { + /// SDK-based login (legacy flow) + Sdk(LoginState), + /// Profile-based login (new OAuth2/API key flow) + Profile { + api_key: String, + org_name: Option, + }, +} + +impl LoginContext { + /// Get the API key for authentication + pub fn api_key(&self) -> &str { + match &self.source { + LoginSource::Sdk(login) => &login.api_key, + LoginSource::Profile { api_key, .. } => api_key, + } + } + + /// Get the organization name if available + pub fn org_name(&self) -> Option<&str> { + match &self.source { + LoginSource::Sdk(login) => Some(login.org_name.as_str()), + LoginSource::Profile { org_name, .. } => org_name.as_deref(), + } + } } pub async fn login(base: &BaseArgs) -> Result { - let mut builder = BraintrustClient::builder().blocking_login(true); + // Priority 1: Explicit API key (--api-key or BRAINTRUST_API_KEY) if let Some(api_key) = &base.api_key { - builder = builder.api_key(api_key); + return login_with_explicit_key(base, api_key).await; + } + + // Priority 2: Profile-based authentication + if let Ok(Some(profile)) = config::get_profile(&base.profile) { + // Refresh token if needed + let profile = auth::refresh_token_if_needed(&base.profile, profile).await?; + + let api_url = profile.api_url.clone(); + let app_url = base.app_url.clone().unwrap_or_else(|| { + api_url + .replace("api.braintrust", "www.braintrust") + .replace("api.braintrustdata", "www.braintrustdata") + }); + + // Resolve effective project: --project > BRAINTRUST_DEFAULT_PROJECT (in base.project) > profile.project + let project = base.project.clone().or_else(|| profile.project.clone()); + + return Ok(LoginContext { + source: LoginSource::Profile { + api_key: profile.access_token, + org_name: profile.org_name, + }, + api_url, + app_url, + project, + }); } + + // Priority 3: Fall back to SDK login (legacy behavior) + // This will likely fail if no credentials are available + login_with_sdk(base).await +} + +async fn login_with_explicit_key(base: &BaseArgs, api_key: &str) -> Result { + let mut builder = BraintrustClient::builder() + .blocking_login(true) + .api_key(api_key); + if let Some(api_url) = &base.api_url { builder = builder.api_url(api_url); } @@ -30,7 +101,6 @@ pub async fn login(base: &BaseArgs) -> Result { .or_else(|| base.api_url.clone()) .unwrap_or_else(|| "https://api.braintrust.dev".to_string()); - // Derive app_url from api_url (api.braintrust.dev -> www.braintrust.dev) let app_url = base.app_url.clone().unwrap_or_else(|| { api_url .replace("api.braintrust", "www.braintrust") @@ -38,8 +108,57 @@ pub async fn login(base: &BaseArgs) -> Result { }); Ok(LoginContext { - login, + source: LoginSource::Sdk(login), + api_url, + app_url, + project: base.project.clone(), + }) +} + +async fn login_with_sdk(base: &BaseArgs) -> Result { + let mut builder = BraintrustClient::builder().blocking_login(true); + + if let Some(api_url) = &base.api_url { + builder = builder.api_url(api_url); + } + if let Some(project) = &base.project { + builder = builder.default_project(project); + } + + let client = builder.build().await; + + // Provide a better error message if SDK login fails + let client = client.map_err(|e| { + anyhow!( + "Failed to authenticate: {}. \ + Try setting BRAINTRUST_API_KEY or run `bt auth login{}`", + e, + if base.profile != "DEFAULT" { + format!(" --profile {}", base.profile) + } else { + String::new() + } + ) + })?; + + let login = client.wait_for_login().await?; + + let api_url = login + .api_url + .clone() + .or_else(|| base.api_url.clone()) + .unwrap_or_else(|| "https://api.braintrust.dev".to_string()); + + let app_url = base.app_url.clone().unwrap_or_else(|| { + api_url + .replace("api.braintrust", "www.braintrust") + .replace("api.braintrustdata", "www.braintrustdata") + }); + + Ok(LoginContext { + source: LoginSource::Sdk(login), api_url, app_url, + project: base.project.clone(), }) } diff --git a/src/main.rs b/src/main.rs index baabc0d..bf2acd6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,8 @@ use clap::{Parser, Subcommand}; use std::ffi::OsString; mod args; +mod auth; +mod config; mod env; #[cfg(unix)] mod eval; @@ -24,6 +26,8 @@ struct Cli { #[derive(Debug, Subcommand)] enum Commands { + /// Authenticate with Braintrust + Auth(auth::AuthArgs), /// Run SQL queries against Braintrust Sql(CLIArgs), #[cfg(unix)] @@ -43,6 +47,7 @@ async fn main() -> Result<()> { let cli = Cli::parse_from(argv); match cli.command { + Commands::Auth(args) => auth::run(args).await?, Commands::Sql(cmd) => sql::run(cmd.base, cmd.args).await?, #[cfg(unix)] Commands::Eval(cmd) => eval::run(cmd.base, cmd.args).await?, diff --git a/src/projects/api.rs b/src/projects/api.rs index fd24213..c858563 100644 --- a/src/projects/api.rs +++ b/src/projects/api.rs @@ -19,13 +19,23 @@ struct ListResponse { } pub async fn list_projects(client: &ApiClient) -> Result> { - let path = format!("/v1/project?org_name={}", encode(client.org_name())); + let path = if client.org_name().is_empty() { + // OAuth2 tokens don't need org_name filtering + "/v1/project".to_string() + } else { + // API keys need org_name for proper filtering + format!("/v1/project?org_name={}", encode(client.org_name())) + }; let list: ListResponse = client.get(&path).await?; Ok(list.objects) } pub async fn create_project(client: &ApiClient, name: &str) -> Result { - let body = serde_json::json!({ "name": name, "org_name": client.org_name() }); + let body = if client.org_name().is_empty() { + serde_json::json!({ "name": name }) + } else { + serde_json::json!({ "name": name, "org_name": client.org_name() }) + }; client.post("/v1/project", &body).await } @@ -35,11 +45,15 @@ pub async fn delete_project(client: &ApiClient, project_id: &str) -> Result<()> } pub async fn get_project_by_name(client: &ApiClient, name: &str) -> Result> { - let path = format!( - "/v1/project?org_name={}&name={}", - encode(client.org_name()), - encode(name) - ); + let path = if client.org_name().is_empty() { + format!("/v1/project?name={}", encode(name)) + } else { + format!( + "/v1/project?org_name={}&name={}", + encode(client.org_name()), + encode(name) + ) + }; let list: ListResponse = client.get(&path).await?; Ok(list.objects.into_iter().next()) } diff --git a/src/projects/mod.rs b/src/projects/mod.rs index b24cc43..f0847a7 100644 --- a/src/projects/mod.rs +++ b/src/projects/mod.rs @@ -76,11 +76,13 @@ pub async fn run(base: BaseArgs, args: ProjectsArgs) -> Result<()> { match args.command { None | Some(ProjectsCommands::List) => { - list::run(&client, &ctx.login.org_name, base.json).await + let org_name = ctx.org_name().unwrap_or(""); + list::run(&client, org_name, base.json).await } Some(ProjectsCommands::Create(a)) => create::run(&client, a.name.as_deref()).await, Some(ProjectsCommands::View(a)) => { - view::run(&client, &ctx.app_url, &ctx.login.org_name, a.name()).await + let org_name = ctx.org_name().unwrap_or(""); + view::run(&client, &ctx.app_url, org_name, a.name()).await } Some(ProjectsCommands::Delete(a)) => delete::run(&client, a.name.as_deref()).await, Some(ProjectsCommands::Switch(a)) => switch::run(&client, a.name.as_deref()).await,