Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 160 additions & 42 deletions crates/brightstaff/src/handlers/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ async fn llm_chat_inner(
state_storage: Option<Arc<dyn StateStorage>>,
request_id: String,
request_path: String,
mut request_headers: hyper::HeaderMap,
request_headers: hyper::HeaderMap,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations
set_service_name(operation_component::LLM);
Expand Down Expand Up @@ -274,9 +274,6 @@ async fn llm_chat_inner(
}
}

// Serialize request for upstream BEFORE router consumes it
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();

// Determine routing using the dedicated router_chat module
// This gets its own span for latency and error tracking
let routing_span = info_span!(
Expand All @@ -293,7 +290,7 @@ async fn llm_chat_inner(
set_service_name(operation_component::ROUTING);
router_chat_get_upstream_model(
router_service,
client_request, // Pass the original request - router_chat will convert it
client_request.clone(), // Clone here to preserve for retries
&traceparent,
&request_path,
&request_id,
Expand Down Expand Up @@ -334,49 +331,93 @@ async fn llm_chat_inner(
span.update_name(span_name.clone());
});

debug!(
url = %full_qualified_llm_provider_url,
provider_hint = %resolved_model,
upstream_model = %model_name_only,
"Routing to upstream"
);
// Capture start time right before sending request to upstream
let request_start_time = std::time::Instant::now();
let _request_start_system_time = std::time::SystemTime::now();

request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&resolved_model).unwrap(),
);
let mut current_resolved_model = resolved_model.clone();
let mut current_client_request = client_request;
let mut attempts = 0;
let max_attempts = 2; // Original + 1 retry

let llm_response = loop {
attempts += 1;

// Handle provider/model slug format (e.g., "openai/gpt-4")
// Extract just the model name for upstream (providers don't understand the slug)
let current_model_name_only = if let Some((_, model)) = current_resolved_model.split_once('/') {
model.to_string()
} else {
current_resolved_model.clone()
};

debug!(
url = %full_qualified_llm_provider_url,
provider_hint = %current_resolved_model,
upstream_model = %current_model_name_only,
attempt = attempts,
"Routing to upstream"
);

request_headers.insert(
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
);
// remove content-length header if it exists
request_headers.remove(header::CONTENT_LENGTH);
// Set the model to just the model name (without provider prefix)
current_client_request.set_model(current_model_name_only.clone());

// Inject current LLM span's trace context so upstream spans are children of plano(llm)
global::get_text_map_propagator(|propagator| {
let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
});
// Serialize request for upstream
let current_request_bytes = ProviderRequestType::to_bytes(&current_client_request).unwrap();

// Capture start time right before sending request to upstream
let request_start_time = std::time::Instant::now();
let _request_start_system_time = std::time::SystemTime::now();
let mut current_request_headers = request_headers.clone();
current_request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&current_resolved_model).unwrap(),
);

let llm_response = match reqwest::Client::new()
.post(&full_qualified_llm_provider_url)
.headers(request_headers)
.body(client_request_bytes_for_upstream)
.send()
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
current_request_headers.insert(
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
);
// remove content-length header if it exists
current_request_headers.remove(header::CONTENT_LENGTH);

// Inject current LLM span's trace context so upstream spans are children of plano(llm)
global::get_text_map_propagator(|propagator| {
let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut current_request_headers));
});

let res = match reqwest::Client::new()
.post(&full_qualified_llm_provider_url)
.headers(current_request_headers)
.body(current_request_bytes)
.send()
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};

if res.status() == StatusCode::TOO_MANY_REQUESTS && attempts < max_attempts {
let providers = llm_providers.read().await;
if let Some(provider) = providers.get(&current_resolved_model) {
if provider.retry_on_ratelimit == Some(true) {
if let Some(alt_provider) = providers.get_alternative(&current_resolved_model) {
info!(
request_id = %request_id,
current_model = %current_resolved_model,
alt_model = %alt_provider.name,
"429 received, retrying with alternative model"
);
current_resolved_model = alt_provider.name.clone();
continue;
}
}
}
}
break res;
};

// copy over the headers and status code from the original response
Expand All @@ -391,6 +432,7 @@ async fn llm_chat_inner(
// Build LLM span with actual status code using constants
let byte_stream = llm_response.bytes_stream();


// Create base processor for metrics and tracing
let base_processor = ObservableStreamProcessor::new(
operation_component::LLM,
Expand Down Expand Up @@ -441,6 +483,82 @@ async fn llm_chat_inner(
}
}

#[cfg(test)]
mod tests {
use common::configuration::{LlmProvider, LlmProviderType};
use common::llm_providers::LlmProviders;

// We can't easily create Request<Incoming> in tests without a full server setup.
// So we'll skip the functional test of llm_chat and rely on unit tests of get_alternative.

#[tokio::test]
async fn test_llm_providers_get_alternative() {
let primary = LlmProvider {
name: "primary".to_string(),
provider_interface: LlmProviderType::OpenAI,
model: Some("gpt-4".to_string()),
default: Some(false),
..Default::default()
};

let secondary = LlmProvider {
name: "secondary".to_string(),
provider_interface: LlmProviderType::OpenAI,
model: Some("gpt-4-alt".to_string()),
default: Some(true),
..Default::default()
};

let providers_vec = vec![primary.clone(), secondary.clone()];
let llm_providers = LlmProviders::try_from(providers_vec).unwrap();

let alt = llm_providers.get_alternative("primary");
assert!(alt.is_some());
assert_eq!(alt.unwrap().name, "secondary");

let alt_none = llm_providers.get_alternative("secondary");
assert!(alt_none.is_some());
assert_eq!(alt_none.unwrap().name, "primary");
}

#[tokio::test]
async fn test_llm_providers_get_alternative_internal_skipped() {
let primary = LlmProvider {
name: "primary".to_string(),
provider_interface: LlmProviderType::OpenAI,
model: Some("gpt-4".to_string()),
default: Some(true),
..Default::default()
};

let internal = LlmProvider {
name: "internal".to_string(),
provider_interface: LlmProviderType::Arch,
model: Some("router".to_string()),
internal: Some(true),
default: Some(false),
..Default::default()
};

let secondary = LlmProvider {
name: "secondary".to_string(),
provider_interface: LlmProviderType::OpenAI,
model: Some("gpt-4-alt".to_string()),
default: Some(false),
..Default::default()
};

let providers_vec = vec![primary, internal, secondary];
let llm_providers = LlmProviders::try_from(providers_vec).unwrap();

let alt = llm_providers.get_alternative("primary");
assert!(alt.is_some());
assert_eq!(alt.unwrap().name, "secondary");
}
}



/// Resolves model aliases by looking up the requested model in the model_aliases map.
/// Returns the target model if an alias is found, otherwise returns the original model.
fn resolve_model_alias(
Expand Down
Empty file modified crates/build.sh
100644 → 100755
Empty file.
2 changes: 2 additions & 0 deletions crates/common/src/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ pub struct LlmProvider {
pub base_url_path_prefix: Option<String>,
pub internal: Option<bool>,
pub passthrough_auth: Option<bool>,
pub retry_on_ratelimit: Option<bool>,
}

pub trait IntoModels {
Expand Down Expand Up @@ -372,6 +373,7 @@ impl Default for LlmProvider {
base_url_path_prefix: None,
internal: None,
passthrough_auth: None,
retry_on_ratelimit: None,
}
}
}
Expand Down
76 changes: 76 additions & 0 deletions crates/common/src/llm_providers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,29 @@ impl LlmProviders {

None
}

/// Get an alternative provider that is not the one specified by current_name.
/// Prefers the default provider if it's different, otherwise picks the first non-internal provider.
pub fn get_alternative(&self, current_name: &str) -> Option<Arc<LlmProvider>> {
// Try to find a default provider that is not the current one
if let Some(default_provider) = &self.default {
if default_provider.name != current_name {
return Some(Arc::clone(default_provider));
}
}

// Otherwise just pick the first canonical non-internal provider that is not the current one
self.providers.iter().find_map(|(key, provider)| {
if provider.internal != Some(true)
&& provider.name != current_name
&& key == &provider.name
{
Some(Arc::clone(provider))
} else {
None
}
})
}
}

#[derive(thiserror::Error, Debug)]
Expand Down Expand Up @@ -278,6 +301,7 @@ mod tests {
internal: None,
stream: None,
passthrough_auth: None,
retry_on_ratelimit: None,
}
}

Expand Down Expand Up @@ -334,4 +358,56 @@ mod tests {
.wildcard_providers
.contains_key("custom-provider"));
}

#[test]
fn test_get_alternative_prefers_default() {
let primary = create_test_provider("primary", Some("gpt-4".to_string()));
let mut secondary = create_test_provider("secondary", Some("gpt-4-alt".to_string()));
secondary.default = Some(true);
let tertiary = create_test_provider("tertiary", Some("gpt-4-other".to_string()));

let providers = vec![primary, secondary, tertiary];
let llm_providers = LlmProviders::try_from(providers).unwrap();

// If we are at primary, should return secondary (default)
let alt = llm_providers.get_alternative("primary");
assert_eq!(alt.unwrap().name, "secondary");

// If we are at tertiary, should return secondary (default)
let alt = llm_providers.get_alternative("tertiary");
assert_eq!(alt.unwrap().name, "secondary");

// If we are at secondary (the default), should return something else (primary or tertiary)
let alt = llm_providers.get_alternative("secondary");
let alt_name = alt.unwrap().name.clone();
assert!(alt_name == "primary" || alt_name == "tertiary");
}

#[test]
fn test_get_alternative_skips_internal() {
let primary = create_test_provider("primary", Some("gpt-4".to_string()));
let mut internal = create_test_provider("internal", Some("router".to_string()));
internal.internal = Some(true);
let secondary = create_test_provider("secondary", Some("gpt-4-alt".to_string()));

let providers = vec![primary, internal, secondary];
let llm_providers = LlmProviders::try_from(providers).unwrap();

// Should return secondary, NOT internal
let alt = llm_providers.get_alternative("primary");
assert_eq!(alt.unwrap().name, "secondary");
}

#[test]
fn test_get_alternative_returns_none_if_no_other_available() {
let primary = create_test_provider("primary", Some("gpt-4".to_string()));
let mut internal = create_test_provider("internal", Some("router".to_string()));
internal.internal = Some(true);

let providers = vec![primary, internal];
let llm_providers = LlmProviders::try_from(providers).unwrap();

let alt = llm_providers.get_alternative("primary");
assert!(alt.is_none());
}
}
Loading