Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 63 additions & 4 deletions crates/braintrust-llm-router/src/providers/bedrock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ use bytes::Bytes;
use http::Request as HttpRequest;
use lingua::serde_json::Value;
use reqwest::header::{HeaderMap, HeaderValue, CONTENT_TYPE};

use reqwest::{Client, StatusCode, Url};

use crate::auth::AuthConfig;
use crate::catalog::ModelSpec;
use crate::client::{default_client, ClientSettings};
use crate::error::{Error, Result, UpstreamHttpError};
use crate::providers::ClientHeaders;
use crate::streaming::{bedrock_event_stream, single_bytes_stream, RawResponseStream};
use crate::streaming::{
bedrock_event_stream, bedrock_messages_event_stream, single_bytes_stream, RawResponseStream,
};
use lingua::ProviderFormat;

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -91,7 +94,13 @@ impl BedrockProvider {
}

fn invoke_url(&self, model: &str, stream: bool) -> Result<Url> {
let path = if stream {
let path = if lingua::is_bedrock_anthropic_model(Some(model)) {
if stream {
format!("model/{model}/invoke-with-response-stream")
} else {
format!("model/{model}/invoke")
}
} else if stream {
format!("model/{model}/converse-stream")
} else {
format!("model/{model}/converse")
Expand Down Expand Up @@ -267,7 +276,6 @@ impl crate::providers::Provider for BedrockProvider {
return Ok(single_bytes_stream(response));
}

// Router should have already added stream options to payload
let url = self.invoke_url(&spec.model, true)?;

#[cfg(feature = "tracing")]
Expand Down Expand Up @@ -317,7 +325,11 @@ impl crate::providers::Provider for BedrockProvider {
});
}

Ok(bedrock_event_stream(response))
if lingua::is_bedrock_anthropic_model(Some(&spec.model)) {
Ok(bedrock_messages_event_stream(response))
} else {
Ok(bedrock_event_stream(response))
}
}

async fn health_check(&self, auth: &AuthConfig) -> Result<()> {
Expand Down Expand Up @@ -358,3 +370,50 @@ fn extract_retry_after(status: StatusCode, _body: &str) -> Option<Duration> {
None
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn invoke_url_converse_mode() {
let provider = BedrockProvider::new(BedrockConfig::default()).unwrap();

let url = provider
.invoke_url("amazon.nova-micro-v1:0", false)
.unwrap();
assert!(url
.as_str()
.contains("/model/amazon.nova-micro-v1:0/converse"));

let url = provider
.invoke_url("amazon.nova-micro-v1:0", true)
.unwrap();
assert!(url
.as_str()
.contains("/model/amazon.nova-micro-v1:0/converse-stream"));
}

#[test]
fn invoke_url_anthropic_messages_mode() {
let provider = BedrockProvider::new(BedrockConfig::default()).unwrap();

let url = provider
.invoke_url("anthropic.claude-3-5-sonnet-20241022-v2:0", false)
.unwrap();
assert!(url
.as_str()
.contains("/model/anthropic.claude-3-5-sonnet-20241022-v2:0/invoke"));
assert!(
!url.as_str().contains("converse"),
"Anthropic messages mode should not use converse endpoints"
);

let url = provider
.invoke_url("anthropic.claude-3-5-sonnet-20241022-v2:0", true)
.unwrap();
assert!(url.as_str().contains(
"/model/anthropic.claude-3-5-sonnet-20241022-v2:0/invoke-with-response-stream"
));
}
}
9 changes: 4 additions & 5 deletions crates/braintrust-llm-router/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,7 @@ impl Router {
client_headers: &ClientHeaders,
) -> Result<Bytes> {
let (provider, auth, spec, strategy) = self.resolve_provider(model)?;
let payload = match lingua::transform_request(body.clone(), provider.format(), Some(model))
{
let payload = match lingua::transform_request(body.clone(), spec.format, Some(model)) {
Ok(TransformResult::PassThrough(bytes)) => bytes,
Ok(TransformResult::Transformed { bytes, .. }) => bytes,
Err(TransformError::UnsupportedTargetFormat(_)) => body.clone(),
Expand Down Expand Up @@ -205,8 +204,7 @@ impl Router {
client_headers: &ClientHeaders,
) -> Result<ResponseStream> {
let (provider, auth, spec, _) = self.resolve_provider(model)?;
let payload = match lingua::transform_request(body.clone(), provider.format(), Some(model))
{
let payload = match lingua::transform_request(body.clone(), spec.format, Some(model)) {
Ok(TransformResult::PassThrough(bytes)) => bytes,
Ok(TransformResult::Transformed { bytes, .. }) => bytes,
Err(TransformError::UnsupportedTargetFormat(_)) => body.clone(),
Expand All @@ -222,7 +220,8 @@ impl Router {

pub fn provider_alias(&self, model: &str) -> Result<String> {
let (_, format, alias) = self.resolver.resolve(model)?;
Ok(self.formats.get(&format).cloned().unwrap_or(alias))
let alias = self.formats.get(&format).cloned().unwrap_or(alias);
Ok(alias)
}

fn resolve_provider(&self, model: &str) -> Result<ResolvedRoute<'_>> {
Expand Down
210 changes: 210 additions & 0 deletions crates/braintrust-llm-router/src/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use futures::Stream;
use reqwest::Response;

use crate::error::{Error, Result};
#[cfg(feature = "provider-bedrock")]
use lingua::serde_json::Value;
use lingua::ProviderFormat;
use lingua::TransformResult;

Expand Down Expand Up @@ -305,6 +307,142 @@ pub fn bedrock_event_stream(response: Response) -> RawResponseStream {
Box::pin(RawBedrockEventStream::new(response.bytes_stream()))
}

/// Bedrock Messages API event stream that yields raw Anthropic JSON payloads.
///
/// Uses the same AWS binary event stream decoder as the Converse stream but
/// emits the payload bytes directly without wrapping in `{"eventType": payload}`.
/// The payloads are already valid Anthropic streaming JSON events.
#[cfg(feature = "provider-bedrock")]
struct RawBedrockMessagesEventStream<S>
where
S: Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Unpin + Send + 'static,
{
inner: S,
buffer: BytesMut,
decoder: aws_smithy_eventstream::frame::MessageFrameDecoder,
finished: bool,
}

#[cfg(feature = "provider-bedrock")]
impl<S> RawBedrockMessagesEventStream<S>
where
S: Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Unpin + Send + 'static,
{
fn new(inner: S) -> Self {
Self {
inner,
buffer: BytesMut::new(),
decoder: aws_smithy_eventstream::frame::MessageFrameDecoder::new(),
finished: false,
}
}
}

#[cfg(feature = "provider-bedrock")]
impl<S> Stream for RawBedrockMessagesEventStream<S>
where
S: Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Unpin + Send + 'static,
{
type Item = Result<Bytes>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
use aws_smithy_eventstream::frame::DecodedFrame;

let this = self.get_mut();

if this.finished {
return Poll::Ready(None);
}

loop {
match this.decoder.decode_frame(&mut this.buffer) {
Ok(DecodedFrame::Complete(message)) => {
let payload = message.payload();
if payload.is_empty() {
continue;
}

// The invoke-with-response-stream payload is
// {"bytes":"<base64-encoded Anthropic JSON>"}
// We need to extract and decode the bytes field.
let json_bytes = match extract_bedrock_invoke_payload(payload) {
Ok(Some(decoded)) => decoded,
Ok(None) => continue,
Err(e) => return Poll::Ready(Some(Err(e))),
};

return Poll::Ready(Some(Ok(json_bytes)));
}
Ok(DecodedFrame::Incomplete) => {
// Need more data, fall through to poll inner stream
}
Err(e) => {
return Poll::Ready(Some(Err(Error::Provider {
provider: "bedrock".to_string(),
source: anyhow::anyhow!("Event stream decode error: {}", e),
retry_after: None,
http: None,
})));
}
}

match Pin::new(&mut this.inner).poll_next(cx) {
Poll::Ready(Some(Ok(bytes))) => {
this.buffer.extend_from_slice(&bytes);
}
Poll::Ready(Some(Err(err))) => return Poll::Ready(Some(Err(err.into()))),
Poll::Ready(None) => {
this.finished = true;
return Poll::Ready(None);
}
Poll::Pending => return Poll::Pending,
}
}
}
}

/// Create a Bedrock Messages API event stream that yields raw Anthropic JSON payloads.
///
/// Uses the AWS binary event stream decoder but emits payloads directly
/// (no `{"eventType": payload}` wrapping). For use with Anthropic models on Bedrock.
#[cfg(feature = "provider-bedrock")]
pub fn bedrock_messages_event_stream(response: Response) -> RawResponseStream {
Box::pin(RawBedrockMessagesEventStream::new(response.bytes_stream()))
}

/// Extract the Anthropic JSON payload from a Bedrock invoke-with-response-stream event.
///
/// The event payload has the shape `{"bytes":"<base64-encoded JSON>"}`.
/// Returns `Ok(Some(decoded_bytes))` on success, `Ok(None)` if the payload
/// should be skipped, or `Err` on decode failure.
#[cfg(feature = "provider-bedrock")]
fn extract_bedrock_invoke_payload(raw: &[u8]) -> Result<Option<Bytes>> {
use base64::Engine;

let wrapper: Value = lingua::serde_json::from_slice(raw).map_err(|e| Error::Provider {
provider: "bedrock".to_string(),
source: anyhow::anyhow!("failed to parse invoke stream event: {}", e),
retry_after: None,
http: None,
})?;

let b64 = match wrapper.get("bytes").and_then(Value::as_str) {
Some(s) => s,
None => return Ok(None),
};

let decoded = base64::engine::general_purpose::STANDARD
.decode(b64)
.map_err(|e| Error::Provider {
provider: "bedrock".to_string(),
source: anyhow::anyhow!("failed to base64-decode invoke stream event: {}", e),
retry_after: None,
http: None,
})?;

Ok(Some(Bytes::from(decoded)))
}

fn split_event(buffer: &BytesMut) -> Option<(Bytes, BytesMut)> {
// Check for \r\n\r\n first (4-byte CRLF delimiter)
if let Some(index) = buffer.windows(4).position(|w| w == b"\r\n\r\n") {
Expand Down Expand Up @@ -350,4 +488,76 @@ mod tests {
buffer = rest;
assert!(!buffer.is_empty());
}

#[cfg(feature = "provider-bedrock")]
mod bedrock_messages_stream {
use super::*;

#[test]
fn extract_bedrock_invoke_payload_decodes_base64() {
use base64::Engine;

let inner_json = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}"#;
let encoded = base64::engine::general_purpose::STANDARD.encode(inner_json);
let wrapper = format!(r#"{{"bytes":"{}"}}"#, encoded);

let result = extract_bedrock_invoke_payload(wrapper.as_bytes()).unwrap();
assert!(result.is_some());
let decoded = result.unwrap();
assert_eq!(decoded.as_ref(), inner_json.as_bytes());
}

#[test]
fn extract_bedrock_invoke_payload_returns_none_without_bytes_field() {
let payload = br#"{"other_field": "value"}"#;
let result = extract_bedrock_invoke_payload(payload).unwrap();
assert!(result.is_none());
}

#[test]
fn extract_bedrock_invoke_payload_errors_on_invalid_json() {
let payload = b"not json";
let result = extract_bedrock_invoke_payload(payload);
assert!(result.is_err());
}

#[test]
fn extract_bedrock_invoke_payload_errors_on_invalid_base64() {
let payload = br#"{"bytes": "!!!not-valid-base64!!!"}"#;
let result = extract_bedrock_invoke_payload(payload);
assert!(result.is_err());
}

#[test]
fn extract_bedrock_invoke_payload_handles_message_start_event() {
use base64::Engine;

let inner_json = r#"{"type":"message_start","message":{"id":"msg_123","type":"message","role":"assistant","content":[],"model":"claude-3-5-sonnet","stop_reason":null,"usage":{"input_tokens":10,"output_tokens":1}}}"#;
let encoded = base64::engine::general_purpose::STANDARD.encode(inner_json);
let wrapper = format!(r#"{{"bytes":"{}"}}"#, encoded);

let result = extract_bedrock_invoke_payload(wrapper.as_bytes()).unwrap();
assert!(result.is_some());
let decoded = result.unwrap();
let decoded_str = std::str::from_utf8(&decoded).unwrap();
assert!(decoded_str.contains("message_start"));
assert!(decoded_str.contains("msg_123"));
}

#[test]
fn extract_bedrock_invoke_payload_handles_message_stop_event() {
use base64::Engine;

let inner_json = r#"{"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":15}}"#;
let encoded = base64::engine::general_purpose::STANDARD.encode(inner_json);
let wrapper = format!(r#"{{"bytes":"{}"}}"#, encoded);

let result = extract_bedrock_invoke_payload(wrapper.as_bytes()).unwrap();
assert!(result.is_some());
let decoded = result.unwrap();
let decoded_str = std::str::from_utf8(&decoded).unwrap();
assert!(decoded_str.contains("message_delta"));
assert!(decoded_str.contains("end_turn"));
}
}
}
Loading