Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions config/plano_config_schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,22 @@ properties:
enum:
- llm
- prompt
routing:
type: object
properties:
llm_provider:
type: string
model:
type: string
session_ttl_seconds:
type: integer
minimum: 1
description: TTL in seconds for session-pinned routing cache entries. Default 600 (10 minutes).
session_max_entries:
type: integer
minimum: 1
description: Maximum number of session-pinned routing cache entries. Default 10000.
additionalProperties: false
state_storage:
type: object
properties:
Expand Down
110 changes: 72 additions & 38 deletions crates/brightstaff/src/handlers/llm/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use bytes::Bytes;
use common::configuration::{FilterPipeline, ModelAlias};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER};
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, SESSION_ID_HEADER};
use common::llm_providers::LlmProviders;
use hermesllm::apis::openai::Message;
use hermesllm::apis::openai_responses::InputParam;
Expand Down Expand Up @@ -92,6 +92,21 @@ async fn llm_chat_inner(

let traceparent = extract_or_generate_traceparent(&request_headers);

// Session pinning: extract session ID and check cache before routing
let session_id: Option<String> = request_headers
.get(SESSION_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());
let pinned_model: Option<String> = if let Some(ref sid) = session_id {
state
.router_service
.get_cached_route(sid)
.await
.map(|c| c.model_name)
} else {
None
};

let full_qualified_llm_provider_url = format!("{}{}", state.llm_provider_url, request_path);

// --- Phase 1: Parse and validate the incoming request ---
Expand Down Expand Up @@ -242,46 +257,65 @@ async fn llm_chat_inner(
}
};

// --- Phase 3: Route the request ---
let routing_span = info_span!(
"routing",
component = "routing",
http.method = "POST",
http.target = %request_path,
model.requested = %model_from_request,
model.alias_resolved = %alias_resolved_model,
route.selected_model = tracing::field::Empty,
routing.determination_ms = tracing::field::Empty,
);
let routing_result = match async {
set_service_name(operation_component::ROUTING);
router_chat_get_upstream_model(
Arc::clone(&state.router_service),
client_request,
&traceparent,
&request_path,
&request_id,
inline_routing_policy,
)
// --- Phase 3: Route the request (or use pinned model from session cache) ---
let resolved_model = if let Some(cached_model) = pinned_model {
info!(
session_id = %session_id.as_deref().unwrap_or(""),
model = %cached_model,
"using pinned routing decision from cache"
);
cached_model
} else {
let routing_span = info_span!(
"routing",
component = "routing",
http.method = "POST",
http.target = %request_path,
model.requested = %model_from_request,
model.alias_resolved = %alias_resolved_model,
route.selected_model = tracing::field::Empty,
routing.determination_ms = tracing::field::Empty,
);
let routing_result = match async {
set_service_name(operation_component::ROUTING);
router_chat_get_upstream_model(
Arc::clone(&state.router_service),
client_request,
&traceparent,
&request_path,
&request_id,
inline_routing_policy,
)
.await
}
.instrument(routing_span)
.await
}
.instrument(routing_span)
.await
{
Ok(result) => result,
Err(err) => {
let mut internal_error = Response::new(full(err.message));
*internal_error.status_mut() = err.status_code;
return Ok(internal_error);
{
Ok(result) => result,
Err(err) => {
let mut internal_error = Response::new(full(err.message));
*internal_error.status_mut() = err.status_code;
return Ok(internal_error);
}
};

let (router_selected_model, route_name) =
(routing_result.model_name, routing_result.route_name);
let model = if router_selected_model != "none" {
router_selected_model
} else {
alias_resolved_model.clone()
};

// Cache the routing decision so subsequent requests with the same session ID are pinned
if let Some(ref sid) = session_id {
state
.router_service
.cache_route(sid.clone(), model.clone(), route_name)
.await;
}
};

// Determine final model (router returns "none" when it doesn't select a specific model)
let router_selected_model = routing_result.model_name;
let resolved_model = if router_selected_model != "none" {
router_selected_model
} else {
alias_resolved_model.clone()
model
};
tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str());

Expand Down
63 changes: 61 additions & 2 deletions crates/brightstaff/src/handlers/routing_service.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use bytes::Bytes;
use common::configuration::{ModelUsagePreference, SpanAttributes};
use common::consts::REQUEST_ID_HEADER;
use common::consts::{REQUEST_ID_HEADER, SESSION_ID_HEADER};
use common::errors::BrightStaffError;
use hermesllm::clients::SupportedAPIsFromClient;
use hermesllm::ProviderRequestType;
Expand Down Expand Up @@ -67,6 +67,9 @@ struct RoutingDecisionResponse {
model: String,
route: Option<String>,
trace_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
session_id: Option<String>,
pinned: bool,
}

pub async fn routing_decision(
Expand All @@ -82,6 +85,11 @@ pub async fn routing_decision(
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());

let session_id: Option<String> = request_headers
.get(SESSION_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string());

let custom_attrs = collect_custom_trace_attributes(&request_headers, span_attributes.as_ref());

let request_span = info_span!(
Expand All @@ -99,6 +107,7 @@ pub async fn routing_decision(
request_path,
request_headers,
custom_attrs,
session_id,
)
.instrument(request_span)
.await
Expand All @@ -111,6 +120,7 @@ async fn routing_decision_inner(
request_path: String,
request_headers: hyper::HeaderMap,
custom_attrs: std::collections::HashMap<String, String>,
session_id: Option<String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
set_service_name(operation_component::ROUTING);
opentelemetry::trace::get_active_span(|span| {
Expand All @@ -128,6 +138,34 @@ async fn routing_decision_inner(
.unwrap_or("unknown")
.to_string();

// Session pinning: check cache before doing any routing work
if let Some(ref sid) = session_id {
if let Some(cached) = router_service.get_cached_route(sid).await {
info!(
session_id = %sid,
model = %cached.model_name,
route = ?cached.route_name,
"returning pinned routing decision from cache"
);
let response = RoutingDecisionResponse {
model: cached.model_name,
route: cached.route_name,
trace_id,
session_id: Some(sid.clone()),
pinned: true,
};
let json = serde_json::to_string(&response).unwrap();
let body = Full::new(Bytes::from(json))
.map_err(|never| match never {})
.boxed();
return Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(body)
.unwrap());
}
}

// Parse request body
let raw_bytes = request.collect().await?.to_bytes();

Expand Down Expand Up @@ -166,7 +204,7 @@ async fn routing_decision_inner(

// Call the existing routing logic with inline preferences
let routing_result = router_chat_get_upstream_model(
router_service,
Arc::clone(&router_service),
client_request,
&traceparent,
&request_path,
Expand All @@ -177,10 +215,23 @@ async fn routing_decision_inner(

match routing_result {
Ok(result) => {
// Cache the result if session_id is present
if let Some(ref sid) = session_id {
router_service
.cache_route(
sid.clone(),
result.model_name.clone(),
result.route_name.clone(),
)
.await;
}

let response = RoutingDecisionResponse {
model: result.model_name,
route: result.route_name,
trace_id,
session_id,
pinned: false,
};

info!(
Expand Down Expand Up @@ -318,12 +369,16 @@ mod tests {
model: "openai/gpt-4o".to_string(),
route: Some("code_generation".to_string()),
trace_id: "abc123".to_string(),
session_id: Some("sess-abc".to_string()),
pinned: true,
};
let json = serde_json::to_string(&response).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["model"], "openai/gpt-4o");
assert_eq!(parsed["route"], "code_generation");
assert_eq!(parsed["trace_id"], "abc123");
assert_eq!(parsed["session_id"], "sess-abc");
assert_eq!(parsed["pinned"], true);
}

#[test]
Expand All @@ -332,10 +387,14 @@ mod tests {
model: "none".to_string(),
route: None,
trace_id: "abc123".to_string(),
session_id: None,
pinned: false,
};
let json = serde_json::to_string(&response).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["model"], "none");
assert!(parsed["route"].is_null());
assert!(parsed.get("session_id").is_none());
assert_eq!(parsed["pinned"], false);
}
}
18 changes: 18 additions & 0 deletions crates/brightstaff/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,31 @@ async fn init_app_state(
.map(|p| p.name.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());

let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds);

let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries);

let router_service = Arc::new(RouterService::new(
config.model_providers.clone(),
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
routing_model_name,
routing_llm_provider,
session_ttl_seconds,
session_max_entries,
));

// Spawn background task to clean up expired session cache entries every 5 minutes
{
let router_service = Arc::clone(&router_service);
tokio::spawn(async move {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(300));
loop {
interval.tick().await;
router_service.cleanup_expired_sessions().await;
}
});
}

let orchestrator_model_name: String = overrides
.agent_orchestration_model
.as_deref()
Expand Down
Loading
Loading