diff --git a/config/plano_config_schema.yaml b/config/plano_config_schema.yaml index 5190fecf7..e011ae72b 100644 --- a/config/plano_config_schema.yaml +++ b/config/plano_config_schema.yaml @@ -43,6 +43,8 @@ properties: - streamable-http tool: type: string + streaming: + type: boolean additionalProperties: false required: - id diff --git a/crates/brightstaff/Cargo.toml b/crates/brightstaff/Cargo.toml index 5d986ffa5..183138210 100644 --- a/crates/brightstaff/Cargo.toml +++ b/crates/brightstaff/Cargo.toml @@ -26,7 +26,7 @@ opentelemetry-stdout = "0.31" opentelemetry_sdk = { version = "0.31", features = ["rt-tokio"] } pretty_assertions = "1.4.1" rand = "0.9.2" -reqwest = { version = "0.12.15", features = ["stream"] } +reqwest = { version = "0.12.15", features = ["stream", "http2"] } serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" serde_with = "3.13.0" diff --git a/crates/brightstaff/src/handlers/agent_selector.rs b/crates/brightstaff/src/handlers/agent_selector.rs index 2341e156c..44e69904f 100644 --- a/crates/brightstaff/src/handlers/agent_selector.rs +++ b/crates/brightstaff/src/handlers/agent_selector.rs @@ -210,6 +210,7 @@ mod tests { url: "http://localhost:8080".to_string(), tool: None, transport: None, + streaming: None, } } diff --git a/crates/brightstaff/src/handlers/integration_tests.rs b/crates/brightstaff/src/handlers/integration_tests.rs index c5bfb1b26..dc3a452fe 100644 --- a/crates/brightstaff/src/handlers/integration_tests.rs +++ b/crates/brightstaff/src/handlers/integration_tests.rs @@ -52,6 +52,7 @@ mod tests { url: "http://localhost:8081".to_string(), tool: None, transport: None, + streaming: None, }, Agent { id: "terminal-agent".to_string(), @@ -59,6 +60,7 @@ mod tests { url: "http://localhost:8082".to_string(), tool: None, transport: None, + streaming: None, }, ]; diff --git a/crates/brightstaff/src/handlers/pipeline_processor.rs b/crates/brightstaff/src/handlers/pipeline_processor.rs index 4cb8531f6..193090254 100644 --- a/crates/brightstaff/src/handlers/pipeline_processor.rs +++ b/crates/brightstaff/src/handlers/pipeline_processor.rs @@ -10,6 +10,9 @@ use hermesllm::{ProviderRequest, ProviderRequestType}; use hyper::header::HeaderMap; use opentelemetry::global; use opentelemetry_http::HeaderInjector; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::StreamExt; use tracing::{debug, info, instrument, warn}; use crate::handlers::jsonrpc::{ @@ -50,6 +53,18 @@ pub enum PipelineError { }, } +/// A live streaming filter pipeline. LLM chunks go into `input_tx`; +/// processed chunks come out of `output_rx`. Each filter in the chain +/// is connected via a single bidirectional streaming HTTP connection. +#[derive(Debug)] +pub struct StreamingFilterPipeline { + pub input_tx: mpsc::Sender, + pub output_rx: mpsc::Receiver, + pub handles: Vec>, +} + +const STREAMING_PIPELINE_BUFFER: usize = 16; + /// Service for processing agent pipelines pub struct PipelineProcessor { client: reqwest::Client, @@ -429,6 +444,130 @@ impl PipelineProcessor { session_id } + /// Build headers for an HTTP raw filter request (shared by per-chunk and streaming paths). + fn build_raw_filter_headers( + request_headers: &HeaderMap, + agent_id: &str, + ) -> Result { + let mut headers = request_headers.clone(); + headers.remove(hyper::header::CONTENT_LENGTH); + + headers.remove(TRACE_PARENT_HEADER); + global::get_text_map_propagator(|propagator| { + let cx = + tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current()); + propagator.inject_context(&cx, &mut HeaderInjector(&mut headers)); + }); + + headers.insert( + ARCH_UPSTREAM_HOST_HEADER, + hyper::header::HeaderValue::from_str(agent_id) + .map_err(|_| PipelineError::AgentNotFound(agent_id.to_string()))?, + ); + headers.insert( + ENVOY_RETRY_HEADER, + hyper::header::HeaderValue::from_str("3").unwrap(), + ); + headers.insert( + "Accept", + hyper::header::HeaderValue::from_static("application/json, text/event-stream"), + ); + headers.insert( + "Content-Type", + hyper::header::HeaderValue::from_static("application/octet-stream"), + ); + + Ok(headers) + } + + /// Set up a bidirectional streaming output filter pipeline. + /// + /// Opens one streaming POST per filter (using chunked transfer encoding) + /// and chains them: the response stream of filter N feeds the request body + /// of filter N+1. Returns a pipeline where the caller pushes LLM chunks + /// into `input_tx` and reads processed chunks from `output_rx`. + pub async fn start_streaming_output_pipeline( + agents: &[&Agent], + request_headers: &HeaderMap, + request_path: &str, + ) -> Result { + let client = reqwest::Client::builder() + .build() + .map_err(PipelineError::RequestFailed)?; + + let (input_tx, first_rx) = mpsc::channel::(STREAMING_PIPELINE_BUFFER); + let mut current_rx = first_rx; + let mut handles = Vec::new(); + + for agent in agents { + let url = format!("{}{}", agent.url, request_path); + let headers = Self::build_raw_filter_headers(request_headers, &agent.id)?; + + let body_stream = ReceiverStream::new(current_rx).map(Ok::<_, std::io::Error>); + let body = reqwest::Body::wrap_stream(body_stream); + + debug!(agent = %agent.id, url = %url, "opening streaming filter connection"); + + let response = client.post(&url).headers(headers).body(body).send().await?; + + let http_status = response.status(); + if !http_status.is_success() { + let error_body = response + .text() + .await + .unwrap_or_else(|_| "".to_string()); + return Err(if http_status.is_client_error() { + PipelineError::ClientError { + agent: agent.id.clone(), + status: http_status.as_u16(), + body: error_body, + } + } else { + PipelineError::ServerError { + agent: agent.id.clone(), + status: http_status.as_u16(), + body: error_body, + } + }); + } + + let (next_tx, next_rx) = mpsc::channel::(STREAMING_PIPELINE_BUFFER); + let agent_id = agent.id.clone(); + + let handle = tokio::spawn(async move { + let mut resp_stream = response.bytes_stream(); + while let Some(item) = resp_stream.next().await { + match item { + Ok(chunk) => { + if next_tx.send(chunk).await.is_err() { + debug!(agent = %agent_id, "streaming pipeline receiver dropped"); + break; + } + } + Err(e) => { + warn!(agent = %agent_id, error = %e, "streaming filter response error"); + break; + } + } + } + debug!(agent = %agent_id, "streaming filter stage completed"); + }); + + handles.push(handle); + current_rx = next_rx; + } + + info!( + filter_count = agents.len(), + "streaming output filter pipeline established" + ); + Ok(StreamingFilterPipeline { + input_tx, + output_rx: current_rx, + handles, + }) + } + /// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back. /// Used for input and output filters where the full raw request/response is passed through. /// No MCP protocol wrapping; agent_type is ignored. @@ -454,25 +593,7 @@ impl PipelineProcessor { span.update_name(format!("execute_raw_filter ({})", agent.id)); }); - let mut agent_headers = request_headers.clone(); - agent_headers.remove(hyper::header::CONTENT_LENGTH); - - agent_headers.remove(TRACE_PARENT_HEADER); - global::get_text_map_propagator(|propagator| { - let cx = - tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current()); - propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers)); - }); - - agent_headers.insert( - ARCH_UPSTREAM_HOST_HEADER, - hyper::header::HeaderValue::from_str(&agent.id) - .map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?, - ); - agent_headers.insert( - ENVOY_RETRY_HEADER, - hyper::header::HeaderValue::from_str("3").unwrap(), - ); + let mut agent_headers = Self::build_raw_filter_headers(request_headers, &agent.id)?; agent_headers.insert( "Accept", hyper::header::HeaderValue::from_static("application/json"), @@ -482,9 +603,6 @@ impl PipelineProcessor { hyper::header::HeaderValue::from_static("application/json"), ); - // Append the original request path so the filter endpoint encodes the API format. - // e.g. agent.url="http://host/anonymize" + request_path="/v1/chat/completions" - // -> POST http://host/anonymize/v1/chat/completions let url = format!("{}{}", agent.url, request_path); debug!(agent = %agent.id, url = %url, "sending raw filter request"); @@ -682,6 +800,7 @@ mod tests { tool: None, url: server_url, agent_type: None, + streaming: None, }; let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]}); @@ -722,6 +841,7 @@ mod tests { tool: None, url: server_url, agent_type: None, + streaming: None, }; let body = serde_json::json!({"messages": [{"role": "user", "content": "Ping"}]}); @@ -775,6 +895,7 @@ mod tests { tool: None, url: server_url, agent_type: None, + streaming: None, }; let body = serde_json::json!({"messages": [{"role": "user", "content": "Hi"}]}); @@ -793,4 +914,29 @@ mod tests { _ => panic!("Expected client error when isError flag is set"), } } + + #[tokio::test] + async fn test_streaming_pipeline_connection_refused() { + let agent = Agent { + id: "unreachable".to_string(), + transport: None, + tool: None, + url: "http://127.0.0.1:1".to_string(), + agent_type: Some("http".to_string()), + streaming: Some(true), + }; + let headers = HeaderMap::new(); + let result = PipelineProcessor::start_streaming_output_pipeline( + &[&agent], + &headers, + "/v1/chat/completions", + ) + .await; + + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + PipelineError::RequestFailed(_) + )); + } } diff --git a/crates/brightstaff/src/handlers/streaming.rs b/crates/brightstaff/src/handlers/streaming.rs index 0cea182e1..af5cded2b 100644 --- a/crates/brightstaff/src/handlers/streaming.rs +++ b/crates/brightstaff/src/handlers/streaming.rs @@ -294,11 +294,13 @@ where } /// Creates a streaming response that processes each raw chunk through output filters. -/// Filters receive the raw LLM response bytes and request path (any API shape; not limited to -/// chat completions). On filter error mid-stream the original chunk is passed through (headers already sent). +/// +/// If all filters in the chain have `streaming: true`, uses a single bidirectional +/// HTTP/2 connection per filter (no per-chunk overhead). Otherwise falls back to +/// per-chunk HTTP requests (the original behavior). pub fn create_streaming_response_with_output_filter( - mut byte_stream: S, - mut inner_processor: P, + byte_stream: S, + inner_processor: P, output_chain: ResolvedFilterChain, request_headers: HeaderMap, request_path: String, @@ -307,92 +309,251 @@ where S: StreamExt> + Send + Unpin + 'static, P: StreamProcessor, { + let use_streaming = output_chain.all_support_streaming(); let (tx, rx) = mpsc::channel::(STREAM_BUFFER_SIZE); let current_span = tracing::Span::current(); - let processor_handle = tokio::spawn( - async move { - let mut is_first_chunk = true; - let mut pipeline_processor = PipelineProcessor::default(); - let chain = output_chain.to_agent_filter_chain("output_filter"); + let processor_handle = if use_streaming { + info!("using bidirectional streaming output filter pipeline"); + spawn_streaming_output_filter( + byte_stream, + inner_processor, + output_chain, + request_headers, + request_path, + tx, + current_span, + ) + } else { + debug!("using per-chunk output filter pipeline"); + spawn_per_chunk_output_filter( + byte_stream, + inner_processor, + output_chain, + request_headers, + request_path, + tx, + current_span, + ) + }; - while let Some(item) = byte_stream.next().await { - let chunk = match item { - Ok(chunk) => chunk, - Err(err) => { - let err_msg = format!("Error receiving chunk: {:?}", err); - warn!(error = %err_msg, "stream error"); - inner_processor.on_error(&err_msg); - break; - } - }; + let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk))); + let stream_body = BoxBody::new(StreamBody::new(stream)); - if is_first_chunk { - inner_processor.on_first_bytes(); - is_first_chunk = false; - } + StreamingResponse { + body: stream_body, + processor_handle, + } +} - // Pass raw chunk bytes through the output filter chain - let processed_chunk = match pipeline_processor - .process_raw_filter_chain( - &chunk, - &chain, - &output_chain.agents, - &request_headers, - &request_path, +/// Bidirectional streaming path: one HTTP/2 connection per filter for the entire +/// LLM response. Falls back to per-chunk mode if the pipeline fails to establish. +fn spawn_streaming_output_filter( + mut byte_stream: S, + mut inner_processor: P, + output_chain: ResolvedFilterChain, + request_headers: HeaderMap, + request_path: String, + tx: mpsc::Sender, + current_span: tracing::Span, +) -> tokio::task::JoinHandle<()> +where + S: StreamExt> + Send + Unpin + 'static, + P: StreamProcessor, +{ + tokio::spawn( + async move { + let agents = output_chain.streaming_agents(); + let pipeline = PipelineProcessor::start_streaming_output_pipeline( + &agents, + &request_headers, + &request_path, + ) + .await; + + let pipeline = match pipeline { + Ok(p) => p, + Err(e) => { + warn!(error = %e, "failed to establish streaming pipeline, falling back to per-chunk"); + run_per_chunk_loop( + byte_stream, + inner_processor, + output_chain, + request_headers, + request_path, + tx, ) - .await - { - Ok(filtered) => filtered, - Err(PipelineError::ClientError { - agent, - status, - body, - }) => { - warn!( - agent = %agent, - status = %status, - body = %body, - "output filter client error, passing through original chunk" - ); - chunk + .await; + return; + } + }; + + let input_tx = pipeline.input_tx; + let mut output_rx = pipeline.output_rx; + let _handles = pipeline.handles; + let mut is_first_chunk = true; + + // Writer: LLM chunks → pipeline input + let writer = async { + while let Some(item) = byte_stream.next().await { + match item { + Ok(chunk) => { + if input_tx.send(chunk).await.is_err() { + debug!("streaming pipeline input closed"); + break; + } + } + Err(err) => { + warn!(error = %format!("{err:?}"), "LLM stream error"); + break; + } } - Err(e) => { - warn!(error = %e, "output filter error, passing through original chunk"); - chunk + } + drop(input_tx); + }; + + // Reader: pipeline output → client + let reader = async { + while let Some(processed) = output_rx.recv().await { + if is_first_chunk { + inner_processor.on_first_bytes(); + is_first_chunk = false; } - }; - // Pass through inner processor for metrics/observability - match inner_processor.process_chunk(processed_chunk) { - Ok(Some(final_chunk)) => { - if tx.send(final_chunk).await.is_err() { - warn!("receiver dropped"); + match inner_processor.process_chunk(processed) { + Ok(Some(final_chunk)) => { + if tx.send(final_chunk).await.is_err() { + warn!("client receiver dropped"); + break; + } + } + Ok(None) => continue, + Err(err) => { + warn!("processor error: {}", err); + inner_processor.on_error(&err); break; } } - Ok(None) => continue, - Err(err) => { - warn!("processor error: {}", err); - inner_processor.on_error(&err); - break; - } } - } + }; + tokio::join!(writer, reader); inner_processor.on_complete(); - debug!("output filter streaming completed"); + debug!("streaming output filter pipeline completed"); } .instrument(current_span), - ); + ) +} - let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk))); - let stream_body = BoxBody::new(StreamBody::new(stream)); +/// Per-chunk path: one HTTP request per chunk per filter (original behavior). +fn spawn_per_chunk_output_filter( + byte_stream: S, + inner_processor: P, + output_chain: ResolvedFilterChain, + request_headers: HeaderMap, + request_path: String, + tx: mpsc::Sender, + current_span: tracing::Span, +) -> tokio::task::JoinHandle<()> +where + S: StreamExt> + Send + Unpin + 'static, + P: StreamProcessor, +{ + tokio::spawn( + async move { + run_per_chunk_loop( + byte_stream, + inner_processor, + output_chain, + request_headers, + request_path, + tx, + ) + .await; + } + .instrument(current_span), + ) +} - StreamingResponse { - body: stream_body, - processor_handle, +async fn run_per_chunk_loop( + mut byte_stream: S, + mut inner_processor: P, + output_chain: ResolvedFilterChain, + request_headers: HeaderMap, + request_path: String, + tx: mpsc::Sender, +) where + S: StreamExt> + Send + Unpin + 'static, + P: StreamProcessor, +{ + let mut is_first_chunk = true; + let mut pipeline_processor = PipelineProcessor::default(); + let chain = output_chain.to_agent_filter_chain("output_filter"); + + while let Some(item) = byte_stream.next().await { + let chunk = match item { + Ok(chunk) => chunk, + Err(err) => { + let err_msg = format!("Error receiving chunk: {:?}", err); + warn!(error = %err_msg, "stream error"); + inner_processor.on_error(&err_msg); + break; + } + }; + + if is_first_chunk { + inner_processor.on_first_bytes(); + is_first_chunk = false; + } + + let processed_chunk = match pipeline_processor + .process_raw_filter_chain( + &chunk, + &chain, + &output_chain.agents, + &request_headers, + &request_path, + ) + .await + { + Ok(filtered) => filtered, + Err(PipelineError::ClientError { + agent, + status, + body, + }) => { + warn!( + agent = %agent, + status = %status, + body = %body, + "output filter client error, passing through original chunk" + ); + chunk + } + Err(e) => { + warn!(error = %e, "output filter error, passing through original chunk"); + chunk + } + }; + + match inner_processor.process_chunk(processed_chunk) { + Ok(Some(final_chunk)) => { + if tx.send(final_chunk).await.is_err() { + warn!("receiver dropped"); + break; + } + } + Ok(None) => continue, + Err(err) => { + warn!("processor error: {}", err); + inner_processor.on_error(&err); + break; + } + } } + + inner_processor.on_complete(); + debug!("output filter streaming completed"); } /// Truncates a message to the specified maximum length, adding "..." if truncated. diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index df1790594..85c4879d5 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -20,6 +20,7 @@ pub struct Agent { pub url: String, #[serde(rename = "type")] pub agent_type: Option, + pub streaming: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -43,6 +44,29 @@ impl ResolvedFilterChain { self.filter_ids.is_empty() } + /// True when every filter in the chain is an HTTP filter with `streaming: true`. + /// MCP filters and filters without the streaming flag use per-chunk mode. + pub fn all_support_streaming(&self) -> bool { + !self.filter_ids.is_empty() + && self.filter_ids.iter().all(|id| { + self.agents + .get(id) + .map(|a| { + a.streaming.unwrap_or(false) + && a.agent_type.as_deref().unwrap_or("mcp") != "mcp" + }) + .unwrap_or(false) + }) + } + + /// Returns references to the ordered agents for the streaming pipeline. + pub fn streaming_agents(&self) -> Vec<&Agent> { + self.filter_ids + .iter() + .filter_map(|id| self.agents.get(id)) + .collect() + } + pub fn to_agent_filter_chain(&self, id: &str) -> AgentFilterChain { AgentFilterChain { id: id.to_string(), @@ -542,7 +566,7 @@ mod test { use pretty_assertions::assert_eq; use std::fs; - use super::{IntoModels, LlmProvider, LlmProviderType}; + use super::{Agent, IntoModels, LlmProvider, LlmProviderType, ResolvedFilterChain}; use crate::api::open_ai::ToolType; #[test] @@ -663,4 +687,81 @@ mod test { assert!(!model_ids.contains(&"arch-router".to_string())); assert!(!model_ids.contains(&"plano-orchestrator".to_string())); } + + fn make_agent(id: &str, agent_type: Option<&str>, streaming: Option) -> Agent { + Agent { + id: id.to_string(), + url: format!("http://localhost:10501/{id}"), + agent_type: agent_type.map(String::from), + transport: None, + tool: None, + streaming, + } + } + + #[test] + fn test_all_support_streaming_all_http_streaming() { + let chain = ResolvedFilterChain { + filter_ids: vec!["a".into(), "b".into()], + agents: [ + ("a".into(), make_agent("a", Some("http"), Some(true))), + ("b".into(), make_agent("b", Some("http"), Some(true))), + ] + .into(), + }; + assert!(chain.all_support_streaming()); + } + + #[test] + fn test_all_support_streaming_one_missing_flag() { + let chain = ResolvedFilterChain { + filter_ids: vec!["a".into(), "b".into()], + agents: [ + ("a".into(), make_agent("a", Some("http"), Some(true))), + ("b".into(), make_agent("b", Some("http"), None)), + ] + .into(), + }; + assert!(!chain.all_support_streaming()); + } + + #[test] + fn test_all_support_streaming_mcp_filter() { + let chain = ResolvedFilterChain { + filter_ids: vec!["a".into()], + agents: [("a".into(), make_agent("a", Some("mcp"), Some(true)))].into(), + }; + assert!(!chain.all_support_streaming()); + } + + #[test] + fn test_all_support_streaming_default_type_is_mcp() { + let chain = ResolvedFilterChain { + filter_ids: vec!["a".into()], + agents: [("a".into(), make_agent("a", None, Some(true)))].into(), + }; + assert!(!chain.all_support_streaming()); + } + + #[test] + fn test_all_support_streaming_empty_chain() { + let chain = ResolvedFilterChain::default(); + assert!(!chain.all_support_streaming()); + } + + #[test] + fn test_streaming_agents_ordered() { + let chain = ResolvedFilterChain { + filter_ids: vec!["b".into(), "a".into()], + agents: [ + ("a".into(), make_agent("a", Some("http"), Some(true))), + ("b".into(), make_agent("b", Some("http"), Some(true))), + ] + .into(), + }; + let agents = chain.streaming_agents(); + assert_eq!(agents.len(), 2); + assert_eq!(agents[0].id, "b"); + assert_eq!(agents[1].id, "a"); + } } diff --git a/demos/filter_chains/pii_anonymizer/config.yaml b/demos/filter_chains/pii_anonymizer/config.yaml index b183379f8..78cf87067 100644 --- a/demos/filter_chains/pii_anonymizer/config.yaml +++ b/demos/filter_chains/pii_anonymizer/config.yaml @@ -7,6 +7,7 @@ filters: - id: pii_deanonymizer url: http://localhost:10501/deanonymize type: http + streaming: true model_providers: - model: openai/gpt-4o-mini diff --git a/demos/filter_chains/pii_anonymizer/pii_anonymizer.py b/demos/filter_chains/pii_anonymizer/pii_anonymizer.py index 9c1450cc0..a0108dec2 100644 --- a/demos/filter_chains/pii_anonymizer/pii_anonymizer.py +++ b/demos/filter_chains/pii_anonymizer/pii_anonymizer.py @@ -21,10 +21,16 @@ from typing import Any, Dict from fastapi import FastAPI, Request -from fastapi.responses import Response +from fastapi.responses import Response, StreamingResponse from pii import anonymize_text, anonymize_message_content -from store import get_mapping, store_mapping, deanonymize_sse, deanonymize_json +from store import ( + get_mapping, + store_mapping, + deanonymize_sse, + deanonymize_sse_stream, + deanonymize_json, +) logging.basicConfig( level=logging.INFO, @@ -105,11 +111,36 @@ async def deanonymize(path: str, request: Request) -> Response: /deanonymize/v1/chat/completions — OpenAI chat completions /deanonymize/v1/messages — Anthropic messages /deanonymize/v1/responses — OpenAI responses API + + Supports two modes: + - Bidirectional streaming: request body is streamed (Content-Type: application/octet-stream). + Reads via request.stream(), processes SSE events incrementally, returns StreamingResponse. + - Per-chunk / full body: reads entire body, processes, returns complete Response. """ endpoint = f"/{path}" is_anthropic = endpoint == "/v1/messages" request_id = request.headers.get("x-request-id", "unknown") mapping = get_mapping(request_id) + + content_type = request.headers.get("content-type", "") + is_streaming = "application/octet-stream" in content_type + + if is_streaming: + if not mapping: + logger.info("request_id=%s streaming, no mapping — passthrough", request_id) + + async def passthrough(): + async for chunk in request.stream(): + yield chunk + + return StreamingResponse(passthrough(), media_type="text/event-stream") + + logger.info("request_id=%s streaming deanonymize", request_id) + return StreamingResponse( + deanonymize_sse_stream(request_id, request.stream(), mapping, is_anthropic), + media_type="text/event-stream", + ) + raw_body = await request.body() if not mapping: diff --git a/demos/filter_chains/pii_anonymizer/store.py b/demos/filter_chains/pii_anonymizer/store.py index 74cbd01cf..7247373b8 100644 --- a/demos/filter_chains/pii_anonymizer/store.py +++ b/demos/filter_chains/pii_anonymizer/store.py @@ -4,7 +4,7 @@ import logging import threading import time -from typing import Dict, Optional, Tuple +from typing import AsyncIterator, Dict, Optional, Tuple from fastapi.responses import Response @@ -59,36 +59,71 @@ def restore_streaming(request_id: str, content: str, mapping: Dict[str, str]) -> def deanonymize_sse( request_id: str, body_str: str, mapping: Dict[str, str], is_anthropic: bool ) -> Response: - result_lines = [] - for line in body_str.split("\n"): - stripped = line.strip() - if not (stripped.startswith("data: ") and stripped[6:] != "[DONE]"): - result_lines.append(line) - continue - try: - chunk = json.loads(stripped[6:]) - if is_anthropic: - # {"type": "content_block_delta", "delta": {"type": "text_delta", "text": "..."}} - if chunk.get("type") == "content_block_delta": - delta = chunk.get("delta", {}) - if delta.get("type") == "text_delta" and delta.get("text"): - delta["text"] = restore_streaming( - request_id, delta["text"], mapping - ) - else: - # {"choices": [{"delta": {"content": "..."}}]} - for choice in chunk.get("choices", []): - delta = choice.get("delta", {}) - if delta.get("content"): - delta["content"] = restore_streaming( - request_id, delta["content"], mapping - ) - result_lines.append("data: " + json.dumps(chunk)) - except json.JSONDecodeError: - result_lines.append(line) + result_lines = [ + _process_sse_line(request_id, line, mapping, is_anthropic) + for line in body_str.split("\n") + ] return Response(content="\n".join(result_lines), media_type="text/plain") +def _process_sse_line( + request_id: str, line: str, mapping: Dict[str, str], is_anthropic: bool +) -> str: + """Process a single SSE line, restoring PII in data payloads.""" + stripped = line.strip() + if not (stripped.startswith("data: ") and stripped[6:] != "[DONE]"): + return line + try: + chunk = json.loads(stripped[6:]) + if is_anthropic: + if chunk.get("type") == "content_block_delta": + delta = chunk.get("delta", {}) + if delta.get("type") == "text_delta" and delta.get("text"): + delta["text"] = restore_streaming( + request_id, delta["text"], mapping + ) + else: + for choice in chunk.get("choices", []): + delta = choice.get("delta", {}) + if delta.get("content"): + delta["content"] = restore_streaming( + request_id, delta["content"], mapping + ) + return "data: " + json.dumps(chunk) + except json.JSONDecodeError: + return line + + +async def deanonymize_sse_stream( + request_id: str, + byte_stream: AsyncIterator[bytes], + mapping: Dict[str, str], + is_anthropic: bool, +): + """Async generator that reads SSE events from a streaming request body, + de-anonymizes them, and yields processed events as they become complete. + Buffers partial data and splits on SSE event boundaries (blank lines). + """ + buffer = "" + async for raw_chunk in byte_stream: + buffer += raw_chunk.decode("utf-8", errors="replace") + # Yield each complete SSE event (delimited by double newline) + while "\n\n" in buffer: + event, buffer = buffer.split("\n\n", 1) + processed_lines = [ + _process_sse_line(request_id, line, mapping, is_anthropic) + for line in event.split("\n") + ] + yield "\n".join(processed_lines) + "\n\n" + # Flush any trailing data + if buffer.strip(): + processed_lines = [ + _process_sse_line(request_id, line, mapping, is_anthropic) + for line in buffer.split("\n") + ] + yield "\n".join(processed_lines) + + def deanonymize_json( request_id: str, raw_body: bytes,