Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions app/services/providers/google_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ async def process_completion(
async def _stream_google_response(
self, api_key: str, model: str, google_payload: dict[str, Any]
) -> AsyncGenerator[bytes, None]:
# https://ai.google.dev/api/generate-content#method:-models.streamgeneratecontent
model_path = model if model.startswith("models/") else f"models/{model}"
url = f"{self._base_url}/{model_path}:streamGenerateContent"

Expand All @@ -325,9 +326,6 @@ async def _stream_google_response(
error=ValueError(error_text)
)
headers = {"Content-Type": "application/json", "Accept": "application/json"}
logger.debug(
f"Google API request - URL: {url}, Payload sample: {str(google_payload)[:200]}..."
)

async with (
aiohttp.ClientSession() as session,
Expand All @@ -345,8 +343,10 @@ async def _stream_google_response(
)

# Process response in chunks
# https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse
buffer = ""
# According to https://ai.google.dev/api/generate-content#UsageMetadata
# thoughtsTokenCount is output only. We should only record it once
logged_thoughts_tokens = False
async for chunk in response.content.iter_chunks():
if not chunk[0]: # Empty chunk
continue
Expand Down Expand Up @@ -379,6 +379,11 @@ async def _stream_google_response(
usage_data = self.format_google_usage(
json_obj["usageMetadata"]
)
if logged_thoughts_tokens:
del usage_data['completion_tokens_details']
elif usage_data.get('completion_tokens_details', {}).get('reasoning_tokens'):
logged_thoughts_tokens = True


if "candidates" in json_obj:
choices = []
Expand Down Expand Up @@ -431,6 +436,7 @@ async def _stream_google_response(

@staticmethod
def format_google_usage(metadata: dict) -> dict:
# https://ai.google.dev/api/generate-content#UsageMetadata
"""Format Google usage metadata to OpenAI format"""
if not metadata:
return None
Expand All @@ -457,7 +463,7 @@ async def convert_openai_completion_payload_to_google(
"stopSequences": payload.get("stop", []),
"temperature": payload.get("temperature", 0.7),
"topP": payload.get("top_p", 0.95),
"maxOutputTokens": payload.get("max_completion_tokens", payload.get("max_tokens", 2048)),
"maxOutputTokens": payload.get("max_completion_tokens", payload.get("max_tokens")),
},
}

Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/test_google_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def test_chat_completion(self):
"generationConfig": {
"temperature": 0.7,
"topP": 0.95,
"maxOutputTokens": 2048,
"maxOutputTokens": None,
"stopSequences": [],
},
"contents": [
Expand Down Expand Up @@ -127,7 +127,7 @@ async def test_chat_completion_streaming(self):
"generationConfig": {
"temperature": 0.7,
"topP": 0.95,
"maxOutputTokens": 2048,
"maxOutputTokens": None,
"stopSequences": [],
},
"contents": [
Expand Down
Loading