diff --git a/crates/brightstaff/src/handlers/llm.rs b/crates/brightstaff/src/handlers/llm.rs index b9fe0ba37..855301d50 100644 --- a/crates/brightstaff/src/handlers/llm.rs +++ b/crates/brightstaff/src/handlers/llm.rs @@ -90,7 +90,7 @@ async fn llm_chat_inner( state_storage: Option>, request_id: String, request_path: String, - mut request_headers: hyper::HeaderMap, + request_headers: hyper::HeaderMap, ) -> Result>, hyper::Error> { // Set service name for LLM operations set_service_name(operation_component::LLM); @@ -274,9 +274,6 @@ async fn llm_chat_inner( } } - // Serialize request for upstream BEFORE router consumes it - let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap(); - // Determine routing using the dedicated router_chat module // This gets its own span for latency and error tracking let routing_span = info_span!( @@ -293,7 +290,7 @@ async fn llm_chat_inner( set_service_name(operation_component::ROUTING); router_chat_get_upstream_model( router_service, - client_request, // Pass the original request - router_chat will convert it + client_request.clone(), // Clone here to preserve for retries &traceparent, &request_path, &request_id, @@ -334,49 +331,93 @@ async fn llm_chat_inner( span.update_name(span_name.clone()); }); - debug!( - url = %full_qualified_llm_provider_url, - provider_hint = %resolved_model, - upstream_model = %model_name_only, - "Routing to upstream" - ); + // Capture start time right before sending request to upstream + let request_start_time = std::time::Instant::now(); + let _request_start_system_time = std::time::SystemTime::now(); - request_headers.insert( - ARCH_PROVIDER_HINT_HEADER, - header::HeaderValue::from_str(&resolved_model).unwrap(), - ); + let mut current_resolved_model = resolved_model.clone(); + let mut current_client_request = client_request; + let mut attempts = 0; + let max_attempts = 2; // Original + 1 retry + + let llm_response = loop { + attempts += 1; + + // Handle provider/model slug format (e.g., "openai/gpt-4") + // Extract just the model name for upstream (providers don't understand the slug) + let current_model_name_only = if let Some((_, model)) = current_resolved_model.split_once('/') { + model.to_string() + } else { + current_resolved_model.clone() + }; + + debug!( + url = %full_qualified_llm_provider_url, + provider_hint = %current_resolved_model, + upstream_model = %current_model_name_only, + attempt = attempts, + "Routing to upstream" + ); - request_headers.insert( - header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER), - header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(), - ); - // remove content-length header if it exists - request_headers.remove(header::CONTENT_LENGTH); + // Set the model to just the model name (without provider prefix) + current_client_request.set_model(current_model_name_only.clone()); - // Inject current LLM span's trace context so upstream spans are children of plano(llm) - global::get_text_map_propagator(|propagator| { - let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current()); - propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers)); - }); + // Serialize request for upstream + let current_request_bytes = ProviderRequestType::to_bytes(¤t_client_request).unwrap(); - // Capture start time right before sending request to upstream - let request_start_time = std::time::Instant::now(); - let _request_start_system_time = std::time::SystemTime::now(); + let mut current_request_headers = request_headers.clone(); + current_request_headers.insert( + ARCH_PROVIDER_HINT_HEADER, + header::HeaderValue::from_str(¤t_resolved_model).unwrap(), + ); - let llm_response = match reqwest::Client::new() - .post(&full_qualified_llm_provider_url) - .headers(request_headers) - .body(client_request_bytes_for_upstream) - .send() - .await - { - Ok(res) => res, - Err(err) => { - let err_msg = format!("Failed to send request: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(internal_error); + current_request_headers.insert( + header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER), + header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(), + ); + // remove content-length header if it exists + current_request_headers.remove(header::CONTENT_LENGTH); + + // Inject current LLM span's trace context so upstream spans are children of plano(llm) + global::get_text_map_propagator(|propagator| { + let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current()); + propagator.inject_context(&cx, &mut HeaderInjector(&mut current_request_headers)); + }); + + let res = match reqwest::Client::new() + .post(&full_qualified_llm_provider_url) + .headers(current_request_headers) + .body(current_request_bytes) + .send() + .await + { + Ok(res) => res, + Err(err) => { + let err_msg = format!("Failed to send request: {}", err); + let mut internal_error = Response::new(full(err_msg)); + *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return Ok(internal_error); + } + }; + + if res.status() == StatusCode::TOO_MANY_REQUESTS && attempts < max_attempts { + let providers = llm_providers.read().await; + if let Some(provider) = providers.get(¤t_resolved_model) { + if provider.retry_on_ratelimit == Some(true) { + if let Some(alt_provider) = providers.get_alternative(¤t_resolved_model) { + info!( + request_id = %request_id, + current_model = %current_resolved_model, + alt_model = %alt_provider.name, + "429 received, retrying with alternative model" + ); + current_resolved_model = alt_provider.name.clone(); + continue; + } + } + } } + break res; }; // copy over the headers and status code from the original response @@ -391,6 +432,7 @@ async fn llm_chat_inner( // Build LLM span with actual status code using constants let byte_stream = llm_response.bytes_stream(); + // Create base processor for metrics and tracing let base_processor = ObservableStreamProcessor::new( operation_component::LLM, @@ -441,6 +483,82 @@ async fn llm_chat_inner( } } +#[cfg(test)] +mod tests { + use common::configuration::{LlmProvider, LlmProviderType}; + use common::llm_providers::LlmProviders; + + // We can't easily create Request in tests without a full server setup. + // So we'll skip the functional test of llm_chat and rely on unit tests of get_alternative. + + #[tokio::test] + async fn test_llm_providers_get_alternative() { + let primary = LlmProvider { + name: "primary".to_string(), + provider_interface: LlmProviderType::OpenAI, + model: Some("gpt-4".to_string()), + default: Some(false), + ..Default::default() + }; + + let secondary = LlmProvider { + name: "secondary".to_string(), + provider_interface: LlmProviderType::OpenAI, + model: Some("gpt-4-alt".to_string()), + default: Some(true), + ..Default::default() + }; + + let providers_vec = vec![primary.clone(), secondary.clone()]; + let llm_providers = LlmProviders::try_from(providers_vec).unwrap(); + + let alt = llm_providers.get_alternative("primary"); + assert!(alt.is_some()); + assert_eq!(alt.unwrap().name, "secondary"); + + let alt_none = llm_providers.get_alternative("secondary"); + assert!(alt_none.is_some()); + assert_eq!(alt_none.unwrap().name, "primary"); + } + + #[tokio::test] + async fn test_llm_providers_get_alternative_internal_skipped() { + let primary = LlmProvider { + name: "primary".to_string(), + provider_interface: LlmProviderType::OpenAI, + model: Some("gpt-4".to_string()), + default: Some(true), + ..Default::default() + }; + + let internal = LlmProvider { + name: "internal".to_string(), + provider_interface: LlmProviderType::Arch, + model: Some("router".to_string()), + internal: Some(true), + default: Some(false), + ..Default::default() + }; + + let secondary = LlmProvider { + name: "secondary".to_string(), + provider_interface: LlmProviderType::OpenAI, + model: Some("gpt-4-alt".to_string()), + default: Some(false), + ..Default::default() + }; + + let providers_vec = vec![primary, internal, secondary]; + let llm_providers = LlmProviders::try_from(providers_vec).unwrap(); + + let alt = llm_providers.get_alternative("primary"); + assert!(alt.is_some()); + assert_eq!(alt.unwrap().name, "secondary"); + } +} + + + /// Resolves model aliases by looking up the requested model in the model_aliases map. /// Returns the target model if an alias is found, otherwise returns the original model. fn resolve_model_alias( diff --git a/crates/build.sh b/crates/build.sh old mode 100644 new mode 100755 diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index ccca89c38..5989fcbb3 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -328,6 +328,7 @@ pub struct LlmProvider { pub base_url_path_prefix: Option, pub internal: Option, pub passthrough_auth: Option, + pub retry_on_ratelimit: Option, } pub trait IntoModels { @@ -372,6 +373,7 @@ impl Default for LlmProvider { base_url_path_prefix: None, internal: None, passthrough_auth: None, + retry_on_ratelimit: None, } } } diff --git a/crates/common/src/llm_providers.rs b/crates/common/src/llm_providers.rs index 3c9d1d68d..49857fbe8 100644 --- a/crates/common/src/llm_providers.rs +++ b/crates/common/src/llm_providers.rs @@ -80,6 +80,29 @@ impl LlmProviders { None } + + /// Get an alternative provider that is not the one specified by current_name. + /// Prefers the default provider if it's different, otherwise picks the first non-internal provider. + pub fn get_alternative(&self, current_name: &str) -> Option> { + // Try to find a default provider that is not the current one + if let Some(default_provider) = &self.default { + if default_provider.name != current_name { + return Some(Arc::clone(default_provider)); + } + } + + // Otherwise just pick the first canonical non-internal provider that is not the current one + self.providers.iter().find_map(|(key, provider)| { + if provider.internal != Some(true) + && provider.name != current_name + && key == &provider.name + { + Some(Arc::clone(provider)) + } else { + None + } + }) + } } #[derive(thiserror::Error, Debug)] @@ -278,6 +301,7 @@ mod tests { internal: None, stream: None, passthrough_auth: None, + retry_on_ratelimit: None, } } @@ -334,4 +358,56 @@ mod tests { .wildcard_providers .contains_key("custom-provider")); } + + #[test] + fn test_get_alternative_prefers_default() { + let primary = create_test_provider("primary", Some("gpt-4".to_string())); + let mut secondary = create_test_provider("secondary", Some("gpt-4-alt".to_string())); + secondary.default = Some(true); + let tertiary = create_test_provider("tertiary", Some("gpt-4-other".to_string())); + + let providers = vec![primary, secondary, tertiary]; + let llm_providers = LlmProviders::try_from(providers).unwrap(); + + // If we are at primary, should return secondary (default) + let alt = llm_providers.get_alternative("primary"); + assert_eq!(alt.unwrap().name, "secondary"); + + // If we are at tertiary, should return secondary (default) + let alt = llm_providers.get_alternative("tertiary"); + assert_eq!(alt.unwrap().name, "secondary"); + + // If we are at secondary (the default), should return something else (primary or tertiary) + let alt = llm_providers.get_alternative("secondary"); + let alt_name = alt.unwrap().name.clone(); + assert!(alt_name == "primary" || alt_name == "tertiary"); + } + + #[test] + fn test_get_alternative_skips_internal() { + let primary = create_test_provider("primary", Some("gpt-4".to_string())); + let mut internal = create_test_provider("internal", Some("router".to_string())); + internal.internal = Some(true); + let secondary = create_test_provider("secondary", Some("gpt-4-alt".to_string())); + + let providers = vec![primary, internal, secondary]; + let llm_providers = LlmProviders::try_from(providers).unwrap(); + + // Should return secondary, NOT internal + let alt = llm_providers.get_alternative("primary"); + assert_eq!(alt.unwrap().name, "secondary"); + } + + #[test] + fn test_get_alternative_returns_none_if_no_other_available() { + let primary = create_test_provider("primary", Some("gpt-4".to_string())); + let mut internal = create_test_provider("internal", Some("router".to_string())); + internal.internal = Some(true); + + let providers = vec![primary, internal]; + let llm_providers = LlmProviders::try_from(providers).unwrap(); + + let alt = llm_providers.get_alternative("primary"); + assert!(alt.is_none()); + } }