diff --git a/app/core/openai/requests.py b/app/core/openai/requests.py index b36dfd2c..179a02b7 100644 --- a/app/core/openai/requests.py +++ b/app/core/openai/requests.py @@ -126,6 +126,10 @@ def _sanitize_input_items(input_items: list[JsonValue]) -> list[JsonValue]: return sanitized_input +def sanitize_input_items(input_items: list[JsonValue]) -> list[JsonValue]: + return _sanitize_input_items(input_items) + + def _sanitize_interleaved_reasoning_input_item(item: JsonValue) -> JsonValue | None: item_mapping = _json_mapping_or_none(item) if item_mapping is None: diff --git a/app/db/alembic/versions/20260327_000000_add_response_snapshots.py b/app/db/alembic/versions/20260327_000000_add_response_snapshots.py new file mode 100644 index 00000000..6d3a202c --- /dev/null +++ b/app/db/alembic/versions/20260327_000000_add_response_snapshots.py @@ -0,0 +1,75 @@ +"""add durable response snapshots + +Revision ID: 20260327_000000_add_response_snapshots +Revises: 20260321_210000_merge_request_log_tiers_and_dashboard_index_heads +Create Date: 2026-03-27 00:00:00.000000 +""" + +from __future__ import annotations + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.engine import Connection + +# revision identifiers, used by Alembic. +revision = "20260327_000000_add_response_snapshots" +down_revision = "20260321_210000_merge_request_log_tiers_and_dashboard_index_heads" +branch_labels = None +depends_on = None + + +def _table_exists(connection: Connection, table_name: str) -> bool: + inspector = sa.inspect(connection) + return inspector.has_table(table_name) + + +def _columns(connection: Connection, table_name: str) -> set[str]: + inspector = sa.inspect(connection) + if not inspector.has_table(table_name): + return set() + return {str(column["name"]) for column in inspector.get_columns(table_name) if column.get("name") is not None} + + +def _indexes(connection: Connection, table_name: str) -> set[str]: + inspector = sa.inspect(connection) + if not inspector.has_table(table_name): + return set() + return {str(index["name"]) for index in inspector.get_indexes(table_name) if index.get("name") is not None} + + +def upgrade() -> None: + bind = op.get_bind() + if not _table_exists(bind, "response_snapshots"): + op.create_table( + "response_snapshots", + sa.Column("response_id", sa.String(), nullable=False), + sa.Column("parent_response_id", sa.String(), nullable=True), + sa.Column("account_id", sa.String(), nullable=True), + sa.Column("api_key_id", sa.String(), nullable=True), + sa.Column("model", sa.String(), nullable=False), + sa.Column("input_items_json", sa.Text(), nullable=False), + sa.Column("response_json", sa.Text(), nullable=False), + sa.Column("created_at", sa.DateTime(), nullable=False, server_default=sa.func.now()), + sa.PrimaryKeyConstraint("response_id"), + ) + existing_columns = _columns(bind, "response_snapshots") + if "api_key_id" not in existing_columns: + op.add_column("response_snapshots", sa.Column("api_key_id", sa.String(), nullable=True)) + existing_indexes = _indexes(bind, "response_snapshots") + if "idx_response_snapshots_parent_created_at" not in existing_indexes: + op.create_index( + "idx_response_snapshots_parent_created_at", + "response_snapshots", + ["parent_response_id", "created_at"], + unique=False, + ) + + +def downgrade() -> None: + bind = op.get_bind() + if not _table_exists(bind, "response_snapshots"): + return + existing_indexes = _indexes(bind, "response_snapshots") + if "idx_response_snapshots_parent_created_at" in existing_indexes: + op.drop_index("idx_response_snapshots_parent_created_at", table_name="response_snapshots") + op.drop_table("response_snapshots") diff --git a/app/db/models.py b/app/db/models.py index 085a18f1..a5c9e2f5 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -131,6 +131,22 @@ class RequestLog(Base): error_message: Mapped[str | None] = mapped_column(Text, nullable=True) +class ResponseSnapshot(Base): + __tablename__ = "response_snapshots" + __table_args__ = ( + Index("idx_response_snapshots_parent_created_at", "parent_response_id", "created_at"), + ) + + response_id: Mapped[str] = mapped_column(String, primary_key=True) + parent_response_id: Mapped[str | None] = mapped_column(String, nullable=True) + account_id: Mapped[str | None] = mapped_column(String, nullable=True) + api_key_id: Mapped[str | None] = mapped_column(String, nullable=True) + model: Mapped[str] = mapped_column(String, nullable=False) + input_items_json: Mapped[str] = mapped_column(Text, nullable=False) + response_json: Mapped[str] = mapped_column(Text, nullable=False) + created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.now(), nullable=False) + + class StickySession(Base): __tablename__ = "sticky_sessions" diff --git a/app/dependencies.py b/app/dependencies.py index c25ef49c..8da59b45 100644 --- a/app/dependencies.py +++ b/app/dependencies.py @@ -20,6 +20,7 @@ from app.modules.firewall.service import FirewallService from app.modules.oauth.service import OauthService from app.modules.proxy.repo_bundle import ProxyRepositories +from app.modules.proxy.response_snapshots_repository import ResponseSnapshotsRepository from app.modules.proxy.service import ProxyService from app.modules.proxy.sticky_repository import StickySessionsRepository from app.modules.request_logs.repository import RequestLogsRepository @@ -153,6 +154,7 @@ async def _proxy_repo_context() -> AsyncIterator[ProxyRepositories]: sticky_sessions=StickySessionsRepository(session), api_keys=ApiKeysRepository(session), additional_usage=AdditionalUsageRepository(session), + response_snapshots=ResponseSnapshotsRepository(session), ) diff --git a/app/modules/proxy/load_balancer.py b/app/modules/proxy/load_balancer.py index 14eb158a..5696fd89 100644 --- a/app/modules/proxy/load_balancer.py +++ b/app/modules/proxy/load_balancer.py @@ -90,6 +90,7 @@ async def select_account( routing_strategy: RoutingStrategy = "usage_weighted", model: str | None = None, additional_limit_name: str | None = None, + preferred_account_id: str | None = None, exclude_account_ids: Collection[str] | None = None, ) -> AccountSelection: selection_inputs = await self._load_selection_inputs( @@ -134,6 +135,7 @@ async def select_account( sticky_max_age_seconds=sticky_max_age_seconds, prefer_earlier_reset_accounts=prefer_earlier_reset_accounts, routing_strategy=routing_strategy, + preferred_account_id=preferred_account_id, sticky_repo=repos.sticky_sessions, ) if result.account is not None: @@ -347,8 +349,20 @@ async def _select_with_stickiness( sticky_max_age_seconds: int | None, prefer_earlier_reset_accounts: bool, routing_strategy: RoutingStrategy, + preferred_account_id: str | None, sticky_repo: StickySessionsRepository | None, ) -> SelectionResult: + if preferred_account_id: + preferred_state = next((state for state in states if state.account_id == preferred_account_id), None) + if preferred_state is not None: + preferred_result = select_account( + [preferred_state], + prefer_earlier_reset=prefer_earlier_reset_accounts, + routing_strategy=routing_strategy, + allow_backoff_fallback=False, + ) + if preferred_result.account is not None: + return preferred_result if not sticky_key or not sticky_repo: return select_account( states, diff --git a/app/modules/proxy/repo_bundle.py b/app/modules/proxy/repo_bundle.py index afa6508f..4d102429 100644 --- a/app/modules/proxy/repo_bundle.py +++ b/app/modules/proxy/repo_bundle.py @@ -6,6 +6,7 @@ from app.modules.accounts.repository import AccountsRepository from app.modules.api_keys.repository import ApiKeysRepository +from app.modules.proxy.response_snapshots_repository import ResponseSnapshotsRepository from app.modules.proxy.sticky_repository import StickySessionsRepository from app.modules.request_logs.repository import RequestLogsRepository from app.modules.usage.repository import AdditionalUsageRepository, UsageRepository @@ -19,6 +20,7 @@ class ProxyRepositories: sticky_sessions: StickySessionsRepository api_keys: ApiKeysRepository additional_usage: AdditionalUsageRepository + response_snapshots: ResponseSnapshotsRepository | None = None ProxyRepoFactory = Callable[[], AsyncContextManager[ProxyRepositories]] diff --git a/app/modules/proxy/response_snapshots_repository.py b/app/modules/proxy/response_snapshots_repository.py new file mode 100644 index 00000000..49947c20 --- /dev/null +++ b/app/modules/proxy/response_snapshots_repository.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import json + +from sqlalchemy import select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.dialects.sqlite import insert as sqlite_insert +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql import Insert + +from app.core.types import JsonValue +from app.db.models import ResponseSnapshot + + +class ResponseSnapshotsRepository: + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def get(self, response_id: str, *, api_key_id: str | None) -> ResponseSnapshot | None: + if not response_id: + return None + statement = select(ResponseSnapshot).where(ResponseSnapshot.response_id == response_id) + if api_key_id is None: + statement = statement.where(ResponseSnapshot.api_key_id.is_(None)) + else: + statement = statement.where(ResponseSnapshot.api_key_id == api_key_id) + result = await self._session.execute(statement) + return result.scalar_one_or_none() + + async def upsert( + self, + *, + response_id: str, + parent_response_id: str | None, + account_id: str | None, + api_key_id: str | None, + model: str, + input_items: list[JsonValue], + response_payload: dict[str, JsonValue], + ) -> ResponseSnapshot: + statement = self._build_upsert_statement( + response_id=response_id, + parent_response_id=parent_response_id, + account_id=account_id, + api_key_id=api_key_id, + model=model, + input_items_json=json.dumps(input_items, ensure_ascii=False, separators=(",", ":")), + response_json=json.dumps(response_payload, ensure_ascii=False, separators=(",", ":")), + ) + await self._session.execute(statement) + await self._session.commit() + snapshot = await self.get(response_id, api_key_id=api_key_id) + if snapshot is None: + raise RuntimeError(f"ResponseSnapshot upsert failed for response_id={response_id!r}") + await self._session.refresh(snapshot) + return snapshot + + def _build_upsert_statement( + self, + *, + response_id: str, + parent_response_id: str | None, + account_id: str | None, + api_key_id: str | None, + model: str, + input_items_json: str, + response_json: str, + ) -> Insert: + dialect = self._session.get_bind().dialect.name + if dialect == "postgresql": + insert_fn = pg_insert + elif dialect == "sqlite": + insert_fn = sqlite_insert + else: + raise RuntimeError(f"ResponseSnapshot upsert unsupported for dialect={dialect!r}") + statement = insert_fn(ResponseSnapshot).values( + response_id=response_id, + parent_response_id=parent_response_id, + account_id=account_id, + api_key_id=api_key_id, + model=model, + input_items_json=input_items_json, + response_json=response_json, + ) + return statement.on_conflict_do_update( + index_elements=[ResponseSnapshot.response_id], + set_={ + "parent_response_id": parent_response_id, + "account_id": account_id, + "api_key_id": api_key_id, + "model": model, + "input_items_json": input_items_json, + "response_json": response_json, + }, + ) diff --git a/app/modules/proxy/service.py b/app/modules/proxy/service.py index 860ad70c..72fa276e 100644 --- a/app/modules/proxy/service.py +++ b/app/modules/proxy/service.py @@ -6,7 +6,7 @@ import logging import time from collections import deque -from collections.abc import Sequence +from collections.abc import Collection, Sequence from dataclasses import dataclass, field from hashlib import sha256 from typing import AsyncIterator, Mapping, NoReturn, cast @@ -51,8 +51,8 @@ from app.core.exceptions import AppError, ProxyAuthError, ProxyRateLimitError from app.core.openai.exceptions import ClientPayloadError from app.core.openai.models import CompactResponsePayload, OpenAIEvent, OpenAIResponsePayload -from app.core.openai.parsing import parse_sse_event -from app.core.openai.requests import ResponsesCompactRequest, ResponsesRequest +from app.core.openai.parsing import parse_response_payload, parse_sse_event +from app.core.openai.requests import ResponsesCompactRequest, ResponsesRequest, sanitize_input_items from app.core.types import JsonValue from app.core.usage.types import UsageWindowRow from app.core.utils.json_guards import is_json_mapping @@ -132,6 +132,14 @@ class _AffinityPolicy: max_age_seconds: int | None = None +@dataclass(frozen=True, slots=True) +class _ResolvedResponsesRequest: + payload: ResponsesRequest + current_input_items: list[JsonValue] + parent_response_id: str | None = None + preferred_account_id: str | None = None + + def _resolve_upstream_stream_transport(upstream_stream_transport: str) -> str | None: if upstream_stream_transport == "default": return None @@ -230,7 +238,42 @@ async def _stream_http_bridge_or_retry( yield line return - async for line in self._stream_via_http_bridge( + fallback_to_retry = False + bridge_yielded = False + try: + async for line in self._stream_via_http_bridge( + payload, + headers, + codex_session_affinity=codex_session_affinity, + propagate_http_errors=propagate_http_errors, + openai_cache_affinity=openai_cache_affinity, + api_key=api_key, + api_key_reservation=api_key_reservation, + suppress_text_done_events=suppress_text_done_events, + idle_ttl_seconds=getattr(settings, "http_responses_session_bridge_idle_ttl_seconds", 120.0), + codex_idle_ttl_seconds=getattr(settings, "http_responses_session_bridge_codex_idle_ttl_seconds", 900.0), + max_sessions=getattr(settings, "http_responses_session_bridge_max_sessions", 256), + queue_limit=getattr(settings, "http_responses_session_bridge_queue_limit", 8), + downstream_turn_state=downstream_turn_state, + ): + if ( + payload.previous_response_id is not None + and not bridge_yielded + and _is_previous_response_not_found_event_block(line) + ): + fallback_to_retry = True + break + bridge_yielded = True + yield line + if not fallback_to_retry: + return + except ProxyResponseError as exc: + error = _parse_openai_error(exc.payload) + error_code = _normalize_error_code(error.code if error else None, error.type if error else None) + if error_code != "previous_response_not_found" or payload.previous_response_id is None: + raise + + async for line in self._stream_with_retry( payload, headers, codex_session_affinity=codex_session_affinity, @@ -239,11 +282,7 @@ async def _stream_http_bridge_or_retry( api_key=api_key, api_key_reservation=api_key_reservation, suppress_text_done_events=suppress_text_done_events, - idle_ttl_seconds=getattr(settings, "http_responses_session_bridge_idle_ttl_seconds", 120.0), - codex_idle_ttl_seconds=getattr(settings, "http_responses_session_bridge_codex_idle_ttl_seconds", 900.0), - max_sessions=getattr(settings, "http_responses_session_bridge_max_sessions", 256), - queue_limit=getattr(settings, "http_responses_session_bridge_queue_limit", 8), - downstream_turn_state=downstream_turn_state, + request_transport=_REQUEST_TRANSPORT_HTTP, ): yield line @@ -862,6 +901,11 @@ async def proxy_responses_websocket( request_state = prepared_request.request_state request_affinity = prepared_request.affinity_policy text_data = prepared_request.text_data + request_state.upstream_text_data = text_data + request_state.affinity_key = request_affinity.key + request_state.affinity_kind = request_affinity.kind + request_state.affinity_reallocate_sticky = request_affinity.reallocate_sticky + request_state.affinity_max_age_seconds = request_affinity.max_age_seconds except AppError as exc: async with client_send_lock: await websocket.send_text( @@ -884,6 +928,14 @@ async def proxy_responses_websocket( ) ) continue + except ProxyResponseError as exc: + async with client_send_lock: + await websocket.send_text( + _serialize_websocket_error_event( + _wrapped_websocket_error_event(exc.status_code, exc.payload) + ) + ) + continue if ( request_state is not None @@ -933,19 +985,24 @@ async def proxy_responses_websocket( ) continue connect_headers = _headers_with_turn_state(filtered_headers, upstream_turn_state) + connect_kwargs = { + "sticky_key": request_affinity.key, + "sticky_kind": request_affinity.kind, + "reallocate_sticky": request_affinity.reallocate_sticky, + "sticky_max_age_seconds": request_affinity.max_age_seconds, + "prefer_earlier_reset": prefer_earlier_reset, + "routing_strategy": routing_strategy, + "model": request_state.model, + "request_state": request_state, + "api_key": api_key, + "client_send_lock": client_send_lock, + "websocket": websocket, + } + if request_state.preferred_account_id is not None: + connect_kwargs["preferred_account_id"] = request_state.preferred_account_id account, upstream = await self._connect_proxy_websocket( connect_headers, - sticky_key=request_affinity.key, - sticky_kind=request_affinity.kind, - reallocate_sticky=request_affinity.reallocate_sticky, - sticky_max_age_seconds=request_affinity.max_age_seconds, - prefer_earlier_reset=prefer_earlier_reset, - routing_strategy=routing_strategy, - model=request_state.model, - request_state=request_state, - api_key=api_key, - client_send_lock=client_send_lock, - websocket=websocket, + **connect_kwargs, ) if upstream is None or account is None: if request_state_registered: @@ -968,6 +1025,9 @@ async def proxy_responses_websocket( api_key=api_key, upstream_control=upstream_control, response_create_gate=response_create_gate, + filtered_headers=filtered_headers, + prefer_earlier_reset=prefer_earlier_reset, + routing_strategy=routing_strategy, proxy_request_budget_seconds=runtime_settings.proxy_request_budget_seconds, stream_idle_timeout_seconds=runtime_settings.stream_idle_timeout_seconds, ) @@ -1046,6 +1106,11 @@ async def _prepare_websocket_response_create_request( responses_payload = normalize_responses_request_payload(payload, openai_compat=openai_cache_affinity) apply_api_key_enforcement(responses_payload, refreshed_api_key) validate_model_access(refreshed_api_key, responses_payload.model) + resolved_request = await self._resolve_previous_response_request( + responses_payload, + api_key_id=refreshed_api_key.id if refreshed_api_key else None, + ) + responses_payload = resolved_request.payload reservation = await self._reserve_websocket_api_key_usage( refreshed_api_key, request_model=responses_payload.model, @@ -1060,6 +1125,10 @@ async def _prepare_websocket_response_create_request( include_type_field=True, attach_event_queue=False, client_metadata=client_metadata, + parent_response_id=resolved_request.parent_response_id, + current_input_items=resolved_request.current_input_items, + preferred_account_id=resolved_request.preferred_account_id, + api_key_id=refreshed_api_key.id if refreshed_api_key else None, ) had_prompt_cache_key = _prompt_cache_key_from_request_model(responses_payload) is not None affinity_policy = _sticky_key_for_responses_request( @@ -1119,6 +1188,10 @@ def _prepare_response_bridge_request_state( include_type_field: bool, attach_event_queue: bool, client_metadata: Mapping[str, JsonValue] | None, + parent_response_id: str | None = None, + current_input_items: list[JsonValue] | None = None, + preferred_account_id: str | None = None, + api_key_id: str | None = None, ) -> tuple[_WebSocketRequestState, str]: upstream_payload = dict(payload.to_payload()) upstream_payload.pop("stream", None) @@ -1136,9 +1209,15 @@ def _prepare_response_bridge_request_state( api_key_reservation=api_key_reservation, started_at=time.monotonic(), requested_service_tier=forwarded_service_tier, + parent_response_id=parent_response_id if parent_response_id is not None else payload.previous_response_id, + preferred_account_id=preferred_account_id, + current_input_items=( + current_input_items if current_input_items is not None else _clone_json_list(payload.input) + ), awaiting_response_created=True, event_queue=asyncio.Queue() if attach_event_queue else None, api_key=api_key, + api_key_id=api_key_id if api_key_id is not None else (api_key.id if api_key is not None else None), previous_response_id=payload.previous_response_id, ) text_data = json.dumps(upstream_payload, ensure_ascii=True, separators=(",", ":")) @@ -1160,10 +1239,11 @@ async def _connect_proxy_websocket( websocket: WebSocket, reallocate_sticky: bool = False, sticky_max_age_seconds: int | None = None, + preferred_account_id: str | None = None, ) -> tuple[Account | None, UpstreamResponsesWebSocket | None]: deadline = _websocket_connect_deadline(request_state, get_settings().proxy_request_budget_seconds) try: - selection = await self._select_account_with_budget( + selection = await self._select_account_with_budget_compat( deadline, request_id=request_state.request_id, kind="websocket", @@ -1174,6 +1254,7 @@ async def _connect_proxy_websocket( prefer_earlier_reset_accounts=prefer_earlier_reset, routing_strategy=routing_strategy, model=model, + preferred_account_id=preferred_account_id, ) except ProxyResponseError as exc: if _is_proxy_budget_exhausted_error(exc): @@ -2199,6 +2280,7 @@ async def _process_http_bridge_upstream_text( if actual_service_tier is not None: matched_request_state.actual_service_tier = actual_service_tier matched_request_state.service_tier = actual_service_tier + _collect_output_item_event(payload, matched_request_state.output_items) terminal_request_state = None if event_type in {"response.completed", "response.failed", "response.incomplete", "error"}: @@ -2287,6 +2369,9 @@ async def _relay_upstream_websocket_messages( api_key: ApiKeyData | None, upstream_control: _WebSocketUpstreamControl, response_create_gate: asyncio.Semaphore, + filtered_headers: dict[str, str], + prefer_earlier_reset: bool, + routing_strategy: RoutingStrategy, proxy_request_budget_seconds: float, stream_idle_timeout_seconds: float, ) -> None: @@ -2369,12 +2454,36 @@ async def _relay_upstream_websocket_messages( async with client_send_lock: await websocket.send_bytes(message.data) continue + disconnect_message = _upstream_websocket_disconnect_message(message) + logger.warning( + "Upstream websocket disconnected before terminal event account_id=%s request_id=%s kind=%s close_code=%s error=%s", + account_id_value, + get_request_id(), + message.kind, + message.close_code, + message.error, + ) + await self._handle_stream_error(account, {"message": disconnect_message}, "stream_incomplete") + retried = await self._retry_websocket_request_after_disconnect( + disconnected_account=account, + pending_requests=pending_requests, + pending_lock=pending_lock, + filtered_headers=filtered_headers, + prefer_earlier_reset=prefer_earlier_reset, + routing_strategy=routing_strategy, + proxy_request_budget_seconds=proxy_request_budget_seconds, + disconnect_event=message, + ) + if retried is not None: + account, upstream = retried + account_id_value = account.id + continue await self._fail_pending_websocket_requests( account_id_value=account_id_value, pending_requests=pending_requests, pending_lock=pending_lock, error_code="stream_incomplete", - error_message=_upstream_websocket_disconnect_message(message), + error_message=disconnect_message, api_key=api_key, websocket=websocket, client_send_lock=client_send_lock, @@ -2390,6 +2499,117 @@ async def _relay_upstream_websocket_messages( except Exception: logger.debug("Failed to close downstream websocket", exc_info=True) + async def _retry_websocket_request_after_disconnect( + self, + *, + disconnected_account: Account, + pending_requests: deque[_WebSocketRequestState], + pending_lock: anyio.Lock, + filtered_headers: dict[str, str], + prefer_earlier_reset: bool, + routing_strategy: RoutingStrategy, + proxy_request_budget_seconds: float, + disconnect_event: UpstreamWebSocketMessage, + ) -> tuple[Account, UpstreamResponsesWebSocket] | None: + async with pending_lock: + retry_candidates = [ + request_state for request_state in pending_requests if _is_retryable_websocket_request_state(request_state) + ] + if len(retry_candidates) != 1 or len(pending_requests) != 1: + return None + request_state = retry_candidates[0] + request_state.websocket_retry_count += 1 + + logger.warning( + "Retrying websocket request after upstream disconnect request_id=%s failed_account_id=%s kind=%s close_code=%s error=%s retry_count=%s", + request_state.request_id, + disconnected_account.id, + disconnect_event.kind, + disconnect_event.close_code, + disconnect_event.error, + request_state.websocket_retry_count, + ) + + deadline = request_state.started_at + proxy_request_budget_seconds + try: + selection = await self._select_account_with_budget_compat( + deadline, + request_id=request_state.request_id, + kind="websocket", + sticky_key=request_state.affinity_key, + sticky_kind=request_state.affinity_kind, + reallocate_sticky=request_state.affinity_reallocate_sticky, + sticky_max_age_seconds=request_state.affinity_max_age_seconds, + prefer_earlier_reset_accounts=prefer_earlier_reset, + routing_strategy=routing_strategy, + model=request_state.model, + preferred_account_id=request_state.preferred_account_id, + exclude_account_ids={disconnected_account.id}, + ) + except Exception: + logger.warning( + "Failed to select retry account after websocket disconnect request_id=%s failed_account_id=%s", + request_state.request_id, + disconnected_account.id, + exc_info=True, + ) + return None + + retry_account = selection.account + if retry_account is None: + logger.warning( + "No retry account available after websocket disconnect request_id=%s failed_account_id=%s error=%s", + request_state.request_id, + disconnected_account.id, + selection.error_message, + ) + return None + + retry_upstream: UpstreamResponsesWebSocket | None = None + try: + remaining_budget = _remaining_budget_seconds(deadline) + if remaining_budget <= 0: + return None + retry_account = await self._ensure_fresh_with_budget(retry_account, timeout_seconds=remaining_budget) + remaining_budget = _remaining_budget_seconds(deadline) + if remaining_budget <= 0: + return None + retry_upstream = await self._open_upstream_websocket_with_budget( + retry_account, + filtered_headers, + timeout_seconds=remaining_budget, + ) + if request_state.upstream_text_data is not None: + await retry_upstream.send_text(request_state.upstream_text_data) + elif request_state.upstream_bytes_data is not None: + await retry_upstream.send_bytes(request_state.upstream_bytes_data) + else: + await retry_upstream.close() + return None + except Exception: + logger.warning( + "Websocket retry failed request_id=%s failed_account_id=%s retry_account_id=%s", + request_state.request_id, + disconnected_account.id, + retry_account.id, + exc_info=True, + ) + await self._load_balancer.record_error(retry_account) + if retry_upstream is not None: + try: + await retry_upstream.close() + except Exception: + logger.debug("Failed to close retried upstream websocket", exc_info=True) + return None + + logger.info( + "Retried websocket request request_id=%s failed_account_id=%s retry_account_id=%s", + request_state.request_id, + disconnected_account.id, + retry_account.id, + ) + return retry_account, retry_upstream + async def _process_upstream_websocket_text( self, text: str, @@ -2428,6 +2648,7 @@ async def _process_upstream_websocket_text( if actual_service_tier is not None: request_state.actual_service_tier = actual_service_tier request_state.service_tier = actual_service_tier + _collect_output_item_event(payload, request_state.output_items) if ( event_type in {"response.completed", "response.failed", "response.incomplete", "error"} and pending_requests @@ -2592,6 +2813,17 @@ async def _finalize_websocket_request_state( elif settlement.record_success: await self._load_balancer.record_success(account) + if event_type == "response.completed": + await self._persist_response_snapshot( + response_id=response_id, + parent_response_id=request_state.parent_response_id, + account_id=account_id_value, + api_key_id=request_state.api_key_id, + model=request_state.model or "", + input_items=request_state.current_input_items, + response_payload=_terminal_response_payload(payload, request_state.output_items), + ) + latency_ms = int((time.monotonic() - request_state.started_at) * 1000) cached_input_tokens = usage.input_tokens_details.cached_tokens if usage and usage.input_tokens_details else None reasoning_tokens = ( @@ -2981,6 +3213,120 @@ async def get_rate_limit_payload(self) -> RateLimitStatusPayloadData: additional_rate_limits=additional_rate_limits, ) + async def _resolve_previous_response_request( + self, + payload: ResponsesRequest, + *, + api_key_id: str | None, + ) -> _ResolvedResponsesRequest: + current_input_items = _clone_json_list(payload.input) + previous_response_id = payload.previous_response_id + if not previous_response_id: + return _ResolvedResponsesRequest(payload=payload, current_input_items=current_input_items) + + replay_items, preferred_account_id = await self._resolve_previous_response_chain( + previous_response_id, + api_key_id=api_key_id, + ) + resolved_payload = payload.model_copy( + deep=True, + update={ + "input": [*replay_items, *current_input_items], + "previous_response_id": None, + }, + ) + return _ResolvedResponsesRequest( + payload=resolved_payload, + current_input_items=current_input_items, + parent_response_id=previous_response_id, + preferred_account_id=preferred_account_id, + ) + + async def _resolve_previous_response_chain( + self, + response_id: str, + *, + api_key_id: str | None, + ) -> tuple[list[JsonValue], str | None]: + if not response_id: + _raise_unknown_previous_response_id() + + async with self._repo_factory() as repos: + if repos.response_snapshots is None: + _raise_unknown_previous_response_id() + chain: list[dict[str, str | None]] = [] + preferred_account_id: str | None = None + current_id = response_id + seen_response_ids: set[str] = set() + + while current_id: + if current_id in seen_response_ids: + _raise_unknown_previous_response_id() + seen_response_ids.add(current_id) + snapshot = await repos.response_snapshots.get(current_id, api_key_id=api_key_id) + if snapshot is None: + _raise_unknown_previous_response_id() + chain.append( + { + "response_id": snapshot.response_id, + "parent_response_id": snapshot.parent_response_id, + "account_id": snapshot.account_id, + "input_items_json": snapshot.input_items_json, + "response_json": snapshot.response_json, + } + ) + if preferred_account_id is None and snapshot.account_id: + preferred_account_id = snapshot.account_id + current_id = snapshot.parent_response_id or "" + + replay_items: list[JsonValue] = [] + for snapshot in reversed(chain): + snapshot_input = _decode_snapshot_json_list(snapshot["input_items_json"] or "[]") + snapshot_response = _decode_snapshot_json_mapping(snapshot["response_json"] or "{}") + replay_items.extend(snapshot_input) + replay_items.extend(_replayable_response_output_items(snapshot_response)) + return replay_items, preferred_account_id + + async def _persist_response_snapshot( + self, + *, + response_id: str | None, + parent_response_id: str | None, + account_id: str | None, + api_key_id: str | None, + model: str, + input_items: list[JsonValue], + response_payload: dict[str, JsonValue] | None, + ) -> None: + if not response_id or response_payload is None: + return + response_payload = dict(response_payload) + response_payload.setdefault("id", response_id) + if parse_response_payload(response_payload) is None: + return + + with anyio.CancelScope(shield=True): + try: + async with self._repo_factory() as repos: + if repos.response_snapshots is None: + return + await repos.response_snapshots.upsert( + response_id=response_id, + parent_response_id=parent_response_id, + account_id=account_id, + api_key_id=api_key_id, + model=model, + input_items=input_items, + response_payload=response_payload, + ) + except Exception: + logger.warning( + "Failed to persist response snapshot response_id=%s parent_response_id=%s", + response_id, + parent_response_id, + exc_info=True, + ) + async def _stream_with_retry( self, payload: ResponsesRequest, @@ -2999,6 +3345,11 @@ async def _stream_with_retry( base_settings = get_settings() settings = await get_settings_cache().get() deadline = start + base_settings.proxy_request_budget_seconds + resolved_request = await self._resolve_previous_response_request( + payload, + api_key_id=api_key.id if api_key else None, + ) + payload = resolved_request.payload prefer_earlier_reset = settings.prefer_earlier_reset_accounts upstream_stream_transport = _resolve_upstream_stream_transport(settings.upstream_stream_transport) had_prompt_cache_key = _prompt_cache_key_from_request_model(payload) is not None @@ -3054,7 +3405,7 @@ async def _stream_with_retry( yield format_sse_event(_proxy_request_timeout_event(request_id)) return try: - selection = await self._select_account_with_budget( + selection = await self._select_account_with_budget_compat( deadline, request_id=request_id, kind="stream", @@ -3065,6 +3416,7 @@ async def _stream_with_retry( prefer_earlier_reset_accounts=prefer_earlier_reset, routing_strategy=routing_strategy, model=payload.model, + preferred_account_id=resolved_request.preferred_account_id, ) except ProxyResponseError as exc: error = _parse_openai_error(exc.payload) @@ -3224,6 +3576,8 @@ async def _stream_with_retry( ), api_key=api_key, settlement=settlement, + parent_response_id=resolved_request.parent_response_id, + snapshot_input_items=resolved_request.current_input_items, suppress_text_done_events=suppress_text_done_events, upstream_stream_transport=upstream_stream_transport, request_transport=request_transport, @@ -3396,6 +3750,8 @@ async def _stream_with_retry( False, api_key=api_key, settlement=settlement, + parent_response_id=resolved_request.parent_response_id, + snapshot_input_items=resolved_request.current_input_items, suppress_text_done_events=suppress_text_done_events, upstream_stream_transport=upstream_stream_transport, request_transport=request_transport, @@ -3513,6 +3869,8 @@ async def _stream_once( allow_transient_retry: bool = False, api_key: ApiKeyData | None, settlement: _StreamSettlement, + parent_response_id: str | None, + snapshot_input_items: list[JsonValue], suppress_text_done_events: bool, upstream_stream_transport: str | None, request_transport: str, @@ -3531,6 +3889,9 @@ async def _stream_once( error_message = None usage = None saw_text_delta = False + output_items: dict[int, dict[str, JsonValue]] = {} + completed_response_payload: dict[str, JsonValue] | None = None + response_id: str | None = None try: if upstream_stream_transport is not None: @@ -3558,6 +3919,8 @@ async def _stream_once( first_payload = parse_sse_data_json(first) event = parse_sse_event(first) event_type = _event_type_from_payload(event, first_payload) + _collect_output_item_event(first_payload, output_items) + response_id = _websocket_response_id(event, first_payload) or response_id event_service_tier = _service_tier_from_event_payload(first_payload) if event_service_tier is not None: actual_service_tier = event_service_tier @@ -3597,6 +3960,7 @@ async def _stream_once( if event and event.type in ("response.completed", "response.incomplete"): usage = event.response.usage if event.response else None + completed_response_payload = _terminal_response_payload(first_payload, output_items) if event.type == "response.incomplete": status = "error" @@ -3616,6 +3980,8 @@ async def _stream_once( event_payload = parse_sse_data_json(line) event = parse_sse_event(line) event_type = _event_type_from_payload(event, event_payload) + _collect_output_item_event(event_payload, output_items) + response_id = _websocket_response_id(event, event_payload) or response_id event_service_tier = _service_tier_from_event_payload(event_payload) if event_service_tier is not None: actual_service_tier = event_service_tier @@ -3647,6 +4013,7 @@ async def _stream_once( settlement.account_health_error = _should_penalize_stream_error(error_code) if event_type in ("response.completed", "response.incomplete"): usage = event.response.usage if event.response else None + completed_response_payload = _terminal_response_payload(event_payload, output_items) if event_type == "response.incomplete": status = "error" yield line @@ -3702,6 +4069,16 @@ async def _stream_once( requested_service_tier=requested_service_tier, actual_service_tier=actual_service_tier, ) + if status == "success": + await self._persist_response_snapshot( + response_id=response_id, + parent_response_id=parent_response_id, + account_id=account_id_value, + api_key_id=api_key.id if api_key else None, + model=model, + input_items=snapshot_input_items, + response_payload=completed_response_payload, + ) async def _write_request_log( self, @@ -3972,6 +4349,19 @@ async def _ensure_fresh_with_budget( return await self._ensure_fresh(account, force=force, timeout_seconds=timeout_seconds) return await self._ensure_fresh(account, force=force) + async def _select_account_with_budget_compat( + self, + deadline: float, + **kwargs: object, + ) -> AccountSelection: + parameters = inspect.signature(self._select_account_with_budget).parameters + compatible_kwargs = dict(kwargs) + if "preferred_account_id" not in parameters: + compatible_kwargs.pop("preferred_account_id", None) + if "exclude_account_ids" not in parameters: + compatible_kwargs.pop("exclude_account_ids", None) + return await self._select_account_with_budget(deadline, **compatible_kwargs) + async def _select_account_with_budget( self, deadline: float, @@ -3986,6 +4376,8 @@ async def _select_account_with_budget( routing_strategy: RoutingStrategy = "usage_weighted", model: str | None = None, additional_limit_name: str | None = None, + preferred_account_id: str | None = None, + exclude_account_ids: Collection[str] | None = None, ) -> AccountSelection: remaining_budget = _remaining_budget_seconds(deadline) if remaining_budget <= 0: @@ -4004,6 +4396,8 @@ async def _select_account_with_budget( routing_strategy=routing_strategy, model=model, additional_limit_name=additional_limit_name, + preferred_account_id=preferred_account_id, + exclude_account_ids=exclude_account_ids, ) except TimeoutError: logger.warning("%s account selection exceeded request budget request_id=%s", kind.title(), request_id) @@ -4113,12 +4507,24 @@ class _WebSocketRequestState: requested_service_tier: str | None = None actual_service_tier: str | None = None response_id: str | None = None + parent_response_id: str | None = None + preferred_account_id: str | None = None + current_input_items: list[JsonValue] = field(default_factory=list) + output_items: dict[int, dict[str, JsonValue]] = field(default_factory=dict) awaiting_response_created: bool = False event_queue: asyncio.Queue[str | None] | None = None transport: str = _REQUEST_TRANSPORT_WEBSOCKET api_key: ApiKeyData | None = None + api_key_id: str | None = None request_text: str | None = None + upstream_text_data: str | None = None + upstream_bytes_data: bytes | None = None + affinity_key: str | None = None + affinity_kind: StickySessionKind | None = None + affinity_reallocate_sticky: bool = False + affinity_max_age_seconds: int | None = None replay_count: int = 0 + websocket_retry_count: int = 0 skip_request_log: bool = False previous_response_id: str | None = None error_code_override: str | None = None @@ -4205,6 +4611,102 @@ def _websocket_response_id(event: OpenAIEvent | None, payload: dict[str, JsonVal return stripped or None +def _collect_output_item_event( + payload: dict[str, JsonValue] | None, + output_items: dict[int, dict[str, JsonValue]], +) -> None: + if not isinstance(payload, dict): + return + event_type = payload.get("type") + if event_type not in ("response.output_item.added", "response.output_item.done"): + return + output_index = payload.get("output_index") + item = payload.get("item") + if not isinstance(output_index, int) or not isinstance(item, dict): + return + output_items[output_index] = dict(item) + + +def _merge_response_output_items( + response: dict[str, JsonValue], + output_items: dict[int, dict[str, JsonValue]], +) -> dict[str, JsonValue]: + merged = dict(response) + existing_output = response.get("output") + if isinstance(existing_output, list) and existing_output: + return merged + if output_items: + merged["output"] = [item for _, item in sorted(output_items.items())] + return merged + + +def _terminal_response_payload( + payload: dict[str, JsonValue] | None, + output_items: dict[int, dict[str, JsonValue]], +) -> dict[str, JsonValue] | None: + if not isinstance(payload, dict): + return None + response = payload.get("response") + if not isinstance(response, dict): + return None + return _merge_response_output_items(response, output_items) + + +def _replayable_response_output_items(response_payload: dict[str, JsonValue]) -> list[JsonValue]: + output_value = response_payload.get("output") + if not isinstance(output_value, list): + _raise_unknown_previous_response_id() + filtered_output = [ + item + for item in output_value + if not (isinstance(item, dict) and item.get("type") == "reasoning") + ] + normalized_output = sanitize_input_items(filtered_output) + replay_items: list[JsonValue] = [] + for item in normalized_output: + if isinstance(item, dict) and item.get("role") == "assistant": + replay_items.append({"role": "assistant", "content": item.get("content")}) + continue + replay_items.append(item) + return replay_items + + +def _clone_json_list(value: JsonValue) -> list[JsonValue]: + if not isinstance(value, list): + raise TypeError("expected list input items") + return json.loads(json.dumps(value, ensure_ascii=False)) + + +def _decode_snapshot_json_list(value: str) -> list[JsonValue]: + try: + decoded = json.loads(value) + except json.JSONDecodeError as exc: + payload = openai_error("invalid_request_error", "Unknown previous_response_id", "invalid_request_error") + payload["error"]["param"] = "previous_response_id" + raise ProxyResponseError(400, payload) from exc + if not isinstance(decoded, list): + _raise_unknown_previous_response_id() + return decoded + + +def _decode_snapshot_json_mapping(value: str) -> dict[str, JsonValue]: + try: + decoded = json.loads(value) + except json.JSONDecodeError as exc: + payload = openai_error("invalid_request_error", "Unknown previous_response_id", "invalid_request_error") + payload["error"]["param"] = "previous_response_id" + raise ProxyResponseError(400, payload) from exc + if not isinstance(decoded, dict): + _raise_unknown_previous_response_id() + return decoded + + +def _raise_unknown_previous_response_id() -> NoReturn: + payload = openai_error("invalid_request_error", "Unknown previous_response_id", "invalid_request_error") + payload["error"]["param"] = "previous_response_id" + raise ProxyResponseError(400, payload) + + def _find_websocket_request_state_by_response_id( pending_requests: deque[_WebSocketRequestState], response_id: str, @@ -4265,6 +4767,28 @@ def _pop_terminal_websocket_request_state( return None +def _is_retryable_websocket_request_state(request_state: _WebSocketRequestState) -> bool: + if request_state.response_id is not None: + return False + if request_state.websocket_retry_count >= 1: + return False + return request_state.upstream_text_data is not None or request_state.upstream_bytes_data is not None + + +def _is_previous_response_not_found_event_block(line: str) -> bool: + payload = parse_sse_data_json(line) + event = parse_sse_event(line) + event_type = _event_type_from_payload(event, payload) + if event_type == "response.failed": + error = event.response.error if event and event.response else None + elif event_type == "error": + error = event.error if event else None + else: + return False + error_code = _normalize_error_code(error.code if error else None, error.type if error else None) + return error_code == "previous_response_not_found" + + def _upstream_websocket_disconnect_message(message: UpstreamWebSocketMessage) -> str: if message.kind == "error" and message.error: return f"Upstream websocket closed before response.completed: {message.error}" diff --git a/openspec/changes/support-responses-previous-response-id/proposal.md b/openspec/changes/support-responses-previous-response-id/proposal.md index a18d7d9a..b6eb4f25 100644 --- a/openspec/changes/support-responses-previous-response-id/proposal.md +++ b/openspec/changes/support-responses-previous-response-id/proposal.md @@ -1,12 +1,14 @@ ## Why -Codex CLI websocket/resume flows now send previous_response_id for incremental Responses requests. codex-lb still rejects that field, causing websocket-enabled exec/resume failures and leaving users stuck on websocket-off fallback paths with degraded cache behavior. +The original `previous_response_id` work in `#211` mixed durable continuity improvements with an unwanted PostgreSQL-first backend rewrite. `codex-lb` still needs the continuity gains, but SQLite remains the default runtime and must stay first-class. We need to restore the remaining wins on the existing project-native database primitives. ## What Changes -- Allow and forward previous_response_id on Responses requests where upstream accepts it. -- Preserve conflict validation between conversation and previous_response_id. -- Add regression coverage for websocket and HTTP forwarding paths. -- Verify real cache behavior locally with docker compose across websocket on/off and /v1/responses variants. +- Persist terminal Responses snapshots in the default database so `previous_response_id` can survive process restart and HTTP bridge loss. +- Resolve `previous_response_id` from caller-scoped continuity state when live upstream continuity is unavailable, while preserving the conflict validation against `conversation`. +- Prefer the originating upstream account for replay when that account is still eligible. +- Retry one websocket request on early upstream disconnect before `response.created`. +- Extend migration, HTTP bridge, websocket, and API-key scoping regression coverage and sync the specs. ## Impact -- Restores compatibility with newer Codex CLI websocket flows. -- Enables direct measurement of whether cache behavior returns to expected levels. +- Restores durable `previous_response_id` compatibility for newer Codex CLI and OpenAI-style Responses flows without making PostgreSQL mandatory. +- Preserves SQLite as the default runtime while keeping PostgreSQL optional. +- Improves resilience for native websocket clients during early upstream disconnects. diff --git a/openspec/changes/support-responses-previous-response-id/specs/database-migrations/spec.md b/openspec/changes/support-responses-previous-response-id/specs/database-migrations/spec.md new file mode 100644 index 00000000..145e9760 --- /dev/null +++ b/openspec/changes/support-responses-previous-response-id/specs/database-migrations/spec.md @@ -0,0 +1,11 @@ +### ADDED Requirement: Durable response snapshot continuity storage +Startup migrations SHALL create and preserve the durable storage needed to replay `previous_response_id` across bridge loss and restart. The continuity schema SHALL include caller scoping so one API key cannot replay another caller's stored response chain. + +#### Scenario: startup migration creates response snapshot storage +- **WHEN** startup migrations upgrade a database without response continuity storage +- **THEN** the schema includes `response_snapshots` +- **AND** that table includes `api_key_id` alongside the serialized continuity payload columns + +#### Scenario: startup migration repairs partial response snapshot storage +- **WHEN** startup migrations encounter an existing `response_snapshots` table missing `api_key_id` or the parent/created-at continuity index +- **THEN** the migration adds the missing column and index without requiring operator intervention diff --git a/openspec/changes/support-responses-previous-response-id/specs/responses-api-compat/spec.md b/openspec/changes/support-responses-previous-response-id/specs/responses-api-compat/spec.md index 8aa46956..42a3a60d 100644 --- a/openspec/changes/support-responses-previous-response-id/specs/responses-api-compat/spec.md +++ b/openspec/changes/support-responses-previous-response-id/specs/responses-api-compat/spec.md @@ -1,5 +1,5 @@ -### MODIFIED Requirement: Validate request structure and unsupported fields -The service MUST accept `input` as either a string or an array of input items. When `input` is a string, the service MUST normalize it into a single user input item with `input_text` content before forwarding upstream. The service MUST continue to reject requests that include both `conversation` and `previous_response_id`. +### MODIFIED Requirement: Support Responses input types and conversation constraints +The service MUST accept `input` as either a string or an array of input items. When `input` is a string, the service MUST normalize it into a single user input item with `input_text` content before forwarding upstream. The service MUST accept `previous_response_id` when `conversation` is absent, MUST prefer live upstream continuity when available, and otherwise MUST resolve that response id from caller-scoped persisted continuity state before forwarding upstream. The service MUST continue to reject requests that include both `conversation` and `previous_response_id`. #### Scenario: conversation and previous_response_id conflict - **WHEN** the client provides both `conversation` and `previous_response_id` @@ -7,11 +7,27 @@ The service MUST accept `input` as either a string or an array of input items. W #### Scenario: previous_response_id provided - **WHEN** the client provides `previous_response_id` without `conversation` -- **THEN** the service accepts the request and forwards `previous_response_id` upstream unchanged +- **THEN** the service accepts the request +- **AND** it either preserves live upstream continuity or rebuilds the request locally from persisted continuity state before forwarding upstream -### MODIFIED Requirement: WebSocket Responses proxy preserves request shape -When proxying websocket `response.create` requests, the service MUST preserve supported incremental request fields required by native Codex clients. The service MUST forward `previous_response_id` unchanged when present and MUST continue to omit only HTTP-only transport fields such as `stream` and `background` from the upstream websocket payload. +#### Scenario: previous_response_id is outside caller scope +- **WHEN** the client provides `previous_response_id` +- **AND** the stored continuity belongs to another API key scope or does not exist +- **THEN** the service returns `invalid_request_error` on `previous_response_id` -#### Scenario: websocket response.create includes previous_response_id -- **WHEN** a websocket `response.create` payload includes a non-empty `previous_response_id` -- **THEN** the upstream websocket payload includes the same `previous_response_id` +### MODIFIED Requirement: HTTP Responses routes preserve upstream websocket session continuity +When serving HTTP `/v1/responses` or HTTP `/backend-api/codex/responses`, the service MUST preserve upstream Responses websocket session continuity on a stable per-session bridge key when that live session is available. If that live session is unavailable and caller-scoped persisted continuity exists for `previous_response_id`, the service MUST rebuild the request locally and complete it through a fresh upstream request without opening a replacement bridge session that forwards the old `previous_response_id` upstream unchanged. + +#### Scenario: HTTP bridge loss falls back to persisted replay +- **WHEN** a client sends HTTP `/v1/responses` or `/backend-api/codex/responses` with `previous_response_id` +- **AND** there is no matching live bridged upstream session +- **AND** caller-scoped persisted continuity exists +- **THEN** the service rebuilds the request locally from persisted continuity state and completes it successfully + +### ADDED Requirement: Websocket Responses retry one request after early upstream disconnect +When an upstream Responses websocket disconnects before `response.created`, the service MUST retry at most one pending request on another eligible account when exactly one request is in flight and that request has not yet been acknowledged upstream. + +#### Scenario: upstream disconnects before response.created +- **WHEN** exactly one websocket `response.create` request is pending +- **AND** the upstream disconnects before emitting `response.created` +- **THEN** the service retries that request once on another eligible account diff --git a/openspec/changes/support-responses-previous-response-id/tasks.md b/openspec/changes/support-responses-previous-response-id/tasks.md index 5ce62e63..eb020d03 100644 --- a/openspec/changes/support-responses-previous-response-id/tasks.md +++ b/openspec/changes/support-responses-previous-response-id/tasks.md @@ -1,5 +1,6 @@ -- [x] Update Responses request validation/serialization to support previous_response_id -- [x] Add unit and integration regression coverage for websocket and HTTP forwarding -- [x] Bring up local docker compose with preserved codex-lb-data auth state -- [x] Measure cache behavior matrix across websocket on/off and /v1/responses variants -- [x] Sync relevant spec updates and verify implementation +- [x] Persist response snapshots in the default database with caller/API-key scoping. +- [x] Resolve `previous_response_id` from persisted continuity state when live bridge continuity is unavailable. +- [x] Prefer the originating account for replay when that account is still eligible. +- [x] Retry one websocket request after an upstream disconnect before `response.created`. +- [x] Add migration, HTTP, websocket, and API-key scoping regression coverage. +- [x] Sync the Responses and database-migrations specs with the SQLite-first continuity design. diff --git a/openspec/specs/database-migrations/spec.md b/openspec/specs/database-migrations/spec.md index 9e410ec6..b0e0cdb3 100644 --- a/openspec/specs/database-migrations/spec.md +++ b/openspec/specs/database-migrations/spec.md @@ -107,6 +107,21 @@ The system SHALL create a SQLite backup before applying startup migrations when - **THEN** the system creates a pre-migration backup file - **AND** enforces configured retention on backup files +### Requirement: Durable response snapshot continuity storage + +Startup migrations SHALL create and preserve the durable storage required for replaying `previous_response_id` across bridge loss and restart. This storage SHALL include a `response_snapshots` table with caller-scoping metadata and serialized request/response continuity data. + +#### Scenario: Startup creates response snapshot storage + +- **WHEN** startup migrations upgrade a database that does not yet include response continuity storage +- **THEN** the migrated schema includes `response_snapshots` +- **AND** that table includes `response_id`, `parent_response_id`, `account_id`, `api_key_id`, `model`, `input_items_json`, `response_json`, and `created_at` + +#### Scenario: Startup repairs partial response snapshot storage + +- **WHEN** startup migrations encounter an existing `response_snapshots` table missing the caller-scoping column or continuity index +- **THEN** the migration adds the missing `api_key_id` column and required continuity index without requiring manual intervention + ### Requirement: Migration policy and drift guard in CI The project SHALL fail CI when migration policy is violated or ORM metadata and migrated schema diverge. diff --git a/openspec/specs/responses-api-compat/context.md b/openspec/specs/responses-api-compat/context.md index 58ad7e3f..bcfb2b4d 100644 --- a/openspec/specs/responses-api-compat/context.md +++ b/openspec/specs/responses-api-compat/context.md @@ -19,7 +19,9 @@ See `openspec/specs/responses-api-compat/spec.md` for normative requirements. - `store=true` is rejected; responses are not persisted. - `include` values must be on the documented allowlist. - `truncation` is rejected. -- `previous_response_id` is forwarded when `conversation` is absent, but the `conversation + previous_response_id` conflict remains rejected. +- `previous_response_id` is accepted when `conversation` is absent, but the `conversation + previous_response_id` conflict remains rejected. +- Live bridge/session continuity is preferred, but the service now persists caller-scoped response snapshots in the default database so `previous_response_id` can survive bridge loss and process restart without requiring PostgreSQL. +- Persisted continuity is scoped by `api_key_id` when API key auth is enabled; requests from another API key fail closed instead of replaying the wrong caller's context. - HTTP `/v1/responses` and HTTP `/backend-api/codex/responses` now use a server-side upstream websocket session bridge by default so repeated compatible requests can keep upstream response/session continuity without forcing clients onto the public websocket route. - Codex-affinity HTTP bridge sessions can optionally use a conservative first-request prewarm (`generate=false`), but that behavior now stays behind an explicit flag so production defaults do not pay an extra upstream request unless operators opt in. - When operators configure a multi-instance bridge ring, each stable bridge key now has a deterministic owner replica; non-owner replicas fail closed with `bridge_instance_mismatch` instead of silently creating fragmented continuity on the wrong host. Unstable per-request bridge keys remain local and are allowed on any replica because there is no continuity to preserve. @@ -46,9 +48,10 @@ See `openspec/specs/responses-api-compat/spec.md` for normative requirements. - **Stream ends without terminal event:** Emit `response.failed` with `stream_incomplete`. - **Upstream error / no accounts:** Non-streaming responses return an OpenAI error envelope with 5xx status. - **Compact upstream transport/client failure:** Retry only inside `/codex/responses/compact` when the failure is safely retryable; otherwise return an explicit upstream error without surrogate fallback. -- **HTTP bridge session closes or expires:** The next compatible HTTP `/v1/responses` or `/backend-api/codex/responses` request recreates a fresh upstream websocket bridge session; continuity is guaranteed only within the lifetime of one active bridged session. +- **HTTP bridge session closes or expires:** The next compatible HTTP `/v1/responses` or `/backend-api/codex/responses` request can replay from persisted caller-scoped snapshots when `previous_response_id` is known; otherwise continuity still fails closed with `previous_response_not_found`. - **Multi-instance routing without bridge owner policy:** if operators do not configure a bridge ring or front-door affinity, continuity can still fragment across replicas; with a configured bridge ring, wrong-replica requests now fail closed instead of silently forking bridge state. - **Codex websocket reconnects:** Reconnect continuity now depends on the client replaying the accepted `x-codex-turn-state`; generated turn-state is emitted on accept for backend Codex routes and echoed back when the client already supplies one. +- **Early upstream websocket disconnects:** when the upstream drops before `response.created` and only one request is pending, the proxy retries that request once on another eligible account; later disconnects still surface as stream failures. - **Websocket handshake forbidden/not-found:** Auto transport now fails loud on `403` / `404` instead of silently hiding the websocket regression behind HTTP fallback. - **Invalid request payloads:** Return 4xx with `invalid_request_error`. diff --git a/openspec/specs/responses-api-compat/spec.md b/openspec/specs/responses-api-compat/spec.md index 8d4750de..acb524e4 100644 --- a/openspec/specs/responses-api-compat/spec.md +++ b/openspec/specs/responses-api-compat/spec.md @@ -16,7 +16,7 @@ The service MUST accept POST requests to `/v1/responses` with a JSON body and MU - **THEN** the service returns a 4xx response with an OpenAI error envelope describing the invalid parameter ### Requirement: Support Responses input types and conversation constraints -The service MUST accept `input` as either a string or an array of input items. When `input` is a string, the service MUST normalize it into a single user input item with `input_text` content before forwarding upstream. The service MUST accept `previous_response_id` when `conversation` is absent and MUST continue to reject requests that include both `conversation` and `previous_response_id`. +The service MUST accept `input` as either a string or an array of input items. When `input` is a string, the service MUST normalize it into a single user input item with `input_text` content before forwarding upstream. The service MUST accept `previous_response_id` when `conversation` is absent, MUST prefer live upstream continuity when available, and otherwise MUST attempt caller-scoped replay from persisted response snapshots before forwarding upstream. The service MUST continue to reject requests that include both `conversation` and `previous_response_id`. #### Scenario: String input - **WHEN** the client sends `input` as a string @@ -32,7 +32,13 @@ The service MUST accept `input` as either a string or an array of input items. W #### Scenario: previous_response_id provided - **WHEN** the client provides `previous_response_id` without `conversation` -- **THEN** the service accepts the request and forwards `previous_response_id` upstream unchanged +- **THEN** the service accepts the request +- **AND** it either preserves live upstream continuity for that response id or rebuilds the request locally from persisted continuity state before forwarding upstream + +#### Scenario: previous_response_id is unknown for the caller scope +- **WHEN** the client provides `previous_response_id` without `conversation` +- **AND** the service cannot find matching continuity state for that same caller scope +- **THEN** the service returns a 4xx OpenAI error envelope on `previous_response_id` ### Requirement: Reject input_file file_id in Responses The service MUST reject `input_file.file_id` in Responses input items and return a 4xx OpenAI invalid_request_error with message "Invalid request payload". @@ -212,10 +218,18 @@ When serving HTTP `/v1/responses` or HTTP `/backend-api/codex/responses`, the se - **WHEN** a client sends a later HTTP `/backend-api/codex/responses` request with `previous_response_id` that references a response created earlier on the same bridged session - **THEN** the service forwards that request through the same upstream websocket session so upstream can resolve the referenced prior response -#### Scenario: HTTP previous_response_id fails closed when bridged continuity is unavailable +#### Scenario: HTTP previous_response_id falls back to persisted replay when bridged continuity is unavailable +- **WHEN** a client sends HTTP `/v1/responses` or `/backend-api/codex/responses` with `previous_response_id` +- **AND** there is no matching live bridged upstream websocket session for that continuity key +- **AND** caller-scoped persisted continuity state exists for that response id +- **THEN** the service rebuilds the request locally from persisted snapshots and completes it through a fresh upstream request +- **AND** it MUST NOT open a replacement bridge session solely to forward the old `previous_response_id` upstream unchanged + +#### Scenario: HTTP previous_response_id still fails closed when persisted continuity is unavailable - **WHEN** a client sends HTTP `/v1/responses` or `/backend-api/codex/responses` with `previous_response_id` - **AND** there is no matching live bridged upstream websocket session for that continuity key -- **THEN** the service MUST fail the request without opening a fresh upstream session +- **AND** caller-scoped persisted continuity state does not exist for that response id +- **THEN** the service MUST fail the request without opening a fresh upstream session that forwards `previous_response_id` upstream - **AND** it MUST return `previous_response_not_found` on `previous_response_id` #### Scenario: bridged HTTP requests keep external HTTP transport logging @@ -274,6 +288,19 @@ When serving websocket Responses endpoints, the service MUST advertise an `x-cod - **THEN** the websocket accept response echoes that same turn-state - **AND** the proxy uses that same turn-state as the Codex session affinity key +### Requirement: Websocket responses retry one early-disconnect request +When an upstream Responses websocket disconnects before `response.created`, the service MUST retry at most one in-flight request on another eligible account when exactly one request is pending and that request has not yet been acknowledged upstream. The retry MUST preserve the downstream websocket contract and SHOULD prefer the account associated with persisted continuity when one is known. + +#### Scenario: upstream disconnects before response.created +- **WHEN** exactly one websocket `response.create` request is pending +- **AND** the upstream disconnects before emitting `response.created` +- **THEN** the service retries that request once on another eligible account +- **AND** the downstream websocket still receives the eventual terminal response events for the retried request + +#### Scenario: upstream disconnect after response.created does not trigger replay retry +- **WHEN** the upstream disconnects after emitting `response.created` +- **THEN** the service does not silently replay the same request again on another account + ### Requirement: Auto websocket fallback remains narrow and explicit When automatic upstream transport selection prefers websocket, the service MUST only downgrade to HTTP automatically on `426 Upgrade Required`. Handshake failures such as `403 Forbidden` or `404 Not Found` MUST surface as upstream errors instead of silently falling back to HTTP. diff --git a/tests/integration/test_http_responses_bridge.py b/tests/integration/test_http_responses_bridge.py index a34bddef..28e7ca0d 100644 --- a/tests/integration/test_http_responses_bridge.py +++ b/tests/integration/test_http_responses_bridge.py @@ -98,6 +98,16 @@ async def _get_account(account_id: str) -> Account: return account +def _completed_response_event(response_id: str, *, text: str = "OK") -> str: + return ( + 'data: {"type":"response.completed","response":{"id":"' + + response_id + + '","object":"response","status":"completed","output":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"' + + text + + '"}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2,"input_tokens_details":{"cached_tokens":0},"output_tokens_details":{"reasoning_tokens":0}}}}\n\n' + ) + + class _SettingsCache: def __init__(self, settings: object) -> None: self._settings = settings @@ -2301,7 +2311,7 @@ async def fake_connect_responses_websocket( @pytest.mark.asyncio -async def test_v1_responses_http_bridge_requires_live_session_for_previous_response_id(async_client, monkeypatch): +async def test_v1_responses_http_bridge_replays_previous_response_after_bridge_loss(async_client, monkeypatch): _install_bridge_settings(monkeypatch, enabled=True) account_id = await _import_account( async_client, @@ -2311,6 +2321,7 @@ async def test_v1_responses_http_bridge_requires_live_session_for_previous_respo account = await _get_account(account_id) fake_upstream = _FakeBridgeUpstreamWebSocket() connect_count = 0 + seen_inputs: list[object] = [] async def fake_select_account_with_budget( self, @@ -2362,9 +2373,15 @@ async def fake_connect_responses_websocket( connect_count += 1 return fake_upstream + async def fake_legacy_stream(payload, headers, access_token, account_id, base_url=None, raise_for_status=False, **_kw): + del headers, access_token, account_id, base_url, raise_for_status, _kw + seen_inputs.append(payload.input) + yield _completed_response_event("resp_bridge_replayed") + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) + monkeypatch.setattr(proxy_module, "core_stream_responses", fake_legacy_stream) first = await async_client.post( "/v1/responses", @@ -2389,19 +2406,16 @@ async def fake_connect_responses_websocket( }, ) - assert second.status_code == 400 - assert second.json() == { - "error": { - "message": ( - f"Previous response with id '{first_body['id']}' not found. " - "HTTP bridge continuity was lost. Replay x-codex-turn-state or retry with a stable prompt_cache_key." - ), - "type": "invalid_request_error", - "code": "previous_response_not_found", - "param": "previous_response_id", - } - } + assert second.status_code == 200 + assert second.json()["id"] == "resp_bridge_replayed" assert connect_count == 1 + assert seen_inputs == [ + [ + {"role": "user", "content": [{"type": "input_text", "text": "hello"}]}, + {"role": "assistant", "content": [{"type": "output_text", "text": "OK"}]}, + {"role": "user", "content": [{"type": "input_text", "text": "hello-again"}]}, + ] + ] @pytest.mark.asyncio @@ -2893,7 +2907,10 @@ async def fail_legacy_stream(*args, **kwargs): @pytest.mark.asyncio -async def test_v1_responses_http_bridge_does_not_open_fresh_session_for_previous_response_id(async_client, monkeypatch): +async def test_v1_responses_http_bridge_replays_previous_response_without_opening_fresh_session( + async_client, + monkeypatch, +): _install_bridge_settings(monkeypatch, enabled=True) account_id = await _import_account( async_client, @@ -2905,6 +2922,7 @@ async def test_v1_responses_http_bridge_does_not_open_fresh_session_for_previous second_upstream = _FakeBridgeUpstreamWebSocket() upstreams = [first_upstream, second_upstream] connect_count = 0 + seen_inputs: list[object] = [] async def fake_select_account_with_budget( self, @@ -2957,9 +2975,15 @@ async def fake_connect_responses_websocket( connect_count += 1 return upstream + async def fake_legacy_stream(payload, headers, access_token, account_id, base_url=None, raise_for_status=False, **_kw): + del headers, access_token, account_id, base_url, raise_for_status, _kw + seen_inputs.append(payload.input) + yield _completed_response_event("resp_bridge_replayed") + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) + monkeypatch.setattr(proxy_module, "core_stream_responses", fake_legacy_stream) first = await async_client.post( "/v1/responses", @@ -2984,19 +3008,16 @@ async def fake_connect_responses_websocket( }, ) - assert second.status_code == 400 - assert second.json() == { - "error": { - "message": ( - f"Previous response with id '{first_body['id']}' not found. " - "HTTP bridge continuity was lost. Replay x-codex-turn-state or retry with a stable prompt_cache_key." - ), - "type": "invalid_request_error", - "code": "previous_response_not_found", - "param": "previous_response_id", - } - } + assert second.status_code == 200 + assert second.json()["id"] == "resp_bridge_replayed" assert connect_count == 1 + assert seen_inputs == [ + [ + {"role": "user", "content": [{"type": "input_text", "text": "hello"}]}, + {"role": "assistant", "content": [{"type": "output_text", "text": "OK"}]}, + {"role": "user", "content": [{"type": "input_text", "text": "hello-again"}]}, + ] + ] @pytest.mark.asyncio @@ -3985,7 +4006,7 @@ async def fake_connect_responses_websocket( @pytest.mark.asyncio -async def test_v1_responses_http_bridge_send_failure_returns_previous_response_not_found( +async def test_v1_responses_http_bridge_send_failure_replays_previous_response_from_snapshot( async_client, app_instance, monkeypatch, @@ -4000,6 +4021,7 @@ async def test_v1_responses_http_bridge_send_failure_returns_previous_response_n fake_upstream = _FakeBridgeUpstreamWebSocket() failing_upstream = _FailingSendThenCloseUpstreamWebSocket() connect_count = 0 + seen_inputs: list[object] = [] async def fake_select_account_with_budget( self, @@ -4051,9 +4073,15 @@ async def fake_connect_responses_websocket( connect_count += 1 return fake_upstream if connect_count == 1 else failing_upstream + async def fake_legacy_stream(payload, headers, access_token, account_id, base_url=None, raise_for_status=False, **_kw): + del headers, access_token, account_id, base_url, raise_for_status, _kw + seen_inputs.append(payload.input) + yield _completed_response_event("resp_bridge_replayed") + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) + monkeypatch.setattr(proxy_module, "core_stream_responses", fake_legacy_stream) first = await async_client.post( "/v1/responses", @@ -4083,24 +4111,20 @@ async def fake_connect_responses_websocket( }, ) - assert second.status_code == 400 - assert second.json() == { - "error": { - "message": ( - f"Previous response with id '{first_body['id']}' not found. " - "HTTP bridge continuity was lost before the request reached upstream. " - "Replay x-codex-turn-state or retry with a stable prompt_cache_key." - ), - "type": "invalid_request_error", - "code": "previous_response_not_found", - "param": "previous_response_id", - } - } + assert second.status_code == 200 + assert second.json()["id"] == "resp_bridge_replayed" assert connect_count == 1 + assert seen_inputs == [ + [ + {"role": "user", "content": [{"type": "input_text", "text": "hello"}]}, + {"role": "assistant", "content": [{"type": "output_text", "text": "OK"}]}, + {"role": "user", "content": [{"type": "input_text", "text": "hello-again"}]}, + ] + ] @pytest.mark.asyncio -async def test_v1_responses_http_bridge_precreated_disconnect_returns_previous_response_not_found( +async def test_v1_responses_http_bridge_precreated_disconnect_replays_previous_response_from_snapshot( async_client, app_instance, monkeypatch, @@ -4115,6 +4139,7 @@ async def test_v1_responses_http_bridge_precreated_disconnect_returns_previous_r fake_upstream = _FakeBridgeUpstreamWebSocket() precreated_close_upstream = _PrecreatedCloseUpstreamWebSocket() connect_count = 0 + seen_inputs: list[object] = [] async def fake_select_account_with_budget( self, @@ -4166,9 +4191,15 @@ async def fake_connect_responses_websocket( connect_count += 1 return fake_upstream if connect_count == 1 else precreated_close_upstream + async def fake_legacy_stream(payload, headers, access_token, account_id, base_url=None, raise_for_status=False, **_kw): + del headers, access_token, account_id, base_url, raise_for_status, _kw + seen_inputs.append(payload.input) + yield _completed_response_event("resp_bridge_replayed") + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) monkeypatch.setattr(proxy_module, "connect_responses_websocket", fake_connect_responses_websocket) + monkeypatch.setattr(proxy_module, "core_stream_responses", fake_legacy_stream) first = await async_client.post( "/v1/responses", @@ -4198,20 +4229,16 @@ async def fake_connect_responses_websocket( }, ) - assert second.status_code == 400 - assert second.json() == { - "error": { - "message": ( - f"Previous response with id '{first_body['id']}' not found. " - "HTTP bridge continuity was lost before upstream created the next response. " - "Replay x-codex-turn-state or retry with a stable prompt_cache_key." - ), - "type": "invalid_request_error", - "code": "previous_response_not_found", - "param": "previous_response_id", - } - } + assert second.status_code == 200 + assert second.json()["id"] == "resp_bridge_replayed" assert connect_count == 1 + assert seen_inputs == [ + [ + {"role": "user", "content": [{"type": "input_text", "text": "hello"}]}, + {"role": "assistant", "content": [{"type": "output_text", "text": "OK"}]}, + {"role": "user", "content": [{"type": "input_text", "text": "hello-again"}]}, + ] + ] @pytest.mark.asyncio diff --git a/tests/integration/test_load_balancer_integration.py b/tests/integration/test_load_balancer_integration.py index 19bb5c07..1b2353dc 100644 --- a/tests/integration/test_load_balancer_integration.py +++ b/tests/integration/test_load_balancer_integration.py @@ -14,6 +14,7 @@ from app.modules.api_keys.repository import ApiKeysRepository from app.modules.proxy.load_balancer import LoadBalancer from app.modules.proxy.repo_bundle import ProxyRepositories +from app.modules.proxy.response_snapshots_repository import ResponseSnapshotsRepository from app.modules.proxy.sticky_repository import StickySessionsRepository from app.modules.request_logs.repository import RequestLogsRepository from app.modules.usage.repository import AdditionalUsageRepository, UsageRepository @@ -31,6 +32,7 @@ async def _repo_factory() -> AsyncIterator[ProxyRepositories]: sticky_sessions=StickySessionsRepository(session), api_keys=ApiKeysRepository(session), additional_usage=AdditionalUsageRepository(session), + response_snapshots=ResponseSnapshotsRepository(session), ) @@ -427,3 +429,118 @@ async def test_load_balancer_filters_accounts_by_persisted_additional_usage(db_s assert selection.account is not None assert selection.account.id == eligible_account.id + + +@pytest.mark.asyncio +async def test_load_balancer_prefers_preferred_account_when_eligible(db_setup): + encryptor = TokenEncryptor() + now = utcnow() + now_epoch = int(now.replace(tzinfo=timezone.utc).timestamp()) + reset_at = now_epoch + 3600 + + preferred = Account( + id="acc_preferred_hit", + email="preferred-hit@example.com", + plan_type="plus", + access_token_encrypted=encryptor.encrypt("preferred-hit-access"), + refresh_token_encrypted=encryptor.encrypt("preferred-hit-refresh"), + id_token_encrypted=encryptor.encrypt("preferred-hit-id"), + last_refresh=now, + status=AccountStatus.ACTIVE, + deactivation_reason=None, + ) + other = Account( + id="acc_preferred_other", + email="preferred-other@example.com", + plan_type="plus", + access_token_encrypted=encryptor.encrypt("other-access"), + refresh_token_encrypted=encryptor.encrypt("other-refresh"), + id_token_encrypted=encryptor.encrypt("other-id"), + last_refresh=now, + status=AccountStatus.ACTIVE, + deactivation_reason=None, + ) + + async with SessionLocal() as session: + accounts_repo = AccountsRepository(session) + usage_repo = UsageRepository(session) + await accounts_repo.upsert(preferred) + await accounts_repo.upsert(other) + await usage_repo.add_entry( + account_id=preferred.id, + used_percent=85.0, + window="primary", + reset_at=reset_at, + window_minutes=300, + ) + await usage_repo.add_entry( + account_id=other.id, + used_percent=5.0, + window="primary", + reset_at=reset_at, + window_minutes=300, + ) + + balancer = LoadBalancer(_repo_factory) + selection = await balancer.select_account(preferred_account_id=preferred.id) + + assert selection.account is not None + assert selection.account.id == preferred.id + + +@pytest.mark.asyncio +async def test_load_balancer_falls_back_when_preferred_account_unavailable(db_setup): + encryptor = TokenEncryptor() + now = utcnow() + now_epoch = int(now.replace(tzinfo=timezone.utc).timestamp()) + reset_at = now_epoch + 3600 + + preferred = Account( + id="acc_preferred_blocked", + email="preferred-blocked@example.com", + plan_type="plus", + access_token_encrypted=encryptor.encrypt("preferred-blocked-access"), + refresh_token_encrypted=encryptor.encrypt("preferred-blocked-refresh"), + id_token_encrypted=encryptor.encrypt("preferred-blocked-id"), + last_refresh=now, + status=AccountStatus.RATE_LIMITED, + deactivation_reason=None, + reset_at=reset_at, + ) + fallback = Account( + id="acc_preferred_fallback", + email="preferred-fallback@example.com", + plan_type="plus", + access_token_encrypted=encryptor.encrypt("fallback-access"), + refresh_token_encrypted=encryptor.encrypt("fallback-refresh"), + id_token_encrypted=encryptor.encrypt("fallback-id"), + last_refresh=now, + status=AccountStatus.ACTIVE, + deactivation_reason=None, + ) + + async with SessionLocal() as session: + accounts_repo = AccountsRepository(session) + usage_repo = UsageRepository(session) + await accounts_repo.upsert(preferred) + await accounts_repo.upsert(fallback) + await usage_repo.add_entry( + account_id=preferred.id, + used_percent=10.0, + window="primary", + reset_at=reset_at, + window_minutes=300, + ) + await usage_repo.add_entry( + account_id=fallback.id, + used_percent=40.0, + window="primary", + reset_at=reset_at, + window_minutes=300, + ) + + balancer = LoadBalancer(_repo_factory) + selection = await balancer.select_account(preferred_account_id=preferred.id) + + assert selection.account is not None + assert selection.account.id == fallback.id diff --git a/tests/integration/test_migrations.py b/tests/integration/test_migrations.py index 6dd6443f..e141c82a 100644 --- a/tests/integration/test_migrations.py +++ b/tests/integration/test_migrations.py @@ -1,6 +1,7 @@ from __future__ import annotations import pytest +import sqlalchemy as sa from sqlalchemy import text from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine @@ -29,6 +30,7 @@ from app.db.models import Account, AccountStatus from app.db.session import SessionLocal from app.modules.accounts.repository import AccountsRepository +from app.modules.proxy.response_snapshots_repository import ResponseSnapshotsRepository try: from app.db.migrate import check_migration_policy @@ -87,6 +89,70 @@ async def test_run_startup_migrations_preserves_unknown_plan_types(db_setup): assert rerun.current_revision == _HEAD_REVISION +@pytest.mark.asyncio +async def test_run_startup_migrations_creates_response_snapshots_table(db_setup): + result = await run_startup_migrations(_DATABASE_URL) + assert result.current_revision == _HEAD_REVISION + + async_engine = create_async_engine(_DATABASE_URL) + try: + async with async_engine.begin() as conn: + columns = await conn.run_sync(lambda sync_conn: sa.inspect(sync_conn).get_columns("response_snapshots")) + finally: + await async_engine.dispose() + + assert [column["name"] for column in columns] == [ + "response_id", + "parent_response_id", + "account_id", + "api_key_id", + "model", + "input_items_json", + "response_json", + "created_at", + ] + + +@pytest.mark.asyncio +async def test_response_snapshots_repository_scopes_by_api_key_and_preserves_created_at(db_setup): + async with SessionLocal() as session: + repo = ResponseSnapshotsRepository(session) + inserted = await repo.upsert( + response_id="resp_scoped", + parent_response_id=None, + account_id="acc_a", + api_key_id="key_a", + model="gpt-5.2", + input_items=[{"role": "user", "content": [{"type": "input_text", "text": "hello"}]}], + response_payload={ + "id": "resp_scoped", + "status": "completed", + "output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "hi"}]}], + }, + ) + created_at = inserted.created_at + + updated = await repo.upsert( + response_id="resp_scoped", + parent_response_id="resp_parent", + account_id="acc_a", + api_key_id="key_a", + model="gpt-5.2", + input_items=[{"role": "user", "content": [{"type": "input_text", "text": "hello again"}]}], + response_payload={ + "id": "resp_scoped", + "status": "completed", + "output": [{"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "hi again"}]}], + }, + ) + + assert updated.created_at == created_at + assert updated.parent_response_id == "resp_parent" + assert await repo.get("resp_scoped", api_key_id="key_a") is not None + assert await repo.get("resp_scoped", api_key_id="key_b") is None + assert await repo.get("resp_scoped", api_key_id=None) is None + + @pytest.mark.asyncio async def test_run_startup_migrations_bootstraps_legacy_history(db_setup): async with SessionLocal() as session: diff --git a/tests/integration/test_openai_compat_features.py b/tests/integration/test_openai_compat_features.py index 6ceace26..ffa30900 100644 --- a/tests/integration/test_openai_compat_features.py +++ b/tests/integration/test_openai_compat_features.py @@ -95,31 +95,144 @@ async def test_v1_responses_rejects_input_file_id(async_client): @pytest.mark.asyncio -async def test_v1_responses_accepts_previous_response_id(async_client, monkeypatch): - await _import_account(async_client, "acc_prev_response_id", "prev-response-id@example.com") - seen_previous_response_ids: list[str | None] = [] +async def test_v1_responses_replays_previous_response_after_restart(async_client, app_instance, monkeypatch): + await _import_account(async_client, "acc_prev_response", "prev-response@example.com") + + seen_inputs: list[object] = [] + + async def fake_stream(payload, headers, access_token, account_id, base_url=None, raise_for_status=False, **_kwargs): + del headers, access_token, account_id, base_url, raise_for_status, _kwargs + seen_inputs.append(payload.input) + if len(seen_inputs) == 1: + yield ( + 'data: {"type":"response.output_item.done","output_index":0,' + '"item":{"id":"msg_prev","type":"message","role":"assistant",' + '"content":[{"type":"output_text","text":"Prior answer"}]}}\n\n' + ) + yield _completed_event("resp_prev") + return + yield _completed_event("resp_followup") - async def fake_stream(payload, headers, access_token, account_id, base_url=None, raise_for_status=False, **_kw): - del headers, access_token, account_id, base_url, raise_for_status, _kw - seen_previous_response_ids.append(getattr(payload, "previous_response_id", None)) - yield 'data: {"type":"response.completed","response":{"id":"resp_abc123"}}\n\n' + monkeypatch.setattr(proxy_module, "core_stream_responses", fake_stream) + + first = await async_client.post("/v1/responses", json={"model": "gpt-5.2", "input": "Hello"}) + assert first.status_code == 200 + + if hasattr(app_instance.state, "proxy_service"): + delattr(app_instance.state, "proxy_service") + + second = await async_client.post( + "/v1/responses", + json={ + "model": "gpt-5.2", + "previous_response_id": "resp_prev", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}], + }, + ) + assert second.status_code == 200 + assert seen_inputs == [ + [{"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}], + [ + {"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}, + {"role": "assistant", "content": [{"type": "output_text", "text": "Prior answer"}]}, + {"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}, + ], + ] + + +@pytest.mark.asyncio +async def test_v1_responses_previous_response_id_is_scoped_to_api_key(async_client, app_instance, monkeypatch): + await _import_account(async_client, "acc_prev_scoped", "prev-scoped@example.com") + + enable = await async_client.put( + "/api/settings", + json={ + "stickyThreadsEnabled": False, + "preferEarlierResetAccounts": False, + "totpRequiredOnLogin": False, + "apiKeyAuthEnabled": True, + }, + ) + assert enable.status_code == 200 + + created_a = await async_client.post( + "/api/api-keys/", + json={"name": "response-key-a"}, + ) + assert created_a.status_code == 200 + key_a = created_a.json()["key"] + + created_b = await async_client.post( + "/api/api-keys/", + json={"name": "response-key-b"}, + ) + assert created_b.status_code == 200 + key_b = created_b.json()["key"] + + seen_inputs: list[object] = [] + + async def fake_stream(payload, headers, access_token, account_id, base_url=None, raise_for_status=False, **_kwargs): + del headers, access_token, account_id, base_url, raise_for_status, _kwargs + seen_inputs.append(payload.input) + if len(seen_inputs) == 1: + yield ( + 'data: {"type":"response.output_item.done","output_index":0,' + '"item":{"id":"msg_prev","type":"message","role":"assistant",' + '"content":[{"type":"output_text","text":"Prior answer"}]}}\n\n' + ) + yield _completed_event("resp_prev_scoped") + return + yield _completed_event("resp_followup_scoped") monkeypatch.setattr(proxy_module, "core_stream_responses", fake_stream) - payload = { - "model": "gpt-5.2", - "previous_response_id": "resp_abc123", - "input": [ - { - "role": "user", - "content": [{"type": "input_text", "text": "Continue."}], - } + first = await async_client.post( + "/v1/responses", + headers={"Authorization": f"Bearer {key_a}"}, + json={"model": "gpt-5.2", "input": "Hello"}, + ) + assert first.status_code == 200 + + if hasattr(app_instance.state, "proxy_service"): + delattr(app_instance.state, "proxy_service") + + second = await async_client.post( + "/v1/responses", + headers={"Authorization": f"Bearer {key_a}"}, + json={ + "model": "gpt-5.2", + "previous_response_id": "resp_prev_scoped", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}], + }, + ) + assert second.status_code == 200 + + if hasattr(app_instance.state, "proxy_service"): + delattr(app_instance.state, "proxy_service") + + blocked = await async_client.post( + "/v1/responses", + headers={"Authorization": f"Bearer {key_b}"}, + json={ + "model": "gpt-5.2", + "previous_response_id": "resp_prev_scoped", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}], + }, + ) + assert blocked.status_code == 400 + error = blocked.json()["error"] + assert error["type"] == "invalid_request_error" + assert error["param"] == "previous_response_id" + assert error["message"] == "Unknown previous_response_id" + + assert seen_inputs == [ + [{"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}], + [ + {"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}, + {"role": "assistant", "content": [{"type": "output_text", "text": "Prior answer"}]}, + {"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}, ], - "stream": True, - } - resp = await async_client.post("/v1/responses", json=payload) - assert resp.status_code == 200 - assert seen_previous_response_ids == ["resp_abc123"] + ] @pytest.mark.asyncio diff --git a/tests/integration/test_proxy_websocket_responses.py b/tests/integration/test_proxy_websocket_responses.py index 39c329fa..4687c81d 100644 --- a/tests/integration/test_proxy_websocket_responses.py +++ b/tests/integration/test_proxy_websocket_responses.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import base64 import json from collections import deque from types import SimpleNamespace @@ -94,6 +95,460 @@ def _websocket_settings(**overrides): return SimpleNamespace(**values) +def test_v1_responses_websocket_replays_previous_response_after_restart(app_instance, monkeypatch): + first_upstream = _FakeUpstreamWebSocket( + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_prev", "object": "response", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.output_item.done", + "output_index": 0, + "item": { + "id": "msg_ws_prev", + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Prior answer"}], + }, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_ws_prev", + "object": "response", + "status": "completed", + "usage": {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}, + }, + }, + separators=(",", ":"), + ), + ), + ] + ) + second_upstream = _FakeUpstreamWebSocket( + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_next", "object": "response", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": { + "id": "resp_ws_next", + "object": "response", + "status": "completed", + "usage": {"input_tokens": 2, "output_tokens": 1, "total_tokens": 3}, + }, + }, + separators=(",", ":"), + ), + ), + ] + ) + upstreams = deque([first_upstream, second_upstream]) + connect_calls: list[dict[str, object]] = [] + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + preferred_account_id=None, + ): + del self, headers, sticky_key, sticky_kind, reallocate_sticky, sticky_max_age_seconds + del prefer_earlier_reset, routing_strategy, model, api_key, client_send_lock, websocket + connect_calls.append( + { + "request_id": request_state.request_id, + "preferred_account_id": preferred_account_id, + } + ) + return SimpleNamespace(id="acct_ws_prev"), upstreams.popleft() + + async def fake_write_request_log(self, **kwargs): + del self, kwargs + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr(proxy_module.ProxyService, "_write_request_log", fake_write_request_log) + + with TestClient(app_instance) as client: + with client.websocket_connect("/v1/responses") as websocket: + websocket.send_text(json.dumps({"type": "response.create", "model": "gpt-5.2", "input": "Hello"})) + assert json.loads(websocket.receive_text())["type"] == "response.created" + assert json.loads(websocket.receive_text())["type"] == "response.output_item.done" + assert json.loads(websocket.receive_text())["type"] == "response.completed" + + if hasattr(app_instance.state, "proxy_service"): + delattr(app_instance.state, "proxy_service") + + with client.websocket_connect("/v1/responses") as websocket: + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.2", + "previous_response_id": "resp_ws_prev", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}], + } + ) + ) + assert json.loads(websocket.receive_text())["type"] == "response.created" + assert json.loads(websocket.receive_text())["type"] == "response.completed" + + assert connect_calls[1]["preferred_account_id"] == "acct_ws_prev" + assert [json.loads(message) for message in second_upstream.sent_text] == [ + { + "type": "response.create", + "model": "gpt-5.2", + "instructions": "", + "input": [ + {"role": "user", "content": [{"type": "input_text", "text": "Hello"}]}, + {"role": "assistant", "content": [{"type": "output_text", "text": "Prior answer"}]}, + {"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}, + ], + "tools": [], + "store": False, + "include": [], + } + ] + + +def test_v1_responses_websocket_previous_response_id_is_scoped_to_api_key(app_instance, monkeypatch): + first_upstream = _FakeUpstreamWebSocket( + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.created", + "response": {"id": "resp_ws_scoped", "object": "response", "status": "in_progress"}, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.output_item.done", + "output_index": 0, + "item": { + "id": "msg_ws_scoped", + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": "Prior answer"}], + }, + }, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + { + "type": "response.completed", + "response": {"id": "resp_ws_scoped", "object": "response", "status": "completed"}, + }, + separators=(",", ":"), + ), + ), + ] + ) + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + upstreams = deque([first_upstream]) + + async def fake_connect_proxy_websocket( + self, + headers, + *, + sticky_key, + sticky_kind, + prefer_earlier_reset, + reallocate_sticky, + sticky_max_age_seconds, + routing_strategy, + model, + request_state, + api_key, + client_send_lock, + websocket, + preferred_account_id=None, + ): + del self, headers, sticky_key, sticky_kind, prefer_earlier_reset + del reallocate_sticky, sticky_max_age_seconds, routing_strategy, model + del api_key, client_send_lock, websocket, preferred_account_id + return SimpleNamespace(id="acct_ws_scoped"), upstreams.popleft() + + async def fake_write_request_log(self, **kwargs): + del self, kwargs + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_connect_proxy_websocket", fake_connect_proxy_websocket) + monkeypatch.setattr(proxy_module.ProxyService, "_write_request_log", fake_write_request_log) + + with TestClient(app_instance) as client: + enable = client.put( + "/api/settings", + json={ + "stickyThreadsEnabled": False, + "preferEarlierResetAccounts": False, + "totpRequiredOnLogin": False, + "apiKeyAuthEnabled": True, + }, + ) + assert enable.status_code == 200 + + imported = client.post( + "/api/accounts/import", + files={ + "auth_json": ( + "auth.json", + json.dumps( + { + "tokens": { + "idToken": "header." + + base64.urlsafe_b64encode( + json.dumps( + { + "email": "ws-scoped@example.com", + "chatgpt_account_id": "acc_ws_scoped", + "https://api.openai.com/auth": {"chatgpt_plan_type": "plus"}, + }, + separators=(",", ":"), + ).encode("utf-8") + ).rstrip(b"=").decode("ascii") + + ".sig", + "accessToken": "access-token", + "refreshToken": "refresh-token", + "accountId": "acc_ws_scoped", + } + }, + separators=(",", ":"), + ), + "application/json", + ) + }, + ) + assert imported.status_code == 200 + + created_a = client.post("/api/api-keys/", json={"name": "ws-key-a"}) + assert created_a.status_code == 200 + key_a = created_a.json()["key"] + + created_b = client.post("/api/api-keys/", json={"name": "ws-key-b"}) + assert created_b.status_code == 200 + key_b = created_b.json()["key"] + + with client.websocket_connect("/v1/responses", headers={"Authorization": f"Bearer {key_a}"}) as websocket: + websocket.send_text(json.dumps({"type": "response.create", "model": "gpt-5.2", "input": "Hello"})) + assert json.loads(websocket.receive_text())["type"] == "response.created" + assert json.loads(websocket.receive_text())["type"] == "response.output_item.done" + assert json.loads(websocket.receive_text())["type"] == "response.completed" + + if hasattr(app_instance.state, "proxy_service"): + delattr(app_instance.state, "proxy_service") + + with client.websocket_connect("/v1/responses", headers={"Authorization": f"Bearer {key_b}"}) as websocket: + websocket.send_text( + json.dumps( + { + "type": "response.create", + "model": "gpt-5.2", + "previous_response_id": "resp_ws_scoped", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "Continue"}]}], + } + ) + ) + event = json.loads(websocket.receive_text()) + + assert event["type"] == "error" + assert event["status"] == 400 + assert event["error"]["type"] == "invalid_request_error" + assert event["error"]["param"] == "previous_response_id" + assert event["error"]["message"] == "Unknown previous_response_id" + + +def test_backend_responses_websocket_retries_once_after_upstream_eof_before_response_created(app_instance, monkeypatch): + first_upstream = _FakeUpstreamWebSocket([_FakeUpstreamMessage("close", close_code=1011)]) + second_upstream = _FakeUpstreamWebSocket( + [ + _FakeUpstreamMessage( + "text", + text=json.dumps( + {"type": "response.created", "response": {"id": "resp_ws_retry", "status": "in_progress"}}, + separators=(",", ":"), + ), + ), + _FakeUpstreamMessage( + "text", + text=json.dumps( + {"type": "response.completed", "response": {"id": "resp_ws_retry", "status": "completed"}}, + separators=(",", ":"), + ), + ), + ] + ) + connect_accounts: list[str] = [] + stream_errors: list[tuple[str, str]] = [] + log_calls: list[dict[str, object]] = [] + + class _FakeSettingsCache: + async def get(self): + return _websocket_settings() + + async def allow_firewall(_websocket): + return None + + async def allow_proxy_api_key(_authorization: str | None): + return None + + async def fake_select_account_with_budget( + self, + deadline, + *, + request_id, + kind, + sticky_key=None, + sticky_kind=None, + reallocate_sticky=False, + sticky_max_age_seconds=None, + prefer_earlier_reset_accounts=False, + routing_strategy="usage_weighted", + model=None, + additional_limit_name=None, + preferred_account_id=None, + exclude_account_ids=None, + ): + del ( + self, + deadline, + request_id, + kind, + sticky_key, + sticky_kind, + reallocate_sticky, + sticky_max_age_seconds, + prefer_earlier_reset_accounts, + routing_strategy, + model, + additional_limit_name, + preferred_account_id, + ) + if exclude_account_ids: + assert exclude_account_ids == {"acct_ws_first"} + return SimpleNamespace(account=SimpleNamespace(id="acct_ws_retry"), error_message=None, error_code=None) + return SimpleNamespace(account=SimpleNamespace(id="acct_ws_first"), error_message=None, error_code=None) + + async def fake_ensure_fresh_with_budget(self, account, *, force=False, timeout_seconds=None): + del self, force, timeout_seconds + return account + + async def fake_open_upstream_websocket_with_budget(self, account, headers, *, timeout_seconds): + del self, headers, timeout_seconds + connect_accounts.append(account.id) + if account.id == "acct_ws_first": + return first_upstream + return second_upstream + + async def fake_write_request_log(self, **kwargs): + del self + log_calls.append(kwargs) + + async def fake_handle_stream_error(self, account, error, code): + del error + stream_errors.append((account.id, code)) + + monkeypatch.setattr(proxy_api_module, "_websocket_firewall_denial_response", allow_firewall) + monkeypatch.setattr(proxy_api_module, "validate_proxy_api_key_authorization", allow_proxy_api_key) + monkeypatch.setattr(proxy_module, "get_settings_cache", lambda: _FakeSettingsCache()) + monkeypatch.setattr(proxy_module.ProxyService, "_select_account_with_budget", fake_select_account_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_ensure_fresh_with_budget", fake_ensure_fresh_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_open_upstream_websocket_with_budget", fake_open_upstream_websocket_with_budget) + monkeypatch.setattr(proxy_module.ProxyService, "_write_request_log", fake_write_request_log) + monkeypatch.setattr(proxy_module.ProxyService, "_handle_stream_error", fake_handle_stream_error) + + request_payload = { + "type": "response.create", + "model": "gpt-5.4", + "instructions": "", + "input": [{"role": "user", "content": [{"type": "input_text", "text": "hi"}]}], + "stream": True, + } + + with TestClient(app_instance) as client: + with client.websocket_connect("/backend-api/codex/responses") as websocket: + websocket.send_text(json.dumps(request_payload)) + created_event = json.loads(websocket.receive_text()) + completed_event = json.loads(websocket.receive_text()) + + assert [created_event["type"], completed_event["type"]] == ["response.created", "response.completed"] + assert connect_accounts == ["acct_ws_first", "acct_ws_retry"] + assert stream_errors == [("acct_ws_first", "stream_incomplete")] + assert len(first_upstream.sent_text) == 1 + assert len(second_upstream.sent_text) == 1 + assert json.loads(first_upstream.sent_text[0])["type"] == "response.create" + assert json.loads(second_upstream.sent_text[0])["type"] == "response.create" + assert len(log_calls) == 1 + assert log_calls[0]["request_id"] == "resp_ws_retry" + assert log_calls[0]["status"] == "success" + + def test_backend_responses_websocket_proxies_upstream_and_persists_log(app_instance, monkeypatch): upstream_messages = [ _FakeUpstreamMessage(