From e3036ad2599afdba0247f78a79ff375e3dc9c4cf Mon Sep 17 00:00:00 2001 From: Wenjing Yu Date: Mon, 4 Aug 2025 14:40:10 -0700 Subject: [PATCH] support sse for vertex --- app/services/providers/vertex_adapter.py | 62 +++++++++++++++++++++--- 1 file changed, 55 insertions(+), 7 deletions(-) diff --git a/app/services/providers/vertex_adapter.py b/app/services/providers/vertex_adapter.py index 61c2bff..cb598c8 100644 --- a/app/services/providers/vertex_adapter.py +++ b/app/services/providers/vertex_adapter.py @@ -1,6 +1,7 @@ import asyncio import json import time +import uuid from collections.abc import AsyncGenerator from datetime import datetime, timezone from typing import Any @@ -236,15 +237,62 @@ def error_handler(error_text: str, http_status: int): # vertex doesn't do actual streaming, it just returns a stream of json objects url = f"{self._base_url}/v1/projects/{self.project_id}/locations/{self.location}/publishers/{self.publisher}/models/{model_name}:streamRawPredict" async def custom_stream_response(url, headers, anthropic_payload, model_name): + """Call Vertex streamRawPredict and convert the *single* SSE frame into OpenAI chunk format.""" + async def stream_response() -> AsyncGenerator[bytes, None]: - resp = await AnthropicAdapter.process_regular_response(url, headers, anthropic_payload, model_name, error_handler) - resp['object'] = 'chat.completion.chunk' - for choice in resp['choices']: - choice['delta'] = choice['message'] - del choice['message'] - yield f"data: {json.dumps(resp)}\n\n".encode() - yield b"data: [DONE]\n\n" + async with aiohttp.ClientSession() as session: + async with session.post(url, headers=headers, json=anthropic_payload) as response: + if response.status != 200: + error_text = await response.text() + error_handler(error_text, response.status) + + # Read the entire event-stream; Vertex currently responds with a few data: lines + body = await response.text() + + # Extract JSON payload(s) from SSE lines that start with "data: " + payloads: list[dict[str, Any]] = [] + for line in body.splitlines(): + line = line.strip() + if not line.startswith("data:"): + continue + data_part = line[len("data:"):].strip() + if data_part == "[DONE]": + continue + try: + payloads.append(json.loads(data_part)) + except json.JSONDecodeError: + continue + + if not payloads: + raise ProviderAPIException("Vertex", response.status, "Empty response from Vertex streamRawPredict") + + # Vertex typically returns a single JSON object – use the first + vertex_resp = payloads[0] + + # Convert to OpenAI chunk structure expected by Forge callers + openai_chunk = { + "id": vertex_resp.get("responseId", f"chatcmpl-{uuid.uuid4().hex}"), + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "index": 0, + "delta": { + "role": "assistant", + "content": vertex_resp.get("candidates", [{}])[0].get("content", "") + }, + "finish_reason": None, + } + ], + } + + # Yield the chunk then DONE, mimicking OpenAI stream format + yield f"data: {json.dumps(openai_chunk)}\n\n".encode() + yield b"data: [DONE]\n\n" + return stream_response() + return await custom_stream_response(url, headers, anthropic_payload, model_name) else: url = f"{self._base_url}/v1/projects/{self.project_id}/locations/{self.location}/publishers/{self.publisher}/models/{model_name}:rawPredict"