Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions app/services/provider_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,12 +673,16 @@ async def process_request(
# https://platform.openai.com/docs/api-reference/chat/object#chat/object-usage
if isinstance(result, dict) and "usage" in result:
usage = result.get("usage", {})
input_tokens = usage.get("prompt_tokens", 0)
output_tokens = usage.get("completion_tokens", 0)
input_tokens = usage.get("prompt_tokens", 0) or input_tokens
output_tokens = usage.get("completion_tokens", 0) or output_tokens
total_tokens = usage.get("total_tokens", 0) or (input_tokens + output_tokens)
prompt_tokens_details = usage.get("prompt_tokens_details", {}) or {}
completion_tokens_details = usage.get("completion_tokens_details", {}) or {}
cached_tokens = prompt_tokens_details.get("cached_tokens", 0)
reasoning_tokens = completion_tokens_details.get("reasoning_tokens", 0)
reasoning_tokens = completion_tokens_details.get("reasoning_tokens", 0) or (total_tokens - input_tokens - output_tokens)

# re-calculate output tokens
output_tokens = max(output_tokens, total_tokens - input_tokens)

asyncio.create_task(
update_usage_in_background(
Expand All @@ -697,6 +701,7 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]:
approximate_output_tokens = 0
output_tokens = 0
input_tokens = 0
total_tokens = 0
cached_tokens = 0
reasoning_tokens = 0
chunks_processed = 0
Expand Down Expand Up @@ -744,14 +749,18 @@ async def token_counting_stream() -> AsyncGenerator[bytes, None]:
usage = data.get("usage", {})
input_tokens = (
usage.get("prompt_tokens", 0) or 0
)
) or input_tokens
output_tokens = (
usage.get("completion_tokens", 0) or 0
)
) or output_tokens
total_tokens = usage.get("total_tokens", 0) or total_tokens or (input_tokens + output_tokens)
prompt_tokens_details = usage.get("prompt_tokens_details", {}) or {}
completion_tokens_details = usage.get("completion_tokens_details", {}) or {}
cached_tokens = prompt_tokens_details.get("cached_tokens", 0)
reasoning_tokens = completion_tokens_details.get("reasoning_tokens", 0)
cached_tokens = prompt_tokens_details.get("cached_tokens", 0) or cached_tokens
reasoning_tokens = completion_tokens_details.get("reasoning_tokens", 0) or reasoning_tokens or (total_tokens - input_tokens - output_tokens)

# re-calculate output tokens
output_tokens = max(output_tokens, total_tokens - input_tokens)

# Extract content from the chunk based on OpenAI format
if "choices" in data:
Expand Down
41 changes: 26 additions & 15 deletions app/services/providers/anthropic_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,21 @@ def provider_name(self) -> str:
return self._provider_name

@staticmethod
def format_anthropic_usage(usage_data: dict[str, Any]) -> dict[str, Any]:
def format_anthropic_usage(usage_data: dict[str, Any], token_usage: dict[str, int]) -> dict[str, Any]:
if not usage_data:
return None

input_tokens = usage_data.get("input_tokens", 0)
output_tokens = usage_data.get("output_tokens", 0)
cached_tokens = usage_data.get("cache_creation_input_tokens", 0) or 0
cached_tokens += usage_data.get("cache_read_input_tokens", 0) or 0
token_usage['input_tokens'] += usage_data.get("input_tokens", 0)
token_usage['output_tokens'] += usage_data.get("output_tokens", 0)
token_usage['cached_tokens'] += usage_data.get("cache_creation_input_tokens", 0) or 0
token_usage['cached_tokens'] += usage_data.get("cache_read_input_tokens", 0) or 0

return {
"prompt_tokens": input_tokens,
"completion_tokens": output_tokens,
"total_tokens": input_tokens + output_tokens,
"prompt_tokens": token_usage['input_tokens'],
"completion_tokens": token_usage['output_tokens'],
"total_tokens": token_usage['input_tokens'] + token_usage['output_tokens'],
"prompt_tokens_details": {
"cached_tokens": cached_tokens,
"cached_tokens": token_usage['cached_tokens'],
},
}

Expand Down Expand Up @@ -445,6 +446,11 @@ async def stream_anthropic_response(
async def stream_response() -> AsyncGenerator[bytes, None]:
# Store parts of usage info as they arrive
request_id = f"chatcmpl-{uuid.uuid4()}"
token_usage = {
'input_tokens': 0,
"output_tokens": 0,
"cached_tokens": 0,
}

async with (
aiohttp.ClientSession() as session,
Expand Down Expand Up @@ -489,7 +495,7 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
# Capture Input Tokens from message_start
if event_type == "message_start":
message_data = data.get("message", {})
usage_data = AnthropicAdapter.format_anthropic_usage(message_data.get("usage", {}))
usage_data = AnthropicAdapter.format_anthropic_usage(message_data.get("usage", {}), token_usage)
if message_data:
openai_chunk = {
"id": request_id,
Expand All @@ -512,7 +518,7 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
# Handle start of content blocks (text or tool_use)
content_block = data.get("content_block", {})

usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {}))
usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {}), token_usage)
if content_block.get("type") == "tool_use":
# Start of a tool call
openai_chunk = {
Expand Down Expand Up @@ -551,7 +557,7 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
elif event_type == "content_block_delta":
delta = data.get("delta", {})

usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {}))
usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {}), token_usage)
if delta.get("type") == "text_delta":
# Text content delta
delta_content = delta.get("text", "")
Expand Down Expand Up @@ -602,7 +608,7 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
elif event_type == "message_delta":
delta_data = data.get("delta", {})

usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {}))
usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {}), token_usage)

anthropic_stop_reason = delta_data.get("stop_reason")
if anthropic_stop_reason:
Expand Down Expand Up @@ -634,7 +640,7 @@ async def stream_response() -> AsyncGenerator[bytes, None]:
anthropic_stop_reason, "stop"
)

usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {}))
usage_data = AnthropicAdapter.format_anthropic_usage(data.get("usage", {}), token_usage)

# --- Yielding Logic ---
if openai_chunk:
Expand Down Expand Up @@ -735,7 +741,12 @@ async def process_regular_response(
)

# https://docs.anthropic.com/en/api/messages#response-usage
usage_data = AnthropicAdapter.format_anthropic_usage(anthropic_response.get("usage", {}))
token_usage = {
'input_tokens': 0,
'output_tokens': 0,
'cached_tokens': 0,
}
usage_data = AnthropicAdapter.format_anthropic_usage(anthropic_response.get("usage", {}), token_usage)
return {
"id": completion_id,
"object": "chat.completion",
Expand Down
7 changes: 0 additions & 7 deletions app/services/providers/google_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,6 @@ async def _stream_google_response(

# Process response in chunks
buffer = ""
# According to https://ai.google.dev/api/generate-content#UsageMetadata
# thoughtsTokenCount is output only. We should only record it once
logged_thoughts_tokens = False
async for chunk in response.content.iter_chunks():
if not chunk[0]: # Empty chunk
continue
Expand Down Expand Up @@ -379,10 +376,6 @@ async def _stream_google_response(
usage_data = self.format_google_usage(
json_obj["usageMetadata"]
)
if logged_thoughts_tokens:
del usage_data['completion_tokens_details']
elif usage_data.get('completion_tokens_details', {}).get('reasoning_tokens'):
logged_thoughts_tokens = True


if "candidates" in json_obj:
Expand Down
3 changes: 2 additions & 1 deletion app/services/providers/usage_tracker_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ async def update_usage_tracker(
usage_tracker.input_tokens = input_tokens
usage_tracker.output_tokens = output_tokens
usage_tracker.cached_tokens = cached_tokens
usage_tracker.reasoning_tokens = reasoning_tokens
usage_tracker.updated_at = now
usage_tracker.cost = price_info['total_cost']
usage_tracker.currency = price_info['currency']
Expand All @@ -100,4 +101,4 @@ async def delete_usage_tracker_record(
await db.commit()
except Exception as e:
await db.rollback()
logger.error(f"Failed to delete usage tracker record: {e}")
logger.error(f"Failed to delete usage tracker record: {e}")
2 changes: 1 addition & 1 deletion tests/unit_tests/test_google_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ async def test_chat_completion_streaming(self):
result,
expected_model="models/gemini-1.5-pro-latest",
expected_message=GOOGLE_STANDARD_CHAT_COMPLETION_RESPONSE,
expected_usage={"prompt_tokens": 12, "completion_tokens": 16},
expected_usage={"prompt_tokens": 6, "completion_tokens": 16},
)
assert mock_session.posted_json[0] == {
"generationConfig": {
Expand Down
10 changes: 5 additions & 5 deletions tests/unit_tests/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,11 @@ def process_openai_streaming_response(response: str, result: dict):
total_tokens = usage.get("total_tokens", 0) or (prompt_tokens + completion_tokens)
cached_tokens = usage.get("prompt_tokens_details", {}).get("cached_tokens", 0)
reasoning_tokens = usage.get("completion_tokens_details", {}).get("reasoning_tokens", 0)
result_usage["prompt_tokens"] += prompt_tokens
result_usage["completion_tokens"] += completion_tokens
result_usage["total_tokens"] += total_tokens
result_usage["prompt_tokens_details"]["cached_tokens"] += cached_tokens
result_usage["completion_tokens_details"]["reasoning_tokens"] += reasoning_tokens
result_usage["prompt_tokens"] = prompt_tokens
result_usage["completion_tokens"] = completion_tokens
result_usage["total_tokens"] = total_tokens
result_usage["prompt_tokens_details"]["cached_tokens"] = cached_tokens
result_usage["completion_tokens_details"]["reasoning_tokens"] = reasoning_tokens
result["usage"] = result_usage


Expand Down
Loading