diff --git a/app/services/provider_service.py b/app/services/provider_service.py index 1fd211e..b4d5235 100644 --- a/app/services/provider_service.py +++ b/app/services/provider_service.py @@ -673,12 +673,16 @@ async def process_request( # https://platform.openai.com/docs/api-reference/chat/object#chat/object-usage if isinstance(result, dict) and "usage" in result: usage = result.get("usage", {}) - input_tokens = usage.get("prompt_tokens", 0) - output_tokens = usage.get("completion_tokens", 0) + input_tokens = usage.get("prompt_tokens", 0) or input_tokens + output_tokens = usage.get("completion_tokens", 0) or output_tokens + total_tokens = usage.get("total_tokens", 0) or (input_tokens + output_tokens) prompt_tokens_details = usage.get("prompt_tokens_details", {}) or {} completion_tokens_details = usage.get("completion_tokens_details", {}) or {} cached_tokens = prompt_tokens_details.get("cached_tokens", 0) - reasoning_tokens = completion_tokens_details.get("reasoning_tokens", 0) + reasoning_tokens = completion_tokens_details.get("reasoning_tokens", 0) or (total_tokens - input_tokens - output_tokens) + + # re-calculate output tokens + output_tokens = max(output_tokens, total_tokens - input_tokens) asyncio.create_task( update_usage_in_background( @@ -697,6 +701,7 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: approximate_output_tokens = 0 output_tokens = 0 input_tokens = 0 + total_tokens = 0 cached_tokens = 0 reasoning_tokens = 0 chunks_processed = 0 @@ -744,14 +749,18 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]: usage = data.get("usage", {}) input_tokens = ( usage.get("prompt_tokens", 0) or 0 - ) + ) or input_tokens output_tokens = ( usage.get("completion_tokens", 0) or 0 - ) + ) or output_tokens + total_tokens = usage.get("total_tokens", 0) or total_tokens or (input_tokens + output_tokens) prompt_tokens_details = usage.get("prompt_tokens_details", {}) or {} completion_tokens_details = usage.get("completion_tokens_details", {}) or {} - cached_tokens = prompt_tokens_details.get("cached_tokens", 0) - reasoning_tokens = completion_tokens_details.get("reasoning_tokens", 0) + cached_tokens = prompt_tokens_details.get("cached_tokens", 0) or cached_tokens + reasoning_tokens = completion_tokens_details.get("reasoning_tokens", 0) or reasoning_tokens or (total_tokens - input_tokens - output_tokens) + + # re-calculate output tokens + output_tokens = max(output_tokens, total_tokens - input_tokens) # Extract content from the chunk based on OpenAI format if "choices" in data: diff --git a/app/services/providers/anthropic_adapter.py b/app/services/providers/anthropic_adapter.py index bc051ae..6577e3f 100644 --- a/app/services/providers/anthropic_adapter.py +++ b/app/services/providers/anthropic_adapter.py @@ -40,20 +40,21 @@ def provider_name(self) -> str: return self._provider_name @staticmethod - def format_anthropic_usage(usage_data: dict[str, Any]) -> dict[str, Any]: + def format_anthropic_usage(usage_data: dict[str, Any], token_usage: dict[str, int]) -> dict[str, Any]: if not usage_data: return None - input_tokens = usage_data.get("input_tokens", 0) - output_tokens = usage_data.get("output_tokens", 0) - cached_tokens = usage_data.get("cache_creation_input_tokens", 0) or 0 - cached_tokens += usage_data.get("cache_read_input_tokens", 0) or 0 + token_usage['input_tokens'] += usage_data.get("input_tokens", 0) + token_usage['output_tokens'] += usage_data.get("output_tokens", 0) + token_usage['cached_tokens'] += usage_data.get("cache_creation_input_tokens", 0) or 0 + token_usage['cached_tokens'] += usage_data.get("cache_read_input_tokens", 0) or 0 + return { - "prompt_tokens": input_tokens, - "completion_tokens": output_tokens, - "total_tokens": input_tokens + output_tokens, + "prompt_tokens": token_usage['input_tokens'], + "completion_tokens": token_usage['output_tokens'], + "total_tokens": token_usage['input_tokens'] + token_usage['output_tokens'], "prompt_tokens_details": { - "cached_tokens": cached_tokens, + "cached_tokens": token_usage['cached_tokens'], }, } @@ -445,6 +446,11 @@ async def stream_anthropic_response( async def stream_response() -> AsyncGenerator[bytes, None]: # Store parts of usage info as they arrive request_id = f"chatcmpl-{uuid.uuid4()}" + token_usage = { + 'input_tokens': 0, + "output_tokens": 0, + "cached_tokens": 0, + } async with ( aiohttp.ClientSession() as session, @@ -489,7 +495,7 @@ async def stream_response() -> AsyncGenerator[bytes, None]: # Capture Input Tokens from message_start if event_type == "message_start": message_data = data.get("message", {}) - usage_data = AnthropicAdapter.format_anthropic_usage(message_data.get("usage", {})) + usage_data = AnthropicAdapter.format_anthropic_usage(message_data.get("usage", {}), token_usage) if message_data: openai_chunk = { "id": request_id, @@ -512,7 +518,7 @@ async def stream_response() -> AsyncGenerator[bytes, None]: # Handle start of content blocks (text or tool_use) content_block = data.get("content_block", {}) - usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {})) + usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {}), token_usage) if content_block.get("type") == "tool_use": # Start of a tool call openai_chunk = { @@ -551,7 +557,7 @@ async def stream_response() -> AsyncGenerator[bytes, None]: elif event_type == "content_block_delta": delta = data.get("delta", {}) - usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {})) + usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {}), token_usage) if delta.get("type") == "text_delta": # Text content delta delta_content = delta.get("text", "") @@ -602,7 +608,7 @@ async def stream_response() -> AsyncGenerator[bytes, None]: elif event_type == "message_delta": delta_data = data.get("delta", {}) - usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {})) + usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {}), token_usage) anthropic_stop_reason = delta_data.get("stop_reason") if anthropic_stop_reason: @@ -634,7 +640,7 @@ async def stream_response() -> AsyncGenerator[bytes, None]: anthropic_stop_reason, "stop" ) - usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {})) + usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {}), token_usage) # --- Yielding Logic --- if openai_chunk: @@ -735,7 +741,12 @@ async def process_regular_response( ) # https://docs.anthropic.com/en/api/messages#response-usage - usage_data = AnthropicAdapter.format_anthropic_usage(anthropic_response.get("usage", {})) + token_usage = { + 'input_tokens': 0, + 'output_tokens': 0, + 'cached_tokens': 0, + } + usage_data = AnthropicAdapter.format_anthropic_usage(anthropic_response.get("usage", {}), token_usage) return { "id": completion_id, "object": "chat.completion", diff --git a/app/services/providers/google_adapter.py b/app/services/providers/google_adapter.py index 041534f..95bf7cc 100644 --- a/app/services/providers/google_adapter.py +++ b/app/services/providers/google_adapter.py @@ -344,9 +344,6 @@ async def _stream_google_response( # Process response in chunks buffer = "" - # According to https://ai.google.dev/api/generate-content#UsageMetadata - # thoughtsTokenCount is output only. We should only record it once - logged_thoughts_tokens = False async for chunk in response.content.iter_chunks(): if not chunk[0]: # Empty chunk continue @@ -379,10 +376,6 @@ async def _stream_google_response( usage_data = self.format_google_usage( json_obj["usageMetadata"] ) - if logged_thoughts_tokens: - del usage_data['completion_tokens_details'] - elif usage_data.get('completion_tokens_details', {}).get('reasoning_tokens'): - logged_thoughts_tokens = True if "candidates" in json_obj: diff --git a/app/services/providers/usage_tracker_service.py b/app/services/providers/usage_tracker_service.py index 40bde88..c2fb8a7 100644 --- a/app/services/providers/usage_tracker_service.py +++ b/app/services/providers/usage_tracker_service.py @@ -75,6 +75,7 @@ async def update_usage_tracker( usage_tracker.input_tokens = input_tokens usage_tracker.output_tokens = output_tokens usage_tracker.cached_tokens = cached_tokens + usage_tracker.reasoning_tokens = reasoning_tokens usage_tracker.updated_at = now usage_tracker.cost = price_info['total_cost'] usage_tracker.currency = price_info['currency'] @@ -100,4 +101,4 @@ async def delete_usage_tracker_record( await db.commit() except Exception as e: await db.rollback() - logger.error(f"Failed to delete usage tracker record: {e}") \ No newline at end of file + logger.error(f"Failed to delete usage tracker record: {e}") diff --git a/tests/unit_tests/test_google_provider.py b/tests/unit_tests/test_google_provider.py index 7383ae1..27e8b73 100644 --- a/tests/unit_tests/test_google_provider.py +++ b/tests/unit_tests/test_google_provider.py @@ -121,7 +121,7 @@ async def test_chat_completion_streaming(self): result, expected_model="models/gemini-1.5-pro-latest", expected_message=GOOGLE_STANDARD_CHAT_COMPLETION_RESPONSE, - expected_usage={"prompt_tokens": 12, "completion_tokens": 16}, + expected_usage={"prompt_tokens": 6, "completion_tokens": 16}, ) assert mock_session.posted_json[0] == { "generationConfig": { diff --git a/tests/unit_tests/utils/helpers.py b/tests/unit_tests/utils/helpers.py index a84928a..a62554c 100644 --- a/tests/unit_tests/utils/helpers.py +++ b/tests/unit_tests/utils/helpers.py @@ -156,11 +156,11 @@ def process_openai_streaming_response(response: str, result: dict): total_tokens = usage.get("total_tokens", 0) or (prompt_tokens + completion_tokens) cached_tokens = usage.get("prompt_tokens_details", {}).get("cached_tokens", 0) reasoning_tokens = usage.get("completion_tokens_details", {}).get("reasoning_tokens", 0) - result_usage["prompt_tokens"] += prompt_tokens - result_usage["completion_tokens"] += completion_tokens - result_usage["total_tokens"] += total_tokens - result_usage["prompt_tokens_details"]["cached_tokens"] += cached_tokens - result_usage["completion_tokens_details"]["reasoning_tokens"] += reasoning_tokens + result_usage["prompt_tokens"] = prompt_tokens + result_usage["completion_tokens"] = completion_tokens + result_usage["total_tokens"] = total_tokens + result_usage["prompt_tokens_details"]["cached_tokens"] = cached_tokens + result_usage["completion_tokens_details"]["reasoning_tokens"] = reasoning_tokens result["usage"] = result_usage