diff --git a/crates/Cargo.lock b/crates/Cargo.lock index fbf817e70..e32dcf486 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -344,6 +344,7 @@ dependencies = [ "tracing", "tracing-opentelemetry", "tracing-subscriber", + "urlencoding", "uuid", ] diff --git a/crates/brightstaff/Cargo.toml b/crates/brightstaff/Cargo.toml index 5d986ffa5..d627ba962 100644 --- a/crates/brightstaff/Cargo.toml +++ b/crates/brightstaff/Cargo.toml @@ -38,6 +38,7 @@ tokio-postgres = { version = "0.7", features = ["with-serde_json-1"] } tokio-stream = "0.1" time = { version = "0.3", features = ["formatting", "macros"] } tracing = "0.1" +urlencoding = "2.1.3" tracing-opentelemetry = "0.32.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } uuid = { version = "1.0", features = ["v4", "serde"] } diff --git a/crates/brightstaff/src/handlers/routing_service.rs b/crates/brightstaff/src/handlers/routing_service.rs index ec09f06fb..f389bda8b 100644 --- a/crates/brightstaff/src/handlers/routing_service.rs +++ b/crates/brightstaff/src/handlers/routing_service.rs @@ -13,26 +13,43 @@ use tracing::{debug, info, info_span, warn, Instrument}; use super::extract_or_generate_traceparent; use crate::handlers::llm::model_selection::router_chat_get_upstream_model; use crate::router::llm::RouterService; +use crate::router::policy_provider::PolicyProvider; use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name}; const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120; -/// Extracts `routing_policy` from a JSON body, returning the cleaned body bytes -/// and parsed preferences. The `routing_policy` field is removed from the JSON -/// before re-serializing so downstream parsers don't see the non-standard field. +/// Extracted routing metadata from a request body. +#[derive(Debug, Default)] +pub struct RoutingMetadata { + /// Inline routing policy (highest priority). + pub inline_policy: Option>, + /// Policy ID for external policy provider lookup. + pub policy_id: Option, + /// Revision for revision-aware caching. + pub revision: Option, +} + +/// Extracts routing metadata from a JSON body, returning the cleaned body bytes +/// and parsed metadata. +/// +/// Fields removed from the JSON before re-serializing: +/// - `routing_policy`: Inline routing preferences +/// - `policy_id`: External policy identifier +/// - `revision`: Policy revision for cache invalidation /// /// If `warn_on_size` is true, logs a warning when the serialized policy exceeds 5KB. -pub fn extract_routing_policy( +pub fn extract_routing_metadata( raw_bytes: &[u8], warn_on_size: bool, -) -> Result<(Bytes, Option>), String> { +) -> Result<(Bytes, RoutingMetadata), String> { let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes) .map_err(|err| format!("Failed to parse JSON: {}", err))?; - let preferences = json_body - .as_object_mut() - .and_then(|obj| obj.remove("routing_policy")) - .and_then(|policy_value| { + let mut metadata = RoutingMetadata::default(); + + if let Some(obj) = json_body.as_object_mut() { + // Extract inline routing_policy (highest priority) + if let Some(policy_value) = obj.remove("routing_policy") { if warn_on_size { let policy_str = serde_json::to_string(&policy_value).unwrap_or_default(); if policy_str.len() > ROUTING_POLICY_SIZE_WARNING_BYTES { @@ -49,17 +66,81 @@ pub fn extract_routing_policy( num_models = prefs.len(), "using inline routing_policy from request body" ); - Some(prefs) + metadata.inline_policy = Some(prefs); } Err(err) => { warn!(error = %err, "failed to parse routing_policy"); - None } } - }); + } + + // Extract policy_id for external policy provider + if let Some(policy_id_value) = obj.remove("policy_id") { + if let Some(policy_id) = policy_id_value.as_str() { + debug!(policy_id = %policy_id, "extracted policy_id from request"); + metadata.policy_id = Some(policy_id.to_string()); + } + } + + // Extract revision for revision-aware caching + if let Some(revision_value) = obj.remove("revision") { + if let Some(revision) = revision_value.as_u64() { + debug!(revision = revision, "extracted revision from request"); + metadata.revision = Some(revision); + } + } + } let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap()); - Ok((bytes, preferences)) + Ok((bytes, metadata)) +} + +/// Resolves routing preferences using the following priority: +/// 1. Inline `routing_policy` in request payload (highest priority) +/// 2. `policy_id` + `revision` → HTTP policy provider (with cache) +/// 3. None (fallback to default routing) +pub async fn resolve_routing_preferences( + metadata: RoutingMetadata, + policy_provider: Option<&PolicyProvider>, +) -> Option> { + // Priority 1: Inline policy + if let Some(inline) = metadata.inline_policy { + return Some(inline); + } + + // Priority 2: External policy provider + if let (Some(provider), Some(policy_id)) = (policy_provider, &metadata.policy_id) { + match provider.get_policy(policy_id, metadata.revision).await { + Ok(Some(policy)) => { + info!( + policy_id = %policy_id, + num_models = policy.len(), + "using policy from external provider" + ); + return Some(policy); + } + Ok(None) => { + warn!(policy_id = %policy_id, "policy not found from external provider"); + } + Err(err) => { + warn!(error = %err, policy_id = %policy_id, "failed to fetch policy from external provider"); + } + } + } + + // Priority 3: No preferences (fallback to default) + None +} + +/// Backward-compatible function that only extracts inline routing_policy. +/// Deprecated: Use `extract_routing_metadata` instead. +#[deprecated(note = "Use extract_routing_metadata instead")] +pub fn extract_routing_policy( + raw_bytes: &[u8], + warn_on_size: bool, +) -> Result<(Bytes, Option>), String> { + let (bytes, metadata) = extract_routing_metadata(raw_bytes, warn_on_size)?; + Ok((bytes, metadata.inline_policy)) } #[derive(serde::Serialize)] diff --git a/crates/brightstaff/src/router/mod.rs b/crates/brightstaff/src/router/mod.rs index b010d80c9..14cafeab5 100644 --- a/crates/brightstaff/src/router/mod.rs +++ b/crates/brightstaff/src/router/mod.rs @@ -3,5 +3,6 @@ pub mod llm; pub mod orchestrator; pub mod orchestrator_model; pub mod orchestrator_model_v1; +pub mod policy_provider; pub mod router_model; pub mod router_model_v1; diff --git a/crates/brightstaff/src/router/policy_provider.rs b/crates/brightstaff/src/router/policy_provider.rs new file mode 100644 index 000000000..4f37dea0f --- /dev/null +++ b/crates/brightstaff/src/router/policy_provider.rs @@ -0,0 +1,261 @@ +//! External HTTP routing policy provider. +//! +//! Fetches routing policies from an external HTTP endpoint with caching support. +//! Policies are cached by `policy_id` with revision-aware invalidation. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; + +use common::configuration::ModelUsagePreference; +use reqwest::header::HeaderMap; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use tokio::sync::RwLock; +use tracing::{debug, warn}; + +/// Configuration for the external HTTP policy provider. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PolicyProviderConfig { + /// URL of the external policy endpoint. + pub url: String, + /// Optional headers to include in requests (e.g., Authorization). + #[serde(default)] + pub headers: HashMap, + /// Cache TTL in seconds. Defaults to 300 (5 minutes). + #[serde(default = "default_ttl_seconds")] + pub ttl_seconds: u64, +} + +fn default_ttl_seconds() -> u64 { + 300 +} + +impl Default for PolicyProviderConfig { + fn default() -> Self { + Self { + url: String::new(), + headers: HashMap::new(), + ttl_seconds: 300, + } + } +} + +/// Response from the external policy endpoint. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PolicyResponse { + pub policy_id: String, + pub revision: u64, + pub schema_version: String, + pub routing_preferences: Vec, +} + +/// Cached policy entry with revision and expiration. +#[derive(Debug, Clone)] +struct CachedPolicy { + policy: Vec, + revision: u64, + cached_at: Instant, + ttl: Duration, +} + +impl CachedPolicy { + fn is_expired(&self) -> bool { + self.cached_at.elapsed() > self.ttl + } +} + +#[derive(Debug, Error)] +pub enum PolicyProviderError { + #[error("HTTP request failed: {0}")] + Http(#[from] reqwest::Error), + + #[error("Failed to parse policy response: {0}")] + Parse(String), + + #[error("Unsupported schema version: {0}")] + UnsupportedSchemaVersion(String), + + #[error("Policy ID mismatch: expected {expected}, got {actual}")] + PolicyIdMismatch { expected: String, actual: String }, + + #[error("No policy provider configured")] + NotConfigured, +} + +/// External HTTP routing policy provider with caching. +pub struct PolicyProvider { + config: PolicyProviderConfig, + client: reqwest::Client, + cache: RwLock>, +} + +impl PolicyProvider { + pub fn new(config: PolicyProviderConfig) -> Self { + Self { + config, + client: reqwest::Client::new(), + cache: RwLock::new(HashMap::new()), + } + } + + /// Fetches routing policy for the given policy_id and revision. + /// + /// Resolution order: + /// 1. If cached and cached revision >= requested revision, use cache + /// 2. Otherwise, fetch from external endpoint + /// + /// Returns `None` if no policy_id is provided or if the provider is not configured. + pub async fn get_policy( + &self, + policy_id: &str, + revision: Option, + ) -> Result>, PolicyProviderError> { + if self.config.url.is_empty() { + return Err(PolicyProviderError::NotConfigured); + } + + let revision = revision.unwrap_or(0); + + // Check cache first + { + let cache = self.cache.read().await; + if let Some(cached) = cache.get(policy_id) { + if !cached.is_expired() && cached.revision >= revision { + debug!( + policy_id = %policy_id, + cached_revision = cached.revision, + requested_revision = revision, + "using cached policy" + ); + return Ok(Some(cached.policy.clone())); + } + } + } + + // Fetch from external endpoint + let policy = self.fetch_policy(policy_id, revision).await?; + + // Update cache + { + let mut cache = self.cache.write().await; + cache.insert( + policy_id.to_string(), + CachedPolicy { + policy: policy.routing_preferences.clone(), + revision: policy.revision, + cached_at: Instant::now(), + ttl: Duration::from_secs(self.config.ttl_seconds), + }, + ); + } + + debug!( + policy_id = %policy_id, + revision = policy.revision, + num_models = policy.routing_preferences.len(), + "fetched and cached policy from external endpoint" + ); + + Ok(Some(policy.routing_preferences)) + } + + async fn fetch_policy( + &self, + policy_id: &str, + revision: u64, + ) -> Result { + let url = format!( + "{}?policy_id={}&revision={}", + self.config.url, + urlencoding::encode(policy_id), + revision + ); + + let mut headers = HeaderMap::new(); + for (key, value) in &self.config.headers { + if let Ok(header_name) = key.parse() { + if let Ok(header_value) = value.parse() { + headers.insert(header_name, header_value); + } + } + } + + debug!(url = %url, "fetching policy from external endpoint"); + + let response = self.client.get(&url).headers(headers).send().await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + return Err(PolicyProviderError::Parse(format!( + "HTTP {} from policy endpoint: {}", + status, body + ))); + } + + let policy: PolicyResponse = response.json().await.map_err(|e| { + PolicyProviderError::Parse(format!("Failed to parse JSON response: {}", e)) + })?; + + // Validate schema version + if policy.schema_version != "v1" { + return Err(PolicyProviderError::UnsupportedSchemaVersion( + policy.schema_version, + )); + } + + // Validate policy_id matches + if policy.policy_id != policy_id { + return Err(PolicyProviderError::PolicyIdMismatch { + expected: policy_id.to_string(), + actual: policy.policy_id, + }); + } + + Ok(policy) + } + + /// Clears the cache for a specific policy_id or all policies. + pub async fn clear_cache(&self, policy_id: Option<&str>) { + let mut cache = self.cache.write().await; + match policy_id { + Some(id) => { + cache.remove(id); + } + None => { + cache.clear(); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_default_ttl() { + let config = PolicyProviderConfig::default(); + assert_eq!(config.ttl_seconds, 300); + } + + #[test] + fn test_cached_policy_expiration() { + let cached = CachedPolicy { + policy: vec![], + revision: 1, + cached_at: Instant::now() - Duration::from_secs(400), + ttl: Duration::from_secs(300), + }; + assert!(cached.is_expired()); + + let fresh = CachedPolicy { + policy: vec![], + revision: 1, + cached_at: Instant::now(), + ttl: Duration::from_secs(300), + }; + assert!(!fresh.is_expired()); + } +}