Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 6 additions & 58 deletions app/services/providers/vertex_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ async def process_completion(self, endpoint: str, payload: dict[str, Any], api_k
anthropic_payload["anthropic_version"] = "vertex-2023-10-16"
del anthropic_payload["model"]

logger.debug(f"Vertex API request - model: {model_name}, streaming: {streaming}, publisher: {self.publisher}, location: {self.location}")

def error_handler(error_text: str, http_status: int):
try:
error_json = json.loads(error_text)
Expand All @@ -236,66 +238,12 @@ def error_handler(error_text: str, http_status: int):
# https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.endpoints/streamRawPredict
# vertex doesn't do actual streaming, it just returns a stream of json objects
url = f"{self._base_url}/v1/projects/{self.project_id}/locations/{self.location}/publishers/{self.publisher}/models/{model_name}:streamRawPredict"
async def custom_stream_response(url, headers, anthropic_payload, model_name):
"""Call Vertex streamRawPredict and convert the *single* SSE frame into OpenAI chunk format."""

async def stream_response() -> AsyncGenerator[bytes, None]:
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=anthropic_payload) as response:
if response.status != 200:
error_text = await response.text()
error_handler(error_text, response.status)

# Read the entire event-stream; Vertex currently responds with a few data: lines
body = await response.text()

# Extract JSON payload(s) from SSE lines that start with "data: "
payloads: list[dict[str, Any]] = []
for line in body.splitlines():
line = line.strip()
if not line.startswith("data:"):
continue
data_part = line[len("data:"):].strip()
if data_part == "[DONE]":
continue
try:
payloads.append(json.loads(data_part))
except json.JSONDecodeError:
continue

if not payloads:
raise ProviderAPIException("Vertex", response.status, "Empty response from Vertex streamRawPredict")

# Vertex typically returns a single JSON object – use the first
vertex_resp = payloads[0]

# Convert to OpenAI chunk structure expected by Forge callers
openai_chunk = {
"id": vertex_resp.get("responseId", f"chatcmpl-{uuid.uuid4().hex}"),
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model_name,
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
"content": vertex_resp.get("candidates", [{}])[0].get("content", "")
},
"finish_reason": None,
}
],
}

# Yield the chunk then DONE, mimicking OpenAI stream format
yield f"data: {json.dumps(openai_chunk)}\n\n".encode()
yield b"data: [DONE]\n\n"

return stream_response()

return await custom_stream_response(url, headers, anthropic_payload, model_name)
logger.debug(f"Vertex streaming URL: {url}")
# Use the same streaming response handling as Anthropic adapter
return await AnthropicAdapter.stream_anthropic_response(url, headers, anthropic_payload, model_name, error_handler)
else:
url = f"{self._base_url}/v1/projects/{self.project_id}/locations/{self.location}/publishers/{self.publisher}/models/{model_name}:rawPredict"
logger.debug(f"Vertex non-streaming URL: {url}")
return await AnthropicAdapter.process_regular_response(url, headers, anthropic_payload, model_name, error_handler)

async def process_embeddings(self, payload: dict[str, Any]) -> Any:
Expand Down
Loading