diff --git a/crates/braintrust-llm-router/src/providers/bedrock.rs b/crates/braintrust-llm-router/src/providers/bedrock.rs index 862b9386..34327444 100644 --- a/crates/braintrust-llm-router/src/providers/bedrock.rs +++ b/crates/braintrust-llm-router/src/providers/bedrock.rs @@ -11,6 +11,7 @@ use bytes::Bytes; use http::Request as HttpRequest; use lingua::serde_json::Value; use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE}; + use reqwest::{Client, StatusCode, Url}; use crate::auth::AuthConfig; @@ -18,7 +19,9 @@ use crate::catalog::ModelSpec; use crate::client::{default_client, ClientSettings}; use crate::error::{Error, Result, UpstreamHttpError}; use crate::providers::ClientHeaders; -use crate::streaming::{bedrock_event_stream, single_bytes_stream, RawResponseStream}; +use crate::streaming::{ + bedrock_event_stream, bedrock_messages_event_stream, single_bytes_stream, RawResponseStream, +}; use lingua::ProviderFormat; #[derive(Debug, Clone)] @@ -91,7 +94,13 @@ impl BedrockProvider { } fn invoke_url(&self, model: &str, stream: bool) -> Result { - let path = if stream { + let path = if lingua::is_bedrock_anthropic_model(Some(model)) { + if stream { + format!("model/{model}/invoke-with-response-stream") + } else { + format!("model/{model}/invoke") + } + } else if stream { format!("model/{model}/converse-stream") } else { format!("model/{model}/converse") @@ -267,7 +276,6 @@ impl crate::providers::Provider for BedrockProvider { return Ok(single_bytes_stream(response)); } - // Router should have already added stream options to payload let url = self.invoke_url(&spec.model, true)?; #[cfg(feature = "tracing")] @@ -317,7 +325,11 @@ impl crate::providers::Provider for BedrockProvider { }); } - Ok(bedrock_event_stream(response)) + if lingua::is_bedrock_anthropic_model(Some(&spec.model)) { + Ok(bedrock_messages_event_stream(response)) + } else { + Ok(bedrock_event_stream(response)) + } } async fn health_check(&self, auth: &AuthConfig) -> Result<()> { @@ -358,3 +370,50 @@ fn extract_retry_after(status: StatusCode, _body: &str) -> Option { None } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn invoke_url_converse_mode() { + let provider = BedrockProvider::new(BedrockConfig::default()).unwrap(); + + let url = provider + .invoke_url("amazon.nova-micro-v1:0", false) + .unwrap(); + assert!(url + .as_str() + .contains("/model/amazon.nova-micro-v1:0/converse")); + + let url = provider + .invoke_url("amazon.nova-micro-v1:0", true) + .unwrap(); + assert!(url + .as_str() + .contains("/model/amazon.nova-micro-v1:0/converse-stream")); + } + + #[test] + fn invoke_url_anthropic_messages_mode() { + let provider = BedrockProvider::new(BedrockConfig::default()).unwrap(); + + let url = provider + .invoke_url("anthropic.claude-3-5-sonnet-20241022-v2:0", false) + .unwrap(); + assert!(url + .as_str() + .contains("/model/anthropic.claude-3-5-sonnet-20241022-v2:0/invoke")); + assert!( + !url.as_str().contains("converse"), + "Anthropic messages mode should not use converse endpoints" + ); + + let url = provider + .invoke_url("anthropic.claude-3-5-sonnet-20241022-v2:0", true) + .unwrap(); + assert!(url.as_str().contains( + "/model/anthropic.claude-3-5-sonnet-20241022-v2:0/invoke-with-response-stream" + )); + } +} diff --git a/crates/braintrust-llm-router/src/router.rs b/crates/braintrust-llm-router/src/router.rs index 5fca83ea..181f4117 100644 --- a/crates/braintrust-llm-router/src/router.rs +++ b/crates/braintrust-llm-router/src/router.rs @@ -149,8 +149,7 @@ impl Router { client_headers: &ClientHeaders, ) -> Result { let (provider, auth, spec, strategy) = self.resolve_provider(model)?; - let payload = match lingua::transform_request(body.clone(), provider.format(), Some(model)) - { + let payload = match lingua::transform_request(body.clone(), spec.format, Some(model)) { Ok(TransformResult::PassThrough(bytes)) => bytes, Ok(TransformResult::Transformed { bytes, .. }) => bytes, Err(TransformError::UnsupportedTargetFormat(_)) => body.clone(), @@ -205,8 +204,7 @@ impl Router { client_headers: &ClientHeaders, ) -> Result { let (provider, auth, spec, _) = self.resolve_provider(model)?; - let payload = match lingua::transform_request(body.clone(), provider.format(), Some(model)) - { + let payload = match lingua::transform_request(body.clone(), spec.format, Some(model)) { Ok(TransformResult::PassThrough(bytes)) => bytes, Ok(TransformResult::Transformed { bytes, .. }) => bytes, Err(TransformError::UnsupportedTargetFormat(_)) => body.clone(), @@ -222,7 +220,8 @@ impl Router { pub fn provider_alias(&self, model: &str) -> Result { let (_, format, alias) = self.resolver.resolve(model)?; - Ok(self.formats.get(&format).cloned().unwrap_or(alias)) + let alias = self.formats.get(&format).cloned().unwrap_or(alias); + Ok(alias) } fn resolve_provider(&self, model: &str) -> Result> { diff --git a/crates/braintrust-llm-router/src/streaming.rs b/crates/braintrust-llm-router/src/streaming.rs index 848654f0..17213eb6 100644 --- a/crates/braintrust-llm-router/src/streaming.rs +++ b/crates/braintrust-llm-router/src/streaming.rs @@ -6,6 +6,8 @@ use futures::Stream; use reqwest::Response; use crate::error::{Error, Result}; +#[cfg(feature = "provider-bedrock")] +use lingua::serde_json::Value; use lingua::ProviderFormat; use lingua::TransformResult; @@ -305,6 +307,142 @@ pub fn bedrock_event_stream(response: Response) -> RawResponseStream { Box::pin(RawBedrockEventStream::new(response.bytes_stream())) } +/// Bedrock Messages API event stream that yields raw Anthropic JSON payloads. +/// +/// Uses the same AWS binary event stream decoder as the Converse stream but +/// emits the payload bytes directly without wrapping in `{"eventType": payload}`. +/// The payloads are already valid Anthropic streaming JSON events. +#[cfg(feature = "provider-bedrock")] +struct RawBedrockMessagesEventStream +where + S: Stream> + Unpin + Send + 'static, +{ + inner: S, + buffer: BytesMut, + decoder: aws_smithy_eventstream::frame::MessageFrameDecoder, + finished: bool, +} + +#[cfg(feature = "provider-bedrock")] +impl RawBedrockMessagesEventStream +where + S: Stream> + Unpin + Send + 'static, +{ + fn new(inner: S) -> Self { + Self { + inner, + buffer: BytesMut::new(), + decoder: aws_smithy_eventstream::frame::MessageFrameDecoder::new(), + finished: false, + } + } +} + +#[cfg(feature = "provider-bedrock")] +impl Stream for RawBedrockMessagesEventStream +where + S: Stream> + Unpin + Send + 'static, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use aws_smithy_eventstream::frame::DecodedFrame; + + let this = self.get_mut(); + + if this.finished { + return Poll::Ready(None); + } + + loop { + match this.decoder.decode_frame(&mut this.buffer) { + Ok(DecodedFrame::Complete(message)) => { + let payload = message.payload(); + if payload.is_empty() { + continue; + } + + // The invoke-with-response-stream payload is + // {"bytes":""} + // We need to extract and decode the bytes field. + let json_bytes = match extract_bedrock_invoke_payload(payload) { + Ok(Some(decoded)) => decoded, + Ok(None) => continue, + Err(e) => return Poll::Ready(Some(Err(e))), + }; + + return Poll::Ready(Some(Ok(json_bytes))); + } + Ok(DecodedFrame::Incomplete) => { + // Need more data, fall through to poll inner stream + } + Err(e) => { + return Poll::Ready(Some(Err(Error::Provider { + provider: "bedrock".to_string(), + source: anyhow::anyhow!("Event stream decode error: {}", e), + retry_after: None, + http: None, + }))); + } + } + + match Pin::new(&mut this.inner).poll_next(cx) { + Poll::Ready(Some(Ok(bytes))) => { + this.buffer.extend_from_slice(&bytes); + } + Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))), + Poll::Ready(None) => { + this.finished = true; + return Poll::Ready(None); + } + Poll::Pending => return Poll::Pending, + } + } + } +} + +/// Create a Bedrock Messages API event stream that yields raw Anthropic JSON payloads. +/// +/// Uses the AWS binary event stream decoder but emits payloads directly +/// (no `{"eventType": payload}` wrapping). For use with Anthropic models on Bedrock. +#[cfg(feature = "provider-bedrock")] +pub fn bedrock_messages_event_stream(response: Response) -> RawResponseStream { + Box::pin(RawBedrockMessagesEventStream::new(response.bytes_stream())) +} + +/// Extract the Anthropic JSON payload from a Bedrock invoke-with-response-stream event. +/// +/// The event payload has the shape `{"bytes":""}`. +/// Returns `Ok(Some(decoded_bytes))` on success, `Ok(None)` if the payload +/// should be skipped, or `Err` on decode failure. +#[cfg(feature = "provider-bedrock")] +fn extract_bedrock_invoke_payload(raw: &[u8]) -> Result> { + use base64::Engine; + + let wrapper: Value = lingua::serde_json::from_slice(raw).map_err(|e| Error::Provider { + provider: "bedrock".to_string(), + source: anyhow::anyhow!("failed to parse invoke stream event: {}", e), + retry_after: None, + http: None, + })?; + + let b64 = match wrapper.get("bytes").and_then(Value::as_str) { + Some(s) => s, + None => return Ok(None), + }; + + let decoded = base64::engine::general_purpose::STANDARD + .decode(b64) + .map_err(|e| Error::Provider { + provider: "bedrock".to_string(), + source: anyhow::anyhow!("failed to base64-decode invoke stream event: {}", e), + retry_after: None, + http: None, + })?; + + Ok(Some(Bytes::from(decoded))) +} + fn split_event(buffer: &BytesMut) -> Option<(Bytes, BytesMut)> { // Check for \r\n\r\n first (4-byte CRLF delimiter) if let Some(index) = buffer.windows(4).position(|w| w == b"\r\n\r\n") { @@ -350,4 +488,76 @@ mod tests { buffer = rest; assert!(!buffer.is_empty()); } + + #[cfg(feature = "provider-bedrock")] + mod bedrock_messages_stream { + use super::*; + + #[test] + fn extract_bedrock_invoke_payload_decodes_base64() { + use base64::Engine; + + let inner_json = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#; + let encoded = base64::engine::general_purpose::STANDARD.encode(inner_json); + let wrapper = format!(r#"{{"bytes":"{}"}}"#, encoded); + + let result = extract_bedrock_invoke_payload(wrapper.as_bytes()).unwrap(); + assert!(result.is_some()); + let decoded = result.unwrap(); + assert_eq!(decoded.as_ref(), inner_json.as_bytes()); + } + + #[test] + fn extract_bedrock_invoke_payload_returns_none_without_bytes_field() { + let payload = br#"{"other_field": "value"}"#; + let result = extract_bedrock_invoke_payload(payload).unwrap(); + assert!(result.is_none()); + } + + #[test] + fn extract_bedrock_invoke_payload_errors_on_invalid_json() { + let payload = b"not json"; + let result = extract_bedrock_invoke_payload(payload); + assert!(result.is_err()); + } + + #[test] + fn extract_bedrock_invoke_payload_errors_on_invalid_base64() { + let payload = br#"{"bytes": "!!!not-valid-base64!!!"}"#; + let result = extract_bedrock_invoke_payload(payload); + assert!(result.is_err()); + } + + #[test] + fn extract_bedrock_invoke_payload_handles_message_start_event() { + use base64::Engine; + + let inner_json = r#"{"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","content":[],"model":"claude-3-5-sonnet","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":1}}}"#; + let encoded = base64::engine::general_purpose::STANDARD.encode(inner_json); + let wrapper = format!(r#"{{"bytes":"{}"}}"#, encoded); + + let result = extract_bedrock_invoke_payload(wrapper.as_bytes()).unwrap(); + assert!(result.is_some()); + let decoded = result.unwrap(); + let decoded_str = std::str::from_utf8(&decoded).unwrap(); + assert!(decoded_str.contains("message_start")); + assert!(decoded_str.contains("msg_123")); + } + + #[test] + fn extract_bedrock_invoke_payload_handles_message_stop_event() { + use base64::Engine; + + let inner_json = r#"{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":15}}"#; + let encoded = base64::engine::general_purpose::STANDARD.encode(inner_json); + let wrapper = format!(r#"{{"bytes":"{}"}}"#, encoded); + + let result = extract_bedrock_invoke_payload(wrapper.as_bytes()).unwrap(); + assert!(result.is_some()); + let decoded = result.unwrap(); + let decoded_str = std::str::from_utf8(&decoded).unwrap(); + assert!(decoded_str.contains("message_delta")); + assert!(decoded_str.contains("end_turn")); + } + } } diff --git a/crates/braintrust-llm-router/tests/router.rs b/crates/braintrust-llm-router/tests/router.rs index 605ec3df..2f87d0e8 100644 --- a/crates/braintrust-llm-router/tests/router.rs +++ b/crates/braintrust-llm-router/tests/router.rs @@ -383,3 +383,127 @@ async fn router_retries_and_propagates_terminal_error() { assert!(matches!(err, Error::Timeout)); assert_eq!(attempts.load(Ordering::SeqCst), 3); } + +#[derive(Clone)] +struct PayloadCapturingProvider { + received: Arc>>, +} + +#[async_trait] +impl Provider for PayloadCapturingProvider { + fn id(&self) -> &'static str { + "capturing" + } + + fn format(&self) -> ProviderFormat { + ProviderFormat::OpenAI + } + + async fn complete( + &self, + payload: Bytes, + _auth: &AuthConfig, + _spec: &ModelSpec, + _client_headers: &ClientHeaders, + ) -> braintrust_llm_router::Result { + *self.received.lock().unwrap() = Some(payload); + + let response = json!({ + "id": "test", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "hello"}], + "model": "test", + "stop_reason": "end_turn", + "usage": {"input_tokens": 1, "output_tokens": 1} + }); + Ok(Bytes::from( + braintrust_llm_router::serde_json::to_vec(&response).unwrap(), + )) + } + + async fn complete_stream( + &self, + _payload: Bytes, + _auth: &AuthConfig, + _spec: &ModelSpec, + _client_headers: &ClientHeaders, + ) -> braintrust_llm_router::Result { + Ok(Box::pin(tokio_stream::empty())) + } + + async fn health_check(&self, _auth: &AuthConfig) -> braintrust_llm_router::Result<()> { + Ok(()) + } +} + +#[tokio::test] +async fn router_transforms_to_target_format() { + let mut catalog = ModelCatalog::empty(); + catalog.insert( + "openai-model".into(), + ModelSpec { + model: "openai-model".into(), + format: ProviderFormat::OpenAI, + flavor: ModelFlavor::Chat, + display_name: None, + parent: None, + input_cost_per_mil_tokens: None, + output_cost_per_mil_tokens: None, + input_cache_read_cost_per_mil_tokens: None, + multimodal: None, + reasoning: None, + max_input_tokens: None, + max_output_tokens: None, + supports_streaming: true, + extra: Default::default(), + }, + ); + let catalog = Arc::new(catalog); + + let received = Arc::new(std::sync::Mutex::new(None)); + let provider = PayloadCapturingProvider { + received: Arc::clone(&received), + }; + + let router = RouterBuilder::new() + .with_catalog(catalog) + .add_provider("openai", provider) + .add_auth( + "openai", + AuthConfig::ApiKey { + key: "test".into(), + header: None, + prefix: None, + }, + ) + .build() + .expect("router builds"); + + // Send OpenAI-format request — should pass through + let body = to_body(json!({ + "model": "openai-model", + "messages": [{"role": "user", "content": "hello"}] + })); + + let _ = router + .complete( + body, + "openai-model", + ProviderFormat::OpenAI, + &ClientHeaders::default(), + ) + .await; + + let payload_bytes = received + .lock() + .unwrap() + .take() + .expect("provider received payload"); + let payload: Value = braintrust_llm_router::serde_json::from_slice(&payload_bytes).unwrap(); + + assert!( + payload.get("model").is_some(), + "payload should be in OpenAI format" + ); +} diff --git a/crates/lingua/src/lib.rs b/crates/lingua/src/lib.rs index 7112f18e..e24b0669 100644 --- a/crates/lingua/src/lib.rs +++ b/crates/lingua/src/lib.rs @@ -33,8 +33,9 @@ pub use capabilities::ProviderFormat; // Re-export key processing functions (bytes-based API) pub use processing::{ - extract_model, parse_stream_event, response_to_universal, sanitize_payload, transform_request, - transform_response, transform_stream_chunk, ParsedStreamEvent, TransformError, TransformResult, + extract_model, is_bedrock_anthropic_model, parse_stream_event, response_to_universal, + sanitize_payload, transform_request, transform_response, transform_stream_chunk, + ParsedStreamEvent, TransformError, TransformResult, }; // Re-export universal types diff --git a/crates/lingua/src/processing/mod.rs b/crates/lingua/src/processing/mod.rs index fd8d9696..c6c33598 100644 --- a/crates/lingua/src/processing/mod.rs +++ b/crates/lingua/src/processing/mod.rs @@ -10,6 +10,7 @@ pub use adapters::{ pub use dedup::deduplicate_messages; pub use import::{import_and_deduplicate_messages, import_messages_from_spans, Span}; pub use transform::{ - extract_model, parse_stream_event, response_to_universal, sanitize_payload, transform_request, - transform_response, transform_stream_chunk, ParsedStreamEvent, TransformError, TransformResult, + extract_model, is_bedrock_anthropic_model, parse_stream_event, response_to_universal, + sanitize_payload, transform_request, transform_response, transform_stream_chunk, + ParsedStreamEvent, TransformError, TransformResult, }; diff --git a/crates/lingua/src/processing/transform.rs b/crates/lingua/src/processing/transform.rs index 4fb3b489..d71fde24 100644 --- a/crates/lingua/src/processing/transform.rs +++ b/crates/lingua/src/processing/transform.rs @@ -79,6 +79,10 @@ impl TransformError { | TransformError::FromUniversalFailed(_) ) } + + pub fn is_unsupported_target_format(&self) -> bool { + matches!(self, TransformError::UnsupportedTargetFormat(_)) + } } impl From for TransformError { @@ -225,28 +229,37 @@ pub fn transform_request( let source_adapter = detect_adapter(&payload, DetectKind::Request)?; - if source_adapter.format() == target_format - && !needs_forced_translation(&payload, model, target_format) + // Bedrock Anthropic models are cataloged as Converse but use the Anthropic wire format + let effective_format = + if target_format == ProviderFormat::Converse && is_bedrock_anthropic_model(model) { + ProviderFormat::Anthropic + } else { + target_format + }; + + if source_adapter.format() == effective_format + && !needs_forced_translation(&payload, model, effective_format) { return Ok(TransformResult::PassThrough(input)); } let source_format = source_adapter.format(); - let target_adapter = adapter_for_format(target_format) - .ok_or(TransformError::UnsupportedTargetFormat(target_format))?; + let target_adapter = adapter_for_format(effective_format) + .ok_or(TransformError::UnsupportedTargetFormat(effective_format))?; let mut universal = source_adapter.request_to_universal(payload)?; - // Inject model from parameter if not present if model.is_some() && universal.model.is_none() { universal.model = model.map(String::from); } - // Apply target provider defaults (e.g., Anthropic's required max_tokens) target_adapter.apply_defaults(&mut universal); - // Convert to target format (validation happens in adapter) - let transformed = target_adapter.request_from_universal(&universal)?; + let mut transformed = target_adapter.request_from_universal(&universal)?; + + if effective_format == ProviderFormat::Anthropic && is_bedrock_anthropic_model(model) { + apply_bedrock_anthropic_mutations(&mut transformed); + } let bytes = crate::serde_json::to_vec(&transformed) .map_err(|e| TransformError::SerializationFailed(e.to_string()))?; @@ -501,8 +514,63 @@ fn detect_adapter( .ok_or(TransformError::UnableToDetectFormat) } +/// Check if a model name indicates a Bedrock Anthropic model +/// (e.g. `us.anthropic.claude-3-5-sonnet-20241022-v2:0` or `anthropic.claude-3-haiku-20240307-v1:0`). +pub fn is_bedrock_anthropic_model(model: Option<&str>) -> bool { + model.is_some_and(|m| m.starts_with("anthropic.") || m.contains(".anthropic.")) +} + +/// Apply Bedrock-specific mutations to an Anthropic-format payload. +/// +/// Bedrock's Anthropic Messages API requires: +/// - No `model` field (model is in the URL path) +/// - No `stream` field (streaming is determined by endpoint choice) +/// - `anthropic_version` header injected into the body +/// +/// # Lossy: whitespace-only stop sequences +/// +/// Bedrock's Anthropic Messages API has an undocumented restriction that rejects +/// stop sequences containing only whitespace (e.g. `"\n"`, `"\t"`, `" "`). The +/// direct Anthropic API and OpenAI both accept these -- `"\n"` is commonly used +/// to get single-line responses. +/// +/// There is no Bedrock equivalent or workaround, so we silently strip +/// whitespace-only entries. This is a lossy transformation: the caller asked the +/// model to stop on newline but Bedrock simply cannot honor that request. +/// +// Bedrock Anthropic models are cataloged as Converse but routed to the Anthropic +// adapter via `effective_format`. These mutations finalize the payload for Bedrock's +// Anthropic Messages API (model/stream stripped, anthropic_version injected). +fn apply_bedrock_anthropic_mutations(value: &mut Value) { + if let Some(obj) = value.as_object_mut() { + obj.remove("model"); + obj.remove("stream"); + obj.insert( + "anthropic_version".to_string(), + Value::String("bedrock-2023-05-31".to_string()), + ); + + // See doc comment above -- Bedrock rejects whitespace-only stop sequences. + if let Some(Value::Array(seqs)) = obj.get_mut("stop_sequences") { + seqs.retain(|s| { + s.as_str() + .is_some_and(|s| s.contains(|c: char| !c.is_whitespace())) + }); + if seqs.is_empty() { + obj.remove("stop_sequences"); + } + } + } +} + /// Check if a request needs forced translation even when source == target format. fn needs_forced_translation(payload: &Value, model: Option<&str>, target: ProviderFormat) -> bool { + // Bedrock Anthropic models use Anthropic wire format (via effective_format) and + // always need translation so we can apply bedrock-specific mutations. + if target == ProviderFormat::Anthropic && is_bedrock_anthropic_model(model) { + return true; + } + if target != ProviderFormat::OpenAI { return false; } @@ -756,4 +824,152 @@ mod tests { "Non-reasoning models should passthrough" ); } + + #[test] + #[cfg(feature = "anthropic")] + fn test_bedrock_anthropic_mutations_applied() { + let payload = json!({ + "model": "us.anthropic.claude-3-5-sonnet-20241022-v2:0", + "max_tokens": 1024, + "stream": true, + "messages": [{"role": "user", "content": "Hello"}] + }); + let input = to_bytes(&payload); + let model = "us.anthropic.claude-3-5-sonnet-20241022-v2:0"; + + // Target is Converse (from catalog), but effective_format promotes to Anthropic + let result = transform_request(input, ProviderFormat::Converse, Some(model)).unwrap(); + + assert!( + !result.is_passthrough(), + "Bedrock Anthropic models should force translation" + ); + + let output: Value = crate::serde_json::from_slice(result.as_bytes()).unwrap(); + assert!( + output.get("model").is_none(), + "model should be stripped for Bedrock Anthropic" + ); + assert!( + output.get("stream").is_none(), + "stream should be stripped for Bedrock Anthropic" + ); + assert_eq!( + output.get("anthropic_version").and_then(Value::as_str), + Some("bedrock-2023-05-31"), + "anthropic_version should be injected" + ); + assert_eq!( + output.get("max_tokens").and_then(Value::as_i64), + Some(1024), + "max_tokens should be preserved" + ); + } + + #[test] + #[cfg(feature = "anthropic")] + fn test_non_bedrock_anthropic_passthrough() { + let payload = json!({ + "model": "claude-3-5-sonnet", + "max_tokens": 1024, + "messages": [{"role": "user", "content": "Hello"}] + }); + let input = to_bytes(&payload); + + let result = transform_request(input, ProviderFormat::Anthropic, None).unwrap(); + + assert!( + result.is_passthrough(), + "Non-Bedrock Anthropic models should passthrough" + ); + } + + #[test] + #[cfg(feature = "anthropic")] + fn test_bedrock_anthropic_strips_whitespace_stop_sequences() { + let payload = json!({ + "model": "us.anthropic.claude-haiku-4-5-20251001-v1:0", + "max_tokens": 50, + "messages": [{"role": "user", "content": "Hello"}], + "stop_sequences": ["\n", "end", "\t"] + }); + let input = to_bytes(&payload); + let model = "us.anthropic.claude-haiku-4-5-20251001-v1:0"; + + let result = transform_request(input, ProviderFormat::Converse, Some(model)).unwrap(); + let output: Value = crate::serde_json::from_slice(result.as_bytes()).unwrap(); + + let stop_seqs = output + .get("stop_sequences") + .and_then(Value::as_array) + .expect("stop_sequences should be present"); + assert_eq!(stop_seqs.len(), 1); + assert_eq!(stop_seqs[0].as_str(), Some("end")); + } + + #[test] + #[cfg(feature = "anthropic")] + fn test_bedrock_anthropic_removes_all_whitespace_stop_sequences() { + let payload = json!({ + "model": "us.anthropic.claude-haiku-4-5-20251001-v1:0", + "max_tokens": 50, + "messages": [{"role": "user", "content": "Hello"}], + "stop_sequences": ["\n", "\t", " "] + }); + let input = to_bytes(&payload); + let model = "us.anthropic.claude-haiku-4-5-20251001-v1:0"; + + let result = transform_request(input, ProviderFormat::Converse, Some(model)).unwrap(); + let output: Value = crate::serde_json::from_slice(result.as_bytes()).unwrap(); + + assert!( + output.get("stop_sequences").is_none(), + "stop_sequences should be removed when all entries are whitespace-only" + ); + } + + #[test] + #[cfg(all(feature = "openai", feature = "bedrock"))] + fn test_non_bedrock_converse_uses_converse_adapter() { + let payload = json!({ + "model": "amazon.nova-micro-v1:0", + "messages": [{"role": "user", "content": "Hello"}] + }); + let input = to_bytes(&payload); + let model = "amazon.nova-micro-v1:0"; + + let result = transform_request(input, ProviderFormat::Converse, Some(model)).unwrap(); + + assert!( + !result.is_passthrough(), + "OpenAI -> Converse should transform" + ); + + let output: Value = crate::serde_json::from_slice(result.as_bytes()).unwrap(); + assert!( + output.get("modelId").is_some(), + "non-Bedrock Converse model should use Converse adapter (has modelId)" + ); + assert!( + output.get("anthropic_version").is_none(), + "non-Bedrock model should not get anthropic_version" + ); + } + + #[test] + fn test_is_bedrock_anthropic_model() { + assert!(super::is_bedrock_anthropic_model(Some( + "us.anthropic.claude-3-5-sonnet-20241022-v2:0" + ))); + assert!(super::is_bedrock_anthropic_model(Some( + "anthropic.claude-3-haiku-20240307-v1:0" + ))); + assert!(!super::is_bedrock_anthropic_model(Some( + "claude-3-5-sonnet" + ))); + assert!(!super::is_bedrock_anthropic_model(Some( + "amazon.nova-micro-v1:0" + ))); + assert!(!super::is_bedrock_anthropic_model(None)); + } } diff --git a/crates/lingua/src/universal/reasoning.rs b/crates/lingua/src/universal/reasoning.rs index 24c2e169..6b54d920 100644 --- a/crates/lingua/src/universal/reasoning.rs +++ b/crates/lingua/src/universal/reasoning.rs @@ -311,7 +311,7 @@ impl ReasoningConfig { ProviderFormat::OpenAI => Ok(to_openai_chat(self, max_tokens).map(Value::String)), ProviderFormat::Responses => Ok(to_openai_responses(self, max_tokens)), ProviderFormat::Anthropic => Ok(to_anthropic(self, max_tokens)), - ProviderFormat::Converse => Ok(to_anthropic(self, max_tokens)), // Bedrock uses same format as Anthropic + ProviderFormat::Converse => Ok(to_anthropic(self, max_tokens)), // Bedrock Converse uses same thinking format as Anthropic ProviderFormat::Google => Ok(to_google(self, max_tokens)), _ => Ok(None), } diff --git a/payloads/cases/advanced.ts b/payloads/cases/advanced.ts index 2ea18699..d3a4d243 100644 --- a/payloads/cases/advanced.ts +++ b/payloads/cases/advanced.ts @@ -4,7 +4,7 @@ import { OPENAI_CHAT_COMPLETIONS_MODEL, OPENAI_RESPONSES_MODEL, ANTHROPIC_MODEL, - BEDROCK_MODEL, + BEDROCK_ANTH_MODEL, } from "./models"; const IMAGE_BASE64 = @@ -100,7 +100,7 @@ export const advancedCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", @@ -178,7 +178,7 @@ export const advancedCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", @@ -236,7 +236,7 @@ export const advancedCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", @@ -364,7 +364,7 @@ export const advancedCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", diff --git a/payloads/cases/models.ts b/payloads/cases/models.ts index e42a3eb5..ecca1d52 100644 --- a/payloads/cases/models.ts +++ b/payloads/cases/models.ts @@ -7,4 +7,5 @@ export const ANTHROPIC_MODEL = "claude-sonnet-4-20250514"; // For Anthropic structured outputs (requires Sonnet 4.5+ for JSON schema output_format) export const ANTHROPIC_STRUCTURED_OUTPUT_MODEL = "claude-sonnet-4-5-20250929"; export const GOOGLE_MODEL = "gemini-2.5-flash"; -export const BEDROCK_MODEL = "us.anthropic.claude-haiku-4-5-20251001-v1:0"; +export const BEDROCK_ANTH_MODEL = "us.anthropic.claude-haiku-4-5-20251001-v1:0"; +export const BEDROCK_CONVERSE_MODEL = "amazon.nova-micro-v1:0"; diff --git a/payloads/cases/simple.ts b/payloads/cases/simple.ts index 8fbdf0f6..bdab119b 100644 --- a/payloads/cases/simple.ts +++ b/payloads/cases/simple.ts @@ -4,7 +4,7 @@ import { OPENAI_CHAT_COMPLETIONS_MODEL, OPENAI_RESPONSES_MODEL, ANTHROPIC_MODEL, - BEDROCK_MODEL, + BEDROCK_ANTH_MODEL, } from "./models"; // Simple test cases - basic functionality testing @@ -54,7 +54,7 @@ export const simpleCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", @@ -113,7 +113,7 @@ export const simpleCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", @@ -181,7 +181,7 @@ export const simpleCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", @@ -313,7 +313,7 @@ export const simpleCases: TestCaseCollection = { }, bedrock: { - modelId: BEDROCK_MODEL, + modelId: BEDROCK_ANTH_MODEL, messages: [ { role: "user", diff --git a/payloads/scripts/validation/index.ts b/payloads/scripts/validation/index.ts index 91b569ff..5f1788a6 100644 --- a/payloads/scripts/validation/index.ts +++ b/payloads/scripts/validation/index.ts @@ -14,7 +14,8 @@ import { OPENAI_CHAT_COMPLETIONS_MODEL, ANTHROPIC_STRUCTURED_OUTPUT_MODEL, GOOGLE_MODEL, - BEDROCK_MODEL, + BEDROCK_ANTH_MODEL, + BEDROCK_CONVERSE_MODEL, } from "../../cases/models"; import { proxyCases, @@ -226,7 +227,8 @@ const PROVIDER_REGISTRY: Record = { openai: OPENAI_CHAT_COMPLETIONS_MODEL, anthropic: ANTHROPIC_STRUCTURED_OUTPUT_MODEL, google: GOOGLE_MODEL, - bedrock: BEDROCK_MODEL, + bedrock: BEDROCK_ANTH_MODEL, + "bedrock-converse": BEDROCK_CONVERSE_MODEL, }; /** @@ -451,22 +453,16 @@ export async function runValidation( return result; } - // Override model only for cross-provider testing - // OpenAI formats (chat-completions, responses) with non-OpenAI providers + // Override model for cross-provider testing (any format with non-default provider) if ( providerAlias !== "default" && providerAlias !== "openai" && // Don't override for OpenAI - tests have correct models PROVIDER_REGISTRY[providerAlias] ) { - const isOpenAIFormat = - format === "chat-completions" || format === "responses"; - if (isOpenAIFormat) { - // Override for cross-provider translation testing - request = { - ...request, - model: PROVIDER_REGISTRY[providerAlias], - }; - } + request = { + ...request, + model: PROVIDER_REGISTRY[providerAlias], + }; } // Execute through proxy