diff --git a/config/plano_config_schema.yaml b/config/plano_config_schema.yaml index 5190fecf7..7b18eb027 100644 --- a/config/plano_config_schema.yaml +++ b/config/plano_config_schema.yaml @@ -422,6 +422,22 @@ properties: enum: - llm - prompt + routing: + type: object + properties: + llm_provider: + type: string + model: + type: string + session_ttl_seconds: + type: integer + minimum: 1 + description: TTL in seconds for session-pinned routing cache entries. Default 600 (10 minutes). + session_max_entries: + type: integer + minimum: 1 + description: Maximum number of session-pinned routing cache entries. Default 10000. + additionalProperties: false state_storage: type: object properties: diff --git a/crates/brightstaff/src/handlers/llm/mod.rs b/crates/brightstaff/src/handlers/llm/mod.rs index 9d4a2dfb3..2c3a01240 100644 --- a/crates/brightstaff/src/handlers/llm/mod.rs +++ b/crates/brightstaff/src/handlers/llm/mod.rs @@ -1,6 +1,6 @@ use bytes::Bytes; use common::configuration::{FilterPipeline, ModelAlias}; -use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER}; +use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, SESSION_ID_HEADER}; use common::llm_providers::LlmProviders; use hermesllm::apis::openai::Message; use hermesllm::apis::openai_responses::InputParam; @@ -92,6 +92,21 @@ async fn llm_chat_inner( let traceparent = extract_or_generate_traceparent(&request_headers); + // Session pinning: extract session ID and check cache before routing + let session_id: Option = request_headers + .get(SESSION_ID_HEADER) + .and_then(|h| h.to_str().ok()) + .map(|s| s.to_string()); + let pinned_model: Option = if let Some(ref sid) = session_id { + state + .router_service + .get_cached_route(sid) + .await + .map(|c| c.model_name) + } else { + None + }; + let full_qualified_llm_provider_url = format!("{}{}", state.llm_provider_url, request_path); // --- Phase 1: Parse and validate the incoming request --- @@ -242,46 +257,65 @@ async fn llm_chat_inner( } }; - // --- Phase 3: Route the request --- - let routing_span = info_span!( - "routing", - component = "routing", - http.method = "POST", - http.target = %request_path, - model.requested = %model_from_request, - model.alias_resolved = %alias_resolved_model, - route.selected_model = tracing::field::Empty, - routing.determination_ms = tracing::field::Empty, - ); - let routing_result = match async { - set_service_name(operation_component::ROUTING); - router_chat_get_upstream_model( - Arc::clone(&state.router_service), - client_request, - &traceparent, - &request_path, - &request_id, - inline_routing_policy, - ) + // --- Phase 3: Route the request (or use pinned model from session cache) --- + let resolved_model = if let Some(cached_model) = pinned_model { + info!( + session_id = %session_id.as_deref().unwrap_or(""), + model = %cached_model, + "using pinned routing decision from cache" + ); + cached_model + } else { + let routing_span = info_span!( + "routing", + component = "routing", + http.method = "POST", + http.target = %request_path, + model.requested = %model_from_request, + model.alias_resolved = %alias_resolved_model, + route.selected_model = tracing::field::Empty, + routing.determination_ms = tracing::field::Empty, + ); + let routing_result = match async { + set_service_name(operation_component::ROUTING); + router_chat_get_upstream_model( + Arc::clone(&state.router_service), + client_request, + &traceparent, + &request_path, + &request_id, + inline_routing_policy, + ) + .await + } + .instrument(routing_span) .await - } - .instrument(routing_span) - .await - { - Ok(result) => result, - Err(err) => { - let mut internal_error = Response::new(full(err.message)); - *internal_error.status_mut() = err.status_code; - return Ok(internal_error); + { + Ok(result) => result, + Err(err) => { + let mut internal_error = Response::new(full(err.message)); + *internal_error.status_mut() = err.status_code; + return Ok(internal_error); + } + }; + + let (router_selected_model, route_name) = + (routing_result.model_name, routing_result.route_name); + let model = if router_selected_model != "none" { + router_selected_model + } else { + alias_resolved_model.clone() + }; + + // Cache the routing decision so subsequent requests with the same session ID are pinned + if let Some(ref sid) = session_id { + state + .router_service + .cache_route(sid.clone(), model.clone(), route_name) + .await; } - }; - // Determine final model (router returns "none" when it doesn't select a specific model) - let router_selected_model = routing_result.model_name; - let resolved_model = if router_selected_model != "none" { - router_selected_model - } else { - alias_resolved_model.clone() + model }; tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str()); diff --git a/crates/brightstaff/src/handlers/routing_service.rs b/crates/brightstaff/src/handlers/routing_service.rs index ec09f06fb..5eb3c6d6d 100644 --- a/crates/brightstaff/src/handlers/routing_service.rs +++ b/crates/brightstaff/src/handlers/routing_service.rs @@ -1,6 +1,6 @@ use bytes::Bytes; use common::configuration::{ModelUsagePreference, SpanAttributes}; -use common::consts::REQUEST_ID_HEADER; +use common::consts::{REQUEST_ID_HEADER, SESSION_ID_HEADER}; use common::errors::BrightStaffError; use hermesllm::clients::SupportedAPIsFromClient; use hermesllm::ProviderRequestType; @@ -67,6 +67,9 @@ struct RoutingDecisionResponse { model: String, route: Option, trace_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + session_id: Option, + pinned: bool, } pub async fn routing_decision( @@ -82,6 +85,11 @@ pub async fn routing_decision( .map(|s| s.to_string()) .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + let session_id: Option = request_headers + .get(SESSION_ID_HEADER) + .and_then(|h| h.to_str().ok()) + .map(|s| s.to_string()); + let custom_attrs = collect_custom_trace_attributes(&request_headers, span_attributes.as_ref()); let request_span = info_span!( @@ -99,6 +107,7 @@ pub async fn routing_decision( request_path, request_headers, custom_attrs, + session_id, ) .instrument(request_span) .await @@ -111,6 +120,7 @@ async fn routing_decision_inner( request_path: String, request_headers: hyper::HeaderMap, custom_attrs: std::collections::HashMap, + session_id: Option, ) -> Result>, hyper::Error> { set_service_name(operation_component::ROUTING); opentelemetry::trace::get_active_span(|span| { @@ -128,6 +138,34 @@ async fn routing_decision_inner( .unwrap_or("unknown") .to_string(); + // Session pinning: check cache before doing any routing work + if let Some(ref sid) = session_id { + if let Some(cached) = router_service.get_cached_route(sid).await { + info!( + session_id = %sid, + model = %cached.model_name, + route = ?cached.route_name, + "returning pinned routing decision from cache" + ); + let response = RoutingDecisionResponse { + model: cached.model_name, + route: cached.route_name, + trace_id, + session_id: Some(sid.clone()), + pinned: true, + }; + let json = serde_json::to_string(&response).unwrap(); + let body = Full::new(Bytes::from(json)) + .map_err(|never| match never {}) + .boxed(); + return Ok(Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(body) + .unwrap()); + } + } + // Parse request body let raw_bytes = request.collect().await?.to_bytes(); @@ -166,7 +204,7 @@ async fn routing_decision_inner( // Call the existing routing logic with inline preferences let routing_result = router_chat_get_upstream_model( - router_service, + Arc::clone(&router_service), client_request, &traceparent, &request_path, @@ -177,10 +215,23 @@ async fn routing_decision_inner( match routing_result { Ok(result) => { + // Cache the result if session_id is present + if let Some(ref sid) = session_id { + router_service + .cache_route( + sid.clone(), + result.model_name.clone(), + result.route_name.clone(), + ) + .await; + } + let response = RoutingDecisionResponse { model: result.model_name, route: result.route_name, trace_id, + session_id, + pinned: false, }; info!( @@ -318,12 +369,16 @@ mod tests { model: "openai/gpt-4o".to_string(), route: Some("code_generation".to_string()), trace_id: "abc123".to_string(), + session_id: Some("sess-abc".to_string()), + pinned: true, }; let json = serde_json::to_string(&response).unwrap(); let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); assert_eq!(parsed["model"], "openai/gpt-4o"); assert_eq!(parsed["route"], "code_generation"); assert_eq!(parsed["trace_id"], "abc123"); + assert_eq!(parsed["session_id"], "sess-abc"); + assert_eq!(parsed["pinned"], true); } #[test] @@ -332,10 +387,14 @@ mod tests { model: "none".to_string(), route: None, trace_id: "abc123".to_string(), + session_id: None, + pinned: false, }; let json = serde_json::to_string(&response).unwrap(); let parsed: serde_json::Value = serde_json::from_str(&json).unwrap(); assert_eq!(parsed["model"], "none"); assert!(parsed["route"].is_null()); + assert!(parsed.get("session_id").is_none()); + assert_eq!(parsed["pinned"], false); } } diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 60a69bca6..bdf54ab6d 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -162,13 +162,31 @@ async fn init_app_state( .map(|p| p.name.clone()) .unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string()); + let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds); + + let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries); + let router_service = Arc::new(RouterService::new( config.model_providers.clone(), format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"), routing_model_name, routing_llm_provider, + session_ttl_seconds, + session_max_entries, )); + // Spawn background task to clean up expired session cache entries every 5 minutes + { + let router_service = Arc::clone(&router_service); + tokio::spawn(async move { + let mut interval = tokio::time::interval(std::time::Duration::from_secs(300)); + loop { + interval.tick().await; + router_service.cleanup_expired_sessions().await; + } + }); + } + let orchestrator_model_name: String = overrides .agent_orchestration_model .as_deref() diff --git a/crates/brightstaff/src/router/llm.rs b/crates/brightstaff/src/router/llm.rs index 7d27e80a2..8eb2f0dba 100644 --- a/crates/brightstaff/src/router/llm.rs +++ b/crates/brightstaff/src/router/llm.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, sync::Arc}; +use std::{collections::HashMap, sync::Arc, time::Duration, time::Instant}; use common::{ configuration::{LlmProvider, ModelUsagePreference, RoutingPreference}, @@ -7,6 +7,7 @@ use common::{ use hermesllm::apis::openai::Message; use hyper::header; use thiserror::Error; +use tokio::sync::RwLock; use tracing::{debug, info}; use super::http::{self, post_and_extract_content}; @@ -14,12 +15,25 @@ use super::router_model::RouterModel; use crate::router::router_model_v1; +const DEFAULT_SESSION_TTL_SECONDS: u64 = 600; +const DEFAULT_SESSION_MAX_ENTRIES: usize = 10_000; + +#[derive(Clone, Debug)] +pub struct CachedRoute { + pub model_name: String, + pub route_name: Option, + pub cached_at: Instant, +} + pub struct RouterService { router_url: String, client: reqwest::Client, router_model: Arc, routing_provider_name: String, llm_usage_defined: bool, + session_cache: RwLock>, + session_ttl: Duration, + session_max_entries: usize, } #[derive(Debug, Error)] @@ -39,6 +53,8 @@ impl RouterService { router_url: String, routing_model_name: String, routing_provider_name: String, + session_ttl_seconds: Option, + session_max_entries: Option, ) -> Self { let providers_with_usage = providers .iter() @@ -62,12 +78,74 @@ impl RouterService { router_model_v1::MAX_TOKEN_LEN, )); + let session_ttl = + Duration::from_secs(session_ttl_seconds.unwrap_or(DEFAULT_SESSION_TTL_SECONDS)); + let session_max_entries = session_max_entries.unwrap_or(DEFAULT_SESSION_MAX_ENTRIES); + RouterService { router_url, client: reqwest::Client::new(), router_model, routing_provider_name, llm_usage_defined: !providers_with_usage.is_empty(), + session_cache: RwLock::new(HashMap::new()), + session_ttl, + session_max_entries, + } + } + + /// Look up a cached routing decision by session ID. + /// Returns None if not found or expired. + pub async fn get_cached_route(&self, session_id: &str) -> Option { + let cache = self.session_cache.read().await; + if let Some(entry) = cache.get(session_id) { + if entry.cached_at.elapsed() < self.session_ttl { + return Some(entry.clone()); + } + } + None + } + + /// Store a routing decision in the session cache. + /// If at max capacity, evicts the oldest entry. + pub async fn cache_route( + &self, + session_id: String, + model_name: String, + route_name: Option, + ) { + let mut cache = self.session_cache.write().await; + if cache.len() >= self.session_max_entries && !cache.contains_key(&session_id) { + if let Some(oldest_key) = cache + .iter() + .min_by_key(|(_, v)| v.cached_at) + .map(|(k, _)| k.clone()) + { + cache.remove(&oldest_key); + } + } + cache.insert( + session_id, + CachedRoute { + model_name, + route_name, + cached_at: Instant::now(), + }, + ); + } + + /// Remove all expired entries from the session cache. + pub async fn cleanup_expired_sessions(&self) { + let mut cache = self.session_cache.write().await; + let before = cache.len(); + cache.retain(|_, entry| entry.cached_at.elapsed() < self.session_ttl); + let removed = before - cache.len(); + if removed > 0 { + info!( + removed = removed, + remaining = cache.len(), + "cleaned up expired session cache entries" + ); } } @@ -146,3 +224,101 @@ impl RouterService { Ok(parsed) } } + +#[cfg(test)] +mod tests { + use super::*; + + fn make_router_service(ttl_seconds: u64, max_entries: usize) -> RouterService { + RouterService::new( + vec![], + "http://localhost:12001/v1/chat/completions".to_string(), + "Arch-Router".to_string(), + "arch-router".to_string(), + Some(ttl_seconds), + Some(max_entries), + ) + } + + #[tokio::test] + async fn test_cache_miss_returns_none() { + let svc = make_router_service(600, 100); + assert!(svc.get_cached_route("unknown-session").await.is_none()); + } + + #[tokio::test] + async fn test_cache_hit_returns_cached_route() { + let svc = make_router_service(600, 100); + svc.cache_route( + "s1".to_string(), + "gpt-4o".to_string(), + Some("code".to_string()), + ) + .await; + + let cached = svc.get_cached_route("s1").await.unwrap(); + assert_eq!(cached.model_name, "gpt-4o"); + assert_eq!(cached.route_name, Some("code".to_string())); + } + + #[tokio::test] + async fn test_cache_expired_entry_returns_none() { + let svc = make_router_service(0, 100); + svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None) + .await; + assert!(svc.get_cached_route("s1").await.is_none()); + } + + #[tokio::test] + async fn test_cleanup_removes_expired() { + let svc = make_router_service(0, 100); + svc.cache_route("s1".to_string(), "gpt-4o".to_string(), None) + .await; + svc.cache_route("s2".to_string(), "claude".to_string(), None) + .await; + + svc.cleanup_expired_sessions().await; + + let cache = svc.session_cache.read().await; + assert!(cache.is_empty()); + } + + #[tokio::test] + async fn test_cache_evicts_oldest_when_full() { + let svc = make_router_service(600, 2); + svc.cache_route("s1".to_string(), "model-a".to_string(), None) + .await; + tokio::time::sleep(Duration::from_millis(10)).await; + svc.cache_route("s2".to_string(), "model-b".to_string(), None) + .await; + + svc.cache_route("s3".to_string(), "model-c".to_string(), None) + .await; + + let cache = svc.session_cache.read().await; + assert_eq!(cache.len(), 2); + assert!(!cache.contains_key("s1")); + assert!(cache.contains_key("s2")); + assert!(cache.contains_key("s3")); + } + + #[tokio::test] + async fn test_cache_update_existing_session_does_not_evict() { + let svc = make_router_service(600, 2); + svc.cache_route("s1".to_string(), "model-a".to_string(), None) + .await; + svc.cache_route("s2".to_string(), "model-b".to_string(), None) + .await; + + svc.cache_route( + "s1".to_string(), + "model-a-updated".to_string(), + Some("route".to_string()), + ) + .await; + + let cache = svc.session_cache.read().await; + assert_eq!(cache.len(), 2); + assert_eq!(cache.get("s1").unwrap().model_name, "model-a-updated"); + } +} diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index df1790594..929efdbed 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -7,6 +7,14 @@ use crate::api::open_ai::{ ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType, }; +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Routing { + pub model_provider: Option, + pub model: Option, + pub session_ttl_seconds: Option, + pub session_max_entries: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelAlias { pub target: String, @@ -111,6 +119,7 @@ pub struct Configuration { pub model_providers: Vec, pub model_aliases: Option>, pub overrides: Option, + pub routing: Option, pub system_prompt: Option, pub prompt_guards: Option, pub prompt_targets: Option>, diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index dbd0bc417..179c66ac1 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -22,6 +22,7 @@ pub const X_ARCH_TOOL_CALL: &str = "x-arch-tool-call-message"; pub const X_ARCH_FC_MODEL_RESPONSE: &str = "x-arch-fc-model-response"; pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function"; pub const REQUEST_ID_HEADER: &str = "x-request-id"; +pub const SESSION_ID_HEADER: &str = "x-session-id"; pub const ENVOY_ORIGINAL_PATH_HEADER: &str = "x-envoy-original-path"; pub const TRACE_PARENT_HEADER: &str = "traceparent"; pub const ARCH_INTERNAL_CLUSTER_NAME: &str = "arch_internal"; diff --git a/demos/llm_routing/model_routing_service/README.md b/demos/llm_routing/model_routing_service/README.md index 72b672f32..ce73224d6 100644 --- a/demos/llm_routing/model_routing_service/README.md +++ b/demos/llm_routing/model_routing_service/README.md @@ -103,6 +103,63 @@ Response: The response tells you which model would handle this request and which route was matched, without actually making the LLM call. +## Session Pinning + +Send an `X-Session-Id` header to pin the routing decision for a session. Once a model is selected, all subsequent requests with the same session ID return the same model without re-running routing. + +```bash +# First call — runs routing, caches result +curl http://localhost:12000/routing/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "X-Session-Id: my-session-123" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Write a Python function for binary search"}] + }' +``` + +Response (first call): +```json +{ + "model": "anthropic/claude-sonnet-4-20250514", + "route": "code_generation", + "trace_id": "c16d1096c1af4a17abb48fb182918a88", + "session_id": "my-session-123", + "pinned": false +} +``` + +```bash +# Second call — same session, returns cached result +curl http://localhost:12000/routing/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "X-Session-Id: my-session-123" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [{"role": "user", "content": "Now explain merge sort"}] + }' +``` + +Response (pinned): +```json +{ + "model": "anthropic/claude-sonnet-4-20250514", + "route": "code_generation", + "trace_id": "a1b2c3d4e5f6...", + "session_id": "my-session-123", + "pinned": true +} +``` + +Session TTL and max cache size are configurable in `config.yaml`: +```yaml +routing: + session_ttl_seconds: 600 # default: 600 (10 minutes) + session_max_entries: 10000 # default: 10000 +``` + +Without the `X-Session-Id` header, routing runs fresh every time (no breaking change). + ## Kubernetes Deployment (Self-hosted Arch-Router on GPU) To run Arch-Router in-cluster using vLLM instead of the default hosted endpoint: @@ -199,5 +256,33 @@ kubectl rollout restart deployment/plano "trace_id": "26be822bbdf14a3ba19fe198e55ea4a9" } +--- 7. Session pinning - first call (fresh routing decision) --- +{ + "model": "anthropic/claude-sonnet-4-20250514", + "route": "code_generation", + "trace_id": "f1a2b3c4d5e6f7a8b9c0d1e2f3a4b5c6", + "session_id": "demo-session-001", + "pinned": false +} + +--- 8. Session pinning - second call (same session, pinned) --- + Notice: same model returned with "pinned": true, routing was skipped +{ + "model": "anthropic/claude-sonnet-4-20250514", + "route": "code_generation", + "trace_id": "a9b8c7d6e5f4a3b2c1d0e9f8a7b6c5d4", + "session_id": "demo-session-001", + "pinned": true +} + +--- 9. Different session gets its own fresh routing --- +{ + "model": "openai/gpt-4o", + "route": "complex_reasoning", + "trace_id": "1a2b3c4d5e6f7a8b9c0d1e2f3a4b5c6d", + "session_id": "demo-session-002", + "pinned": false +} + === Demo Complete === ``` diff --git a/demos/llm_routing/model_routing_service/demo.sh b/demos/llm_routing/model_routing_service/demo.sh index 0c3fdc5d6..1e3d3b6cb 100755 --- a/demos/llm_routing/model_routing_service/demo.sh +++ b/demos/llm_routing/model_routing_service/demo.sh @@ -117,4 +117,47 @@ curl -s "$PLANO_URL/routing/v1/messages" \ }' | python3 -m json.tool echo "" +# --- Example 7: Session pinning - first call (fresh routing) --- +echo "--- 7. Session pinning - first call (fresh routing decision) ---" +echo "" +curl -s "$PLANO_URL/routing/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -H "X-Session-Id: demo-session-001" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [ + {"role": "user", "content": "Write a Python function that implements binary search on a sorted array"} + ] + }' | python3 -m json.tool +echo "" + +# --- Example 8: Session pinning - second call (pinned result) --- +echo "--- 8. Session pinning - second call (same session, pinned) ---" +echo " Notice: same model returned with \"pinned\": true, routing was skipped" +echo "" +curl -s "$PLANO_URL/routing/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -H "X-Session-Id: demo-session-001" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [ + {"role": "user", "content": "Now explain how merge sort works and when to prefer it over quicksort"} + ] + }' | python3 -m json.tool +echo "" + +# --- Example 9: Different session gets fresh routing --- +echo "--- 9. Different session gets its own fresh routing ---" +echo "" +curl -s "$PLANO_URL/routing/v1/chat/completions" \ + -H "Content-Type: application/json" \ + -H "X-Session-Id: demo-session-002" \ + -d '{ + "model": "gpt-4o-mini", + "messages": [ + {"role": "user", "content": "Explain the trade-offs between microservices and monolithic architectures"} + ] + }' | python3 -m json.tool +echo "" + echo "=== Demo Complete ===" diff --git a/demos/llm_routing/session_pinning/README.md b/demos/llm_routing/session_pinning/README.md new file mode 100644 index 000000000..a84d440e4 --- /dev/null +++ b/demos/llm_routing/session_pinning/README.md @@ -0,0 +1,156 @@ +# Session Pinning Demo + +> Consistent model selection for agentic loops using `X-Session-Id`. + +## Why Session Pinning? + +When an agent runs in a loop — research → analyse → implement → evaluate → summarise — each step hits Plano's router independently. Because prompts vary in intent, the router may select **different models** for each step, fragmenting context mid-session. + +**Session pinning** solves this: send an `X-Session-Id` header and the first request runs routing as usual, caching the decision. Every subsequent request with the same session ID returns the **same model**, without re-running the router. + +``` +Without pinning With pinning (X-Session-Id) +───────────────── ────────────────────────── +Step 1 → claude-sonnet (code_gen) Step 1 → claude-sonnet ← routed +Step 2 → gpt-4o (reasoning) Step 2 → claude-sonnet ← pinned ✓ +Step 3 → claude-sonnet (code_gen) Step 3 → claude-sonnet ← pinned ✓ +Step 4 → gpt-4o (reasoning) Step 4 → claude-sonnet ← pinned ✓ +Step 5 → claude-sonnet (code_gen) Step 5 → claude-sonnet ← pinned ✓ + ↑ model switches every step ↑ one model, start to finish +``` + +--- + +## Quick Start + +```bash +# 1. Set API keys +export OPENAI_API_KEY= +export ANTHROPIC_API_KEY= + +# 2. Start Plano +cd demos/llm_routing/session_pinning +planoai up config.yaml + +# 3. Run the demo (uv manages dependencies automatically) +./demo.sh # or: uv run demo.py +``` + +--- + +## What the Demo Does + +A **Database Research Agent** investigates whether to use PostgreSQL or MongoDB +for an e-commerce platform. It runs 5 steps, each building on prior findings via +accumulated message history. Steps alternate between `code_generation` and +`complex_reasoning` intents so Plano routes to different models without pinning. + +| Step | Task | Intent | +|:----:|------|--------| +| 1 | List technical requirements | code_generation → claude-sonnet | +| 2 | Compare PostgreSQL vs MongoDB | complex_reasoning → gpt-4o | +| 3 | Write schema (CREATE TABLE) | code_generation → claude-sonnet | +| 4 | Assess scalability trade-offs | complex_reasoning → gpt-4o | +| 5 | Write final recommendation report | code_generation → claude-sonnet | + +The demo runs the loop **twice** against `/v1/chat/completions` using the +[OpenAI SDK](https://github.com/openai/openai-python): + +1. **Without pinning** — no `X-Session-Id`; models alternate per step +2. **With pinning** — `X-Session-Id` header included; model is pinned from step 1 + +Each step makes real LLM calls. Step 5's report explicitly references findings +from earlier steps, demonstrating why coherent context requires a consistent model. + +### Expected Output + +``` + Run 1: WITHOUT Session Pinning + ───────────────────────────────────────────────────────────────────── + step 1 [claude-sonnet-4-20250514] List requirements + "Critical requirements: 1. ACID transactions for order integrity…" + + step 2 [gpt-4o ] Compare databases ← switched + "PostgreSQL excels at joins and ACID guarantees…" + + step 3 [claude-sonnet-4-20250514] Write schema ← switched + "CREATE TABLE orders (\n id SERIAL PRIMARY KEY…" + + step 4 [gpt-4o ] Assess scalability ← switched + "At high write volume, PostgreSQL row-level locking…" + + step 5 [claude-sonnet-4-20250514] Write report ← switched + "RECOMMENDATION: PostgreSQL is the right choice…" + + ✗ Without pinning: model switched 4 time(s) — gpt-4o, claude-sonnet-4-20250514 + + + Run 2: WITH Session Pinning (X-Session-Id: a1b2c3d4…) + ───────────────────────────────────────────────────────────────────── + step 1 [claude-sonnet-4-20250514] List requirements + "Critical requirements: 1. ACID transactions for order integrity…" + + step 2 [claude-sonnet-4-20250514] Compare databases + "Building on the requirements I just outlined: PostgreSQL…" + + step 3 [claude-sonnet-4-20250514] Write schema + "Following the comparison above, here is the PostgreSQL schema…" + + step 4 [claude-sonnet-4-20250514] Assess scalability + "Given the schema I designed, PostgreSQL's row-level locking…" + + step 5 [claude-sonnet-4-20250514] Write report + "RECOMMENDATION: Based on my analysis of requirements, comparison…" + + ✓ With pinning: claude-sonnet-4-20250514 held for all 5 steps + + ══ Final Report (pinned session) ═════════════════════════════════════ + RECOMMENDATION: Based on my analysis of requirements, the head-to-head + comparison, the schema I designed, and the scalability trade-offs… + ══════════════════════════════════════════════════════════════════════ +``` + +### How It Works + +Session pinning is implemented in brightstaff. When `X-Session-Id` is present: + +1. **First request** — routing runs normally, result is cached keyed by session ID +2. **Subsequent requests** — cache hit skips routing and returns the cached model instantly + +The `X-Session-Id` header is forwarded transparently; no changes to your OpenAI +SDK calls beyond adding the header. + +```python +from openai import OpenAI + +client = OpenAI(base_url="http://localhost:12000/v1", api_key="EMPTY") + +session_id = str(uuid.uuid4()) + +response = client.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": prompt}], + extra_headers={"X-Session-Id": session_id}, # pin the session +) +``` + +--- + +## Configuration + +Session pinning is configurable in `config.yaml`: + +```yaml +routing: + session_ttl_seconds: 600 # How long a pinned session lasts (default: 10 min) + session_max_entries: 10000 # Max cached sessions before LRU eviction +``` + +Without the `X-Session-Id` header, routing runs fresh every time — no breaking +change to existing clients. + +--- + +## See Also + +- [Model Routing Service Demo](../model_routing_service/) — curl-based examples of the routing endpoint diff --git a/demos/llm_routing/session_pinning/agent.py b/demos/llm_routing/session_pinning/agent.py new file mode 100644 index 000000000..ffb553d36 --- /dev/null +++ b/demos/llm_routing/session_pinning/agent.py @@ -0,0 +1,429 @@ +#!/usr/bin/env -S uv run --script +# /// script +# requires-python = ">=3.12" +# dependencies = ["fastapi>=0.115", "uvicorn>=0.30", "openai>=1.0.0"] +# /// +""" +Research Agent — FastAPI service exposing /v1/chat/completions. + +For each incoming request the agent runs 3 independent research tasks, +each with its own tool-calling loop. The tasks deliberately alternate between +code_generation and complex_reasoning intents so Plano's preference-based +router selects different models for each task. + +If the client sends X-Session-Id, the agent forwards it on every outbound +call to Plano. The first task pins the model; all subsequent tasks skip the +router and reuse it — keeping the whole session on one consistent model. + +Run standalone: + uv run agent.py + PLANO_URL=http://myhost:12000 AGENT_PORT=8000 uv run agent.py +""" + +import json +import logging +import os +import uuid + +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletionMessageParam + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [AGENT] %(levelname)s %(message)s", +) +log = logging.getLogger(__name__) + +PLANO_URL = os.environ.get("PLANO_URL", "http://localhost:12000") +PORT = int(os.environ.get("AGENT_PORT", "8000")) + +# --------------------------------------------------------------------------- +# Tasks — each has its own conversation so Plano routes each independently. +# Intent alternates: code_generation → complex_reasoning → code_generation. +# --------------------------------------------------------------------------- + +TASKS = [ + { + "name": "generate_comparison", + # Triggers code_generation routing preference (write/generate output) + "prompt": ( + "Use the tools to fetch benchmark data for PostgreSQL and MongoDB " + "under a mixed workload. Then generate a compact Markdown comparison " + "table with columns: metric, PostgreSQL, MongoDB. Cover read QPS, " + "write QPS, p99 latency ms, ACID support, and horizontal scaling." + ), + }, + { + "name": "analyse_tradeoffs", + # Triggers complex_reasoning routing preference (analyse/reason/evaluate) + "prompt": ( + "Context from prior research:\n{context}\n\n" + "Perform a deep analysis: for a high-traffic e-commerce platform that " + "requires ACID guarantees for order processing but flexible schemas for " + "product attributes, carefully reason through and evaluate the long-term " + "architectural trade-offs of each database. Consider consistency " + "guarantees, operational complexity, and scalability risks." + ), + }, + { + "name": "write_schema", + # Triggers code_generation routing preference (write SQL / generate code) + "prompt": ( + "Context from prior research:\n{context}\n\n" + "Write the CREATE TABLE SQL schema for the database you would recommend " + "from the analysis above. Include: orders, order_items, products, and " + "users tables with appropriate primary keys, foreign keys, and indexes." + ), + }, +] + +SYSTEM_PROMPT = ( + "You are a database selection analyst for an e-commerce platform. " + "Use the available tools when you need data. " + "Be concise — each response should be a compact table, code block, " + "or 3–5 clear sentences." +) + +# --------------------------------------------------------------------------- +# Tool definitions +# --------------------------------------------------------------------------- + +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_db_benchmarks", + "description": ( + "Fetch performance benchmark data for a database. " + "Returns read/write throughput, latency, and scaling characteristics." + ), + "parameters": { + "type": "object", + "properties": { + "database": { + "type": "string", + "enum": ["postgresql", "mongodb"], + }, + "workload": { + "type": "string", + "enum": ["read_heavy", "write_heavy", "mixed"], + }, + }, + "required": ["database", "workload"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_case_studies", + "description": "Retrieve e-commerce case studies for a database.", + "parameters": { + "type": "object", + "properties": { + "database": {"type": "string", "enum": ["postgresql", "mongodb"]}, + }, + "required": ["database"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "check_feature_support", + "description": ( + "Check whether a database supports a specific feature " + "(e.g. ACID transactions, horizontal sharding, JSON documents)." + ), + "parameters": { + "type": "object", + "properties": { + "database": {"type": "string", "enum": ["postgresql", "mongodb"]}, + "feature": {"type": "string"}, + }, + "required": ["database", "feature"], + }, + }, + }, +] + +# --------------------------------------------------------------------------- +# Tool implementations (simulated — no external calls) +# --------------------------------------------------------------------------- + +_BENCHMARKS = { + ("postgresql", "read_heavy"): { + "read_qps": 55_000, + "write_qps": 18_000, + "p99_ms": 4, + "notes": "Excellent for complex joins; connection pooling via pgBouncer recommended", + }, + ("postgresql", "write_heavy"): { + "read_qps": 30_000, + "write_qps": 24_000, + "p99_ms": 8, + "notes": "WAL overhead increases at very high write volume; partitioning helps", + }, + ("postgresql", "mixed"): { + "read_qps": 42_000, + "write_qps": 21_000, + "p99_ms": 6, + "notes": "Solid all-round; MVCC keeps reads non-blocking", + }, + ("mongodb", "read_heavy"): { + "read_qps": 85_000, + "write_qps": 30_000, + "p99_ms": 2, + "notes": "Atlas Search built-in; sharding distributes read load well", + }, + ("mongodb", "write_heavy"): { + "read_qps": 40_000, + "write_qps": 65_000, + "p99_ms": 3, + "notes": "WiredTiger compression reduces I/O; journal writes are async-safe", + }, + ("mongodb", "mixed"): { + "read_qps": 60_000, + "write_qps": 50_000, + "p99_ms": 3, + "notes": "Flexible schema accelerates feature iteration", + }, +} + +_CASE_STUDIES = { + "postgresql": [ + { + "company": "Shopify", + "scale": "100 B+ req/day", + "notes": "Moved critical order tables back to Postgres for ACID guarantees", + }, + { + "company": "Zalando", + "scale": "50 M customers", + "notes": "Uses Postgres + Citus for sharded order processing", + }, + { + "company": "Instacart", + "scale": "10 M orders/mo", + "notes": "Postgres for inventory; strict consistency required for stock levels", + }, + ], + "mongodb": [ + { + "company": "eBay", + "scale": "1.5 B listings", + "notes": "Product catalogue in MongoDB for flexible attribute schemas", + }, + { + "company": "Alibaba", + "scale": "billions of docs", + "notes": "Session and cart data in MongoDB; high write throughput", + }, + { + "company": "Foursquare", + "scale": "10 B+ check-ins", + "notes": "Geospatial queries and flexible location schemas", + }, + ], +} + +_FEATURES = { + ("postgresql", "acid transactions"): { + "supported": True, + "notes": "Full ACID with serialisable isolation", + }, + ("postgresql", "horizontal sharding"): { + "supported": True, + "notes": "Via Citus extension or manual partitioning; not native", + }, + ("postgresql", "json documents"): { + "supported": True, + "notes": "JSONB with indexing; flexible but slower than native doc store", + }, + ("postgresql", "full-text search"): { + "supported": True, + "notes": "Built-in tsvector/tsquery; Elasticsearch for advanced use cases", + }, + ("postgresql", "multi-document transactions"): { + "supported": True, + "notes": "Native cross-table ACID", + }, + ("mongodb", "acid transactions"): { + "supported": True, + "notes": "Multi-document ACID since v4.0; single-doc always atomic", + }, + ("mongodb", "horizontal sharding"): { + "supported": True, + "notes": "Native sharding; auto-balancing across shards", + }, + ("mongodb", "json documents"): { + "supported": True, + "notes": "Native BSON document model; schema-free by default", + }, + ("mongodb", "full-text search"): { + "supported": True, + "notes": "Atlas Search (Lucene-based) for advanced full-text", + }, + ("mongodb", "multi-document transactions"): { + "supported": True, + "notes": "Available but adds latency; best avoided on hot paths", + }, +} + + +def _dispatch(name: str, args: dict) -> str: + if name == "get_db_benchmarks": + key = (args["database"].lower(), args["workload"].lower()) + return json.dumps(_BENCHMARKS.get(key, {"error": f"no data for {key}"})) + + if name == "get_case_studies": + db = args["database"].lower() + return json.dumps(_CASE_STUDIES.get(db, {"error": f"unknown db '{db}'"})) + + if name == "check_feature_support": + key = (args["database"].lower(), args["feature"].lower()) + for k, v in _FEATURES.items(): + if k[0] == key[0] and k[1] in key[1]: + return json.dumps(v) + return json.dumps({"error": f"feature '{args['feature']}' not in dataset"}) + + return json.dumps({"error": f"unknown tool '{name}'"}) + + +# --------------------------------------------------------------------------- +# Task runner — one independent conversation per task +# --------------------------------------------------------------------------- + + +async def run_task( + client: AsyncOpenAI, + task_name: str, + prompt: str, + session_id: str | None, +) -> tuple[str, str]: + """ + Run a single research task with its own tool-calling loop. + + Each task is an independent conversation so the router sees only + this task's intent — not the accumulated context of previous tasks. + Session pinning via X-Session-Id pins the model from the first task + onward, so all tasks stay on the same model. + + Returns (answer, first_model_used). + """ + headers = {"X-Session-Id": session_id} if session_id else {} + messages: list[ChatCompletionMessageParam] = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": prompt}, + ] + first_model: str | None = None + + while True: + resp = await client.chat.completions.create( + model="gpt-4o-mini", # Plano's router overrides this via routing_preferences + messages=messages, + tools=TOOLS, + tool_choice="auto", + max_completion_tokens=600, + extra_headers=headers or None, + ) + if first_model is None: + first_model = resp.model + + log.info( + "task=%s model=%s finish=%s", + task_name, + resp.model, + resp.choices[0].finish_reason, + ) + + choice = resp.choices[0] + if choice.finish_reason == "tool_calls" and choice.message.tool_calls: + messages.append(choice.message) + for tc in choice.message.tool_calls: + args = json.loads(tc.function.arguments or "{}") + result = _dispatch(tc.function.name, args) + log.info(" tool %s(%s)", tc.function.name, args) + messages.append( + {"role": "tool", "content": result, "tool_call_id": tc.id} + ) + else: + return (choice.message.content or "").strip(), first_model or "unknown" + + +# --------------------------------------------------------------------------- +# Research loop — runs all tasks, threading context forward +# --------------------------------------------------------------------------- + + +async def run_research_loop( + client: AsyncOpenAI, + session_id: str | None, +) -> tuple[str, list[dict]]: + """ + Run all 3 research tasks in sequence, passing each task's output as + context to the next. Returns (final_answer, routing_trace). + """ + context = "" + trace: list[dict] = [] + final_answer = "" + + for task in TASKS: + prompt = task["prompt"].format(context=context) + answer, model = await run_task(client, task["name"], prompt, session_id) + trace.append({"task": task["name"], "model": model}) + context += f"\n### {task['name']}\n{answer}\n" + final_answer = answer + + return final_answer, trace + + +# --------------------------------------------------------------------------- +# FastAPI app +# --------------------------------------------------------------------------- + +app = FastAPI(title="Research Agent", version="1.0.0") + + +@app.post("/v1/chat/completions") +async def chat(request: Request) -> JSONResponse: + body = await request.json() + session_id: str | None = request.headers.get("x-session-id") + + log.info("request session_id=%s", session_id or "none") + + client = AsyncOpenAI(base_url=f"{PLANO_URL}/v1", api_key="EMPTY") + answer, trace = await run_research_loop(client, session_id) + + return JSONResponse( + { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": answer}, + "finish_reason": "stop", + } + ], + "routing_trace": trace, + "session_id": session_id, + } + ) + + +@app.get("/health") +async def health() -> dict: + return {"status": "ok", "plano_url": PLANO_URL} + + +# --------------------------------------------------------------------------- +# Entry point +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + log.info("starting on port %d plano=%s", PORT, PLANO_URL) + uvicorn.run(app, host="0.0.0.0", port=PORT, log_level="warning") diff --git a/demos/llm_routing/session_pinning/config.yaml b/demos/llm_routing/session_pinning/config.yaml new file mode 100644 index 000000000..7b98b25b7 --- /dev/null +++ b/demos/llm_routing/session_pinning/config.yaml @@ -0,0 +1,27 @@ +version: v0.3.0 + +listeners: + - type: model + name: model_listener + port: 12000 + +model_providers: + + - model: openai/gpt-4o-mini + access_key: $OPENAI_API_KEY + default: true + + - model: openai/gpt-4o + access_key: $OPENAI_API_KEY + routing_preferences: + - name: complex_reasoning + description: complex reasoning tasks, multi-step analysis, or detailed explanations + + - model: anthropic/claude-sonnet-4-20250514 + access_key: $ANTHROPIC_API_KEY + routing_preferences: + - name: code_generation + description: generating new code, writing functions, or creating boilerplate + +tracing: + random_sampling: 100 diff --git a/demos/llm_routing/session_pinning/demo.py b/demos/llm_routing/session_pinning/demo.py new file mode 100644 index 000000000..fdf7634b5 --- /dev/null +++ b/demos/llm_routing/session_pinning/demo.py @@ -0,0 +1,174 @@ +#!/usr/bin/env -S uv run --script +# /// script +# requires-python = ">=3.12" +# dependencies = ["httpx>=0.27"] +# /// +""" +Session Pinning Demo — Research Agent client + +Sends the same query to the Research Agent twice — once without a session ID +and once with one — and compares the routing trace to show how session pinning +keeps the model consistent across the LLM's tool-calling loop. + +Requires the agent to already be running (start it with ./start_agents.sh). + +Usage: + uv run demo.py + AGENT_URL=http://localhost:8000 uv run demo.py +""" + +import asyncio +import os +import uuid + +import httpx + +AGENT_URL = os.environ.get("AGENT_URL", "http://localhost:8000") + +QUERY = ( + "Should we use PostgreSQL or MongoDB for a high-traffic e-commerce backend " + "that needs strong consistency for orders but flexible schemas for products?" +) + + +# --------------------------------------------------------------------------- +# Client helpers +# --------------------------------------------------------------------------- + + +async def wait_for_agent(timeout: int = 30) -> bool: + async with httpx.AsyncClient() as client: + for _ in range(timeout * 2): + try: + r = await client.get(f"{AGENT_URL}/health", timeout=1.0) + if r.status_code == 200: + return True + except Exception: + pass + await asyncio.sleep(0.5) + return False + + +async def ask_agent(query: str, session_id: str | None = None) -> dict: + headers: dict[str, str] = {} + if session_id: + headers["X-Session-Id"] = session_id + + async with httpx.AsyncClient(timeout=120.0) as client: + r = await client.post( + f"{AGENT_URL}/v1/chat/completions", + headers=headers, + json={"messages": [{"role": "user", "content": query}]}, + ) + r.raise_for_status() + return r.json() + + +# --------------------------------------------------------------------------- +# Display helpers +# --------------------------------------------------------------------------- + + +def _short(model: str) -> str: + return model.split("/")[-1] if "/" in model else model + + +def _print_trace(result: dict) -> None: + trace = result.get("routing_trace", []) + if not trace: + print(" (no trace)") + return + + prev: str | None = None + for t in trace: + short = _short(t["model"]) + switch = " ← switched" if (prev and t["model"] != prev) else "" + prev = t["model"] + print(f" {t['task']:<26} [{short}]{switch}") + + +def _print_summary(label: str, result: dict) -> None: + models = [t["model"] for t in result.get("routing_trace", [])] + if not models: + print(f" ? {label}: no routing data") + return + unique = set(models) + if len(unique) == 1: + print(f" ✓ {label}: {_short(next(iter(unique)))} for all {len(models)} turns") + else: + switched = sum(1 for a, b in zip(models, models[1:]) if a != b) + names = ", ".join(sorted(_short(m) for m in unique)) + print(f" ✗ {label}: model switched {switched} time(s) — {names}") + + +# --------------------------------------------------------------------------- +# Demo +# --------------------------------------------------------------------------- + + +async def main() -> None: + print() + print(" ╔══════════════════════════════════════════════════════════════╗") + print(" ║ Session Pinning Demo — Research Agent ║") + print(" ╚══════════════════════════════════════════════════════════════╝") + print() + print(f" Agent : {AGENT_URL}") + print(f" Query : \"{QUERY[:72]}…\"") + print() + print(" The agent uses a tool-calling loop (get_db_benchmarks,") + print(" get_case_studies, check_feature_support) to research the") + print(" question. Each LLM turn hits Plano's preference-based router.") + print() + + print(f" Waiting for agent at {AGENT_URL}…", end=" ", flush=True) + if not await wait_for_agent(): + print("FAILED — agent did not respond within 30 s") + return + print("ready.") + print() + + sid = str(uuid.uuid4()) + print(" Sending queries (running concurrently)…") + print() + without, with_pin = await asyncio.gather( + ask_agent(QUERY, session_id=None), + ask_agent(QUERY, session_id=sid), + ) + + # ── Run 1 ──────────────────────────────────────────────────────────── + print(" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + print(" Run 1: WITHOUT Session Pinning") + print(" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + print() + print(" LLM turns inside the agent loop:") + print() + _print_trace(without) + print() + _print_summary("Without pinning", without) + print() + + # ── Run 2 ──────────────────────────────────────────────────────────── + print(" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + print(f" Run 2: WITH Session Pinning (X-Session-Id: {sid[:8]}…)") + print(" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━") + print() + print(" LLM turns inside the agent loop:") + print() + _print_trace(with_pin) + print() + _print_summary("With pinning ", with_pin) + print() + + # ── Final answer ───────────────────────────────────────────────────── + answer = with_pin["choices"][0]["message"]["content"] + print(" ══ Agent recommendation (pinned session) ═════════════════════") + print() + for line in answer.splitlines(): + print(f" {line}") + print() + print(" ══════════════════════════════════════════════════════════════") + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/demos/llm_routing/session_pinning/demo.sh b/demos/llm_routing/session_pinning/demo.sh new file mode 100755 index 000000000..210fd1361 --- /dev/null +++ b/demos/llm_routing/session_pinning/demo.sh @@ -0,0 +1,19 @@ +#!/bin/bash +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +export PLANO_URL="${PLANO_URL:-http://localhost:12000}" +export AGENT_PORT="${AGENT_PORT:-8000}" +export AGENT_URL="http://localhost:$AGENT_PORT" + +cleanup() { + [ -n "$AGENT_PID" ] && kill "$AGENT_PID" 2>/dev/null +} +trap cleanup EXIT INT TERM + +# Start the agent in the background +"$SCRIPT_DIR/start_agents.sh" & +AGENT_PID=$! + +# Run the demo client +uv run "$SCRIPT_DIR/demo.py" diff --git a/demos/llm_routing/session_pinning/start_agents.sh b/demos/llm_routing/session_pinning/start_agents.sh new file mode 100755 index 000000000..5baaa3785 --- /dev/null +++ b/demos/llm_routing/session_pinning/start_agents.sh @@ -0,0 +1,28 @@ +#!/bin/bash +set -e + +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +PIDS=() + +log() { echo "$(date '+%F %T') - $*"; } + +cleanup() { + log "Stopping agents..." + for PID in "${PIDS[@]}"; do + kill "$PID" 2>/dev/null && log "Stopped process $PID" + done + exit 0 +} + +trap cleanup EXIT INT TERM + +export PLANO_URL="${PLANO_URL:-http://localhost:12000}" +export AGENT_PORT="${AGENT_PORT:-8000}" + +log "Starting research_agent on port $AGENT_PORT..." +uv run "$SCRIPT_DIR/agent.py" & +PIDS+=($!) + +for PID in "${PIDS[@]}"; do + wait "$PID" +done