From 7ab493e7ae4850a6d04ce758cefe3dd116b5f928 Mon Sep 17 00:00:00 2001 From: Naz Quadri Date: Fri, 19 Dec 2025 09:38:47 -0500 Subject: [PATCH] feat: add StreamChunk::Usage variant for streaming token usage Emits Usage chunk from both Anthropic and OpenAI streaming responses: - Anthropic: extracts usage from message_delta event - OpenAI: extracts usage from final chunk (requires stream_options.include_usage) - Usage is emitted immediately before Done chunk - Includes cache token support via prompt_tokens_details --- src/backends/anthropic.rs | 224 ++++++++++++++++++++--------- src/chat/mod.rs | 3 + src/providers/openai_compatible.rs | 80 +++++++++++ tests/test_backends.rs | 3 + 4 files changed, 241 insertions(+), 69 deletions(-) diff --git a/src/backends/anthropic.rs b/src/backends/anthropic.rs index 89fb600..afcbc39 100644 --- a/src/backends/anthropic.rs +++ b/src/backends/anthropic.rs @@ -175,6 +175,8 @@ struct AnthropicStreamResponse { content_block: Option, /// Delta for content_block_delta and message_delta events delta: Option, + /// Usage information (for message_delta events) + usage: Option, } /// Content block within an Anthropic streaming content_block_start event. @@ -821,8 +823,7 @@ fn create_anthropic_tool_stream( buffer.drain(..pos + 2); match parse_anthropic_sse_chunk_with_tools(&event, tool_states) { - Ok(Some(chunk)) => results.push(Ok(chunk)), - Ok(None) => {} + Ok(chunks) => results.extend(chunks.into_iter().map(Ok)), Err(e) => results.push(Err(e)), } } @@ -986,13 +987,33 @@ struct ToolUseState { json_buffer: String, } +/// Converts AnthropicUsage to the common Usage type. +fn convert_anthropic_usage(anthropic_usage: &AnthropicUsage) -> Usage { + let cached_tokens = anthropic_usage.cache_creation_input_tokens.unwrap_or(0) + + anthropic_usage.cache_read_input_tokens.unwrap_or(0); + Usage { + prompt_tokens: anthropic_usage.input_tokens, + completion_tokens: anthropic_usage.output_tokens, + total_tokens: anthropic_usage.input_tokens + anthropic_usage.output_tokens, + completion_tokens_details: None, + prompt_tokens_details: if cached_tokens > 0 { + Some(crate::chat::PromptTokensDetails { + cached_tokens: Some(cached_tokens), + audio_tokens: None, + }) + } else { + None + }, + } +} + /// Parses Anthropic SSE chunks with tool use support. /// /// This parser handles all Anthropic streaming event types including: /// - `content_block_start` with `type: "text"` or `type: "tool_use"` /// - `content_block_delta` with `type: "text_delta"` or `type: "input_json_delta"` /// - `content_block_stop` -/// - `message_delta` with `stop_reason` +/// - `message_delta` with `stop_reason` and `usage` /// /// # Arguments /// @@ -1001,13 +1022,14 @@ struct ToolUseState { /// /// # Returns /// -/// * `Ok(Some(StreamChunk))` - A stream chunk if one was parsed -/// * `Ok(None)` - If chunk should be skipped +/// * `Ok(Vec)` - Zero or more stream chunks parsed from the event /// * `Err(LLMError)` - If parsing fails fn parse_anthropic_sse_chunk_with_tools( chunk: &str, tool_states: &mut HashMap, -) -> Result, LLMError> { +) -> Result, LLMError> { + let mut results = Vec::new(); + for line in chunk.lines() { let line = line.trim(); if let Some(data) = line.strip_prefix("data: ") { @@ -1032,7 +1054,7 @@ fn parse_anthropic_sse_chunk_with_tools( }, ); - return Ok(Some(StreamChunk::ToolUseStart { index, id, name })); + results.push(StreamChunk::ToolUseStart { index, id, name }); } // For text blocks, we just wait for content_block_delta } @@ -1043,7 +1065,7 @@ fn parse_anthropic_sse_chunk_with_tools( match delta.delta_type.as_deref() { Some("text_delta") => { if let Some(text) = delta.text { - return Ok(Some(StreamChunk::Text(text))); + results.push(StreamChunk::Text(text)); } } Some("input_json_delta") => { @@ -1052,10 +1074,10 @@ fn parse_anthropic_sse_chunk_with_tools( if let Some(state) = tool_states.get_mut(&index) { state.json_buffer.push_str(&partial_json); } - return Ok(Some(StreamChunk::ToolUseInputDelta { + results.push(StreamChunk::ToolUseInputDelta { index, partial_json, - })); + }); } } _ => {} @@ -1082,29 +1104,29 @@ fn parse_anthropic_sse_chunk_with_tools( arguments, }, }; - return Ok(Some(StreamChunk::ToolUseComplete { - index, - tool_call, - })); + results.push(StreamChunk::ToolUseComplete { index, tool_call }); } } } "message_delta" => { + // Emit Usage before Done if present + if let Some(ref usage) = response.usage { + results.push(StreamChunk::Usage(convert_anthropic_usage(usage))); + } if let Some(delta) = response.delta { if let Some(stop_reason) = delta.stop_reason { - return Ok(Some(StreamChunk::Done { stop_reason })); + results.push(StreamChunk::Done { stop_reason }); } } } _ => {} } - return Ok(None); } Err(_) => continue, } } } - Ok(None) + Ok(results) } #[cfg(test)] @@ -1118,11 +1140,12 @@ data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta" "#; let mut tool_states = HashMap::new(); - let result = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); + let results = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); - match result { - Some(StreamChunk::Text(text)) => assert_eq!(text, "Hello"), - _ => panic!("Expected Text chunk, got {:?}", result), + assert_eq!(results.len(), 1); + match &results[0] { + StreamChunk::Text(text) => assert_eq!(text, "Hello"), + _ => panic!("Expected Text chunk, got {:?}", results[0]), } } @@ -1133,15 +1156,16 @@ data: {"type": "content_block_start", "index": 1, "content_block": {"type": "too "#; let mut tool_states = HashMap::new(); - let result = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); + let results = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); - match result { - Some(StreamChunk::ToolUseStart { index, id, name }) => { - assert_eq!(index, 1); + assert_eq!(results.len(), 1); + match &results[0] { + StreamChunk::ToolUseStart { index, id, name } => { + assert_eq!(*index, 1); assert_eq!(id, "toolu_01ABC"); assert_eq!(name, "get_weather"); } - _ => panic!("Expected ToolUseStart chunk, got {:?}", result), + _ => panic!("Expected ToolUseStart chunk, got {:?}", results[0]), } // Verify state was stored @@ -1167,17 +1191,18 @@ data: {"type": "content_block_delta", "index": 1, "delta": {"type": "input_json_ }, ); - let result = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); + let results = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); - match result { - Some(StreamChunk::ToolUseInputDelta { + assert_eq!(results.len(), 1); + match &results[0] { + StreamChunk::ToolUseInputDelta { index, partial_json, - }) => { - assert_eq!(index, 1); + } => { + assert_eq!(*index, 1); assert_eq!(partial_json, "{\"location\":"); } - _ => panic!("Expected ToolUseInputDelta chunk, got {:?}", result), + _ => panic!("Expected ToolUseInputDelta chunk, got {:?}", results[0]), } // Verify JSON was accumulated @@ -1201,16 +1226,17 @@ data: {"type": "content_block_stop", "index": 1} }, ); - let result = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); + let results = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); - match result { - Some(StreamChunk::ToolUseComplete { index, tool_call }) => { - assert_eq!(index, 1); + assert_eq!(results.len(), 1); + match &results[0] { + StreamChunk::ToolUseComplete { index, tool_call } => { + assert_eq!(*index, 1); assert_eq!(tool_call.id, "toolu_01ABC"); assert_eq!(tool_call.function.name, "get_weather"); assert_eq!(tool_call.function.arguments, r#"{"location": "Paris"}"#); } - _ => panic!("Expected ToolUseComplete chunk, got {:?}", result), + _ => panic!("Expected ToolUseComplete chunk, got {:?}", results[0]), } // Verify state was removed @@ -1237,11 +1263,12 @@ data: {"type": "content_block_stop", "index": 1} }, ); - let result = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); + let results = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); - match result { - Some(StreamChunk::ToolUseComplete { index, tool_call }) => { - assert_eq!(index, 1); + assert_eq!(results.len(), 1); + match &results[0] { + StreamChunk::ToolUseComplete { index, tool_call } => { + assert_eq!(*index, 1); assert_eq!(tool_call.id, "toolu_01XYZ"); assert_eq!(tool_call.function.name, "get_current_time"); // CRITICAL: arguments must be "{}" not "" for Anthropic API compatibility @@ -1250,7 +1277,7 @@ data: {"type": "content_block_stop", "index": 1} "Empty arguments should default to '{{}}' not empty string" ); } - _ => panic!("Expected ToolUseComplete chunk, got {:?}", result), + _ => panic!("Expected ToolUseComplete chunk, got {:?}", results[0]), } // Verify state was removed @@ -1264,13 +1291,14 @@ data: {"type": "message_delta", "delta": {"stop_reason": "tool_use"}} "#; let mut tool_states = HashMap::new(); - let result = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); + let results = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); - match result { - Some(StreamChunk::Done { stop_reason }) => { + assert_eq!(results.len(), 1); + match &results[0] { + StreamChunk::Done { stop_reason } => { assert_eq!(stop_reason, "tool_use"); } - _ => panic!("Expected Done chunk, got {:?}", result), + _ => panic!("Expected Done chunk, got {:?}", results[0]), } } @@ -1281,13 +1309,14 @@ data: {"type": "message_delta", "delta": {"stop_reason": "end_turn"}} "#; let mut tool_states = HashMap::new(); - let result = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); + let results = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); - match result { - Some(StreamChunk::Done { stop_reason }) => { + assert_eq!(results.len(), 1); + match &results[0] { + StreamChunk::Done { stop_reason } => { assert_eq!(stop_reason, "end_turn"); } - _ => panic!("Expected Done chunk, got {:?}", result), + _ => panic!("Expected Done chunk, got {:?}", results[0]), } } @@ -1300,8 +1329,9 @@ data: {"type": "message_delta", "delta": {"stop_reason": "end_turn"}} data: {"type": "content_block_start", "index": 1, "content_block": {"type": "tool_use", "id": "toolu_01ABC", "name": "get_weather", "input": {}}} "#; - let result = parse_anthropic_sse_chunk_with_tools(start_chunk, &mut tool_states).unwrap(); - assert!(matches!(result, Some(StreamChunk::ToolUseStart { .. }))); + let results = parse_anthropic_sse_chunk_with_tools(start_chunk, &mut tool_states).unwrap(); + assert_eq!(results.len(), 1); + assert!(matches!(&results[0], StreamChunk::ToolUseStart { .. })); // 2. Input JSON deltas let delta1 = r#"event: content_block_delta @@ -1324,10 +1354,11 @@ data: {"type": "content_block_delta", "index": 1, "delta": {"type": "input_json_ data: {"type": "content_block_stop", "index": 1} "#; - let result = parse_anthropic_sse_chunk_with_tools(stop_chunk, &mut tool_states).unwrap(); + let results = parse_anthropic_sse_chunk_with_tools(stop_chunk, &mut tool_states).unwrap(); - match result { - Some(StreamChunk::ToolUseComplete { tool_call, .. }) => { + assert_eq!(results.len(), 1); + match &results[0] { + StreamChunk::ToolUseComplete { tool_call, .. } => { assert_eq!(tool_call.function.arguments, "{\"location\": \"Paris\"}"); } _ => panic!("Expected ToolUseComplete"), @@ -1338,12 +1369,13 @@ data: {"type": "content_block_stop", "index": 1} data: {"type": "message_delta", "delta": {"stop_reason": "tool_use"}} "#; - let result = parse_anthropic_sse_chunk_with_tools(done_chunk, &mut tool_states).unwrap(); + let results = parse_anthropic_sse_chunk_with_tools(done_chunk, &mut tool_states).unwrap(); + assert_eq!(results.len(), 1); assert!(matches!( - result, - Some(StreamChunk::Done { + &results[0], + StreamChunk::Done { stop_reason - }) if stop_reason == "tool_use" + } if stop_reason == "tool_use" )); } @@ -1356,18 +1388,18 @@ data: {"type": "message_delta", "delta": {"stop_reason": "tool_use"}} data: {"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "I'll check the weather"}} "#; - let result = parse_anthropic_sse_chunk_with_tools(text_chunk, &mut tool_states).unwrap(); - assert!(matches!(result, Some(StreamChunk::Text(t)) if t == "I'll check the weather")); + let results = parse_anthropic_sse_chunk_with_tools(text_chunk, &mut tool_states).unwrap(); + assert_eq!(results.len(), 1); + assert!(matches!(&results[0], StreamChunk::Text(t) if t == "I'll check the weather")); // Then tool use let tool_start = r#"event: content_block_start data: {"type": "content_block_start", "index": 1, "content_block": {"type": "tool_use", "id": "toolu_01XYZ", "name": "weather", "input": {}}} "#; - let result = parse_anthropic_sse_chunk_with_tools(tool_start, &mut tool_states).unwrap(); - assert!( - matches!(result, Some(StreamChunk::ToolUseStart { name, .. }) if name == "weather") - ); + let results = parse_anthropic_sse_chunk_with_tools(tool_start, &mut tool_states).unwrap(); + assert_eq!(results.len(), 1); + assert!(matches!(&results[0], StreamChunk::ToolUseStart { name, .. } if name == "weather")); } #[test] @@ -1377,8 +1409,8 @@ data: {"type": "message_start", "message": {"id": "msg_123", "type": "message", "#; let mut tool_states = HashMap::new(); - let result = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); - assert!(result.is_none()); + let results = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); + assert!(results.is_empty()); } #[test] @@ -1388,7 +1420,61 @@ data: {"type": "ping"} "#; let mut tool_states = HashMap::new(); - let result = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); - assert!(result.is_none()); + let results = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); + assert!(results.is_empty()); + } + + #[test] + fn test_parse_stream_message_delta_with_usage() { + let chunk = r#"event: message_delta +data: {"type": "message_delta", "delta": {"stop_reason": "end_turn"}, "usage": {"input_tokens": 25, "output_tokens": 150}} + +"#; + let mut tool_states = HashMap::new(); + let results = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); + + // Should emit Usage before Done + assert_eq!(results.len(), 2); + + match &results[0] { + StreamChunk::Usage(usage) => { + assert_eq!(usage.prompt_tokens, 25); + assert_eq!(usage.completion_tokens, 150); + assert_eq!(usage.total_tokens, 175); + } + _ => panic!("Expected Usage chunk first, got {:?}", results[0]), + } + + match &results[1] { + StreamChunk::Done { stop_reason } => { + assert_eq!(stop_reason, "end_turn"); + } + _ => panic!("Expected Done chunk second, got {:?}", results[1]), + } + } + + #[test] + fn test_parse_stream_message_delta_with_usage_and_cache() { + let chunk = r#"event: message_delta +data: {"type": "message_delta", "delta": {"stop_reason": "end_turn"}, "usage": {"input_tokens": 100, "output_tokens": 50, "cache_read_input_tokens": 80}} + +"#; + let mut tool_states = HashMap::new(); + let results = parse_anthropic_sse_chunk_with_tools(chunk, &mut tool_states).unwrap(); + + assert_eq!(results.len(), 2); + + match &results[0] { + StreamChunk::Usage(usage) => { + assert_eq!(usage.prompt_tokens, 100); + assert_eq!(usage.completion_tokens, 50); + assert_eq!(usage.total_tokens, 150); + // Cache tokens should be in prompt_tokens_details + assert!(usage.prompt_tokens_details.is_some()); + let details = usage.prompt_tokens_details.as_ref().unwrap(); + assert_eq!(details.cached_tokens, Some(80)); + } + _ => panic!("Expected Usage chunk, got {:?}", results[0]), + } } } diff --git a/src/chat/mod.rs b/src/chat/mod.rs index 331ca95..1a72711 100644 --- a/src/chat/mod.rs +++ b/src/chat/mod.rs @@ -103,6 +103,9 @@ pub enum StreamChunk { /// The reason the stream stopped (e.g., "end_turn", "tool_use") stop_reason: String, }, + + /// Token usage information (emitted before Done) + Usage(Usage), } /// Breakdown of completion tokens. diff --git a/src/providers/openai_compatible.rs b/src/providers/openai_compatible.rs index 8aac07e..afd87de 100644 --- a/src/providers/openai_compatible.rs +++ b/src/providers/openai_compatible.rs @@ -865,6 +865,11 @@ fn parse_openai_sse_chunk_with_tools( } } + // Emit Usage before Done if present + if let Some(ref usage) = chunk.usage { + results.push(ChatStreamChunk::Usage(usage.clone())); + } + let stop_reason = match finish_reason.as_str() { "tool_calls" => "tool_use", "stop" => "end_turn", @@ -875,6 +880,14 @@ fn parse_openai_sse_chunk_with_tools( }); } } + + // Handle usage in final chunk (may have empty choices) + // This handles the case where OpenAI sends usage in a separate final chunk + if chunk.choices.is_empty() { + if let Some(ref usage) = chunk.usage { + results.push(ChatStreamChunk::Usage(usage.clone())); + } + } } } } @@ -886,6 +899,8 @@ fn parse_openai_sse_chunk_with_tools( #[derive(Debug, Deserialize)] struct OpenAIToolStreamChunk { choices: Vec, + /// Usage information (present in final chunk when stream_options.include_usage is true) + usage: Option, } #[derive(Debug, Deserialize)] @@ -1417,4 +1432,69 @@ mod tests { results[0] ); } + + #[test] + fn test_parse_openai_stream_with_usage() { + // Test usage in final chunk with finish_reason + let event = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":20,"total_tokens":30}}"#; + let mut tool_states = HashMap::new(); + let results = parse_openai_sse_chunk_with_tools(event, &mut tool_states).unwrap(); + + // Should have Usage then Done + assert_eq!(results.len(), 2); + + match &results[0] { + ChatStreamChunk::Usage(usage) => { + assert_eq!(usage.prompt_tokens, 10); + assert_eq!(usage.completion_tokens, 20); + assert_eq!(usage.total_tokens, 30); + } + _ => panic!("Expected Usage chunk first, got {:?}", results[0]), + } + + match &results[1] { + ChatStreamChunk::Done { stop_reason } => { + assert_eq!(stop_reason, "end_turn"); + } + _ => panic!("Expected Done chunk second, got {:?}", results[1]), + } + } + + #[test] + fn test_parse_openai_stream_usage_separate_chunk() { + // OpenAI sometimes sends usage in a separate chunk with empty choices + let event = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","choices":[],"usage":{"prompt_tokens":25,"completion_tokens":100,"total_tokens":125}}"#; + let mut tool_states = HashMap::new(); + let results = parse_openai_sse_chunk_with_tools(event, &mut tool_states).unwrap(); + + assert_eq!(results.len(), 1); + match &results[0] { + ChatStreamChunk::Usage(usage) => { + assert_eq!(usage.prompt_tokens, 25); + assert_eq!(usage.completion_tokens, 100); + assert_eq!(usage.total_tokens, 125); + } + _ => panic!("Expected Usage chunk, got {:?}", results[0]), + } + } + + #[test] + fn test_parse_openai_stream_usage_with_details() { + // Test usage with prompt_tokens_details (cached_tokens) + let event = r#"data: {"id":"chatcmpl-123","object":"chat.completion.chunk","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150,"prompt_tokens_details":{"cached_tokens":80}}}"#; + let mut tool_states = HashMap::new(); + let results = parse_openai_sse_chunk_with_tools(event, &mut tool_states).unwrap(); + + assert_eq!(results.len(), 2); + match &results[0] { + ChatStreamChunk::Usage(usage) => { + assert_eq!(usage.prompt_tokens, 100); + assert_eq!(usage.completion_tokens, 50); + assert!(usage.prompt_tokens_details.is_some()); + let details = usage.prompt_tokens_details.as_ref().unwrap(); + assert_eq!(details.cached_tokens, Some(80)); + } + _ => panic!("Expected Usage chunk, got {:?}", results[0]), + } + } } diff --git a/tests/test_backends.rs b/tests/test_backends.rs index bf31329..5902b56 100644 --- a/tests/test_backends.rs +++ b/tests/test_backends.rs @@ -1043,6 +1043,9 @@ async fn test_anthropic_chat_stream_with_tools() { StreamChunk::ToolUseInputDelta { .. } => { // These are intermediate chunks, we don't need to collect them } + StreamChunk::Usage(_) => { + // Usage information, tracked separately if needed + } } } Err(e) => panic!("Stream error: {e}"),