diff --git a/cli/planoai/config_generator.py b/cli/planoai/config_generator.py index 522968c94..347de5d62 100644 --- a/cli/planoai/config_generator.py +++ b/cli/planoai/config_generator.py @@ -185,6 +185,41 @@ def validate_and_render_schema(): f"Invalid opentracing_grpc_endpoint {opentracing_grpc_endpoint}, path must be empty" ) + routing = config_yaml.get("routing", {}) + policy_provider = routing.get("policy_provider") + if policy_provider: + policy_url = policy_provider.get("url") + if not policy_url: + raise Exception( + "routing.policy_provider.url is required when policy_provider is set" + ) + if "$" in policy_url: + policy_url = os.path.expandvars(policy_url) + policy_url_result = urlparse(policy_url) + if ( + policy_url_result.scheme not in ["http", "https"] + or not policy_url_result.hostname + ): + raise Exception( + f"Invalid routing.policy_provider.url {policy_provider.get('url')}, must be a valid http/https URL" + ) + + ttl_seconds = policy_provider.get("ttl_seconds") + if ttl_seconds is not None and ttl_seconds <= 0: + raise Exception( + "routing.policy_provider.ttl_seconds must be greater than 0" + ) + + headers = policy_provider.get("headers") + if headers is not None: + if not isinstance(headers, dict): + raise Exception("routing.policy_provider.headers must be an object") + for key, value in headers.items(): + if not isinstance(key, str) or not isinstance(value, str): + raise Exception( + "routing.policy_provider.headers must contain string keys and string values" + ) + llms_with_endpoint = [] llms_with_endpoint_cluster_names = set() updated_model_providers = [] diff --git a/config/plano_config_schema.yaml b/config/plano_config_schema.yaml index b63cb8244..e202a705c 100644 --- a/config/plano_config_schema.yaml +++ b/config/plano_config_schema.yaml @@ -411,10 +411,27 @@ properties: routing: type: object properties: + model_provider: + type: string llm_provider: type: string model: type: string + policy_provider: + type: object + properties: + url: + type: string + headers: + type: object + additionalProperties: + type: string + ttl_seconds: + type: integer + minimum: 1 + additionalProperties: false + required: + - url additionalProperties: false state_storage: type: object diff --git a/crates/brightstaff/src/handlers/llm.rs b/crates/brightstaff/src/handlers/llm.rs index 67afebff3..62fc03b5f 100644 --- a/crates/brightstaff/src/handlers/llm.rs +++ b/crates/brightstaff/src/handlers/llm.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use tokio::sync::RwLock; use tracing::{debug, info, info_span, warn, Instrument}; +use crate::handlers::policy_provider::PolicyProviderClient; use crate::handlers::router_chat::router_chat_get_upstream_model; use crate::handlers::utils::{ create_streaming_response, truncate_message, ObservableStreamProcessor, @@ -34,9 +35,11 @@ use crate::tracing::{ use common::errors::BrightStaffError; +#[allow(clippy::too_many_arguments)] pub async fn llm_chat( request: Request, router_service: Arc, + policy_provider: Option>, full_qualified_llm_provider_url: String, model_aliases: Arc>>, llm_providers: Arc>, @@ -73,6 +76,7 @@ pub async fn llm_chat( llm_chat_inner( request, router_service, + policy_provider, full_qualified_llm_provider_url, model_aliases, llm_providers, @@ -90,6 +94,7 @@ pub async fn llm_chat( async fn llm_chat_inner( request: Request, router_service: Arc, + policy_provider: Option>, full_qualified_llm_provider_url: String, model_aliases: Arc>>, llm_providers: Arc>, @@ -134,7 +139,7 @@ async fn llm_chat_inner( ); // Extract routing_policy from request body if present - let (chat_request_bytes, inline_routing_policy) = + let (chat_request_bytes, inline_routing_policy, policy_id) = match crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false) { Ok(result) => result, Err(err) => { @@ -355,6 +360,8 @@ async fn llm_chat_inner( &request_path, &request_id, inline_routing_policy, + policy_id, + policy_provider, ) .await } diff --git a/crates/brightstaff/src/handlers/mod.rs b/crates/brightstaff/src/handlers/mod.rs index 9c602e93d..82c58ec48 100644 --- a/crates/brightstaff/src/handlers/mod.rs +++ b/crates/brightstaff/src/handlers/mod.rs @@ -5,6 +5,7 @@ pub mod jsonrpc; pub mod llm; pub mod models; pub mod pipeline_processor; +pub mod policy_provider; pub mod response_handler; pub mod router_chat; pub mod routing_service; diff --git a/crates/brightstaff/src/handlers/policy_provider.rs b/crates/brightstaff/src/handlers/policy_provider.rs new file mode 100644 index 000000000..8a1f4f170 --- /dev/null +++ b/crates/brightstaff/src/handlers/policy_provider.rs @@ -0,0 +1,291 @@ +use std::sync::Arc; +use std::time::Duration; + +use common::configuration::{ModelUsagePreference, RoutingPolicyProvider}; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue}; +use serde::Deserialize; +use tracing::warn; + +use crate::state::policy_cache::PolicyCache; + +const DEFAULT_POLICY_TTL_SECONDS: u64 = 300; + +#[derive(Debug, Deserialize)] +struct ExternalPolicyResponse { + policy_id: String, + routing_preferences: Vec, +} + +#[derive(Debug)] +pub enum PolicyFetchError { + Transient(String), + Invalid(String), +} + +impl PolicyFetchError { + pub fn is_transient(&self) -> bool { + matches!(self, PolicyFetchError::Transient(_)) + } + + pub fn message(&self) -> &str { + match self { + PolicyFetchError::Transient(msg) | PolicyFetchError::Invalid(msg) => msg, + } + } +} + +pub struct PolicyProviderClient { + client: reqwest::Client, + config: RoutingPolicyProvider, + cache: Arc, + ttl: Duration, +} + +impl PolicyProviderClient { + pub fn new(config: RoutingPolicyProvider, cache: Arc) -> Self { + let ttl = Duration::from_secs(config.ttl_seconds.unwrap_or(DEFAULT_POLICY_TTL_SECONDS)); + Self { + client: reqwest::Client::new(), + config, + cache, + ttl, + } + } + + pub async fn fetch_policy( + &self, + policy_id: &str, + ) -> Result, PolicyFetchError> { + if let Some(cached) = self.cache.get_valid(policy_id).await { + return Ok(cached); + } + + let headers = self.build_headers()?; + let response = self + .client + .get(&self.config.url) + .query(&[("policy_id", policy_id)]) + .headers(headers) + .send() + .await + .map_err(|err| PolicyFetchError::Transient(format!("policy fetch failed: {}", err)))?; + + if !response.status().is_success() { + return if response.status().is_server_error() { + Err(PolicyFetchError::Transient(format!( + "policy provider returned {}", + response.status() + ))) + } else { + Err(PolicyFetchError::Invalid(format!( + "policy provider returned non-success status {}", + response.status() + ))) + }; + } + + let payload: ExternalPolicyResponse = response + .json() + .await + .map_err(|err| PolicyFetchError::Invalid(format!("invalid policy payload: {}", err)))?; + + if payload.policy_id != policy_id { + return Err(PolicyFetchError::Invalid(format!( + "policy_id mismatch in provider response: expected '{}', got '{}'", + policy_id, payload.policy_id + ))); + } + + if payload.routing_preferences.is_empty() { + warn!( + policy_id, + "policy provider returned empty routing preferences" + ); + } + + self.cache + .insert( + policy_id.to_string(), + payload.routing_preferences.clone(), + self.ttl, + ) + .await; + Ok(payload.routing_preferences) + } + + fn build_headers(&self) -> Result { + let mut headers = HeaderMap::new(); + if let Some(configured_headers) = &self.config.headers { + for (name, value) in configured_headers { + let header_name = HeaderName::from_bytes(name.as_bytes()).map_err(|err| { + PolicyFetchError::Invalid(format!( + "invalid header name '{}' in routing.policy_provider.headers: {}", + name, err + )) + })?; + let header_value = HeaderValue::from_str(value).map_err(|err| { + PolicyFetchError::Invalid(format!( + "invalid header value for '{}' in routing.policy_provider.headers: {}", + name, err + )) + })?; + headers.insert(header_name, header_value); + } + } + Ok(headers) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::sync::Arc; + use std::time::Duration; + + use common::configuration::RoutingPolicyProvider; + use mockito::{Matcher, Server}; + + use crate::handlers::policy_provider::{PolicyFetchError, PolicyProviderClient}; + use crate::state::policy_cache::PolicyCache; + + fn provider_config(url: String, ttl_seconds: Option) -> RoutingPolicyProvider { + RoutingPolicyProvider { + url, + headers: None, + ttl_seconds, + } + } + + #[tokio::test] + async fn fetches_policy_and_populates_cache() { + let mut server = Server::new_async().await; + let _mock = server + .mock("GET", "/v1/routing-policy") + .match_query(Matcher::UrlEncoded( + "policy_id".to_string(), + "customer-abc".to_string(), + )) + .with_status(200) + .with_header("content-type", "application/json") + .with_body( + r#"{ + "policy_id":"customer-abc", + "routing_preferences":[ + { + "model":"openai/gpt-4o", + "routing_preferences":[{"name":"quick response","description":"fast"}] + } + ] + }"#, + ) + .expect(1) + .create_async() + .await; + + let cache = Arc::new(PolicyCache::new()); + let client = PolicyProviderClient::new( + provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)), + cache, + ); + + let first = client.fetch_policy("customer-abc").await.unwrap(); + let second = client.fetch_policy("customer-abc").await.unwrap(); + assert_eq!(first.len(), 1); + assert_eq!(second[0].model, "openai/gpt-4o"); + } + + #[tokio::test] + async fn returns_invalid_on_policy_id_mismatch() { + let mut server = Server::new_async().await; + let _mock = server + .mock("GET", "/v1/routing-policy") + .match_query(Matcher::Any) + .with_status(200) + .with_header("content-type", "application/json") + .with_body( + r#"{ + "policy_id":"different-id", + "routing_preferences":[] + }"#, + ) + .create_async() + .await; + + let cache = Arc::new(PolicyCache::new()); + let client = PolicyProviderClient::new( + provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)), + cache, + ); + + let err = client.fetch_policy("customer-abc").await.unwrap_err(); + assert!(matches!(err, PolicyFetchError::Invalid(_))); + } + + #[tokio::test] + async fn returns_transient_on_server_error() { + let mut server = Server::new_async().await; + let _mock = server + .mock("GET", "/v1/routing-policy") + .match_query(Matcher::Any) + .with_status(500) + .create_async() + .await; + + let cache = Arc::new(PolicyCache::new()); + let client = PolicyProviderClient::new( + provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)), + cache, + ); + + let err = client.fetch_policy("customer-abc").await.unwrap_err(); + assert!(err.is_transient()); + } + + #[tokio::test] + async fn returns_invalid_on_client_error_status() { + let mut server = Server::new_async().await; + let _mock = server + .mock("GET", "/v1/routing-policy") + .match_query(Matcher::Any) + .with_status(404) + .create_async() + .await; + + let cache = Arc::new(PolicyCache::new()); + let client = PolicyProviderClient::new( + provider_config(format!("{}/v1/routing-policy", server.url()), Some(300)), + cache, + ); + + let err = client.fetch_policy("customer-abc").await.unwrap_err(); + assert!(matches!(err, PolicyFetchError::Invalid(_))); + } + + #[tokio::test] + async fn supports_headers() { + let mut server = Server::new_async().await; + let _mock = server + .mock("GET", "/v1/routing-policy") + .match_header("authorization", "Bearer token") + .match_query(Matcher::Any) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"policy_id":"customer-abc","routing_preferences":[]}"#) + .create_async() + .await; + + let mut headers = HashMap::new(); + headers.insert("Authorization".to_string(), "Bearer token".to_string()); + let cache = Arc::new(PolicyCache::new()); + let client = PolicyProviderClient::new( + RoutingPolicyProvider { + url: format!("{}/v1/routing-policy", server.url()), + headers: Some(headers), + ttl_seconds: Some(Duration::from_secs(300).as_secs()), + }, + cache, + ); + + let _ = client.fetch_policy("customer-abc").await.unwrap(); + } +} diff --git a/crates/brightstaff/src/handlers/router_chat.rs b/crates/brightstaff/src/handlers/router_chat.rs index 910e5408e..2737174c7 100644 --- a/crates/brightstaff/src/handlers/router_chat.rs +++ b/crates/brightstaff/src/handlers/router_chat.rs @@ -2,9 +2,12 @@ use common::configuration::ModelUsagePreference; use hermesllm::clients::endpoints::SupportedUpstreamAPIs; use hermesllm::{ProviderRequest, ProviderRequestType}; use hyper::StatusCode; +use serde_json::Value; +use std::collections::HashMap; use std::sync::Arc; use tracing::{debug, info, warn}; +use crate::handlers::policy_provider::PolicyProviderClient; use crate::router::llm_router::RouterService; use crate::tracing::routing; @@ -13,6 +16,7 @@ pub struct RoutingResult { pub route_name: Option, } +#[derive(Debug)] pub struct RoutingError { pub message: String, pub status_code: StatusCode, @@ -25,6 +29,60 @@ impl RoutingError { status_code: StatusCode::INTERNAL_SERVER_ERROR, } } + + pub fn bad_request(message: String) -> Self { + Self { + message, + status_code: StatusCode::BAD_REQUEST, + } + } +} + +async fn resolve_usage_preferences( + inline_usage_preferences: Option>, + policy_id: Option<&str>, + policy_provider: Option<&PolicyProviderClient>, + routing_metadata: Option<&HashMap>, +) -> Result>, RoutingError> { + if let Some(inline_preferences) = inline_usage_preferences { + info!("using inline routing_policy from request body"); + return Ok(Some(inline_preferences)); + } + + if let (Some(policy_id), Some(policy_provider_client)) = (policy_id, policy_provider) { + match policy_provider_client.fetch_policy(policy_id).await { + Ok(preferences) => { + info!( + policy_id, + "using routing policy from external policy provider" + ); + return Ok(Some(preferences)); + } + Err(err) if err.is_transient() => { + warn!( + policy_id, + error = %err.message(), + "policy provider fetch failed, falling back to metadata/config routing preferences" + ); + } + Err(err) => { + return Err(RoutingError::bad_request(format!( + "Failed to load routing policy for policy_id '{}': {}", + policy_id, + err.message() + ))); + } + } + } + + let usage_preferences_str: Option = routing_metadata.and_then(|metadata| { + metadata + .get("plano_preference_config") + .map(|value| value.to_string()) + }); + Ok(usage_preferences_str + .as_ref() + .and_then(|s| serde_yaml::from_str(s).ok())) } /// Determines the routing decision if @@ -32,6 +90,7 @@ impl RoutingError { /// # Returns /// * `Ok(RoutingResult)` - Contains the selected model name and span ID /// * `Err(RoutingError)` - Contains error details and optional span ID +#[allow(clippy::too_many_arguments)] pub async fn router_chat_get_upstream_model( router_service: Arc, client_request: ProviderRequestType, @@ -39,6 +98,8 @@ pub async fn router_chat_get_upstream_model( request_path: &str, request_id: &str, inline_usage_preferences: Option>, + policy_id: Option, + policy_provider: Option>, ) -> Result { // Clone metadata for routing before converting (which consumes client_request) let routing_metadata = client_request.metadata().clone(); @@ -77,21 +138,13 @@ pub async fn router_chat_get_upstream_model( "router request" ); - // Use inline preferences if provided, otherwise fall back to metadata extraction - let usage_preferences: Option> = if inline_usage_preferences.is_some() - { - inline_usage_preferences - } else { - let usage_preferences_str: Option = - routing_metadata.as_ref().and_then(|metadata| { - metadata - .get("plano_preference_config") - .map(|value| value.to_string()) - }); - usage_preferences_str - .as_ref() - .and_then(|s| serde_yaml::from_str(s).ok()) - }; + let usage_preferences = resolve_usage_preferences( + inline_usage_preferences, + policy_id.as_deref(), + policy_provider.as_deref(), + routing_metadata.as_ref(), + ) + .await?; // Prepare log message with latest message from chat request let latest_message_for_log = chat_request @@ -168,3 +221,109 @@ pub async fn router_chat_get_upstream_model( } } } + +#[cfg(test)] +mod tests { + use super::resolve_usage_preferences; + use crate::handlers::policy_provider::PolicyProviderClient; + use crate::state::policy_cache::PolicyCache; + use common::configuration::{ModelUsagePreference, RoutingPolicyProvider, RoutingPreference}; + use mockito::{Matcher, Server}; + use serde_json::json; + use std::collections::HashMap; + use std::sync::Arc; + + fn inline_policy(name: &str) -> Vec { + vec![ModelUsagePreference { + model: "openai/gpt-4o".to_string(), + routing_preferences: vec![RoutingPreference { + name: name.to_string(), + description: "desc".to_string(), + }], + }] + } + + #[tokio::test] + async fn resolve_usage_preferences_prioritizes_inline_policy() { + let inline = inline_policy("inline"); + let mut metadata = HashMap::new(); + metadata.insert( + "plano_preference_config".to_string(), + json!( + [{"model":"openai/gpt-4o-mini","routing_preferences":[{"name":"metadata","description":"desc"}]}] + ), + ); + + let result = resolve_usage_preferences( + Some(inline.clone()), + Some("policy-a"), + None, + Some(&metadata), + ) + .await + .unwrap(); + assert_eq!(result.unwrap()[0].routing_preferences[0].name, "inline"); + } + + #[tokio::test] + async fn resolve_usage_preferences_falls_back_to_metadata_on_transient_policy_error() { + let mut server = Server::new_async().await; + let _mock = server + .mock("GET", "/policy") + .match_query(Matcher::Any) + .with_status(500) + .create_async() + .await; + + let provider = PolicyProviderClient::new( + RoutingPolicyProvider { + url: format!("{}/policy", server.url()), + headers: None, + ttl_seconds: Some(60), + }, + Arc::new(PolicyCache::new()), + ); + let mut metadata = HashMap::new(); + metadata.insert( + "plano_preference_config".to_string(), + json!( + [{"model":"openai/gpt-4o-mini","routing_preferences":[{"name":"metadata","description":"desc"}]}] + ), + ); + + let result = + resolve_usage_preferences(None, Some("customer-a"), Some(&provider), Some(&metadata)) + .await + .unwrap() + .unwrap(); + + assert_eq!(result[0].routing_preferences[0].name, "metadata"); + } + + #[tokio::test] + async fn resolve_usage_preferences_returns_bad_request_on_policy_mismatch() { + let mut server = Server::new_async().await; + let _mock = server + .mock("GET", "/policy") + .match_query(Matcher::Any) + .with_status(200) + .with_header("content-type", "application/json") + .with_body(r#"{"policy_id":"different","routing_preferences":[]}"#) + .create_async() + .await; + + let provider = PolicyProviderClient::new( + RoutingPolicyProvider { + url: format!("{}/policy", server.url()), + headers: None, + ttl_seconds: Some(60), + }, + Arc::new(PolicyCache::new()), + ); + + let err = resolve_usage_preferences(None, Some("expected"), Some(&provider), None) + .await + .unwrap_err(); + assert_eq!(err.status_code, hyper::StatusCode::BAD_REQUEST); + } +} diff --git a/crates/brightstaff/src/handlers/routing_service.rs b/crates/brightstaff/src/handlers/routing_service.rs index 4eae4685e..5e0c420ec 100644 --- a/crates/brightstaff/src/handlers/routing_service.rs +++ b/crates/brightstaff/src/handlers/routing_service.rs @@ -10,11 +10,13 @@ use hyper::{Request, Response, StatusCode}; use std::sync::Arc; use tracing::{debug, info, info_span, warn, Instrument}; +use crate::handlers::policy_provider::PolicyProviderClient; use crate::handlers::router_chat::router_chat_get_upstream_model; use crate::router::llm_router::RouterService; use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name}; const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120; +type ExtractRoutingPolicyResult = (Bytes, Option>, Option); /// Extracts `routing_policy` from a JSON body, returning the cleaned body bytes /// and parsed preferences. The `routing_policy` field is removed from the JSON @@ -24,10 +26,20 @@ const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120; pub fn extract_routing_policy( raw_bytes: &[u8], warn_on_size: bool, -) -> Result<(Bytes, Option>), String> { +) -> Result { let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes) .map_err(|err| format!("Failed to parse JSON: {}", err))?; + let policy_id = json_body + .as_object_mut() + .and_then(|obj| obj.remove("policy_id")) + .map(|policy_id_value| match policy_id_value { + serde_json::Value::String(policy_id) if !policy_id.trim().is_empty() => Ok(policy_id), + serde_json::Value::String(_) => Err("policy_id cannot be empty".to_string()), + _ => Err("policy_id must be a string".to_string()), + }) + .transpose()?; + let preferences = json_body .as_object_mut() .and_then(|obj| obj.remove("routing_policy")) @@ -58,7 +70,7 @@ pub fn extract_routing_policy( }); let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap()); - Ok((bytes, preferences)) + Ok((bytes, preferences, policy_id)) } #[derive(serde::Serialize)] @@ -71,6 +83,7 @@ struct RoutingDecisionResponse { pub async fn routing_decision( request: Request, router_service: Arc, + policy_provider: Option>, request_path: String, span_attributes: Arc>, ) -> Result>, hyper::Error> { @@ -95,6 +108,7 @@ pub async fn routing_decision( routing_decision_inner( request, router_service, + policy_provider, request_id, request_path, request_headers, @@ -107,6 +121,7 @@ pub async fn routing_decision( async fn routing_decision_inner( request: Request, router_service: Arc, + policy_provider: Option>, request_id: String, request_path: String, request_headers: hyper::HeaderMap, @@ -153,17 +168,18 @@ async fn routing_decision_inner( ); // Extract routing_policy from request body before parsing as ProviderRequestType - let (chat_request_bytes, inline_preferences) = match extract_routing_policy(&raw_bytes, true) { - Ok(result) => result, - Err(err) => { - warn!(error = %err, "failed to parse request JSON"); - return Ok(BrightStaffError::InvalidRequest(format!( - "Failed to parse request JSON: {}", - err - )) - .into_response()); - } - }; + let (chat_request_bytes, inline_preferences, policy_id) = + match extract_routing_policy(&raw_bytes, true) { + Ok(result) => result, + Err(err) => { + warn!(error = %err, "failed to parse request JSON"); + return Ok(BrightStaffError::InvalidRequest(format!( + "Failed to parse request JSON: {}", + err + )) + .into_response()); + } + }; let client_request = match ProviderRequestType::try_from(( &chat_request_bytes[..], @@ -188,6 +204,8 @@ async fn routing_decision_inner( &request_path, &request_id, inline_preferences, + policy_id, + policy_provider, ) .await; @@ -218,7 +236,11 @@ async fn routing_decision_inner( } Err(err) => { warn!(error = %err.message, "routing decision failed"); - Ok(BrightStaffError::InternalServerError(err.message).into_response()) + Ok(BrightStaffError::ForwardedError { + status_code: err.status_code, + message: err.message, + } + .into_response()) } } } @@ -243,9 +265,10 @@ mod tests { #[test] fn extract_routing_policy_no_policy() { let body = make_chat_body(""); - let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap(); + let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap(); assert!(prefs.is_none()); + assert!(policy_id.is_none()); let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap(); assert_eq!(cleaned_json["model"], "gpt-4o-mini"); assert!(cleaned_json.get("routing_policy").is_none()); @@ -268,7 +291,7 @@ mod tests { } ]"#; let body = make_chat_body(policy); - let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap(); + let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap(); let prefs = prefs.expect("should have parsed preferences"); assert_eq!(prefs.len(), 2); @@ -280,6 +303,7 @@ mod tests { // routing_policy should be stripped from cleaned body let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap(); assert!(cleaned_json.get("routing_policy").is_none()); + assert!(policy_id.is_none()); assert_eq!(cleaned_json["model"], "gpt-4o-mini"); } @@ -288,13 +312,14 @@ mod tests { // routing_policy is present but has wrong shape let policy = r#""routing_policy": "not-an-array""#; let body = make_chat_body(policy); - let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap(); + let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap(); // Invalid policy should be ignored (returns None), not error assert!(prefs.is_none()); // routing_policy should still be stripped from cleaned body let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap(); assert!(cleaned_json.get("routing_policy").is_none()); + assert!(policy_id.is_none()); } #[test] @@ -309,23 +334,44 @@ mod tests { fn extract_routing_policy_empty_array() { let policy = r#""routing_policy": []"#; let body = make_chat_body(policy); - let (_, prefs) = extract_routing_policy(&body, false).unwrap(); + let (_, prefs, policy_id) = extract_routing_policy(&body, false).unwrap(); let prefs = prefs.expect("empty array is valid"); assert_eq!(prefs.len(), 0); + assert!(policy_id.is_none()); } #[test] fn extract_routing_policy_preserves_other_fields() { let policy = r#""routing_policy": [{"model": "gpt-4o", "routing_preferences": [{"name": "test", "description": "test"}]}], "temperature": 0.5, "max_tokens": 100"#; let body = make_chat_body(policy); - let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap(); + let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap(); assert!(prefs.is_some()); let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap(); assert_eq!(cleaned_json["temperature"], 0.5); assert_eq!(cleaned_json["max_tokens"], 100); assert!(cleaned_json.get("routing_policy").is_none()); + assert!(policy_id.is_none()); + } + + #[test] + fn extract_routing_policy_extracts_and_strips_policy_id() { + let body = make_chat_body(r#""policy_id": "customer-abc-123""#); + let (cleaned, prefs, policy_id) = extract_routing_policy(&body, false).unwrap(); + + assert!(prefs.is_none()); + assert_eq!(policy_id, Some("customer-abc-123".to_string())); + let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap(); + assert!(cleaned_json.get("policy_id").is_none()); + } + + #[test] + fn extract_routing_policy_rejects_non_string_policy_id() { + let body = make_chat_body(r#""policy_id": 123"#); + let result = extract_routing_policy(&body, false); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("policy_id must be a string")); } #[test] diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 51c9127f4..75d81c731 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -2,10 +2,12 @@ use brightstaff::handlers::agent_chat_completions::agent_chat; use brightstaff::handlers::function_calling::function_calling_chat_handler; use brightstaff::handlers::llm::llm_chat; use brightstaff::handlers::models::list_models; +use brightstaff::handlers::policy_provider::PolicyProviderClient; use brightstaff::handlers::routing_service::routing_decision; use brightstaff::router::llm_router::RouterService; use brightstaff::router::plano_orchestrator::OrchestratorService; use brightstaff::state::memory::MemoryConversationalStorage; +use brightstaff::state::policy_cache::PolicyCache; use brightstaff::state::postgresql::PostgreSQLConversationStorage; use brightstaff::state::StateStorage; use brightstaff::utils::tracing::init_tracer; @@ -108,6 +110,16 @@ async fn main() -> Result<(), Box> { routing_model_name, routing_llm_provider, )); + let policy_provider: Option> = plano_config + .routing + .as_ref() + .and_then(|routing| routing.policy_provider.clone()) + .map(|policy_provider_config| { + Arc::new(PolicyProviderClient::new( + policy_provider_config, + Arc::new(PolicyCache::new()), + )) + }); let orchestrator_service: Arc = Arc::new(OrchestratorService::new( format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"), @@ -172,6 +184,7 @@ async fn main() -> Result<(), Box> { let router_service: Arc = Arc::clone(&router_service); let orchestrator_service: Arc = Arc::clone(&orchestrator_service); + let policy_provider = policy_provider.clone(); let model_aliases: Arc< Option>, > = Arc::clone(&model_aliases); @@ -185,6 +198,7 @@ async fn main() -> Result<(), Box> { let service = service_fn(move |req| { let router_service = Arc::clone(&router_service); let orchestrator_service = Arc::clone(&orchestrator_service); + let policy_provider = policy_provider.clone(); let parent_cx = extract_context_from_request(&req); let llm_provider_url = llm_provider_url.clone(); let llm_providers = llm_providers.clone(); @@ -227,6 +241,7 @@ async fn main() -> Result<(), Box> { return routing_decision( req, router_service, + policy_provider, stripped_path, span_attributes, ) @@ -243,6 +258,7 @@ async fn main() -> Result<(), Box> { llm_chat( req, router_service, + policy_provider, fully_qualified_url, model_aliases, llm_providers, diff --git a/crates/brightstaff/src/state/mod.rs b/crates/brightstaff/src/state/mod.rs index 3d59f359a..84cdde9b5 100644 --- a/crates/brightstaff/src/state/mod.rs +++ b/crates/brightstaff/src/state/mod.rs @@ -9,6 +9,7 @@ use std::sync::Arc; use tracing::debug; pub mod memory; +pub mod policy_cache; pub mod postgresql; pub mod response_state_processor; diff --git a/crates/brightstaff/src/state/policy_cache.rs b/crates/brightstaff/src/state/policy_cache.rs new file mode 100644 index 000000000..c8bf2bdd2 --- /dev/null +++ b/crates/brightstaff/src/state/policy_cache.rs @@ -0,0 +1,108 @@ +use common::configuration::ModelUsagePreference; +use std::collections::HashMap; +use std::time::{Duration, Instant}; +use tokio::sync::RwLock; + +#[derive(Clone)] +struct CachedPolicy { + preferences: Vec, + expires_at: Instant, +} + +pub struct PolicyCache { + entries: RwLock>, +} + +impl Default for PolicyCache { + fn default() -> Self { + Self::new() + } +} + +impl PolicyCache { + pub fn new() -> Self { + Self { + entries: RwLock::new(HashMap::new()), + } + } + + pub async fn get_valid(&self, policy_id: &str) -> Option> { + let now = Instant::now(); + let cached = { + let entries = self.entries.read().await; + entries.get(policy_id).cloned() + }; + + let cached = cached?; + if cached.expires_at > now { + return Some(cached.preferences); + } + + self.entries.write().await.remove(policy_id); + None + } + + pub async fn insert( + &self, + policy_id: String, + preferences: Vec, + ttl: Duration, + ) { + let expires_at = Instant::now() + ttl; + self.entries.write().await.insert( + policy_id, + CachedPolicy { + preferences, + expires_at, + }, + ); + } +} + +#[cfg(test)] +mod tests { + use super::PolicyCache; + use common::configuration::{ModelUsagePreference, RoutingPreference}; + use std::time::Duration; + + fn sample_preferences() -> Vec { + vec![ModelUsagePreference { + model: "openai/gpt-4o".to_string(), + routing_preferences: vec![RoutingPreference { + name: "quick response".to_string(), + description: "fast lightweight responses".to_string(), + }], + }] + } + + #[tokio::test] + async fn returns_cached_policy_before_expiry() { + let cache = PolicyCache::new(); + cache + .insert( + "customer-a".to_string(), + sample_preferences(), + Duration::from_secs(10), + ) + .await; + + let cached = cache.get_valid("customer-a").await; + assert!(cached.is_some()); + assert_eq!(cached.unwrap()[0].model, "openai/gpt-4o"); + } + + #[tokio::test] + async fn expires_cached_policy_after_ttl() { + let cache = PolicyCache::new(); + cache + .insert( + "customer-a".to_string(), + sample_preferences(), + Duration::from_millis(5), + ) + .await; + + tokio::time::sleep(Duration::from_millis(10)).await; + assert!(cache.get_valid("customer-a").await.is_none()); + } +} diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index f4e2b7b41..640d2387e 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -9,8 +9,17 @@ use crate::api::open_ai::{ #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Routing { + #[serde(alias = "llm_provider")] pub model_provider: Option, pub model: Option, + pub policy_provider: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RoutingPolicyProvider { + pub url: String, + pub headers: Option>, + pub ttl_seconds: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -270,7 +279,7 @@ impl LlmProviderType { } } -#[derive(Serialize, Deserialize, Debug)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelUsagePreference { pub model: String, pub routing_preferences: Vec, diff --git a/docs/source/guides/llm_router.rst b/docs/source/guides/llm_router.rst index 188b1e300..d6a609428 100644 --- a/docs/source/guides/llm_router.rst +++ b/docs/source/guides/llm_router.rst @@ -193,6 +193,65 @@ Clients can let the router decide or still specify aliases: # No model specified - router will analyze and choose claude-sonnet-4-5 ) +External Policy Provider (policy_id) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For multitenant deployments, Plano can fetch routing preferences from an external HTTP endpoint using a ``policy_id`` provided by the caller. + +Resolution order is: + +1. Inline ``routing_policy`` in request payload +2. ``policy_id`` lookup via ``routing.policy_provider`` +3. Metadata ``plano_preference_config`` +4. Config-file ``routing_preferences`` + +.. code-block:: yaml + :caption: External Policy Provider Configuration + + routing: + model_provider: arch-router + model: Arch-Router + policy_provider: + url: https://my-service.internal/v1/routing-policy + headers: + Authorization: Bearer $POLICY_API_KEY + ttl_seconds: 300 + +When ``policy_id`` is provided and no inline ``routing_policy`` is present, Plano fetches: + +.. code-block:: text + + GET https://my-service.internal/v1/routing-policy?policy_id=customer-abc-123 + +.. code-block:: json + :caption: Routing request with policy_id + + { + "messages": [{"role": "user", "content": "Help me summarize this"}], + "policy_id": "customer-abc-123" + } + +.. code-block:: json + :caption: Expected response from external policy endpoint + + { + "policy_id": "customer-abc-123", + "routing_preferences": [ + { + "model": "openai/gpt-4o", + "routing_preferences": [ + {"name": "quick response", "description": "fast lightweight responses"} + ] + }, + { + "model": "anthropic/claude-sonnet-4-0", + "routing_preferences": [ + {"name": "deep analysis", "description": "comprehensive detailed analysis"} + ] + } + ] + } + Arch-Router ----------- diff --git a/docs/source/resources/includes/plano_config_full_reference.yaml b/docs/source/resources/includes/plano_config_full_reference.yaml index a650baea3..f420d1c4d 100644 --- a/docs/source/resources/includes/plano_config_full_reference.yaml +++ b/docs/source/resources/includes/plano_config_full_reference.yaml @@ -47,6 +47,18 @@ model_aliases: smart-llm: target: gpt-4o +# Optional routing policy provider for multitenant preference-based routing. +# If policy_id is included in the request and inline routing_policy is absent, +# Plano fetches routing preferences from this endpoint and caches by policy_id. +routing: + model_provider: arch-router + model: Arch-Router + policy_provider: + url: https://my-service.internal/v1/routing-policy + headers: + Authorization: Bearer $POLICY_API_KEY + ttl_seconds: 300 + # HTTP listeners - entry points for agent routing, prompt targets, and direct LLM access listeners: # Agent listener for routing requests to multiple agents